reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -126
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -246
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -222
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -851
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +37 -0
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +524 -98
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -16
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -248
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -191
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -14
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -17
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -428
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -401
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +98 -68
  91. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
  94. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  95. reflectorch-1.3.0.dist-info/RECORD +0 -86
  96. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,292 +1,292 @@
1
- from typing import Optional, Tuple, Iterable, Any, Union, Type
2
- from collections import defaultdict
3
-
4
- from tqdm import tqdm as standard_tqdm
5
- from tqdm.notebook import tqdm as notebook_tqdm
6
- import numpy as np
7
-
8
- import torch
9
- from torch.nn import Module
10
-
11
- from reflectorch.ml.loggers import Logger, Loggers
12
-
13
- from .utils import is_divisor
14
-
15
- __all__ = [
16
- 'Trainer',
17
- 'TrainerCallback',
18
- 'DataLoader',
19
- 'PeriodicTrainerCallback',
20
- ]
21
-
22
-
23
- class Trainer(object):
24
- """Trainer class
25
-
26
- Args:
27
- model (nn.Module): neural network
28
- loader (DataLoader): data loader
29
- lr (float): learning rate
30
- batch_size (int): batch size
31
- clip_grad_norm (int, optional): maximum norm for gradient clipping if it is not ``None``. Defaults to None.
32
- logger (Union[Logger, Tuple[Logger, ...], Loggers], optional): logger. Defaults to None.
33
- optim_cls (Type[torch.optim.Optimizer], optional): Pytorch optimizer. Defaults to torch.optim.Adam.
34
- optim_kwargs (dict, optional): optimizer arguments. Defaults to None.
35
- """
36
-
37
- TOTAL_LOSS_KEY: str = 'total_loss'
38
-
39
- def __init__(self,
40
- model: Module,
41
- loader: 'DataLoader',
42
- lr: float,
43
- batch_size: int,
44
- clip_grad_norm_max: Optional[int] = None,
45
- logger: Union[Logger, Tuple[Logger, ...], Loggers] = None,
46
- optim_cls: Type[torch.optim.Optimizer] = torch.optim.Adam,
47
- optim_kwargs: dict = None,
48
- **kwargs
49
- ):
50
-
51
- self.model = model
52
- self.loader = loader
53
- self.batch_size = batch_size
54
- self.clip_grad_norm_max = clip_grad_norm_max
55
-
56
- self.optim = self.configure_optimizer(optim_cls, lr=lr, **(optim_kwargs or {}))
57
- self.lrs = []
58
- self.losses = defaultdict(list)
59
-
60
- self.logger = _init_logger(logger)
61
- self.callback_params = {}
62
-
63
- for k, v in kwargs.items():
64
- setattr(self, k, v)
65
-
66
- self.init()
67
-
68
- def init(self):
69
- pass
70
-
71
- def log(self, name: str, data):
72
- """log data"""
73
- self.logger.log(name, data)
74
-
75
- def train(self,
76
- num_batches: int,
77
- callbacks: Union[Tuple['TrainerCallback', ...], 'TrainerCallback'] = (),
78
- disable_tqdm: bool = False,
79
- use_notebook_tqdm: bool = False,
80
- update_tqdm_freq: int = 1,
81
- grad_accumulation_steps: int = 1,
82
- ):
83
- """starts the training process
84
-
85
- Args:
86
- num_batches (int): total number of training iterations
87
- callbacks (Union[Tuple['TrainerCallback'], 'TrainerCallback']): the trainer callbacks. Defaults to ().
88
- disable_tqdm (bool, optional): if ``True``, the progress bar is disabled. Defaults to False.
89
- use_notebook_tqdm (bool, optional): should be set to ``True`` when used in a Jupyter Notebook. Defaults to False.
90
- update_tqdm_freq (int, optional): frequency for updating the progress bar. Defaults to 10.
91
- grad_accumulation_steps (int, optional): number of gradient accumulation steps. Defaults to 1.
92
- """
93
-
94
- if isinstance(callbacks, TrainerCallback):
95
- callbacks = (callbacks,)
96
-
97
- callbacks = _StackedTrainerCallbacks(list(callbacks) + [self.loader])
98
-
99
- tqdm_class = notebook_tqdm if use_notebook_tqdm else standard_tqdm
100
- pbar = tqdm_class(range(num_batches), disable=disable_tqdm)
101
-
102
- callbacks.start_training(self)
103
-
104
- for batch_num in pbar:
105
- self.model.train()
106
-
107
- self.optim.zero_grad()
108
- total_loss, avr_loss_dict = 0, defaultdict(list)
109
-
110
- for _ in range(grad_accumulation_steps):
111
-
112
- batch_data = self.get_batch_by_idx(batch_num)
113
- loss_dict = self.get_loss_dict(batch_data)
114
- loss = loss_dict['loss'] / grad_accumulation_steps
115
- total_loss += loss.item()
116
- _update_loss_dict(avr_loss_dict, loss_dict)
117
-
118
- if not torch.isfinite(loss).item():
119
- raise ValueError('Loss is not finite!')
120
-
121
- loss.backward()
122
-
123
- if self.clip_grad_norm_max is not None:
124
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip_grad_norm_max)
125
-
126
- self.optim.step()
127
-
128
- avr_loss_dict = {k: np.mean(v) for k, v in avr_loss_dict.items()}
129
- self._update_losses(avr_loss_dict, total_loss)
130
-
131
- if not disable_tqdm:
132
- self._update_tqdm(pbar, batch_num, update_tqdm_freq)
133
-
134
- break_epoch = callbacks.end_batch(self, batch_num)
135
-
136
- if break_epoch:
137
- break
138
-
139
- callbacks.end_training(self)
140
-
141
- def _update_tqdm(self, pbar, batch_num: int, update_tqdm_freq: int):
142
- if is_divisor(batch_num, update_tqdm_freq):
143
- last_loss = np.mean(self.losses[self.TOTAL_LOSS_KEY][-10:])
144
- pbar.set_description(f'Loss = {last_loss:.2e}')
145
-
146
- postfix = {}
147
- for key in self.losses.keys():
148
- if key != self.TOTAL_LOSS_KEY:
149
- last_value = self.losses[key][-1]
150
- postfix[key] = f'{last_value:.4f}'
151
-
152
- postfix['lr'] = f'{self.lr():.2e}'
153
-
154
- pbar.set_postfix(postfix)
155
-
156
- def get_batch_by_idx(self, batch_num: int) -> Any:
157
- raise NotImplementedError
158
-
159
- def get_loss_dict(self, batch_data) -> dict:
160
- raise NotImplementedError
161
-
162
- def _update_losses(self, loss_dict: dict, loss: float) -> None:
163
- _update_loss_dict(self.losses, loss_dict)
164
- self.losses[self.TOTAL_LOSS_KEY].append(loss)
165
- self.lrs.append(self.lr())
166
-
167
- def configure_optimizer(self, optim_cls, lr: float, **kwargs) -> torch.optim.Optimizer:
168
- """configure the optimizer based on the optimizer class, the learning rate and the optimizer keyword arguments
169
-
170
- Args:
171
- optim_cls: the class of the optimizer
172
- lr (float): the learning rate
173
-
174
- Returns:
175
- torch.optim.Optimizer:
176
- """
177
- optim = optim_cls(self.model.parameters(), lr, **kwargs)
178
- return optim
179
-
180
- def lr(self, param_group: int = 0) -> float:
181
- """get the learning rate"""
182
- return self.optim.param_groups[param_group]['lr']
183
-
184
- def set_lr(self, lr: float, param_group: int = 0) -> None:
185
- """set the learning rate"""
186
- self.optim.param_groups[param_group]['lr'] = lr
187
-
188
-
189
- class TrainerCallback(object):
190
- """Base class for trainer callbacks
191
- """
192
- def start_training(self, trainer: Trainer) -> None:
193
- """add functionality the start of training
194
-
195
- Args:
196
- trainer (Trainer): the trainer object
197
- """
198
- pass
199
-
200
- def end_training(self, trainer: Trainer) -> None:
201
- """add functionality at the end of training
202
-
203
- Args:
204
- trainer (Trainer): the trainer object
205
- """
206
- pass
207
-
208
- def end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
209
- """add functionality at the end of the iteration / batch
210
-
211
- Args:
212
- trainer (Trainer): the trainer object
213
- batch_num (int): the index of the current iteration / batch
214
-
215
- Returns:
216
- Union[bool, None]:
217
- """
218
- pass
219
-
220
- def __repr__(self):
221
- return f'{self.__class__.__name__}()'
222
-
223
-
224
- class DataLoader(TrainerCallback):
225
- pass
226
-
227
-
228
- class PeriodicTrainerCallback(TrainerCallback):
229
- """Base class for trainer callbacks which perform an action periodically after a number of iterations
230
-
231
- Args:
232
- step (int, optional): Number of iterations after which the action is repeated. Defaults to 1.
233
- last_epoch (int, optional): the last training iteration for which the action is performed. Defaults to -1.
234
- """
235
- def __init__(self, step: int = 1, last_epoch: int = -1):
236
- self.step = step
237
- self.last_epoch = last_epoch
238
-
239
- def end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
240
- """add functionality at the end of the iteration / batch
241
-
242
- Args:
243
- trainer (Trainer): the trainer object
244
- batch_num (int): the index of the current iteration / batch
245
-
246
- Returns:
247
- Union[bool, None]:
248
- """
249
- if (
250
- is_divisor(batch_num, self.step) and
251
- (self.last_epoch == -1 or batch_num < self.last_epoch)
252
- ):
253
- return self._end_batch(trainer, batch_num)
254
-
255
- def _end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
256
- pass
257
-
258
-
259
- class _StackedTrainerCallbacks(TrainerCallback):
260
- def __init__(self, callbacks: Iterable[TrainerCallback]):
261
- self.callbacks = tuple(callbacks)
262
-
263
- def start_training(self, trainer: Trainer) -> None:
264
- for c in self.callbacks:
265
- c.start_training(trainer)
266
-
267
- def end_training(self, trainer: Trainer) -> None:
268
- for c in self.callbacks:
269
- c.end_training(trainer)
270
-
271
- def end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
272
- break_epoch = False
273
- for c in self.callbacks:
274
- break_epoch += bool(c.end_batch(trainer, batch_num))
275
- return break_epoch
276
-
277
- def __repr__(self):
278
- callbacks = ", ".join(repr(c) for c in self.callbacks)
279
- return f'StackedTrainerCallbacks({callbacks})'
280
-
281
-
282
- def _init_logger(logger: Union[Logger, Tuple[Logger, ...], Loggers] = None):
283
- if not logger:
284
- return Logger()
285
- if isinstance(logger, Logger):
286
- return logger
287
- return Loggers(*logger)
288
-
289
-
290
- def _update_loss_dict(loss_dict: dict, new_values: dict):
291
- for k, v in new_values.items():
292
- loss_dict[k].append(v.item())
1
+ from typing import Optional, Tuple, Iterable, Any, Union, Type
2
+ from collections import defaultdict
3
+
4
+ from tqdm import tqdm as standard_tqdm
5
+ from tqdm.notebook import tqdm as notebook_tqdm
6
+ import numpy as np
7
+
8
+ import torch
9
+ from torch.nn import Module
10
+
11
+ from reflectorch.ml.loggers import Logger, Loggers
12
+
13
+ from .utils import is_divisor
14
+
15
+ __all__ = [
16
+ 'Trainer',
17
+ 'TrainerCallback',
18
+ 'DataLoader',
19
+ 'PeriodicTrainerCallback',
20
+ ]
21
+
22
+
23
+ class Trainer(object):
24
+ """Trainer class
25
+
26
+ Args:
27
+ model (nn.Module): neural network
28
+ loader (DataLoader): data loader
29
+ lr (float): learning rate
30
+ batch_size (int): batch size
31
+ clip_grad_norm (int, optional): maximum norm for gradient clipping if it is not ``None``. Defaults to None.
32
+ logger (Union[Logger, Tuple[Logger, ...], Loggers], optional): logger. Defaults to None.
33
+ optim_cls (Type[torch.optim.Optimizer], optional): Pytorch optimizer. Defaults to torch.optim.Adam.
34
+ optim_kwargs (dict, optional): optimizer arguments. Defaults to None.
35
+ """
36
+
37
+ TOTAL_LOSS_KEY: str = 'total_loss'
38
+
39
+ def __init__(self,
40
+ model: Module,
41
+ loader: 'DataLoader',
42
+ lr: float,
43
+ batch_size: int,
44
+ clip_grad_norm_max: Optional[int] = None,
45
+ logger: Union[Logger, Tuple[Logger, ...], Loggers] = None,
46
+ optim_cls: Type[torch.optim.Optimizer] = torch.optim.Adam,
47
+ optim_kwargs: dict = None,
48
+ **kwargs
49
+ ):
50
+
51
+ self.model = model
52
+ self.loader = loader
53
+ self.batch_size = batch_size
54
+ self.clip_grad_norm_max = clip_grad_norm_max
55
+
56
+ self.optim = self.configure_optimizer(optim_cls, lr=lr, **(optim_kwargs or {}))
57
+ self.lrs = []
58
+ self.losses = defaultdict(list)
59
+
60
+ self.logger = _init_logger(logger)
61
+ self.callback_params = {}
62
+
63
+ for k, v in kwargs.items():
64
+ setattr(self, k, v)
65
+
66
+ self.init()
67
+
68
+ def init(self):
69
+ pass
70
+
71
+ def log(self, name: str, data):
72
+ """log data"""
73
+ self.logger.log(name, data)
74
+
75
+ def train(self,
76
+ num_batches: int,
77
+ callbacks: Union[Tuple['TrainerCallback', ...], 'TrainerCallback'] = (),
78
+ disable_tqdm: bool = False,
79
+ use_notebook_tqdm: bool = False,
80
+ update_tqdm_freq: int = 1,
81
+ grad_accumulation_steps: int = 1,
82
+ ):
83
+ """starts the training process
84
+
85
+ Args:
86
+ num_batches (int): total number of training iterations
87
+ callbacks (Union[Tuple['TrainerCallback'], 'TrainerCallback']): the trainer callbacks. Defaults to ().
88
+ disable_tqdm (bool, optional): if ``True``, the progress bar is disabled. Defaults to False.
89
+ use_notebook_tqdm (bool, optional): should be set to ``True`` when used in a Jupyter Notebook. Defaults to False.
90
+ update_tqdm_freq (int, optional): frequency for updating the progress bar. Defaults to 10.
91
+ grad_accumulation_steps (int, optional): number of gradient accumulation steps. Defaults to 1.
92
+ """
93
+
94
+ if isinstance(callbacks, TrainerCallback):
95
+ callbacks = (callbacks,)
96
+
97
+ callbacks = _StackedTrainerCallbacks(list(callbacks) + [self.loader])
98
+
99
+ tqdm_class = notebook_tqdm if use_notebook_tqdm else standard_tqdm
100
+ pbar = tqdm_class(range(num_batches), disable=disable_tqdm)
101
+
102
+ callbacks.start_training(self)
103
+
104
+ for batch_num in pbar:
105
+ self.model.train()
106
+
107
+ self.optim.zero_grad()
108
+ total_loss, avr_loss_dict = 0, defaultdict(list)
109
+
110
+ for _ in range(grad_accumulation_steps):
111
+
112
+ batch_data = self.get_batch_by_idx(batch_num)
113
+ loss_dict = self.get_loss_dict(batch_data)
114
+ loss = loss_dict['loss'] / grad_accumulation_steps
115
+ total_loss += loss.item()
116
+ _update_loss_dict(avr_loss_dict, loss_dict)
117
+
118
+ if not torch.isfinite(loss).item():
119
+ raise ValueError('Loss is not finite!')
120
+
121
+ loss.backward()
122
+
123
+ if self.clip_grad_norm_max is not None:
124
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip_grad_norm_max)
125
+
126
+ self.optim.step()
127
+
128
+ avr_loss_dict = {k: np.mean(v) for k, v in avr_loss_dict.items()}
129
+ self._update_losses(avr_loss_dict, total_loss)
130
+
131
+ if not disable_tqdm:
132
+ self._update_tqdm(pbar, batch_num, update_tqdm_freq)
133
+
134
+ break_epoch = callbacks.end_batch(self, batch_num)
135
+
136
+ if break_epoch:
137
+ break
138
+
139
+ callbacks.end_training(self)
140
+
141
+ def _update_tqdm(self, pbar, batch_num: int, update_tqdm_freq: int):
142
+ if is_divisor(batch_num, update_tqdm_freq):
143
+ last_loss = np.mean(self.losses[self.TOTAL_LOSS_KEY][-10:])
144
+ pbar.set_description(f'Loss = {last_loss:.2e}')
145
+
146
+ postfix = {}
147
+ for key in self.losses.keys():
148
+ if key != self.TOTAL_LOSS_KEY:
149
+ last_value = self.losses[key][-1]
150
+ postfix[key] = f'{last_value:.4f}'
151
+
152
+ postfix['lr'] = f'{self.lr():.2e}'
153
+
154
+ pbar.set_postfix(postfix)
155
+
156
+ def get_batch_by_idx(self, batch_num: int) -> Any:
157
+ raise NotImplementedError
158
+
159
+ def get_loss_dict(self, batch_data) -> dict:
160
+ raise NotImplementedError
161
+
162
+ def _update_losses(self, loss_dict: dict, loss: float) -> None:
163
+ _update_loss_dict(self.losses, loss_dict)
164
+ self.losses[self.TOTAL_LOSS_KEY].append(loss)
165
+ self.lrs.append(self.lr())
166
+
167
+ def configure_optimizer(self, optim_cls, lr: float, **kwargs) -> torch.optim.Optimizer:
168
+ """configure the optimizer based on the optimizer class, the learning rate and the optimizer keyword arguments
169
+
170
+ Args:
171
+ optim_cls: the class of the optimizer
172
+ lr (float): the learning rate
173
+
174
+ Returns:
175
+ torch.optim.Optimizer:
176
+ """
177
+ optim = optim_cls(self.model.parameters(), lr, **kwargs)
178
+ return optim
179
+
180
+ def lr(self, param_group: int = 0) -> float:
181
+ """get the learning rate"""
182
+ return self.optim.param_groups[param_group]['lr']
183
+
184
+ def set_lr(self, lr: float, param_group: int = 0) -> None:
185
+ """set the learning rate"""
186
+ self.optim.param_groups[param_group]['lr'] = lr
187
+
188
+
189
+ class TrainerCallback(object):
190
+ """Base class for trainer callbacks
191
+ """
192
+ def start_training(self, trainer: Trainer) -> None:
193
+ """add functionality the start of training
194
+
195
+ Args:
196
+ trainer (Trainer): the trainer object
197
+ """
198
+ pass
199
+
200
+ def end_training(self, trainer: Trainer) -> None:
201
+ """add functionality at the end of training
202
+
203
+ Args:
204
+ trainer (Trainer): the trainer object
205
+ """
206
+ pass
207
+
208
+ def end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
209
+ """add functionality at the end of the iteration / batch
210
+
211
+ Args:
212
+ trainer (Trainer): the trainer object
213
+ batch_num (int): the index of the current iteration / batch
214
+
215
+ Returns:
216
+ Union[bool, None]:
217
+ """
218
+ pass
219
+
220
+ def __repr__(self):
221
+ return f'{self.__class__.__name__}()'
222
+
223
+
224
+ class DataLoader(TrainerCallback):
225
+ pass
226
+
227
+
228
+ class PeriodicTrainerCallback(TrainerCallback):
229
+ """Base class for trainer callbacks which perform an action periodically after a number of iterations
230
+
231
+ Args:
232
+ step (int, optional): Number of iterations after which the action is repeated. Defaults to 1.
233
+ last_epoch (int, optional): the last training iteration for which the action is performed. Defaults to -1.
234
+ """
235
+ def __init__(self, step: int = 1, last_epoch: int = -1):
236
+ self.step = step
237
+ self.last_epoch = last_epoch
238
+
239
+ def end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
240
+ """add functionality at the end of the iteration / batch
241
+
242
+ Args:
243
+ trainer (Trainer): the trainer object
244
+ batch_num (int): the index of the current iteration / batch
245
+
246
+ Returns:
247
+ Union[bool, None]:
248
+ """
249
+ if (
250
+ is_divisor(batch_num, self.step) and
251
+ (self.last_epoch == -1 or batch_num < self.last_epoch)
252
+ ):
253
+ return self._end_batch(trainer, batch_num)
254
+
255
+ def _end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
256
+ pass
257
+
258
+
259
+ class _StackedTrainerCallbacks(TrainerCallback):
260
+ def __init__(self, callbacks: Iterable[TrainerCallback]):
261
+ self.callbacks = tuple(callbacks)
262
+
263
+ def start_training(self, trainer: Trainer) -> None:
264
+ for c in self.callbacks:
265
+ c.start_training(trainer)
266
+
267
+ def end_training(self, trainer: Trainer) -> None:
268
+ for c in self.callbacks:
269
+ c.end_training(trainer)
270
+
271
+ def end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
272
+ break_epoch = False
273
+ for c in self.callbacks:
274
+ break_epoch += bool(c.end_batch(trainer, batch_num))
275
+ return break_epoch
276
+
277
+ def __repr__(self):
278
+ callbacks = ", ".join(repr(c) for c in self.callbacks)
279
+ return f'StackedTrainerCallbacks({callbacks})'
280
+
281
+
282
+ def _init_logger(logger: Union[Logger, Tuple[Logger, ...], Loggers] = None):
283
+ if not logger:
284
+ return Logger()
285
+ if isinstance(logger, Logger):
286
+ return logger
287
+ return Loggers(*logger)
288
+
289
+
290
+ def _update_loss_dict(loss_dict: dict, new_values: dict):
291
+ for k, v in new_values.items():
292
+ loss_dict[k].append(v.item())