compressed-tensors 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/base.py +2 -1
- compressed_tensors/compressors/__init__.py +5 -1
- compressed_tensors/compressors/base.py +11 -54
- compressed_tensors/compressors/dense.py +4 -4
- compressed_tensors/compressors/helpers.py +12 -12
- compressed_tensors/compressors/int_quantized.py +126 -0
- compressed_tensors/compressors/marlin_24.py +250 -0
- compressed_tensors/compressors/model_compressor.py +315 -0
- compressed_tensors/compressors/pack_quantized.py +212 -0
- compressed_tensors/compressors/sparse_bitmask.py +3 -3
- compressed_tensors/compressors/utils/__init__.py +19 -0
- compressed_tensors/compressors/utils/helpers.py +43 -0
- compressed_tensors/compressors/utils/permutations_24.py +65 -0
- compressed_tensors/compressors/utils/semi_structured_conversions.py +341 -0
- compressed_tensors/config/base.py +7 -4
- compressed_tensors/config/dense.py +4 -4
- compressed_tensors/config/sparse_bitmask.py +3 -3
- compressed_tensors/quantization/lifecycle/__init__.py +1 -0
- compressed_tensors/quantization/lifecycle/apply.py +62 -11
- compressed_tensors/quantization/lifecycle/compressed.py +69 -0
- compressed_tensors/quantization/lifecycle/forward.py +161 -54
- compressed_tensors/quantization/lifecycle/frozen.py +4 -0
- compressed_tensors/quantization/lifecycle/initialize.py +33 -5
- compressed_tensors/quantization/observers/base.py +31 -27
- compressed_tensors/quantization/observers/helpers.py +6 -1
- compressed_tensors/quantization/observers/memoryless.py +17 -9
- compressed_tensors/quantization/observers/min_max.py +44 -13
- compressed_tensors/quantization/quant_args.py +2 -2
- compressed_tensors/quantization/quant_config.py +69 -21
- compressed_tensors/quantization/quant_scheme.py +81 -1
- compressed_tensors/quantization/utils/helpers.py +76 -8
- compressed_tensors/utils/helpers.py +24 -6
- compressed_tensors/utils/safetensors_load.py +3 -2
- compressed_tensors/version.py +53 -0
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.4.0.dist-info}/METADATA +46 -8
- compressed_tensors-0.4.0.dist-info/RECORD +48 -0
- compressed_tensors-0.3.3.dist-info/RECORD +0 -38
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.4.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.4.0.dist-info}/WHEEL +0 -0
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.4.0.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,7 @@
|
|
14
14
|
|
15
15
|
from functools import wraps
|
16
16
|
from math import ceil
|
17
|
+
from typing import Optional
|
17
18
|
|
18
19
|
import torch
|
19
20
|
from compressed_tensors.quantization.quant_args import (
|
@@ -25,7 +26,13 @@ from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
|
25
26
|
from torch.nn import Module
|
26
27
|
|
27
28
|
|
28
|
-
__all__ = [
|
29
|
+
__all__ = [
|
30
|
+
"quantize",
|
31
|
+
"dequantize",
|
32
|
+
"fake_quantize",
|
33
|
+
"wrap_module_forward_quantized",
|
34
|
+
"maybe_calibrate_or_quantize",
|
35
|
+
]
|
29
36
|
|
30
37
|
|
31
38
|
@torch.no_grad()
|
@@ -33,14 +40,39 @@ def quantize(
|
|
33
40
|
x: torch.Tensor,
|
34
41
|
scale: torch.Tensor,
|
35
42
|
zero_point: torch.Tensor,
|
36
|
-
|
37
|
-
|
43
|
+
args: QuantizationArgs,
|
44
|
+
dtype: Optional[torch.dtype] = None,
|
38
45
|
) -> torch.Tensor:
|
46
|
+
"""
|
47
|
+
Quantize the input tensor x using the QuantizationStrategy specified in args.
|
48
|
+
Quantization can be done per tensor, channel, token or group. For group
|
49
|
+
quantization, the group_size must be divisible by the column size. The input scale
|
50
|
+
and zero_points are reshaped to support vectorization (Assumes 1 is the
|
51
|
+
channel dimension)
|
39
52
|
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
53
|
+
:param x: Input tensor
|
54
|
+
:param scale: scale tensor
|
55
|
+
:param zero_point: zero point tensor
|
56
|
+
:param args: quantization args dictating how to quantize x
|
57
|
+
:param dtype: optional dtype to cast the quantized output to
|
58
|
+
:return: fake quantized tensor
|
59
|
+
"""
|
60
|
+
# ensure all tensors are on the same device
|
61
|
+
# assumes that the target device is the input
|
62
|
+
# tensor's device
|
63
|
+
if x.device != scale.device:
|
64
|
+
scale = scale.to(x.device)
|
65
|
+
if x.device != zero_point.device:
|
66
|
+
zero_point = zero_point.to(x.device)
|
67
|
+
|
68
|
+
return _process_quantization(
|
69
|
+
x=x,
|
70
|
+
scale=scale,
|
71
|
+
zero_point=zero_point,
|
72
|
+
args=args,
|
73
|
+
dtype=dtype,
|
74
|
+
do_quantize=True,
|
75
|
+
do_dequantize=False,
|
44
76
|
)
|
45
77
|
|
46
78
|
|
@@ -49,8 +81,42 @@ def dequantize(
|
|
49
81
|
x_q: torch.Tensor,
|
50
82
|
scale: torch.Tensor,
|
51
83
|
zero_point: torch.Tensor,
|
84
|
+
args: QuantizationArgs = None,
|
52
85
|
) -> torch.Tensor:
|
53
|
-
|
86
|
+
"""
|
87
|
+
Dequantize a quantized input tensor x_q based on the strategy specified in args. If
|
88
|
+
args is not provided, the strategy will be inferred.
|
89
|
+
|
90
|
+
:param x: quantized input tensor
|
91
|
+
:param scale: scale tensor
|
92
|
+
:param zero_point: zero point tensor
|
93
|
+
:param args: quantization args used to quantize x_q
|
94
|
+
:return: dequantized float tensor
|
95
|
+
"""
|
96
|
+
if args is None:
|
97
|
+
if scale.ndim == 0 or scale.ndim == 1:
|
98
|
+
args = QuantizationArgs(strategy=QuantizationStrategy.TENSOR)
|
99
|
+
elif scale.ndim == 2:
|
100
|
+
if scale.shape[1] == 1:
|
101
|
+
args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL)
|
102
|
+
else:
|
103
|
+
group_size = int(x_q.shape[1] / scale.shape[1])
|
104
|
+
args = QuantizationArgs(
|
105
|
+
strategy=QuantizationStrategy.GROUP, group_size=group_size
|
106
|
+
)
|
107
|
+
else:
|
108
|
+
raise ValueError(
|
109
|
+
f"Could not infer a quantization strategy from scale with {scale.ndim} "
|
110
|
+
"dimmensions. Expected 0-2 dimmensions."
|
111
|
+
)
|
112
|
+
return _process_quantization(
|
113
|
+
x=x_q,
|
114
|
+
scale=scale,
|
115
|
+
zero_point=zero_point,
|
116
|
+
args=args,
|
117
|
+
do_quantize=False,
|
118
|
+
do_dequantize=True,
|
119
|
+
)
|
54
120
|
|
55
121
|
|
56
122
|
@torch.no_grad()
|
@@ -61,30 +127,51 @@ def fake_quantize(
|
|
61
127
|
args: QuantizationArgs,
|
62
128
|
) -> torch.Tensor:
|
63
129
|
"""
|
64
|
-
Fake quantize the input tensor x
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
the channel dimension)
|
130
|
+
Fake quantize the input tensor x by quantizing then dequantizing with
|
131
|
+
the QuantizationStrategy specified in args. Quantization can be done per tensor,
|
132
|
+
channel, token or group. For group quantization, the group_size must be divisible
|
133
|
+
by the column size. The input scale and zero_points are reshaped to support
|
134
|
+
vectorization (Assumes 1 is the channel dimension)
|
70
135
|
|
71
136
|
:param x: Input tensor
|
72
137
|
:param scale: scale tensor
|
73
138
|
:param zero_point: zero point tensor
|
74
|
-
:param args: quantization args
|
139
|
+
:param args: quantization args dictating how to quantize x
|
75
140
|
:return: fake quantized tensor
|
76
|
-
|
77
141
|
"""
|
78
|
-
|
79
|
-
|
80
|
-
|
142
|
+
return _process_quantization(
|
143
|
+
x=x,
|
144
|
+
scale=scale,
|
145
|
+
zero_point=zero_point,
|
146
|
+
args=args,
|
147
|
+
do_quantize=True,
|
148
|
+
do_dequantize=True,
|
149
|
+
)
|
150
|
+
|
81
151
|
|
152
|
+
@torch.no_grad()
|
153
|
+
def _process_quantization(
|
154
|
+
x: torch.Tensor,
|
155
|
+
scale: torch.Tensor,
|
156
|
+
zero_point: torch.Tensor,
|
157
|
+
args: QuantizationArgs,
|
158
|
+
dtype: Optional[torch.dtype] = None,
|
159
|
+
do_quantize: bool = True,
|
160
|
+
do_dequantize: bool = True,
|
161
|
+
) -> torch.Tensor:
|
162
|
+
bit_range = 2**args.num_bits
|
163
|
+
q_max = torch.tensor(bit_range / 2 - 1, device=x.device)
|
164
|
+
q_min = torch.tensor(-bit_range / 2, device=x.device)
|
82
165
|
group_size = args.group_size
|
83
166
|
|
84
|
-
# group
|
85
167
|
if args.strategy == QuantizationStrategy.GROUP:
|
86
168
|
|
87
|
-
|
169
|
+
if do_dequantize and not do_quantize:
|
170
|
+
# if dequantizing a quantized type infer the output type from the scale
|
171
|
+
output = torch.zeros_like(x, dtype=scale.dtype)
|
172
|
+
else:
|
173
|
+
output_dtype = dtype if dtype is not None else x.dtype
|
174
|
+
output = torch.zeros_like(x, dtype=output_dtype)
|
88
175
|
|
89
176
|
# TODO: vectorize the for loop
|
90
177
|
# TODO: fix genetric assumption about the tensor size for computing group
|
@@ -106,48 +193,38 @@ def fake_quantize(
|
|
106
193
|
for i in range(ceil(columns / group_size)):
|
107
194
|
# scale.shape should be [nchan, ndim]
|
108
195
|
# sc.shape should be [nchan, 1] after unsqueeze
|
109
|
-
|
110
|
-
|
111
|
-
zp = zero_point[:, i].unsqueeze(1)
|
196
|
+
sc = scale[:, i].view(-1, 1)
|
197
|
+
zp = zero_point[:, i].view(-1, 1)
|
112
198
|
|
113
199
|
idx = i * group_size
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
DQ = dequantize(Q, scale, zero_point)
|
126
|
-
|
127
|
-
# per-token
|
128
|
-
elif args.strategy == QuantizationStrategy.TOKEN:
|
129
|
-
# before: scale shape = [num_tokens]
|
130
|
-
# after: scale shape = [num_tokens, 1]
|
131
|
-
# x.shape = 1, num_tokens, 1]
|
132
|
-
# scale gets broadcasted as expected withput having [1, num_tokens, 1] shape
|
133
|
-
|
134
|
-
scale = scale.unsqueeze(1)
|
135
|
-
zero_point = zero_point.unsqueeze(1)
|
136
|
-
|
137
|
-
Q = quantize(x, scale, zero_point, min_q, max_q)
|
138
|
-
DQ = dequantize(Q, scale, zero_point)
|
200
|
+
if do_quantize:
|
201
|
+
output[:, idx : (idx + group_size)] = _quantize(
|
202
|
+
x[:, idx : (idx + group_size)], sc, zp, q_min, q_max, dtype=dtype
|
203
|
+
)
|
204
|
+
if do_dequantize:
|
205
|
+
input = (
|
206
|
+
output[:, idx : (idx + group_size)]
|
207
|
+
if do_quantize
|
208
|
+
else x[:, idx : (idx + group_size)]
|
209
|
+
)
|
210
|
+
output[:, idx : (idx + group_size)] = _dequantize(input, sc, zp)
|
139
211
|
|
140
|
-
else:
|
141
|
-
|
142
|
-
|
212
|
+
else: # covers channel, token and tensor strategies
|
213
|
+
if do_quantize:
|
214
|
+
output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
|
215
|
+
if do_dequantize:
|
216
|
+
output = _dequantize(output if do_quantize else x, scale, zero_point)
|
143
217
|
|
144
|
-
return
|
218
|
+
return output
|
145
219
|
|
146
220
|
|
147
221
|
def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
148
222
|
# expects a module already initialized and injected with the parameters in
|
149
223
|
# initialize_module_for_quantization
|
150
|
-
|
224
|
+
if hasattr(module.forward, "__func__"):
|
225
|
+
forward_func_orig = module.forward.__func__
|
226
|
+
else:
|
227
|
+
forward_func_orig = module.forward.func
|
151
228
|
|
152
229
|
@wraps(forward_func_orig) # ensures docstring, names, etc are propagated
|
153
230
|
def wrapped_forward(self, *args, **kwargs):
|
@@ -219,3 +296,33 @@ def maybe_calibrate_or_quantize(
|
|
219
296
|
scale.data = updated_scale.to(device)
|
220
297
|
zero_point.data = updated_zero_point.to(device)
|
221
298
|
return fake_quantize(value, scale, zero_point, args)
|
299
|
+
|
300
|
+
|
301
|
+
@torch.no_grad()
|
302
|
+
def _quantize(
|
303
|
+
x: torch.Tensor,
|
304
|
+
scale: torch.Tensor,
|
305
|
+
zero_point: torch.Tensor,
|
306
|
+
q_min: torch.Tensor,
|
307
|
+
q_max: torch.Tensor,
|
308
|
+
dtype: Optional[torch.dtype] = None,
|
309
|
+
) -> torch.Tensor:
|
310
|
+
quantized_value = torch.clamp(
|
311
|
+
torch.round(x / scale + zero_point),
|
312
|
+
q_min,
|
313
|
+
q_max,
|
314
|
+
)
|
315
|
+
|
316
|
+
if dtype is not None:
|
317
|
+
quantized_value = quantized_value.to(dtype)
|
318
|
+
|
319
|
+
return quantized_value
|
320
|
+
|
321
|
+
|
322
|
+
@torch.no_grad()
|
323
|
+
def _dequantize(
|
324
|
+
x_q: torch.Tensor,
|
325
|
+
scale: torch.Tensor,
|
326
|
+
zero_point: torch.Tensor,
|
327
|
+
) -> torch.Tensor:
|
328
|
+
return (x_q - zero_point) * scale
|
@@ -35,6 +35,10 @@ def freeze_module_quantization(module: Module):
|
|
35
35
|
# no quantization scheme nothing to do
|
36
36
|
return
|
37
37
|
|
38
|
+
if module.quantization_status == QuantizationStatus.FROZEN:
|
39
|
+
# nothing to do, already frozen
|
40
|
+
return
|
41
|
+
|
38
42
|
# delete observers from module if not dynamic
|
39
43
|
if scheme.input_activations and not scheme.input_activations.dynamic:
|
40
44
|
delattr(module, "input_observer")
|
@@ -20,7 +20,10 @@ import torch
|
|
20
20
|
from compressed_tensors.quantization.lifecycle.forward import (
|
21
21
|
wrap_module_forward_quantized,
|
22
22
|
)
|
23
|
-
from compressed_tensors.quantization.quant_args import
|
23
|
+
from compressed_tensors.quantization.quant_args import (
|
24
|
+
QuantizationArgs,
|
25
|
+
QuantizationStrategy,
|
26
|
+
)
|
24
27
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
25
28
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
26
29
|
from torch.nn import Module, Parameter
|
@@ -58,7 +61,12 @@ def initialize_module_for_quantization(
|
|
58
61
|
_initialize_scale_zero_point_observer(module, "input", scheme.input_activations)
|
59
62
|
if scheme.weights is not None:
|
60
63
|
if hasattr(module, "weight"):
|
61
|
-
|
64
|
+
weight_shape = None
|
65
|
+
if isinstance(module, torch.nn.Linear):
|
66
|
+
weight_shape = module.weight.shape
|
67
|
+
_initialize_scale_zero_point_observer(
|
68
|
+
module, "weight", scheme.weights, weight_shape=weight_shape
|
69
|
+
)
|
62
70
|
else:
|
63
71
|
_LOGGER.warning(
|
64
72
|
f"module type {type(module)} targeted for weight quantization but "
|
@@ -78,7 +86,10 @@ def initialize_module_for_quantization(
|
|
78
86
|
|
79
87
|
|
80
88
|
def _initialize_scale_zero_point_observer(
|
81
|
-
module: Module,
|
89
|
+
module: Module,
|
90
|
+
base_name: str,
|
91
|
+
quantization_args: QuantizationArgs,
|
92
|
+
weight_shape: Optional[torch.Size] = None,
|
82
93
|
):
|
83
94
|
# initialize observer module and attach as submodule
|
84
95
|
observer = quantization_args.get_observer()
|
@@ -89,11 +100,28 @@ def _initialize_scale_zero_point_observer(
|
|
89
100
|
|
90
101
|
device = next(module.parameters()).device
|
91
102
|
|
103
|
+
# infer expected scale/zero point shape
|
104
|
+
expected_shape = 1 # per tensor
|
105
|
+
|
106
|
+
if base_name == "weight" and weight_shape is not None:
|
107
|
+
if quantization_args.strategy == QuantizationStrategy.CHANNEL:
|
108
|
+
# (output_channels, 1)
|
109
|
+
expected_shape = (weight_shape[0], 1)
|
110
|
+
elif quantization_args.strategy == QuantizationStrategy.GROUP:
|
111
|
+
expected_shape = (
|
112
|
+
weight_shape[0],
|
113
|
+
weight_shape[1] // quantization_args.group_size,
|
114
|
+
)
|
115
|
+
|
92
116
|
# initializes empty scale and zero point parameters for the module
|
93
|
-
init_scale = Parameter(
|
117
|
+
init_scale = Parameter(
|
118
|
+
torch.empty(expected_shape, dtype=module.weight.dtype, device=device),
|
119
|
+
requires_grad=False,
|
120
|
+
)
|
94
121
|
module.register_parameter(f"{base_name}_scale", init_scale)
|
95
122
|
|
96
123
|
init_zero_point = Parameter(
|
97
|
-
torch.empty(
|
124
|
+
torch.empty(expected_shape, device=device, dtype=int),
|
125
|
+
requires_grad=False,
|
98
126
|
)
|
99
127
|
module.register_parameter(f"{base_name}_zero_point", init_zero_point)
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Optional, Tuple
|
15
|
+
from typing import Any, Iterable, Optional, Tuple, Union
|
16
16
|
|
17
17
|
import torch
|
18
18
|
from compressed_tensors.quantization.quant_args import (
|
@@ -40,6 +40,7 @@ class Observer(Module, RegistryMixin):
|
|
40
40
|
self._scale = None
|
41
41
|
self._zero_point = None
|
42
42
|
|
43
|
+
@torch.no_grad()
|
43
44
|
def forward(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
|
44
45
|
"""
|
45
46
|
maps directly to get_qparams
|
@@ -49,9 +50,16 @@ class Observer(Module, RegistryMixin):
|
|
49
50
|
"""
|
50
51
|
return self.get_qparams(observed=observed)
|
51
52
|
|
52
|
-
def calculate_qparams(
|
53
|
+
def calculate_qparams(
|
54
|
+
self,
|
55
|
+
observed: Tensor,
|
56
|
+
reduce_dims: Optional[Tuple[int]] = None,
|
57
|
+
) -> Tuple[FloatTensor, IntTensor]:
|
53
58
|
"""
|
54
59
|
:param observed: observed tensor to calculate quantization parameters for
|
60
|
+
:param reduce_dims: optional tuple of dimensions to reduce along,
|
61
|
+
returned scale and zero point will be shaped (1,) along the
|
62
|
+
reduced dimensions
|
55
63
|
:return: tuple of scale and zero point derived from the observed tensor
|
56
64
|
"""
|
57
65
|
raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")
|
@@ -69,6 +77,7 @@ class Observer(Module, RegistryMixin):
|
|
69
77
|
Convenience function to wrap overwritten calculate_qparams
|
70
78
|
adds support to make observed tensor optional and support for tracking latest
|
71
79
|
calculated scale and zero point
|
80
|
+
|
72
81
|
:param observed: optional observed tensor to calculate quantization parameters
|
73
82
|
from
|
74
83
|
:return: tuple of scale and zero point based on last observed value
|
@@ -84,47 +93,42 @@ class Observer(Module, RegistryMixin):
|
|
84
93
|
elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
|
85
94
|
columns = observed.shape[1]
|
86
95
|
scales, zero_points = [], []
|
87
|
-
|
96
|
+
group_idxs = range(0, columns, self.quantization_args.group_size)
|
97
|
+
for group_id, group_idx in enumerate(group_idxs):
|
88
98
|
scale, zero_point = self.get_qparams_along_dim(
|
89
|
-
observed[:,
|
99
|
+
observed[:, group_idx : (group_idx + group_size)],
|
90
100
|
0,
|
101
|
+
tensor_id=group_id,
|
91
102
|
)
|
92
103
|
scales.append(scale)
|
93
104
|
zero_points.append(zero_point)
|
94
105
|
|
95
|
-
self._scale = torch.
|
96
|
-
self._zero_point = torch.
|
106
|
+
self._scale = torch.cat(scales, dim=1, out=self._scale)
|
107
|
+
self._zero_point = torch.cat(zero_points, dim=1, out=self._zero_point)
|
97
108
|
|
98
109
|
elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
|
99
110
|
# assume observed is transposed, because its the output, hence use dim 0
|
100
111
|
self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)
|
101
112
|
|
102
113
|
elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
|
103
|
-
|
104
114
|
# use dim 1, assume the obsersed.shape = [batch, token, hidden]
|
105
115
|
# should be batch, token
|
106
|
-
|
107
116
|
self._scale, self._zero_point = self.get_qparams_along_dim(
|
108
|
-
observed,
|
117
|
+
observed,
|
118
|
+
dim={0, 1},
|
109
119
|
)
|
110
120
|
|
111
121
|
return self._scale, self._zero_point
|
112
122
|
|
113
|
-
def get_qparams_along_dim(
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
)
|
126
|
-
|
127
|
-
scales.append(scale)
|
128
|
-
zero_points.append(zero_point)
|
129
|
-
# breakpoint()
|
130
|
-
return torch.stack(scales), torch.stack(zero_points)
|
123
|
+
def get_qparams_along_dim(
|
124
|
+
self,
|
125
|
+
observed,
|
126
|
+
dim: Union[int, Iterable[int]],
|
127
|
+
tensor_id: Optional[Any] = None,
|
128
|
+
):
|
129
|
+
dim = set(dim)
|
130
|
+
|
131
|
+
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
|
132
|
+
return self.calculate_qparams(
|
133
|
+
observed, reduce_dims=reduce_dims, tensor_id=tensor_id
|
134
|
+
)
|
@@ -35,19 +35,24 @@ def calculate_qparams(
|
|
35
35
|
"""
|
36
36
|
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
|
37
37
|
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
38
|
+
device = min_vals.device
|
38
39
|
|
39
40
|
bit_range = 2**quantization_args.num_bits - 1
|
40
41
|
bit_min = -(bit_range + 1) / 2
|
41
42
|
bit_max = bit_min + bit_range
|
42
43
|
if quantization_args.symmetric:
|
43
|
-
zero_points = torch.tensor(0).to(torch.int8)
|
44
44
|
max_val_pos = torch.max(-min_vals, max_vals)
|
45
45
|
scales = max_val_pos / (float(bit_range) / 2)
|
46
46
|
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
|
47
|
+
zero_points = torch.zeros(scales.shape, device=device, dtype=torch.int8)
|
47
48
|
else:
|
48
49
|
scales = (max_vals - min_vals) / float(bit_range)
|
49
50
|
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
|
50
51
|
zero_points = bit_min - torch.round(min_vals / scales)
|
51
52
|
zero_points = torch.clamp(zero_points, bit_min, bit_max).to(torch.int8)
|
52
53
|
|
54
|
+
if scales.ndim == 0:
|
55
|
+
scales = scales.reshape(1)
|
56
|
+
zero_points = zero_points.reshape(1)
|
57
|
+
|
53
58
|
return scales, zero_points
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Tuple
|
15
|
+
from typing import Any, Optional, Tuple
|
16
16
|
|
17
17
|
import torch
|
18
18
|
from compressed_tensors.quantization.observers.base import Observer
|
@@ -30,19 +30,27 @@ class MemorylessObserver(Observer):
|
|
30
30
|
zero point based on the latest observed value without tracking state
|
31
31
|
"""
|
32
32
|
|
33
|
-
def calculate_qparams(
|
33
|
+
def calculate_qparams(
|
34
|
+
self,
|
35
|
+
observed: Tensor,
|
36
|
+
tensor_id: Optional[Any] = None,
|
37
|
+
reduce_dims: Optional[Tuple[int]] = None,
|
38
|
+
) -> Tuple[FloatTensor, IntTensor]:
|
34
39
|
"""
|
35
|
-
Returns the min and max values of observed
|
40
|
+
Returns the min and max values of observed tensor
|
36
41
|
|
37
42
|
:param observed: observed tensor to calculate quantization parameters for
|
43
|
+
:param tensor_id: optional id for tensor; not used for memoryless
|
44
|
+
:param reduce_dims: optional tuple of dimensions to reduce along,
|
45
|
+
returned scale and zero point will be shaped (1,) along the
|
46
|
+
reduced dimensions
|
38
47
|
:return: tuple of scale and zero point derived from the observed tensor
|
39
48
|
"""
|
40
|
-
# TODO: Add support for full range of quantization Args, only supports 8bit
|
41
|
-
# per tensor
|
42
|
-
min_val, max_val = torch.aminmax(observed)
|
43
49
|
|
44
|
-
|
45
|
-
|
46
|
-
|
50
|
+
if not reduce_dims:
|
51
|
+
min_val, max_val = torch.aminmax(observed)
|
52
|
+
else:
|
53
|
+
min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
|
54
|
+
max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
|
47
55
|
|
48
56
|
return calculate_qparams(min_val, max_val, self.quantization_args)
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Tuple
|
15
|
+
from typing import Any, Optional, Tuple
|
16
16
|
|
17
17
|
import torch
|
18
18
|
from compressed_tensors.quantization.observers.base import Observer
|
@@ -36,30 +36,61 @@ class MovingAverageMinMaxObserver(Observer):
|
|
36
36
|
):
|
37
37
|
super().__init__(quantization_args=quantization_args)
|
38
38
|
|
39
|
-
self.min_val =
|
40
|
-
self.max_val =
|
39
|
+
self.min_val = {}
|
40
|
+
self.max_val = {}
|
41
41
|
self.averaging_constant = averaging_constant
|
42
42
|
|
43
|
-
def calculate_qparams(
|
43
|
+
def calculate_qparams(
|
44
|
+
self,
|
45
|
+
observed: Tensor,
|
46
|
+
reduce_dims: Optional[Tuple[int]] = None,
|
47
|
+
tensor_id: Optional[Any] = None,
|
48
|
+
) -> Tuple[FloatTensor, IntTensor]:
|
44
49
|
"""
|
45
50
|
Updates the observed min and max using a moving average smoothed by the
|
46
51
|
averaging_constant
|
47
52
|
|
48
53
|
:param observed: observed tensor to calculate quantization parameters for
|
54
|
+
:param reduce_dims: optional tuple of dimensions to reduce along,
|
55
|
+
returned scale and zero point will be shaped (1,) along the
|
56
|
+
reduced dimensions
|
57
|
+
:param tensor_id: Optional id if different ranges of observed tensors are
|
58
|
+
passed, useful for sharding tensors by group_size
|
49
59
|
:return: tuple of scale and zero point derived from the observed tensor
|
50
60
|
"""
|
61
|
+
tensor_id = tensor_id or "default"
|
51
62
|
|
52
|
-
|
63
|
+
if not reduce_dims:
|
64
|
+
min_val, max_val = torch.aminmax(observed)
|
65
|
+
else:
|
66
|
+
min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
|
67
|
+
max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
|
68
|
+
|
69
|
+
running_min_val = self.min_val.get(tensor_id, None)
|
70
|
+
running_max_val = self.max_val.get(tensor_id, None)
|
53
71
|
|
54
|
-
if
|
55
|
-
|
56
|
-
|
72
|
+
if running_min_val is None or running_max_val is None:
|
73
|
+
updated_min_val = min_val
|
74
|
+
updated_max_val = max_val
|
57
75
|
else:
|
58
|
-
|
59
|
-
min_val -
|
76
|
+
updated_min_val = running_min_val + self.averaging_constant * (
|
77
|
+
min_val - running_min_val
|
60
78
|
)
|
61
|
-
|
62
|
-
max_val -
|
79
|
+
updated_max_val = running_max_val + self.averaging_constant * (
|
80
|
+
max_val - running_max_val
|
63
81
|
)
|
64
82
|
|
65
|
-
|
83
|
+
self.min_val[tensor_id] = updated_min_val
|
84
|
+
self.max_val[tensor_id] = updated_max_val
|
85
|
+
|
86
|
+
return calculate_qparams(
|
87
|
+
updated_min_val, updated_max_val, self.quantization_args
|
88
|
+
)
|
89
|
+
|
90
|
+
def get_qparams_along_dim(
|
91
|
+
self, observed, dim: int, tensor_id: Optional[Any] = None
|
92
|
+
):
|
93
|
+
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
|
94
|
+
return self.calculate_qparams(
|
95
|
+
observed, reduce_dims=reduce_dims, tensor_id=tensor_id
|
96
|
+
)
|
@@ -42,7 +42,7 @@ class QuantizationStrategy(str, Enum):
|
|
42
42
|
TOKEN = "token"
|
43
43
|
|
44
44
|
|
45
|
-
class QuantizationArgs(BaseModel):
|
45
|
+
class QuantizationArgs(BaseModel, use_enum_values=True):
|
46
46
|
"""
|
47
47
|
User facing arguments used to define a quantization config for weights or
|
48
48
|
activations
|
@@ -62,7 +62,7 @@ class QuantizationArgs(BaseModel):
|
|
62
62
|
"""
|
63
63
|
|
64
64
|
num_bits: int = 8
|
65
|
-
type: QuantizationType = QuantizationType.INT
|
65
|
+
type: QuantizationType = QuantizationType.INT.value
|
66
66
|
symmetric: bool = True
|
67
67
|
group_size: Optional[int] = None
|
68
68
|
strategy: Optional[QuantizationStrategy] = None
|