mct-nightly 2.2.0.20240917.426__py3-none-any.whl → 2.2.0.20240918.448__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {mct_nightly-2.2.0.20240917.426.dist-info → mct_nightly-2.2.0.20240918.448.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.2.0.20240917.426.dist-info → mct_nightly-2.2.0.20240918.448.dist-info}/RECORD +30 -20
  3. {mct_nightly-2.2.0.20240917.426.dist-info → mct_nightly-2.2.0.20240918.448.dist-info}/top_level.txt +1 -0
  4. model_compression_toolkit/__init__.py +1 -1
  5. model_compression_toolkit/core/common/graph/base_node.py +3 -0
  6. model_compression_toolkit/core/common/graph/functional_node.py +1 -1
  7. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -1
  8. model_compression_toolkit/core/keras/reader/node_builder.py +23 -1
  9. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
  10. model_compression_toolkit/core/pytorch/reader/graph_builders.py +13 -4
  11. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +12 -3
  12. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +10 -1
  13. model_compression_toolkit/gptq/__init__.py +17 -5
  14. model_compression_toolkit/gptq/common/gptq_config.py +88 -75
  15. model_compression_toolkit/gptq/pytorch/gptq_training.py +18 -9
  16. model_compression_toolkit/gptq/pytorch/quantization_facade.py +49 -29
  17. model_compression_toolkit/gptq/pytorch/quantizer/gradual_activation_quantization.py +80 -0
  18. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +10 -10
  19. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +6 -49
  20. model_compression_toolkit/trainable_infrastructure/pytorch/annealing_schedulers.py +39 -0
  21. model_compression_toolkit/trainable_infrastructure/pytorch/util.py +29 -0
  22. tests_pytest/__init__.py +14 -0
  23. tests_pytest/pytorch/__init__.py +14 -0
  24. tests_pytest/pytorch/gptq/__init__.py +14 -0
  25. tests_pytest/pytorch/gptq/test_annealing_cfg.py +40 -0
  26. tests_pytest/pytorch/gptq/test_gradual_act_quantization.py +100 -0
  27. tests_pytest/pytorch/trainable_infrastructure/__init__.py +14 -0
  28. tests_pytest/pytorch/trainable_infrastructure/test_linear_annealing.py +49 -0
  29. {mct_nightly-2.2.0.20240917.426.dist-info → mct_nightly-2.2.0.20240918.448.dist-info}/LICENSE.md +0 -0
  30. {mct_nightly-2.2.0.20240917.426.dist-info → mct_nightly-2.2.0.20240918.448.dist-info}/WHEEL +0 -0
@@ -21,6 +21,8 @@ import copy
21
21
  import torch
22
22
 
23
23
  from model_compression_toolkit.core.common.hessian import HessianInfoService
24
+ from model_compression_toolkit.gptq.pytorch.quantizer.gradual_activation_quantization import \
25
+ get_gradual_activation_quantizer_wrapper_factory
24
26
  from model_compression_toolkit.logger import Logger
25
27
  from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
26
28
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
@@ -36,6 +38,7 @@ from model_compression_toolkit.gptq.pytorch.graph_info import get_gptq_trainable
36
38
  from model_compression_toolkit.gptq.pytorch.quantizer.quantization_builder import quantization_builder
37
39
  from model_compression_toolkit.gptq.pytorch.quantizer.regularization_factory import get_regularization
38
40
  from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder
41
+ from model_compression_toolkit.trainable_infrastructure.pytorch.util import get_total_grad_steps
39
42
 
40
43
 
41
44
  class PytorchGPTQTrainer(GPTQTrainer):
@@ -66,6 +69,13 @@ class PytorchGPTQTrainer(GPTQTrainer):
66
69
  representative_data_gen: Dataset to use for inputs of the models.
67
70
  hessian_info_service: HessianInfoService to fetch info based on the hessian approximation of the float model.
68
71
  """
72
+ def _get_total_grad_steps():
73
+ return get_total_grad_steps(representative_data_gen) * gptq_config.n_epochs
74
+
75
+ # must be set prior to model building in the base class constructor
76
+ self.gradual_act_quantizer_wrapper_factory = get_gradual_activation_quantizer_wrapper_factory(
77
+ gptq_config, _get_total_grad_steps)
78
+
69
79
  super().__init__(graph_float,
70
80
  graph_quant,
71
81
  gptq_config,
@@ -98,7 +108,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
98
108
 
99
109
  self.weights_for_average_loss = to_torch_tensor(self.compute_hessian_based_weights())
100
110
 
101
- self.reg_func = get_regularization(self.gptq_config, representative_data_gen)
111
+ self.reg_func = get_regularization(self.gptq_config, _get_total_grad_steps)
102
112
 
103
113
  def _is_gptq_weights_trainable(self,
104
114
  node: BaseNode) -> bool:
@@ -145,7 +155,6 @@ class PytorchGPTQTrainer(GPTQTrainer):
145
155
  def get_activation_quantizer_holder(self, n: BaseNode) -> Callable:
146
156
  """
147
157
  Retrieve a PytorchActivationQuantizationHolder layer to use for activation quantization of a node.
148
- If the layer is not supposed to be wrapped with an activation quantizer - return None.
149
158
  Args:
150
159
  n: Node to attach a PytorchActivationQuantizationHolder to its output.
151
160
  Returns:
@@ -153,13 +162,13 @@ class PytorchGPTQTrainer(GPTQTrainer):
153
162
  """
154
163
  _, activation_quantizers = quantization_builder(n, self.gptq_config)
155
164
  # Holder by definition uses a single quantizer for the activation quantization
156
- # thus we make sure this is the only possible case (unless it's a node we no activation
157
- # quantization, which in this case has an empty list).
158
- if len(activation_quantizers) == 1:
159
- return PytorchActivationQuantizationHolder(activation_quantizers[0])
160
- Logger.critical(f"'PytorchActivationQuantizationHolder' requires exactly one quantizer, "
161
- f"but {len(activation_quantizers)} were found for node {n.name}. "
162
- f"Ensure the node is configured with a single activation quantizer.")
165
+ # thus we make sure this is the only possible case
166
+ if len(activation_quantizers) != 1:
167
+ Logger.critical(f"'PytorchActivationQuantizationHolder' requires exactly one quantizer, "
168
+ f"but {len(activation_quantizers)} were found for node {n.name}. "
169
+ f"Ensure the node is configured with a single activation quantizer.")
170
+ quantizer = self.gradual_act_quantizer_wrapper_factory(activation_quantizers[0])
171
+ return PytorchActivationQuantizationHolder(quantizer)
163
172
 
164
173
  def build_gptq_model(self):
165
174
  """
@@ -13,26 +13,26 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  import copy
16
+ from typing import Callable, Union
16
17
 
17
- from typing import Callable
18
- from model_compression_toolkit.core import common
19
- from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE
20
- from model_compression_toolkit.verify_packages import FOUND_TORCH
18
+ from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE, PYTORCH
19
+ from model_compression_toolkit.core import CoreConfig
20
+ from model_compression_toolkit.core.analyzer import analyzer_model_quantization
21
+ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
22
+ MixedPrecisionQuantizationConfig
23
+ from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
24
+ ResourceUtilization
21
25
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
22
- from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
23
- from model_compression_toolkit.logger import Logger
24
- from model_compression_toolkit.constants import PYTORCH
25
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, GPTQHessianScoresConfig
26
- from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
27
- from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
28
26
  from model_compression_toolkit.core.runner import core_runner
27
+ from model_compression_toolkit.gptq.common.gptq_config import (
28
+ GradientPTQConfig, GPTQHessianScoresConfig, GradualActivationQuantizationConfig)
29
+ from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
29
30
  from model_compression_toolkit.gptq.keras.quantization_facade import GPTQ_MOMENTUM
30
31
  from model_compression_toolkit.gptq.runner import gptq_runner
31
- from model_compression_toolkit.core.analyzer import analyzer_model_quantization
32
- from model_compression_toolkit.core import CoreConfig
33
- from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
34
- MixedPrecisionQuantizationConfig
35
- from model_compression_toolkit.metadata import get_versions_dict, create_model_metadata
32
+ from model_compression_toolkit.logger import Logger
33
+ from model_compression_toolkit.metadata import create_model_metadata
34
+ from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
35
+ from model_compression_toolkit.verify_packages import FOUND_TORCH
36
36
 
37
37
  LR_DEFAULT = 1e-4
38
38
  LR_REST_DEFAULT = 1e-4
@@ -53,33 +53,38 @@ if FOUND_TORCH:
53
53
  DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
54
54
 
55
55
  def get_pytorch_gptq_config(n_epochs: int,
56
- optimizer: Optimizer = Adam([torch.Tensor([])], lr=LR_DEFAULT),
57
- optimizer_rest: Optimizer = Adam([torch.Tensor([])], lr=LR_REST_DEFAULT),
56
+ optimizer: Optimizer = None,
57
+ optimizer_rest: Optimizer = None,
58
58
  loss: Callable = multiple_tensors_mse_loss,
59
59
  log_function: Callable = None,
60
60
  use_hessian_based_weights: bool = True,
61
61
  regularization_factor: float = REG_DEFAULT,
62
- hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE
62
+ hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE,
63
+ gradual_activation_quantization: Union[bool, GradualActivationQuantizationConfig] = False,
63
64
  ) -> GradientPTQConfig:
64
65
  """
65
- Create a GradientPTQConfigV2 instance for Pytorch models.
66
+ Create a GradientPTQConfig instance for Pytorch models.
66
67
 
67
68
  args:
68
69
  n_epochs (int): Number of epochs for running the representative dataset for fine-tuning.
69
70
  optimizer (Optimizer): Pytorch optimizer to use for fine-tuning for auxiliry variable.
70
71
  optimizer_rest (Optimizer): Pytorch optimizer to use for fine-tuning of the bias variable.
71
- loss (Callable): loss to use during fine-tuning. should accept 4 lists of tensors. 1st list of quantized tensors, the 2nd list is the float tensors, the 3rd is a list of quantized weights and the 4th is a list of float weights.
72
+ loss (Callable): loss to use during fine-tuning. See the default loss function for the exact interface.
72
73
  log_function (Callable): Function to log information about the gptq process.
73
74
  use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
74
75
  regularization_factor (float): A floating point number that defines the regularization factor.
75
76
  hessian_batch_size (int): Batch size for Hessian computation in Hessian-based weights GPTQ.
77
+ gradual_activation_quantization (bool, GradualActivationQuantizationConfig):
78
+ If False, GradualActivationQuantization is disabled.
79
+ If True, GradualActivationQuantization is enabled with the default settings.
80
+ GradualActivationQuantizationConfig object can be passed to use non-default settings.
76
81
 
77
82
  returns:
78
- a GradientPTQConfigV2 object to use when fine-tuning the quantized model using gptq.
83
+ a GradientPTQConfig object to use when fine-tuning the quantized model using gptq.
79
84
 
80
85
  Examples:
81
86
 
82
- Import MCT and Create a GradientPTQConfigV2 to run for 5 epochs:
87
+ Import MCT and Create a GradientPTQConfig to run for 5 epochs:
83
88
 
84
89
  >>> import model_compression_toolkit as mct
85
90
  >>> gptq_conf = mct.gptq.get_pytorch_gptq_config(n_epochs=5)
@@ -89,16 +94,31 @@ if FOUND_TORCH:
89
94
  >>> import torch
90
95
  >>> gptq_conf = mct.gptq.get_pytorch_gptq_config(n_epochs=3, optimizer=torch.optim.Adam([torch.Tensor(1)]))
91
96
 
92
- The configuration can be passed to :func:`~model_compression_toolkit.pytorch_post_training_quantization` in order to quantize a pytorch model using gptq.
97
+ To enable Gradual Activation Quantization with non-default settings build GradualActivationQuantizationConfig:
98
+ >>> gradual_act_conf = mct.gptq.GradualActivationQuantizationConfig(mct.gptq.QFractionLinearAnnealingConfig(initial_q_fraction=0.2))
99
+ >>> gptq_conf = mct.gptq.get_pytorch_gptq_config(n_epochs=3, gradual_activation_quantization=gradual_act_conf)
100
+ The configuration can be passed to :func:`~model_compression_toolkit.pytorch_gradient_post_training_quantization` in order to quantize a pytorch model using gptq.
93
101
 
94
102
  """
103
+ optimizer = optimizer or Adam([torch.Tensor([])], lr=LR_DEFAULT)
104
+ optimizer_rest = optimizer_rest or Adam([torch.Tensor([])], lr=LR_REST_DEFAULT)
105
+
95
106
  bias_optimizer = torch.optim.SGD([torch.Tensor([])], lr=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM)
107
+
108
+ if isinstance(gradual_activation_quantization, bool):
109
+ gradual_quant_config = GradualActivationQuantizationConfig() if gradual_activation_quantization else None
110
+ elif isinstance(gradual_activation_quantization, GradualActivationQuantizationConfig):
111
+ gradual_quant_config = gradual_activation_quantization
112
+ else:
113
+ raise TypeError(f'gradual_activation_quantization argument should be bool or '
114
+ f'GradualActivationQuantizationConfig, received {type(gradual_activation_quantization)}') # pragma: no cover
115
+
96
116
  return GradientPTQConfig(n_epochs, optimizer, optimizer_rest=optimizer_rest, loss=loss,
97
117
  log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer,
98
118
  use_hessian_based_weights=use_hessian_based_weights,
99
119
  regularization_factor=regularization_factor,
100
- hessian_weights_config=GPTQHessianScoresConfig(hessian_batch_size=hessian_batch_size))
101
-
120
+ hessian_weights_config=GPTQHessianScoresConfig(hessian_batch_size=hessian_batch_size),
121
+ gradual_activation_quantization_config=gradual_quant_config)
102
122
 
103
123
  def pytorch_gradient_post_training_quantization(model: Module,
104
124
  representative_data_gen: Callable,
@@ -222,11 +242,11 @@ if FOUND_TORCH:
222
242
  else:
223
243
  # If torch is not installed,
224
244
  # we raise an exception when trying to use these functions.
225
- def get_pytorch_gptq_config(*args, **kwargs):
245
+ def get_pytorch_gptq_config(*args, **kwargs): # pragma: no cover
226
246
  Logger.critical("PyTorch must be installed to use 'get_pytorch_gptq_config'. "
227
- "The 'torch' package is missing.") # pragma: no cover
247
+ "The 'torch' package is missing.")
228
248
 
229
249
 
230
- def pytorch_gradient_post_training_quantization(*args, **kwargs):
250
+ def pytorch_gradient_post_training_quantization(*args, **kwargs): # pragma: no cover
231
251
  Logger.critical("PyTorch must be installed to use 'pytorch_gradient_post_training_quantization'. "
232
- "The 'torch' package is missing.") # pragma: no cover
252
+ "The 'torch' package is missing.")
@@ -0,0 +1,80 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from functools import partial
16
+ from typing import Callable
17
+
18
+ from model_compression_toolkit.gptq import GradientPTQConfig, QFractionLinearAnnealingConfig
19
+ from model_compression_toolkit.trainable_infrastructure import BasePytorchTrainableQuantizer
20
+
21
+ from model_compression_toolkit.trainable_infrastructure.pytorch.annealing_schedulers import LinearAnnealingScheduler
22
+
23
+
24
+ def get_gradual_activation_quantizer_wrapper_factory(gptq_config: GradientPTQConfig,
25
+ get_total_grad_steps_fn: Callable[[], int]) \
26
+ -> Callable[[BasePytorchTrainableQuantizer], 'GradualActivationQuantizerWrapper']:
27
+ """
28
+ Get a factory for 'GradualActivationQuantizerWrapper'.
29
+
30
+ Args:
31
+ gptq_config: GPTQ configuration.
32
+ get_total_grad_steps_fn: a callable to obtain the total expected number of gradient steps.
33
+
34
+ Returns:
35
+ A factory function to build 'GradualActivationQuantizerWrapper' from Quantizer.
36
+ """
37
+ if gptq_config.gradual_activation_quantization_config is None:
38
+ return lambda q: q
39
+
40
+ annealing_cfg = gptq_config.gradual_activation_quantization_config.q_fraction_scheduler_policy
41
+ if isinstance(annealing_cfg, QFractionLinearAnnealingConfig):
42
+ t_end = annealing_cfg.end_step or get_total_grad_steps_fn()
43
+ factor_scheduler = LinearAnnealingScheduler(t_start=annealing_cfg.start_step, t_end=t_end,
44
+ initial_val=annealing_cfg.initial_q_fraction,
45
+ target_val=annealing_cfg.target_q_fraction)
46
+ else:
47
+ raise ValueError(f'Unknown annealing policy {annealing_cfg}')
48
+
49
+ return partial(GradualActivationQuantizerWrapper, q_fraction_scheduler=factor_scheduler)
50
+
51
+
52
+ class GradualActivationQuantizerWrapper:
53
+ # TODO update paper's url
54
+ """
55
+ Quantizer wrapper for Gradual Activation Quantization training (https://arxiv.org/abs/2309.11531).
56
+
57
+ It computes the weighted sum of the float activation 'x' and the quantized activation 'q(x)':
58
+
59
+ out = (1 - q_fraction) * x + q_fraction * q(x)
60
+
61
+ where 'q_fraction' is a tensor fraction to quantize in the range [0, 1] provided by a scheduler.
62
+
63
+ Args:
64
+ quantizer: quantizer to wrap.
65
+ q_fraction_scheduler: a callable that accepts a gradient step and returns the corresponding quantized fraction.
66
+ """
67
+ def __init__(self, quantizer: BasePytorchTrainableQuantizer, q_fraction_scheduler: Callable[[int], float]):
68
+ self.quantizer = quantizer
69
+ self.q_fraction_scheduler = q_fraction_scheduler
70
+ self.step_cnt = 0
71
+
72
+ def __call__(self, x, training: bool = True):
73
+ q_fraction = self.q_fraction_scheduler(self.step_cnt)
74
+ out_q = self.quantizer(x, training)
75
+ out = (1 - q_fraction) * x + q_fraction * out_q
76
+ self.step_cnt += 1
77
+ return out
78
+
79
+ def initialize_quantization(self, *args, **kwargs):
80
+ self.quantizer.initialize_quantization(*args, **kwargs)
@@ -12,33 +12,33 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from tqdm import tqdm
16
15
  from typing import Callable
17
16
 
18
- from model_compression_toolkit.gptq import RoundingType, GradientPTQConfig, GradientPTQConfig
17
+ from model_compression_toolkit.gptq import RoundingType, GradientPTQConfig
19
18
  from model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.soft_quantizer_reg import \
20
19
  SoftQuantizerRegularization
20
+ from model_compression_toolkit.trainable_infrastructure.pytorch.annealing_schedulers import LinearAnnealingScheduler
21
21
 
22
22
 
23
- def get_regularization(gptq_config: GradientPTQConfig, representative_data_gen: Callable) -> Callable:
23
+ WARMUP_STEP_FRACTION = 0.2
24
+
25
+ def get_regularization(gptq_config: GradientPTQConfig, get_total_grad_steps_fn: Callable[[], int]) -> Callable:
24
26
  """
25
27
  Returns a function that computes the regularization term for GPTQ training based on the given
26
28
  rounding type in the GPTQ configuration.
27
29
 
28
30
  Args:
29
31
  gptq_config: A GPTQ configuration.
30
- representative_data_gen: Dataset used for the GPTQ training.
32
+ get_total_grad_steps_fn: a callable to obtain the total expected number of gradient steps.
31
33
 
32
34
  Returns: A function for computing the regularization. If there is no regularization function defined for the given
33
35
  rounding type, then it returns a function that just returns 0.
34
36
 
35
37
  """
36
38
  if gptq_config.rounding_type == RoundingType.SoftQuantizer:
37
- # dry run on the representative dataset to count number of batches
38
- num_batches = 0
39
- for _ in tqdm(representative_data_gen(), "GPTQ initialization"):
40
- num_batches += 1
41
-
42
- return SoftQuantizerRegularization(total_gradient_steps=num_batches * gptq_config.n_epochs)
39
+ total_gradient_steps = get_total_grad_steps_fn()
40
+ t_start = int(WARMUP_STEP_FRACTION * total_gradient_steps)
41
+ scheduler = LinearAnnealingScheduler(t_start=t_start, t_end=total_gradient_steps, initial_val=20, target_val=2)
42
+ return SoftQuantizerRegularization(scheduler)
43
43
  else:
44
44
  return lambda m, e_reg: 0
@@ -12,57 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import List
15
+ from typing import List, Callable
16
16
 
17
17
  import torch
18
- import numpy as np
19
18
  from torch import nn
20
19
 
20
+ from mct_quantizers import PytorchQuantizationWrapper
21
21
  from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
22
- from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
23
22
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
24
- from mct_quantizers import PytorchQuantizationWrapper
25
-
26
-
27
- class LinearTempDecay:
28
- """
29
- Annealing process for the soft quantizer regularization temperature term.
30
- """
31
-
32
- def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 20, end_b: int = 2):
33
- """
34
- Initializes a LinearTempDecay object.
35
-
36
- Args:
37
- t_max: maximal time step.
38
- rel_start_decay: Decay step size at the beginning of the process.
39
- start_b: Starting value of the regularization term.
40
- end_b: Target value of the regularization term.
41
- """
42
-
43
- self.t_max = t_max
44
- self.start_decay = rel_start_decay * t_max
45
- self.start_b = start_b
46
- self.end_b = end_b
47
-
48
- def __call__(self, t: float) -> float:
49
- """
50
- Cosine annealing scheduler for soft quantizer regularization temperature term.
51
-
52
- Args:
53
- t: The current time step.
54
-
55
- Returns: Scheduled temperature.
56
- """
57
-
58
- is_before_start_decay = (t < self.start_decay)
59
-
60
- rel_t = (t - self.start_decay) / (self.t_max - self.start_decay)
61
-
62
- return self.start_b * is_before_start_decay + \
63
- (1 - is_before_start_decay) * \
64
- (self.end_b + (self.start_b - self.end_b) * torch.maximum(to_torch_tensor(np.array([0.0])),
65
- to_torch_tensor(np.array((1 - rel_t)))))
66
23
 
67
24
 
68
25
  class SoftQuantizerRegularization:
@@ -70,16 +27,16 @@ class SoftQuantizerRegularization:
70
27
  A class to handle the computation of soft quantizer regularization for GPTQ training.
71
28
  """
72
29
 
73
- def __init__(self, total_gradient_steps: int):
30
+ def __init__(self, beta_scheduler: Callable[[int], float]):
74
31
  """
75
32
  Initializes the regularization computation object with a LinearDecay object.
76
33
 
77
34
  Args:
78
- total_gradient_steps: The number of gradient steps during optimization.
35
+ beta_scheduler: a callable that accepts current time step and returns a corresponding beta value.
79
36
  """
80
37
 
81
38
  # Initializing the temperature decay according to the number of expected gradient steps
82
- self.linear_decay = LinearTempDecay(total_gradient_steps)
39
+ self.beta_scheduler = beta_scheduler
83
40
 
84
41
  self.count_iter = 0
85
42
 
@@ -95,7 +52,7 @@ class SoftQuantizerRegularization:
95
52
  """
96
53
 
97
54
  soft_reg_aux: List[torch.Tensor] = []
98
- b = self.linear_decay(self.count_iter)
55
+ b = self.beta_scheduler(self.count_iter)
99
56
  for layer in model.modules():
100
57
  if isinstance(layer, PytorchQuantizationWrapper):
101
58
  kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
@@ -0,0 +1,39 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
16
+
17
+
18
+ class LinearAnnealingScheduler:
19
+ def __init__(self, t_start: int, t_end: int, initial_val: float, target_val: float):
20
+ """
21
+ Linear annealing scheduler. Returns the corresponding annealed value per time step.
22
+
23
+ Args:
24
+ t_start: time step to begin annealing.
25
+ t_end: time step to complete annealing.
26
+ initial_val: initial value.
27
+ target_val: target value.
28
+ """
29
+ if not (0 <= t_start < t_end):
30
+ raise ValueError(f'Expected 0 <= t_start < t_end, actual {t_end=} {t_start=}')
31
+
32
+ self.t_start = t_start
33
+ self.t_end = t_end
34
+ self.initial_val = initial_val
35
+ self.target_val = target_val
36
+
37
+ def __call__(self, t: int) -> float:
38
+ factor = to_torch_tensor((t - self.t_start) / (self.t_end - self.t_start)).clip(0, 1)
39
+ return self.initial_val + factor * (self.target_val - self.initial_val)
@@ -0,0 +1,29 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from functools import cache
16
+ from typing import Callable
17
+
18
+ from tqdm import tqdm
19
+
20
+
21
+ @cache
22
+ def get_total_grad_steps(representative_data_gen: Callable) -> int:
23
+ # dry run on the representative dataset to count number of batches
24
+ num_batches = 0
25
+ for _ in tqdm(representative_data_gen(), "Estimating representative dataset size"):
26
+ num_batches += 1
27
+ return num_batches
28
+
29
+
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,40 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import pytest
16
+
17
+ from model_compression_toolkit.gptq import QFractionLinearAnnealingConfig
18
+
19
+
20
+ def test_linear_annealing_cfg_validation():
21
+ with pytest.raises(ValueError, match='Expected.* target_q_fraction <= 1'):
22
+ QFractionLinearAnnealingConfig(initial_q_fraction=0.1, target_q_fraction=1.1, start_step=0, end_step=None)
23
+
24
+ with pytest.raises(ValueError, match='Expected.* 0 <= initial_q_fraction'):
25
+ QFractionLinearAnnealingConfig(initial_q_fraction=-0.1, target_q_fraction=-0.9, start_step=0, end_step=100)
26
+
27
+ with pytest.raises(ValueError, match='Expected.* initial_q_fraction < target_q_fraction'):
28
+ QFractionLinearAnnealingConfig(initial_q_fraction=0.1, target_q_fraction=0.1, start_step=0, end_step=100)
29
+
30
+ with pytest.raises(ValueError, match='Expected.* initial_q_fraction < target_q_fraction'):
31
+ QFractionLinearAnnealingConfig(initial_q_fraction=0.2, target_q_fraction=0.1, start_step=0, end_step=100)
32
+
33
+ with pytest.raises(ValueError, match='Expected.* start_step >= 0'):
34
+ QFractionLinearAnnealingConfig(initial_q_fraction=0, target_q_fraction=1, start_step=-1, end_step=100)
35
+
36
+ with pytest.raises(ValueError, match='Expected.* start_step < end_step'):
37
+ QFractionLinearAnnealingConfig(initial_q_fraction=0, target_q_fraction=1, start_step=100, end_step=100)
38
+
39
+ with pytest.raises(ValueError, match='Expected.* start_step < end_step'):
40
+ QFractionLinearAnnealingConfig(initial_q_fraction=0, target_q_fraction=1, start_step=100, end_step=99)