compressed-tensors-nightly 0.5.0.20240814__py3-none-any.whl → 0.5.0.20240830__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/compressors/base.py +198 -8
- compressed_tensors/compressors/model_compressor.py +65 -1
- compressed_tensors/compressors/naive_quantized.py +71 -75
- compressed_tensors/compressors/pack_quantized.py +83 -94
- compressed_tensors/linear/__init__.py +13 -0
- compressed_tensors/linear/compressed_linear.py +87 -0
- compressed_tensors/quantization/lifecycle/apply.py +36 -4
- compressed_tensors/quantization/lifecycle/calibration.py +3 -2
- compressed_tensors/quantization/lifecycle/compressed.py +1 -1
- compressed_tensors/quantization/lifecycle/forward.py +67 -43
- compressed_tensors/quantization/lifecycle/helpers.py +29 -2
- compressed_tensors/quantization/lifecycle/initialize.py +50 -16
- compressed_tensors/quantization/observers/__init__.py +1 -0
- compressed_tensors/quantization/observers/base.py +54 -14
- compressed_tensors/quantization/observers/min_max.py +8 -0
- compressed_tensors/quantization/observers/mse.py +162 -0
- compressed_tensors/quantization/quant_args.py +48 -20
- compressed_tensors/utils/__init__.py +1 -0
- compressed_tensors/utils/helpers.py +13 -0
- compressed_tensors/utils/offload.py +7 -1
- compressed_tensors/utils/permute.py +70 -0
- compressed_tensors/utils/safetensors_load.py +2 -0
- compressed_tensors/utils/semi_structured_conversions.py +1 -0
- {compressed_tensors_nightly-0.5.0.20240814.dist-info → compressed_tensors_nightly-0.5.0.20240830.dist-info}/METADATA +3 -2
- compressed_tensors_nightly-0.5.0.20240830.dist-info/RECORD +52 -0
- compressed_tensors_nightly-0.5.0.20240814.dist-info/RECORD +0 -48
- {compressed_tensors_nightly-0.5.0.20240814.dist-info → compressed_tensors_nightly-0.5.0.20240830.dist-info}/LICENSE +0 -0
- {compressed_tensors_nightly-0.5.0.20240814.dist-info → compressed_tensors_nightly-0.5.0.20240830.dist-info}/WHEEL +0 -0
- {compressed_tensors_nightly-0.5.0.20240814.dist-info → compressed_tensors_nightly-0.5.0.20240830.dist-info}/top_level.txt +0 -0
@@ -25,7 +25,7 @@ from compressed_tensors.quantization.quant_args import (
|
|
25
25
|
)
|
26
26
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
27
27
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
28
|
-
from compressed_tensors.utils import update_parameter_data
|
28
|
+
from compressed_tensors.utils import safe_permute, update_parameter_data
|
29
29
|
from torch.nn import Module
|
30
30
|
|
31
31
|
|
@@ -45,6 +45,7 @@ def quantize(
|
|
45
45
|
zero_point: torch.Tensor,
|
46
46
|
args: QuantizationArgs,
|
47
47
|
dtype: Optional[torch.dtype] = None,
|
48
|
+
g_idx: Optional[torch.Tensor] = None,
|
48
49
|
) -> torch.Tensor:
|
49
50
|
"""
|
50
51
|
Quantize the input tensor x using the QuantizationStrategy specified in args.
|
@@ -58,16 +59,9 @@ def quantize(
|
|
58
59
|
:param zero_point: zero point tensor
|
59
60
|
:param args: quantization args dictating how to quantize x
|
60
61
|
:param dtype: optional dtype to cast the quantized output to
|
62
|
+
:param g_idx: optional mapping from column index to group index
|
61
63
|
:return: fake quantized tensor
|
62
64
|
"""
|
63
|
-
# ensure all tensors are on the same device
|
64
|
-
# assumes that the target device is the input
|
65
|
-
# tensor's device
|
66
|
-
if x.device != scale.device:
|
67
|
-
scale = scale.to(x.device)
|
68
|
-
if x.device != zero_point.device:
|
69
|
-
zero_point = zero_point.to(x.device)
|
70
|
-
|
71
65
|
return _process_quantization(
|
72
66
|
x=x,
|
73
67
|
scale=scale,
|
@@ -76,6 +70,7 @@ def quantize(
|
|
76
70
|
dtype=dtype,
|
77
71
|
do_quantize=True,
|
78
72
|
do_dequantize=False,
|
73
|
+
g_idx=g_idx,
|
79
74
|
)
|
80
75
|
|
81
76
|
|
@@ -86,6 +81,7 @@ def dequantize(
|
|
86
81
|
zero_point: torch.Tensor = None,
|
87
82
|
args: QuantizationArgs = None,
|
88
83
|
dtype: Optional[torch.dtype] = None,
|
84
|
+
g_idx: Optional[torch.Tensor] = None,
|
89
85
|
) -> torch.Tensor:
|
90
86
|
"""
|
91
87
|
Dequantize a quantized input tensor x_q based on the strategy specified in args. If
|
@@ -96,6 +92,7 @@ def dequantize(
|
|
96
92
|
:param zero_point: zero point tensor
|
97
93
|
:param args: quantization args used to quantize x_q
|
98
94
|
:param dtype: optional dtype to cast the dequantized output to
|
95
|
+
:param g_idx: optional mapping from column index to group index
|
99
96
|
:return: dequantized float tensor
|
100
97
|
"""
|
101
98
|
if args is None:
|
@@ -126,6 +123,7 @@ def dequantize(
|
|
126
123
|
do_quantize=False,
|
127
124
|
do_dequantize=True,
|
128
125
|
dtype=dtype,
|
126
|
+
g_idx=g_idx,
|
129
127
|
)
|
130
128
|
|
131
129
|
|
@@ -135,6 +133,7 @@ def fake_quantize(
|
|
135
133
|
scale: torch.Tensor,
|
136
134
|
zero_point: torch.Tensor,
|
137
135
|
args: QuantizationArgs,
|
136
|
+
g_idx: Optional[torch.Tensor] = None,
|
138
137
|
) -> torch.Tensor:
|
139
138
|
"""
|
140
139
|
Fake quantize the input tensor x by quantizing then dequantizing with
|
@@ -147,6 +146,7 @@ def fake_quantize(
|
|
147
146
|
:param scale: scale tensor
|
148
147
|
:param zero_point: zero point tensor
|
149
148
|
:param args: quantization args dictating how to quantize x
|
149
|
+
:param g_idx: optional mapping from column index to group index
|
150
150
|
:return: fake quantized tensor
|
151
151
|
"""
|
152
152
|
return _process_quantization(
|
@@ -156,6 +156,7 @@ def fake_quantize(
|
|
156
156
|
args=args,
|
157
157
|
do_quantize=True,
|
158
158
|
do_dequantize=True,
|
159
|
+
g_idx=g_idx,
|
159
160
|
)
|
160
161
|
|
161
162
|
|
@@ -164,21 +165,19 @@ def _process_quantization(
|
|
164
165
|
x: torch.Tensor,
|
165
166
|
scale: torch.Tensor,
|
166
167
|
zero_point: torch.Tensor,
|
168
|
+
g_idx: Optional[torch.Tensor],
|
167
169
|
args: QuantizationArgs,
|
168
170
|
dtype: Optional[torch.dtype] = None,
|
169
171
|
do_quantize: bool = True,
|
170
172
|
do_dequantize: bool = True,
|
171
173
|
) -> torch.Tensor:
|
172
|
-
|
173
174
|
q_min, q_max = calculate_range(args, x.device)
|
174
175
|
group_size = args.group_size
|
175
176
|
|
176
177
|
if args.strategy == QuantizationStrategy.GROUP:
|
177
178
|
output_dtype = dtype if dtype is not None else x.dtype
|
178
179
|
output = torch.zeros_like(x).to(output_dtype)
|
179
|
-
|
180
|
-
# TODO: vectorize the for loop
|
181
|
-
# TODO: fix genetric assumption about the tensor size for computing group
|
180
|
+
columns = output.shape[1]
|
182
181
|
|
183
182
|
# TODO: make validation step for inputs
|
184
183
|
|
@@ -187,23 +186,38 @@ def _process_quantization(
|
|
187
186
|
scale = scale.unsqueeze(1)
|
188
187
|
zero_point = zero_point.unsqueeze(1) if zero_point is not None else None
|
189
188
|
|
190
|
-
columns = x.shape[1]
|
191
189
|
if columns >= group_size:
|
192
190
|
if columns % group_size != 0:
|
193
191
|
raise ValueError(
|
194
|
-
"
|
192
|
+
"tensor column shape must be divisble "
|
195
193
|
f"by the given group_size {group_size}"
|
196
194
|
)
|
197
|
-
for i in range(ceil(columns / group_size)):
|
198
|
-
# scale.shape should be [nchan, ndim]
|
199
|
-
# sc.shape should be [nchan, 1] after unsqueeze
|
200
|
-
sc = scale[:, i].view(-1, 1)
|
201
|
-
zp = zero_point[:, i].view(-1, 1) if zero_point is not None else None
|
202
195
|
|
203
|
-
|
196
|
+
# support column-order (default) quantization as well as other orderings
|
197
|
+
# such as activation ordering. Below checks if g_idx has been initialized
|
198
|
+
is_column_order = g_idx is None or -1 in g_idx
|
199
|
+
if is_column_order:
|
200
|
+
num_groups = int(ceil(columns / group_size))
|
201
|
+
group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
|
202
|
+
|
203
|
+
else:
|
204
|
+
group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
|
205
|
+
group_sizes = group_sizes[torch.argsort(group_indices)]
|
206
|
+
|
207
|
+
perm = torch.argsort(g_idx)
|
208
|
+
x = safe_permute(x, perm, dim=1)
|
209
|
+
|
210
|
+
# TODO: experiment with vectorizing for loop for performance
|
211
|
+
end = 0
|
212
|
+
for index, group_count in enumerate(group_sizes):
|
213
|
+
sc = scale[:, index].view(-1, 1)
|
214
|
+
zp = zero_point[:, index].view(-1, 1) if zero_point is not None else None
|
215
|
+
|
216
|
+
start = end
|
217
|
+
end = start + group_count
|
204
218
|
if do_quantize:
|
205
|
-
output[:,
|
206
|
-
x[:,
|
219
|
+
output[:, start:end] = _quantize(
|
220
|
+
x[:, start:end],
|
207
221
|
sc,
|
208
222
|
zp,
|
209
223
|
q_min,
|
@@ -211,13 +225,13 @@ def _process_quantization(
|
|
211
225
|
args,
|
212
226
|
dtype=dtype,
|
213
227
|
)
|
228
|
+
|
214
229
|
if do_dequantize:
|
215
|
-
input =
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
output[:, idx : (idx + group_size)] = _dequantize(input, sc, zp)
|
230
|
+
input = output[:, start:end] if do_quantize else x[:, start:end]
|
231
|
+
output[:, start:end] = _dequantize(input, sc, zp)
|
232
|
+
|
233
|
+
if not is_column_order:
|
234
|
+
output = safe_permute(output, torch.argsort(perm), dim=1)
|
221
235
|
|
222
236
|
else: # covers channel, token and tensor strategies
|
223
237
|
if do_quantize:
|
@@ -252,6 +266,7 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
252
266
|
return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
|
253
267
|
|
254
268
|
input_ = args[0]
|
269
|
+
compressed = module.quantization_status == QuantizationStatus.COMPRESSED
|
255
270
|
|
256
271
|
if scheme.input_activations is not None:
|
257
272
|
# calibrate and (fake) quantize input activations when applicable
|
@@ -259,7 +274,7 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
259
274
|
module, input_, "input", scheme.input_activations
|
260
275
|
)
|
261
276
|
|
262
|
-
if scheme.weights is not None:
|
277
|
+
if scheme.weights is not None and not compressed:
|
263
278
|
# calibrate and (fake) quantize weights when applicable
|
264
279
|
unquantized_weight = self.weight.data.clone()
|
265
280
|
self.weight.data = maybe_calibrate_or_quantize(
|
@@ -278,7 +293,7 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
278
293
|
)
|
279
294
|
|
280
295
|
# restore back to unquantized_value
|
281
|
-
if scheme.weights is not None:
|
296
|
+
if scheme.weights is not None and not compressed:
|
282
297
|
self.weight.data = unquantized_weight
|
283
298
|
|
284
299
|
return output
|
@@ -292,11 +307,16 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
292
307
|
def maybe_calibrate_or_quantize(
|
293
308
|
module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
|
294
309
|
) -> torch.Tensor:
|
295
|
-
#
|
296
|
-
if module.quantization_status
|
297
|
-
|
298
|
-
|
299
|
-
|
310
|
+
# don't run quantization if we haven't entered calibration mode
|
311
|
+
if module.quantization_status == QuantizationStatus.INITIALIZED:
|
312
|
+
return value
|
313
|
+
|
314
|
+
# in compressed mode, the weight is already compressed and quantized so we don't
|
315
|
+
# need to run fake quantization
|
316
|
+
if (
|
317
|
+
module.quantization_status == QuantizationStatus.COMPRESSED
|
318
|
+
and base_name == "weight"
|
319
|
+
):
|
300
320
|
return value
|
301
321
|
|
302
322
|
if value.numel() == 0:
|
@@ -304,14 +324,16 @@ def maybe_calibrate_or_quantize(
|
|
304
324
|
# skip quantization
|
305
325
|
return value
|
306
326
|
|
327
|
+
g_idx = getattr(module, "weight_g_idx", None)
|
328
|
+
|
307
329
|
if args.dynamic:
|
308
330
|
# dynamic quantization - get scale and zero point directly from observer
|
309
331
|
observer = getattr(module, f"{base_name}_observer")
|
310
|
-
scale, zero_point = observer(value)
|
332
|
+
scale, zero_point = observer(value, g_idx=g_idx)
|
311
333
|
else:
|
312
334
|
# static quantization - get previous scale and zero point from layer
|
313
335
|
scale = getattr(module, f"{base_name}_scale")
|
314
|
-
zero_point = getattr(module, f"{base_name}_zero_point")
|
336
|
+
zero_point = getattr(module, f"{base_name}_zero_point", None)
|
315
337
|
|
316
338
|
if (
|
317
339
|
module.quantization_status == QuantizationStatus.CALIBRATION
|
@@ -320,13 +342,13 @@ def maybe_calibrate_or_quantize(
|
|
320
342
|
# calibration mode - get new quant params from observer
|
321
343
|
observer = getattr(module, f"{base_name}_observer")
|
322
344
|
|
323
|
-
updated_scale, updated_zero_point = observer(value)
|
345
|
+
updated_scale, updated_zero_point = observer(value, g_idx=g_idx)
|
324
346
|
|
325
347
|
# update scale and zero point
|
326
348
|
update_parameter_data(module, updated_scale, f"{base_name}_scale")
|
327
349
|
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
|
328
350
|
|
329
|
-
return fake_quantize(value, scale, zero_point, args)
|
351
|
+
return fake_quantize(value, scale, zero_point, args, g_idx=g_idx)
|
330
352
|
|
331
353
|
|
332
354
|
@torch.no_grad()
|
@@ -340,7 +362,9 @@ def _quantize(
|
|
340
362
|
dtype: Optional[torch.dtype] = None,
|
341
363
|
) -> torch.Tensor:
|
342
364
|
|
343
|
-
scaled = x / scale
|
365
|
+
scaled = x / scale
|
366
|
+
if zero_point is not None:
|
367
|
+
scaled += zero_point.to(x.dtype)
|
344
368
|
# clamp first because cast isn't guaranteed to be saturated (ie for fp8)
|
345
369
|
clamped_value = torch.clamp(
|
346
370
|
scaled,
|
@@ -361,11 +385,11 @@ def _dequantize(
|
|
361
385
|
zero_point: torch.Tensor = None,
|
362
386
|
dtype: Optional[torch.dtype] = None,
|
363
387
|
) -> torch.Tensor:
|
388
|
+
dequant_value = x_q.to(scale.dtype)
|
364
389
|
|
365
|
-
dequant_value = x_q
|
366
390
|
if zero_point is not None:
|
367
391
|
dequant_value = dequant_value - zero_point.to(scale.dtype)
|
368
|
-
dequant_value = dequant_value
|
392
|
+
dequant_value = dequant_value * scale
|
369
393
|
|
370
394
|
if dtype is not None:
|
371
395
|
dequant_value = dequant_value.to(dtype)
|
@@ -16,7 +16,9 @@
|
|
16
16
|
Miscelaneous helpers for the quantization lifecycle
|
17
17
|
"""
|
18
18
|
|
19
|
+
from typing import Optional
|
19
20
|
|
21
|
+
import torch
|
20
22
|
from torch.nn import Module
|
21
23
|
|
22
24
|
|
@@ -27,16 +29,41 @@ __all__ = [
|
|
27
29
|
]
|
28
30
|
|
29
31
|
|
30
|
-
def update_layer_weight_quant_params(
|
31
|
-
|
32
|
+
def update_layer_weight_quant_params(
|
33
|
+
layer: Module,
|
34
|
+
weight: Optional[torch.Tensor] = None,
|
35
|
+
g_idx: Optional[torch.Tensor] = None,
|
36
|
+
reset_obs: bool = False,
|
37
|
+
):
|
38
|
+
"""
|
39
|
+
Update quantization parameters on layer
|
40
|
+
|
41
|
+
:param layer: input layer
|
42
|
+
:param weight: weight to update quant params with, defaults to layer weight
|
43
|
+
:param g_idx: optional mapping from column index to group index
|
44
|
+
:param reset_obs: reset the observer before calculating quant params,
|
45
|
+
defaults to False
|
46
|
+
"""
|
47
|
+
attached_weight = getattr(layer, "weight", None)
|
48
|
+
|
49
|
+
if weight is None:
|
50
|
+
weight = attached_weight
|
32
51
|
scale = getattr(layer, "weight_scale", None)
|
33
52
|
zero_point = getattr(layer, "weight_zero_point", None)
|
53
|
+
if g_idx is None:
|
54
|
+
g_idx = getattr(layer, "weight_g_idx", None)
|
34
55
|
observer = getattr(layer, "weight_observer", None)
|
35
56
|
|
36
57
|
if weight is None or observer is None or scale is None or zero_point is None:
|
37
58
|
# scale, zp, or observer not calibratable or weight not available
|
38
59
|
return
|
39
60
|
|
61
|
+
if reset_obs:
|
62
|
+
observer.reset()
|
63
|
+
|
64
|
+
if attached_weight is not None:
|
65
|
+
weight = weight.to(attached_weight.dtype)
|
66
|
+
|
40
67
|
updated_scale, updated_zero_point = observer(weight)
|
41
68
|
|
42
69
|
# update scale and zero point
|
@@ -17,8 +17,6 @@ import logging
|
|
17
17
|
from typing import Optional
|
18
18
|
|
19
19
|
import torch
|
20
|
-
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
21
|
-
from accelerate.utils import PrefixedDataset
|
22
20
|
from compressed_tensors.quantization.lifecycle.forward import (
|
23
21
|
wrap_module_forward_quantized,
|
24
22
|
)
|
@@ -43,6 +41,7 @@ _LOGGER = logging.getLogger(__name__)
|
|
43
41
|
def initialize_module_for_quantization(
|
44
42
|
module: Module,
|
45
43
|
scheme: Optional[QuantizationScheme] = None,
|
44
|
+
force_zero_point: bool = True,
|
46
45
|
):
|
47
46
|
"""
|
48
47
|
attaches appropriate scales, zero points, and observers to a layer
|
@@ -54,6 +53,8 @@ def initialize_module_for_quantization(
|
|
54
53
|
:param scheme: scheme to use for quantization. if None is provided,
|
55
54
|
will attempt to use scheme stored in the module under `quantization_scheme`,
|
56
55
|
if not provided, the layer will be skipped
|
56
|
+
:param force_zero_point: whether to force initialization of a zero point for
|
57
|
+
symmetric quantization
|
57
58
|
"""
|
58
59
|
scheme = scheme or getattr(module, "quantization_scheme", None)
|
59
60
|
if scheme is None:
|
@@ -61,14 +62,18 @@ def initialize_module_for_quantization(
|
|
61
62
|
return
|
62
63
|
|
63
64
|
if scheme.input_activations is not None:
|
64
|
-
_initialize_scale_zero_point_observer(
|
65
|
+
_initialize_scale_zero_point_observer(
|
66
|
+
module, "input", scheme.input_activations, force_zero_point=force_zero_point
|
67
|
+
)
|
65
68
|
if scheme.weights is not None:
|
66
69
|
if hasattr(module, "weight"):
|
67
|
-
weight_shape =
|
68
|
-
if isinstance(module, torch.nn.Linear):
|
69
|
-
weight_shape = module.weight.shape
|
70
|
+
weight_shape = module.weight.shape
|
70
71
|
_initialize_scale_zero_point_observer(
|
71
|
-
module,
|
72
|
+
module,
|
73
|
+
"weight",
|
74
|
+
scheme.weights,
|
75
|
+
weight_shape=weight_shape,
|
76
|
+
force_zero_point=force_zero_point,
|
72
77
|
)
|
73
78
|
else:
|
74
79
|
_LOGGER.warning(
|
@@ -78,7 +83,10 @@ def initialize_module_for_quantization(
|
|
78
83
|
)
|
79
84
|
if scheme.output_activations is not None:
|
80
85
|
_initialize_scale_zero_point_observer(
|
81
|
-
module,
|
86
|
+
module,
|
87
|
+
"output",
|
88
|
+
scheme.output_activations,
|
89
|
+
force_zero_point=force_zero_point,
|
82
90
|
)
|
83
91
|
|
84
92
|
module.quantization_scheme = scheme
|
@@ -86,6 +94,16 @@ def initialize_module_for_quantization(
|
|
86
94
|
|
87
95
|
offloaded = False
|
88
96
|
if is_module_offloaded(module):
|
97
|
+
try:
|
98
|
+
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
99
|
+
from accelerate.utils import PrefixedDataset
|
100
|
+
except ModuleNotFoundError:
|
101
|
+
raise ModuleNotFoundError(
|
102
|
+
"Offloaded model detected. To use CPU offloading with "
|
103
|
+
"compressed-tensors the `accelerate` package must be installed, "
|
104
|
+
"run `pip install compressed-tensors[accelerate]`"
|
105
|
+
)
|
106
|
+
|
89
107
|
offloaded = True
|
90
108
|
hook = module._hf_hook
|
91
109
|
prefix_dict = module._hf_hook.weights_map
|
@@ -116,6 +134,7 @@ def _initialize_scale_zero_point_observer(
|
|
116
134
|
base_name: str,
|
117
135
|
quantization_args: QuantizationArgs,
|
118
136
|
weight_shape: Optional[torch.Size] = None,
|
137
|
+
force_zero_point: bool = True,
|
119
138
|
):
|
120
139
|
# initialize observer module and attach as submodule
|
121
140
|
observer = quantization_args.get_observer()
|
@@ -141,16 +160,31 @@ def _initialize_scale_zero_point_observer(
|
|
141
160
|
weight_shape[1] // quantization_args.group_size,
|
142
161
|
)
|
143
162
|
|
144
|
-
|
163
|
+
scale_dtype = module.weight.dtype
|
164
|
+
if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]:
|
165
|
+
scale_dtype = torch.float16
|
166
|
+
|
167
|
+
# initializes empty scale, zero point, and g_idx parameters for the module
|
145
168
|
init_scale = Parameter(
|
146
|
-
torch.empty(expected_shape, dtype=
|
169
|
+
torch.empty(expected_shape, dtype=scale_dtype, device=device),
|
147
170
|
requires_grad=False,
|
148
171
|
)
|
149
172
|
module.register_parameter(f"{base_name}_scale", init_scale)
|
150
173
|
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
174
|
+
if force_zero_point or not quantization_args.symmetric:
|
175
|
+
zp_dtype = quantization_args.pytorch_dtype()
|
176
|
+
init_zero_point = Parameter(
|
177
|
+
torch.zeros(expected_shape, device=device, dtype=zp_dtype),
|
178
|
+
requires_grad=False,
|
179
|
+
)
|
180
|
+
module.register_parameter(f"{base_name}_zero_point", init_zero_point)
|
181
|
+
|
182
|
+
# initialize with empty for actorder, to be populated by GPTQ or state_dict
|
183
|
+
if quantization_args.actorder:
|
184
|
+
g_idx_shape = (weight_shape[1],)
|
185
|
+
g_idx_dtype = torch.int
|
186
|
+
init_g_idx = Parameter(
|
187
|
+
torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype),
|
188
|
+
requires_grad=False,
|
189
|
+
)
|
190
|
+
module.register_parameter(f"{base_name}_g_idx", init_g_idx)
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import logging
|
16
|
+
from math import ceil
|
16
17
|
from typing import Any, Iterable, Optional, Tuple, Union
|
17
18
|
|
18
19
|
import torch
|
@@ -21,6 +22,7 @@ from compressed_tensors.quantization.quant_args import (
|
|
21
22
|
QuantizationStrategy,
|
22
23
|
)
|
23
24
|
from compressed_tensors.registry.registry import RegistryMixin
|
25
|
+
from compressed_tensors.utils import safe_permute
|
24
26
|
from torch import FloatTensor, IntTensor, Tensor
|
25
27
|
from torch.nn import Module
|
26
28
|
|
@@ -46,15 +48,18 @@ class Observer(Module, RegistryMixin):
|
|
46
48
|
self._num_observed_tokens = None
|
47
49
|
|
48
50
|
@torch.no_grad()
|
49
|
-
def forward(
|
51
|
+
def forward(
|
52
|
+
self, observed: Tensor, g_idx: Optional[Tensor] = None
|
53
|
+
) -> Tuple[FloatTensor, IntTensor]:
|
50
54
|
"""
|
51
55
|
maps directly to get_qparams
|
52
|
-
:param observed: optional observed tensor to calculate
|
53
|
-
|
56
|
+
:param observed: optional observed tensor from which to calculate
|
57
|
+
quantization parameters
|
58
|
+
:param g_idx: optional mapping from column index to group index
|
54
59
|
:return: tuple of scale and zero point based on last observed value
|
55
60
|
"""
|
56
61
|
self.record_observed_tokens(observed)
|
57
|
-
return self.get_qparams(observed=observed)
|
62
|
+
return self.get_qparams(observed=observed, g_idx=g_idx)
|
58
63
|
|
59
64
|
def calculate_qparams(
|
60
65
|
self,
|
@@ -77,7 +82,9 @@ class Observer(Module, RegistryMixin):
|
|
77
82
|
...
|
78
83
|
|
79
84
|
def get_qparams(
|
80
|
-
self,
|
85
|
+
self,
|
86
|
+
observed: Optional[Tensor] = None,
|
87
|
+
g_idx: Optional[Tensor] = None,
|
81
88
|
) -> Tuple[FloatTensor, IntTensor]:
|
82
89
|
"""
|
83
90
|
Convenience function to wrap overwritten calculate_qparams
|
@@ -86,6 +93,7 @@ class Observer(Module, RegistryMixin):
|
|
86
93
|
|
87
94
|
:param observed: optional observed tensor to calculate quantization parameters
|
88
95
|
from
|
96
|
+
:param g_idx: optional mapping from column index to group index
|
89
97
|
:return: tuple of scale and zero point based on last observed value
|
90
98
|
"""
|
91
99
|
if observed is not None:
|
@@ -97,20 +105,42 @@ class Observer(Module, RegistryMixin):
|
|
97
105
|
self._scale, self._zero_point = self.calculate_qparams(observed)
|
98
106
|
|
99
107
|
elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
|
108
|
+
rows = observed.shape[0]
|
100
109
|
columns = observed.shape[1]
|
101
|
-
|
102
|
-
|
103
|
-
|
110
|
+
num_groups = int(ceil(columns / group_size))
|
111
|
+
self._scale = torch.empty(
|
112
|
+
(rows, num_groups), dtype=observed.dtype, device=observed.device
|
113
|
+
)
|
114
|
+
zp_dtype = self.quantization_args.pytorch_dtype()
|
115
|
+
self._zero_point = torch.empty(
|
116
|
+
(rows, num_groups), dtype=zp_dtype, device=observed.device
|
117
|
+
)
|
118
|
+
|
119
|
+
# support column-order (default) quantization as well as other orderings
|
120
|
+
# such as activation ordering. Below checks if g_idx has initialized
|
121
|
+
is_column_order = g_idx is None or -1 in g_idx
|
122
|
+
if is_column_order:
|
123
|
+
group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
|
124
|
+
else:
|
125
|
+
group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
|
126
|
+
group_sizes = group_sizes[torch.argsort(group_indices)]
|
127
|
+
|
128
|
+
perm = torch.argsort(g_idx)
|
129
|
+
observed = safe_permute(observed, perm, dim=1)
|
130
|
+
|
131
|
+
# TODO: experiment with vectorizing for loop for performance
|
132
|
+
end = 0
|
133
|
+
for group_index, group_count in enumerate(group_sizes):
|
134
|
+
start = end
|
135
|
+
end = start + group_count
|
104
136
|
scale, zero_point = self.get_qparams_along_dim(
|
105
|
-
observed[:,
|
137
|
+
observed[:, start:end],
|
106
138
|
0,
|
107
|
-
tensor_id=
|
139
|
+
tensor_id=group_index,
|
108
140
|
)
|
109
|
-
scales.append(scale)
|
110
|
-
zero_points.append(zero_point)
|
111
141
|
|
112
|
-
|
113
|
-
|
142
|
+
self._scale[:, group_index] = scale.squeeze(1)
|
143
|
+
self._zero_point[:, group_index] = zero_point.squeeze(1)
|
114
144
|
|
115
145
|
elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
|
116
146
|
# assume observed is transposed, because its the output, hence use dim 0
|
@@ -132,6 +162,8 @@ class Observer(Module, RegistryMixin):
|
|
132
162
|
dim: Union[int, Iterable[int]],
|
133
163
|
tensor_id: Optional[Any] = None,
|
134
164
|
):
|
165
|
+
if isinstance(dim, int):
|
166
|
+
dim = [dim]
|
135
167
|
dim = set(dim)
|
136
168
|
|
137
169
|
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
|
@@ -171,3 +203,11 @@ class Observer(Module, RegistryMixin):
|
|
171
203
|
# observed_tokens (batch_size * sequence_length)
|
172
204
|
observed_tokens, _ = batch_tensor.shape
|
173
205
|
self._num_observed_tokens += observed_tokens
|
206
|
+
|
207
|
+
def reset(self):
|
208
|
+
"""
|
209
|
+
Reset the state of the observer
|
210
|
+
"""
|
211
|
+
self._num_observed_tokens = None
|
212
|
+
self._scale = None
|
213
|
+
self._zero_point = None
|
@@ -94,3 +94,11 @@ class MovingAverageMinMaxObserver(Observer):
|
|
94
94
|
return self.calculate_qparams(
|
95
95
|
observed, reduce_dims=reduce_dims, tensor_id=tensor_id
|
96
96
|
)
|
97
|
+
|
98
|
+
def reset(self):
|
99
|
+
"""
|
100
|
+
Reset the state of the observer, including min and maximum values
|
101
|
+
"""
|
102
|
+
super().reset()
|
103
|
+
self.min_val = {}
|
104
|
+
self.max_val = {}
|