reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +17 -17
- reflectorch/data_generation/__init__.py +128 -126
- reflectorch/data_generation/dataset.py +210 -210
- reflectorch/data_generation/likelihoods.py +80 -80
- reflectorch/data_generation/noise.py +470 -470
- reflectorch/data_generation/priors/__init__.py +60 -60
- reflectorch/data_generation/priors/base.py +55 -55
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
- reflectorch/data_generation/priors/independent_priors.py +195 -195
- reflectorch/data_generation/priors/multilayer_models.py +311 -311
- reflectorch/data_generation/priors/multilayer_structures.py +104 -104
- reflectorch/data_generation/priors/no_constraints.py +206 -206
- reflectorch/data_generation/priors/parametric_models.py +841 -841
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
- reflectorch/data_generation/priors/params.py +252 -252
- reflectorch/data_generation/priors/sampler_strategies.py +369 -369
- reflectorch/data_generation/priors/scaler_mixin.py +65 -65
- reflectorch/data_generation/priors/subprior_sampler.py +371 -371
- reflectorch/data_generation/priors/utils.py +118 -118
- reflectorch/data_generation/process_data.py +41 -41
- reflectorch/data_generation/q_generator.py +280 -246
- reflectorch/data_generation/reflectivity/__init__.py +102 -102
- reflectorch/data_generation/reflectivity/abeles.py +97 -97
- reflectorch/data_generation/reflectivity/kinematical.py +70 -70
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
- reflectorch/data_generation/reflectivity/smearing.py +138 -138
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
- reflectorch/data_generation/scale_curves.py +112 -112
- reflectorch/data_generation/smearing.py +98 -98
- reflectorch/data_generation/utils.py +223 -222
- reflectorch/extensions/jupyter/__init__.py +11 -6
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -34
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -5
- reflectorch/extensions/matplotlib/losses.py +32 -32
- reflectorch/extensions/refnx/refnx_conversion.py +76 -76
- reflectorch/inference/__init__.py +28 -24
- reflectorch/inference/inference_model.py +847 -851
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +37 -0
- reflectorch/inference/multilayer_fitter.py +171 -171
- reflectorch/inference/multilayer_inference_model.py +193 -193
- reflectorch/inference/plotting.py +524 -98
- reflectorch/inference/preprocess_exp/__init__.py +6 -6
- reflectorch/inference/preprocess_exp/attenuation.py +36 -36
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
- reflectorch/inference/preprocess_exp/footprint.py +81 -81
- reflectorch/inference/preprocess_exp/interpolation.py +19 -16
- reflectorch/inference/preprocess_exp/normalize.py +21 -21
- reflectorch/inference/preprocess_exp/preprocess.py +121 -121
- reflectorch/inference/query_matcher.py +81 -81
- reflectorch/inference/record_time.py +43 -43
- reflectorch/inference/sampler_solution.py +56 -56
- reflectorch/inference/scipy_fitter.py +272 -248
- reflectorch/inference/torch_fitter.py +87 -87
- reflectorch/ml/__init__.py +32 -32
- reflectorch/ml/basic_trainer.py +292 -292
- reflectorch/ml/callbacks.py +80 -80
- reflectorch/ml/dataloaders.py +26 -26
- reflectorch/ml/loggers.py +55 -55
- reflectorch/ml/schedulers.py +355 -355
- reflectorch/ml/trainers.py +200 -191
- reflectorch/ml/utils.py +2 -2
- reflectorch/models/__init__.py +15 -14
- reflectorch/models/activations.py +50 -50
- reflectorch/models/encoders/__init__.py +19 -17
- reflectorch/models/encoders/conv_encoder.py +218 -218
- reflectorch/models/encoders/conv_res_net.py +115 -115
- reflectorch/models/encoders/fno.py +133 -133
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -14
- reflectorch/models/networks/mlp_networks.py +434 -428
- reflectorch/models/networks/residual_net.py +156 -156
- reflectorch/paths.py +29 -27
- reflectorch/runs/__init__.py +31 -31
- reflectorch/runs/config.py +25 -25
- reflectorch/runs/slurm_utils.py +93 -93
- reflectorch/runs/train.py +78 -78
- reflectorch/runs/utils.py +404 -401
- reflectorch/test_config.py +4 -4
- reflectorch/train.py +4 -4
- reflectorch/train_on_cluster.py +4 -4
- reflectorch/utils.py +98 -68
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
- reflectorch-1.5.0.dist-info/RECORD +96 -0
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
- reflectorch-1.3.0.dist-info/RECORD +0 -86
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
reflectorch/ml/basic_trainer.py
CHANGED
|
@@ -1,292 +1,292 @@
|
|
|
1
|
-
from typing import Optional, Tuple, Iterable, Any, Union, Type
|
|
2
|
-
from collections import defaultdict
|
|
3
|
-
|
|
4
|
-
from tqdm import tqdm as standard_tqdm
|
|
5
|
-
from tqdm.notebook import tqdm as notebook_tqdm
|
|
6
|
-
import numpy as np
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
from torch.nn import Module
|
|
10
|
-
|
|
11
|
-
from reflectorch.ml.loggers import Logger, Loggers
|
|
12
|
-
|
|
13
|
-
from .utils import is_divisor
|
|
14
|
-
|
|
15
|
-
__all__ = [
|
|
16
|
-
'Trainer',
|
|
17
|
-
'TrainerCallback',
|
|
18
|
-
'DataLoader',
|
|
19
|
-
'PeriodicTrainerCallback',
|
|
20
|
-
]
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class Trainer(object):
|
|
24
|
-
"""Trainer class
|
|
25
|
-
|
|
26
|
-
Args:
|
|
27
|
-
model (nn.Module): neural network
|
|
28
|
-
loader (DataLoader): data loader
|
|
29
|
-
lr (float): learning rate
|
|
30
|
-
batch_size (int): batch size
|
|
31
|
-
clip_grad_norm (int, optional): maximum norm for gradient clipping if it is not ``None``. Defaults to None.
|
|
32
|
-
logger (Union[Logger, Tuple[Logger, ...], Loggers], optional): logger. Defaults to None.
|
|
33
|
-
optim_cls (Type[torch.optim.Optimizer], optional): Pytorch optimizer. Defaults to torch.optim.Adam.
|
|
34
|
-
optim_kwargs (dict, optional): optimizer arguments. Defaults to None.
|
|
35
|
-
"""
|
|
36
|
-
|
|
37
|
-
TOTAL_LOSS_KEY: str = 'total_loss'
|
|
38
|
-
|
|
39
|
-
def __init__(self,
|
|
40
|
-
model: Module,
|
|
41
|
-
loader: 'DataLoader',
|
|
42
|
-
lr: float,
|
|
43
|
-
batch_size: int,
|
|
44
|
-
clip_grad_norm_max: Optional[int] = None,
|
|
45
|
-
logger: Union[Logger, Tuple[Logger, ...], Loggers] = None,
|
|
46
|
-
optim_cls: Type[torch.optim.Optimizer] = torch.optim.Adam,
|
|
47
|
-
optim_kwargs: dict = None,
|
|
48
|
-
**kwargs
|
|
49
|
-
):
|
|
50
|
-
|
|
51
|
-
self.model = model
|
|
52
|
-
self.loader = loader
|
|
53
|
-
self.batch_size = batch_size
|
|
54
|
-
self.clip_grad_norm_max = clip_grad_norm_max
|
|
55
|
-
|
|
56
|
-
self.optim = self.configure_optimizer(optim_cls, lr=lr, **(optim_kwargs or {}))
|
|
57
|
-
self.lrs = []
|
|
58
|
-
self.losses = defaultdict(list)
|
|
59
|
-
|
|
60
|
-
self.logger = _init_logger(logger)
|
|
61
|
-
self.callback_params = {}
|
|
62
|
-
|
|
63
|
-
for k, v in kwargs.items():
|
|
64
|
-
setattr(self, k, v)
|
|
65
|
-
|
|
66
|
-
self.init()
|
|
67
|
-
|
|
68
|
-
def init(self):
|
|
69
|
-
pass
|
|
70
|
-
|
|
71
|
-
def log(self, name: str, data):
|
|
72
|
-
"""log data"""
|
|
73
|
-
self.logger.log(name, data)
|
|
74
|
-
|
|
75
|
-
def train(self,
|
|
76
|
-
num_batches: int,
|
|
77
|
-
callbacks: Union[Tuple['TrainerCallback', ...], 'TrainerCallback'] = (),
|
|
78
|
-
disable_tqdm: bool = False,
|
|
79
|
-
use_notebook_tqdm: bool = False,
|
|
80
|
-
update_tqdm_freq: int = 1,
|
|
81
|
-
grad_accumulation_steps: int = 1,
|
|
82
|
-
):
|
|
83
|
-
"""starts the training process
|
|
84
|
-
|
|
85
|
-
Args:
|
|
86
|
-
num_batches (int): total number of training iterations
|
|
87
|
-
callbacks (Union[Tuple['TrainerCallback'], 'TrainerCallback']): the trainer callbacks. Defaults to ().
|
|
88
|
-
disable_tqdm (bool, optional): if ``True``, the progress bar is disabled. Defaults to False.
|
|
89
|
-
use_notebook_tqdm (bool, optional): should be set to ``True`` when used in a Jupyter Notebook. Defaults to False.
|
|
90
|
-
update_tqdm_freq (int, optional): frequency for updating the progress bar. Defaults to 10.
|
|
91
|
-
grad_accumulation_steps (int, optional): number of gradient accumulation steps. Defaults to 1.
|
|
92
|
-
"""
|
|
93
|
-
|
|
94
|
-
if isinstance(callbacks, TrainerCallback):
|
|
95
|
-
callbacks = (callbacks,)
|
|
96
|
-
|
|
97
|
-
callbacks = _StackedTrainerCallbacks(list(callbacks) + [self.loader])
|
|
98
|
-
|
|
99
|
-
tqdm_class = notebook_tqdm if use_notebook_tqdm else standard_tqdm
|
|
100
|
-
pbar = tqdm_class(range(num_batches), disable=disable_tqdm)
|
|
101
|
-
|
|
102
|
-
callbacks.start_training(self)
|
|
103
|
-
|
|
104
|
-
for batch_num in pbar:
|
|
105
|
-
self.model.train()
|
|
106
|
-
|
|
107
|
-
self.optim.zero_grad()
|
|
108
|
-
total_loss, avr_loss_dict = 0, defaultdict(list)
|
|
109
|
-
|
|
110
|
-
for _ in range(grad_accumulation_steps):
|
|
111
|
-
|
|
112
|
-
batch_data = self.get_batch_by_idx(batch_num)
|
|
113
|
-
loss_dict = self.get_loss_dict(batch_data)
|
|
114
|
-
loss = loss_dict['loss'] / grad_accumulation_steps
|
|
115
|
-
total_loss += loss.item()
|
|
116
|
-
_update_loss_dict(avr_loss_dict, loss_dict)
|
|
117
|
-
|
|
118
|
-
if not torch.isfinite(loss).item():
|
|
119
|
-
raise ValueError('Loss is not finite!')
|
|
120
|
-
|
|
121
|
-
loss.backward()
|
|
122
|
-
|
|
123
|
-
if self.clip_grad_norm_max is not None:
|
|
124
|
-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip_grad_norm_max)
|
|
125
|
-
|
|
126
|
-
self.optim.step()
|
|
127
|
-
|
|
128
|
-
avr_loss_dict = {k: np.mean(v) for k, v in avr_loss_dict.items()}
|
|
129
|
-
self._update_losses(avr_loss_dict, total_loss)
|
|
130
|
-
|
|
131
|
-
if not disable_tqdm:
|
|
132
|
-
self._update_tqdm(pbar, batch_num, update_tqdm_freq)
|
|
133
|
-
|
|
134
|
-
break_epoch = callbacks.end_batch(self, batch_num)
|
|
135
|
-
|
|
136
|
-
if break_epoch:
|
|
137
|
-
break
|
|
138
|
-
|
|
139
|
-
callbacks.end_training(self)
|
|
140
|
-
|
|
141
|
-
def _update_tqdm(self, pbar, batch_num: int, update_tqdm_freq: int):
|
|
142
|
-
if is_divisor(batch_num, update_tqdm_freq):
|
|
143
|
-
last_loss = np.mean(self.losses[self.TOTAL_LOSS_KEY][-10:])
|
|
144
|
-
pbar.set_description(f'Loss = {last_loss:.2e}')
|
|
145
|
-
|
|
146
|
-
postfix = {}
|
|
147
|
-
for key in self.losses.keys():
|
|
148
|
-
if key != self.TOTAL_LOSS_KEY:
|
|
149
|
-
last_value = self.losses[key][-1]
|
|
150
|
-
postfix[key] = f'{last_value:.4f}'
|
|
151
|
-
|
|
152
|
-
postfix['lr'] = f'{self.lr():.2e}'
|
|
153
|
-
|
|
154
|
-
pbar.set_postfix(postfix)
|
|
155
|
-
|
|
156
|
-
def get_batch_by_idx(self, batch_num: int) -> Any:
|
|
157
|
-
raise NotImplementedError
|
|
158
|
-
|
|
159
|
-
def get_loss_dict(self, batch_data) -> dict:
|
|
160
|
-
raise NotImplementedError
|
|
161
|
-
|
|
162
|
-
def _update_losses(self, loss_dict: dict, loss: float) -> None:
|
|
163
|
-
_update_loss_dict(self.losses, loss_dict)
|
|
164
|
-
self.losses[self.TOTAL_LOSS_KEY].append(loss)
|
|
165
|
-
self.lrs.append(self.lr())
|
|
166
|
-
|
|
167
|
-
def configure_optimizer(self, optim_cls, lr: float, **kwargs) -> torch.optim.Optimizer:
|
|
168
|
-
"""configure the optimizer based on the optimizer class, the learning rate and the optimizer keyword arguments
|
|
169
|
-
|
|
170
|
-
Args:
|
|
171
|
-
optim_cls: the class of the optimizer
|
|
172
|
-
lr (float): the learning rate
|
|
173
|
-
|
|
174
|
-
Returns:
|
|
175
|
-
torch.optim.Optimizer:
|
|
176
|
-
"""
|
|
177
|
-
optim = optim_cls(self.model.parameters(), lr, **kwargs)
|
|
178
|
-
return optim
|
|
179
|
-
|
|
180
|
-
def lr(self, param_group: int = 0) -> float:
|
|
181
|
-
"""get the learning rate"""
|
|
182
|
-
return self.optim.param_groups[param_group]['lr']
|
|
183
|
-
|
|
184
|
-
def set_lr(self, lr: float, param_group: int = 0) -> None:
|
|
185
|
-
"""set the learning rate"""
|
|
186
|
-
self.optim.param_groups[param_group]['lr'] = lr
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
class TrainerCallback(object):
|
|
190
|
-
"""Base class for trainer callbacks
|
|
191
|
-
"""
|
|
192
|
-
def start_training(self, trainer: Trainer) -> None:
|
|
193
|
-
"""add functionality the start of training
|
|
194
|
-
|
|
195
|
-
Args:
|
|
196
|
-
trainer (Trainer): the trainer object
|
|
197
|
-
"""
|
|
198
|
-
pass
|
|
199
|
-
|
|
200
|
-
def end_training(self, trainer: Trainer) -> None:
|
|
201
|
-
"""add functionality at the end of training
|
|
202
|
-
|
|
203
|
-
Args:
|
|
204
|
-
trainer (Trainer): the trainer object
|
|
205
|
-
"""
|
|
206
|
-
pass
|
|
207
|
-
|
|
208
|
-
def end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
|
|
209
|
-
"""add functionality at the end of the iteration / batch
|
|
210
|
-
|
|
211
|
-
Args:
|
|
212
|
-
trainer (Trainer): the trainer object
|
|
213
|
-
batch_num (int): the index of the current iteration / batch
|
|
214
|
-
|
|
215
|
-
Returns:
|
|
216
|
-
Union[bool, None]:
|
|
217
|
-
"""
|
|
218
|
-
pass
|
|
219
|
-
|
|
220
|
-
def __repr__(self):
|
|
221
|
-
return f'{self.__class__.__name__}()'
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
class DataLoader(TrainerCallback):
|
|
225
|
-
pass
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
class PeriodicTrainerCallback(TrainerCallback):
|
|
229
|
-
"""Base class for trainer callbacks which perform an action periodically after a number of iterations
|
|
230
|
-
|
|
231
|
-
Args:
|
|
232
|
-
step (int, optional): Number of iterations after which the action is repeated. Defaults to 1.
|
|
233
|
-
last_epoch (int, optional): the last training iteration for which the action is performed. Defaults to -1.
|
|
234
|
-
"""
|
|
235
|
-
def __init__(self, step: int = 1, last_epoch: int = -1):
|
|
236
|
-
self.step = step
|
|
237
|
-
self.last_epoch = last_epoch
|
|
238
|
-
|
|
239
|
-
def end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
|
|
240
|
-
"""add functionality at the end of the iteration / batch
|
|
241
|
-
|
|
242
|
-
Args:
|
|
243
|
-
trainer (Trainer): the trainer object
|
|
244
|
-
batch_num (int): the index of the current iteration / batch
|
|
245
|
-
|
|
246
|
-
Returns:
|
|
247
|
-
Union[bool, None]:
|
|
248
|
-
"""
|
|
249
|
-
if (
|
|
250
|
-
is_divisor(batch_num, self.step) and
|
|
251
|
-
(self.last_epoch == -1 or batch_num < self.last_epoch)
|
|
252
|
-
):
|
|
253
|
-
return self._end_batch(trainer, batch_num)
|
|
254
|
-
|
|
255
|
-
def _end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
|
|
256
|
-
pass
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
class _StackedTrainerCallbacks(TrainerCallback):
|
|
260
|
-
def __init__(self, callbacks: Iterable[TrainerCallback]):
|
|
261
|
-
self.callbacks = tuple(callbacks)
|
|
262
|
-
|
|
263
|
-
def start_training(self, trainer: Trainer) -> None:
|
|
264
|
-
for c in self.callbacks:
|
|
265
|
-
c.start_training(trainer)
|
|
266
|
-
|
|
267
|
-
def end_training(self, trainer: Trainer) -> None:
|
|
268
|
-
for c in self.callbacks:
|
|
269
|
-
c.end_training(trainer)
|
|
270
|
-
|
|
271
|
-
def end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
|
|
272
|
-
break_epoch = False
|
|
273
|
-
for c in self.callbacks:
|
|
274
|
-
break_epoch += bool(c.end_batch(trainer, batch_num))
|
|
275
|
-
return break_epoch
|
|
276
|
-
|
|
277
|
-
def __repr__(self):
|
|
278
|
-
callbacks = ", ".join(repr(c) for c in self.callbacks)
|
|
279
|
-
return f'StackedTrainerCallbacks({callbacks})'
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
def _init_logger(logger: Union[Logger, Tuple[Logger, ...], Loggers] = None):
|
|
283
|
-
if not logger:
|
|
284
|
-
return Logger()
|
|
285
|
-
if isinstance(logger, Logger):
|
|
286
|
-
return logger
|
|
287
|
-
return Loggers(*logger)
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
def _update_loss_dict(loss_dict: dict, new_values: dict):
|
|
291
|
-
for k, v in new_values.items():
|
|
292
|
-
loss_dict[k].append(v.item())
|
|
1
|
+
from typing import Optional, Tuple, Iterable, Any, Union, Type
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
|
|
4
|
+
from tqdm import tqdm as standard_tqdm
|
|
5
|
+
from tqdm.notebook import tqdm as notebook_tqdm
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch.nn import Module
|
|
10
|
+
|
|
11
|
+
from reflectorch.ml.loggers import Logger, Loggers
|
|
12
|
+
|
|
13
|
+
from .utils import is_divisor
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
'Trainer',
|
|
17
|
+
'TrainerCallback',
|
|
18
|
+
'DataLoader',
|
|
19
|
+
'PeriodicTrainerCallback',
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Trainer(object):
|
|
24
|
+
"""Trainer class
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
model (nn.Module): neural network
|
|
28
|
+
loader (DataLoader): data loader
|
|
29
|
+
lr (float): learning rate
|
|
30
|
+
batch_size (int): batch size
|
|
31
|
+
clip_grad_norm (int, optional): maximum norm for gradient clipping if it is not ``None``. Defaults to None.
|
|
32
|
+
logger (Union[Logger, Tuple[Logger, ...], Loggers], optional): logger. Defaults to None.
|
|
33
|
+
optim_cls (Type[torch.optim.Optimizer], optional): Pytorch optimizer. Defaults to torch.optim.Adam.
|
|
34
|
+
optim_kwargs (dict, optional): optimizer arguments. Defaults to None.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
TOTAL_LOSS_KEY: str = 'total_loss'
|
|
38
|
+
|
|
39
|
+
def __init__(self,
|
|
40
|
+
model: Module,
|
|
41
|
+
loader: 'DataLoader',
|
|
42
|
+
lr: float,
|
|
43
|
+
batch_size: int,
|
|
44
|
+
clip_grad_norm_max: Optional[int] = None,
|
|
45
|
+
logger: Union[Logger, Tuple[Logger, ...], Loggers] = None,
|
|
46
|
+
optim_cls: Type[torch.optim.Optimizer] = torch.optim.Adam,
|
|
47
|
+
optim_kwargs: dict = None,
|
|
48
|
+
**kwargs
|
|
49
|
+
):
|
|
50
|
+
|
|
51
|
+
self.model = model
|
|
52
|
+
self.loader = loader
|
|
53
|
+
self.batch_size = batch_size
|
|
54
|
+
self.clip_grad_norm_max = clip_grad_norm_max
|
|
55
|
+
|
|
56
|
+
self.optim = self.configure_optimizer(optim_cls, lr=lr, **(optim_kwargs or {}))
|
|
57
|
+
self.lrs = []
|
|
58
|
+
self.losses = defaultdict(list)
|
|
59
|
+
|
|
60
|
+
self.logger = _init_logger(logger)
|
|
61
|
+
self.callback_params = {}
|
|
62
|
+
|
|
63
|
+
for k, v in kwargs.items():
|
|
64
|
+
setattr(self, k, v)
|
|
65
|
+
|
|
66
|
+
self.init()
|
|
67
|
+
|
|
68
|
+
def init(self):
|
|
69
|
+
pass
|
|
70
|
+
|
|
71
|
+
def log(self, name: str, data):
|
|
72
|
+
"""log data"""
|
|
73
|
+
self.logger.log(name, data)
|
|
74
|
+
|
|
75
|
+
def train(self,
|
|
76
|
+
num_batches: int,
|
|
77
|
+
callbacks: Union[Tuple['TrainerCallback', ...], 'TrainerCallback'] = (),
|
|
78
|
+
disable_tqdm: bool = False,
|
|
79
|
+
use_notebook_tqdm: bool = False,
|
|
80
|
+
update_tqdm_freq: int = 1,
|
|
81
|
+
grad_accumulation_steps: int = 1,
|
|
82
|
+
):
|
|
83
|
+
"""starts the training process
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
num_batches (int): total number of training iterations
|
|
87
|
+
callbacks (Union[Tuple['TrainerCallback'], 'TrainerCallback']): the trainer callbacks. Defaults to ().
|
|
88
|
+
disable_tqdm (bool, optional): if ``True``, the progress bar is disabled. Defaults to False.
|
|
89
|
+
use_notebook_tqdm (bool, optional): should be set to ``True`` when used in a Jupyter Notebook. Defaults to False.
|
|
90
|
+
update_tqdm_freq (int, optional): frequency for updating the progress bar. Defaults to 10.
|
|
91
|
+
grad_accumulation_steps (int, optional): number of gradient accumulation steps. Defaults to 1.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
if isinstance(callbacks, TrainerCallback):
|
|
95
|
+
callbacks = (callbacks,)
|
|
96
|
+
|
|
97
|
+
callbacks = _StackedTrainerCallbacks(list(callbacks) + [self.loader])
|
|
98
|
+
|
|
99
|
+
tqdm_class = notebook_tqdm if use_notebook_tqdm else standard_tqdm
|
|
100
|
+
pbar = tqdm_class(range(num_batches), disable=disable_tqdm)
|
|
101
|
+
|
|
102
|
+
callbacks.start_training(self)
|
|
103
|
+
|
|
104
|
+
for batch_num in pbar:
|
|
105
|
+
self.model.train()
|
|
106
|
+
|
|
107
|
+
self.optim.zero_grad()
|
|
108
|
+
total_loss, avr_loss_dict = 0, defaultdict(list)
|
|
109
|
+
|
|
110
|
+
for _ in range(grad_accumulation_steps):
|
|
111
|
+
|
|
112
|
+
batch_data = self.get_batch_by_idx(batch_num)
|
|
113
|
+
loss_dict = self.get_loss_dict(batch_data)
|
|
114
|
+
loss = loss_dict['loss'] / grad_accumulation_steps
|
|
115
|
+
total_loss += loss.item()
|
|
116
|
+
_update_loss_dict(avr_loss_dict, loss_dict)
|
|
117
|
+
|
|
118
|
+
if not torch.isfinite(loss).item():
|
|
119
|
+
raise ValueError('Loss is not finite!')
|
|
120
|
+
|
|
121
|
+
loss.backward()
|
|
122
|
+
|
|
123
|
+
if self.clip_grad_norm_max is not None:
|
|
124
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip_grad_norm_max)
|
|
125
|
+
|
|
126
|
+
self.optim.step()
|
|
127
|
+
|
|
128
|
+
avr_loss_dict = {k: np.mean(v) for k, v in avr_loss_dict.items()}
|
|
129
|
+
self._update_losses(avr_loss_dict, total_loss)
|
|
130
|
+
|
|
131
|
+
if not disable_tqdm:
|
|
132
|
+
self._update_tqdm(pbar, batch_num, update_tqdm_freq)
|
|
133
|
+
|
|
134
|
+
break_epoch = callbacks.end_batch(self, batch_num)
|
|
135
|
+
|
|
136
|
+
if break_epoch:
|
|
137
|
+
break
|
|
138
|
+
|
|
139
|
+
callbacks.end_training(self)
|
|
140
|
+
|
|
141
|
+
def _update_tqdm(self, pbar, batch_num: int, update_tqdm_freq: int):
|
|
142
|
+
if is_divisor(batch_num, update_tqdm_freq):
|
|
143
|
+
last_loss = np.mean(self.losses[self.TOTAL_LOSS_KEY][-10:])
|
|
144
|
+
pbar.set_description(f'Loss = {last_loss:.2e}')
|
|
145
|
+
|
|
146
|
+
postfix = {}
|
|
147
|
+
for key in self.losses.keys():
|
|
148
|
+
if key != self.TOTAL_LOSS_KEY:
|
|
149
|
+
last_value = self.losses[key][-1]
|
|
150
|
+
postfix[key] = f'{last_value:.4f}'
|
|
151
|
+
|
|
152
|
+
postfix['lr'] = f'{self.lr():.2e}'
|
|
153
|
+
|
|
154
|
+
pbar.set_postfix(postfix)
|
|
155
|
+
|
|
156
|
+
def get_batch_by_idx(self, batch_num: int) -> Any:
|
|
157
|
+
raise NotImplementedError
|
|
158
|
+
|
|
159
|
+
def get_loss_dict(self, batch_data) -> dict:
|
|
160
|
+
raise NotImplementedError
|
|
161
|
+
|
|
162
|
+
def _update_losses(self, loss_dict: dict, loss: float) -> None:
|
|
163
|
+
_update_loss_dict(self.losses, loss_dict)
|
|
164
|
+
self.losses[self.TOTAL_LOSS_KEY].append(loss)
|
|
165
|
+
self.lrs.append(self.lr())
|
|
166
|
+
|
|
167
|
+
def configure_optimizer(self, optim_cls, lr: float, **kwargs) -> torch.optim.Optimizer:
|
|
168
|
+
"""configure the optimizer based on the optimizer class, the learning rate and the optimizer keyword arguments
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
optim_cls: the class of the optimizer
|
|
172
|
+
lr (float): the learning rate
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
torch.optim.Optimizer:
|
|
176
|
+
"""
|
|
177
|
+
optim = optim_cls(self.model.parameters(), lr, **kwargs)
|
|
178
|
+
return optim
|
|
179
|
+
|
|
180
|
+
def lr(self, param_group: int = 0) -> float:
|
|
181
|
+
"""get the learning rate"""
|
|
182
|
+
return self.optim.param_groups[param_group]['lr']
|
|
183
|
+
|
|
184
|
+
def set_lr(self, lr: float, param_group: int = 0) -> None:
|
|
185
|
+
"""set the learning rate"""
|
|
186
|
+
self.optim.param_groups[param_group]['lr'] = lr
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class TrainerCallback(object):
|
|
190
|
+
"""Base class for trainer callbacks
|
|
191
|
+
"""
|
|
192
|
+
def start_training(self, trainer: Trainer) -> None:
|
|
193
|
+
"""add functionality the start of training
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
trainer (Trainer): the trainer object
|
|
197
|
+
"""
|
|
198
|
+
pass
|
|
199
|
+
|
|
200
|
+
def end_training(self, trainer: Trainer) -> None:
|
|
201
|
+
"""add functionality at the end of training
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
trainer (Trainer): the trainer object
|
|
205
|
+
"""
|
|
206
|
+
pass
|
|
207
|
+
|
|
208
|
+
def end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
|
|
209
|
+
"""add functionality at the end of the iteration / batch
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
trainer (Trainer): the trainer object
|
|
213
|
+
batch_num (int): the index of the current iteration / batch
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
Union[bool, None]:
|
|
217
|
+
"""
|
|
218
|
+
pass
|
|
219
|
+
|
|
220
|
+
def __repr__(self):
|
|
221
|
+
return f'{self.__class__.__name__}()'
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class DataLoader(TrainerCallback):
|
|
225
|
+
pass
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
class PeriodicTrainerCallback(TrainerCallback):
|
|
229
|
+
"""Base class for trainer callbacks which perform an action periodically after a number of iterations
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
step (int, optional): Number of iterations after which the action is repeated. Defaults to 1.
|
|
233
|
+
last_epoch (int, optional): the last training iteration for which the action is performed. Defaults to -1.
|
|
234
|
+
"""
|
|
235
|
+
def __init__(self, step: int = 1, last_epoch: int = -1):
|
|
236
|
+
self.step = step
|
|
237
|
+
self.last_epoch = last_epoch
|
|
238
|
+
|
|
239
|
+
def end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
|
|
240
|
+
"""add functionality at the end of the iteration / batch
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
trainer (Trainer): the trainer object
|
|
244
|
+
batch_num (int): the index of the current iteration / batch
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
Union[bool, None]:
|
|
248
|
+
"""
|
|
249
|
+
if (
|
|
250
|
+
is_divisor(batch_num, self.step) and
|
|
251
|
+
(self.last_epoch == -1 or batch_num < self.last_epoch)
|
|
252
|
+
):
|
|
253
|
+
return self._end_batch(trainer, batch_num)
|
|
254
|
+
|
|
255
|
+
def _end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
|
|
256
|
+
pass
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
class _StackedTrainerCallbacks(TrainerCallback):
|
|
260
|
+
def __init__(self, callbacks: Iterable[TrainerCallback]):
|
|
261
|
+
self.callbacks = tuple(callbacks)
|
|
262
|
+
|
|
263
|
+
def start_training(self, trainer: Trainer) -> None:
|
|
264
|
+
for c in self.callbacks:
|
|
265
|
+
c.start_training(trainer)
|
|
266
|
+
|
|
267
|
+
def end_training(self, trainer: Trainer) -> None:
|
|
268
|
+
for c in self.callbacks:
|
|
269
|
+
c.end_training(trainer)
|
|
270
|
+
|
|
271
|
+
def end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
|
|
272
|
+
break_epoch = False
|
|
273
|
+
for c in self.callbacks:
|
|
274
|
+
break_epoch += bool(c.end_batch(trainer, batch_num))
|
|
275
|
+
return break_epoch
|
|
276
|
+
|
|
277
|
+
def __repr__(self):
|
|
278
|
+
callbacks = ", ".join(repr(c) for c in self.callbacks)
|
|
279
|
+
return f'StackedTrainerCallbacks({callbacks})'
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def _init_logger(logger: Union[Logger, Tuple[Logger, ...], Loggers] = None):
|
|
283
|
+
if not logger:
|
|
284
|
+
return Logger()
|
|
285
|
+
if isinstance(logger, Logger):
|
|
286
|
+
return logger
|
|
287
|
+
return Loggers(*logger)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def _update_loss_dict(loss_dict: dict, new_values: dict):
|
|
291
|
+
for k, v in new_values.items():
|
|
292
|
+
loss_dict[k].append(v.item())
|