compressed-tensors 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/base.py +2 -1
- compressed_tensors/compressors/__init__.py +5 -1
- compressed_tensors/compressors/base.py +11 -54
- compressed_tensors/compressors/dense.py +4 -4
- compressed_tensors/compressors/helpers.py +12 -12
- compressed_tensors/compressors/int_quantized.py +126 -0
- compressed_tensors/compressors/marlin_24.py +250 -0
- compressed_tensors/compressors/model_compressor.py +315 -0
- compressed_tensors/compressors/pack_quantized.py +212 -0
- compressed_tensors/compressors/sparse_bitmask.py +4 -4
- compressed_tensors/compressors/utils/__init__.py +19 -0
- compressed_tensors/compressors/utils/helpers.py +43 -0
- compressed_tensors/compressors/utils/permutations_24.py +65 -0
- compressed_tensors/compressors/utils/semi_structured_conversions.py +341 -0
- compressed_tensors/config/base.py +7 -4
- compressed_tensors/config/dense.py +4 -4
- compressed_tensors/config/sparse_bitmask.py +3 -3
- compressed_tensors/quantization/lifecycle/__init__.py +1 -0
- compressed_tensors/quantization/lifecycle/apply.py +75 -19
- compressed_tensors/quantization/lifecycle/compressed.py +69 -0
- compressed_tensors/quantization/lifecycle/forward.py +208 -22
- compressed_tensors/quantization/lifecycle/frozen.py +4 -0
- compressed_tensors/quantization/lifecycle/initialize.py +33 -5
- compressed_tensors/quantization/observers/base.py +70 -5
- compressed_tensors/quantization/observers/helpers.py +6 -1
- compressed_tensors/quantization/observers/memoryless.py +17 -9
- compressed_tensors/quantization/observers/min_max.py +44 -13
- compressed_tensors/quantization/quant_args.py +33 -4
- compressed_tensors/quantization/quant_config.py +69 -21
- compressed_tensors/quantization/quant_scheme.py +81 -1
- compressed_tensors/quantization/utils/helpers.py +77 -8
- compressed_tensors/utils/helpers.py +26 -122
- compressed_tensors/utils/safetensors_load.py +3 -2
- compressed_tensors/version.py +53 -0
- {compressed_tensors-0.3.2.dist-info → compressed_tensors-0.4.0.dist-info}/METADATA +46 -9
- compressed_tensors-0.4.0.dist-info/RECORD +48 -0
- compressed_tensors-0.3.2.dist-info/RECORD +0 -38
- {compressed_tensors-0.3.2.dist-info → compressed_tensors-0.4.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.3.2.dist-info → compressed_tensors-0.4.0.dist-info}/WHEEL +0 -0
- {compressed_tensors-0.3.2.dist-info → compressed_tensors-0.4.0.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,7 @@
|
|
15
15
|
from enum import Enum
|
16
16
|
from typing import Any, Dict, Optional
|
17
17
|
|
18
|
-
from pydantic import BaseModel, Field
|
18
|
+
from pydantic import BaseModel, Field, validator
|
19
19
|
|
20
20
|
|
21
21
|
__all__ = ["QuantizationType", "QuantizationStrategy", "QuantizationArgs"]
|
@@ -39,9 +39,10 @@ class QuantizationStrategy(str, Enum):
|
|
39
39
|
CHANNEL = "channel"
|
40
40
|
GROUP = "group"
|
41
41
|
BLOCK = "block"
|
42
|
+
TOKEN = "token"
|
42
43
|
|
43
44
|
|
44
|
-
class QuantizationArgs(BaseModel):
|
45
|
+
class QuantizationArgs(BaseModel, use_enum_values=True):
|
45
46
|
"""
|
46
47
|
User facing arguments used to define a quantization config for weights or
|
47
48
|
activations
|
@@ -61,10 +62,10 @@ class QuantizationArgs(BaseModel):
|
|
61
62
|
"""
|
62
63
|
|
63
64
|
num_bits: int = 8
|
64
|
-
type: QuantizationType = QuantizationType.INT
|
65
|
+
type: QuantizationType = QuantizationType.INT.value
|
65
66
|
symmetric: bool = True
|
66
|
-
strategy: QuantizationStrategy = QuantizationStrategy.TENSOR
|
67
67
|
group_size: Optional[int] = None
|
68
|
+
strategy: Optional[QuantizationStrategy] = None
|
68
69
|
block_structure: Optional[str] = None
|
69
70
|
dynamic: bool = False
|
70
71
|
observer: str = Field(
|
@@ -94,3 +95,31 @@ class QuantizationArgs(BaseModel):
|
|
94
95
|
self.observer = "memoryless"
|
95
96
|
|
96
97
|
return Observer.load_from_registry(self.observer, quantization_args=self)
|
98
|
+
|
99
|
+
@validator("strategy", pre=True, always=True)
|
100
|
+
def validate_strategy(cls, value, values):
|
101
|
+
group_size = values.get("group_size")
|
102
|
+
|
103
|
+
# use group_size to determinine strategy if not given explicity
|
104
|
+
if group_size is not None and value is None:
|
105
|
+
if group_size > 0:
|
106
|
+
return QuantizationStrategy.GROUP
|
107
|
+
|
108
|
+
elif group_size == -1:
|
109
|
+
return QuantizationStrategy.CHANNEL
|
110
|
+
|
111
|
+
else:
|
112
|
+
raise ValueError(
|
113
|
+
f"group_size={group_size} with strategy {value} is invald. "
|
114
|
+
"group_size > 0 for strategy='group' and "
|
115
|
+
"group_size = -1 for 'channel'"
|
116
|
+
)
|
117
|
+
|
118
|
+
if value == QuantizationStrategy.GROUP:
|
119
|
+
if group_size is None:
|
120
|
+
raise ValueError(f"strategy {value} requires group_size to be set.")
|
121
|
+
|
122
|
+
if value is None:
|
123
|
+
return QuantizationStrategy.TENSOR
|
124
|
+
|
125
|
+
return value
|
@@ -13,10 +13,13 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from enum import Enum
|
16
|
-
from typing import Dict, List, Optional
|
16
|
+
from typing import Dict, List, Optional, Union
|
17
17
|
|
18
|
-
from compressed_tensors.
|
19
|
-
from compressed_tensors.quantization.quant_scheme import
|
18
|
+
from compressed_tensors.config import CompressionFormat
|
19
|
+
from compressed_tensors.quantization.quant_scheme import (
|
20
|
+
QuantizationScheme,
|
21
|
+
preset_name_to_scheme,
|
22
|
+
)
|
20
23
|
from compressed_tensors.quantization.utils import (
|
21
24
|
calculate_compression_ratio,
|
22
25
|
is_module_quantized,
|
@@ -25,13 +28,14 @@ from compressed_tensors.quantization.utils import (
|
|
25
28
|
)
|
26
29
|
from pydantic import BaseModel, Field
|
27
30
|
from torch.nn import Module
|
28
|
-
from transformers import AutoConfig
|
29
31
|
|
30
32
|
|
31
33
|
__all__ = [
|
32
34
|
"QuantizationStatus",
|
33
35
|
"QuantizationConfig",
|
34
36
|
"LIFECYCLE_ORDER",
|
37
|
+
"DEFAULT_QUANTIZATION_METHOD",
|
38
|
+
"DEFAULT_QUANTIZATION_FORMAT",
|
35
39
|
]
|
36
40
|
|
37
41
|
|
@@ -62,10 +66,33 @@ class QuantizationStatus(str, Enum):
|
|
62
66
|
return
|
63
67
|
|
64
68
|
def __ge__(self, other):
|
69
|
+
if other is None:
|
70
|
+
return True
|
65
71
|
if not isinstance(other, self.__class__):
|
66
72
|
raise NotImplementedError
|
67
73
|
return LIFECYCLE_ORDER.index(self) >= LIFECYCLE_ORDER.index(other)
|
68
74
|
|
75
|
+
def __gt__(self, other):
|
76
|
+
if other is None:
|
77
|
+
return True
|
78
|
+
if not isinstance(other, self.__class__):
|
79
|
+
raise NotImplementedError
|
80
|
+
return LIFECYCLE_ORDER.index(self) > LIFECYCLE_ORDER.index(other)
|
81
|
+
|
82
|
+
def __lt__(self, other):
|
83
|
+
if other is None:
|
84
|
+
return False
|
85
|
+
if not isinstance(other, self.__class__):
|
86
|
+
raise NotImplementedError
|
87
|
+
return LIFECYCLE_ORDER.index(self) < LIFECYCLE_ORDER.index(other)
|
88
|
+
|
89
|
+
def __le__(self, other):
|
90
|
+
if other is None:
|
91
|
+
return False
|
92
|
+
if not isinstance(other, self.__class__):
|
93
|
+
raise NotImplementedError
|
94
|
+
return LIFECYCLE_ORDER.index(self) <= LIFECYCLE_ORDER.index(other)
|
95
|
+
|
69
96
|
|
70
97
|
LIFECYCLE_ORDER = [
|
71
98
|
QuantizationStatus.INITIALIZED,
|
@@ -74,6 +101,9 @@ LIFECYCLE_ORDER = [
|
|
74
101
|
QuantizationStatus.COMPRESSED,
|
75
102
|
]
|
76
103
|
|
104
|
+
DEFAULT_QUANTIZATION_METHOD = "compressed-tensors"
|
105
|
+
DEFAULT_QUANTIZATION_FORMAT = "fakequant"
|
106
|
+
|
77
107
|
|
78
108
|
class QuantizationConfig(BaseModel):
|
79
109
|
"""
|
@@ -81,7 +111,8 @@ class QuantizationConfig(BaseModel):
|
|
81
111
|
mapped to a QuantizationScheme in config_groups.
|
82
112
|
|
83
113
|
:param config_groups: dict of QuantizationSchemes specifying the quantization
|
84
|
-
settings for each quantized layer
|
114
|
+
settings for each quantized layer. A group could also be a reference to
|
115
|
+
a predefined scheme name, mapped to a list of its target layers/classes
|
85
116
|
:param quant_method: a constant used to differentiate sparseML quantization from
|
86
117
|
other quantization configs
|
87
118
|
:param format: specifies how the quantized model is stored on disk
|
@@ -93,30 +124,34 @@ class QuantizationConfig(BaseModel):
|
|
93
124
|
are not quantized even if they match up with a target in config_groups
|
94
125
|
"""
|
95
126
|
|
96
|
-
config_groups: Dict[str, QuantizationScheme]
|
97
|
-
quant_method: str =
|
98
|
-
format: str =
|
127
|
+
config_groups: Dict[str, Union[QuantizationScheme, List[str]]]
|
128
|
+
quant_method: str = DEFAULT_QUANTIZATION_METHOD
|
129
|
+
format: str = DEFAULT_QUANTIZATION_FORMAT
|
99
130
|
quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
|
100
131
|
global_compression_ratio: Optional[float] = None
|
101
132
|
ignore: Optional[List[str]] = Field(default_factory=list)
|
102
133
|
|
103
|
-
|
104
|
-
def from_model_config(model_name_or_path) -> "QuantizationConfig":
|
134
|
+
def model_post_init(self, __context):
|
105
135
|
"""
|
106
|
-
|
107
|
-
|
108
|
-
:param pretrained_model_name_or_path: path to model config on disk or HF hub
|
109
|
-
:return: instantiated QuantizationConfig if config contains a quant config
|
136
|
+
updates any quantization schemes defined as presets to be fully loaded
|
137
|
+
schemes
|
110
138
|
"""
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
139
|
+
for group_name, targets_or_scheme in self.config_groups.items():
|
140
|
+
if isinstance(targets_or_scheme, QuantizationScheme):
|
141
|
+
continue # scheme already defined
|
142
|
+
self.config_groups[group_name] = preset_name_to_scheme(
|
143
|
+
name=group_name,
|
144
|
+
targets=targets_or_scheme,
|
145
|
+
)
|
146
|
+
|
147
|
+
def to_dict(self):
|
148
|
+
# for compatibility with HFQuantizer
|
149
|
+
return self.dict()
|
117
150
|
|
118
151
|
@staticmethod
|
119
|
-
def from_pretrained(
|
152
|
+
def from_pretrained(
|
153
|
+
model: Module, format: Optional[str] = None
|
154
|
+
) -> Optional["QuantizationConfig"]:
|
120
155
|
"""
|
121
156
|
Converts a model into its associated QuantizationConfig based on the
|
122
157
|
QuantizationScheme attached to each quanitzed module
|
@@ -147,6 +182,9 @@ class QuantizationConfig(BaseModel):
|
|
147
182
|
if not match_found:
|
148
183
|
quant_scheme_to_layers.append(scheme)
|
149
184
|
|
185
|
+
if len(quant_scheme_to_layers) == 0: # No quantized layers
|
186
|
+
return None
|
187
|
+
|
150
188
|
# clean up ignore list, we can leave out layers types if none of the
|
151
189
|
# instances are quantized
|
152
190
|
consolidated_ignore = []
|
@@ -162,10 +200,20 @@ class QuantizationConfig(BaseModel):
|
|
162
200
|
group_name = "group_" + str(idx)
|
163
201
|
config_groups[group_name] = scheme
|
164
202
|
|
203
|
+
# TODO: this is incorrect in compressed mode, since we are overwriting the
|
204
|
+
# original weight we lose the uncompressed bit_depth indo
|
165
205
|
compression_ratio = calculate_compression_ratio(model)
|
206
|
+
|
207
|
+
if format is None:
|
208
|
+
if quantization_status == QuantizationStatus.COMPRESSED:
|
209
|
+
format = CompressionFormat.int_quantized.value
|
210
|
+
else:
|
211
|
+
format = CompressionFormat.dense.value
|
212
|
+
|
166
213
|
return QuantizationConfig(
|
167
214
|
config_groups=config_groups,
|
168
215
|
quantization_status=quantization_status,
|
169
216
|
global_compression_ratio=compression_ratio,
|
217
|
+
format=format,
|
170
218
|
ignore=consolidated_ignore,
|
171
219
|
)
|
@@ -12,13 +12,18 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
from copy import deepcopy
|
15
16
|
from typing import List, Optional
|
16
17
|
|
17
18
|
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
18
19
|
from pydantic import BaseModel
|
19
20
|
|
20
21
|
|
21
|
-
__all__ = [
|
22
|
+
__all__ = [
|
23
|
+
"QuantizationScheme",
|
24
|
+
"preset_name_to_scheme",
|
25
|
+
"is_preset_scheme",
|
26
|
+
]
|
22
27
|
|
23
28
|
|
24
29
|
class QuantizationScheme(BaseModel):
|
@@ -37,3 +42,78 @@ class QuantizationScheme(BaseModel):
|
|
37
42
|
weights: Optional[QuantizationArgs] = None
|
38
43
|
input_activations: Optional[QuantizationArgs] = None
|
39
44
|
output_activations: Optional[QuantizationArgs] = None
|
45
|
+
|
46
|
+
@classmethod
|
47
|
+
def default_scheme(
|
48
|
+
cls,
|
49
|
+
targets: Optional[List[str]] = None,
|
50
|
+
):
|
51
|
+
|
52
|
+
if targets is None:
|
53
|
+
# default to quantizing all Linear layers
|
54
|
+
targets = ["Linear"]
|
55
|
+
|
56
|
+
# default to 8 bit integer symmetric quantization
|
57
|
+
# for weights
|
58
|
+
weights = QuantizationArgs(num_bits=8, symmetric=True)
|
59
|
+
|
60
|
+
# default to 8 bit integer asymmetric quantization
|
61
|
+
input_activations = QuantizationArgs(num_bits=8, symmetric=True)
|
62
|
+
|
63
|
+
# Do not quantize the output activations
|
64
|
+
# by default
|
65
|
+
output_activations = None
|
66
|
+
|
67
|
+
return cls(
|
68
|
+
targets=targets,
|
69
|
+
weights=weights,
|
70
|
+
input_activations=input_activations,
|
71
|
+
output_activations=output_activations,
|
72
|
+
)
|
73
|
+
|
74
|
+
|
75
|
+
"""
|
76
|
+
Pre-Set Quantization Scheme Args
|
77
|
+
"""
|
78
|
+
|
79
|
+
|
80
|
+
def preset_name_to_scheme(name: str, targets: List[str]) -> QuantizationScheme:
|
81
|
+
"""
|
82
|
+
:param name: preset quantization settings name. must exist in upper case in
|
83
|
+
PRESET_SCHEMES
|
84
|
+
:param targets: list of quantization targets to be passed to the Scheme
|
85
|
+
:return: new QuantizationScheme for a given name with the given targets
|
86
|
+
"""
|
87
|
+
name = name.upper()
|
88
|
+
|
89
|
+
if name not in PRESET_SCHEMES:
|
90
|
+
raise KeyError(
|
91
|
+
f"Unknown preset scheme name {name}, "
|
92
|
+
f"available names: {list(PRESET_SCHEMES.keys())}"
|
93
|
+
)
|
94
|
+
|
95
|
+
scheme_args = deepcopy(PRESET_SCHEMES[name]) # deepcopy to avoid args references
|
96
|
+
return QuantizationScheme(
|
97
|
+
targets=targets,
|
98
|
+
**scheme_args,
|
99
|
+
)
|
100
|
+
|
101
|
+
|
102
|
+
def is_preset_scheme(name: str) -> bool:
|
103
|
+
"""
|
104
|
+
:param name: preset quantization settings name
|
105
|
+
:return: True if the name is a preset scheme name
|
106
|
+
"""
|
107
|
+
return name.upper() in PRESET_SCHEMES
|
108
|
+
|
109
|
+
|
110
|
+
W8A8 = dict(
|
111
|
+
weights=QuantizationArgs(), input_activations=QuantizationArgs(symmetric=True)
|
112
|
+
)
|
113
|
+
|
114
|
+
W4A16 = dict(weights=QuantizationArgs(num_bits=4, group_size=128))
|
115
|
+
|
116
|
+
PRESET_SCHEMES = {
|
117
|
+
"W8A8": W8A8,
|
118
|
+
"W4A16": W4A16,
|
119
|
+
}
|
@@ -12,21 +12,43 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
|
15
|
+
import logging
|
16
|
+
from typing import Optional, Tuple
|
16
17
|
|
17
18
|
import torch
|
19
|
+
from compressed_tensors.quantization.observers.base import Observer
|
18
20
|
from torch.nn import Module
|
19
21
|
from tqdm import tqdm
|
20
22
|
|
21
23
|
|
22
24
|
__all__ = [
|
25
|
+
"infer_quantization_status",
|
23
26
|
"is_module_quantized",
|
24
27
|
"is_model_quantized",
|
25
28
|
"iter_named_leaf_modules",
|
26
29
|
"module_type",
|
27
30
|
"calculate_compression_ratio",
|
31
|
+
"get_torch_bit_depth",
|
32
|
+
"can_quantize",
|
28
33
|
]
|
29
34
|
|
35
|
+
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
36
|
+
|
37
|
+
|
38
|
+
def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]: # noqa
|
39
|
+
"""
|
40
|
+
Checks the quantization status of a model. Assumes all modules in the model have
|
41
|
+
the same status, so only the first quantized model is checked.
|
42
|
+
|
43
|
+
:param model: model to check quantization status for
|
44
|
+
:return: quantization status if the model is quantized, otherwise None
|
45
|
+
"""
|
46
|
+
for module in model.modules():
|
47
|
+
status = getattr(module, "quantization_status", None)
|
48
|
+
if status is not None:
|
49
|
+
return status
|
50
|
+
return None
|
51
|
+
|
30
52
|
|
31
53
|
def is_module_quantized(module: Module) -> bool:
|
32
54
|
"""
|
@@ -78,11 +100,60 @@ def module_type(module: Module) -> str:
|
|
78
100
|
|
79
101
|
|
80
102
|
def iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
|
81
|
-
|
82
|
-
|
103
|
+
"""
|
104
|
+
Yields modules that do not have any submodules except observers. The observers
|
105
|
+
themselves are not yielded
|
106
|
+
|
107
|
+
:param model: model to get leaf modules of
|
108
|
+
:returns: generator tuple of (name, leaf_submodule)
|
109
|
+
"""
|
83
110
|
for name, submodule in model.named_modules():
|
84
|
-
|
111
|
+
children = list(submodule.children())
|
112
|
+
if len(children) == 0 and not isinstance(submodule, Observer):
|
85
113
|
yield name, submodule
|
114
|
+
else:
|
115
|
+
has_non_observer_children = False
|
116
|
+
for child in children:
|
117
|
+
if not isinstance(child, Observer):
|
118
|
+
has_non_observer_children = True
|
119
|
+
|
120
|
+
if not has_non_observer_children:
|
121
|
+
yield name, submodule
|
122
|
+
|
123
|
+
|
124
|
+
def get_torch_bit_depth(value: torch.Tensor) -> int:
|
125
|
+
"""
|
126
|
+
Determine the number of bits used to represent the dtype of a tensor
|
127
|
+
|
128
|
+
:param value: tensor to check bit depth of
|
129
|
+
:return: bit depth of each element in the value tensor
|
130
|
+
"""
|
131
|
+
try:
|
132
|
+
bit_depth = torch.finfo(value.dtype).bits
|
133
|
+
except TypeError:
|
134
|
+
bit_depth = torch.iinfo(value.dtype).bits
|
135
|
+
|
136
|
+
return bit_depth
|
137
|
+
|
138
|
+
|
139
|
+
def can_quantize(value: torch.Tensor, quant_args: "QuantizationArgs") -> bool: # noqa
|
140
|
+
"""
|
141
|
+
Checks if value can be quantized by quant_args.
|
142
|
+
|
143
|
+
:param value: tensor to check for quantization
|
144
|
+
:param quant_args: QuantizationArgs to use for quantization
|
145
|
+
:return: False if value is already quantized to quant_args or value is incompatible
|
146
|
+
with quant_args, True if value can be quantized with quant_args
|
147
|
+
"""
|
148
|
+
bit_depth = get_torch_bit_depth(value)
|
149
|
+
requested_depth = quant_args.num_bits
|
150
|
+
if bit_depth < quant_args.num_bits:
|
151
|
+
_LOGGER.warn(
|
152
|
+
f"Can't quantize tensor with bit depth {bit_depth} to {requested_depth}."
|
153
|
+
"The QuantizationArgs provided are not compatible with the input tensor."
|
154
|
+
)
|
155
|
+
|
156
|
+
return bit_depth > quant_args.num_bits
|
86
157
|
|
87
158
|
|
88
159
|
def calculate_compression_ratio(model: Module) -> float:
|
@@ -101,13 +172,11 @@ def calculate_compression_ratio(model: Module) -> float:
|
|
101
172
|
desc="Calculating quantization compression ratio",
|
102
173
|
):
|
103
174
|
for parameter in model.parameters():
|
104
|
-
|
105
|
-
uncompressed_bits = torch.finfo(parameter.dtype).bits
|
106
|
-
except TypeError:
|
107
|
-
uncompressed_bits = torch.iinfo(parameter.dtype).bits
|
175
|
+
uncompressed_bits = get_torch_bit_depth(parameter)
|
108
176
|
compressed_bits = uncompressed_bits
|
109
177
|
if is_module_quantized(submodule):
|
110
178
|
compressed_bits = submodule.quantization_scheme.weights.num_bits
|
179
|
+
|
111
180
|
num_weights = parameter.numel()
|
112
181
|
total_compressed += compressed_bits * num_weights
|
113
182
|
total_uncompressed += uncompressed_bits * num_weights
|
@@ -12,47 +12,20 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from pathlib import Path
|
16
|
-
from typing import Dict, Optional, Union
|
17
15
|
|
18
|
-
import
|
19
|
-
|
20
|
-
from compressed_tensors.compressors import ModelCompressor
|
21
|
-
from compressed_tensors.config import (
|
22
|
-
CompressionConfig,
|
23
|
-
CompressionFormat,
|
24
|
-
DenseSparsityConfig,
|
25
|
-
)
|
26
|
-
from safetensors.torch import save_file
|
27
|
-
from torch import Tensor
|
16
|
+
from typing import Optional
|
17
|
+
|
28
18
|
from transformers import AutoConfig
|
29
19
|
|
30
20
|
|
31
|
-
__all__ = [
|
32
|
-
"infer_compressor_from_model_config",
|
33
|
-
"infer_compression_config_from_model_config",
|
34
|
-
"load_compressed",
|
35
|
-
"save_compressed",
|
36
|
-
"save_compressed_model",
|
37
|
-
]
|
21
|
+
__all__ = ["infer_compressor_from_model_config", "fix_fsdp_module_name"]
|
38
22
|
|
39
|
-
|
40
|
-
pretrained_model_name_or_path: str,
|
41
|
-
) -> Optional[CompressionConfig]:
|
42
|
-
"""
|
43
|
-
Given a path to a model config, extract a sparsity config if it exists and return
|
44
|
-
the associated CompressionConfig
|
45
|
-
|
46
|
-
:param pretrained_model_name_or_path: path to model config on disk or HF hub
|
47
|
-
:return: matching compression config if config contains a sparsity config
|
48
|
-
"""
|
49
|
-
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
|
50
|
-
return getattr(config, SPARSITY_CONFIG_NAME, None)
|
23
|
+
FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
|
51
24
|
|
52
25
|
|
53
26
|
def infer_compressor_from_model_config(
|
54
27
|
pretrained_model_name_or_path: str,
|
55
|
-
) -> Optional[ModelCompressor]:
|
28
|
+
) -> Optional["ModelCompressor"]: # noqa: F821
|
56
29
|
"""
|
57
30
|
Given a path to a model config, extract a sparsity config if it exists and return
|
58
31
|
the associated ModelCompressor
|
@@ -60,100 +33,31 @@ def infer_compressor_from_model_config(
|
|
60
33
|
:param pretrained_model_name_or_path: path to model config on disk or HF hub
|
61
34
|
:return: matching compressor if config contains a sparsity config
|
62
35
|
"""
|
63
|
-
|
64
|
-
|
65
|
-
return compressor
|
66
|
-
|
67
|
-
|
68
|
-
def save_compressed(
|
69
|
-
tensors: Dict[str, Tensor],
|
70
|
-
save_path: Union[str, Path],
|
71
|
-
compression_format: Optional[CompressionFormat] = None,
|
72
|
-
):
|
73
|
-
"""
|
74
|
-
Save compressed tensors to disk. If tensors are not compressed,
|
75
|
-
save them as is.
|
76
|
-
|
77
|
-
:param tensors: dictionary of tensors to compress
|
78
|
-
:param save_path: path to save compressed tensors
|
79
|
-
:param compression_format: compression format used for the tensors
|
80
|
-
:return: compression config, if tensors were compressed - None otherwise
|
81
|
-
"""
|
82
|
-
if tensors is None or len(tensors) == 0:
|
83
|
-
raise ValueError("No tensors or empty tensors provided to compress")
|
36
|
+
from compressed_tensors.compressors import ModelCompressor
|
37
|
+
from compressed_tensors.config import CompressionConfig
|
84
38
|
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
compression_format in ModelCompressor.registered_names()
|
90
|
-
or compression_format in ModelCompressor.registered_aliases()
|
91
|
-
):
|
92
|
-
raise ValueError(
|
93
|
-
f"Unknown compression format: {compression_format}. "
|
94
|
-
f"Must be one of {set(ModelCompressor.registered_names() + ModelCompressor.registered_aliases())}" # noqa E501
|
95
|
-
)
|
39
|
+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
|
40
|
+
sparsity_config = ModelCompressor.parse_sparsity_config(config)
|
41
|
+
if sparsity_config is None:
|
42
|
+
return None
|
96
43
|
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
save_file(compressed_tensors, save_path)
|
44
|
+
format = sparsity_config.get("format")
|
45
|
+
sparsity_config = CompressionConfig.load_from_registry(format, **sparsity_config)
|
46
|
+
compressor = ModelCompressor.load_from_registry(format, config=sparsity_config)
|
47
|
+
return compressor
|
102
48
|
|
103
49
|
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
) -> Dict[str, Tensor]:
|
50
|
+
# TODO: There is already the same function in
|
51
|
+
# SparseML, should be moved to a shared location
|
52
|
+
# in the future
|
53
|
+
def fix_fsdp_module_name(name: str) -> str:
|
109
54
|
"""
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
:param
|
114
|
-
:
|
115
|
-
:param device: device to move tensors to. If None, tensors are loaded on CPU.
|
116
|
-
:return decompressed tensors
|
55
|
+
Remove FSDP wrapper prefixes from a module name
|
56
|
+
Accounts for scenario where FSDP_WRAPPER_NAME is
|
57
|
+
at the end of the name, as well as in the middle.
|
58
|
+
:param name: name to strip
|
59
|
+
:return: stripped name
|
117
60
|
"""
|
118
|
-
|
119
|
-
|
120
|
-
raise ValueError("No compressed tensors provided to load")
|
121
|
-
|
122
|
-
# if no compression_config specified, default to `dense_sparsity`
|
123
|
-
compression_config = compression_config or DenseSparsityConfig()
|
124
|
-
|
125
|
-
# decompress
|
126
|
-
compression_format = compression_config.format
|
127
|
-
compressor = ModelCompressor.load_from_registry(
|
128
|
-
compression_format, config=compression_config
|
61
|
+
return name.replace(FSDP_WRAPPER_NAME + ".", "").replace(
|
62
|
+
"." + FSDP_WRAPPER_NAME, ""
|
129
63
|
)
|
130
|
-
return dict(compressor.decompress(compressed_tensors, device=device))
|
131
|
-
|
132
|
-
|
133
|
-
def save_compressed_model(
|
134
|
-
model: torch.nn.Module,
|
135
|
-
filename: str,
|
136
|
-
compression_format: Optional[CompressionFormat] = None,
|
137
|
-
force_contiguous: bool = True,
|
138
|
-
):
|
139
|
-
"""
|
140
|
-
Wrapper around safetensors `save_model` helper function, which allows for
|
141
|
-
saving compressed model to disk.
|
142
|
-
|
143
|
-
Note: The model is assumed to have a
|
144
|
-
state_dict with unique entries
|
145
|
-
|
146
|
-
:param model: model to save on disk
|
147
|
-
:param filename: filename location to save the file
|
148
|
-
:param compression_format: compression format used for the model
|
149
|
-
:param force_contiguous: forcing the state_dict to be saved as contiguous tensors
|
150
|
-
"""
|
151
|
-
state_dict = model.state_dict()
|
152
|
-
if force_contiguous:
|
153
|
-
state_dict = {k: v.contiguous() for k, v in state_dict.items()}
|
154
|
-
try:
|
155
|
-
save_compressed(state_dict, filename, compression_format=compression_format)
|
156
|
-
except ValueError as e:
|
157
|
-
msg = str(e)
|
158
|
-
msg += " Or use save_compressed_model(..., force_contiguous=True), read the docs for potential caveats." # noqa E501
|
159
|
-
raise ValueError(msg)
|
@@ -31,6 +31,7 @@ __all__ = [
|
|
31
31
|
"get_weight_mappings",
|
32
32
|
"get_nested_weight_mappings",
|
33
33
|
"get_quantization_state_dict",
|
34
|
+
"is_quantization_param",
|
34
35
|
]
|
35
36
|
|
36
37
|
|
@@ -214,7 +215,7 @@ def get_quantization_state_dict(model_path: str) -> Dict[str, Tensor]:
|
|
214
215
|
weight_mappings = get_weight_mappings(model_path)
|
215
216
|
state_dict = {}
|
216
217
|
for weight_name, safe_path in weight_mappings.items():
|
217
|
-
if not
|
218
|
+
if not is_quantization_param(weight_name):
|
218
219
|
continue
|
219
220
|
with safe_open(safe_path, framework="pt", device="cpu") as f:
|
220
221
|
state_dict[weight_name] = f.get_tensor(weight_name)
|
@@ -222,7 +223,7 @@ def get_quantization_state_dict(model_path: str) -> Dict[str, Tensor]:
|
|
222
223
|
return state_dict
|
223
224
|
|
224
225
|
|
225
|
-
def
|
226
|
+
def is_quantization_param(name: str) -> bool:
|
226
227
|
"""
|
227
228
|
Checks is a parameter name is associated with a quantization parameter
|
228
229
|
|