mct-nightly 2.2.0.20240917.426__py3-none-any.whl → 2.2.0.20240919.455__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20240917.426.dist-info → mct_nightly-2.2.0.20240919.455.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20240917.426.dist-info → mct_nightly-2.2.0.20240919.455.dist-info}/RECORD +30 -20
- {mct_nightly-2.2.0.20240917.426.dist-info → mct_nightly-2.2.0.20240919.455.dist-info}/top_level.txt +1 -0
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/graph/base_node.py +3 -0
- model_compression_toolkit/core/common/graph/functional_node.py +1 -1
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -1
- model_compression_toolkit/core/keras/reader/node_builder.py +23 -1
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
- model_compression_toolkit/core/pytorch/reader/graph_builders.py +13 -4
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +12 -3
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +10 -1
- model_compression_toolkit/gptq/__init__.py +17 -5
- model_compression_toolkit/gptq/common/gptq_config.py +88 -75
- model_compression_toolkit/gptq/pytorch/gptq_training.py +18 -9
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +49 -29
- model_compression_toolkit/gptq/pytorch/quantizer/gradual_activation_quantization.py +80 -0
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +10 -10
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +6 -49
- model_compression_toolkit/trainable_infrastructure/pytorch/annealing_schedulers.py +39 -0
- model_compression_toolkit/trainable_infrastructure/pytorch/util.py +29 -0
- tests_pytest/__init__.py +14 -0
- tests_pytest/pytorch/__init__.py +14 -0
- tests_pytest/pytorch/gptq/__init__.py +14 -0
- tests_pytest/pytorch/gptq/test_annealing_cfg.py +40 -0
- tests_pytest/pytorch/gptq/test_gradual_act_quantization.py +100 -0
- tests_pytest/pytorch/trainable_infrastructure/__init__.py +14 -0
- tests_pytest/pytorch/trainable_infrastructure/test_linear_annealing.py +49 -0
- {mct_nightly-2.2.0.20240917.426.dist-info → mct_nightly-2.2.0.20240919.455.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20240917.426.dist-info → mct_nightly-2.2.0.20240919.455.dist-info}/WHEEL +0 -0
@@ -21,6 +21,8 @@ import copy
|
|
21
21
|
import torch
|
22
22
|
|
23
23
|
from model_compression_toolkit.core.common.hessian import HessianInfoService
|
24
|
+
from model_compression_toolkit.gptq.pytorch.quantizer.gradual_activation_quantization import \
|
25
|
+
get_gradual_activation_quantizer_wrapper_factory
|
24
26
|
from model_compression_toolkit.logger import Logger
|
25
27
|
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
|
26
28
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
@@ -36,6 +38,7 @@ from model_compression_toolkit.gptq.pytorch.graph_info import get_gptq_trainable
|
|
36
38
|
from model_compression_toolkit.gptq.pytorch.quantizer.quantization_builder import quantization_builder
|
37
39
|
from model_compression_toolkit.gptq.pytorch.quantizer.regularization_factory import get_regularization
|
38
40
|
from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder
|
41
|
+
from model_compression_toolkit.trainable_infrastructure.pytorch.util import get_total_grad_steps
|
39
42
|
|
40
43
|
|
41
44
|
class PytorchGPTQTrainer(GPTQTrainer):
|
@@ -66,6 +69,13 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
66
69
|
representative_data_gen: Dataset to use for inputs of the models.
|
67
70
|
hessian_info_service: HessianInfoService to fetch info based on the hessian approximation of the float model.
|
68
71
|
"""
|
72
|
+
def _get_total_grad_steps():
|
73
|
+
return get_total_grad_steps(representative_data_gen) * gptq_config.n_epochs
|
74
|
+
|
75
|
+
# must be set prior to model building in the base class constructor
|
76
|
+
self.gradual_act_quantizer_wrapper_factory = get_gradual_activation_quantizer_wrapper_factory(
|
77
|
+
gptq_config, _get_total_grad_steps)
|
78
|
+
|
69
79
|
super().__init__(graph_float,
|
70
80
|
graph_quant,
|
71
81
|
gptq_config,
|
@@ -98,7 +108,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
98
108
|
|
99
109
|
self.weights_for_average_loss = to_torch_tensor(self.compute_hessian_based_weights())
|
100
110
|
|
101
|
-
self.reg_func = get_regularization(self.gptq_config,
|
111
|
+
self.reg_func = get_regularization(self.gptq_config, _get_total_grad_steps)
|
102
112
|
|
103
113
|
def _is_gptq_weights_trainable(self,
|
104
114
|
node: BaseNode) -> bool:
|
@@ -145,7 +155,6 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
145
155
|
def get_activation_quantizer_holder(self, n: BaseNode) -> Callable:
|
146
156
|
"""
|
147
157
|
Retrieve a PytorchActivationQuantizationHolder layer to use for activation quantization of a node.
|
148
|
-
If the layer is not supposed to be wrapped with an activation quantizer - return None.
|
149
158
|
Args:
|
150
159
|
n: Node to attach a PytorchActivationQuantizationHolder to its output.
|
151
160
|
Returns:
|
@@ -153,13 +162,13 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
153
162
|
"""
|
154
163
|
_, activation_quantizers = quantization_builder(n, self.gptq_config)
|
155
164
|
# Holder by definition uses a single quantizer for the activation quantization
|
156
|
-
# thus we make sure this is the only possible case
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
165
|
+
# thus we make sure this is the only possible case
|
166
|
+
if len(activation_quantizers) != 1:
|
167
|
+
Logger.critical(f"'PytorchActivationQuantizationHolder' requires exactly one quantizer, "
|
168
|
+
f"but {len(activation_quantizers)} were found for node {n.name}. "
|
169
|
+
f"Ensure the node is configured with a single activation quantizer.")
|
170
|
+
quantizer = self.gradual_act_quantizer_wrapper_factory(activation_quantizers[0])
|
171
|
+
return PytorchActivationQuantizationHolder(quantizer)
|
163
172
|
|
164
173
|
def build_gptq_model(self):
|
165
174
|
"""
|
@@ -13,26 +13,26 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
import copy
|
16
|
+
from typing import Callable, Union
|
16
17
|
|
17
|
-
from
|
18
|
-
from model_compression_toolkit.core import
|
19
|
-
from model_compression_toolkit.
|
20
|
-
from model_compression_toolkit.
|
18
|
+
from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE, PYTORCH
|
19
|
+
from model_compression_toolkit.core import CoreConfig
|
20
|
+
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
|
21
|
+
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
22
|
+
MixedPrecisionQuantizationConfig
|
23
|
+
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
|
24
|
+
ResourceUtilization
|
21
25
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
22
|
-
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
|
23
|
-
from model_compression_toolkit.logger import Logger
|
24
|
-
from model_compression_toolkit.constants import PYTORCH
|
25
|
-
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, GPTQHessianScoresConfig
|
26
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
27
|
-
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
28
26
|
from model_compression_toolkit.core.runner import core_runner
|
27
|
+
from model_compression_toolkit.gptq.common.gptq_config import (
|
28
|
+
GradientPTQConfig, GPTQHessianScoresConfig, GradualActivationQuantizationConfig)
|
29
|
+
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
|
29
30
|
from model_compression_toolkit.gptq.keras.quantization_facade import GPTQ_MOMENTUM
|
30
31
|
from model_compression_toolkit.gptq.runner import gptq_runner
|
31
|
-
from model_compression_toolkit.
|
32
|
-
from model_compression_toolkit.
|
33
|
-
from model_compression_toolkit.
|
34
|
-
|
35
|
-
from model_compression_toolkit.metadata import get_versions_dict, create_model_metadata
|
32
|
+
from model_compression_toolkit.logger import Logger
|
33
|
+
from model_compression_toolkit.metadata import create_model_metadata
|
34
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
35
|
+
from model_compression_toolkit.verify_packages import FOUND_TORCH
|
36
36
|
|
37
37
|
LR_DEFAULT = 1e-4
|
38
38
|
LR_REST_DEFAULT = 1e-4
|
@@ -53,33 +53,38 @@ if FOUND_TORCH:
|
|
53
53
|
DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
|
54
54
|
|
55
55
|
def get_pytorch_gptq_config(n_epochs: int,
|
56
|
-
optimizer: Optimizer =
|
57
|
-
optimizer_rest: Optimizer =
|
56
|
+
optimizer: Optimizer = None,
|
57
|
+
optimizer_rest: Optimizer = None,
|
58
58
|
loss: Callable = multiple_tensors_mse_loss,
|
59
59
|
log_function: Callable = None,
|
60
60
|
use_hessian_based_weights: bool = True,
|
61
61
|
regularization_factor: float = REG_DEFAULT,
|
62
|
-
hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE
|
62
|
+
hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE,
|
63
|
+
gradual_activation_quantization: Union[bool, GradualActivationQuantizationConfig] = False,
|
63
64
|
) -> GradientPTQConfig:
|
64
65
|
"""
|
65
|
-
Create a
|
66
|
+
Create a GradientPTQConfig instance for Pytorch models.
|
66
67
|
|
67
68
|
args:
|
68
69
|
n_epochs (int): Number of epochs for running the representative dataset for fine-tuning.
|
69
70
|
optimizer (Optimizer): Pytorch optimizer to use for fine-tuning for auxiliry variable.
|
70
71
|
optimizer_rest (Optimizer): Pytorch optimizer to use for fine-tuning of the bias variable.
|
71
|
-
loss (Callable): loss to use during fine-tuning.
|
72
|
+
loss (Callable): loss to use during fine-tuning. See the default loss function for the exact interface.
|
72
73
|
log_function (Callable): Function to log information about the gptq process.
|
73
74
|
use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
|
74
75
|
regularization_factor (float): A floating point number that defines the regularization factor.
|
75
76
|
hessian_batch_size (int): Batch size for Hessian computation in Hessian-based weights GPTQ.
|
77
|
+
gradual_activation_quantization (bool, GradualActivationQuantizationConfig):
|
78
|
+
If False, GradualActivationQuantization is disabled.
|
79
|
+
If True, GradualActivationQuantization is enabled with the default settings.
|
80
|
+
GradualActivationQuantizationConfig object can be passed to use non-default settings.
|
76
81
|
|
77
82
|
returns:
|
78
|
-
a
|
83
|
+
a GradientPTQConfig object to use when fine-tuning the quantized model using gptq.
|
79
84
|
|
80
85
|
Examples:
|
81
86
|
|
82
|
-
Import MCT and Create a
|
87
|
+
Import MCT and Create a GradientPTQConfig to run for 5 epochs:
|
83
88
|
|
84
89
|
>>> import model_compression_toolkit as mct
|
85
90
|
>>> gptq_conf = mct.gptq.get_pytorch_gptq_config(n_epochs=5)
|
@@ -89,16 +94,31 @@ if FOUND_TORCH:
|
|
89
94
|
>>> import torch
|
90
95
|
>>> gptq_conf = mct.gptq.get_pytorch_gptq_config(n_epochs=3, optimizer=torch.optim.Adam([torch.Tensor(1)]))
|
91
96
|
|
92
|
-
|
97
|
+
To enable Gradual Activation Quantization with non-default settings build GradualActivationQuantizationConfig:
|
98
|
+
>>> gradual_act_conf = mct.gptq.GradualActivationQuantizationConfig(mct.gptq.QFractionLinearAnnealingConfig(initial_q_fraction=0.2))
|
99
|
+
>>> gptq_conf = mct.gptq.get_pytorch_gptq_config(n_epochs=3, gradual_activation_quantization=gradual_act_conf)
|
100
|
+
The configuration can be passed to :func:`~model_compression_toolkit.pytorch_gradient_post_training_quantization` in order to quantize a pytorch model using gptq.
|
93
101
|
|
94
102
|
"""
|
103
|
+
optimizer = optimizer or Adam([torch.Tensor([])], lr=LR_DEFAULT)
|
104
|
+
optimizer_rest = optimizer_rest or Adam([torch.Tensor([])], lr=LR_REST_DEFAULT)
|
105
|
+
|
95
106
|
bias_optimizer = torch.optim.SGD([torch.Tensor([])], lr=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM)
|
107
|
+
|
108
|
+
if isinstance(gradual_activation_quantization, bool):
|
109
|
+
gradual_quant_config = GradualActivationQuantizationConfig() if gradual_activation_quantization else None
|
110
|
+
elif isinstance(gradual_activation_quantization, GradualActivationQuantizationConfig):
|
111
|
+
gradual_quant_config = gradual_activation_quantization
|
112
|
+
else:
|
113
|
+
raise TypeError(f'gradual_activation_quantization argument should be bool or '
|
114
|
+
f'GradualActivationQuantizationConfig, received {type(gradual_activation_quantization)}') # pragma: no cover
|
115
|
+
|
96
116
|
return GradientPTQConfig(n_epochs, optimizer, optimizer_rest=optimizer_rest, loss=loss,
|
97
117
|
log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer,
|
98
118
|
use_hessian_based_weights=use_hessian_based_weights,
|
99
119
|
regularization_factor=regularization_factor,
|
100
|
-
hessian_weights_config=GPTQHessianScoresConfig(hessian_batch_size=hessian_batch_size)
|
101
|
-
|
120
|
+
hessian_weights_config=GPTQHessianScoresConfig(hessian_batch_size=hessian_batch_size),
|
121
|
+
gradual_activation_quantization_config=gradual_quant_config)
|
102
122
|
|
103
123
|
def pytorch_gradient_post_training_quantization(model: Module,
|
104
124
|
representative_data_gen: Callable,
|
@@ -222,11 +242,11 @@ if FOUND_TORCH:
|
|
222
242
|
else:
|
223
243
|
# If torch is not installed,
|
224
244
|
# we raise an exception when trying to use these functions.
|
225
|
-
def get_pytorch_gptq_config(*args, **kwargs):
|
245
|
+
def get_pytorch_gptq_config(*args, **kwargs): # pragma: no cover
|
226
246
|
Logger.critical("PyTorch must be installed to use 'get_pytorch_gptq_config'. "
|
227
|
-
"The 'torch' package is missing.")
|
247
|
+
"The 'torch' package is missing.")
|
228
248
|
|
229
249
|
|
230
|
-
def pytorch_gradient_post_training_quantization(*args, **kwargs):
|
250
|
+
def pytorch_gradient_post_training_quantization(*args, **kwargs): # pragma: no cover
|
231
251
|
Logger.critical("PyTorch must be installed to use 'pytorch_gradient_post_training_quantization'. "
|
232
|
-
"The 'torch' package is missing.")
|
252
|
+
"The 'torch' package is missing.")
|
@@ -0,0 +1,80 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
from functools import partial
|
16
|
+
from typing import Callable
|
17
|
+
|
18
|
+
from model_compression_toolkit.gptq import GradientPTQConfig, QFractionLinearAnnealingConfig
|
19
|
+
from model_compression_toolkit.trainable_infrastructure import BasePytorchTrainableQuantizer
|
20
|
+
|
21
|
+
from model_compression_toolkit.trainable_infrastructure.pytorch.annealing_schedulers import LinearAnnealingScheduler
|
22
|
+
|
23
|
+
|
24
|
+
def get_gradual_activation_quantizer_wrapper_factory(gptq_config: GradientPTQConfig,
|
25
|
+
get_total_grad_steps_fn: Callable[[], int]) \
|
26
|
+
-> Callable[[BasePytorchTrainableQuantizer], 'GradualActivationQuantizerWrapper']:
|
27
|
+
"""
|
28
|
+
Get a factory for 'GradualActivationQuantizerWrapper'.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
gptq_config: GPTQ configuration.
|
32
|
+
get_total_grad_steps_fn: a callable to obtain the total expected number of gradient steps.
|
33
|
+
|
34
|
+
Returns:
|
35
|
+
A factory function to build 'GradualActivationQuantizerWrapper' from Quantizer.
|
36
|
+
"""
|
37
|
+
if gptq_config.gradual_activation_quantization_config is None:
|
38
|
+
return lambda q: q
|
39
|
+
|
40
|
+
annealing_cfg = gptq_config.gradual_activation_quantization_config.q_fraction_scheduler_policy
|
41
|
+
if isinstance(annealing_cfg, QFractionLinearAnnealingConfig):
|
42
|
+
t_end = annealing_cfg.end_step or get_total_grad_steps_fn()
|
43
|
+
factor_scheduler = LinearAnnealingScheduler(t_start=annealing_cfg.start_step, t_end=t_end,
|
44
|
+
initial_val=annealing_cfg.initial_q_fraction,
|
45
|
+
target_val=annealing_cfg.target_q_fraction)
|
46
|
+
else:
|
47
|
+
raise ValueError(f'Unknown annealing policy {annealing_cfg}')
|
48
|
+
|
49
|
+
return partial(GradualActivationQuantizerWrapper, q_fraction_scheduler=factor_scheduler)
|
50
|
+
|
51
|
+
|
52
|
+
class GradualActivationQuantizerWrapper:
|
53
|
+
# TODO update paper's url
|
54
|
+
"""
|
55
|
+
Quantizer wrapper for Gradual Activation Quantization training (https://arxiv.org/abs/2309.11531).
|
56
|
+
|
57
|
+
It computes the weighted sum of the float activation 'x' and the quantized activation 'q(x)':
|
58
|
+
|
59
|
+
out = (1 - q_fraction) * x + q_fraction * q(x)
|
60
|
+
|
61
|
+
where 'q_fraction' is a tensor fraction to quantize in the range [0, 1] provided by a scheduler.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
quantizer: quantizer to wrap.
|
65
|
+
q_fraction_scheduler: a callable that accepts a gradient step and returns the corresponding quantized fraction.
|
66
|
+
"""
|
67
|
+
def __init__(self, quantizer: BasePytorchTrainableQuantizer, q_fraction_scheduler: Callable[[int], float]):
|
68
|
+
self.quantizer = quantizer
|
69
|
+
self.q_fraction_scheduler = q_fraction_scheduler
|
70
|
+
self.step_cnt = 0
|
71
|
+
|
72
|
+
def __call__(self, x, training: bool = True):
|
73
|
+
q_fraction = self.q_fraction_scheduler(self.step_cnt)
|
74
|
+
out_q = self.quantizer(x, training)
|
75
|
+
out = (1 - q_fraction) * x + q_fraction * out_q
|
76
|
+
self.step_cnt += 1
|
77
|
+
return out
|
78
|
+
|
79
|
+
def initialize_quantization(self, *args, **kwargs):
|
80
|
+
self.quantizer.initialize_quantization(*args, **kwargs)
|
@@ -12,33 +12,33 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from tqdm import tqdm
|
16
15
|
from typing import Callable
|
17
16
|
|
18
|
-
from model_compression_toolkit.gptq import RoundingType, GradientPTQConfig
|
17
|
+
from model_compression_toolkit.gptq import RoundingType, GradientPTQConfig
|
19
18
|
from model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.soft_quantizer_reg import \
|
20
19
|
SoftQuantizerRegularization
|
20
|
+
from model_compression_toolkit.trainable_infrastructure.pytorch.annealing_schedulers import LinearAnnealingScheduler
|
21
21
|
|
22
22
|
|
23
|
-
|
23
|
+
WARMUP_STEP_FRACTION = 0.2
|
24
|
+
|
25
|
+
def get_regularization(gptq_config: GradientPTQConfig, get_total_grad_steps_fn: Callable[[], int]) -> Callable:
|
24
26
|
"""
|
25
27
|
Returns a function that computes the regularization term for GPTQ training based on the given
|
26
28
|
rounding type in the GPTQ configuration.
|
27
29
|
|
28
30
|
Args:
|
29
31
|
gptq_config: A GPTQ configuration.
|
30
|
-
|
32
|
+
get_total_grad_steps_fn: a callable to obtain the total expected number of gradient steps.
|
31
33
|
|
32
34
|
Returns: A function for computing the regularization. If there is no regularization function defined for the given
|
33
35
|
rounding type, then it returns a function that just returns 0.
|
34
36
|
|
35
37
|
"""
|
36
38
|
if gptq_config.rounding_type == RoundingType.SoftQuantizer:
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
return SoftQuantizerRegularization(total_gradient_steps=num_batches * gptq_config.n_epochs)
|
39
|
+
total_gradient_steps = get_total_grad_steps_fn()
|
40
|
+
t_start = int(WARMUP_STEP_FRACTION * total_gradient_steps)
|
41
|
+
scheduler = LinearAnnealingScheduler(t_start=t_start, t_end=total_gradient_steps, initial_val=20, target_val=2)
|
42
|
+
return SoftQuantizerRegularization(scheduler)
|
43
43
|
else:
|
44
44
|
return lambda m, e_reg: 0
|
@@ -12,57 +12,14 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from typing import List
|
15
|
+
from typing import List, Callable
|
16
16
|
|
17
17
|
import torch
|
18
|
-
import numpy as np
|
19
18
|
from torch import nn
|
20
19
|
|
20
|
+
from mct_quantizers import PytorchQuantizationWrapper
|
21
21
|
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
22
|
-
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
|
23
22
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
24
|
-
from mct_quantizers import PytorchQuantizationWrapper
|
25
|
-
|
26
|
-
|
27
|
-
class LinearTempDecay:
|
28
|
-
"""
|
29
|
-
Annealing process for the soft quantizer regularization temperature term.
|
30
|
-
"""
|
31
|
-
|
32
|
-
def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 20, end_b: int = 2):
|
33
|
-
"""
|
34
|
-
Initializes a LinearTempDecay object.
|
35
|
-
|
36
|
-
Args:
|
37
|
-
t_max: maximal time step.
|
38
|
-
rel_start_decay: Decay step size at the beginning of the process.
|
39
|
-
start_b: Starting value of the regularization term.
|
40
|
-
end_b: Target value of the regularization term.
|
41
|
-
"""
|
42
|
-
|
43
|
-
self.t_max = t_max
|
44
|
-
self.start_decay = rel_start_decay * t_max
|
45
|
-
self.start_b = start_b
|
46
|
-
self.end_b = end_b
|
47
|
-
|
48
|
-
def __call__(self, t: float) -> float:
|
49
|
-
"""
|
50
|
-
Cosine annealing scheduler for soft quantizer regularization temperature term.
|
51
|
-
|
52
|
-
Args:
|
53
|
-
t: The current time step.
|
54
|
-
|
55
|
-
Returns: Scheduled temperature.
|
56
|
-
"""
|
57
|
-
|
58
|
-
is_before_start_decay = (t < self.start_decay)
|
59
|
-
|
60
|
-
rel_t = (t - self.start_decay) / (self.t_max - self.start_decay)
|
61
|
-
|
62
|
-
return self.start_b * is_before_start_decay + \
|
63
|
-
(1 - is_before_start_decay) * \
|
64
|
-
(self.end_b + (self.start_b - self.end_b) * torch.maximum(to_torch_tensor(np.array([0.0])),
|
65
|
-
to_torch_tensor(np.array((1 - rel_t)))))
|
66
23
|
|
67
24
|
|
68
25
|
class SoftQuantizerRegularization:
|
@@ -70,16 +27,16 @@ class SoftQuantizerRegularization:
|
|
70
27
|
A class to handle the computation of soft quantizer regularization for GPTQ training.
|
71
28
|
"""
|
72
29
|
|
73
|
-
def __init__(self,
|
30
|
+
def __init__(self, beta_scheduler: Callable[[int], float]):
|
74
31
|
"""
|
75
32
|
Initializes the regularization computation object with a LinearDecay object.
|
76
33
|
|
77
34
|
Args:
|
78
|
-
|
35
|
+
beta_scheduler: a callable that accepts current time step and returns a corresponding beta value.
|
79
36
|
"""
|
80
37
|
|
81
38
|
# Initializing the temperature decay according to the number of expected gradient steps
|
82
|
-
self.
|
39
|
+
self.beta_scheduler = beta_scheduler
|
83
40
|
|
84
41
|
self.count_iter = 0
|
85
42
|
|
@@ -95,7 +52,7 @@ class SoftQuantizerRegularization:
|
|
95
52
|
"""
|
96
53
|
|
97
54
|
soft_reg_aux: List[torch.Tensor] = []
|
98
|
-
b = self.
|
55
|
+
b = self.beta_scheduler(self.count_iter)
|
99
56
|
for layer in model.modules():
|
100
57
|
if isinstance(layer, PytorchQuantizationWrapper):
|
101
58
|
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
|
@@ -0,0 +1,39 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
|
16
|
+
|
17
|
+
|
18
|
+
class LinearAnnealingScheduler:
|
19
|
+
def __init__(self, t_start: int, t_end: int, initial_val: float, target_val: float):
|
20
|
+
"""
|
21
|
+
Linear annealing scheduler. Returns the corresponding annealed value per time step.
|
22
|
+
|
23
|
+
Args:
|
24
|
+
t_start: time step to begin annealing.
|
25
|
+
t_end: time step to complete annealing.
|
26
|
+
initial_val: initial value.
|
27
|
+
target_val: target value.
|
28
|
+
"""
|
29
|
+
if not (0 <= t_start < t_end):
|
30
|
+
raise ValueError(f'Expected 0 <= t_start < t_end, actual {t_end=} {t_start=}')
|
31
|
+
|
32
|
+
self.t_start = t_start
|
33
|
+
self.t_end = t_end
|
34
|
+
self.initial_val = initial_val
|
35
|
+
self.target_val = target_val
|
36
|
+
|
37
|
+
def __call__(self, t: int) -> float:
|
38
|
+
factor = to_torch_tensor((t - self.t_start) / (self.t_end - self.t_start)).clip(0, 1)
|
39
|
+
return self.initial_val + factor * (self.target_val - self.initial_val)
|
@@ -0,0 +1,29 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
from functools import lru_cache
|
16
|
+
from typing import Callable
|
17
|
+
|
18
|
+
from tqdm import tqdm
|
19
|
+
|
20
|
+
|
21
|
+
@lru_cache
|
22
|
+
def get_total_grad_steps(representative_data_gen: Callable) -> int:
|
23
|
+
# dry run on the representative dataset to count number of batches
|
24
|
+
num_batches = 0
|
25
|
+
for _ in tqdm(representative_data_gen(), "Estimating representative dataset size"):
|
26
|
+
num_batches += 1
|
27
|
+
return num_batches
|
28
|
+
|
29
|
+
|
tests_pytest/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
@@ -0,0 +1,40 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
import pytest
|
16
|
+
|
17
|
+
from model_compression_toolkit.gptq import QFractionLinearAnnealingConfig
|
18
|
+
|
19
|
+
|
20
|
+
def test_linear_annealing_cfg_validation():
|
21
|
+
with pytest.raises(ValueError, match='Expected.* target_q_fraction <= 1'):
|
22
|
+
QFractionLinearAnnealingConfig(initial_q_fraction=0.1, target_q_fraction=1.1, start_step=0, end_step=None)
|
23
|
+
|
24
|
+
with pytest.raises(ValueError, match='Expected.* 0 <= initial_q_fraction'):
|
25
|
+
QFractionLinearAnnealingConfig(initial_q_fraction=-0.1, target_q_fraction=-0.9, start_step=0, end_step=100)
|
26
|
+
|
27
|
+
with pytest.raises(ValueError, match='Expected.* initial_q_fraction < target_q_fraction'):
|
28
|
+
QFractionLinearAnnealingConfig(initial_q_fraction=0.1, target_q_fraction=0.1, start_step=0, end_step=100)
|
29
|
+
|
30
|
+
with pytest.raises(ValueError, match='Expected.* initial_q_fraction < target_q_fraction'):
|
31
|
+
QFractionLinearAnnealingConfig(initial_q_fraction=0.2, target_q_fraction=0.1, start_step=0, end_step=100)
|
32
|
+
|
33
|
+
with pytest.raises(ValueError, match='Expected.* start_step >= 0'):
|
34
|
+
QFractionLinearAnnealingConfig(initial_q_fraction=0, target_q_fraction=1, start_step=-1, end_step=100)
|
35
|
+
|
36
|
+
with pytest.raises(ValueError, match='Expected.* start_step < end_step'):
|
37
|
+
QFractionLinearAnnealingConfig(initial_q_fraction=0, target_q_fraction=1, start_step=100, end_step=100)
|
38
|
+
|
39
|
+
with pytest.raises(ValueError, match='Expected.* start_step < end_step'):
|
40
|
+
QFractionLinearAnnealingConfig(initial_q_fraction=0, target_q_fraction=1, start_step=100, end_step=99)
|