compressed-tensors-nightly 0.5.0.20240814__py3-none-any.whl → 0.5.0.20240830__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/compressors/base.py +198 -8
- compressed_tensors/compressors/model_compressor.py +65 -1
- compressed_tensors/compressors/naive_quantized.py +71 -75
- compressed_tensors/compressors/pack_quantized.py +83 -94
- compressed_tensors/linear/__init__.py +13 -0
- compressed_tensors/linear/compressed_linear.py +87 -0
- compressed_tensors/quantization/lifecycle/apply.py +36 -4
- compressed_tensors/quantization/lifecycle/calibration.py +3 -2
- compressed_tensors/quantization/lifecycle/compressed.py +1 -1
- compressed_tensors/quantization/lifecycle/forward.py +67 -43
- compressed_tensors/quantization/lifecycle/helpers.py +29 -2
- compressed_tensors/quantization/lifecycle/initialize.py +50 -16
- compressed_tensors/quantization/observers/__init__.py +1 -0
- compressed_tensors/quantization/observers/base.py +54 -14
- compressed_tensors/quantization/observers/min_max.py +8 -0
- compressed_tensors/quantization/observers/mse.py +162 -0
- compressed_tensors/quantization/quant_args.py +48 -20
- compressed_tensors/utils/__init__.py +1 -0
- compressed_tensors/utils/helpers.py +13 -0
- compressed_tensors/utils/offload.py +7 -1
- compressed_tensors/utils/permute.py +70 -0
- compressed_tensors/utils/safetensors_load.py +2 -0
- compressed_tensors/utils/semi_structured_conversions.py +1 -0
- {compressed_tensors_nightly-0.5.0.20240814.dist-info → compressed_tensors_nightly-0.5.0.20240830.dist-info}/METADATA +3 -2
- compressed_tensors_nightly-0.5.0.20240830.dist-info/RECORD +52 -0
- compressed_tensors_nightly-0.5.0.20240814.dist-info/RECORD +0 -48
- {compressed_tensors_nightly-0.5.0.20240814.dist-info → compressed_tensors_nightly-0.5.0.20240830.dist-info}/LICENSE +0 -0
- {compressed_tensors_nightly-0.5.0.20240814.dist-info → compressed_tensors_nightly-0.5.0.20240830.dist-info}/WHEEL +0 -0
- {compressed_tensors_nightly-0.5.0.20240814.dist-info → compressed_tensors_nightly-0.5.0.20240830.dist-info}/top_level.txt +0 -0
@@ -11,10 +11,8 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
|
15
|
-
import logging
|
16
14
|
import math
|
17
|
-
from typing import Dict,
|
15
|
+
from typing import Dict, Optional, Tuple
|
18
16
|
|
19
17
|
import numpy as np
|
20
18
|
import torch
|
@@ -23,16 +21,11 @@ from compressed_tensors.config import CompressionFormat
|
|
23
21
|
from compressed_tensors.quantization import QuantizationArgs
|
24
22
|
from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
|
25
23
|
from compressed_tensors.quantization.utils import can_quantize
|
26
|
-
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
|
27
|
-
from safetensors import safe_open
|
28
24
|
from torch import Tensor
|
29
|
-
from tqdm import tqdm
|
30
25
|
|
31
26
|
|
32
27
|
__all__ = ["PackedQuantizationCompressor", "pack_to_int32", "unpack_from_int32"]
|
33
28
|
|
34
|
-
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
35
|
-
|
36
29
|
|
37
30
|
@Compressor.register(name=CompressionFormat.pack_quantized.value)
|
38
31
|
class PackedQuantizationCompressor(Compressor):
|
@@ -44,102 +37,96 @@ class PackedQuantizationCompressor(Compressor):
|
|
44
37
|
"weight_packed",
|
45
38
|
"weight_scale",
|
46
39
|
"weight_zero_point",
|
40
|
+
"weight_g_idx",
|
47
41
|
"weight_shape",
|
48
42
|
]
|
49
43
|
|
50
|
-
def
|
44
|
+
def compression_param_info(
|
51
45
|
self,
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
) -> Dict[str, Tensor]:
|
46
|
+
weight_shape: torch.Size,
|
47
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
48
|
+
) -> Dict[str, Tuple[torch.Size, torch.dtype]]:
|
56
49
|
"""
|
57
|
-
|
50
|
+
Creates a dictionary of expected shapes and dtypes for each compression
|
51
|
+
parameter used by the compressor
|
58
52
|
|
59
|
-
:param
|
60
|
-
:param
|
61
|
-
|
62
|
-
|
53
|
+
:param weight_shape: uncompressed weight shape
|
54
|
+
:param quantization_args: quantization parameters for the weight
|
55
|
+
:return: dictionary mapping compressed parameter names to shape and dtype
|
56
|
+
"""
|
57
|
+
pack_factor = 32 // quantization_args.num_bits
|
58
|
+
packed_size = math.ceil(weight_shape[1] / pack_factor)
|
59
|
+
return {
|
60
|
+
"weight_packed": (torch.Size((weight_shape[0], packed_size)), torch.int32),
|
61
|
+
"weight_shape": (torch.Size((2,)), torch.int32),
|
62
|
+
}
|
63
|
+
|
64
|
+
def compress_weight(
|
65
|
+
self,
|
66
|
+
weight: Tensor,
|
67
|
+
scale: Tensor,
|
68
|
+
zero_point: Optional[Tensor] = None,
|
69
|
+
g_idx: Optional[torch.Tensor] = None,
|
70
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
71
|
+
device: Optional[torch.device] = None,
|
72
|
+
) -> Dict[str, torch.Tensor]:
|
73
|
+
"""
|
74
|
+
Compresses a single uncompressed weight
|
75
|
+
|
76
|
+
:param weight: uncompressed weight tensor
|
77
|
+
:param scale: quantization scale for weight
|
78
|
+
:param zero_point: quantization zero point for weight
|
79
|
+
:param g_idx: optional mapping from column index to group index
|
80
|
+
:param quantization_args: quantization parameters for weight
|
81
|
+
:param device: optional device to move compressed output to
|
82
|
+
:return: dictionary of compressed weight data
|
63
83
|
"""
|
64
84
|
compressed_dict = {}
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
scale=scale,
|
84
|
-
zero_point=zp,
|
85
|
-
args=quant_args,
|
86
|
-
dtype=torch.int8,
|
87
|
-
)
|
88
|
-
value = pack_to_int32(value.cpu(), quant_args.num_bits)
|
89
|
-
compressed_dict[merge_names(prefix, "weight_shape")] = shape
|
90
|
-
compressed_dict[merge_names(prefix, "weight_packed")] = value
|
91
|
-
continue
|
92
|
-
|
93
|
-
elif name.endswith("zero_point"):
|
94
|
-
if torch.all(value == 0):
|
95
|
-
# all zero_points are 0, no need to include in
|
96
|
-
# compressed state_dict
|
97
|
-
continue
|
98
|
-
|
99
|
-
compressed_dict[name] = value.to("cpu")
|
85
|
+
if can_quantize(weight, quantization_args):
|
86
|
+
quantized_weight = quantize(
|
87
|
+
x=weight,
|
88
|
+
scale=scale,
|
89
|
+
zero_point=zero_point,
|
90
|
+
g_idx=g_idx,
|
91
|
+
args=quantization_args,
|
92
|
+
dtype=torch.int8,
|
93
|
+
)
|
94
|
+
|
95
|
+
packed_weight = pack_to_int32(quantized_weight, quantization_args.num_bits)
|
96
|
+
weight_shape = torch.tensor(weight.shape)
|
97
|
+
if device is not None:
|
98
|
+
packed_weight = packed_weight.to(device)
|
99
|
+
weight_shape = weight_shape.to(device)
|
100
|
+
|
101
|
+
compressed_dict["weight_shape"] = weight_shape
|
102
|
+
compressed_dict["weight_packed"] = packed_weight
|
100
103
|
|
101
104
|
return compressed_dict
|
102
105
|
|
103
|
-
def
|
106
|
+
def decompress_weight(
|
104
107
|
self,
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
) -> Generator[Tuple[str, Tensor], None, None]:
|
108
|
+
compressed_data: Dict[str, Tensor],
|
109
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
110
|
+
) -> torch.Tensor:
|
109
111
|
"""
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
:
|
115
|
-
one or more safetensors files) or compressed tensors file
|
116
|
-
:param device: optional device to load intermediate weights into
|
117
|
-
:return: compressed state dict
|
112
|
+
Decompresses a single compressed weight
|
113
|
+
|
114
|
+
:param compressed_data: dictionary of data needed for decompression
|
115
|
+
:param quantization_args: quantization parameters for the weight
|
116
|
+
:return: tensor of the decompressed weight
|
118
117
|
"""
|
119
|
-
|
120
|
-
|
118
|
+
weight = compressed_data["weight_packed"]
|
119
|
+
scale = compressed_data["weight_scale"]
|
120
|
+
zero_point = compressed_data.get("weight_zero_point", None)
|
121
|
+
g_idx = compressed_data.get("weight_g_idx", None)
|
122
|
+
original_shape = torch.Size(compressed_data["weight_shape"])
|
123
|
+
num_bits = quantization_args.num_bits
|
124
|
+
unpacked = unpack_from_int32(weight, num_bits, original_shape)
|
125
|
+
decompressed_weight = dequantize(
|
126
|
+
x_q=unpacked, scale=scale, zero_point=zero_point, g_idx=g_idx
|
121
127
|
)
|
122
|
-
|
123
|
-
|
124
|
-
for param_name, safe_path in weight_mappings[weight_name].items():
|
125
|
-
weight_data["num_bits"] = names_to_scheme.get(weight_name).num_bits
|
126
|
-
full_name = merge_names(weight_name, param_name)
|
127
|
-
with safe_open(safe_path, framework="pt", device=device) as f:
|
128
|
-
weight_data[param_name] = f.get_tensor(full_name)
|
129
|
-
|
130
|
-
if "weight_scale" in weight_data:
|
131
|
-
zero_point = weight_data.get("weight_zero_point", None)
|
132
|
-
scale = weight_data["weight_scale"]
|
133
|
-
weight = weight_data["weight_packed"]
|
134
|
-
num_bits = weight_data["num_bits"]
|
135
|
-
original_shape = torch.Size(weight_data["weight_shape"])
|
136
|
-
unpacked = unpack_from_int32(weight, num_bits, original_shape)
|
137
|
-
decompressed = dequantize(
|
138
|
-
x_q=unpacked,
|
139
|
-
scale=scale,
|
140
|
-
zero_point=zero_point,
|
141
|
-
)
|
142
|
-
yield merge_names(weight_name, "weight"), decompressed
|
128
|
+
|
129
|
+
return decompressed_weight
|
143
130
|
|
144
131
|
|
145
132
|
def pack_to_int32(value: torch.Tensor, num_bits: int) -> torch.Tensor:
|
@@ -197,13 +184,15 @@ def unpack_from_int32(
|
|
197
184
|
if num_bits > 8:
|
198
185
|
raise ValueError("Unpacking is only supported for less than 8 bits")
|
199
186
|
|
200
|
-
# convert packed input to unsigned numpy
|
201
|
-
value = value.numpy().view(np.uint32)
|
202
187
|
pack_factor = 32 // num_bits
|
203
188
|
|
204
189
|
# unpack
|
205
190
|
mask = pow(2, num_bits) - 1
|
206
|
-
unpacked =
|
191
|
+
unpacked = torch.zeros(
|
192
|
+
(value.shape[0], value.shape[1] * pack_factor),
|
193
|
+
device=value.device,
|
194
|
+
dtype=torch.int32,
|
195
|
+
)
|
207
196
|
for i in range(pack_factor):
|
208
197
|
unpacked[:, i::pack_factor] = (value >> (num_bits * i)) & mask
|
209
198
|
|
@@ -214,6 +203,6 @@ def unpack_from_int32(
|
|
214
203
|
# bits are packed in unsigned format, reformat to signed
|
215
204
|
# update the value range from unsigned to signed
|
216
205
|
offset = pow(2, num_bits) // 2
|
217
|
-
unpacked = (unpacked
|
206
|
+
unpacked = (unpacked - offset).to(torch.int8)
|
218
207
|
|
219
|
-
return
|
208
|
+
return unpacked
|
@@ -0,0 +1,13 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
@@ -0,0 +1,87 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import torch
|
16
|
+
from compressed_tensors.compressors.base import Compressor
|
17
|
+
from compressed_tensors.quantization import (
|
18
|
+
QuantizationScheme,
|
19
|
+
QuantizationStatus,
|
20
|
+
initialize_module_for_quantization,
|
21
|
+
)
|
22
|
+
from torch import Tensor
|
23
|
+
from torch.nn import Parameter
|
24
|
+
from torch.nn.functional import linear
|
25
|
+
from torch.nn.modules import Linear
|
26
|
+
|
27
|
+
|
28
|
+
class CompressedLinear(Linear):
|
29
|
+
"""
|
30
|
+
Wrapper module for running a compressed forward pass of a quantized Linear module.
|
31
|
+
The wrapped layer will decompressed on each forward call.
|
32
|
+
|
33
|
+
:param module: dense linear module to replace
|
34
|
+
:param quantization_scheme: quantization config for the module to wrap
|
35
|
+
:param quantization_format: compression format module is stored as
|
36
|
+
"""
|
37
|
+
|
38
|
+
@classmethod
|
39
|
+
@torch.no_grad()
|
40
|
+
def from_linear(
|
41
|
+
cls,
|
42
|
+
module: Linear,
|
43
|
+
quantization_scheme: QuantizationScheme,
|
44
|
+
quantization_format: str,
|
45
|
+
):
|
46
|
+
module.__class__ = CompressedLinear
|
47
|
+
module.compressor = Compressor.load_from_registry(quantization_format)
|
48
|
+
device = next(module.parameters()).device
|
49
|
+
|
50
|
+
# this will initialize all the scales and zero points
|
51
|
+
initialize_module_for_quantization(
|
52
|
+
module, quantization_scheme, force_zero_point=False
|
53
|
+
)
|
54
|
+
|
55
|
+
# get the shape and dtype of compressed parameters
|
56
|
+
compression_params = module.compressor.compression_param_info(
|
57
|
+
module.weight.shape, quantization_scheme.weights
|
58
|
+
)
|
59
|
+
|
60
|
+
# no need for this once quantization is initialized, will be replaced
|
61
|
+
# with the compressed parameter
|
62
|
+
delattr(module, "weight")
|
63
|
+
|
64
|
+
# populate compressed weights and quantization parameters
|
65
|
+
for name, (shape, dtype) in compression_params.items():
|
66
|
+
param = Parameter(
|
67
|
+
torch.empty(shape, device=device, dtype=dtype), requires_grad=False
|
68
|
+
)
|
69
|
+
module.register_parameter(name, param)
|
70
|
+
|
71
|
+
# mark module as compressed
|
72
|
+
module.quantization_status = QuantizationStatus.COMPRESSED
|
73
|
+
|
74
|
+
# handles case where forward is wrapped in new_forward by accelerate hooks
|
75
|
+
if hasattr(module, "_old_forward"):
|
76
|
+
module._old_forward = CompressedLinear.forward.__get__(
|
77
|
+
module, CompressedLinear
|
78
|
+
)
|
79
|
+
|
80
|
+
return module
|
81
|
+
|
82
|
+
def forward(self, input: Tensor) -> Tensor:
|
83
|
+
"""
|
84
|
+
Decompresses the weight, then runs the wrapped forward pass
|
85
|
+
"""
|
86
|
+
uncompressed_weight = self.compressor.decompress_module(self)
|
87
|
+
return linear(input, uncompressed_weight, self.bias)
|
@@ -21,6 +21,7 @@ from typing import OrderedDict as OrderedDictType
|
|
21
21
|
from typing import Union
|
22
22
|
|
23
23
|
import torch
|
24
|
+
from compressed_tensors.config import CompressionFormat
|
24
25
|
from compressed_tensors.quantization.lifecycle.calibration import (
|
25
26
|
set_module_for_calibration,
|
26
27
|
)
|
@@ -43,7 +44,7 @@ from compressed_tensors.quantization.utils import (
|
|
43
44
|
is_kv_cache_quant_scheme,
|
44
45
|
iter_named_leaf_modules,
|
45
46
|
)
|
46
|
-
from compressed_tensors.utils.helpers import fix_fsdp_module_name
|
47
|
+
from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
|
47
48
|
from compressed_tensors.utils.offload import update_parameter_data
|
48
49
|
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
|
49
50
|
from torch.nn import Module
|
@@ -104,12 +105,16 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
|
|
104
105
|
)
|
105
106
|
|
106
107
|
|
107
|
-
def apply_quantization_config(
|
108
|
+
def apply_quantization_config(
|
109
|
+
model: Module, config: QuantizationConfig, run_compressed: bool = False
|
110
|
+
) -> Dict:
|
108
111
|
"""
|
109
112
|
Initializes the model for quantization in-place based on the given config
|
110
113
|
|
111
114
|
:param model: model to apply quantization config to
|
112
115
|
:param config: quantization config
|
116
|
+
:param run_compressed: Whether the model will be run in compressed mode or
|
117
|
+
decompressed fully on load
|
113
118
|
"""
|
114
119
|
# remove reference to the original `config`
|
115
120
|
# argument. This function can mutate it, and we'd
|
@@ -124,6 +129,9 @@ def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict
|
|
124
129
|
for target in scheme.targets:
|
125
130
|
target_to_scheme[target] = scheme
|
126
131
|
|
132
|
+
if run_compressed:
|
133
|
+
from compressed_tensors.linear.compressed_linear import CompressedLinear
|
134
|
+
|
127
135
|
# list of submodules to ignore
|
128
136
|
ignored_submodules = defaultdict(list)
|
129
137
|
# mark appropriate layers for quantization by setting their quantization schemes
|
@@ -136,10 +144,24 @@ def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict
|
|
136
144
|
continue # layer matches ignore list, continue
|
137
145
|
targets = find_name_or_class_matches(name, submodule, target_to_scheme)
|
138
146
|
if targets:
|
147
|
+
scheme = _scheme_from_targets(target_to_scheme, targets, name)
|
148
|
+
if run_compressed:
|
149
|
+
format = config.format
|
150
|
+
if format != CompressionFormat.dense.value:
|
151
|
+
if isinstance(submodule, torch.nn.Linear):
|
152
|
+
# TODO: expand to more module types
|
153
|
+
compressed_linear = CompressedLinear.from_linear(
|
154
|
+
submodule,
|
155
|
+
quantization_scheme=scheme,
|
156
|
+
quantization_format=format,
|
157
|
+
)
|
158
|
+
replace_module(model, name, compressed_linear)
|
159
|
+
|
139
160
|
# target matched - add layer and scheme to target list
|
140
161
|
submodule.quantization_scheme = _scheme_from_targets(
|
141
162
|
target_to_scheme, targets, name
|
142
163
|
)
|
164
|
+
|
143
165
|
names_to_scheme[name] = submodule.quantization_scheme.weights
|
144
166
|
|
145
167
|
if config.ignore is not None and ignored_submodules is not None:
|
@@ -149,8 +171,8 @@ def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict
|
|
149
171
|
"not found in the model: "
|
150
172
|
f"{set(config.ignore) - set(ignored_submodules)}"
|
151
173
|
)
|
152
|
-
# apply current quantization status across all targeted layers
|
153
174
|
|
175
|
+
# apply current quantization status across all targeted layers
|
154
176
|
apply_quantization_status(model, config.quantization_status)
|
155
177
|
return names_to_scheme
|
156
178
|
|
@@ -198,7 +220,12 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
|
|
198
220
|
current_status = infer_quantization_status(model)
|
199
221
|
|
200
222
|
if status >= QuantizationStatus.INITIALIZED > current_status:
|
201
|
-
|
223
|
+
force_zero_point_init = status != QuantizationStatus.COMPRESSED
|
224
|
+
model.apply(
|
225
|
+
lambda module: initialize_module_for_quantization(
|
226
|
+
module, force_zero_point=force_zero_point_init
|
227
|
+
)
|
228
|
+
)
|
202
229
|
|
203
230
|
if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
|
204
231
|
# only quantize weights up front when our end goal state is calibration,
|
@@ -279,9 +306,11 @@ def _load_quant_args_from_state_dict(
|
|
279
306
|
"""
|
280
307
|
scale_name = f"{base_name}_scale"
|
281
308
|
zp_name = f"{base_name}_zero_point"
|
309
|
+
g_idx_name = f"{base_name}_g_idx"
|
282
310
|
|
283
311
|
state_dict_scale = state_dict.get(f"{module_name}.{scale_name}", None)
|
284
312
|
state_dict_zp = state_dict.get(f"{module_name}.{zp_name}", None)
|
313
|
+
state_dict_g_idx = state_dict.get(f"{module_name}.{g_idx_name}", None)
|
285
314
|
|
286
315
|
if state_dict_scale is not None:
|
287
316
|
# module is quantized
|
@@ -291,6 +320,9 @@ def _load_quant_args_from_state_dict(
|
|
291
320
|
state_dict_zp = torch.zeros_like(state_dict_scale, device="cpu")
|
292
321
|
update_parameter_data(module, state_dict_zp, zp_name)
|
293
322
|
|
323
|
+
if state_dict_g_idx is not None:
|
324
|
+
update_parameter_data(module, state_dict_g_idx, g_idx_name)
|
325
|
+
|
294
326
|
|
295
327
|
def _scheme_from_targets(
|
296
328
|
target_to_scheme: OrderedDictType[str, QuantizationScheme],
|
@@ -44,7 +44,7 @@ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool =
|
|
44
44
|
return
|
45
45
|
status = getattr(module, "quantization_status", None)
|
46
46
|
if not status or status != QuantizationStatus.INITIALIZED:
|
47
|
-
|
47
|
+
_LOGGER.warning(
|
48
48
|
f"Attempting set module with status {status} to calibration mode. "
|
49
49
|
f"but status is not {QuantizationStatus.INITIALIZED} - you may "
|
50
50
|
"be calibrating an uninitialized module which may fail or attempting "
|
@@ -54,13 +54,14 @@ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool =
|
|
54
54
|
if quantize_weights_upfront and module.quantization_scheme.weights is not None:
|
55
55
|
# set weight scale and zero_point up front, calibration data doesn't affect it
|
56
56
|
observer = module.weight_observer
|
57
|
+
g_idx = getattr(module, "weight_g_idx", None)
|
57
58
|
|
58
59
|
offloaded = False
|
59
60
|
if is_module_offloaded(module):
|
60
61
|
module._hf_hook.pre_forward(module)
|
61
62
|
offloaded = True
|
62
63
|
|
63
|
-
scale, zero_point = observer(module.weight)
|
64
|
+
scale, zero_point = observer(module.weight, g_idx=g_idx)
|
64
65
|
update_parameter_data(module, scale, "weight_scale")
|
65
66
|
update_parameter_data(module, zero_point, "weight_zero_point")
|
66
67
|
|
@@ -50,7 +50,7 @@ def compress_quantized_weights(module: Module):
|
|
50
50
|
scale = getattr(module, "weight_scale", None)
|
51
51
|
zero_point = getattr(module, "weight_zero_point", None)
|
52
52
|
|
53
|
-
if weight is None or scale is None
|
53
|
+
if weight is None or scale is None:
|
54
54
|
# no weight, scale, or ZP, nothing to do
|
55
55
|
|
56
56
|
# mark as compressed here to maintain consistent status throughout the model
|