compressed-tensors 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/__init__.py +1 -0
- compressed_tensors/base.py +2 -0
- compressed_tensors/compressors/__init__.py +6 -12
- compressed_tensors/compressors/base.py +137 -9
- compressed_tensors/compressors/helpers.py +6 -6
- compressed_tensors/compressors/model_compressors/__init__.py +17 -0
- compressed_tensors/compressors/{model_compressor.py → model_compressors/model_compressor.py} +99 -43
- compressed_tensors/compressors/quantized_compressors/__init__.py +18 -0
- compressed_tensors/compressors/{naive_quantized.py → quantized_compressors/base.py} +64 -62
- compressed_tensors/compressors/quantized_compressors/naive_quantized.py +140 -0
- compressed_tensors/compressors/quantized_compressors/pack_quantized.py +211 -0
- compressed_tensors/compressors/sparse_compressors/__init__.py +18 -0
- compressed_tensors/compressors/sparse_compressors/base.py +110 -0
- compressed_tensors/compressors/{dense.py → sparse_compressors/dense.py} +3 -3
- compressed_tensors/compressors/{sparse_bitmask.py → sparse_compressors/sparse_bitmask.py} +14 -59
- compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +16 -0
- compressed_tensors/compressors/{marlin_24.py → sparse_quantized_compressors/marlin_24.py} +3 -3
- compressed_tensors/config/base.py +6 -1
- compressed_tensors/linear/__init__.py +13 -0
- compressed_tensors/linear/compressed_linear.py +87 -0
- compressed_tensors/quantization/__init__.py +1 -0
- compressed_tensors/quantization/cache.py +201 -0
- compressed_tensors/quantization/lifecycle/apply.py +63 -9
- compressed_tensors/quantization/lifecycle/calibration.py +7 -7
- compressed_tensors/quantization/lifecycle/compressed.py +3 -1
- compressed_tensors/quantization/lifecycle/forward.py +126 -44
- compressed_tensors/quantization/lifecycle/frozen.py +6 -1
- compressed_tensors/quantization/lifecycle/helpers.py +0 -20
- compressed_tensors/quantization/lifecycle/initialize.py +138 -55
- compressed_tensors/quantization/observers/__init__.py +1 -0
- compressed_tensors/quantization/observers/base.py +54 -14
- compressed_tensors/quantization/observers/min_max.py +8 -0
- compressed_tensors/quantization/observers/mse.py +162 -0
- compressed_tensors/quantization/quant_args.py +102 -24
- compressed_tensors/quantization/quant_config.py +14 -2
- compressed_tensors/quantization/quant_scheme.py +12 -13
- compressed_tensors/quantization/utils/helpers.py +44 -19
- compressed_tensors/utils/__init__.py +1 -0
- compressed_tensors/utils/helpers.py +30 -1
- compressed_tensors/utils/offload.py +14 -2
- compressed_tensors/utils/permute.py +70 -0
- compressed_tensors/utils/safetensors_load.py +2 -0
- compressed_tensors/utils/semi_structured_conversions.py +1 -0
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/METADATA +35 -23
- compressed_tensors-0.7.0.dist-info/RECORD +59 -0
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/WHEEL +1 -1
- compressed_tensors/compressors/pack_quantized.py +0 -219
- compressed_tensors-0.5.0.dist-info/RECORD +0 -48
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/top_level.txt +0 -0
@@ -12,17 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
import
|
16
|
-
from typing import Dict, Generator, List, Tuple, Union
|
15
|
+
from typing import Dict, List, Tuple, Union
|
17
16
|
|
18
17
|
import numpy
|
19
18
|
import torch
|
20
|
-
from compressed_tensors.compressors import
|
19
|
+
from compressed_tensors.compressors.base import BaseCompressor
|
20
|
+
from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
|
21
21
|
from compressed_tensors.config import CompressionFormat
|
22
|
-
from compressed_tensors.utils import
|
23
|
-
from safetensors import safe_open
|
22
|
+
from compressed_tensors.utils import merge_names
|
24
23
|
from torch import Tensor
|
25
|
-
from tqdm import tqdm
|
26
24
|
|
27
25
|
|
28
26
|
__all__ = [
|
@@ -34,11 +32,9 @@ __all__ = [
|
|
34
32
|
"unpack_bitmasks",
|
35
33
|
]
|
36
34
|
|
37
|
-
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
38
35
|
|
39
|
-
|
40
|
-
|
41
|
-
class BitmaskCompressor(Compressor):
|
36
|
+
@BaseCompressor.register(name=CompressionFormat.sparse_bitmask.value)
|
37
|
+
class BitmaskCompressor(BaseSparseCompressor):
|
42
38
|
"""
|
43
39
|
Compression for sparse models using bitmasks. Non-zero weights are stored in a 1d
|
44
40
|
values tensor, with their locations stored in a 2d bitmask
|
@@ -46,56 +42,15 @@ class BitmaskCompressor(Compressor):
|
|
46
42
|
|
47
43
|
COMPRESSION_PARAM_NAMES = ["shape", "compressed", "bitmask", "row_offsets"]
|
48
44
|
|
49
|
-
def
|
50
|
-
|
51
|
-
|
45
|
+
def compress_weight(self, name, value):
|
46
|
+
bitmask_tensor = BitmaskTensor.from_dense(value)
|
47
|
+
bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu")
|
48
|
+
return bitmask_dict
|
52
49
|
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
_LOGGER.debug(
|
58
|
-
f"Compressing model with {len(model_state)} parameterized layers..."
|
59
|
-
)
|
60
|
-
for name, value in tqdm(model_state.items(), desc="Compressing model"):
|
61
|
-
bitmask_tensor = BitmaskTensor.from_dense(value)
|
62
|
-
bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu")
|
63
|
-
for key in bitmask_dict.keys():
|
64
|
-
if key in compressed_dict:
|
65
|
-
_LOGGER.warn(
|
66
|
-
f"Expected all compressed state_dict keys to be unique, but "
|
67
|
-
f"found an existing entry for {key}. The existing entry will "
|
68
|
-
"be replaced."
|
69
|
-
)
|
70
|
-
compressed_dict.update(bitmask_dict)
|
71
|
-
|
72
|
-
return compressed_dict
|
73
|
-
|
74
|
-
def decompress(
|
75
|
-
self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
|
76
|
-
) -> Generator[Tuple[str, Tensor], None, None]:
|
77
|
-
"""
|
78
|
-
Reads a bitmask compressed state dict located
|
79
|
-
at path_to_model_or_tensors and returns a generator
|
80
|
-
for sequentially decompressing back to a dense state dict
|
81
|
-
|
82
|
-
:param model_path: path to compressed safetensors model (directory with
|
83
|
-
one or more safetensors files) or compressed tensors file
|
84
|
-
:param device: device to load decompressed weights onto
|
85
|
-
:return: iterator for generating decompressed weights
|
86
|
-
"""
|
87
|
-
weight_mappings = get_nested_weight_mappings(
|
88
|
-
path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
|
89
|
-
)
|
90
|
-
for weight_name in weight_mappings.keys():
|
91
|
-
weight_data = {}
|
92
|
-
for param_name, safe_path in weight_mappings[weight_name].items():
|
93
|
-
full_name = merge_names(weight_name, param_name)
|
94
|
-
with safe_open(safe_path, framework="pt", device=device) as f:
|
95
|
-
weight_data[param_name] = f.get_tensor(full_name)
|
96
|
-
data = BitmaskTensor(**weight_data)
|
97
|
-
decompressed = data.decompress()
|
98
|
-
yield weight_name, decompressed
|
50
|
+
def decompress_weight(self, weight_data):
|
51
|
+
data = BitmaskTensor(**weight_data)
|
52
|
+
decompressed = data.decompress()
|
53
|
+
return decompressed
|
99
54
|
|
100
55
|
|
101
56
|
class BitmaskTensor:
|
@@ -0,0 +1,16 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# flake8: noqa
|
15
|
+
|
16
|
+
from .marlin_24 import Marlin24Compressor
|
@@ -17,7 +17,7 @@ from typing import Dict, Generator, Tuple
|
|
17
17
|
|
18
18
|
import numpy as np
|
19
19
|
import torch
|
20
|
-
from compressed_tensors.compressors import
|
20
|
+
from compressed_tensors.compressors.base import BaseCompressor
|
21
21
|
from compressed_tensors.config import CompressionFormat
|
22
22
|
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
|
23
23
|
from compressed_tensors.quantization.lifecycle.forward import quantize
|
@@ -35,8 +35,8 @@ from tqdm import tqdm
|
|
35
35
|
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
36
36
|
|
37
37
|
|
38
|
-
@
|
39
|
-
class Marlin24Compressor(
|
38
|
+
@BaseCompressor.register(name=CompressionFormat.marlin_24.value)
|
39
|
+
class Marlin24Compressor(BaseCompressor):
|
40
40
|
"""
|
41
41
|
Compresses a quantized model with 2:4 sparsity structure for inference with the
|
42
42
|
Marlin24 kernel. Decompression is not implemented for this compressor.
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from enum import Enum
|
16
|
-
from typing import Optional
|
16
|
+
from typing import List, Optional
|
17
17
|
|
18
18
|
from compressed_tensors.registry import RegistryMixin
|
19
19
|
from pydantic import BaseModel
|
@@ -37,11 +37,16 @@ class SparsityCompressionConfig(RegistryMixin, BaseModel):
|
|
37
37
|
Base data class for storing sparsity compression parameters
|
38
38
|
|
39
39
|
:param format: name of compression format
|
40
|
+
:param targets: List of layer names or layer types that aren't sparse and should
|
41
|
+
be ignored during compression. By default, assume all layers are targeted
|
42
|
+
:param ignore: List of layer names (unique) to ignore from targets. Defaults to None
|
40
43
|
:param global_sparsity: average sparsity of the entire model
|
41
44
|
:param sparsity_structure: structure of the sparsity, such as
|
42
45
|
"unstructured", "2:4", "8:16" etc
|
43
46
|
"""
|
44
47
|
|
45
48
|
format: str
|
49
|
+
targets: Optional[List[str]] = None
|
50
|
+
ignore: Optional[List[str]] = None
|
46
51
|
global_sparsity: Optional[float] = 0.0
|
47
52
|
sparsity_structure: Optional[str] = "unstructured"
|
@@ -0,0 +1,13 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
@@ -0,0 +1,87 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import torch
|
16
|
+
from compressed_tensors.compressors.base import BaseCompressor
|
17
|
+
from compressed_tensors.quantization import (
|
18
|
+
QuantizationScheme,
|
19
|
+
QuantizationStatus,
|
20
|
+
initialize_module_for_quantization,
|
21
|
+
)
|
22
|
+
from torch import Tensor
|
23
|
+
from torch.nn import Parameter
|
24
|
+
from torch.nn.functional import linear
|
25
|
+
from torch.nn.modules import Linear
|
26
|
+
|
27
|
+
|
28
|
+
class CompressedLinear(Linear):
|
29
|
+
"""
|
30
|
+
Wrapper module for running a compressed forward pass of a quantized Linear module.
|
31
|
+
The wrapped layer will decompressed on each forward call.
|
32
|
+
|
33
|
+
:param module: dense linear module to replace
|
34
|
+
:param quantization_scheme: quantization config for the module to wrap
|
35
|
+
:param quantization_format: compression format module is stored as
|
36
|
+
"""
|
37
|
+
|
38
|
+
@classmethod
|
39
|
+
@torch.no_grad()
|
40
|
+
def from_linear(
|
41
|
+
cls,
|
42
|
+
module: Linear,
|
43
|
+
quantization_scheme: QuantizationScheme,
|
44
|
+
quantization_format: str,
|
45
|
+
):
|
46
|
+
module.__class__ = CompressedLinear
|
47
|
+
module.compressor = BaseCompressor.load_from_registry(quantization_format)
|
48
|
+
device = next(module.parameters()).device
|
49
|
+
|
50
|
+
# this will initialize all the scales and zero points
|
51
|
+
initialize_module_for_quantization(
|
52
|
+
module, quantization_scheme, force_zero_point=False
|
53
|
+
)
|
54
|
+
|
55
|
+
# get the shape and dtype of compressed parameters
|
56
|
+
compression_params = module.compressor.compression_param_info(
|
57
|
+
module.weight.shape, quantization_scheme.weights
|
58
|
+
)
|
59
|
+
|
60
|
+
# no need for this once quantization is initialized, will be replaced
|
61
|
+
# with the compressed parameter
|
62
|
+
delattr(module, "weight")
|
63
|
+
|
64
|
+
# populate compressed weights and quantization parameters
|
65
|
+
for name, (shape, dtype) in compression_params.items():
|
66
|
+
param = Parameter(
|
67
|
+
torch.empty(shape, device=device, dtype=dtype), requires_grad=False
|
68
|
+
)
|
69
|
+
module.register_parameter(name, param)
|
70
|
+
|
71
|
+
# mark module as compressed
|
72
|
+
module.quantization_status = QuantizationStatus.COMPRESSED
|
73
|
+
|
74
|
+
# handles case where forward is wrapped in new_forward by accelerate hooks
|
75
|
+
if hasattr(module, "_old_forward"):
|
76
|
+
module._old_forward = CompressedLinear.forward.__get__(
|
77
|
+
module, CompressedLinear
|
78
|
+
)
|
79
|
+
|
80
|
+
return module
|
81
|
+
|
82
|
+
def forward(self, input: Tensor) -> Tensor:
|
83
|
+
"""
|
84
|
+
Decompresses the weight, then runs the wrapped forward pass
|
85
|
+
"""
|
86
|
+
uncompressed_weight = self.compressor.decompress_module(self)
|
87
|
+
return linear(input, uncompressed_weight, self.bias)
|
@@ -0,0 +1,201 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
from enum import Enum
|
17
|
+
from typing import Any, Dict, List, Optional, Tuple
|
18
|
+
|
19
|
+
from compressed_tensors.quantization.observers import Observer
|
20
|
+
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
21
|
+
from torch import Tensor
|
22
|
+
from transformers import DynamicCache as HFDyanmicCache
|
23
|
+
|
24
|
+
|
25
|
+
class KVCacheScaleType(Enum):
|
26
|
+
KEY = "k_scale"
|
27
|
+
VALUE = "v_scale"
|
28
|
+
|
29
|
+
|
30
|
+
class QuantizedKVParameterCache(HFDyanmicCache):
|
31
|
+
|
32
|
+
"""
|
33
|
+
Quantized KV cache used in the forward call based on HF's dynamic cache.
|
34
|
+
Quantization strategy (tensor, group, channel) set from Quantization arg's strategy
|
35
|
+
Singleton, so that the same cache gets reused in all forward call of self_attn.
|
36
|
+
Each time forward is called, .update() is called, and ._quantize(), ._dequantize()
|
37
|
+
gets called appropriately.
|
38
|
+
The size of tensor is
|
39
|
+
`[batch_size, num_heads, seq_len - residual_length, head_dim]`.
|
40
|
+
|
41
|
+
|
42
|
+
Triggered by adding kv_cache_scheme in the recipe.
|
43
|
+
|
44
|
+
Example:
|
45
|
+
|
46
|
+
```python3
|
47
|
+
recipe = '''
|
48
|
+
quant_stage:
|
49
|
+
quant_modifiers:
|
50
|
+
QuantizationModifier:
|
51
|
+
kv_cache_scheme:
|
52
|
+
num_bits: 8
|
53
|
+
type: float
|
54
|
+
strategy: tensor
|
55
|
+
dynamic: false
|
56
|
+
symmetric: true
|
57
|
+
'''
|
58
|
+
|
59
|
+
"""
|
60
|
+
|
61
|
+
_instance = None
|
62
|
+
_initialized = False
|
63
|
+
|
64
|
+
def __new__(cls, *args, **kwargs):
|
65
|
+
"""Singleton"""
|
66
|
+
if cls._instance is None:
|
67
|
+
cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls)
|
68
|
+
return cls._instance
|
69
|
+
|
70
|
+
def __init__(self, quantization_args: QuantizationArgs):
|
71
|
+
if not self._initialized:
|
72
|
+
super().__init__()
|
73
|
+
|
74
|
+
self.quantization_args = quantization_args
|
75
|
+
|
76
|
+
self.k_observers: List[Observer] = []
|
77
|
+
self.v_observers: List[Observer] = []
|
78
|
+
|
79
|
+
# each index corresponds to layer_idx of the attention layer
|
80
|
+
self.k_scales: List[Tensor] = []
|
81
|
+
self.v_scales: List[Tensor] = []
|
82
|
+
|
83
|
+
self.k_zps: List[Tensor] = []
|
84
|
+
self.v_zps: List[Tensor] = []
|
85
|
+
|
86
|
+
self._initialized = True
|
87
|
+
|
88
|
+
def update(
|
89
|
+
self,
|
90
|
+
key_states: Tensor,
|
91
|
+
value_states: Tensor,
|
92
|
+
layer_idx: int,
|
93
|
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
94
|
+
) -> Tuple[Tensor, Tensor]:
|
95
|
+
"""
|
96
|
+
Get the k_scale and v_scale and output the
|
97
|
+
fakequant-ed key_states and value_states
|
98
|
+
"""
|
99
|
+
|
100
|
+
if len(self.k_observers) <= layer_idx:
|
101
|
+
k_observer = self.quantization_args.get_observer()
|
102
|
+
v_observer = self.quantization_args.get_observer()
|
103
|
+
|
104
|
+
self.k_observers.append(k_observer)
|
105
|
+
self.v_observers.append(v_observer)
|
106
|
+
|
107
|
+
q_key_states = self._quantize(
|
108
|
+
key_states.contiguous(), KVCacheScaleType.KEY, layer_idx
|
109
|
+
)
|
110
|
+
q_value_states = self._quantize(
|
111
|
+
value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx
|
112
|
+
)
|
113
|
+
|
114
|
+
qdq_key_states = self._dequantize(q_key_states, KVCacheScaleType.KEY, layer_idx)
|
115
|
+
qdq_value_states = self._dequantize(
|
116
|
+
q_value_states, KVCacheScaleType.VALUE, layer_idx
|
117
|
+
)
|
118
|
+
|
119
|
+
keys_to_return, values_to_return = qdq_key_states, qdq_value_states
|
120
|
+
|
121
|
+
return keys_to_return, values_to_return
|
122
|
+
|
123
|
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
124
|
+
"""
|
125
|
+
Returns the sequence length of the cached states.
|
126
|
+
A layer index can be optionally passed.
|
127
|
+
"""
|
128
|
+
if len(self.key_cache) <= layer_idx:
|
129
|
+
return 0
|
130
|
+
# since we cannot get the seq_length of each layer directly and
|
131
|
+
# rely on `_seen_tokens` which is updated every "layer_idx" == 0,
|
132
|
+
# this is a hack to get the actual seq_length for the given layer_idx
|
133
|
+
# this part of code otherwise fails when used to
|
134
|
+
# verify attn_weight shape in some models
|
135
|
+
return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1
|
136
|
+
|
137
|
+
def reset_states(self):
|
138
|
+
"""reset the kv states (used in calibration)"""
|
139
|
+
self.key_cache: List[Tensor] = []
|
140
|
+
self.value_cache: List[Tensor] = []
|
141
|
+
# Used in `generate` to keep tally of how many tokens the cache has seen
|
142
|
+
self._seen_tokens = 0
|
143
|
+
self._quantized_key_cache: List[Tensor] = []
|
144
|
+
self._quantized_value_cache: List[Tensor] = []
|
145
|
+
|
146
|
+
def reset(self):
|
147
|
+
"""
|
148
|
+
Reset the instantiation, create new instance on init
|
149
|
+
"""
|
150
|
+
QuantizedKVParameterCache._instance = None
|
151
|
+
QuantizedKVParameterCache._initialized = False
|
152
|
+
|
153
|
+
def _quantize(self, tensor, kv_type, layer_idx):
|
154
|
+
"""Quantizes a key/value using a defined quantization method."""
|
155
|
+
from compressed_tensors.quantization.lifecycle.forward import quantize
|
156
|
+
|
157
|
+
if kv_type == KVCacheScaleType.KEY: # key type
|
158
|
+
observer = self.k_observers[layer_idx]
|
159
|
+
scales = self.k_scales
|
160
|
+
zps = self.k_zps
|
161
|
+
else:
|
162
|
+
assert kv_type == KVCacheScaleType.VALUE
|
163
|
+
observer = self.v_observers[layer_idx]
|
164
|
+
scales = self.v_scales
|
165
|
+
zps = self.v_zps
|
166
|
+
|
167
|
+
scale, zp = observer(tensor)
|
168
|
+
if len(scales) <= layer_idx:
|
169
|
+
scales.append(scale)
|
170
|
+
zps.append(zp)
|
171
|
+
else:
|
172
|
+
scales[layer_idx] = scale
|
173
|
+
zps[layer_idx] = scale
|
174
|
+
|
175
|
+
q_tensor = quantize(
|
176
|
+
x=tensor,
|
177
|
+
scale=scale,
|
178
|
+
zero_point=zp,
|
179
|
+
args=self.quantization_args,
|
180
|
+
)
|
181
|
+
return q_tensor
|
182
|
+
|
183
|
+
def _dequantize(self, qtensor, kv_type, layer_idx):
|
184
|
+
"""Dequantizes back the tensor that was quantized by `self._quantize()`"""
|
185
|
+
from compressed_tensors.quantization.lifecycle.forward import dequantize
|
186
|
+
|
187
|
+
if kv_type == KVCacheScaleType.KEY:
|
188
|
+
scale = self.k_scales[layer_idx]
|
189
|
+
zp = self.k_zps[layer_idx]
|
190
|
+
else:
|
191
|
+
assert kv_type == KVCacheScaleType.VALUE
|
192
|
+
scale = self.v_scales[layer_idx]
|
193
|
+
zp = self.v_zps[layer_idx]
|
194
|
+
|
195
|
+
qdq_tensor = dequantize(
|
196
|
+
x_q=qtensor,
|
197
|
+
scale=scale,
|
198
|
+
zero_point=zp,
|
199
|
+
args=self.quantization_args,
|
200
|
+
)
|
201
|
+
return qdq_tensor
|
@@ -14,12 +14,14 @@
|
|
14
14
|
|
15
15
|
import logging
|
16
16
|
import re
|
17
|
-
from collections import OrderedDict
|
17
|
+
from collections import OrderedDict, defaultdict
|
18
|
+
from copy import deepcopy
|
18
19
|
from typing import Dict, Iterable, List, Optional
|
19
20
|
from typing import OrderedDict as OrderedDictType
|
20
21
|
from typing import Union
|
21
22
|
|
22
23
|
import torch
|
24
|
+
from compressed_tensors.config import CompressionFormat
|
23
25
|
from compressed_tensors.quantization.lifecycle.calibration import (
|
24
26
|
set_module_for_calibration,
|
25
27
|
)
|
@@ -41,8 +43,9 @@ from compressed_tensors.quantization.utils import (
|
|
41
43
|
infer_quantization_status,
|
42
44
|
is_kv_cache_quant_scheme,
|
43
45
|
iter_named_leaf_modules,
|
46
|
+
iter_named_quantizable_modules,
|
44
47
|
)
|
45
|
-
from compressed_tensors.utils.helpers import fix_fsdp_module_name
|
48
|
+
from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
|
46
49
|
from compressed_tensors.utils.offload import update_parameter_data
|
47
50
|
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
|
48
51
|
from torch.nn import Module
|
@@ -103,13 +106,25 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
|
|
103
106
|
)
|
104
107
|
|
105
108
|
|
106
|
-
def apply_quantization_config(
|
109
|
+
def apply_quantization_config(
|
110
|
+
model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
|
111
|
+
) -> OrderedDict:
|
107
112
|
"""
|
108
113
|
Initializes the model for quantization in-place based on the given config
|
109
114
|
|
110
115
|
:param model: model to apply quantization config to
|
111
116
|
:param config: quantization config
|
117
|
+
:param run_compressed: Whether the model will be run in compressed mode or
|
118
|
+
decompressed fully on load
|
112
119
|
"""
|
120
|
+
# Workaround for when HF Quantizer passes None, see PR #180
|
121
|
+
if config is None:
|
122
|
+
return OrderedDict()
|
123
|
+
|
124
|
+
# remove reference to the original `config`
|
125
|
+
# argument. This function can mutate it, and we'd
|
126
|
+
# like to keep the original `config` as it is.
|
127
|
+
config = deepcopy(config)
|
113
128
|
# build mapping of targets to schemes for easier matching
|
114
129
|
# use ordered dict to preserve target ordering in config
|
115
130
|
target_to_scheme = OrderedDict()
|
@@ -119,21 +134,47 @@ def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict
|
|
119
134
|
for target in scheme.targets:
|
120
135
|
target_to_scheme[target] = scheme
|
121
136
|
|
137
|
+
if run_compressed:
|
138
|
+
from compressed_tensors.linear.compressed_linear import CompressedLinear
|
139
|
+
|
122
140
|
# list of submodules to ignore
|
123
|
-
ignored_submodules =
|
141
|
+
ignored_submodules = defaultdict(list)
|
124
142
|
# mark appropriate layers for quantization by setting their quantization schemes
|
125
|
-
for name, submodule in
|
143
|
+
for name, submodule in iter_named_quantizable_modules(
|
144
|
+
model,
|
145
|
+
include_children=True,
|
146
|
+
include_attn=True,
|
147
|
+
): # child modules and attention modules
|
126
148
|
# potentially fix module name to remove FSDP wrapper prefix
|
127
149
|
name = fix_fsdp_module_name(name)
|
128
|
-
if find_name_or_class_matches(name, submodule, config.ignore):
|
129
|
-
|
150
|
+
if matches := find_name_or_class_matches(name, submodule, config.ignore):
|
151
|
+
for match in matches:
|
152
|
+
ignored_submodules[match].append(name)
|
130
153
|
continue # layer matches ignore list, continue
|
154
|
+
|
131
155
|
targets = find_name_or_class_matches(name, submodule, target_to_scheme)
|
156
|
+
|
132
157
|
if targets:
|
158
|
+
# mark modules to be quantized by adding
|
159
|
+
# quant scheme to the matching layers
|
160
|
+
scheme = _scheme_from_targets(target_to_scheme, targets, name)
|
161
|
+
if run_compressed:
|
162
|
+
format = config.format
|
163
|
+
if format != CompressionFormat.dense.value:
|
164
|
+
if isinstance(submodule, torch.nn.Linear):
|
165
|
+
# TODO: expand to more module types
|
166
|
+
compressed_linear = CompressedLinear.from_linear(
|
167
|
+
submodule,
|
168
|
+
quantization_scheme=scheme,
|
169
|
+
quantization_format=format,
|
170
|
+
)
|
171
|
+
replace_module(model, name, compressed_linear)
|
172
|
+
|
133
173
|
# target matched - add layer and scheme to target list
|
134
174
|
submodule.quantization_scheme = _scheme_from_targets(
|
135
175
|
target_to_scheme, targets, name
|
136
176
|
)
|
177
|
+
|
137
178
|
names_to_scheme[name] = submodule.quantization_scheme.weights
|
138
179
|
|
139
180
|
if config.ignore is not None and ignored_submodules is not None:
|
@@ -143,8 +184,8 @@ def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict
|
|
143
184
|
"not found in the model: "
|
144
185
|
f"{set(config.ignore) - set(ignored_submodules)}"
|
145
186
|
)
|
146
|
-
# apply current quantization status across all targeted layers
|
147
187
|
|
188
|
+
# apply current quantization status across all targeted layers
|
148
189
|
apply_quantization_status(model, config.quantization_status)
|
149
190
|
return names_to_scheme
|
150
191
|
|
@@ -172,6 +213,9 @@ def process_kv_cache_config(
|
|
172
213
|
:param config: the QuantizationConfig
|
173
214
|
:return: the QuantizationConfig with additional "kv_cache" group
|
174
215
|
"""
|
216
|
+
if targets == KV_CACHE_TARGETS:
|
217
|
+
_LOGGER.info(f"KV cache targets set to default value of: {KV_CACHE_TARGETS}")
|
218
|
+
|
175
219
|
kv_cache_dict = config.kv_cache_scheme.model_dump()
|
176
220
|
kv_cache_scheme = QuantizationScheme(
|
177
221
|
output_activations=QuantizationArgs(**kv_cache_dict),
|
@@ -192,7 +236,12 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
|
|
192
236
|
current_status = infer_quantization_status(model)
|
193
237
|
|
194
238
|
if status >= QuantizationStatus.INITIALIZED > current_status:
|
195
|
-
|
239
|
+
force_zero_point_init = status != QuantizationStatus.COMPRESSED
|
240
|
+
model.apply(
|
241
|
+
lambda module: initialize_module_for_quantization(
|
242
|
+
module, force_zero_point=force_zero_point_init
|
243
|
+
)
|
244
|
+
)
|
196
245
|
|
197
246
|
if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
|
198
247
|
# only quantize weights up front when our end goal state is calibration,
|
@@ -273,9 +322,11 @@ def _load_quant_args_from_state_dict(
|
|
273
322
|
"""
|
274
323
|
scale_name = f"{base_name}_scale"
|
275
324
|
zp_name = f"{base_name}_zero_point"
|
325
|
+
g_idx_name = f"{base_name}_g_idx"
|
276
326
|
|
277
327
|
state_dict_scale = state_dict.get(f"{module_name}.{scale_name}", None)
|
278
328
|
state_dict_zp = state_dict.get(f"{module_name}.{zp_name}", None)
|
329
|
+
state_dict_g_idx = state_dict.get(f"{module_name}.{g_idx_name}", None)
|
279
330
|
|
280
331
|
if state_dict_scale is not None:
|
281
332
|
# module is quantized
|
@@ -285,6 +336,9 @@ def _load_quant_args_from_state_dict(
|
|
285
336
|
state_dict_zp = torch.zeros_like(state_dict_scale, device="cpu")
|
286
337
|
update_parameter_data(module, state_dict_zp, zp_name)
|
287
338
|
|
339
|
+
if state_dict_g_idx is not None:
|
340
|
+
update_parameter_data(module, state_dict_g_idx, g_idx_name)
|
341
|
+
|
288
342
|
|
289
343
|
def _scheme_from_targets(
|
290
344
|
target_to_scheme: OrderedDictType[str, QuantizationScheme],
|
@@ -36,15 +36,15 @@ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool =
|
|
36
36
|
apply to full model with `model.apply(set_module_for_calibration)`
|
37
37
|
|
38
38
|
:param module: module to set for calibration
|
39
|
-
:param quantize_weights_upfront: whether to automatically
|
40
|
-
|
39
|
+
:param quantize_weights_upfront: whether to automatically
|
40
|
+
run weight quantization at the start of calibration
|
41
41
|
"""
|
42
42
|
if not getattr(module, "quantization_scheme", None):
|
43
43
|
# no quantization scheme nothing to do
|
44
44
|
return
|
45
45
|
status = getattr(module, "quantization_status", None)
|
46
46
|
if not status or status != QuantizationStatus.INITIALIZED:
|
47
|
-
|
47
|
+
_LOGGER.warning(
|
48
48
|
f"Attempting set module with status {status} to calibration mode. "
|
49
49
|
f"but status is not {QuantizationStatus.INITIALIZED} - you may "
|
50
50
|
"be calibrating an uninitialized module which may fail or attempting "
|
@@ -54,13 +54,13 @@ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool =
|
|
54
54
|
if quantize_weights_upfront and module.quantization_scheme.weights is not None:
|
55
55
|
# set weight scale and zero_point up front, calibration data doesn't affect it
|
56
56
|
observer = module.weight_observer
|
57
|
+
g_idx = getattr(module, "weight_g_idx", None)
|
57
58
|
|
58
|
-
offloaded =
|
59
|
-
if
|
59
|
+
offloaded = is_module_offloaded(module)
|
60
|
+
if offloaded:
|
60
61
|
module._hf_hook.pre_forward(module)
|
61
|
-
offloaded = True
|
62
62
|
|
63
|
-
scale, zero_point = observer(module.weight)
|
63
|
+
scale, zero_point = observer(module.weight, g_idx=g_idx)
|
64
64
|
update_parameter_data(module, scale, "weight_scale")
|
65
65
|
update_parameter_data(module, zero_point, "weight_zero_point")
|
66
66
|
|