reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +17 -17
- reflectorch/data_generation/__init__.py +128 -126
- reflectorch/data_generation/dataset.py +210 -210
- reflectorch/data_generation/likelihoods.py +80 -80
- reflectorch/data_generation/noise.py +470 -470
- reflectorch/data_generation/priors/__init__.py +60 -60
- reflectorch/data_generation/priors/base.py +55 -55
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
- reflectorch/data_generation/priors/independent_priors.py +195 -195
- reflectorch/data_generation/priors/multilayer_models.py +311 -311
- reflectorch/data_generation/priors/multilayer_structures.py +104 -104
- reflectorch/data_generation/priors/no_constraints.py +206 -206
- reflectorch/data_generation/priors/parametric_models.py +841 -841
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
- reflectorch/data_generation/priors/params.py +252 -252
- reflectorch/data_generation/priors/sampler_strategies.py +369 -369
- reflectorch/data_generation/priors/scaler_mixin.py +65 -65
- reflectorch/data_generation/priors/subprior_sampler.py +371 -371
- reflectorch/data_generation/priors/utils.py +118 -118
- reflectorch/data_generation/process_data.py +41 -41
- reflectorch/data_generation/q_generator.py +280 -246
- reflectorch/data_generation/reflectivity/__init__.py +102 -102
- reflectorch/data_generation/reflectivity/abeles.py +97 -97
- reflectorch/data_generation/reflectivity/kinematical.py +70 -70
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
- reflectorch/data_generation/reflectivity/smearing.py +138 -138
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
- reflectorch/data_generation/scale_curves.py +112 -112
- reflectorch/data_generation/smearing.py +98 -98
- reflectorch/data_generation/utils.py +223 -222
- reflectorch/extensions/jupyter/__init__.py +11 -6
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -34
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -5
- reflectorch/extensions/matplotlib/losses.py +32 -32
- reflectorch/extensions/refnx/refnx_conversion.py +76 -76
- reflectorch/inference/__init__.py +28 -24
- reflectorch/inference/inference_model.py +847 -851
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +37 -0
- reflectorch/inference/multilayer_fitter.py +171 -171
- reflectorch/inference/multilayer_inference_model.py +193 -193
- reflectorch/inference/plotting.py +524 -98
- reflectorch/inference/preprocess_exp/__init__.py +6 -6
- reflectorch/inference/preprocess_exp/attenuation.py +36 -36
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
- reflectorch/inference/preprocess_exp/footprint.py +81 -81
- reflectorch/inference/preprocess_exp/interpolation.py +19 -16
- reflectorch/inference/preprocess_exp/normalize.py +21 -21
- reflectorch/inference/preprocess_exp/preprocess.py +121 -121
- reflectorch/inference/query_matcher.py +81 -81
- reflectorch/inference/record_time.py +43 -43
- reflectorch/inference/sampler_solution.py +56 -56
- reflectorch/inference/scipy_fitter.py +272 -248
- reflectorch/inference/torch_fitter.py +87 -87
- reflectorch/ml/__init__.py +32 -32
- reflectorch/ml/basic_trainer.py +292 -292
- reflectorch/ml/callbacks.py +80 -80
- reflectorch/ml/dataloaders.py +26 -26
- reflectorch/ml/loggers.py +55 -55
- reflectorch/ml/schedulers.py +355 -355
- reflectorch/ml/trainers.py +200 -191
- reflectorch/ml/utils.py +2 -2
- reflectorch/models/__init__.py +15 -14
- reflectorch/models/activations.py +50 -50
- reflectorch/models/encoders/__init__.py +19 -17
- reflectorch/models/encoders/conv_encoder.py +218 -218
- reflectorch/models/encoders/conv_res_net.py +115 -115
- reflectorch/models/encoders/fno.py +133 -133
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -14
- reflectorch/models/networks/mlp_networks.py +434 -428
- reflectorch/models/networks/residual_net.py +156 -156
- reflectorch/paths.py +29 -27
- reflectorch/runs/__init__.py +31 -31
- reflectorch/runs/config.py +25 -25
- reflectorch/runs/slurm_utils.py +93 -93
- reflectorch/runs/train.py +78 -78
- reflectorch/runs/utils.py +404 -401
- reflectorch/test_config.py +4 -4
- reflectorch/train.py +4 -4
- reflectorch/train_on_cluster.py +4 -4
- reflectorch/utils.py +98 -68
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
- reflectorch-1.5.0.dist-info/RECORD +96 -0
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
- reflectorch-1.3.0.dist-info/RECORD +0 -86
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
reflectorch/ml/schedulers.py
CHANGED
|
@@ -1,356 +1,356 @@
|
|
|
1
|
-
import math
|
|
2
|
-
from torch.optim import lr_scheduler
|
|
3
|
-
|
|
4
|
-
import numpy as np
|
|
5
|
-
|
|
6
|
-
from reflectorch.ml.basic_trainer import Trainer, TrainerCallback, PeriodicTrainerCallback
|
|
7
|
-
|
|
8
|
-
__all__ = [
|
|
9
|
-
'ScheduleBatchSize',
|
|
10
|
-
'ScheduleLR',
|
|
11
|
-
'StepLR',
|
|
12
|
-
'CyclicLR',
|
|
13
|
-
'LogCyclicLR',
|
|
14
|
-
'ReduceLROnPlateau',
|
|
15
|
-
'OneCycleLR',
|
|
16
|
-
'CosineAnnealingWithWarmup',
|
|
17
|
-
]
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class ScheduleBatchSize(PeriodicTrainerCallback):
|
|
21
|
-
"""Batch size scheduler
|
|
22
|
-
|
|
23
|
-
Args:
|
|
24
|
-
step (int): number of iterations after which the batch size is modified.
|
|
25
|
-
gamma (int, optional): quantity which is added to or multiplied with the current batch size. Defaults to 2.
|
|
26
|
-
last_epoch (int, optional): the last training iteration for which the batch size is modified. Defaults to -1.
|
|
27
|
-
mode (str, optional): ``'add'`` for addition or ``'multiply'`` for multiplication. Defaults to 'add'.
|
|
28
|
-
"""
|
|
29
|
-
def __init__(self, step: int, gamma: int = 2, last_epoch: int = -1, mode: str = 'add'):
|
|
30
|
-
super().__init__(step, last_epoch)
|
|
31
|
-
|
|
32
|
-
assert mode in ('add', 'multiply')
|
|
33
|
-
|
|
34
|
-
self.gamma = gamma
|
|
35
|
-
self.mode = mode
|
|
36
|
-
|
|
37
|
-
def _end_batch(self, trainer: Trainer, batch_num: int) -> None:
|
|
38
|
-
if self.mode == 'add':
|
|
39
|
-
trainer.batch_size += self.gamma
|
|
40
|
-
elif self.mode == 'multiply':
|
|
41
|
-
trainer.batch_size *= self.gamma
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
class ScheduleLR(TrainerCallback):
|
|
45
|
-
"""Base class for learning rate schedulers
|
|
46
|
-
|
|
47
|
-
Args:
|
|
48
|
-
lr_scheduler_cls: class of the learning rate scheduler
|
|
49
|
-
"""
|
|
50
|
-
|
|
51
|
-
def __init__(self, lr_scheduler_cls, **kwargs):
|
|
52
|
-
self.lr_scheduler_cls = lr_scheduler_cls
|
|
53
|
-
self.kwargs = kwargs
|
|
54
|
-
self.lr_scheduler = None
|
|
55
|
-
|
|
56
|
-
def start_training(self, trainer: Trainer) -> None:
|
|
57
|
-
"""initializes a learning rate scheduler based on its class and keyword arguments at the start of training
|
|
58
|
-
|
|
59
|
-
Args:
|
|
60
|
-
trainer (Trainer): the trainer object
|
|
61
|
-
"""
|
|
62
|
-
self.lr_scheduler = self.lr_scheduler_cls(trainer.optim, **self.kwargs)
|
|
63
|
-
trainer.callback_params['lrs'] = []
|
|
64
|
-
|
|
65
|
-
def end_batch(self, trainer: Trainer, batch_num: int) -> None:
|
|
66
|
-
"""modifies the learning rate at the end of each iteration
|
|
67
|
-
|
|
68
|
-
Args:
|
|
69
|
-
trainer (Trainer): the trainer object
|
|
70
|
-
batch_num (int): index of the current iteration
|
|
71
|
-
"""
|
|
72
|
-
self.lr_scheduler.step()
|
|
73
|
-
|
|
74
|
-
def simulate_and_plot(self, total_steps: int, initial_lr: float, log_scale: bool = False):
|
|
75
|
-
import torch
|
|
76
|
-
import matplotlib.pyplot as plt
|
|
77
|
-
|
|
78
|
-
dummy_optim = torch.optim.Adam([torch.zeros(1)], lr=initial_lr)
|
|
79
|
-
scheduler = self.lr_scheduler_cls(dummy_optim, **self.kwargs)
|
|
80
|
-
|
|
81
|
-
lrs = []
|
|
82
|
-
for step in range(total_steps):
|
|
83
|
-
lrs.append(dummy_optim.param_groups[0]['lr'])
|
|
84
|
-
scheduler.step()
|
|
85
|
-
|
|
86
|
-
plt.figure(figsize=(10, 6))
|
|
87
|
-
plt.plot(lrs, label='Learning Rate')
|
|
88
|
-
plt.xlabel('Steps')
|
|
89
|
-
plt.ylabel('Learning Rate')
|
|
90
|
-
plt.title('Learning Rate Schedule')
|
|
91
|
-
|
|
92
|
-
if log_scale:
|
|
93
|
-
plt.yscale('log')
|
|
94
|
-
|
|
95
|
-
plt.grid(True, which="both", linestyle='--', linewidth=0.5)
|
|
96
|
-
plt.legend()
|
|
97
|
-
plt.show()
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
class StepLR(ScheduleLR):
|
|
101
|
-
"""Learning rate scheduler which decays the learning rate of each parameter group by gamma every ``step_size`` epochs.
|
|
102
|
-
|
|
103
|
-
Args:
|
|
104
|
-
step_size (int): Period of learning rate decay
|
|
105
|
-
gamma (float): Multiplicative factor of learning rate decay
|
|
106
|
-
last_epoch (int, optional): The index of last iteration. Defaults to -1.
|
|
107
|
-
"""
|
|
108
|
-
def __init__(self, step_size: int, gamma: float, last_epoch: int = -1, **kwargs):
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
super().__init__(lr_scheduler.StepLR, step_size=step_size, gamma=gamma, last_epoch=last_epoch, **kwargs)
|
|
112
|
-
|
|
113
|
-
def start_training(self, trainer: Trainer) -> None:
|
|
114
|
-
trainer.optim.param_groups[0]['initial_lr'] = trainer.lr()
|
|
115
|
-
super().start_training(trainer)
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
class CyclicLR(ScheduleLR):
|
|
119
|
-
"""Cyclic learning rate scheduler
|
|
120
|
-
|
|
121
|
-
Args:
|
|
122
|
-
base_lr (float): Initial learning rate which is the lower boundary in the cycle
|
|
123
|
-
max_lr (float): Upper learning rate boundary in the cycle
|
|
124
|
-
step_size_up (int, optional): Number of training iterations in the increasing half of a cycle. Defaults to 2000.
|
|
125
|
-
cycle_momentum (bool, optional): If True, momentum is cycled inversely to learning rate between ``base_momentum`` and ``max_momentum``. Defaults to False.
|
|
126
|
-
gamma (float, optional): Constant in ``‘exp_range’`` mode scaling function: gamma^(cycle iterations). Defaults to 1.
|
|
127
|
-
mode (str, optional): One of: ``'triangular'`` (a basic triangular cycle without amplitude scaling),
|
|
128
|
-
``'triangular2'`` (a basic triangular cycle that scales initial amplitude by half each cycle), ``'exp_range'``
|
|
129
|
-
(a cycle that scales initial amplitude by gamma^iterations at each cycle iteration). Defaults to 'triangular'.
|
|
130
|
-
"""
|
|
131
|
-
def __init__(self, base_lr, max_lr, step_size_up: int = 2000,
|
|
132
|
-
cycle_momentum: bool = False, gamma: float = 1., mode: str = 'triangular',
|
|
133
|
-
**kwargs):
|
|
134
|
-
super().__init__(
|
|
135
|
-
lr_scheduler.CyclicLR,
|
|
136
|
-
base_lr=base_lr,
|
|
137
|
-
max_lr=max_lr,
|
|
138
|
-
step_size_up=step_size_up,
|
|
139
|
-
cycle_momentum=cycle_momentum,
|
|
140
|
-
gamma=gamma,
|
|
141
|
-
mode=mode,
|
|
142
|
-
**kwargs
|
|
143
|
-
)
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
class LogCyclicLR(TrainerCallback):
|
|
147
|
-
"""Cyclic learning rate scheduler on a logarithmic scale
|
|
148
|
-
|
|
149
|
-
Args:
|
|
150
|
-
base_lr (float): Lower learning rate boundary in the cycle
|
|
151
|
-
max_lr (float): Upper learning rate boundary in the cycle
|
|
152
|
-
period (int, optional): Number of training iterations in the cycle. Defaults to 2000.
|
|
153
|
-
gamma (float, optional): Constant for scaling the amplitude as ``gamma`` ^ ``iterations``. Defaults to 1.
|
|
154
|
-
start_period (int, optional): Number of starting iterations with the default learning rate.
|
|
155
|
-
log (bool, optional): If ``True``, the cycle is in the logarithmic domain.
|
|
156
|
-
param_groups (tupe, optional): Parameter groups of the optimizer.
|
|
157
|
-
"""
|
|
158
|
-
def __init__(self,
|
|
159
|
-
base_lr,
|
|
160
|
-
max_lr,
|
|
161
|
-
period: int = 2000,
|
|
162
|
-
gamma: float = None,
|
|
163
|
-
log: bool = True,
|
|
164
|
-
param_groups: tuple = (0,),
|
|
165
|
-
start_period: int = 25,
|
|
166
|
-
):
|
|
167
|
-
self.base_lr = base_lr
|
|
168
|
-
self.max_lr = max_lr
|
|
169
|
-
self.period = period
|
|
170
|
-
self.gamma = gamma
|
|
171
|
-
self.param_groups = param_groups
|
|
172
|
-
self.log = log
|
|
173
|
-
self.start_period = start_period
|
|
174
|
-
self._axis = None
|
|
175
|
-
self._period = None
|
|
176
|
-
|
|
177
|
-
def get_lr(self, batch_num: int):
|
|
178
|
-
return self._get_lr(batch_num)
|
|
179
|
-
|
|
180
|
-
def _get_lr(self, batch_num):
|
|
181
|
-
num_period, t = batch_num // self.period, batch_num % self.period
|
|
182
|
-
|
|
183
|
-
if self._period != num_period:
|
|
184
|
-
self._period = num_period
|
|
185
|
-
if self.gamma and (num_period >= self.start_period):
|
|
186
|
-
amp = (self.max_lr - self.base_lr) * (self.gamma ** (num_period - self.start_period))
|
|
187
|
-
max_lr = self.base_lr + amp
|
|
188
|
-
else:
|
|
189
|
-
max_lr = self.max_lr
|
|
190
|
-
|
|
191
|
-
if self.log:
|
|
192
|
-
self._axis = np.logspace(np.log10(self.base_lr), np.log10(max_lr), self.period // 2)
|
|
193
|
-
else:
|
|
194
|
-
self._axis = np.linspace(self.base_lr, max_lr, self.period // 2)
|
|
195
|
-
if t < self.period // 2:
|
|
196
|
-
lr = self._axis[t]
|
|
197
|
-
else:
|
|
198
|
-
lr = self._axis[self.period - t - 1]
|
|
199
|
-
return lr
|
|
200
|
-
|
|
201
|
-
def end_batch(self, trainer: Trainer, batch_num: int):
|
|
202
|
-
lr = self.get_lr(batch_num)
|
|
203
|
-
for param_group in self.param_groups:
|
|
204
|
-
trainer.set_lr(lr, param_group)
|
|
205
|
-
|
|
206
|
-
def simulate_and_plot(self, total_steps: int, log_scale: bool = True):
|
|
207
|
-
import matplotlib.pyplot as plt
|
|
208
|
-
lrs = [self.get_lr(batch_num) for batch_num in range(total_steps)]
|
|
209
|
-
|
|
210
|
-
plt.figure(figsize=(10, 6))
|
|
211
|
-
plt.plot(lrs, label='Learning Rate')
|
|
212
|
-
plt.xlabel('Steps')
|
|
213
|
-
plt.ylabel('Learning Rate')
|
|
214
|
-
plt.title('Learning Rate Schedule')
|
|
215
|
-
if log_scale:
|
|
216
|
-
plt.yscale('log')
|
|
217
|
-
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
|
|
218
|
-
plt.legend()
|
|
219
|
-
plt.show()
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
class ReduceLROnPlateau(TrainerCallback):
|
|
223
|
-
"""Learning rate scheduler which reduces the learning rate when the loss stops decreasing
|
|
224
|
-
|
|
225
|
-
Args:
|
|
226
|
-
gamma (float, optional): Multiplicative factor of learning rate decay. Defaults to 0.5.
|
|
227
|
-
patience (int, optional): The number of allowed iterations with no improvement after which the learning rate will be reduced. Defaults to 500.
|
|
228
|
-
average (int, optional): Size of the window over which the average loss is computed. Defaults to 50.
|
|
229
|
-
loss_key (str, optional): Defaults to 'total_loss'.
|
|
230
|
-
param_groups (tuple, optional): Defaults to (0,).
|
|
231
|
-
|
|
232
|
-
"""
|
|
233
|
-
def __init__(
|
|
234
|
-
self,
|
|
235
|
-
gamma: float = 0.5,
|
|
236
|
-
patience: int = 500,
|
|
237
|
-
average: int = 50,
|
|
238
|
-
loss_key: str = 'total_loss',
|
|
239
|
-
param_groups: tuple = (0,),
|
|
240
|
-
):
|
|
241
|
-
"""
|
|
242
|
-
"""
|
|
243
|
-
self.patience = patience
|
|
244
|
-
self.average = average
|
|
245
|
-
self.gamma = gamma
|
|
246
|
-
self.loss_key = loss_key
|
|
247
|
-
self.param_groups = param_groups
|
|
248
|
-
|
|
249
|
-
def end_batch(self, trainer: Trainer, batch_num: int) -> None:
|
|
250
|
-
loss = trainer.losses[self.loss_key]
|
|
251
|
-
|
|
252
|
-
if len(loss) < self.patience:
|
|
253
|
-
return
|
|
254
|
-
|
|
255
|
-
if np.mean(loss[-self.patience:-(self.patience - self.average)]) <= np.mean(loss[-self.average:]):
|
|
256
|
-
for param_group in self.param_groups:
|
|
257
|
-
trainer.set_lr(trainer.lr(param_group) * self.gamma, param_group)
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
class OneCycleLR(ScheduleLR):
|
|
261
|
-
"""One-cycle learning rate scheduler (https://arxiv.org/abs/1708.07120)
|
|
262
|
-
|
|
263
|
-
Args:
|
|
264
|
-
max_lr (float): Upper learning rate boundary in the cycle
|
|
265
|
-
total_steps (int): The total number of steps in the cycle
|
|
266
|
-
pct_start (float, optional): The percentage of the cycle (in number of steps) spent increasing the learning rate. Defaults to 0.3.
|
|
267
|
-
div_factor (float, optional): Determines the initial learning rate via initial_lr = ``max_lr`` / ``div_factor``. Defaults to 25..
|
|
268
|
-
final_div_factor (float, optional): Determines the minimum learning rate via min_lr = ``initial_lr`` / ``final_div_factor``. Defaults to 1e4.
|
|
269
|
-
three_phase (bool, optional): If ``True``, use a third phase of the schedule to annihilate the learning rate according to ``final_div_factor`` instead of modifying the second phase. Defaults to True.
|
|
270
|
-
"""
|
|
271
|
-
def __init__(self, max_lr: float, total_steps: int, pct_start: float = 0.3, div_factor: float = 25.,
|
|
272
|
-
final_div_factor: float = 1e4, three_phase: bool = True, **kwargs):
|
|
273
|
-
super().__init__(
|
|
274
|
-
lr_scheduler.OneCycleLR,
|
|
275
|
-
max_lr=max_lr,
|
|
276
|
-
total_steps=total_steps,
|
|
277
|
-
pct_start=pct_start,
|
|
278
|
-
div_factor=div_factor ,
|
|
279
|
-
final_div_factor=final_div_factor,
|
|
280
|
-
three_phase=three_phase,
|
|
281
|
-
**kwargs
|
|
282
|
-
)
|
|
283
|
-
|
|
284
|
-
class CosineAnnealingWithWarmup(TrainerCallback):
|
|
285
|
-
"""
|
|
286
|
-
Cosine annealing scheduler with a warm-up stage.
|
|
287
|
-
|
|
288
|
-
Args:
|
|
289
|
-
max_lr (float): The maximum learning rate after the warm-up phase.
|
|
290
|
-
min_lr (float): The minimum learning rate after the warm-up phase.
|
|
291
|
-
warmup_iters (int): The number of iterations for the warm-up phase.
|
|
292
|
-
total_iters (int): The total number of iterations for the scheduler (including warm-up).
|
|
293
|
-
"""
|
|
294
|
-
def __init__(self, max_lr=None, min_lr=1.0e-6, warmup_iters=100, total_iters=100000):
|
|
295
|
-
self.max_lr = max_lr
|
|
296
|
-
self.min_lr = min_lr
|
|
297
|
-
self.warmup_iters = warmup_iters
|
|
298
|
-
self.total_iters = total_iters
|
|
299
|
-
|
|
300
|
-
def get_lr(self, step):
|
|
301
|
-
"""
|
|
302
|
-
Compute the learning rate for a given iteration.
|
|
303
|
-
|
|
304
|
-
Args:
|
|
305
|
-
step (int): The current iteration.
|
|
306
|
-
|
|
307
|
-
Returns:
|
|
308
|
-
float: The learning rate for the current iteration.
|
|
309
|
-
"""
|
|
310
|
-
if step < self.warmup_iters:
|
|
311
|
-
# Warm-up stage: Linear increase from 0 to max_lr
|
|
312
|
-
return self.max_lr * step / self.warmup_iters
|
|
313
|
-
elif step < self.total_iters:
|
|
314
|
-
# Cosine annealing stage
|
|
315
|
-
t = (step - self.warmup_iters) / (self.total_iters - self.warmup_iters)
|
|
316
|
-
return self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1 + math.cos(math.pi * t))
|
|
317
|
-
else:
|
|
318
|
-
# Beyond total iterations: Return min_lr
|
|
319
|
-
return self.min_lr
|
|
320
|
-
|
|
321
|
-
def start_training(self, trainer: Trainer) -> None:
|
|
322
|
-
self.max_lr = trainer.lr()
|
|
323
|
-
|
|
324
|
-
def end_batch(self, trainer: Trainer, batch_num: int):
|
|
325
|
-
"""
|
|
326
|
-
Updates the learning rate at the end of each batch.
|
|
327
|
-
|
|
328
|
-
Args:
|
|
329
|
-
trainer (Trainer): The trainer object.
|
|
330
|
-
batch_num (int): The current batch number.
|
|
331
|
-
"""
|
|
332
|
-
lr = self.get_lr(batch_num)
|
|
333
|
-
trainer.set_lr(lr)
|
|
334
|
-
|
|
335
|
-
def simulate_and_plot(self, total_steps: int = None, log_scale: bool = False):
|
|
336
|
-
"""
|
|
337
|
-
Simulates and plots the learning rate evolution.
|
|
338
|
-
|
|
339
|
-
Args:
|
|
340
|
-
total_batches (int, optional): Total number of batches to simulate. If None, uses self.total_iters.
|
|
341
|
-
"""
|
|
342
|
-
|
|
343
|
-
total_steps = total_steps or self.total_iters
|
|
344
|
-
lrs = [self.get_lr(step) for step in range(total_steps)]
|
|
345
|
-
|
|
346
|
-
import matplotlib.pyplot as plt
|
|
347
|
-
plt.figure(figsize=(10, 6))
|
|
348
|
-
plt.plot(lrs, label='Learning Rate')
|
|
349
|
-
plt.xlabel('Steps')
|
|
350
|
-
plt.ylabel('Learning Rate')
|
|
351
|
-
plt.title('Learning Rate Scheduler')
|
|
352
|
-
if log_scale:
|
|
353
|
-
plt.yscale('log')
|
|
354
|
-
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
|
|
355
|
-
plt.legend()
|
|
1
|
+
import math
|
|
2
|
+
from torch.optim import lr_scheduler
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from reflectorch.ml.basic_trainer import Trainer, TrainerCallback, PeriodicTrainerCallback
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
'ScheduleBatchSize',
|
|
10
|
+
'ScheduleLR',
|
|
11
|
+
'StepLR',
|
|
12
|
+
'CyclicLR',
|
|
13
|
+
'LogCyclicLR',
|
|
14
|
+
'ReduceLROnPlateau',
|
|
15
|
+
'OneCycleLR',
|
|
16
|
+
'CosineAnnealingWithWarmup',
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ScheduleBatchSize(PeriodicTrainerCallback):
|
|
21
|
+
"""Batch size scheduler
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
step (int): number of iterations after which the batch size is modified.
|
|
25
|
+
gamma (int, optional): quantity which is added to or multiplied with the current batch size. Defaults to 2.
|
|
26
|
+
last_epoch (int, optional): the last training iteration for which the batch size is modified. Defaults to -1.
|
|
27
|
+
mode (str, optional): ``'add'`` for addition or ``'multiply'`` for multiplication. Defaults to 'add'.
|
|
28
|
+
"""
|
|
29
|
+
def __init__(self, step: int, gamma: int = 2, last_epoch: int = -1, mode: str = 'add'):
|
|
30
|
+
super().__init__(step, last_epoch)
|
|
31
|
+
|
|
32
|
+
assert mode in ('add', 'multiply')
|
|
33
|
+
|
|
34
|
+
self.gamma = gamma
|
|
35
|
+
self.mode = mode
|
|
36
|
+
|
|
37
|
+
def _end_batch(self, trainer: Trainer, batch_num: int) -> None:
|
|
38
|
+
if self.mode == 'add':
|
|
39
|
+
trainer.batch_size += self.gamma
|
|
40
|
+
elif self.mode == 'multiply':
|
|
41
|
+
trainer.batch_size *= self.gamma
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class ScheduleLR(TrainerCallback):
|
|
45
|
+
"""Base class for learning rate schedulers
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
lr_scheduler_cls: class of the learning rate scheduler
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(self, lr_scheduler_cls, **kwargs):
|
|
52
|
+
self.lr_scheduler_cls = lr_scheduler_cls
|
|
53
|
+
self.kwargs = kwargs
|
|
54
|
+
self.lr_scheduler = None
|
|
55
|
+
|
|
56
|
+
def start_training(self, trainer: Trainer) -> None:
|
|
57
|
+
"""initializes a learning rate scheduler based on its class and keyword arguments at the start of training
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
trainer (Trainer): the trainer object
|
|
61
|
+
"""
|
|
62
|
+
self.lr_scheduler = self.lr_scheduler_cls(trainer.optim, **self.kwargs)
|
|
63
|
+
trainer.callback_params['lrs'] = []
|
|
64
|
+
|
|
65
|
+
def end_batch(self, trainer: Trainer, batch_num: int) -> None:
|
|
66
|
+
"""modifies the learning rate at the end of each iteration
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
trainer (Trainer): the trainer object
|
|
70
|
+
batch_num (int): index of the current iteration
|
|
71
|
+
"""
|
|
72
|
+
self.lr_scheduler.step()
|
|
73
|
+
|
|
74
|
+
def simulate_and_plot(self, total_steps: int, initial_lr: float, log_scale: bool = False):
|
|
75
|
+
import torch
|
|
76
|
+
import matplotlib.pyplot as plt
|
|
77
|
+
|
|
78
|
+
dummy_optim = torch.optim.Adam([torch.zeros(1)], lr=initial_lr)
|
|
79
|
+
scheduler = self.lr_scheduler_cls(dummy_optim, **self.kwargs)
|
|
80
|
+
|
|
81
|
+
lrs = []
|
|
82
|
+
for step in range(total_steps):
|
|
83
|
+
lrs.append(dummy_optim.param_groups[0]['lr'])
|
|
84
|
+
scheduler.step()
|
|
85
|
+
|
|
86
|
+
plt.figure(figsize=(10, 6))
|
|
87
|
+
plt.plot(lrs, label='Learning Rate')
|
|
88
|
+
plt.xlabel('Steps')
|
|
89
|
+
plt.ylabel('Learning Rate')
|
|
90
|
+
plt.title('Learning Rate Schedule')
|
|
91
|
+
|
|
92
|
+
if log_scale:
|
|
93
|
+
plt.yscale('log')
|
|
94
|
+
|
|
95
|
+
plt.grid(True, which="both", linestyle='--', linewidth=0.5)
|
|
96
|
+
plt.legend()
|
|
97
|
+
plt.show()
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class StepLR(ScheduleLR):
|
|
101
|
+
"""Learning rate scheduler which decays the learning rate of each parameter group by gamma every ``step_size`` epochs.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
step_size (int): Period of learning rate decay
|
|
105
|
+
gamma (float): Multiplicative factor of learning rate decay
|
|
106
|
+
last_epoch (int, optional): The index of last iteration. Defaults to -1.
|
|
107
|
+
"""
|
|
108
|
+
def __init__(self, step_size: int, gamma: float, last_epoch: int = -1, **kwargs):
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
super().__init__(lr_scheduler.StepLR, step_size=step_size, gamma=gamma, last_epoch=last_epoch, **kwargs)
|
|
112
|
+
|
|
113
|
+
def start_training(self, trainer: Trainer) -> None:
|
|
114
|
+
trainer.optim.param_groups[0]['initial_lr'] = trainer.lr()
|
|
115
|
+
super().start_training(trainer)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class CyclicLR(ScheduleLR):
|
|
119
|
+
"""Cyclic learning rate scheduler
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
base_lr (float): Initial learning rate which is the lower boundary in the cycle
|
|
123
|
+
max_lr (float): Upper learning rate boundary in the cycle
|
|
124
|
+
step_size_up (int, optional): Number of training iterations in the increasing half of a cycle. Defaults to 2000.
|
|
125
|
+
cycle_momentum (bool, optional): If True, momentum is cycled inversely to learning rate between ``base_momentum`` and ``max_momentum``. Defaults to False.
|
|
126
|
+
gamma (float, optional): Constant in ``‘exp_range’`` mode scaling function: gamma^(cycle iterations). Defaults to 1.
|
|
127
|
+
mode (str, optional): One of: ``'triangular'`` (a basic triangular cycle without amplitude scaling),
|
|
128
|
+
``'triangular2'`` (a basic triangular cycle that scales initial amplitude by half each cycle), ``'exp_range'``
|
|
129
|
+
(a cycle that scales initial amplitude by gamma^iterations at each cycle iteration). Defaults to 'triangular'.
|
|
130
|
+
"""
|
|
131
|
+
def __init__(self, base_lr, max_lr, step_size_up: int = 2000,
|
|
132
|
+
cycle_momentum: bool = False, gamma: float = 1., mode: str = 'triangular',
|
|
133
|
+
**kwargs):
|
|
134
|
+
super().__init__(
|
|
135
|
+
lr_scheduler.CyclicLR,
|
|
136
|
+
base_lr=base_lr,
|
|
137
|
+
max_lr=max_lr,
|
|
138
|
+
step_size_up=step_size_up,
|
|
139
|
+
cycle_momentum=cycle_momentum,
|
|
140
|
+
gamma=gamma,
|
|
141
|
+
mode=mode,
|
|
142
|
+
**kwargs
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class LogCyclicLR(TrainerCallback):
|
|
147
|
+
"""Cyclic learning rate scheduler on a logarithmic scale
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
base_lr (float): Lower learning rate boundary in the cycle
|
|
151
|
+
max_lr (float): Upper learning rate boundary in the cycle
|
|
152
|
+
period (int, optional): Number of training iterations in the cycle. Defaults to 2000.
|
|
153
|
+
gamma (float, optional): Constant for scaling the amplitude as ``gamma`` ^ ``iterations``. Defaults to 1.
|
|
154
|
+
start_period (int, optional): Number of starting iterations with the default learning rate.
|
|
155
|
+
log (bool, optional): If ``True``, the cycle is in the logarithmic domain.
|
|
156
|
+
param_groups (tupe, optional): Parameter groups of the optimizer.
|
|
157
|
+
"""
|
|
158
|
+
def __init__(self,
|
|
159
|
+
base_lr,
|
|
160
|
+
max_lr,
|
|
161
|
+
period: int = 2000,
|
|
162
|
+
gamma: float = None,
|
|
163
|
+
log: bool = True,
|
|
164
|
+
param_groups: tuple = (0,),
|
|
165
|
+
start_period: int = 25,
|
|
166
|
+
):
|
|
167
|
+
self.base_lr = base_lr
|
|
168
|
+
self.max_lr = max_lr
|
|
169
|
+
self.period = period
|
|
170
|
+
self.gamma = gamma
|
|
171
|
+
self.param_groups = param_groups
|
|
172
|
+
self.log = log
|
|
173
|
+
self.start_period = start_period
|
|
174
|
+
self._axis = None
|
|
175
|
+
self._period = None
|
|
176
|
+
|
|
177
|
+
def get_lr(self, batch_num: int):
|
|
178
|
+
return self._get_lr(batch_num)
|
|
179
|
+
|
|
180
|
+
def _get_lr(self, batch_num):
|
|
181
|
+
num_period, t = batch_num // self.period, batch_num % self.period
|
|
182
|
+
|
|
183
|
+
if self._period != num_period:
|
|
184
|
+
self._period = num_period
|
|
185
|
+
if self.gamma and (num_period >= self.start_period):
|
|
186
|
+
amp = (self.max_lr - self.base_lr) * (self.gamma ** (num_period - self.start_period))
|
|
187
|
+
max_lr = self.base_lr + amp
|
|
188
|
+
else:
|
|
189
|
+
max_lr = self.max_lr
|
|
190
|
+
|
|
191
|
+
if self.log:
|
|
192
|
+
self._axis = np.logspace(np.log10(self.base_lr), np.log10(max_lr), self.period // 2)
|
|
193
|
+
else:
|
|
194
|
+
self._axis = np.linspace(self.base_lr, max_lr, self.period // 2)
|
|
195
|
+
if t < self.period // 2:
|
|
196
|
+
lr = self._axis[t]
|
|
197
|
+
else:
|
|
198
|
+
lr = self._axis[self.period - t - 1]
|
|
199
|
+
return lr
|
|
200
|
+
|
|
201
|
+
def end_batch(self, trainer: Trainer, batch_num: int):
|
|
202
|
+
lr = self.get_lr(batch_num)
|
|
203
|
+
for param_group in self.param_groups:
|
|
204
|
+
trainer.set_lr(lr, param_group)
|
|
205
|
+
|
|
206
|
+
def simulate_and_plot(self, total_steps: int, log_scale: bool = True):
|
|
207
|
+
import matplotlib.pyplot as plt
|
|
208
|
+
lrs = [self.get_lr(batch_num) for batch_num in range(total_steps)]
|
|
209
|
+
|
|
210
|
+
plt.figure(figsize=(10, 6))
|
|
211
|
+
plt.plot(lrs, label='Learning Rate')
|
|
212
|
+
plt.xlabel('Steps')
|
|
213
|
+
plt.ylabel('Learning Rate')
|
|
214
|
+
plt.title('Learning Rate Schedule')
|
|
215
|
+
if log_scale:
|
|
216
|
+
plt.yscale('log')
|
|
217
|
+
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
|
|
218
|
+
plt.legend()
|
|
219
|
+
plt.show()
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
class ReduceLROnPlateau(TrainerCallback):
|
|
223
|
+
"""Learning rate scheduler which reduces the learning rate when the loss stops decreasing
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
gamma (float, optional): Multiplicative factor of learning rate decay. Defaults to 0.5.
|
|
227
|
+
patience (int, optional): The number of allowed iterations with no improvement after which the learning rate will be reduced. Defaults to 500.
|
|
228
|
+
average (int, optional): Size of the window over which the average loss is computed. Defaults to 50.
|
|
229
|
+
loss_key (str, optional): Defaults to 'total_loss'.
|
|
230
|
+
param_groups (tuple, optional): Defaults to (0,).
|
|
231
|
+
|
|
232
|
+
"""
|
|
233
|
+
def __init__(
|
|
234
|
+
self,
|
|
235
|
+
gamma: float = 0.5,
|
|
236
|
+
patience: int = 500,
|
|
237
|
+
average: int = 50,
|
|
238
|
+
loss_key: str = 'total_loss',
|
|
239
|
+
param_groups: tuple = (0,),
|
|
240
|
+
):
|
|
241
|
+
"""
|
|
242
|
+
"""
|
|
243
|
+
self.patience = patience
|
|
244
|
+
self.average = average
|
|
245
|
+
self.gamma = gamma
|
|
246
|
+
self.loss_key = loss_key
|
|
247
|
+
self.param_groups = param_groups
|
|
248
|
+
|
|
249
|
+
def end_batch(self, trainer: Trainer, batch_num: int) -> None:
|
|
250
|
+
loss = trainer.losses[self.loss_key]
|
|
251
|
+
|
|
252
|
+
if len(loss) < self.patience:
|
|
253
|
+
return
|
|
254
|
+
|
|
255
|
+
if np.mean(loss[-self.patience:-(self.patience - self.average)]) <= np.mean(loss[-self.average:]):
|
|
256
|
+
for param_group in self.param_groups:
|
|
257
|
+
trainer.set_lr(trainer.lr(param_group) * self.gamma, param_group)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
class OneCycleLR(ScheduleLR):
|
|
261
|
+
"""One-cycle learning rate scheduler (https://arxiv.org/abs/1708.07120)
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
max_lr (float): Upper learning rate boundary in the cycle
|
|
265
|
+
total_steps (int): The total number of steps in the cycle
|
|
266
|
+
pct_start (float, optional): The percentage of the cycle (in number of steps) spent increasing the learning rate. Defaults to 0.3.
|
|
267
|
+
div_factor (float, optional): Determines the initial learning rate via initial_lr = ``max_lr`` / ``div_factor``. Defaults to 25..
|
|
268
|
+
final_div_factor (float, optional): Determines the minimum learning rate via min_lr = ``initial_lr`` / ``final_div_factor``. Defaults to 1e4.
|
|
269
|
+
three_phase (bool, optional): If ``True``, use a third phase of the schedule to annihilate the learning rate according to ``final_div_factor`` instead of modifying the second phase. Defaults to True.
|
|
270
|
+
"""
|
|
271
|
+
def __init__(self, max_lr: float, total_steps: int, pct_start: float = 0.3, div_factor: float = 25.,
|
|
272
|
+
final_div_factor: float = 1e4, three_phase: bool = True, **kwargs):
|
|
273
|
+
super().__init__(
|
|
274
|
+
lr_scheduler.OneCycleLR,
|
|
275
|
+
max_lr=max_lr,
|
|
276
|
+
total_steps=total_steps,
|
|
277
|
+
pct_start=pct_start,
|
|
278
|
+
div_factor=div_factor ,
|
|
279
|
+
final_div_factor=final_div_factor,
|
|
280
|
+
three_phase=three_phase,
|
|
281
|
+
**kwargs
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
class CosineAnnealingWithWarmup(TrainerCallback):
|
|
285
|
+
"""
|
|
286
|
+
Cosine annealing scheduler with a warm-up stage.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
max_lr (float): The maximum learning rate after the warm-up phase.
|
|
290
|
+
min_lr (float): The minimum learning rate after the warm-up phase.
|
|
291
|
+
warmup_iters (int): The number of iterations for the warm-up phase.
|
|
292
|
+
total_iters (int): The total number of iterations for the scheduler (including warm-up).
|
|
293
|
+
"""
|
|
294
|
+
def __init__(self, max_lr=None, min_lr=1.0e-6, warmup_iters=100, total_iters=100000):
|
|
295
|
+
self.max_lr = max_lr
|
|
296
|
+
self.min_lr = min_lr
|
|
297
|
+
self.warmup_iters = warmup_iters
|
|
298
|
+
self.total_iters = total_iters
|
|
299
|
+
|
|
300
|
+
def get_lr(self, step):
|
|
301
|
+
"""
|
|
302
|
+
Compute the learning rate for a given iteration.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
step (int): The current iteration.
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
float: The learning rate for the current iteration.
|
|
309
|
+
"""
|
|
310
|
+
if step < self.warmup_iters:
|
|
311
|
+
# Warm-up stage: Linear increase from 0 to max_lr
|
|
312
|
+
return self.max_lr * step / self.warmup_iters
|
|
313
|
+
elif step < self.total_iters:
|
|
314
|
+
# Cosine annealing stage
|
|
315
|
+
t = (step - self.warmup_iters) / (self.total_iters - self.warmup_iters)
|
|
316
|
+
return self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1 + math.cos(math.pi * t))
|
|
317
|
+
else:
|
|
318
|
+
# Beyond total iterations: Return min_lr
|
|
319
|
+
return self.min_lr
|
|
320
|
+
|
|
321
|
+
def start_training(self, trainer: Trainer) -> None:
|
|
322
|
+
self.max_lr = trainer.lr()
|
|
323
|
+
|
|
324
|
+
def end_batch(self, trainer: Trainer, batch_num: int):
|
|
325
|
+
"""
|
|
326
|
+
Updates the learning rate at the end of each batch.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
trainer (Trainer): The trainer object.
|
|
330
|
+
batch_num (int): The current batch number.
|
|
331
|
+
"""
|
|
332
|
+
lr = self.get_lr(batch_num)
|
|
333
|
+
trainer.set_lr(lr)
|
|
334
|
+
|
|
335
|
+
def simulate_and_plot(self, total_steps: int = None, log_scale: bool = False):
|
|
336
|
+
"""
|
|
337
|
+
Simulates and plots the learning rate evolution.
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
total_batches (int, optional): Total number of batches to simulate. If None, uses self.total_iters.
|
|
341
|
+
"""
|
|
342
|
+
|
|
343
|
+
total_steps = total_steps or self.total_iters
|
|
344
|
+
lrs = [self.get_lr(step) for step in range(total_steps)]
|
|
345
|
+
|
|
346
|
+
import matplotlib.pyplot as plt
|
|
347
|
+
plt.figure(figsize=(10, 6))
|
|
348
|
+
plt.plot(lrs, label='Learning Rate')
|
|
349
|
+
plt.xlabel('Steps')
|
|
350
|
+
plt.ylabel('Learning Rate')
|
|
351
|
+
plt.title('Learning Rate Scheduler')
|
|
352
|
+
if log_scale:
|
|
353
|
+
plt.yscale('log')
|
|
354
|
+
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
|
|
355
|
+
plt.legend()
|
|
356
356
|
plt.show()
|