compressed-tensors 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +1 -1
- compressed_tensors/config/base.py +60 -2
- compressed_tensors/quantization/__init__.py +0 -1
- compressed_tensors/quantization/lifecycle/__init__.py +0 -2
- compressed_tensors/quantization/lifecycle/apply.py +1 -16
- compressed_tensors/quantization/lifecycle/forward.py +25 -86
- compressed_tensors/quantization/lifecycle/initialize.py +23 -25
- compressed_tensors/quantization/quant_args.py +28 -15
- compressed_tensors/quantization/quant_scheme.py +3 -0
- compressed_tensors/quantization/utils/helpers.py +125 -8
- compressed_tensors/registry/registry.py +1 -1
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.7.0.dist-info → compressed_tensors-0.8.0.dist-info}/METADATA +1 -1
- {compressed_tensors-0.7.0.dist-info → compressed_tensors-0.8.0.dist-info}/RECORD +17 -26
- {compressed_tensors-0.7.0.dist-info → compressed_tensors-0.8.0.dist-info}/WHEEL +1 -1
- compressed_tensors/quantization/cache.py +0 -201
- compressed_tensors/quantization/lifecycle/calibration.py +0 -70
- compressed_tensors/quantization/lifecycle/frozen.py +0 -55
- compressed_tensors/quantization/observers/__init__.py +0 -22
- compressed_tensors/quantization/observers/base.py +0 -213
- compressed_tensors/quantization/observers/helpers.py +0 -111
- compressed_tensors/quantization/observers/memoryless.py +0 -56
- compressed_tensors/quantization/observers/min_max.py +0 -104
- compressed_tensors/quantization/observers/mse.py +0 -162
- {compressed_tensors-0.7.0.dist-info → compressed_tensors-0.8.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.7.0.dist-info → compressed_tensors-0.8.0.dist-info}/top_level.txt +0 -0
@@ -238,7 +238,7 @@ def pack_scales_24(scales, quantization_args, w_shape):
|
|
238
238
|
_, scale_perm_2_4, scale_perm_single_2_4 = get_permutations_24(num_bits)
|
239
239
|
|
240
240
|
if (
|
241
|
-
quantization_args.strategy
|
241
|
+
quantization_args.strategy == QuantizationStrategy.GROUP
|
242
242
|
and quantization_args.group_size < size_k
|
243
243
|
):
|
244
244
|
scales = scales.reshape((-1, len(scale_perm_2_4)))[:, scale_perm_2_4]
|
@@ -12,16 +12,17 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from enum import Enum
|
15
|
+
from enum import Enum, unique
|
16
16
|
from typing import List, Optional
|
17
17
|
|
18
18
|
from compressed_tensors.registry import RegistryMixin
|
19
19
|
from pydantic import BaseModel
|
20
20
|
|
21
21
|
|
22
|
-
__all__ = ["SparsityCompressionConfig", "CompressionFormat"]
|
22
|
+
__all__ = ["SparsityCompressionConfig", "CompressionFormat", "SparsityStructure"]
|
23
23
|
|
24
24
|
|
25
|
+
@unique
|
25
26
|
class CompressionFormat(Enum):
|
26
27
|
dense = "dense"
|
27
28
|
sparse_bitmask = "sparse-bitmask"
|
@@ -32,6 +33,63 @@ class CompressionFormat(Enum):
|
|
32
33
|
marlin_24 = "marlin-24"
|
33
34
|
|
34
35
|
|
36
|
+
@unique
|
37
|
+
class SparsityStructure(Enum):
|
38
|
+
"""
|
39
|
+
An enumeration to represent different sparsity structures.
|
40
|
+
|
41
|
+
Attributes
|
42
|
+
----------
|
43
|
+
TWO_FOUR : str
|
44
|
+
Represents a 2:4 sparsity structure.
|
45
|
+
ZERO_ZERO : str
|
46
|
+
Represents a 0:0 sparsity structure.
|
47
|
+
UNSTRUCTURED : str
|
48
|
+
Represents an unstructured sparsity structure.
|
49
|
+
|
50
|
+
Examples
|
51
|
+
--------
|
52
|
+
>>> SparsityStructure('2:4')
|
53
|
+
<SparsityStructure.TWO_FOUR: '2:4'>
|
54
|
+
|
55
|
+
>>> SparsityStructure('unstructured')
|
56
|
+
<SparsityStructure.UNSTRUCTURED: 'unstructured'>
|
57
|
+
|
58
|
+
>>> SparsityStructure('2:4') == SparsityStructure.TWO_FOUR
|
59
|
+
True
|
60
|
+
|
61
|
+
>>> SparsityStructure('UNSTRUCTURED') == SparsityStructure.UNSTRUCTURED
|
62
|
+
True
|
63
|
+
|
64
|
+
>>> SparsityStructure(None) == SparsityStructure.UNSTRUCTURED
|
65
|
+
True
|
66
|
+
|
67
|
+
>>> SparsityStructure('invalid')
|
68
|
+
Traceback (most recent call last):
|
69
|
+
...
|
70
|
+
ValueError: invalid is not a valid SparsityStructure
|
71
|
+
"""
|
72
|
+
|
73
|
+
TWO_FOUR = "2:4"
|
74
|
+
UNSTRUCTURED = "unstructured"
|
75
|
+
ZERO_ZERO = "0:0"
|
76
|
+
|
77
|
+
def __new__(cls, value):
|
78
|
+
obj = object.__new__(cls)
|
79
|
+
obj._value_ = value.lower() if value is not None else value
|
80
|
+
return obj
|
81
|
+
|
82
|
+
@classmethod
|
83
|
+
def _missing_(cls, value):
|
84
|
+
# Handle None and case-insensitive values
|
85
|
+
if value is None:
|
86
|
+
return cls.UNSTRUCTURED
|
87
|
+
for member in cls:
|
88
|
+
if member.value == value.lower():
|
89
|
+
return member
|
90
|
+
raise ValueError(f"{value} is not a valid {cls.__name__}")
|
91
|
+
|
92
|
+
|
35
93
|
class SparsityCompressionConfig(RegistryMixin, BaseModel):
|
36
94
|
"""
|
37
95
|
Base data class for storing sparsity compression parameters
|
@@ -22,13 +22,9 @@ from typing import Union
|
|
22
22
|
|
23
23
|
import torch
|
24
24
|
from compressed_tensors.config import CompressionFormat
|
25
|
-
from compressed_tensors.quantization.lifecycle.calibration import (
|
26
|
-
set_module_for_calibration,
|
27
|
-
)
|
28
25
|
from compressed_tensors.quantization.lifecycle.compressed import (
|
29
26
|
compress_quantized_weights,
|
30
27
|
)
|
31
|
-
from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization
|
32
28
|
from compressed_tensors.quantization.lifecycle.initialize import (
|
33
29
|
initialize_module_for_quantization,
|
34
30
|
)
|
@@ -233,6 +229,7 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
|
|
233
229
|
:param model: model to apply quantization to
|
234
230
|
:param status: status to update the module to
|
235
231
|
"""
|
232
|
+
|
236
233
|
current_status = infer_quantization_status(model)
|
237
234
|
|
238
235
|
if status >= QuantizationStatus.INITIALIZED > current_status:
|
@@ -243,18 +240,6 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
|
|
243
240
|
)
|
244
241
|
)
|
245
242
|
|
246
|
-
if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
|
247
|
-
# only quantize weights up front when our end goal state is calibration,
|
248
|
-
# weight quantization parameters are already loaded for frozen/compressed
|
249
|
-
quantize_weights_upfront = status == QuantizationStatus.CALIBRATION
|
250
|
-
model.apply(
|
251
|
-
lambda module: set_module_for_calibration(
|
252
|
-
module, quantize_weights_upfront=quantize_weights_upfront
|
253
|
-
)
|
254
|
-
)
|
255
|
-
if current_status < status >= QuantizationStatus.FROZEN > current_status:
|
256
|
-
model.apply(freeze_module_quantization)
|
257
|
-
|
258
243
|
if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
|
259
244
|
model.apply(compress_quantized_weights)
|
260
245
|
|
@@ -14,11 +14,9 @@
|
|
14
14
|
|
15
15
|
from functools import wraps
|
16
16
|
from math import ceil
|
17
|
-
from typing import
|
17
|
+
from typing import Optional
|
18
18
|
|
19
19
|
import torch
|
20
|
-
from compressed_tensors.quantization.cache import QuantizedKVParameterCache
|
21
|
-
from compressed_tensors.quantization.observers.helpers import calculate_range
|
22
20
|
from compressed_tensors.quantization.quant_args import (
|
23
21
|
QuantizationArgs,
|
24
22
|
QuantizationStrategy,
|
@@ -26,7 +24,11 @@ from compressed_tensors.quantization.quant_args import (
|
|
26
24
|
)
|
27
25
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
28
26
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
29
|
-
from compressed_tensors.utils import
|
27
|
+
from compressed_tensors.quantization.utils import (
|
28
|
+
calculate_range,
|
29
|
+
compute_dynamic_scales_and_zp,
|
30
|
+
)
|
31
|
+
from compressed_tensors.utils import safe_permute
|
30
32
|
from torch.nn import Module
|
31
33
|
|
32
34
|
|
@@ -35,7 +37,7 @@ __all__ = [
|
|
35
37
|
"dequantize",
|
36
38
|
"fake_quantize",
|
37
39
|
"wrap_module_forward_quantized",
|
38
|
-
"
|
40
|
+
"forward_quantize",
|
39
41
|
]
|
40
42
|
|
41
43
|
|
@@ -272,15 +274,13 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
272
274
|
compressed = module.quantization_status == QuantizationStatus.COMPRESSED
|
273
275
|
|
274
276
|
if scheme.input_activations is not None:
|
275
|
-
#
|
276
|
-
input_ =
|
277
|
-
module, input_, "input", scheme.input_activations
|
278
|
-
)
|
277
|
+
# prehook should calibrate activations before forward call
|
278
|
+
input_ = forward_quantize(module, input_, "input", scheme.input_activations)
|
279
279
|
|
280
280
|
if scheme.weights is not None and not compressed:
|
281
281
|
# calibrate and (fake) quantize weights when applicable
|
282
282
|
unquantized_weight = self.weight.data.clone()
|
283
|
-
self.weight.data =
|
283
|
+
self.weight.data = forward_quantize(
|
284
284
|
module, self.weight, "weight", scheme.weights
|
285
285
|
)
|
286
286
|
|
@@ -288,64 +288,23 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
288
288
|
output = forward_func_orig.__get__(module, module.__class__)(
|
289
289
|
input_, *args[1:], **kwargs
|
290
290
|
)
|
291
|
-
if scheme.output_activations is not None:
|
292
|
-
|
293
|
-
# calibrate and (fake) quantize output activations when applicable
|
294
|
-
# kv_cache scales updated on model self_attn forward call in
|
295
|
-
# wrap_module_forward_quantized_attn
|
296
|
-
output = maybe_calibrate_or_quantize(
|
297
|
-
module, output, "output", scheme.output_activations
|
298
|
-
)
|
299
291
|
|
300
292
|
# restore back to unquantized_value
|
301
293
|
if scheme.weights is not None and not compressed:
|
302
294
|
self.weight.data = unquantized_weight
|
303
295
|
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
# initialize_module_for_quantization
|
315
|
-
if hasattr(module.forward, "__func__"):
|
316
|
-
forward_func_orig = module.forward.__func__
|
317
|
-
else:
|
318
|
-
forward_func_orig = module.forward.func
|
319
|
-
|
320
|
-
@wraps(forward_func_orig) # ensures docstring, names, etc are propagated
|
321
|
-
def wrapped_forward(self, *args, **kwargs):
|
322
|
-
|
323
|
-
# kv cache stored under weights
|
324
|
-
if module.quantization_status == QuantizationStatus.CALIBRATION:
|
325
|
-
quantization_args: QuantizationArgs = scheme.output_activations
|
326
|
-
past_key_value: QuantizedKVParameterCache = quantization_args.get_kv_cache()
|
327
|
-
kwargs["past_key_value"] = past_key_value
|
328
|
-
|
329
|
-
# QuantizedKVParameterCache used for obtaining k_scale, v_scale only,
|
330
|
-
# does not store quantized_key_states and quantized_value_state
|
331
|
-
kwargs["use_cache"] = False
|
332
|
-
|
333
|
-
attn_forward: Callable = forward_func_orig.__get__(module, module.__class__)
|
334
|
-
|
335
|
-
past_key_value.reset_states()
|
336
|
-
|
337
|
-
rtn = attn_forward(*args, **kwargs)
|
338
|
-
|
339
|
-
update_parameter_data(
|
340
|
-
module, past_key_value.k_scales[module.layer_idx], "k_scale"
|
341
|
-
)
|
342
|
-
update_parameter_data(
|
343
|
-
module, past_key_value.v_scales[module.layer_idx], "v_scale"
|
296
|
+
if scheme.output_activations is not None:
|
297
|
+
# forward-hook should calibrate/forward_quantize
|
298
|
+
if (
|
299
|
+
module.quantization_status == QuantizationStatus.CALIBRATION
|
300
|
+
and not scheme.output_activations.dynamic
|
301
|
+
):
|
302
|
+
return output
|
303
|
+
|
304
|
+
output = forward_quantize(
|
305
|
+
module, output, "output", scheme.output_activations
|
344
306
|
)
|
345
|
-
|
346
|
-
return rtn
|
347
|
-
|
348
|
-
return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
|
307
|
+
return output
|
349
308
|
|
350
309
|
# bind wrapped forward to module class so reference to `self` is correct
|
351
310
|
bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
|
@@ -353,12 +312,9 @@ def wrap_module_forward_quantized_attn(module: Module, scheme: QuantizationSchem
|
|
353
312
|
setattr(module, "forward", bound_wrapped_forward)
|
354
313
|
|
355
314
|
|
356
|
-
def
|
315
|
+
def forward_quantize(
|
357
316
|
module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
|
358
317
|
) -> torch.Tensor:
|
359
|
-
# don't run quantization if we haven't entered calibration mode
|
360
|
-
if module.quantization_status == QuantizationStatus.INITIALIZED:
|
361
|
-
return value
|
362
318
|
|
363
319
|
# in compressed mode, the weight is already compressed and quantized so we don't
|
364
320
|
# need to run fake quantization
|
@@ -376,30 +332,13 @@ def maybe_calibrate_or_quantize(
|
|
376
332
|
g_idx = getattr(module, "weight_g_idx", None)
|
377
333
|
|
378
334
|
if args.dynamic:
|
379
|
-
# dynamic quantization -
|
380
|
-
|
381
|
-
scale, zero_point = observer(value, g_idx=g_idx)
|
335
|
+
# dynamic quantization - determine the scale/zp on the fly
|
336
|
+
scale, zero_point = compute_dynamic_scales_and_zp(value=value, args=args)
|
382
337
|
else:
|
383
|
-
# static quantization - get
|
338
|
+
# static quantization - get scale and zero point from layer
|
384
339
|
scale = getattr(module, f"{base_name}_scale")
|
385
340
|
zero_point = getattr(module, f"{base_name}_zero_point", None)
|
386
341
|
|
387
|
-
if (
|
388
|
-
module.quantization_status == QuantizationStatus.CALIBRATION
|
389
|
-
and base_name != "weight"
|
390
|
-
):
|
391
|
-
# calibration mode - get new quant params from observer
|
392
|
-
observer = getattr(module, f"{base_name}_observer")
|
393
|
-
|
394
|
-
updated_scale, updated_zero_point = observer(value, g_idx=g_idx)
|
395
|
-
|
396
|
-
# update scale and zero point
|
397
|
-
update_parameter_data(module, updated_scale, f"{base_name}_scale")
|
398
|
-
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
|
399
|
-
|
400
|
-
scale = updated_scale
|
401
|
-
zero_point = updated_zero_point
|
402
|
-
|
403
342
|
return fake_quantize(
|
404
343
|
x=value,
|
405
344
|
scale=scale,
|
@@ -14,13 +14,12 @@
|
|
14
14
|
|
15
15
|
|
16
16
|
import logging
|
17
|
+
from enum import Enum
|
17
18
|
from typing import Optional
|
18
19
|
|
19
20
|
import torch
|
20
|
-
from compressed_tensors.quantization.cache import KVCacheScaleType
|
21
21
|
from compressed_tensors.quantization.lifecycle.forward import (
|
22
22
|
wrap_module_forward_quantized,
|
23
|
-
wrap_module_forward_quantized_attn,
|
24
23
|
)
|
25
24
|
from compressed_tensors.quantization.quant_args import (
|
26
25
|
ActivationOrdering,
|
@@ -36,12 +35,19 @@ from torch.nn import Module, Parameter
|
|
36
35
|
|
37
36
|
__all__ = [
|
38
37
|
"initialize_module_for_quantization",
|
38
|
+
"is_attention_module",
|
39
|
+
"KVCacheScaleType",
|
39
40
|
]
|
40
41
|
|
41
42
|
|
42
43
|
_LOGGER = logging.getLogger(__name__)
|
43
44
|
|
44
45
|
|
46
|
+
class KVCacheScaleType(Enum):
|
47
|
+
KEY = "k_scale"
|
48
|
+
VALUE = "v_scale"
|
49
|
+
|
50
|
+
|
45
51
|
def initialize_module_for_quantization(
|
46
52
|
module: Module,
|
47
53
|
scheme: Optional[QuantizationScheme] = None,
|
@@ -66,15 +72,13 @@ def initialize_module_for_quantization(
|
|
66
72
|
return
|
67
73
|
|
68
74
|
if is_attention_module(module):
|
69
|
-
# wrap forward call of module to perform
|
70
75
|
# quantized actions based on calltime status
|
71
|
-
wrap_module_forward_quantized_attn(module, scheme)
|
72
76
|
_initialize_attn_scales(module)
|
73
77
|
|
74
78
|
else:
|
75
79
|
|
76
80
|
if scheme.input_activations is not None:
|
77
|
-
|
81
|
+
_initialize_scale_zero_point(
|
78
82
|
module,
|
79
83
|
"input",
|
80
84
|
scheme.input_activations,
|
@@ -85,7 +89,7 @@ def initialize_module_for_quantization(
|
|
85
89
|
weight_shape = None
|
86
90
|
if isinstance(module, torch.nn.Linear):
|
87
91
|
weight_shape = module.weight.shape
|
88
|
-
|
92
|
+
_initialize_scale_zero_point(
|
89
93
|
module,
|
90
94
|
"weight",
|
91
95
|
scheme.weights,
|
@@ -101,7 +105,7 @@ def initialize_module_for_quantization(
|
|
101
105
|
|
102
106
|
if scheme.output_activations is not None:
|
103
107
|
if not is_kv_cache_quant_scheme(scheme):
|
104
|
-
|
108
|
+
_initialize_scale_zero_point(
|
105
109
|
module, "output", scheme.output_activations
|
106
110
|
)
|
107
111
|
|
@@ -109,6 +113,7 @@ def initialize_module_for_quantization(
|
|
109
113
|
module.quantization_status = QuantizationStatus.INITIALIZED
|
110
114
|
|
111
115
|
offloaded = False
|
116
|
+
# What is this doing/why isn't this in the attn case?
|
112
117
|
if is_module_offloaded(module):
|
113
118
|
try:
|
114
119
|
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
@@ -146,19 +151,23 @@ def initialize_module_for_quantization(
|
|
146
151
|
module._hf_hook.weights_map = new_prefix_dict
|
147
152
|
|
148
153
|
|
149
|
-
def
|
154
|
+
def is_attention_module(module: Module):
|
155
|
+
return "attention" in module.__class__.__name__.lower() and (
|
156
|
+
hasattr(module, "k_proj")
|
157
|
+
or hasattr(module, "v_proj")
|
158
|
+
or hasattr(module, "qkv_proj")
|
159
|
+
)
|
160
|
+
|
161
|
+
|
162
|
+
def _initialize_scale_zero_point(
|
150
163
|
module: Module,
|
151
164
|
base_name: str,
|
152
165
|
quantization_args: QuantizationArgs,
|
153
166
|
weight_shape: Optional[torch.Size] = None,
|
154
167
|
force_zero_point: bool = True,
|
155
168
|
):
|
156
|
-
# initialize observer module and attach as submodule
|
157
|
-
observer = quantization_args.get_observer()
|
158
|
-
module.register_module(f"{base_name}_observer", observer)
|
159
|
-
|
160
169
|
if quantization_args.dynamic:
|
161
|
-
return
|
170
|
+
return
|
162
171
|
|
163
172
|
device = next(module.parameters()).device
|
164
173
|
if is_module_offloaded(module):
|
@@ -173,10 +182,7 @@ def _initialize_scale_zero_point_observer(
|
|
173
182
|
expected_shape = (weight_shape[0], 1)
|
174
183
|
elif quantization_args.strategy == QuantizationStrategy.GROUP:
|
175
184
|
num_groups = weight_shape[1] // quantization_args.group_size
|
176
|
-
expected_shape = (
|
177
|
-
weight_shape[0],
|
178
|
-
max(num_groups, 1)
|
179
|
-
)
|
185
|
+
expected_shape = (weight_shape[0], max(num_groups, 1))
|
180
186
|
|
181
187
|
scale_dtype = module.weight.dtype
|
182
188
|
if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]:
|
@@ -208,14 +214,6 @@ def _initialize_scale_zero_point_observer(
|
|
208
214
|
module.register_parameter(f"{base_name}_g_idx", init_g_idx)
|
209
215
|
|
210
216
|
|
211
|
-
def is_attention_module(module: Module):
|
212
|
-
return "attention" in module.__class__.__name__.lower() and (
|
213
|
-
hasattr(module, "k_proj")
|
214
|
-
or hasattr(module, "v_proj")
|
215
|
-
or hasattr(module, "qkv_proj")
|
216
|
-
)
|
217
|
-
|
218
|
-
|
219
217
|
def _initialize_attn_scales(module: Module) -> None:
|
220
218
|
"""Initlaize k_scale, v_scale for self_attn"""
|
221
219
|
|
@@ -12,6 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
import warnings
|
15
16
|
from enum import Enum
|
16
17
|
from typing import Any, Dict, Optional, Union
|
17
18
|
|
@@ -94,7 +95,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
94
95
|
block_structure: Optional[str] = None
|
95
96
|
dynamic: bool = False
|
96
97
|
actorder: Union[ActivationOrdering, bool, None] = None
|
97
|
-
observer: str = Field(
|
98
|
+
observer: Optional[str] = Field(
|
98
99
|
default="minmax",
|
99
100
|
description=(
|
100
101
|
"The class to use to compute the quantization param - "
|
@@ -113,20 +114,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
113
114
|
"""
|
114
115
|
:return: torch quantization FakeQuantize built based on these QuantizationArgs
|
115
116
|
"""
|
116
|
-
|
117
|
-
|
118
|
-
if self.dynamic:
|
119
|
-
# override defualt observer for dynamic, you never want minmax which
|
120
|
-
# keeps state across samples for dynamic
|
121
|
-
self.observer = "memoryless"
|
122
|
-
|
123
|
-
return Observer.load_from_registry(self.observer, quantization_args=self)
|
124
|
-
|
125
|
-
def get_kv_cache(self):
|
126
|
-
"""Get the singleton KV Cache"""
|
127
|
-
from compressed_tensors.quantization.cache import QuantizedKVParameterCache
|
128
|
-
|
129
|
-
return QuantizedKVParameterCache(self)
|
117
|
+
return self.observer
|
130
118
|
|
131
119
|
@field_validator("type", mode="before")
|
132
120
|
def validate_type(cls, value) -> QuantizationType:
|
@@ -171,6 +159,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
171
159
|
strategy = model.strategy
|
172
160
|
group_size = model.group_size
|
173
161
|
actorder = model.actorder
|
162
|
+
dynamic = model.dynamic
|
163
|
+
observer = model.observer
|
174
164
|
|
175
165
|
# infer strategy
|
176
166
|
if strategy is None:
|
@@ -207,8 +197,31 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
207
197
|
"activation ordering"
|
208
198
|
)
|
209
199
|
|
200
|
+
# infer observer w.r.t. dynamic
|
201
|
+
if dynamic:
|
202
|
+
if strategy not in (
|
203
|
+
QuantizationStrategy.TOKEN,
|
204
|
+
QuantizationStrategy.TENSOR,
|
205
|
+
):
|
206
|
+
raise ValueError(
|
207
|
+
f"One of {QuantizationStrategy.TOKEN} or "
|
208
|
+
f"{QuantizationStrategy.TENSOR} must be used for dynamic ",
|
209
|
+
"quantization",
|
210
|
+
)
|
211
|
+
if observer is not None:
|
212
|
+
if observer != "memoryless": # avoid annoying users with old configs
|
213
|
+
warnings.warn(
|
214
|
+
"No observer is used for dynamic quantization, setting to None"
|
215
|
+
)
|
216
|
+
observer = None
|
217
|
+
|
218
|
+
elif observer is None:
|
219
|
+
# default to minmax for non-dynamic cases
|
220
|
+
observer = "minmax"
|
221
|
+
|
210
222
|
# write back modified values
|
211
223
|
model.strategy = strategy
|
224
|
+
model.observer = observer
|
212
225
|
return model
|
213
226
|
|
214
227
|
def pytorch_dtype(self) -> torch.dtype:
|
@@ -122,6 +122,7 @@ INT8_W8A8 = dict(
|
|
122
122
|
strategy=QuantizationStrategy.TOKEN,
|
123
123
|
symmetric=True,
|
124
124
|
dynamic=True,
|
125
|
+
observer=None,
|
125
126
|
),
|
126
127
|
)
|
127
128
|
|
@@ -164,6 +165,7 @@ INT8_W4A8 = dict(
|
|
164
165
|
strategy=QuantizationStrategy.TOKEN,
|
165
166
|
symmetric=True,
|
166
167
|
dynamic=True,
|
168
|
+
observer=None,
|
167
169
|
),
|
168
170
|
)
|
169
171
|
|
@@ -200,6 +202,7 @@ FP8_DYNAMIC = dict(
|
|
200
202
|
strategy=QuantizationStrategy.TOKEN,
|
201
203
|
symmetric=True,
|
202
204
|
dynamic=True,
|
205
|
+
observer=None,
|
203
206
|
),
|
204
207
|
)
|
205
208
|
|