compressed-tensors-nightly 0.3.3.20240514__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/__init__.py +21 -0
- compressed_tensors/base.py +17 -0
- compressed_tensors/compressors/__init__.py +22 -0
- compressed_tensors/compressors/base.py +59 -0
- compressed_tensors/compressors/dense.py +34 -0
- compressed_tensors/compressors/helpers.py +137 -0
- compressed_tensors/compressors/int_quantized.py +95 -0
- compressed_tensors/compressors/model_compressor.py +264 -0
- compressed_tensors/compressors/sparse_bitmask.py +239 -0
- compressed_tensors/config/__init__.py +18 -0
- compressed_tensors/config/base.py +43 -0
- compressed_tensors/config/dense.py +36 -0
- compressed_tensors/config/sparse_bitmask.py +36 -0
- compressed_tensors/quantization/__init__.py +21 -0
- compressed_tensors/quantization/lifecycle/__init__.py +23 -0
- compressed_tensors/quantization/lifecycle/apply.py +196 -0
- compressed_tensors/quantization/lifecycle/calibration.py +51 -0
- compressed_tensors/quantization/lifecycle/compressed.py +69 -0
- compressed_tensors/quantization/lifecycle/forward.py +333 -0
- compressed_tensors/quantization/lifecycle/frozen.py +50 -0
- compressed_tensors/quantization/lifecycle/initialize.py +99 -0
- compressed_tensors/quantization/observers/__init__.py +21 -0
- compressed_tensors/quantization/observers/base.py +130 -0
- compressed_tensors/quantization/observers/helpers.py +54 -0
- compressed_tensors/quantization/observers/memoryless.py +48 -0
- compressed_tensors/quantization/observers/min_max.py +80 -0
- compressed_tensors/quantization/quant_args.py +125 -0
- compressed_tensors/quantization/quant_config.py +210 -0
- compressed_tensors/quantization/quant_scheme.py +39 -0
- compressed_tensors/quantization/utils/__init__.py +16 -0
- compressed_tensors/quantization/utils/helpers.py +131 -0
- compressed_tensors/registry/__init__.py +17 -0
- compressed_tensors/registry/registry.py +360 -0
- compressed_tensors/utils/__init__.py +16 -0
- compressed_tensors/utils/helpers.py +45 -0
- compressed_tensors/utils/safetensors_load.py +237 -0
- compressed_tensors/version.py +50 -0
- compressed_tensors_nightly-0.3.3.20240514.dist-info/LICENSE +201 -0
- compressed_tensors_nightly-0.3.3.20240514.dist-info/METADATA +105 -0
- compressed_tensors_nightly-0.3.3.20240514.dist-info/RECORD +42 -0
- compressed_tensors_nightly-0.3.3.20240514.dist-info/WHEEL +5 -0
- compressed_tensors_nightly-0.3.3.20240514.dist-info/top_level.txt +1 -0
@@ -0,0 +1,69 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
import logging
|
17
|
+
|
18
|
+
import torch
|
19
|
+
from compressed_tensors.quantization.lifecycle.forward import quantize
|
20
|
+
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
21
|
+
from torch.nn import Module
|
22
|
+
|
23
|
+
|
24
|
+
__all__ = [
|
25
|
+
"compress_quantized_weights",
|
26
|
+
]
|
27
|
+
|
28
|
+
|
29
|
+
_LOGGER = logging.getLogger(__name__)
|
30
|
+
|
31
|
+
|
32
|
+
def compress_quantized_weights(module: Module):
|
33
|
+
"""
|
34
|
+
Quantizes the module weight representation to use fewer bits in memory
|
35
|
+
|
36
|
+
apply to full model with `model.apply(compress_quantized_weights)`
|
37
|
+
|
38
|
+
:param module: module to compress to quantized representation
|
39
|
+
"""
|
40
|
+
scheme = getattr(module, "quantization_scheme", None)
|
41
|
+
if not scheme or not scheme.weights:
|
42
|
+
# no quantization scheme or weights not quantized, nothing to do
|
43
|
+
return
|
44
|
+
|
45
|
+
if scheme is QuantizationStatus.COMPRESSED:
|
46
|
+
# module is already compressed, nothing to do
|
47
|
+
return
|
48
|
+
|
49
|
+
weight = getattr(module, "weight", None)
|
50
|
+
scale = getattr(module, "weight_scale", None)
|
51
|
+
zero_point = getattr(module, "weight_zero_point", None)
|
52
|
+
|
53
|
+
if weight is None or scale is None or zero_point is None:
|
54
|
+
# no weight, scale, or ZP, nothing to do
|
55
|
+
|
56
|
+
# mark as compressed here to maintain consistent status throughout the model
|
57
|
+
module.quantization_status = QuantizationStatus.COMPRESSED
|
58
|
+
return
|
59
|
+
|
60
|
+
module.weight.requires_grad = False # cannot use auto grad after compression
|
61
|
+
module.weight.data = quantize(
|
62
|
+
x=weight,
|
63
|
+
scale=scale,
|
64
|
+
zero_point=zero_point,
|
65
|
+
args=scheme.weights,
|
66
|
+
dtype=torch.int8,
|
67
|
+
)
|
68
|
+
|
69
|
+
module.quantization_status = QuantizationStatus.COMPRESSED
|
@@ -0,0 +1,333 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from functools import wraps
|
16
|
+
from math import ceil
|
17
|
+
from typing import Optional
|
18
|
+
|
19
|
+
import torch
|
20
|
+
from compressed_tensors.quantization.quant_args import (
|
21
|
+
QuantizationArgs,
|
22
|
+
QuantizationStrategy,
|
23
|
+
)
|
24
|
+
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
25
|
+
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
26
|
+
from torch.nn import Module
|
27
|
+
|
28
|
+
|
29
|
+
__all__ = [
|
30
|
+
"quantize",
|
31
|
+
"dequantize",
|
32
|
+
"fake_quantize",
|
33
|
+
"wrap_module_forward_quantized",
|
34
|
+
"maybe_calibrate_or_quantize",
|
35
|
+
]
|
36
|
+
|
37
|
+
|
38
|
+
@torch.no_grad()
|
39
|
+
def quantize(
|
40
|
+
x: torch.Tensor,
|
41
|
+
scale: torch.Tensor,
|
42
|
+
zero_point: torch.Tensor,
|
43
|
+
args: QuantizationArgs,
|
44
|
+
dtype: Optional[torch.dtype] = None,
|
45
|
+
) -> torch.Tensor:
|
46
|
+
"""
|
47
|
+
Quantize the input tensor x using the QuantizationStrategy specified in args.
|
48
|
+
Quantization can be done per tensor, channel, token or group. For group
|
49
|
+
quantization, the group_size must be divisible by the column size. The input scale
|
50
|
+
and zero_points are reshaped to support vectorization (Assumes 1 is the
|
51
|
+
channel dimension)
|
52
|
+
|
53
|
+
:param x: Input tensor
|
54
|
+
:param scale: scale tensor
|
55
|
+
:param zero_point: zero point tensor
|
56
|
+
:param args: quantization args dictating how to quantize x
|
57
|
+
:param dtype: optional dtype to cast the quantized output to
|
58
|
+
:return: fake quantized tensor
|
59
|
+
"""
|
60
|
+
return _process_quantization(
|
61
|
+
x=x,
|
62
|
+
scale=scale,
|
63
|
+
zero_point=zero_point,
|
64
|
+
args=args,
|
65
|
+
dtype=dtype,
|
66
|
+
do_quantize=True,
|
67
|
+
do_dequantize=False,
|
68
|
+
)
|
69
|
+
|
70
|
+
|
71
|
+
@torch.no_grad()
|
72
|
+
def dequantize(
|
73
|
+
x_q: torch.Tensor,
|
74
|
+
scale: torch.Tensor,
|
75
|
+
zero_point: torch.Tensor,
|
76
|
+
args: QuantizationArgs = None,
|
77
|
+
) -> torch.Tensor:
|
78
|
+
"""
|
79
|
+
Dequantize a quantized input tensor x_q based on the strategy specified in args. If
|
80
|
+
args is not provided, the strategy will be inferred.
|
81
|
+
|
82
|
+
:param x: quantized input tensor
|
83
|
+
:param scale: scale tensor
|
84
|
+
:param zero_point: zero point tensor
|
85
|
+
:param args: quantization args used to quantize x_q
|
86
|
+
:return: dequantized float tensor
|
87
|
+
"""
|
88
|
+
if args is None:
|
89
|
+
if scale.ndim == 0:
|
90
|
+
args = QuantizationArgs(strategy=QuantizationStrategy.TENSOR)
|
91
|
+
elif scale.ndim == 2:
|
92
|
+
args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL)
|
93
|
+
elif scale.ndim == 3:
|
94
|
+
group_size = int(x_q.shape[1] / scale.shape[1])
|
95
|
+
args = QuantizationArgs(
|
96
|
+
strategy=QuantizationStrategy.GROUP, group_size=group_size
|
97
|
+
)
|
98
|
+
return _process_quantization(
|
99
|
+
x=x_q,
|
100
|
+
scale=scale,
|
101
|
+
zero_point=zero_point,
|
102
|
+
args=args,
|
103
|
+
do_quantize=False,
|
104
|
+
do_dequantize=True,
|
105
|
+
)
|
106
|
+
|
107
|
+
|
108
|
+
@torch.no_grad()
|
109
|
+
def fake_quantize(
|
110
|
+
x: torch.Tensor,
|
111
|
+
scale: torch.Tensor,
|
112
|
+
zero_point: torch.Tensor,
|
113
|
+
args: QuantizationArgs,
|
114
|
+
) -> torch.Tensor:
|
115
|
+
"""
|
116
|
+
Fake quantize the input tensor x by quantizing then dequantizing with
|
117
|
+
the QuantizationStrategy specified in args. Quantization can be done per tensor,
|
118
|
+
channel, token or group. For group quantization, the group_size must be divisible
|
119
|
+
by the column size. The input scale and zero_points are reshaped to support
|
120
|
+
vectorization (Assumes 1 is the channel dimension)
|
121
|
+
|
122
|
+
:param x: Input tensor
|
123
|
+
:param scale: scale tensor
|
124
|
+
:param zero_point: zero point tensor
|
125
|
+
:param args: quantization args dictating how to quantize x
|
126
|
+
:return: fake quantized tensor
|
127
|
+
"""
|
128
|
+
return _process_quantization(
|
129
|
+
x=x,
|
130
|
+
scale=scale,
|
131
|
+
zero_point=zero_point,
|
132
|
+
args=args,
|
133
|
+
do_quantize=True,
|
134
|
+
do_dequantize=True,
|
135
|
+
)
|
136
|
+
|
137
|
+
|
138
|
+
@torch.no_grad()
|
139
|
+
def _process_quantization(
|
140
|
+
x: torch.Tensor,
|
141
|
+
scale: torch.Tensor,
|
142
|
+
zero_point: torch.Tensor,
|
143
|
+
args: QuantizationArgs,
|
144
|
+
dtype: Optional[torch.dtype] = None,
|
145
|
+
do_quantize: bool = True,
|
146
|
+
do_dequantize: bool = True,
|
147
|
+
) -> torch.Tensor:
|
148
|
+
bit_range = 2**args.num_bits
|
149
|
+
q_max = torch.tensor(bit_range / 2 - 1, device=x.device)
|
150
|
+
q_min = torch.tensor(-bit_range / 2, device=x.device)
|
151
|
+
group_size = args.group_size
|
152
|
+
|
153
|
+
# group
|
154
|
+
if args.strategy == QuantizationStrategy.GROUP:
|
155
|
+
|
156
|
+
if do_dequantize: # if dequantizing the output should be a fp type
|
157
|
+
output = torch.zeros_like(x, dtype=scale.dtype)
|
158
|
+
else:
|
159
|
+
output_dtype = dtype if dtype is not None else x.dtype
|
160
|
+
output = torch.zeros_like(x, dtype=output_dtype)
|
161
|
+
|
162
|
+
# TODO: vectorize the for loop
|
163
|
+
# TODO: fix genetric assumption about the tensor size for computing group
|
164
|
+
|
165
|
+
# TODO: make validation step for inputs
|
166
|
+
|
167
|
+
while scale.ndim < 2:
|
168
|
+
# pad scale and zero point dims for slicing
|
169
|
+
scale = scale.unsqueeze(1)
|
170
|
+
zero_point = zero_point.unsqueeze(1)
|
171
|
+
|
172
|
+
columns = x.shape[1]
|
173
|
+
if columns >= group_size:
|
174
|
+
if columns % group_size != 0:
|
175
|
+
raise ValueError(
|
176
|
+
"tesnor column shape must be divisble "
|
177
|
+
f"by the given group_size {group_size}"
|
178
|
+
)
|
179
|
+
for i in range(ceil(columns / group_size)):
|
180
|
+
# scale.shape should be [nchan, ndim]
|
181
|
+
# sc.shape should be [nchan, 1] after unsqueeze
|
182
|
+
sc = scale[:, i].view(-1, 1)
|
183
|
+
zp = zero_point[:, i].view(-1, 1)
|
184
|
+
|
185
|
+
idx = i * group_size
|
186
|
+
if do_quantize:
|
187
|
+
output[:, idx : (idx + group_size)] = _quantize(
|
188
|
+
x[:, idx : (idx + group_size)], sc, zp, q_min, q_max, dtype=dtype
|
189
|
+
)
|
190
|
+
if do_dequantize:
|
191
|
+
input = (
|
192
|
+
output[:, idx : (idx + group_size)]
|
193
|
+
if do_quantize
|
194
|
+
else x[:, idx : (idx + group_size)]
|
195
|
+
)
|
196
|
+
output[:, idx : (idx + group_size)] = _dequantize(input, sc, zp)
|
197
|
+
|
198
|
+
# channel-wise
|
199
|
+
elif args.strategy == QuantizationStrategy.CHANNEL: # group_size == -1
|
200
|
+
if do_quantize:
|
201
|
+
output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
|
202
|
+
if do_dequantize:
|
203
|
+
output = _dequantize(output if do_quantize else x, scale, zero_point)
|
204
|
+
|
205
|
+
# per-token
|
206
|
+
elif args.strategy == QuantizationStrategy.TOKEN:
|
207
|
+
# before: scale shape = [num_tokens]
|
208
|
+
# after: scale shape = [num_tokens, 1]
|
209
|
+
# x.shape = 1, num_tokens, 1]
|
210
|
+
# scale gets broadcasted as expected withput having [1, num_tokens, 1] shape
|
211
|
+
|
212
|
+
scale = scale.unsqueeze(1)
|
213
|
+
zero_point = zero_point.unsqueeze(1)
|
214
|
+
|
215
|
+
if do_quantize:
|
216
|
+
output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
|
217
|
+
if do_dequantize:
|
218
|
+
output = _dequantize(output if do_quantize else x, scale, zero_point)
|
219
|
+
|
220
|
+
else:
|
221
|
+
if do_quantize:
|
222
|
+
output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
|
223
|
+
if do_dequantize:
|
224
|
+
output = _dequantize(output if do_quantize else x, scale, zero_point)
|
225
|
+
|
226
|
+
return output
|
227
|
+
|
228
|
+
|
229
|
+
def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
230
|
+
# expects a module already initialized and injected with the parameters in
|
231
|
+
# initialize_module_for_quantization
|
232
|
+
forward_func_orig = module.forward.__func__
|
233
|
+
|
234
|
+
@wraps(forward_func_orig) # ensures docstring, names, etc are propagated
|
235
|
+
def wrapped_forward(self, *args, **kwargs):
|
236
|
+
input_ = args[0]
|
237
|
+
|
238
|
+
if scheme.input_activations is not None:
|
239
|
+
# calibrate and (fake) quantize input activations when applicable
|
240
|
+
input_ = maybe_calibrate_or_quantize(
|
241
|
+
module, input_, "input", scheme.input_activations
|
242
|
+
)
|
243
|
+
|
244
|
+
if scheme.weights is not None:
|
245
|
+
# calibrate and (fake) quantize weights when applicable
|
246
|
+
unquantized_weight = self.weight.data.clone()
|
247
|
+
self.weight.data = maybe_calibrate_or_quantize(
|
248
|
+
module, self.weight, "weight", scheme.weights
|
249
|
+
)
|
250
|
+
|
251
|
+
# perform wrapped forward call
|
252
|
+
output = forward_func_orig.__get__(module, module.__class__)(
|
253
|
+
input_, *args[1:], **kwargs
|
254
|
+
)
|
255
|
+
|
256
|
+
if scheme.output_activations is not None:
|
257
|
+
# calibrate and (fake) quantize output activations when applicable
|
258
|
+
output = maybe_calibrate_or_quantize(
|
259
|
+
module, output, "output", scheme.output_activations
|
260
|
+
)
|
261
|
+
|
262
|
+
# restore back to unquantized_value
|
263
|
+
if scheme.weights is not None:
|
264
|
+
self.weight.data = unquantized_weight
|
265
|
+
|
266
|
+
return output
|
267
|
+
|
268
|
+
# bind wrapped forward to module class so reference to `self` is correct
|
269
|
+
bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
|
270
|
+
# set forward to wrapped forward
|
271
|
+
setattr(module, "forward", bound_wrapped_forward)
|
272
|
+
|
273
|
+
|
274
|
+
def maybe_calibrate_or_quantize(
|
275
|
+
module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
|
276
|
+
) -> torch.Tensor:
|
277
|
+
# only run quantized for the included stages
|
278
|
+
if module.quantization_status not in {
|
279
|
+
QuantizationStatus.CALIBRATION,
|
280
|
+
QuantizationStatus.FROZEN,
|
281
|
+
}:
|
282
|
+
return value
|
283
|
+
|
284
|
+
if args.dynamic:
|
285
|
+
# dynamic quantization - get scale and zero point directly from observer
|
286
|
+
observer = getattr(module, f"{base_name}_observer")
|
287
|
+
scale, zero_point = observer(value)
|
288
|
+
else:
|
289
|
+
# static quantization - get previous scale and zero point from layer
|
290
|
+
scale = getattr(module, f"{base_name}_scale")
|
291
|
+
zero_point = getattr(module, f"{base_name}_zero_point")
|
292
|
+
|
293
|
+
if module.quantization_status == QuantizationStatus.CALIBRATION:
|
294
|
+
# calibration mode - get new quant params from observer
|
295
|
+
observer = getattr(module, f"{base_name}_observer")
|
296
|
+
|
297
|
+
updated_scale, updated_zero_point = observer(value)
|
298
|
+
|
299
|
+
# update scale and zero point
|
300
|
+
device = next(module.parameters()).device
|
301
|
+
scale.data = updated_scale.to(device)
|
302
|
+
zero_point.data = updated_zero_point.to(device)
|
303
|
+
return fake_quantize(value, scale, zero_point, args)
|
304
|
+
|
305
|
+
|
306
|
+
@torch.no_grad()
|
307
|
+
def _quantize(
|
308
|
+
x: torch.Tensor,
|
309
|
+
scale: torch.Tensor,
|
310
|
+
zero_point: torch.Tensor,
|
311
|
+
q_min: torch.Tensor,
|
312
|
+
q_max: torch.Tensor,
|
313
|
+
dtype: Optional[torch.dtype] = None,
|
314
|
+
) -> torch.Tensor:
|
315
|
+
quantized_value = torch.clamp(
|
316
|
+
torch.round(x / scale + zero_point),
|
317
|
+
q_min,
|
318
|
+
q_max,
|
319
|
+
)
|
320
|
+
|
321
|
+
if dtype is not None:
|
322
|
+
quantized_value = quantized_value.to(dtype)
|
323
|
+
|
324
|
+
return quantized_value
|
325
|
+
|
326
|
+
|
327
|
+
@torch.no_grad()
|
328
|
+
def _dequantize(
|
329
|
+
x_q: torch.Tensor,
|
330
|
+
scale: torch.Tensor,
|
331
|
+
zero_point: torch.Tensor,
|
332
|
+
) -> torch.Tensor:
|
333
|
+
return (x_q - zero_point) * scale
|
@@ -0,0 +1,50 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
17
|
+
from torch.nn import Module
|
18
|
+
|
19
|
+
|
20
|
+
__all__ = [
|
21
|
+
"freeze_module_quantization",
|
22
|
+
]
|
23
|
+
|
24
|
+
|
25
|
+
def freeze_module_quantization(module: Module):
|
26
|
+
"""
|
27
|
+
deletes observers so static quantization is completed.
|
28
|
+
|
29
|
+
apply to full model with `model.apply(freeze_module_quantization)`
|
30
|
+
|
31
|
+
:param module: module to freeze quantization for
|
32
|
+
"""
|
33
|
+
scheme = getattr(module, "quantization_scheme", None)
|
34
|
+
if not scheme:
|
35
|
+
# no quantization scheme nothing to do
|
36
|
+
return
|
37
|
+
|
38
|
+
if module.quantization_status == QuantizationStatus.FROZEN:
|
39
|
+
# nothing to do, already frozen
|
40
|
+
return
|
41
|
+
|
42
|
+
# delete observers from module if not dynamic
|
43
|
+
if scheme.input_activations and not scheme.input_activations.dynamic:
|
44
|
+
delattr(module, "input_observer")
|
45
|
+
if scheme.weights and not scheme.weights.dynamic:
|
46
|
+
delattr(module, "weight_observer")
|
47
|
+
if scheme.output_activations and not scheme.output_activations.dynamic:
|
48
|
+
delattr(module, "output_observer")
|
49
|
+
|
50
|
+
module.quantization_status = QuantizationStatus.FROZEN
|
@@ -0,0 +1,99 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
import logging
|
17
|
+
from typing import Optional
|
18
|
+
|
19
|
+
import torch
|
20
|
+
from compressed_tensors.quantization.lifecycle.forward import (
|
21
|
+
wrap_module_forward_quantized,
|
22
|
+
)
|
23
|
+
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
24
|
+
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
25
|
+
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
26
|
+
from torch.nn import Module, Parameter
|
27
|
+
|
28
|
+
|
29
|
+
__all__ = [
|
30
|
+
"initialize_module_for_quantization",
|
31
|
+
]
|
32
|
+
|
33
|
+
|
34
|
+
_LOGGER = logging.getLogger(__name__)
|
35
|
+
|
36
|
+
|
37
|
+
def initialize_module_for_quantization(
|
38
|
+
module: Module,
|
39
|
+
scheme: Optional[QuantizationScheme] = None,
|
40
|
+
):
|
41
|
+
"""
|
42
|
+
attaches appropriate scales, zero points, and observers to a layer
|
43
|
+
given its target quantization scheme
|
44
|
+
|
45
|
+
apply to full model with `model.apply(initialize_module_for_quantization)`
|
46
|
+
|
47
|
+
:param module: module to set for calibration
|
48
|
+
:param scheme: scheme to use for quantization. if None is provided,
|
49
|
+
will attempt to use scheme stored in the module under `quantization_scheme`,
|
50
|
+
if not provided, the layer will be skipped
|
51
|
+
"""
|
52
|
+
scheme = scheme or getattr(module, "quantization_scheme", None)
|
53
|
+
if scheme is None:
|
54
|
+
# no scheme passed and layer not targeted for quantization - skip
|
55
|
+
return
|
56
|
+
|
57
|
+
if scheme.input_activations is not None:
|
58
|
+
_initialize_scale_zero_point_observer(module, "input", scheme.input_activations)
|
59
|
+
if scheme.weights is not None:
|
60
|
+
if hasattr(module, "weight"):
|
61
|
+
_initialize_scale_zero_point_observer(module, "weight", scheme.weights)
|
62
|
+
else:
|
63
|
+
_LOGGER.warning(
|
64
|
+
f"module type {type(module)} targeted for weight quantization but "
|
65
|
+
"has no attribute weight, skipping weight quantization "
|
66
|
+
f"for {type(module)}"
|
67
|
+
)
|
68
|
+
if scheme.output_activations is not None:
|
69
|
+
_initialize_scale_zero_point_observer(
|
70
|
+
module, "output", scheme.output_activations
|
71
|
+
)
|
72
|
+
|
73
|
+
module.quantization_scheme = scheme
|
74
|
+
module.quantization_status = QuantizationStatus.INITIALIZED
|
75
|
+
|
76
|
+
# wrap forward call of module to perform quantized actions based on calltime status
|
77
|
+
wrap_module_forward_quantized(module, scheme)
|
78
|
+
|
79
|
+
|
80
|
+
def _initialize_scale_zero_point_observer(
|
81
|
+
module: Module, base_name: str, quantization_args: QuantizationArgs
|
82
|
+
):
|
83
|
+
# initialize observer module and attach as submodule
|
84
|
+
observer = quantization_args.get_observer()
|
85
|
+
module.register_module(f"{base_name}_observer", observer)
|
86
|
+
|
87
|
+
if quantization_args.dynamic:
|
88
|
+
return # no need to register a scale and zero point for a dynamic observer
|
89
|
+
|
90
|
+
device = next(module.parameters()).device
|
91
|
+
|
92
|
+
# initializes empty scale and zero point parameters for the module
|
93
|
+
init_scale = Parameter(torch.empty(0, device=device), requires_grad=False)
|
94
|
+
module.register_parameter(f"{base_name}_scale", init_scale)
|
95
|
+
|
96
|
+
init_zero_point = Parameter(
|
97
|
+
torch.empty(0, device=device, dtype=int), requires_grad=False
|
98
|
+
)
|
99
|
+
module.register_parameter(f"{base_name}_zero_point", init_zero_point)
|
@@ -0,0 +1,21 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# flake8: noqa
|
16
|
+
# isort: skip_file
|
17
|
+
|
18
|
+
from .helpers import *
|
19
|
+
from .base import *
|
20
|
+
from .memoryless import *
|
21
|
+
from .min_max import *
|