compressed-tensors 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/base.py +2 -1
- compressed_tensors/compressors/__init__.py +5 -1
- compressed_tensors/compressors/base.py +11 -54
- compressed_tensors/compressors/dense.py +4 -4
- compressed_tensors/compressors/helpers.py +12 -12
- compressed_tensors/compressors/int_quantized.py +126 -0
- compressed_tensors/compressors/marlin_24.py +250 -0
- compressed_tensors/compressors/model_compressor.py +315 -0
- compressed_tensors/compressors/pack_quantized.py +212 -0
- compressed_tensors/compressors/sparse_bitmask.py +3 -3
- compressed_tensors/compressors/utils/__init__.py +19 -0
- compressed_tensors/compressors/utils/helpers.py +43 -0
- compressed_tensors/compressors/utils/permutations_24.py +65 -0
- compressed_tensors/compressors/utils/semi_structured_conversions.py +341 -0
- compressed_tensors/config/base.py +7 -4
- compressed_tensors/config/dense.py +4 -4
- compressed_tensors/config/sparse_bitmask.py +3 -3
- compressed_tensors/quantization/lifecycle/__init__.py +1 -0
- compressed_tensors/quantization/lifecycle/apply.py +62 -11
- compressed_tensors/quantization/lifecycle/compressed.py +69 -0
- compressed_tensors/quantization/lifecycle/forward.py +161 -54
- compressed_tensors/quantization/lifecycle/frozen.py +4 -0
- compressed_tensors/quantization/lifecycle/initialize.py +33 -5
- compressed_tensors/quantization/observers/base.py +31 -27
- compressed_tensors/quantization/observers/helpers.py +6 -1
- compressed_tensors/quantization/observers/memoryless.py +17 -9
- compressed_tensors/quantization/observers/min_max.py +44 -13
- compressed_tensors/quantization/quant_args.py +2 -2
- compressed_tensors/quantization/quant_config.py +69 -21
- compressed_tensors/quantization/quant_scheme.py +81 -1
- compressed_tensors/quantization/utils/helpers.py +76 -8
- compressed_tensors/utils/helpers.py +24 -6
- compressed_tensors/utils/safetensors_load.py +3 -2
- compressed_tensors/version.py +53 -0
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.4.0.dist-info}/METADATA +46 -8
- compressed_tensors-0.4.0.dist-info/RECORD +48 -0
- compressed_tensors-0.3.3.dist-info/RECORD +0 -38
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.4.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.4.0.dist-info}/WHEEL +0 -0
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,341 @@
|
|
1
|
+
#
|
2
|
+
# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es).
|
3
|
+
# Pulled from nm-vllm/vllm/model_executor/layers/quantization/utils/format_24.py
|
4
|
+
#
|
5
|
+
# flake8: noqa
|
6
|
+
# isort: skip_file
|
7
|
+
|
8
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
9
|
+
#
|
10
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
11
|
+
# you may not use this file except in compliance with the License.
|
12
|
+
# You may obtain a copy of the License at
|
13
|
+
#
|
14
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
15
|
+
#
|
16
|
+
# Unless required by applicable law or agreed to in writing,
|
17
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
18
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
19
|
+
# See the License for the specific language governing permissions and
|
20
|
+
# limitations under the License.
|
21
|
+
|
22
|
+
import torch
|
23
|
+
|
24
|
+
|
25
|
+
__all__ = [
|
26
|
+
"sparse_semi_structured_from_dense_cutlass",
|
27
|
+
"sparse_semi_structured_to_dense_cutlass",
|
28
|
+
"mask_creator",
|
29
|
+
]
|
30
|
+
|
31
|
+
# This is PyTorch implementation of main part of reorder_meta()
|
32
|
+
# function, from tools/util/include/cutlass/util/host_reorder.h file
|
33
|
+
# of CUTLASS source tree. Furthermore, CUTLASS template for sparse
|
34
|
+
# GEMM decides upon layout of this matrix, and at the moment for the
|
35
|
+
# sparse GEMM executed on tensor cores, this is layout described by
|
36
|
+
# ColumnMajorInterleaved<2> data structure, in
|
37
|
+
# include/cutlass/layout/matrix.h of CUTLASS source tree. The
|
38
|
+
# reordering of meta matrix into meta_reordered matrix calculated
|
39
|
+
# according to these segments of CUTLASS code is re-implemented here.
|
40
|
+
# Note that this calculation produces offsets for scattering metadata
|
41
|
+
# matrix elements into reordered metadata matrix elements (or,
|
42
|
+
# equivalently, for gathering reordered metadata matrix element back
|
43
|
+
# into metadata matrix elements).
|
44
|
+
def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
|
45
|
+
dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
|
46
|
+
dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
|
47
|
+
|
48
|
+
# Reorder the rows, then swizzle the 2x2 blocks.
|
49
|
+
group_x = 64
|
50
|
+
group_y = 32 if meta_dtype.itemsize == 2 else 16
|
51
|
+
|
52
|
+
dst_rows = (
|
53
|
+
dst_rows // group_x * group_x
|
54
|
+
+ (dst_rows % 2) * 2
|
55
|
+
+ (dst_rows % 8) // 4
|
56
|
+
+ ((dst_rows % group_y) % 4) // 2 * 32
|
57
|
+
+ ((dst_rows % group_x) // 8) * 4
|
58
|
+
)
|
59
|
+
|
60
|
+
topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
|
61
|
+
bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
|
62
|
+
dst_rows += topright - bottomleft
|
63
|
+
dst_cols -= topright - bottomleft
|
64
|
+
|
65
|
+
# Assumed that meta tensor is to be stored in CUTLASS
|
66
|
+
# InterleavedColumnMajor layout, and reverse engineered
|
67
|
+
# corresponding code to store values into this tensor.
|
68
|
+
interleave = 2
|
69
|
+
cols_maj = dst_cols // interleave
|
70
|
+
cols_min = dst_cols % interleave
|
71
|
+
return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
|
72
|
+
|
73
|
+
|
74
|
+
# This function converts dense matrix into sparse semi-structured
|
75
|
+
# representation, producing "compressed" matrix, in the layout used by
|
76
|
+
# CUTLASS backend, and corresponding metadata matrix.
|
77
|
+
def sparse_semi_structured_from_dense_cutlass(dense):
|
78
|
+
if dense.dim() != 2:
|
79
|
+
raise RuntimeError(
|
80
|
+
f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501
|
81
|
+
)
|
82
|
+
|
83
|
+
m, k = dense.shape
|
84
|
+
device = dense.device
|
85
|
+
|
86
|
+
meta_dtype = torch.int8
|
87
|
+
if dense.dtype == torch.int8:
|
88
|
+
meta_dtype = torch.int32
|
89
|
+
elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:
|
90
|
+
meta_dtype = torch.int16
|
91
|
+
else:
|
92
|
+
raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
|
93
|
+
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
|
94
|
+
if quadbits_per_meta_elem not in (4, 8):
|
95
|
+
raise RuntimeError("Invalid number of elements per meta element calculated")
|
96
|
+
|
97
|
+
if meta_dtype == torch.int32:
|
98
|
+
if m % 16 != 0:
|
99
|
+
raise RuntimeError(
|
100
|
+
f"Number of rows of dense matrix {m} must be divisible by 16"
|
101
|
+
)
|
102
|
+
else:
|
103
|
+
if m % 32 != 0:
|
104
|
+
raise RuntimeError(
|
105
|
+
f"Number of rows of dense matrix {m} must be divisible by 32"
|
106
|
+
)
|
107
|
+
if k % (4 * quadbits_per_meta_elem) != 0:
|
108
|
+
raise RuntimeError(
|
109
|
+
f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501
|
110
|
+
)
|
111
|
+
|
112
|
+
if dense.dtype != torch.float:
|
113
|
+
ksparse = 4
|
114
|
+
dense_4 = dense.view(-1, k // ksparse, ksparse)
|
115
|
+
m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
|
116
|
+
else:
|
117
|
+
ksparse = 2
|
118
|
+
dense_2 = dense.view(-1, k // ksparse, ksparse)
|
119
|
+
m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
|
120
|
+
meta_ncols = k // (ksparse * quadbits_per_meta_elem)
|
121
|
+
|
122
|
+
# Encoding quadruples of True/False values as follows:
|
123
|
+
# [True, True, False, False] -> 0b0100
|
124
|
+
# [True, False, True, False] -> 0b1000
|
125
|
+
# [False, True, True, False] -> 0b1001
|
126
|
+
# [True, False, False, True ] -> 0b1100
|
127
|
+
# [False, True, False, True ] -> 0b1101
|
128
|
+
# [False, False, True, True ] -> 0b1110
|
129
|
+
# Thus, lower two bits in the encoding are index of the True value
|
130
|
+
# at the lowest index in the quadruple, and the higher two bits in
|
131
|
+
# the encoding are index of the other True value in the quadruple.
|
132
|
+
# In case there are less than two True values, than False value or
|
133
|
+
# values at some index or indices are considered True for the
|
134
|
+
# encoding. In case there are more than two True values, then the
|
135
|
+
# excess True value(s) at some indices are considered False for
|
136
|
+
# the encoding. The exact encodings used for these cases are as
|
137
|
+
# follows:
|
138
|
+
# [False, False, False, False] -> 0b1110
|
139
|
+
# [False, False, False, True ] -> 0b1110
|
140
|
+
# [False, False, True, False] -> 0b1110
|
141
|
+
# [False, True, False, False] -> 0b1001
|
142
|
+
# [False, True, True, True ] -> 0b1101
|
143
|
+
# [True, False, False, False] -> 0b1000
|
144
|
+
# [True, False, True, True ] -> 0b1100
|
145
|
+
# [True, True, False, True ] -> 0b0100
|
146
|
+
# [True, True, True, False] -> 0b0100
|
147
|
+
# [True, True, True, True ] -> 0b0100
|
148
|
+
# These particular encodings are chosen, with the help of Espresso
|
149
|
+
# logic minimizer software, for the purpose of minimization of
|
150
|
+
# corresponding Boolean functions, that translate non-zero flags
|
151
|
+
# into encoding bits. Note also possible choices for the first
|
152
|
+
# and last of these encodings were limited only to (0b0100,
|
153
|
+
# 0b1110), in order to produce valid encodings for 1:2 sparsity
|
154
|
+
# case.
|
155
|
+
|
156
|
+
expr0 = m0 & m1
|
157
|
+
expr1 = ~m0 & m1
|
158
|
+
expr2 = ~m0 & ~m1
|
159
|
+
bit0 = expr1
|
160
|
+
bit1 = expr2
|
161
|
+
bit2 = expr0 | expr2 | m3
|
162
|
+
bit3 = expr1 | ~m1
|
163
|
+
idxs0 = bit0 | (bit1.to(torch.int64) << 1)
|
164
|
+
idxs1 = bit2 | (bit3.to(torch.int64) << 1)
|
165
|
+
|
166
|
+
if dense.dtype != torch.float:
|
167
|
+
sparse0 = dense_4.gather(
|
168
|
+
-1, idxs0.unsqueeze(-1)
|
169
|
+
) # type: ignore[possibly-undefined]
|
170
|
+
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
|
171
|
+
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
|
172
|
+
else:
|
173
|
+
sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(
|
174
|
+
m, k // 2
|
175
|
+
) # type: ignore[possibly-undefined]
|
176
|
+
|
177
|
+
meta_4 = idxs0 | (idxs1 << 2)
|
178
|
+
meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
|
179
|
+
|
180
|
+
if quadbits_per_meta_elem == 4:
|
181
|
+
meta = (
|
182
|
+
meta_n[:, :, 0]
|
183
|
+
| (meta_n[:, :, 1] << 4)
|
184
|
+
| (meta_n[:, :, 2] << 8)
|
185
|
+
| (meta_n[:, :, 3] << 12)
|
186
|
+
)
|
187
|
+
elif quadbits_per_meta_elem == 8:
|
188
|
+
meta = (
|
189
|
+
meta_n[:, :, 0]
|
190
|
+
| (meta_n[:, :, 1] << 4)
|
191
|
+
| (meta_n[:, :, 2] << 8)
|
192
|
+
| (meta_n[:, :, 3] << 12)
|
193
|
+
| (meta_n[:, :, 4] << 16)
|
194
|
+
| (meta_n[:, :, 5] << 20)
|
195
|
+
| (meta_n[:, :, 6] << 24)
|
196
|
+
| (meta_n[:, :, 7] << 28)
|
197
|
+
)
|
198
|
+
|
199
|
+
# Reorder meta tensor elements.
|
200
|
+
meta_reordered = meta.new_empty(
|
201
|
+
(m * meta_ncols,)
|
202
|
+
) # type: ignore[possibly-undefined]
|
203
|
+
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
204
|
+
m, meta_ncols, meta_dtype, device
|
205
|
+
)
|
206
|
+
meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
|
207
|
+
|
208
|
+
return (sparse, meta_reordered.view(m, meta_ncols))
|
209
|
+
|
210
|
+
|
211
|
+
# This function performs reverse of the function above - it
|
212
|
+
# reconstructs dense matrix from a pair of "compressed" matrix, given
|
213
|
+
# in the layout used by CUTLASS backend, and accompanying metadata
|
214
|
+
# matrix.
|
215
|
+
def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
|
216
|
+
if sparse.dim() != 2:
|
217
|
+
raise RuntimeError(
|
218
|
+
f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501
|
219
|
+
)
|
220
|
+
|
221
|
+
m, k = sparse.shape
|
222
|
+
device = sparse.device
|
223
|
+
|
224
|
+
if meta_reordered.dim() != 2:
|
225
|
+
raise RuntimeError(
|
226
|
+
f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501
|
227
|
+
)
|
228
|
+
if meta_reordered.device != device:
|
229
|
+
raise RuntimeError(
|
230
|
+
f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501
|
231
|
+
)
|
232
|
+
|
233
|
+
meta_dtype = meta_reordered.dtype
|
234
|
+
if meta_dtype not in (torch.int16, torch.int32):
|
235
|
+
raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
|
236
|
+
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
|
237
|
+
|
238
|
+
ksparse = 4 if sparse.dtype != torch.float else 2
|
239
|
+
|
240
|
+
meta_nrows, meta_ncols = meta_reordered.shape
|
241
|
+
if meta_nrows != m:
|
242
|
+
raise RuntimeError(
|
243
|
+
f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501
|
244
|
+
)
|
245
|
+
if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
|
246
|
+
raise RuntimeError(
|
247
|
+
f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501
|
248
|
+
"expected according to the number of columns of meta matrix"
|
249
|
+
)
|
250
|
+
|
251
|
+
# Undo meta tensor elements reordering.
|
252
|
+
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
253
|
+
m, meta_ncols, meta_dtype, device
|
254
|
+
)
|
255
|
+
meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols)
|
256
|
+
|
257
|
+
# Unpack sparse tensor back to original dense tensor, using
|
258
|
+
# information provided by meta tensor. Note that torch.float
|
259
|
+
# datatype is handled pretty much the same as
|
260
|
+
# torch.half/torch.bfloat16, as metadata for a pair of torch.float
|
261
|
+
# value is encoded as if underlying 8 bytes contain four
|
262
|
+
# torch.half/torch.bfloat16 values, where either first two or last
|
263
|
+
# two are zeros.
|
264
|
+
meta_2 = torch.empty(
|
265
|
+
(m, meta_ncols, 2 * quadbits_per_meta_elem),
|
266
|
+
dtype=meta_dtype,
|
267
|
+
device=device,
|
268
|
+
)
|
269
|
+
if quadbits_per_meta_elem == 4:
|
270
|
+
meta_2[:, :, 0] = meta & 0b11
|
271
|
+
meta_2[:, :, 1] = (meta >> 2) & 0b11
|
272
|
+
meta_2[:, :, 2] = (meta >> 4) & 0b11
|
273
|
+
meta_2[:, :, 3] = (meta >> 6) & 0b11
|
274
|
+
meta_2[:, :, 4] = (meta >> 8) & 0b11
|
275
|
+
meta_2[:, :, 5] = (meta >> 10) & 0b11
|
276
|
+
meta_2[:, :, 6] = (meta >> 12) & 0b11
|
277
|
+
meta_2[:, :, 7] = (meta >> 14) & 0b11
|
278
|
+
elif quadbits_per_meta_elem == 8:
|
279
|
+
meta_2[:, :, 0] = meta & 0b11
|
280
|
+
meta_2[:, :, 1] = (meta >> 2) & 0b11
|
281
|
+
meta_2[:, :, 2] = (meta >> 4) & 0b11
|
282
|
+
meta_2[:, :, 3] = (meta >> 6) & 0b11
|
283
|
+
meta_2[:, :, 4] = (meta >> 8) & 0b11
|
284
|
+
meta_2[:, :, 5] = (meta >> 10) & 0b11
|
285
|
+
meta_2[:, :, 6] = (meta >> 12) & 0b11
|
286
|
+
meta_2[:, :, 7] = (meta >> 14) & 0b11
|
287
|
+
meta_2[:, :, 8] = (meta >> 16) & 0b11
|
288
|
+
meta_2[:, :, 9] = (meta >> 18) & 0b11
|
289
|
+
meta_2[:, :, 10] = (meta >> 20) & 0b11
|
290
|
+
meta_2[:, :, 11] = (meta >> 22) & 0b11
|
291
|
+
meta_2[:, :, 12] = (meta >> 24) & 0b11
|
292
|
+
meta_2[:, :, 13] = (meta >> 26) & 0b11
|
293
|
+
meta_2[:, :, 14] = (meta >> 28) & 0b11
|
294
|
+
meta_2[:, :, 15] = (meta >> 30) & 0b11
|
295
|
+
|
296
|
+
dense_offsets = meta_2.view(-1) + (
|
297
|
+
torch.arange(0, 2 * m * k // ksparse, device=device) * 4
|
298
|
+
).view(-1, 1).repeat(1, 2).view(-1)
|
299
|
+
|
300
|
+
dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
|
301
|
+
if sparse.dtype != torch.float:
|
302
|
+
# dense.scatter_(0, dense_offsets, sparse.view(-1))
|
303
|
+
dense.scatter_(0, dense_offsets, sparse.reshape(-1))
|
304
|
+
else:
|
305
|
+
dense.view(torch.half).scatter_(
|
306
|
+
0, dense_offsets, sparse.view(torch.half).view(-1)
|
307
|
+
)
|
308
|
+
|
309
|
+
return dense.view(m, 2 * k)
|
310
|
+
|
311
|
+
|
312
|
+
def mask_creator(tensor):
|
313
|
+
"""
|
314
|
+
Class for creating N:M sparsity masks.
|
315
|
+
Masks will be created using the N:M ratio, where for every block of
|
316
|
+
M weights, N will be pruned based on ranked weight value. Each mask
|
317
|
+
will correspond to the given tensor.
|
318
|
+
|
319
|
+
:param N: The number of weights in a group to keep
|
320
|
+
:param M: The size of a weight group
|
321
|
+
"""
|
322
|
+
N = 2
|
323
|
+
M = 4
|
324
|
+
|
325
|
+
mask = None
|
326
|
+
# for i, tensor in enumerate(tensors):
|
327
|
+
if tensor.numel() % M != 0:
|
328
|
+
raise ValueError(
|
329
|
+
f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups"
|
330
|
+
)
|
331
|
+
|
332
|
+
num_groups = tensor.numel() // M
|
333
|
+
|
334
|
+
# N:M sparsity for linear layers
|
335
|
+
tensor_temp = tensor.detach().abs().reshape(num_groups, M)
|
336
|
+
index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)]
|
337
|
+
|
338
|
+
w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)
|
339
|
+
mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)
|
340
|
+
|
341
|
+
return mask
|
@@ -19,17 +19,20 @@ from compressed_tensors.registry import RegistryMixin
|
|
19
19
|
from pydantic import BaseModel
|
20
20
|
|
21
21
|
|
22
|
-
__all__ = ["
|
22
|
+
__all__ = ["SparsityCompressionConfig", "CompressionFormat"]
|
23
23
|
|
24
24
|
|
25
25
|
class CompressionFormat(Enum):
|
26
|
-
|
26
|
+
dense = "dense"
|
27
27
|
sparse_bitmask = "sparse-bitmask"
|
28
|
+
int_quantized = "int-quantized"
|
29
|
+
pack_quantized = "pack-quantized"
|
30
|
+
marlin_24 = "marlin-24"
|
28
31
|
|
29
32
|
|
30
|
-
class
|
33
|
+
class SparsityCompressionConfig(RegistryMixin, BaseModel):
|
31
34
|
"""
|
32
|
-
Base data class for storing compression parameters
|
35
|
+
Base data class for storing sparsity compression parameters
|
33
36
|
|
34
37
|
:param format: name of compression format
|
35
38
|
:param global_sparsity: average sparsity of the entire model
|
@@ -14,14 +14,14 @@
|
|
14
14
|
|
15
15
|
from typing import Optional
|
16
16
|
|
17
|
-
from compressed_tensors.config import
|
17
|
+
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
|
18
18
|
|
19
19
|
|
20
20
|
__all__ = ["DenseSparsityConfig"]
|
21
21
|
|
22
22
|
|
23
|
-
@
|
24
|
-
class DenseSparsityConfig(
|
23
|
+
@SparsityCompressionConfig.register(name=CompressionFormat.dense.value)
|
24
|
+
class DenseSparsityConfig(SparsityCompressionConfig):
|
25
25
|
"""
|
26
26
|
Identity configuration for storing a sparse model in
|
27
27
|
an uncompressed dense format
|
@@ -31,6 +31,6 @@ class DenseSparsityConfig(CompressionConfig):
|
|
31
31
|
"unstructured", "2:4", "8:16" etc
|
32
32
|
"""
|
33
33
|
|
34
|
-
format: str = CompressionFormat.
|
34
|
+
format: str = CompressionFormat.dense.value
|
35
35
|
global_sparsity: Optional[float] = 0.0
|
36
36
|
sparsity_structure: Optional[str] = "unstructured"
|
@@ -14,14 +14,14 @@
|
|
14
14
|
|
15
15
|
from typing import Optional
|
16
16
|
|
17
|
-
from compressed_tensors.config import
|
17
|
+
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
|
18
18
|
|
19
19
|
|
20
20
|
__all__ = ["BitmaskConfig"]
|
21
21
|
|
22
22
|
|
23
|
-
@
|
24
|
-
class BitmaskConfig(
|
23
|
+
@SparsityCompressionConfig.register(name=CompressionFormat.sparse_bitmask.value)
|
24
|
+
class BitmaskConfig(SparsityCompressionConfig):
|
25
25
|
"""
|
26
26
|
Configuration for storing a sparse model using
|
27
27
|
bitmask compression
|
@@ -12,13 +12,18 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
import logging
|
15
16
|
import re
|
16
17
|
from collections import OrderedDict
|
17
18
|
from typing import Dict, Iterable, Optional
|
18
19
|
|
20
|
+
import torch
|
19
21
|
from compressed_tensors.quantization.lifecycle.calibration import (
|
20
22
|
set_module_for_calibration,
|
21
23
|
)
|
24
|
+
from compressed_tensors.quantization.lifecycle.compressed import (
|
25
|
+
compress_quantized_weights,
|
26
|
+
)
|
22
27
|
from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization
|
23
28
|
from compressed_tensors.quantization.lifecycle.initialize import (
|
24
29
|
initialize_module_for_quantization,
|
@@ -27,7 +32,11 @@ from compressed_tensors.quantization.quant_config import (
|
|
27
32
|
QuantizationConfig,
|
28
33
|
QuantizationStatus,
|
29
34
|
)
|
30
|
-
from compressed_tensors.quantization.utils import
|
35
|
+
from compressed_tensors.quantization.utils import (
|
36
|
+
infer_quantization_status,
|
37
|
+
iter_named_leaf_modules,
|
38
|
+
)
|
39
|
+
from compressed_tensors.utils.helpers import fix_fsdp_module_name
|
31
40
|
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
|
32
41
|
from torch.nn import Module
|
33
42
|
|
@@ -43,6 +52,9 @@ from compressed_tensors.quantization.utils.helpers import is_module_quantized
|
|
43
52
|
from compressed_tensors.utils.safetensors_load import get_quantization_state_dict
|
44
53
|
|
45
54
|
|
55
|
+
_LOGGER = logging.getLogger(__name__)
|
56
|
+
|
57
|
+
|
46
58
|
def load_pretrained_quantization(model: Module, model_name_or_path: str):
|
47
59
|
"""
|
48
60
|
Loads the quantization parameters (scale and zero point) from model_name_or_path to
|
@@ -98,15 +110,27 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
|
|
98
110
|
for target in scheme.targets:
|
99
111
|
target_to_scheme[target] = scheme
|
100
112
|
|
113
|
+
# list of submodules to ignore
|
114
|
+
ignored_submodules = []
|
101
115
|
# mark appropriate layers for quantization by setting their quantization schemes
|
102
116
|
for name, submodule in iter_named_leaf_modules(model):
|
117
|
+
# potentially fix module name to remove FSDP wrapper prefix
|
118
|
+
name = fix_fsdp_module_name(name)
|
103
119
|
if find_first_name_or_class_match(name, submodule, config.ignore):
|
120
|
+
ignored_submodules.append(name)
|
104
121
|
continue # layer matches ignore list, continue
|
105
122
|
target = find_first_name_or_class_match(name, submodule, target_to_scheme)
|
106
123
|
if target is not None:
|
107
124
|
# target matched - add layer and scheme to target list
|
108
125
|
submodule.quantization_scheme = target_to_scheme[target]
|
109
126
|
|
127
|
+
if config.ignore is not None and ignored_submodules is not None:
|
128
|
+
if set(config.ignore) - set(ignored_submodules):
|
129
|
+
_LOGGER.warning(
|
130
|
+
"Some layers that were to be ignored were "
|
131
|
+
"not found in the model: "
|
132
|
+
f"{set(config.ignore) - set(ignored_submodules)}"
|
133
|
+
)
|
110
134
|
# apply current quantization status across all targeted layers
|
111
135
|
apply_quantization_status(model, config.quantization_status)
|
112
136
|
|
@@ -118,13 +142,19 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
|
|
118
142
|
:param model: model to apply quantization to
|
119
143
|
:param status: status to update the module to
|
120
144
|
"""
|
121
|
-
|
145
|
+
current_status = infer_quantization_status(model)
|
146
|
+
|
147
|
+
if status >= QuantizationStatus.INITIALIZED > current_status:
|
122
148
|
model.apply(initialize_module_for_quantization)
|
123
|
-
|
149
|
+
|
150
|
+
if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
|
124
151
|
model.apply(set_module_for_calibration)
|
125
|
-
if status >= QuantizationStatus.FROZEN:
|
152
|
+
if current_status < status >= QuantizationStatus.FROZEN > current_status:
|
126
153
|
model.apply(freeze_module_quantization)
|
127
154
|
|
155
|
+
if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
|
156
|
+
model.apply(compress_quantized_weights)
|
157
|
+
|
128
158
|
|
129
159
|
def find_first_name_or_class_match(
|
130
160
|
name: str, module: Module, targets: Iterable[str], check_contains: bool = False
|
@@ -132,9 +162,10 @@ def find_first_name_or_class_match(
|
|
132
162
|
# first element of targets that matches the given name
|
133
163
|
# if no name matches returns first target that matches the class name
|
134
164
|
# returns None otherwise
|
135
|
-
|
136
|
-
|
137
|
-
|
165
|
+
if isinstance(targets, Iterable):
|
166
|
+
return _find_first_match(name, targets) or _find_first_match(
|
167
|
+
module.__class__.__name__, targets, check_contains
|
168
|
+
)
|
138
169
|
|
139
170
|
|
140
171
|
def _find_first_match(
|
@@ -143,6 +174,7 @@ def _find_first_match(
|
|
143
174
|
# returns first element of target that matches value either
|
144
175
|
# exactly or as a regex after 're:'. if check_contains is set to True,
|
145
176
|
# additionally checks if the target string is contained with value.
|
177
|
+
|
146
178
|
for target in targets:
|
147
179
|
if target.startswith("re:"):
|
148
180
|
pattern = target[3:]
|
@@ -156,6 +188,14 @@ def _find_first_match(
|
|
156
188
|
return None
|
157
189
|
|
158
190
|
|
191
|
+
def _infer_status(model: Module) -> Optional[QuantizationStatus]:
|
192
|
+
for module in model.modules():
|
193
|
+
status = getattr(module, "quantization_status", None)
|
194
|
+
if status is not None:
|
195
|
+
return status
|
196
|
+
return None
|
197
|
+
|
198
|
+
|
159
199
|
def _load_quant_args_from_state_dict(
|
160
200
|
base_name: str, module_name: str, module: Module, state_dict: Dict
|
161
201
|
):
|
@@ -172,7 +212,18 @@ def _load_quant_args_from_state_dict(
|
|
172
212
|
zp_name = f"{base_name}_zero_point"
|
173
213
|
device = next(module.parameters()).device
|
174
214
|
|
175
|
-
scale = getattr(module, scale_name)
|
176
|
-
zp = getattr(module, zp_name)
|
177
|
-
scale
|
178
|
-
|
215
|
+
scale = getattr(module, scale_name, None)
|
216
|
+
zp = getattr(module, zp_name, None)
|
217
|
+
if scale is not None:
|
218
|
+
state_dict_scale = state_dict.get(f"{module_name}.{scale_name}")
|
219
|
+
if state_dict_scale is not None:
|
220
|
+
scale.data = state_dict_scale.to(device).to(scale.dtype)
|
221
|
+
else:
|
222
|
+
scale.data = scale.data.to(device)
|
223
|
+
|
224
|
+
if zp is not None:
|
225
|
+
zp_from_state = state_dict.get(f"{module_name}.{zp_name}", None)
|
226
|
+
if zp_from_state is not None: # load the non-zero zero points
|
227
|
+
zp.data = state_dict[f"{module_name}.{zp_name}"].to(device)
|
228
|
+
else: # fill with zeros matching scale shape
|
229
|
+
zp.data = torch.zeros_like(scale, dtype=torch.int8).to(device)
|
@@ -0,0 +1,69 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
import logging
|
17
|
+
|
18
|
+
import torch
|
19
|
+
from compressed_tensors.quantization.lifecycle.forward import quantize
|
20
|
+
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
21
|
+
from torch.nn import Module
|
22
|
+
|
23
|
+
|
24
|
+
__all__ = [
|
25
|
+
"compress_quantized_weights",
|
26
|
+
]
|
27
|
+
|
28
|
+
|
29
|
+
_LOGGER = logging.getLogger(__name__)
|
30
|
+
|
31
|
+
|
32
|
+
def compress_quantized_weights(module: Module):
|
33
|
+
"""
|
34
|
+
Quantizes the module weight representation to use fewer bits in memory
|
35
|
+
|
36
|
+
apply to full model with `model.apply(compress_quantized_weights)`
|
37
|
+
|
38
|
+
:param module: module to compress to quantized representation
|
39
|
+
"""
|
40
|
+
scheme = getattr(module, "quantization_scheme", None)
|
41
|
+
if not scheme or not scheme.weights:
|
42
|
+
# no quantization scheme or weights not quantized, nothing to do
|
43
|
+
return
|
44
|
+
|
45
|
+
if scheme is QuantizationStatus.COMPRESSED:
|
46
|
+
# module is already compressed, nothing to do
|
47
|
+
return
|
48
|
+
|
49
|
+
weight = getattr(module, "weight", None)
|
50
|
+
scale = getattr(module, "weight_scale", None)
|
51
|
+
zero_point = getattr(module, "weight_zero_point", None)
|
52
|
+
|
53
|
+
if weight is None or scale is None or zero_point is None:
|
54
|
+
# no weight, scale, or ZP, nothing to do
|
55
|
+
|
56
|
+
# mark as compressed here to maintain consistent status throughout the model
|
57
|
+
module.quantization_status = QuantizationStatus.COMPRESSED
|
58
|
+
return
|
59
|
+
|
60
|
+
module.weight.requires_grad = False # cannot use auto grad after compression
|
61
|
+
module.weight.data = quantize(
|
62
|
+
x=weight,
|
63
|
+
scale=scale,
|
64
|
+
zero_point=zero_point,
|
65
|
+
args=scheme.weights,
|
66
|
+
dtype=torch.int8,
|
67
|
+
)
|
68
|
+
|
69
|
+
module.quantization_status = QuantizationStatus.COMPRESSED
|