compressed-tensors-nightly 0.3.3.20240514__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. compressed_tensors/__init__.py +21 -0
  2. compressed_tensors/base.py +17 -0
  3. compressed_tensors/compressors/__init__.py +22 -0
  4. compressed_tensors/compressors/base.py +59 -0
  5. compressed_tensors/compressors/dense.py +34 -0
  6. compressed_tensors/compressors/helpers.py +137 -0
  7. compressed_tensors/compressors/int_quantized.py +95 -0
  8. compressed_tensors/compressors/model_compressor.py +264 -0
  9. compressed_tensors/compressors/sparse_bitmask.py +239 -0
  10. compressed_tensors/config/__init__.py +18 -0
  11. compressed_tensors/config/base.py +43 -0
  12. compressed_tensors/config/dense.py +36 -0
  13. compressed_tensors/config/sparse_bitmask.py +36 -0
  14. compressed_tensors/quantization/__init__.py +21 -0
  15. compressed_tensors/quantization/lifecycle/__init__.py +23 -0
  16. compressed_tensors/quantization/lifecycle/apply.py +196 -0
  17. compressed_tensors/quantization/lifecycle/calibration.py +51 -0
  18. compressed_tensors/quantization/lifecycle/compressed.py +69 -0
  19. compressed_tensors/quantization/lifecycle/forward.py +333 -0
  20. compressed_tensors/quantization/lifecycle/frozen.py +50 -0
  21. compressed_tensors/quantization/lifecycle/initialize.py +99 -0
  22. compressed_tensors/quantization/observers/__init__.py +21 -0
  23. compressed_tensors/quantization/observers/base.py +130 -0
  24. compressed_tensors/quantization/observers/helpers.py +54 -0
  25. compressed_tensors/quantization/observers/memoryless.py +48 -0
  26. compressed_tensors/quantization/observers/min_max.py +80 -0
  27. compressed_tensors/quantization/quant_args.py +125 -0
  28. compressed_tensors/quantization/quant_config.py +210 -0
  29. compressed_tensors/quantization/quant_scheme.py +39 -0
  30. compressed_tensors/quantization/utils/__init__.py +16 -0
  31. compressed_tensors/quantization/utils/helpers.py +131 -0
  32. compressed_tensors/registry/__init__.py +17 -0
  33. compressed_tensors/registry/registry.py +360 -0
  34. compressed_tensors/utils/__init__.py +16 -0
  35. compressed_tensors/utils/helpers.py +45 -0
  36. compressed_tensors/utils/safetensors_load.py +237 -0
  37. compressed_tensors/version.py +50 -0
  38. compressed_tensors_nightly-0.3.3.20240514.dist-info/LICENSE +201 -0
  39. compressed_tensors_nightly-0.3.3.20240514.dist-info/METADATA +105 -0
  40. compressed_tensors_nightly-0.3.3.20240514.dist-info/RECORD +42 -0
  41. compressed_tensors_nightly-0.3.3.20240514.dist-info/WHEEL +5 -0
  42. compressed_tensors_nightly-0.3.3.20240514.dist-info/top_level.txt +1 -0
@@ -0,0 +1,239 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ from typing import Dict, Generator, List, Tuple, Union
17
+
18
+ import numpy
19
+ import torch
20
+ from compressed_tensors.compressors import Compressor
21
+ from compressed_tensors.config import CompressionFormat
22
+ from compressed_tensors.utils import get_nested_weight_mappings, merge_names
23
+ from safetensors import safe_open
24
+ from torch import Tensor
25
+ from tqdm import tqdm
26
+
27
+
28
+ __all__ = [
29
+ "BitmaskCompressor",
30
+ "BitmaskTensor",
31
+ "bitmask_compress",
32
+ "bitmask_decompress",
33
+ "pack_bitmasks",
34
+ "unpack_bitmasks",
35
+ ]
36
+
37
+ _LOGGER: logging.Logger = logging.getLogger(__name__)
38
+
39
+
40
+ @Compressor.register(name=CompressionFormat.sparse_bitmask.value)
41
+ class BitmaskCompressor(Compressor):
42
+ """
43
+ Compression for sparse models using bitmasks. Non-zero weights are stored in a 1d
44
+ values tensor, with their locations stored in a 2d bitmask
45
+ """
46
+
47
+ COMPRESSION_PARAM_NAMES = ["shape", "compressed", "bitmask", "row_offsets"]
48
+
49
+ def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
50
+ """
51
+ Compresses a dense state dict using bitmask compression
52
+
53
+ :param model_state: state dict of uncompressed model
54
+ :return: compressed state dict
55
+ """
56
+ compressed_dict = {}
57
+ _LOGGER.debug(
58
+ f"Compressing model with {len(model_state)} parameterized layers..."
59
+ )
60
+ for name, value in tqdm(model_state.items(), desc="Compressing model"):
61
+ bitmask_tensor = BitmaskTensor.from_dense(value)
62
+ bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu")
63
+ for key in bitmask_dict.keys():
64
+ if key in compressed_dict:
65
+ _LOGGER.warn(
66
+ f"Expected all compressed state_dict keys to be unique, but "
67
+ f"found an existing entry for {key}. The existing entry will "
68
+ "be replaced."
69
+ )
70
+ compressed_dict |= bitmask_dict
71
+
72
+ return compressed_dict
73
+
74
+ def decompress(
75
+ self, path_to_model_or_tensors: str, device: str = "cpu"
76
+ ) -> Generator[Tuple[str, Tensor], None, None]:
77
+ """
78
+ Reads a bitmask compressed state dict located
79
+ at path_to_model_or_tensors and returns a generator
80
+ for sequentially decompressing back to a dense state dict
81
+
82
+ :param model_path: path to compressed safetensors model (directory with
83
+ one or more safetensors files) or compressed tensors file
84
+ :param device: device to load decompressed weights onto
85
+ :return: iterator for generating decompressed weights
86
+ """
87
+ weight_mappings = get_nested_weight_mappings(
88
+ path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
89
+ )
90
+ for weight_name in weight_mappings.keys():
91
+ weight_data = {}
92
+ for param_name, safe_path in weight_mappings[weight_name].items():
93
+ full_name = merge_names(weight_name, param_name)
94
+ with safe_open(safe_path, framework="pt", device=device) as f:
95
+ weight_data[param_name] = f.get_tensor(full_name)
96
+ data = BitmaskTensor(**weight_data)
97
+ decompressed = data.decompress()
98
+ yield weight_name, decompressed
99
+
100
+
101
+ class BitmaskTensor:
102
+ """
103
+ Owns compressions and decompression for a single bitmask compressed tensor.
104
+ Adapted from: https://github.com/mgoin/torch_bitmask/tree/main
105
+
106
+ :param shape: shape of dense tensor
107
+ :compressed: flat tensor of non-zero values
108
+ :bitmask: 2d bitmask of non-zero values
109
+ :row_offsets: flat tensor indicating what index in values each dense row starts at
110
+ """
111
+
112
+ def __init__(
113
+ self,
114
+ shape: Union[torch.Size, List],
115
+ compressed: Tensor,
116
+ bitmask: Tensor,
117
+ row_offsets: Tensor,
118
+ ):
119
+ self.shape = list(shape)
120
+ self.compressed = compressed
121
+ self.bitmask = bitmask
122
+ self.row_offsets = row_offsets
123
+
124
+ @staticmethod
125
+ def from_dense(tensor: Tensor) -> "BitmaskTensor":
126
+ """
127
+ :param tensor: dense tensor to compress
128
+ :return: instantiated compressed tensor
129
+ """
130
+ shape = tensor.shape
131
+ compressed, bitmask, row_offsets = bitmask_compress(tensor.cpu())
132
+ return BitmaskTensor(
133
+ shape=shape, compressed=compressed, bitmask=bitmask, row_offsets=row_offsets
134
+ )
135
+
136
+ def decompress(self) -> Tensor:
137
+ """
138
+ :return: reconstructed dense tensor
139
+ """
140
+ return bitmask_decompress(self.compressed, self.bitmask, self.shape)
141
+
142
+ def curr_memory_size_bytes(self):
143
+ """
144
+ :return: size in bytes required to store compressed tensor on disk
145
+ """
146
+
147
+ def sizeof_tensor(a):
148
+ return a.element_size() * a.nelement()
149
+
150
+ return (
151
+ sizeof_tensor(self.compressed)
152
+ + sizeof_tensor(self.bitmask)
153
+ + sizeof_tensor(self.row_offsets)
154
+ )
155
+
156
+ def dict(self, name_prefix: str, device: str = "cpu") -> Dict[str, Tensor]:
157
+ """
158
+ :name_prefix: name of original tensor to store compressed weight as
159
+ :return: dict of compressed data for the stored weight
160
+ """
161
+ return {
162
+ merge_names(name_prefix, "shape"): torch.tensor(self.shape, device=device),
163
+ merge_names(name_prefix, "compressed"): self.compressed.to(device),
164
+ merge_names(name_prefix, "bitmask"): self.bitmask.to(device),
165
+ merge_names(name_prefix, "row_offsets"): self.row_offsets.to(device),
166
+ }
167
+
168
+ def __repr__(self):
169
+ return f"BitmaskTensor(shape={self.shape}, compressed=True)"
170
+
171
+
172
+ def bitmask_compress(tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
173
+ """
174
+ Compresses a dense tensor using bitmask compression
175
+
176
+ :param tensor: dense tensor to compress
177
+ :return: tuple of compressed data representing tensor
178
+ """
179
+ bytemasks = tensor != 0
180
+ row_counts = bytemasks.sum(dim=-1)
181
+ row_offsets = torch.cumsum(row_counts, 0) - row_counts
182
+ values = tensor[bytemasks]
183
+ bitmasks_packed = pack_bitmasks(bytemasks)
184
+
185
+ return values, bitmasks_packed, row_offsets
186
+
187
+
188
+ def bitmask_decompress(
189
+ values: Tensor, bitmasks: Tensor, original_shape: torch.Size
190
+ ) -> Tensor:
191
+ """
192
+ Reconstructs a dense tensor from a compressed one
193
+
194
+ :param values: 1d tensor of non-zero values
195
+ :param bitmasks: 2d int8 tensor flagging locations of non-zero values in the
196
+ tensors original shape
197
+ :param original_shape: shape of the dense tensor
198
+ :return: decompressed dense tensor
199
+ """
200
+ bytemasks_unpacked = unpack_bitmasks(bitmasks, original_shape)
201
+
202
+ decompressed_tensor = torch.zeros(original_shape, dtype=values.dtype)
203
+ decompressed_tensor[bytemasks_unpacked] = values
204
+
205
+ return decompressed_tensor
206
+
207
+
208
+ def pack_bitmasks(bytemasks: Tensor) -> Tensor:
209
+ """
210
+ Converts a bytemask tensor to a bitmask tensor to reduce memory. Shape RxC will be
211
+ compressed to R x ceil(C/8)
212
+ :param bytemasks: mask tensor where each byte corresponds to a weight
213
+ :return: mask tensor where each bit corresounds to a weight
214
+ """
215
+ packed_bits_numpy = numpy.packbits(bytemasks.numpy(), axis=-1, bitorder="little")
216
+ packed_bits_torch = torch.from_numpy(packed_bits_numpy)
217
+
218
+ return packed_bits_torch
219
+
220
+
221
+ def unpack_bitmasks(packed_bitmasks: Tensor, original_shape: torch.Size) -> Tensor:
222
+ """
223
+ Converts a bitmask tensor back to a bytemask tensor for use during decompression
224
+
225
+ :param packed_bitmasks: mask tensor where each bit corresponds to a weight
226
+ :param original_shape: dense shape to decompress to
227
+ :return: boolean mask of weights in the original dense shape
228
+ """
229
+ # Unpack the bits
230
+ unpacked_bits = numpy.unpackbits(
231
+ packed_bitmasks.numpy(), axis=-1, count=original_shape[-1], bitorder="little"
232
+ )
233
+
234
+ # Reshape to match the original shape
235
+ unpacked_bitmasks_torch = torch.from_numpy(
236
+ unpacked_bits.reshape(original_shape).astype(bool)
237
+ )
238
+
239
+ return unpacked_bitmasks_torch
@@ -0,0 +1,18 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # flake8: noqa
16
+ from .base import *
17
+ from .dense import *
18
+ from .sparse_bitmask import *
@@ -0,0 +1,43 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from enum import Enum
16
+ from typing import Optional
17
+
18
+ from compressed_tensors.registry import RegistryMixin
19
+ from pydantic import BaseModel
20
+
21
+
22
+ __all__ = ["SparsityCompressionConfig", "CompressionFormat"]
23
+
24
+
25
+ class CompressionFormat(Enum):
26
+ dense = "dense"
27
+ sparse_bitmask = "sparse-bitmask"
28
+ int_quantized = "int-quantized"
29
+
30
+
31
+ class SparsityCompressionConfig(RegistryMixin, BaseModel):
32
+ """
33
+ Base data class for storing sparsity compression parameters
34
+
35
+ :param format: name of compression format
36
+ :param global_sparsity: average sparsity of the entire model
37
+ :param sparsity_structure: structure of the sparsity, such as
38
+ "unstructured", "2:4", "8:16" etc
39
+ """
40
+
41
+ format: str
42
+ global_sparsity: Optional[float] = 0.0
43
+ sparsity_structure: Optional[str] = "unstructured"
@@ -0,0 +1,36 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
18
+
19
+
20
+ __all__ = ["DenseSparsityConfig"]
21
+
22
+
23
+ @SparsityCompressionConfig.register(name=CompressionFormat.dense.value)
24
+ class DenseSparsityConfig(SparsityCompressionConfig):
25
+ """
26
+ Identity configuration for storing a sparse model in
27
+ an uncompressed dense format
28
+
29
+ :param global_sparsity: average sparsity of the entire model
30
+ :param sparsity_structure: structure of the sparsity, such as
31
+ "unstructured", "2:4", "8:16" etc
32
+ """
33
+
34
+ format: str = CompressionFormat.dense.value
35
+ global_sparsity: Optional[float] = 0.0
36
+ sparsity_structure: Optional[str] = "unstructured"
@@ -0,0 +1,36 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
18
+
19
+
20
+ __all__ = ["BitmaskConfig"]
21
+
22
+
23
+ @SparsityCompressionConfig.register(name=CompressionFormat.sparse_bitmask.value)
24
+ class BitmaskConfig(SparsityCompressionConfig):
25
+ """
26
+ Configuration for storing a sparse model using
27
+ bitmask compression
28
+
29
+ :param global_sparsity: average sparsity of the entire model
30
+ :param sparsity_structure: structure of the sparsity, such as
31
+ "unstructured", "2:4", "8:16" etc
32
+ """
33
+
34
+ format: str = CompressionFormat.sparse_bitmask.value
35
+ global_sparsity: Optional[float] = 0.0
36
+ sparsity_structure: Optional[str] = "unstructured"
@@ -0,0 +1,21 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # flake8: noqa
16
+ # isort: skip_file
17
+
18
+ from .quant_args import *
19
+ from .quant_config import *
20
+ from .quant_scheme import *
21
+ from .lifecycle import *
@@ -0,0 +1,23 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # flake8: noqa
16
+ # isort: skip_file
17
+
18
+ from .calibration import *
19
+ from .forward import *
20
+ from .frozen import *
21
+ from .initialize import *
22
+ from .compressed import *
23
+ from .apply import *
@@ -0,0 +1,196 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from collections import OrderedDict
17
+ from typing import Dict, Iterable, Optional
18
+
19
+ from compressed_tensors.quantization.lifecycle.calibration import (
20
+ set_module_for_calibration,
21
+ )
22
+ from compressed_tensors.quantization.lifecycle.compressed import (
23
+ compress_quantized_weights,
24
+ )
25
+ from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization
26
+ from compressed_tensors.quantization.lifecycle.initialize import (
27
+ initialize_module_for_quantization,
28
+ )
29
+ from compressed_tensors.quantization.quant_config import (
30
+ QuantizationConfig,
31
+ QuantizationStatus,
32
+ )
33
+ from compressed_tensors.quantization.utils import iter_named_leaf_modules
34
+ from compressed_tensors.utils.safetensors_load import get_safetensors_folder
35
+ from torch.nn import Module
36
+
37
+
38
+ __all__ = [
39
+ "load_pretrained_quantization",
40
+ "apply_quantization_config",
41
+ "apply_quantization_status",
42
+ "find_first_name_or_class_match",
43
+ ]
44
+
45
+ from compressed_tensors.quantization.utils.helpers import is_module_quantized
46
+ from compressed_tensors.utils.safetensors_load import get_quantization_state_dict
47
+
48
+
49
+ def load_pretrained_quantization(model: Module, model_name_or_path: str):
50
+ """
51
+ Loads the quantization parameters (scale and zero point) from model_name_or_path to
52
+ a model that has already been initialized with a quantization config
53
+
54
+ :param model: model to load pretrained quantization parameters to
55
+ :param model_name_or_path: Hugging Face stub or local folder containing a quantized
56
+ model, which is used to load quantization parameters
57
+ """
58
+ model_path = get_safetensors_folder(model_name_or_path)
59
+ state_dict = get_quantization_state_dict(model_path)
60
+
61
+ for name, submodule in iter_named_leaf_modules(model):
62
+ if not is_module_quantized(submodule):
63
+ continue
64
+ if submodule.quantization_scheme.weights is not None:
65
+ base_name = "weight"
66
+ _load_quant_args_from_state_dict(
67
+ base_name=base_name,
68
+ module_name=name,
69
+ module=submodule,
70
+ state_dict=state_dict,
71
+ )
72
+ if submodule.quantization_scheme.input_activations is not None:
73
+ base_name = "input"
74
+ _load_quant_args_from_state_dict(
75
+ base_name=base_name,
76
+ module_name=name,
77
+ module=submodule,
78
+ state_dict=state_dict,
79
+ )
80
+ if submodule.quantization_scheme.output_activations is not None:
81
+ base_name = "output"
82
+ _load_quant_args_from_state_dict(
83
+ base_name=base_name,
84
+ module_name=name,
85
+ module=submodule,
86
+ state_dict=state_dict,
87
+ )
88
+
89
+
90
+ def apply_quantization_config(model: Module, config: QuantizationConfig):
91
+ """
92
+ Initializes the model for quantization in-place based on the given config
93
+
94
+ :param model: model to apply quantization config to
95
+ :param config: quantization config
96
+ """
97
+ # build mapping of targets to schemes for easier matching
98
+ # use ordered dict to preserve target ordering in config
99
+ target_to_scheme = OrderedDict()
100
+ for scheme in config.config_groups.values():
101
+ for target in scheme.targets:
102
+ target_to_scheme[target] = scheme
103
+
104
+ # mark appropriate layers for quantization by setting their quantization schemes
105
+ for name, submodule in iter_named_leaf_modules(model):
106
+ if find_first_name_or_class_match(name, submodule, config.ignore):
107
+ continue # layer matches ignore list, continue
108
+ target = find_first_name_or_class_match(name, submodule, target_to_scheme)
109
+ if target is not None:
110
+ # target matched - add layer and scheme to target list
111
+ submodule.quantization_scheme = target_to_scheme[target]
112
+
113
+ # apply current quantization status across all targeted layers
114
+ apply_quantization_status(model, config.quantization_status)
115
+
116
+
117
+ def apply_quantization_status(model: Module, status: QuantizationStatus):
118
+ """
119
+ Applies in place the quantization lifecycle up to the given status
120
+
121
+ :param model: model to apply quantization to
122
+ :param status: status to update the module to
123
+ """
124
+ current_status = _infer_status(model)
125
+
126
+ if status >= QuantizationStatus.INITIALIZED > current_status:
127
+ model.apply(initialize_module_for_quantization)
128
+
129
+ if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
130
+ model.apply(set_module_for_calibration)
131
+
132
+ if current_status < status >= QuantizationStatus.FROZEN > current_status:
133
+ model.apply(freeze_module_quantization)
134
+
135
+ if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
136
+ model.apply(compress_quantized_weights)
137
+
138
+
139
+ def find_first_name_or_class_match(
140
+ name: str, module: Module, targets: Iterable[str], check_contains: bool = False
141
+ ) -> Optional[str]:
142
+ # first element of targets that matches the given name
143
+ # if no name matches returns first target that matches the class name
144
+ # returns None otherwise
145
+ return _find_first_match(name, targets) or _find_first_match(
146
+ module.__class__.__name__, targets, check_contains
147
+ )
148
+
149
+
150
+ def _find_first_match(
151
+ value: str, targets: Iterable[str], check_contains: bool = False
152
+ ) -> Optional[str]:
153
+ # returns first element of target that matches value either
154
+ # exactly or as a regex after 're:'. if check_contains is set to True,
155
+ # additionally checks if the target string is contained with value.
156
+ for target in targets:
157
+ if target.startswith("re:"):
158
+ pattern = target[3:]
159
+ if re.match(pattern, value):
160
+ return target
161
+ elif check_contains:
162
+ if target.lower() in value.lower():
163
+ return target
164
+ elif target == value:
165
+ return target
166
+ return None
167
+
168
+
169
+ def _infer_status(model: Module) -> Optional[QuantizationStatus]:
170
+ for module in model.modules():
171
+ status = getattr(module, "quantization_status", None)
172
+ if status is not None:
173
+ return status
174
+ return None
175
+
176
+
177
+ def _load_quant_args_from_state_dict(
178
+ base_name: str, module_name: str, module: Module, state_dict: Dict
179
+ ):
180
+ """
181
+ Loads scale and zero point from a state_dict into the specified module
182
+
183
+ :param base_name: quantization target, one of: weights, input_activations or
184
+ output_activations
185
+ :param module_name: pytorch module name to look up in state_dict
186
+ :module: pytorch module associated with module_name
187
+ :state_dict: state_dict to search for matching quantization parameters
188
+ """
189
+ scale_name = f"{base_name}_scale"
190
+ zp_name = f"{base_name}_zero_point"
191
+ device = next(module.parameters()).device
192
+
193
+ scale = getattr(module, scale_name)
194
+ zp = getattr(module, zp_name)
195
+ scale.data = state_dict[f"{module_name}.{scale_name}"].to(device)
196
+ zp.data = state_dict[f"{module_name}.{zp_name}"].to(device)
@@ -0,0 +1,51 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import logging
17
+
18
+ from compressed_tensors.quantization.quant_config import QuantizationStatus
19
+ from torch.nn import Module
20
+
21
+
22
+ __all__ = [
23
+ "set_module_for_calibration",
24
+ ]
25
+
26
+
27
+ _LOGGER = logging.getLogger(__name__)
28
+
29
+
30
+ def set_module_for_calibration(module: Module):
31
+ """
32
+ marks a layer as ready for calibration which activates observers
33
+ to update scales and zero points on each forward pass
34
+
35
+ apply to full model with `model.apply(set_module_for_calibration)`
36
+
37
+ :param module: module to set for calibration
38
+ """
39
+ if not getattr(module, "quantization_scheme", None):
40
+ # no quantization scheme nothing to do
41
+ return
42
+ status = getattr(module, "quantization_status", None)
43
+ if not status or status != QuantizationStatus.INITIALIZED:
44
+ raise _LOGGER.warning(
45
+ f"Attempting set module with status {status} to calibration mode. "
46
+ f"but status is not {QuantizationStatus.INITIALIZED} - you may "
47
+ "be calibrating an uninitialized module which may fail or attempting "
48
+ "to re-calibrate a frozen module"
49
+ )
50
+
51
+ module.quantization_status = QuantizationStatus.CALIBRATION