compressed-tensors-nightly 0.3.3.20240514__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/__init__.py +21 -0
- compressed_tensors/base.py +17 -0
- compressed_tensors/compressors/__init__.py +22 -0
- compressed_tensors/compressors/base.py +59 -0
- compressed_tensors/compressors/dense.py +34 -0
- compressed_tensors/compressors/helpers.py +137 -0
- compressed_tensors/compressors/int_quantized.py +95 -0
- compressed_tensors/compressors/model_compressor.py +264 -0
- compressed_tensors/compressors/sparse_bitmask.py +239 -0
- compressed_tensors/config/__init__.py +18 -0
- compressed_tensors/config/base.py +43 -0
- compressed_tensors/config/dense.py +36 -0
- compressed_tensors/config/sparse_bitmask.py +36 -0
- compressed_tensors/quantization/__init__.py +21 -0
- compressed_tensors/quantization/lifecycle/__init__.py +23 -0
- compressed_tensors/quantization/lifecycle/apply.py +196 -0
- compressed_tensors/quantization/lifecycle/calibration.py +51 -0
- compressed_tensors/quantization/lifecycle/compressed.py +69 -0
- compressed_tensors/quantization/lifecycle/forward.py +333 -0
- compressed_tensors/quantization/lifecycle/frozen.py +50 -0
- compressed_tensors/quantization/lifecycle/initialize.py +99 -0
- compressed_tensors/quantization/observers/__init__.py +21 -0
- compressed_tensors/quantization/observers/base.py +130 -0
- compressed_tensors/quantization/observers/helpers.py +54 -0
- compressed_tensors/quantization/observers/memoryless.py +48 -0
- compressed_tensors/quantization/observers/min_max.py +80 -0
- compressed_tensors/quantization/quant_args.py +125 -0
- compressed_tensors/quantization/quant_config.py +210 -0
- compressed_tensors/quantization/quant_scheme.py +39 -0
- compressed_tensors/quantization/utils/__init__.py +16 -0
- compressed_tensors/quantization/utils/helpers.py +131 -0
- compressed_tensors/registry/__init__.py +17 -0
- compressed_tensors/registry/registry.py +360 -0
- compressed_tensors/utils/__init__.py +16 -0
- compressed_tensors/utils/helpers.py +45 -0
- compressed_tensors/utils/safetensors_load.py +237 -0
- compressed_tensors/version.py +50 -0
- compressed_tensors_nightly-0.3.3.20240514.dist-info/LICENSE +201 -0
- compressed_tensors_nightly-0.3.3.20240514.dist-info/METADATA +105 -0
- compressed_tensors_nightly-0.3.3.20240514.dist-info/RECORD +42 -0
- compressed_tensors_nightly-0.3.3.20240514.dist-info/WHEEL +5 -0
- compressed_tensors_nightly-0.3.3.20240514.dist-info/top_level.txt +1 -0
@@ -0,0 +1,130 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional, Tuple
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from compressed_tensors.quantization.quant_args import (
|
19
|
+
QuantizationArgs,
|
20
|
+
QuantizationStrategy,
|
21
|
+
)
|
22
|
+
from compressed_tensors.registry.registry import RegistryMixin
|
23
|
+
from torch import FloatTensor, IntTensor, Tensor
|
24
|
+
from torch.nn import Module
|
25
|
+
|
26
|
+
|
27
|
+
__all__ = ["Observer"]
|
28
|
+
|
29
|
+
|
30
|
+
class Observer(Module, RegistryMixin):
|
31
|
+
"""
|
32
|
+
Base Observer class to be subclassed for specific implementation.
|
33
|
+
Subclasses should override `calculate_qparams` to return a scale, zero_point
|
34
|
+
pair
|
35
|
+
"""
|
36
|
+
|
37
|
+
def __init__(self, quantization_args: QuantizationArgs):
|
38
|
+
self.quantization_args: QuantizationArgs = quantization_args
|
39
|
+
super().__init__()
|
40
|
+
self._scale = None
|
41
|
+
self._zero_point = None
|
42
|
+
|
43
|
+
@torch.no_grad()
|
44
|
+
def forward(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
|
45
|
+
"""
|
46
|
+
maps directly to get_qparams
|
47
|
+
:param observed: optional observed tensor to calculate quantization parameters
|
48
|
+
from
|
49
|
+
:return: tuple of scale and zero point based on last observed value
|
50
|
+
"""
|
51
|
+
return self.get_qparams(observed=observed)
|
52
|
+
|
53
|
+
def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
|
54
|
+
"""
|
55
|
+
:param observed: observed tensor to calculate quantization parameters for
|
56
|
+
:return: tuple of scale and zero point derived from the observed tensor
|
57
|
+
"""
|
58
|
+
raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")
|
59
|
+
|
60
|
+
def post_calculate_qparams(self) -> None:
|
61
|
+
"""
|
62
|
+
Run any logic specific to its observers after running calculate_qparams
|
63
|
+
"""
|
64
|
+
...
|
65
|
+
|
66
|
+
def get_qparams(
|
67
|
+
self, observed: Optional[Tensor] = None
|
68
|
+
) -> Tuple[FloatTensor, IntTensor]:
|
69
|
+
"""
|
70
|
+
Convenience function to wrap overwritten calculate_qparams
|
71
|
+
adds support to make observed tensor optional and support for tracking latest
|
72
|
+
calculated scale and zero point
|
73
|
+
:param observed: optional observed tensor to calculate quantization parameters
|
74
|
+
from
|
75
|
+
:return: tuple of scale and zero point based on last observed value
|
76
|
+
"""
|
77
|
+
if observed is not None:
|
78
|
+
group_size = self.quantization_args.group_size
|
79
|
+
|
80
|
+
if self.quantization_args.strategy == QuantizationStrategy.TENSOR:
|
81
|
+
|
82
|
+
# re-calculate scale and zero point, update the stored value
|
83
|
+
self._scale, self._zero_point = self.calculate_qparams(observed)
|
84
|
+
|
85
|
+
elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
|
86
|
+
columns = observed.shape[1]
|
87
|
+
scales, zero_points = [], []
|
88
|
+
for i in range(0, columns, self.quantization_args.group_size):
|
89
|
+
scale, zero_point = self.get_qparams_along_dim(
|
90
|
+
observed[:, i : (i + group_size)],
|
91
|
+
0,
|
92
|
+
)
|
93
|
+
scales.append(scale)
|
94
|
+
zero_points.append(zero_point)
|
95
|
+
self._scale = torch.stack(scales, dim=1, out=self._scale)
|
96
|
+
self._zero_point = torch.stack(zero_points, dim=1, out=self._zero_point)
|
97
|
+
|
98
|
+
elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
|
99
|
+
# assume observed is transposed, because its the output, hence use dim 0
|
100
|
+
self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)
|
101
|
+
|
102
|
+
elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
|
103
|
+
|
104
|
+
# use dim 1, assume the obsersed.shape = [batch, token, hidden]
|
105
|
+
# should be batch, token
|
106
|
+
|
107
|
+
self._scale, self._zero_point = self.get_qparams_along_dim(
|
108
|
+
observed, dim=1
|
109
|
+
)
|
110
|
+
|
111
|
+
return self._scale, self._zero_point
|
112
|
+
|
113
|
+
def get_qparams_along_dim(self, observed, dim: int):
|
114
|
+
# TODO: add documentation that specifies the shape must
|
115
|
+
# be padded with 1-dims so the scales are along the right channel
|
116
|
+
# TODO: generalize the logic for reduce_dims
|
117
|
+
scales, zero_points = [], []
|
118
|
+
|
119
|
+
# TODO: make a more generic way to get the channel
|
120
|
+
num_dims = observed.shape[dim]
|
121
|
+
|
122
|
+
for dim_idx in range(num_dims):
|
123
|
+
scale, zero_point = self.calculate_qparams(
|
124
|
+
observed.select(dim=dim, index=dim_idx)
|
125
|
+
)
|
126
|
+
|
127
|
+
scales.append(scale)
|
128
|
+
zero_points.append(zero_point)
|
129
|
+
# breakpoint()
|
130
|
+
return torch.stack(scales), torch.stack(zero_points)
|
@@ -0,0 +1,54 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Tuple
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
19
|
+
from torch import FloatTensor, IntTensor, Tensor
|
20
|
+
|
21
|
+
|
22
|
+
__all__ = ["calculate_qparams"]
|
23
|
+
|
24
|
+
|
25
|
+
def calculate_qparams(
|
26
|
+
min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs
|
27
|
+
) -> Tuple[FloatTensor, IntTensor]:
|
28
|
+
"""
|
29
|
+
:param min_vals: tensor of min value(s) to caluclate scale(s) and zero point(s)
|
30
|
+
from
|
31
|
+
:param max_vals: tensor of max value(s) to caluclate scale(s) and zero point(s)
|
32
|
+
from
|
33
|
+
:param quantization_args: settings to quantization
|
34
|
+
:return: tuple of the calculated scale(s) and zero point(s)
|
35
|
+
"""
|
36
|
+
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
|
37
|
+
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
38
|
+
device = min_vals.device
|
39
|
+
|
40
|
+
bit_range = 2**quantization_args.num_bits - 1
|
41
|
+
bit_min = -(bit_range + 1) / 2
|
42
|
+
bit_max = bit_min + bit_range
|
43
|
+
if quantization_args.symmetric:
|
44
|
+
max_val_pos = torch.max(-min_vals, max_vals)
|
45
|
+
scales = max_val_pos / (float(bit_range) / 2)
|
46
|
+
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
|
47
|
+
zero_points = torch.zeros(scales.shape, device=device, dtype=torch.int8)
|
48
|
+
else:
|
49
|
+
scales = (max_vals - min_vals) / float(bit_range)
|
50
|
+
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
|
51
|
+
zero_points = bit_min - torch.round(min_vals / scales)
|
52
|
+
zero_points = torch.clamp(zero_points, bit_min, bit_max).to(torch.int8)
|
53
|
+
|
54
|
+
return scales, zero_points
|
@@ -0,0 +1,48 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Tuple
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from compressed_tensors.quantization.observers.base import Observer
|
19
|
+
from compressed_tensors.quantization.observers.helpers import calculate_qparams
|
20
|
+
from torch import FloatTensor, IntTensor, Tensor
|
21
|
+
|
22
|
+
|
23
|
+
__all__ = ["MemorylessObserver"]
|
24
|
+
|
25
|
+
|
26
|
+
@Observer.register("memoryless", alias=["dynamic"])
|
27
|
+
class MemorylessObserver(Observer):
|
28
|
+
"""
|
29
|
+
Implements a quantization observer that sets the scale and
|
30
|
+
zero point based on the latest observed value without tracking state
|
31
|
+
"""
|
32
|
+
|
33
|
+
def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
|
34
|
+
"""
|
35
|
+
Returns the min and max values of observed
|
36
|
+
|
37
|
+
:param observed: observed tensor to calculate quantization parameters for
|
38
|
+
:return: tuple of scale and zero point derived from the observed tensor
|
39
|
+
"""
|
40
|
+
# TODO: Add support for full range of quantization Args, only supports 8bit
|
41
|
+
# per tensor
|
42
|
+
min_val, max_val = torch.aminmax(observed)
|
43
|
+
|
44
|
+
# ensure zero is in the range
|
45
|
+
min_val = torch.min(min_val, torch.zeros_like(min_val))
|
46
|
+
max_val = torch.max(max_val, torch.zeros_like(max_val))
|
47
|
+
|
48
|
+
return calculate_qparams(min_val, max_val, self.quantization_args)
|
@@ -0,0 +1,80 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional, Tuple
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from compressed_tensors.quantization.observers.base import Observer
|
19
|
+
from compressed_tensors.quantization.observers.helpers import calculate_qparams
|
20
|
+
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
21
|
+
from torch import FloatTensor, IntTensor, Tensor
|
22
|
+
|
23
|
+
|
24
|
+
__all__ = ["MovingAverageMinMaxObserver"]
|
25
|
+
|
26
|
+
|
27
|
+
@Observer.register("minmax")
|
28
|
+
class MovingAverageMinMaxObserver(Observer):
|
29
|
+
"""
|
30
|
+
Implements a dynamic quantization observer that sets the scale and
|
31
|
+
zero point based on a moving average of the overall min and max observed values
|
32
|
+
"""
|
33
|
+
|
34
|
+
def __init__(
|
35
|
+
self, quantization_args: QuantizationArgs, averaging_constant: float = 0.01
|
36
|
+
):
|
37
|
+
super().__init__(quantization_args=quantization_args)
|
38
|
+
|
39
|
+
self.min_val = None
|
40
|
+
self.max_val = None
|
41
|
+
self.averaging_constant = averaging_constant
|
42
|
+
|
43
|
+
def calculate_qparams(
|
44
|
+
self,
|
45
|
+
observed: Tensor,
|
46
|
+
reduce_dims: Optional[Tuple[int]] = None,
|
47
|
+
) -> Tuple[FloatTensor, IntTensor]:
|
48
|
+
"""
|
49
|
+
Updates the observed min and max using a moving average smoothed by the
|
50
|
+
averaging_constant
|
51
|
+
|
52
|
+
:param observed: observed tensor to calculate quantization parameters for
|
53
|
+
:param reduce_dims: optional tuple of dimensions to reduce along,
|
54
|
+
returned scale and zero point will be shaped (1,) along the
|
55
|
+
reduced dimensions
|
56
|
+
:return: tuple of scale and zero point derived from the observed tensor
|
57
|
+
"""
|
58
|
+
|
59
|
+
if not reduce_dims:
|
60
|
+
min_val, max_val = torch.aminmax(observed)
|
61
|
+
else:
|
62
|
+
min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
|
63
|
+
max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
|
64
|
+
|
65
|
+
if self.min_val is None and self.max_val is None:
|
66
|
+
self.min_val = min_val
|
67
|
+
self.max_val = max_val
|
68
|
+
else:
|
69
|
+
self.min_val = self.min_val + self.averaging_constant * (
|
70
|
+
min_val - self.min_val
|
71
|
+
)
|
72
|
+
self.max_val = self.max_val + self.averaging_constant * (
|
73
|
+
max_val - self.max_val
|
74
|
+
)
|
75
|
+
|
76
|
+
return calculate_qparams(self.min_val, self.max_val, self.quantization_args)
|
77
|
+
|
78
|
+
def get_qparams_along_dim(self, observed, dim: int):
|
79
|
+
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
|
80
|
+
return self.calculate_qparams(observed, reduce_dims=reduce_dims)
|
@@ -0,0 +1,125 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from enum import Enum
|
16
|
+
from typing import Any, Dict, Optional
|
17
|
+
|
18
|
+
from pydantic import BaseModel, Field, validator
|
19
|
+
|
20
|
+
|
21
|
+
__all__ = ["QuantizationType", "QuantizationStrategy", "QuantizationArgs"]
|
22
|
+
|
23
|
+
|
24
|
+
class QuantizationType(str, Enum):
|
25
|
+
"""
|
26
|
+
Enum storing quantization type options
|
27
|
+
"""
|
28
|
+
|
29
|
+
INT = "int"
|
30
|
+
FLOAT = "float"
|
31
|
+
|
32
|
+
|
33
|
+
class QuantizationStrategy(str, Enum):
|
34
|
+
"""
|
35
|
+
Enum storing quantization strategy options
|
36
|
+
"""
|
37
|
+
|
38
|
+
TENSOR = "tensor"
|
39
|
+
CHANNEL = "channel"
|
40
|
+
GROUP = "group"
|
41
|
+
BLOCK = "block"
|
42
|
+
TOKEN = "token"
|
43
|
+
|
44
|
+
|
45
|
+
class QuantizationArgs(BaseModel):
|
46
|
+
"""
|
47
|
+
User facing arguments used to define a quantization config for weights or
|
48
|
+
activations
|
49
|
+
|
50
|
+
:param num_bits: quantization bit depth
|
51
|
+
:param type: dtype to quantized to, either int or float
|
52
|
+
:param symmetric: whether or not quantization scale is symmetric about zero-point
|
53
|
+
:param strategy: string id determining the scope of scale/zero-point to apply
|
54
|
+
:param group_size: group length to use for the group strategy
|
55
|
+
:param block_structure: 2d block structure to use for the block strategy, must be
|
56
|
+
of the format "2x4", "8x16", etc.
|
57
|
+
:param dynamic: set True to perform dynamic quantization - values will not be
|
58
|
+
calibrated during calibration phase, instead during inference new quantization
|
59
|
+
ranges will be observed with every sample. Defaults to False for static
|
60
|
+
quantization. Note that enabling dynamic quantization will change the default
|
61
|
+
observer to a memoryless one
|
62
|
+
"""
|
63
|
+
|
64
|
+
num_bits: int = 8
|
65
|
+
type: QuantizationType = QuantizationType.INT
|
66
|
+
symmetric: bool = True
|
67
|
+
group_size: Optional[int] = None
|
68
|
+
strategy: Optional[QuantizationStrategy] = None
|
69
|
+
block_structure: Optional[str] = None
|
70
|
+
dynamic: bool = False
|
71
|
+
observer: str = Field(
|
72
|
+
default="minmax",
|
73
|
+
description=(
|
74
|
+
"The class to use to compute the quantization param - "
|
75
|
+
"scale and zero-point'"
|
76
|
+
),
|
77
|
+
)
|
78
|
+
observer_kwargs: Dict[str, Any] = Field(
|
79
|
+
default_factory=dict,
|
80
|
+
description=(
|
81
|
+
"optional dict of kwargs to be passed directly to torch quantization "
|
82
|
+
"Observers constructor excluding quantization range or symmetry"
|
83
|
+
),
|
84
|
+
)
|
85
|
+
|
86
|
+
def get_observer(self):
|
87
|
+
"""
|
88
|
+
:return: torch quantization FakeQuantize built based on these QuantizationArgs
|
89
|
+
"""
|
90
|
+
from compressed_tensors.quantization.observers.base import Observer
|
91
|
+
|
92
|
+
if self.observer == "minmax" and self.dynamic:
|
93
|
+
# override defualt observer for dynamic, you never want minmax which
|
94
|
+
# keeps state across samples for dynamic
|
95
|
+
self.observer = "memoryless"
|
96
|
+
|
97
|
+
return Observer.load_from_registry(self.observer, quantization_args=self)
|
98
|
+
|
99
|
+
@validator("strategy", pre=True, always=True)
|
100
|
+
def validate_strategy(cls, value, values):
|
101
|
+
group_size = values.get("group_size")
|
102
|
+
|
103
|
+
# use group_size to determinine strategy if not given explicity
|
104
|
+
if group_size is not None and value is None:
|
105
|
+
if group_size > 0:
|
106
|
+
return QuantizationStrategy.GROUP
|
107
|
+
|
108
|
+
elif group_size == -1:
|
109
|
+
return QuantizationStrategy.CHANNEL
|
110
|
+
|
111
|
+
else:
|
112
|
+
raise ValueError(
|
113
|
+
f"group_size={group_size} with strategy {value} is invald. "
|
114
|
+
"group_size > 0 for strategy='group' and "
|
115
|
+
"group_size = -1 for 'channel'"
|
116
|
+
)
|
117
|
+
|
118
|
+
if value == QuantizationStrategy.GROUP:
|
119
|
+
if group_size is None:
|
120
|
+
raise ValueError(f"strategy {value} requires group_size to be set.")
|
121
|
+
|
122
|
+
if value is None:
|
123
|
+
return QuantizationStrategy.TENSOR
|
124
|
+
|
125
|
+
return value
|
@@ -0,0 +1,210 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from enum import Enum
|
16
|
+
from typing import Dict, List, Optional
|
17
|
+
|
18
|
+
from compressed_tensors.base import QUANTIZATION_CONFIG_NAME
|
19
|
+
from compressed_tensors.config import CompressionFormat
|
20
|
+
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
21
|
+
from compressed_tensors.quantization.utils import (
|
22
|
+
calculate_compression_ratio,
|
23
|
+
is_module_quantized,
|
24
|
+
iter_named_leaf_modules,
|
25
|
+
module_type,
|
26
|
+
)
|
27
|
+
from pydantic import BaseModel, Field
|
28
|
+
from torch.nn import Module
|
29
|
+
from transformers import AutoConfig
|
30
|
+
|
31
|
+
|
32
|
+
__all__ = [
|
33
|
+
"QuantizationStatus",
|
34
|
+
"QuantizationConfig",
|
35
|
+
"LIFECYCLE_ORDER",
|
36
|
+
]
|
37
|
+
|
38
|
+
|
39
|
+
class QuantizationStatus(str, Enum):
|
40
|
+
"""
|
41
|
+
Enum storing the different states a quantized layer can be in
|
42
|
+
|
43
|
+
Initialized: scale, zero points and observers have been attached to the layer but
|
44
|
+
are set to dummy values (not yet calibrated)
|
45
|
+
Calibration: scale and zero points have been calibrated through OBCQ or similar
|
46
|
+
algorithm, observers are still attached
|
47
|
+
Frozen: scale and zero points are finalized, observers have been deleted, weights
|
48
|
+
are still in their original precision
|
49
|
+
Compressed: weights have been converted to their target type or compressed to
|
50
|
+
their closed approximation
|
51
|
+
"""
|
52
|
+
|
53
|
+
INITIALIZED = "initialized"
|
54
|
+
CALIBRATION = "calibration"
|
55
|
+
FROZEN = "frozen"
|
56
|
+
COMPRESSED = "compressed"
|
57
|
+
|
58
|
+
@classmethod
|
59
|
+
def lifecycle_order(cls) -> List["QuantizationStatus"]:
|
60
|
+
"""
|
61
|
+
:return: list of correct quantization lifecycle order
|
62
|
+
"""
|
63
|
+
return
|
64
|
+
|
65
|
+
def __ge__(self, other):
|
66
|
+
if other is None:
|
67
|
+
return True
|
68
|
+
if not isinstance(other, self.__class__):
|
69
|
+
raise NotImplementedError
|
70
|
+
return LIFECYCLE_ORDER.index(self) >= LIFECYCLE_ORDER.index(other)
|
71
|
+
|
72
|
+
def __gt__(self, other):
|
73
|
+
if other is None:
|
74
|
+
return True
|
75
|
+
if not isinstance(other, self.__class__):
|
76
|
+
raise NotImplementedError
|
77
|
+
return LIFECYCLE_ORDER.index(self) > LIFECYCLE_ORDER.index(other)
|
78
|
+
|
79
|
+
def __lt__(self, other):
|
80
|
+
if other is None:
|
81
|
+
return False
|
82
|
+
if not isinstance(other, self.__class__):
|
83
|
+
raise NotImplementedError
|
84
|
+
return LIFECYCLE_ORDER.index(self) < LIFECYCLE_ORDER.index(other)
|
85
|
+
|
86
|
+
def __le__(self, other):
|
87
|
+
if other is None:
|
88
|
+
return False
|
89
|
+
if not isinstance(other, self.__class__):
|
90
|
+
raise NotImplementedError
|
91
|
+
return LIFECYCLE_ORDER.index(self) <= LIFECYCLE_ORDER.index(other)
|
92
|
+
|
93
|
+
|
94
|
+
LIFECYCLE_ORDER = [
|
95
|
+
QuantizationStatus.INITIALIZED,
|
96
|
+
QuantizationStatus.CALIBRATION,
|
97
|
+
QuantizationStatus.FROZEN,
|
98
|
+
QuantizationStatus.COMPRESSED,
|
99
|
+
]
|
100
|
+
|
101
|
+
|
102
|
+
class QuantizationConfig(BaseModel):
|
103
|
+
"""
|
104
|
+
Full configuration specifying how a model is quantized. Each quantized layer is
|
105
|
+
mapped to a QuantizationScheme in config_groups.
|
106
|
+
|
107
|
+
:param config_groups: dict of QuantizationSchemes specifying the quantization
|
108
|
+
settings for each quantized layer
|
109
|
+
:param quant_method: a constant used to differentiate sparseML quantization from
|
110
|
+
other quantization configs
|
111
|
+
:param format: specifies how the quantized model is stored on disk
|
112
|
+
:quantization_status: specifies the current status of all quantized layers. It is
|
113
|
+
assumed all layers are in the same state.
|
114
|
+
:global_compression_ratio: optional informational config to report the model
|
115
|
+
compression ratio acheived by the quantization config
|
116
|
+
:ignore: optional list of layers to ignore from config_groups. Layers in this list
|
117
|
+
are not quantized even if they match up with a target in config_groups
|
118
|
+
"""
|
119
|
+
|
120
|
+
config_groups: Dict[str, QuantizationScheme]
|
121
|
+
quant_method: str = "sparseml"
|
122
|
+
format: str = "fakequant"
|
123
|
+
quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
|
124
|
+
global_compression_ratio: Optional[float] = None
|
125
|
+
ignore: Optional[List[str]] = Field(default_factory=list)
|
126
|
+
|
127
|
+
@staticmethod
|
128
|
+
def from_model_config(model_name_or_path) -> "QuantizationConfig":
|
129
|
+
"""
|
130
|
+
Given a path to a model config, extract a quantization config if it exists
|
131
|
+
|
132
|
+
:param pretrained_model_name_or_path: path to model config on disk or HF hub
|
133
|
+
:return: instantiated QuantizationConfig if config contains a quant config
|
134
|
+
"""
|
135
|
+
config = AutoConfig.from_pretrained(model_name_or_path)
|
136
|
+
quantization_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)
|
137
|
+
if quantization_config is None:
|
138
|
+
return None
|
139
|
+
|
140
|
+
return QuantizationConfig.parse_obj(quantization_config)
|
141
|
+
|
142
|
+
@staticmethod
|
143
|
+
def from_pretrained(
|
144
|
+
model: Module, format: Optional[str] = None
|
145
|
+
) -> Optional["QuantizationConfig"]:
|
146
|
+
"""
|
147
|
+
Converts a model into its associated QuantizationConfig based on the
|
148
|
+
QuantizationScheme attached to each quanitzed module
|
149
|
+
|
150
|
+
:param model: model to calculate quantization scheme of
|
151
|
+
:return: filled out QuantizationScheme for the input model
|
152
|
+
"""
|
153
|
+
quant_scheme_to_layers = []
|
154
|
+
quantization_status = None
|
155
|
+
ignore = {}
|
156
|
+
quantization_type_names = set()
|
157
|
+
for name, submodule in iter_named_leaf_modules(model):
|
158
|
+
layer_type = module_type(submodule)
|
159
|
+
if not is_module_quantized(submodule):
|
160
|
+
if layer_type not in ignore:
|
161
|
+
ignore[layer_type] = []
|
162
|
+
ignore[layer_type].append(name)
|
163
|
+
else:
|
164
|
+
quantization_status = submodule.quantization_status
|
165
|
+
scheme = submodule.quantization_scheme
|
166
|
+
quantization_type_names.add(layer_type)
|
167
|
+
|
168
|
+
match_found = False
|
169
|
+
for existing_scheme in quant_scheme_to_layers:
|
170
|
+
if scheme == existing_scheme:
|
171
|
+
match_found = True
|
172
|
+
break
|
173
|
+
if not match_found:
|
174
|
+
quant_scheme_to_layers.append(scheme)
|
175
|
+
|
176
|
+
if len(quant_scheme_to_layers) == 0: # No quantized layers
|
177
|
+
return None
|
178
|
+
|
179
|
+
# clean up ignore list, we can leave out layers types if none of the
|
180
|
+
# instances are quantized
|
181
|
+
consolidated_ignore = []
|
182
|
+
for layer_type, ignore_names in ignore.items():
|
183
|
+
if layer_type in quantization_type_names:
|
184
|
+
# specific layers of a quantized type are ignored
|
185
|
+
consolidated_ignore += ignore_names
|
186
|
+
# else we leave it off the ignore list, doesn't fall under any of the
|
187
|
+
# existing quantization schemes so it won't be quantized
|
188
|
+
|
189
|
+
config_groups = {}
|
190
|
+
for idx, scheme in enumerate(quant_scheme_to_layers):
|
191
|
+
group_name = "group_" + str(idx)
|
192
|
+
config_groups[group_name] = scheme
|
193
|
+
|
194
|
+
# TODO: this is incorrect in compressed mode, since we are overwriting the
|
195
|
+
# original weight we lose the uncompressed bit_depth indo
|
196
|
+
compression_ratio = calculate_compression_ratio(model)
|
197
|
+
|
198
|
+
if format is None:
|
199
|
+
if quantization_status == QuantizationStatus.COMPRESSED:
|
200
|
+
format = CompressionFormat.int_quantized.value
|
201
|
+
else:
|
202
|
+
format = CompressionFormat.dense.value
|
203
|
+
|
204
|
+
return QuantizationConfig(
|
205
|
+
config_groups=config_groups,
|
206
|
+
quantization_status=quantization_status,
|
207
|
+
global_compression_ratio=compression_ratio,
|
208
|
+
format=format,
|
209
|
+
ignore=consolidated_ignore,
|
210
|
+
)
|
@@ -0,0 +1,39 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import List, Optional
|
16
|
+
|
17
|
+
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
18
|
+
from pydantic import BaseModel
|
19
|
+
|
20
|
+
|
21
|
+
__all__ = ["QuantizationScheme"]
|
22
|
+
|
23
|
+
|
24
|
+
class QuantizationScheme(BaseModel):
|
25
|
+
"""
|
26
|
+
Set of QuantizationArgs defining how the weights, inputs and outputs of target list
|
27
|
+
of modules should be quantized
|
28
|
+
|
29
|
+
:param targets: list of modules to apply the QuantizationArgs to, can be layer
|
30
|
+
names, layer types or a regular expression
|
31
|
+
:param weights: quantization config for layer weights
|
32
|
+
:param input_activations: quantization config for layer inputs
|
33
|
+
:param output_activations: quantization config for layer outputs
|
34
|
+
"""
|
35
|
+
|
36
|
+
targets: List[str]
|
37
|
+
weights: Optional[QuantizationArgs] = None
|
38
|
+
input_activations: Optional[QuantizationArgs] = None
|
39
|
+
output_activations: Optional[QuantizationArgs] = None
|