compressed-tensors 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. compressed_tensors/base.py +2 -1
  2. compressed_tensors/compressors/__init__.py +5 -1
  3. compressed_tensors/compressors/base.py +11 -54
  4. compressed_tensors/compressors/dense.py +4 -4
  5. compressed_tensors/compressors/helpers.py +12 -12
  6. compressed_tensors/compressors/int_quantized.py +126 -0
  7. compressed_tensors/compressors/marlin_24.py +250 -0
  8. compressed_tensors/compressors/model_compressor.py +315 -0
  9. compressed_tensors/compressors/pack_quantized.py +212 -0
  10. compressed_tensors/compressors/sparse_bitmask.py +4 -4
  11. compressed_tensors/compressors/utils/__init__.py +19 -0
  12. compressed_tensors/compressors/utils/helpers.py +43 -0
  13. compressed_tensors/compressors/utils/permutations_24.py +65 -0
  14. compressed_tensors/compressors/utils/semi_structured_conversions.py +341 -0
  15. compressed_tensors/config/base.py +7 -4
  16. compressed_tensors/config/dense.py +4 -4
  17. compressed_tensors/config/sparse_bitmask.py +3 -3
  18. compressed_tensors/quantization/lifecycle/__init__.py +1 -0
  19. compressed_tensors/quantization/lifecycle/apply.py +75 -19
  20. compressed_tensors/quantization/lifecycle/compressed.py +69 -0
  21. compressed_tensors/quantization/lifecycle/forward.py +208 -22
  22. compressed_tensors/quantization/lifecycle/frozen.py +4 -0
  23. compressed_tensors/quantization/lifecycle/initialize.py +33 -5
  24. compressed_tensors/quantization/observers/base.py +70 -5
  25. compressed_tensors/quantization/observers/helpers.py +6 -1
  26. compressed_tensors/quantization/observers/memoryless.py +17 -9
  27. compressed_tensors/quantization/observers/min_max.py +44 -13
  28. compressed_tensors/quantization/quant_args.py +33 -4
  29. compressed_tensors/quantization/quant_config.py +69 -21
  30. compressed_tensors/quantization/quant_scheme.py +81 -1
  31. compressed_tensors/quantization/utils/helpers.py +77 -8
  32. compressed_tensors/utils/helpers.py +26 -122
  33. compressed_tensors/utils/safetensors_load.py +3 -2
  34. compressed_tensors/version.py +53 -0
  35. {compressed_tensors-0.3.2.dist-info → compressed_tensors-0.4.0.dist-info}/METADATA +46 -9
  36. compressed_tensors-0.4.0.dist-info/RECORD +48 -0
  37. compressed_tensors-0.3.2.dist-info/RECORD +0 -38
  38. {compressed_tensors-0.3.2.dist-info → compressed_tensors-0.4.0.dist-info}/LICENSE +0 -0
  39. {compressed_tensors-0.3.2.dist-info → compressed_tensors-0.4.0.dist-info}/WHEEL +0 -0
  40. {compressed_tensors-0.3.2.dist-info → compressed_tensors-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,341 @@
1
+ #
2
+ # Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es).
3
+ # Pulled from nm-vllm/vllm/model_executor/layers/quantization/utils/format_24.py
4
+ #
5
+ # flake8: noqa
6
+ # isort: skip_file
7
+
8
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing,
17
+ # software distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ import torch
23
+
24
+
25
+ __all__ = [
26
+ "sparse_semi_structured_from_dense_cutlass",
27
+ "sparse_semi_structured_to_dense_cutlass",
28
+ "mask_creator",
29
+ ]
30
+
31
+ # This is PyTorch implementation of main part of reorder_meta()
32
+ # function, from tools/util/include/cutlass/util/host_reorder.h file
33
+ # of CUTLASS source tree. Furthermore, CUTLASS template for sparse
34
+ # GEMM decides upon layout of this matrix, and at the moment for the
35
+ # sparse GEMM executed on tensor cores, this is layout described by
36
+ # ColumnMajorInterleaved<2> data structure, in
37
+ # include/cutlass/layout/matrix.h of CUTLASS source tree. The
38
+ # reordering of meta matrix into meta_reordered matrix calculated
39
+ # according to these segments of CUTLASS code is re-implemented here.
40
+ # Note that this calculation produces offsets for scattering metadata
41
+ # matrix elements into reordered metadata matrix elements (or,
42
+ # equivalently, for gathering reordered metadata matrix element back
43
+ # into metadata matrix elements).
44
+ def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
45
+ dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
46
+ dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
47
+
48
+ # Reorder the rows, then swizzle the 2x2 blocks.
49
+ group_x = 64
50
+ group_y = 32 if meta_dtype.itemsize == 2 else 16
51
+
52
+ dst_rows = (
53
+ dst_rows // group_x * group_x
54
+ + (dst_rows % 2) * 2
55
+ + (dst_rows % 8) // 4
56
+ + ((dst_rows % group_y) % 4) // 2 * 32
57
+ + ((dst_rows % group_x) // 8) * 4
58
+ )
59
+
60
+ topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
61
+ bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
62
+ dst_rows += topright - bottomleft
63
+ dst_cols -= topright - bottomleft
64
+
65
+ # Assumed that meta tensor is to be stored in CUTLASS
66
+ # InterleavedColumnMajor layout, and reverse engineered
67
+ # corresponding code to store values into this tensor.
68
+ interleave = 2
69
+ cols_maj = dst_cols // interleave
70
+ cols_min = dst_cols % interleave
71
+ return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
72
+
73
+
74
+ # This function converts dense matrix into sparse semi-structured
75
+ # representation, producing "compressed" matrix, in the layout used by
76
+ # CUTLASS backend, and corresponding metadata matrix.
77
+ def sparse_semi_structured_from_dense_cutlass(dense):
78
+ if dense.dim() != 2:
79
+ raise RuntimeError(
80
+ f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501
81
+ )
82
+
83
+ m, k = dense.shape
84
+ device = dense.device
85
+
86
+ meta_dtype = torch.int8
87
+ if dense.dtype == torch.int8:
88
+ meta_dtype = torch.int32
89
+ elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:
90
+ meta_dtype = torch.int16
91
+ else:
92
+ raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
93
+ quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
94
+ if quadbits_per_meta_elem not in (4, 8):
95
+ raise RuntimeError("Invalid number of elements per meta element calculated")
96
+
97
+ if meta_dtype == torch.int32:
98
+ if m % 16 != 0:
99
+ raise RuntimeError(
100
+ f"Number of rows of dense matrix {m} must be divisible by 16"
101
+ )
102
+ else:
103
+ if m % 32 != 0:
104
+ raise RuntimeError(
105
+ f"Number of rows of dense matrix {m} must be divisible by 32"
106
+ )
107
+ if k % (4 * quadbits_per_meta_elem) != 0:
108
+ raise RuntimeError(
109
+ f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501
110
+ )
111
+
112
+ if dense.dtype != torch.float:
113
+ ksparse = 4
114
+ dense_4 = dense.view(-1, k // ksparse, ksparse)
115
+ m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
116
+ else:
117
+ ksparse = 2
118
+ dense_2 = dense.view(-1, k // ksparse, ksparse)
119
+ m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
120
+ meta_ncols = k // (ksparse * quadbits_per_meta_elem)
121
+
122
+ # Encoding quadruples of True/False values as follows:
123
+ # [True, True, False, False] -> 0b0100
124
+ # [True, False, True, False] -> 0b1000
125
+ # [False, True, True, False] -> 0b1001
126
+ # [True, False, False, True ] -> 0b1100
127
+ # [False, True, False, True ] -> 0b1101
128
+ # [False, False, True, True ] -> 0b1110
129
+ # Thus, lower two bits in the encoding are index of the True value
130
+ # at the lowest index in the quadruple, and the higher two bits in
131
+ # the encoding are index of the other True value in the quadruple.
132
+ # In case there are less than two True values, than False value or
133
+ # values at some index or indices are considered True for the
134
+ # encoding. In case there are more than two True values, then the
135
+ # excess True value(s) at some indices are considered False for
136
+ # the encoding. The exact encodings used for these cases are as
137
+ # follows:
138
+ # [False, False, False, False] -> 0b1110
139
+ # [False, False, False, True ] -> 0b1110
140
+ # [False, False, True, False] -> 0b1110
141
+ # [False, True, False, False] -> 0b1001
142
+ # [False, True, True, True ] -> 0b1101
143
+ # [True, False, False, False] -> 0b1000
144
+ # [True, False, True, True ] -> 0b1100
145
+ # [True, True, False, True ] -> 0b0100
146
+ # [True, True, True, False] -> 0b0100
147
+ # [True, True, True, True ] -> 0b0100
148
+ # These particular encodings are chosen, with the help of Espresso
149
+ # logic minimizer software, for the purpose of minimization of
150
+ # corresponding Boolean functions, that translate non-zero flags
151
+ # into encoding bits. Note also possible choices for the first
152
+ # and last of these encodings were limited only to (0b0100,
153
+ # 0b1110), in order to produce valid encodings for 1:2 sparsity
154
+ # case.
155
+
156
+ expr0 = m0 & m1
157
+ expr1 = ~m0 & m1
158
+ expr2 = ~m0 & ~m1
159
+ bit0 = expr1
160
+ bit1 = expr2
161
+ bit2 = expr0 | expr2 | m3
162
+ bit3 = expr1 | ~m1
163
+ idxs0 = bit0 | (bit1.to(torch.int64) << 1)
164
+ idxs1 = bit2 | (bit3.to(torch.int64) << 1)
165
+
166
+ if dense.dtype != torch.float:
167
+ sparse0 = dense_4.gather(
168
+ -1, idxs0.unsqueeze(-1)
169
+ ) # type: ignore[possibly-undefined]
170
+ sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
171
+ sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
172
+ else:
173
+ sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(
174
+ m, k // 2
175
+ ) # type: ignore[possibly-undefined]
176
+
177
+ meta_4 = idxs0 | (idxs1 << 2)
178
+ meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
179
+
180
+ if quadbits_per_meta_elem == 4:
181
+ meta = (
182
+ meta_n[:, :, 0]
183
+ | (meta_n[:, :, 1] << 4)
184
+ | (meta_n[:, :, 2] << 8)
185
+ | (meta_n[:, :, 3] << 12)
186
+ )
187
+ elif quadbits_per_meta_elem == 8:
188
+ meta = (
189
+ meta_n[:, :, 0]
190
+ | (meta_n[:, :, 1] << 4)
191
+ | (meta_n[:, :, 2] << 8)
192
+ | (meta_n[:, :, 3] << 12)
193
+ | (meta_n[:, :, 4] << 16)
194
+ | (meta_n[:, :, 5] << 20)
195
+ | (meta_n[:, :, 6] << 24)
196
+ | (meta_n[:, :, 7] << 28)
197
+ )
198
+
199
+ # Reorder meta tensor elements.
200
+ meta_reordered = meta.new_empty(
201
+ (m * meta_ncols,)
202
+ ) # type: ignore[possibly-undefined]
203
+ meta_offsets = _calculate_meta_reordering_scatter_offsets(
204
+ m, meta_ncols, meta_dtype, device
205
+ )
206
+ meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
207
+
208
+ return (sparse, meta_reordered.view(m, meta_ncols))
209
+
210
+
211
+ # This function performs reverse of the function above - it
212
+ # reconstructs dense matrix from a pair of "compressed" matrix, given
213
+ # in the layout used by CUTLASS backend, and accompanying metadata
214
+ # matrix.
215
+ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
216
+ if sparse.dim() != 2:
217
+ raise RuntimeError(
218
+ f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501
219
+ )
220
+
221
+ m, k = sparse.shape
222
+ device = sparse.device
223
+
224
+ if meta_reordered.dim() != 2:
225
+ raise RuntimeError(
226
+ f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501
227
+ )
228
+ if meta_reordered.device != device:
229
+ raise RuntimeError(
230
+ f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501
231
+ )
232
+
233
+ meta_dtype = meta_reordered.dtype
234
+ if meta_dtype not in (torch.int16, torch.int32):
235
+ raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
236
+ quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
237
+
238
+ ksparse = 4 if sparse.dtype != torch.float else 2
239
+
240
+ meta_nrows, meta_ncols = meta_reordered.shape
241
+ if meta_nrows != m:
242
+ raise RuntimeError(
243
+ f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501
244
+ )
245
+ if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
246
+ raise RuntimeError(
247
+ f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501
248
+ "expected according to the number of columns of meta matrix"
249
+ )
250
+
251
+ # Undo meta tensor elements reordering.
252
+ meta_offsets = _calculate_meta_reordering_scatter_offsets(
253
+ m, meta_ncols, meta_dtype, device
254
+ )
255
+ meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols)
256
+
257
+ # Unpack sparse tensor back to original dense tensor, using
258
+ # information provided by meta tensor. Note that torch.float
259
+ # datatype is handled pretty much the same as
260
+ # torch.half/torch.bfloat16, as metadata for a pair of torch.float
261
+ # value is encoded as if underlying 8 bytes contain four
262
+ # torch.half/torch.bfloat16 values, where either first two or last
263
+ # two are zeros.
264
+ meta_2 = torch.empty(
265
+ (m, meta_ncols, 2 * quadbits_per_meta_elem),
266
+ dtype=meta_dtype,
267
+ device=device,
268
+ )
269
+ if quadbits_per_meta_elem == 4:
270
+ meta_2[:, :, 0] = meta & 0b11
271
+ meta_2[:, :, 1] = (meta >> 2) & 0b11
272
+ meta_2[:, :, 2] = (meta >> 4) & 0b11
273
+ meta_2[:, :, 3] = (meta >> 6) & 0b11
274
+ meta_2[:, :, 4] = (meta >> 8) & 0b11
275
+ meta_2[:, :, 5] = (meta >> 10) & 0b11
276
+ meta_2[:, :, 6] = (meta >> 12) & 0b11
277
+ meta_2[:, :, 7] = (meta >> 14) & 0b11
278
+ elif quadbits_per_meta_elem == 8:
279
+ meta_2[:, :, 0] = meta & 0b11
280
+ meta_2[:, :, 1] = (meta >> 2) & 0b11
281
+ meta_2[:, :, 2] = (meta >> 4) & 0b11
282
+ meta_2[:, :, 3] = (meta >> 6) & 0b11
283
+ meta_2[:, :, 4] = (meta >> 8) & 0b11
284
+ meta_2[:, :, 5] = (meta >> 10) & 0b11
285
+ meta_2[:, :, 6] = (meta >> 12) & 0b11
286
+ meta_2[:, :, 7] = (meta >> 14) & 0b11
287
+ meta_2[:, :, 8] = (meta >> 16) & 0b11
288
+ meta_2[:, :, 9] = (meta >> 18) & 0b11
289
+ meta_2[:, :, 10] = (meta >> 20) & 0b11
290
+ meta_2[:, :, 11] = (meta >> 22) & 0b11
291
+ meta_2[:, :, 12] = (meta >> 24) & 0b11
292
+ meta_2[:, :, 13] = (meta >> 26) & 0b11
293
+ meta_2[:, :, 14] = (meta >> 28) & 0b11
294
+ meta_2[:, :, 15] = (meta >> 30) & 0b11
295
+
296
+ dense_offsets = meta_2.view(-1) + (
297
+ torch.arange(0, 2 * m * k // ksparse, device=device) * 4
298
+ ).view(-1, 1).repeat(1, 2).view(-1)
299
+
300
+ dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
301
+ if sparse.dtype != torch.float:
302
+ # dense.scatter_(0, dense_offsets, sparse.view(-1))
303
+ dense.scatter_(0, dense_offsets, sparse.reshape(-1))
304
+ else:
305
+ dense.view(torch.half).scatter_(
306
+ 0, dense_offsets, sparse.view(torch.half).view(-1)
307
+ )
308
+
309
+ return dense.view(m, 2 * k)
310
+
311
+
312
+ def mask_creator(tensor):
313
+ """
314
+ Class for creating N:M sparsity masks.
315
+ Masks will be created using the N:M ratio, where for every block of
316
+ M weights, N will be pruned based on ranked weight value. Each mask
317
+ will correspond to the given tensor.
318
+
319
+ :param N: The number of weights in a group to keep
320
+ :param M: The size of a weight group
321
+ """
322
+ N = 2
323
+ M = 4
324
+
325
+ mask = None
326
+ # for i, tensor in enumerate(tensors):
327
+ if tensor.numel() % M != 0:
328
+ raise ValueError(
329
+ f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups"
330
+ )
331
+
332
+ num_groups = tensor.numel() // M
333
+
334
+ # N:M sparsity for linear layers
335
+ tensor_temp = tensor.detach().abs().reshape(num_groups, M)
336
+ index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)]
337
+
338
+ w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)
339
+ mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)
340
+
341
+ return mask
@@ -19,17 +19,20 @@ from compressed_tensors.registry import RegistryMixin
19
19
  from pydantic import BaseModel
20
20
 
21
21
 
22
- __all__ = ["CompressionConfig", "CompressionFormat"]
22
+ __all__ = ["SparsityCompressionConfig", "CompressionFormat"]
23
23
 
24
24
 
25
25
  class CompressionFormat(Enum):
26
- dense_sparsity = "dense-sparsity"
26
+ dense = "dense"
27
27
  sparse_bitmask = "sparse-bitmask"
28
+ int_quantized = "int-quantized"
29
+ pack_quantized = "pack-quantized"
30
+ marlin_24 = "marlin-24"
28
31
 
29
32
 
30
- class CompressionConfig(RegistryMixin, BaseModel):
33
+ class SparsityCompressionConfig(RegistryMixin, BaseModel):
31
34
  """
32
- Base data class for storing compression parameters
35
+ Base data class for storing sparsity compression parameters
33
36
 
34
37
  :param format: name of compression format
35
38
  :param global_sparsity: average sparsity of the entire model
@@ -14,14 +14,14 @@
14
14
 
15
15
  from typing import Optional
16
16
 
17
- from compressed_tensors.config import CompressionConfig, CompressionFormat
17
+ from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
18
18
 
19
19
 
20
20
  __all__ = ["DenseSparsityConfig"]
21
21
 
22
22
 
23
- @CompressionConfig.register(name=CompressionFormat.dense_sparsity.value)
24
- class DenseSparsityConfig(CompressionConfig):
23
+ @SparsityCompressionConfig.register(name=CompressionFormat.dense.value)
24
+ class DenseSparsityConfig(SparsityCompressionConfig):
25
25
  """
26
26
  Identity configuration for storing a sparse model in
27
27
  an uncompressed dense format
@@ -31,6 +31,6 @@ class DenseSparsityConfig(CompressionConfig):
31
31
  "unstructured", "2:4", "8:16" etc
32
32
  """
33
33
 
34
- format: str = CompressionFormat.dense_sparsity.value
34
+ format: str = CompressionFormat.dense.value
35
35
  global_sparsity: Optional[float] = 0.0
36
36
  sparsity_structure: Optional[str] = "unstructured"
@@ -14,14 +14,14 @@
14
14
 
15
15
  from typing import Optional
16
16
 
17
- from compressed_tensors.config import CompressionConfig, CompressionFormat
17
+ from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
18
18
 
19
19
 
20
20
  __all__ = ["BitmaskConfig"]
21
21
 
22
22
 
23
- @CompressionConfig.register(name=CompressionFormat.sparse_bitmask.value)
24
- class BitmaskConfig(CompressionConfig):
23
+ @SparsityCompressionConfig.register(name=CompressionFormat.sparse_bitmask.value)
24
+ class BitmaskConfig(SparsityCompressionConfig):
25
25
  """
26
26
  Configuration for storing a sparse model using
27
27
  bitmask compression
@@ -19,4 +19,5 @@ from .calibration import *
19
19
  from .forward import *
20
20
  from .frozen import *
21
21
  from .initialize import *
22
+ from .compressed import *
22
23
  from .apply import *
@@ -12,13 +12,18 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import logging
15
16
  import re
16
17
  from collections import OrderedDict
17
18
  from typing import Dict, Iterable, Optional
18
19
 
20
+ import torch
19
21
  from compressed_tensors.quantization.lifecycle.calibration import (
20
22
  set_module_for_calibration,
21
23
  )
24
+ from compressed_tensors.quantization.lifecycle.compressed import (
25
+ compress_quantized_weights,
26
+ )
22
27
  from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization
23
28
  from compressed_tensors.quantization.lifecycle.initialize import (
24
29
  initialize_module_for_quantization,
@@ -27,7 +32,11 @@ from compressed_tensors.quantization.quant_config import (
27
32
  QuantizationConfig,
28
33
  QuantizationStatus,
29
34
  )
30
- from compressed_tensors.quantization.utils import iter_named_leaf_modules
35
+ from compressed_tensors.quantization.utils import (
36
+ infer_quantization_status,
37
+ iter_named_leaf_modules,
38
+ )
39
+ from compressed_tensors.utils.helpers import fix_fsdp_module_name
31
40
  from compressed_tensors.utils.safetensors_load import get_safetensors_folder
32
41
  from torch.nn import Module
33
42
 
@@ -36,12 +45,16 @@ __all__ = [
36
45
  "load_pretrained_quantization",
37
46
  "apply_quantization_config",
38
47
  "apply_quantization_status",
48
+ "find_first_name_or_class_match",
39
49
  ]
40
50
 
41
51
  from compressed_tensors.quantization.utils.helpers import is_module_quantized
42
52
  from compressed_tensors.utils.safetensors_load import get_quantization_state_dict
43
53
 
44
54
 
55
+ _LOGGER = logging.getLogger(__name__)
56
+
57
+
45
58
  def load_pretrained_quantization(model: Module, model_name_or_path: str):
46
59
  """
47
60
  Loads the quantization parameters (scale and zero point) from model_name_or_path to
@@ -97,15 +110,27 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
97
110
  for target in scheme.targets:
98
111
  target_to_scheme[target] = scheme
99
112
 
113
+ # list of submodules to ignore
114
+ ignored_submodules = []
100
115
  # mark appropriate layers for quantization by setting their quantization schemes
101
116
  for name, submodule in iter_named_leaf_modules(model):
102
- if _find_first_name_or_class_match(name, submodule, config.ignore):
117
+ # potentially fix module name to remove FSDP wrapper prefix
118
+ name = fix_fsdp_module_name(name)
119
+ if find_first_name_or_class_match(name, submodule, config.ignore):
120
+ ignored_submodules.append(name)
103
121
  continue # layer matches ignore list, continue
104
- target = _find_first_name_or_class_match(name, submodule, target_to_scheme)
122
+ target = find_first_name_or_class_match(name, submodule, target_to_scheme)
105
123
  if target is not None:
106
124
  # target matched - add layer and scheme to target list
107
125
  submodule.quantization_scheme = target_to_scheme[target]
108
126
 
127
+ if config.ignore is not None and ignored_submodules is not None:
128
+ if set(config.ignore) - set(ignored_submodules):
129
+ _LOGGER.warning(
130
+ "Some layers that were to be ignored were "
131
+ "not found in the model: "
132
+ f"{set(config.ignore) - set(ignored_submodules)}"
133
+ )
109
134
  # apply current quantization status across all targeted layers
110
135
  apply_quantization_status(model, config.quantization_status)
111
136
 
@@ -117,40 +142,60 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
117
142
  :param model: model to apply quantization to
118
143
  :param status: status to update the module to
119
144
  """
120
- if status >= QuantizationStatus.INITIALIZED:
145
+ current_status = infer_quantization_status(model)
146
+
147
+ if status >= QuantizationStatus.INITIALIZED > current_status:
121
148
  model.apply(initialize_module_for_quantization)
122
- if status >= QuantizationStatus.CALIBRATION:
149
+
150
+ if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
123
151
  model.apply(set_module_for_calibration)
124
- if status >= QuantizationStatus.FROZEN:
152
+ if current_status < status >= QuantizationStatus.FROZEN > current_status:
125
153
  model.apply(freeze_module_quantization)
126
154
 
155
+ if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
156
+ model.apply(compress_quantized_weights)
127
157
 
128
- def _find_first_name_or_class_match(
129
- name: str,
130
- module: Module,
131
- targets: Iterable[str],
158
+
159
+ def find_first_name_or_class_match(
160
+ name: str, module: Module, targets: Iterable[str], check_contains: bool = False
132
161
  ) -> Optional[str]:
133
162
  # first element of targets that matches the given name
134
163
  # if no name matches returns first target that matches the class name
135
164
  # returns None otherwise
136
- return _find_first_match(name, targets) or _find_first_match(
137
- module.__class__.__name__, targets
138
- )
165
+ if isinstance(targets, Iterable):
166
+ return _find_first_match(name, targets) or _find_first_match(
167
+ module.__class__.__name__, targets, check_contains
168
+ )
139
169
 
140
170
 
141
- def _find_first_match(value: str, targets: Iterable[str]) -> Optional[str]:
171
+ def _find_first_match(
172
+ value: str, targets: Iterable[str], check_contains: bool = False
173
+ ) -> Optional[str]:
142
174
  # returns first element of target that matches value either
143
- # exactly or as a regex after 're:'
175
+ # exactly or as a regex after 're:'. if check_contains is set to True,
176
+ # additionally checks if the target string is contained with value.
177
+
144
178
  for target in targets:
145
179
  if target.startswith("re:"):
146
180
  pattern = target[3:]
147
181
  if re.match(pattern, value):
148
182
  return target
183
+ elif check_contains:
184
+ if target.lower() in value.lower():
185
+ return target
149
186
  elif target == value:
150
187
  return target
151
188
  return None
152
189
 
153
190
 
191
+ def _infer_status(model: Module) -> Optional[QuantizationStatus]:
192
+ for module in model.modules():
193
+ status = getattr(module, "quantization_status", None)
194
+ if status is not None:
195
+ return status
196
+ return None
197
+
198
+
154
199
  def _load_quant_args_from_state_dict(
155
200
  base_name: str, module_name: str, module: Module, state_dict: Dict
156
201
  ):
@@ -167,7 +212,18 @@ def _load_quant_args_from_state_dict(
167
212
  zp_name = f"{base_name}_zero_point"
168
213
  device = next(module.parameters()).device
169
214
 
170
- scale = getattr(module, scale_name)
171
- zp = getattr(module, zp_name)
172
- scale.data = state_dict[f"{module_name}.{scale_name}"].to(device)
173
- zp.data = state_dict[f"{module_name}.{zp_name}"].to(device)
215
+ scale = getattr(module, scale_name, None)
216
+ zp = getattr(module, zp_name, None)
217
+ if scale is not None:
218
+ state_dict_scale = state_dict.get(f"{module_name}.{scale_name}")
219
+ if state_dict_scale is not None:
220
+ scale.data = state_dict_scale.to(device).to(scale.dtype)
221
+ else:
222
+ scale.data = scale.data.to(device)
223
+
224
+ if zp is not None:
225
+ zp_from_state = state_dict.get(f"{module_name}.{zp_name}", None)
226
+ if zp_from_state is not None: # load the non-zero zero points
227
+ zp.data = state_dict[f"{module_name}.{zp_name}"].to(device)
228
+ else: # fill with zeros matching scale shape
229
+ zp.data = torch.zeros_like(scale, dtype=torch.int8).to(device)
@@ -0,0 +1,69 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import logging
17
+
18
+ import torch
19
+ from compressed_tensors.quantization.lifecycle.forward import quantize
20
+ from compressed_tensors.quantization.quant_config import QuantizationStatus
21
+ from torch.nn import Module
22
+
23
+
24
+ __all__ = [
25
+ "compress_quantized_weights",
26
+ ]
27
+
28
+
29
+ _LOGGER = logging.getLogger(__name__)
30
+
31
+
32
+ def compress_quantized_weights(module: Module):
33
+ """
34
+ Quantizes the module weight representation to use fewer bits in memory
35
+
36
+ apply to full model with `model.apply(compress_quantized_weights)`
37
+
38
+ :param module: module to compress to quantized representation
39
+ """
40
+ scheme = getattr(module, "quantization_scheme", None)
41
+ if not scheme or not scheme.weights:
42
+ # no quantization scheme or weights not quantized, nothing to do
43
+ return
44
+
45
+ if scheme is QuantizationStatus.COMPRESSED:
46
+ # module is already compressed, nothing to do
47
+ return
48
+
49
+ weight = getattr(module, "weight", None)
50
+ scale = getattr(module, "weight_scale", None)
51
+ zero_point = getattr(module, "weight_zero_point", None)
52
+
53
+ if weight is None or scale is None or zero_point is None:
54
+ # no weight, scale, or ZP, nothing to do
55
+
56
+ # mark as compressed here to maintain consistent status throughout the model
57
+ module.quantization_status = QuantizationStatus.COMPRESSED
58
+ return
59
+
60
+ module.weight.requires_grad = False # cannot use auto grad after compression
61
+ module.weight.data = quantize(
62
+ x=weight,
63
+ scale=scale,
64
+ zero_point=zero_point,
65
+ args=scheme.weights,
66
+ dtype=torch.int8,
67
+ )
68
+
69
+ module.quantization_status = QuantizationStatus.COMPRESSED