compressed-tensors 0.4.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/base.py +1 -0
- compressed_tensors/compressors/__init__.py +5 -1
- compressed_tensors/compressors/base.py +200 -8
- compressed_tensors/compressors/dense.py +1 -1
- compressed_tensors/compressors/marlin_24.py +11 -10
- compressed_tensors/compressors/model_compressor.py +101 -13
- compressed_tensors/compressors/naive_quantized.py +140 -0
- compressed_tensors/compressors/pack_quantized.py +128 -132
- compressed_tensors/compressors/sparse_bitmask.py +1 -1
- compressed_tensors/config/base.py +8 -1
- compressed_tensors/{compressors/utils → linear}/__init__.py +0 -6
- compressed_tensors/linear/compressed_linear.py +87 -0
- compressed_tensors/quantization/lifecycle/__init__.py +1 -0
- compressed_tensors/quantization/lifecycle/apply.py +204 -44
- compressed_tensors/quantization/lifecycle/calibration.py +22 -2
- compressed_tensors/quantization/lifecycle/compressed.py +3 -1
- compressed_tensors/quantization/lifecycle/forward.py +139 -61
- compressed_tensors/quantization/lifecycle/helpers.py +80 -0
- compressed_tensors/quantization/lifecycle/initialize.py +77 -13
- compressed_tensors/quantization/observers/__init__.py +1 -0
- compressed_tensors/quantization/observers/base.py +93 -14
- compressed_tensors/quantization/observers/helpers.py +64 -11
- compressed_tensors/quantization/observers/min_max.py +8 -0
- compressed_tensors/quantization/observers/mse.py +162 -0
- compressed_tensors/quantization/quant_args.py +139 -23
- compressed_tensors/quantization/quant_config.py +35 -2
- compressed_tensors/quantization/quant_scheme.py +112 -13
- compressed_tensors/quantization/utils/helpers.py +68 -2
- compressed_tensors/utils/__init__.py +5 -0
- compressed_tensors/utils/helpers.py +44 -2
- compressed_tensors/utils/offload.py +116 -0
- compressed_tensors/utils/permute.py +70 -0
- compressed_tensors/utils/safetensors_load.py +2 -0
- compressed_tensors/{compressors/utils → utils}/semi_structured_conversions.py +1 -0
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/METADATA +35 -22
- compressed_tensors-0.6.0.dist-info/RECORD +52 -0
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/WHEEL +1 -1
- compressed_tensors/compressors/int_quantized.py +0 -126
- compressed_tensors/compressors/utils/helpers.py +0 -43
- compressed_tensors-0.4.0.dist-info/RECORD +0 -48
- /compressed_tensors/{compressors/utils → utils}/permutations_24.py +0 -0
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,8 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
import logging
|
16
|
+
from math import ceil
|
15
17
|
from typing import Any, Iterable, Optional, Tuple, Union
|
16
18
|
|
17
19
|
import torch
|
@@ -20,10 +22,14 @@ from compressed_tensors.quantization.quant_args import (
|
|
20
22
|
QuantizationStrategy,
|
21
23
|
)
|
22
24
|
from compressed_tensors.registry.registry import RegistryMixin
|
25
|
+
from compressed_tensors.utils import safe_permute
|
23
26
|
from torch import FloatTensor, IntTensor, Tensor
|
24
27
|
from torch.nn import Module
|
25
28
|
|
26
29
|
|
30
|
+
_LOGGER = logging.getLogger(__name__)
|
31
|
+
|
32
|
+
|
27
33
|
__all__ = ["Observer"]
|
28
34
|
|
29
35
|
|
@@ -39,16 +45,21 @@ class Observer(Module, RegistryMixin):
|
|
39
45
|
super().__init__()
|
40
46
|
self._scale = None
|
41
47
|
self._zero_point = None
|
48
|
+
self._num_observed_tokens = None
|
42
49
|
|
43
50
|
@torch.no_grad()
|
44
|
-
def forward(
|
51
|
+
def forward(
|
52
|
+
self, observed: Tensor, g_idx: Optional[Tensor] = None
|
53
|
+
) -> Tuple[FloatTensor, IntTensor]:
|
45
54
|
"""
|
46
55
|
maps directly to get_qparams
|
47
|
-
:param observed: optional observed tensor to calculate
|
48
|
-
|
56
|
+
:param observed: optional observed tensor from which to calculate
|
57
|
+
quantization parameters
|
58
|
+
:param g_idx: optional mapping from column index to group index
|
49
59
|
:return: tuple of scale and zero point based on last observed value
|
50
60
|
"""
|
51
|
-
|
61
|
+
self.record_observed_tokens(observed)
|
62
|
+
return self.get_qparams(observed=observed, g_idx=g_idx)
|
52
63
|
|
53
64
|
def calculate_qparams(
|
54
65
|
self,
|
@@ -71,7 +82,9 @@ class Observer(Module, RegistryMixin):
|
|
71
82
|
...
|
72
83
|
|
73
84
|
def get_qparams(
|
74
|
-
self,
|
85
|
+
self,
|
86
|
+
observed: Optional[Tensor] = None,
|
87
|
+
g_idx: Optional[Tensor] = None,
|
75
88
|
) -> Tuple[FloatTensor, IntTensor]:
|
76
89
|
"""
|
77
90
|
Convenience function to wrap overwritten calculate_qparams
|
@@ -80,6 +93,7 @@ class Observer(Module, RegistryMixin):
|
|
80
93
|
|
81
94
|
:param observed: optional observed tensor to calculate quantization parameters
|
82
95
|
from
|
96
|
+
:param g_idx: optional mapping from column index to group index
|
83
97
|
:return: tuple of scale and zero point based on last observed value
|
84
98
|
"""
|
85
99
|
if observed is not None:
|
@@ -91,20 +105,42 @@ class Observer(Module, RegistryMixin):
|
|
91
105
|
self._scale, self._zero_point = self.calculate_qparams(observed)
|
92
106
|
|
93
107
|
elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
|
108
|
+
rows = observed.shape[0]
|
94
109
|
columns = observed.shape[1]
|
95
|
-
|
96
|
-
|
97
|
-
|
110
|
+
num_groups = int(ceil(columns / group_size))
|
111
|
+
self._scale = torch.empty(
|
112
|
+
(rows, num_groups), dtype=observed.dtype, device=observed.device
|
113
|
+
)
|
114
|
+
zp_dtype = self.quantization_args.pytorch_dtype()
|
115
|
+
self._zero_point = torch.empty(
|
116
|
+
(rows, num_groups), dtype=zp_dtype, device=observed.device
|
117
|
+
)
|
118
|
+
|
119
|
+
# support column-order (default) quantization as well as other orderings
|
120
|
+
# such as activation ordering. Below checks if g_idx has initialized
|
121
|
+
is_column_order = g_idx is None or -1 in g_idx
|
122
|
+
if is_column_order:
|
123
|
+
group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
|
124
|
+
else:
|
125
|
+
group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
|
126
|
+
group_sizes = group_sizes[torch.argsort(group_indices)]
|
127
|
+
|
128
|
+
perm = torch.argsort(g_idx)
|
129
|
+
observed = safe_permute(observed, perm, dim=1)
|
130
|
+
|
131
|
+
# TODO: experiment with vectorizing for loop for performance
|
132
|
+
end = 0
|
133
|
+
for group_index, group_count in enumerate(group_sizes):
|
134
|
+
start = end
|
135
|
+
end = start + group_count
|
98
136
|
scale, zero_point = self.get_qparams_along_dim(
|
99
|
-
observed[:,
|
137
|
+
observed[:, start:end],
|
100
138
|
0,
|
101
|
-
tensor_id=
|
139
|
+
tensor_id=group_index,
|
102
140
|
)
|
103
|
-
scales.append(scale)
|
104
|
-
zero_points.append(zero_point)
|
105
141
|
|
106
|
-
|
107
|
-
|
142
|
+
self._scale[:, group_index] = scale.squeeze(1)
|
143
|
+
self._zero_point[:, group_index] = zero_point.squeeze(1)
|
108
144
|
|
109
145
|
elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
|
110
146
|
# assume observed is transposed, because its the output, hence use dim 0
|
@@ -126,9 +162,52 @@ class Observer(Module, RegistryMixin):
|
|
126
162
|
dim: Union[int, Iterable[int]],
|
127
163
|
tensor_id: Optional[Any] = None,
|
128
164
|
):
|
165
|
+
if isinstance(dim, int):
|
166
|
+
dim = [dim]
|
129
167
|
dim = set(dim)
|
130
168
|
|
131
169
|
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
|
132
170
|
return self.calculate_qparams(
|
133
171
|
observed, reduce_dims=reduce_dims, tensor_id=tensor_id
|
134
172
|
)
|
173
|
+
|
174
|
+
def record_observed_tokens(self, batch_tensor: Tensor):
|
175
|
+
"""
|
176
|
+
Counts the number of tokens observed during the
|
177
|
+
forward passes. The count is aggregated in the
|
178
|
+
_num_observed_tokens attribute of the class.
|
179
|
+
|
180
|
+
Note: The batch_tensor is expected to have two dimensions
|
181
|
+
(batch_size * sequence_length, num_features). This is the
|
182
|
+
general shape expected by the forward pass of the expert
|
183
|
+
layers in a MOE model. If the input tensor does not have
|
184
|
+
two dimensions, the _num_observed_tokens attribute will be set
|
185
|
+
to None.
|
186
|
+
"""
|
187
|
+
if not isinstance(batch_tensor, Tensor):
|
188
|
+
raise ValueError(f"Expected value to be a tensor, got {type(batch_tensor)}")
|
189
|
+
|
190
|
+
if batch_tensor.ndim != 2:
|
191
|
+
_LOGGER.debug(
|
192
|
+
"The input tensor is expected to have two dimensions "
|
193
|
+
"(batch_size * sequence_length, num_features). "
|
194
|
+
f"The input tensor has {batch_tensor.ndim} dimensions."
|
195
|
+
)
|
196
|
+
return
|
197
|
+
|
198
|
+
if self._num_observed_tokens is None:
|
199
|
+
# initialize the count
|
200
|
+
self._num_observed_tokens = 0
|
201
|
+
|
202
|
+
# batch_tensor (batch_size * sequence_length, num_features)
|
203
|
+
# observed_tokens (batch_size * sequence_length)
|
204
|
+
observed_tokens, _ = batch_tensor.shape
|
205
|
+
self._num_observed_tokens += observed_tokens
|
206
|
+
|
207
|
+
def reset(self):
|
208
|
+
"""
|
209
|
+
Reset the state of the observer
|
210
|
+
"""
|
211
|
+
self._num_observed_tokens = None
|
212
|
+
self._scale = None
|
213
|
+
self._zero_point = None
|
@@ -12,23 +12,45 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
from collections import Counter
|
15
16
|
from typing import Tuple
|
16
17
|
|
17
18
|
import torch
|
18
|
-
from compressed_tensors.quantization.quant_args import
|
19
|
+
from compressed_tensors.quantization.quant_args import (
|
20
|
+
FP8_DTYPE,
|
21
|
+
QuantizationArgs,
|
22
|
+
QuantizationType,
|
23
|
+
)
|
19
24
|
from torch import FloatTensor, IntTensor, Tensor
|
20
25
|
|
21
26
|
|
22
|
-
__all__ = ["calculate_qparams"]
|
27
|
+
__all__ = ["calculate_qparams", "get_observer_token_count", "calculate_range"]
|
28
|
+
|
29
|
+
|
30
|
+
def get_observer_token_count(module: torch.nn.Module) -> Counter:
|
31
|
+
"""
|
32
|
+
Parse the module and return the number of tokens observed by
|
33
|
+
each module's observer.
|
34
|
+
|
35
|
+
:param module: module to parse
|
36
|
+
:return: counter with the number of tokens observed by each observer
|
37
|
+
"""
|
38
|
+
token_counts = Counter()
|
39
|
+
for name, module in module.named_modules():
|
40
|
+
if name.endswith(".input_observer"):
|
41
|
+
token_counts[
|
42
|
+
name.replace(".input_observer", "")
|
43
|
+
] = module._num_observed_tokens
|
44
|
+
return token_counts
|
23
45
|
|
24
46
|
|
25
47
|
def calculate_qparams(
|
26
48
|
min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs
|
27
49
|
) -> Tuple[FloatTensor, IntTensor]:
|
28
50
|
"""
|
29
|
-
:param min_vals: tensor of min value(s) to
|
51
|
+
:param min_vals: tensor of min value(s) to calculate scale(s) and zero point(s)
|
30
52
|
from
|
31
|
-
:param max_vals: tensor of max value(s) to
|
53
|
+
:param max_vals: tensor of max value(s) to calculate scale(s) and zero point(s)
|
32
54
|
from
|
33
55
|
:param quantization_args: settings to quantization
|
34
56
|
:return: tuple of the calculated scale(s) and zero point(s)
|
@@ -37,22 +59,53 @@ def calculate_qparams(
|
|
37
59
|
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
38
60
|
device = min_vals.device
|
39
61
|
|
40
|
-
|
41
|
-
|
42
|
-
|
62
|
+
bit_min, bit_max = calculate_range(quantization_args, device)
|
63
|
+
bit_range = bit_max - bit_min
|
64
|
+
zp_dtype = quantization_args.pytorch_dtype()
|
65
|
+
|
43
66
|
if quantization_args.symmetric:
|
44
|
-
max_val_pos = torch.max(
|
67
|
+
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
|
45
68
|
scales = max_val_pos / (float(bit_range) / 2)
|
46
69
|
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
|
47
|
-
zero_points = torch.zeros(scales.shape, device=device, dtype=
|
70
|
+
zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
|
48
71
|
else:
|
49
72
|
scales = (max_vals - min_vals) / float(bit_range)
|
50
73
|
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
|
51
|
-
zero_points = bit_min -
|
52
|
-
zero_points = torch.clamp(zero_points, bit_min, bit_max)
|
74
|
+
zero_points = bit_min - (min_vals / scales)
|
75
|
+
zero_points = torch.clamp(zero_points, bit_min, bit_max)
|
76
|
+
|
77
|
+
# match zero-points to quantized type
|
78
|
+
zero_points = zero_points.to(zp_dtype)
|
53
79
|
|
54
80
|
if scales.ndim == 0:
|
55
81
|
scales = scales.reshape(1)
|
56
82
|
zero_points = zero_points.reshape(1)
|
57
83
|
|
58
84
|
return scales, zero_points
|
85
|
+
|
86
|
+
|
87
|
+
def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
|
88
|
+
"""
|
89
|
+
Calculated the effective quantization range for the given Quantization Args
|
90
|
+
|
91
|
+
:param quantization_args: quantization args to get range of
|
92
|
+
:param device: device to store the range to
|
93
|
+
:return: tuple endpoints for the given quantization range
|
94
|
+
"""
|
95
|
+
if quantization_args.type == QuantizationType.INT:
|
96
|
+
bit_range = 2**quantization_args.num_bits
|
97
|
+
q_max = torch.tensor(bit_range / 2 - 1, device=device)
|
98
|
+
q_min = torch.tensor(-bit_range / 2, device=device)
|
99
|
+
elif quantization_args.type == QuantizationType.FLOAT:
|
100
|
+
if quantization_args.num_bits != 8:
|
101
|
+
raise ValueError(
|
102
|
+
"Floating point quantization is only supported for 8 bits,"
|
103
|
+
f"got {quantization_args.num_bits}"
|
104
|
+
)
|
105
|
+
fp_range_info = torch.finfo(FP8_DTYPE)
|
106
|
+
q_max = torch.tensor(fp_range_info.max, device=device)
|
107
|
+
q_min = torch.tensor(fp_range_info.min, device=device)
|
108
|
+
else:
|
109
|
+
raise ValueError(f"Invalid quantization type {quantization_args.type}")
|
110
|
+
|
111
|
+
return q_min, q_max
|
@@ -94,3 +94,11 @@ class MovingAverageMinMaxObserver(Observer):
|
|
94
94
|
return self.calculate_qparams(
|
95
95
|
observed, reduce_dims=reduce_dims, tensor_id=tensor_id
|
96
96
|
)
|
97
|
+
|
98
|
+
def reset(self):
|
99
|
+
"""
|
100
|
+
Reset the state of the observer, including min and maximum values
|
101
|
+
"""
|
102
|
+
super().reset()
|
103
|
+
self.min_val = {}
|
104
|
+
self.max_val = {}
|
@@ -0,0 +1,162 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Any, Optional, Tuple
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from compressed_tensors.quantization.observers.base import Observer
|
19
|
+
from compressed_tensors.quantization.observers.helpers import calculate_qparams
|
20
|
+
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
21
|
+
from torch import FloatTensor, IntTensor, Tensor
|
22
|
+
|
23
|
+
|
24
|
+
__all__ = ["MovingAverageMSEObserver"]
|
25
|
+
|
26
|
+
|
27
|
+
@Observer.register("mse")
|
28
|
+
class MovingAverageMSEObserver(Observer):
|
29
|
+
"""
|
30
|
+
Implements a dynamic quantization observer that sets the scale and
|
31
|
+
zero point based on a moving average of the mse-clipped min and max observed values
|
32
|
+
"""
|
33
|
+
|
34
|
+
def __init__(
|
35
|
+
self,
|
36
|
+
quantization_args: QuantizationArgs,
|
37
|
+
averaging_constant: float = 0.01,
|
38
|
+
grid: float = 100.0,
|
39
|
+
maxshrink: float = 0.80,
|
40
|
+
norm: float = 2.4,
|
41
|
+
):
|
42
|
+
super().__init__(quantization_args=quantization_args)
|
43
|
+
|
44
|
+
self.min_val = {}
|
45
|
+
self.max_val = {}
|
46
|
+
self.averaging_constant = averaging_constant
|
47
|
+
self.grid = grid
|
48
|
+
self.maxshrink = maxshrink
|
49
|
+
self.norm = norm
|
50
|
+
|
51
|
+
def calculate_mse_min_max(
|
52
|
+
self,
|
53
|
+
observed: Tensor,
|
54
|
+
reduce_dims: Optional[Tuple[int]] = None,
|
55
|
+
):
|
56
|
+
"""
|
57
|
+
Computes the mse-clipped min and max values of the observed tensor by
|
58
|
+
optimizing for quantization error
|
59
|
+
|
60
|
+
:param observed: observed tensor to calculate quantization parameters for
|
61
|
+
:param reduce_dims: optional tuple of dimensions to reduce along,
|
62
|
+
returned values will be shaped (1,) along the reduced dimensions
|
63
|
+
:return: tuple of min and max values derived from the observed tensor
|
64
|
+
"""
|
65
|
+
from compressed_tensors.quantization.lifecycle import fake_quantize
|
66
|
+
|
67
|
+
if not reduce_dims:
|
68
|
+
absolute_min_val, absolute_max_val = torch.aminmax(observed)
|
69
|
+
else:
|
70
|
+
absolute_min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
|
71
|
+
absolute_max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
|
72
|
+
|
73
|
+
best = torch.full(absolute_min_val.shape, float("inf"))
|
74
|
+
min_val = torch.ones(absolute_min_val.shape)
|
75
|
+
max_val = torch.zeros(absolute_max_val.shape)
|
76
|
+
for i in range(int(self.maxshrink * self.grid)):
|
77
|
+
p = 1 - i / self.grid
|
78
|
+
shrinked_min_val = p * absolute_min_val
|
79
|
+
shrinked_max_val = p * absolute_max_val
|
80
|
+
|
81
|
+
candidate_scales, candidate_zero_points = calculate_qparams(
|
82
|
+
shrinked_min_val, shrinked_max_val, self.quantization_args
|
83
|
+
)
|
84
|
+
q = fake_quantize(
|
85
|
+
observed,
|
86
|
+
candidate_scales,
|
87
|
+
candidate_zero_points,
|
88
|
+
self.quantization_args,
|
89
|
+
)
|
90
|
+
|
91
|
+
q -= observed
|
92
|
+
q.abs_()
|
93
|
+
q.pow_(self.norm)
|
94
|
+
if not reduce_dims:
|
95
|
+
err = torch.sum(q)
|
96
|
+
else:
|
97
|
+
err = torch.sum(q, reduce_dims, keepdims=True)
|
98
|
+
|
99
|
+
tmp = err < best
|
100
|
+
if torch.any(tmp):
|
101
|
+
best[tmp] = err[tmp]
|
102
|
+
min_val[tmp] = shrinked_min_val[tmp]
|
103
|
+
max_val[tmp] = shrinked_max_val[tmp]
|
104
|
+
return min_val, max_val
|
105
|
+
|
106
|
+
def calculate_qparams(
|
107
|
+
self,
|
108
|
+
observed: Tensor,
|
109
|
+
reduce_dims: Optional[Tuple[int]] = None,
|
110
|
+
tensor_id: Optional[Any] = None,
|
111
|
+
) -> Tuple[FloatTensor, IntTensor]:
|
112
|
+
"""
|
113
|
+
Updates the mse-clipped min and max values of the observed tensor using
|
114
|
+
a moving average smoothed by the averaging_constant
|
115
|
+
|
116
|
+
:param observed: observed tensor to calculate quantization parameters for
|
117
|
+
:param reduce_dims: optional tuple of dimensions to reduce along,
|
118
|
+
returned scale and zero point will be shaped (1,) along the
|
119
|
+
reduced dimensions
|
120
|
+
:param tensor_id: Optional id if different ranges of observed tensors are
|
121
|
+
passed, useful for sharding tensors by group_size
|
122
|
+
:return: tuple of scale and zero point derived from the observed tensor
|
123
|
+
"""
|
124
|
+
min_val, max_val = self.calculate_mse_min_max(observed, reduce_dims)
|
125
|
+
|
126
|
+
running_min_val = self.min_val.get(tensor_id, None)
|
127
|
+
running_max_val = self.max_val.get(tensor_id, None)
|
128
|
+
|
129
|
+
if running_min_val is None or running_max_val is None:
|
130
|
+
updated_min_val = min_val
|
131
|
+
updated_max_val = max_val
|
132
|
+
else:
|
133
|
+
updated_min_val = running_min_val + self.averaging_constant * (
|
134
|
+
min_val - running_min_val
|
135
|
+
)
|
136
|
+
updated_max_val = running_max_val + self.averaging_constant * (
|
137
|
+
max_val - running_max_val
|
138
|
+
)
|
139
|
+
|
140
|
+
tensor_id = tensor_id or "default"
|
141
|
+
self.min_val[tensor_id] = updated_min_val
|
142
|
+
self.max_val[tensor_id] = updated_max_val
|
143
|
+
|
144
|
+
return calculate_qparams(
|
145
|
+
updated_min_val, updated_max_val, self.quantization_args
|
146
|
+
)
|
147
|
+
|
148
|
+
def get_qparams_along_dim(
|
149
|
+
self, observed, dim: int, tensor_id: Optional[Any] = None
|
150
|
+
):
|
151
|
+
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
|
152
|
+
return self.calculate_qparams(
|
153
|
+
observed, reduce_dims=reduce_dims, tensor_id=tensor_id
|
154
|
+
)
|
155
|
+
|
156
|
+
def reset(self):
|
157
|
+
"""
|
158
|
+
Reset the state of the observer, including min and maximum values
|
159
|
+
"""
|
160
|
+
super().reset()
|
161
|
+
self.min_val = {}
|
162
|
+
self.max_val = {}
|
@@ -13,12 +13,22 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from enum import Enum
|
16
|
-
from typing import Any, Dict, Optional
|
16
|
+
from typing import Any, Dict, Optional, Union
|
17
17
|
|
18
|
-
|
18
|
+
import torch
|
19
|
+
from pydantic import BaseModel, Field, field_validator, model_validator
|
19
20
|
|
20
21
|
|
21
|
-
__all__ = [
|
22
|
+
__all__ = [
|
23
|
+
"FP8_DTYPE",
|
24
|
+
"QuantizationType",
|
25
|
+
"QuantizationStrategy",
|
26
|
+
"QuantizationArgs",
|
27
|
+
"round_to_quantized_type",
|
28
|
+
"ActivationOrdering",
|
29
|
+
]
|
30
|
+
|
31
|
+
FP8_DTYPE = torch.float8_e4m3fn
|
22
32
|
|
23
33
|
|
24
34
|
class QuantizationType(str, Enum):
|
@@ -42,6 +52,19 @@ class QuantizationStrategy(str, Enum):
|
|
42
52
|
TOKEN = "token"
|
43
53
|
|
44
54
|
|
55
|
+
class ActivationOrdering(str, Enum):
|
56
|
+
"""
|
57
|
+
Enum storing strategies for activation ordering
|
58
|
+
|
59
|
+
Group: reorder groups and weight\n
|
60
|
+
Weight: only reorder weight, not groups. Slightly lower latency and
|
61
|
+
accuracy compared to group actorder\n
|
62
|
+
"""
|
63
|
+
|
64
|
+
GROUP = "group"
|
65
|
+
WEIGHT = "weight"
|
66
|
+
|
67
|
+
|
45
68
|
class QuantizationArgs(BaseModel, use_enum_values=True):
|
46
69
|
"""
|
47
70
|
User facing arguments used to define a quantization config for weights or
|
@@ -59,15 +82,18 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
59
82
|
ranges will be observed with every sample. Defaults to False for static
|
60
83
|
quantization. Note that enabling dynamic quantization will change the default
|
61
84
|
observer to a memoryless one
|
85
|
+
:param actorder: whether to apply group quantization in decreasing order of
|
86
|
+
activation. Defaults to None for arbitrary ordering
|
62
87
|
"""
|
63
88
|
|
64
89
|
num_bits: int = 8
|
65
|
-
type: QuantizationType = QuantizationType.INT
|
90
|
+
type: QuantizationType = QuantizationType.INT
|
66
91
|
symmetric: bool = True
|
67
92
|
group_size: Optional[int] = None
|
68
93
|
strategy: Optional[QuantizationStrategy] = None
|
69
94
|
block_structure: Optional[str] = None
|
70
95
|
dynamic: bool = False
|
96
|
+
actorder: Union[ActivationOrdering, bool, None] = None
|
71
97
|
observer: str = Field(
|
72
98
|
default="minmax",
|
73
99
|
description=(
|
@@ -89,37 +115,127 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
89
115
|
"""
|
90
116
|
from compressed_tensors.quantization.observers.base import Observer
|
91
117
|
|
92
|
-
if self.
|
118
|
+
if self.dynamic:
|
93
119
|
# override defualt observer for dynamic, you never want minmax which
|
94
120
|
# keeps state across samples for dynamic
|
95
121
|
self.observer = "memoryless"
|
96
122
|
|
97
123
|
return Observer.load_from_registry(self.observer, quantization_args=self)
|
98
124
|
|
99
|
-
@
|
100
|
-
def
|
101
|
-
|
125
|
+
@field_validator("type", mode="before")
|
126
|
+
def validate_type(cls, value) -> QuantizationType:
|
127
|
+
if isinstance(value, str):
|
128
|
+
return QuantizationType(value.lower())
|
102
129
|
|
103
|
-
|
104
|
-
if group_size is not None and value is None:
|
105
|
-
if group_size > 0:
|
106
|
-
return QuantizationStrategy.GROUP
|
130
|
+
return value
|
107
131
|
|
108
|
-
|
109
|
-
|
132
|
+
@field_validator("group_size", mode="before")
|
133
|
+
def validate_group(cls, value) -> Union[int, None]:
|
134
|
+
if value is None:
|
135
|
+
return value
|
136
|
+
|
137
|
+
if value < -1:
|
138
|
+
raise ValueError(
|
139
|
+
f"Invalid group size {value}. Use group_size > 0 for "
|
140
|
+
"strategy='group' and group_size = -1 for 'channel'"
|
141
|
+
)
|
142
|
+
|
143
|
+
return value
|
144
|
+
|
145
|
+
@field_validator("strategy", mode="before")
|
146
|
+
def validate_strategy(cls, value) -> Union[QuantizationStrategy, None]:
|
147
|
+
if isinstance(value, str):
|
148
|
+
return QuantizationStrategy(value.lower())
|
149
|
+
|
150
|
+
return value
|
110
151
|
|
152
|
+
@field_validator("actorder", mode="before")
|
153
|
+
def validate_actorder(cls, value) -> Optional[ActivationOrdering]:
|
154
|
+
if isinstance(value, bool):
|
155
|
+
return ActivationOrdering.GROUP if value else None
|
156
|
+
|
157
|
+
if isinstance(value, str):
|
158
|
+
return ActivationOrdering(value.lower())
|
159
|
+
|
160
|
+
return value
|
161
|
+
|
162
|
+
@model_validator(mode="after")
|
163
|
+
def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
|
164
|
+
# extract user-passed values from dictionary
|
165
|
+
strategy = model.strategy
|
166
|
+
group_size = model.group_size
|
167
|
+
actorder = model.actorder
|
168
|
+
|
169
|
+
# infer strategy
|
170
|
+
if strategy is None:
|
171
|
+
if group_size is None:
|
172
|
+
strategy = QuantizationStrategy.TENSOR
|
173
|
+
elif group_size > 0:
|
174
|
+
strategy = QuantizationStrategy.GROUP
|
175
|
+
elif group_size == -1:
|
176
|
+
strategy = QuantizationStrategy.CHANNEL
|
111
177
|
else:
|
112
178
|
raise ValueError(
|
113
|
-
f"
|
114
|
-
"group_size
|
115
|
-
"group_size = -1 for 'channel'"
|
179
|
+
f"Invalid group size {group_size}. Use group_size > 0 for "
|
180
|
+
"strategy='group' and group_size = -1 for 'channel'"
|
116
181
|
)
|
117
182
|
|
118
|
-
|
119
|
-
|
120
|
-
|
183
|
+
# validate strategy and group
|
184
|
+
if strategy == QuantizationStrategy.GROUP:
|
185
|
+
if group_size is None or group_size <= 0:
|
186
|
+
raise ValueError(
|
187
|
+
f"strategy {strategy} requires group_size to be "
|
188
|
+
"set to a positive value"
|
189
|
+
)
|
190
|
+
if (
|
191
|
+
group_size is not None
|
192
|
+
and group_size > 0
|
193
|
+
and strategy != QuantizationStrategy.GROUP
|
194
|
+
):
|
195
|
+
raise ValueError("group_size requires strategy to be set to 'group'")
|
196
|
+
|
197
|
+
# validate activation ordering and strategy
|
198
|
+
if actorder is not None and strategy != QuantizationStrategy.GROUP:
|
199
|
+
raise ValueError(
|
200
|
+
"Must use group quantization strategy in order to apply "
|
201
|
+
"activation ordering"
|
202
|
+
)
|
203
|
+
|
204
|
+
# write back modified values
|
205
|
+
model.strategy = strategy
|
206
|
+
return model
|
207
|
+
|
208
|
+
def pytorch_dtype(self) -> torch.dtype:
|
209
|
+
if self.type == QuantizationType.FLOAT:
|
210
|
+
return FP8_DTYPE
|
211
|
+
elif self.type == QuantizationType.INT:
|
212
|
+
if self.num_bits <= 8:
|
213
|
+
return torch.int8
|
214
|
+
elif self.num_bits <= 16:
|
215
|
+
return torch.int16
|
216
|
+
else:
|
217
|
+
return torch.int32
|
218
|
+
else:
|
219
|
+
raise ValueError(f"Invalid quantization type {self.type}")
|
121
220
|
|
122
|
-
if value is None:
|
123
|
-
return QuantizationStrategy.TENSOR
|
124
221
|
|
125
|
-
|
222
|
+
def round_to_quantized_type(
|
223
|
+
tensor: torch.Tensor, args: QuantizationArgs
|
224
|
+
) -> torch.Tensor:
|
225
|
+
"""
|
226
|
+
Rounds each element of the input tensor to the nearest quantized representation,
|
227
|
+
keeping to original dtype
|
228
|
+
|
229
|
+
:param tensor: tensor to round
|
230
|
+
:param args: QuantizationArgs to pull appropriate dtype from
|
231
|
+
:return: rounded tensor
|
232
|
+
"""
|
233
|
+
original_dtype = tensor.dtype
|
234
|
+
if args.type == QuantizationType.FLOAT:
|
235
|
+
rounded = tensor.to(FP8_DTYPE)
|
236
|
+
elif args.type == QuantizationType.INT:
|
237
|
+
rounded = torch.round(tensor)
|
238
|
+
else:
|
239
|
+
raise ValueError(f"Invalid quantization type {args.type}")
|
240
|
+
|
241
|
+
return rounded.to(original_dtype)
|