compressed-tensors 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. compressed_tensors/__init__.py +1 -0
  2. compressed_tensors/base.py +2 -0
  3. compressed_tensors/compressors/__init__.py +6 -12
  4. compressed_tensors/compressors/base.py +137 -9
  5. compressed_tensors/compressors/helpers.py +6 -6
  6. compressed_tensors/compressors/model_compressors/__init__.py +17 -0
  7. compressed_tensors/compressors/{model_compressor.py → model_compressors/model_compressor.py} +99 -43
  8. compressed_tensors/compressors/quantized_compressors/__init__.py +18 -0
  9. compressed_tensors/compressors/{naive_quantized.py → quantized_compressors/base.py} +64 -62
  10. compressed_tensors/compressors/quantized_compressors/naive_quantized.py +140 -0
  11. compressed_tensors/compressors/quantized_compressors/pack_quantized.py +211 -0
  12. compressed_tensors/compressors/sparse_compressors/__init__.py +18 -0
  13. compressed_tensors/compressors/sparse_compressors/base.py +110 -0
  14. compressed_tensors/compressors/{dense.py → sparse_compressors/dense.py} +3 -3
  15. compressed_tensors/compressors/{sparse_bitmask.py → sparse_compressors/sparse_bitmask.py} +14 -59
  16. compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +16 -0
  17. compressed_tensors/compressors/{marlin_24.py → sparse_quantized_compressors/marlin_24.py} +3 -3
  18. compressed_tensors/config/base.py +6 -1
  19. compressed_tensors/linear/__init__.py +13 -0
  20. compressed_tensors/linear/compressed_linear.py +87 -0
  21. compressed_tensors/quantization/__init__.py +1 -0
  22. compressed_tensors/quantization/cache.py +201 -0
  23. compressed_tensors/quantization/lifecycle/apply.py +63 -9
  24. compressed_tensors/quantization/lifecycle/calibration.py +7 -7
  25. compressed_tensors/quantization/lifecycle/compressed.py +3 -1
  26. compressed_tensors/quantization/lifecycle/forward.py +126 -44
  27. compressed_tensors/quantization/lifecycle/frozen.py +6 -1
  28. compressed_tensors/quantization/lifecycle/helpers.py +0 -20
  29. compressed_tensors/quantization/lifecycle/initialize.py +138 -55
  30. compressed_tensors/quantization/observers/__init__.py +1 -0
  31. compressed_tensors/quantization/observers/base.py +54 -14
  32. compressed_tensors/quantization/observers/min_max.py +8 -0
  33. compressed_tensors/quantization/observers/mse.py +162 -0
  34. compressed_tensors/quantization/quant_args.py +102 -24
  35. compressed_tensors/quantization/quant_config.py +14 -2
  36. compressed_tensors/quantization/quant_scheme.py +12 -13
  37. compressed_tensors/quantization/utils/helpers.py +44 -19
  38. compressed_tensors/utils/__init__.py +1 -0
  39. compressed_tensors/utils/helpers.py +30 -1
  40. compressed_tensors/utils/offload.py +14 -2
  41. compressed_tensors/utils/permute.py +70 -0
  42. compressed_tensors/utils/safetensors_load.py +2 -0
  43. compressed_tensors/utils/semi_structured_conversions.py +1 -0
  44. compressed_tensors/version.py +1 -1
  45. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/METADATA +35 -23
  46. compressed_tensors-0.7.0.dist-info/RECORD +59 -0
  47. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/WHEEL +1 -1
  48. compressed_tensors/compressors/pack_quantized.py +0 -219
  49. compressed_tensors-0.5.0.dist-info/RECORD +0 -48
  50. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/LICENSE +0 -0
  51. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/top_level.txt +0 -0
@@ -16,36 +16,51 @@ import logging
16
16
  from typing import Dict, Generator, Tuple
17
17
 
18
18
  import torch
19
- from compressed_tensors.compressors import Compressor
20
- from compressed_tensors.config import CompressionFormat
19
+ from compressed_tensors.compressors.base import BaseCompressor
21
20
  from compressed_tensors.quantization import QuantizationArgs
22
- from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
23
- from compressed_tensors.quantization.utils import can_quantize
24
21
  from compressed_tensors.utils import get_nested_weight_mappings, merge_names
25
22
  from safetensors import safe_open
26
23
  from torch import Tensor
27
24
  from tqdm import tqdm
28
25
 
29
26
 
30
- __all__ = [
31
- "QuantizationCompressor",
32
- "IntQuantizationCompressor",
33
- "FloatQuantizationCompressor",
34
- ]
35
-
36
27
  _LOGGER: logging.Logger = logging.getLogger(__name__)
37
28
 
29
+ __all__ = ["BaseQuantizationCompressor"]
30
+
38
31
 
39
- @Compressor.register(name=CompressionFormat.naive_quantized.value)
40
- class QuantizationCompressor(Compressor):
32
+ class BaseQuantizationCompressor(BaseCompressor):
41
33
  """
42
- Implements naive compression for quantized models. Weight of each
43
- quantized layer is converted from its original float type to the closest Pytorch
44
- type to the type specified by the layer's QuantizationArgs.
34
+ Base class representing a quant compression algorithm. Each child class should
35
+ implement compression_param_info, compress_weight and decompress_weight.
36
+
37
+ Compressors support compressing/decompressing a full module state dict or a single
38
+ quantized PyTorch leaf module.
39
+
40
+ Model Load Lifecycle (run_compressed=False):
41
+ - ModelCompressor.decompress()
42
+ - apply_quantization_config()
43
+ - BaseQuantizationCompressor.decompress()
44
+ - BaseQuantizationCompressor.decompress_weight()
45
+
46
+ Model Save Lifecycle:
47
+ - ModelCompressor.compress()
48
+ - BaseQuantizationCompressor.compress()
49
+ - BaseQuantizationCompressor.compress_weight()
50
+
51
+ Module Lifecycle (run_compressed=True):
52
+ - apply_quantization_config()
53
+ - compressed_module = CompressedLinear(module)
54
+ - initialize_module_for_quantization()
55
+ - BaseQuantizationCompressor.compression_param_info()
56
+ - register_parameters()
57
+ - compressed_module.forward()
58
+ - compressed_module.decompress()
59
+
60
+
61
+ :param config: config specifying compression parameters
45
62
  """
46
63
 
47
- COMPRESSION_PARAM_NAMES = ["weight", "weight_scale", "weight_zero_point"]
48
-
49
64
  def compress(
50
65
  self,
51
66
  model_state: Dict[str, Tensor],
@@ -57,7 +72,7 @@ class QuantizationCompressor(Compressor):
57
72
 
58
73
  :param model_state: state dict of uncompressed model
59
74
  :param names_to_scheme: quantization args for each quantized weight, needed for
60
- quantize function to calculate bit depth
75
+ quantize function to calculate bit depth
61
76
  :return: compressed state dict
62
77
  """
63
78
  compressed_dict = {}
@@ -66,42 +81,50 @@ class QuantizationCompressor(Compressor):
66
81
  f"Compressing model with {len(model_state)} parameterized layers..."
67
82
  )
68
83
 
69
- for name, value in tqdm(model_state.items(), desc="Compressing model"):
84
+ for name, value in tqdm(model_state.items(), desc="Quantized Compression"):
70
85
  if name.endswith(weight_suffix):
71
86
  prefix = name[: -(len(weight_suffix))]
72
87
  scale = model_state.get(merge_names(prefix, "weight_scale"), None)
73
88
  zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
74
- if scale is not None and zp is not None:
89
+ g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None)
90
+ if scale is not None:
75
91
  # weight is quantized, compress it
76
92
  quant_args = names_to_scheme[prefix]
77
- if can_quantize(value, quant_args):
78
- # only quantize if not already quantized
79
- value = quantize(
80
- x=value,
81
- scale=scale,
82
- zero_point=zp,
83
- args=quant_args,
84
- dtype=quant_args.pytorch_dtype(),
85
- )
86
- elif name.endswith("zero_point"):
87
- if torch.all(value == 0):
88
- # all zero_points are 0, no need to include in
89
- # compressed state_dict
90
- continue
91
- compressed_dict[name] = value.to("cpu")
93
+ compressed_data = self.compress_weight(
94
+ weight=value,
95
+ scale=scale,
96
+ zero_point=zp,
97
+ g_idx=g_idx,
98
+ quantization_args=quant_args,
99
+ device="cpu",
100
+ )
101
+ for key, value in compressed_data.items():
102
+ compressed_dict[merge_names(prefix, key)] = value
103
+ else:
104
+ compressed_dict[name] = value.to("cpu")
105
+ elif name.endswith("zero_point") and torch.all(value == 0):
106
+ continue
107
+ elif name.endswith("g_idx") and torch.any(value <= -1):
108
+ continue
109
+ else:
110
+ compressed_dict[name] = value.to("cpu")
92
111
 
93
112
  return compressed_dict
94
113
 
95
114
  def decompress(
96
- self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
115
+ self,
116
+ path_to_model_or_tensors: str,
117
+ names_to_scheme: Dict[str, QuantizationArgs],
118
+ device: str = "cpu",
97
119
  ) -> Generator[Tuple[str, Tensor], None, None]:
98
120
  """
99
121
  Reads a compressed state dict located at path_to_model_or_tensors
100
122
  and returns a generator for sequentially decompressing back to a
101
123
  dense state dict
102
124
 
103
- :param model_path: path to compressed safetensors model (directory with
104
- one or more safetensors files) or compressed tensors file
125
+ :param path_to_model_or_tensors: path to compressed safetensors model (directory
126
+ with one or more safetensors files) or compressed tensors file
127
+ :param names_to_scheme: quantization args for each quantized weight
105
128
  :param device: optional device to load intermediate weights into
106
129
  :return: compressed state dict
107
130
  """
@@ -116,29 +139,8 @@ class QuantizationCompressor(Compressor):
116
139
  weight_data[param_name] = f.get_tensor(full_name)
117
140
 
118
141
  if "weight_scale" in weight_data:
119
- zero_point = weight_data.get("weight_zero_point", None)
120
- scale = weight_data["weight_scale"]
121
- decompressed = dequantize(
122
- x_q=weight_data["weight"],
123
- scale=scale,
124
- zero_point=zero_point,
142
+ quant_args = names_to_scheme[weight_name]
143
+ decompressed = self.decompress_weight(
144
+ compressed_data=weight_data, quantization_args=quant_args
125
145
  )
126
146
  yield merge_names(weight_name, "weight"), decompressed
127
-
128
-
129
- @Compressor.register(name=CompressionFormat.int_quantized.value)
130
- class IntQuantizationCompressor(QuantizationCompressor):
131
- """
132
- Alias for integer quantized models
133
- """
134
-
135
- pass
136
-
137
-
138
- @Compressor.register(name=CompressionFormat.float_quantized.value)
139
- class FloatQuantizationCompressor(QuantizationCompressor):
140
- """
141
- Alias for fp quantized models
142
- """
143
-
144
- pass
@@ -0,0 +1,140 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, Optional, Tuple
16
+
17
+ import torch
18
+ from compressed_tensors.compressors.base import BaseCompressor
19
+ from compressed_tensors.compressors.quantized_compressors.base import (
20
+ BaseQuantizationCompressor,
21
+ )
22
+ from compressed_tensors.config import CompressionFormat
23
+ from compressed_tensors.quantization import QuantizationArgs
24
+ from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
25
+ from compressed_tensors.quantization.utils import can_quantize
26
+ from torch import Tensor
27
+
28
+
29
+ __all__ = [
30
+ "NaiveQuantizationCompressor",
31
+ "IntQuantizationCompressor",
32
+ "FloatQuantizationCompressor",
33
+ ]
34
+
35
+
36
+ @BaseCompressor.register(name=CompressionFormat.naive_quantized.value)
37
+ class NaiveQuantizationCompressor(BaseQuantizationCompressor):
38
+ """
39
+ Implements naive compression for quantized models. Weight of each
40
+ quantized layer is converted from its original float type to the closest Pytorch
41
+ type to the type specified by the layer's QuantizationArgs.
42
+ """
43
+
44
+ COMPRESSION_PARAM_NAMES = [
45
+ "weight",
46
+ "weight_scale",
47
+ "weight_zero_point",
48
+ "weight_g_idx",
49
+ ]
50
+
51
+ def compression_param_info(
52
+ self,
53
+ weight_shape: torch.Size,
54
+ quantization_args: Optional[QuantizationArgs] = None,
55
+ ) -> Dict[str, Tuple[torch.Size, torch.dtype]]:
56
+ """
57
+ Creates a dictionary of expected shapes and dtypes for each compression
58
+ parameter used by the compressor
59
+
60
+ :param weight_shape: uncompressed weight shape
61
+ :param quantization_args: quantization parameters for the weight
62
+ :return: dictionary mapping compressed parameter names to shape and dtype
63
+ """
64
+ dtype = quantization_args.pytorch_dtype()
65
+ return {"weight": (weight_shape, dtype)}
66
+
67
+ def compress_weight(
68
+ self,
69
+ weight: Tensor,
70
+ scale: Tensor,
71
+ zero_point: Optional[Tensor] = None,
72
+ g_idx: Optional[torch.Tensor] = None,
73
+ quantization_args: Optional[QuantizationArgs] = None,
74
+ device: Optional[torch.device] = None,
75
+ ) -> Dict[str, torch.Tensor]:
76
+ """
77
+ Compresses a single uncompressed weight
78
+
79
+ :param weight: uncompressed weight tensor
80
+ :param scale: quantization scale for weight
81
+ :param zero_point: quantization zero point for weight
82
+ :param g_idx: optional mapping from column index to group index
83
+ :param quantization_args: quantization parameters for weight
84
+ :param device: optional device to move compressed output to
85
+ :return: dictionary of compressed weight data
86
+ """
87
+ if can_quantize(weight, quantization_args):
88
+ quantized_weight = quantize(
89
+ x=weight,
90
+ scale=scale,
91
+ zero_point=zero_point,
92
+ g_idx=g_idx,
93
+ args=quantization_args,
94
+ dtype=quantization_args.pytorch_dtype(),
95
+ )
96
+
97
+ if device is not None:
98
+ quantized_weight = quantized_weight.to(device)
99
+
100
+ return {"weight": quantized_weight}
101
+
102
+ def decompress_weight(
103
+ self,
104
+ compressed_data: Dict[str, Tensor],
105
+ quantization_args: Optional[QuantizationArgs] = None,
106
+ ) -> torch.Tensor:
107
+ """
108
+ Decompresses a single compressed weight
109
+
110
+ :param compressed_data: dictionary of data needed for decompression
111
+ :param quantization_args: quantization parameters for the weight
112
+ :return: tensor of the decompressed weight
113
+ """
114
+ weight = compressed_data["weight"]
115
+ scale = compressed_data["weight_scale"]
116
+ zero_point = compressed_data.get("weight_zero_point", None)
117
+ g_idx = compressed_data.get("weight_g_idx", None)
118
+ decompressed_weight = dequantize(
119
+ x_q=weight, scale=scale, zero_point=zero_point, g_idx=g_idx
120
+ )
121
+
122
+ return decompressed_weight
123
+
124
+
125
+ @BaseCompressor.register(name=CompressionFormat.int_quantized.value)
126
+ class IntQuantizationCompressor(NaiveQuantizationCompressor):
127
+ """
128
+ Alias for integer quantized models
129
+ """
130
+
131
+ pass
132
+
133
+
134
+ @BaseCompressor.register(name=CompressionFormat.float_quantized.value)
135
+ class FloatQuantizationCompressor(NaiveQuantizationCompressor):
136
+ """
137
+ Alias for fp quantized models
138
+ """
139
+
140
+ pass
@@ -0,0 +1,211 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from typing import Dict, Optional, Tuple
16
+
17
+ import numpy as np
18
+ import torch
19
+ from compressed_tensors.compressors.base import BaseCompressor
20
+ from compressed_tensors.compressors.quantized_compressors.base import (
21
+ BaseQuantizationCompressor,
22
+ )
23
+ from compressed_tensors.config import CompressionFormat
24
+ from compressed_tensors.quantization import QuantizationArgs
25
+ from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
26
+ from compressed_tensors.quantization.utils import can_quantize
27
+ from torch import Tensor
28
+
29
+
30
+ __all__ = ["PackedQuantizationCompressor", "pack_to_int32", "unpack_from_int32"]
31
+
32
+
33
+ @BaseCompressor.register(name=CompressionFormat.pack_quantized.value)
34
+ class PackedQuantizationCompressor(BaseQuantizationCompressor):
35
+ """
36
+ Compresses a quantized model by packing every eight 4-bit weights into an int32
37
+ """
38
+
39
+ COMPRESSION_PARAM_NAMES = [
40
+ "weight_packed",
41
+ "weight_scale",
42
+ "weight_zero_point",
43
+ "weight_g_idx",
44
+ "weight_shape",
45
+ ]
46
+
47
+ def compression_param_info(
48
+ self,
49
+ weight_shape: torch.Size,
50
+ quantization_args: Optional[QuantizationArgs] = None,
51
+ ) -> Dict[str, Tuple[torch.Size, torch.dtype]]:
52
+ """
53
+ Creates a dictionary of expected shapes and dtypes for each compression
54
+ parameter used by the compressor
55
+
56
+ :param weight_shape: uncompressed weight shape
57
+ :param quantization_args: quantization parameters for the weight
58
+ :return: dictionary mapping compressed parameter names to shape and dtype
59
+ """
60
+ pack_factor = 32 // quantization_args.num_bits
61
+ packed_size = math.ceil(weight_shape[1] / pack_factor)
62
+ return {
63
+ "weight_packed": (torch.Size((weight_shape[0], packed_size)), torch.int32),
64
+ "weight_shape": (torch.Size((2,)), torch.int32),
65
+ }
66
+
67
+ def compress_weight(
68
+ self,
69
+ weight: Tensor,
70
+ scale: Tensor,
71
+ zero_point: Optional[Tensor] = None,
72
+ g_idx: Optional[torch.Tensor] = None,
73
+ quantization_args: Optional[QuantizationArgs] = None,
74
+ device: Optional[torch.device] = None,
75
+ ) -> Dict[str, torch.Tensor]:
76
+ """
77
+ Compresses a single uncompressed weight
78
+
79
+ :param weight: uncompressed weight tensor
80
+ :param scale: quantization scale for weight
81
+ :param zero_point: quantization zero point for weight
82
+ :param g_idx: optional mapping from column index to group index
83
+ :param quantization_args: quantization parameters for weight
84
+ :param device: optional device to move compressed output to
85
+ :return: dictionary of compressed weight data
86
+ """
87
+ compressed_dict = {}
88
+ if can_quantize(weight, quantization_args):
89
+ quantized_weight = quantize(
90
+ x=weight,
91
+ scale=scale,
92
+ zero_point=zero_point,
93
+ g_idx=g_idx,
94
+ args=quantization_args,
95
+ dtype=torch.int8,
96
+ )
97
+
98
+ packed_weight = pack_to_int32(quantized_weight, quantization_args.num_bits)
99
+ weight_shape = torch.tensor(weight.shape)
100
+ if device is not None:
101
+ packed_weight = packed_weight.to(device)
102
+ weight_shape = weight_shape.to(device)
103
+
104
+ compressed_dict["weight_shape"] = weight_shape
105
+ compressed_dict["weight_packed"] = packed_weight
106
+
107
+ return compressed_dict
108
+
109
+ def decompress_weight(
110
+ self,
111
+ compressed_data: Dict[str, Tensor],
112
+ quantization_args: Optional[QuantizationArgs] = None,
113
+ ) -> torch.Tensor:
114
+ """
115
+ Decompresses a single compressed weight
116
+
117
+ :param compressed_data: dictionary of data needed for decompression
118
+ :param quantization_args: quantization parameters for the weight
119
+ :return: tensor of the decompressed weight
120
+ """
121
+ weight = compressed_data["weight_packed"]
122
+ scale = compressed_data["weight_scale"]
123
+ zero_point = compressed_data.get("weight_zero_point", None)
124
+ g_idx = compressed_data.get("weight_g_idx", None)
125
+ original_shape = torch.Size(compressed_data["weight_shape"])
126
+ num_bits = quantization_args.num_bits
127
+ unpacked = unpack_from_int32(weight, num_bits, original_shape)
128
+ decompressed_weight = dequantize(
129
+ x_q=unpacked, scale=scale, zero_point=zero_point, g_idx=g_idx
130
+ )
131
+
132
+ return decompressed_weight
133
+
134
+
135
+ def pack_to_int32(value: torch.Tensor, num_bits: int) -> torch.Tensor:
136
+ """
137
+ Packs a tensor of quantized weights stored in int8 into int32s with padding
138
+
139
+ :param value: tensor to pack
140
+ :param num_bits: number of bits used to store underlying data
141
+ :returns: packed int32 tensor
142
+ """
143
+ if value.dtype is not torch.int8:
144
+ raise ValueError("Tensor must be quantized to torch.int8 before packing")
145
+
146
+ if num_bits > 8:
147
+ raise ValueError("Packing is only supported for less than 8 bits")
148
+
149
+ # convert to unsigned for packing
150
+ offset = pow(2, num_bits) // 2
151
+ value = (value + offset).to(torch.uint8)
152
+ value = value.cpu().numpy().astype(np.uint32)
153
+ pack_factor = 32 // num_bits
154
+
155
+ # pad input tensor and initialize packed output
156
+ packed_size = math.ceil(value.shape[1] / pack_factor)
157
+ packed = np.zeros((value.shape[0], packed_size), dtype=np.uint32)
158
+ padding = packed.shape[1] * pack_factor - value.shape[1]
159
+ value = np.pad(value, pad_width=[(0, 0), (0, padding)], constant_values=0)
160
+
161
+ # pack values
162
+ for i in range(pack_factor):
163
+ packed |= value[:, i::pack_factor] << num_bits * i
164
+
165
+ # convert back to signed and torch
166
+ packed = np.ascontiguousarray(packed).view(np.int32)
167
+ return torch.from_numpy(packed)
168
+
169
+
170
+ def unpack_from_int32(
171
+ value: torch.Tensor, num_bits: int, shape: torch.Size
172
+ ) -> torch.Tensor:
173
+ """
174
+ Unpacks a tensor of packed int32 weights into individual int8s, maintaining the
175
+ original their bit range
176
+
177
+ :param value: tensor to upack
178
+ :param num_bits: number of bits to unpack each data point into
179
+ :param shape: shape to unpack into, used to remove padding
180
+ :returns: unpacked int8 tensor
181
+ """
182
+ if value.dtype is not torch.int32:
183
+ raise ValueError(
184
+ f"Expected {torch.int32} but got {value.dtype}, Aborting unpack."
185
+ )
186
+
187
+ if num_bits > 8:
188
+ raise ValueError("Unpacking is only supported for less than 8 bits")
189
+
190
+ pack_factor = 32 // num_bits
191
+
192
+ # unpack
193
+ mask = pow(2, num_bits) - 1
194
+ unpacked = torch.zeros(
195
+ (value.shape[0], value.shape[1] * pack_factor),
196
+ device=value.device,
197
+ dtype=torch.int32,
198
+ )
199
+ for i in range(pack_factor):
200
+ unpacked[:, i::pack_factor] = (value >> (num_bits * i)) & mask
201
+
202
+ # remove padding
203
+ original_row_size = int(shape[1])
204
+ unpacked = unpacked[:, :original_row_size]
205
+
206
+ # bits are packed in unsigned format, reformat to signed
207
+ # update the value range from unsigned to signed
208
+ offset = pow(2, num_bits) // 2
209
+ unpacked = (unpacked - offset).to(torch.int8)
210
+
211
+ return unpacked
@@ -0,0 +1,18 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # flake8: noqa
15
+
16
+ from .base import *
17
+ from .dense import *
18
+ from .sparse_bitmask import *
@@ -0,0 +1,110 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ from typing import Dict, Generator, Tuple
17
+
18
+ from compressed_tensors.compressors.base import BaseCompressor
19
+ from compressed_tensors.utils import get_nested_weight_mappings, merge_names
20
+ from safetensors import safe_open
21
+ from torch import Tensor
22
+ from tqdm import tqdm
23
+
24
+
25
+ __all__ = ["BaseSparseCompressor"]
26
+
27
+ _LOGGER: logging.Logger = logging.getLogger(__name__)
28
+
29
+
30
+ class BaseSparseCompressor(BaseCompressor):
31
+ """
32
+ Base class representing a sparse compression algorithm. Each child class should
33
+ implement compression_param_info, compress_weight and decompress_weight.
34
+
35
+ Compressors support compressing/decompressing a full module state dict or a single
36
+ quantized PyTorch leaf module.
37
+
38
+ Model Load Lifecycle (run_compressed=False):
39
+ - ModelCompressor.decompress()
40
+ - apply_quantization_config()
41
+ - BaseSparseCompressor.decompress()
42
+ - BaseSparseCompressor.decompress_weight()
43
+
44
+ Model Save Lifecycle:
45
+ - ModelCompressor.compress()
46
+ - BaseSparseCompressor.compress()
47
+ - BaseSparseCompressor.compress_weight()
48
+
49
+ Module Lifecycle (run_compressed=True):
50
+ - apply_quantization_config()
51
+ - compressed_module = CompressedLinear(module)
52
+ - initialize_module_for_quantization()
53
+ - BaseSparseCompressor.compression_param_info()
54
+ - register_parameters()
55
+ - compressed_module.forward()
56
+ - compressed_module.decompress()
57
+
58
+
59
+ :param config: config specifying compression parameters
60
+ """
61
+
62
+ def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
63
+ """
64
+ Compresses a dense state dict using bitmask compression
65
+
66
+ :param model_state: state dict of uncompressed model
67
+ :return: compressed state dict
68
+ """
69
+ compressed_dict = {}
70
+ _LOGGER.debug(
71
+ f"Compressing model with {len(model_state)} parameterized layers..."
72
+ )
73
+ for name, value in tqdm(model_state.items(), desc="Compressing model"):
74
+ compression_data = self.compress_weight(name, value)
75
+ for key in compression_data.keys():
76
+ if key in compressed_dict:
77
+ _LOGGER.warn(
78
+ f"Expected all compressed state_dict keys to be unique, but "
79
+ f"found an existing entry for {key}. The existing entry will "
80
+ "be replaced."
81
+ )
82
+
83
+ compressed_dict.update(compression_data)
84
+
85
+ return compressed_dict
86
+
87
+ def decompress(
88
+ self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
89
+ ) -> Generator[Tuple[str, Tensor], None, None]:
90
+ """
91
+ Reads a bitmask compressed state dict located
92
+ at path_to_model_or_tensors and returns a generator
93
+ for sequentially decompressing back to a dense state dict
94
+
95
+ :param model_path: path to compressed safetensors model (directory with
96
+ one or more safetensors files) or compressed tensors file
97
+ :param device: device to load decompressed weights onto
98
+ :return: iterator for generating decompressed weights
99
+ """
100
+ weight_mappings = get_nested_weight_mappings(
101
+ path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
102
+ )
103
+ for weight_name in weight_mappings.keys():
104
+ weight_data = {}
105
+ for param_name, safe_path in weight_mappings[weight_name].items():
106
+ full_name = merge_names(weight_name, param_name)
107
+ with safe_open(safe_path, framework="pt", device=device) as f:
108
+ weight_data[param_name] = f.get_tensor(full_name)
109
+ decompressed = self.decompress_weight(weight_data)
110
+ yield weight_name, decompressed
@@ -14,13 +14,13 @@
14
14
 
15
15
  from typing import Dict, Generator, Tuple
16
16
 
17
- from compressed_tensors.compressors import Compressor
17
+ from compressed_tensors.compressors.base import BaseCompressor
18
18
  from compressed_tensors.config import CompressionFormat
19
19
  from torch import Tensor
20
20
 
21
21
 
22
- @Compressor.register(name=CompressionFormat.dense.value)
23
- class DenseCompressor(Compressor):
22
+ @BaseCompressor.register(name=CompressionFormat.dense.value)
23
+ class DenseCompressor(BaseCompressor):
24
24
  """
25
25
  Identity compressor for dense models, returns the original state_dict
26
26
  """