compressed-tensors 0.4.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/base.py +1 -0
- compressed_tensors/compressors/__init__.py +5 -1
- compressed_tensors/compressors/base.py +200 -8
- compressed_tensors/compressors/dense.py +1 -1
- compressed_tensors/compressors/marlin_24.py +11 -10
- compressed_tensors/compressors/model_compressor.py +101 -13
- compressed_tensors/compressors/naive_quantized.py +140 -0
- compressed_tensors/compressors/pack_quantized.py +128 -132
- compressed_tensors/compressors/sparse_bitmask.py +1 -1
- compressed_tensors/config/base.py +8 -1
- compressed_tensors/{compressors/utils → linear}/__init__.py +0 -6
- compressed_tensors/linear/compressed_linear.py +87 -0
- compressed_tensors/quantization/lifecycle/__init__.py +1 -0
- compressed_tensors/quantization/lifecycle/apply.py +204 -44
- compressed_tensors/quantization/lifecycle/calibration.py +22 -2
- compressed_tensors/quantization/lifecycle/compressed.py +3 -1
- compressed_tensors/quantization/lifecycle/forward.py +139 -61
- compressed_tensors/quantization/lifecycle/helpers.py +80 -0
- compressed_tensors/quantization/lifecycle/initialize.py +77 -13
- compressed_tensors/quantization/observers/__init__.py +1 -0
- compressed_tensors/quantization/observers/base.py +93 -14
- compressed_tensors/quantization/observers/helpers.py +64 -11
- compressed_tensors/quantization/observers/min_max.py +8 -0
- compressed_tensors/quantization/observers/mse.py +162 -0
- compressed_tensors/quantization/quant_args.py +139 -23
- compressed_tensors/quantization/quant_config.py +35 -2
- compressed_tensors/quantization/quant_scheme.py +112 -13
- compressed_tensors/quantization/utils/helpers.py +68 -2
- compressed_tensors/utils/__init__.py +5 -0
- compressed_tensors/utils/helpers.py +44 -2
- compressed_tensors/utils/offload.py +116 -0
- compressed_tensors/utils/permute.py +70 -0
- compressed_tensors/utils/safetensors_load.py +2 -0
- compressed_tensors/{compressors/utils → utils}/semi_structured_conversions.py +1 -0
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/METADATA +35 -22
- compressed_tensors-0.6.0.dist-info/RECORD +52 -0
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/WHEEL +1 -1
- compressed_tensors/compressors/int_quantized.py +0 -126
- compressed_tensors/compressors/utils/helpers.py +0 -43
- compressed_tensors-0.4.0.dist-info/RECORD +0 -48
- /compressed_tensors/{compressors/utils → utils}/permutations_24.py +0 -0
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/top_level.txt +0 -0
@@ -17,12 +17,15 @@ from math import ceil
|
|
17
17
|
from typing import Optional
|
18
18
|
|
19
19
|
import torch
|
20
|
+
from compressed_tensors.quantization.observers.helpers import calculate_range
|
20
21
|
from compressed_tensors.quantization.quant_args import (
|
21
22
|
QuantizationArgs,
|
22
23
|
QuantizationStrategy,
|
24
|
+
round_to_quantized_type,
|
23
25
|
)
|
24
26
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
25
27
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
28
|
+
from compressed_tensors.utils import safe_permute, update_parameter_data
|
26
29
|
from torch.nn import Module
|
27
30
|
|
28
31
|
|
@@ -42,6 +45,7 @@ def quantize(
|
|
42
45
|
zero_point: torch.Tensor,
|
43
46
|
args: QuantizationArgs,
|
44
47
|
dtype: Optional[torch.dtype] = None,
|
48
|
+
g_idx: Optional[torch.Tensor] = None,
|
45
49
|
) -> torch.Tensor:
|
46
50
|
"""
|
47
51
|
Quantize the input tensor x using the QuantizationStrategy specified in args.
|
@@ -55,16 +59,9 @@ def quantize(
|
|
55
59
|
:param zero_point: zero point tensor
|
56
60
|
:param args: quantization args dictating how to quantize x
|
57
61
|
:param dtype: optional dtype to cast the quantized output to
|
62
|
+
:param g_idx: optional mapping from column index to group index
|
58
63
|
:return: fake quantized tensor
|
59
64
|
"""
|
60
|
-
# ensure all tensors are on the same device
|
61
|
-
# assumes that the target device is the input
|
62
|
-
# tensor's device
|
63
|
-
if x.device != scale.device:
|
64
|
-
scale = scale.to(x.device)
|
65
|
-
if x.device != zero_point.device:
|
66
|
-
zero_point = zero_point.to(x.device)
|
67
|
-
|
68
65
|
return _process_quantization(
|
69
66
|
x=x,
|
70
67
|
scale=scale,
|
@@ -73,6 +70,7 @@ def quantize(
|
|
73
70
|
dtype=dtype,
|
74
71
|
do_quantize=True,
|
75
72
|
do_dequantize=False,
|
73
|
+
g_idx=g_idx,
|
76
74
|
)
|
77
75
|
|
78
76
|
|
@@ -80,8 +78,10 @@ def quantize(
|
|
80
78
|
def dequantize(
|
81
79
|
x_q: torch.Tensor,
|
82
80
|
scale: torch.Tensor,
|
83
|
-
zero_point: torch.Tensor,
|
81
|
+
zero_point: torch.Tensor = None,
|
84
82
|
args: QuantizationArgs = None,
|
83
|
+
dtype: Optional[torch.dtype] = None,
|
84
|
+
g_idx: Optional[torch.Tensor] = None,
|
85
85
|
) -> torch.Tensor:
|
86
86
|
"""
|
87
87
|
Dequantize a quantized input tensor x_q based on the strategy specified in args. If
|
@@ -91,6 +91,8 @@ def dequantize(
|
|
91
91
|
:param scale: scale tensor
|
92
92
|
:param zero_point: zero point tensor
|
93
93
|
:param args: quantization args used to quantize x_q
|
94
|
+
:param dtype: optional dtype to cast the dequantized output to
|
95
|
+
:param g_idx: optional mapping from column index to group index
|
94
96
|
:return: dequantized float tensor
|
95
97
|
"""
|
96
98
|
if args is None:
|
@@ -107,8 +109,12 @@ def dequantize(
|
|
107
109
|
else:
|
108
110
|
raise ValueError(
|
109
111
|
f"Could not infer a quantization strategy from scale with {scale.ndim} "
|
110
|
-
"dimmensions. Expected 0
|
112
|
+
"dimmensions. Expected 0 or 2 dimmensions."
|
111
113
|
)
|
114
|
+
|
115
|
+
if dtype is None:
|
116
|
+
dtype = scale.dtype
|
117
|
+
|
112
118
|
return _process_quantization(
|
113
119
|
x=x_q,
|
114
120
|
scale=scale,
|
@@ -116,6 +122,8 @@ def dequantize(
|
|
116
122
|
args=args,
|
117
123
|
do_quantize=False,
|
118
124
|
do_dequantize=True,
|
125
|
+
dtype=dtype,
|
126
|
+
g_idx=g_idx,
|
119
127
|
)
|
120
128
|
|
121
129
|
|
@@ -125,6 +133,7 @@ def fake_quantize(
|
|
125
133
|
scale: torch.Tensor,
|
126
134
|
zero_point: torch.Tensor,
|
127
135
|
args: QuantizationArgs,
|
136
|
+
g_idx: Optional[torch.Tensor] = None,
|
128
137
|
) -> torch.Tensor:
|
129
138
|
"""
|
130
139
|
Fake quantize the input tensor x by quantizing then dequantizing with
|
@@ -137,6 +146,7 @@ def fake_quantize(
|
|
137
146
|
:param scale: scale tensor
|
138
147
|
:param zero_point: zero point tensor
|
139
148
|
:param args: quantization args dictating how to quantize x
|
149
|
+
:param g_idx: optional mapping from column index to group index
|
140
150
|
:return: fake quantized tensor
|
141
151
|
"""
|
142
152
|
return _process_quantization(
|
@@ -146,6 +156,7 @@ def fake_quantize(
|
|
146
156
|
args=args,
|
147
157
|
do_quantize=True,
|
148
158
|
do_dequantize=True,
|
159
|
+
g_idx=g_idx,
|
149
160
|
)
|
150
161
|
|
151
162
|
|
@@ -154,64 +165,85 @@ def _process_quantization(
|
|
154
165
|
x: torch.Tensor,
|
155
166
|
scale: torch.Tensor,
|
156
167
|
zero_point: torch.Tensor,
|
168
|
+
g_idx: Optional[torch.Tensor],
|
157
169
|
args: QuantizationArgs,
|
158
170
|
dtype: Optional[torch.dtype] = None,
|
159
171
|
do_quantize: bool = True,
|
160
172
|
do_dequantize: bool = True,
|
161
173
|
) -> torch.Tensor:
|
162
|
-
|
163
|
-
q_max = torch.tensor(bit_range / 2 - 1, device=x.device)
|
164
|
-
q_min = torch.tensor(-bit_range / 2, device=x.device)
|
174
|
+
q_min, q_max = calculate_range(args, x.device)
|
165
175
|
group_size = args.group_size
|
166
176
|
|
167
177
|
if args.strategy == QuantizationStrategy.GROUP:
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
output = torch.zeros_like(x, dtype=scale.dtype)
|
172
|
-
else:
|
173
|
-
output_dtype = dtype if dtype is not None else x.dtype
|
174
|
-
output = torch.zeros_like(x, dtype=output_dtype)
|
175
|
-
|
176
|
-
# TODO: vectorize the for loop
|
177
|
-
# TODO: fix genetric assumption about the tensor size for computing group
|
178
|
+
output_dtype = dtype if dtype is not None else x.dtype
|
179
|
+
output = torch.zeros_like(x).to(output_dtype)
|
180
|
+
columns = output.shape[1]
|
178
181
|
|
179
182
|
# TODO: make validation step for inputs
|
180
183
|
|
181
184
|
while scale.ndim < 2:
|
182
185
|
# pad scale and zero point dims for slicing
|
183
186
|
scale = scale.unsqueeze(1)
|
184
|
-
zero_point = zero_point.unsqueeze(1)
|
187
|
+
zero_point = zero_point.unsqueeze(1) if zero_point is not None else None
|
185
188
|
|
186
|
-
columns = x.shape[1]
|
187
189
|
if columns >= group_size:
|
188
190
|
if columns % group_size != 0:
|
189
191
|
raise ValueError(
|
190
|
-
"
|
192
|
+
"tensor column shape must be divisble "
|
191
193
|
f"by the given group_size {group_size}"
|
192
194
|
)
|
193
|
-
for i in range(ceil(columns / group_size)):
|
194
|
-
# scale.shape should be [nchan, ndim]
|
195
|
-
# sc.shape should be [nchan, 1] after unsqueeze
|
196
|
-
sc = scale[:, i].view(-1, 1)
|
197
|
-
zp = zero_point[:, i].view(-1, 1)
|
198
195
|
|
199
|
-
|
196
|
+
# support column-order (default) quantization as well as other orderings
|
197
|
+
# such as activation ordering. Below checks if g_idx has been initialized
|
198
|
+
is_column_order = g_idx is None or -1 in g_idx
|
199
|
+
if is_column_order:
|
200
|
+
num_groups = int(ceil(columns / group_size))
|
201
|
+
group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
|
202
|
+
|
203
|
+
else:
|
204
|
+
group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
|
205
|
+
group_sizes = group_sizes[torch.argsort(group_indices)]
|
206
|
+
|
207
|
+
perm = torch.argsort(g_idx)
|
208
|
+
x = safe_permute(x, perm, dim=1)
|
209
|
+
|
210
|
+
# TODO: experiment with vectorizing for loop for performance
|
211
|
+
end = 0
|
212
|
+
for index, group_count in enumerate(group_sizes):
|
213
|
+
sc = scale[:, index].view(-1, 1)
|
214
|
+
zp = zero_point[:, index].view(-1, 1) if zero_point is not None else None
|
215
|
+
|
216
|
+
start = end
|
217
|
+
end = start + group_count
|
200
218
|
if do_quantize:
|
201
|
-
output[:,
|
202
|
-
x[:,
|
219
|
+
output[:, start:end] = _quantize(
|
220
|
+
x[:, start:end],
|
221
|
+
sc,
|
222
|
+
zp,
|
223
|
+
q_min,
|
224
|
+
q_max,
|
225
|
+
args,
|
226
|
+
dtype=dtype,
|
203
227
|
)
|
228
|
+
|
204
229
|
if do_dequantize:
|
205
|
-
input =
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
output[:, idx : (idx + group_size)] = _dequantize(input, sc, zp)
|
230
|
+
input = output[:, start:end] if do_quantize else x[:, start:end]
|
231
|
+
output[:, start:end] = _dequantize(input, sc, zp)
|
232
|
+
|
233
|
+
if not is_column_order:
|
234
|
+
output = safe_permute(output, torch.argsort(perm), dim=1)
|
211
235
|
|
212
236
|
else: # covers channel, token and tensor strategies
|
213
237
|
if do_quantize:
|
214
|
-
output = _quantize(
|
238
|
+
output = _quantize(
|
239
|
+
x,
|
240
|
+
scale,
|
241
|
+
zero_point,
|
242
|
+
q_min,
|
243
|
+
q_max,
|
244
|
+
args,
|
245
|
+
dtype=dtype,
|
246
|
+
)
|
215
247
|
if do_dequantize:
|
216
248
|
output = _dequantize(output if do_quantize else x, scale, zero_point)
|
217
249
|
|
@@ -228,7 +260,13 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
228
260
|
|
229
261
|
@wraps(forward_func_orig) # ensures docstring, names, etc are propagated
|
230
262
|
def wrapped_forward(self, *args, **kwargs):
|
263
|
+
if not getattr(module, "quantization_enabled", True):
|
264
|
+
# quantization is disabled on forward passes, return baseline
|
265
|
+
# forward call
|
266
|
+
return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
|
267
|
+
|
231
268
|
input_ = args[0]
|
269
|
+
compressed = module.quantization_status == QuantizationStatus.COMPRESSED
|
232
270
|
|
233
271
|
if scheme.input_activations is not None:
|
234
272
|
# calibrate and (fake) quantize input activations when applicable
|
@@ -236,7 +274,7 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
236
274
|
module, input_, "input", scheme.input_activations
|
237
275
|
)
|
238
276
|
|
239
|
-
if scheme.weights is not None:
|
277
|
+
if scheme.weights is not None and not compressed:
|
240
278
|
# calibrate and (fake) quantize weights when applicable
|
241
279
|
unquantized_weight = self.weight.data.clone()
|
242
280
|
self.weight.data = maybe_calibrate_or_quantize(
|
@@ -255,7 +293,7 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
255
293
|
)
|
256
294
|
|
257
295
|
# restore back to unquantized_value
|
258
|
-
if scheme.weights is not None:
|
296
|
+
if scheme.weights is not None and not compressed:
|
259
297
|
self.weight.data = unquantized_weight
|
260
298
|
|
261
299
|
return output
|
@@ -269,33 +307,57 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
269
307
|
def maybe_calibrate_or_quantize(
|
270
308
|
module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
|
271
309
|
) -> torch.Tensor:
|
272
|
-
#
|
273
|
-
if module.quantization_status
|
274
|
-
QuantizationStatus.CALIBRATION,
|
275
|
-
QuantizationStatus.FROZEN,
|
276
|
-
}:
|
310
|
+
# don't run quantization if we haven't entered calibration mode
|
311
|
+
if module.quantization_status == QuantizationStatus.INITIALIZED:
|
277
312
|
return value
|
278
313
|
|
314
|
+
# in compressed mode, the weight is already compressed and quantized so we don't
|
315
|
+
# need to run fake quantization
|
316
|
+
if (
|
317
|
+
module.quantization_status == QuantizationStatus.COMPRESSED
|
318
|
+
and base_name == "weight"
|
319
|
+
):
|
320
|
+
return value
|
321
|
+
|
322
|
+
if value.numel() == 0:
|
323
|
+
# if the tensor is empty,
|
324
|
+
# skip quantization
|
325
|
+
return value
|
326
|
+
|
327
|
+
g_idx = getattr(module, "weight_g_idx", None)
|
328
|
+
|
279
329
|
if args.dynamic:
|
280
330
|
# dynamic quantization - get scale and zero point directly from observer
|
281
331
|
observer = getattr(module, f"{base_name}_observer")
|
282
|
-
scale, zero_point = observer(value)
|
332
|
+
scale, zero_point = observer(value, g_idx=g_idx)
|
283
333
|
else:
|
284
334
|
# static quantization - get previous scale and zero point from layer
|
285
335
|
scale = getattr(module, f"{base_name}_scale")
|
286
|
-
zero_point = getattr(module, f"{base_name}_zero_point")
|
336
|
+
zero_point = getattr(module, f"{base_name}_zero_point", None)
|
287
337
|
|
288
|
-
if
|
338
|
+
if (
|
339
|
+
module.quantization_status == QuantizationStatus.CALIBRATION
|
340
|
+
and base_name != "weight"
|
341
|
+
):
|
289
342
|
# calibration mode - get new quant params from observer
|
290
343
|
observer = getattr(module, f"{base_name}_observer")
|
291
344
|
|
292
|
-
updated_scale, updated_zero_point = observer(value)
|
345
|
+
updated_scale, updated_zero_point = observer(value, g_idx=g_idx)
|
293
346
|
|
294
347
|
# update scale and zero point
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
348
|
+
update_parameter_data(module, updated_scale, f"{base_name}_scale")
|
349
|
+
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
|
350
|
+
|
351
|
+
scale = updated_scale
|
352
|
+
zero_point = updated_zero_point
|
353
|
+
|
354
|
+
return fake_quantize(
|
355
|
+
x=value,
|
356
|
+
scale=scale,
|
357
|
+
zero_point=zero_point,
|
358
|
+
args=args,
|
359
|
+
g_idx=g_idx,
|
360
|
+
)
|
299
361
|
|
300
362
|
|
301
363
|
@torch.no_grad()
|
@@ -305,14 +367,20 @@ def _quantize(
|
|
305
367
|
zero_point: torch.Tensor,
|
306
368
|
q_min: torch.Tensor,
|
307
369
|
q_max: torch.Tensor,
|
370
|
+
args: QuantizationArgs,
|
308
371
|
dtype: Optional[torch.dtype] = None,
|
309
372
|
) -> torch.Tensor:
|
310
|
-
|
311
|
-
|
373
|
+
|
374
|
+
scaled = x / scale
|
375
|
+
if zero_point is not None:
|
376
|
+
scaled += zero_point.to(x.dtype)
|
377
|
+
# clamp first because cast isn't guaranteed to be saturated (ie for fp8)
|
378
|
+
clamped_value = torch.clamp(
|
379
|
+
scaled,
|
312
380
|
q_min,
|
313
381
|
q_max,
|
314
382
|
)
|
315
|
-
|
383
|
+
quantized_value = round_to_quantized_type(clamped_value, args)
|
316
384
|
if dtype is not None:
|
317
385
|
quantized_value = quantized_value.to(dtype)
|
318
386
|
|
@@ -323,6 +391,16 @@ def _quantize(
|
|
323
391
|
def _dequantize(
|
324
392
|
x_q: torch.Tensor,
|
325
393
|
scale: torch.Tensor,
|
326
|
-
zero_point: torch.Tensor,
|
394
|
+
zero_point: torch.Tensor = None,
|
395
|
+
dtype: Optional[torch.dtype] = None,
|
327
396
|
) -> torch.Tensor:
|
328
|
-
|
397
|
+
dequant_value = x_q.to(scale.dtype)
|
398
|
+
|
399
|
+
if zero_point is not None:
|
400
|
+
dequant_value = dequant_value - zero_point.to(scale.dtype)
|
401
|
+
dequant_value = dequant_value * scale
|
402
|
+
|
403
|
+
if dtype is not None:
|
404
|
+
dequant_value = dequant_value.to(dtype)
|
405
|
+
|
406
|
+
return dequant_value
|
@@ -0,0 +1,80 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""
|
16
|
+
Miscelaneous helpers for the quantization lifecycle
|
17
|
+
"""
|
18
|
+
|
19
|
+
from typing import Optional
|
20
|
+
|
21
|
+
import torch
|
22
|
+
from torch.nn import Module
|
23
|
+
|
24
|
+
|
25
|
+
__all__ = [
|
26
|
+
"update_layer_weight_quant_params",
|
27
|
+
"enable_quantization",
|
28
|
+
"disable_quantization",
|
29
|
+
]
|
30
|
+
|
31
|
+
|
32
|
+
def update_layer_weight_quant_params(
|
33
|
+
layer: Module,
|
34
|
+
weight: Optional[torch.Tensor] = None,
|
35
|
+
g_idx: Optional[torch.Tensor] = None,
|
36
|
+
reset_obs: bool = False,
|
37
|
+
):
|
38
|
+
"""
|
39
|
+
Update quantization parameters on layer
|
40
|
+
|
41
|
+
:param layer: input layer
|
42
|
+
:param weight: weight to update quant params with, defaults to layer weight
|
43
|
+
:param g_idx: optional mapping from column index to group index
|
44
|
+
:param reset_obs: reset the observer before calculating quant params,
|
45
|
+
defaults to False
|
46
|
+
"""
|
47
|
+
attached_weight = getattr(layer, "weight", None)
|
48
|
+
|
49
|
+
if weight is None:
|
50
|
+
weight = attached_weight
|
51
|
+
scale = getattr(layer, "weight_scale", None)
|
52
|
+
zero_point = getattr(layer, "weight_zero_point", None)
|
53
|
+
if g_idx is None:
|
54
|
+
g_idx = getattr(layer, "weight_g_idx", None)
|
55
|
+
observer = getattr(layer, "weight_observer", None)
|
56
|
+
|
57
|
+
if weight is None or observer is None or scale is None or zero_point is None:
|
58
|
+
# scale, zp, or observer not calibratable or weight not available
|
59
|
+
return
|
60
|
+
|
61
|
+
if reset_obs:
|
62
|
+
observer.reset()
|
63
|
+
|
64
|
+
if attached_weight is not None:
|
65
|
+
weight = weight.to(attached_weight.dtype)
|
66
|
+
|
67
|
+
updated_scale, updated_zero_point = observer(weight)
|
68
|
+
|
69
|
+
# update scale and zero point
|
70
|
+
device = next(layer.parameters()).device
|
71
|
+
scale.data = updated_scale.to(device)
|
72
|
+
zero_point.data = updated_zero_point.to(device)
|
73
|
+
|
74
|
+
|
75
|
+
def enable_quantization(module: Module):
|
76
|
+
module.quantization_enabled = True
|
77
|
+
|
78
|
+
|
79
|
+
def disable_quantization(module: Module):
|
80
|
+
module.quantization_enabled = False
|
@@ -21,11 +21,13 @@ from compressed_tensors.quantization.lifecycle.forward import (
|
|
21
21
|
wrap_module_forward_quantized,
|
22
22
|
)
|
23
23
|
from compressed_tensors.quantization.quant_args import (
|
24
|
+
ActivationOrdering,
|
24
25
|
QuantizationArgs,
|
25
26
|
QuantizationStrategy,
|
26
27
|
)
|
27
28
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
28
29
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
30
|
+
from compressed_tensors.utils import get_execution_device, is_module_offloaded
|
29
31
|
from torch.nn import Module, Parameter
|
30
32
|
|
31
33
|
|
@@ -40,6 +42,7 @@ _LOGGER = logging.getLogger(__name__)
|
|
40
42
|
def initialize_module_for_quantization(
|
41
43
|
module: Module,
|
42
44
|
scheme: Optional[QuantizationScheme] = None,
|
45
|
+
force_zero_point: bool = True,
|
43
46
|
):
|
44
47
|
"""
|
45
48
|
attaches appropriate scales, zero points, and observers to a layer
|
@@ -51,6 +54,8 @@ def initialize_module_for_quantization(
|
|
51
54
|
:param scheme: scheme to use for quantization. if None is provided,
|
52
55
|
will attempt to use scheme stored in the module under `quantization_scheme`,
|
53
56
|
if not provided, the layer will be skipped
|
57
|
+
:param force_zero_point: whether to force initialization of a zero point for
|
58
|
+
symmetric quantization
|
54
59
|
"""
|
55
60
|
scheme = scheme or getattr(module, "quantization_scheme", None)
|
56
61
|
if scheme is None:
|
@@ -58,14 +63,18 @@ def initialize_module_for_quantization(
|
|
58
63
|
return
|
59
64
|
|
60
65
|
if scheme.input_activations is not None:
|
61
|
-
_initialize_scale_zero_point_observer(
|
66
|
+
_initialize_scale_zero_point_observer(
|
67
|
+
module, "input", scheme.input_activations, force_zero_point=force_zero_point
|
68
|
+
)
|
62
69
|
if scheme.weights is not None:
|
63
70
|
if hasattr(module, "weight"):
|
64
|
-
weight_shape =
|
65
|
-
if isinstance(module, torch.nn.Linear):
|
66
|
-
weight_shape = module.weight.shape
|
71
|
+
weight_shape = module.weight.shape
|
67
72
|
_initialize_scale_zero_point_observer(
|
68
|
-
module,
|
73
|
+
module,
|
74
|
+
"weight",
|
75
|
+
scheme.weights,
|
76
|
+
weight_shape=weight_shape,
|
77
|
+
force_zero_point=force_zero_point,
|
69
78
|
)
|
70
79
|
else:
|
71
80
|
_LOGGER.warning(
|
@@ -75,21 +84,58 @@ def initialize_module_for_quantization(
|
|
75
84
|
)
|
76
85
|
if scheme.output_activations is not None:
|
77
86
|
_initialize_scale_zero_point_observer(
|
78
|
-
module,
|
87
|
+
module,
|
88
|
+
"output",
|
89
|
+
scheme.output_activations,
|
90
|
+
force_zero_point=force_zero_point,
|
79
91
|
)
|
80
92
|
|
81
93
|
module.quantization_scheme = scheme
|
82
94
|
module.quantization_status = QuantizationStatus.INITIALIZED
|
83
95
|
|
96
|
+
offloaded = False
|
97
|
+
if is_module_offloaded(module):
|
98
|
+
try:
|
99
|
+
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
100
|
+
from accelerate.utils import PrefixedDataset
|
101
|
+
except ModuleNotFoundError:
|
102
|
+
raise ModuleNotFoundError(
|
103
|
+
"Offloaded model detected. To use CPU offloading with "
|
104
|
+
"compressed-tensors the `accelerate` package must be installed, "
|
105
|
+
"run `pip install compressed-tensors[accelerate]`"
|
106
|
+
)
|
107
|
+
|
108
|
+
offloaded = True
|
109
|
+
hook = module._hf_hook
|
110
|
+
prefix_dict = module._hf_hook.weights_map
|
111
|
+
new_prefix = {}
|
112
|
+
|
113
|
+
# recreate the prefix dict (since it is immutable)
|
114
|
+
# and add quantization parameters
|
115
|
+
for key, data in module.named_parameters():
|
116
|
+
if key not in prefix_dict:
|
117
|
+
new_prefix[f"{prefix_dict.prefix}{key}"] = data
|
118
|
+
else:
|
119
|
+
new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
|
120
|
+
new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
|
121
|
+
remove_hook_from_module(module)
|
122
|
+
|
84
123
|
# wrap forward call of module to perform quantized actions based on calltime status
|
85
124
|
wrap_module_forward_quantized(module, scheme)
|
86
125
|
|
126
|
+
if offloaded:
|
127
|
+
# we need to re-add the hook for offloading now that we've wrapped forward
|
128
|
+
add_hook_to_module(module, hook)
|
129
|
+
if prefix_dict is not None:
|
130
|
+
module._hf_hook.weights_map = new_prefix_dict
|
131
|
+
|
87
132
|
|
88
133
|
def _initialize_scale_zero_point_observer(
|
89
134
|
module: Module,
|
90
135
|
base_name: str,
|
91
136
|
quantization_args: QuantizationArgs,
|
92
137
|
weight_shape: Optional[torch.Size] = None,
|
138
|
+
force_zero_point: bool = True,
|
93
139
|
):
|
94
140
|
# initialize observer module and attach as submodule
|
95
141
|
observer = quantization_args.get_observer()
|
@@ -99,6 +145,8 @@ def _initialize_scale_zero_point_observer(
|
|
99
145
|
return # no need to register a scale and zero point for a dynamic observer
|
100
146
|
|
101
147
|
device = next(module.parameters()).device
|
148
|
+
if is_module_offloaded(module):
|
149
|
+
device = get_execution_device(module)
|
102
150
|
|
103
151
|
# infer expected scale/zero point shape
|
104
152
|
expected_shape = 1 # per tensor
|
@@ -113,15 +161,31 @@ def _initialize_scale_zero_point_observer(
|
|
113
161
|
weight_shape[1] // quantization_args.group_size,
|
114
162
|
)
|
115
163
|
|
116
|
-
|
164
|
+
scale_dtype = module.weight.dtype
|
165
|
+
if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]:
|
166
|
+
scale_dtype = torch.float16
|
167
|
+
|
168
|
+
# initializes empty scale, zero point, and g_idx parameters for the module
|
117
169
|
init_scale = Parameter(
|
118
|
-
torch.empty(expected_shape, dtype=
|
170
|
+
torch.empty(expected_shape, dtype=scale_dtype, device=device),
|
119
171
|
requires_grad=False,
|
120
172
|
)
|
121
173
|
module.register_parameter(f"{base_name}_scale", init_scale)
|
122
174
|
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
175
|
+
if force_zero_point or not quantization_args.symmetric:
|
176
|
+
zp_dtype = quantization_args.pytorch_dtype()
|
177
|
+
init_zero_point = Parameter(
|
178
|
+
torch.zeros(expected_shape, device=device, dtype=zp_dtype),
|
179
|
+
requires_grad=False,
|
180
|
+
)
|
181
|
+
module.register_parameter(f"{base_name}_zero_point", init_zero_point)
|
182
|
+
|
183
|
+
# only grouped activation ordering has g_idx
|
184
|
+
if quantization_args.actorder == ActivationOrdering.GROUP:
|
185
|
+
g_idx_shape = (weight_shape[1],)
|
186
|
+
g_idx_dtype = torch.int
|
187
|
+
init_g_idx = Parameter(
|
188
|
+
torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype),
|
189
|
+
requires_grad=False,
|
190
|
+
)
|
191
|
+
module.register_parameter(f"{base_name}_g_idx", init_g_idx)
|