compressed-tensors 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/__init__.py +1 -0
- compressed_tensors/base.py +2 -0
- compressed_tensors/compressors/__init__.py +6 -12
- compressed_tensors/compressors/base.py +137 -9
- compressed_tensors/compressors/helpers.py +6 -6
- compressed_tensors/compressors/model_compressors/__init__.py +17 -0
- compressed_tensors/compressors/{model_compressor.py → model_compressors/model_compressor.py} +99 -43
- compressed_tensors/compressors/quantized_compressors/__init__.py +18 -0
- compressed_tensors/compressors/{naive_quantized.py → quantized_compressors/base.py} +64 -62
- compressed_tensors/compressors/quantized_compressors/naive_quantized.py +140 -0
- compressed_tensors/compressors/quantized_compressors/pack_quantized.py +211 -0
- compressed_tensors/compressors/sparse_compressors/__init__.py +18 -0
- compressed_tensors/compressors/sparse_compressors/base.py +110 -0
- compressed_tensors/compressors/{dense.py → sparse_compressors/dense.py} +3 -3
- compressed_tensors/compressors/{sparse_bitmask.py → sparse_compressors/sparse_bitmask.py} +14 -59
- compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +16 -0
- compressed_tensors/compressors/{marlin_24.py → sparse_quantized_compressors/marlin_24.py} +3 -3
- compressed_tensors/config/base.py +6 -1
- compressed_tensors/linear/__init__.py +13 -0
- compressed_tensors/linear/compressed_linear.py +87 -0
- compressed_tensors/quantization/__init__.py +1 -0
- compressed_tensors/quantization/cache.py +201 -0
- compressed_tensors/quantization/lifecycle/apply.py +63 -9
- compressed_tensors/quantization/lifecycle/calibration.py +7 -7
- compressed_tensors/quantization/lifecycle/compressed.py +3 -1
- compressed_tensors/quantization/lifecycle/forward.py +126 -44
- compressed_tensors/quantization/lifecycle/frozen.py +6 -1
- compressed_tensors/quantization/lifecycle/helpers.py +0 -20
- compressed_tensors/quantization/lifecycle/initialize.py +138 -55
- compressed_tensors/quantization/observers/__init__.py +1 -0
- compressed_tensors/quantization/observers/base.py +54 -14
- compressed_tensors/quantization/observers/min_max.py +8 -0
- compressed_tensors/quantization/observers/mse.py +162 -0
- compressed_tensors/quantization/quant_args.py +102 -24
- compressed_tensors/quantization/quant_config.py +14 -2
- compressed_tensors/quantization/quant_scheme.py +12 -13
- compressed_tensors/quantization/utils/helpers.py +44 -19
- compressed_tensors/utils/__init__.py +1 -0
- compressed_tensors/utils/helpers.py +30 -1
- compressed_tensors/utils/offload.py +14 -2
- compressed_tensors/utils/permute.py +70 -0
- compressed_tensors/utils/safetensors_load.py +2 -0
- compressed_tensors/utils/semi_structured_conversions.py +1 -0
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/METADATA +35 -23
- compressed_tensors-0.7.0.dist-info/RECORD +59 -0
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/WHEEL +1 -1
- compressed_tensors/compressors/pack_quantized.py +0 -219
- compressed_tensors-0.5.0.dist-info/RECORD +0 -48
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/top_level.txt +0 -0
@@ -16,36 +16,51 @@ import logging
|
|
16
16
|
from typing import Dict, Generator, Tuple
|
17
17
|
|
18
18
|
import torch
|
19
|
-
from compressed_tensors.compressors import
|
20
|
-
from compressed_tensors.config import CompressionFormat
|
19
|
+
from compressed_tensors.compressors.base import BaseCompressor
|
21
20
|
from compressed_tensors.quantization import QuantizationArgs
|
22
|
-
from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
|
23
|
-
from compressed_tensors.quantization.utils import can_quantize
|
24
21
|
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
|
25
22
|
from safetensors import safe_open
|
26
23
|
from torch import Tensor
|
27
24
|
from tqdm import tqdm
|
28
25
|
|
29
26
|
|
30
|
-
__all__ = [
|
31
|
-
"QuantizationCompressor",
|
32
|
-
"IntQuantizationCompressor",
|
33
|
-
"FloatQuantizationCompressor",
|
34
|
-
]
|
35
|
-
|
36
27
|
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
37
28
|
|
29
|
+
__all__ = ["BaseQuantizationCompressor"]
|
30
|
+
|
38
31
|
|
39
|
-
|
40
|
-
class QuantizationCompressor(Compressor):
|
32
|
+
class BaseQuantizationCompressor(BaseCompressor):
|
41
33
|
"""
|
42
|
-
|
43
|
-
|
44
|
-
|
34
|
+
Base class representing a quant compression algorithm. Each child class should
|
35
|
+
implement compression_param_info, compress_weight and decompress_weight.
|
36
|
+
|
37
|
+
Compressors support compressing/decompressing a full module state dict or a single
|
38
|
+
quantized PyTorch leaf module.
|
39
|
+
|
40
|
+
Model Load Lifecycle (run_compressed=False):
|
41
|
+
- ModelCompressor.decompress()
|
42
|
+
- apply_quantization_config()
|
43
|
+
- BaseQuantizationCompressor.decompress()
|
44
|
+
- BaseQuantizationCompressor.decompress_weight()
|
45
|
+
|
46
|
+
Model Save Lifecycle:
|
47
|
+
- ModelCompressor.compress()
|
48
|
+
- BaseQuantizationCompressor.compress()
|
49
|
+
- BaseQuantizationCompressor.compress_weight()
|
50
|
+
|
51
|
+
Module Lifecycle (run_compressed=True):
|
52
|
+
- apply_quantization_config()
|
53
|
+
- compressed_module = CompressedLinear(module)
|
54
|
+
- initialize_module_for_quantization()
|
55
|
+
- BaseQuantizationCompressor.compression_param_info()
|
56
|
+
- register_parameters()
|
57
|
+
- compressed_module.forward()
|
58
|
+
- compressed_module.decompress()
|
59
|
+
|
60
|
+
|
61
|
+
:param config: config specifying compression parameters
|
45
62
|
"""
|
46
63
|
|
47
|
-
COMPRESSION_PARAM_NAMES = ["weight", "weight_scale", "weight_zero_point"]
|
48
|
-
|
49
64
|
def compress(
|
50
65
|
self,
|
51
66
|
model_state: Dict[str, Tensor],
|
@@ -57,7 +72,7 @@ class QuantizationCompressor(Compressor):
|
|
57
72
|
|
58
73
|
:param model_state: state dict of uncompressed model
|
59
74
|
:param names_to_scheme: quantization args for each quantized weight, needed for
|
60
|
-
|
75
|
+
quantize function to calculate bit depth
|
61
76
|
:return: compressed state dict
|
62
77
|
"""
|
63
78
|
compressed_dict = {}
|
@@ -66,42 +81,50 @@ class QuantizationCompressor(Compressor):
|
|
66
81
|
f"Compressing model with {len(model_state)} parameterized layers..."
|
67
82
|
)
|
68
83
|
|
69
|
-
for name, value in tqdm(model_state.items(), desc="
|
84
|
+
for name, value in tqdm(model_state.items(), desc="Quantized Compression"):
|
70
85
|
if name.endswith(weight_suffix):
|
71
86
|
prefix = name[: -(len(weight_suffix))]
|
72
87
|
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
|
73
88
|
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
|
74
|
-
|
89
|
+
g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None)
|
90
|
+
if scale is not None:
|
75
91
|
# weight is quantized, compress it
|
76
92
|
quant_args = names_to_scheme[prefix]
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
93
|
+
compressed_data = self.compress_weight(
|
94
|
+
weight=value,
|
95
|
+
scale=scale,
|
96
|
+
zero_point=zp,
|
97
|
+
g_idx=g_idx,
|
98
|
+
quantization_args=quant_args,
|
99
|
+
device="cpu",
|
100
|
+
)
|
101
|
+
for key, value in compressed_data.items():
|
102
|
+
compressed_dict[merge_names(prefix, key)] = value
|
103
|
+
else:
|
104
|
+
compressed_dict[name] = value.to("cpu")
|
105
|
+
elif name.endswith("zero_point") and torch.all(value == 0):
|
106
|
+
continue
|
107
|
+
elif name.endswith("g_idx") and torch.any(value <= -1):
|
108
|
+
continue
|
109
|
+
else:
|
110
|
+
compressed_dict[name] = value.to("cpu")
|
92
111
|
|
93
112
|
return compressed_dict
|
94
113
|
|
95
114
|
def decompress(
|
96
|
-
self,
|
115
|
+
self,
|
116
|
+
path_to_model_or_tensors: str,
|
117
|
+
names_to_scheme: Dict[str, QuantizationArgs],
|
118
|
+
device: str = "cpu",
|
97
119
|
) -> Generator[Tuple[str, Tensor], None, None]:
|
98
120
|
"""
|
99
121
|
Reads a compressed state dict located at path_to_model_or_tensors
|
100
122
|
and returns a generator for sequentially decompressing back to a
|
101
123
|
dense state dict
|
102
124
|
|
103
|
-
:param
|
104
|
-
one or more safetensors files) or compressed tensors file
|
125
|
+
:param path_to_model_or_tensors: path to compressed safetensors model (directory
|
126
|
+
with one or more safetensors files) or compressed tensors file
|
127
|
+
:param names_to_scheme: quantization args for each quantized weight
|
105
128
|
:param device: optional device to load intermediate weights into
|
106
129
|
:return: compressed state dict
|
107
130
|
"""
|
@@ -116,29 +139,8 @@ class QuantizationCompressor(Compressor):
|
|
116
139
|
weight_data[param_name] = f.get_tensor(full_name)
|
117
140
|
|
118
141
|
if "weight_scale" in weight_data:
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
x_q=weight_data["weight"],
|
123
|
-
scale=scale,
|
124
|
-
zero_point=zero_point,
|
142
|
+
quant_args = names_to_scheme[weight_name]
|
143
|
+
decompressed = self.decompress_weight(
|
144
|
+
compressed_data=weight_data, quantization_args=quant_args
|
125
145
|
)
|
126
146
|
yield merge_names(weight_name, "weight"), decompressed
|
127
|
-
|
128
|
-
|
129
|
-
@Compressor.register(name=CompressionFormat.int_quantized.value)
|
130
|
-
class IntQuantizationCompressor(QuantizationCompressor):
|
131
|
-
"""
|
132
|
-
Alias for integer quantized models
|
133
|
-
"""
|
134
|
-
|
135
|
-
pass
|
136
|
-
|
137
|
-
|
138
|
-
@Compressor.register(name=CompressionFormat.float_quantized.value)
|
139
|
-
class FloatQuantizationCompressor(QuantizationCompressor):
|
140
|
-
"""
|
141
|
-
Alias for fp quantized models
|
142
|
-
"""
|
143
|
-
|
144
|
-
pass
|
@@ -0,0 +1,140 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, Optional, Tuple
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from compressed_tensors.compressors.base import BaseCompressor
|
19
|
+
from compressed_tensors.compressors.quantized_compressors.base import (
|
20
|
+
BaseQuantizationCompressor,
|
21
|
+
)
|
22
|
+
from compressed_tensors.config import CompressionFormat
|
23
|
+
from compressed_tensors.quantization import QuantizationArgs
|
24
|
+
from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
|
25
|
+
from compressed_tensors.quantization.utils import can_quantize
|
26
|
+
from torch import Tensor
|
27
|
+
|
28
|
+
|
29
|
+
__all__ = [
|
30
|
+
"NaiveQuantizationCompressor",
|
31
|
+
"IntQuantizationCompressor",
|
32
|
+
"FloatQuantizationCompressor",
|
33
|
+
]
|
34
|
+
|
35
|
+
|
36
|
+
@BaseCompressor.register(name=CompressionFormat.naive_quantized.value)
|
37
|
+
class NaiveQuantizationCompressor(BaseQuantizationCompressor):
|
38
|
+
"""
|
39
|
+
Implements naive compression for quantized models. Weight of each
|
40
|
+
quantized layer is converted from its original float type to the closest Pytorch
|
41
|
+
type to the type specified by the layer's QuantizationArgs.
|
42
|
+
"""
|
43
|
+
|
44
|
+
COMPRESSION_PARAM_NAMES = [
|
45
|
+
"weight",
|
46
|
+
"weight_scale",
|
47
|
+
"weight_zero_point",
|
48
|
+
"weight_g_idx",
|
49
|
+
]
|
50
|
+
|
51
|
+
def compression_param_info(
|
52
|
+
self,
|
53
|
+
weight_shape: torch.Size,
|
54
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
55
|
+
) -> Dict[str, Tuple[torch.Size, torch.dtype]]:
|
56
|
+
"""
|
57
|
+
Creates a dictionary of expected shapes and dtypes for each compression
|
58
|
+
parameter used by the compressor
|
59
|
+
|
60
|
+
:param weight_shape: uncompressed weight shape
|
61
|
+
:param quantization_args: quantization parameters for the weight
|
62
|
+
:return: dictionary mapping compressed parameter names to shape and dtype
|
63
|
+
"""
|
64
|
+
dtype = quantization_args.pytorch_dtype()
|
65
|
+
return {"weight": (weight_shape, dtype)}
|
66
|
+
|
67
|
+
def compress_weight(
|
68
|
+
self,
|
69
|
+
weight: Tensor,
|
70
|
+
scale: Tensor,
|
71
|
+
zero_point: Optional[Tensor] = None,
|
72
|
+
g_idx: Optional[torch.Tensor] = None,
|
73
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
74
|
+
device: Optional[torch.device] = None,
|
75
|
+
) -> Dict[str, torch.Tensor]:
|
76
|
+
"""
|
77
|
+
Compresses a single uncompressed weight
|
78
|
+
|
79
|
+
:param weight: uncompressed weight tensor
|
80
|
+
:param scale: quantization scale for weight
|
81
|
+
:param zero_point: quantization zero point for weight
|
82
|
+
:param g_idx: optional mapping from column index to group index
|
83
|
+
:param quantization_args: quantization parameters for weight
|
84
|
+
:param device: optional device to move compressed output to
|
85
|
+
:return: dictionary of compressed weight data
|
86
|
+
"""
|
87
|
+
if can_quantize(weight, quantization_args):
|
88
|
+
quantized_weight = quantize(
|
89
|
+
x=weight,
|
90
|
+
scale=scale,
|
91
|
+
zero_point=zero_point,
|
92
|
+
g_idx=g_idx,
|
93
|
+
args=quantization_args,
|
94
|
+
dtype=quantization_args.pytorch_dtype(),
|
95
|
+
)
|
96
|
+
|
97
|
+
if device is not None:
|
98
|
+
quantized_weight = quantized_weight.to(device)
|
99
|
+
|
100
|
+
return {"weight": quantized_weight}
|
101
|
+
|
102
|
+
def decompress_weight(
|
103
|
+
self,
|
104
|
+
compressed_data: Dict[str, Tensor],
|
105
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
106
|
+
) -> torch.Tensor:
|
107
|
+
"""
|
108
|
+
Decompresses a single compressed weight
|
109
|
+
|
110
|
+
:param compressed_data: dictionary of data needed for decompression
|
111
|
+
:param quantization_args: quantization parameters for the weight
|
112
|
+
:return: tensor of the decompressed weight
|
113
|
+
"""
|
114
|
+
weight = compressed_data["weight"]
|
115
|
+
scale = compressed_data["weight_scale"]
|
116
|
+
zero_point = compressed_data.get("weight_zero_point", None)
|
117
|
+
g_idx = compressed_data.get("weight_g_idx", None)
|
118
|
+
decompressed_weight = dequantize(
|
119
|
+
x_q=weight, scale=scale, zero_point=zero_point, g_idx=g_idx
|
120
|
+
)
|
121
|
+
|
122
|
+
return decompressed_weight
|
123
|
+
|
124
|
+
|
125
|
+
@BaseCompressor.register(name=CompressionFormat.int_quantized.value)
|
126
|
+
class IntQuantizationCompressor(NaiveQuantizationCompressor):
|
127
|
+
"""
|
128
|
+
Alias for integer quantized models
|
129
|
+
"""
|
130
|
+
|
131
|
+
pass
|
132
|
+
|
133
|
+
|
134
|
+
@BaseCompressor.register(name=CompressionFormat.float_quantized.value)
|
135
|
+
class FloatQuantizationCompressor(NaiveQuantizationCompressor):
|
136
|
+
"""
|
137
|
+
Alias for fp quantized models
|
138
|
+
"""
|
139
|
+
|
140
|
+
pass
|
@@ -0,0 +1,211 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import math
|
15
|
+
from typing import Dict, Optional, Tuple
|
16
|
+
|
17
|
+
import numpy as np
|
18
|
+
import torch
|
19
|
+
from compressed_tensors.compressors.base import BaseCompressor
|
20
|
+
from compressed_tensors.compressors.quantized_compressors.base import (
|
21
|
+
BaseQuantizationCompressor,
|
22
|
+
)
|
23
|
+
from compressed_tensors.config import CompressionFormat
|
24
|
+
from compressed_tensors.quantization import QuantizationArgs
|
25
|
+
from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
|
26
|
+
from compressed_tensors.quantization.utils import can_quantize
|
27
|
+
from torch import Tensor
|
28
|
+
|
29
|
+
|
30
|
+
__all__ = ["PackedQuantizationCompressor", "pack_to_int32", "unpack_from_int32"]
|
31
|
+
|
32
|
+
|
33
|
+
@BaseCompressor.register(name=CompressionFormat.pack_quantized.value)
|
34
|
+
class PackedQuantizationCompressor(BaseQuantizationCompressor):
|
35
|
+
"""
|
36
|
+
Compresses a quantized model by packing every eight 4-bit weights into an int32
|
37
|
+
"""
|
38
|
+
|
39
|
+
COMPRESSION_PARAM_NAMES = [
|
40
|
+
"weight_packed",
|
41
|
+
"weight_scale",
|
42
|
+
"weight_zero_point",
|
43
|
+
"weight_g_idx",
|
44
|
+
"weight_shape",
|
45
|
+
]
|
46
|
+
|
47
|
+
def compression_param_info(
|
48
|
+
self,
|
49
|
+
weight_shape: torch.Size,
|
50
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
51
|
+
) -> Dict[str, Tuple[torch.Size, torch.dtype]]:
|
52
|
+
"""
|
53
|
+
Creates a dictionary of expected shapes and dtypes for each compression
|
54
|
+
parameter used by the compressor
|
55
|
+
|
56
|
+
:param weight_shape: uncompressed weight shape
|
57
|
+
:param quantization_args: quantization parameters for the weight
|
58
|
+
:return: dictionary mapping compressed parameter names to shape and dtype
|
59
|
+
"""
|
60
|
+
pack_factor = 32 // quantization_args.num_bits
|
61
|
+
packed_size = math.ceil(weight_shape[1] / pack_factor)
|
62
|
+
return {
|
63
|
+
"weight_packed": (torch.Size((weight_shape[0], packed_size)), torch.int32),
|
64
|
+
"weight_shape": (torch.Size((2,)), torch.int32),
|
65
|
+
}
|
66
|
+
|
67
|
+
def compress_weight(
|
68
|
+
self,
|
69
|
+
weight: Tensor,
|
70
|
+
scale: Tensor,
|
71
|
+
zero_point: Optional[Tensor] = None,
|
72
|
+
g_idx: Optional[torch.Tensor] = None,
|
73
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
74
|
+
device: Optional[torch.device] = None,
|
75
|
+
) -> Dict[str, torch.Tensor]:
|
76
|
+
"""
|
77
|
+
Compresses a single uncompressed weight
|
78
|
+
|
79
|
+
:param weight: uncompressed weight tensor
|
80
|
+
:param scale: quantization scale for weight
|
81
|
+
:param zero_point: quantization zero point for weight
|
82
|
+
:param g_idx: optional mapping from column index to group index
|
83
|
+
:param quantization_args: quantization parameters for weight
|
84
|
+
:param device: optional device to move compressed output to
|
85
|
+
:return: dictionary of compressed weight data
|
86
|
+
"""
|
87
|
+
compressed_dict = {}
|
88
|
+
if can_quantize(weight, quantization_args):
|
89
|
+
quantized_weight = quantize(
|
90
|
+
x=weight,
|
91
|
+
scale=scale,
|
92
|
+
zero_point=zero_point,
|
93
|
+
g_idx=g_idx,
|
94
|
+
args=quantization_args,
|
95
|
+
dtype=torch.int8,
|
96
|
+
)
|
97
|
+
|
98
|
+
packed_weight = pack_to_int32(quantized_weight, quantization_args.num_bits)
|
99
|
+
weight_shape = torch.tensor(weight.shape)
|
100
|
+
if device is not None:
|
101
|
+
packed_weight = packed_weight.to(device)
|
102
|
+
weight_shape = weight_shape.to(device)
|
103
|
+
|
104
|
+
compressed_dict["weight_shape"] = weight_shape
|
105
|
+
compressed_dict["weight_packed"] = packed_weight
|
106
|
+
|
107
|
+
return compressed_dict
|
108
|
+
|
109
|
+
def decompress_weight(
|
110
|
+
self,
|
111
|
+
compressed_data: Dict[str, Tensor],
|
112
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
113
|
+
) -> torch.Tensor:
|
114
|
+
"""
|
115
|
+
Decompresses a single compressed weight
|
116
|
+
|
117
|
+
:param compressed_data: dictionary of data needed for decompression
|
118
|
+
:param quantization_args: quantization parameters for the weight
|
119
|
+
:return: tensor of the decompressed weight
|
120
|
+
"""
|
121
|
+
weight = compressed_data["weight_packed"]
|
122
|
+
scale = compressed_data["weight_scale"]
|
123
|
+
zero_point = compressed_data.get("weight_zero_point", None)
|
124
|
+
g_idx = compressed_data.get("weight_g_idx", None)
|
125
|
+
original_shape = torch.Size(compressed_data["weight_shape"])
|
126
|
+
num_bits = quantization_args.num_bits
|
127
|
+
unpacked = unpack_from_int32(weight, num_bits, original_shape)
|
128
|
+
decompressed_weight = dequantize(
|
129
|
+
x_q=unpacked, scale=scale, zero_point=zero_point, g_idx=g_idx
|
130
|
+
)
|
131
|
+
|
132
|
+
return decompressed_weight
|
133
|
+
|
134
|
+
|
135
|
+
def pack_to_int32(value: torch.Tensor, num_bits: int) -> torch.Tensor:
|
136
|
+
"""
|
137
|
+
Packs a tensor of quantized weights stored in int8 into int32s with padding
|
138
|
+
|
139
|
+
:param value: tensor to pack
|
140
|
+
:param num_bits: number of bits used to store underlying data
|
141
|
+
:returns: packed int32 tensor
|
142
|
+
"""
|
143
|
+
if value.dtype is not torch.int8:
|
144
|
+
raise ValueError("Tensor must be quantized to torch.int8 before packing")
|
145
|
+
|
146
|
+
if num_bits > 8:
|
147
|
+
raise ValueError("Packing is only supported for less than 8 bits")
|
148
|
+
|
149
|
+
# convert to unsigned for packing
|
150
|
+
offset = pow(2, num_bits) // 2
|
151
|
+
value = (value + offset).to(torch.uint8)
|
152
|
+
value = value.cpu().numpy().astype(np.uint32)
|
153
|
+
pack_factor = 32 // num_bits
|
154
|
+
|
155
|
+
# pad input tensor and initialize packed output
|
156
|
+
packed_size = math.ceil(value.shape[1] / pack_factor)
|
157
|
+
packed = np.zeros((value.shape[0], packed_size), dtype=np.uint32)
|
158
|
+
padding = packed.shape[1] * pack_factor - value.shape[1]
|
159
|
+
value = np.pad(value, pad_width=[(0, 0), (0, padding)], constant_values=0)
|
160
|
+
|
161
|
+
# pack values
|
162
|
+
for i in range(pack_factor):
|
163
|
+
packed |= value[:, i::pack_factor] << num_bits * i
|
164
|
+
|
165
|
+
# convert back to signed and torch
|
166
|
+
packed = np.ascontiguousarray(packed).view(np.int32)
|
167
|
+
return torch.from_numpy(packed)
|
168
|
+
|
169
|
+
|
170
|
+
def unpack_from_int32(
|
171
|
+
value: torch.Tensor, num_bits: int, shape: torch.Size
|
172
|
+
) -> torch.Tensor:
|
173
|
+
"""
|
174
|
+
Unpacks a tensor of packed int32 weights into individual int8s, maintaining the
|
175
|
+
original their bit range
|
176
|
+
|
177
|
+
:param value: tensor to upack
|
178
|
+
:param num_bits: number of bits to unpack each data point into
|
179
|
+
:param shape: shape to unpack into, used to remove padding
|
180
|
+
:returns: unpacked int8 tensor
|
181
|
+
"""
|
182
|
+
if value.dtype is not torch.int32:
|
183
|
+
raise ValueError(
|
184
|
+
f"Expected {torch.int32} but got {value.dtype}, Aborting unpack."
|
185
|
+
)
|
186
|
+
|
187
|
+
if num_bits > 8:
|
188
|
+
raise ValueError("Unpacking is only supported for less than 8 bits")
|
189
|
+
|
190
|
+
pack_factor = 32 // num_bits
|
191
|
+
|
192
|
+
# unpack
|
193
|
+
mask = pow(2, num_bits) - 1
|
194
|
+
unpacked = torch.zeros(
|
195
|
+
(value.shape[0], value.shape[1] * pack_factor),
|
196
|
+
device=value.device,
|
197
|
+
dtype=torch.int32,
|
198
|
+
)
|
199
|
+
for i in range(pack_factor):
|
200
|
+
unpacked[:, i::pack_factor] = (value >> (num_bits * i)) & mask
|
201
|
+
|
202
|
+
# remove padding
|
203
|
+
original_row_size = int(shape[1])
|
204
|
+
unpacked = unpacked[:, :original_row_size]
|
205
|
+
|
206
|
+
# bits are packed in unsigned format, reformat to signed
|
207
|
+
# update the value range from unsigned to signed
|
208
|
+
offset = pow(2, num_bits) // 2
|
209
|
+
unpacked = (unpacked - offset).to(torch.int8)
|
210
|
+
|
211
|
+
return unpacked
|
@@ -0,0 +1,18 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# flake8: noqa
|
15
|
+
|
16
|
+
from .base import *
|
17
|
+
from .dense import *
|
18
|
+
from .sparse_bitmask import *
|
@@ -0,0 +1,110 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import logging
|
16
|
+
from typing import Dict, Generator, Tuple
|
17
|
+
|
18
|
+
from compressed_tensors.compressors.base import BaseCompressor
|
19
|
+
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
|
20
|
+
from safetensors import safe_open
|
21
|
+
from torch import Tensor
|
22
|
+
from tqdm import tqdm
|
23
|
+
|
24
|
+
|
25
|
+
__all__ = ["BaseSparseCompressor"]
|
26
|
+
|
27
|
+
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
28
|
+
|
29
|
+
|
30
|
+
class BaseSparseCompressor(BaseCompressor):
|
31
|
+
"""
|
32
|
+
Base class representing a sparse compression algorithm. Each child class should
|
33
|
+
implement compression_param_info, compress_weight and decompress_weight.
|
34
|
+
|
35
|
+
Compressors support compressing/decompressing a full module state dict or a single
|
36
|
+
quantized PyTorch leaf module.
|
37
|
+
|
38
|
+
Model Load Lifecycle (run_compressed=False):
|
39
|
+
- ModelCompressor.decompress()
|
40
|
+
- apply_quantization_config()
|
41
|
+
- BaseSparseCompressor.decompress()
|
42
|
+
- BaseSparseCompressor.decompress_weight()
|
43
|
+
|
44
|
+
Model Save Lifecycle:
|
45
|
+
- ModelCompressor.compress()
|
46
|
+
- BaseSparseCompressor.compress()
|
47
|
+
- BaseSparseCompressor.compress_weight()
|
48
|
+
|
49
|
+
Module Lifecycle (run_compressed=True):
|
50
|
+
- apply_quantization_config()
|
51
|
+
- compressed_module = CompressedLinear(module)
|
52
|
+
- initialize_module_for_quantization()
|
53
|
+
- BaseSparseCompressor.compression_param_info()
|
54
|
+
- register_parameters()
|
55
|
+
- compressed_module.forward()
|
56
|
+
- compressed_module.decompress()
|
57
|
+
|
58
|
+
|
59
|
+
:param config: config specifying compression parameters
|
60
|
+
"""
|
61
|
+
|
62
|
+
def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
63
|
+
"""
|
64
|
+
Compresses a dense state dict using bitmask compression
|
65
|
+
|
66
|
+
:param model_state: state dict of uncompressed model
|
67
|
+
:return: compressed state dict
|
68
|
+
"""
|
69
|
+
compressed_dict = {}
|
70
|
+
_LOGGER.debug(
|
71
|
+
f"Compressing model with {len(model_state)} parameterized layers..."
|
72
|
+
)
|
73
|
+
for name, value in tqdm(model_state.items(), desc="Compressing model"):
|
74
|
+
compression_data = self.compress_weight(name, value)
|
75
|
+
for key in compression_data.keys():
|
76
|
+
if key in compressed_dict:
|
77
|
+
_LOGGER.warn(
|
78
|
+
f"Expected all compressed state_dict keys to be unique, but "
|
79
|
+
f"found an existing entry for {key}. The existing entry will "
|
80
|
+
"be replaced."
|
81
|
+
)
|
82
|
+
|
83
|
+
compressed_dict.update(compression_data)
|
84
|
+
|
85
|
+
return compressed_dict
|
86
|
+
|
87
|
+
def decompress(
|
88
|
+
self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
|
89
|
+
) -> Generator[Tuple[str, Tensor], None, None]:
|
90
|
+
"""
|
91
|
+
Reads a bitmask compressed state dict located
|
92
|
+
at path_to_model_or_tensors and returns a generator
|
93
|
+
for sequentially decompressing back to a dense state dict
|
94
|
+
|
95
|
+
:param model_path: path to compressed safetensors model (directory with
|
96
|
+
one or more safetensors files) or compressed tensors file
|
97
|
+
:param device: device to load decompressed weights onto
|
98
|
+
:return: iterator for generating decompressed weights
|
99
|
+
"""
|
100
|
+
weight_mappings = get_nested_weight_mappings(
|
101
|
+
path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
|
102
|
+
)
|
103
|
+
for weight_name in weight_mappings.keys():
|
104
|
+
weight_data = {}
|
105
|
+
for param_name, safe_path in weight_mappings[weight_name].items():
|
106
|
+
full_name = merge_names(weight_name, param_name)
|
107
|
+
with safe_open(safe_path, framework="pt", device=device) as f:
|
108
|
+
weight_data[param_name] = f.get_tensor(full_name)
|
109
|
+
decompressed = self.decompress_weight(weight_data)
|
110
|
+
yield weight_name, decompressed
|
@@ -14,13 +14,13 @@
|
|
14
14
|
|
15
15
|
from typing import Dict, Generator, Tuple
|
16
16
|
|
17
|
-
from compressed_tensors.compressors import
|
17
|
+
from compressed_tensors.compressors.base import BaseCompressor
|
18
18
|
from compressed_tensors.config import CompressionFormat
|
19
19
|
from torch import Tensor
|
20
20
|
|
21
21
|
|
22
|
-
@
|
23
|
-
class DenseCompressor(
|
22
|
+
@BaseCompressor.register(name=CompressionFormat.dense.value)
|
23
|
+
class DenseCompressor(BaseCompressor):
|
24
24
|
"""
|
25
25
|
Identity compressor for dense models, returns the original state_dict
|
26
26
|
"""
|