compressed-tensors 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/__init__.py +1 -0
- compressed_tensors/base.py +2 -0
- compressed_tensors/compressors/__init__.py +6 -12
- compressed_tensors/compressors/base.py +137 -9
- compressed_tensors/compressors/helpers.py +6 -6
- compressed_tensors/compressors/model_compressors/__init__.py +17 -0
- compressed_tensors/compressors/{model_compressor.py → model_compressors/model_compressor.py} +99 -43
- compressed_tensors/compressors/quantized_compressors/__init__.py +18 -0
- compressed_tensors/compressors/{naive_quantized.py → quantized_compressors/base.py} +64 -62
- compressed_tensors/compressors/quantized_compressors/naive_quantized.py +140 -0
- compressed_tensors/compressors/quantized_compressors/pack_quantized.py +211 -0
- compressed_tensors/compressors/sparse_compressors/__init__.py +18 -0
- compressed_tensors/compressors/sparse_compressors/base.py +110 -0
- compressed_tensors/compressors/{dense.py → sparse_compressors/dense.py} +3 -3
- compressed_tensors/compressors/{sparse_bitmask.py → sparse_compressors/sparse_bitmask.py} +14 -59
- compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +16 -0
- compressed_tensors/compressors/{marlin_24.py → sparse_quantized_compressors/marlin_24.py} +3 -3
- compressed_tensors/config/base.py +6 -1
- compressed_tensors/linear/__init__.py +13 -0
- compressed_tensors/linear/compressed_linear.py +87 -0
- compressed_tensors/quantization/__init__.py +1 -0
- compressed_tensors/quantization/cache.py +201 -0
- compressed_tensors/quantization/lifecycle/apply.py +63 -9
- compressed_tensors/quantization/lifecycle/calibration.py +7 -7
- compressed_tensors/quantization/lifecycle/compressed.py +3 -1
- compressed_tensors/quantization/lifecycle/forward.py +126 -44
- compressed_tensors/quantization/lifecycle/frozen.py +6 -1
- compressed_tensors/quantization/lifecycle/helpers.py +0 -20
- compressed_tensors/quantization/lifecycle/initialize.py +138 -55
- compressed_tensors/quantization/observers/__init__.py +1 -0
- compressed_tensors/quantization/observers/base.py +54 -14
- compressed_tensors/quantization/observers/min_max.py +8 -0
- compressed_tensors/quantization/observers/mse.py +162 -0
- compressed_tensors/quantization/quant_args.py +102 -24
- compressed_tensors/quantization/quant_config.py +14 -2
- compressed_tensors/quantization/quant_scheme.py +12 -13
- compressed_tensors/quantization/utils/helpers.py +44 -19
- compressed_tensors/utils/__init__.py +1 -0
- compressed_tensors/utils/helpers.py +30 -1
- compressed_tensors/utils/offload.py +14 -2
- compressed_tensors/utils/permute.py +70 -0
- compressed_tensors/utils/safetensors_load.py +2 -0
- compressed_tensors/utils/semi_structured_conversions.py +1 -0
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/METADATA +35 -23
- compressed_tensors-0.7.0.dist-info/RECORD +59 -0
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/WHEEL +1 -1
- compressed_tensors/compressors/pack_quantized.py +0 -219
- compressed_tensors-0.5.0.dist-info/RECORD +0 -48
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/top_level.txt +0 -0
@@ -49,8 +49,9 @@ def compress_quantized_weights(module: Module):
|
|
49
49
|
weight = getattr(module, "weight", None)
|
50
50
|
scale = getattr(module, "weight_scale", None)
|
51
51
|
zero_point = getattr(module, "weight_zero_point", None)
|
52
|
+
g_idx = getattr(module, "weight_g_idx", None)
|
52
53
|
|
53
|
-
if weight is None or scale is None
|
54
|
+
if weight is None or scale is None:
|
54
55
|
# no weight, scale, or ZP, nothing to do
|
55
56
|
|
56
57
|
# mark as compressed here to maintain consistent status throughout the model
|
@@ -62,6 +63,7 @@ def compress_quantized_weights(module: Module):
|
|
62
63
|
x=weight,
|
63
64
|
scale=scale,
|
64
65
|
zero_point=zero_point,
|
66
|
+
g_idx=g_idx,
|
65
67
|
args=scheme.weights,
|
66
68
|
dtype=torch.int8,
|
67
69
|
)
|
@@ -14,9 +14,10 @@
|
|
14
14
|
|
15
15
|
from functools import wraps
|
16
16
|
from math import ceil
|
17
|
-
from typing import Optional
|
17
|
+
from typing import Callable, Optional
|
18
18
|
|
19
19
|
import torch
|
20
|
+
from compressed_tensors.quantization.cache import QuantizedKVParameterCache
|
20
21
|
from compressed_tensors.quantization.observers.helpers import calculate_range
|
21
22
|
from compressed_tensors.quantization.quant_args import (
|
22
23
|
QuantizationArgs,
|
@@ -25,7 +26,7 @@ from compressed_tensors.quantization.quant_args import (
|
|
25
26
|
)
|
26
27
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
27
28
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
28
|
-
from compressed_tensors.utils import update_parameter_data
|
29
|
+
from compressed_tensors.utils import safe_permute, update_parameter_data
|
29
30
|
from torch.nn import Module
|
30
31
|
|
31
32
|
|
@@ -45,6 +46,7 @@ def quantize(
|
|
45
46
|
zero_point: torch.Tensor,
|
46
47
|
args: QuantizationArgs,
|
47
48
|
dtype: Optional[torch.dtype] = None,
|
49
|
+
g_idx: Optional[torch.Tensor] = None,
|
48
50
|
) -> torch.Tensor:
|
49
51
|
"""
|
50
52
|
Quantize the input tensor x using the QuantizationStrategy specified in args.
|
@@ -58,15 +60,9 @@ def quantize(
|
|
58
60
|
:param zero_point: zero point tensor
|
59
61
|
:param args: quantization args dictating how to quantize x
|
60
62
|
:param dtype: optional dtype to cast the quantized output to
|
63
|
+
:param g_idx: optional mapping from column index to group index
|
61
64
|
:return: fake quantized tensor
|
62
65
|
"""
|
63
|
-
# ensure all tensors are on the same device
|
64
|
-
# assumes that the target device is the input
|
65
|
-
# tensor's device
|
66
|
-
if x.device != scale.device:
|
67
|
-
scale = scale.to(x.device)
|
68
|
-
if x.device != zero_point.device:
|
69
|
-
zero_point = zero_point.to(x.device)
|
70
66
|
|
71
67
|
return _process_quantization(
|
72
68
|
x=x,
|
@@ -76,6 +72,7 @@ def quantize(
|
|
76
72
|
dtype=dtype,
|
77
73
|
do_quantize=True,
|
78
74
|
do_dequantize=False,
|
75
|
+
g_idx=g_idx,
|
79
76
|
)
|
80
77
|
|
81
78
|
|
@@ -86,6 +83,7 @@ def dequantize(
|
|
86
83
|
zero_point: torch.Tensor = None,
|
87
84
|
args: QuantizationArgs = None,
|
88
85
|
dtype: Optional[torch.dtype] = None,
|
86
|
+
g_idx: Optional[torch.Tensor] = None,
|
89
87
|
) -> torch.Tensor:
|
90
88
|
"""
|
91
89
|
Dequantize a quantized input tensor x_q based on the strategy specified in args. If
|
@@ -96,6 +94,7 @@ def dequantize(
|
|
96
94
|
:param zero_point: zero point tensor
|
97
95
|
:param args: quantization args used to quantize x_q
|
98
96
|
:param dtype: optional dtype to cast the dequantized output to
|
97
|
+
:param g_idx: optional mapping from column index to group index
|
99
98
|
:return: dequantized float tensor
|
100
99
|
"""
|
101
100
|
if args is None:
|
@@ -126,6 +125,7 @@ def dequantize(
|
|
126
125
|
do_quantize=False,
|
127
126
|
do_dequantize=True,
|
128
127
|
dtype=dtype,
|
128
|
+
g_idx=g_idx,
|
129
129
|
)
|
130
130
|
|
131
131
|
|
@@ -135,6 +135,7 @@ def fake_quantize(
|
|
135
135
|
scale: torch.Tensor,
|
136
136
|
zero_point: torch.Tensor,
|
137
137
|
args: QuantizationArgs,
|
138
|
+
g_idx: Optional[torch.Tensor] = None,
|
138
139
|
) -> torch.Tensor:
|
139
140
|
"""
|
140
141
|
Fake quantize the input tensor x by quantizing then dequantizing with
|
@@ -147,6 +148,7 @@ def fake_quantize(
|
|
147
148
|
:param scale: scale tensor
|
148
149
|
:param zero_point: zero point tensor
|
149
150
|
:param args: quantization args dictating how to quantize x
|
151
|
+
:param g_idx: optional mapping from column index to group index
|
150
152
|
:return: fake quantized tensor
|
151
153
|
"""
|
152
154
|
return _process_quantization(
|
@@ -156,6 +158,7 @@ def fake_quantize(
|
|
156
158
|
args=args,
|
157
159
|
do_quantize=True,
|
158
160
|
do_dequantize=True,
|
161
|
+
g_idx=g_idx,
|
159
162
|
)
|
160
163
|
|
161
164
|
|
@@ -165,20 +168,18 @@ def _process_quantization(
|
|
165
168
|
scale: torch.Tensor,
|
166
169
|
zero_point: torch.Tensor,
|
167
170
|
args: QuantizationArgs,
|
171
|
+
g_idx: Optional[torch.Tensor] = None,
|
168
172
|
dtype: Optional[torch.dtype] = None,
|
169
173
|
do_quantize: bool = True,
|
170
174
|
do_dequantize: bool = True,
|
171
175
|
) -> torch.Tensor:
|
172
|
-
|
173
176
|
q_min, q_max = calculate_range(args, x.device)
|
174
177
|
group_size = args.group_size
|
175
178
|
|
176
179
|
if args.strategy == QuantizationStrategy.GROUP:
|
177
180
|
output_dtype = dtype if dtype is not None else x.dtype
|
178
181
|
output = torch.zeros_like(x).to(output_dtype)
|
179
|
-
|
180
|
-
# TODO: vectorize the for loop
|
181
|
-
# TODO: fix genetric assumption about the tensor size for computing group
|
182
|
+
columns = output.shape[1]
|
182
183
|
|
183
184
|
# TODO: make validation step for inputs
|
184
185
|
|
@@ -187,23 +188,38 @@ def _process_quantization(
|
|
187
188
|
scale = scale.unsqueeze(1)
|
188
189
|
zero_point = zero_point.unsqueeze(1) if zero_point is not None else None
|
189
190
|
|
190
|
-
columns = x.shape[1]
|
191
191
|
if columns >= group_size:
|
192
192
|
if columns % group_size != 0:
|
193
193
|
raise ValueError(
|
194
|
-
"
|
194
|
+
"tensor column shape must be divisble "
|
195
195
|
f"by the given group_size {group_size}"
|
196
196
|
)
|
197
|
-
for i in range(ceil(columns / group_size)):
|
198
|
-
# scale.shape should be [nchan, ndim]
|
199
|
-
# sc.shape should be [nchan, 1] after unsqueeze
|
200
|
-
sc = scale[:, i].view(-1, 1)
|
201
|
-
zp = zero_point[:, i].view(-1, 1) if zero_point is not None else None
|
202
197
|
|
203
|
-
|
198
|
+
# support column-order (default) quantization as well as other orderings
|
199
|
+
# such as activation ordering. Below checks if g_idx has been initialized
|
200
|
+
is_column_order = g_idx is None or -1 in g_idx
|
201
|
+
if is_column_order:
|
202
|
+
num_groups = int(ceil(columns / group_size))
|
203
|
+
group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
|
204
|
+
|
205
|
+
else:
|
206
|
+
group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
|
207
|
+
group_sizes = group_sizes[torch.argsort(group_indices)]
|
208
|
+
|
209
|
+
perm = torch.argsort(g_idx)
|
210
|
+
x = safe_permute(x, perm, dim=1)
|
211
|
+
|
212
|
+
# TODO: experiment with vectorizing for loop for performance
|
213
|
+
end = 0
|
214
|
+
for index, group_count in enumerate(group_sizes):
|
215
|
+
sc = scale[:, index].view(-1, 1)
|
216
|
+
zp = zero_point[:, index].view(-1, 1) if zero_point is not None else None
|
217
|
+
|
218
|
+
start = end
|
219
|
+
end = start + group_count
|
204
220
|
if do_quantize:
|
205
|
-
output[:,
|
206
|
-
x[:,
|
221
|
+
output[:, start:end] = _quantize(
|
222
|
+
x[:, start:end],
|
207
223
|
sc,
|
208
224
|
zp,
|
209
225
|
q_min,
|
@@ -211,13 +227,13 @@ def _process_quantization(
|
|
211
227
|
args,
|
212
228
|
dtype=dtype,
|
213
229
|
)
|
230
|
+
|
214
231
|
if do_dequantize:
|
215
|
-
input =
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
output[:, idx : (idx + group_size)] = _dequantize(input, sc, zp)
|
232
|
+
input = output[:, start:end] if do_quantize else x[:, start:end]
|
233
|
+
output[:, start:end] = _dequantize(input, sc, zp)
|
234
|
+
|
235
|
+
if not is_column_order:
|
236
|
+
output = safe_permute(output, torch.argsort(perm), dim=1)
|
221
237
|
|
222
238
|
else: # covers channel, token and tensor strategies
|
223
239
|
if do_quantize:
|
@@ -253,13 +269,15 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
253
269
|
|
254
270
|
input_ = args[0]
|
255
271
|
|
272
|
+
compressed = module.quantization_status == QuantizationStatus.COMPRESSED
|
273
|
+
|
256
274
|
if scheme.input_activations is not None:
|
257
275
|
# calibrate and (fake) quantize input activations when applicable
|
258
276
|
input_ = maybe_calibrate_or_quantize(
|
259
277
|
module, input_, "input", scheme.input_activations
|
260
278
|
)
|
261
279
|
|
262
|
-
if scheme.weights is not None:
|
280
|
+
if scheme.weights is not None and not compressed:
|
263
281
|
# calibrate and (fake) quantize weights when applicable
|
264
282
|
unquantized_weight = self.weight.data.clone()
|
265
283
|
self.weight.data = maybe_calibrate_or_quantize(
|
@@ -270,15 +288,17 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
270
288
|
output = forward_func_orig.__get__(module, module.__class__)(
|
271
289
|
input_, *args[1:], **kwargs
|
272
290
|
)
|
273
|
-
|
274
291
|
if scheme.output_activations is not None:
|
292
|
+
|
275
293
|
# calibrate and (fake) quantize output activations when applicable
|
294
|
+
# kv_cache scales updated on model self_attn forward call in
|
295
|
+
# wrap_module_forward_quantized_attn
|
276
296
|
output = maybe_calibrate_or_quantize(
|
277
297
|
module, output, "output", scheme.output_activations
|
278
298
|
)
|
279
299
|
|
280
300
|
# restore back to unquantized_value
|
281
|
-
if scheme.weights is not None:
|
301
|
+
if scheme.weights is not None and not compressed:
|
282
302
|
self.weight.data = unquantized_weight
|
283
303
|
|
284
304
|
return output
|
@@ -289,14 +309,63 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
289
309
|
setattr(module, "forward", bound_wrapped_forward)
|
290
310
|
|
291
311
|
|
312
|
+
def wrap_module_forward_quantized_attn(module: Module, scheme: QuantizationScheme):
|
313
|
+
# expects a module already initialized and injected with the parameters in
|
314
|
+
# initialize_module_for_quantization
|
315
|
+
if hasattr(module.forward, "__func__"):
|
316
|
+
forward_func_orig = module.forward.__func__
|
317
|
+
else:
|
318
|
+
forward_func_orig = module.forward.func
|
319
|
+
|
320
|
+
@wraps(forward_func_orig) # ensures docstring, names, etc are propagated
|
321
|
+
def wrapped_forward(self, *args, **kwargs):
|
322
|
+
|
323
|
+
# kv cache stored under weights
|
324
|
+
if module.quantization_status == QuantizationStatus.CALIBRATION:
|
325
|
+
quantization_args: QuantizationArgs = scheme.output_activations
|
326
|
+
past_key_value: QuantizedKVParameterCache = quantization_args.get_kv_cache()
|
327
|
+
kwargs["past_key_value"] = past_key_value
|
328
|
+
|
329
|
+
# QuantizedKVParameterCache used for obtaining k_scale, v_scale only,
|
330
|
+
# does not store quantized_key_states and quantized_value_state
|
331
|
+
kwargs["use_cache"] = False
|
332
|
+
|
333
|
+
attn_forward: Callable = forward_func_orig.__get__(module, module.__class__)
|
334
|
+
|
335
|
+
past_key_value.reset_states()
|
336
|
+
|
337
|
+
rtn = attn_forward(*args, **kwargs)
|
338
|
+
|
339
|
+
update_parameter_data(
|
340
|
+
module, past_key_value.k_scales[module.layer_idx], "k_scale"
|
341
|
+
)
|
342
|
+
update_parameter_data(
|
343
|
+
module, past_key_value.v_scales[module.layer_idx], "v_scale"
|
344
|
+
)
|
345
|
+
|
346
|
+
return rtn
|
347
|
+
|
348
|
+
return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
|
349
|
+
|
350
|
+
# bind wrapped forward to module class so reference to `self` is correct
|
351
|
+
bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
|
352
|
+
# set forward to wrapped forward
|
353
|
+
setattr(module, "forward", bound_wrapped_forward)
|
354
|
+
|
355
|
+
|
292
356
|
def maybe_calibrate_or_quantize(
|
293
357
|
module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
|
294
358
|
) -> torch.Tensor:
|
295
|
-
#
|
296
|
-
if module.quantization_status
|
297
|
-
|
298
|
-
|
299
|
-
|
359
|
+
# don't run quantization if we haven't entered calibration mode
|
360
|
+
if module.quantization_status == QuantizationStatus.INITIALIZED:
|
361
|
+
return value
|
362
|
+
|
363
|
+
# in compressed mode, the weight is already compressed and quantized so we don't
|
364
|
+
# need to run fake quantization
|
365
|
+
if (
|
366
|
+
module.quantization_status == QuantizationStatus.COMPRESSED
|
367
|
+
and base_name == "weight"
|
368
|
+
):
|
300
369
|
return value
|
301
370
|
|
302
371
|
if value.numel() == 0:
|
@@ -304,14 +373,16 @@ def maybe_calibrate_or_quantize(
|
|
304
373
|
# skip quantization
|
305
374
|
return value
|
306
375
|
|
376
|
+
g_idx = getattr(module, "weight_g_idx", None)
|
377
|
+
|
307
378
|
if args.dynamic:
|
308
379
|
# dynamic quantization - get scale and zero point directly from observer
|
309
380
|
observer = getattr(module, f"{base_name}_observer")
|
310
|
-
scale, zero_point = observer(value)
|
381
|
+
scale, zero_point = observer(value, g_idx=g_idx)
|
311
382
|
else:
|
312
383
|
# static quantization - get previous scale and zero point from layer
|
313
384
|
scale = getattr(module, f"{base_name}_scale")
|
314
|
-
zero_point = getattr(module, f"{base_name}_zero_point")
|
385
|
+
zero_point = getattr(module, f"{base_name}_zero_point", None)
|
315
386
|
|
316
387
|
if (
|
317
388
|
module.quantization_status == QuantizationStatus.CALIBRATION
|
@@ -320,13 +391,22 @@ def maybe_calibrate_or_quantize(
|
|
320
391
|
# calibration mode - get new quant params from observer
|
321
392
|
observer = getattr(module, f"{base_name}_observer")
|
322
393
|
|
323
|
-
updated_scale, updated_zero_point = observer(value)
|
394
|
+
updated_scale, updated_zero_point = observer(value, g_idx=g_idx)
|
324
395
|
|
325
396
|
# update scale and zero point
|
326
397
|
update_parameter_data(module, updated_scale, f"{base_name}_scale")
|
327
398
|
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
|
328
399
|
|
329
|
-
|
400
|
+
scale = updated_scale
|
401
|
+
zero_point = updated_zero_point
|
402
|
+
|
403
|
+
return fake_quantize(
|
404
|
+
x=value,
|
405
|
+
scale=scale,
|
406
|
+
zero_point=zero_point,
|
407
|
+
args=args,
|
408
|
+
g_idx=g_idx,
|
409
|
+
)
|
330
410
|
|
331
411
|
|
332
412
|
@torch.no_grad()
|
@@ -340,7 +420,9 @@ def _quantize(
|
|
340
420
|
dtype: Optional[torch.dtype] = None,
|
341
421
|
) -> torch.Tensor:
|
342
422
|
|
343
|
-
scaled = x / scale
|
423
|
+
scaled = x / scale
|
424
|
+
if zero_point is not None:
|
425
|
+
scaled += zero_point.to(x.dtype)
|
344
426
|
# clamp first because cast isn't guaranteed to be saturated (ie for fp8)
|
345
427
|
clamped_value = torch.clamp(
|
346
428
|
scaled,
|
@@ -361,11 +443,11 @@ def _dequantize(
|
|
361
443
|
zero_point: torch.Tensor = None,
|
362
444
|
dtype: Optional[torch.dtype] = None,
|
363
445
|
) -> torch.Tensor:
|
446
|
+
dequant_value = x_q.to(scale.dtype)
|
364
447
|
|
365
|
-
dequant_value = x_q
|
366
448
|
if zero_point is not None:
|
367
449
|
dequant_value = dequant_value - zero_point.to(scale.dtype)
|
368
|
-
dequant_value = dequant_value
|
450
|
+
dequant_value = dequant_value * scale
|
369
451
|
|
370
452
|
if dtype is not None:
|
371
453
|
dequant_value = dequant_value.to(dtype)
|
@@ -14,6 +14,7 @@
|
|
14
14
|
|
15
15
|
|
16
16
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
17
|
+
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
|
17
18
|
from torch.nn import Module
|
18
19
|
|
19
20
|
|
@@ -44,7 +45,11 @@ def freeze_module_quantization(module: Module):
|
|
44
45
|
delattr(module, "input_observer")
|
45
46
|
if scheme.weights and not scheme.weights.dynamic:
|
46
47
|
delattr(module, "weight_observer")
|
47
|
-
if
|
48
|
+
if (
|
49
|
+
scheme.output_activations
|
50
|
+
and not is_kv_cache_quant_scheme(scheme)
|
51
|
+
and not scheme.output_activations.dynamic
|
52
|
+
):
|
48
53
|
delattr(module, "output_observer")
|
49
54
|
|
50
55
|
module.quantization_status = QuantizationStatus.FROZEN
|
@@ -16,35 +16,15 @@
|
|
16
16
|
Miscelaneous helpers for the quantization lifecycle
|
17
17
|
"""
|
18
18
|
|
19
|
-
|
20
19
|
from torch.nn import Module
|
21
20
|
|
22
21
|
|
23
22
|
__all__ = [
|
24
|
-
"update_layer_weight_quant_params",
|
25
23
|
"enable_quantization",
|
26
24
|
"disable_quantization",
|
27
25
|
]
|
28
26
|
|
29
27
|
|
30
|
-
def update_layer_weight_quant_params(layer: Module):
|
31
|
-
weight = getattr(layer, "weight", None)
|
32
|
-
scale = getattr(layer, "weight_scale", None)
|
33
|
-
zero_point = getattr(layer, "weight_zero_point", None)
|
34
|
-
observer = getattr(layer, "weight_observer", None)
|
35
|
-
|
36
|
-
if weight is None or observer is None or scale is None or zero_point is None:
|
37
|
-
# scale, zp, or observer not calibratable or weight not available
|
38
|
-
return
|
39
|
-
|
40
|
-
updated_scale, updated_zero_point = observer(weight)
|
41
|
-
|
42
|
-
# update scale and zero point
|
43
|
-
device = next(layer.parameters()).device
|
44
|
-
scale.data = updated_scale.to(device)
|
45
|
-
zero_point.data = updated_zero_point.to(device)
|
46
|
-
|
47
|
-
|
48
28
|
def enable_quantization(module: Module):
|
49
29
|
module.quantization_enabled = True
|
50
30
|
|