compressed-tensors 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. compressed_tensors/base.py +2 -1
  2. compressed_tensors/compressors/__init__.py +5 -1
  3. compressed_tensors/compressors/base.py +11 -54
  4. compressed_tensors/compressors/dense.py +4 -4
  5. compressed_tensors/compressors/helpers.py +12 -12
  6. compressed_tensors/compressors/int_quantized.py +126 -0
  7. compressed_tensors/compressors/marlin_24.py +250 -0
  8. compressed_tensors/compressors/model_compressor.py +315 -0
  9. compressed_tensors/compressors/pack_quantized.py +212 -0
  10. compressed_tensors/compressors/sparse_bitmask.py +3 -3
  11. compressed_tensors/compressors/utils/__init__.py +19 -0
  12. compressed_tensors/compressors/utils/helpers.py +43 -0
  13. compressed_tensors/compressors/utils/permutations_24.py +65 -0
  14. compressed_tensors/compressors/utils/semi_structured_conversions.py +341 -0
  15. compressed_tensors/config/base.py +7 -4
  16. compressed_tensors/config/dense.py +4 -4
  17. compressed_tensors/config/sparse_bitmask.py +3 -3
  18. compressed_tensors/quantization/lifecycle/__init__.py +1 -0
  19. compressed_tensors/quantization/lifecycle/apply.py +62 -11
  20. compressed_tensors/quantization/lifecycle/compressed.py +69 -0
  21. compressed_tensors/quantization/lifecycle/forward.py +161 -54
  22. compressed_tensors/quantization/lifecycle/frozen.py +4 -0
  23. compressed_tensors/quantization/lifecycle/initialize.py +33 -5
  24. compressed_tensors/quantization/observers/base.py +31 -27
  25. compressed_tensors/quantization/observers/helpers.py +6 -1
  26. compressed_tensors/quantization/observers/memoryless.py +17 -9
  27. compressed_tensors/quantization/observers/min_max.py +44 -13
  28. compressed_tensors/quantization/quant_args.py +2 -2
  29. compressed_tensors/quantization/quant_config.py +69 -21
  30. compressed_tensors/quantization/quant_scheme.py +81 -1
  31. compressed_tensors/quantization/utils/helpers.py +76 -8
  32. compressed_tensors/utils/helpers.py +24 -6
  33. compressed_tensors/utils/safetensors_load.py +3 -2
  34. compressed_tensors/version.py +53 -0
  35. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.4.0.dist-info}/METADATA +46 -8
  36. compressed_tensors-0.4.0.dist-info/RECORD +48 -0
  37. compressed_tensors-0.3.3.dist-info/RECORD +0 -38
  38. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.4.0.dist-info}/LICENSE +0 -0
  39. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.4.0.dist-info}/WHEEL +0 -0
  40. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,315 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import logging
17
+ import operator
18
+ import os
19
+ from copy import deepcopy
20
+ from typing import Any, Dict, Optional, Union
21
+
22
+ from compressed_tensors.base import (
23
+ COMPRESSION_CONFIG_NAME,
24
+ QUANTIZATION_CONFIG_NAME,
25
+ SPARSITY_CONFIG_NAME,
26
+ )
27
+ from compressed_tensors.compressors import Compressor
28
+ from compressed_tensors.config import SparsityCompressionConfig
29
+ from compressed_tensors.quantization import (
30
+ QuantizationConfig,
31
+ QuantizationStatus,
32
+ apply_quantization_config,
33
+ load_pretrained_quantization,
34
+ )
35
+ from compressed_tensors.quantization.utils import (
36
+ is_module_quantized,
37
+ iter_named_leaf_modules,
38
+ )
39
+ from compressed_tensors.utils import get_safetensors_folder
40
+ from compressed_tensors.utils.helpers import fix_fsdp_module_name
41
+ from torch import Tensor
42
+ from torch.nn import Module, Parameter
43
+ from tqdm import tqdm
44
+ from transformers import AutoConfig
45
+ from transformers.file_utils import CONFIG_NAME
46
+
47
+
48
+ __all__ = ["ModelCompressor", "map_modules_to_quant_args"]
49
+
50
+ _LOGGER: logging.Logger = logging.getLogger(__name__)
51
+
52
+
53
+ class ModelCompressor:
54
+ """
55
+ Handles compression and decompression of a model with a sparsity config and/or
56
+ quantization config.
57
+
58
+ Compression LifeCycle
59
+ - compressor = ModelCompressor.from_pretrained_model(model)
60
+ - compressed_state_dict = compressor.compress(model, state_dict)
61
+ - compressor.quantization_compressor.compress(model, state_dict)
62
+ - compressor.sparsity_compressor.compress(model, state_dict)
63
+ - model.save_pretrained(output_dir, state_dict=compressed_state_dict)
64
+ - compressor.update_config(output_dir)
65
+
66
+ Decompression LifeCycle
67
+ - compressor = ModelCompressor.from_pretrained(comp_model_path)
68
+ - model = AutoModel.from_pretrained(comp_model_path)
69
+ - compressor.decompress(comp_model_path, model)
70
+ - compressor.sparsity_compressor.decompress(comp_model_path, model)
71
+ - compressor.quantization_compressor.decompress(comp_model_path, model)
72
+
73
+ :param sparsity_config: config specifying sparsity compression parameters
74
+ :param quantization_config: config specifying quantization compression parameters
75
+ """
76
+
77
+ @classmethod
78
+ def from_pretrained(
79
+ cls,
80
+ pretrained_model_name_or_path: str,
81
+ ) -> Optional["ModelCompressor"]:
82
+ """
83
+ Given a path to a model config, extract the sparsity and/or quantization
84
+ configs and load a ModelCompressor
85
+
86
+ :param pretrained_model_name_or_path: path to model config on disk or HF hub
87
+ :return: compressor for the extracted configs
88
+ """
89
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
90
+ compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None)
91
+ return cls.from_compression_config(compression_config)
92
+
93
+ @classmethod
94
+ def from_compression_config(cls, compression_config: Dict[str, Any]):
95
+ """
96
+ :param compression_config: compression/quantization config dictionary
97
+ found under key "quantization_config" in HF model config
98
+ :return: compressor for the extracted configs
99
+ """
100
+ if compression_config is None:
101
+ return None
102
+
103
+ try:
104
+ from transformers.utils.quantization_config import CompressedTensorsConfig
105
+
106
+ if isinstance(compression_config, CompressedTensorsConfig):
107
+ compression_config = compression_config.to_dict()
108
+ except ImportError:
109
+ pass
110
+
111
+ sparsity_config = cls.parse_sparsity_config(compression_config)
112
+ quantization_config = cls.parse_quantization_config(compression_config)
113
+ if sparsity_config is None and quantization_config is None:
114
+ return None
115
+
116
+ if sparsity_config is not None and not isinstance(
117
+ sparsity_config, SparsityCompressionConfig
118
+ ):
119
+ format = sparsity_config.get("format")
120
+ sparsity_config = SparsityCompressionConfig.load_from_registry(
121
+ format, **sparsity_config
122
+ )
123
+ if quantization_config is not None and not isinstance(
124
+ quantization_config, QuantizationConfig
125
+ ):
126
+ quantization_config = QuantizationConfig.parse_obj(quantization_config)
127
+
128
+ return cls(
129
+ sparsity_config=sparsity_config, quantization_config=quantization_config
130
+ )
131
+
132
+ @classmethod
133
+ def from_pretrained_model(
134
+ cls,
135
+ model: Module,
136
+ sparsity_config: Union[SparsityCompressionConfig, str, None] = None,
137
+ quantization_format: Optional[str] = None,
138
+ ) -> Optional["ModelCompressor"]:
139
+ """
140
+ Given a pytorch model and optional sparsity and/or quantization configs,
141
+ load the appropriate compressors
142
+
143
+ :param model: pytorch model to target for compression
144
+ :param sparsity_config: a filled in sparsity config or string corresponding
145
+ to a sparsity compression algorithm
146
+ :param quantization_format: string corresponding to a quantization compression
147
+ algorithm
148
+ :return: compressor for the extracted configs
149
+ """
150
+ quantization_config = QuantizationConfig.from_pretrained(
151
+ model, format=quantization_format
152
+ )
153
+
154
+ if isinstance(sparsity_config, str): # we passed in a sparsity format
155
+ sparsity_config = SparsityCompressionConfig.load_from_registry(
156
+ sparsity_config
157
+ )
158
+
159
+ if sparsity_config is None and quantization_config is None:
160
+ return None
161
+
162
+ return cls(
163
+ sparsity_config=sparsity_config, quantization_config=quantization_config
164
+ )
165
+
166
+ @staticmethod
167
+ def parse_sparsity_config(compression_config: Dict) -> Union[Dict, None]:
168
+ if compression_config is None:
169
+ return None
170
+ if SPARSITY_CONFIG_NAME not in compression_config:
171
+ return None
172
+ if hasattr(compression_config, SPARSITY_CONFIG_NAME):
173
+ # for loaded HFQuantizer config
174
+ return getattr(compression_config, SPARSITY_CONFIG_NAME)
175
+
176
+ # SparseAutoModel format
177
+ return compression_config.get(SPARSITY_CONFIG_NAME, None)
178
+
179
+ @staticmethod
180
+ def parse_quantization_config(compression_config: Dict) -> Union[Dict, None]:
181
+ if compression_config is None:
182
+ return None
183
+
184
+ if hasattr(compression_config, QUANTIZATION_CONFIG_NAME):
185
+ # for loaded HFQuantizer config
186
+ return getattr(compression_config, QUANTIZATION_CONFIG_NAME)
187
+
188
+ # SparseAutoModel format
189
+ quantization_config = deepcopy(compression_config)
190
+ quantization_config.pop(SPARSITY_CONFIG_NAME, None)
191
+ if len(quantization_config) == 0:
192
+ quantization_config = None
193
+ return quantization_config
194
+
195
+ def __init__(
196
+ self,
197
+ sparsity_config: Optional[SparsityCompressionConfig] = None,
198
+ quantization_config: Optional[QuantizationConfig] = None,
199
+ ):
200
+ self.sparsity_config = sparsity_config
201
+ self.quantization_config = quantization_config
202
+ self.sparsity_compressor = None
203
+ self.quantization_compressor = None
204
+
205
+ if sparsity_config is not None:
206
+ self.sparsity_compressor = Compressor.load_from_registry(
207
+ sparsity_config.format, config=sparsity_config
208
+ )
209
+ if quantization_config is not None:
210
+ self.quantization_compressor = Compressor.load_from_registry(
211
+ quantization_config.format, config=quantization_config
212
+ )
213
+
214
+ def compress(
215
+ self, model: Module, state_dict: Optional[Dict[str, Tensor]] = None
216
+ ) -> Dict[str, Tensor]:
217
+ """
218
+ Compresses a dense state dict or model with sparsity and/or quantization
219
+
220
+ :param model: uncompressed model to compress
221
+ :param model_state: optional uncompressed state_dict to insert into model
222
+ :return: compressed state dict
223
+ """
224
+ if state_dict is None:
225
+ state_dict = model.state_dict()
226
+
227
+ compressed_state_dict = state_dict
228
+ quantized_modules_to_args = map_modules_to_quant_args(model)
229
+ if self.quantization_compressor is not None:
230
+ compressed_state_dict = self.quantization_compressor.compress(
231
+ state_dict, model_quant_args=quantized_modules_to_args
232
+ )
233
+
234
+ if self.sparsity_compressor is not None:
235
+ compressed_state_dict = self.sparsity_compressor.compress(
236
+ compressed_state_dict
237
+ )
238
+
239
+ return compressed_state_dict
240
+
241
+ def decompress(self, model_path: str, model: Module):
242
+ """
243
+ Overwrites the weights in model with weights decompressed from model_path
244
+
245
+ :param model_path: path to compressed weights
246
+ :param model: pytorch model to load decompressed weights into
247
+ """
248
+ model_path = get_safetensors_folder(model_path)
249
+ if self.sparsity_compressor is not None:
250
+ dense_gen = self.sparsity_compressor.decompress(model_path)
251
+ self._replace_weights(dense_gen, model)
252
+ setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
253
+
254
+ if self.quantization_compressor is not None:
255
+ apply_quantization_config(model, self.quantization_config)
256
+ load_pretrained_quantization(model, model_path)
257
+ dense_gen = self.quantization_compressor.decompress(model_path)
258
+ self._replace_weights(dense_gen, model)
259
+
260
+ def update_status(module):
261
+ module.quantization_status = QuantizationStatus.FROZEN
262
+
263
+ model.apply(update_status)
264
+ setattr(model, QUANTIZATION_CONFIG_NAME, self.quantization_config)
265
+
266
+ def update_config(self, save_directory: str):
267
+ """
268
+ Update the model config located at save_directory with compression configs
269
+ for sparsity and/or quantization
270
+
271
+ :param save_directory: path to a folder containing a HF model config
272
+ """
273
+ config_file_path = os.path.join(save_directory, CONFIG_NAME)
274
+ if not os.path.exists(config_file_path):
275
+ _LOGGER.warning(
276
+ f"Could not find a valid model config file in "
277
+ f"{save_directory}. Compression config will not be saved."
278
+ )
279
+ return
280
+
281
+ with open(config_file_path, "r") as config_file:
282
+ config_data = json.load(config_file)
283
+
284
+ config_data[COMPRESSION_CONFIG_NAME] = {}
285
+ if self.quantization_config is not None:
286
+ quant_config_data = self.quantization_config.model_dump()
287
+ config_data[COMPRESSION_CONFIG_NAME] = quant_config_data
288
+ if self.sparsity_config is not None:
289
+ sparsity_config_data = self.sparsity_config.model_dump()
290
+ config_data[COMPRESSION_CONFIG_NAME][
291
+ SPARSITY_CONFIG_NAME
292
+ ] = sparsity_config_data
293
+
294
+ with open(config_file_path, "w") as config_file:
295
+ json.dump(config_data, config_file, indent=2, sort_keys=True)
296
+
297
+ def _replace_weights(self, dense_weight_generator, model):
298
+ for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
299
+ # loading the decompressed weights into the model
300
+ model_device = operator.attrgetter(name)(model).device
301
+ data_old = operator.attrgetter(name)(model)
302
+ data_dtype = data_old.dtype
303
+ data_new = Parameter(data.to(model_device).to(data_dtype))
304
+ data_old.data = data_new.data
305
+
306
+
307
+ def map_modules_to_quant_args(model: Module) -> Dict:
308
+ quantized_modules_to_args = {}
309
+ for name, submodule in iter_named_leaf_modules(model):
310
+ if is_module_quantized(submodule):
311
+ if submodule.quantization_scheme.weights is not None:
312
+ name = fix_fsdp_module_name(name)
313
+ quantized_modules_to_args[name] = submodule.quantization_scheme.weights
314
+
315
+ return quantized_modules_to_args
@@ -0,0 +1,212 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import math
17
+ from typing import Dict, Generator, Tuple
18
+
19
+ import numpy as np
20
+ import torch
21
+ from compressed_tensors.compressors import Compressor
22
+ from compressed_tensors.config import CompressionFormat
23
+ from compressed_tensors.quantization import QuantizationArgs
24
+ from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
25
+ from compressed_tensors.quantization.utils import can_quantize
26
+ from compressed_tensors.utils import get_nested_weight_mappings, merge_names
27
+ from safetensors import safe_open
28
+ from torch import Tensor
29
+ from tqdm import tqdm
30
+
31
+
32
+ __all__ = ["PackedQuantizationCompressor", "pack_4bit_ints", "unpack_4bit_ints"]
33
+
34
+ _LOGGER: logging.Logger = logging.getLogger(__name__)
35
+
36
+
37
+ @Compressor.register(name=CompressionFormat.pack_quantized.value)
38
+ class PackedQuantizationCompressor(Compressor):
39
+ """
40
+ Compresses a quantized model by packing every eight 4-bit weights into an int32
41
+ """
42
+
43
+ COMPRESSION_PARAM_NAMES = [
44
+ "weight_packed",
45
+ "weight_scale",
46
+ "weight_zero_point",
47
+ "weight_shape",
48
+ ]
49
+
50
+ def compress(
51
+ self,
52
+ model_state: Dict[str, Tensor],
53
+ model_quant_args: Dict[str, QuantizationArgs],
54
+ **kwargs,
55
+ ) -> Dict[str, Tensor]:
56
+ """
57
+ Compresses a dense state dict
58
+
59
+ :param model_state: state dict of uncompressed model
60
+ :param model_quant_args: quantization args for each quantized weight, needed for
61
+ quantize function to calculate bit depth
62
+ :return: compressed state dict
63
+ """
64
+ compressed_dict = {}
65
+ weight_suffix = ".weight"
66
+ _LOGGER.debug(
67
+ f"Compressing model with {len(model_state)} parameterized layers..."
68
+ )
69
+
70
+ for name, value in tqdm(model_state.items(), desc="Compressing model"):
71
+ if name.endswith(weight_suffix):
72
+ prefix = name[: -(len(weight_suffix))]
73
+ scale = model_state.get(merge_names(prefix, "weight_scale"), None)
74
+ zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
75
+ shape = torch.tensor(value.shape)
76
+ if scale is not None and zp is not None:
77
+ # weight is quantized, compress it
78
+ quant_args = model_quant_args[prefix]
79
+ if can_quantize(value, quant_args):
80
+ # convert weight to an int if not already compressed
81
+ value = quantize(
82
+ x=value,
83
+ scale=scale,
84
+ zero_point=zp,
85
+ args=quant_args,
86
+ dtype=torch.int8,
87
+ )
88
+ value = pack_4bit_ints(value.cpu())
89
+ compressed_dict[merge_names(prefix, "weight_shape")] = shape
90
+ compressed_dict[merge_names(prefix, "weight_packed")] = value
91
+ continue
92
+
93
+ elif name.endswith("zero_point"):
94
+ if torch.all(value == 0):
95
+ # all zero_points are 0, no need to include in
96
+ # compressed state_dict
97
+ continue
98
+
99
+ compressed_dict[name] = value.to("cpu")
100
+
101
+ return compressed_dict
102
+
103
+ def decompress(
104
+ self, path_to_model_or_tensors: str, device: str = "cpu"
105
+ ) -> Generator[Tuple[str, Tensor], None, None]:
106
+ """
107
+ Reads a compressed state dict located at path_to_model_or_tensors
108
+ and returns a generator for sequentially decompressing back to a
109
+ dense state dict
110
+
111
+ :param model_path: path to compressed safetensors model (directory with
112
+ one or more safetensors files) or compressed tensors file
113
+ :param device: optional device to load intermediate weights into
114
+ :return: compressed state dict
115
+ """
116
+ weight_mappings = get_nested_weight_mappings(
117
+ path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
118
+ )
119
+ for weight_name in weight_mappings.keys():
120
+ weight_data = {}
121
+ for param_name, safe_path in weight_mappings[weight_name].items():
122
+ full_name = merge_names(weight_name, param_name)
123
+ with safe_open(safe_path, framework="pt", device=device) as f:
124
+ weight_data[param_name] = f.get_tensor(full_name)
125
+
126
+ if "weight_scale" in weight_data:
127
+ zero_point = weight_data.get("weight_zero_point", None)
128
+ scale = weight_data["weight_scale"]
129
+ if zero_point is None:
130
+ # zero_point assumed to be 0 if not included in state_dict
131
+ zero_point = torch.zeros_like(scale)
132
+
133
+ weight = weight_data["weight_packed"]
134
+ original_shape = torch.Size(weight_data["weight_shape"])
135
+ unpacked = unpack_4bit_ints(weight, original_shape)
136
+ decompressed = dequantize(
137
+ x_q=unpacked,
138
+ scale=scale,
139
+ zero_point=zero_point,
140
+ )
141
+ yield merge_names(weight_name, "weight"), decompressed
142
+
143
+
144
+ def pack_4bit_ints(value: torch.Tensor) -> torch.Tensor:
145
+ """
146
+ Packs a tensor of int4 weights stored in int8 into int32s with padding
147
+
148
+ :param value: tensor to pack
149
+ :returns: packed int32 tensor
150
+ """
151
+ if value.dtype is not torch.int8:
152
+ raise ValueError("Tensor must be quantized to torch.int8 before packing")
153
+
154
+ # need to convert to unsigned 8bit to use numpy's pack/unpack
155
+ temp = (value - 8).to(torch.uint8)
156
+ bits = np.unpackbits(temp.numpy(), axis=-1, bitorder="little")
157
+ ranges = np.array([range(x, x + 4) for x in range(0, bits.shape[1], 8)]).flatten()
158
+ only_4_bits = bits[:, ranges] # top 4 bits are 0 because we're really uint4
159
+
160
+ # pad each row to fill a full 32bit int
161
+ pack_depth = 32
162
+ padding = (
163
+ math.ceil(only_4_bits.shape[1] / pack_depth) * pack_depth - only_4_bits.shape[1]
164
+ )
165
+ padded_bits = np.pad(
166
+ only_4_bits, pad_width=[(0, 0), (0, padding)], constant_values=0
167
+ )
168
+
169
+ # after packbits each uint8 is two packed uint4s
170
+ # then we keep the bit pattern the same but convert to int32
171
+ compressed = np.packbits(padded_bits, axis=-1, bitorder="little")
172
+ compressed = np.ascontiguousarray(compressed).view(np.int32)
173
+
174
+ return torch.from_numpy(compressed)
175
+
176
+
177
+ def unpack_4bit_ints(value: torch.Tensor, shape: torch.Size) -> torch.Tensor:
178
+ """
179
+ Unpacks a tensor packed int4 weights into individual int8s, maintaining the
180
+ original their int4 range
181
+
182
+ :param value: tensor to upack
183
+ :param shape: shape to unpack into, used to remove padding
184
+ :returns: unpacked int8 tensor
185
+ """
186
+ if value.dtype is not torch.int32:
187
+ raise ValueError(
188
+ f"Expected {torch.int32} but got {value.dtype}, Aborting unpack."
189
+ )
190
+
191
+ # unpack bits and undo padding to nearest int32 bits
192
+ individual_depth = 4
193
+ as_uint8 = value.numpy().view(np.uint8)
194
+ bits = np.unpackbits(as_uint8, axis=-1, bitorder="little")
195
+ original_row_size = int(shape[1] * individual_depth)
196
+ bits = bits[:, :original_row_size]
197
+
198
+ # reformat each packed uint4 to a uint8 by filling to top 4 bits with zeros
199
+ # (uint8 format is required by np.packbits)
200
+ shape_8bit = (bits.shape[0], bits.shape[1] * 2)
201
+ bits_as_8bit = np.zeros(shape_8bit, dtype=np.uint8)
202
+ ranges = np.array([range(x, x + 4) for x in range(0, shape_8bit[1], 8)]).flatten()
203
+ bits_as_8bit[:, ranges] = bits
204
+
205
+ # repack the bits to uint8
206
+ repacked = np.packbits(bits_as_8bit, axis=-1, bitorder="little")
207
+
208
+ # bits are packed in unsigned format, reformat to signed
209
+ # update the value range from uint4 to int4
210
+ final = repacked.astype(np.int8) - 8
211
+
212
+ return torch.from_numpy(final)
@@ -17,7 +17,7 @@ from typing import Dict, Generator, List, Tuple, Union
17
17
 
18
18
  import numpy
19
19
  import torch
20
- from compressed_tensors.compressors import ModelCompressor
20
+ from compressed_tensors.compressors import Compressor
21
21
  from compressed_tensors.config import CompressionFormat
22
22
  from compressed_tensors.utils import get_nested_weight_mappings, merge_names
23
23
  from safetensors import safe_open
@@ -37,8 +37,8 @@ __all__ = [
37
37
  _LOGGER: logging.Logger = logging.getLogger(__name__)
38
38
 
39
39
 
40
- @ModelCompressor.register(name=CompressionFormat.sparse_bitmask.value)
41
- class BitmaskCompressor(ModelCompressor):
40
+ @Compressor.register(name=CompressionFormat.sparse_bitmask.value)
41
+ class BitmaskCompressor(Compressor):
42
42
  """
43
43
  Compression for sparse models using bitmasks. Non-zero weights are stored in a 1d
44
44
  values tensor, with their locations stored in a 2d bitmask
@@ -0,0 +1,19 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # flake8: noqa
16
+
17
+ from .helpers import *
18
+ from .permutations_24 import *
19
+ from .semi_structured_conversions import *
@@ -0,0 +1,43 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+
17
+
18
+ __all__ = ["tensor_follows_mask_structure"]
19
+
20
+
21
+ def tensor_follows_mask_structure(tensor, mask: str = "2:4") -> bool:
22
+ """
23
+ :param tensor: tensor to check
24
+ :param mask: mask structure to check for, in the format "n:m"
25
+ :return: True if the tensor follows the mask structure, False otherwise.
26
+ Note, some weights can incidentally be zero, so we check for
27
+ atleast n zeros in each chunk of size m
28
+ """
29
+
30
+ n, m = tuple(map(int, mask.split(":")))
31
+ # Reshape the tensor into chunks of size m
32
+ tensor = tensor.view(-1, m)
33
+
34
+ # Count the number of zeros in each chunk
35
+ zero_counts = (tensor == 0).sum(dim=1)
36
+
37
+ # Check if the number of zeros in each chunk atleast n
38
+ # Greater than sign is needed as some weights can incidentally
39
+ # be zero
40
+ if not torch.all(zero_counts >= n).item():
41
+ raise ValueError()
42
+
43
+ return True
@@ -0,0 +1,65 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import numpy
17
+ import torch
18
+
19
+
20
+ __all__ = ["get_permutations_24"]
21
+
22
+
23
+ # Precompute permutations for Marlin24 weight and scale shuffling
24
+ # Originally implemented in nm-vllm/vllm/model_executor/layers/quantization/utils/marlin_24_perms.py # noqa: E501
25
+ #
26
+ # Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight
27
+ # data so that it is compatible with the tensor-core format that is described here:
28
+ # https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
29
+ #
30
+ # As a result of this reordering, the vector loads inside the kernel will get the data
31
+ # as it is needed for tensor-core (without the need to use ldmatrix instructions)
32
+ def get_permutations_24(num_bits):
33
+ perm_list = []
34
+ for i in range(32):
35
+ perm1 = []
36
+ col = i // 4
37
+ col_o = col // 2
38
+ for block in [0, 1]:
39
+ for row in [
40
+ 2 * (i % 4),
41
+ 2 * (i % 4) + 1,
42
+ 2 * (i % 4 + 4),
43
+ 2 * (i % 4 + 4) + 1,
44
+ ]:
45
+ perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block)
46
+ for j in range(4):
47
+ perm_list.extend([p + 1 * j for p in perm1])
48
+ perm = numpy.array(perm_list)
49
+
50
+ if num_bits == 4:
51
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
52
+ elif num_bits == 8:
53
+ interleave = numpy.array([0, 2, 1, 3])
54
+ else:
55
+ raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
56
+
57
+ perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
58
+ perm = torch.from_numpy(perm)
59
+ scale_perm = []
60
+ for i in range(8):
61
+ scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
62
+ scale_perm_single = []
63
+ for i in range(8):
64
+ scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
65
+ return perm, scale_perm, scale_perm_single