compressed-tensors-nightly 0.3.3.20240514__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. compressed_tensors/__init__.py +21 -0
  2. compressed_tensors/base.py +17 -0
  3. compressed_tensors/compressors/__init__.py +22 -0
  4. compressed_tensors/compressors/base.py +59 -0
  5. compressed_tensors/compressors/dense.py +34 -0
  6. compressed_tensors/compressors/helpers.py +137 -0
  7. compressed_tensors/compressors/int_quantized.py +95 -0
  8. compressed_tensors/compressors/model_compressor.py +264 -0
  9. compressed_tensors/compressors/sparse_bitmask.py +239 -0
  10. compressed_tensors/config/__init__.py +18 -0
  11. compressed_tensors/config/base.py +43 -0
  12. compressed_tensors/config/dense.py +36 -0
  13. compressed_tensors/config/sparse_bitmask.py +36 -0
  14. compressed_tensors/quantization/__init__.py +21 -0
  15. compressed_tensors/quantization/lifecycle/__init__.py +23 -0
  16. compressed_tensors/quantization/lifecycle/apply.py +196 -0
  17. compressed_tensors/quantization/lifecycle/calibration.py +51 -0
  18. compressed_tensors/quantization/lifecycle/compressed.py +69 -0
  19. compressed_tensors/quantization/lifecycle/forward.py +333 -0
  20. compressed_tensors/quantization/lifecycle/frozen.py +50 -0
  21. compressed_tensors/quantization/lifecycle/initialize.py +99 -0
  22. compressed_tensors/quantization/observers/__init__.py +21 -0
  23. compressed_tensors/quantization/observers/base.py +130 -0
  24. compressed_tensors/quantization/observers/helpers.py +54 -0
  25. compressed_tensors/quantization/observers/memoryless.py +48 -0
  26. compressed_tensors/quantization/observers/min_max.py +80 -0
  27. compressed_tensors/quantization/quant_args.py +125 -0
  28. compressed_tensors/quantization/quant_config.py +210 -0
  29. compressed_tensors/quantization/quant_scheme.py +39 -0
  30. compressed_tensors/quantization/utils/__init__.py +16 -0
  31. compressed_tensors/quantization/utils/helpers.py +131 -0
  32. compressed_tensors/registry/__init__.py +17 -0
  33. compressed_tensors/registry/registry.py +360 -0
  34. compressed_tensors/utils/__init__.py +16 -0
  35. compressed_tensors/utils/helpers.py +45 -0
  36. compressed_tensors/utils/safetensors_load.py +237 -0
  37. compressed_tensors/version.py +50 -0
  38. compressed_tensors_nightly-0.3.3.20240514.dist-info/LICENSE +201 -0
  39. compressed_tensors_nightly-0.3.3.20240514.dist-info/METADATA +105 -0
  40. compressed_tensors_nightly-0.3.3.20240514.dist-info/RECORD +42 -0
  41. compressed_tensors_nightly-0.3.3.20240514.dist-info/WHEEL +5 -0
  42. compressed_tensors_nightly-0.3.3.20240514.dist-info/top_level.txt +1 -0
@@ -0,0 +1,130 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Tuple
16
+
17
+ import torch
18
+ from compressed_tensors.quantization.quant_args import (
19
+ QuantizationArgs,
20
+ QuantizationStrategy,
21
+ )
22
+ from compressed_tensors.registry.registry import RegistryMixin
23
+ from torch import FloatTensor, IntTensor, Tensor
24
+ from torch.nn import Module
25
+
26
+
27
+ __all__ = ["Observer"]
28
+
29
+
30
+ class Observer(Module, RegistryMixin):
31
+ """
32
+ Base Observer class to be subclassed for specific implementation.
33
+ Subclasses should override `calculate_qparams` to return a scale, zero_point
34
+ pair
35
+ """
36
+
37
+ def __init__(self, quantization_args: QuantizationArgs):
38
+ self.quantization_args: QuantizationArgs = quantization_args
39
+ super().__init__()
40
+ self._scale = None
41
+ self._zero_point = None
42
+
43
+ @torch.no_grad()
44
+ def forward(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
45
+ """
46
+ maps directly to get_qparams
47
+ :param observed: optional observed tensor to calculate quantization parameters
48
+ from
49
+ :return: tuple of scale and zero point based on last observed value
50
+ """
51
+ return self.get_qparams(observed=observed)
52
+
53
+ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
54
+ """
55
+ :param observed: observed tensor to calculate quantization parameters for
56
+ :return: tuple of scale and zero point derived from the observed tensor
57
+ """
58
+ raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")
59
+
60
+ def post_calculate_qparams(self) -> None:
61
+ """
62
+ Run any logic specific to its observers after running calculate_qparams
63
+ """
64
+ ...
65
+
66
+ def get_qparams(
67
+ self, observed: Optional[Tensor] = None
68
+ ) -> Tuple[FloatTensor, IntTensor]:
69
+ """
70
+ Convenience function to wrap overwritten calculate_qparams
71
+ adds support to make observed tensor optional and support for tracking latest
72
+ calculated scale and zero point
73
+ :param observed: optional observed tensor to calculate quantization parameters
74
+ from
75
+ :return: tuple of scale and zero point based on last observed value
76
+ """
77
+ if observed is not None:
78
+ group_size = self.quantization_args.group_size
79
+
80
+ if self.quantization_args.strategy == QuantizationStrategy.TENSOR:
81
+
82
+ # re-calculate scale and zero point, update the stored value
83
+ self._scale, self._zero_point = self.calculate_qparams(observed)
84
+
85
+ elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
86
+ columns = observed.shape[1]
87
+ scales, zero_points = [], []
88
+ for i in range(0, columns, self.quantization_args.group_size):
89
+ scale, zero_point = self.get_qparams_along_dim(
90
+ observed[:, i : (i + group_size)],
91
+ 0,
92
+ )
93
+ scales.append(scale)
94
+ zero_points.append(zero_point)
95
+ self._scale = torch.stack(scales, dim=1, out=self._scale)
96
+ self._zero_point = torch.stack(zero_points, dim=1, out=self._zero_point)
97
+
98
+ elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
99
+ # assume observed is transposed, because its the output, hence use dim 0
100
+ self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)
101
+
102
+ elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
103
+
104
+ # use dim 1, assume the obsersed.shape = [batch, token, hidden]
105
+ # should be batch, token
106
+
107
+ self._scale, self._zero_point = self.get_qparams_along_dim(
108
+ observed, dim=1
109
+ )
110
+
111
+ return self._scale, self._zero_point
112
+
113
+ def get_qparams_along_dim(self, observed, dim: int):
114
+ # TODO: add documentation that specifies the shape must
115
+ # be padded with 1-dims so the scales are along the right channel
116
+ # TODO: generalize the logic for reduce_dims
117
+ scales, zero_points = [], []
118
+
119
+ # TODO: make a more generic way to get the channel
120
+ num_dims = observed.shape[dim]
121
+
122
+ for dim_idx in range(num_dims):
123
+ scale, zero_point = self.calculate_qparams(
124
+ observed.select(dim=dim, index=dim_idx)
125
+ )
126
+
127
+ scales.append(scale)
128
+ zero_points.append(zero_point)
129
+ # breakpoint()
130
+ return torch.stack(scales), torch.stack(zero_points)
@@ -0,0 +1,54 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Tuple
16
+
17
+ import torch
18
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
19
+ from torch import FloatTensor, IntTensor, Tensor
20
+
21
+
22
+ __all__ = ["calculate_qparams"]
23
+
24
+
25
+ def calculate_qparams(
26
+ min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs
27
+ ) -> Tuple[FloatTensor, IntTensor]:
28
+ """
29
+ :param min_vals: tensor of min value(s) to caluclate scale(s) and zero point(s)
30
+ from
31
+ :param max_vals: tensor of max value(s) to caluclate scale(s) and zero point(s)
32
+ from
33
+ :param quantization_args: settings to quantization
34
+ :return: tuple of the calculated scale(s) and zero point(s)
35
+ """
36
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
37
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
38
+ device = min_vals.device
39
+
40
+ bit_range = 2**quantization_args.num_bits - 1
41
+ bit_min = -(bit_range + 1) / 2
42
+ bit_max = bit_min + bit_range
43
+ if quantization_args.symmetric:
44
+ max_val_pos = torch.max(-min_vals, max_vals)
45
+ scales = max_val_pos / (float(bit_range) / 2)
46
+ scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
47
+ zero_points = torch.zeros(scales.shape, device=device, dtype=torch.int8)
48
+ else:
49
+ scales = (max_vals - min_vals) / float(bit_range)
50
+ scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
51
+ zero_points = bit_min - torch.round(min_vals / scales)
52
+ zero_points = torch.clamp(zero_points, bit_min, bit_max).to(torch.int8)
53
+
54
+ return scales, zero_points
@@ -0,0 +1,48 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Tuple
16
+
17
+ import torch
18
+ from compressed_tensors.quantization.observers.base import Observer
19
+ from compressed_tensors.quantization.observers.helpers import calculate_qparams
20
+ from torch import FloatTensor, IntTensor, Tensor
21
+
22
+
23
+ __all__ = ["MemorylessObserver"]
24
+
25
+
26
+ @Observer.register("memoryless", alias=["dynamic"])
27
+ class MemorylessObserver(Observer):
28
+ """
29
+ Implements a quantization observer that sets the scale and
30
+ zero point based on the latest observed value without tracking state
31
+ """
32
+
33
+ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
34
+ """
35
+ Returns the min and max values of observed
36
+
37
+ :param observed: observed tensor to calculate quantization parameters for
38
+ :return: tuple of scale and zero point derived from the observed tensor
39
+ """
40
+ # TODO: Add support for full range of quantization Args, only supports 8bit
41
+ # per tensor
42
+ min_val, max_val = torch.aminmax(observed)
43
+
44
+ # ensure zero is in the range
45
+ min_val = torch.min(min_val, torch.zeros_like(min_val))
46
+ max_val = torch.max(max_val, torch.zeros_like(max_val))
47
+
48
+ return calculate_qparams(min_val, max_val, self.quantization_args)
@@ -0,0 +1,80 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Tuple
16
+
17
+ import torch
18
+ from compressed_tensors.quantization.observers.base import Observer
19
+ from compressed_tensors.quantization.observers.helpers import calculate_qparams
20
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
21
+ from torch import FloatTensor, IntTensor, Tensor
22
+
23
+
24
+ __all__ = ["MovingAverageMinMaxObserver"]
25
+
26
+
27
+ @Observer.register("minmax")
28
+ class MovingAverageMinMaxObserver(Observer):
29
+ """
30
+ Implements a dynamic quantization observer that sets the scale and
31
+ zero point based on a moving average of the overall min and max observed values
32
+ """
33
+
34
+ def __init__(
35
+ self, quantization_args: QuantizationArgs, averaging_constant: float = 0.01
36
+ ):
37
+ super().__init__(quantization_args=quantization_args)
38
+
39
+ self.min_val = None
40
+ self.max_val = None
41
+ self.averaging_constant = averaging_constant
42
+
43
+ def calculate_qparams(
44
+ self,
45
+ observed: Tensor,
46
+ reduce_dims: Optional[Tuple[int]] = None,
47
+ ) -> Tuple[FloatTensor, IntTensor]:
48
+ """
49
+ Updates the observed min and max using a moving average smoothed by the
50
+ averaging_constant
51
+
52
+ :param observed: observed tensor to calculate quantization parameters for
53
+ :param reduce_dims: optional tuple of dimensions to reduce along,
54
+ returned scale and zero point will be shaped (1,) along the
55
+ reduced dimensions
56
+ :return: tuple of scale and zero point derived from the observed tensor
57
+ """
58
+
59
+ if not reduce_dims:
60
+ min_val, max_val = torch.aminmax(observed)
61
+ else:
62
+ min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
63
+ max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
64
+
65
+ if self.min_val is None and self.max_val is None:
66
+ self.min_val = min_val
67
+ self.max_val = max_val
68
+ else:
69
+ self.min_val = self.min_val + self.averaging_constant * (
70
+ min_val - self.min_val
71
+ )
72
+ self.max_val = self.max_val + self.averaging_constant * (
73
+ max_val - self.max_val
74
+ )
75
+
76
+ return calculate_qparams(self.min_val, self.max_val, self.quantization_args)
77
+
78
+ def get_qparams_along_dim(self, observed, dim: int):
79
+ reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
80
+ return self.calculate_qparams(observed, reduce_dims=reduce_dims)
@@ -0,0 +1,125 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from enum import Enum
16
+ from typing import Any, Dict, Optional
17
+
18
+ from pydantic import BaseModel, Field, validator
19
+
20
+
21
+ __all__ = ["QuantizationType", "QuantizationStrategy", "QuantizationArgs"]
22
+
23
+
24
+ class QuantizationType(str, Enum):
25
+ """
26
+ Enum storing quantization type options
27
+ """
28
+
29
+ INT = "int"
30
+ FLOAT = "float"
31
+
32
+
33
+ class QuantizationStrategy(str, Enum):
34
+ """
35
+ Enum storing quantization strategy options
36
+ """
37
+
38
+ TENSOR = "tensor"
39
+ CHANNEL = "channel"
40
+ GROUP = "group"
41
+ BLOCK = "block"
42
+ TOKEN = "token"
43
+
44
+
45
+ class QuantizationArgs(BaseModel):
46
+ """
47
+ User facing arguments used to define a quantization config for weights or
48
+ activations
49
+
50
+ :param num_bits: quantization bit depth
51
+ :param type: dtype to quantized to, either int or float
52
+ :param symmetric: whether or not quantization scale is symmetric about zero-point
53
+ :param strategy: string id determining the scope of scale/zero-point to apply
54
+ :param group_size: group length to use for the group strategy
55
+ :param block_structure: 2d block structure to use for the block strategy, must be
56
+ of the format "2x4", "8x16", etc.
57
+ :param dynamic: set True to perform dynamic quantization - values will not be
58
+ calibrated during calibration phase, instead during inference new quantization
59
+ ranges will be observed with every sample. Defaults to False for static
60
+ quantization. Note that enabling dynamic quantization will change the default
61
+ observer to a memoryless one
62
+ """
63
+
64
+ num_bits: int = 8
65
+ type: QuantizationType = QuantizationType.INT
66
+ symmetric: bool = True
67
+ group_size: Optional[int] = None
68
+ strategy: Optional[QuantizationStrategy] = None
69
+ block_structure: Optional[str] = None
70
+ dynamic: bool = False
71
+ observer: str = Field(
72
+ default="minmax",
73
+ description=(
74
+ "The class to use to compute the quantization param - "
75
+ "scale and zero-point'"
76
+ ),
77
+ )
78
+ observer_kwargs: Dict[str, Any] = Field(
79
+ default_factory=dict,
80
+ description=(
81
+ "optional dict of kwargs to be passed directly to torch quantization "
82
+ "Observers constructor excluding quantization range or symmetry"
83
+ ),
84
+ )
85
+
86
+ def get_observer(self):
87
+ """
88
+ :return: torch quantization FakeQuantize built based on these QuantizationArgs
89
+ """
90
+ from compressed_tensors.quantization.observers.base import Observer
91
+
92
+ if self.observer == "minmax" and self.dynamic:
93
+ # override defualt observer for dynamic, you never want minmax which
94
+ # keeps state across samples for dynamic
95
+ self.observer = "memoryless"
96
+
97
+ return Observer.load_from_registry(self.observer, quantization_args=self)
98
+
99
+ @validator("strategy", pre=True, always=True)
100
+ def validate_strategy(cls, value, values):
101
+ group_size = values.get("group_size")
102
+
103
+ # use group_size to determinine strategy if not given explicity
104
+ if group_size is not None and value is None:
105
+ if group_size > 0:
106
+ return QuantizationStrategy.GROUP
107
+
108
+ elif group_size == -1:
109
+ return QuantizationStrategy.CHANNEL
110
+
111
+ else:
112
+ raise ValueError(
113
+ f"group_size={group_size} with strategy {value} is invald. "
114
+ "group_size > 0 for strategy='group' and "
115
+ "group_size = -1 for 'channel'"
116
+ )
117
+
118
+ if value == QuantizationStrategy.GROUP:
119
+ if group_size is None:
120
+ raise ValueError(f"strategy {value} requires group_size to be set.")
121
+
122
+ if value is None:
123
+ return QuantizationStrategy.TENSOR
124
+
125
+ return value
@@ -0,0 +1,210 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from enum import Enum
16
+ from typing import Dict, List, Optional
17
+
18
+ from compressed_tensors.base import QUANTIZATION_CONFIG_NAME
19
+ from compressed_tensors.config import CompressionFormat
20
+ from compressed_tensors.quantization.quant_scheme import QuantizationScheme
21
+ from compressed_tensors.quantization.utils import (
22
+ calculate_compression_ratio,
23
+ is_module_quantized,
24
+ iter_named_leaf_modules,
25
+ module_type,
26
+ )
27
+ from pydantic import BaseModel, Field
28
+ from torch.nn import Module
29
+ from transformers import AutoConfig
30
+
31
+
32
+ __all__ = [
33
+ "QuantizationStatus",
34
+ "QuantizationConfig",
35
+ "LIFECYCLE_ORDER",
36
+ ]
37
+
38
+
39
+ class QuantizationStatus(str, Enum):
40
+ """
41
+ Enum storing the different states a quantized layer can be in
42
+
43
+ Initialized: scale, zero points and observers have been attached to the layer but
44
+ are set to dummy values (not yet calibrated)
45
+ Calibration: scale and zero points have been calibrated through OBCQ or similar
46
+ algorithm, observers are still attached
47
+ Frozen: scale and zero points are finalized, observers have been deleted, weights
48
+ are still in their original precision
49
+ Compressed: weights have been converted to their target type or compressed to
50
+ their closed approximation
51
+ """
52
+
53
+ INITIALIZED = "initialized"
54
+ CALIBRATION = "calibration"
55
+ FROZEN = "frozen"
56
+ COMPRESSED = "compressed"
57
+
58
+ @classmethod
59
+ def lifecycle_order(cls) -> List["QuantizationStatus"]:
60
+ """
61
+ :return: list of correct quantization lifecycle order
62
+ """
63
+ return
64
+
65
+ def __ge__(self, other):
66
+ if other is None:
67
+ return True
68
+ if not isinstance(other, self.__class__):
69
+ raise NotImplementedError
70
+ return LIFECYCLE_ORDER.index(self) >= LIFECYCLE_ORDER.index(other)
71
+
72
+ def __gt__(self, other):
73
+ if other is None:
74
+ return True
75
+ if not isinstance(other, self.__class__):
76
+ raise NotImplementedError
77
+ return LIFECYCLE_ORDER.index(self) > LIFECYCLE_ORDER.index(other)
78
+
79
+ def __lt__(self, other):
80
+ if other is None:
81
+ return False
82
+ if not isinstance(other, self.__class__):
83
+ raise NotImplementedError
84
+ return LIFECYCLE_ORDER.index(self) < LIFECYCLE_ORDER.index(other)
85
+
86
+ def __le__(self, other):
87
+ if other is None:
88
+ return False
89
+ if not isinstance(other, self.__class__):
90
+ raise NotImplementedError
91
+ return LIFECYCLE_ORDER.index(self) <= LIFECYCLE_ORDER.index(other)
92
+
93
+
94
+ LIFECYCLE_ORDER = [
95
+ QuantizationStatus.INITIALIZED,
96
+ QuantizationStatus.CALIBRATION,
97
+ QuantizationStatus.FROZEN,
98
+ QuantizationStatus.COMPRESSED,
99
+ ]
100
+
101
+
102
+ class QuantizationConfig(BaseModel):
103
+ """
104
+ Full configuration specifying how a model is quantized. Each quantized layer is
105
+ mapped to a QuantizationScheme in config_groups.
106
+
107
+ :param config_groups: dict of QuantizationSchemes specifying the quantization
108
+ settings for each quantized layer
109
+ :param quant_method: a constant used to differentiate sparseML quantization from
110
+ other quantization configs
111
+ :param format: specifies how the quantized model is stored on disk
112
+ :quantization_status: specifies the current status of all quantized layers. It is
113
+ assumed all layers are in the same state.
114
+ :global_compression_ratio: optional informational config to report the model
115
+ compression ratio acheived by the quantization config
116
+ :ignore: optional list of layers to ignore from config_groups. Layers in this list
117
+ are not quantized even if they match up with a target in config_groups
118
+ """
119
+
120
+ config_groups: Dict[str, QuantizationScheme]
121
+ quant_method: str = "sparseml"
122
+ format: str = "fakequant"
123
+ quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
124
+ global_compression_ratio: Optional[float] = None
125
+ ignore: Optional[List[str]] = Field(default_factory=list)
126
+
127
+ @staticmethod
128
+ def from_model_config(model_name_or_path) -> "QuantizationConfig":
129
+ """
130
+ Given a path to a model config, extract a quantization config if it exists
131
+
132
+ :param pretrained_model_name_or_path: path to model config on disk or HF hub
133
+ :return: instantiated QuantizationConfig if config contains a quant config
134
+ """
135
+ config = AutoConfig.from_pretrained(model_name_or_path)
136
+ quantization_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)
137
+ if quantization_config is None:
138
+ return None
139
+
140
+ return QuantizationConfig.parse_obj(quantization_config)
141
+
142
+ @staticmethod
143
+ def from_pretrained(
144
+ model: Module, format: Optional[str] = None
145
+ ) -> Optional["QuantizationConfig"]:
146
+ """
147
+ Converts a model into its associated QuantizationConfig based on the
148
+ QuantizationScheme attached to each quanitzed module
149
+
150
+ :param model: model to calculate quantization scheme of
151
+ :return: filled out QuantizationScheme for the input model
152
+ """
153
+ quant_scheme_to_layers = []
154
+ quantization_status = None
155
+ ignore = {}
156
+ quantization_type_names = set()
157
+ for name, submodule in iter_named_leaf_modules(model):
158
+ layer_type = module_type(submodule)
159
+ if not is_module_quantized(submodule):
160
+ if layer_type not in ignore:
161
+ ignore[layer_type] = []
162
+ ignore[layer_type].append(name)
163
+ else:
164
+ quantization_status = submodule.quantization_status
165
+ scheme = submodule.quantization_scheme
166
+ quantization_type_names.add(layer_type)
167
+
168
+ match_found = False
169
+ for existing_scheme in quant_scheme_to_layers:
170
+ if scheme == existing_scheme:
171
+ match_found = True
172
+ break
173
+ if not match_found:
174
+ quant_scheme_to_layers.append(scheme)
175
+
176
+ if len(quant_scheme_to_layers) == 0: # No quantized layers
177
+ return None
178
+
179
+ # clean up ignore list, we can leave out layers types if none of the
180
+ # instances are quantized
181
+ consolidated_ignore = []
182
+ for layer_type, ignore_names in ignore.items():
183
+ if layer_type in quantization_type_names:
184
+ # specific layers of a quantized type are ignored
185
+ consolidated_ignore += ignore_names
186
+ # else we leave it off the ignore list, doesn't fall under any of the
187
+ # existing quantization schemes so it won't be quantized
188
+
189
+ config_groups = {}
190
+ for idx, scheme in enumerate(quant_scheme_to_layers):
191
+ group_name = "group_" + str(idx)
192
+ config_groups[group_name] = scheme
193
+
194
+ # TODO: this is incorrect in compressed mode, since we are overwriting the
195
+ # original weight we lose the uncompressed bit_depth indo
196
+ compression_ratio = calculate_compression_ratio(model)
197
+
198
+ if format is None:
199
+ if quantization_status == QuantizationStatus.COMPRESSED:
200
+ format = CompressionFormat.int_quantized.value
201
+ else:
202
+ format = CompressionFormat.dense.value
203
+
204
+ return QuantizationConfig(
205
+ config_groups=config_groups,
206
+ quantization_status=quantization_status,
207
+ global_compression_ratio=compression_ratio,
208
+ format=format,
209
+ ignore=consolidated_ignore,
210
+ )
@@ -0,0 +1,39 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Optional
16
+
17
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
18
+ from pydantic import BaseModel
19
+
20
+
21
+ __all__ = ["QuantizationScheme"]
22
+
23
+
24
+ class QuantizationScheme(BaseModel):
25
+ """
26
+ Set of QuantizationArgs defining how the weights, inputs and outputs of target list
27
+ of modules should be quantized
28
+
29
+ :param targets: list of modules to apply the QuantizationArgs to, can be layer
30
+ names, layer types or a regular expression
31
+ :param weights: quantization config for layer weights
32
+ :param input_activations: quantization config for layer inputs
33
+ :param output_activations: quantization config for layer outputs
34
+ """
35
+
36
+ targets: List[str]
37
+ weights: Optional[QuantizationArgs] = None
38
+ input_activations: Optional[QuantizationArgs] = None
39
+ output_activations: Optional[QuantizationArgs] = None