compressed-tensors 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/base.py +2 -1
- compressed_tensors/compressors/__init__.py +5 -1
- compressed_tensors/compressors/base.py +11 -54
- compressed_tensors/compressors/dense.py +4 -4
- compressed_tensors/compressors/helpers.py +12 -12
- compressed_tensors/compressors/int_quantized.py +126 -0
- compressed_tensors/compressors/marlin_24.py +250 -0
- compressed_tensors/compressors/model_compressor.py +315 -0
- compressed_tensors/compressors/pack_quantized.py +212 -0
- compressed_tensors/compressors/sparse_bitmask.py +4 -4
- compressed_tensors/compressors/utils/__init__.py +19 -0
- compressed_tensors/compressors/utils/helpers.py +43 -0
- compressed_tensors/compressors/utils/permutations_24.py +65 -0
- compressed_tensors/compressors/utils/semi_structured_conversions.py +341 -0
- compressed_tensors/config/base.py +7 -4
- compressed_tensors/config/dense.py +4 -4
- compressed_tensors/config/sparse_bitmask.py +3 -3
- compressed_tensors/quantization/lifecycle/__init__.py +1 -0
- compressed_tensors/quantization/lifecycle/apply.py +75 -19
- compressed_tensors/quantization/lifecycle/compressed.py +69 -0
- compressed_tensors/quantization/lifecycle/forward.py +208 -22
- compressed_tensors/quantization/lifecycle/frozen.py +4 -0
- compressed_tensors/quantization/lifecycle/initialize.py +33 -5
- compressed_tensors/quantization/observers/base.py +70 -5
- compressed_tensors/quantization/observers/helpers.py +6 -1
- compressed_tensors/quantization/observers/memoryless.py +17 -9
- compressed_tensors/quantization/observers/min_max.py +44 -13
- compressed_tensors/quantization/quant_args.py +33 -4
- compressed_tensors/quantization/quant_config.py +69 -21
- compressed_tensors/quantization/quant_scheme.py +81 -1
- compressed_tensors/quantization/utils/helpers.py +77 -8
- compressed_tensors/utils/helpers.py +26 -122
- compressed_tensors/utils/safetensors_load.py +3 -2
- compressed_tensors/version.py +53 -0
- {compressed_tensors-0.3.2.dist-info → compressed_tensors-0.4.0.dist-info}/METADATA +46 -9
- compressed_tensors-0.4.0.dist-info/RECORD +48 -0
- compressed_tensors-0.3.2.dist-info/RECORD +0 -38
- {compressed_tensors-0.3.2.dist-info → compressed_tensors-0.4.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.3.2.dist-info → compressed_tensors-0.4.0.dist-info}/WHEEL +0 -0
- {compressed_tensors-0.3.2.dist-info → compressed_tensors-0.4.0.dist-info}/top_level.txt +0 -0
@@ -13,15 +13,26 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from functools import wraps
|
16
|
+
from math import ceil
|
17
|
+
from typing import Optional
|
16
18
|
|
17
19
|
import torch
|
18
|
-
from compressed_tensors.quantization.quant_args import
|
20
|
+
from compressed_tensors.quantization.quant_args import (
|
21
|
+
QuantizationArgs,
|
22
|
+
QuantizationStrategy,
|
23
|
+
)
|
19
24
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
20
25
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
21
26
|
from torch.nn import Module
|
22
27
|
|
23
28
|
|
24
|
-
__all__ = [
|
29
|
+
__all__ = [
|
30
|
+
"quantize",
|
31
|
+
"dequantize",
|
32
|
+
"fake_quantize",
|
33
|
+
"wrap_module_forward_quantized",
|
34
|
+
"maybe_calibrate_or_quantize",
|
35
|
+
]
|
25
36
|
|
26
37
|
|
27
38
|
@torch.no_grad()
|
@@ -29,15 +40,39 @@ def quantize(
|
|
29
40
|
x: torch.Tensor,
|
30
41
|
scale: torch.Tensor,
|
31
42
|
zero_point: torch.Tensor,
|
32
|
-
|
33
|
-
|
43
|
+
args: QuantizationArgs,
|
44
|
+
dtype: Optional[torch.dtype] = None,
|
34
45
|
) -> torch.Tensor:
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
46
|
+
"""
|
47
|
+
Quantize the input tensor x using the QuantizationStrategy specified in args.
|
48
|
+
Quantization can be done per tensor, channel, token or group. For group
|
49
|
+
quantization, the group_size must be divisible by the column size. The input scale
|
50
|
+
and zero_points are reshaped to support vectorization (Assumes 1 is the
|
51
|
+
channel dimension)
|
52
|
+
|
53
|
+
:param x: Input tensor
|
54
|
+
:param scale: scale tensor
|
55
|
+
:param zero_point: zero point tensor
|
56
|
+
:param args: quantization args dictating how to quantize x
|
57
|
+
:param dtype: optional dtype to cast the quantized output to
|
58
|
+
:return: fake quantized tensor
|
59
|
+
"""
|
60
|
+
# ensure all tensors are on the same device
|
61
|
+
# assumes that the target device is the input
|
62
|
+
# tensor's device
|
63
|
+
if x.device != scale.device:
|
64
|
+
scale = scale.to(x.device)
|
65
|
+
if x.device != zero_point.device:
|
66
|
+
zero_point = zero_point.to(x.device)
|
67
|
+
|
68
|
+
return _process_quantization(
|
69
|
+
x=x,
|
70
|
+
scale=scale,
|
71
|
+
zero_point=zero_point,
|
72
|
+
args=args,
|
73
|
+
dtype=dtype,
|
74
|
+
do_quantize=True,
|
75
|
+
do_dequantize=False,
|
41
76
|
)
|
42
77
|
|
43
78
|
|
@@ -46,8 +81,42 @@ def dequantize(
|
|
46
81
|
x_q: torch.Tensor,
|
47
82
|
scale: torch.Tensor,
|
48
83
|
zero_point: torch.Tensor,
|
84
|
+
args: QuantizationArgs = None,
|
49
85
|
) -> torch.Tensor:
|
50
|
-
|
86
|
+
"""
|
87
|
+
Dequantize a quantized input tensor x_q based on the strategy specified in args. If
|
88
|
+
args is not provided, the strategy will be inferred.
|
89
|
+
|
90
|
+
:param x: quantized input tensor
|
91
|
+
:param scale: scale tensor
|
92
|
+
:param zero_point: zero point tensor
|
93
|
+
:param args: quantization args used to quantize x_q
|
94
|
+
:return: dequantized float tensor
|
95
|
+
"""
|
96
|
+
if args is None:
|
97
|
+
if scale.ndim == 0 or scale.ndim == 1:
|
98
|
+
args = QuantizationArgs(strategy=QuantizationStrategy.TENSOR)
|
99
|
+
elif scale.ndim == 2:
|
100
|
+
if scale.shape[1] == 1:
|
101
|
+
args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL)
|
102
|
+
else:
|
103
|
+
group_size = int(x_q.shape[1] / scale.shape[1])
|
104
|
+
args = QuantizationArgs(
|
105
|
+
strategy=QuantizationStrategy.GROUP, group_size=group_size
|
106
|
+
)
|
107
|
+
else:
|
108
|
+
raise ValueError(
|
109
|
+
f"Could not infer a quantization strategy from scale with {scale.ndim} "
|
110
|
+
"dimmensions. Expected 0-2 dimmensions."
|
111
|
+
)
|
112
|
+
return _process_quantization(
|
113
|
+
x=x_q,
|
114
|
+
scale=scale,
|
115
|
+
zero_point=zero_point,
|
116
|
+
args=args,
|
117
|
+
do_quantize=False,
|
118
|
+
do_dequantize=True,
|
119
|
+
)
|
51
120
|
|
52
121
|
|
53
122
|
@torch.no_grad()
|
@@ -56,19 +125,106 @@ def fake_quantize(
|
|
56
125
|
scale: torch.Tensor,
|
57
126
|
zero_point: torch.Tensor,
|
58
127
|
args: QuantizationArgs,
|
128
|
+
) -> torch.Tensor:
|
129
|
+
"""
|
130
|
+
Fake quantize the input tensor x by quantizing then dequantizing with
|
131
|
+
the QuantizationStrategy specified in args. Quantization can be done per tensor,
|
132
|
+
channel, token or group. For group quantization, the group_size must be divisible
|
133
|
+
by the column size. The input scale and zero_points are reshaped to support
|
134
|
+
vectorization (Assumes 1 is the channel dimension)
|
135
|
+
|
136
|
+
:param x: Input tensor
|
137
|
+
:param scale: scale tensor
|
138
|
+
:param zero_point: zero point tensor
|
139
|
+
:param args: quantization args dictating how to quantize x
|
140
|
+
:return: fake quantized tensor
|
141
|
+
"""
|
142
|
+
return _process_quantization(
|
143
|
+
x=x,
|
144
|
+
scale=scale,
|
145
|
+
zero_point=zero_point,
|
146
|
+
args=args,
|
147
|
+
do_quantize=True,
|
148
|
+
do_dequantize=True,
|
149
|
+
)
|
150
|
+
|
151
|
+
|
152
|
+
@torch.no_grad()
|
153
|
+
def _process_quantization(
|
154
|
+
x: torch.Tensor,
|
155
|
+
scale: torch.Tensor,
|
156
|
+
zero_point: torch.Tensor,
|
157
|
+
args: QuantizationArgs,
|
158
|
+
dtype: Optional[torch.dtype] = None,
|
159
|
+
do_quantize: bool = True,
|
160
|
+
do_dequantize: bool = True,
|
59
161
|
) -> torch.Tensor:
|
60
162
|
bit_range = 2**args.num_bits
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
163
|
+
q_max = torch.tensor(bit_range / 2 - 1, device=x.device)
|
164
|
+
q_min = torch.tensor(-bit_range / 2, device=x.device)
|
165
|
+
group_size = args.group_size
|
166
|
+
|
167
|
+
if args.strategy == QuantizationStrategy.GROUP:
|
168
|
+
|
169
|
+
if do_dequantize and not do_quantize:
|
170
|
+
# if dequantizing a quantized type infer the output type from the scale
|
171
|
+
output = torch.zeros_like(x, dtype=scale.dtype)
|
172
|
+
else:
|
173
|
+
output_dtype = dtype if dtype is not None else x.dtype
|
174
|
+
output = torch.zeros_like(x, dtype=output_dtype)
|
175
|
+
|
176
|
+
# TODO: vectorize the for loop
|
177
|
+
# TODO: fix genetric assumption about the tensor size for computing group
|
178
|
+
|
179
|
+
# TODO: make validation step for inputs
|
180
|
+
|
181
|
+
while scale.ndim < 2:
|
182
|
+
# pad scale and zero point dims for slicing
|
183
|
+
scale = scale.unsqueeze(1)
|
184
|
+
zero_point = zero_point.unsqueeze(1)
|
185
|
+
|
186
|
+
columns = x.shape[1]
|
187
|
+
if columns >= group_size:
|
188
|
+
if columns % group_size != 0:
|
189
|
+
raise ValueError(
|
190
|
+
"tesnor column shape must be divisble "
|
191
|
+
f"by the given group_size {group_size}"
|
192
|
+
)
|
193
|
+
for i in range(ceil(columns / group_size)):
|
194
|
+
# scale.shape should be [nchan, ndim]
|
195
|
+
# sc.shape should be [nchan, 1] after unsqueeze
|
196
|
+
sc = scale[:, i].view(-1, 1)
|
197
|
+
zp = zero_point[:, i].view(-1, 1)
|
198
|
+
|
199
|
+
idx = i * group_size
|
200
|
+
if do_quantize:
|
201
|
+
output[:, idx : (idx + group_size)] = _quantize(
|
202
|
+
x[:, idx : (idx + group_size)], sc, zp, q_min, q_max, dtype=dtype
|
203
|
+
)
|
204
|
+
if do_dequantize:
|
205
|
+
input = (
|
206
|
+
output[:, idx : (idx + group_size)]
|
207
|
+
if do_quantize
|
208
|
+
else x[:, idx : (idx + group_size)]
|
209
|
+
)
|
210
|
+
output[:, idx : (idx + group_size)] = _dequantize(input, sc, zp)
|
211
|
+
|
212
|
+
else: # covers channel, token and tensor strategies
|
213
|
+
if do_quantize:
|
214
|
+
output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
|
215
|
+
if do_dequantize:
|
216
|
+
output = _dequantize(output if do_quantize else x, scale, zero_point)
|
217
|
+
|
218
|
+
return output
|
66
219
|
|
67
220
|
|
68
221
|
def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
69
222
|
# expects a module already initialized and injected with the parameters in
|
70
223
|
# initialize_module_for_quantization
|
71
|
-
|
224
|
+
if hasattr(module.forward, "__func__"):
|
225
|
+
forward_func_orig = module.forward.__func__
|
226
|
+
else:
|
227
|
+
forward_func_orig = module.forward.func
|
72
228
|
|
73
229
|
@wraps(forward_func_orig) # ensures docstring, names, etc are propagated
|
74
230
|
def wrapped_forward(self, *args, **kwargs):
|
@@ -76,14 +232,14 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
76
232
|
|
77
233
|
if scheme.input_activations is not None:
|
78
234
|
# calibrate and (fake) quantize input activations when applicable
|
79
|
-
input_ =
|
235
|
+
input_ = maybe_calibrate_or_quantize(
|
80
236
|
module, input_, "input", scheme.input_activations
|
81
237
|
)
|
82
238
|
|
83
239
|
if scheme.weights is not None:
|
84
240
|
# calibrate and (fake) quantize weights when applicable
|
85
241
|
unquantized_weight = self.weight.data.clone()
|
86
|
-
self.weight.data =
|
242
|
+
self.weight.data = maybe_calibrate_or_quantize(
|
87
243
|
module, self.weight, "weight", scheme.weights
|
88
244
|
)
|
89
245
|
|
@@ -94,7 +250,7 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
94
250
|
|
95
251
|
if scheme.output_activations is not None:
|
96
252
|
# calibrate and (fake) quantize output activations when applicable
|
97
|
-
output =
|
253
|
+
output = maybe_calibrate_or_quantize(
|
98
254
|
module, output, "output", scheme.output_activations
|
99
255
|
)
|
100
256
|
|
@@ -110,7 +266,7 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
110
266
|
setattr(module, "forward", bound_wrapped_forward)
|
111
267
|
|
112
268
|
|
113
|
-
def
|
269
|
+
def maybe_calibrate_or_quantize(
|
114
270
|
module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
|
115
271
|
) -> torch.Tensor:
|
116
272
|
# only run quantized for the included stages
|
@@ -132,11 +288,41 @@ def _maybe_calibrate_or_quantize(
|
|
132
288
|
if module.quantization_status == QuantizationStatus.CALIBRATION:
|
133
289
|
# calibration mode - get new quant params from observer
|
134
290
|
observer = getattr(module, f"{base_name}_observer")
|
291
|
+
|
135
292
|
updated_scale, updated_zero_point = observer(value)
|
136
293
|
|
137
294
|
# update scale and zero point
|
138
295
|
device = next(module.parameters()).device
|
139
296
|
scale.data = updated_scale.to(device)
|
140
297
|
zero_point.data = updated_zero_point.to(device)
|
141
|
-
|
142
298
|
return fake_quantize(value, scale, zero_point, args)
|
299
|
+
|
300
|
+
|
301
|
+
@torch.no_grad()
|
302
|
+
def _quantize(
|
303
|
+
x: torch.Tensor,
|
304
|
+
scale: torch.Tensor,
|
305
|
+
zero_point: torch.Tensor,
|
306
|
+
q_min: torch.Tensor,
|
307
|
+
q_max: torch.Tensor,
|
308
|
+
dtype: Optional[torch.dtype] = None,
|
309
|
+
) -> torch.Tensor:
|
310
|
+
quantized_value = torch.clamp(
|
311
|
+
torch.round(x / scale + zero_point),
|
312
|
+
q_min,
|
313
|
+
q_max,
|
314
|
+
)
|
315
|
+
|
316
|
+
if dtype is not None:
|
317
|
+
quantized_value = quantized_value.to(dtype)
|
318
|
+
|
319
|
+
return quantized_value
|
320
|
+
|
321
|
+
|
322
|
+
@torch.no_grad()
|
323
|
+
def _dequantize(
|
324
|
+
x_q: torch.Tensor,
|
325
|
+
scale: torch.Tensor,
|
326
|
+
zero_point: torch.Tensor,
|
327
|
+
) -> torch.Tensor:
|
328
|
+
return (x_q - zero_point) * scale
|
@@ -35,6 +35,10 @@ def freeze_module_quantization(module: Module):
|
|
35
35
|
# no quantization scheme nothing to do
|
36
36
|
return
|
37
37
|
|
38
|
+
if module.quantization_status == QuantizationStatus.FROZEN:
|
39
|
+
# nothing to do, already frozen
|
40
|
+
return
|
41
|
+
|
38
42
|
# delete observers from module if not dynamic
|
39
43
|
if scheme.input_activations and not scheme.input_activations.dynamic:
|
40
44
|
delattr(module, "input_observer")
|
@@ -20,7 +20,10 @@ import torch
|
|
20
20
|
from compressed_tensors.quantization.lifecycle.forward import (
|
21
21
|
wrap_module_forward_quantized,
|
22
22
|
)
|
23
|
-
from compressed_tensors.quantization.quant_args import
|
23
|
+
from compressed_tensors.quantization.quant_args import (
|
24
|
+
QuantizationArgs,
|
25
|
+
QuantizationStrategy,
|
26
|
+
)
|
24
27
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
25
28
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
26
29
|
from torch.nn import Module, Parameter
|
@@ -58,7 +61,12 @@ def initialize_module_for_quantization(
|
|
58
61
|
_initialize_scale_zero_point_observer(module, "input", scheme.input_activations)
|
59
62
|
if scheme.weights is not None:
|
60
63
|
if hasattr(module, "weight"):
|
61
|
-
|
64
|
+
weight_shape = None
|
65
|
+
if isinstance(module, torch.nn.Linear):
|
66
|
+
weight_shape = module.weight.shape
|
67
|
+
_initialize_scale_zero_point_observer(
|
68
|
+
module, "weight", scheme.weights, weight_shape=weight_shape
|
69
|
+
)
|
62
70
|
else:
|
63
71
|
_LOGGER.warning(
|
64
72
|
f"module type {type(module)} targeted for weight quantization but "
|
@@ -78,7 +86,10 @@ def initialize_module_for_quantization(
|
|
78
86
|
|
79
87
|
|
80
88
|
def _initialize_scale_zero_point_observer(
|
81
|
-
module: Module,
|
89
|
+
module: Module,
|
90
|
+
base_name: str,
|
91
|
+
quantization_args: QuantizationArgs,
|
92
|
+
weight_shape: Optional[torch.Size] = None,
|
82
93
|
):
|
83
94
|
# initialize observer module and attach as submodule
|
84
95
|
observer = quantization_args.get_observer()
|
@@ -89,11 +100,28 @@ def _initialize_scale_zero_point_observer(
|
|
89
100
|
|
90
101
|
device = next(module.parameters()).device
|
91
102
|
|
103
|
+
# infer expected scale/zero point shape
|
104
|
+
expected_shape = 1 # per tensor
|
105
|
+
|
106
|
+
if base_name == "weight" and weight_shape is not None:
|
107
|
+
if quantization_args.strategy == QuantizationStrategy.CHANNEL:
|
108
|
+
# (output_channels, 1)
|
109
|
+
expected_shape = (weight_shape[0], 1)
|
110
|
+
elif quantization_args.strategy == QuantizationStrategy.GROUP:
|
111
|
+
expected_shape = (
|
112
|
+
weight_shape[0],
|
113
|
+
weight_shape[1] // quantization_args.group_size,
|
114
|
+
)
|
115
|
+
|
92
116
|
# initializes empty scale and zero point parameters for the module
|
93
|
-
init_scale = Parameter(
|
117
|
+
init_scale = Parameter(
|
118
|
+
torch.empty(expected_shape, dtype=module.weight.dtype, device=device),
|
119
|
+
requires_grad=False,
|
120
|
+
)
|
94
121
|
module.register_parameter(f"{base_name}_scale", init_scale)
|
95
122
|
|
96
123
|
init_zero_point = Parameter(
|
97
|
-
torch.empty(
|
124
|
+
torch.empty(expected_shape, device=device, dtype=int),
|
125
|
+
requires_grad=False,
|
98
126
|
)
|
99
127
|
module.register_parameter(f"{base_name}_zero_point", init_zero_point)
|
@@ -12,9 +12,13 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Optional, Tuple
|
15
|
+
from typing import Any, Iterable, Optional, Tuple, Union
|
16
16
|
|
17
|
-
|
17
|
+
import torch
|
18
|
+
from compressed_tensors.quantization.quant_args import (
|
19
|
+
QuantizationArgs,
|
20
|
+
QuantizationStrategy,
|
21
|
+
)
|
18
22
|
from compressed_tensors.registry.registry import RegistryMixin
|
19
23
|
from torch import FloatTensor, IntTensor, Tensor
|
20
24
|
from torch.nn import Module
|
@@ -36,6 +40,7 @@ class Observer(Module, RegistryMixin):
|
|
36
40
|
self._scale = None
|
37
41
|
self._zero_point = None
|
38
42
|
|
43
|
+
@torch.no_grad()
|
39
44
|
def forward(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
|
40
45
|
"""
|
41
46
|
maps directly to get_qparams
|
@@ -45,13 +50,26 @@ class Observer(Module, RegistryMixin):
|
|
45
50
|
"""
|
46
51
|
return self.get_qparams(observed=observed)
|
47
52
|
|
48
|
-
def calculate_qparams(
|
53
|
+
def calculate_qparams(
|
54
|
+
self,
|
55
|
+
observed: Tensor,
|
56
|
+
reduce_dims: Optional[Tuple[int]] = None,
|
57
|
+
) -> Tuple[FloatTensor, IntTensor]:
|
49
58
|
"""
|
50
59
|
:param observed: observed tensor to calculate quantization parameters for
|
60
|
+
:param reduce_dims: optional tuple of dimensions to reduce along,
|
61
|
+
returned scale and zero point will be shaped (1,) along the
|
62
|
+
reduced dimensions
|
51
63
|
:return: tuple of scale and zero point derived from the observed tensor
|
52
64
|
"""
|
53
65
|
raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")
|
54
66
|
|
67
|
+
def post_calculate_qparams(self) -> None:
|
68
|
+
"""
|
69
|
+
Run any logic specific to its observers after running calculate_qparams
|
70
|
+
"""
|
71
|
+
...
|
72
|
+
|
55
73
|
def get_qparams(
|
56
74
|
self, observed: Optional[Tensor] = None
|
57
75
|
) -> Tuple[FloatTensor, IntTensor]:
|
@@ -59,11 +77,58 @@ class Observer(Module, RegistryMixin):
|
|
59
77
|
Convenience function to wrap overwritten calculate_qparams
|
60
78
|
adds support to make observed tensor optional and support for tracking latest
|
61
79
|
calculated scale and zero point
|
80
|
+
|
62
81
|
:param observed: optional observed tensor to calculate quantization parameters
|
63
82
|
from
|
64
83
|
:return: tuple of scale and zero point based on last observed value
|
65
84
|
"""
|
66
85
|
if observed is not None:
|
67
|
-
|
68
|
-
|
86
|
+
group_size = self.quantization_args.group_size
|
87
|
+
|
88
|
+
if self.quantization_args.strategy == QuantizationStrategy.TENSOR:
|
89
|
+
|
90
|
+
# re-calculate scale and zero point, update the stored value
|
91
|
+
self._scale, self._zero_point = self.calculate_qparams(observed)
|
92
|
+
|
93
|
+
elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
|
94
|
+
columns = observed.shape[1]
|
95
|
+
scales, zero_points = [], []
|
96
|
+
group_idxs = range(0, columns, self.quantization_args.group_size)
|
97
|
+
for group_id, group_idx in enumerate(group_idxs):
|
98
|
+
scale, zero_point = self.get_qparams_along_dim(
|
99
|
+
observed[:, group_idx : (group_idx + group_size)],
|
100
|
+
0,
|
101
|
+
tensor_id=group_id,
|
102
|
+
)
|
103
|
+
scales.append(scale)
|
104
|
+
zero_points.append(zero_point)
|
105
|
+
|
106
|
+
self._scale = torch.cat(scales, dim=1, out=self._scale)
|
107
|
+
self._zero_point = torch.cat(zero_points, dim=1, out=self._zero_point)
|
108
|
+
|
109
|
+
elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
|
110
|
+
# assume observed is transposed, because its the output, hence use dim 0
|
111
|
+
self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)
|
112
|
+
|
113
|
+
elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
|
114
|
+
# use dim 1, assume the obsersed.shape = [batch, token, hidden]
|
115
|
+
# should be batch, token
|
116
|
+
self._scale, self._zero_point = self.get_qparams_along_dim(
|
117
|
+
observed,
|
118
|
+
dim={0, 1},
|
119
|
+
)
|
120
|
+
|
69
121
|
return self._scale, self._zero_point
|
122
|
+
|
123
|
+
def get_qparams_along_dim(
|
124
|
+
self,
|
125
|
+
observed,
|
126
|
+
dim: Union[int, Iterable[int]],
|
127
|
+
tensor_id: Optional[Any] = None,
|
128
|
+
):
|
129
|
+
dim = set(dim)
|
130
|
+
|
131
|
+
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
|
132
|
+
return self.calculate_qparams(
|
133
|
+
observed, reduce_dims=reduce_dims, tensor_id=tensor_id
|
134
|
+
)
|
@@ -35,19 +35,24 @@ def calculate_qparams(
|
|
35
35
|
"""
|
36
36
|
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
|
37
37
|
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
38
|
+
device = min_vals.device
|
38
39
|
|
39
40
|
bit_range = 2**quantization_args.num_bits - 1
|
40
41
|
bit_min = -(bit_range + 1) / 2
|
41
42
|
bit_max = bit_min + bit_range
|
42
43
|
if quantization_args.symmetric:
|
43
|
-
zero_points = torch.tensor(0).to(torch.int8)
|
44
44
|
max_val_pos = torch.max(-min_vals, max_vals)
|
45
45
|
scales = max_val_pos / (float(bit_range) / 2)
|
46
46
|
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
|
47
|
+
zero_points = torch.zeros(scales.shape, device=device, dtype=torch.int8)
|
47
48
|
else:
|
48
49
|
scales = (max_vals - min_vals) / float(bit_range)
|
49
50
|
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
|
50
51
|
zero_points = bit_min - torch.round(min_vals / scales)
|
51
52
|
zero_points = torch.clamp(zero_points, bit_min, bit_max).to(torch.int8)
|
52
53
|
|
54
|
+
if scales.ndim == 0:
|
55
|
+
scales = scales.reshape(1)
|
56
|
+
zero_points = zero_points.reshape(1)
|
57
|
+
|
53
58
|
return scales, zero_points
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Tuple
|
15
|
+
from typing import Any, Optional, Tuple
|
16
16
|
|
17
17
|
import torch
|
18
18
|
from compressed_tensors.quantization.observers.base import Observer
|
@@ -30,19 +30,27 @@ class MemorylessObserver(Observer):
|
|
30
30
|
zero point based on the latest observed value without tracking state
|
31
31
|
"""
|
32
32
|
|
33
|
-
def calculate_qparams(
|
33
|
+
def calculate_qparams(
|
34
|
+
self,
|
35
|
+
observed: Tensor,
|
36
|
+
tensor_id: Optional[Any] = None,
|
37
|
+
reduce_dims: Optional[Tuple[int]] = None,
|
38
|
+
) -> Tuple[FloatTensor, IntTensor]:
|
34
39
|
"""
|
35
|
-
Returns the min and max values of observed
|
40
|
+
Returns the min and max values of observed tensor
|
36
41
|
|
37
42
|
:param observed: observed tensor to calculate quantization parameters for
|
43
|
+
:param tensor_id: optional id for tensor; not used for memoryless
|
44
|
+
:param reduce_dims: optional tuple of dimensions to reduce along,
|
45
|
+
returned scale and zero point will be shaped (1,) along the
|
46
|
+
reduced dimensions
|
38
47
|
:return: tuple of scale and zero point derived from the observed tensor
|
39
48
|
"""
|
40
|
-
# TODO: Add support for full range of quantization Args, only supports 8bit
|
41
|
-
# per tensor
|
42
|
-
min_val, max_val = torch.aminmax(observed)
|
43
49
|
|
44
|
-
|
45
|
-
|
46
|
-
|
50
|
+
if not reduce_dims:
|
51
|
+
min_val, max_val = torch.aminmax(observed)
|
52
|
+
else:
|
53
|
+
min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
|
54
|
+
max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
|
47
55
|
|
48
56
|
return calculate_qparams(min_val, max_val, self.quantization_args)
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Tuple
|
15
|
+
from typing import Any, Optional, Tuple
|
16
16
|
|
17
17
|
import torch
|
18
18
|
from compressed_tensors.quantization.observers.base import Observer
|
@@ -36,30 +36,61 @@ class MovingAverageMinMaxObserver(Observer):
|
|
36
36
|
):
|
37
37
|
super().__init__(quantization_args=quantization_args)
|
38
38
|
|
39
|
-
self.min_val =
|
40
|
-
self.max_val =
|
39
|
+
self.min_val = {}
|
40
|
+
self.max_val = {}
|
41
41
|
self.averaging_constant = averaging_constant
|
42
42
|
|
43
|
-
def calculate_qparams(
|
43
|
+
def calculate_qparams(
|
44
|
+
self,
|
45
|
+
observed: Tensor,
|
46
|
+
reduce_dims: Optional[Tuple[int]] = None,
|
47
|
+
tensor_id: Optional[Any] = None,
|
48
|
+
) -> Tuple[FloatTensor, IntTensor]:
|
44
49
|
"""
|
45
50
|
Updates the observed min and max using a moving average smoothed by the
|
46
51
|
averaging_constant
|
47
52
|
|
48
53
|
:param observed: observed tensor to calculate quantization parameters for
|
54
|
+
:param reduce_dims: optional tuple of dimensions to reduce along,
|
55
|
+
returned scale and zero point will be shaped (1,) along the
|
56
|
+
reduced dimensions
|
57
|
+
:param tensor_id: Optional id if different ranges of observed tensors are
|
58
|
+
passed, useful for sharding tensors by group_size
|
49
59
|
:return: tuple of scale and zero point derived from the observed tensor
|
50
60
|
"""
|
61
|
+
tensor_id = tensor_id or "default"
|
51
62
|
|
52
|
-
|
63
|
+
if not reduce_dims:
|
64
|
+
min_val, max_val = torch.aminmax(observed)
|
65
|
+
else:
|
66
|
+
min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
|
67
|
+
max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
|
68
|
+
|
69
|
+
running_min_val = self.min_val.get(tensor_id, None)
|
70
|
+
running_max_val = self.max_val.get(tensor_id, None)
|
53
71
|
|
54
|
-
if
|
55
|
-
|
56
|
-
|
72
|
+
if running_min_val is None or running_max_val is None:
|
73
|
+
updated_min_val = min_val
|
74
|
+
updated_max_val = max_val
|
57
75
|
else:
|
58
|
-
|
59
|
-
min_val -
|
76
|
+
updated_min_val = running_min_val + self.averaging_constant * (
|
77
|
+
min_val - running_min_val
|
60
78
|
)
|
61
|
-
|
62
|
-
max_val -
|
79
|
+
updated_max_val = running_max_val + self.averaging_constant * (
|
80
|
+
max_val - running_max_val
|
63
81
|
)
|
64
82
|
|
65
|
-
|
83
|
+
self.min_val[tensor_id] = updated_min_val
|
84
|
+
self.max_val[tensor_id] = updated_max_val
|
85
|
+
|
86
|
+
return calculate_qparams(
|
87
|
+
updated_min_val, updated_max_val, self.quantization_args
|
88
|
+
)
|
89
|
+
|
90
|
+
def get_qparams_along_dim(
|
91
|
+
self, observed, dim: int, tensor_id: Optional[Any] = None
|
92
|
+
):
|
93
|
+
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
|
94
|
+
return self.calculate_qparams(
|
95
|
+
observed, reduce_dims=reduce_dims, tensor_id=tensor_id
|
96
|
+
)
|