mct-nightly 2.1.0.20240724.437__py3-none-any.whl → 2.1.0.20240726.430__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.1.0.20240724.437.dist-info → mct_nightly-2.1.0.20240726.430.dist-info}/METADATA +1 -1
- {mct_nightly-2.1.0.20240724.437.dist-info → mct_nightly-2.1.0.20240726.430.dist-info}/RECORD +35 -31
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/pytorch/constants.py +6 -1
- model_compression_toolkit/core/pytorch/utils.py +27 -0
- model_compression_toolkit/data_generation/common/data_generation.py +20 -18
- model_compression_toolkit/data_generation/common/data_generation_config.py +8 -11
- model_compression_toolkit/data_generation/common/enums.py +24 -12
- model_compression_toolkit/data_generation/common/image_pipeline.py +50 -12
- model_compression_toolkit/data_generation/common/model_info_exctractors.py +0 -8
- model_compression_toolkit/data_generation/common/optimization_utils.py +7 -11
- model_compression_toolkit/data_generation/keras/constants.py +5 -2
- model_compression_toolkit/data_generation/keras/image_operations.py +189 -0
- model_compression_toolkit/data_generation/keras/image_pipeline.py +50 -104
- model_compression_toolkit/data_generation/keras/keras_data_generation.py +28 -36
- model_compression_toolkit/data_generation/keras/model_info_exctractors.py +0 -13
- model_compression_toolkit/data_generation/keras/optimization_functions/bn_layer_weighting_functions.py +16 -6
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +219 -0
- model_compression_toolkit/data_generation/keras/optimization_functions/output_loss_functions.py +39 -13
- model_compression_toolkit/data_generation/keras/optimization_functions/scheduler_step_functions.py +6 -98
- model_compression_toolkit/data_generation/keras/optimization_utils.py +15 -28
- model_compression_toolkit/data_generation/pytorch/constants.py +4 -1
- model_compression_toolkit/data_generation/pytorch/image_operations.py +105 -0
- model_compression_toolkit/data_generation/pytorch/image_pipeline.py +70 -78
- model_compression_toolkit/data_generation/pytorch/model_info_exctractors.py +0 -10
- model_compression_toolkit/data_generation/pytorch/optimization_functions/bn_layer_weighting_functions.py +17 -6
- model_compression_toolkit/data_generation/pytorch/optimization_functions/image_initilization.py +2 -2
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +219 -0
- model_compression_toolkit/data_generation/pytorch/optimization_functions/output_loss_functions.py +55 -21
- model_compression_toolkit/data_generation/pytorch/optimization_functions/scheduler_step_functions.py +15 -0
- model_compression_toolkit/data_generation/pytorch/optimization_utils.py +32 -54
- model_compression_toolkit/data_generation/pytorch/pytorch_data_generation.py +57 -52
- {mct_nightly-2.1.0.20240724.437.dist-info → mct_nightly-2.1.0.20240726.430.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.1.0.20240724.437.dist-info → mct_nightly-2.1.0.20240726.430.dist-info}/WHEEL +0 -0
- {mct_nightly-2.1.0.20240724.437.dist-info → mct_nightly-2.1.0.20240726.430.dist-info}/top_level.txt +0 -0
@@ -12,36 +12,45 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
import numpy as np
|
16
|
+
import torch
|
17
|
+
from typing import Type, Dict, Tuple, Union, List
|
16
18
|
|
17
19
|
from torch import Tensor
|
18
20
|
from torchvision.transforms import RandomCrop, RandomHorizontalFlip, CenterCrop
|
19
21
|
|
20
|
-
|
21
|
-
from model_compression_toolkit.data_generation.common.enums import ImagePipelineType, ImageNormalizationType
|
22
|
+
from model_compression_toolkit.data_generation.common.enums import ImagePipelineType
|
22
23
|
from model_compression_toolkit.data_generation.common.image_pipeline import BaseImagePipeline
|
24
|
+
from model_compression_toolkit.data_generation.pytorch.image_operations import Smoothing, create_valid_grid
|
23
25
|
|
24
26
|
|
25
27
|
class PytorchIdentityImagePipeline(BaseImagePipeline):
|
26
28
|
"""
|
27
29
|
An image pipeline implementation for PyTorch models that returns the input images as is (identity).
|
28
30
|
"""
|
29
|
-
def __init__(self,
|
31
|
+
def __init__(self,
|
32
|
+
output_image_size: Union[int, Tuple[int, int]],
|
33
|
+
extra_pixels: Union[int, Tuple[int, int]] = 0,
|
34
|
+
normalization: List[List[int]] = [[0, 0, 0], [1, 1, 1]],
|
35
|
+
image_clipping: bool = True,
|
36
|
+
):
|
30
37
|
"""
|
31
38
|
Initialize the PytorchIdentityImagePipeline.
|
32
39
|
|
33
40
|
Args:
|
34
|
-
output_image_size (int): The output image size.
|
35
|
-
extra_pixels (int): Extra pixels to add to the input image size (not used in identity pipeline).
|
41
|
+
output_image_size (Union[int, Tuple[int, int]]): The output image size.
|
42
|
+
extra_pixels (Union[int, Tuple[int, int]]): Extra pixels to add to the input image size (not used in identity pipeline).
|
43
|
+
normalization (List[List[float]]): The image normalization values for processing images during optimization.
|
44
|
+
image_clipping (bool): Whether to clip images during optimization.
|
36
45
|
"""
|
37
|
-
super(PytorchIdentityImagePipeline, self).__init__(output_image_size)
|
46
|
+
super(PytorchIdentityImagePipeline, self).__init__(output_image_size, extra_pixels, image_clipping, normalization)
|
38
47
|
|
39
|
-
def get_image_input_size(self) -> int:
|
48
|
+
def get_image_input_size(self) -> Tuple[int, int]:
|
40
49
|
"""
|
41
50
|
Get the input size of the image.
|
42
51
|
|
43
52
|
Returns:
|
44
|
-
int: The input image size.
|
53
|
+
Tuple[int, int]: The input image size.
|
45
54
|
"""
|
46
55
|
return self.output_image_size
|
47
56
|
|
@@ -70,31 +79,43 @@ class PytorchIdentityImagePipeline(BaseImagePipeline):
|
|
70
79
|
return images
|
71
80
|
|
72
81
|
|
73
|
-
class
|
82
|
+
class PytorchSmoothAugmentationImagePipeline(BaseImagePipeline):
|
74
83
|
"""
|
75
|
-
An image pipeline implementation for PyTorch models that includes random cropping.
|
84
|
+
An image pipeline implementation for PyTorch models that includes random cropping and flipping.
|
76
85
|
"""
|
77
|
-
def __init__(self,
|
86
|
+
def __init__(self,
|
87
|
+
output_image_size: Union[int, Tuple[int, int]],
|
88
|
+
extra_pixels: Union[int, Tuple[int, int]] = 0,
|
89
|
+
normalization: List[List[int]] = [[0, 0, 0], [1, 1, 1]],
|
90
|
+
image_clipping: bool = True,
|
91
|
+
smoothing_filter_size: int = 3,
|
92
|
+
smoothing_filter_sigma: float = 1.25):
|
78
93
|
"""
|
79
94
|
Initialize the PytorchRandomCropFlipImagePipeline.
|
80
95
|
|
81
96
|
Args:
|
82
|
-
output_image_size (int): The output image size.
|
83
|
-
extra_pixels (int): Extra pixels to add to the input image size. Defaults to 0.
|
84
|
-
|
85
|
-
|
86
|
-
|
97
|
+
output_image_size (Union[int, Tuple[int, int]]): The output image size.
|
98
|
+
extra_pixels (Union[int, Tuple[int, int]]): Extra pixels to add to the input image size. Defaults to 0.
|
99
|
+
normalization (List[List[float]]): The image normalization values for processing images during optimization.
|
100
|
+
image_clipping (bool): Whether to clip images during optimization.
|
101
|
+
smoothing_filter_size (int): The size of the smoothing filter. Defaults to 3.
|
102
|
+
smoothing_filter_sigma (float): The standard deviation of the smoothing filter. Defaults to 1.25.
|
103
|
+
"""
|
104
|
+
super(PytorchSmoothAugmentationImagePipeline, self).__init__(output_image_size, extra_pixels, image_clipping, normalization)
|
105
|
+
self.smoothing = Smoothing(size=smoothing_filter_size, sigma=smoothing_filter_sigma)
|
87
106
|
self.random_crop = RandomCrop(self.output_image_size)
|
107
|
+
self.random_flip = RandomHorizontalFlip(0.5)
|
88
108
|
self.center_crop = CenterCrop(self.output_image_size)
|
109
|
+
self.valid_grid = create_valid_grid(means=self.normalization[0], stds=self.normalization[1])
|
89
110
|
|
90
|
-
def get_image_input_size(self) -> int:
|
111
|
+
def get_image_input_size(self) -> Tuple[int, int]:
|
91
112
|
"""
|
92
113
|
Get the input size of the image.
|
93
114
|
|
94
115
|
Returns:
|
95
|
-
int: The input image size.
|
116
|
+
Tuple[int, int]: The input image size.
|
96
117
|
"""
|
97
|
-
return self.output_image_size
|
118
|
+
return tuple([o + e for (o, e) in zip(self.output_image_size, self.extra_pixels)])
|
98
119
|
|
99
120
|
def image_input_manipulation(self, images: Tensor) -> Tensor:
|
100
121
|
"""
|
@@ -104,9 +125,14 @@ class PytorchRandomCropImagePipeline(BaseImagePipeline):
|
|
104
125
|
images (Tensor): The input images.
|
105
126
|
|
106
127
|
Returns:
|
107
|
-
Tensor: The manipulated images
|
128
|
+
Tensor: The manipulated images.
|
108
129
|
"""
|
109
|
-
|
130
|
+
new_images = self.random_flip(images)
|
131
|
+
new_images = self.smoothing(new_images)
|
132
|
+
new_images = self.random_crop(new_images)
|
133
|
+
if self.image_clipping:
|
134
|
+
new_images = self.clip_images(new_images, self.valid_grid)
|
135
|
+
return new_images
|
110
136
|
|
111
137
|
def image_output_finalize(self, images: Tensor) -> Tensor:
|
112
138
|
"""
|
@@ -118,71 +144,37 @@ class PytorchRandomCropImagePipeline(BaseImagePipeline):
|
|
118
144
|
Returns:
|
119
145
|
Tensor: The finalized images (center cropped).
|
120
146
|
"""
|
121
|
-
|
122
|
-
|
147
|
+
new_images = self.smoothing(images)
|
148
|
+
new_images = self.center_crop(new_images)
|
149
|
+
if self.image_clipping:
|
150
|
+
new_images = self.clip_images(new_images, self.valid_grid)
|
151
|
+
return new_images
|
123
152
|
|
124
|
-
|
125
|
-
|
126
|
-
An image pipeline implementation for PyTorch models that includes random cropping and flipping.
|
127
|
-
"""
|
128
|
-
def __init__(self, output_image_size: int, extra_pixels: int = 0):
|
153
|
+
@staticmethod
|
154
|
+
def clip_images(images: Tensor, valid_grid: Tensor, reflection: bool = False) -> Tensor:
|
129
155
|
"""
|
130
|
-
|
156
|
+
Clip the images based on a valid grid.
|
131
157
|
|
132
158
|
Args:
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
super(PytorchRandomCropFlipImagePipeline, self).__init__(output_image_size)
|
137
|
-
self.extra_pixels = extra_pixels
|
138
|
-
self.random_crop = RandomCrop(self.output_image_size)
|
139
|
-
self.random_flip = RandomHorizontalFlip(0.5)
|
140
|
-
self.center_crop = CenterCrop(self.output_image_size)
|
141
|
-
|
142
|
-
def get_image_input_size(self) -> int:
|
143
|
-
"""
|
144
|
-
Get the input size of the image.
|
159
|
+
images (Tensor): The images to be clipped.
|
160
|
+
valid_grid (Tensor): The valid grid for clipping.
|
161
|
+
reflection (bool): Whether to apply reflection during clipping. Defaults to False.
|
145
162
|
|
146
163
|
Returns:
|
147
|
-
|
148
|
-
"""
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
Returns:
|
159
|
-
Tensor: The manipulated images (randomly flipped and cropped).
|
160
|
-
"""
|
161
|
-
random_flipped_data = self.random_flip(images)
|
162
|
-
return self.random_crop(random_flipped_data)
|
163
|
-
|
164
|
-
def image_output_finalize(self, images: Tensor) -> Tensor:
|
165
|
-
"""
|
166
|
-
Finalize the output images with center cropping.
|
167
|
-
|
168
|
-
Args:
|
169
|
-
images (Tensor): The output images.
|
170
|
-
|
171
|
-
Returns:
|
172
|
-
Tensor: The finalized images (center cropped).
|
173
|
-
"""
|
174
|
-
return self.center_crop(images)
|
164
|
+
Tensor: The clipped images.
|
165
|
+
"""
|
166
|
+
with torch.no_grad():
|
167
|
+
for i_ch in range(valid_grid.shape[0]):
|
168
|
+
clamp = torch.clamp(images[:, i_ch, :, :], valid_grid[i_ch, :].min(), valid_grid[i_ch, :].max())
|
169
|
+
if reflection:
|
170
|
+
images[:, i_ch, :, :] = 2 * clamp - images[:, i_ch, :, :]
|
171
|
+
else:
|
172
|
+
images[:, i_ch, :, :] = clamp
|
173
|
+
return images
|
175
174
|
|
176
175
|
|
177
176
|
# Dictionary mapping ImagePipelineType to corresponding image pipeline classes
|
178
177
|
image_pipeline_dict: Dict[ImagePipelineType, Type[BaseImagePipeline]] = {
|
179
178
|
ImagePipelineType.IDENTITY: PytorchIdentityImagePipeline,
|
180
|
-
ImagePipelineType.
|
181
|
-
ImagePipelineType.RANDOM_CROP_FLIP: PytorchRandomCropFlipImagePipeline
|
182
|
-
}
|
183
|
-
|
184
|
-
# Dictionary mapping ImageNormalizationType to corresponding normalization values
|
185
|
-
image_normalization_dict: Dict[ImageNormalizationType, List[List[float]]] = {
|
186
|
-
ImageNormalizationType.TORCHVISION: [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]],
|
187
|
-
ImageNormalizationType.NO_NORMALIZATION: [[0, 0, 0], [1, 1, 1]]
|
179
|
+
ImagePipelineType.SMOOTHING_AND_AUGMENTATION: PytorchSmoothAugmentationImagePipeline
|
188
180
|
}
|
@@ -126,7 +126,6 @@ class PytorchActivationExtractor(ActivationExtractor):
|
|
126
126
|
self.layer_types_to_extract_inputs = tuple(layer_types_to_extract_inputs)
|
127
127
|
self.last_layer_types_to_extract_inputs = tuple(last_layer_types_to_extract_inputs)
|
128
128
|
self.num_layers = sum([1 if isinstance(layer, tuple(layer_types_to_extract_inputs)) else 0 for layer in model.modules()])
|
129
|
-
Logger.info(f'Number of layers = {self.num_layers}')
|
130
129
|
self.hooks = {} # Dictionary to store InputHook instances by layer name
|
131
130
|
self.last_linear_layers_hooks = {} # Dictionary to store InputHook instances by layer name
|
132
131
|
self.hook_handles = [] # List to store hook handles
|
@@ -206,15 +205,6 @@ class PytorchActivationExtractor(ActivationExtractor):
|
|
206
205
|
"""
|
207
206
|
return self.last_linear_layer_weights
|
208
207
|
|
209
|
-
def get_num_extractor_layers(self) -> int:
|
210
|
-
"""
|
211
|
-
Get the number of hooked layers in the model.
|
212
|
-
|
213
|
-
Returns:
|
214
|
-
int: Number of hooked layers in the model.
|
215
|
-
"""
|
216
|
-
return self.num_layers
|
217
|
-
|
218
208
|
def get_extractor_layer_names(self) -> List:
|
219
209
|
"""
|
220
210
|
Get a list of the hooked layer names.
|
@@ -18,16 +18,21 @@ import torch
|
|
18
18
|
|
19
19
|
from model_compression_toolkit.data_generation.common.enums import BNLayerWeightingType
|
20
20
|
from model_compression_toolkit.data_generation.pytorch.model_info_exctractors import OriginalBNStatsHolder, \
|
21
|
-
ActivationExtractor
|
21
|
+
ActivationExtractor, PytorchActivationExtractor
|
22
22
|
|
23
23
|
|
24
|
-
def average_bn_layer_weighting_fn(orig_bn_stats_holder: OriginalBNStatsHolder,
|
24
|
+
def average_bn_layer_weighting_fn(orig_bn_stats_holder: OriginalBNStatsHolder,
|
25
|
+
activation_extractor: PytorchActivationExtractor,
|
26
|
+
i_iter: int,
|
27
|
+
n_iter: int) -> Dict[str, float]:
|
25
28
|
"""
|
26
29
|
Calculate average weighting for each batch normalization layer.
|
27
30
|
|
28
31
|
Args:
|
29
32
|
orig_bn_stats_holder (OriginalBNStatsHolder): Holder for original batch normalization statistics.
|
30
|
-
|
33
|
+
activation_extractor (PytorchActivationExtractor): The activation extractor for the model.
|
34
|
+
i_iter (int): Current optimization iteration.
|
35
|
+
n_iter (int): Total number of optimization iterations.
|
31
36
|
|
32
37
|
Returns:
|
33
38
|
Dict[str, float]: A dictionary containing layer names as keys and average weightings as values.
|
@@ -35,19 +40,25 @@ def average_bn_layer_weighting_fn(orig_bn_stats_holder: OriginalBNStatsHolder, *
|
|
35
40
|
num_bn_layers = orig_bn_stats_holder.get_num_bn_layers()
|
36
41
|
return {bn_layer_name: 1 / num_bn_layers for bn_layer_name in orig_bn_stats_holder.get_bn_layer_names()}
|
37
42
|
|
38
|
-
def first_bn_multiplier_weighting_fn(orig_bn_stats_holder: OriginalBNStatsHolder,
|
43
|
+
def first_bn_multiplier_weighting_fn(orig_bn_stats_holder: OriginalBNStatsHolder,
|
44
|
+
activation_extractor: PytorchActivationExtractor,
|
45
|
+
i_iter: int,
|
46
|
+
n_iter: int) -> Dict[str, float]:
|
39
47
|
"""
|
40
48
|
Calculate layer weightings with a higher multiplier for the first batch normalization layer.
|
41
49
|
|
42
50
|
Args:
|
43
51
|
orig_bn_stats_holder (OriginalBNStatsHolder): Holder for original batch normalization statistics.
|
44
|
-
|
52
|
+
activation_extractor (PytorchActivationExtractor): The activation extractor for the model.
|
53
|
+
i_iter (int): Current optimization iteration.
|
54
|
+
n_iter (int): Total number of optimization iterations.
|
45
55
|
|
46
56
|
Returns:
|
47
57
|
Dict[str, float]: A dictionary containing layer names as keys and weightings as values.
|
48
58
|
"""
|
49
59
|
layer_weighting_dict = {orig_bn_stats_holder.get_bn_layer_names()[0]: 10}
|
50
|
-
|
60
|
+
layer_weighting_dict.update({bn_layer_name: 1 for bn_layer_name in orig_bn_stats_holder.get_bn_layer_names()[1:]})
|
61
|
+
return layer_weighting_dict
|
51
62
|
|
52
63
|
|
53
64
|
# Dictionary of layer weighting functions
|
model_compression_toolkit/data_generation/pytorch/optimization_functions/image_initilization.py
CHANGED
@@ -104,7 +104,7 @@ def diverse_sample(size: Tuple[int, ...]) -> Tensor:
|
|
104
104
|
|
105
105
|
def default_data_init_fn(
|
106
106
|
n_images: int = 1000,
|
107
|
-
size:
|
107
|
+
size: Union[int, Tuple[int, int]] = (224, 224),
|
108
108
|
crop: int = 32,
|
109
109
|
sample_fn: Callable = diverse_sample,
|
110
110
|
batch_size: int = 50) -> Tuple[int, DataLoader]:
|
@@ -113,7 +113,7 @@ def default_data_init_fn(
|
|
113
113
|
|
114
114
|
Args:
|
115
115
|
n_images (int): The number of random samples.
|
116
|
-
size (Tuple[int, int]): The size of each sample.
|
116
|
+
size (Union[int, Tuple[int, int]]): The size of each sample.
|
117
117
|
crop (int): The crop size.
|
118
118
|
sample_fn (Callable): The function to generate a random sample.
|
119
119
|
batch_size (int): The batch size.
|
@@ -0,0 +1,219 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
from torch.optim.optimizer import Optimizer
|
16
|
+
from torch import inf
|
17
|
+
from typing import Union, List, Dict, Any
|
18
|
+
|
19
|
+
from model_compression_toolkit.logger import Logger
|
20
|
+
|
21
|
+
|
22
|
+
class ReduceLROnPlateauWithReset:
|
23
|
+
"""
|
24
|
+
Reduce learning rate when a metric has stopped improving. This scheduler allows resetting
|
25
|
+
the learning rate to the initial value after a specified number of bad epochs.
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(self, optimizer: Optimizer, mode: str = 'min', factor: float = 0.1, patience: int = 10,
|
29
|
+
threshold: float = 1e-4, threshold_mode: str = 'rel', cooldown: int = 0,
|
30
|
+
min_lr: Union[float, List[float]] = 0, eps: float = 1e-8, verbose: bool = False):
|
31
|
+
"""
|
32
|
+
Initialize the ReduceLROnPlateauWithReset scheduler.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
optimizer (Optimizer): Wrapped optimizer.
|
36
|
+
mode (str): One of `min`, `max`. In `min` mode, lr will be reduced when the quantity
|
37
|
+
monitored has stopped decreasing; in `max` mode it will be reduced when the
|
38
|
+
quantity monitored has stopped increasing. Default: 'min'.
|
39
|
+
factor (float): Factor by which the learning rate will be reduced. new_lr = lr * factor.
|
40
|
+
Default: 0.1.
|
41
|
+
patience (int): Number of epochs with no improvement after which learning rate will be reduced.
|
42
|
+
Default: 10.
|
43
|
+
threshold (float): Threshold for measuring the new optimum, to only focus on significant changes.
|
44
|
+
Default: 1e-4.
|
45
|
+
threshold_mode (str): One of `rel`, `abs`. In `rel` mode, dynamic_threshold = best * ( 1 + threshold )
|
46
|
+
in 'max' mode or best * ( 1 - threshold ) in `min` mode. In `abs` mode, dynamic_threshold
|
47
|
+
= best + threshold in `max` mode or best - threshold in `min` mode. Default: 'rel'.
|
48
|
+
cooldown (int): Number of epochs to wait before resuming normal operation after lr has been reduced.
|
49
|
+
Default: 0.
|
50
|
+
min_lr (float or list): A scalar or a list of scalars. A lower bound on the learning rate of all param groups
|
51
|
+
or each group respectively. Default: 0.
|
52
|
+
eps (float): Minimal decay applied to lr. If the difference between new and old lr is smaller than eps,
|
53
|
+
the update is ignored. Default: 1e-8.
|
54
|
+
verbose (bool): If True, prints a message to stdout for each update. Default: False.
|
55
|
+
"""
|
56
|
+
if factor >= 1.0:
|
57
|
+
Logger.critical('Factor should be < 1.0.') # pragma: no cover
|
58
|
+
self.factor = factor
|
59
|
+
|
60
|
+
# Attach optimizer
|
61
|
+
if not isinstance(optimizer, Optimizer):
|
62
|
+
Logger.critical('{} is not an Optimizer'.format(
|
63
|
+
type(optimizer).__name__)) # pragma: no cover
|
64
|
+
self.optimizer = optimizer
|
65
|
+
|
66
|
+
if isinstance(min_lr, (list, tuple)):
|
67
|
+
if len(min_lr) != len(optimizer.param_groups):
|
68
|
+
Logger.critical("expected {} min_lrs, got {}".format(
|
69
|
+
len(optimizer.param_groups), len(min_lr))) # pragma: no cover
|
70
|
+
self.min_lrs = list(min_lr)
|
71
|
+
else:
|
72
|
+
self.min_lrs = [min_lr] * len(optimizer.param_groups)
|
73
|
+
|
74
|
+
self.patience = patience
|
75
|
+
self.verbose = verbose
|
76
|
+
self.cooldown = cooldown
|
77
|
+
self.cooldown_counter = 0
|
78
|
+
self.mode = mode
|
79
|
+
self.threshold = threshold
|
80
|
+
self.threshold_mode = threshold_mode
|
81
|
+
self.best = None
|
82
|
+
self.num_bad_epochs = None
|
83
|
+
self.mode_worse = None # the worse value for the chosen mode
|
84
|
+
self.eps = eps
|
85
|
+
self.last_epoch = 0
|
86
|
+
|
87
|
+
self._init_is_better()
|
88
|
+
self._reset()
|
89
|
+
|
90
|
+
def _reset(self) -> None:
|
91
|
+
"""
|
92
|
+
Resets num_bad_epochs counter and cooldown counter.
|
93
|
+
"""
|
94
|
+
self.best = self.mode_worse
|
95
|
+
self.cooldown_counter = 0
|
96
|
+
self.num_bad_epochs = 0
|
97
|
+
|
98
|
+
def step(self, metrics: float, epoch: Union[int, None] = None) -> None:
|
99
|
+
"""
|
100
|
+
Update learning rate based on the given metrics.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
metrics (float): The value of the metric to evaluate.
|
104
|
+
epoch (int, optional): The current epoch number. If not provided, it is incremented.
|
105
|
+
"""
|
106
|
+
# Convert `metrics` to float, in case it's a zero-dim Tensor
|
107
|
+
current = float(metrics)
|
108
|
+
if epoch is None:
|
109
|
+
epoch = self.last_epoch + 1
|
110
|
+
self.last_epoch = epoch
|
111
|
+
|
112
|
+
# Check if the current metrics are better than the best
|
113
|
+
if self.is_better(current, self.best):
|
114
|
+
self.best = current
|
115
|
+
self.num_bad_epochs = 0
|
116
|
+
else:
|
117
|
+
self.num_bad_epochs += 1
|
118
|
+
|
119
|
+
# Handle cooldown period
|
120
|
+
if self.in_cooldown:
|
121
|
+
self.cooldown_counter -= 1
|
122
|
+
self.num_bad_epochs = 0 # Ignore any bad epochs in cooldown
|
123
|
+
|
124
|
+
# Reduce learning rate if the number of bad epochs exceeds patience
|
125
|
+
if self.num_bad_epochs > self.patience:
|
126
|
+
self._reduce_lr(epoch)
|
127
|
+
self.cooldown_counter = self.cooldown
|
128
|
+
self.num_bad_epochs = 0
|
129
|
+
self.best = self.mode_worse
|
130
|
+
|
131
|
+
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
|
132
|
+
|
133
|
+
def _reduce_lr(self, epoch: int) -> None:
|
134
|
+
"""
|
135
|
+
Reduce the learning rate for each parameter group.
|
136
|
+
|
137
|
+
Args:
|
138
|
+
epoch (int): The current epoch number.
|
139
|
+
"""
|
140
|
+
for i, param_group in enumerate(self.optimizer.param_groups):
|
141
|
+
old_lr = float(param_group['lr'])
|
142
|
+
new_lr = max(old_lr * self.factor, self.min_lrs[i])
|
143
|
+
if old_lr - new_lr > self.eps:
|
144
|
+
param_group['lr'] = new_lr
|
145
|
+
if self.verbose:
|
146
|
+
epoch_str = ("%.2f" if isinstance(epoch, float) else "%.5d") % epoch
|
147
|
+
print('Epoch {}: reducing learning rate'
|
148
|
+
' of group {} to {:.4e}.'.format(epoch_str, i, new_lr))
|
149
|
+
|
150
|
+
@property
|
151
|
+
def in_cooldown(self) -> bool:
|
152
|
+
"""
|
153
|
+
Check if the scheduler is in a cooldown period.
|
154
|
+
|
155
|
+
Returns:
|
156
|
+
bool: True if in cooldown period, False otherwise.
|
157
|
+
"""
|
158
|
+
return self.cooldown_counter > 0
|
159
|
+
|
160
|
+
def is_better(self, a: float, best: Union[float, None]) -> bool:
|
161
|
+
"""
|
162
|
+
Determine if the new value is better than the best value based on mode and threshold.
|
163
|
+
|
164
|
+
Args:
|
165
|
+
a (float): The new value to compare.
|
166
|
+
best (float): The best value to compare against.
|
167
|
+
|
168
|
+
Returns:
|
169
|
+
bool: True if the new value is better, False otherwise.
|
170
|
+
"""
|
171
|
+
if best is None:
|
172
|
+
return True
|
173
|
+
|
174
|
+
if self.mode == 'min' and self.threshold_mode == 'rel':
|
175
|
+
rel_epsilon = 1. - self.threshold
|
176
|
+
return a < best * rel_epsilon
|
177
|
+
elif self.mode == 'min' and self.threshold_mode == 'abs':
|
178
|
+
return a < best - self.threshold
|
179
|
+
elif self.mode == 'max' and self.threshold_mode == 'rel':
|
180
|
+
rel_epsilon = self.threshold + 1.
|
181
|
+
return a > best * rel_epsilon
|
182
|
+
else: # mode == 'max' and threshold_mode == 'abs':
|
183
|
+
return a > best + self.threshold
|
184
|
+
|
185
|
+
def _init_is_better(self) -> None:
|
186
|
+
"""
|
187
|
+
Initialize the comparison function for determining if a new value is better.
|
188
|
+
|
189
|
+
Raises:
|
190
|
+
ValueError: If an unknown mode or threshold mode is provided.
|
191
|
+
"""
|
192
|
+
if self.mode not in {'min', 'max'}:
|
193
|
+
Logger.critical('mode ' + self.mode + ' is unknown!') # pragma: no cover
|
194
|
+
if self.threshold_mode not in {'rel', 'abs'}:
|
195
|
+
Logger.critical('threshold mode ' + self.threshold_mode + ' is unknown!') # pragma: no cover
|
196
|
+
|
197
|
+
if self.mode == 'min':
|
198
|
+
self.mode_worse = float('inf')
|
199
|
+
else: # mode == 'max':
|
200
|
+
self.mode_worse = float('-inf')
|
201
|
+
|
202
|
+
def state_dict(self) -> Dict[str, Any]:
|
203
|
+
"""
|
204
|
+
Return the state of the scheduler as a dictionary.
|
205
|
+
|
206
|
+
Returns:
|
207
|
+
dict: The state of the scheduler.
|
208
|
+
"""
|
209
|
+
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
|
210
|
+
|
211
|
+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
212
|
+
"""
|
213
|
+
Load the scheduler state.
|
214
|
+
|
215
|
+
Args:
|
216
|
+
state_dict (dict): The state dictionary to load.
|
217
|
+
"""
|
218
|
+
self.__dict__.update(state_dict)
|
219
|
+
self._init_is_better()
|