compressed-tensors 0.7.1__tar.gz → 0.8.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/PKG-INFO +1 -1
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +1 -1
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/config/base.py +60 -2
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/quantization/__init__.py +0 -1
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -2
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/quantization/lifecycle/apply.py +1 -16
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/quantization/lifecycle/forward.py +24 -87
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/quantization/lifecycle/initialize.py +21 -24
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/quantization/quant_args.py +11 -22
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/quantization/utils/helpers.py +125 -8
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/registry/registry.py +1 -1
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/version.py +1 -1
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors.egg-info/PKG-INFO +1 -1
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors.egg-info/SOURCES.txt +0 -8
- compressed-tensors-0.7.1/src/compressed_tensors/quantization/cache.py +0 -201
- compressed-tensors-0.7.1/src/compressed_tensors/quantization/lifecycle/calibration.py +0 -70
- compressed-tensors-0.7.1/src/compressed_tensors/quantization/lifecycle/frozen.py +0 -55
- compressed-tensors-0.7.1/src/compressed_tensors/quantization/observers/__init__.py +0 -21
- compressed-tensors-0.7.1/src/compressed_tensors/quantization/observers/base.py +0 -213
- compressed-tensors-0.7.1/src/compressed_tensors/quantization/observers/helpers.py +0 -149
- compressed-tensors-0.7.1/src/compressed_tensors/quantization/observers/min_max.py +0 -104
- compressed-tensors-0.7.1/src/compressed_tensors/quantization/observers/mse.py +0 -162
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/LICENSE +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/README.md +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/pyproject.toml +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/setup.cfg +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/setup.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/__init__.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/base.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/__init__.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/base.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/helpers.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/quantized_compressors/base.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/config/__init__.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/config/dense.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/linear/__init__.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/linear/compressed_linear.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/quantization/quant_config.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/quantization/quant_scheme.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/registry/__init__.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/utils/__init__.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/utils/helpers.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/utils/offload.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/utils/permutations_24.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/utils/permute.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/utils/safetensors_load.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors.egg-info/requires.txt +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors.egg-info/top_level.txt +0 -0
@@ -238,7 +238,7 @@ def pack_scales_24(scales, quantization_args, w_shape):
|
|
238
238
|
_, scale_perm_2_4, scale_perm_single_2_4 = get_permutations_24(num_bits)
|
239
239
|
|
240
240
|
if (
|
241
|
-
quantization_args.strategy
|
241
|
+
quantization_args.strategy == QuantizationStrategy.GROUP
|
242
242
|
and quantization_args.group_size < size_k
|
243
243
|
):
|
244
244
|
scales = scales.reshape((-1, len(scale_perm_2_4)))[:, scale_perm_2_4]
|
@@ -12,16 +12,17 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from enum import Enum
|
15
|
+
from enum import Enum, unique
|
16
16
|
from typing import List, Optional
|
17
17
|
|
18
18
|
from compressed_tensors.registry import RegistryMixin
|
19
19
|
from pydantic import BaseModel
|
20
20
|
|
21
21
|
|
22
|
-
__all__ = ["SparsityCompressionConfig", "CompressionFormat"]
|
22
|
+
__all__ = ["SparsityCompressionConfig", "CompressionFormat", "SparsityStructure"]
|
23
23
|
|
24
24
|
|
25
|
+
@unique
|
25
26
|
class CompressionFormat(Enum):
|
26
27
|
dense = "dense"
|
27
28
|
sparse_bitmask = "sparse-bitmask"
|
@@ -32,6 +33,63 @@ class CompressionFormat(Enum):
|
|
32
33
|
marlin_24 = "marlin-24"
|
33
34
|
|
34
35
|
|
36
|
+
@unique
|
37
|
+
class SparsityStructure(Enum):
|
38
|
+
"""
|
39
|
+
An enumeration to represent different sparsity structures.
|
40
|
+
|
41
|
+
Attributes
|
42
|
+
----------
|
43
|
+
TWO_FOUR : str
|
44
|
+
Represents a 2:4 sparsity structure.
|
45
|
+
ZERO_ZERO : str
|
46
|
+
Represents a 0:0 sparsity structure.
|
47
|
+
UNSTRUCTURED : str
|
48
|
+
Represents an unstructured sparsity structure.
|
49
|
+
|
50
|
+
Examples
|
51
|
+
--------
|
52
|
+
>>> SparsityStructure('2:4')
|
53
|
+
<SparsityStructure.TWO_FOUR: '2:4'>
|
54
|
+
|
55
|
+
>>> SparsityStructure('unstructured')
|
56
|
+
<SparsityStructure.UNSTRUCTURED: 'unstructured'>
|
57
|
+
|
58
|
+
>>> SparsityStructure('2:4') == SparsityStructure.TWO_FOUR
|
59
|
+
True
|
60
|
+
|
61
|
+
>>> SparsityStructure('UNSTRUCTURED') == SparsityStructure.UNSTRUCTURED
|
62
|
+
True
|
63
|
+
|
64
|
+
>>> SparsityStructure(None) == SparsityStructure.UNSTRUCTURED
|
65
|
+
True
|
66
|
+
|
67
|
+
>>> SparsityStructure('invalid')
|
68
|
+
Traceback (most recent call last):
|
69
|
+
...
|
70
|
+
ValueError: invalid is not a valid SparsityStructure
|
71
|
+
"""
|
72
|
+
|
73
|
+
TWO_FOUR = "2:4"
|
74
|
+
UNSTRUCTURED = "unstructured"
|
75
|
+
ZERO_ZERO = "0:0"
|
76
|
+
|
77
|
+
def __new__(cls, value):
|
78
|
+
obj = object.__new__(cls)
|
79
|
+
obj._value_ = value.lower() if value is not None else value
|
80
|
+
return obj
|
81
|
+
|
82
|
+
@classmethod
|
83
|
+
def _missing_(cls, value):
|
84
|
+
# Handle None and case-insensitive values
|
85
|
+
if value is None:
|
86
|
+
return cls.UNSTRUCTURED
|
87
|
+
for member in cls:
|
88
|
+
if member.value == value.lower():
|
89
|
+
return member
|
90
|
+
raise ValueError(f"{value} is not a valid {cls.__name__}")
|
91
|
+
|
92
|
+
|
35
93
|
class SparsityCompressionConfig(RegistryMixin, BaseModel):
|
36
94
|
"""
|
37
95
|
Base data class for storing sparsity compression parameters
|
@@ -22,13 +22,9 @@ from typing import Union
|
|
22
22
|
|
23
23
|
import torch
|
24
24
|
from compressed_tensors.config import CompressionFormat
|
25
|
-
from compressed_tensors.quantization.lifecycle.calibration import (
|
26
|
-
set_module_for_calibration,
|
27
|
-
)
|
28
25
|
from compressed_tensors.quantization.lifecycle.compressed import (
|
29
26
|
compress_quantized_weights,
|
30
27
|
)
|
31
|
-
from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization
|
32
28
|
from compressed_tensors.quantization.lifecycle.initialize import (
|
33
29
|
initialize_module_for_quantization,
|
34
30
|
)
|
@@ -233,6 +229,7 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
|
|
233
229
|
:param model: model to apply quantization to
|
234
230
|
:param status: status to update the module to
|
235
231
|
"""
|
232
|
+
|
236
233
|
current_status = infer_quantization_status(model)
|
237
234
|
|
238
235
|
if status >= QuantizationStatus.INITIALIZED > current_status:
|
@@ -243,18 +240,6 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
|
|
243
240
|
)
|
244
241
|
)
|
245
242
|
|
246
|
-
if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
|
247
|
-
# only quantize weights up front when our end goal state is calibration,
|
248
|
-
# weight quantization parameters are already loaded for frozen/compressed
|
249
|
-
quantize_weights_upfront = status == QuantizationStatus.CALIBRATION
|
250
|
-
model.apply(
|
251
|
-
lambda module: set_module_for_calibration(
|
252
|
-
module, quantize_weights_upfront=quantize_weights_upfront
|
253
|
-
)
|
254
|
-
)
|
255
|
-
if current_status < status >= QuantizationStatus.FROZEN > current_status:
|
256
|
-
model.apply(freeze_module_quantization)
|
257
|
-
|
258
243
|
if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
|
259
244
|
model.apply(compress_quantized_weights)
|
260
245
|
|
@@ -14,14 +14,9 @@
|
|
14
14
|
|
15
15
|
from functools import wraps
|
16
16
|
from math import ceil
|
17
|
-
from typing import
|
17
|
+
from typing import Optional
|
18
18
|
|
19
19
|
import torch
|
20
|
-
from compressed_tensors.quantization.cache import QuantizedKVParameterCache
|
21
|
-
from compressed_tensors.quantization.observers.helpers import (
|
22
|
-
calculate_range,
|
23
|
-
compute_dynamic_scales_and_zp,
|
24
|
-
)
|
25
20
|
from compressed_tensors.quantization.quant_args import (
|
26
21
|
QuantizationArgs,
|
27
22
|
QuantizationStrategy,
|
@@ -29,7 +24,11 @@ from compressed_tensors.quantization.quant_args import (
|
|
29
24
|
)
|
30
25
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
31
26
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
32
|
-
from compressed_tensors.utils import
|
27
|
+
from compressed_tensors.quantization.utils import (
|
28
|
+
calculate_range,
|
29
|
+
compute_dynamic_scales_and_zp,
|
30
|
+
)
|
31
|
+
from compressed_tensors.utils import safe_permute
|
33
32
|
from torch.nn import Module
|
34
33
|
|
35
34
|
|
@@ -38,7 +37,7 @@ __all__ = [
|
|
38
37
|
"dequantize",
|
39
38
|
"fake_quantize",
|
40
39
|
"wrap_module_forward_quantized",
|
41
|
-
"
|
40
|
+
"forward_quantize",
|
42
41
|
]
|
43
42
|
|
44
43
|
|
@@ -275,15 +274,13 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
275
274
|
compressed = module.quantization_status == QuantizationStatus.COMPRESSED
|
276
275
|
|
277
276
|
if scheme.input_activations is not None:
|
278
|
-
#
|
279
|
-
input_ =
|
280
|
-
module, input_, "input", scheme.input_activations
|
281
|
-
)
|
277
|
+
# prehook should calibrate activations before forward call
|
278
|
+
input_ = forward_quantize(module, input_, "input", scheme.input_activations)
|
282
279
|
|
283
280
|
if scheme.weights is not None and not compressed:
|
284
281
|
# calibrate and (fake) quantize weights when applicable
|
285
282
|
unquantized_weight = self.weight.data.clone()
|
286
|
-
self.weight.data =
|
283
|
+
self.weight.data = forward_quantize(
|
287
284
|
module, self.weight, "weight", scheme.weights
|
288
285
|
)
|
289
286
|
|
@@ -291,64 +288,23 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
291
288
|
output = forward_func_orig.__get__(module, module.__class__)(
|
292
289
|
input_, *args[1:], **kwargs
|
293
290
|
)
|
294
|
-
if scheme.output_activations is not None:
|
295
|
-
|
296
|
-
# calibrate and (fake) quantize output activations when applicable
|
297
|
-
# kv_cache scales updated on model self_attn forward call in
|
298
|
-
# wrap_module_forward_quantized_attn
|
299
|
-
output = maybe_calibrate_or_quantize(
|
300
|
-
module, output, "output", scheme.output_activations
|
301
|
-
)
|
302
291
|
|
303
292
|
# restore back to unquantized_value
|
304
293
|
if scheme.weights is not None and not compressed:
|
305
294
|
self.weight.data = unquantized_weight
|
306
295
|
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
# initialize_module_for_quantization
|
318
|
-
if hasattr(module.forward, "__func__"):
|
319
|
-
forward_func_orig = module.forward.__func__
|
320
|
-
else:
|
321
|
-
forward_func_orig = module.forward.func
|
322
|
-
|
323
|
-
@wraps(forward_func_orig) # ensures docstring, names, etc are propagated
|
324
|
-
def wrapped_forward(self, *args, **kwargs):
|
325
|
-
|
326
|
-
# kv cache stored under weights
|
327
|
-
if module.quantization_status == QuantizationStatus.CALIBRATION:
|
328
|
-
quantization_args: QuantizationArgs = scheme.output_activations
|
329
|
-
past_key_value: QuantizedKVParameterCache = quantization_args.get_kv_cache()
|
330
|
-
kwargs["past_key_value"] = past_key_value
|
331
|
-
|
332
|
-
# QuantizedKVParameterCache used for obtaining k_scale, v_scale only,
|
333
|
-
# does not store quantized_key_states and quantized_value_state
|
334
|
-
kwargs["use_cache"] = False
|
335
|
-
|
336
|
-
attn_forward: Callable = forward_func_orig.__get__(module, module.__class__)
|
337
|
-
|
338
|
-
past_key_value.reset_states()
|
339
|
-
|
340
|
-
rtn = attn_forward(*args, **kwargs)
|
341
|
-
|
342
|
-
update_parameter_data(
|
343
|
-
module, past_key_value.k_scales[module.layer_idx], "k_scale"
|
344
|
-
)
|
345
|
-
update_parameter_data(
|
346
|
-
module, past_key_value.v_scales[module.layer_idx], "v_scale"
|
296
|
+
if scheme.output_activations is not None:
|
297
|
+
# forward-hook should calibrate/forward_quantize
|
298
|
+
if (
|
299
|
+
module.quantization_status == QuantizationStatus.CALIBRATION
|
300
|
+
and not scheme.output_activations.dynamic
|
301
|
+
):
|
302
|
+
return output
|
303
|
+
|
304
|
+
output = forward_quantize(
|
305
|
+
module, output, "output", scheme.output_activations
|
347
306
|
)
|
348
|
-
|
349
|
-
return rtn
|
350
|
-
|
351
|
-
return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
|
307
|
+
return output
|
352
308
|
|
353
309
|
# bind wrapped forward to module class so reference to `self` is correct
|
354
310
|
bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
|
@@ -356,12 +312,9 @@ def wrap_module_forward_quantized_attn(module: Module, scheme: QuantizationSchem
|
|
356
312
|
setattr(module, "forward", bound_wrapped_forward)
|
357
313
|
|
358
314
|
|
359
|
-
def
|
315
|
+
def forward_quantize(
|
360
316
|
module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
|
361
317
|
) -> torch.Tensor:
|
362
|
-
# don't run quantization if we haven't entered calibration mode
|
363
|
-
if module.quantization_status == QuantizationStatus.INITIALIZED:
|
364
|
-
return value
|
365
318
|
|
366
319
|
# in compressed mode, the weight is already compressed and quantized so we don't
|
367
320
|
# need to run fake quantization
|
@@ -379,29 +332,13 @@ def maybe_calibrate_or_quantize(
|
|
379
332
|
g_idx = getattr(module, "weight_g_idx", None)
|
380
333
|
|
381
334
|
if args.dynamic:
|
382
|
-
# dynamic quantization -
|
335
|
+
# dynamic quantization - determine the scale/zp on the fly
|
383
336
|
scale, zero_point = compute_dynamic_scales_and_zp(value=value, args=args)
|
384
337
|
else:
|
385
|
-
# static quantization - get
|
338
|
+
# static quantization - get scale and zero point from layer
|
386
339
|
scale = getattr(module, f"{base_name}_scale")
|
387
340
|
zero_point = getattr(module, f"{base_name}_zero_point", None)
|
388
341
|
|
389
|
-
if (
|
390
|
-
module.quantization_status == QuantizationStatus.CALIBRATION
|
391
|
-
and base_name != "weight"
|
392
|
-
):
|
393
|
-
# calibration mode - get new quant params from observer
|
394
|
-
observer = getattr(module, f"{base_name}_observer")
|
395
|
-
|
396
|
-
updated_scale, updated_zero_point = observer(value, g_idx=g_idx)
|
397
|
-
|
398
|
-
# update scale and zero point
|
399
|
-
update_parameter_data(module, updated_scale, f"{base_name}_scale")
|
400
|
-
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
|
401
|
-
|
402
|
-
scale = updated_scale
|
403
|
-
zero_point = updated_zero_point
|
404
|
-
|
405
342
|
return fake_quantize(
|
406
343
|
x=value,
|
407
344
|
scale=scale,
|
@@ -14,13 +14,12 @@
|
|
14
14
|
|
15
15
|
|
16
16
|
import logging
|
17
|
+
from enum import Enum
|
17
18
|
from typing import Optional
|
18
19
|
|
19
20
|
import torch
|
20
|
-
from compressed_tensors.quantization.cache import KVCacheScaleType
|
21
21
|
from compressed_tensors.quantization.lifecycle.forward import (
|
22
22
|
wrap_module_forward_quantized,
|
23
|
-
wrap_module_forward_quantized_attn,
|
24
23
|
)
|
25
24
|
from compressed_tensors.quantization.quant_args import (
|
26
25
|
ActivationOrdering,
|
@@ -36,12 +35,19 @@ from torch.nn import Module, Parameter
|
|
36
35
|
|
37
36
|
__all__ = [
|
38
37
|
"initialize_module_for_quantization",
|
38
|
+
"is_attention_module",
|
39
|
+
"KVCacheScaleType",
|
39
40
|
]
|
40
41
|
|
41
42
|
|
42
43
|
_LOGGER = logging.getLogger(__name__)
|
43
44
|
|
44
45
|
|
46
|
+
class KVCacheScaleType(Enum):
|
47
|
+
KEY = "k_scale"
|
48
|
+
VALUE = "v_scale"
|
49
|
+
|
50
|
+
|
45
51
|
def initialize_module_for_quantization(
|
46
52
|
module: Module,
|
47
53
|
scheme: Optional[QuantizationScheme] = None,
|
@@ -66,15 +72,13 @@ def initialize_module_for_quantization(
|
|
66
72
|
return
|
67
73
|
|
68
74
|
if is_attention_module(module):
|
69
|
-
# wrap forward call of module to perform
|
70
75
|
# quantized actions based on calltime status
|
71
|
-
wrap_module_forward_quantized_attn(module, scheme)
|
72
76
|
_initialize_attn_scales(module)
|
73
77
|
|
74
78
|
else:
|
75
79
|
|
76
80
|
if scheme.input_activations is not None:
|
77
|
-
|
81
|
+
_initialize_scale_zero_point(
|
78
82
|
module,
|
79
83
|
"input",
|
80
84
|
scheme.input_activations,
|
@@ -85,7 +89,7 @@ def initialize_module_for_quantization(
|
|
85
89
|
weight_shape = None
|
86
90
|
if isinstance(module, torch.nn.Linear):
|
87
91
|
weight_shape = module.weight.shape
|
88
|
-
|
92
|
+
_initialize_scale_zero_point(
|
89
93
|
module,
|
90
94
|
"weight",
|
91
95
|
scheme.weights,
|
@@ -101,7 +105,7 @@ def initialize_module_for_quantization(
|
|
101
105
|
|
102
106
|
if scheme.output_activations is not None:
|
103
107
|
if not is_kv_cache_quant_scheme(scheme):
|
104
|
-
|
108
|
+
_initialize_scale_zero_point(
|
105
109
|
module, "output", scheme.output_activations
|
106
110
|
)
|
107
111
|
|
@@ -109,6 +113,7 @@ def initialize_module_for_quantization(
|
|
109
113
|
module.quantization_status = QuantizationStatus.INITIALIZED
|
110
114
|
|
111
115
|
offloaded = False
|
116
|
+
# What is this doing/why isn't this in the attn case?
|
112
117
|
if is_module_offloaded(module):
|
113
118
|
try:
|
114
119
|
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
@@ -146,21 +151,21 @@ def initialize_module_for_quantization(
|
|
146
151
|
module._hf_hook.weights_map = new_prefix_dict
|
147
152
|
|
148
153
|
|
149
|
-
def
|
154
|
+
def is_attention_module(module: Module):
|
155
|
+
return "attention" in module.__class__.__name__.lower() and (
|
156
|
+
hasattr(module, "k_proj")
|
157
|
+
or hasattr(module, "v_proj")
|
158
|
+
or hasattr(module, "qkv_proj")
|
159
|
+
)
|
160
|
+
|
161
|
+
|
162
|
+
def _initialize_scale_zero_point(
|
150
163
|
module: Module,
|
151
164
|
base_name: str,
|
152
165
|
quantization_args: QuantizationArgs,
|
153
166
|
weight_shape: Optional[torch.Size] = None,
|
154
167
|
force_zero_point: bool = True,
|
155
168
|
):
|
156
|
-
|
157
|
-
# initialize observer module and attach as submodule
|
158
|
-
observer = quantization_args.get_observer()
|
159
|
-
# no need to register an observer for dynamic quantization
|
160
|
-
if observer:
|
161
|
-
module.register_module(f"{base_name}_observer", observer)
|
162
|
-
|
163
|
-
# no need to register a scale and zero point for a dynamic quantization
|
164
169
|
if quantization_args.dynamic:
|
165
170
|
return
|
166
171
|
|
@@ -209,14 +214,6 @@ def _initialize_scale_zero_point_observer(
|
|
209
214
|
module.register_parameter(f"{base_name}_g_idx", init_g_idx)
|
210
215
|
|
211
216
|
|
212
|
-
def is_attention_module(module: Module):
|
213
|
-
return "attention" in module.__class__.__name__.lower() and (
|
214
|
-
hasattr(module, "k_proj")
|
215
|
-
or hasattr(module, "v_proj")
|
216
|
-
or hasattr(module, "qkv_proj")
|
217
|
-
)
|
218
|
-
|
219
|
-
|
220
217
|
def _initialize_attn_scales(module: Module) -> None:
|
221
218
|
"""Initlaize k_scale, v_scale for self_attn"""
|
222
219
|
|
@@ -114,20 +114,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
114
114
|
"""
|
115
115
|
:return: torch quantization FakeQuantize built based on these QuantizationArgs
|
116
116
|
"""
|
117
|
-
|
118
|
-
|
119
|
-
# No observer required for the dynamic case
|
120
|
-
if self.dynamic:
|
121
|
-
self.observer = None
|
122
|
-
return self.observer
|
123
|
-
|
124
|
-
return Observer.load_from_registry(self.observer, quantization_args=self)
|
125
|
-
|
126
|
-
def get_kv_cache(self):
|
127
|
-
"""Get the singleton KV Cache"""
|
128
|
-
from compressed_tensors.quantization.cache import QuantizedKVParameterCache
|
129
|
-
|
130
|
-
return QuantizedKVParameterCache(self)
|
117
|
+
return self.observer
|
131
118
|
|
132
119
|
@field_validator("type", mode="before")
|
133
120
|
def validate_type(cls, value) -> QuantizationType:
|
@@ -210,6 +197,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
210
197
|
"activation ordering"
|
211
198
|
)
|
212
199
|
|
200
|
+
# infer observer w.r.t. dynamic
|
213
201
|
if dynamic:
|
214
202
|
if strategy not in (
|
215
203
|
QuantizationStrategy.TOKEN,
|
@@ -221,18 +209,19 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
221
209
|
"quantization",
|
222
210
|
)
|
223
211
|
if observer is not None:
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
212
|
+
if observer != "memoryless": # avoid annoying users with old configs
|
213
|
+
warnings.warn(
|
214
|
+
"No observer is used for dynamic quantization, setting to None"
|
215
|
+
)
|
216
|
+
observer = None
|
228
217
|
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
model.observer = "minmax"
|
218
|
+
elif observer is None:
|
219
|
+
# default to minmax for non-dynamic cases
|
220
|
+
observer = "minmax"
|
233
221
|
|
234
222
|
# write back modified values
|
235
223
|
model.strategy = strategy
|
224
|
+
model.observer = observer
|
236
225
|
return model
|
237
226
|
|
238
227
|
def pytorch_dtype(self) -> torch.dtype:
|