reflectorch 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. reflectorch/__init__.py +17 -0
  2. reflectorch/data_generation/__init__.py +128 -0
  3. reflectorch/data_generation/dataset.py +216 -0
  4. reflectorch/data_generation/likelihoods.py +80 -0
  5. reflectorch/data_generation/noise.py +471 -0
  6. reflectorch/data_generation/priors/__init__.py +60 -0
  7. reflectorch/data_generation/priors/base.py +55 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
  9. reflectorch/data_generation/priors/independent_priors.py +195 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -0
  12. reflectorch/data_generation/priors/no_constraints.py +206 -0
  13. reflectorch/data_generation/priors/parametric_models.py +842 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
  15. reflectorch/data_generation/priors/params.py +252 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +370 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -0
  19. reflectorch/data_generation/priors/utils.py +118 -0
  20. reflectorch/data_generation/process_data.py +41 -0
  21. reflectorch/data_generation/q_generator.py +280 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +71 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -0
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
  29. reflectorch/data_generation/scale_curves.py +112 -0
  30. reflectorch/data_generation/smearing.py +99 -0
  31. reflectorch/data_generation/utils.py +223 -0
  32. reflectorch/extensions/__init__.py +0 -0
  33. reflectorch/extensions/jupyter/__init__.py +11 -0
  34. reflectorch/extensions/jupyter/api.py +85 -0
  35. reflectorch/extensions/jupyter/callbacks.py +34 -0
  36. reflectorch/extensions/jupyter/components.py +758 -0
  37. reflectorch/extensions/jupyter/custom_select.py +268 -0
  38. reflectorch/extensions/jupyter/log_widget.py +241 -0
  39. reflectorch/extensions/jupyter/model_selection.py +495 -0
  40. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  41. reflectorch/extensions/jupyter/widget.py +625 -0
  42. reflectorch/extensions/matplotlib/__init__.py +5 -0
  43. reflectorch/extensions/matplotlib/losses.py +32 -0
  44. reflectorch/extensions/refnx/refnx_conversion.py +77 -0
  45. reflectorch/inference/__init__.py +28 -0
  46. reflectorch/inference/inference_model.py +848 -0
  47. reflectorch/inference/input_interface.py +239 -0
  48. reflectorch/inference/loading_data.py +55 -0
  49. reflectorch/inference/multilayer_fitter.py +171 -0
  50. reflectorch/inference/multilayer_inference_model.py +193 -0
  51. reflectorch/inference/plotting.py +524 -0
  52. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  53. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  54. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  55. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  56. reflectorch/inference/preprocess_exp/interpolation.py +19 -0
  57. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  58. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  59. reflectorch/inference/query_matcher.py +82 -0
  60. reflectorch/inference/record_time.py +43 -0
  61. reflectorch/inference/sampler_solution.py +56 -0
  62. reflectorch/inference/scipy_fitter.py +364 -0
  63. reflectorch/inference/torch_fitter.py +87 -0
  64. reflectorch/ml/__init__.py +32 -0
  65. reflectorch/ml/basic_trainer.py +292 -0
  66. reflectorch/ml/callbacks.py +81 -0
  67. reflectorch/ml/dataloaders.py +27 -0
  68. reflectorch/ml/loggers.py +56 -0
  69. reflectorch/ml/schedulers.py +356 -0
  70. reflectorch/ml/trainers.py +201 -0
  71. reflectorch/ml/utils.py +2 -0
  72. reflectorch/models/__init__.py +16 -0
  73. reflectorch/models/activations.py +50 -0
  74. reflectorch/models/encoders/__init__.py +19 -0
  75. reflectorch/models/encoders/conv_encoder.py +219 -0
  76. reflectorch/models/encoders/conv_res_net.py +115 -0
  77. reflectorch/models/encoders/fno.py +134 -0
  78. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  79. reflectorch/models/networks/__init__.py +14 -0
  80. reflectorch/models/networks/mlp_networks.py +434 -0
  81. reflectorch/models/networks/residual_net.py +157 -0
  82. reflectorch/paths.py +29 -0
  83. reflectorch/runs/__init__.py +31 -0
  84. reflectorch/runs/config.py +25 -0
  85. reflectorch/runs/slurm_utils.py +93 -0
  86. reflectorch/runs/train.py +78 -0
  87. reflectorch/runs/utils.py +405 -0
  88. reflectorch/test_config.py +4 -0
  89. reflectorch/train.py +4 -0
  90. reflectorch/train_on_cluster.py +4 -0
  91. reflectorch/utils.py +98 -0
  92. reflectorch-1.5.1.dist-info/METADATA +151 -0
  93. reflectorch-1.5.1.dist-info/RECORD +96 -0
  94. reflectorch-1.5.1.dist-info/WHEEL +5 -0
  95. reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
  96. reflectorch-1.5.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,292 @@
1
+ from typing import Optional, Tuple, Iterable, Any, Union, Type
2
+ from collections import defaultdict
3
+
4
+ from tqdm import tqdm as standard_tqdm
5
+ from tqdm.notebook import tqdm as notebook_tqdm
6
+ import numpy as np
7
+
8
+ import torch
9
+ from torch.nn import Module
10
+
11
+ from reflectorch.ml.loggers import Logger, Loggers
12
+
13
+ from .utils import is_divisor
14
+
15
+ __all__ = [
16
+ 'Trainer',
17
+ 'TrainerCallback',
18
+ 'DataLoader',
19
+ 'PeriodicTrainerCallback',
20
+ ]
21
+
22
+
23
+ class Trainer(object):
24
+ """Trainer class
25
+
26
+ Args:
27
+ model (nn.Module): neural network
28
+ loader (DataLoader): data loader
29
+ lr (float): learning rate
30
+ batch_size (int): batch size
31
+ clip_grad_norm (int, optional): maximum norm for gradient clipping if it is not ``None``. Defaults to None.
32
+ logger (Union[Logger, Tuple[Logger, ...], Loggers], optional): logger. Defaults to None.
33
+ optim_cls (Type[torch.optim.Optimizer], optional): Pytorch optimizer. Defaults to torch.optim.Adam.
34
+ optim_kwargs (dict, optional): optimizer arguments. Defaults to None.
35
+ """
36
+
37
+ TOTAL_LOSS_KEY: str = 'total_loss'
38
+
39
+ def __init__(self,
40
+ model: Module,
41
+ loader: 'DataLoader',
42
+ lr: float,
43
+ batch_size: int,
44
+ clip_grad_norm_max: Optional[int] = None,
45
+ logger: Union[Logger, Tuple[Logger, ...], Loggers] = None,
46
+ optim_cls: Type[torch.optim.Optimizer] = torch.optim.Adam,
47
+ optim_kwargs: dict = None,
48
+ **kwargs
49
+ ):
50
+
51
+ self.model = model
52
+ self.loader = loader
53
+ self.batch_size = batch_size
54
+ self.clip_grad_norm_max = clip_grad_norm_max
55
+
56
+ self.optim = self.configure_optimizer(optim_cls, lr=lr, **(optim_kwargs or {}))
57
+ self.lrs = []
58
+ self.losses = defaultdict(list)
59
+
60
+ self.logger = _init_logger(logger)
61
+ self.callback_params = {}
62
+
63
+ for k, v in kwargs.items():
64
+ setattr(self, k, v)
65
+
66
+ self.init()
67
+
68
+ def init(self):
69
+ pass
70
+
71
+ def log(self, name: str, data):
72
+ """log data"""
73
+ self.logger.log(name, data)
74
+
75
+ def train(self,
76
+ num_batches: int,
77
+ callbacks: Union[Tuple['TrainerCallback', ...], 'TrainerCallback'] = (),
78
+ disable_tqdm: bool = False,
79
+ use_notebook_tqdm: bool = False,
80
+ update_tqdm_freq: int = 1,
81
+ grad_accumulation_steps: int = 1,
82
+ ):
83
+ """starts the training process
84
+
85
+ Args:
86
+ num_batches (int): total number of training iterations
87
+ callbacks (Union[Tuple['TrainerCallback'], 'TrainerCallback']): the trainer callbacks. Defaults to ().
88
+ disable_tqdm (bool, optional): if ``True``, the progress bar is disabled. Defaults to False.
89
+ use_notebook_tqdm (bool, optional): should be set to ``True`` when used in a Jupyter Notebook. Defaults to False.
90
+ update_tqdm_freq (int, optional): frequency for updating the progress bar. Defaults to 10.
91
+ grad_accumulation_steps (int, optional): number of gradient accumulation steps. Defaults to 1.
92
+ """
93
+
94
+ if isinstance(callbacks, TrainerCallback):
95
+ callbacks = (callbacks,)
96
+
97
+ callbacks = _StackedTrainerCallbacks(list(callbacks) + [self.loader])
98
+
99
+ tqdm_class = notebook_tqdm if use_notebook_tqdm else standard_tqdm
100
+ pbar = tqdm_class(range(num_batches), disable=disable_tqdm)
101
+
102
+ callbacks.start_training(self)
103
+
104
+ for batch_num in pbar:
105
+ self.model.train()
106
+
107
+ self.optim.zero_grad()
108
+ total_loss, avr_loss_dict = 0, defaultdict(list)
109
+
110
+ for _ in range(grad_accumulation_steps):
111
+
112
+ batch_data = self.get_batch_by_idx(batch_num)
113
+ loss_dict = self.get_loss_dict(batch_data)
114
+ loss = loss_dict['loss'] / grad_accumulation_steps
115
+ total_loss += loss.item()
116
+ _update_loss_dict(avr_loss_dict, loss_dict)
117
+
118
+ if not torch.isfinite(loss).item():
119
+ raise ValueError('Loss is not finite!')
120
+
121
+ loss.backward()
122
+
123
+ if self.clip_grad_norm_max is not None:
124
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip_grad_norm_max)
125
+
126
+ self.optim.step()
127
+
128
+ avr_loss_dict = {k: np.mean(v) for k, v in avr_loss_dict.items()}
129
+ self._update_losses(avr_loss_dict, total_loss)
130
+
131
+ if not disable_tqdm:
132
+ self._update_tqdm(pbar, batch_num, update_tqdm_freq)
133
+
134
+ break_epoch = callbacks.end_batch(self, batch_num)
135
+
136
+ if break_epoch:
137
+ break
138
+
139
+ callbacks.end_training(self)
140
+
141
+ def _update_tqdm(self, pbar, batch_num: int, update_tqdm_freq: int):
142
+ if is_divisor(batch_num, update_tqdm_freq):
143
+ last_loss = np.mean(self.losses[self.TOTAL_LOSS_KEY][-10:])
144
+ pbar.set_description(f'Loss = {last_loss:.2e}')
145
+
146
+ postfix = {}
147
+ for key in self.losses.keys():
148
+ if key != self.TOTAL_LOSS_KEY:
149
+ last_value = self.losses[key][-1]
150
+ postfix[key] = f'{last_value:.4f}'
151
+
152
+ postfix['lr'] = f'{self.lr():.2e}'
153
+
154
+ pbar.set_postfix(postfix)
155
+
156
+ def get_batch_by_idx(self, batch_num: int) -> Any:
157
+ raise NotImplementedError
158
+
159
+ def get_loss_dict(self, batch_data) -> dict:
160
+ raise NotImplementedError
161
+
162
+ def _update_losses(self, loss_dict: dict, loss: float) -> None:
163
+ _update_loss_dict(self.losses, loss_dict)
164
+ self.losses[self.TOTAL_LOSS_KEY].append(loss)
165
+ self.lrs.append(self.lr())
166
+
167
+ def configure_optimizer(self, optim_cls, lr: float, **kwargs) -> torch.optim.Optimizer:
168
+ """configure the optimizer based on the optimizer class, the learning rate and the optimizer keyword arguments
169
+
170
+ Args:
171
+ optim_cls: the class of the optimizer
172
+ lr (float): the learning rate
173
+
174
+ Returns:
175
+ torch.optim.Optimizer:
176
+ """
177
+ optim = optim_cls(self.model.parameters(), lr, **kwargs)
178
+ return optim
179
+
180
+ def lr(self, param_group: int = 0) -> float:
181
+ """get the learning rate"""
182
+ return self.optim.param_groups[param_group]['lr']
183
+
184
+ def set_lr(self, lr: float, param_group: int = 0) -> None:
185
+ """set the learning rate"""
186
+ self.optim.param_groups[param_group]['lr'] = lr
187
+
188
+
189
+ class TrainerCallback(object):
190
+ """Base class for trainer callbacks
191
+ """
192
+ def start_training(self, trainer: Trainer) -> None:
193
+ """add functionality the start of training
194
+
195
+ Args:
196
+ trainer (Trainer): the trainer object
197
+ """
198
+ pass
199
+
200
+ def end_training(self, trainer: Trainer) -> None:
201
+ """add functionality at the end of training
202
+
203
+ Args:
204
+ trainer (Trainer): the trainer object
205
+ """
206
+ pass
207
+
208
+ def end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
209
+ """add functionality at the end of the iteration / batch
210
+
211
+ Args:
212
+ trainer (Trainer): the trainer object
213
+ batch_num (int): the index of the current iteration / batch
214
+
215
+ Returns:
216
+ Union[bool, None]:
217
+ """
218
+ pass
219
+
220
+ def __repr__(self):
221
+ return f'{self.__class__.__name__}()'
222
+
223
+
224
+ class DataLoader(TrainerCallback):
225
+ pass
226
+
227
+
228
+ class PeriodicTrainerCallback(TrainerCallback):
229
+ """Base class for trainer callbacks which perform an action periodically after a number of iterations
230
+
231
+ Args:
232
+ step (int, optional): Number of iterations after which the action is repeated. Defaults to 1.
233
+ last_epoch (int, optional): the last training iteration for which the action is performed. Defaults to -1.
234
+ """
235
+ def __init__(self, step: int = 1, last_epoch: int = -1):
236
+ self.step = step
237
+ self.last_epoch = last_epoch
238
+
239
+ def end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
240
+ """add functionality at the end of the iteration / batch
241
+
242
+ Args:
243
+ trainer (Trainer): the trainer object
244
+ batch_num (int): the index of the current iteration / batch
245
+
246
+ Returns:
247
+ Union[bool, None]:
248
+ """
249
+ if (
250
+ is_divisor(batch_num, self.step) and
251
+ (self.last_epoch == -1 or batch_num < self.last_epoch)
252
+ ):
253
+ return self._end_batch(trainer, batch_num)
254
+
255
+ def _end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
256
+ pass
257
+
258
+
259
+ class _StackedTrainerCallbacks(TrainerCallback):
260
+ def __init__(self, callbacks: Iterable[TrainerCallback]):
261
+ self.callbacks = tuple(callbacks)
262
+
263
+ def start_training(self, trainer: Trainer) -> None:
264
+ for c in self.callbacks:
265
+ c.start_training(trainer)
266
+
267
+ def end_training(self, trainer: Trainer) -> None:
268
+ for c in self.callbacks:
269
+ c.end_training(trainer)
270
+
271
+ def end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
272
+ break_epoch = False
273
+ for c in self.callbacks:
274
+ break_epoch += bool(c.end_batch(trainer, batch_num))
275
+ return break_epoch
276
+
277
+ def __repr__(self):
278
+ callbacks = ", ".join(repr(c) for c in self.callbacks)
279
+ return f'StackedTrainerCallbacks({callbacks})'
280
+
281
+
282
+ def _init_logger(logger: Union[Logger, Tuple[Logger, ...], Loggers] = None):
283
+ if not logger:
284
+ return Logger()
285
+ if isinstance(logger, Logger):
286
+ return logger
287
+ return Loggers(*logger)
288
+
289
+
290
+ def _update_loss_dict(loss_dict: dict, new_values: dict):
291
+ for k, v in new_values.items():
292
+ loss_dict[k].append(v.item())
@@ -0,0 +1,81 @@
1
+ import torch
2
+
3
+ import numpy as np
4
+
5
+ from reflectorch.ml.basic_trainer import (
6
+ TrainerCallback,
7
+ Trainer,
8
+ )
9
+ from reflectorch.ml.utils import is_divisor
10
+
11
+ __all__ = [
12
+ 'SaveBestModel',
13
+ 'LogLosses',
14
+ ]
15
+
16
+
17
+ class SaveBestModel(TrainerCallback):
18
+ """Callback for periodically saving the best model weights
19
+
20
+ Args:
21
+ path (str): path for saving the model weights
22
+ freq (int, optional): frequency in iterations at which the current average loss is evaluated. Defaults to 50.
23
+ average (int, optional): number of recent iterations over which the average loss is computed. Defaults to 10.
24
+ """
25
+
26
+ def __init__(self, path: str, freq: int = 50, average: int = 10):
27
+ self.path = path
28
+ self.average = average
29
+ self._best_loss = np.inf
30
+ self.freq = freq
31
+
32
+ def end_batch(self, trainer: Trainer, batch_num: int) -> None:
33
+ """checks if the current average loss has improved from the previous save, if true the model is saved
34
+
35
+ Args:
36
+ trainer (Trainer): the trainer object
37
+ batch_num (int): the current iteration / batch
38
+ """
39
+ if is_divisor(batch_num, self.freq):
40
+
41
+ loss = np.mean(trainer.losses['total_loss'][-self.average:])
42
+
43
+ if loss < self._best_loss:
44
+ self._best_loss = loss
45
+ self.save(trainer, batch_num)
46
+
47
+ def save(self, trainer: Trainer, batch_num: int):
48
+ """saves a dictionary containing the network weights, the learning rates, the losses and the current \
49
+ best loss with its corresponding iteration to the disk
50
+
51
+ Args:
52
+ trainer (Trainer): the trainer object
53
+ batch_num (int): the current iteration / batch
54
+ """
55
+ prev_save = trainer.callback_params.pop('saved_iteration', 0)
56
+ trainer.callback_params['saved_iteration'] = batch_num
57
+ save_dict = {
58
+ 'model': trainer.model.state_dict(),
59
+ 'lrs': trainer.lrs,
60
+ 'losses': trainer.losses,
61
+ 'prev_save': prev_save,
62
+ 'batch_num': batch_num,
63
+ 'best_loss': self._best_loss
64
+ }
65
+ torch.save(save_dict, self.path)
66
+
67
+
68
+ class LogLosses(TrainerCallback):
69
+ """Callback for logging the training losses"""
70
+ def end_batch(self, trainer: Trainer, batch_num: int) -> None:
71
+ """log loss at the current iteration
72
+
73
+ Args:
74
+ trainer (Trainer): the trainer object
75
+ batch_num (int): the index of the current iteration / batch
76
+ """
77
+ for loss_name, loss_values in trainer.losses.items():
78
+ try:
79
+ trainer.log(f'train/{loss_name}', loss_values[-1])
80
+ except IndexError:
81
+ continue
@@ -0,0 +1,27 @@
1
+ from torch import Tensor
2
+
3
+
4
+ from reflectorch.data_generation import BasicDataset
5
+ from reflectorch.data_generation.reflectivity import kinematical_approximation
6
+ from reflectorch.data_generation.priors import BasicParams
7
+ from reflectorch.ml.basic_trainer import DataLoader
8
+
9
+
10
+ __all__ = [
11
+ "ReflectivityDataLoader",
12
+ "MultilayerDataLoader",
13
+ ]
14
+
15
+
16
+ class ReflectivityDataLoader(BasicDataset, DataLoader):
17
+ """Dataloader for reflectivity data, combining functionality from the ``BasicDataset`` (basic dataset class for reflectivity) and the ``DataLoader`` (which inherits from ``TrainerCallback``) classes"""
18
+ pass
19
+
20
+
21
+ class MultilayerDataLoader(ReflectivityDataLoader):
22
+ """Dataloader for reflectivity curves simulated using the kinematical approximation"""
23
+ def _sample_from_prior(self, batch_size: int):
24
+ return self.prior_sampler.optimized_sample(batch_size)
25
+
26
+ def _calc_curves(self, q_values: Tensor, params: BasicParams):
27
+ return kinematical_approximation(q_values, params.thicknesses, params.roughnesses, params.slds)
@@ -0,0 +1,56 @@
1
+ from torch.utils.tensorboard import SummaryWriter
2
+
3
+ __all__ = [
4
+ 'Logger',
5
+ 'Loggers',
6
+ 'PrintLogger',
7
+ 'TensorBoardLogger',
8
+ ]
9
+
10
+
11
+ class Logger(object):
12
+ "Base class defining a common interface for logging"
13
+ def log(self, name: str, data):
14
+ pass
15
+
16
+ def __setitem__(self, key, value):
17
+ """Enable dictionary-style setting to log data."""
18
+ self.log(key, value)
19
+
20
+
21
+ class Loggers(Logger):
22
+ """Class for using multiple loggers"""
23
+ def __init__(self, *loggers):
24
+ self._loggers = tuple(loggers)
25
+
26
+ def log(self, name: str, data):
27
+ for logger in self._loggers:
28
+ logger.log(name, data)
29
+
30
+
31
+ class PrintLogger(Logger):
32
+ """Logger which prints to the console"""
33
+ def log(self, name: str, data):
34
+ print(name, ': ', data)
35
+
36
+ class TensorBoardLogger(Logger):
37
+ def __init__(self, log_dir: str):
38
+ """
39
+ Args:
40
+ log_dir (str): Directory where TensorBoard logs will be written
41
+ """
42
+ super().__init__()
43
+ self.writer = SummaryWriter(log_dir=log_dir)
44
+ self.step = 1
45
+
46
+ def log(self, name: str, data):
47
+ """Log scalar data to TensorBoard
48
+
49
+ Args:
50
+ name (str): Name/tag for the data
51
+ data: Scalar value to log
52
+ """
53
+ if hasattr(data, 'item'):
54
+ data = data.item()
55
+ self.writer.add_scalar(name, data, self.step)
56
+ self.step += 1