compressed-tensors 0.6.0__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/__init__.py +1 -0
- compressed_tensors/base.py +2 -0
- compressed_tensors/compressors/__init__.py +6 -12
- compressed_tensors/compressors/base.py +38 -102
- compressed_tensors/compressors/helpers.py +6 -6
- compressed_tensors/compressors/model_compressors/__init__.py +17 -0
- compressed_tensors/compressors/{model_compressor.py → model_compressors/model_compressor.py} +95 -106
- compressed_tensors/compressors/quantized_compressors/__init__.py +18 -0
- compressed_tensors/compressors/quantized_compressors/base.py +146 -0
- compressed_tensors/compressors/{naive_quantized.py → quantized_compressors/naive_quantized.py} +11 -11
- compressed_tensors/compressors/{pack_quantized.py → quantized_compressors/pack_quantized.py} +6 -3
- compressed_tensors/compressors/sparse_compressors/__init__.py +18 -0
- compressed_tensors/compressors/sparse_compressors/base.py +110 -0
- compressed_tensors/compressors/{dense.py → sparse_compressors/dense.py} +3 -3
- compressed_tensors/compressors/{sparse_bitmask.py → sparse_compressors/sparse_bitmask.py} +14 -59
- compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +16 -0
- compressed_tensors/compressors/{marlin_24.py → sparse_quantized_compressors/marlin_24.py} +3 -3
- compressed_tensors/linear/compressed_linear.py +2 -2
- compressed_tensors/quantization/__init__.py +1 -0
- compressed_tensors/quantization/cache.py +201 -0
- compressed_tensors/quantization/lifecycle/apply.py +19 -3
- compressed_tensors/quantization/lifecycle/calibration.py +2 -3
- compressed_tensors/quantization/lifecycle/forward.py +58 -7
- compressed_tensors/quantization/lifecycle/frozen.py +6 -1
- compressed_tensors/quantization/lifecycle/helpers.py +0 -47
- compressed_tensors/quantization/lifecycle/initialize.py +116 -67
- compressed_tensors/quantization/observers/__init__.py +0 -1
- compressed_tensors/quantization/observers/helpers.py +40 -2
- compressed_tensors/quantization/quant_args.py +34 -4
- compressed_tensors/quantization/quant_config.py +14 -2
- compressed_tensors/quantization/quant_scheme.py +8 -4
- compressed_tensors/quantization/utils/helpers.py +43 -18
- compressed_tensors/utils/helpers.py +17 -1
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.6.0.dist-info → compressed_tensors-0.7.1.dist-info}/METADATA +1 -1
- compressed_tensors-0.7.1.dist-info/RECORD +58 -0
- compressed_tensors/quantization/observers/memoryless.py +0 -56
- compressed_tensors-0.6.0.dist-info/RECORD +0 -52
- {compressed_tensors-0.6.0.dist-info → compressed_tensors-0.7.1.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.6.0.dist-info → compressed_tensors-0.7.1.dist-info}/WHEEL +0 -0
- {compressed_tensors-0.6.0.dist-info → compressed_tensors-0.7.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,201 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
from enum import Enum
|
17
|
+
from typing import Any, Dict, List, Optional, Tuple
|
18
|
+
|
19
|
+
from compressed_tensors.quantization.observers import Observer
|
20
|
+
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
21
|
+
from torch import Tensor
|
22
|
+
from transformers import DynamicCache as HFDyanmicCache
|
23
|
+
|
24
|
+
|
25
|
+
class KVCacheScaleType(Enum):
|
26
|
+
KEY = "k_scale"
|
27
|
+
VALUE = "v_scale"
|
28
|
+
|
29
|
+
|
30
|
+
class QuantizedKVParameterCache(HFDyanmicCache):
|
31
|
+
|
32
|
+
"""
|
33
|
+
Quantized KV cache used in the forward call based on HF's dynamic cache.
|
34
|
+
Quantization strategy (tensor, group, channel) set from Quantization arg's strategy
|
35
|
+
Singleton, so that the same cache gets reused in all forward call of self_attn.
|
36
|
+
Each time forward is called, .update() is called, and ._quantize(), ._dequantize()
|
37
|
+
gets called appropriately.
|
38
|
+
The size of tensor is
|
39
|
+
`[batch_size, num_heads, seq_len - residual_length, head_dim]`.
|
40
|
+
|
41
|
+
|
42
|
+
Triggered by adding kv_cache_scheme in the recipe.
|
43
|
+
|
44
|
+
Example:
|
45
|
+
|
46
|
+
```python3
|
47
|
+
recipe = '''
|
48
|
+
quant_stage:
|
49
|
+
quant_modifiers:
|
50
|
+
QuantizationModifier:
|
51
|
+
kv_cache_scheme:
|
52
|
+
num_bits: 8
|
53
|
+
type: float
|
54
|
+
strategy: tensor
|
55
|
+
dynamic: false
|
56
|
+
symmetric: true
|
57
|
+
'''
|
58
|
+
|
59
|
+
"""
|
60
|
+
|
61
|
+
_instance = None
|
62
|
+
_initialized = False
|
63
|
+
|
64
|
+
def __new__(cls, *args, **kwargs):
|
65
|
+
"""Singleton"""
|
66
|
+
if cls._instance is None:
|
67
|
+
cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls)
|
68
|
+
return cls._instance
|
69
|
+
|
70
|
+
def __init__(self, quantization_args: QuantizationArgs):
|
71
|
+
if not self._initialized:
|
72
|
+
super().__init__()
|
73
|
+
|
74
|
+
self.quantization_args = quantization_args
|
75
|
+
|
76
|
+
self.k_observers: List[Observer] = []
|
77
|
+
self.v_observers: List[Observer] = []
|
78
|
+
|
79
|
+
# each index corresponds to layer_idx of the attention layer
|
80
|
+
self.k_scales: List[Tensor] = []
|
81
|
+
self.v_scales: List[Tensor] = []
|
82
|
+
|
83
|
+
self.k_zps: List[Tensor] = []
|
84
|
+
self.v_zps: List[Tensor] = []
|
85
|
+
|
86
|
+
self._initialized = True
|
87
|
+
|
88
|
+
def update(
|
89
|
+
self,
|
90
|
+
key_states: Tensor,
|
91
|
+
value_states: Tensor,
|
92
|
+
layer_idx: int,
|
93
|
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
94
|
+
) -> Tuple[Tensor, Tensor]:
|
95
|
+
"""
|
96
|
+
Get the k_scale and v_scale and output the
|
97
|
+
fakequant-ed key_states and value_states
|
98
|
+
"""
|
99
|
+
|
100
|
+
if len(self.k_observers) <= layer_idx:
|
101
|
+
k_observer = self.quantization_args.get_observer()
|
102
|
+
v_observer = self.quantization_args.get_observer()
|
103
|
+
|
104
|
+
self.k_observers.append(k_observer)
|
105
|
+
self.v_observers.append(v_observer)
|
106
|
+
|
107
|
+
q_key_states = self._quantize(
|
108
|
+
key_states.contiguous(), KVCacheScaleType.KEY, layer_idx
|
109
|
+
)
|
110
|
+
q_value_states = self._quantize(
|
111
|
+
value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx
|
112
|
+
)
|
113
|
+
|
114
|
+
qdq_key_states = self._dequantize(q_key_states, KVCacheScaleType.KEY, layer_idx)
|
115
|
+
qdq_value_states = self._dequantize(
|
116
|
+
q_value_states, KVCacheScaleType.VALUE, layer_idx
|
117
|
+
)
|
118
|
+
|
119
|
+
keys_to_return, values_to_return = qdq_key_states, qdq_value_states
|
120
|
+
|
121
|
+
return keys_to_return, values_to_return
|
122
|
+
|
123
|
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
124
|
+
"""
|
125
|
+
Returns the sequence length of the cached states.
|
126
|
+
A layer index can be optionally passed.
|
127
|
+
"""
|
128
|
+
if len(self.key_cache) <= layer_idx:
|
129
|
+
return 0
|
130
|
+
# since we cannot get the seq_length of each layer directly and
|
131
|
+
# rely on `_seen_tokens` which is updated every "layer_idx" == 0,
|
132
|
+
# this is a hack to get the actual seq_length for the given layer_idx
|
133
|
+
# this part of code otherwise fails when used to
|
134
|
+
# verify attn_weight shape in some models
|
135
|
+
return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1
|
136
|
+
|
137
|
+
def reset_states(self):
|
138
|
+
"""reset the kv states (used in calibration)"""
|
139
|
+
self.key_cache: List[Tensor] = []
|
140
|
+
self.value_cache: List[Tensor] = []
|
141
|
+
# Used in `generate` to keep tally of how many tokens the cache has seen
|
142
|
+
self._seen_tokens = 0
|
143
|
+
self._quantized_key_cache: List[Tensor] = []
|
144
|
+
self._quantized_value_cache: List[Tensor] = []
|
145
|
+
|
146
|
+
def reset(self):
|
147
|
+
"""
|
148
|
+
Reset the instantiation, create new instance on init
|
149
|
+
"""
|
150
|
+
QuantizedKVParameterCache._instance = None
|
151
|
+
QuantizedKVParameterCache._initialized = False
|
152
|
+
|
153
|
+
def _quantize(self, tensor, kv_type, layer_idx):
|
154
|
+
"""Quantizes a key/value using a defined quantization method."""
|
155
|
+
from compressed_tensors.quantization.lifecycle.forward import quantize
|
156
|
+
|
157
|
+
if kv_type == KVCacheScaleType.KEY: # key type
|
158
|
+
observer = self.k_observers[layer_idx]
|
159
|
+
scales = self.k_scales
|
160
|
+
zps = self.k_zps
|
161
|
+
else:
|
162
|
+
assert kv_type == KVCacheScaleType.VALUE
|
163
|
+
observer = self.v_observers[layer_idx]
|
164
|
+
scales = self.v_scales
|
165
|
+
zps = self.v_zps
|
166
|
+
|
167
|
+
scale, zp = observer(tensor)
|
168
|
+
if len(scales) <= layer_idx:
|
169
|
+
scales.append(scale)
|
170
|
+
zps.append(zp)
|
171
|
+
else:
|
172
|
+
scales[layer_idx] = scale
|
173
|
+
zps[layer_idx] = scale
|
174
|
+
|
175
|
+
q_tensor = quantize(
|
176
|
+
x=tensor,
|
177
|
+
scale=scale,
|
178
|
+
zero_point=zp,
|
179
|
+
args=self.quantization_args,
|
180
|
+
)
|
181
|
+
return q_tensor
|
182
|
+
|
183
|
+
def _dequantize(self, qtensor, kv_type, layer_idx):
|
184
|
+
"""Dequantizes back the tensor that was quantized by `self._quantize()`"""
|
185
|
+
from compressed_tensors.quantization.lifecycle.forward import dequantize
|
186
|
+
|
187
|
+
if kv_type == KVCacheScaleType.KEY:
|
188
|
+
scale = self.k_scales[layer_idx]
|
189
|
+
zp = self.k_zps[layer_idx]
|
190
|
+
else:
|
191
|
+
assert kv_type == KVCacheScaleType.VALUE
|
192
|
+
scale = self.v_scales[layer_idx]
|
193
|
+
zp = self.v_zps[layer_idx]
|
194
|
+
|
195
|
+
qdq_tensor = dequantize(
|
196
|
+
x_q=qtensor,
|
197
|
+
scale=scale,
|
198
|
+
zero_point=zp,
|
199
|
+
args=self.quantization_args,
|
200
|
+
)
|
201
|
+
return qdq_tensor
|
@@ -43,6 +43,7 @@ from compressed_tensors.quantization.utils import (
|
|
43
43
|
infer_quantization_status,
|
44
44
|
is_kv_cache_quant_scheme,
|
45
45
|
iter_named_leaf_modules,
|
46
|
+
iter_named_quantizable_modules,
|
46
47
|
)
|
47
48
|
from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
|
48
49
|
from compressed_tensors.utils.offload import update_parameter_data
|
@@ -106,8 +107,8 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
|
|
106
107
|
|
107
108
|
|
108
109
|
def apply_quantization_config(
|
109
|
-
model: Module, config: QuantizationConfig, run_compressed: bool = False
|
110
|
-
) ->
|
110
|
+
model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
|
111
|
+
) -> OrderedDict:
|
111
112
|
"""
|
112
113
|
Initializes the model for quantization in-place based on the given config
|
113
114
|
|
@@ -116,6 +117,10 @@ def apply_quantization_config(
|
|
116
117
|
:param run_compressed: Whether the model will be run in compressed mode or
|
117
118
|
decompressed fully on load
|
118
119
|
"""
|
120
|
+
# Workaround for when HF Quantizer passes None, see PR #180
|
121
|
+
if config is None:
|
122
|
+
return OrderedDict()
|
123
|
+
|
119
124
|
# remove reference to the original `config`
|
120
125
|
# argument. This function can mutate it, and we'd
|
121
126
|
# like to keep the original `config` as it is.
|
@@ -135,15 +140,23 @@ def apply_quantization_config(
|
|
135
140
|
# list of submodules to ignore
|
136
141
|
ignored_submodules = defaultdict(list)
|
137
142
|
# mark appropriate layers for quantization by setting their quantization schemes
|
138
|
-
for name, submodule in
|
143
|
+
for name, submodule in iter_named_quantizable_modules(
|
144
|
+
model,
|
145
|
+
include_children=True,
|
146
|
+
include_attn=True,
|
147
|
+
): # child modules and attention modules
|
139
148
|
# potentially fix module name to remove FSDP wrapper prefix
|
140
149
|
name = fix_fsdp_module_name(name)
|
141
150
|
if matches := find_name_or_class_matches(name, submodule, config.ignore):
|
142
151
|
for match in matches:
|
143
152
|
ignored_submodules[match].append(name)
|
144
153
|
continue # layer matches ignore list, continue
|
154
|
+
|
145
155
|
targets = find_name_or_class_matches(name, submodule, target_to_scheme)
|
156
|
+
|
146
157
|
if targets:
|
158
|
+
# mark modules to be quantized by adding
|
159
|
+
# quant scheme to the matching layers
|
147
160
|
scheme = _scheme_from_targets(target_to_scheme, targets, name)
|
148
161
|
if run_compressed:
|
149
162
|
format = config.format
|
@@ -200,6 +213,9 @@ def process_kv_cache_config(
|
|
200
213
|
:param config: the QuantizationConfig
|
201
214
|
:return: the QuantizationConfig with additional "kv_cache" group
|
202
215
|
"""
|
216
|
+
if targets == KV_CACHE_TARGETS:
|
217
|
+
_LOGGER.info(f"KV cache targets set to default value of: {KV_CACHE_TARGETS}")
|
218
|
+
|
203
219
|
kv_cache_dict = config.kv_cache_scheme.model_dump()
|
204
220
|
kv_cache_scheme = QuantizationScheme(
|
205
221
|
output_activations=QuantizationArgs(**kv_cache_dict),
|
@@ -56,10 +56,9 @@ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool =
|
|
56
56
|
observer = module.weight_observer
|
57
57
|
g_idx = getattr(module, "weight_g_idx", None)
|
58
58
|
|
59
|
-
offloaded =
|
60
|
-
if
|
59
|
+
offloaded = is_module_offloaded(module)
|
60
|
+
if offloaded:
|
61
61
|
module._hf_hook.pre_forward(module)
|
62
|
-
offloaded = True
|
63
62
|
|
64
63
|
scale, zero_point = observer(module.weight, g_idx=g_idx)
|
65
64
|
update_parameter_data(module, scale, "weight_scale")
|
@@ -14,10 +14,14 @@
|
|
14
14
|
|
15
15
|
from functools import wraps
|
16
16
|
from math import ceil
|
17
|
-
from typing import Optional
|
17
|
+
from typing import Callable, Optional
|
18
18
|
|
19
19
|
import torch
|
20
|
-
from compressed_tensors.quantization.
|
20
|
+
from compressed_tensors.quantization.cache import QuantizedKVParameterCache
|
21
|
+
from compressed_tensors.quantization.observers.helpers import (
|
22
|
+
calculate_range,
|
23
|
+
compute_dynamic_scales_and_zp,
|
24
|
+
)
|
21
25
|
from compressed_tensors.quantization.quant_args import (
|
22
26
|
QuantizationArgs,
|
23
27
|
QuantizationStrategy,
|
@@ -62,6 +66,7 @@ def quantize(
|
|
62
66
|
:param g_idx: optional mapping from column index to group index
|
63
67
|
:return: fake quantized tensor
|
64
68
|
"""
|
69
|
+
|
65
70
|
return _process_quantization(
|
66
71
|
x=x,
|
67
72
|
scale=scale,
|
@@ -165,8 +170,8 @@ def _process_quantization(
|
|
165
170
|
x: torch.Tensor,
|
166
171
|
scale: torch.Tensor,
|
167
172
|
zero_point: torch.Tensor,
|
168
|
-
g_idx: Optional[torch.Tensor],
|
169
173
|
args: QuantizationArgs,
|
174
|
+
g_idx: Optional[torch.Tensor] = None,
|
170
175
|
dtype: Optional[torch.dtype] = None,
|
171
176
|
do_quantize: bool = True,
|
172
177
|
do_dequantize: bool = True,
|
@@ -266,6 +271,7 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
266
271
|
return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
|
267
272
|
|
268
273
|
input_ = args[0]
|
274
|
+
|
269
275
|
compressed = module.quantization_status == QuantizationStatus.COMPRESSED
|
270
276
|
|
271
277
|
if scheme.input_activations is not None:
|
@@ -285,9 +291,11 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
285
291
|
output = forward_func_orig.__get__(module, module.__class__)(
|
286
292
|
input_, *args[1:], **kwargs
|
287
293
|
)
|
288
|
-
|
289
294
|
if scheme.output_activations is not None:
|
295
|
+
|
290
296
|
# calibrate and (fake) quantize output activations when applicable
|
297
|
+
# kv_cache scales updated on model self_attn forward call in
|
298
|
+
# wrap_module_forward_quantized_attn
|
291
299
|
output = maybe_calibrate_or_quantize(
|
292
300
|
module, output, "output", scheme.output_activations
|
293
301
|
)
|
@@ -304,6 +312,50 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
304
312
|
setattr(module, "forward", bound_wrapped_forward)
|
305
313
|
|
306
314
|
|
315
|
+
def wrap_module_forward_quantized_attn(module: Module, scheme: QuantizationScheme):
|
316
|
+
# expects a module already initialized and injected with the parameters in
|
317
|
+
# initialize_module_for_quantization
|
318
|
+
if hasattr(module.forward, "__func__"):
|
319
|
+
forward_func_orig = module.forward.__func__
|
320
|
+
else:
|
321
|
+
forward_func_orig = module.forward.func
|
322
|
+
|
323
|
+
@wraps(forward_func_orig) # ensures docstring, names, etc are propagated
|
324
|
+
def wrapped_forward(self, *args, **kwargs):
|
325
|
+
|
326
|
+
# kv cache stored under weights
|
327
|
+
if module.quantization_status == QuantizationStatus.CALIBRATION:
|
328
|
+
quantization_args: QuantizationArgs = scheme.output_activations
|
329
|
+
past_key_value: QuantizedKVParameterCache = quantization_args.get_kv_cache()
|
330
|
+
kwargs["past_key_value"] = past_key_value
|
331
|
+
|
332
|
+
# QuantizedKVParameterCache used for obtaining k_scale, v_scale only,
|
333
|
+
# does not store quantized_key_states and quantized_value_state
|
334
|
+
kwargs["use_cache"] = False
|
335
|
+
|
336
|
+
attn_forward: Callable = forward_func_orig.__get__(module, module.__class__)
|
337
|
+
|
338
|
+
past_key_value.reset_states()
|
339
|
+
|
340
|
+
rtn = attn_forward(*args, **kwargs)
|
341
|
+
|
342
|
+
update_parameter_data(
|
343
|
+
module, past_key_value.k_scales[module.layer_idx], "k_scale"
|
344
|
+
)
|
345
|
+
update_parameter_data(
|
346
|
+
module, past_key_value.v_scales[module.layer_idx], "v_scale"
|
347
|
+
)
|
348
|
+
|
349
|
+
return rtn
|
350
|
+
|
351
|
+
return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
|
352
|
+
|
353
|
+
# bind wrapped forward to module class so reference to `self` is correct
|
354
|
+
bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
|
355
|
+
# set forward to wrapped forward
|
356
|
+
setattr(module, "forward", bound_wrapped_forward)
|
357
|
+
|
358
|
+
|
307
359
|
def maybe_calibrate_or_quantize(
|
308
360
|
module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
|
309
361
|
) -> torch.Tensor:
|
@@ -327,9 +379,8 @@ def maybe_calibrate_or_quantize(
|
|
327
379
|
g_idx = getattr(module, "weight_g_idx", None)
|
328
380
|
|
329
381
|
if args.dynamic:
|
330
|
-
# dynamic quantization -
|
331
|
-
|
332
|
-
scale, zero_point = observer(value, g_idx=g_idx)
|
382
|
+
# dynamic quantization - no need to invoke observer
|
383
|
+
scale, zero_point = compute_dynamic_scales_and_zp(value=value, args=args)
|
333
384
|
else:
|
334
385
|
# static quantization - get previous scale and zero point from layer
|
335
386
|
scale = getattr(module, f"{base_name}_scale")
|
@@ -14,6 +14,7 @@
|
|
14
14
|
|
15
15
|
|
16
16
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
17
|
+
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
|
17
18
|
from torch.nn import Module
|
18
19
|
|
19
20
|
|
@@ -44,7 +45,11 @@ def freeze_module_quantization(module: Module):
|
|
44
45
|
delattr(module, "input_observer")
|
45
46
|
if scheme.weights and not scheme.weights.dynamic:
|
46
47
|
delattr(module, "weight_observer")
|
47
|
-
if
|
48
|
+
if (
|
49
|
+
scheme.output_activations
|
50
|
+
and not is_kv_cache_quant_scheme(scheme)
|
51
|
+
and not scheme.output_activations.dynamic
|
52
|
+
):
|
48
53
|
delattr(module, "output_observer")
|
49
54
|
|
50
55
|
module.quantization_status = QuantizationStatus.FROZEN
|
@@ -16,62 +16,15 @@
|
|
16
16
|
Miscelaneous helpers for the quantization lifecycle
|
17
17
|
"""
|
18
18
|
|
19
|
-
from typing import Optional
|
20
|
-
|
21
|
-
import torch
|
22
19
|
from torch.nn import Module
|
23
20
|
|
24
21
|
|
25
22
|
__all__ = [
|
26
|
-
"update_layer_weight_quant_params",
|
27
23
|
"enable_quantization",
|
28
24
|
"disable_quantization",
|
29
25
|
]
|
30
26
|
|
31
27
|
|
32
|
-
def update_layer_weight_quant_params(
|
33
|
-
layer: Module,
|
34
|
-
weight: Optional[torch.Tensor] = None,
|
35
|
-
g_idx: Optional[torch.Tensor] = None,
|
36
|
-
reset_obs: bool = False,
|
37
|
-
):
|
38
|
-
"""
|
39
|
-
Update quantization parameters on layer
|
40
|
-
|
41
|
-
:param layer: input layer
|
42
|
-
:param weight: weight to update quant params with, defaults to layer weight
|
43
|
-
:param g_idx: optional mapping from column index to group index
|
44
|
-
:param reset_obs: reset the observer before calculating quant params,
|
45
|
-
defaults to False
|
46
|
-
"""
|
47
|
-
attached_weight = getattr(layer, "weight", None)
|
48
|
-
|
49
|
-
if weight is None:
|
50
|
-
weight = attached_weight
|
51
|
-
scale = getattr(layer, "weight_scale", None)
|
52
|
-
zero_point = getattr(layer, "weight_zero_point", None)
|
53
|
-
if g_idx is None:
|
54
|
-
g_idx = getattr(layer, "weight_g_idx", None)
|
55
|
-
observer = getattr(layer, "weight_observer", None)
|
56
|
-
|
57
|
-
if weight is None or observer is None or scale is None or zero_point is None:
|
58
|
-
# scale, zp, or observer not calibratable or weight not available
|
59
|
-
return
|
60
|
-
|
61
|
-
if reset_obs:
|
62
|
-
observer.reset()
|
63
|
-
|
64
|
-
if attached_weight is not None:
|
65
|
-
weight = weight.to(attached_weight.dtype)
|
66
|
-
|
67
|
-
updated_scale, updated_zero_point = observer(weight)
|
68
|
-
|
69
|
-
# update scale and zero point
|
70
|
-
device = next(layer.parameters()).device
|
71
|
-
scale.data = updated_scale.to(device)
|
72
|
-
zero_point.data = updated_zero_point.to(device)
|
73
|
-
|
74
|
-
|
75
28
|
def enable_quantization(module: Module):
|
76
29
|
module.quantization_enabled = True
|
77
30
|
|