compressed-tensors 0.3.3__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/base.py +3 -1
- compressed_tensors/compressors/__init__.py +9 -1
- compressed_tensors/compressors/base.py +12 -55
- compressed_tensors/compressors/dense.py +5 -5
- compressed_tensors/compressors/helpers.py +12 -12
- compressed_tensors/compressors/marlin_24.py +251 -0
- compressed_tensors/compressors/model_compressor.py +336 -0
- compressed_tensors/compressors/naive_quantized.py +144 -0
- compressed_tensors/compressors/pack_quantized.py +219 -0
- compressed_tensors/compressors/sparse_bitmask.py +4 -4
- compressed_tensors/config/base.py +9 -4
- compressed_tensors/config/dense.py +4 -4
- compressed_tensors/config/sparse_bitmask.py +3 -3
- compressed_tensors/quantization/lifecycle/__init__.py +2 -0
- compressed_tensors/quantization/lifecycle/apply.py +204 -31
- compressed_tensors/quantization/lifecycle/calibration.py +20 -1
- compressed_tensors/quantization/lifecycle/compressed.py +69 -0
- compressed_tensors/quantization/lifecycle/forward.py +214 -62
- compressed_tensors/quantization/lifecycle/frozen.py +4 -0
- compressed_tensors/quantization/lifecycle/helpers.py +53 -0
- compressed_tensors/quantization/lifecycle/initialize.py +62 -5
- compressed_tensors/quantization/observers/base.py +66 -23
- compressed_tensors/quantization/observers/helpers.py +69 -11
- compressed_tensors/quantization/observers/memoryless.py +17 -9
- compressed_tensors/quantization/observers/min_max.py +44 -13
- compressed_tensors/quantization/quant_args.py +47 -3
- compressed_tensors/quantization/quant_config.py +104 -23
- compressed_tensors/quantization/quant_scheme.py +183 -2
- compressed_tensors/quantization/utils/helpers.py +142 -8
- compressed_tensors/utils/__init__.py +4 -0
- compressed_tensors/utils/helpers.py +54 -7
- compressed_tensors/utils/offload.py +104 -0
- compressed_tensors/utils/permutations_24.py +65 -0
- compressed_tensors/utils/safetensors_load.py +3 -2
- compressed_tensors/utils/semi_structured_conversions.py +341 -0
- compressed_tensors/version.py +53 -0
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/METADATA +47 -8
- compressed_tensors-0.5.0.dist-info/RECORD +48 -0
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/WHEEL +1 -1
- compressed_tensors-0.3.3.dist-info/RECORD +0 -38
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/top_level.txt +0 -0
@@ -14,18 +14,28 @@
|
|
14
14
|
|
15
15
|
from functools import wraps
|
16
16
|
from math import ceil
|
17
|
+
from typing import Optional
|
17
18
|
|
18
19
|
import torch
|
20
|
+
from compressed_tensors.quantization.observers.helpers import calculate_range
|
19
21
|
from compressed_tensors.quantization.quant_args import (
|
20
22
|
QuantizationArgs,
|
21
23
|
QuantizationStrategy,
|
24
|
+
round_to_quantized_type,
|
22
25
|
)
|
23
26
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
24
27
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
28
|
+
from compressed_tensors.utils import update_parameter_data
|
25
29
|
from torch.nn import Module
|
26
30
|
|
27
31
|
|
28
|
-
__all__ = [
|
32
|
+
__all__ = [
|
33
|
+
"quantize",
|
34
|
+
"dequantize",
|
35
|
+
"fake_quantize",
|
36
|
+
"wrap_module_forward_quantized",
|
37
|
+
"maybe_calibrate_or_quantize",
|
38
|
+
]
|
29
39
|
|
30
40
|
|
31
41
|
@torch.no_grad()
|
@@ -33,14 +43,39 @@ def quantize(
|
|
33
43
|
x: torch.Tensor,
|
34
44
|
scale: torch.Tensor,
|
35
45
|
zero_point: torch.Tensor,
|
36
|
-
|
37
|
-
|
46
|
+
args: QuantizationArgs,
|
47
|
+
dtype: Optional[torch.dtype] = None,
|
38
48
|
) -> torch.Tensor:
|
49
|
+
"""
|
50
|
+
Quantize the input tensor x using the QuantizationStrategy specified in args.
|
51
|
+
Quantization can be done per tensor, channel, token or group. For group
|
52
|
+
quantization, the group_size must be divisible by the column size. The input scale
|
53
|
+
and zero_points are reshaped to support vectorization (Assumes 1 is the
|
54
|
+
channel dimension)
|
39
55
|
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
56
|
+
:param x: Input tensor
|
57
|
+
:param scale: scale tensor
|
58
|
+
:param zero_point: zero point tensor
|
59
|
+
:param args: quantization args dictating how to quantize x
|
60
|
+
:param dtype: optional dtype to cast the quantized output to
|
61
|
+
:return: fake quantized tensor
|
62
|
+
"""
|
63
|
+
# ensure all tensors are on the same device
|
64
|
+
# assumes that the target device is the input
|
65
|
+
# tensor's device
|
66
|
+
if x.device != scale.device:
|
67
|
+
scale = scale.to(x.device)
|
68
|
+
if x.device != zero_point.device:
|
69
|
+
zero_point = zero_point.to(x.device)
|
70
|
+
|
71
|
+
return _process_quantization(
|
72
|
+
x=x,
|
73
|
+
scale=scale,
|
74
|
+
zero_point=zero_point,
|
75
|
+
args=args,
|
76
|
+
dtype=dtype,
|
77
|
+
do_quantize=True,
|
78
|
+
do_dequantize=False,
|
44
79
|
)
|
45
80
|
|
46
81
|
|
@@ -48,9 +83,50 @@ def quantize(
|
|
48
83
|
def dequantize(
|
49
84
|
x_q: torch.Tensor,
|
50
85
|
scale: torch.Tensor,
|
51
|
-
zero_point: torch.Tensor,
|
86
|
+
zero_point: torch.Tensor = None,
|
87
|
+
args: QuantizationArgs = None,
|
88
|
+
dtype: Optional[torch.dtype] = None,
|
52
89
|
) -> torch.Tensor:
|
53
|
-
|
90
|
+
"""
|
91
|
+
Dequantize a quantized input tensor x_q based on the strategy specified in args. If
|
92
|
+
args is not provided, the strategy will be inferred.
|
93
|
+
|
94
|
+
:param x: quantized input tensor
|
95
|
+
:param scale: scale tensor
|
96
|
+
:param zero_point: zero point tensor
|
97
|
+
:param args: quantization args used to quantize x_q
|
98
|
+
:param dtype: optional dtype to cast the dequantized output to
|
99
|
+
:return: dequantized float tensor
|
100
|
+
"""
|
101
|
+
if args is None:
|
102
|
+
if scale.ndim == 0 or scale.ndim == 1:
|
103
|
+
args = QuantizationArgs(strategy=QuantizationStrategy.TENSOR)
|
104
|
+
elif scale.ndim == 2:
|
105
|
+
if scale.shape[1] == 1:
|
106
|
+
args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL)
|
107
|
+
else:
|
108
|
+
group_size = int(x_q.shape[1] / scale.shape[1])
|
109
|
+
args = QuantizationArgs(
|
110
|
+
strategy=QuantizationStrategy.GROUP, group_size=group_size
|
111
|
+
)
|
112
|
+
else:
|
113
|
+
raise ValueError(
|
114
|
+
f"Could not infer a quantization strategy from scale with {scale.ndim} "
|
115
|
+
"dimmensions. Expected 0 or 2 dimmensions."
|
116
|
+
)
|
117
|
+
|
118
|
+
if dtype is None:
|
119
|
+
dtype = scale.dtype
|
120
|
+
|
121
|
+
return _process_quantization(
|
122
|
+
x=x_q,
|
123
|
+
scale=scale,
|
124
|
+
zero_point=zero_point,
|
125
|
+
args=args,
|
126
|
+
do_quantize=False,
|
127
|
+
do_dequantize=True,
|
128
|
+
dtype=dtype,
|
129
|
+
)
|
54
130
|
|
55
131
|
|
56
132
|
@torch.no_grad()
|
@@ -61,30 +137,45 @@ def fake_quantize(
|
|
61
137
|
args: QuantizationArgs,
|
62
138
|
) -> torch.Tensor:
|
63
139
|
"""
|
64
|
-
Fake quantize the input tensor x
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
the channel dimension)
|
140
|
+
Fake quantize the input tensor x by quantizing then dequantizing with
|
141
|
+
the QuantizationStrategy specified in args. Quantization can be done per tensor,
|
142
|
+
channel, token or group. For group quantization, the group_size must be divisible
|
143
|
+
by the column size. The input scale and zero_points are reshaped to support
|
144
|
+
vectorization (Assumes 1 is the channel dimension)
|
70
145
|
|
71
146
|
:param x: Input tensor
|
72
147
|
:param scale: scale tensor
|
73
148
|
:param zero_point: zero point tensor
|
74
|
-
:param args: quantization args
|
149
|
+
:param args: quantization args dictating how to quantize x
|
75
150
|
:return: fake quantized tensor
|
76
|
-
|
77
151
|
"""
|
78
|
-
|
79
|
-
|
80
|
-
|
152
|
+
return _process_quantization(
|
153
|
+
x=x,
|
154
|
+
scale=scale,
|
155
|
+
zero_point=zero_point,
|
156
|
+
args=args,
|
157
|
+
do_quantize=True,
|
158
|
+
do_dequantize=True,
|
159
|
+
)
|
160
|
+
|
161
|
+
|
162
|
+
@torch.no_grad()
|
163
|
+
def _process_quantization(
|
164
|
+
x: torch.Tensor,
|
165
|
+
scale: torch.Tensor,
|
166
|
+
zero_point: torch.Tensor,
|
167
|
+
args: QuantizationArgs,
|
168
|
+
dtype: Optional[torch.dtype] = None,
|
169
|
+
do_quantize: bool = True,
|
170
|
+
do_dequantize: bool = True,
|
171
|
+
) -> torch.Tensor:
|
81
172
|
|
173
|
+
q_min, q_max = calculate_range(args, x.device)
|
82
174
|
group_size = args.group_size
|
83
175
|
|
84
|
-
# group
|
85
176
|
if args.strategy == QuantizationStrategy.GROUP:
|
86
|
-
|
87
|
-
|
177
|
+
output_dtype = dtype if dtype is not None else x.dtype
|
178
|
+
output = torch.zeros_like(x).to(output_dtype)
|
88
179
|
|
89
180
|
# TODO: vectorize the for loop
|
90
181
|
# TODO: fix genetric assumption about the tensor size for computing group
|
@@ -94,7 +185,7 @@ def fake_quantize(
|
|
94
185
|
while scale.ndim < 2:
|
95
186
|
# pad scale and zero point dims for slicing
|
96
187
|
scale = scale.unsqueeze(1)
|
97
|
-
zero_point = zero_point.unsqueeze(1)
|
188
|
+
zero_point = zero_point.unsqueeze(1) if zero_point is not None else None
|
98
189
|
|
99
190
|
columns = x.shape[1]
|
100
191
|
if columns >= group_size:
|
@@ -106,51 +197,60 @@ def fake_quantize(
|
|
106
197
|
for i in range(ceil(columns / group_size)):
|
107
198
|
# scale.shape should be [nchan, ndim]
|
108
199
|
# sc.shape should be [nchan, 1] after unsqueeze
|
109
|
-
|
110
|
-
|
111
|
-
zp = zero_point[:, i].unsqueeze(1)
|
200
|
+
sc = scale[:, i].view(-1, 1)
|
201
|
+
zp = zero_point[:, i].view(-1, 1) if zero_point is not None else None
|
112
202
|
|
113
203
|
idx = i * group_size
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
204
|
+
if do_quantize:
|
205
|
+
output[:, idx : (idx + group_size)] = _quantize(
|
206
|
+
x[:, idx : (idx + group_size)],
|
207
|
+
sc,
|
208
|
+
zp,
|
209
|
+
q_min,
|
210
|
+
q_max,
|
211
|
+
args,
|
212
|
+
dtype=dtype,
|
213
|
+
)
|
214
|
+
if do_dequantize:
|
215
|
+
input = (
|
216
|
+
output[:, idx : (idx + group_size)]
|
217
|
+
if do_quantize
|
218
|
+
else x[:, idx : (idx + group_size)]
|
219
|
+
)
|
220
|
+
output[:, idx : (idx + group_size)] = _dequantize(input, sc, zp)
|
221
|
+
|
222
|
+
else: # covers channel, token and tensor strategies
|
223
|
+
if do_quantize:
|
224
|
+
output = _quantize(
|
225
|
+
x,
|
226
|
+
scale,
|
227
|
+
zero_point,
|
228
|
+
q_min,
|
229
|
+
q_max,
|
230
|
+
args,
|
231
|
+
dtype=dtype,
|
232
|
+
)
|
233
|
+
if do_dequantize:
|
234
|
+
output = _dequantize(output if do_quantize else x, scale, zero_point)
|
143
235
|
|
144
|
-
return
|
236
|
+
return output
|
145
237
|
|
146
238
|
|
147
239
|
def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
148
240
|
# expects a module already initialized and injected with the parameters in
|
149
241
|
# initialize_module_for_quantization
|
150
|
-
|
242
|
+
if hasattr(module.forward, "__func__"):
|
243
|
+
forward_func_orig = module.forward.__func__
|
244
|
+
else:
|
245
|
+
forward_func_orig = module.forward.func
|
151
246
|
|
152
247
|
@wraps(forward_func_orig) # ensures docstring, names, etc are propagated
|
153
248
|
def wrapped_forward(self, *args, **kwargs):
|
249
|
+
if not getattr(module, "quantization_enabled", True):
|
250
|
+
# quantization is disabled on forward passes, return baseline
|
251
|
+
# forward call
|
252
|
+
return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
|
253
|
+
|
154
254
|
input_ = args[0]
|
155
255
|
|
156
256
|
if scheme.input_activations is not None:
|
@@ -199,6 +299,11 @@ def maybe_calibrate_or_quantize(
|
|
199
299
|
}:
|
200
300
|
return value
|
201
301
|
|
302
|
+
if value.numel() == 0:
|
303
|
+
# if the tensor is empty,
|
304
|
+
# skip quantization
|
305
|
+
return value
|
306
|
+
|
202
307
|
if args.dynamic:
|
203
308
|
# dynamic quantization - get scale and zero point directly from observer
|
204
309
|
observer = getattr(module, f"{base_name}_observer")
|
@@ -208,14 +313,61 @@ def maybe_calibrate_or_quantize(
|
|
208
313
|
scale = getattr(module, f"{base_name}_scale")
|
209
314
|
zero_point = getattr(module, f"{base_name}_zero_point")
|
210
315
|
|
211
|
-
if
|
316
|
+
if (
|
317
|
+
module.quantization_status == QuantizationStatus.CALIBRATION
|
318
|
+
and base_name != "weight"
|
319
|
+
):
|
212
320
|
# calibration mode - get new quant params from observer
|
213
321
|
observer = getattr(module, f"{base_name}_observer")
|
214
322
|
|
215
323
|
updated_scale, updated_zero_point = observer(value)
|
216
324
|
|
217
325
|
# update scale and zero point
|
218
|
-
|
219
|
-
|
220
|
-
|
326
|
+
update_parameter_data(module, updated_scale, f"{base_name}_scale")
|
327
|
+
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
|
328
|
+
|
221
329
|
return fake_quantize(value, scale, zero_point, args)
|
330
|
+
|
331
|
+
|
332
|
+
@torch.no_grad()
|
333
|
+
def _quantize(
|
334
|
+
x: torch.Tensor,
|
335
|
+
scale: torch.Tensor,
|
336
|
+
zero_point: torch.Tensor,
|
337
|
+
q_min: torch.Tensor,
|
338
|
+
q_max: torch.Tensor,
|
339
|
+
args: QuantizationArgs,
|
340
|
+
dtype: Optional[torch.dtype] = None,
|
341
|
+
) -> torch.Tensor:
|
342
|
+
|
343
|
+
scaled = x / scale + zero_point.to(x.dtype)
|
344
|
+
# clamp first because cast isn't guaranteed to be saturated (ie for fp8)
|
345
|
+
clamped_value = torch.clamp(
|
346
|
+
scaled,
|
347
|
+
q_min,
|
348
|
+
q_max,
|
349
|
+
)
|
350
|
+
quantized_value = round_to_quantized_type(clamped_value, args)
|
351
|
+
if dtype is not None:
|
352
|
+
quantized_value = quantized_value.to(dtype)
|
353
|
+
|
354
|
+
return quantized_value
|
355
|
+
|
356
|
+
|
357
|
+
@torch.no_grad()
|
358
|
+
def _dequantize(
|
359
|
+
x_q: torch.Tensor,
|
360
|
+
scale: torch.Tensor,
|
361
|
+
zero_point: torch.Tensor = None,
|
362
|
+
dtype: Optional[torch.dtype] = None,
|
363
|
+
) -> torch.Tensor:
|
364
|
+
|
365
|
+
dequant_value = x_q
|
366
|
+
if zero_point is not None:
|
367
|
+
dequant_value = dequant_value - zero_point.to(scale.dtype)
|
368
|
+
dequant_value = dequant_value.to(scale.dtype) * scale
|
369
|
+
|
370
|
+
if dtype is not None:
|
371
|
+
dequant_value = dequant_value.to(dtype)
|
372
|
+
|
373
|
+
return dequant_value
|
@@ -35,6 +35,10 @@ def freeze_module_quantization(module: Module):
|
|
35
35
|
# no quantization scheme nothing to do
|
36
36
|
return
|
37
37
|
|
38
|
+
if module.quantization_status == QuantizationStatus.FROZEN:
|
39
|
+
# nothing to do, already frozen
|
40
|
+
return
|
41
|
+
|
38
42
|
# delete observers from module if not dynamic
|
39
43
|
if scheme.input_activations and not scheme.input_activations.dynamic:
|
40
44
|
delattr(module, "input_observer")
|
@@ -0,0 +1,53 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""
|
16
|
+
Miscelaneous helpers for the quantization lifecycle
|
17
|
+
"""
|
18
|
+
|
19
|
+
|
20
|
+
from torch.nn import Module
|
21
|
+
|
22
|
+
|
23
|
+
__all__ = [
|
24
|
+
"update_layer_weight_quant_params",
|
25
|
+
"enable_quantization",
|
26
|
+
"disable_quantization",
|
27
|
+
]
|
28
|
+
|
29
|
+
|
30
|
+
def update_layer_weight_quant_params(layer: Module):
|
31
|
+
weight = getattr(layer, "weight", None)
|
32
|
+
scale = getattr(layer, "weight_scale", None)
|
33
|
+
zero_point = getattr(layer, "weight_zero_point", None)
|
34
|
+
observer = getattr(layer, "weight_observer", None)
|
35
|
+
|
36
|
+
if weight is None or observer is None or scale is None or zero_point is None:
|
37
|
+
# scale, zp, or observer not calibratable or weight not available
|
38
|
+
return
|
39
|
+
|
40
|
+
updated_scale, updated_zero_point = observer(weight)
|
41
|
+
|
42
|
+
# update scale and zero point
|
43
|
+
device = next(layer.parameters()).device
|
44
|
+
scale.data = updated_scale.to(device)
|
45
|
+
zero_point.data = updated_zero_point.to(device)
|
46
|
+
|
47
|
+
|
48
|
+
def enable_quantization(module: Module):
|
49
|
+
module.quantization_enabled = True
|
50
|
+
|
51
|
+
|
52
|
+
def disable_quantization(module: Module):
|
53
|
+
module.quantization_enabled = False
|
@@ -17,12 +17,18 @@ import logging
|
|
17
17
|
from typing import Optional
|
18
18
|
|
19
19
|
import torch
|
20
|
+
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
21
|
+
from accelerate.utils import PrefixedDataset
|
20
22
|
from compressed_tensors.quantization.lifecycle.forward import (
|
21
23
|
wrap_module_forward_quantized,
|
22
24
|
)
|
23
|
-
from compressed_tensors.quantization.quant_args import
|
25
|
+
from compressed_tensors.quantization.quant_args import (
|
26
|
+
QuantizationArgs,
|
27
|
+
QuantizationStrategy,
|
28
|
+
)
|
24
29
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
25
30
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
31
|
+
from compressed_tensors.utils import get_execution_device, is_module_offloaded
|
26
32
|
from torch.nn import Module, Parameter
|
27
33
|
|
28
34
|
|
@@ -58,7 +64,12 @@ def initialize_module_for_quantization(
|
|
58
64
|
_initialize_scale_zero_point_observer(module, "input", scheme.input_activations)
|
59
65
|
if scheme.weights is not None:
|
60
66
|
if hasattr(module, "weight"):
|
61
|
-
|
67
|
+
weight_shape = None
|
68
|
+
if isinstance(module, torch.nn.Linear):
|
69
|
+
weight_shape = module.weight.shape
|
70
|
+
_initialize_scale_zero_point_observer(
|
71
|
+
module, "weight", scheme.weights, weight_shape=weight_shape
|
72
|
+
)
|
62
73
|
else:
|
63
74
|
_LOGGER.warning(
|
64
75
|
f"module type {type(module)} targeted for weight quantization but "
|
@@ -73,12 +84,38 @@ def initialize_module_for_quantization(
|
|
73
84
|
module.quantization_scheme = scheme
|
74
85
|
module.quantization_status = QuantizationStatus.INITIALIZED
|
75
86
|
|
87
|
+
offloaded = False
|
88
|
+
if is_module_offloaded(module):
|
89
|
+
offloaded = True
|
90
|
+
hook = module._hf_hook
|
91
|
+
prefix_dict = module._hf_hook.weights_map
|
92
|
+
new_prefix = {}
|
93
|
+
|
94
|
+
# recreate the prefix dict (since it is immutable)
|
95
|
+
# and add quantization parameters
|
96
|
+
for key, data in module.named_parameters():
|
97
|
+
if key not in prefix_dict:
|
98
|
+
new_prefix[f"{prefix_dict.prefix}{key}"] = data
|
99
|
+
else:
|
100
|
+
new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
|
101
|
+
new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
|
102
|
+
remove_hook_from_module(module)
|
103
|
+
|
76
104
|
# wrap forward call of module to perform quantized actions based on calltime status
|
77
105
|
wrap_module_forward_quantized(module, scheme)
|
78
106
|
|
107
|
+
if offloaded:
|
108
|
+
# we need to re-add the hook for offloading now that we've wrapped forward
|
109
|
+
add_hook_to_module(module, hook)
|
110
|
+
if prefix_dict is not None:
|
111
|
+
module._hf_hook.weights_map = new_prefix_dict
|
112
|
+
|
79
113
|
|
80
114
|
def _initialize_scale_zero_point_observer(
|
81
|
-
module: Module,
|
115
|
+
module: Module,
|
116
|
+
base_name: str,
|
117
|
+
quantization_args: QuantizationArgs,
|
118
|
+
weight_shape: Optional[torch.Size] = None,
|
82
119
|
):
|
83
120
|
# initialize observer module and attach as submodule
|
84
121
|
observer = quantization_args.get_observer()
|
@@ -88,12 +125,32 @@ def _initialize_scale_zero_point_observer(
|
|
88
125
|
return # no need to register a scale and zero point for a dynamic observer
|
89
126
|
|
90
127
|
device = next(module.parameters()).device
|
128
|
+
if is_module_offloaded(module):
|
129
|
+
device = get_execution_device(module)
|
130
|
+
|
131
|
+
# infer expected scale/zero point shape
|
132
|
+
expected_shape = 1 # per tensor
|
133
|
+
|
134
|
+
if base_name == "weight" and weight_shape is not None:
|
135
|
+
if quantization_args.strategy == QuantizationStrategy.CHANNEL:
|
136
|
+
# (output_channels, 1)
|
137
|
+
expected_shape = (weight_shape[0], 1)
|
138
|
+
elif quantization_args.strategy == QuantizationStrategy.GROUP:
|
139
|
+
expected_shape = (
|
140
|
+
weight_shape[0],
|
141
|
+
weight_shape[1] // quantization_args.group_size,
|
142
|
+
)
|
91
143
|
|
92
144
|
# initializes empty scale and zero point parameters for the module
|
93
|
-
init_scale = Parameter(
|
145
|
+
init_scale = Parameter(
|
146
|
+
torch.empty(expected_shape, dtype=module.weight.dtype, device=device),
|
147
|
+
requires_grad=False,
|
148
|
+
)
|
94
149
|
module.register_parameter(f"{base_name}_scale", init_scale)
|
95
150
|
|
151
|
+
zp_dtype = quantization_args.pytorch_dtype()
|
96
152
|
init_zero_point = Parameter(
|
97
|
-
torch.empty(
|
153
|
+
torch.empty(expected_shape, device=device, dtype=zp_dtype),
|
154
|
+
requires_grad=False,
|
98
155
|
)
|
99
156
|
module.register_parameter(f"{base_name}_zero_point", init_zero_point)
|