compressed-tensors-nightly 0.3.3.20240610__py3-none-any.whl → 0.3.3.20240612__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,6 +18,7 @@ from .base import Compressor
18
18
  from .dense import DenseCompressor
19
19
  from .helpers import load_compressed, save_compressed, save_compressed_model
20
20
  from .int_quantized import IntQuantizationCompressor
21
- from .model_compressor import ModelCompressor
21
+ from .marlin_24 import Marlin24Compressor
22
+ from .model_compressor import ModelCompressor, map_modules_to_quant_args
22
23
  from .pack_quantized import PackedQuantizationCompressor
23
24
  from .sparse_bitmask import BitmaskCompressor, BitmaskTensor
@@ -0,0 +1,250 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ from typing import Dict, Generator, Tuple
17
+
18
+ import numpy as np
19
+ import torch
20
+ from compressed_tensors.compressors import Compressor
21
+ from compressed_tensors.compressors.utils import (
22
+ get_permutations_24,
23
+ sparse_semi_structured_from_dense_cutlass,
24
+ tensor_follows_mask_structure,
25
+ )
26
+ from compressed_tensors.config import CompressionFormat
27
+ from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
28
+ from compressed_tensors.quantization.lifecycle.forward import quantize
29
+ from compressed_tensors.utils import is_quantization_param, merge_names
30
+ from torch import Tensor
31
+ from tqdm import tqdm
32
+
33
+
34
+ _LOGGER: logging.Logger = logging.getLogger(__name__)
35
+
36
+
37
+ @Compressor.register(name=CompressionFormat.marlin_24.value)
38
+ class Marlin24Compressor(Compressor):
39
+ """
40
+ Compresses a quantized model with 2:4 sparsity structure for inference with the
41
+ Marlin24 kernel. Decompression is not implemented for this compressor.
42
+ """
43
+
44
+ COMPRESSION_PARAM_NAMES = ["weight_packed", "scale_packed", "meta"]
45
+
46
+ @staticmethod
47
+ def validate_quant_compatability(
48
+ model_quant_args: Dict[str, QuantizationArgs]
49
+ ) -> bool:
50
+ """
51
+ Checks if every quantized module in the model is compatible with Marlin24
52
+ compression. Quantization must be channel or group strategy with group_size
53
+ of 128. Only symmetric quantization is supported
54
+
55
+ :param model_quant_args: dictionary of mapping module names to their
56
+ quantization configuration
57
+ :return: True if all modules are compatible with Marlin24 compression, raises
58
+ a ValueError otherwise
59
+ """
60
+ for name, quant_args in model_quant_args.items():
61
+ strategy = quant_args.strategy
62
+ group_size = quant_args.group_size
63
+ symmetric = quant_args.symmetric
64
+ if (
65
+ strategy is not QuantizationStrategy.GROUP.value
66
+ and strategy is not QuantizationStrategy.CHANNEL.value
67
+ ):
68
+ raise ValueError(
69
+ f"Marlin24 Compressor is only valid for group and channel "
70
+ f"quantization strategies, got {strategy} in {name}"
71
+ )
72
+
73
+ if group_size is not None and group_size != 128:
74
+ raise ValueError(
75
+ f"Marlin24 Compressor is only valid for group size 128, "
76
+ f"got {group_size} in {name}"
77
+ )
78
+
79
+ if not symmetric:
80
+ raise ValueError(
81
+ f"Marlin24 Compressor is only valid for symmetric quantzation, "
82
+ f"got symmetric={symmetric} in {name}"
83
+ )
84
+
85
+ return True
86
+
87
+ @staticmethod
88
+ def validate_sparsity_structure(name: str, weight: Tensor) -> bool:
89
+ """
90
+ Checks if a tensor fits the required 2:4 sparsity structure
91
+
92
+ :param name: name of the tensor to check
93
+ :param weight: tensor to check for sparsity structure
94
+ :return: True if all rows match the 2:4 sparsity structure, raises
95
+ ValueError otherwise
96
+ """
97
+
98
+ if not tensor_follows_mask_structure(weight):
99
+ raise ValueError(
100
+ "Marlin24 Compressor is only compatible with weights that have "
101
+ f"a 2:4 sparsity structure. Found segments in {name} "
102
+ "that do not match the expected structure."
103
+ )
104
+
105
+ return True
106
+
107
+ def compress(
108
+ self,
109
+ model_state: Dict[str, Tensor],
110
+ model_quant_args: Dict[str, QuantizationArgs],
111
+ **kwargs,
112
+ ) -> Dict[str, Tensor]:
113
+ """
114
+ Compresses a quantized state_dict with 2:4 sparsity structure for inference
115
+ with the Marlin24 kernel
116
+
117
+ :param model_state: state dict of uncompressed model
118
+ :param model_quant_args: quantization args for each quantized weight, needed for
119
+ quantize function to calculate bit depth
120
+ :return: compressed state dict
121
+ """
122
+ self.validate_quant_compatability(model_quant_args)
123
+
124
+ compressed_dict = {}
125
+ weight_suffix = ".weight"
126
+ _LOGGER.debug(
127
+ f"Compressing model with {len(model_state)} parameterized layers..."
128
+ )
129
+
130
+ for name, value in tqdm(model_state.items(), desc="Compressing model"):
131
+ if name.endswith(weight_suffix):
132
+ prefix = name[: -(len(weight_suffix))]
133
+ scale = model_state.get(merge_names(prefix, "weight_scale"), None)
134
+ zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
135
+ if scale is not None: # weight is quantized, compress it
136
+
137
+ # Marlin24 kernel requires float16 inputs
138
+ scale = scale.to(torch.float16)
139
+ value = value.to(torch.float16)
140
+
141
+ # quantize weight, keeping it as a float16 for now
142
+ quant_args = model_quant_args[prefix]
143
+ value = quantize(
144
+ x=value, scale=scale, zero_point=zp, args=quant_args
145
+ )
146
+
147
+ # compress based on sparsity structure
148
+ self.validate_sparsity_structure(prefix, value)
149
+ value, meta = compress_weight_24(value)
150
+ meta = meta.cpu()
151
+
152
+ # Marlin24 kernel expects input dim first
153
+ value = value.t().contiguous().cpu()
154
+ scale = scale.t().contiguous().cpu()
155
+ og_weight_shape = value.shape
156
+
157
+ # Marlin24 kernel expects unsigned values, shift zero-point
158
+ value += (1 << quant_args.num_bits) // 2
159
+
160
+ # pack quantized weight and scale
161
+ value = pack_weight_24(value, quant_args)
162
+ packed_scale = pack_scales_24(scale, quant_args, og_weight_shape)
163
+ meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
164
+
165
+ # save compressed values
166
+ compressed_dict[merge_names(prefix, "scale_packed")] = packed_scale
167
+ compressed_dict[merge_names(prefix, "weight_packed")] = value
168
+ compressed_dict[merge_names(prefix, "meta")] = meta
169
+ continue
170
+
171
+ if not is_quantization_param(name):
172
+ # export unquantized parameters without modifying
173
+ compressed_dict[name] = value.to("cpu")
174
+
175
+ return compressed_dict
176
+
177
+ def decompress(
178
+ self, path_to_model_or_tensors: str, device: str = "cpu"
179
+ ) -> Generator[Tuple[str, Tensor], None, None]:
180
+ raise NotImplementedError(
181
+ "Decompression is not implemented for the Marlin24 Compressor."
182
+ )
183
+
184
+
185
+ def compress_weight_24(weight: Tensor):
186
+ weight = weight.contiguous()
187
+ w_comp, meta = sparse_semi_structured_from_dense_cutlass(weight)
188
+ w_comp = w_comp.contiguous()
189
+ return w_comp, meta
190
+
191
+
192
+ def marlin_permute_weights(q_w, size_k, size_n, perm, tile):
193
+ assert q_w.shape == (size_k, size_n)
194
+ assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
195
+ assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
196
+
197
+ # Permute weights to 16x64 marlin tiles
198
+ q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
199
+ q_w = q_w.permute((0, 2, 1, 3))
200
+ q_w = q_w.reshape((size_k // tile, size_n * tile))
201
+
202
+ q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
203
+
204
+ return q_w
205
+
206
+
207
+ def pack_weight_24(
208
+ weight: Tensor,
209
+ quantization_args: QuantizationArgs,
210
+ tile: int = 16,
211
+ ):
212
+ size_k = weight.shape[0]
213
+ size_n = weight.shape[1]
214
+ num_bits = quantization_args.num_bits
215
+ pack_factor = 32 // num_bits
216
+
217
+ # Reshuffle to marlin_24 format
218
+ perm, _, _ = get_permutations_24(num_bits)
219
+ q_w = marlin_permute_weights(weight, size_k, size_n, perm, tile)
220
+
221
+ q_w = q_w.cpu().numpy().astype(np.uint32)
222
+
223
+ q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
224
+ for i in range(pack_factor):
225
+ q_packed |= q_w[:, i::pack_factor] << num_bits * i
226
+
227
+ q_packed = torch.from_numpy(q_packed.astype(np.int32))
228
+
229
+ return q_packed
230
+
231
+
232
+ def pack_scales_24(scales, quantization_args, w_shape):
233
+ size_k = w_shape[0]
234
+ size_n = w_shape[1]
235
+ num_bits = quantization_args.num_bits
236
+
237
+ _, scale_perm_2_4, scale_perm_single_2_4 = get_permutations_24(num_bits)
238
+
239
+ if (
240
+ quantization_args.strategy is QuantizationStrategy.GROUP
241
+ and quantization_args.group_size < size_k
242
+ ):
243
+ scales = scales.reshape((-1, len(scale_perm_2_4)))[:, scale_perm_2_4]
244
+ else: # channelwise
245
+ scales = scales.reshape((-1, len(scale_perm_single_2_4)))[
246
+ :, scale_perm_single_2_4
247
+ ]
248
+ scales = scales.reshape((-1, size_n)).contiguous()
249
+
250
+ return scales
@@ -45,7 +45,7 @@ from transformers import AutoConfig
45
45
  from transformers.file_utils import CONFIG_NAME
46
46
 
47
47
 
48
- __all__ = ["ModelCompressor"]
48
+ __all__ = ["ModelCompressor", "map_modules_to_quant_args"]
49
49
 
50
50
  _LOGGER: logging.Logger = logging.getLogger(__name__)
51
51
 
@@ -190,7 +190,7 @@ class ModelCompressor:
190
190
  state_dict = model.state_dict()
191
191
 
192
192
  compressed_state_dict = state_dict
193
- quantized_modules_to_args = _get_weight_arg_mappings(model)
193
+ quantized_modules_to_args = map_modules_to_quant_args(model)
194
194
  if self.quantization_compressor is not None:
195
195
  compressed_state_dict = self.quantization_compressor.compress(
196
196
  state_dict, model_quant_args=quantized_modules_to_args
@@ -269,7 +269,7 @@ class ModelCompressor:
269
269
  data_old.data = data_new.data
270
270
 
271
271
 
272
- def _get_weight_arg_mappings(model: Module) -> Dict:
272
+ def map_modules_to_quant_args(model: Module) -> Dict:
273
273
  quantized_modules_to_args = {}
274
274
  for name, submodule in iter_named_leaf_modules(model):
275
275
  if is_module_quantized(submodule):
@@ -0,0 +1,19 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # flake8: noqa
16
+
17
+ from .helpers import *
18
+ from .permutations_24 import *
19
+ from .semi_structured_conversions import *
@@ -0,0 +1,43 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+
17
+
18
+ __all__ = ["tensor_follows_mask_structure"]
19
+
20
+
21
+ def tensor_follows_mask_structure(tensor, mask: str = "2:4") -> bool:
22
+ """
23
+ :param tensor: tensor to check
24
+ :param mask: mask structure to check for, in the format "n:m"
25
+ :return: True if the tensor follows the mask structure, False otherwise.
26
+ Note, some weights can incidentally be zero, so we check for
27
+ atleast n zeros in each chunk of size m
28
+ """
29
+
30
+ n, m = tuple(map(int, mask.split(":")))
31
+ # Reshape the tensor into chunks of size m
32
+ tensor = tensor.view(-1, m)
33
+
34
+ # Count the number of zeros in each chunk
35
+ zero_counts = (tensor == 0).sum(dim=1)
36
+
37
+ # Check if the number of zeros in each chunk atleast n
38
+ # Greater than sign is needed as some weights can incidentally
39
+ # be zero
40
+ if not torch.all(zero_counts >= n).item():
41
+ raise ValueError()
42
+
43
+ return True
@@ -0,0 +1,65 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import numpy
17
+ import torch
18
+
19
+
20
+ __all__ = ["get_permutations_24"]
21
+
22
+
23
+ # Precompute permutations for Marlin24 weight and scale shuffling
24
+ # Originally implemented in nm-vllm/vllm/model_executor/layers/quantization/utils/marlin_24_perms.py # noqa: E501
25
+ #
26
+ # Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight
27
+ # data so that it is compatible with the tensor-core format that is described here:
28
+ # https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
29
+ #
30
+ # As a result of this reordering, the vector loads inside the kernel will get the data
31
+ # as it is needed for tensor-core (without the need to use ldmatrix instructions)
32
+ def get_permutations_24(num_bits):
33
+ perm_list = []
34
+ for i in range(32):
35
+ perm1 = []
36
+ col = i // 4
37
+ col_o = col // 2
38
+ for block in [0, 1]:
39
+ for row in [
40
+ 2 * (i % 4),
41
+ 2 * (i % 4) + 1,
42
+ 2 * (i % 4 + 4),
43
+ 2 * (i % 4 + 4) + 1,
44
+ ]:
45
+ perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block)
46
+ for j in range(4):
47
+ perm_list.extend([p + 1 * j for p in perm1])
48
+ perm = numpy.array(perm_list)
49
+
50
+ if num_bits == 4:
51
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
52
+ elif num_bits == 8:
53
+ interleave = numpy.array([0, 2, 1, 3])
54
+ else:
55
+ raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
56
+
57
+ perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
58
+ perm = torch.from_numpy(perm)
59
+ scale_perm = []
60
+ for i in range(8):
61
+ scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
62
+ scale_perm_single = []
63
+ for i in range(8):
64
+ scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
65
+ return perm, scale_perm, scale_perm_single
@@ -0,0 +1,341 @@
1
+ #
2
+ # Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es).
3
+ # Pulled from nm-vllm/vllm/model_executor/layers/quantization/utils/format_24.py
4
+ #
5
+ # flake8: noqa
6
+ # isort: skip_file
7
+
8
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing,
17
+ # software distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ import torch
23
+
24
+
25
+ __all__ = [
26
+ "sparse_semi_structured_from_dense_cutlass",
27
+ "sparse_semi_structured_to_dense_cutlass",
28
+ "mask_creator",
29
+ ]
30
+
31
+ # This is PyTorch implementation of main part of reorder_meta()
32
+ # function, from tools/util/include/cutlass/util/host_reorder.h file
33
+ # of CUTLASS source tree. Furthermore, CUTLASS template for sparse
34
+ # GEMM decides upon layout of this matrix, and at the moment for the
35
+ # sparse GEMM executed on tensor cores, this is layout described by
36
+ # ColumnMajorInterleaved<2> data structure, in
37
+ # include/cutlass/layout/matrix.h of CUTLASS source tree. The
38
+ # reordering of meta matrix into meta_reordered matrix calculated
39
+ # according to these segments of CUTLASS code is re-implemented here.
40
+ # Note that this calculation produces offsets for scattering metadata
41
+ # matrix elements into reordered metadata matrix elements (or,
42
+ # equivalently, for gathering reordered metadata matrix element back
43
+ # into metadata matrix elements).
44
+ def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
45
+ dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
46
+ dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
47
+
48
+ # Reorder the rows, then swizzle the 2x2 blocks.
49
+ group_x = 64
50
+ group_y = 32 if meta_dtype.itemsize == 2 else 16
51
+
52
+ dst_rows = (
53
+ dst_rows // group_x * group_x
54
+ + (dst_rows % 2) * 2
55
+ + (dst_rows % 8) // 4
56
+ + ((dst_rows % group_y) % 4) // 2 * 32
57
+ + ((dst_rows % group_x) // 8) * 4
58
+ )
59
+
60
+ topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
61
+ bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
62
+ dst_rows += topright - bottomleft
63
+ dst_cols -= topright - bottomleft
64
+
65
+ # Assumed that meta tensor is to be stored in CUTLASS
66
+ # InterleavedColumnMajor layout, and reverse engineered
67
+ # corresponding code to store values into this tensor.
68
+ interleave = 2
69
+ cols_maj = dst_cols // interleave
70
+ cols_min = dst_cols % interleave
71
+ return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
72
+
73
+
74
+ # This function converts dense matrix into sparse semi-structured
75
+ # representation, producing "compressed" matrix, in the layout used by
76
+ # CUTLASS backend, and corresponding metadata matrix.
77
+ def sparse_semi_structured_from_dense_cutlass(dense):
78
+ if dense.dim() != 2:
79
+ raise RuntimeError(
80
+ f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501
81
+ )
82
+
83
+ m, k = dense.shape
84
+ device = dense.device
85
+
86
+ meta_dtype = torch.int8
87
+ if dense.dtype == torch.int8:
88
+ meta_dtype = torch.int32
89
+ elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:
90
+ meta_dtype = torch.int16
91
+ else:
92
+ raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
93
+ quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
94
+ if quadbits_per_meta_elem not in (4, 8):
95
+ raise RuntimeError("Invalid number of elements per meta element calculated")
96
+
97
+ if meta_dtype == torch.int32:
98
+ if m % 16 != 0:
99
+ raise RuntimeError(
100
+ f"Number of rows of dense matrix {m} must be divisible by 16"
101
+ )
102
+ else:
103
+ if m % 32 != 0:
104
+ raise RuntimeError(
105
+ f"Number of rows of dense matrix {m} must be divisible by 32"
106
+ )
107
+ if k % (4 * quadbits_per_meta_elem) != 0:
108
+ raise RuntimeError(
109
+ f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501
110
+ )
111
+
112
+ if dense.dtype != torch.float:
113
+ ksparse = 4
114
+ dense_4 = dense.view(-1, k // ksparse, ksparse)
115
+ m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
116
+ else:
117
+ ksparse = 2
118
+ dense_2 = dense.view(-1, k // ksparse, ksparse)
119
+ m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
120
+ meta_ncols = k // (ksparse * quadbits_per_meta_elem)
121
+
122
+ # Encoding quadruples of True/False values as follows:
123
+ # [True, True, False, False] -> 0b0100
124
+ # [True, False, True, False] -> 0b1000
125
+ # [False, True, True, False] -> 0b1001
126
+ # [True, False, False, True ] -> 0b1100
127
+ # [False, True, False, True ] -> 0b1101
128
+ # [False, False, True, True ] -> 0b1110
129
+ # Thus, lower two bits in the encoding are index of the True value
130
+ # at the lowest index in the quadruple, and the higher two bits in
131
+ # the encoding are index of the other True value in the quadruple.
132
+ # In case there are less than two True values, than False value or
133
+ # values at some index or indices are considered True for the
134
+ # encoding. In case there are more than two True values, then the
135
+ # excess True value(s) at some indices are considered False for
136
+ # the encoding. The exact encodings used for these cases are as
137
+ # follows:
138
+ # [False, False, False, False] -> 0b1110
139
+ # [False, False, False, True ] -> 0b1110
140
+ # [False, False, True, False] -> 0b1110
141
+ # [False, True, False, False] -> 0b1001
142
+ # [False, True, True, True ] -> 0b1101
143
+ # [True, False, False, False] -> 0b1000
144
+ # [True, False, True, True ] -> 0b1100
145
+ # [True, True, False, True ] -> 0b0100
146
+ # [True, True, True, False] -> 0b0100
147
+ # [True, True, True, True ] -> 0b0100
148
+ # These particular encodings are chosen, with the help of Espresso
149
+ # logic minimizer software, for the purpose of minimization of
150
+ # corresponding Boolean functions, that translate non-zero flags
151
+ # into encoding bits. Note also possible choices for the first
152
+ # and last of these encodings were limited only to (0b0100,
153
+ # 0b1110), in order to produce valid encodings for 1:2 sparsity
154
+ # case.
155
+
156
+ expr0 = m0 & m1
157
+ expr1 = ~m0 & m1
158
+ expr2 = ~m0 & ~m1
159
+ bit0 = expr1
160
+ bit1 = expr2
161
+ bit2 = expr0 | expr2 | m3
162
+ bit3 = expr1 | ~m1
163
+ idxs0 = bit0 | (bit1.to(torch.int64) << 1)
164
+ idxs1 = bit2 | (bit3.to(torch.int64) << 1)
165
+
166
+ if dense.dtype != torch.float:
167
+ sparse0 = dense_4.gather(
168
+ -1, idxs0.unsqueeze(-1)
169
+ ) # type: ignore[possibly-undefined]
170
+ sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
171
+ sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
172
+ else:
173
+ sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(
174
+ m, k // 2
175
+ ) # type: ignore[possibly-undefined]
176
+
177
+ meta_4 = idxs0 | (idxs1 << 2)
178
+ meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
179
+
180
+ if quadbits_per_meta_elem == 4:
181
+ meta = (
182
+ meta_n[:, :, 0]
183
+ | (meta_n[:, :, 1] << 4)
184
+ | (meta_n[:, :, 2] << 8)
185
+ | (meta_n[:, :, 3] << 12)
186
+ )
187
+ elif quadbits_per_meta_elem == 8:
188
+ meta = (
189
+ meta_n[:, :, 0]
190
+ | (meta_n[:, :, 1] << 4)
191
+ | (meta_n[:, :, 2] << 8)
192
+ | (meta_n[:, :, 3] << 12)
193
+ | (meta_n[:, :, 4] << 16)
194
+ | (meta_n[:, :, 5] << 20)
195
+ | (meta_n[:, :, 6] << 24)
196
+ | (meta_n[:, :, 7] << 28)
197
+ )
198
+
199
+ # Reorder meta tensor elements.
200
+ meta_reordered = meta.new_empty(
201
+ (m * meta_ncols,)
202
+ ) # type: ignore[possibly-undefined]
203
+ meta_offsets = _calculate_meta_reordering_scatter_offsets(
204
+ m, meta_ncols, meta_dtype, device
205
+ )
206
+ meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
207
+
208
+ return (sparse, meta_reordered.view(m, meta_ncols))
209
+
210
+
211
+ # This function performs reverse of the function above - it
212
+ # reconstructs dense matrix from a pair of "compressed" matrix, given
213
+ # in the layout used by CUTLASS backend, and accompanying metadata
214
+ # matrix.
215
+ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
216
+ if sparse.dim() != 2:
217
+ raise RuntimeError(
218
+ f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501
219
+ )
220
+
221
+ m, k = sparse.shape
222
+ device = sparse.device
223
+
224
+ if meta_reordered.dim() != 2:
225
+ raise RuntimeError(
226
+ f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501
227
+ )
228
+ if meta_reordered.device != device:
229
+ raise RuntimeError(
230
+ f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501
231
+ )
232
+
233
+ meta_dtype = meta_reordered.dtype
234
+ if meta_dtype not in (torch.int16, torch.int32):
235
+ raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
236
+ quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
237
+
238
+ ksparse = 4 if sparse.dtype != torch.float else 2
239
+
240
+ meta_nrows, meta_ncols = meta_reordered.shape
241
+ if meta_nrows != m:
242
+ raise RuntimeError(
243
+ f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501
244
+ )
245
+ if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
246
+ raise RuntimeError(
247
+ f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501
248
+ "expected according to the number of columns of meta matrix"
249
+ )
250
+
251
+ # Undo meta tensor elements reordering.
252
+ meta_offsets = _calculate_meta_reordering_scatter_offsets(
253
+ m, meta_ncols, meta_dtype, device
254
+ )
255
+ meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols)
256
+
257
+ # Unpack sparse tensor back to original dense tensor, using
258
+ # information provided by meta tensor. Note that torch.float
259
+ # datatype is handled pretty much the same as
260
+ # torch.half/torch.bfloat16, as metadata for a pair of torch.float
261
+ # value is encoded as if underlying 8 bytes contain four
262
+ # torch.half/torch.bfloat16 values, where either first two or last
263
+ # two are zeros.
264
+ meta_2 = torch.empty(
265
+ (m, meta_ncols, 2 * quadbits_per_meta_elem),
266
+ dtype=meta_dtype,
267
+ device=device,
268
+ )
269
+ if quadbits_per_meta_elem == 4:
270
+ meta_2[:, :, 0] = meta & 0b11
271
+ meta_2[:, :, 1] = (meta >> 2) & 0b11
272
+ meta_2[:, :, 2] = (meta >> 4) & 0b11
273
+ meta_2[:, :, 3] = (meta >> 6) & 0b11
274
+ meta_2[:, :, 4] = (meta >> 8) & 0b11
275
+ meta_2[:, :, 5] = (meta >> 10) & 0b11
276
+ meta_2[:, :, 6] = (meta >> 12) & 0b11
277
+ meta_2[:, :, 7] = (meta >> 14) & 0b11
278
+ elif quadbits_per_meta_elem == 8:
279
+ meta_2[:, :, 0] = meta & 0b11
280
+ meta_2[:, :, 1] = (meta >> 2) & 0b11
281
+ meta_2[:, :, 2] = (meta >> 4) & 0b11
282
+ meta_2[:, :, 3] = (meta >> 6) & 0b11
283
+ meta_2[:, :, 4] = (meta >> 8) & 0b11
284
+ meta_2[:, :, 5] = (meta >> 10) & 0b11
285
+ meta_2[:, :, 6] = (meta >> 12) & 0b11
286
+ meta_2[:, :, 7] = (meta >> 14) & 0b11
287
+ meta_2[:, :, 8] = (meta >> 16) & 0b11
288
+ meta_2[:, :, 9] = (meta >> 18) & 0b11
289
+ meta_2[:, :, 10] = (meta >> 20) & 0b11
290
+ meta_2[:, :, 11] = (meta >> 22) & 0b11
291
+ meta_2[:, :, 12] = (meta >> 24) & 0b11
292
+ meta_2[:, :, 13] = (meta >> 26) & 0b11
293
+ meta_2[:, :, 14] = (meta >> 28) & 0b11
294
+ meta_2[:, :, 15] = (meta >> 30) & 0b11
295
+
296
+ dense_offsets = meta_2.view(-1) + (
297
+ torch.arange(0, 2 * m * k // ksparse, device=device) * 4
298
+ ).view(-1, 1).repeat(1, 2).view(-1)
299
+
300
+ dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
301
+ if sparse.dtype != torch.float:
302
+ # dense.scatter_(0, dense_offsets, sparse.view(-1))
303
+ dense.scatter_(0, dense_offsets, sparse.reshape(-1))
304
+ else:
305
+ dense.view(torch.half).scatter_(
306
+ 0, dense_offsets, sparse.view(torch.half).view(-1)
307
+ )
308
+
309
+ return dense.view(m, 2 * k)
310
+
311
+
312
+ def mask_creator(tensor):
313
+ """
314
+ Class for creating N:M sparsity masks.
315
+ Masks will be created using the N:M ratio, where for every block of
316
+ M weights, N will be pruned based on ranked weight value. Each mask
317
+ will correspond to the given tensor.
318
+
319
+ :param N: The number of weights in a group to keep
320
+ :param M: The size of a weight group
321
+ """
322
+ N = 2
323
+ M = 4
324
+
325
+ mask = None
326
+ # for i, tensor in enumerate(tensors):
327
+ if tensor.numel() % M != 0:
328
+ raise ValueError(
329
+ f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups"
330
+ )
331
+
332
+ num_groups = tensor.numel() // M
333
+
334
+ # N:M sparsity for linear layers
335
+ tensor_temp = tensor.detach().abs().reshape(num_groups, M)
336
+ index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)]
337
+
338
+ w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)
339
+ mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)
340
+
341
+ return mask
@@ -27,6 +27,7 @@ class CompressionFormat(Enum):
27
27
  sparse_bitmask = "sparse-bitmask"
28
28
  int_quantized = "int-quantized"
29
29
  pack_quantized = "pack-quantized"
30
+ marlin_24 = "marlin-24"
30
31
 
31
32
 
32
33
  class SparsityCompressionConfig(RegistryMixin, BaseModel):
@@ -42,7 +42,7 @@ class QuantizationStrategy(str, Enum):
42
42
  TOKEN = "token"
43
43
 
44
44
 
45
- class QuantizationArgs(BaseModel):
45
+ class QuantizationArgs(BaseModel, use_enum_values=True):
46
46
  """
47
47
  User facing arguments used to define a quantization config for weights or
48
48
  activations
@@ -62,7 +62,7 @@ class QuantizationArgs(BaseModel):
62
62
  """
63
63
 
64
64
  num_bits: int = 8
65
- type: QuantizationType = QuantizationType.INT
65
+ type: QuantizationType = QuantizationType.INT.value
66
66
  symmetric: bool = True
67
67
  group_size: Optional[int] = None
68
68
  strategy: Optional[QuantizationStrategy] = None
@@ -31,6 +31,7 @@ __all__ = [
31
31
  "get_weight_mappings",
32
32
  "get_nested_weight_mappings",
33
33
  "get_quantization_state_dict",
34
+ "is_quantization_param",
34
35
  ]
35
36
 
36
37
 
@@ -214,7 +215,7 @@ def get_quantization_state_dict(model_path: str) -> Dict[str, Tensor]:
214
215
  weight_mappings = get_weight_mappings(model_path)
215
216
  state_dict = {}
216
217
  for weight_name, safe_path in weight_mappings.items():
217
- if not _is_quantization_weight(weight_name):
218
+ if not is_quantization_param(weight_name):
218
219
  continue
219
220
  with safe_open(safe_path, framework="pt", device="cpu") as f:
220
221
  state_dict[weight_name] = f.get_tensor(weight_name)
@@ -222,7 +223,7 @@ def get_quantization_state_dict(model_path: str) -> Dict[str, Tensor]:
222
223
  return state_dict
223
224
 
224
225
 
225
- def _is_quantization_weight(name: str) -> bool:
226
+ def is_quantization_param(name: str) -> bool:
226
227
  """
227
228
  Checks is a parameter name is associated with a quantization parameter
228
229
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: compressed-tensors-nightly
3
- Version: 0.3.3.20240610
3
+ Version: 0.3.3.20240612
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -1,20 +1,25 @@
1
1
  compressed_tensors/__init__.py,sha256=SV1csvHUVCd8kHXz6UDZim1HZ_fAVG3vfk-j_4Bb6hY,789
2
2
  compressed_tensors/base.py,sha256=OA2TOLP1gP3LSH7gp508eqr2ZtDQ-pqRHElCp-aB0vs,755
3
3
  compressed_tensors/version.py,sha256=V8krJZctm43D4AGQhJY6dB0MvP1-T9TJ8BcGa8kESrI,1512
4
- compressed_tensors/compressors/__init__.py,sha256=3yyoNICHll3F4HS6Yu-cgNZpDhfuobFNWCs6DrPcUyQ,992
4
+ compressed_tensors/compressors/__init__.py,sha256=rhqPp3YXFxCJRLZs1KRNSHTIxK2rNU--sYwDI8MW47w,1061
5
5
  compressed_tensors/compressors/base.py,sha256=LWEgbpgTxzmoqQ7Xhq2OQszUgWoDtFuGCiV1Y8nlBGw,2134
6
6
  compressed_tensors/compressors/dense.py,sha256=G_XHbvuENyupIKlXSITOQgvPkNkcMEOLcLWQr70V9EE,1257
7
7
  compressed_tensors/compressors/helpers.py,sha256=k9avlkmeYj6vkOAvl-MgcixtP7ib24SCfhzZ-RusXfw,5403
8
8
  compressed_tensors/compressors/int_quantized.py,sha256=Ct2vCK0yoPm6vkIFlzDMGQ7m14xT1GyURsSwH9DP770,5242
9
- compressed_tensors/compressors/model_compressor.py,sha256=ymn4xzAstcutXxkY3Z3V_1MuJv383-lkZHzp37mA9z0,11119
9
+ compressed_tensors/compressors/marlin_24.py,sha256=X_BjtFB3Mn0hqiLz56UM3jGX2eNmGLnvEIPfbg7di6U,9444
10
+ compressed_tensors/compressors/model_compressor.py,sha256=jUktyujYdd9KqkA9IyZK6EMi09iEw4_itwhzSh805Jk,11150
10
11
  compressed_tensors/compressors/pack_quantized.py,sha256=VPiLlgJlDgARrn7YmiQoLqUfxErKBfj54epMYWRsF8k,8451
11
12
  compressed_tensors/compressors/sparse_bitmask.py,sha256=H9oZSTYI1oRCzAMbd4zThUnZd1h2rfs8DmA3tPcvuNE,8637
13
+ compressed_tensors/compressors/utils/__init__.py,sha256=-mbGDZh1hd9T6u62Ht_iBIK255UmMg0f5bLkSs1f9Cc,731
14
+ compressed_tensors/compressors/utils/helpers.py,sha256=4fq7KclSIK__jemCG9pwYlgWLrQjsaAMxhIrhjdw0BQ,1506
15
+ compressed_tensors/compressors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVyah6BUUir_StT28,2530
16
+ compressed_tensors/compressors/utils/semi_structured_conversions.py,sha256=g1EZHzdv-ko7ufPX430dp7wE33o6FWJXuSP4zZydCu0,13488
12
17
  compressed_tensors/config/__init__.py,sha256=ZBqWn3r6ku1qfmlHHYp0mQueY0i7Pwhr9rbQk9dDlMc,704
13
- compressed_tensors/config/base.py,sha256=grf5tDaLep8i2-W_p7H-fW9DOGXDi4Zz7su7zjs1Qqc,1454
18
+ compressed_tensors/config/base.py,sha256=ZnpuOevCE0pXdA8OJfIJnxj-ccproH7o1EOwRY8_hUU,1482
14
19
  compressed_tensors/config/dense.py,sha256=NgSxnFCnckU9-iunxEaqiFwqgdO7YYxlWKR74jNbjks,1317
15
20
  compressed_tensors/config/sparse_bitmask.py,sha256=pZUboRNZTu6NajGOQEFExoPknak5ynVAUeiiYpS1Gt8,1308
16
21
  compressed_tensors/quantization/__init__.py,sha256=83J5bPB7PavN2TfCoW7_vEDhfYpm4TDrqYO9vdSQ5bk,760
17
- compressed_tensors/quantization/quant_args.py,sha256=A6b2V8lhsM8Ho8RjlPBQdxRUDNWhqq-ie5E3RR2_GNg,4360
22
+ compressed_tensors/quantization/quant_args.py,sha256=Z9Zu20ooAwEWlliAdUw1f1zwSrheuD6vqm3YXgJ1Lws,4388
18
23
  compressed_tensors/quantization/quant_config.py,sha256=Nv9rvWNrlbeJgNZhQf-cPAEWJ9NU75ATWHCacWaiQ_s,8189
19
24
  compressed_tensors/quantization/quant_scheme.py,sha256=-hAK1-C67_wJl10eaVLUvbslPBTV04WyzL_J-u9f1ck,3571
20
25
  compressed_tensors/quantization/lifecycle/__init__.py,sha256=ggRGWRqhCxCaTTDWRcgTVX3axnS2xV6rc5YvdzK7fSg,798
@@ -35,9 +40,9 @@ compressed_tensors/registry/__init__.py,sha256=FwLSNYqfIrb5JD_6OK_MT4_svvKTN_nEh
35
40
  compressed_tensors/registry/registry.py,sha256=fxjOjh2wklCvJhQxwofdy-zV8q7MkQ85SLG77nml2iA,11890
36
41
  compressed_tensors/utils/__init__.py,sha256=5DrYjoZbaEvSkJcC-GRSbM_RBHVF4tG9gMd3zsJnjLw,665
37
42
  compressed_tensors/utils/helpers.py,sha256=5ull5yFT31M2zVxKeFvpvvlvX5f1Sk1LGuj_wrfZWCY,2267
38
- compressed_tensors/utils/safetensors_load.py,sha256=wo9UirGrGlenBqZeqotvpCT7D5MEdjCo2J3HeRaIFoU,8502
39
- compressed_tensors_nightly-0.3.3.20240610.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
40
- compressed_tensors_nightly-0.3.3.20240610.dist-info/METADATA,sha256=eWwLOgihZo6v4Fza_icFxi7Dkj8AFTwCm7OmRxcPegc,5668
41
- compressed_tensors_nightly-0.3.3.20240610.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
42
- compressed_tensors_nightly-0.3.3.20240610.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
43
- compressed_tensors_nightly-0.3.3.20240610.dist-info/RECORD,,
43
+ compressed_tensors/utils/safetensors_load.py,sha256=0MheXwx1jeY12PeISppiSIZHs6rmN2YddwPpFb9V67I,8527
44
+ compressed_tensors_nightly-0.3.3.20240612.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
45
+ compressed_tensors_nightly-0.3.3.20240612.dist-info/METADATA,sha256=GjdOve1sMxN8qOUPu3EjXTNRFvnX0jrjA8lYwmq9CCY,5668
46
+ compressed_tensors_nightly-0.3.3.20240612.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
47
+ compressed_tensors_nightly-0.3.3.20240612.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
48
+ compressed_tensors_nightly-0.3.3.20240612.dist-info/RECORD,,