compressed-tensors 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/__init__.py +1 -0
- compressed_tensors/base.py +2 -0
- compressed_tensors/compressors/__init__.py +6 -12
- compressed_tensors/compressors/base.py +137 -9
- compressed_tensors/compressors/helpers.py +6 -6
- compressed_tensors/compressors/model_compressors/__init__.py +17 -0
- compressed_tensors/compressors/{model_compressor.py → model_compressors/model_compressor.py} +99 -43
- compressed_tensors/compressors/quantized_compressors/__init__.py +18 -0
- compressed_tensors/compressors/{naive_quantized.py → quantized_compressors/base.py} +64 -62
- compressed_tensors/compressors/quantized_compressors/naive_quantized.py +140 -0
- compressed_tensors/compressors/quantized_compressors/pack_quantized.py +211 -0
- compressed_tensors/compressors/sparse_compressors/__init__.py +18 -0
- compressed_tensors/compressors/sparse_compressors/base.py +110 -0
- compressed_tensors/compressors/{dense.py → sparse_compressors/dense.py} +3 -3
- compressed_tensors/compressors/{sparse_bitmask.py → sparse_compressors/sparse_bitmask.py} +14 -59
- compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +16 -0
- compressed_tensors/compressors/{marlin_24.py → sparse_quantized_compressors/marlin_24.py} +3 -3
- compressed_tensors/config/base.py +6 -1
- compressed_tensors/linear/__init__.py +13 -0
- compressed_tensors/linear/compressed_linear.py +87 -0
- compressed_tensors/quantization/__init__.py +1 -0
- compressed_tensors/quantization/cache.py +201 -0
- compressed_tensors/quantization/lifecycle/apply.py +63 -9
- compressed_tensors/quantization/lifecycle/calibration.py +7 -7
- compressed_tensors/quantization/lifecycle/compressed.py +3 -1
- compressed_tensors/quantization/lifecycle/forward.py +126 -44
- compressed_tensors/quantization/lifecycle/frozen.py +6 -1
- compressed_tensors/quantization/lifecycle/helpers.py +0 -20
- compressed_tensors/quantization/lifecycle/initialize.py +138 -55
- compressed_tensors/quantization/observers/__init__.py +1 -0
- compressed_tensors/quantization/observers/base.py +54 -14
- compressed_tensors/quantization/observers/min_max.py +8 -0
- compressed_tensors/quantization/observers/mse.py +162 -0
- compressed_tensors/quantization/quant_args.py +102 -24
- compressed_tensors/quantization/quant_config.py +14 -2
- compressed_tensors/quantization/quant_scheme.py +12 -13
- compressed_tensors/quantization/utils/helpers.py +44 -19
- compressed_tensors/utils/__init__.py +1 -0
- compressed_tensors/utils/helpers.py +30 -1
- compressed_tensors/utils/offload.py +14 -2
- compressed_tensors/utils/permute.py +70 -0
- compressed_tensors/utils/safetensors_load.py +2 -0
- compressed_tensors/utils/semi_structured_conversions.py +1 -0
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/METADATA +35 -23
- compressed_tensors-0.7.0.dist-info/RECORD +59 -0
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/WHEEL +1 -1
- compressed_tensors/compressors/pack_quantized.py +0 -219
- compressed_tensors-0.5.0.dist-info/RECORD +0 -48
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/top_level.txt +0 -0
compressed_tensors/__init__.py
CHANGED
compressed_tensors/base.py
CHANGED
@@ -16,3 +16,5 @@ SPARSITY_CONFIG_NAME = "sparsity_config"
|
|
16
16
|
QUANTIZATION_CONFIG_NAME = "quantization_config"
|
17
17
|
COMPRESSION_CONFIG_NAME = "compression_config"
|
18
18
|
KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
|
19
|
+
COMPRESSION_VERSION_NAME = "version"
|
20
|
+
QUANTIZATION_METHOD_NAME = "quant_method"
|
@@ -14,15 +14,9 @@
|
|
14
14
|
|
15
15
|
# flake8: noqa
|
16
16
|
|
17
|
-
from .base import
|
18
|
-
from .
|
19
|
-
from .
|
20
|
-
from .
|
21
|
-
from .
|
22
|
-
from .
|
23
|
-
FloatQuantizationCompressor,
|
24
|
-
IntQuantizationCompressor,
|
25
|
-
QuantizationCompressor,
|
26
|
-
)
|
27
|
-
from .pack_quantized import PackedQuantizationCompressor
|
28
|
-
from .sparse_bitmask import BitmaskCompressor, BitmaskTensor
|
17
|
+
from .base import *
|
18
|
+
from .helpers import *
|
19
|
+
from .model_compressors import *
|
20
|
+
from .quantized_compressors import *
|
21
|
+
from .sparse_compressors import *
|
22
|
+
from .sparse_quantized_compressors import *
|
@@ -12,20 +12,47 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from
|
15
|
+
from abc import ABC, abstractmethod
|
16
|
+
from typing import Dict, Generator, Optional, Tuple, Union
|
16
17
|
|
18
|
+
import torch
|
17
19
|
from compressed_tensors.config import SparsityCompressionConfig
|
18
|
-
from compressed_tensors.quantization import QuantizationConfig
|
20
|
+
from compressed_tensors.quantization import QuantizationArgs, QuantizationConfig
|
19
21
|
from compressed_tensors.registry import RegistryMixin
|
20
22
|
from torch import Tensor
|
23
|
+
from torch.nn import Module
|
21
24
|
|
22
25
|
|
23
|
-
__all__ = ["
|
26
|
+
__all__ = ["BaseCompressor"]
|
24
27
|
|
25
28
|
|
26
|
-
class
|
29
|
+
class BaseCompressor(RegistryMixin, ABC):
|
27
30
|
"""
|
28
|
-
Base class representing a model compression algorithm
|
31
|
+
Base class representing a model compression algorithm. Each child class should
|
32
|
+
implement compression_param_info, compress_weight and decompress_weight.
|
33
|
+
|
34
|
+
Compressors support compressing/decompressing a full module state dict or a single
|
35
|
+
quantized PyTorch leaf module.
|
36
|
+
|
37
|
+
Model Load Lifecycle (run_compressed=False):
|
38
|
+
- ModelCompressor.decompress()
|
39
|
+
- apply_quantization_config()
|
40
|
+
- BaseCompressor.decompress()
|
41
|
+
|
42
|
+
Model Save Lifecycle:
|
43
|
+
- ModelCompressor.compress()
|
44
|
+
- BaseCompressor.compress()
|
45
|
+
|
46
|
+
|
47
|
+
Module Lifecycle (run_compressed=True):
|
48
|
+
- apply_quantization_config()
|
49
|
+
- compressed_module = CompressedLinear(module)
|
50
|
+
- initialize_module_for_quantization()
|
51
|
+
- BaseCompressor.compression_param_info()
|
52
|
+
- register_parameters()
|
53
|
+
- compressed_module.forward()
|
54
|
+
-compressed_module.decompress()
|
55
|
+
|
29
56
|
|
30
57
|
:param config: config specifying compression parameters
|
31
58
|
"""
|
@@ -35,26 +62,127 @@ class Compressor(RegistryMixin):
|
|
35
62
|
):
|
36
63
|
self.config = config
|
37
64
|
|
38
|
-
def
|
65
|
+
def compression_param_info(
|
66
|
+
self,
|
67
|
+
weight_shape: torch.Size,
|
68
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
69
|
+
) -> Dict[str, Tuple[torch.Size, torch.dtype]]:
|
70
|
+
"""
|
71
|
+
Creates a dictionary of expected shapes and dtypes for each compression
|
72
|
+
parameter used by the compressor
|
73
|
+
|
74
|
+
:param weight_shape: uncompressed weight shape
|
75
|
+
:param quantization_args: quantization parameters for the weight
|
76
|
+
:return: dictionary mapping compressed parameter names to shape and dtype
|
77
|
+
"""
|
78
|
+
raise NotImplementedError()
|
79
|
+
|
80
|
+
@abstractmethod
|
81
|
+
def compress(
|
82
|
+
self,
|
83
|
+
model_state: Dict[str, Tensor],
|
84
|
+
**kwargs,
|
85
|
+
) -> Dict[str, Tensor]:
|
39
86
|
"""
|
40
87
|
Compresses a dense state dict
|
41
88
|
|
42
89
|
:param model_state: state dict of uncompressed model
|
90
|
+
:param kwargs: additional arguments for compression
|
43
91
|
:return: compressed state dict
|
44
92
|
"""
|
45
93
|
raise NotImplementedError()
|
46
94
|
|
95
|
+
@abstractmethod
|
47
96
|
def decompress(
|
48
|
-
self,
|
97
|
+
self,
|
98
|
+
path_to_model_or_tensors: str,
|
99
|
+
device: str = "cpu",
|
100
|
+
**kwargs,
|
49
101
|
) -> Generator[Tuple[str, Tensor], None, None]:
|
50
102
|
"""
|
51
103
|
Reads a compressed state dict located at path_to_model_or_tensors
|
52
104
|
and returns a generator for sequentially decompressing back to a
|
53
105
|
dense state dict
|
54
106
|
|
55
|
-
:param
|
56
|
-
one or more safetensors files) or compressed tensors file
|
107
|
+
:param path_to_model_or_tensors: path to compressed safetensors model (directory
|
108
|
+
with one or more safetensors files) or compressed tensors file
|
109
|
+
:param names_to_scheme: quantization args for each quantized weight
|
57
110
|
:param device: optional device to load intermediate weights into
|
58
111
|
:return: compressed state dict
|
59
112
|
"""
|
60
113
|
raise NotImplementedError()
|
114
|
+
|
115
|
+
def compress_module(self, module: Module) -> Optional[Dict[str, torch.Tensor]]:
|
116
|
+
"""
|
117
|
+
Compresses a single quantized leaf PyTorch module. If the module is not
|
118
|
+
quantized, this function has no effect.
|
119
|
+
|
120
|
+
:param module: PyTorch module to compress
|
121
|
+
:return: dictionary of compressed weight data, or None if module is not
|
122
|
+
quantized
|
123
|
+
"""
|
124
|
+
if not hasattr(module, "quantization_scheme"):
|
125
|
+
return None # module is not quantized
|
126
|
+
quantization_scheme = module.quantization_scheme
|
127
|
+
if not hasattr(quantization_scheme, "weights"):
|
128
|
+
return None # weights are not quantized
|
129
|
+
|
130
|
+
quantization_args = quantization_scheme.weights
|
131
|
+
weight = getattr(module, "weight", None)
|
132
|
+
weight_scale = getattr(module, "weight_scale", None)
|
133
|
+
weight_zero_point = getattr(module, "weight_zero_point", None)
|
134
|
+
|
135
|
+
return self.compress_weight(
|
136
|
+
weight=weight,
|
137
|
+
scale=weight_scale,
|
138
|
+
zero_point=weight_zero_point,
|
139
|
+
quantization_args=quantization_args,
|
140
|
+
)
|
141
|
+
|
142
|
+
def compress_weight(
|
143
|
+
self,
|
144
|
+
weight: Tensor,
|
145
|
+
**kwargs,
|
146
|
+
) -> Dict[str, torch.Tensor]:
|
147
|
+
"""
|
148
|
+
Compresses a single uncompressed weight
|
149
|
+
|
150
|
+
:param weight: uncompressed weight tensor
|
151
|
+
:param kwargs: additional arguments for compression
|
152
|
+
"""
|
153
|
+
raise NotImplementedError()
|
154
|
+
|
155
|
+
def decompress_module(self, module: Module):
|
156
|
+
"""
|
157
|
+
Decompresses a single compressed leaf PyTorch module. If the module is not
|
158
|
+
quantized, this function has no effect.
|
159
|
+
|
160
|
+
:param module: PyTorch module to decompress
|
161
|
+
:return: tensor of the decompressed weight, or None if module is not quantized
|
162
|
+
"""
|
163
|
+
if not hasattr(module, "quantization_scheme"):
|
164
|
+
return None # module is not quantized
|
165
|
+
quantization_scheme = module.quantization_scheme
|
166
|
+
if not hasattr(quantization_scheme, "weights"):
|
167
|
+
return None # weights are not quantized
|
168
|
+
|
169
|
+
quantization_args = quantization_scheme.weights
|
170
|
+
compressed_data = {}
|
171
|
+
for name, parameter in module.named_parameters():
|
172
|
+
compressed_data[name] = parameter
|
173
|
+
|
174
|
+
return self.decompress_weight(
|
175
|
+
compressed_data=compressed_data, quantization_args=quantization_args
|
176
|
+
)
|
177
|
+
|
178
|
+
def decompress_weight(
|
179
|
+
self, compressed_data: Dict[str, Tensor], **kwargs
|
180
|
+
) -> torch.Tensor:
|
181
|
+
"""
|
182
|
+
Decompresses a single compressed weight
|
183
|
+
|
184
|
+
:param compressed_data: dictionary of data needed for decompression
|
185
|
+
:param kwargs: additional arguments for decompression
|
186
|
+
:return: tensor of the decompressed weight
|
187
|
+
"""
|
188
|
+
raise NotImplementedError()
|
@@ -16,7 +16,7 @@ from pathlib import Path
|
|
16
16
|
from typing import Dict, Generator, Optional, Tuple, Union
|
17
17
|
|
18
18
|
import torch
|
19
|
-
from compressed_tensors.compressors import
|
19
|
+
from compressed_tensors.compressors import BaseCompressor
|
20
20
|
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
|
21
21
|
from compressed_tensors.utils.safetensors_load import get_weight_mappings
|
22
22
|
from safetensors import safe_open
|
@@ -52,16 +52,16 @@ def save_compressed(
|
|
52
52
|
compression_format = compression_format or CompressionFormat.dense.value
|
53
53
|
|
54
54
|
if not (
|
55
|
-
compression_format in
|
56
|
-
or compression_format in
|
55
|
+
compression_format in BaseCompressor.registered_names()
|
56
|
+
or compression_format in BaseCompressor.registered_aliases()
|
57
57
|
):
|
58
58
|
raise ValueError(
|
59
59
|
f"Unknown compression format: {compression_format}. "
|
60
|
-
f"Must be one of {set(
|
60
|
+
f"Must be one of {set(BaseCompressor.registered_names() + BaseCompressor.registered_aliases())}" # noqa E501
|
61
61
|
)
|
62
62
|
|
63
63
|
# compress
|
64
|
-
compressor =
|
64
|
+
compressor = BaseCompressor.load_from_registry(compression_format)
|
65
65
|
# save compressed tensors
|
66
66
|
compressed_tensors = compressor.compress(tensors)
|
67
67
|
save_file(compressed_tensors, save_path)
|
@@ -102,7 +102,7 @@ def load_compressed(
|
|
102
102
|
else:
|
103
103
|
# decompress tensors
|
104
104
|
compression_format = compression_config.format
|
105
|
-
compressor =
|
105
|
+
compressor = BaseCompressor.load_from_registry(
|
106
106
|
compression_format, config=compression_config
|
107
107
|
)
|
108
108
|
yield from compressor.decompress(compressed_tensors, device=device)
|
@@ -0,0 +1,17 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# flake8: noqa
|
15
|
+
|
16
|
+
|
17
|
+
from .model_compressor import *
|
compressed_tensors/compressors/{model_compressor.py → model_compressors/model_compressor.py}
RENAMED
@@ -18,18 +18,22 @@ import operator
|
|
18
18
|
import os
|
19
19
|
import re
|
20
20
|
from copy import deepcopy
|
21
|
-
from typing import Any, Dict, Optional, Union
|
21
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, Union
|
22
22
|
|
23
|
+
import compressed_tensors
|
23
24
|
import torch
|
24
25
|
import transformers
|
25
26
|
from compressed_tensors.base import (
|
26
27
|
COMPRESSION_CONFIG_NAME,
|
28
|
+
COMPRESSION_VERSION_NAME,
|
27
29
|
QUANTIZATION_CONFIG_NAME,
|
30
|
+
QUANTIZATION_METHOD_NAME,
|
28
31
|
SPARSITY_CONFIG_NAME,
|
29
32
|
)
|
30
|
-
from compressed_tensors.compressors import
|
31
|
-
from compressed_tensors.config import SparsityCompressionConfig
|
33
|
+
from compressed_tensors.compressors.base import BaseCompressor
|
34
|
+
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
|
32
35
|
from compressed_tensors.quantization import (
|
36
|
+
DEFAULT_QUANTIZATION_METHOD,
|
33
37
|
QuantizationConfig,
|
34
38
|
QuantizationStatus,
|
35
39
|
apply_quantization_config,
|
@@ -40,7 +44,10 @@ from compressed_tensors.quantization.utils import (
|
|
40
44
|
iter_named_leaf_modules,
|
41
45
|
)
|
42
46
|
from compressed_tensors.utils import get_safetensors_folder, update_parameter_data
|
43
|
-
from compressed_tensors.utils.helpers import
|
47
|
+
from compressed_tensors.utils.helpers import (
|
48
|
+
fix_fsdp_module_name,
|
49
|
+
is_compressed_tensors_config,
|
50
|
+
)
|
44
51
|
from torch import Tensor
|
45
52
|
from torch.nn import Module
|
46
53
|
from tqdm import tqdm
|
@@ -53,6 +60,11 @@ __all__ = ["ModelCompressor", "map_modules_to_quant_args"]
|
|
53
60
|
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
54
61
|
|
55
62
|
|
63
|
+
if TYPE_CHECKING:
|
64
|
+
# dummy type if not available from transformers
|
65
|
+
CompressedTensorsConfig = TypeVar("CompressedTensorsConfig")
|
66
|
+
|
67
|
+
|
56
68
|
class ModelCompressor:
|
57
69
|
"""
|
58
70
|
Handles compression and decompression of a model with a sparsity config and/or
|
@@ -88,45 +100,41 @@ class ModelCompressor:
|
|
88
100
|
configs and load a ModelCompressor
|
89
101
|
|
90
102
|
:param pretrained_model_name_or_path: path to model config on disk or HF hub
|
91
|
-
:return: compressor for the
|
103
|
+
:return: compressor for the configs, or None if model is not compressed
|
92
104
|
"""
|
93
105
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
94
106
|
compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None)
|
95
107
|
return cls.from_compression_config(compression_config)
|
96
108
|
|
97
109
|
@classmethod
|
98
|
-
def from_compression_config(
|
110
|
+
def from_compression_config(
|
111
|
+
cls, compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
|
112
|
+
):
|
99
113
|
"""
|
100
|
-
:param compression_config:
|
101
|
-
|
102
|
-
|
114
|
+
:param compression_config:
|
115
|
+
A compression or quantization config
|
116
|
+
|
117
|
+
The type is one of the following:
|
118
|
+
1. A Dict found under either "quantization_config" or "compression_config"
|
119
|
+
keys in the config.json
|
120
|
+
2. A CompressedTensorsConfig found under key "quantization_config" in HF
|
121
|
+
model config
|
122
|
+
:return: compressor for the configs, or None if model is not compressed
|
103
123
|
"""
|
104
124
|
if compression_config is None:
|
105
125
|
return None
|
106
126
|
|
107
|
-
try:
|
108
|
-
from transformers.utils.quantization_config import CompressedTensorsConfig
|
109
|
-
|
110
|
-
if isinstance(compression_config, CompressedTensorsConfig):
|
111
|
-
compression_config = compression_config.to_dict()
|
112
|
-
except ImportError:
|
113
|
-
pass
|
114
|
-
|
115
127
|
sparsity_config = cls.parse_sparsity_config(compression_config)
|
116
128
|
quantization_config = cls.parse_quantization_config(compression_config)
|
117
129
|
if sparsity_config is None and quantization_config is None:
|
118
130
|
return None
|
119
131
|
|
120
|
-
if sparsity_config is not None
|
121
|
-
sparsity_config, SparsityCompressionConfig
|
122
|
-
):
|
132
|
+
if sparsity_config is not None:
|
123
133
|
format = sparsity_config.get("format")
|
124
134
|
sparsity_config = SparsityCompressionConfig.load_from_registry(
|
125
135
|
format, **sparsity_config
|
126
136
|
)
|
127
|
-
if quantization_config is not None
|
128
|
-
quantization_config, QuantizationConfig
|
129
|
-
):
|
137
|
+
if quantization_config is not None:
|
130
138
|
quantization_config = QuantizationConfig.parse_obj(quantization_config)
|
131
139
|
|
132
140
|
return cls(
|
@@ -149,7 +157,7 @@ class ModelCompressor:
|
|
149
157
|
to a sparsity compression algorithm
|
150
158
|
:param quantization_format: string corresponding to a quantization compression
|
151
159
|
algorithm
|
152
|
-
:return: compressor for the
|
160
|
+
:return: compressor for the configs, or None if model is not compressed
|
153
161
|
"""
|
154
162
|
quantization_config = QuantizationConfig.from_pretrained(
|
155
163
|
model, format=quantization_format
|
@@ -168,32 +176,60 @@ class ModelCompressor:
|
|
168
176
|
)
|
169
177
|
|
170
178
|
@staticmethod
|
171
|
-
def parse_sparsity_config(
|
179
|
+
def parse_sparsity_config(
|
180
|
+
compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
|
181
|
+
) -> Union[Dict[str, Any], None]:
|
182
|
+
"""
|
183
|
+
Parse sparsity config from quantization/compression config. Sparsity
|
184
|
+
config is nested inside q/c config
|
185
|
+
|
186
|
+
:param compression_config: quantization/compression config
|
187
|
+
:return: sparsity config
|
188
|
+
"""
|
172
189
|
if compression_config is None:
|
173
190
|
return None
|
174
|
-
if SPARSITY_CONFIG_NAME not in compression_config:
|
175
|
-
return None
|
176
|
-
if hasattr(compression_config, SPARSITY_CONFIG_NAME):
|
177
|
-
# for loaded HFQuantizer config
|
178
|
-
return getattr(compression_config, SPARSITY_CONFIG_NAME)
|
179
191
|
|
180
|
-
|
192
|
+
if is_compressed_tensors_config(compression_config):
|
193
|
+
s_config = compression_config.sparsity_config
|
194
|
+
return s_config.dict() if s_config is not None else None
|
195
|
+
|
181
196
|
return compression_config.get(SPARSITY_CONFIG_NAME, None)
|
182
197
|
|
183
198
|
@staticmethod
|
184
|
-
def parse_quantization_config(
|
199
|
+
def parse_quantization_config(
|
200
|
+
compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
|
201
|
+
) -> Union[Dict[str, Any], None]:
|
202
|
+
"""
|
203
|
+
Parse quantization config from quantization/compression config. The
|
204
|
+
quantization are all the fields that are not the sparsity config or
|
205
|
+
metadata fields
|
206
|
+
|
207
|
+
:param compression_config: quantization/compression config
|
208
|
+
:return: quantization config without sparsity config or metadata fields
|
209
|
+
"""
|
185
210
|
if compression_config is None:
|
186
211
|
return None
|
187
212
|
|
188
|
-
if
|
189
|
-
|
190
|
-
return
|
213
|
+
if is_compressed_tensors_config(compression_config):
|
214
|
+
q_config = compression_config.quantization_config
|
215
|
+
return q_config.dict() if q_config is not None else None
|
191
216
|
|
192
|
-
# SparseAutoModel format
|
193
217
|
quantization_config = deepcopy(compression_config)
|
194
218
|
quantization_config.pop(SPARSITY_CONFIG_NAME, None)
|
219
|
+
|
220
|
+
# some fields are required, even if a qconfig is not present
|
221
|
+
# pop them off and if nothing remains, then there is no qconfig
|
222
|
+
quant_method = quantization_config.pop(QUANTIZATION_METHOD_NAME, None)
|
223
|
+
_ = quantization_config.pop(COMPRESSION_VERSION_NAME, None)
|
224
|
+
|
195
225
|
if len(quantization_config) == 0:
|
196
|
-
|
226
|
+
return None
|
227
|
+
|
228
|
+
# replace popped off values
|
229
|
+
# note that version is discarded for now
|
230
|
+
if quant_method is not None:
|
231
|
+
quantization_config[QUANTIZATION_METHOD_NAME] = quant_method
|
232
|
+
|
197
233
|
return quantization_config
|
198
234
|
|
199
235
|
def __init__(
|
@@ -207,11 +243,11 @@ class ModelCompressor:
|
|
207
243
|
self.quantization_compressor = None
|
208
244
|
|
209
245
|
if sparsity_config is not None:
|
210
|
-
self.sparsity_compressor =
|
246
|
+
self.sparsity_compressor = BaseCompressor.load_from_registry(
|
211
247
|
sparsity_config.format, config=sparsity_config
|
212
248
|
)
|
213
249
|
if quantization_config is not None:
|
214
|
-
self.quantization_compressor =
|
250
|
+
self.quantization_compressor = BaseCompressor.load_from_registry(
|
215
251
|
quantization_config.format, config=quantization_config
|
216
252
|
)
|
217
253
|
|
@@ -222,7 +258,7 @@ class ModelCompressor:
|
|
222
258
|
Compresses a dense state dict or model with sparsity and/or quantization
|
223
259
|
|
224
260
|
:param model: uncompressed model to compress
|
225
|
-
:param
|
261
|
+
:param state_dict: optional uncompressed state_dict to insert into model
|
226
262
|
:return: compressed state dict
|
227
263
|
"""
|
228
264
|
if state_dict is None:
|
@@ -234,6 +270,10 @@ class ModelCompressor:
|
|
234
270
|
compressed_state_dict = self.quantization_compressor.compress(
|
235
271
|
state_dict, names_to_scheme=quantized_modules_to_args
|
236
272
|
)
|
273
|
+
if self.quantization_config.format != CompressionFormat.dense.value:
|
274
|
+
self.quantization_config.quantization_status = (
|
275
|
+
QuantizationStatus.COMPRESSED
|
276
|
+
)
|
237
277
|
|
238
278
|
if self.sparsity_compressor is not None:
|
239
279
|
compressed_state_dict = self.sparsity_compressor.compress(
|
@@ -281,6 +321,9 @@ class ModelCompressor:
|
|
281
321
|
|
282
322
|
:param save_directory: path to a folder containing a HF model config
|
283
323
|
"""
|
324
|
+
if self.quantization_config is None and self.sparsity_config is None:
|
325
|
+
return
|
326
|
+
|
284
327
|
config_file_path = os.path.join(save_directory, CONFIG_NAME)
|
285
328
|
if not os.path.exists(config_file_path):
|
286
329
|
_LOGGER.warning(
|
@@ -292,13 +335,26 @@ class ModelCompressor:
|
|
292
335
|
with open(config_file_path, "r") as config_file:
|
293
336
|
config_data = json.load(config_file)
|
294
337
|
|
295
|
-
|
338
|
+
# required metadata whenever a quantization or sparsity config is present
|
339
|
+
# overwrite previous config and version if already existing
|
340
|
+
config_data[QUANTIZATION_CONFIG_NAME] = {}
|
341
|
+
config_data[QUANTIZATION_CONFIG_NAME][
|
342
|
+
COMPRESSION_VERSION_NAME
|
343
|
+
] = compressed_tensors.__version__
|
344
|
+
if self.quantization_config is not None:
|
345
|
+
self.quantization_config.quant_method = DEFAULT_QUANTIZATION_METHOD
|
346
|
+
else:
|
347
|
+
config_data[QUANTIZATION_CONFIG_NAME][
|
348
|
+
QUANTIZATION_METHOD_NAME
|
349
|
+
] = DEFAULT_QUANTIZATION_METHOD
|
350
|
+
|
351
|
+
# quantization and sparsity configs
|
296
352
|
if self.quantization_config is not None:
|
297
353
|
quant_config_data = self.quantization_config.model_dump()
|
298
|
-
config_data[
|
354
|
+
config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data
|
299
355
|
if self.sparsity_config is not None:
|
300
356
|
sparsity_config_data = self.sparsity_config.model_dump()
|
301
|
-
config_data[
|
357
|
+
config_data[QUANTIZATION_CONFIG_NAME][
|
302
358
|
SPARSITY_CONFIG_NAME
|
303
359
|
] = sparsity_config_data
|
304
360
|
|
@@ -0,0 +1,18 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# flake8: noqa
|
15
|
+
|
16
|
+
from .base import *
|
17
|
+
from .naive_quantized import *
|
18
|
+
from .pack_quantized import *
|