compressed-tensors 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/__init__.py +1 -0
- compressed_tensors/base.py +2 -0
- compressed_tensors/compressors/__init__.py +6 -12
- compressed_tensors/compressors/base.py +137 -9
- compressed_tensors/compressors/helpers.py +6 -6
- compressed_tensors/compressors/model_compressors/__init__.py +17 -0
- compressed_tensors/compressors/{model_compressor.py → model_compressors/model_compressor.py} +99 -43
- compressed_tensors/compressors/quantized_compressors/__init__.py +18 -0
- compressed_tensors/compressors/{naive_quantized.py → quantized_compressors/base.py} +64 -62
- compressed_tensors/compressors/quantized_compressors/naive_quantized.py +140 -0
- compressed_tensors/compressors/quantized_compressors/pack_quantized.py +211 -0
- compressed_tensors/compressors/sparse_compressors/__init__.py +18 -0
- compressed_tensors/compressors/sparse_compressors/base.py +110 -0
- compressed_tensors/compressors/{dense.py → sparse_compressors/dense.py} +3 -3
- compressed_tensors/compressors/{sparse_bitmask.py → sparse_compressors/sparse_bitmask.py} +14 -59
- compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +16 -0
- compressed_tensors/compressors/{marlin_24.py → sparse_quantized_compressors/marlin_24.py} +3 -3
- compressed_tensors/config/base.py +6 -1
- compressed_tensors/linear/__init__.py +13 -0
- compressed_tensors/linear/compressed_linear.py +87 -0
- compressed_tensors/quantization/__init__.py +1 -0
- compressed_tensors/quantization/cache.py +201 -0
- compressed_tensors/quantization/lifecycle/apply.py +63 -9
- compressed_tensors/quantization/lifecycle/calibration.py +7 -7
- compressed_tensors/quantization/lifecycle/compressed.py +3 -1
- compressed_tensors/quantization/lifecycle/forward.py +126 -44
- compressed_tensors/quantization/lifecycle/frozen.py +6 -1
- compressed_tensors/quantization/lifecycle/helpers.py +0 -20
- compressed_tensors/quantization/lifecycle/initialize.py +138 -55
- compressed_tensors/quantization/observers/__init__.py +1 -0
- compressed_tensors/quantization/observers/base.py +54 -14
- compressed_tensors/quantization/observers/min_max.py +8 -0
- compressed_tensors/quantization/observers/mse.py +162 -0
- compressed_tensors/quantization/quant_args.py +102 -24
- compressed_tensors/quantization/quant_config.py +14 -2
- compressed_tensors/quantization/quant_scheme.py +12 -13
- compressed_tensors/quantization/utils/helpers.py +44 -19
- compressed_tensors/utils/__init__.py +1 -0
- compressed_tensors/utils/helpers.py +30 -1
- compressed_tensors/utils/offload.py +14 -2
- compressed_tensors/utils/permute.py +70 -0
- compressed_tensors/utils/safetensors_load.py +2 -0
- compressed_tensors/utils/semi_structured_conversions.py +1 -0
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/METADATA +35 -23
- compressed_tensors-0.7.0.dist-info/RECORD +59 -0
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/WHEEL +1 -1
- compressed_tensors/compressors/pack_quantized.py +0 -219
- compressed_tensors-0.5.0.dist-info/RECORD +0 -48
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/top_level.txt +0 -0
@@ -17,17 +17,19 @@ import logging
|
|
17
17
|
from typing import Optional
|
18
18
|
|
19
19
|
import torch
|
20
|
-
from
|
21
|
-
from accelerate.utils import PrefixedDataset
|
20
|
+
from compressed_tensors.quantization.cache import KVCacheScaleType
|
22
21
|
from compressed_tensors.quantization.lifecycle.forward import (
|
23
22
|
wrap_module_forward_quantized,
|
23
|
+
wrap_module_forward_quantized_attn,
|
24
24
|
)
|
25
25
|
from compressed_tensors.quantization.quant_args import (
|
26
|
+
ActivationOrdering,
|
26
27
|
QuantizationArgs,
|
27
28
|
QuantizationStrategy,
|
28
29
|
)
|
29
30
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
30
31
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
32
|
+
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
|
31
33
|
from compressed_tensors.utils import get_execution_device, is_module_offloaded
|
32
34
|
from torch.nn import Module, Parameter
|
33
35
|
|
@@ -43,6 +45,7 @@ _LOGGER = logging.getLogger(__name__)
|
|
43
45
|
def initialize_module_for_quantization(
|
44
46
|
module: Module,
|
45
47
|
scheme: Optional[QuantizationScheme] = None,
|
48
|
+
force_zero_point: bool = True,
|
46
49
|
):
|
47
50
|
"""
|
48
51
|
attaches appropriate scales, zero points, and observers to a layer
|
@@ -54,61 +57,93 @@ def initialize_module_for_quantization(
|
|
54
57
|
:param scheme: scheme to use for quantization. if None is provided,
|
55
58
|
will attempt to use scheme stored in the module under `quantization_scheme`,
|
56
59
|
if not provided, the layer will be skipped
|
60
|
+
:param force_zero_point: whether to force initialization of a zero point for
|
61
|
+
symmetric quantization
|
57
62
|
"""
|
58
63
|
scheme = scheme or getattr(module, "quantization_scheme", None)
|
59
64
|
if scheme is None:
|
60
65
|
# no scheme passed and layer not targeted for quantization - skip
|
61
66
|
return
|
62
67
|
|
63
|
-
if
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
if isinstance(module, torch.nn.Linear):
|
69
|
-
weight_shape = module.weight.shape
|
70
|
-
_initialize_scale_zero_point_observer(
|
71
|
-
module, "weight", scheme.weights, weight_shape=weight_shape
|
72
|
-
)
|
73
|
-
else:
|
74
|
-
_LOGGER.warning(
|
75
|
-
f"module type {type(module)} targeted for weight quantization but "
|
76
|
-
"has no attribute weight, skipping weight quantization "
|
77
|
-
f"for {type(module)}"
|
78
|
-
)
|
79
|
-
if scheme.output_activations is not None:
|
80
|
-
_initialize_scale_zero_point_observer(
|
81
|
-
module, "output", scheme.output_activations
|
82
|
-
)
|
68
|
+
if is_attention_module(module):
|
69
|
+
# wrap forward call of module to perform
|
70
|
+
# quantized actions based on calltime status
|
71
|
+
wrap_module_forward_quantized_attn(module, scheme)
|
72
|
+
_initialize_attn_scales(module)
|
83
73
|
|
84
|
-
|
85
|
-
module.quantization_status = QuantizationStatus.INITIALIZED
|
74
|
+
else:
|
86
75
|
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
76
|
+
if scheme.input_activations is not None:
|
77
|
+
_initialize_scale_zero_point_observer(
|
78
|
+
module,
|
79
|
+
"input",
|
80
|
+
scheme.input_activations,
|
81
|
+
force_zero_point=force_zero_point,
|
82
|
+
)
|
83
|
+
if scheme.weights is not None:
|
84
|
+
if hasattr(module, "weight"):
|
85
|
+
weight_shape = None
|
86
|
+
if isinstance(module, torch.nn.Linear):
|
87
|
+
weight_shape = module.weight.shape
|
88
|
+
_initialize_scale_zero_point_observer(
|
89
|
+
module,
|
90
|
+
"weight",
|
91
|
+
scheme.weights,
|
92
|
+
weight_shape=weight_shape,
|
93
|
+
force_zero_point=force_zero_point,
|
94
|
+
)
|
99
95
|
else:
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
96
|
+
_LOGGER.warning(
|
97
|
+
f"module type {type(module)} targeted for weight quantization but "
|
98
|
+
"has no attribute weight, skipping weight quantization "
|
99
|
+
f"for {type(module)}"
|
100
|
+
)
|
101
|
+
|
102
|
+
if scheme.output_activations is not None:
|
103
|
+
if not is_kv_cache_quant_scheme(scheme):
|
104
|
+
_initialize_scale_zero_point_observer(
|
105
|
+
module, "output", scheme.output_activations
|
106
|
+
)
|
107
|
+
|
108
|
+
module.quantization_scheme = scheme
|
109
|
+
module.quantization_status = QuantizationStatus.INITIALIZED
|
110
|
+
|
111
|
+
offloaded = False
|
112
|
+
if is_module_offloaded(module):
|
113
|
+
try:
|
114
|
+
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
115
|
+
from accelerate.utils import PrefixedDataset
|
116
|
+
except ModuleNotFoundError:
|
117
|
+
raise ModuleNotFoundError(
|
118
|
+
"Offloaded model detected. To use CPU offloading with "
|
119
|
+
"compressed-tensors the `accelerate` package must be installed, "
|
120
|
+
"run `pip install compressed-tensors[accelerate]`"
|
121
|
+
)
|
122
|
+
|
123
|
+
offloaded = True
|
124
|
+
hook = module._hf_hook
|
125
|
+
prefix_dict = module._hf_hook.weights_map
|
126
|
+
new_prefix = {}
|
127
|
+
|
128
|
+
# recreate the prefix dict (since it is immutable)
|
129
|
+
# and add quantization parameters
|
130
|
+
for key, data in module.named_parameters():
|
131
|
+
if key not in prefix_dict:
|
132
|
+
new_prefix[f"{prefix_dict.prefix}{key}"] = data
|
133
|
+
else:
|
134
|
+
new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
|
135
|
+
new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
|
136
|
+
remove_hook_from_module(module)
|
137
|
+
|
138
|
+
# wrap forward call of module to perform
|
139
|
+
# quantized actions based on calltime status
|
140
|
+
wrap_module_forward_quantized(module, scheme)
|
141
|
+
|
142
|
+
if offloaded:
|
143
|
+
# we need to re-add the hook for offloading now that we've wrapped forward
|
144
|
+
add_hook_to_module(module, hook)
|
145
|
+
if prefix_dict is not None:
|
146
|
+
module._hf_hook.weights_map = new_prefix_dict
|
112
147
|
|
113
148
|
|
114
149
|
def _initialize_scale_zero_point_observer(
|
@@ -116,6 +151,7 @@ def _initialize_scale_zero_point_observer(
|
|
116
151
|
base_name: str,
|
117
152
|
quantization_args: QuantizationArgs,
|
118
153
|
weight_shape: Optional[torch.Size] = None,
|
154
|
+
force_zero_point: bool = True,
|
119
155
|
):
|
120
156
|
# initialize observer module and attach as submodule
|
121
157
|
observer = quantization_args.get_observer()
|
@@ -136,21 +172,68 @@ def _initialize_scale_zero_point_observer(
|
|
136
172
|
# (output_channels, 1)
|
137
173
|
expected_shape = (weight_shape[0], 1)
|
138
174
|
elif quantization_args.strategy == QuantizationStrategy.GROUP:
|
175
|
+
num_groups = weight_shape[1] // quantization_args.group_size
|
139
176
|
expected_shape = (
|
140
177
|
weight_shape[0],
|
141
|
-
|
178
|
+
max(num_groups, 1)
|
142
179
|
)
|
143
180
|
|
144
|
-
|
181
|
+
scale_dtype = module.weight.dtype
|
182
|
+
if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]:
|
183
|
+
scale_dtype = torch.float16
|
184
|
+
|
185
|
+
# initializes empty scale, zero point, and g_idx parameters for the module
|
145
186
|
init_scale = Parameter(
|
146
|
-
torch.empty(expected_shape, dtype=
|
187
|
+
torch.empty(expected_shape, dtype=scale_dtype, device=device),
|
147
188
|
requires_grad=False,
|
148
189
|
)
|
149
190
|
module.register_parameter(f"{base_name}_scale", init_scale)
|
150
191
|
|
151
|
-
|
152
|
-
|
153
|
-
|
192
|
+
if force_zero_point or not quantization_args.symmetric:
|
193
|
+
zp_dtype = quantization_args.pytorch_dtype()
|
194
|
+
init_zero_point = Parameter(
|
195
|
+
torch.zeros(expected_shape, device=device, dtype=zp_dtype),
|
196
|
+
requires_grad=False,
|
197
|
+
)
|
198
|
+
module.register_parameter(f"{base_name}_zero_point", init_zero_point)
|
199
|
+
|
200
|
+
# only grouped activation ordering has g_idx
|
201
|
+
if quantization_args.actorder == ActivationOrdering.GROUP:
|
202
|
+
g_idx_shape = (weight_shape[1],)
|
203
|
+
g_idx_dtype = torch.int
|
204
|
+
init_g_idx = Parameter(
|
205
|
+
torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype),
|
206
|
+
requires_grad=False,
|
207
|
+
)
|
208
|
+
module.register_parameter(f"{base_name}_g_idx", init_g_idx)
|
209
|
+
|
210
|
+
|
211
|
+
def is_attention_module(module: Module):
|
212
|
+
return "attention" in module.__class__.__name__.lower() and (
|
213
|
+
hasattr(module, "k_proj")
|
214
|
+
or hasattr(module, "v_proj")
|
215
|
+
or hasattr(module, "qkv_proj")
|
216
|
+
)
|
217
|
+
|
218
|
+
|
219
|
+
def _initialize_attn_scales(module: Module) -> None:
|
220
|
+
"""Initlaize k_scale, v_scale for self_attn"""
|
221
|
+
|
222
|
+
expected_shape = 1 # per tensor
|
223
|
+
|
224
|
+
param = next(module.parameters())
|
225
|
+
scale_dtype = param.dtype
|
226
|
+
device = param.device
|
227
|
+
|
228
|
+
init_scale = Parameter(
|
229
|
+
torch.empty(expected_shape, dtype=scale_dtype, device=device),
|
230
|
+
requires_grad=False,
|
231
|
+
)
|
232
|
+
|
233
|
+
module.register_parameter(KVCacheScaleType.KEY.value, init_scale)
|
234
|
+
|
235
|
+
init_scale = Parameter(
|
236
|
+
torch.empty(expected_shape, dtype=scale_dtype, device=device),
|
154
237
|
requires_grad=False,
|
155
238
|
)
|
156
|
-
module.register_parameter(
|
239
|
+
module.register_parameter(KVCacheScaleType.VALUE.value, init_scale)
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import logging
|
16
|
+
from math import ceil
|
16
17
|
from typing import Any, Iterable, Optional, Tuple, Union
|
17
18
|
|
18
19
|
import torch
|
@@ -21,6 +22,7 @@ from compressed_tensors.quantization.quant_args import (
|
|
21
22
|
QuantizationStrategy,
|
22
23
|
)
|
23
24
|
from compressed_tensors.registry.registry import RegistryMixin
|
25
|
+
from compressed_tensors.utils import safe_permute
|
24
26
|
from torch import FloatTensor, IntTensor, Tensor
|
25
27
|
from torch.nn import Module
|
26
28
|
|
@@ -46,15 +48,18 @@ class Observer(Module, RegistryMixin):
|
|
46
48
|
self._num_observed_tokens = None
|
47
49
|
|
48
50
|
@torch.no_grad()
|
49
|
-
def forward(
|
51
|
+
def forward(
|
52
|
+
self, observed: Tensor, g_idx: Optional[Tensor] = None
|
53
|
+
) -> Tuple[FloatTensor, IntTensor]:
|
50
54
|
"""
|
51
55
|
maps directly to get_qparams
|
52
|
-
:param observed: optional observed tensor to calculate
|
53
|
-
|
56
|
+
:param observed: optional observed tensor from which to calculate
|
57
|
+
quantization parameters
|
58
|
+
:param g_idx: optional mapping from column index to group index
|
54
59
|
:return: tuple of scale and zero point based on last observed value
|
55
60
|
"""
|
56
61
|
self.record_observed_tokens(observed)
|
57
|
-
return self.get_qparams(observed=observed)
|
62
|
+
return self.get_qparams(observed=observed, g_idx=g_idx)
|
58
63
|
|
59
64
|
def calculate_qparams(
|
60
65
|
self,
|
@@ -77,7 +82,9 @@ class Observer(Module, RegistryMixin):
|
|
77
82
|
...
|
78
83
|
|
79
84
|
def get_qparams(
|
80
|
-
self,
|
85
|
+
self,
|
86
|
+
observed: Optional[Tensor] = None,
|
87
|
+
g_idx: Optional[Tensor] = None,
|
81
88
|
) -> Tuple[FloatTensor, IntTensor]:
|
82
89
|
"""
|
83
90
|
Convenience function to wrap overwritten calculate_qparams
|
@@ -86,6 +93,7 @@ class Observer(Module, RegistryMixin):
|
|
86
93
|
|
87
94
|
:param observed: optional observed tensor to calculate quantization parameters
|
88
95
|
from
|
96
|
+
:param g_idx: optional mapping from column index to group index
|
89
97
|
:return: tuple of scale and zero point based on last observed value
|
90
98
|
"""
|
91
99
|
if observed is not None:
|
@@ -97,20 +105,42 @@ class Observer(Module, RegistryMixin):
|
|
97
105
|
self._scale, self._zero_point = self.calculate_qparams(observed)
|
98
106
|
|
99
107
|
elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
|
108
|
+
rows = observed.shape[0]
|
100
109
|
columns = observed.shape[1]
|
101
|
-
|
102
|
-
|
103
|
-
|
110
|
+
num_groups = int(ceil(columns / group_size))
|
111
|
+
self._scale = torch.empty(
|
112
|
+
(rows, num_groups), dtype=observed.dtype, device=observed.device
|
113
|
+
)
|
114
|
+
zp_dtype = self.quantization_args.pytorch_dtype()
|
115
|
+
self._zero_point = torch.empty(
|
116
|
+
(rows, num_groups), dtype=zp_dtype, device=observed.device
|
117
|
+
)
|
118
|
+
|
119
|
+
# support column-order (default) quantization as well as other orderings
|
120
|
+
# such as activation ordering. Below checks if g_idx has initialized
|
121
|
+
is_column_order = g_idx is None or -1 in g_idx
|
122
|
+
if is_column_order:
|
123
|
+
group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
|
124
|
+
else:
|
125
|
+
group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
|
126
|
+
group_sizes = group_sizes[torch.argsort(group_indices)]
|
127
|
+
|
128
|
+
perm = torch.argsort(g_idx)
|
129
|
+
observed = safe_permute(observed, perm, dim=1)
|
130
|
+
|
131
|
+
# TODO: experiment with vectorizing for loop for performance
|
132
|
+
end = 0
|
133
|
+
for group_index, group_count in enumerate(group_sizes):
|
134
|
+
start = end
|
135
|
+
end = start + group_count
|
104
136
|
scale, zero_point = self.get_qparams_along_dim(
|
105
|
-
observed[:,
|
137
|
+
observed[:, start:end],
|
106
138
|
0,
|
107
|
-
tensor_id=
|
139
|
+
tensor_id=group_index,
|
108
140
|
)
|
109
|
-
scales.append(scale)
|
110
|
-
zero_points.append(zero_point)
|
111
141
|
|
112
|
-
|
113
|
-
|
142
|
+
self._scale[:, group_index] = scale.squeeze(1)
|
143
|
+
self._zero_point[:, group_index] = zero_point.squeeze(1)
|
114
144
|
|
115
145
|
elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
|
116
146
|
# assume observed is transposed, because its the output, hence use dim 0
|
@@ -132,6 +162,8 @@ class Observer(Module, RegistryMixin):
|
|
132
162
|
dim: Union[int, Iterable[int]],
|
133
163
|
tensor_id: Optional[Any] = None,
|
134
164
|
):
|
165
|
+
if isinstance(dim, int):
|
166
|
+
dim = [dim]
|
135
167
|
dim = set(dim)
|
136
168
|
|
137
169
|
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
|
@@ -171,3 +203,11 @@ class Observer(Module, RegistryMixin):
|
|
171
203
|
# observed_tokens (batch_size * sequence_length)
|
172
204
|
observed_tokens, _ = batch_tensor.shape
|
173
205
|
self._num_observed_tokens += observed_tokens
|
206
|
+
|
207
|
+
def reset(self):
|
208
|
+
"""
|
209
|
+
Reset the state of the observer
|
210
|
+
"""
|
211
|
+
self._num_observed_tokens = None
|
212
|
+
self._scale = None
|
213
|
+
self._zero_point = None
|
@@ -94,3 +94,11 @@ class MovingAverageMinMaxObserver(Observer):
|
|
94
94
|
return self.calculate_qparams(
|
95
95
|
observed, reduce_dims=reduce_dims, tensor_id=tensor_id
|
96
96
|
)
|
97
|
+
|
98
|
+
def reset(self):
|
99
|
+
"""
|
100
|
+
Reset the state of the observer, including min and maximum values
|
101
|
+
"""
|
102
|
+
super().reset()
|
103
|
+
self.min_val = {}
|
104
|
+
self.max_val = {}
|
@@ -0,0 +1,162 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Any, Optional, Tuple
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from compressed_tensors.quantization.observers.base import Observer
|
19
|
+
from compressed_tensors.quantization.observers.helpers import calculate_qparams
|
20
|
+
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
21
|
+
from torch import FloatTensor, IntTensor, Tensor
|
22
|
+
|
23
|
+
|
24
|
+
__all__ = ["MovingAverageMSEObserver"]
|
25
|
+
|
26
|
+
|
27
|
+
@Observer.register("mse")
|
28
|
+
class MovingAverageMSEObserver(Observer):
|
29
|
+
"""
|
30
|
+
Implements a dynamic quantization observer that sets the scale and
|
31
|
+
zero point based on a moving average of the mse-clipped min and max observed values
|
32
|
+
"""
|
33
|
+
|
34
|
+
def __init__(
|
35
|
+
self,
|
36
|
+
quantization_args: QuantizationArgs,
|
37
|
+
averaging_constant: float = 0.01,
|
38
|
+
grid: float = 100.0,
|
39
|
+
maxshrink: float = 0.80,
|
40
|
+
norm: float = 2.4,
|
41
|
+
):
|
42
|
+
super().__init__(quantization_args=quantization_args)
|
43
|
+
|
44
|
+
self.min_val = {}
|
45
|
+
self.max_val = {}
|
46
|
+
self.averaging_constant = averaging_constant
|
47
|
+
self.grid = grid
|
48
|
+
self.maxshrink = maxshrink
|
49
|
+
self.norm = norm
|
50
|
+
|
51
|
+
def calculate_mse_min_max(
|
52
|
+
self,
|
53
|
+
observed: Tensor,
|
54
|
+
reduce_dims: Optional[Tuple[int]] = None,
|
55
|
+
):
|
56
|
+
"""
|
57
|
+
Computes the mse-clipped min and max values of the observed tensor by
|
58
|
+
optimizing for quantization error
|
59
|
+
|
60
|
+
:param observed: observed tensor to calculate quantization parameters for
|
61
|
+
:param reduce_dims: optional tuple of dimensions to reduce along,
|
62
|
+
returned values will be shaped (1,) along the reduced dimensions
|
63
|
+
:return: tuple of min and max values derived from the observed tensor
|
64
|
+
"""
|
65
|
+
from compressed_tensors.quantization.lifecycle import fake_quantize
|
66
|
+
|
67
|
+
if not reduce_dims:
|
68
|
+
absolute_min_val, absolute_max_val = torch.aminmax(observed)
|
69
|
+
else:
|
70
|
+
absolute_min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
|
71
|
+
absolute_max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
|
72
|
+
|
73
|
+
best = torch.full(absolute_min_val.shape, float("inf"))
|
74
|
+
min_val = torch.ones(absolute_min_val.shape)
|
75
|
+
max_val = torch.zeros(absolute_max_val.shape)
|
76
|
+
for i in range(int(self.maxshrink * self.grid)):
|
77
|
+
p = 1 - i / self.grid
|
78
|
+
shrinked_min_val = p * absolute_min_val
|
79
|
+
shrinked_max_val = p * absolute_max_val
|
80
|
+
|
81
|
+
candidate_scales, candidate_zero_points = calculate_qparams(
|
82
|
+
shrinked_min_val, shrinked_max_val, self.quantization_args
|
83
|
+
)
|
84
|
+
q = fake_quantize(
|
85
|
+
observed,
|
86
|
+
candidate_scales,
|
87
|
+
candidate_zero_points,
|
88
|
+
self.quantization_args,
|
89
|
+
)
|
90
|
+
|
91
|
+
q -= observed
|
92
|
+
q.abs_()
|
93
|
+
q.pow_(self.norm)
|
94
|
+
if not reduce_dims:
|
95
|
+
err = torch.sum(q)
|
96
|
+
else:
|
97
|
+
err = torch.sum(q, reduce_dims, keepdims=True)
|
98
|
+
|
99
|
+
tmp = err < best
|
100
|
+
if torch.any(tmp):
|
101
|
+
best[tmp] = err[tmp]
|
102
|
+
min_val[tmp] = shrinked_min_val[tmp]
|
103
|
+
max_val[tmp] = shrinked_max_val[tmp]
|
104
|
+
return min_val, max_val
|
105
|
+
|
106
|
+
def calculate_qparams(
|
107
|
+
self,
|
108
|
+
observed: Tensor,
|
109
|
+
reduce_dims: Optional[Tuple[int]] = None,
|
110
|
+
tensor_id: Optional[Any] = None,
|
111
|
+
) -> Tuple[FloatTensor, IntTensor]:
|
112
|
+
"""
|
113
|
+
Updates the mse-clipped min and max values of the observed tensor using
|
114
|
+
a moving average smoothed by the averaging_constant
|
115
|
+
|
116
|
+
:param observed: observed tensor to calculate quantization parameters for
|
117
|
+
:param reduce_dims: optional tuple of dimensions to reduce along,
|
118
|
+
returned scale and zero point will be shaped (1,) along the
|
119
|
+
reduced dimensions
|
120
|
+
:param tensor_id: Optional id if different ranges of observed tensors are
|
121
|
+
passed, useful for sharding tensors by group_size
|
122
|
+
:return: tuple of scale and zero point derived from the observed tensor
|
123
|
+
"""
|
124
|
+
min_val, max_val = self.calculate_mse_min_max(observed, reduce_dims)
|
125
|
+
|
126
|
+
running_min_val = self.min_val.get(tensor_id, None)
|
127
|
+
running_max_val = self.max_val.get(tensor_id, None)
|
128
|
+
|
129
|
+
if running_min_val is None or running_max_val is None:
|
130
|
+
updated_min_val = min_val
|
131
|
+
updated_max_val = max_val
|
132
|
+
else:
|
133
|
+
updated_min_val = running_min_val + self.averaging_constant * (
|
134
|
+
min_val - running_min_val
|
135
|
+
)
|
136
|
+
updated_max_val = running_max_val + self.averaging_constant * (
|
137
|
+
max_val - running_max_val
|
138
|
+
)
|
139
|
+
|
140
|
+
tensor_id = tensor_id or "default"
|
141
|
+
self.min_val[tensor_id] = updated_min_val
|
142
|
+
self.max_val[tensor_id] = updated_max_val
|
143
|
+
|
144
|
+
return calculate_qparams(
|
145
|
+
updated_min_val, updated_max_val, self.quantization_args
|
146
|
+
)
|
147
|
+
|
148
|
+
def get_qparams_along_dim(
|
149
|
+
self, observed, dim: int, tensor_id: Optional[Any] = None
|
150
|
+
):
|
151
|
+
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
|
152
|
+
return self.calculate_qparams(
|
153
|
+
observed, reduce_dims=reduce_dims, tensor_id=tensor_id
|
154
|
+
)
|
155
|
+
|
156
|
+
def reset(self):
|
157
|
+
"""
|
158
|
+
Reset the state of the observer, including min and maximum values
|
159
|
+
"""
|
160
|
+
super().reset()
|
161
|
+
self.min_val = {}
|
162
|
+
self.max_val = {}
|