compressed-tensors 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/compressors/base.py +200 -8
- compressed_tensors/compressors/model_compressor.py +68 -1
- compressed_tensors/compressors/naive_quantized.py +71 -75
- compressed_tensors/compressors/pack_quantized.py +83 -94
- compressed_tensors/config/base.py +6 -1
- compressed_tensors/linear/__init__.py +13 -0
- compressed_tensors/linear/compressed_linear.py +87 -0
- compressed_tensors/quantization/lifecycle/apply.py +46 -8
- compressed_tensors/quantization/lifecycle/calibration.py +5 -4
- compressed_tensors/quantization/lifecycle/compressed.py +3 -1
- compressed_tensors/quantization/lifecycle/forward.py +76 -43
- compressed_tensors/quantization/lifecycle/helpers.py +29 -2
- compressed_tensors/quantization/lifecycle/initialize.py +51 -16
- compressed_tensors/quantization/observers/__init__.py +1 -0
- compressed_tensors/quantization/observers/base.py +54 -14
- compressed_tensors/quantization/observers/min_max.py +8 -0
- compressed_tensors/quantization/observers/mse.py +162 -0
- compressed_tensors/quantization/quant_args.py +96 -24
- compressed_tensors/quantization/quant_scheme.py +7 -9
- compressed_tensors/quantization/utils/helpers.py +1 -1
- compressed_tensors/utils/__init__.py +1 -0
- compressed_tensors/utils/helpers.py +13 -0
- compressed_tensors/utils/offload.py +14 -2
- compressed_tensors/utils/permute.py +70 -0
- compressed_tensors/utils/safetensors_load.py +2 -0
- compressed_tensors/utils/semi_structured_conversions.py +1 -0
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.6.0.dist-info}/METADATA +35 -23
- compressed_tensors-0.6.0.dist-info/RECORD +52 -0
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.6.0.dist-info}/WHEEL +1 -1
- compressed_tensors-0.5.0.dist-info/RECORD +0 -48
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.6.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.6.0.dist-info}/top_level.txt +0 -0
@@ -11,10 +11,8 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
|
15
|
-
import logging
|
16
14
|
import math
|
17
|
-
from typing import Dict,
|
15
|
+
from typing import Dict, Optional, Tuple
|
18
16
|
|
19
17
|
import numpy as np
|
20
18
|
import torch
|
@@ -23,16 +21,11 @@ from compressed_tensors.config import CompressionFormat
|
|
23
21
|
from compressed_tensors.quantization import QuantizationArgs
|
24
22
|
from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
|
25
23
|
from compressed_tensors.quantization.utils import can_quantize
|
26
|
-
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
|
27
|
-
from safetensors import safe_open
|
28
24
|
from torch import Tensor
|
29
|
-
from tqdm import tqdm
|
30
25
|
|
31
26
|
|
32
27
|
__all__ = ["PackedQuantizationCompressor", "pack_to_int32", "unpack_from_int32"]
|
33
28
|
|
34
|
-
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
35
|
-
|
36
29
|
|
37
30
|
@Compressor.register(name=CompressionFormat.pack_quantized.value)
|
38
31
|
class PackedQuantizationCompressor(Compressor):
|
@@ -44,102 +37,96 @@ class PackedQuantizationCompressor(Compressor):
|
|
44
37
|
"weight_packed",
|
45
38
|
"weight_scale",
|
46
39
|
"weight_zero_point",
|
40
|
+
"weight_g_idx",
|
47
41
|
"weight_shape",
|
48
42
|
]
|
49
43
|
|
50
|
-
def
|
44
|
+
def compression_param_info(
|
51
45
|
self,
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
) -> Dict[str, Tensor]:
|
46
|
+
weight_shape: torch.Size,
|
47
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
48
|
+
) -> Dict[str, Tuple[torch.Size, torch.dtype]]:
|
56
49
|
"""
|
57
|
-
|
50
|
+
Creates a dictionary of expected shapes and dtypes for each compression
|
51
|
+
parameter used by the compressor
|
58
52
|
|
59
|
-
:param
|
60
|
-
:param
|
61
|
-
|
62
|
-
|
53
|
+
:param weight_shape: uncompressed weight shape
|
54
|
+
:param quantization_args: quantization parameters for the weight
|
55
|
+
:return: dictionary mapping compressed parameter names to shape and dtype
|
56
|
+
"""
|
57
|
+
pack_factor = 32 // quantization_args.num_bits
|
58
|
+
packed_size = math.ceil(weight_shape[1] / pack_factor)
|
59
|
+
return {
|
60
|
+
"weight_packed": (torch.Size((weight_shape[0], packed_size)), torch.int32),
|
61
|
+
"weight_shape": (torch.Size((2,)), torch.int32),
|
62
|
+
}
|
63
|
+
|
64
|
+
def compress_weight(
|
65
|
+
self,
|
66
|
+
weight: Tensor,
|
67
|
+
scale: Tensor,
|
68
|
+
zero_point: Optional[Tensor] = None,
|
69
|
+
g_idx: Optional[torch.Tensor] = None,
|
70
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
71
|
+
device: Optional[torch.device] = None,
|
72
|
+
) -> Dict[str, torch.Tensor]:
|
73
|
+
"""
|
74
|
+
Compresses a single uncompressed weight
|
75
|
+
|
76
|
+
:param weight: uncompressed weight tensor
|
77
|
+
:param scale: quantization scale for weight
|
78
|
+
:param zero_point: quantization zero point for weight
|
79
|
+
:param g_idx: optional mapping from column index to group index
|
80
|
+
:param quantization_args: quantization parameters for weight
|
81
|
+
:param device: optional device to move compressed output to
|
82
|
+
:return: dictionary of compressed weight data
|
63
83
|
"""
|
64
84
|
compressed_dict = {}
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
scale=scale,
|
84
|
-
zero_point=zp,
|
85
|
-
args=quant_args,
|
86
|
-
dtype=torch.int8,
|
87
|
-
)
|
88
|
-
value = pack_to_int32(value.cpu(), quant_args.num_bits)
|
89
|
-
compressed_dict[merge_names(prefix, "weight_shape")] = shape
|
90
|
-
compressed_dict[merge_names(prefix, "weight_packed")] = value
|
91
|
-
continue
|
92
|
-
|
93
|
-
elif name.endswith("zero_point"):
|
94
|
-
if torch.all(value == 0):
|
95
|
-
# all zero_points are 0, no need to include in
|
96
|
-
# compressed state_dict
|
97
|
-
continue
|
98
|
-
|
99
|
-
compressed_dict[name] = value.to("cpu")
|
85
|
+
if can_quantize(weight, quantization_args):
|
86
|
+
quantized_weight = quantize(
|
87
|
+
x=weight,
|
88
|
+
scale=scale,
|
89
|
+
zero_point=zero_point,
|
90
|
+
g_idx=g_idx,
|
91
|
+
args=quantization_args,
|
92
|
+
dtype=torch.int8,
|
93
|
+
)
|
94
|
+
|
95
|
+
packed_weight = pack_to_int32(quantized_weight, quantization_args.num_bits)
|
96
|
+
weight_shape = torch.tensor(weight.shape)
|
97
|
+
if device is not None:
|
98
|
+
packed_weight = packed_weight.to(device)
|
99
|
+
weight_shape = weight_shape.to(device)
|
100
|
+
|
101
|
+
compressed_dict["weight_shape"] = weight_shape
|
102
|
+
compressed_dict["weight_packed"] = packed_weight
|
100
103
|
|
101
104
|
return compressed_dict
|
102
105
|
|
103
|
-
def
|
106
|
+
def decompress_weight(
|
104
107
|
self,
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
) -> Generator[Tuple[str, Tensor], None, None]:
|
108
|
+
compressed_data: Dict[str, Tensor],
|
109
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
110
|
+
) -> torch.Tensor:
|
109
111
|
"""
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
:
|
115
|
-
one or more safetensors files) or compressed tensors file
|
116
|
-
:param device: optional device to load intermediate weights into
|
117
|
-
:return: compressed state dict
|
112
|
+
Decompresses a single compressed weight
|
113
|
+
|
114
|
+
:param compressed_data: dictionary of data needed for decompression
|
115
|
+
:param quantization_args: quantization parameters for the weight
|
116
|
+
:return: tensor of the decompressed weight
|
118
117
|
"""
|
119
|
-
|
120
|
-
|
118
|
+
weight = compressed_data["weight_packed"]
|
119
|
+
scale = compressed_data["weight_scale"]
|
120
|
+
zero_point = compressed_data.get("weight_zero_point", None)
|
121
|
+
g_idx = compressed_data.get("weight_g_idx", None)
|
122
|
+
original_shape = torch.Size(compressed_data["weight_shape"])
|
123
|
+
num_bits = quantization_args.num_bits
|
124
|
+
unpacked = unpack_from_int32(weight, num_bits, original_shape)
|
125
|
+
decompressed_weight = dequantize(
|
126
|
+
x_q=unpacked, scale=scale, zero_point=zero_point, g_idx=g_idx
|
121
127
|
)
|
122
|
-
|
123
|
-
|
124
|
-
for param_name, safe_path in weight_mappings[weight_name].items():
|
125
|
-
weight_data["num_bits"] = names_to_scheme.get(weight_name).num_bits
|
126
|
-
full_name = merge_names(weight_name, param_name)
|
127
|
-
with safe_open(safe_path, framework="pt", device=device) as f:
|
128
|
-
weight_data[param_name] = f.get_tensor(full_name)
|
129
|
-
|
130
|
-
if "weight_scale" in weight_data:
|
131
|
-
zero_point = weight_data.get("weight_zero_point", None)
|
132
|
-
scale = weight_data["weight_scale"]
|
133
|
-
weight = weight_data["weight_packed"]
|
134
|
-
num_bits = weight_data["num_bits"]
|
135
|
-
original_shape = torch.Size(weight_data["weight_shape"])
|
136
|
-
unpacked = unpack_from_int32(weight, num_bits, original_shape)
|
137
|
-
decompressed = dequantize(
|
138
|
-
x_q=unpacked,
|
139
|
-
scale=scale,
|
140
|
-
zero_point=zero_point,
|
141
|
-
)
|
142
|
-
yield merge_names(weight_name, "weight"), decompressed
|
128
|
+
|
129
|
+
return decompressed_weight
|
143
130
|
|
144
131
|
|
145
132
|
def pack_to_int32(value: torch.Tensor, num_bits: int) -> torch.Tensor:
|
@@ -197,13 +184,15 @@ def unpack_from_int32(
|
|
197
184
|
if num_bits > 8:
|
198
185
|
raise ValueError("Unpacking is only supported for less than 8 bits")
|
199
186
|
|
200
|
-
# convert packed input to unsigned numpy
|
201
|
-
value = value.numpy().view(np.uint32)
|
202
187
|
pack_factor = 32 // num_bits
|
203
188
|
|
204
189
|
# unpack
|
205
190
|
mask = pow(2, num_bits) - 1
|
206
|
-
unpacked =
|
191
|
+
unpacked = torch.zeros(
|
192
|
+
(value.shape[0], value.shape[1] * pack_factor),
|
193
|
+
device=value.device,
|
194
|
+
dtype=torch.int32,
|
195
|
+
)
|
207
196
|
for i in range(pack_factor):
|
208
197
|
unpacked[:, i::pack_factor] = (value >> (num_bits * i)) & mask
|
209
198
|
|
@@ -214,6 +203,6 @@ def unpack_from_int32(
|
|
214
203
|
# bits are packed in unsigned format, reformat to signed
|
215
204
|
# update the value range from unsigned to signed
|
216
205
|
offset = pow(2, num_bits) // 2
|
217
|
-
unpacked = (unpacked
|
206
|
+
unpacked = (unpacked - offset).to(torch.int8)
|
218
207
|
|
219
|
-
return
|
208
|
+
return unpacked
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from enum import Enum
|
16
|
-
from typing import Optional
|
16
|
+
from typing import List, Optional
|
17
17
|
|
18
18
|
from compressed_tensors.registry import RegistryMixin
|
19
19
|
from pydantic import BaseModel
|
@@ -37,11 +37,16 @@ class SparsityCompressionConfig(RegistryMixin, BaseModel):
|
|
37
37
|
Base data class for storing sparsity compression parameters
|
38
38
|
|
39
39
|
:param format: name of compression format
|
40
|
+
:param targets: List of layer names or layer types that aren't sparse and should
|
41
|
+
be ignored during compression. By default, assume all layers are targeted
|
42
|
+
:param ignore: List of layer names (unique) to ignore from targets. Defaults to None
|
40
43
|
:param global_sparsity: average sparsity of the entire model
|
41
44
|
:param sparsity_structure: structure of the sparsity, such as
|
42
45
|
"unstructured", "2:4", "8:16" etc
|
43
46
|
"""
|
44
47
|
|
45
48
|
format: str
|
49
|
+
targets: Optional[List[str]] = None
|
50
|
+
ignore: Optional[List[str]] = None
|
46
51
|
global_sparsity: Optional[float] = 0.0
|
47
52
|
sparsity_structure: Optional[str] = "unstructured"
|
@@ -0,0 +1,13 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
@@ -0,0 +1,87 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import torch
|
16
|
+
from compressed_tensors.compressors.base import Compressor
|
17
|
+
from compressed_tensors.quantization import (
|
18
|
+
QuantizationScheme,
|
19
|
+
QuantizationStatus,
|
20
|
+
initialize_module_for_quantization,
|
21
|
+
)
|
22
|
+
from torch import Tensor
|
23
|
+
from torch.nn import Parameter
|
24
|
+
from torch.nn.functional import linear
|
25
|
+
from torch.nn.modules import Linear
|
26
|
+
|
27
|
+
|
28
|
+
class CompressedLinear(Linear):
|
29
|
+
"""
|
30
|
+
Wrapper module for running a compressed forward pass of a quantized Linear module.
|
31
|
+
The wrapped layer will decompressed on each forward call.
|
32
|
+
|
33
|
+
:param module: dense linear module to replace
|
34
|
+
:param quantization_scheme: quantization config for the module to wrap
|
35
|
+
:param quantization_format: compression format module is stored as
|
36
|
+
"""
|
37
|
+
|
38
|
+
@classmethod
|
39
|
+
@torch.no_grad()
|
40
|
+
def from_linear(
|
41
|
+
cls,
|
42
|
+
module: Linear,
|
43
|
+
quantization_scheme: QuantizationScheme,
|
44
|
+
quantization_format: str,
|
45
|
+
):
|
46
|
+
module.__class__ = CompressedLinear
|
47
|
+
module.compressor = Compressor.load_from_registry(quantization_format)
|
48
|
+
device = next(module.parameters()).device
|
49
|
+
|
50
|
+
# this will initialize all the scales and zero points
|
51
|
+
initialize_module_for_quantization(
|
52
|
+
module, quantization_scheme, force_zero_point=False
|
53
|
+
)
|
54
|
+
|
55
|
+
# get the shape and dtype of compressed parameters
|
56
|
+
compression_params = module.compressor.compression_param_info(
|
57
|
+
module.weight.shape, quantization_scheme.weights
|
58
|
+
)
|
59
|
+
|
60
|
+
# no need for this once quantization is initialized, will be replaced
|
61
|
+
# with the compressed parameter
|
62
|
+
delattr(module, "weight")
|
63
|
+
|
64
|
+
# populate compressed weights and quantization parameters
|
65
|
+
for name, (shape, dtype) in compression_params.items():
|
66
|
+
param = Parameter(
|
67
|
+
torch.empty(shape, device=device, dtype=dtype), requires_grad=False
|
68
|
+
)
|
69
|
+
module.register_parameter(name, param)
|
70
|
+
|
71
|
+
# mark module as compressed
|
72
|
+
module.quantization_status = QuantizationStatus.COMPRESSED
|
73
|
+
|
74
|
+
# handles case where forward is wrapped in new_forward by accelerate hooks
|
75
|
+
if hasattr(module, "_old_forward"):
|
76
|
+
module._old_forward = CompressedLinear.forward.__get__(
|
77
|
+
module, CompressedLinear
|
78
|
+
)
|
79
|
+
|
80
|
+
return module
|
81
|
+
|
82
|
+
def forward(self, input: Tensor) -> Tensor:
|
83
|
+
"""
|
84
|
+
Decompresses the weight, then runs the wrapped forward pass
|
85
|
+
"""
|
86
|
+
uncompressed_weight = self.compressor.decompress_module(self)
|
87
|
+
return linear(input, uncompressed_weight, self.bias)
|
@@ -14,12 +14,14 @@
|
|
14
14
|
|
15
15
|
import logging
|
16
16
|
import re
|
17
|
-
from collections import OrderedDict
|
17
|
+
from collections import OrderedDict, defaultdict
|
18
|
+
from copy import deepcopy
|
18
19
|
from typing import Dict, Iterable, List, Optional
|
19
20
|
from typing import OrderedDict as OrderedDictType
|
20
21
|
from typing import Union
|
21
22
|
|
22
23
|
import torch
|
24
|
+
from compressed_tensors.config import CompressionFormat
|
23
25
|
from compressed_tensors.quantization.lifecycle.calibration import (
|
24
26
|
set_module_for_calibration,
|
25
27
|
)
|
@@ -42,7 +44,7 @@ from compressed_tensors.quantization.utils import (
|
|
42
44
|
is_kv_cache_quant_scheme,
|
43
45
|
iter_named_leaf_modules,
|
44
46
|
)
|
45
|
-
from compressed_tensors.utils.helpers import fix_fsdp_module_name
|
47
|
+
from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
|
46
48
|
from compressed_tensors.utils.offload import update_parameter_data
|
47
49
|
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
|
48
50
|
from torch.nn import Module
|
@@ -103,13 +105,21 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
|
|
103
105
|
)
|
104
106
|
|
105
107
|
|
106
|
-
def apply_quantization_config(
|
108
|
+
def apply_quantization_config(
|
109
|
+
model: Module, config: QuantizationConfig, run_compressed: bool = False
|
110
|
+
) -> Dict:
|
107
111
|
"""
|
108
112
|
Initializes the model for quantization in-place based on the given config
|
109
113
|
|
110
114
|
:param model: model to apply quantization config to
|
111
115
|
:param config: quantization config
|
116
|
+
:param run_compressed: Whether the model will be run in compressed mode or
|
117
|
+
decompressed fully on load
|
112
118
|
"""
|
119
|
+
# remove reference to the original `config`
|
120
|
+
# argument. This function can mutate it, and we'd
|
121
|
+
# like to keep the original `config` as it is.
|
122
|
+
config = deepcopy(config)
|
113
123
|
# build mapping of targets to schemes for easier matching
|
114
124
|
# use ordered dict to preserve target ordering in config
|
115
125
|
target_to_scheme = OrderedDict()
|
@@ -119,21 +129,39 @@ def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict
|
|
119
129
|
for target in scheme.targets:
|
120
130
|
target_to_scheme[target] = scheme
|
121
131
|
|
132
|
+
if run_compressed:
|
133
|
+
from compressed_tensors.linear.compressed_linear import CompressedLinear
|
134
|
+
|
122
135
|
# list of submodules to ignore
|
123
|
-
ignored_submodules =
|
136
|
+
ignored_submodules = defaultdict(list)
|
124
137
|
# mark appropriate layers for quantization by setting their quantization schemes
|
125
138
|
for name, submodule in iter_named_leaf_modules(model):
|
126
139
|
# potentially fix module name to remove FSDP wrapper prefix
|
127
140
|
name = fix_fsdp_module_name(name)
|
128
|
-
if find_name_or_class_matches(name, submodule, config.ignore):
|
129
|
-
|
141
|
+
if matches := find_name_or_class_matches(name, submodule, config.ignore):
|
142
|
+
for match in matches:
|
143
|
+
ignored_submodules[match].append(name)
|
130
144
|
continue # layer matches ignore list, continue
|
131
145
|
targets = find_name_or_class_matches(name, submodule, target_to_scheme)
|
132
146
|
if targets:
|
147
|
+
scheme = _scheme_from_targets(target_to_scheme, targets, name)
|
148
|
+
if run_compressed:
|
149
|
+
format = config.format
|
150
|
+
if format != CompressionFormat.dense.value:
|
151
|
+
if isinstance(submodule, torch.nn.Linear):
|
152
|
+
# TODO: expand to more module types
|
153
|
+
compressed_linear = CompressedLinear.from_linear(
|
154
|
+
submodule,
|
155
|
+
quantization_scheme=scheme,
|
156
|
+
quantization_format=format,
|
157
|
+
)
|
158
|
+
replace_module(model, name, compressed_linear)
|
159
|
+
|
133
160
|
# target matched - add layer and scheme to target list
|
134
161
|
submodule.quantization_scheme = _scheme_from_targets(
|
135
162
|
target_to_scheme, targets, name
|
136
163
|
)
|
164
|
+
|
137
165
|
names_to_scheme[name] = submodule.quantization_scheme.weights
|
138
166
|
|
139
167
|
if config.ignore is not None and ignored_submodules is not None:
|
@@ -143,8 +171,8 @@ def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict
|
|
143
171
|
"not found in the model: "
|
144
172
|
f"{set(config.ignore) - set(ignored_submodules)}"
|
145
173
|
)
|
146
|
-
# apply current quantization status across all targeted layers
|
147
174
|
|
175
|
+
# apply current quantization status across all targeted layers
|
148
176
|
apply_quantization_status(model, config.quantization_status)
|
149
177
|
return names_to_scheme
|
150
178
|
|
@@ -192,7 +220,12 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
|
|
192
220
|
current_status = infer_quantization_status(model)
|
193
221
|
|
194
222
|
if status >= QuantizationStatus.INITIALIZED > current_status:
|
195
|
-
|
223
|
+
force_zero_point_init = status != QuantizationStatus.COMPRESSED
|
224
|
+
model.apply(
|
225
|
+
lambda module: initialize_module_for_quantization(
|
226
|
+
module, force_zero_point=force_zero_point_init
|
227
|
+
)
|
228
|
+
)
|
196
229
|
|
197
230
|
if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
|
198
231
|
# only quantize weights up front when our end goal state is calibration,
|
@@ -273,9 +306,11 @@ def _load_quant_args_from_state_dict(
|
|
273
306
|
"""
|
274
307
|
scale_name = f"{base_name}_scale"
|
275
308
|
zp_name = f"{base_name}_zero_point"
|
309
|
+
g_idx_name = f"{base_name}_g_idx"
|
276
310
|
|
277
311
|
state_dict_scale = state_dict.get(f"{module_name}.{scale_name}", None)
|
278
312
|
state_dict_zp = state_dict.get(f"{module_name}.{zp_name}", None)
|
313
|
+
state_dict_g_idx = state_dict.get(f"{module_name}.{g_idx_name}", None)
|
279
314
|
|
280
315
|
if state_dict_scale is not None:
|
281
316
|
# module is quantized
|
@@ -285,6 +320,9 @@ def _load_quant_args_from_state_dict(
|
|
285
320
|
state_dict_zp = torch.zeros_like(state_dict_scale, device="cpu")
|
286
321
|
update_parameter_data(module, state_dict_zp, zp_name)
|
287
322
|
|
323
|
+
if state_dict_g_idx is not None:
|
324
|
+
update_parameter_data(module, state_dict_g_idx, g_idx_name)
|
325
|
+
|
288
326
|
|
289
327
|
def _scheme_from_targets(
|
290
328
|
target_to_scheme: OrderedDictType[str, QuantizationScheme],
|
@@ -36,15 +36,15 @@ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool =
|
|
36
36
|
apply to full model with `model.apply(set_module_for_calibration)`
|
37
37
|
|
38
38
|
:param module: module to set for calibration
|
39
|
-
:param quantize_weights_upfront: whether to automatically
|
40
|
-
|
39
|
+
:param quantize_weights_upfront: whether to automatically
|
40
|
+
run weight quantization at the start of calibration
|
41
41
|
"""
|
42
42
|
if not getattr(module, "quantization_scheme", None):
|
43
43
|
# no quantization scheme nothing to do
|
44
44
|
return
|
45
45
|
status = getattr(module, "quantization_status", None)
|
46
46
|
if not status or status != QuantizationStatus.INITIALIZED:
|
47
|
-
|
47
|
+
_LOGGER.warning(
|
48
48
|
f"Attempting set module with status {status} to calibration mode. "
|
49
49
|
f"but status is not {QuantizationStatus.INITIALIZED} - you may "
|
50
50
|
"be calibrating an uninitialized module which may fail or attempting "
|
@@ -54,13 +54,14 @@ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool =
|
|
54
54
|
if quantize_weights_upfront and module.quantization_scheme.weights is not None:
|
55
55
|
# set weight scale and zero_point up front, calibration data doesn't affect it
|
56
56
|
observer = module.weight_observer
|
57
|
+
g_idx = getattr(module, "weight_g_idx", None)
|
57
58
|
|
58
59
|
offloaded = False
|
59
60
|
if is_module_offloaded(module):
|
60
61
|
module._hf_hook.pre_forward(module)
|
61
62
|
offloaded = True
|
62
63
|
|
63
|
-
scale, zero_point = observer(module.weight)
|
64
|
+
scale, zero_point = observer(module.weight, g_idx=g_idx)
|
64
65
|
update_parameter_data(module, scale, "weight_scale")
|
65
66
|
update_parameter_data(module, zero_point, "weight_zero_point")
|
66
67
|
|
@@ -49,8 +49,9 @@ def compress_quantized_weights(module: Module):
|
|
49
49
|
weight = getattr(module, "weight", None)
|
50
50
|
scale = getattr(module, "weight_scale", None)
|
51
51
|
zero_point = getattr(module, "weight_zero_point", None)
|
52
|
+
g_idx = getattr(module, "weight_g_idx", None)
|
52
53
|
|
53
|
-
if weight is None or scale is None
|
54
|
+
if weight is None or scale is None:
|
54
55
|
# no weight, scale, or ZP, nothing to do
|
55
56
|
|
56
57
|
# mark as compressed here to maintain consistent status throughout the model
|
@@ -62,6 +63,7 @@ def compress_quantized_weights(module: Module):
|
|
62
63
|
x=weight,
|
63
64
|
scale=scale,
|
64
65
|
zero_point=zero_point,
|
66
|
+
g_idx=g_idx,
|
65
67
|
args=scheme.weights,
|
66
68
|
dtype=torch.int8,
|
67
69
|
)
|