compressed-tensors 0.7.1__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/compressors/model_compressors/model_compressor.py +17 -5
- compressed_tensors/compressors/quantized_compressors/naive_quantized.py +4 -2
- compressed_tensors/compressors/quantized_compressors/pack_quantized.py +2 -0
- compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +1 -1
- compressed_tensors/config/base.py +60 -2
- compressed_tensors/linear/compressed_linear.py +3 -1
- compressed_tensors/quantization/__init__.py +0 -1
- compressed_tensors/quantization/lifecycle/__init__.py +0 -2
- compressed_tensors/quantization/lifecycle/apply.py +3 -17
- compressed_tensors/quantization/lifecycle/forward.py +24 -87
- compressed_tensors/quantization/lifecycle/initialize.py +21 -24
- compressed_tensors/quantization/quant_args.py +27 -25
- compressed_tensors/quantization/quant_config.py +2 -2
- compressed_tensors/quantization/quant_scheme.py +17 -24
- compressed_tensors/quantization/utils/helpers.py +125 -8
- compressed_tensors/registry/registry.py +1 -1
- compressed_tensors/utils/helpers.py +33 -1
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.1.dist-info}/METADATA +1 -1
- {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.1.dist-info}/RECORD +23 -31
- {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.1.dist-info}/WHEEL +1 -1
- compressed_tensors/quantization/cache.py +0 -201
- compressed_tensors/quantization/lifecycle/calibration.py +0 -70
- compressed_tensors/quantization/lifecycle/frozen.py +0 -55
- compressed_tensors/quantization/observers/__init__.py +0 -21
- compressed_tensors/quantization/observers/base.py +0 -213
- compressed_tensors/quantization/observers/helpers.py +0 -149
- compressed_tensors/quantization/observers/min_max.py +0 -104
- compressed_tensors/quantization/observers/mse.py +0 -162
- {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.1.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.1.dist-info}/top_level.txt +0 -0
@@ -1,70 +0,0 @@
|
|
1
|
-
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing,
|
10
|
-
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
|
16
|
-
import logging
|
17
|
-
|
18
|
-
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
19
|
-
from compressed_tensors.utils import is_module_offloaded, update_parameter_data
|
20
|
-
from torch.nn import Module
|
21
|
-
|
22
|
-
|
23
|
-
__all__ = [
|
24
|
-
"set_module_for_calibration",
|
25
|
-
]
|
26
|
-
|
27
|
-
|
28
|
-
_LOGGER = logging.getLogger(__name__)
|
29
|
-
|
30
|
-
|
31
|
-
def set_module_for_calibration(module: Module, quantize_weights_upfront: bool = True):
|
32
|
-
"""
|
33
|
-
marks a layer as ready for calibration which activates observers
|
34
|
-
to update scales and zero points on each forward pass
|
35
|
-
|
36
|
-
apply to full model with `model.apply(set_module_for_calibration)`
|
37
|
-
|
38
|
-
:param module: module to set for calibration
|
39
|
-
:param quantize_weights_upfront: whether to automatically
|
40
|
-
run weight quantization at the start of calibration
|
41
|
-
"""
|
42
|
-
if not getattr(module, "quantization_scheme", None):
|
43
|
-
# no quantization scheme nothing to do
|
44
|
-
return
|
45
|
-
status = getattr(module, "quantization_status", None)
|
46
|
-
if not status or status != QuantizationStatus.INITIALIZED:
|
47
|
-
_LOGGER.warning(
|
48
|
-
f"Attempting set module with status {status} to calibration mode. "
|
49
|
-
f"but status is not {QuantizationStatus.INITIALIZED} - you may "
|
50
|
-
"be calibrating an uninitialized module which may fail or attempting "
|
51
|
-
"to re-calibrate a frozen module"
|
52
|
-
)
|
53
|
-
|
54
|
-
if quantize_weights_upfront and module.quantization_scheme.weights is not None:
|
55
|
-
# set weight scale and zero_point up front, calibration data doesn't affect it
|
56
|
-
observer = module.weight_observer
|
57
|
-
g_idx = getattr(module, "weight_g_idx", None)
|
58
|
-
|
59
|
-
offloaded = is_module_offloaded(module)
|
60
|
-
if offloaded:
|
61
|
-
module._hf_hook.pre_forward(module)
|
62
|
-
|
63
|
-
scale, zero_point = observer(module.weight, g_idx=g_idx)
|
64
|
-
update_parameter_data(module, scale, "weight_scale")
|
65
|
-
update_parameter_data(module, zero_point, "weight_zero_point")
|
66
|
-
|
67
|
-
if offloaded:
|
68
|
-
module._hf_hook.post_forward(module, None)
|
69
|
-
|
70
|
-
module.quantization_status = QuantizationStatus.CALIBRATION
|
@@ -1,55 +0,0 @@
|
|
1
|
-
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing,
|
10
|
-
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
|
16
|
-
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
17
|
-
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
|
18
|
-
from torch.nn import Module
|
19
|
-
|
20
|
-
|
21
|
-
__all__ = [
|
22
|
-
"freeze_module_quantization",
|
23
|
-
]
|
24
|
-
|
25
|
-
|
26
|
-
def freeze_module_quantization(module: Module):
|
27
|
-
"""
|
28
|
-
deletes observers so static quantization is completed.
|
29
|
-
|
30
|
-
apply to full model with `model.apply(freeze_module_quantization)`
|
31
|
-
|
32
|
-
:param module: module to freeze quantization for
|
33
|
-
"""
|
34
|
-
scheme = getattr(module, "quantization_scheme", None)
|
35
|
-
if not scheme:
|
36
|
-
# no quantization scheme nothing to do
|
37
|
-
return
|
38
|
-
|
39
|
-
if module.quantization_status == QuantizationStatus.FROZEN:
|
40
|
-
# nothing to do, already frozen
|
41
|
-
return
|
42
|
-
|
43
|
-
# delete observers from module if not dynamic
|
44
|
-
if scheme.input_activations and not scheme.input_activations.dynamic:
|
45
|
-
delattr(module, "input_observer")
|
46
|
-
if scheme.weights and not scheme.weights.dynamic:
|
47
|
-
delattr(module, "weight_observer")
|
48
|
-
if (
|
49
|
-
scheme.output_activations
|
50
|
-
and not is_kv_cache_quant_scheme(scheme)
|
51
|
-
and not scheme.output_activations.dynamic
|
52
|
-
):
|
53
|
-
delattr(module, "output_observer")
|
54
|
-
|
55
|
-
module.quantization_status = QuantizationStatus.FROZEN
|
@@ -1,21 +0,0 @@
|
|
1
|
-
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing,
|
10
|
-
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
# flake8: noqa
|
16
|
-
# isort: skip_file
|
17
|
-
|
18
|
-
from .helpers import *
|
19
|
-
from .base import *
|
20
|
-
from .min_max import *
|
21
|
-
from .mse import *
|
@@ -1,213 +0,0 @@
|
|
1
|
-
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing,
|
10
|
-
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
import logging
|
16
|
-
from math import ceil
|
17
|
-
from typing import Any, Iterable, Optional, Tuple, Union
|
18
|
-
|
19
|
-
import torch
|
20
|
-
from compressed_tensors.quantization.quant_args import (
|
21
|
-
QuantizationArgs,
|
22
|
-
QuantizationStrategy,
|
23
|
-
)
|
24
|
-
from compressed_tensors.registry.registry import RegistryMixin
|
25
|
-
from compressed_tensors.utils import safe_permute
|
26
|
-
from torch import FloatTensor, IntTensor, Tensor
|
27
|
-
from torch.nn import Module
|
28
|
-
|
29
|
-
|
30
|
-
_LOGGER = logging.getLogger(__name__)
|
31
|
-
|
32
|
-
|
33
|
-
__all__ = ["Observer"]
|
34
|
-
|
35
|
-
|
36
|
-
class Observer(Module, RegistryMixin):
|
37
|
-
"""
|
38
|
-
Base Observer class to be subclassed for specific implementation.
|
39
|
-
Subclasses should override `calculate_qparams` to return a scale, zero_point
|
40
|
-
pair
|
41
|
-
"""
|
42
|
-
|
43
|
-
def __init__(self, quantization_args: QuantizationArgs):
|
44
|
-
self.quantization_args: QuantizationArgs = quantization_args
|
45
|
-
super().__init__()
|
46
|
-
self._scale = None
|
47
|
-
self._zero_point = None
|
48
|
-
self._num_observed_tokens = None
|
49
|
-
|
50
|
-
@torch.no_grad()
|
51
|
-
def forward(
|
52
|
-
self, observed: Tensor, g_idx: Optional[Tensor] = None
|
53
|
-
) -> Tuple[FloatTensor, IntTensor]:
|
54
|
-
"""
|
55
|
-
maps directly to get_qparams
|
56
|
-
:param observed: optional observed tensor from which to calculate
|
57
|
-
quantization parameters
|
58
|
-
:param g_idx: optional mapping from column index to group index
|
59
|
-
:return: tuple of scale and zero point based on last observed value
|
60
|
-
"""
|
61
|
-
self.record_observed_tokens(observed)
|
62
|
-
return self.get_qparams(observed=observed, g_idx=g_idx)
|
63
|
-
|
64
|
-
def calculate_qparams(
|
65
|
-
self,
|
66
|
-
observed: Tensor,
|
67
|
-
reduce_dims: Optional[Tuple[int]] = None,
|
68
|
-
) -> Tuple[FloatTensor, IntTensor]:
|
69
|
-
"""
|
70
|
-
:param observed: observed tensor to calculate quantization parameters for
|
71
|
-
:param reduce_dims: optional tuple of dimensions to reduce along,
|
72
|
-
returned scale and zero point will be shaped (1,) along the
|
73
|
-
reduced dimensions
|
74
|
-
:return: tuple of scale and zero point derived from the observed tensor
|
75
|
-
"""
|
76
|
-
raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")
|
77
|
-
|
78
|
-
def post_calculate_qparams(self) -> None:
|
79
|
-
"""
|
80
|
-
Run any logic specific to its observers after running calculate_qparams
|
81
|
-
"""
|
82
|
-
...
|
83
|
-
|
84
|
-
def get_qparams(
|
85
|
-
self,
|
86
|
-
observed: Optional[Tensor] = None,
|
87
|
-
g_idx: Optional[Tensor] = None,
|
88
|
-
) -> Tuple[FloatTensor, IntTensor]:
|
89
|
-
"""
|
90
|
-
Convenience function to wrap overwritten calculate_qparams
|
91
|
-
adds support to make observed tensor optional and support for tracking latest
|
92
|
-
calculated scale and zero point
|
93
|
-
|
94
|
-
:param observed: optional observed tensor to calculate quantization parameters
|
95
|
-
from
|
96
|
-
:param g_idx: optional mapping from column index to group index
|
97
|
-
:return: tuple of scale and zero point based on last observed value
|
98
|
-
"""
|
99
|
-
if observed is not None:
|
100
|
-
group_size = self.quantization_args.group_size
|
101
|
-
|
102
|
-
if self.quantization_args.strategy == QuantizationStrategy.TENSOR:
|
103
|
-
|
104
|
-
# re-calculate scale and zero point, update the stored value
|
105
|
-
self._scale, self._zero_point = self.calculate_qparams(observed)
|
106
|
-
|
107
|
-
elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
|
108
|
-
rows = observed.shape[0]
|
109
|
-
columns = observed.shape[1]
|
110
|
-
num_groups = int(ceil(columns / group_size))
|
111
|
-
self._scale = torch.empty(
|
112
|
-
(rows, num_groups), dtype=observed.dtype, device=observed.device
|
113
|
-
)
|
114
|
-
zp_dtype = self.quantization_args.pytorch_dtype()
|
115
|
-
self._zero_point = torch.empty(
|
116
|
-
(rows, num_groups), dtype=zp_dtype, device=observed.device
|
117
|
-
)
|
118
|
-
|
119
|
-
# support column-order (default) quantization as well as other orderings
|
120
|
-
# such as activation ordering. Below checks if g_idx has initialized
|
121
|
-
is_column_order = g_idx is None or -1 in g_idx
|
122
|
-
if is_column_order:
|
123
|
-
group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
|
124
|
-
else:
|
125
|
-
group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
|
126
|
-
group_sizes = group_sizes[torch.argsort(group_indices)]
|
127
|
-
|
128
|
-
perm = torch.argsort(g_idx)
|
129
|
-
observed = safe_permute(observed, perm, dim=1)
|
130
|
-
|
131
|
-
# TODO: experiment with vectorizing for loop for performance
|
132
|
-
end = 0
|
133
|
-
for group_index, group_count in enumerate(group_sizes):
|
134
|
-
start = end
|
135
|
-
end = start + group_count
|
136
|
-
scale, zero_point = self.get_qparams_along_dim(
|
137
|
-
observed[:, start:end],
|
138
|
-
0,
|
139
|
-
tensor_id=group_index,
|
140
|
-
)
|
141
|
-
|
142
|
-
self._scale[:, group_index] = scale.squeeze(1)
|
143
|
-
self._zero_point[:, group_index] = zero_point.squeeze(1)
|
144
|
-
|
145
|
-
elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
|
146
|
-
# assume observed is transposed, because its the output, hence use dim 0
|
147
|
-
self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)
|
148
|
-
|
149
|
-
elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
|
150
|
-
# use dim 1, assume the obsersed.shape = [batch, token, hidden]
|
151
|
-
# should be batch, token
|
152
|
-
self._scale, self._zero_point = self.get_qparams_along_dim(
|
153
|
-
observed,
|
154
|
-
dim={0, 1},
|
155
|
-
)
|
156
|
-
|
157
|
-
return self._scale, self._zero_point
|
158
|
-
|
159
|
-
def get_qparams_along_dim(
|
160
|
-
self,
|
161
|
-
observed,
|
162
|
-
dim: Union[int, Iterable[int]],
|
163
|
-
tensor_id: Optional[Any] = None,
|
164
|
-
):
|
165
|
-
if isinstance(dim, int):
|
166
|
-
dim = [dim]
|
167
|
-
dim = set(dim)
|
168
|
-
|
169
|
-
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
|
170
|
-
return self.calculate_qparams(
|
171
|
-
observed, reduce_dims=reduce_dims, tensor_id=tensor_id
|
172
|
-
)
|
173
|
-
|
174
|
-
def record_observed_tokens(self, batch_tensor: Tensor):
|
175
|
-
"""
|
176
|
-
Counts the number of tokens observed during the
|
177
|
-
forward passes. The count is aggregated in the
|
178
|
-
_num_observed_tokens attribute of the class.
|
179
|
-
|
180
|
-
Note: The batch_tensor is expected to have two dimensions
|
181
|
-
(batch_size * sequence_length, num_features). This is the
|
182
|
-
general shape expected by the forward pass of the expert
|
183
|
-
layers in a MOE model. If the input tensor does not have
|
184
|
-
two dimensions, the _num_observed_tokens attribute will be set
|
185
|
-
to None.
|
186
|
-
"""
|
187
|
-
if not isinstance(batch_tensor, Tensor):
|
188
|
-
raise ValueError(f"Expected value to be a tensor, got {type(batch_tensor)}")
|
189
|
-
|
190
|
-
if batch_tensor.ndim != 2:
|
191
|
-
_LOGGER.debug(
|
192
|
-
"The input tensor is expected to have two dimensions "
|
193
|
-
"(batch_size * sequence_length, num_features). "
|
194
|
-
f"The input tensor has {batch_tensor.ndim} dimensions."
|
195
|
-
)
|
196
|
-
return
|
197
|
-
|
198
|
-
if self._num_observed_tokens is None:
|
199
|
-
# initialize the count
|
200
|
-
self._num_observed_tokens = 0
|
201
|
-
|
202
|
-
# batch_tensor (batch_size * sequence_length, num_features)
|
203
|
-
# observed_tokens (batch_size * sequence_length)
|
204
|
-
observed_tokens, _ = batch_tensor.shape
|
205
|
-
self._num_observed_tokens += observed_tokens
|
206
|
-
|
207
|
-
def reset(self):
|
208
|
-
"""
|
209
|
-
Reset the state of the observer
|
210
|
-
"""
|
211
|
-
self._num_observed_tokens = None
|
212
|
-
self._scale = None
|
213
|
-
self._zero_point = None
|
@@ -1,149 +0,0 @@
|
|
1
|
-
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing,
|
10
|
-
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
from collections import Counter
|
16
|
-
from typing import Optional, Tuple
|
17
|
-
|
18
|
-
import torch
|
19
|
-
from compressed_tensors.quantization.quant_args import (
|
20
|
-
FP8_DTYPE,
|
21
|
-
QuantizationArgs,
|
22
|
-
QuantizationStrategy,
|
23
|
-
QuantizationType,
|
24
|
-
)
|
25
|
-
from torch import FloatTensor, IntTensor, Tensor
|
26
|
-
|
27
|
-
|
28
|
-
__all__ = [
|
29
|
-
"calculate_qparams",
|
30
|
-
"get_observer_token_count",
|
31
|
-
"calculate_range",
|
32
|
-
"compute_dynamic_scales_and_zp",
|
33
|
-
]
|
34
|
-
|
35
|
-
|
36
|
-
def compute_dynamic_scales_and_zp(value: Tensor, args: QuantizationArgs):
|
37
|
-
"""
|
38
|
-
Returns the computed scales and zero points for dynamic activation
|
39
|
-
qunatization.
|
40
|
-
|
41
|
-
:param value: tensor to calculate quantization parameters for
|
42
|
-
:param args: quantization args
|
43
|
-
:param reduce_dims: optional tuple of dimensions to reduce along,
|
44
|
-
returned scale and zero point will be shaped (1,) along the
|
45
|
-
reduced dimensions
|
46
|
-
:return: tuple of scale and zero point derived from the observed tensor
|
47
|
-
"""
|
48
|
-
if args.strategy == QuantizationStrategy.TOKEN:
|
49
|
-
dim = {1, 2}
|
50
|
-
reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
|
51
|
-
elif args.strategy == QuantizationStrategy.TENSOR:
|
52
|
-
reduce_dims = None
|
53
|
-
else:
|
54
|
-
raise ValueError(
|
55
|
-
f"One of {QuantizationStrategy.TOKEN} or {QuantizationStrategy.TENSOR} ",
|
56
|
-
"must be used for dynamic quantization",
|
57
|
-
)
|
58
|
-
|
59
|
-
if not reduce_dims:
|
60
|
-
min_val, max_val = torch.aminmax(value)
|
61
|
-
else:
|
62
|
-
min_val = torch.amin(value, dim=reduce_dims, keepdims=True)
|
63
|
-
max_val = torch.amax(value, dim=reduce_dims, keepdims=True)
|
64
|
-
|
65
|
-
return calculate_qparams(min_val, max_val, args)
|
66
|
-
|
67
|
-
|
68
|
-
def get_observer_token_count(module: torch.nn.Module) -> Counter:
|
69
|
-
"""
|
70
|
-
Parse the module and return the number of tokens observed by
|
71
|
-
each module's observer.
|
72
|
-
|
73
|
-
:param module: module to parse
|
74
|
-
:return: counter with the number of tokens observed by each observer
|
75
|
-
"""
|
76
|
-
token_counts = Counter()
|
77
|
-
for name, module in module.named_modules():
|
78
|
-
if name.endswith(".input_observer"):
|
79
|
-
token_counts[
|
80
|
-
name.replace(".input_observer", "")
|
81
|
-
] = module._num_observed_tokens
|
82
|
-
return token_counts
|
83
|
-
|
84
|
-
|
85
|
-
def calculate_qparams(
|
86
|
-
min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs
|
87
|
-
) -> Tuple[FloatTensor, IntTensor]:
|
88
|
-
"""
|
89
|
-
:param min_vals: tensor of min value(s) to calculate scale(s) and zero point(s)
|
90
|
-
from
|
91
|
-
:param max_vals: tensor of max value(s) to calculate scale(s) and zero point(s)
|
92
|
-
from
|
93
|
-
:param quantization_args: settings to quantization
|
94
|
-
:return: tuple of the calculated scale(s) and zero point(s)
|
95
|
-
"""
|
96
|
-
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
|
97
|
-
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
98
|
-
device = min_vals.device
|
99
|
-
|
100
|
-
bit_min, bit_max = calculate_range(quantization_args, device)
|
101
|
-
bit_range = bit_max - bit_min
|
102
|
-
zp_dtype = quantization_args.pytorch_dtype()
|
103
|
-
|
104
|
-
if quantization_args.symmetric:
|
105
|
-
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
|
106
|
-
scales = max_val_pos / (float(bit_range) / 2)
|
107
|
-
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
|
108
|
-
zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
|
109
|
-
else:
|
110
|
-
scales = (max_vals - min_vals) / float(bit_range)
|
111
|
-
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
|
112
|
-
zero_points = bit_min - (min_vals / scales)
|
113
|
-
zero_points = torch.clamp(zero_points, bit_min, bit_max)
|
114
|
-
|
115
|
-
# match zero-points to quantized type
|
116
|
-
zero_points = zero_points.to(zp_dtype)
|
117
|
-
|
118
|
-
if scales.ndim == 0:
|
119
|
-
scales = scales.reshape(1)
|
120
|
-
zero_points = zero_points.reshape(1)
|
121
|
-
|
122
|
-
return scales, zero_points
|
123
|
-
|
124
|
-
|
125
|
-
def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
|
126
|
-
"""
|
127
|
-
Calculated the effective quantization range for the given Quantization Args
|
128
|
-
|
129
|
-
:param quantization_args: quantization args to get range of
|
130
|
-
:param device: device to store the range to
|
131
|
-
:return: tuple endpoints for the given quantization range
|
132
|
-
"""
|
133
|
-
if quantization_args.type == QuantizationType.INT:
|
134
|
-
bit_range = 2**quantization_args.num_bits
|
135
|
-
q_max = torch.tensor(bit_range / 2 - 1, device=device)
|
136
|
-
q_min = torch.tensor(-bit_range / 2, device=device)
|
137
|
-
elif quantization_args.type == QuantizationType.FLOAT:
|
138
|
-
if quantization_args.num_bits != 8:
|
139
|
-
raise ValueError(
|
140
|
-
"Floating point quantization is only supported for 8 bits,"
|
141
|
-
f"got {quantization_args.num_bits}"
|
142
|
-
)
|
143
|
-
fp_range_info = torch.finfo(FP8_DTYPE)
|
144
|
-
q_max = torch.tensor(fp_range_info.max, device=device)
|
145
|
-
q_min = torch.tensor(fp_range_info.min, device=device)
|
146
|
-
else:
|
147
|
-
raise ValueError(f"Invalid quantization type {quantization_args.type}")
|
148
|
-
|
149
|
-
return q_min, q_max
|
@@ -1,104 +0,0 @@
|
|
1
|
-
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing,
|
10
|
-
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
from typing import Any, Optional, Tuple
|
16
|
-
|
17
|
-
import torch
|
18
|
-
from compressed_tensors.quantization.observers.base import Observer
|
19
|
-
from compressed_tensors.quantization.observers.helpers import calculate_qparams
|
20
|
-
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
21
|
-
from torch import FloatTensor, IntTensor, Tensor
|
22
|
-
|
23
|
-
|
24
|
-
__all__ = ["MovingAverageMinMaxObserver"]
|
25
|
-
|
26
|
-
|
27
|
-
@Observer.register("minmax")
|
28
|
-
class MovingAverageMinMaxObserver(Observer):
|
29
|
-
"""
|
30
|
-
Implements a dynamic quantization observer that sets the scale and
|
31
|
-
zero point based on a moving average of the overall min and max observed values
|
32
|
-
"""
|
33
|
-
|
34
|
-
def __init__(
|
35
|
-
self, quantization_args: QuantizationArgs, averaging_constant: float = 0.01
|
36
|
-
):
|
37
|
-
super().__init__(quantization_args=quantization_args)
|
38
|
-
|
39
|
-
self.min_val = {}
|
40
|
-
self.max_val = {}
|
41
|
-
self.averaging_constant = averaging_constant
|
42
|
-
|
43
|
-
def calculate_qparams(
|
44
|
-
self,
|
45
|
-
observed: Tensor,
|
46
|
-
reduce_dims: Optional[Tuple[int]] = None,
|
47
|
-
tensor_id: Optional[Any] = None,
|
48
|
-
) -> Tuple[FloatTensor, IntTensor]:
|
49
|
-
"""
|
50
|
-
Updates the observed min and max using a moving average smoothed by the
|
51
|
-
averaging_constant
|
52
|
-
|
53
|
-
:param observed: observed tensor to calculate quantization parameters for
|
54
|
-
:param reduce_dims: optional tuple of dimensions to reduce along,
|
55
|
-
returned scale and zero point will be shaped (1,) along the
|
56
|
-
reduced dimensions
|
57
|
-
:param tensor_id: Optional id if different ranges of observed tensors are
|
58
|
-
passed, useful for sharding tensors by group_size
|
59
|
-
:return: tuple of scale and zero point derived from the observed tensor
|
60
|
-
"""
|
61
|
-
tensor_id = tensor_id or "default"
|
62
|
-
|
63
|
-
if not reduce_dims:
|
64
|
-
min_val, max_val = torch.aminmax(observed)
|
65
|
-
else:
|
66
|
-
min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
|
67
|
-
max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
|
68
|
-
|
69
|
-
running_min_val = self.min_val.get(tensor_id, None)
|
70
|
-
running_max_val = self.max_val.get(tensor_id, None)
|
71
|
-
|
72
|
-
if running_min_val is None or running_max_val is None:
|
73
|
-
updated_min_val = min_val
|
74
|
-
updated_max_val = max_val
|
75
|
-
else:
|
76
|
-
updated_min_val = running_min_val + self.averaging_constant * (
|
77
|
-
min_val - running_min_val
|
78
|
-
)
|
79
|
-
updated_max_val = running_max_val + self.averaging_constant * (
|
80
|
-
max_val - running_max_val
|
81
|
-
)
|
82
|
-
|
83
|
-
self.min_val[tensor_id] = updated_min_val
|
84
|
-
self.max_val[tensor_id] = updated_max_val
|
85
|
-
|
86
|
-
return calculate_qparams(
|
87
|
-
updated_min_val, updated_max_val, self.quantization_args
|
88
|
-
)
|
89
|
-
|
90
|
-
def get_qparams_along_dim(
|
91
|
-
self, observed, dim: int, tensor_id: Optional[Any] = None
|
92
|
-
):
|
93
|
-
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
|
94
|
-
return self.calculate_qparams(
|
95
|
-
observed, reduce_dims=reduce_dims, tensor_id=tensor_id
|
96
|
-
)
|
97
|
-
|
98
|
-
def reset(self):
|
99
|
-
"""
|
100
|
-
Reset the state of the observer, including min and maximum values
|
101
|
-
"""
|
102
|
-
super().reset()
|
103
|
-
self.min_val = {}
|
104
|
-
self.max_val = {}
|