compressed-tensors-nightly 0.3.3.20240514__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/__init__.py +21 -0
- compressed_tensors/base.py +17 -0
- compressed_tensors/compressors/__init__.py +22 -0
- compressed_tensors/compressors/base.py +59 -0
- compressed_tensors/compressors/dense.py +34 -0
- compressed_tensors/compressors/helpers.py +137 -0
- compressed_tensors/compressors/int_quantized.py +95 -0
- compressed_tensors/compressors/model_compressor.py +264 -0
- compressed_tensors/compressors/sparse_bitmask.py +239 -0
- compressed_tensors/config/__init__.py +18 -0
- compressed_tensors/config/base.py +43 -0
- compressed_tensors/config/dense.py +36 -0
- compressed_tensors/config/sparse_bitmask.py +36 -0
- compressed_tensors/quantization/__init__.py +21 -0
- compressed_tensors/quantization/lifecycle/__init__.py +23 -0
- compressed_tensors/quantization/lifecycle/apply.py +196 -0
- compressed_tensors/quantization/lifecycle/calibration.py +51 -0
- compressed_tensors/quantization/lifecycle/compressed.py +69 -0
- compressed_tensors/quantization/lifecycle/forward.py +333 -0
- compressed_tensors/quantization/lifecycle/frozen.py +50 -0
- compressed_tensors/quantization/lifecycle/initialize.py +99 -0
- compressed_tensors/quantization/observers/__init__.py +21 -0
- compressed_tensors/quantization/observers/base.py +130 -0
- compressed_tensors/quantization/observers/helpers.py +54 -0
- compressed_tensors/quantization/observers/memoryless.py +48 -0
- compressed_tensors/quantization/observers/min_max.py +80 -0
- compressed_tensors/quantization/quant_args.py +125 -0
- compressed_tensors/quantization/quant_config.py +210 -0
- compressed_tensors/quantization/quant_scheme.py +39 -0
- compressed_tensors/quantization/utils/__init__.py +16 -0
- compressed_tensors/quantization/utils/helpers.py +131 -0
- compressed_tensors/registry/__init__.py +17 -0
- compressed_tensors/registry/registry.py +360 -0
- compressed_tensors/utils/__init__.py +16 -0
- compressed_tensors/utils/helpers.py +45 -0
- compressed_tensors/utils/safetensors_load.py +237 -0
- compressed_tensors/version.py +50 -0
- compressed_tensors_nightly-0.3.3.20240514.dist-info/LICENSE +201 -0
- compressed_tensors_nightly-0.3.3.20240514.dist-info/METADATA +105 -0
- compressed_tensors_nightly-0.3.3.20240514.dist-info/RECORD +42 -0
- compressed_tensors_nightly-0.3.3.20240514.dist-info/WHEEL +5 -0
- compressed_tensors_nightly-0.3.3.20240514.dist-info/top_level.txt +1 -0
@@ -0,0 +1,239 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import logging
|
16
|
+
from typing import Dict, Generator, List, Tuple, Union
|
17
|
+
|
18
|
+
import numpy
|
19
|
+
import torch
|
20
|
+
from compressed_tensors.compressors import Compressor
|
21
|
+
from compressed_tensors.config import CompressionFormat
|
22
|
+
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
|
23
|
+
from safetensors import safe_open
|
24
|
+
from torch import Tensor
|
25
|
+
from tqdm import tqdm
|
26
|
+
|
27
|
+
|
28
|
+
__all__ = [
|
29
|
+
"BitmaskCompressor",
|
30
|
+
"BitmaskTensor",
|
31
|
+
"bitmask_compress",
|
32
|
+
"bitmask_decompress",
|
33
|
+
"pack_bitmasks",
|
34
|
+
"unpack_bitmasks",
|
35
|
+
]
|
36
|
+
|
37
|
+
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
38
|
+
|
39
|
+
|
40
|
+
@Compressor.register(name=CompressionFormat.sparse_bitmask.value)
|
41
|
+
class BitmaskCompressor(Compressor):
|
42
|
+
"""
|
43
|
+
Compression for sparse models using bitmasks. Non-zero weights are stored in a 1d
|
44
|
+
values tensor, with their locations stored in a 2d bitmask
|
45
|
+
"""
|
46
|
+
|
47
|
+
COMPRESSION_PARAM_NAMES = ["shape", "compressed", "bitmask", "row_offsets"]
|
48
|
+
|
49
|
+
def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
50
|
+
"""
|
51
|
+
Compresses a dense state dict using bitmask compression
|
52
|
+
|
53
|
+
:param model_state: state dict of uncompressed model
|
54
|
+
:return: compressed state dict
|
55
|
+
"""
|
56
|
+
compressed_dict = {}
|
57
|
+
_LOGGER.debug(
|
58
|
+
f"Compressing model with {len(model_state)} parameterized layers..."
|
59
|
+
)
|
60
|
+
for name, value in tqdm(model_state.items(), desc="Compressing model"):
|
61
|
+
bitmask_tensor = BitmaskTensor.from_dense(value)
|
62
|
+
bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu")
|
63
|
+
for key in bitmask_dict.keys():
|
64
|
+
if key in compressed_dict:
|
65
|
+
_LOGGER.warn(
|
66
|
+
f"Expected all compressed state_dict keys to be unique, but "
|
67
|
+
f"found an existing entry for {key}. The existing entry will "
|
68
|
+
"be replaced."
|
69
|
+
)
|
70
|
+
compressed_dict |= bitmask_dict
|
71
|
+
|
72
|
+
return compressed_dict
|
73
|
+
|
74
|
+
def decompress(
|
75
|
+
self, path_to_model_or_tensors: str, device: str = "cpu"
|
76
|
+
) -> Generator[Tuple[str, Tensor], None, None]:
|
77
|
+
"""
|
78
|
+
Reads a bitmask compressed state dict located
|
79
|
+
at path_to_model_or_tensors and returns a generator
|
80
|
+
for sequentially decompressing back to a dense state dict
|
81
|
+
|
82
|
+
:param model_path: path to compressed safetensors model (directory with
|
83
|
+
one or more safetensors files) or compressed tensors file
|
84
|
+
:param device: device to load decompressed weights onto
|
85
|
+
:return: iterator for generating decompressed weights
|
86
|
+
"""
|
87
|
+
weight_mappings = get_nested_weight_mappings(
|
88
|
+
path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
|
89
|
+
)
|
90
|
+
for weight_name in weight_mappings.keys():
|
91
|
+
weight_data = {}
|
92
|
+
for param_name, safe_path in weight_mappings[weight_name].items():
|
93
|
+
full_name = merge_names(weight_name, param_name)
|
94
|
+
with safe_open(safe_path, framework="pt", device=device) as f:
|
95
|
+
weight_data[param_name] = f.get_tensor(full_name)
|
96
|
+
data = BitmaskTensor(**weight_data)
|
97
|
+
decompressed = data.decompress()
|
98
|
+
yield weight_name, decompressed
|
99
|
+
|
100
|
+
|
101
|
+
class BitmaskTensor:
|
102
|
+
"""
|
103
|
+
Owns compressions and decompression for a single bitmask compressed tensor.
|
104
|
+
Adapted from: https://github.com/mgoin/torch_bitmask/tree/main
|
105
|
+
|
106
|
+
:param shape: shape of dense tensor
|
107
|
+
:compressed: flat tensor of non-zero values
|
108
|
+
:bitmask: 2d bitmask of non-zero values
|
109
|
+
:row_offsets: flat tensor indicating what index in values each dense row starts at
|
110
|
+
"""
|
111
|
+
|
112
|
+
def __init__(
|
113
|
+
self,
|
114
|
+
shape: Union[torch.Size, List],
|
115
|
+
compressed: Tensor,
|
116
|
+
bitmask: Tensor,
|
117
|
+
row_offsets: Tensor,
|
118
|
+
):
|
119
|
+
self.shape = list(shape)
|
120
|
+
self.compressed = compressed
|
121
|
+
self.bitmask = bitmask
|
122
|
+
self.row_offsets = row_offsets
|
123
|
+
|
124
|
+
@staticmethod
|
125
|
+
def from_dense(tensor: Tensor) -> "BitmaskTensor":
|
126
|
+
"""
|
127
|
+
:param tensor: dense tensor to compress
|
128
|
+
:return: instantiated compressed tensor
|
129
|
+
"""
|
130
|
+
shape = tensor.shape
|
131
|
+
compressed, bitmask, row_offsets = bitmask_compress(tensor.cpu())
|
132
|
+
return BitmaskTensor(
|
133
|
+
shape=shape, compressed=compressed, bitmask=bitmask, row_offsets=row_offsets
|
134
|
+
)
|
135
|
+
|
136
|
+
def decompress(self) -> Tensor:
|
137
|
+
"""
|
138
|
+
:return: reconstructed dense tensor
|
139
|
+
"""
|
140
|
+
return bitmask_decompress(self.compressed, self.bitmask, self.shape)
|
141
|
+
|
142
|
+
def curr_memory_size_bytes(self):
|
143
|
+
"""
|
144
|
+
:return: size in bytes required to store compressed tensor on disk
|
145
|
+
"""
|
146
|
+
|
147
|
+
def sizeof_tensor(a):
|
148
|
+
return a.element_size() * a.nelement()
|
149
|
+
|
150
|
+
return (
|
151
|
+
sizeof_tensor(self.compressed)
|
152
|
+
+ sizeof_tensor(self.bitmask)
|
153
|
+
+ sizeof_tensor(self.row_offsets)
|
154
|
+
)
|
155
|
+
|
156
|
+
def dict(self, name_prefix: str, device: str = "cpu") -> Dict[str, Tensor]:
|
157
|
+
"""
|
158
|
+
:name_prefix: name of original tensor to store compressed weight as
|
159
|
+
:return: dict of compressed data for the stored weight
|
160
|
+
"""
|
161
|
+
return {
|
162
|
+
merge_names(name_prefix, "shape"): torch.tensor(self.shape, device=device),
|
163
|
+
merge_names(name_prefix, "compressed"): self.compressed.to(device),
|
164
|
+
merge_names(name_prefix, "bitmask"): self.bitmask.to(device),
|
165
|
+
merge_names(name_prefix, "row_offsets"): self.row_offsets.to(device),
|
166
|
+
}
|
167
|
+
|
168
|
+
def __repr__(self):
|
169
|
+
return f"BitmaskTensor(shape={self.shape}, compressed=True)"
|
170
|
+
|
171
|
+
|
172
|
+
def bitmask_compress(tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
173
|
+
"""
|
174
|
+
Compresses a dense tensor using bitmask compression
|
175
|
+
|
176
|
+
:param tensor: dense tensor to compress
|
177
|
+
:return: tuple of compressed data representing tensor
|
178
|
+
"""
|
179
|
+
bytemasks = tensor != 0
|
180
|
+
row_counts = bytemasks.sum(dim=-1)
|
181
|
+
row_offsets = torch.cumsum(row_counts, 0) - row_counts
|
182
|
+
values = tensor[bytemasks]
|
183
|
+
bitmasks_packed = pack_bitmasks(bytemasks)
|
184
|
+
|
185
|
+
return values, bitmasks_packed, row_offsets
|
186
|
+
|
187
|
+
|
188
|
+
def bitmask_decompress(
|
189
|
+
values: Tensor, bitmasks: Tensor, original_shape: torch.Size
|
190
|
+
) -> Tensor:
|
191
|
+
"""
|
192
|
+
Reconstructs a dense tensor from a compressed one
|
193
|
+
|
194
|
+
:param values: 1d tensor of non-zero values
|
195
|
+
:param bitmasks: 2d int8 tensor flagging locations of non-zero values in the
|
196
|
+
tensors original shape
|
197
|
+
:param original_shape: shape of the dense tensor
|
198
|
+
:return: decompressed dense tensor
|
199
|
+
"""
|
200
|
+
bytemasks_unpacked = unpack_bitmasks(bitmasks, original_shape)
|
201
|
+
|
202
|
+
decompressed_tensor = torch.zeros(original_shape, dtype=values.dtype)
|
203
|
+
decompressed_tensor[bytemasks_unpacked] = values
|
204
|
+
|
205
|
+
return decompressed_tensor
|
206
|
+
|
207
|
+
|
208
|
+
def pack_bitmasks(bytemasks: Tensor) -> Tensor:
|
209
|
+
"""
|
210
|
+
Converts a bytemask tensor to a bitmask tensor to reduce memory. Shape RxC will be
|
211
|
+
compressed to R x ceil(C/8)
|
212
|
+
:param bytemasks: mask tensor where each byte corresponds to a weight
|
213
|
+
:return: mask tensor where each bit corresounds to a weight
|
214
|
+
"""
|
215
|
+
packed_bits_numpy = numpy.packbits(bytemasks.numpy(), axis=-1, bitorder="little")
|
216
|
+
packed_bits_torch = torch.from_numpy(packed_bits_numpy)
|
217
|
+
|
218
|
+
return packed_bits_torch
|
219
|
+
|
220
|
+
|
221
|
+
def unpack_bitmasks(packed_bitmasks: Tensor, original_shape: torch.Size) -> Tensor:
|
222
|
+
"""
|
223
|
+
Converts a bitmask tensor back to a bytemask tensor for use during decompression
|
224
|
+
|
225
|
+
:param packed_bitmasks: mask tensor where each bit corresponds to a weight
|
226
|
+
:param original_shape: dense shape to decompress to
|
227
|
+
:return: boolean mask of weights in the original dense shape
|
228
|
+
"""
|
229
|
+
# Unpack the bits
|
230
|
+
unpacked_bits = numpy.unpackbits(
|
231
|
+
packed_bitmasks.numpy(), axis=-1, count=original_shape[-1], bitorder="little"
|
232
|
+
)
|
233
|
+
|
234
|
+
# Reshape to match the original shape
|
235
|
+
unpacked_bitmasks_torch = torch.from_numpy(
|
236
|
+
unpacked_bits.reshape(original_shape).astype(bool)
|
237
|
+
)
|
238
|
+
|
239
|
+
return unpacked_bitmasks_torch
|
@@ -0,0 +1,18 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# flake8: noqa
|
16
|
+
from .base import *
|
17
|
+
from .dense import *
|
18
|
+
from .sparse_bitmask import *
|
@@ -0,0 +1,43 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from enum import Enum
|
16
|
+
from typing import Optional
|
17
|
+
|
18
|
+
from compressed_tensors.registry import RegistryMixin
|
19
|
+
from pydantic import BaseModel
|
20
|
+
|
21
|
+
|
22
|
+
__all__ = ["SparsityCompressionConfig", "CompressionFormat"]
|
23
|
+
|
24
|
+
|
25
|
+
class CompressionFormat(Enum):
|
26
|
+
dense = "dense"
|
27
|
+
sparse_bitmask = "sparse-bitmask"
|
28
|
+
int_quantized = "int-quantized"
|
29
|
+
|
30
|
+
|
31
|
+
class SparsityCompressionConfig(RegistryMixin, BaseModel):
|
32
|
+
"""
|
33
|
+
Base data class for storing sparsity compression parameters
|
34
|
+
|
35
|
+
:param format: name of compression format
|
36
|
+
:param global_sparsity: average sparsity of the entire model
|
37
|
+
:param sparsity_structure: structure of the sparsity, such as
|
38
|
+
"unstructured", "2:4", "8:16" etc
|
39
|
+
"""
|
40
|
+
|
41
|
+
format: str
|
42
|
+
global_sparsity: Optional[float] = 0.0
|
43
|
+
sparsity_structure: Optional[str] = "unstructured"
|
@@ -0,0 +1,36 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
|
18
|
+
|
19
|
+
|
20
|
+
__all__ = ["DenseSparsityConfig"]
|
21
|
+
|
22
|
+
|
23
|
+
@SparsityCompressionConfig.register(name=CompressionFormat.dense.value)
|
24
|
+
class DenseSparsityConfig(SparsityCompressionConfig):
|
25
|
+
"""
|
26
|
+
Identity configuration for storing a sparse model in
|
27
|
+
an uncompressed dense format
|
28
|
+
|
29
|
+
:param global_sparsity: average sparsity of the entire model
|
30
|
+
:param sparsity_structure: structure of the sparsity, such as
|
31
|
+
"unstructured", "2:4", "8:16" etc
|
32
|
+
"""
|
33
|
+
|
34
|
+
format: str = CompressionFormat.dense.value
|
35
|
+
global_sparsity: Optional[float] = 0.0
|
36
|
+
sparsity_structure: Optional[str] = "unstructured"
|
@@ -0,0 +1,36 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
|
18
|
+
|
19
|
+
|
20
|
+
__all__ = ["BitmaskConfig"]
|
21
|
+
|
22
|
+
|
23
|
+
@SparsityCompressionConfig.register(name=CompressionFormat.sparse_bitmask.value)
|
24
|
+
class BitmaskConfig(SparsityCompressionConfig):
|
25
|
+
"""
|
26
|
+
Configuration for storing a sparse model using
|
27
|
+
bitmask compression
|
28
|
+
|
29
|
+
:param global_sparsity: average sparsity of the entire model
|
30
|
+
:param sparsity_structure: structure of the sparsity, such as
|
31
|
+
"unstructured", "2:4", "8:16" etc
|
32
|
+
"""
|
33
|
+
|
34
|
+
format: str = CompressionFormat.sparse_bitmask.value
|
35
|
+
global_sparsity: Optional[float] = 0.0
|
36
|
+
sparsity_structure: Optional[str] = "unstructured"
|
@@ -0,0 +1,21 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# flake8: noqa
|
16
|
+
# isort: skip_file
|
17
|
+
|
18
|
+
from .quant_args import *
|
19
|
+
from .quant_config import *
|
20
|
+
from .quant_scheme import *
|
21
|
+
from .lifecycle import *
|
@@ -0,0 +1,23 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# flake8: noqa
|
16
|
+
# isort: skip_file
|
17
|
+
|
18
|
+
from .calibration import *
|
19
|
+
from .forward import *
|
20
|
+
from .frozen import *
|
21
|
+
from .initialize import *
|
22
|
+
from .compressed import *
|
23
|
+
from .apply import *
|
@@ -0,0 +1,196 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import re
|
16
|
+
from collections import OrderedDict
|
17
|
+
from typing import Dict, Iterable, Optional
|
18
|
+
|
19
|
+
from compressed_tensors.quantization.lifecycle.calibration import (
|
20
|
+
set_module_for_calibration,
|
21
|
+
)
|
22
|
+
from compressed_tensors.quantization.lifecycle.compressed import (
|
23
|
+
compress_quantized_weights,
|
24
|
+
)
|
25
|
+
from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization
|
26
|
+
from compressed_tensors.quantization.lifecycle.initialize import (
|
27
|
+
initialize_module_for_quantization,
|
28
|
+
)
|
29
|
+
from compressed_tensors.quantization.quant_config import (
|
30
|
+
QuantizationConfig,
|
31
|
+
QuantizationStatus,
|
32
|
+
)
|
33
|
+
from compressed_tensors.quantization.utils import iter_named_leaf_modules
|
34
|
+
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
|
35
|
+
from torch.nn import Module
|
36
|
+
|
37
|
+
|
38
|
+
__all__ = [
|
39
|
+
"load_pretrained_quantization",
|
40
|
+
"apply_quantization_config",
|
41
|
+
"apply_quantization_status",
|
42
|
+
"find_first_name_or_class_match",
|
43
|
+
]
|
44
|
+
|
45
|
+
from compressed_tensors.quantization.utils.helpers import is_module_quantized
|
46
|
+
from compressed_tensors.utils.safetensors_load import get_quantization_state_dict
|
47
|
+
|
48
|
+
|
49
|
+
def load_pretrained_quantization(model: Module, model_name_or_path: str):
|
50
|
+
"""
|
51
|
+
Loads the quantization parameters (scale and zero point) from model_name_or_path to
|
52
|
+
a model that has already been initialized with a quantization config
|
53
|
+
|
54
|
+
:param model: model to load pretrained quantization parameters to
|
55
|
+
:param model_name_or_path: Hugging Face stub or local folder containing a quantized
|
56
|
+
model, which is used to load quantization parameters
|
57
|
+
"""
|
58
|
+
model_path = get_safetensors_folder(model_name_or_path)
|
59
|
+
state_dict = get_quantization_state_dict(model_path)
|
60
|
+
|
61
|
+
for name, submodule in iter_named_leaf_modules(model):
|
62
|
+
if not is_module_quantized(submodule):
|
63
|
+
continue
|
64
|
+
if submodule.quantization_scheme.weights is not None:
|
65
|
+
base_name = "weight"
|
66
|
+
_load_quant_args_from_state_dict(
|
67
|
+
base_name=base_name,
|
68
|
+
module_name=name,
|
69
|
+
module=submodule,
|
70
|
+
state_dict=state_dict,
|
71
|
+
)
|
72
|
+
if submodule.quantization_scheme.input_activations is not None:
|
73
|
+
base_name = "input"
|
74
|
+
_load_quant_args_from_state_dict(
|
75
|
+
base_name=base_name,
|
76
|
+
module_name=name,
|
77
|
+
module=submodule,
|
78
|
+
state_dict=state_dict,
|
79
|
+
)
|
80
|
+
if submodule.quantization_scheme.output_activations is not None:
|
81
|
+
base_name = "output"
|
82
|
+
_load_quant_args_from_state_dict(
|
83
|
+
base_name=base_name,
|
84
|
+
module_name=name,
|
85
|
+
module=submodule,
|
86
|
+
state_dict=state_dict,
|
87
|
+
)
|
88
|
+
|
89
|
+
|
90
|
+
def apply_quantization_config(model: Module, config: QuantizationConfig):
|
91
|
+
"""
|
92
|
+
Initializes the model for quantization in-place based on the given config
|
93
|
+
|
94
|
+
:param model: model to apply quantization config to
|
95
|
+
:param config: quantization config
|
96
|
+
"""
|
97
|
+
# build mapping of targets to schemes for easier matching
|
98
|
+
# use ordered dict to preserve target ordering in config
|
99
|
+
target_to_scheme = OrderedDict()
|
100
|
+
for scheme in config.config_groups.values():
|
101
|
+
for target in scheme.targets:
|
102
|
+
target_to_scheme[target] = scheme
|
103
|
+
|
104
|
+
# mark appropriate layers for quantization by setting their quantization schemes
|
105
|
+
for name, submodule in iter_named_leaf_modules(model):
|
106
|
+
if find_first_name_or_class_match(name, submodule, config.ignore):
|
107
|
+
continue # layer matches ignore list, continue
|
108
|
+
target = find_first_name_or_class_match(name, submodule, target_to_scheme)
|
109
|
+
if target is not None:
|
110
|
+
# target matched - add layer and scheme to target list
|
111
|
+
submodule.quantization_scheme = target_to_scheme[target]
|
112
|
+
|
113
|
+
# apply current quantization status across all targeted layers
|
114
|
+
apply_quantization_status(model, config.quantization_status)
|
115
|
+
|
116
|
+
|
117
|
+
def apply_quantization_status(model: Module, status: QuantizationStatus):
|
118
|
+
"""
|
119
|
+
Applies in place the quantization lifecycle up to the given status
|
120
|
+
|
121
|
+
:param model: model to apply quantization to
|
122
|
+
:param status: status to update the module to
|
123
|
+
"""
|
124
|
+
current_status = _infer_status(model)
|
125
|
+
|
126
|
+
if status >= QuantizationStatus.INITIALIZED > current_status:
|
127
|
+
model.apply(initialize_module_for_quantization)
|
128
|
+
|
129
|
+
if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
|
130
|
+
model.apply(set_module_for_calibration)
|
131
|
+
|
132
|
+
if current_status < status >= QuantizationStatus.FROZEN > current_status:
|
133
|
+
model.apply(freeze_module_quantization)
|
134
|
+
|
135
|
+
if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
|
136
|
+
model.apply(compress_quantized_weights)
|
137
|
+
|
138
|
+
|
139
|
+
def find_first_name_or_class_match(
|
140
|
+
name: str, module: Module, targets: Iterable[str], check_contains: bool = False
|
141
|
+
) -> Optional[str]:
|
142
|
+
# first element of targets that matches the given name
|
143
|
+
# if no name matches returns first target that matches the class name
|
144
|
+
# returns None otherwise
|
145
|
+
return _find_first_match(name, targets) or _find_first_match(
|
146
|
+
module.__class__.__name__, targets, check_contains
|
147
|
+
)
|
148
|
+
|
149
|
+
|
150
|
+
def _find_first_match(
|
151
|
+
value: str, targets: Iterable[str], check_contains: bool = False
|
152
|
+
) -> Optional[str]:
|
153
|
+
# returns first element of target that matches value either
|
154
|
+
# exactly or as a regex after 're:'. if check_contains is set to True,
|
155
|
+
# additionally checks if the target string is contained with value.
|
156
|
+
for target in targets:
|
157
|
+
if target.startswith("re:"):
|
158
|
+
pattern = target[3:]
|
159
|
+
if re.match(pattern, value):
|
160
|
+
return target
|
161
|
+
elif check_contains:
|
162
|
+
if target.lower() in value.lower():
|
163
|
+
return target
|
164
|
+
elif target == value:
|
165
|
+
return target
|
166
|
+
return None
|
167
|
+
|
168
|
+
|
169
|
+
def _infer_status(model: Module) -> Optional[QuantizationStatus]:
|
170
|
+
for module in model.modules():
|
171
|
+
status = getattr(module, "quantization_status", None)
|
172
|
+
if status is not None:
|
173
|
+
return status
|
174
|
+
return None
|
175
|
+
|
176
|
+
|
177
|
+
def _load_quant_args_from_state_dict(
|
178
|
+
base_name: str, module_name: str, module: Module, state_dict: Dict
|
179
|
+
):
|
180
|
+
"""
|
181
|
+
Loads scale and zero point from a state_dict into the specified module
|
182
|
+
|
183
|
+
:param base_name: quantization target, one of: weights, input_activations or
|
184
|
+
output_activations
|
185
|
+
:param module_name: pytorch module name to look up in state_dict
|
186
|
+
:module: pytorch module associated with module_name
|
187
|
+
:state_dict: state_dict to search for matching quantization parameters
|
188
|
+
"""
|
189
|
+
scale_name = f"{base_name}_scale"
|
190
|
+
zp_name = f"{base_name}_zero_point"
|
191
|
+
device = next(module.parameters()).device
|
192
|
+
|
193
|
+
scale = getattr(module, scale_name)
|
194
|
+
zp = getattr(module, zp_name)
|
195
|
+
scale.data = state_dict[f"{module_name}.{scale_name}"].to(device)
|
196
|
+
zp.data = state_dict[f"{module_name}.{zp_name}"].to(device)
|
@@ -0,0 +1,51 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
import logging
|
17
|
+
|
18
|
+
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
19
|
+
from torch.nn import Module
|
20
|
+
|
21
|
+
|
22
|
+
__all__ = [
|
23
|
+
"set_module_for_calibration",
|
24
|
+
]
|
25
|
+
|
26
|
+
|
27
|
+
_LOGGER = logging.getLogger(__name__)
|
28
|
+
|
29
|
+
|
30
|
+
def set_module_for_calibration(module: Module):
|
31
|
+
"""
|
32
|
+
marks a layer as ready for calibration which activates observers
|
33
|
+
to update scales and zero points on each forward pass
|
34
|
+
|
35
|
+
apply to full model with `model.apply(set_module_for_calibration)`
|
36
|
+
|
37
|
+
:param module: module to set for calibration
|
38
|
+
"""
|
39
|
+
if not getattr(module, "quantization_scheme", None):
|
40
|
+
# no quantization scheme nothing to do
|
41
|
+
return
|
42
|
+
status = getattr(module, "quantization_status", None)
|
43
|
+
if not status or status != QuantizationStatus.INITIALIZED:
|
44
|
+
raise _LOGGER.warning(
|
45
|
+
f"Attempting set module with status {status} to calibration mode. "
|
46
|
+
f"but status is not {QuantizationStatus.INITIALIZED} - you may "
|
47
|
+
"be calibrating an uninitialized module which may fail or attempting "
|
48
|
+
"to re-calibrate a frozen module"
|
49
|
+
)
|
50
|
+
|
51
|
+
module.quantization_status = QuantizationStatus.CALIBRATION
|