reflectorch 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- reflectorch/__init__.py +17 -0
- reflectorch/data_generation/__init__.py +128 -0
- reflectorch/data_generation/dataset.py +216 -0
- reflectorch/data_generation/likelihoods.py +80 -0
- reflectorch/data_generation/noise.py +471 -0
- reflectorch/data_generation/priors/__init__.py +60 -0
- reflectorch/data_generation/priors/base.py +55 -0
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
- reflectorch/data_generation/priors/independent_priors.py +195 -0
- reflectorch/data_generation/priors/multilayer_models.py +311 -0
- reflectorch/data_generation/priors/multilayer_structures.py +104 -0
- reflectorch/data_generation/priors/no_constraints.py +206 -0
- reflectorch/data_generation/priors/parametric_models.py +842 -0
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
- reflectorch/data_generation/priors/params.py +252 -0
- reflectorch/data_generation/priors/sampler_strategies.py +370 -0
- reflectorch/data_generation/priors/scaler_mixin.py +65 -0
- reflectorch/data_generation/priors/subprior_sampler.py +371 -0
- reflectorch/data_generation/priors/utils.py +118 -0
- reflectorch/data_generation/process_data.py +41 -0
- reflectorch/data_generation/q_generator.py +280 -0
- reflectorch/data_generation/reflectivity/__init__.py +102 -0
- reflectorch/data_generation/reflectivity/abeles.py +97 -0
- reflectorch/data_generation/reflectivity/kinematical.py +71 -0
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
- reflectorch/data_generation/reflectivity/smearing.py +138 -0
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
- reflectorch/data_generation/scale_curves.py +112 -0
- reflectorch/data_generation/smearing.py +99 -0
- reflectorch/data_generation/utils.py +223 -0
- reflectorch/extensions/__init__.py +0 -0
- reflectorch/extensions/jupyter/__init__.py +11 -0
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -0
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -0
- reflectorch/extensions/matplotlib/losses.py +32 -0
- reflectorch/extensions/refnx/refnx_conversion.py +77 -0
- reflectorch/inference/__init__.py +28 -0
- reflectorch/inference/inference_model.py +848 -0
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +55 -0
- reflectorch/inference/multilayer_fitter.py +171 -0
- reflectorch/inference/multilayer_inference_model.py +193 -0
- reflectorch/inference/plotting.py +524 -0
- reflectorch/inference/preprocess_exp/__init__.py +7 -0
- reflectorch/inference/preprocess_exp/attenuation.py +36 -0
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
- reflectorch/inference/preprocess_exp/footprint.py +81 -0
- reflectorch/inference/preprocess_exp/interpolation.py +19 -0
- reflectorch/inference/preprocess_exp/normalize.py +21 -0
- reflectorch/inference/preprocess_exp/preprocess.py +121 -0
- reflectorch/inference/query_matcher.py +82 -0
- reflectorch/inference/record_time.py +43 -0
- reflectorch/inference/sampler_solution.py +56 -0
- reflectorch/inference/scipy_fitter.py +364 -0
- reflectorch/inference/torch_fitter.py +87 -0
- reflectorch/ml/__init__.py +32 -0
- reflectorch/ml/basic_trainer.py +292 -0
- reflectorch/ml/callbacks.py +81 -0
- reflectorch/ml/dataloaders.py +27 -0
- reflectorch/ml/loggers.py +56 -0
- reflectorch/ml/schedulers.py +356 -0
- reflectorch/ml/trainers.py +201 -0
- reflectorch/ml/utils.py +2 -0
- reflectorch/models/__init__.py +16 -0
- reflectorch/models/activations.py +50 -0
- reflectorch/models/encoders/__init__.py +19 -0
- reflectorch/models/encoders/conv_encoder.py +219 -0
- reflectorch/models/encoders/conv_res_net.py +115 -0
- reflectorch/models/encoders/fno.py +134 -0
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -0
- reflectorch/models/networks/mlp_networks.py +434 -0
- reflectorch/models/networks/residual_net.py +157 -0
- reflectorch/paths.py +29 -0
- reflectorch/runs/__init__.py +31 -0
- reflectorch/runs/config.py +25 -0
- reflectorch/runs/slurm_utils.py +93 -0
- reflectorch/runs/train.py +78 -0
- reflectorch/runs/utils.py +405 -0
- reflectorch/test_config.py +4 -0
- reflectorch/train.py +4 -0
- reflectorch/train_on_cluster.py +4 -0
- reflectorch/utils.py +98 -0
- reflectorch-1.5.1.dist-info/METADATA +151 -0
- reflectorch-1.5.1.dist-info/RECORD +96 -0
- reflectorch-1.5.1.dist-info/WHEEL +5 -0
- reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
- reflectorch-1.5.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
from typing import Optional, Tuple, Iterable, Any, Union, Type
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
|
|
4
|
+
from tqdm import tqdm as standard_tqdm
|
|
5
|
+
from tqdm.notebook import tqdm as notebook_tqdm
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch.nn import Module
|
|
10
|
+
|
|
11
|
+
from reflectorch.ml.loggers import Logger, Loggers
|
|
12
|
+
|
|
13
|
+
from .utils import is_divisor
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
'Trainer',
|
|
17
|
+
'TrainerCallback',
|
|
18
|
+
'DataLoader',
|
|
19
|
+
'PeriodicTrainerCallback',
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Trainer(object):
|
|
24
|
+
"""Trainer class
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
model (nn.Module): neural network
|
|
28
|
+
loader (DataLoader): data loader
|
|
29
|
+
lr (float): learning rate
|
|
30
|
+
batch_size (int): batch size
|
|
31
|
+
clip_grad_norm (int, optional): maximum norm for gradient clipping if it is not ``None``. Defaults to None.
|
|
32
|
+
logger (Union[Logger, Tuple[Logger, ...], Loggers], optional): logger. Defaults to None.
|
|
33
|
+
optim_cls (Type[torch.optim.Optimizer], optional): Pytorch optimizer. Defaults to torch.optim.Adam.
|
|
34
|
+
optim_kwargs (dict, optional): optimizer arguments. Defaults to None.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
TOTAL_LOSS_KEY: str = 'total_loss'
|
|
38
|
+
|
|
39
|
+
def __init__(self,
|
|
40
|
+
model: Module,
|
|
41
|
+
loader: 'DataLoader',
|
|
42
|
+
lr: float,
|
|
43
|
+
batch_size: int,
|
|
44
|
+
clip_grad_norm_max: Optional[int] = None,
|
|
45
|
+
logger: Union[Logger, Tuple[Logger, ...], Loggers] = None,
|
|
46
|
+
optim_cls: Type[torch.optim.Optimizer] = torch.optim.Adam,
|
|
47
|
+
optim_kwargs: dict = None,
|
|
48
|
+
**kwargs
|
|
49
|
+
):
|
|
50
|
+
|
|
51
|
+
self.model = model
|
|
52
|
+
self.loader = loader
|
|
53
|
+
self.batch_size = batch_size
|
|
54
|
+
self.clip_grad_norm_max = clip_grad_norm_max
|
|
55
|
+
|
|
56
|
+
self.optim = self.configure_optimizer(optim_cls, lr=lr, **(optim_kwargs or {}))
|
|
57
|
+
self.lrs = []
|
|
58
|
+
self.losses = defaultdict(list)
|
|
59
|
+
|
|
60
|
+
self.logger = _init_logger(logger)
|
|
61
|
+
self.callback_params = {}
|
|
62
|
+
|
|
63
|
+
for k, v in kwargs.items():
|
|
64
|
+
setattr(self, k, v)
|
|
65
|
+
|
|
66
|
+
self.init()
|
|
67
|
+
|
|
68
|
+
def init(self):
|
|
69
|
+
pass
|
|
70
|
+
|
|
71
|
+
def log(self, name: str, data):
|
|
72
|
+
"""log data"""
|
|
73
|
+
self.logger.log(name, data)
|
|
74
|
+
|
|
75
|
+
def train(self,
|
|
76
|
+
num_batches: int,
|
|
77
|
+
callbacks: Union[Tuple['TrainerCallback', ...], 'TrainerCallback'] = (),
|
|
78
|
+
disable_tqdm: bool = False,
|
|
79
|
+
use_notebook_tqdm: bool = False,
|
|
80
|
+
update_tqdm_freq: int = 1,
|
|
81
|
+
grad_accumulation_steps: int = 1,
|
|
82
|
+
):
|
|
83
|
+
"""starts the training process
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
num_batches (int): total number of training iterations
|
|
87
|
+
callbacks (Union[Tuple['TrainerCallback'], 'TrainerCallback']): the trainer callbacks. Defaults to ().
|
|
88
|
+
disable_tqdm (bool, optional): if ``True``, the progress bar is disabled. Defaults to False.
|
|
89
|
+
use_notebook_tqdm (bool, optional): should be set to ``True`` when used in a Jupyter Notebook. Defaults to False.
|
|
90
|
+
update_tqdm_freq (int, optional): frequency for updating the progress bar. Defaults to 10.
|
|
91
|
+
grad_accumulation_steps (int, optional): number of gradient accumulation steps. Defaults to 1.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
if isinstance(callbacks, TrainerCallback):
|
|
95
|
+
callbacks = (callbacks,)
|
|
96
|
+
|
|
97
|
+
callbacks = _StackedTrainerCallbacks(list(callbacks) + [self.loader])
|
|
98
|
+
|
|
99
|
+
tqdm_class = notebook_tqdm if use_notebook_tqdm else standard_tqdm
|
|
100
|
+
pbar = tqdm_class(range(num_batches), disable=disable_tqdm)
|
|
101
|
+
|
|
102
|
+
callbacks.start_training(self)
|
|
103
|
+
|
|
104
|
+
for batch_num in pbar:
|
|
105
|
+
self.model.train()
|
|
106
|
+
|
|
107
|
+
self.optim.zero_grad()
|
|
108
|
+
total_loss, avr_loss_dict = 0, defaultdict(list)
|
|
109
|
+
|
|
110
|
+
for _ in range(grad_accumulation_steps):
|
|
111
|
+
|
|
112
|
+
batch_data = self.get_batch_by_idx(batch_num)
|
|
113
|
+
loss_dict = self.get_loss_dict(batch_data)
|
|
114
|
+
loss = loss_dict['loss'] / grad_accumulation_steps
|
|
115
|
+
total_loss += loss.item()
|
|
116
|
+
_update_loss_dict(avr_loss_dict, loss_dict)
|
|
117
|
+
|
|
118
|
+
if not torch.isfinite(loss).item():
|
|
119
|
+
raise ValueError('Loss is not finite!')
|
|
120
|
+
|
|
121
|
+
loss.backward()
|
|
122
|
+
|
|
123
|
+
if self.clip_grad_norm_max is not None:
|
|
124
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip_grad_norm_max)
|
|
125
|
+
|
|
126
|
+
self.optim.step()
|
|
127
|
+
|
|
128
|
+
avr_loss_dict = {k: np.mean(v) for k, v in avr_loss_dict.items()}
|
|
129
|
+
self._update_losses(avr_loss_dict, total_loss)
|
|
130
|
+
|
|
131
|
+
if not disable_tqdm:
|
|
132
|
+
self._update_tqdm(pbar, batch_num, update_tqdm_freq)
|
|
133
|
+
|
|
134
|
+
break_epoch = callbacks.end_batch(self, batch_num)
|
|
135
|
+
|
|
136
|
+
if break_epoch:
|
|
137
|
+
break
|
|
138
|
+
|
|
139
|
+
callbacks.end_training(self)
|
|
140
|
+
|
|
141
|
+
def _update_tqdm(self, pbar, batch_num: int, update_tqdm_freq: int):
|
|
142
|
+
if is_divisor(batch_num, update_tqdm_freq):
|
|
143
|
+
last_loss = np.mean(self.losses[self.TOTAL_LOSS_KEY][-10:])
|
|
144
|
+
pbar.set_description(f'Loss = {last_loss:.2e}')
|
|
145
|
+
|
|
146
|
+
postfix = {}
|
|
147
|
+
for key in self.losses.keys():
|
|
148
|
+
if key != self.TOTAL_LOSS_KEY:
|
|
149
|
+
last_value = self.losses[key][-1]
|
|
150
|
+
postfix[key] = f'{last_value:.4f}'
|
|
151
|
+
|
|
152
|
+
postfix['lr'] = f'{self.lr():.2e}'
|
|
153
|
+
|
|
154
|
+
pbar.set_postfix(postfix)
|
|
155
|
+
|
|
156
|
+
def get_batch_by_idx(self, batch_num: int) -> Any:
|
|
157
|
+
raise NotImplementedError
|
|
158
|
+
|
|
159
|
+
def get_loss_dict(self, batch_data) -> dict:
|
|
160
|
+
raise NotImplementedError
|
|
161
|
+
|
|
162
|
+
def _update_losses(self, loss_dict: dict, loss: float) -> None:
|
|
163
|
+
_update_loss_dict(self.losses, loss_dict)
|
|
164
|
+
self.losses[self.TOTAL_LOSS_KEY].append(loss)
|
|
165
|
+
self.lrs.append(self.lr())
|
|
166
|
+
|
|
167
|
+
def configure_optimizer(self, optim_cls, lr: float, **kwargs) -> torch.optim.Optimizer:
|
|
168
|
+
"""configure the optimizer based on the optimizer class, the learning rate and the optimizer keyword arguments
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
optim_cls: the class of the optimizer
|
|
172
|
+
lr (float): the learning rate
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
torch.optim.Optimizer:
|
|
176
|
+
"""
|
|
177
|
+
optim = optim_cls(self.model.parameters(), lr, **kwargs)
|
|
178
|
+
return optim
|
|
179
|
+
|
|
180
|
+
def lr(self, param_group: int = 0) -> float:
|
|
181
|
+
"""get the learning rate"""
|
|
182
|
+
return self.optim.param_groups[param_group]['lr']
|
|
183
|
+
|
|
184
|
+
def set_lr(self, lr: float, param_group: int = 0) -> None:
|
|
185
|
+
"""set the learning rate"""
|
|
186
|
+
self.optim.param_groups[param_group]['lr'] = lr
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class TrainerCallback(object):
|
|
190
|
+
"""Base class for trainer callbacks
|
|
191
|
+
"""
|
|
192
|
+
def start_training(self, trainer: Trainer) -> None:
|
|
193
|
+
"""add functionality the start of training
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
trainer (Trainer): the trainer object
|
|
197
|
+
"""
|
|
198
|
+
pass
|
|
199
|
+
|
|
200
|
+
def end_training(self, trainer: Trainer) -> None:
|
|
201
|
+
"""add functionality at the end of training
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
trainer (Trainer): the trainer object
|
|
205
|
+
"""
|
|
206
|
+
pass
|
|
207
|
+
|
|
208
|
+
def end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
|
|
209
|
+
"""add functionality at the end of the iteration / batch
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
trainer (Trainer): the trainer object
|
|
213
|
+
batch_num (int): the index of the current iteration / batch
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
Union[bool, None]:
|
|
217
|
+
"""
|
|
218
|
+
pass
|
|
219
|
+
|
|
220
|
+
def __repr__(self):
|
|
221
|
+
return f'{self.__class__.__name__}()'
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class DataLoader(TrainerCallback):
|
|
225
|
+
pass
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
class PeriodicTrainerCallback(TrainerCallback):
|
|
229
|
+
"""Base class for trainer callbacks which perform an action periodically after a number of iterations
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
step (int, optional): Number of iterations after which the action is repeated. Defaults to 1.
|
|
233
|
+
last_epoch (int, optional): the last training iteration for which the action is performed. Defaults to -1.
|
|
234
|
+
"""
|
|
235
|
+
def __init__(self, step: int = 1, last_epoch: int = -1):
|
|
236
|
+
self.step = step
|
|
237
|
+
self.last_epoch = last_epoch
|
|
238
|
+
|
|
239
|
+
def end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
|
|
240
|
+
"""add functionality at the end of the iteration / batch
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
trainer (Trainer): the trainer object
|
|
244
|
+
batch_num (int): the index of the current iteration / batch
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
Union[bool, None]:
|
|
248
|
+
"""
|
|
249
|
+
if (
|
|
250
|
+
is_divisor(batch_num, self.step) and
|
|
251
|
+
(self.last_epoch == -1 or batch_num < self.last_epoch)
|
|
252
|
+
):
|
|
253
|
+
return self._end_batch(trainer, batch_num)
|
|
254
|
+
|
|
255
|
+
def _end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
|
|
256
|
+
pass
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
class _StackedTrainerCallbacks(TrainerCallback):
|
|
260
|
+
def __init__(self, callbacks: Iterable[TrainerCallback]):
|
|
261
|
+
self.callbacks = tuple(callbacks)
|
|
262
|
+
|
|
263
|
+
def start_training(self, trainer: Trainer) -> None:
|
|
264
|
+
for c in self.callbacks:
|
|
265
|
+
c.start_training(trainer)
|
|
266
|
+
|
|
267
|
+
def end_training(self, trainer: Trainer) -> None:
|
|
268
|
+
for c in self.callbacks:
|
|
269
|
+
c.end_training(trainer)
|
|
270
|
+
|
|
271
|
+
def end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
|
|
272
|
+
break_epoch = False
|
|
273
|
+
for c in self.callbacks:
|
|
274
|
+
break_epoch += bool(c.end_batch(trainer, batch_num))
|
|
275
|
+
return break_epoch
|
|
276
|
+
|
|
277
|
+
def __repr__(self):
|
|
278
|
+
callbacks = ", ".join(repr(c) for c in self.callbacks)
|
|
279
|
+
return f'StackedTrainerCallbacks({callbacks})'
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def _init_logger(logger: Union[Logger, Tuple[Logger, ...], Loggers] = None):
|
|
283
|
+
if not logger:
|
|
284
|
+
return Logger()
|
|
285
|
+
if isinstance(logger, Logger):
|
|
286
|
+
return logger
|
|
287
|
+
return Loggers(*logger)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def _update_loss_dict(loss_dict: dict, new_values: dict):
|
|
291
|
+
for k, v in new_values.items():
|
|
292
|
+
loss_dict[k].append(v.item())
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from reflectorch.ml.basic_trainer import (
|
|
6
|
+
TrainerCallback,
|
|
7
|
+
Trainer,
|
|
8
|
+
)
|
|
9
|
+
from reflectorch.ml.utils import is_divisor
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
'SaveBestModel',
|
|
13
|
+
'LogLosses',
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SaveBestModel(TrainerCallback):
|
|
18
|
+
"""Callback for periodically saving the best model weights
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
path (str): path for saving the model weights
|
|
22
|
+
freq (int, optional): frequency in iterations at which the current average loss is evaluated. Defaults to 50.
|
|
23
|
+
average (int, optional): number of recent iterations over which the average loss is computed. Defaults to 10.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, path: str, freq: int = 50, average: int = 10):
|
|
27
|
+
self.path = path
|
|
28
|
+
self.average = average
|
|
29
|
+
self._best_loss = np.inf
|
|
30
|
+
self.freq = freq
|
|
31
|
+
|
|
32
|
+
def end_batch(self, trainer: Trainer, batch_num: int) -> None:
|
|
33
|
+
"""checks if the current average loss has improved from the previous save, if true the model is saved
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
trainer (Trainer): the trainer object
|
|
37
|
+
batch_num (int): the current iteration / batch
|
|
38
|
+
"""
|
|
39
|
+
if is_divisor(batch_num, self.freq):
|
|
40
|
+
|
|
41
|
+
loss = np.mean(trainer.losses['total_loss'][-self.average:])
|
|
42
|
+
|
|
43
|
+
if loss < self._best_loss:
|
|
44
|
+
self._best_loss = loss
|
|
45
|
+
self.save(trainer, batch_num)
|
|
46
|
+
|
|
47
|
+
def save(self, trainer: Trainer, batch_num: int):
|
|
48
|
+
"""saves a dictionary containing the network weights, the learning rates, the losses and the current \
|
|
49
|
+
best loss with its corresponding iteration to the disk
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
trainer (Trainer): the trainer object
|
|
53
|
+
batch_num (int): the current iteration / batch
|
|
54
|
+
"""
|
|
55
|
+
prev_save = trainer.callback_params.pop('saved_iteration', 0)
|
|
56
|
+
trainer.callback_params['saved_iteration'] = batch_num
|
|
57
|
+
save_dict = {
|
|
58
|
+
'model': trainer.model.state_dict(),
|
|
59
|
+
'lrs': trainer.lrs,
|
|
60
|
+
'losses': trainer.losses,
|
|
61
|
+
'prev_save': prev_save,
|
|
62
|
+
'batch_num': batch_num,
|
|
63
|
+
'best_loss': self._best_loss
|
|
64
|
+
}
|
|
65
|
+
torch.save(save_dict, self.path)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class LogLosses(TrainerCallback):
|
|
69
|
+
"""Callback for logging the training losses"""
|
|
70
|
+
def end_batch(self, trainer: Trainer, batch_num: int) -> None:
|
|
71
|
+
"""log loss at the current iteration
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
trainer (Trainer): the trainer object
|
|
75
|
+
batch_num (int): the index of the current iteration / batch
|
|
76
|
+
"""
|
|
77
|
+
for loss_name, loss_values in trainer.losses.items():
|
|
78
|
+
try:
|
|
79
|
+
trainer.log(f'train/{loss_name}', loss_values[-1])
|
|
80
|
+
except IndexError:
|
|
81
|
+
continue
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from torch import Tensor
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
from reflectorch.data_generation import BasicDataset
|
|
5
|
+
from reflectorch.data_generation.reflectivity import kinematical_approximation
|
|
6
|
+
from reflectorch.data_generation.priors import BasicParams
|
|
7
|
+
from reflectorch.ml.basic_trainer import DataLoader
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"ReflectivityDataLoader",
|
|
12
|
+
"MultilayerDataLoader",
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ReflectivityDataLoader(BasicDataset, DataLoader):
|
|
17
|
+
"""Dataloader for reflectivity data, combining functionality from the ``BasicDataset`` (basic dataset class for reflectivity) and the ``DataLoader`` (which inherits from ``TrainerCallback``) classes"""
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class MultilayerDataLoader(ReflectivityDataLoader):
|
|
22
|
+
"""Dataloader for reflectivity curves simulated using the kinematical approximation"""
|
|
23
|
+
def _sample_from_prior(self, batch_size: int):
|
|
24
|
+
return self.prior_sampler.optimized_sample(batch_size)
|
|
25
|
+
|
|
26
|
+
def _calc_curves(self, q_values: Tensor, params: BasicParams):
|
|
27
|
+
return kinematical_approximation(q_values, params.thicknesses, params.roughnesses, params.slds)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
'Logger',
|
|
5
|
+
'Loggers',
|
|
6
|
+
'PrintLogger',
|
|
7
|
+
'TensorBoardLogger',
|
|
8
|
+
]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Logger(object):
|
|
12
|
+
"Base class defining a common interface for logging"
|
|
13
|
+
def log(self, name: str, data):
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
def __setitem__(self, key, value):
|
|
17
|
+
"""Enable dictionary-style setting to log data."""
|
|
18
|
+
self.log(key, value)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Loggers(Logger):
|
|
22
|
+
"""Class for using multiple loggers"""
|
|
23
|
+
def __init__(self, *loggers):
|
|
24
|
+
self._loggers = tuple(loggers)
|
|
25
|
+
|
|
26
|
+
def log(self, name: str, data):
|
|
27
|
+
for logger in self._loggers:
|
|
28
|
+
logger.log(name, data)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class PrintLogger(Logger):
|
|
32
|
+
"""Logger which prints to the console"""
|
|
33
|
+
def log(self, name: str, data):
|
|
34
|
+
print(name, ': ', data)
|
|
35
|
+
|
|
36
|
+
class TensorBoardLogger(Logger):
|
|
37
|
+
def __init__(self, log_dir: str):
|
|
38
|
+
"""
|
|
39
|
+
Args:
|
|
40
|
+
log_dir (str): Directory where TensorBoard logs will be written
|
|
41
|
+
"""
|
|
42
|
+
super().__init__()
|
|
43
|
+
self.writer = SummaryWriter(log_dir=log_dir)
|
|
44
|
+
self.step = 1
|
|
45
|
+
|
|
46
|
+
def log(self, name: str, data):
|
|
47
|
+
"""Log scalar data to TensorBoard
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
name (str): Name/tag for the data
|
|
51
|
+
data: Scalar value to log
|
|
52
|
+
"""
|
|
53
|
+
if hasattr(data, 'item'):
|
|
54
|
+
data = data.item()
|
|
55
|
+
self.writer.add_scalar(name, data, self.step)
|
|
56
|
+
self.step += 1
|