mct-nightly 2.1.0.20240725.446__py3-none-any.whl → 2.1.0.20240727.431__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {mct_nightly-2.1.0.20240725.446.dist-info → mct_nightly-2.1.0.20240727.431.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.1.0.20240725.446.dist-info → mct_nightly-2.1.0.20240727.431.dist-info}/RECORD +35 -31
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/pytorch/constants.py +6 -1
  5. model_compression_toolkit/core/pytorch/utils.py +27 -0
  6. model_compression_toolkit/data_generation/common/data_generation.py +20 -18
  7. model_compression_toolkit/data_generation/common/data_generation_config.py +8 -11
  8. model_compression_toolkit/data_generation/common/enums.py +24 -12
  9. model_compression_toolkit/data_generation/common/image_pipeline.py +50 -12
  10. model_compression_toolkit/data_generation/common/model_info_exctractors.py +0 -8
  11. model_compression_toolkit/data_generation/common/optimization_utils.py +7 -11
  12. model_compression_toolkit/data_generation/keras/constants.py +5 -2
  13. model_compression_toolkit/data_generation/keras/image_operations.py +189 -0
  14. model_compression_toolkit/data_generation/keras/image_pipeline.py +50 -104
  15. model_compression_toolkit/data_generation/keras/keras_data_generation.py +28 -36
  16. model_compression_toolkit/data_generation/keras/model_info_exctractors.py +0 -13
  17. model_compression_toolkit/data_generation/keras/optimization_functions/bn_layer_weighting_functions.py +16 -6
  18. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +219 -0
  19. model_compression_toolkit/data_generation/keras/optimization_functions/output_loss_functions.py +39 -13
  20. model_compression_toolkit/data_generation/keras/optimization_functions/scheduler_step_functions.py +6 -98
  21. model_compression_toolkit/data_generation/keras/optimization_utils.py +15 -28
  22. model_compression_toolkit/data_generation/pytorch/constants.py +4 -1
  23. model_compression_toolkit/data_generation/pytorch/image_operations.py +105 -0
  24. model_compression_toolkit/data_generation/pytorch/image_pipeline.py +70 -78
  25. model_compression_toolkit/data_generation/pytorch/model_info_exctractors.py +0 -10
  26. model_compression_toolkit/data_generation/pytorch/optimization_functions/bn_layer_weighting_functions.py +17 -6
  27. model_compression_toolkit/data_generation/pytorch/optimization_functions/image_initilization.py +2 -2
  28. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +219 -0
  29. model_compression_toolkit/data_generation/pytorch/optimization_functions/output_loss_functions.py +55 -21
  30. model_compression_toolkit/data_generation/pytorch/optimization_functions/scheduler_step_functions.py +15 -0
  31. model_compression_toolkit/data_generation/pytorch/optimization_utils.py +32 -54
  32. model_compression_toolkit/data_generation/pytorch/pytorch_data_generation.py +57 -52
  33. {mct_nightly-2.1.0.20240725.446.dist-info → mct_nightly-2.1.0.20240727.431.dist-info}/LICENSE.md +0 -0
  34. {mct_nightly-2.1.0.20240725.446.dist-info → mct_nightly-2.1.0.20240727.431.dist-info}/WHEEL +0 -0
  35. {mct_nightly-2.1.0.20240725.446.dist-info → mct_nightly-2.1.0.20240727.431.dist-info}/top_level.txt +0 -0
@@ -12,36 +12,45 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Type, Dict, List
15
+ import numpy as np
16
+ import torch
17
+ from typing import Type, Dict, Tuple, Union, List
16
18
 
17
19
  from torch import Tensor
18
20
  from torchvision.transforms import RandomCrop, RandomHorizontalFlip, CenterCrop
19
21
 
20
-
21
- from model_compression_toolkit.data_generation.common.enums import ImagePipelineType, ImageNormalizationType
22
+ from model_compression_toolkit.data_generation.common.enums import ImagePipelineType
22
23
  from model_compression_toolkit.data_generation.common.image_pipeline import BaseImagePipeline
24
+ from model_compression_toolkit.data_generation.pytorch.image_operations import Smoothing, create_valid_grid
23
25
 
24
26
 
25
27
  class PytorchIdentityImagePipeline(BaseImagePipeline):
26
28
  """
27
29
  An image pipeline implementation for PyTorch models that returns the input images as is (identity).
28
30
  """
29
- def __init__(self, output_image_size: int, extra_pixels: int = 0):
31
+ def __init__(self,
32
+ output_image_size: Union[int, Tuple[int, int]],
33
+ extra_pixels: Union[int, Tuple[int, int]] = 0,
34
+ normalization: List[List[int]] = [[0, 0, 0], [1, 1, 1]],
35
+ image_clipping: bool = True,
36
+ ):
30
37
  """
31
38
  Initialize the PytorchIdentityImagePipeline.
32
39
 
33
40
  Args:
34
- output_image_size (int): The output image size.
35
- extra_pixels (int): Extra pixels to add to the input image size (not used in identity pipeline).
41
+ output_image_size (Union[int, Tuple[int, int]]): The output image size.
42
+ extra_pixels (Union[int, Tuple[int, int]]): Extra pixels to add to the input image size (not used in identity pipeline).
43
+ normalization (List[List[float]]): The image normalization values for processing images during optimization.
44
+ image_clipping (bool): Whether to clip images during optimization.
36
45
  """
37
- super(PytorchIdentityImagePipeline, self).__init__(output_image_size)
46
+ super(PytorchIdentityImagePipeline, self).__init__(output_image_size, extra_pixels, image_clipping, normalization)
38
47
 
39
- def get_image_input_size(self) -> int:
48
+ def get_image_input_size(self) -> Tuple[int, int]:
40
49
  """
41
50
  Get the input size of the image.
42
51
 
43
52
  Returns:
44
- int: The input image size.
53
+ Tuple[int, int]: The input image size.
45
54
  """
46
55
  return self.output_image_size
47
56
 
@@ -70,31 +79,43 @@ class PytorchIdentityImagePipeline(BaseImagePipeline):
70
79
  return images
71
80
 
72
81
 
73
- class PytorchRandomCropImagePipeline(BaseImagePipeline):
82
+ class PytorchSmoothAugmentationImagePipeline(BaseImagePipeline):
74
83
  """
75
- An image pipeline implementation for PyTorch models that includes random cropping.
84
+ An image pipeline implementation for PyTorch models that includes random cropping and flipping.
76
85
  """
77
- def __init__(self, output_image_size: int, extra_pixels: int = 0):
86
+ def __init__(self,
87
+ output_image_size: Union[int, Tuple[int, int]],
88
+ extra_pixels: Union[int, Tuple[int, int]] = 0,
89
+ normalization: List[List[int]] = [[0, 0, 0], [1, 1, 1]],
90
+ image_clipping: bool = True,
91
+ smoothing_filter_size: int = 3,
92
+ smoothing_filter_sigma: float = 1.25):
78
93
  """
79
94
  Initialize the PytorchRandomCropFlipImagePipeline.
80
95
 
81
96
  Args:
82
- output_image_size (int): The output image size.
83
- extra_pixels (int): Extra pixels to add to the input image size. Defaults to 0.
84
- """
85
- super(PytorchRandomCropImagePipeline, self).__init__(output_image_size)
86
- self.extra_pixels = extra_pixels
97
+ output_image_size (Union[int, Tuple[int, int]]): The output image size.
98
+ extra_pixels (Union[int, Tuple[int, int]]): Extra pixels to add to the input image size. Defaults to 0.
99
+ normalization (List[List[float]]): The image normalization values for processing images during optimization.
100
+ image_clipping (bool): Whether to clip images during optimization.
101
+ smoothing_filter_size (int): The size of the smoothing filter. Defaults to 3.
102
+ smoothing_filter_sigma (float): The standard deviation of the smoothing filter. Defaults to 1.25.
103
+ """
104
+ super(PytorchSmoothAugmentationImagePipeline, self).__init__(output_image_size, extra_pixels, image_clipping, normalization)
105
+ self.smoothing = Smoothing(size=smoothing_filter_size, sigma=smoothing_filter_sigma)
87
106
  self.random_crop = RandomCrop(self.output_image_size)
107
+ self.random_flip = RandomHorizontalFlip(0.5)
88
108
  self.center_crop = CenterCrop(self.output_image_size)
109
+ self.valid_grid = create_valid_grid(means=self.normalization[0], stds=self.normalization[1])
89
110
 
90
- def get_image_input_size(self) -> int:
111
+ def get_image_input_size(self) -> Tuple[int, int]:
91
112
  """
92
113
  Get the input size of the image.
93
114
 
94
115
  Returns:
95
- int: The input image size.
116
+ Tuple[int, int]: The input image size.
96
117
  """
97
- return self.output_image_size + self.extra_pixels
118
+ return tuple([o + e for (o, e) in zip(self.output_image_size, self.extra_pixels)])
98
119
 
99
120
  def image_input_manipulation(self, images: Tensor) -> Tensor:
100
121
  """
@@ -104,9 +125,14 @@ class PytorchRandomCropImagePipeline(BaseImagePipeline):
104
125
  images (Tensor): The input images.
105
126
 
106
127
  Returns:
107
- Tensor: The manipulated images (randomly flipped and cropped).
128
+ Tensor: The manipulated images.
108
129
  """
109
- return self.random_crop(images)
130
+ new_images = self.random_flip(images)
131
+ new_images = self.smoothing(new_images)
132
+ new_images = self.random_crop(new_images)
133
+ if self.image_clipping:
134
+ new_images = self.clip_images(new_images, self.valid_grid)
135
+ return new_images
110
136
 
111
137
  def image_output_finalize(self, images: Tensor) -> Tensor:
112
138
  """
@@ -118,71 +144,37 @@ class PytorchRandomCropImagePipeline(BaseImagePipeline):
118
144
  Returns:
119
145
  Tensor: The finalized images (center cropped).
120
146
  """
121
- return self.center_crop(images)
122
-
147
+ new_images = self.smoothing(images)
148
+ new_images = self.center_crop(new_images)
149
+ if self.image_clipping:
150
+ new_images = self.clip_images(new_images, self.valid_grid)
151
+ return new_images
123
152
 
124
- class PytorchRandomCropFlipImagePipeline(BaseImagePipeline):
125
- """
126
- An image pipeline implementation for PyTorch models that includes random cropping and flipping.
127
- """
128
- def __init__(self, output_image_size: int, extra_pixels: int = 0):
153
+ @staticmethod
154
+ def clip_images(images: Tensor, valid_grid: Tensor, reflection: bool = False) -> Tensor:
129
155
  """
130
- Initialize the PytorchRandomCropFlipImagePipeline.
156
+ Clip the images based on a valid grid.
131
157
 
132
158
  Args:
133
- output_image_size (int): The output image size.
134
- extra_pixels (int): Extra pixels to add to the input image size. Defaults to 0.
135
- """
136
- super(PytorchRandomCropFlipImagePipeline, self).__init__(output_image_size)
137
- self.extra_pixels = extra_pixels
138
- self.random_crop = RandomCrop(self.output_image_size)
139
- self.random_flip = RandomHorizontalFlip(0.5)
140
- self.center_crop = CenterCrop(self.output_image_size)
141
-
142
- def get_image_input_size(self) -> int:
143
- """
144
- Get the input size of the image.
159
+ images (Tensor): The images to be clipped.
160
+ valid_grid (Tensor): The valid grid for clipping.
161
+ reflection (bool): Whether to apply reflection during clipping. Defaults to False.
145
162
 
146
163
  Returns:
147
- int: The input image size.
148
- """
149
- return self.output_image_size + self.extra_pixels
150
-
151
- def image_input_manipulation(self, images: Tensor) -> Tensor:
152
- """
153
- Manipulate the input images with random flipping and cropping.
154
-
155
- Args:
156
- images (Tensor): The input images.
157
-
158
- Returns:
159
- Tensor: The manipulated images (randomly flipped and cropped).
160
- """
161
- random_flipped_data = self.random_flip(images)
162
- return self.random_crop(random_flipped_data)
163
-
164
- def image_output_finalize(self, images: Tensor) -> Tensor:
165
- """
166
- Finalize the output images with center cropping.
167
-
168
- Args:
169
- images (Tensor): The output images.
170
-
171
- Returns:
172
- Tensor: The finalized images (center cropped).
173
- """
174
- return self.center_crop(images)
164
+ Tensor: The clipped images.
165
+ """
166
+ with torch.no_grad():
167
+ for i_ch in range(valid_grid.shape[0]):
168
+ clamp = torch.clamp(images[:, i_ch, :, :], valid_grid[i_ch, :].min(), valid_grid[i_ch, :].max())
169
+ if reflection:
170
+ images[:, i_ch, :, :] = 2 * clamp - images[:, i_ch, :, :]
171
+ else:
172
+ images[:, i_ch, :, :] = clamp
173
+ return images
175
174
 
176
175
 
177
176
  # Dictionary mapping ImagePipelineType to corresponding image pipeline classes
178
177
  image_pipeline_dict: Dict[ImagePipelineType, Type[BaseImagePipeline]] = {
179
178
  ImagePipelineType.IDENTITY: PytorchIdentityImagePipeline,
180
- ImagePipelineType.RANDOM_CROP: PytorchRandomCropImagePipeline,
181
- ImagePipelineType.RANDOM_CROP_FLIP: PytorchRandomCropFlipImagePipeline
182
- }
183
-
184
- # Dictionary mapping ImageNormalizationType to corresponding normalization values
185
- image_normalization_dict: Dict[ImageNormalizationType, List[List[float]]] = {
186
- ImageNormalizationType.TORCHVISION: [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]],
187
- ImageNormalizationType.NO_NORMALIZATION: [[0, 0, 0], [1, 1, 1]]
179
+ ImagePipelineType.SMOOTHING_AND_AUGMENTATION: PytorchSmoothAugmentationImagePipeline
188
180
  }
@@ -126,7 +126,6 @@ class PytorchActivationExtractor(ActivationExtractor):
126
126
  self.layer_types_to_extract_inputs = tuple(layer_types_to_extract_inputs)
127
127
  self.last_layer_types_to_extract_inputs = tuple(last_layer_types_to_extract_inputs)
128
128
  self.num_layers = sum([1 if isinstance(layer, tuple(layer_types_to_extract_inputs)) else 0 for layer in model.modules()])
129
- Logger.info(f'Number of layers = {self.num_layers}')
130
129
  self.hooks = {} # Dictionary to store InputHook instances by layer name
131
130
  self.last_linear_layers_hooks = {} # Dictionary to store InputHook instances by layer name
132
131
  self.hook_handles = [] # List to store hook handles
@@ -206,15 +205,6 @@ class PytorchActivationExtractor(ActivationExtractor):
206
205
  """
207
206
  return self.last_linear_layer_weights
208
207
 
209
- def get_num_extractor_layers(self) -> int:
210
- """
211
- Get the number of hooked layers in the model.
212
-
213
- Returns:
214
- int: Number of hooked layers in the model.
215
- """
216
- return self.num_layers
217
-
218
208
  def get_extractor_layer_names(self) -> List:
219
209
  """
220
210
  Get a list of the hooked layer names.
@@ -18,16 +18,21 @@ import torch
18
18
 
19
19
  from model_compression_toolkit.data_generation.common.enums import BNLayerWeightingType
20
20
  from model_compression_toolkit.data_generation.pytorch.model_info_exctractors import OriginalBNStatsHolder, \
21
- ActivationExtractor
21
+ ActivationExtractor, PytorchActivationExtractor
22
22
 
23
23
 
24
- def average_bn_layer_weighting_fn(orig_bn_stats_holder: OriginalBNStatsHolder, **kwargs) -> Dict[str, float]:
24
+ def average_bn_layer_weighting_fn(orig_bn_stats_holder: OriginalBNStatsHolder,
25
+ activation_extractor: PytorchActivationExtractor,
26
+ i_iter: int,
27
+ n_iter: int) -> Dict[str, float]:
25
28
  """
26
29
  Calculate average weighting for each batch normalization layer.
27
30
 
28
31
  Args:
29
32
  orig_bn_stats_holder (OriginalBNStatsHolder): Holder for original batch normalization statistics.
30
- **kwargs: Additional arguments if needed.
33
+ activation_extractor (PytorchActivationExtractor): The activation extractor for the model.
34
+ i_iter (int): Current optimization iteration.
35
+ n_iter (int): Total number of optimization iterations.
31
36
 
32
37
  Returns:
33
38
  Dict[str, float]: A dictionary containing layer names as keys and average weightings as values.
@@ -35,19 +40,25 @@ def average_bn_layer_weighting_fn(orig_bn_stats_holder: OriginalBNStatsHolder, *
35
40
  num_bn_layers = orig_bn_stats_holder.get_num_bn_layers()
36
41
  return {bn_layer_name: 1 / num_bn_layers for bn_layer_name in orig_bn_stats_holder.get_bn_layer_names()}
37
42
 
38
- def first_bn_multiplier_weighting_fn(orig_bn_stats_holder: OriginalBNStatsHolder, **kwargs) -> Dict[str, float]:
43
+ def first_bn_multiplier_weighting_fn(orig_bn_stats_holder: OriginalBNStatsHolder,
44
+ activation_extractor: PytorchActivationExtractor,
45
+ i_iter: int,
46
+ n_iter: int) -> Dict[str, float]:
39
47
  """
40
48
  Calculate layer weightings with a higher multiplier for the first batch normalization layer.
41
49
 
42
50
  Args:
43
51
  orig_bn_stats_holder (OriginalBNStatsHolder): Holder for original batch normalization statistics.
44
- **kwargs: Additional arguments if needed.
52
+ activation_extractor (PytorchActivationExtractor): The activation extractor for the model.
53
+ i_iter (int): Current optimization iteration.
54
+ n_iter (int): Total number of optimization iterations.
45
55
 
46
56
  Returns:
47
57
  Dict[str, float]: A dictionary containing layer names as keys and weightings as values.
48
58
  """
49
59
  layer_weighting_dict = {orig_bn_stats_holder.get_bn_layer_names()[0]: 10}
50
- return layer_weighting_dict.update({bn_layer_name: 1 for bn_layer_name in orig_bn_stats_holder.get_bn_layer_names()[1:]})
60
+ layer_weighting_dict.update({bn_layer_name: 1 for bn_layer_name in orig_bn_stats_holder.get_bn_layer_names()[1:]})
61
+ return layer_weighting_dict
51
62
 
52
63
 
53
64
  # Dictionary of layer weighting functions
@@ -104,7 +104,7 @@ def diverse_sample(size: Tuple[int, ...]) -> Tensor:
104
104
 
105
105
  def default_data_init_fn(
106
106
  n_images: int = 1000,
107
- size: tuple = (224, 224),
107
+ size: Union[int, Tuple[int, int]] = (224, 224),
108
108
  crop: int = 32,
109
109
  sample_fn: Callable = diverse_sample,
110
110
  batch_size: int = 50) -> Tuple[int, DataLoader]:
@@ -113,7 +113,7 @@ def default_data_init_fn(
113
113
 
114
114
  Args:
115
115
  n_images (int): The number of random samples.
116
- size (Tuple[int, int]): The size of each sample.
116
+ size (Union[int, Tuple[int, int]]): The size of each sample.
117
117
  crop (int): The crop size.
118
118
  sample_fn (Callable): The function to generate a random sample.
119
119
  batch_size (int): The batch size.
@@ -0,0 +1,219 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from torch.optim.optimizer import Optimizer
16
+ from torch import inf
17
+ from typing import Union, List, Dict, Any
18
+
19
+ from model_compression_toolkit.logger import Logger
20
+
21
+
22
+ class ReduceLROnPlateauWithReset:
23
+ """
24
+ Reduce learning rate when a metric has stopped improving. This scheduler allows resetting
25
+ the learning rate to the initial value after a specified number of bad epochs.
26
+ """
27
+
28
+ def __init__(self, optimizer: Optimizer, mode: str = 'min', factor: float = 0.1, patience: int = 10,
29
+ threshold: float = 1e-4, threshold_mode: str = 'rel', cooldown: int = 0,
30
+ min_lr: Union[float, List[float]] = 0, eps: float = 1e-8, verbose: bool = False):
31
+ """
32
+ Initialize the ReduceLROnPlateauWithReset scheduler.
33
+
34
+ Args:
35
+ optimizer (Optimizer): Wrapped optimizer.
36
+ mode (str): One of `min`, `max`. In `min` mode, lr will be reduced when the quantity
37
+ monitored has stopped decreasing; in `max` mode it will be reduced when the
38
+ quantity monitored has stopped increasing. Default: 'min'.
39
+ factor (float): Factor by which the learning rate will be reduced. new_lr = lr * factor.
40
+ Default: 0.1.
41
+ patience (int): Number of epochs with no improvement after which learning rate will be reduced.
42
+ Default: 10.
43
+ threshold (float): Threshold for measuring the new optimum, to only focus on significant changes.
44
+ Default: 1e-4.
45
+ threshold_mode (str): One of `rel`, `abs`. In `rel` mode, dynamic_threshold = best * ( 1 + threshold )
46
+ in 'max' mode or best * ( 1 - threshold ) in `min` mode. In `abs` mode, dynamic_threshold
47
+ = best + threshold in `max` mode or best - threshold in `min` mode. Default: 'rel'.
48
+ cooldown (int): Number of epochs to wait before resuming normal operation after lr has been reduced.
49
+ Default: 0.
50
+ min_lr (float or list): A scalar or a list of scalars. A lower bound on the learning rate of all param groups
51
+ or each group respectively. Default: 0.
52
+ eps (float): Minimal decay applied to lr. If the difference between new and old lr is smaller than eps,
53
+ the update is ignored. Default: 1e-8.
54
+ verbose (bool): If True, prints a message to stdout for each update. Default: False.
55
+ """
56
+ if factor >= 1.0:
57
+ Logger.critical('Factor should be < 1.0.') # pragma: no cover
58
+ self.factor = factor
59
+
60
+ # Attach optimizer
61
+ if not isinstance(optimizer, Optimizer):
62
+ Logger.critical('{} is not an Optimizer'.format(
63
+ type(optimizer).__name__)) # pragma: no cover
64
+ self.optimizer = optimizer
65
+
66
+ if isinstance(min_lr, (list, tuple)):
67
+ if len(min_lr) != len(optimizer.param_groups):
68
+ Logger.critical("expected {} min_lrs, got {}".format(
69
+ len(optimizer.param_groups), len(min_lr))) # pragma: no cover
70
+ self.min_lrs = list(min_lr)
71
+ else:
72
+ self.min_lrs = [min_lr] * len(optimizer.param_groups)
73
+
74
+ self.patience = patience
75
+ self.verbose = verbose
76
+ self.cooldown = cooldown
77
+ self.cooldown_counter = 0
78
+ self.mode = mode
79
+ self.threshold = threshold
80
+ self.threshold_mode = threshold_mode
81
+ self.best = None
82
+ self.num_bad_epochs = None
83
+ self.mode_worse = None # the worse value for the chosen mode
84
+ self.eps = eps
85
+ self.last_epoch = 0
86
+
87
+ self._init_is_better()
88
+ self._reset()
89
+
90
+ def _reset(self) -> None:
91
+ """
92
+ Resets num_bad_epochs counter and cooldown counter.
93
+ """
94
+ self.best = self.mode_worse
95
+ self.cooldown_counter = 0
96
+ self.num_bad_epochs = 0
97
+
98
+ def step(self, metrics: float, epoch: Union[int, None] = None) -> None:
99
+ """
100
+ Update learning rate based on the given metrics.
101
+
102
+ Args:
103
+ metrics (float): The value of the metric to evaluate.
104
+ epoch (int, optional): The current epoch number. If not provided, it is incremented.
105
+ """
106
+ # Convert `metrics` to float, in case it's a zero-dim Tensor
107
+ current = float(metrics)
108
+ if epoch is None:
109
+ epoch = self.last_epoch + 1
110
+ self.last_epoch = epoch
111
+
112
+ # Check if the current metrics are better than the best
113
+ if self.is_better(current, self.best):
114
+ self.best = current
115
+ self.num_bad_epochs = 0
116
+ else:
117
+ self.num_bad_epochs += 1
118
+
119
+ # Handle cooldown period
120
+ if self.in_cooldown:
121
+ self.cooldown_counter -= 1
122
+ self.num_bad_epochs = 0 # Ignore any bad epochs in cooldown
123
+
124
+ # Reduce learning rate if the number of bad epochs exceeds patience
125
+ if self.num_bad_epochs > self.patience:
126
+ self._reduce_lr(epoch)
127
+ self.cooldown_counter = self.cooldown
128
+ self.num_bad_epochs = 0
129
+ self.best = self.mode_worse
130
+
131
+ self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
132
+
133
+ def _reduce_lr(self, epoch: int) -> None:
134
+ """
135
+ Reduce the learning rate for each parameter group.
136
+
137
+ Args:
138
+ epoch (int): The current epoch number.
139
+ """
140
+ for i, param_group in enumerate(self.optimizer.param_groups):
141
+ old_lr = float(param_group['lr'])
142
+ new_lr = max(old_lr * self.factor, self.min_lrs[i])
143
+ if old_lr - new_lr > self.eps:
144
+ param_group['lr'] = new_lr
145
+ if self.verbose:
146
+ epoch_str = ("%.2f" if isinstance(epoch, float) else "%.5d") % epoch
147
+ print('Epoch {}: reducing learning rate'
148
+ ' of group {} to {:.4e}.'.format(epoch_str, i, new_lr))
149
+
150
+ @property
151
+ def in_cooldown(self) -> bool:
152
+ """
153
+ Check if the scheduler is in a cooldown period.
154
+
155
+ Returns:
156
+ bool: True if in cooldown period, False otherwise.
157
+ """
158
+ return self.cooldown_counter > 0
159
+
160
+ def is_better(self, a: float, best: Union[float, None]) -> bool:
161
+ """
162
+ Determine if the new value is better than the best value based on mode and threshold.
163
+
164
+ Args:
165
+ a (float): The new value to compare.
166
+ best (float): The best value to compare against.
167
+
168
+ Returns:
169
+ bool: True if the new value is better, False otherwise.
170
+ """
171
+ if best is None:
172
+ return True
173
+
174
+ if self.mode == 'min' and self.threshold_mode == 'rel':
175
+ rel_epsilon = 1. - self.threshold
176
+ return a < best * rel_epsilon
177
+ elif self.mode == 'min' and self.threshold_mode == 'abs':
178
+ return a < best - self.threshold
179
+ elif self.mode == 'max' and self.threshold_mode == 'rel':
180
+ rel_epsilon = self.threshold + 1.
181
+ return a > best * rel_epsilon
182
+ else: # mode == 'max' and threshold_mode == 'abs':
183
+ return a > best + self.threshold
184
+
185
+ def _init_is_better(self) -> None:
186
+ """
187
+ Initialize the comparison function for determining if a new value is better.
188
+
189
+ Raises:
190
+ ValueError: If an unknown mode or threshold mode is provided.
191
+ """
192
+ if self.mode not in {'min', 'max'}:
193
+ Logger.critical('mode ' + self.mode + ' is unknown!') # pragma: no cover
194
+ if self.threshold_mode not in {'rel', 'abs'}:
195
+ Logger.critical('threshold mode ' + self.threshold_mode + ' is unknown!') # pragma: no cover
196
+
197
+ if self.mode == 'min':
198
+ self.mode_worse = float('inf')
199
+ else: # mode == 'max':
200
+ self.mode_worse = float('-inf')
201
+
202
+ def state_dict(self) -> Dict[str, Any]:
203
+ """
204
+ Return the state of the scheduler as a dictionary.
205
+
206
+ Returns:
207
+ dict: The state of the scheduler.
208
+ """
209
+ return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
210
+
211
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
212
+ """
213
+ Load the scheduler state.
214
+
215
+ Args:
216
+ state_dict (dict): The state dictionary to load.
217
+ """
218
+ self.__dict__.update(state_dict)
219
+ self._init_is_better()