compressed-tensors 0.7.1__tar.gz → 0.8.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/PKG-INFO +1 -1
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +17 -5
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +4 -2
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +2 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +1 -1
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/config/base.py +60 -2
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/linear/compressed_linear.py +3 -1
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/quantization/__init__.py +0 -1
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -2
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/quantization/lifecycle/apply.py +3 -17
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/quantization/lifecycle/forward.py +24 -87
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/quantization/lifecycle/initialize.py +21 -24
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/quantization/quant_args.py +27 -25
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/quantization/quant_config.py +2 -2
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/quantization/quant_scheme.py +17 -24
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/quantization/utils/helpers.py +125 -8
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/registry/registry.py +1 -1
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/utils/helpers.py +33 -1
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/version.py +1 -1
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors.egg-info/PKG-INFO +1 -1
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors.egg-info/SOURCES.txt +0 -8
- compressed-tensors-0.7.1/src/compressed_tensors/quantization/cache.py +0 -201
- compressed-tensors-0.7.1/src/compressed_tensors/quantization/lifecycle/calibration.py +0 -70
- compressed-tensors-0.7.1/src/compressed_tensors/quantization/lifecycle/frozen.py +0 -55
- compressed-tensors-0.7.1/src/compressed_tensors/quantization/observers/__init__.py +0 -21
- compressed-tensors-0.7.1/src/compressed_tensors/quantization/observers/base.py +0 -213
- compressed-tensors-0.7.1/src/compressed_tensors/quantization/observers/helpers.py +0 -149
- compressed-tensors-0.7.1/src/compressed_tensors/quantization/observers/min_max.py +0 -104
- compressed-tensors-0.7.1/src/compressed_tensors/quantization/observers/mse.py +0 -162
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/LICENSE +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/README.md +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/pyproject.toml +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/setup.cfg +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/setup.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/__init__.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/base.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/compressors/__init__.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/compressors/base.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/compressors/helpers.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/compressors/quantized_compressors/base.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/config/__init__.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/config/dense.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/linear/__init__.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/registry/__init__.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/utils/__init__.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/utils/offload.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/utils/permutations_24.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/utils/permute.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/utils/safetensors_load.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors.egg-info/requires.txt +0 -0
- {compressed-tensors-0.7.1 → compressed-tensors-0.8.1}/src/compressed_tensors.egg-info/top_level.txt +0 -0
@@ -24,7 +24,6 @@ import compressed_tensors
|
|
24
24
|
import torch
|
25
25
|
import transformers
|
26
26
|
from compressed_tensors.base import (
|
27
|
-
COMPRESSION_CONFIG_NAME,
|
28
27
|
COMPRESSION_VERSION_NAME,
|
29
28
|
QUANTIZATION_CONFIG_NAME,
|
30
29
|
QUANTIZATION_METHOD_NAME,
|
@@ -39,6 +38,7 @@ from compressed_tensors.quantization import (
|
|
39
38
|
apply_quantization_config,
|
40
39
|
load_pretrained_quantization,
|
41
40
|
)
|
41
|
+
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
42
42
|
from compressed_tensors.quantization.utils import (
|
43
43
|
is_module_quantized,
|
44
44
|
iter_named_leaf_modules,
|
@@ -103,12 +103,14 @@ class ModelCompressor:
|
|
103
103
|
:return: compressor for the configs, or None if model is not compressed
|
104
104
|
"""
|
105
105
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
106
|
-
compression_config = getattr(config,
|
106
|
+
compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)
|
107
|
+
|
107
108
|
return cls.from_compression_config(compression_config)
|
108
109
|
|
109
110
|
@classmethod
|
110
111
|
def from_compression_config(
|
111
|
-
cls,
|
112
|
+
cls,
|
113
|
+
compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"],
|
112
114
|
):
|
113
115
|
"""
|
114
116
|
:param compression_config:
|
@@ -265,7 +267,11 @@ class ModelCompressor:
|
|
265
267
|
state_dict = model.state_dict()
|
266
268
|
|
267
269
|
compressed_state_dict = state_dict
|
268
|
-
|
270
|
+
|
271
|
+
quantized_modules_to_args: Dict[
|
272
|
+
str, QuantizationArgs
|
273
|
+
] = map_modules_to_quant_args(model)
|
274
|
+
|
269
275
|
if self.quantization_compressor is not None:
|
270
276
|
compressed_state_dict = self.quantization_compressor.compress(
|
271
277
|
state_dict, names_to_scheme=quantized_modules_to_args
|
@@ -369,7 +375,13 @@ class ModelCompressor:
|
|
369
375
|
update_parameter_data(module, data, param_name)
|
370
376
|
|
371
377
|
|
372
|
-
def map_modules_to_quant_args(model: Module) -> Dict:
|
378
|
+
def map_modules_to_quant_args(model: Module) -> Dict[str, QuantizationArgs]:
|
379
|
+
"""
|
380
|
+
Given a pytorch model, map out the submodule name (usually linear layers)
|
381
|
+
to the QuantizationArgs
|
382
|
+
|
383
|
+
:param model: pytorch model
|
384
|
+
"""
|
373
385
|
quantized_modules_to_args = {}
|
374
386
|
for name, submodule in iter_named_leaf_modules(model):
|
375
387
|
if is_module_quantized(submodule):
|
@@ -93,9 +93,11 @@ class NaiveQuantizationCompressor(BaseQuantizationCompressor):
|
|
93
93
|
args=quantization_args,
|
94
94
|
dtype=quantization_args.pytorch_dtype(),
|
95
95
|
)
|
96
|
+
else:
|
97
|
+
quantized_weight = weight
|
96
98
|
|
97
|
-
|
98
|
-
|
99
|
+
if device is not None:
|
100
|
+
quantized_weight = quantized_weight.to(device)
|
99
101
|
|
100
102
|
return {"weight": quantized_weight}
|
101
103
|
|
@@ -94,6 +94,8 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
|
|
94
94
|
args=quantization_args,
|
95
95
|
dtype=torch.int8,
|
96
96
|
)
|
97
|
+
else:
|
98
|
+
quantized_weight = weight
|
97
99
|
|
98
100
|
packed_weight = pack_to_int32(quantized_weight, quantization_args.num_bits)
|
99
101
|
weight_shape = torch.tensor(weight.shape)
|
@@ -238,7 +238,7 @@ def pack_scales_24(scales, quantization_args, w_shape):
|
|
238
238
|
_, scale_perm_2_4, scale_perm_single_2_4 = get_permutations_24(num_bits)
|
239
239
|
|
240
240
|
if (
|
241
|
-
quantization_args.strategy
|
241
|
+
quantization_args.strategy == QuantizationStrategy.GROUP
|
242
242
|
and quantization_args.group_size < size_k
|
243
243
|
):
|
244
244
|
scales = scales.reshape((-1, len(scale_perm_2_4)))[:, scale_perm_2_4]
|
@@ -12,16 +12,17 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from enum import Enum
|
15
|
+
from enum import Enum, unique
|
16
16
|
from typing import List, Optional
|
17
17
|
|
18
18
|
from compressed_tensors.registry import RegistryMixin
|
19
19
|
from pydantic import BaseModel
|
20
20
|
|
21
21
|
|
22
|
-
__all__ = ["SparsityCompressionConfig", "CompressionFormat"]
|
22
|
+
__all__ = ["SparsityCompressionConfig", "CompressionFormat", "SparsityStructure"]
|
23
23
|
|
24
24
|
|
25
|
+
@unique
|
25
26
|
class CompressionFormat(Enum):
|
26
27
|
dense = "dense"
|
27
28
|
sparse_bitmask = "sparse-bitmask"
|
@@ -32,6 +33,63 @@ class CompressionFormat(Enum):
|
|
32
33
|
marlin_24 = "marlin-24"
|
33
34
|
|
34
35
|
|
36
|
+
@unique
|
37
|
+
class SparsityStructure(Enum):
|
38
|
+
"""
|
39
|
+
An enumeration to represent different sparsity structures.
|
40
|
+
|
41
|
+
Attributes
|
42
|
+
----------
|
43
|
+
TWO_FOUR : str
|
44
|
+
Represents a 2:4 sparsity structure.
|
45
|
+
ZERO_ZERO : str
|
46
|
+
Represents a 0:0 sparsity structure.
|
47
|
+
UNSTRUCTURED : str
|
48
|
+
Represents an unstructured sparsity structure.
|
49
|
+
|
50
|
+
Examples
|
51
|
+
--------
|
52
|
+
>>> SparsityStructure('2:4')
|
53
|
+
<SparsityStructure.TWO_FOUR: '2:4'>
|
54
|
+
|
55
|
+
>>> SparsityStructure('unstructured')
|
56
|
+
<SparsityStructure.UNSTRUCTURED: 'unstructured'>
|
57
|
+
|
58
|
+
>>> SparsityStructure('2:4') == SparsityStructure.TWO_FOUR
|
59
|
+
True
|
60
|
+
|
61
|
+
>>> SparsityStructure('UNSTRUCTURED') == SparsityStructure.UNSTRUCTURED
|
62
|
+
True
|
63
|
+
|
64
|
+
>>> SparsityStructure(None) == SparsityStructure.UNSTRUCTURED
|
65
|
+
True
|
66
|
+
|
67
|
+
>>> SparsityStructure('invalid')
|
68
|
+
Traceback (most recent call last):
|
69
|
+
...
|
70
|
+
ValueError: invalid is not a valid SparsityStructure
|
71
|
+
"""
|
72
|
+
|
73
|
+
TWO_FOUR = "2:4"
|
74
|
+
UNSTRUCTURED = "unstructured"
|
75
|
+
ZERO_ZERO = "0:0"
|
76
|
+
|
77
|
+
def __new__(cls, value):
|
78
|
+
obj = object.__new__(cls)
|
79
|
+
obj._value_ = value.lower() if value is not None else value
|
80
|
+
return obj
|
81
|
+
|
82
|
+
@classmethod
|
83
|
+
def _missing_(cls, value):
|
84
|
+
# Handle None and case-insensitive values
|
85
|
+
if value is None:
|
86
|
+
return cls.UNSTRUCTURED
|
87
|
+
for member in cls:
|
88
|
+
if member.value == value.lower():
|
89
|
+
return member
|
90
|
+
raise ValueError(f"{value} is not a valid {cls.__name__}")
|
91
|
+
|
92
|
+
|
35
93
|
class SparsityCompressionConfig(RegistryMixin, BaseModel):
|
36
94
|
"""
|
37
95
|
Base data class for storing sparsity compression parameters
|
@@ -12,6 +12,8 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
from typing import Dict, Tuple
|
16
|
+
|
15
17
|
import torch
|
16
18
|
from compressed_tensors.compressors.base import BaseCompressor
|
17
19
|
from compressed_tensors.quantization import (
|
@@ -53,7 +55,7 @@ class CompressedLinear(Linear):
|
|
53
55
|
)
|
54
56
|
|
55
57
|
# get the shape and dtype of compressed parameters
|
56
|
-
compression_params = module.compressor.compression_param_info(
|
58
|
+
compression_params: Dict[str, Tuple] = module.compressor.compression_param_info(
|
57
59
|
module.weight.shape, quantization_scheme.weights
|
58
60
|
)
|
59
61
|
|
@@ -22,13 +22,9 @@ from typing import Union
|
|
22
22
|
|
23
23
|
import torch
|
24
24
|
from compressed_tensors.config import CompressionFormat
|
25
|
-
from compressed_tensors.quantization.lifecycle.calibration import (
|
26
|
-
set_module_for_calibration,
|
27
|
-
)
|
28
25
|
from compressed_tensors.quantization.lifecycle.compressed import (
|
29
26
|
compress_quantized_weights,
|
30
27
|
)
|
31
|
-
from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization
|
32
28
|
from compressed_tensors.quantization.lifecycle.initialize import (
|
33
29
|
initialize_module_for_quantization,
|
34
30
|
)
|
@@ -110,7 +106,8 @@ def apply_quantization_config(
|
|
110
106
|
model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
|
111
107
|
) -> OrderedDict:
|
112
108
|
"""
|
113
|
-
Initializes the model for quantization in-place based on the given config
|
109
|
+
Initializes the model for quantization in-place based on the given config.
|
110
|
+
Optionally coverts quantizable modules to compressed_linear modules
|
114
111
|
|
115
112
|
:param model: model to apply quantization config to
|
116
113
|
:param config: quantization config
|
@@ -233,6 +230,7 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
|
|
233
230
|
:param model: model to apply quantization to
|
234
231
|
:param status: status to update the module to
|
235
232
|
"""
|
233
|
+
|
236
234
|
current_status = infer_quantization_status(model)
|
237
235
|
|
238
236
|
if status >= QuantizationStatus.INITIALIZED > current_status:
|
@@ -243,18 +241,6 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
|
|
243
241
|
)
|
244
242
|
)
|
245
243
|
|
246
|
-
if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
|
247
|
-
# only quantize weights up front when our end goal state is calibration,
|
248
|
-
# weight quantization parameters are already loaded for frozen/compressed
|
249
|
-
quantize_weights_upfront = status == QuantizationStatus.CALIBRATION
|
250
|
-
model.apply(
|
251
|
-
lambda module: set_module_for_calibration(
|
252
|
-
module, quantize_weights_upfront=quantize_weights_upfront
|
253
|
-
)
|
254
|
-
)
|
255
|
-
if current_status < status >= QuantizationStatus.FROZEN > current_status:
|
256
|
-
model.apply(freeze_module_quantization)
|
257
|
-
|
258
244
|
if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
|
259
245
|
model.apply(compress_quantized_weights)
|
260
246
|
|
@@ -14,14 +14,9 @@
|
|
14
14
|
|
15
15
|
from functools import wraps
|
16
16
|
from math import ceil
|
17
|
-
from typing import
|
17
|
+
from typing import Optional
|
18
18
|
|
19
19
|
import torch
|
20
|
-
from compressed_tensors.quantization.cache import QuantizedKVParameterCache
|
21
|
-
from compressed_tensors.quantization.observers.helpers import (
|
22
|
-
calculate_range,
|
23
|
-
compute_dynamic_scales_and_zp,
|
24
|
-
)
|
25
20
|
from compressed_tensors.quantization.quant_args import (
|
26
21
|
QuantizationArgs,
|
27
22
|
QuantizationStrategy,
|
@@ -29,7 +24,11 @@ from compressed_tensors.quantization.quant_args import (
|
|
29
24
|
)
|
30
25
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
31
26
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
32
|
-
from compressed_tensors.utils import
|
27
|
+
from compressed_tensors.quantization.utils import (
|
28
|
+
calculate_range,
|
29
|
+
compute_dynamic_scales_and_zp,
|
30
|
+
)
|
31
|
+
from compressed_tensors.utils import safe_permute
|
33
32
|
from torch.nn import Module
|
34
33
|
|
35
34
|
|
@@ -38,7 +37,7 @@ __all__ = [
|
|
38
37
|
"dequantize",
|
39
38
|
"fake_quantize",
|
40
39
|
"wrap_module_forward_quantized",
|
41
|
-
"
|
40
|
+
"forward_quantize",
|
42
41
|
]
|
43
42
|
|
44
43
|
|
@@ -275,15 +274,13 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
275
274
|
compressed = module.quantization_status == QuantizationStatus.COMPRESSED
|
276
275
|
|
277
276
|
if scheme.input_activations is not None:
|
278
|
-
#
|
279
|
-
input_ =
|
280
|
-
module, input_, "input", scheme.input_activations
|
281
|
-
)
|
277
|
+
# prehook should calibrate activations before forward call
|
278
|
+
input_ = forward_quantize(module, input_, "input", scheme.input_activations)
|
282
279
|
|
283
280
|
if scheme.weights is not None and not compressed:
|
284
281
|
# calibrate and (fake) quantize weights when applicable
|
285
282
|
unquantized_weight = self.weight.data.clone()
|
286
|
-
self.weight.data =
|
283
|
+
self.weight.data = forward_quantize(
|
287
284
|
module, self.weight, "weight", scheme.weights
|
288
285
|
)
|
289
286
|
|
@@ -291,64 +288,23 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
291
288
|
output = forward_func_orig.__get__(module, module.__class__)(
|
292
289
|
input_, *args[1:], **kwargs
|
293
290
|
)
|
294
|
-
if scheme.output_activations is not None:
|
295
|
-
|
296
|
-
# calibrate and (fake) quantize output activations when applicable
|
297
|
-
# kv_cache scales updated on model self_attn forward call in
|
298
|
-
# wrap_module_forward_quantized_attn
|
299
|
-
output = maybe_calibrate_or_quantize(
|
300
|
-
module, output, "output", scheme.output_activations
|
301
|
-
)
|
302
291
|
|
303
292
|
# restore back to unquantized_value
|
304
293
|
if scheme.weights is not None and not compressed:
|
305
294
|
self.weight.data = unquantized_weight
|
306
295
|
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
# initialize_module_for_quantization
|
318
|
-
if hasattr(module.forward, "__func__"):
|
319
|
-
forward_func_orig = module.forward.__func__
|
320
|
-
else:
|
321
|
-
forward_func_orig = module.forward.func
|
322
|
-
|
323
|
-
@wraps(forward_func_orig) # ensures docstring, names, etc are propagated
|
324
|
-
def wrapped_forward(self, *args, **kwargs):
|
325
|
-
|
326
|
-
# kv cache stored under weights
|
327
|
-
if module.quantization_status == QuantizationStatus.CALIBRATION:
|
328
|
-
quantization_args: QuantizationArgs = scheme.output_activations
|
329
|
-
past_key_value: QuantizedKVParameterCache = quantization_args.get_kv_cache()
|
330
|
-
kwargs["past_key_value"] = past_key_value
|
331
|
-
|
332
|
-
# QuantizedKVParameterCache used for obtaining k_scale, v_scale only,
|
333
|
-
# does not store quantized_key_states and quantized_value_state
|
334
|
-
kwargs["use_cache"] = False
|
335
|
-
|
336
|
-
attn_forward: Callable = forward_func_orig.__get__(module, module.__class__)
|
337
|
-
|
338
|
-
past_key_value.reset_states()
|
339
|
-
|
340
|
-
rtn = attn_forward(*args, **kwargs)
|
341
|
-
|
342
|
-
update_parameter_data(
|
343
|
-
module, past_key_value.k_scales[module.layer_idx], "k_scale"
|
344
|
-
)
|
345
|
-
update_parameter_data(
|
346
|
-
module, past_key_value.v_scales[module.layer_idx], "v_scale"
|
296
|
+
if scheme.output_activations is not None:
|
297
|
+
# forward-hook should calibrate/forward_quantize
|
298
|
+
if (
|
299
|
+
module.quantization_status == QuantizationStatus.CALIBRATION
|
300
|
+
and not scheme.output_activations.dynamic
|
301
|
+
):
|
302
|
+
return output
|
303
|
+
|
304
|
+
output = forward_quantize(
|
305
|
+
module, output, "output", scheme.output_activations
|
347
306
|
)
|
348
|
-
|
349
|
-
return rtn
|
350
|
-
|
351
|
-
return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
|
307
|
+
return output
|
352
308
|
|
353
309
|
# bind wrapped forward to module class so reference to `self` is correct
|
354
310
|
bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
|
@@ -356,12 +312,9 @@ def wrap_module_forward_quantized_attn(module: Module, scheme: QuantizationSchem
|
|
356
312
|
setattr(module, "forward", bound_wrapped_forward)
|
357
313
|
|
358
314
|
|
359
|
-
def
|
315
|
+
def forward_quantize(
|
360
316
|
module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
|
361
317
|
) -> torch.Tensor:
|
362
|
-
# don't run quantization if we haven't entered calibration mode
|
363
|
-
if module.quantization_status == QuantizationStatus.INITIALIZED:
|
364
|
-
return value
|
365
318
|
|
366
319
|
# in compressed mode, the weight is already compressed and quantized so we don't
|
367
320
|
# need to run fake quantization
|
@@ -379,29 +332,13 @@ def maybe_calibrate_or_quantize(
|
|
379
332
|
g_idx = getattr(module, "weight_g_idx", None)
|
380
333
|
|
381
334
|
if args.dynamic:
|
382
|
-
# dynamic quantization -
|
335
|
+
# dynamic quantization - determine the scale/zp on the fly
|
383
336
|
scale, zero_point = compute_dynamic_scales_and_zp(value=value, args=args)
|
384
337
|
else:
|
385
|
-
# static quantization - get
|
338
|
+
# static quantization - get scale and zero point from layer
|
386
339
|
scale = getattr(module, f"{base_name}_scale")
|
387
340
|
zero_point = getattr(module, f"{base_name}_zero_point", None)
|
388
341
|
|
389
|
-
if (
|
390
|
-
module.quantization_status == QuantizationStatus.CALIBRATION
|
391
|
-
and base_name != "weight"
|
392
|
-
):
|
393
|
-
# calibration mode - get new quant params from observer
|
394
|
-
observer = getattr(module, f"{base_name}_observer")
|
395
|
-
|
396
|
-
updated_scale, updated_zero_point = observer(value, g_idx=g_idx)
|
397
|
-
|
398
|
-
# update scale and zero point
|
399
|
-
update_parameter_data(module, updated_scale, f"{base_name}_scale")
|
400
|
-
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
|
401
|
-
|
402
|
-
scale = updated_scale
|
403
|
-
zero_point = updated_zero_point
|
404
|
-
|
405
342
|
return fake_quantize(
|
406
343
|
x=value,
|
407
344
|
scale=scale,
|
@@ -14,13 +14,12 @@
|
|
14
14
|
|
15
15
|
|
16
16
|
import logging
|
17
|
+
from enum import Enum
|
17
18
|
from typing import Optional
|
18
19
|
|
19
20
|
import torch
|
20
|
-
from compressed_tensors.quantization.cache import KVCacheScaleType
|
21
21
|
from compressed_tensors.quantization.lifecycle.forward import (
|
22
22
|
wrap_module_forward_quantized,
|
23
|
-
wrap_module_forward_quantized_attn,
|
24
23
|
)
|
25
24
|
from compressed_tensors.quantization.quant_args import (
|
26
25
|
ActivationOrdering,
|
@@ -36,12 +35,19 @@ from torch.nn import Module, Parameter
|
|
36
35
|
|
37
36
|
__all__ = [
|
38
37
|
"initialize_module_for_quantization",
|
38
|
+
"is_attention_module",
|
39
|
+
"KVCacheScaleType",
|
39
40
|
]
|
40
41
|
|
41
42
|
|
42
43
|
_LOGGER = logging.getLogger(__name__)
|
43
44
|
|
44
45
|
|
46
|
+
class KVCacheScaleType(Enum):
|
47
|
+
KEY = "k_scale"
|
48
|
+
VALUE = "v_scale"
|
49
|
+
|
50
|
+
|
45
51
|
def initialize_module_for_quantization(
|
46
52
|
module: Module,
|
47
53
|
scheme: Optional[QuantizationScheme] = None,
|
@@ -66,15 +72,13 @@ def initialize_module_for_quantization(
|
|
66
72
|
return
|
67
73
|
|
68
74
|
if is_attention_module(module):
|
69
|
-
# wrap forward call of module to perform
|
70
75
|
# quantized actions based on calltime status
|
71
|
-
wrap_module_forward_quantized_attn(module, scheme)
|
72
76
|
_initialize_attn_scales(module)
|
73
77
|
|
74
78
|
else:
|
75
79
|
|
76
80
|
if scheme.input_activations is not None:
|
77
|
-
|
81
|
+
_initialize_scale_zero_point(
|
78
82
|
module,
|
79
83
|
"input",
|
80
84
|
scheme.input_activations,
|
@@ -85,7 +89,7 @@ def initialize_module_for_quantization(
|
|
85
89
|
weight_shape = None
|
86
90
|
if isinstance(module, torch.nn.Linear):
|
87
91
|
weight_shape = module.weight.shape
|
88
|
-
|
92
|
+
_initialize_scale_zero_point(
|
89
93
|
module,
|
90
94
|
"weight",
|
91
95
|
scheme.weights,
|
@@ -101,7 +105,7 @@ def initialize_module_for_quantization(
|
|
101
105
|
|
102
106
|
if scheme.output_activations is not None:
|
103
107
|
if not is_kv_cache_quant_scheme(scheme):
|
104
|
-
|
108
|
+
_initialize_scale_zero_point(
|
105
109
|
module, "output", scheme.output_activations
|
106
110
|
)
|
107
111
|
|
@@ -109,6 +113,7 @@ def initialize_module_for_quantization(
|
|
109
113
|
module.quantization_status = QuantizationStatus.INITIALIZED
|
110
114
|
|
111
115
|
offloaded = False
|
116
|
+
# What is this doing/why isn't this in the attn case?
|
112
117
|
if is_module_offloaded(module):
|
113
118
|
try:
|
114
119
|
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
@@ -146,21 +151,21 @@ def initialize_module_for_quantization(
|
|
146
151
|
module._hf_hook.weights_map = new_prefix_dict
|
147
152
|
|
148
153
|
|
149
|
-
def
|
154
|
+
def is_attention_module(module: Module):
|
155
|
+
return "attention" in module.__class__.__name__.lower() and (
|
156
|
+
hasattr(module, "k_proj")
|
157
|
+
or hasattr(module, "v_proj")
|
158
|
+
or hasattr(module, "qkv_proj")
|
159
|
+
)
|
160
|
+
|
161
|
+
|
162
|
+
def _initialize_scale_zero_point(
|
150
163
|
module: Module,
|
151
164
|
base_name: str,
|
152
165
|
quantization_args: QuantizationArgs,
|
153
166
|
weight_shape: Optional[torch.Size] = None,
|
154
167
|
force_zero_point: bool = True,
|
155
168
|
):
|
156
|
-
|
157
|
-
# initialize observer module and attach as submodule
|
158
|
-
observer = quantization_args.get_observer()
|
159
|
-
# no need to register an observer for dynamic quantization
|
160
|
-
if observer:
|
161
|
-
module.register_module(f"{base_name}_observer", observer)
|
162
|
-
|
163
|
-
# no need to register a scale and zero point for a dynamic quantization
|
164
169
|
if quantization_args.dynamic:
|
165
170
|
return
|
166
171
|
|
@@ -209,14 +214,6 @@ def _initialize_scale_zero_point_observer(
|
|
209
214
|
module.register_parameter(f"{base_name}_g_idx", init_g_idx)
|
210
215
|
|
211
216
|
|
212
|
-
def is_attention_module(module: Module):
|
213
|
-
return "attention" in module.__class__.__name__.lower() and (
|
214
|
-
hasattr(module, "k_proj")
|
215
|
-
or hasattr(module, "v_proj")
|
216
|
-
or hasattr(module, "qkv_proj")
|
217
|
-
)
|
218
|
-
|
219
|
-
|
220
217
|
def _initialize_attn_scales(module: Module) -> None:
|
221
218
|
"""Initlaize k_scale, v_scale for self_attn"""
|
222
219
|
|
@@ -17,6 +17,7 @@ from enum import Enum
|
|
17
17
|
from typing import Any, Dict, Optional, Union
|
18
18
|
|
19
19
|
import torch
|
20
|
+
from compressed_tensors.utils import Aliasable
|
20
21
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
21
22
|
|
22
23
|
|
@@ -53,17 +54,29 @@ class QuantizationStrategy(str, Enum):
|
|
53
54
|
TOKEN = "token"
|
54
55
|
|
55
56
|
|
56
|
-
class ActivationOrdering(str, Enum):
|
57
|
+
class ActivationOrdering(Aliasable, str, Enum):
|
57
58
|
"""
|
58
59
|
Enum storing strategies for activation ordering
|
59
60
|
|
60
61
|
Group: reorder groups and weight\n
|
61
|
-
Weight: only reorder weight, not groups. Slightly lower
|
62
|
-
|
62
|
+
Weight: only reorder weight, not groups. Slightly lower accuracy but also lower
|
63
|
+
latency when compared to group actorder\n
|
64
|
+
Dynamic: alias for Group\n
|
65
|
+
Static: alias for Weight\n
|
63
66
|
"""
|
64
67
|
|
65
68
|
GROUP = "group"
|
66
69
|
WEIGHT = "weight"
|
70
|
+
# aliases
|
71
|
+
DYNAMIC = "dynamic"
|
72
|
+
STATIC = "static"
|
73
|
+
|
74
|
+
@staticmethod
|
75
|
+
def get_aliases() -> Dict[str, str]:
|
76
|
+
return {
|
77
|
+
"dynamic": "group",
|
78
|
+
"static": "weight",
|
79
|
+
}
|
67
80
|
|
68
81
|
|
69
82
|
class QuantizationArgs(BaseModel, use_enum_values=True):
|
@@ -114,20 +127,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
114
127
|
"""
|
115
128
|
:return: torch quantization FakeQuantize built based on these QuantizationArgs
|
116
129
|
"""
|
117
|
-
|
118
|
-
|
119
|
-
# No observer required for the dynamic case
|
120
|
-
if self.dynamic:
|
121
|
-
self.observer = None
|
122
|
-
return self.observer
|
123
|
-
|
124
|
-
return Observer.load_from_registry(self.observer, quantization_args=self)
|
125
|
-
|
126
|
-
def get_kv_cache(self):
|
127
|
-
"""Get the singleton KV Cache"""
|
128
|
-
from compressed_tensors.quantization.cache import QuantizedKVParameterCache
|
129
|
-
|
130
|
-
return QuantizedKVParameterCache(self)
|
130
|
+
return self.observer
|
131
131
|
|
132
132
|
@field_validator("type", mode="before")
|
133
133
|
def validate_type(cls, value) -> QuantizationType:
|
@@ -210,6 +210,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
210
210
|
"activation ordering"
|
211
211
|
)
|
212
212
|
|
213
|
+
# infer observer w.r.t. dynamic
|
213
214
|
if dynamic:
|
214
215
|
if strategy not in (
|
215
216
|
QuantizationStrategy.TOKEN,
|
@@ -221,18 +222,19 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
221
222
|
"quantization",
|
222
223
|
)
|
223
224
|
if observer is not None:
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
225
|
+
if observer != "memoryless": # avoid annoying users with old configs
|
226
|
+
warnings.warn(
|
227
|
+
"No observer is used for dynamic quantization, setting to None"
|
228
|
+
)
|
229
|
+
observer = None
|
228
230
|
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
model.observer = "minmax"
|
231
|
+
elif observer is None:
|
232
|
+
# default to minmax for non-dynamic cases
|
233
|
+
observer = "minmax"
|
233
234
|
|
234
235
|
# write back modified values
|
235
236
|
model.strategy = strategy
|
237
|
+
model.observer = observer
|
236
238
|
return model
|
237
239
|
|
238
240
|
def pytorch_dtype(self) -> torch.dtype:
|