reflectorch 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (83) hide show
  1. reflectorch/__init__.py +23 -0
  2. reflectorch/data_generation/__init__.py +130 -0
  3. reflectorch/data_generation/dataset.py +196 -0
  4. reflectorch/data_generation/likelihoods.py +86 -0
  5. reflectorch/data_generation/noise.py +371 -0
  6. reflectorch/data_generation/priors/__init__.py +66 -0
  7. reflectorch/data_generation/priors/base.py +61 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +304 -0
  9. reflectorch/data_generation/priors/independent_priors.py +201 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +110 -0
  12. reflectorch/data_generation/priors/no_constraints.py +212 -0
  13. reflectorch/data_generation/priors/parametric_models.py +767 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +354 -0
  15. reflectorch/data_generation/priors/params.py +258 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +306 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +377 -0
  19. reflectorch/data_generation/priors/utils.py +124 -0
  20. reflectorch/data_generation/process_data.py +47 -0
  21. reflectorch/data_generation/q_generator.py +232 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +56 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +81 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +58 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +92 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +123 -0
  28. reflectorch/data_generation/scale_curves.py +118 -0
  29. reflectorch/data_generation/smearing.py +67 -0
  30. reflectorch/data_generation/utils.py +154 -0
  31. reflectorch/extensions/__init__.py +6 -0
  32. reflectorch/extensions/jupyter/__init__.py +12 -0
  33. reflectorch/extensions/jupyter/callbacks.py +40 -0
  34. reflectorch/extensions/matplotlib/__init__.py +11 -0
  35. reflectorch/extensions/matplotlib/losses.py +38 -0
  36. reflectorch/inference/__init__.py +22 -0
  37. reflectorch/inference/inference_model.py +734 -0
  38. reflectorch/inference/multilayer_fitter.py +171 -0
  39. reflectorch/inference/multilayer_inference_model.py +193 -0
  40. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  41. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  42. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  43. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  44. reflectorch/inference/preprocess_exp/interpolation.py +16 -0
  45. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  46. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  47. reflectorch/inference/record_time.py +43 -0
  48. reflectorch/inference/sampler_solution.py +56 -0
  49. reflectorch/inference/scipy_fitter.py +171 -0
  50. reflectorch/inference/torch_fitter.py +87 -0
  51. reflectorch/ml/__init__.py +37 -0
  52. reflectorch/ml/basic_trainer.py +286 -0
  53. reflectorch/ml/callbacks.py +86 -0
  54. reflectorch/ml/dataloaders.py +27 -0
  55. reflectorch/ml/loggers.py +38 -0
  56. reflectorch/ml/schedulers.py +246 -0
  57. reflectorch/ml/trainers.py +126 -0
  58. reflectorch/ml/utils.py +9 -0
  59. reflectorch/models/__init__.py +22 -0
  60. reflectorch/models/activations.py +50 -0
  61. reflectorch/models/encoders/__init__.py +27 -0
  62. reflectorch/models/encoders/conv_encoder.py +211 -0
  63. reflectorch/models/encoders/conv_res_net.py +119 -0
  64. reflectorch/models/encoders/fno.py +127 -0
  65. reflectorch/models/encoders/transformers.py +56 -0
  66. reflectorch/models/networks/__init__.py +18 -0
  67. reflectorch/models/networks/mlp_networks.py +256 -0
  68. reflectorch/models/networks/residual_net.py +131 -0
  69. reflectorch/paths.py +33 -0
  70. reflectorch/runs/__init__.py +35 -0
  71. reflectorch/runs/config.py +31 -0
  72. reflectorch/runs/slurm_utils.py +99 -0
  73. reflectorch/runs/train.py +85 -0
  74. reflectorch/runs/utils.py +300 -0
  75. reflectorch/test_config.py +4 -0
  76. reflectorch/train.py +4 -0
  77. reflectorch/train_on_cluster.py +4 -0
  78. reflectorch/utils.py +74 -0
  79. reflectorch-1.0.0.dist-info/LICENSE.txt +621 -0
  80. reflectorch-1.0.0.dist-info/METADATA +115 -0
  81. reflectorch-1.0.0.dist-info/RECORD +83 -0
  82. reflectorch-1.0.0.dist-info/WHEEL +5 -0
  83. reflectorch-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,86 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+
9
+ import numpy as np
10
+
11
+ from reflectorch.ml.basic_trainer import (
12
+ TrainerCallback,
13
+ Trainer,
14
+ )
15
+ from reflectorch.ml.utils import is_divisor
16
+
17
+ __all__ = [
18
+ 'SaveBestModel',
19
+ 'LogLosses',
20
+ ]
21
+
22
+
23
+ class SaveBestModel(TrainerCallback):
24
+ """Callback for periodically saving the best model weights
25
+
26
+ Args:
27
+ path (str): path for saving the model weights
28
+ freq (int, optional): frequency in iterations at which the current average loss is evaluated. Defaults to 50.
29
+ average (int, optional): number of recent iterations over which the average loss is computed. Defaults to 10.
30
+ """
31
+
32
+ def __init__(self, path: str, freq: int = 50, average: int = 10):
33
+ self.path = path
34
+ self.average = average
35
+ self._best_loss = np.inf
36
+ self.freq = freq
37
+
38
+ def end_batch(self, trainer: Trainer, batch_num: int) -> None:
39
+ """checks if the current average loss has improved from the previous save, if true the model is saved
40
+
41
+ Args:
42
+ trainer (Trainer): the trainer object
43
+ batch_num (int): the current iteration / batch
44
+ """
45
+ if is_divisor(batch_num, self.freq):
46
+
47
+ loss = np.mean(trainer.losses['total_loss'][-self.average:])
48
+
49
+ if loss < self._best_loss:
50
+ self._best_loss = loss
51
+ self.save(trainer, batch_num)
52
+
53
+ def save(self, trainer: Trainer, batch_num: int):
54
+ """saves a dictionary containing the network weights, the learning rates, the losses and the current \
55
+ best loss with its corresponding iteration to the disk
56
+
57
+ Args:
58
+ trainer (Trainer): the trainer object
59
+ batch_num (int): the current iteration / batch
60
+ """
61
+ prev_save = trainer.callback_params.pop('saved_iteration', 0)
62
+ trainer.callback_params['saved_iteration'] = batch_num
63
+ save_dict = {
64
+ 'model': trainer.model.state_dict(),
65
+ 'lrs': trainer.lrs,
66
+ 'losses': trainer.losses,
67
+ 'prev_save': prev_save,
68
+ 'batch_num': batch_num,
69
+ 'best_loss': self._best_loss
70
+ }
71
+ torch.save(save_dict, self.path)
72
+
73
+
74
+ class LogLosses(TrainerCallback):
75
+ """Callback for logging the training losses"""
76
+ def end_batch(self, trainer: Trainer, batch_num: int) -> None:
77
+ """log loss at the current iteration
78
+
79
+ Args:
80
+ trainer (Trainer): the trainer object
81
+ batch_num (int): the index of the current iteration / batch
82
+ """
83
+ try:
84
+ trainer.log('train/total_loss', trainer.losses[trainer.TOTAL_LOSS_KEY][-1])
85
+ except IndexError:
86
+ pass
@@ -0,0 +1,27 @@
1
+ from torch import Tensor
2
+
3
+
4
+ from reflectorch.data_generation import BasicDataset
5
+ from reflectorch.data_generation.reflectivity import kinematical_approximation
6
+ from reflectorch.data_generation.priors import BasicParams
7
+ from reflectorch.ml.basic_trainer import DataLoader
8
+
9
+
10
+ __all__ = [
11
+ "ReflectivityDataLoader",
12
+ "MultilayerDataLoader",
13
+ ]
14
+
15
+
16
+ class ReflectivityDataLoader(BasicDataset, DataLoader):
17
+ """Dataloader for reflectivity data, combining functionality from the ``BasicDataset`` (basic dataset class for reflectivity) and the ``DataLoader`` (which inherits from ``TrainerCallback``) classes"""
18
+ pass
19
+
20
+
21
+ class MultilayerDataLoader(ReflectivityDataLoader):
22
+ """Dataloader for reflectivity curves simulated using the kinematical approximation"""
23
+ def _sample_from_prior(self, batch_size: int):
24
+ return self.prior_sampler.optimized_sample(batch_size)
25
+
26
+ def _calc_curves(self, q_values: Tensor, params: BasicParams):
27
+ return kinematical_approximation(q_values, params.thicknesses, params.roughnesses, params.slds)
@@ -0,0 +1,38 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ __all__ = [
9
+ 'Logger',
10
+ 'Loggers',
11
+ 'PrintLogger',
12
+ ]
13
+
14
+
15
+ class Logger(object):
16
+ "Base class defining a common interface for logging"
17
+ def log(self, name: str, data):
18
+ pass
19
+
20
+ def __setitem__(self, key, value):
21
+ """Enable dictionary-style setting to log data."""
22
+ self.log(key, value)
23
+
24
+
25
+ class Loggers(Logger):
26
+ """Class for using multiple loggers"""
27
+ def __init__(self, *loggers):
28
+ self._loggers = tuple(loggers)
29
+
30
+ def log(self, name: str, data):
31
+ for logger in self._loggers:
32
+ logger.log(name, data)
33
+
34
+
35
+ class PrintLogger(Logger):
36
+ """Logger which prints to the console"""
37
+ def log(self, name: str, data):
38
+ print(name, ': ', data)
@@ -0,0 +1,246 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torch.optim import lr_scheduler
8
+
9
+ import numpy as np
10
+
11
+ from reflectorch.ml.basic_trainer import Trainer, TrainerCallback, PeriodicTrainerCallback
12
+
13
+ __all__ = [
14
+ 'ScheduleBatchSize',
15
+ 'ScheduleLR',
16
+ 'StepLR',
17
+ 'CyclicLR',
18
+ 'LogCyclicLR',
19
+ 'ReduceLROnPlateau',
20
+ 'OneCycleLR',
21
+ ]
22
+
23
+
24
+ class ScheduleBatchSize(PeriodicTrainerCallback):
25
+ """Batch size scheduler
26
+
27
+ Args:
28
+ step (int): number of iterations after which the batch size is modified.
29
+ gamma (int, optional): quantity which is added to or multiplied with the current batch size. Defaults to 2.
30
+ last_epoch (int, optional): the last training iteration for which the batch size is modified. Defaults to -1.
31
+ mode (str, optional): ``'add'`` for addition or ``'multiply'`` for multiplication. Defaults to 'add'.
32
+ """
33
+ def __init__(self, step: int, gamma: int = 2, last_epoch: int = -1, mode: str = 'add'):
34
+ super().__init__(step, last_epoch)
35
+
36
+ assert mode in ('add', 'multiply')
37
+
38
+ self.gamma = gamma
39
+ self.mode = mode
40
+
41
+ def _end_batch(self, trainer: Trainer, batch_num: int) -> None:
42
+ if self.mode == 'add':
43
+ trainer.batch_size += self.gamma
44
+ elif self.mode == 'multiply':
45
+ trainer.batch_size *= self.gamma
46
+
47
+
48
+ class ScheduleLR(TrainerCallback):
49
+ """Base class for learning rate schedulers
50
+
51
+ Args:
52
+ lr_scheduler_cls: class of the learning rate scheduler
53
+ """
54
+
55
+ def __init__(self, lr_scheduler_cls, **kwargs):
56
+ self.lr_scheduler_cls = lr_scheduler_cls
57
+ self.kwargs = kwargs
58
+ self.lr_scheduler = None
59
+
60
+ def start_training(self, trainer: Trainer) -> None:
61
+ """initializes a learning rate scheduler based on its class and keyword arguments at the start of training
62
+
63
+ Args:
64
+ trainer (Trainer): the trainer object
65
+ """
66
+ self.lr_scheduler = self.lr_scheduler_cls(trainer.optim, **self.kwargs)
67
+ trainer.callback_params['lrs'] = []
68
+
69
+ def end_batch(self, trainer: Trainer, batch_num: int) -> None:
70
+ """modifies the learning rate at the end of each iteration
71
+
72
+ Args:
73
+ trainer (Trainer): the trainer object
74
+ batch_num (int): index of the current iteration
75
+ """
76
+ self.lr_scheduler.step()
77
+
78
+
79
+ class StepLR(ScheduleLR):
80
+ """Learning rate scheduler which decays the learning rate of each parameter group by gamma every ``step_size`` epochs.
81
+
82
+ Args:
83
+ step_size (int): Period of learning rate decay
84
+ gamma (float): Multiplicative factor of learning rate decay
85
+ last_epoch (int, optional): The index of last iteration. Defaults to -1.
86
+ """
87
+ def __init__(self, step_size: int, gamma: float, last_epoch: int = -1, **kwargs):
88
+
89
+
90
+ super().__init__(lr_scheduler.StepLR, step_size=step_size, gamma=gamma, last_epoch=last_epoch, **kwargs)
91
+
92
+ def start_training(self, trainer: Trainer) -> None:
93
+ trainer.optim.param_groups[0]['initial_lr'] = trainer.lr()
94
+ super().start_training(trainer)
95
+
96
+
97
+ class CyclicLR(ScheduleLR):
98
+ """Cyclic learning rate scheduler
99
+
100
+ Args:
101
+ base_lr (float): Initial learning rate which is the lower boundary in the cycle
102
+ max_lr (float): Upper learning rate boundary in the cycle
103
+ step_size_up (int, optional): Number of training iterations in the increasing half of a cycle. Defaults to 2000.
104
+ cycle_momentum (bool, optional): If True, momentum is cycled inversely to learning rate between ``base_momentum`` and ``max_momentum``. Defaults to False.
105
+ gamma (float, optional): Constant in ``‘exp_range’`` mode scaling function: gamma^(cycle iterations). Defaults to 1.
106
+ mode (str, optional): One of: ``'triangular'`` (a basic triangular cycle without amplitude scaling),
107
+ ``'triangular2'`` (a basic triangular cycle that scales initial amplitude by half each cycle), ``'exp_range'``
108
+ (a cycle that scales initial amplitude by gamma^iterations at each cycle iteration). Defaults to 'triangular'.
109
+ """
110
+ def __init__(self, base_lr, max_lr, step_size_up: int = 2000,
111
+ cycle_momentum: bool = False, gamma: float = 1., mode: str = 'triangular',
112
+ **kwargs):
113
+ super().__init__(
114
+ lr_scheduler.CyclicLR,
115
+ base_lr=base_lr,
116
+ max_lr=max_lr,
117
+ step_size_up=step_size_up,
118
+ cycle_momentum=cycle_momentum,
119
+ gamma=gamma,
120
+ mode=mode,
121
+ **kwargs
122
+ )
123
+
124
+
125
+ class LogCyclicLR(TrainerCallback):
126
+ """Cyclic learning rate scheduler on a logarithmic scale
127
+
128
+ Args:
129
+ base_lr (float): Lower learning rate boundary in the cycle
130
+ max_lr (float): Upper learning rate boundary in the cycle
131
+ period (int, optional): Number of training iterations in the cycle. Defaults to 2000.
132
+ gamma (float, optional): Constant for scaling the amplitude as ``gamma`` ^ ``iterations``. Defaults to 1.
133
+ start_period (int, optional): Number of starting iterations with the default learning rate.
134
+ log (bool, optional): If ``True``, the cycle is in the logarithmic domain.
135
+ param_groups (tupe, optional): Parameter groups of the optimizer.
136
+ """
137
+ def __init__(self,
138
+ base_lr,
139
+ max_lr,
140
+ period: int = 2000,
141
+ gamma: float = None,
142
+ log: bool = True,
143
+ param_groups: tuple = (0,),
144
+ start_period: int = 25,
145
+ ):
146
+ self.base_lr = base_lr
147
+ self.max_lr = max_lr
148
+ self.period = period
149
+ self.gamma = gamma
150
+ self.param_groups = param_groups
151
+ self.log = log
152
+ self.start_period = start_period
153
+ self._axis = None
154
+ self._period = None
155
+
156
+ def get_lr(self, batch_num: int):
157
+ return self._get_lr(batch_num)
158
+
159
+ def _get_lr(self, batch_num):
160
+ num_period, t = batch_num // self.period, batch_num % self.period
161
+
162
+ if self._period != num_period:
163
+ self._period = num_period
164
+ if self.gamma and (num_period >= self.start_period):
165
+ amp = (self.max_lr - self.base_lr) * (self.gamma ** (num_period - self.start_period))
166
+ max_lr = self.base_lr + amp
167
+ else:
168
+ max_lr = self.max_lr
169
+
170
+ if self.log:
171
+ self._axis = np.logspace(np.log10(self.base_lr), np.log10(max_lr), self.period // 2)
172
+ else:
173
+ self._axis = np.linspace(self.base_lr, max_lr, self.period // 2)
174
+ if t < self.period // 2:
175
+ lr = self._axis[t]
176
+ else:
177
+ lr = self._axis[self.period - t - 1]
178
+ return lr
179
+
180
+ def end_batch(self, trainer: Trainer, batch_num: int):
181
+ lr = self.get_lr(batch_num)
182
+ for param_group in self.param_groups:
183
+ trainer.set_lr(lr, param_group)
184
+
185
+
186
+ class ReduceLROnPlateau(TrainerCallback):
187
+ """Learning rate scheduler which reduces the learning rate when the loss stops decreasing
188
+
189
+ Args:
190
+ gamma (float, optional): Multiplicative factor of learning rate decay. Defaults to 0.5.
191
+ patience (int, optional): The number of allowed iterations with no improvement after which the learning rate will be reduced. Defaults to 500.
192
+ average (int, optional): Size of the window over which the average loss is computed. Defaults to 50.
193
+ loss_key (str, optional): Defaults to 'total_loss'.
194
+ param_groups (tuple, optional): Defaults to (0,).
195
+
196
+ """
197
+ def __init__(
198
+ self,
199
+ gamma: float = 0.5,
200
+ patience: int = 500,
201
+ average: int = 50,
202
+ loss_key: str = 'total_loss',
203
+ param_groups: tuple = (0,),
204
+ ):
205
+ """
206
+ """
207
+ self.patience = patience
208
+ self.average = average
209
+ self.gamma = gamma
210
+ self.loss_key = loss_key
211
+ self.param_groups = param_groups
212
+
213
+ def end_batch(self, trainer: Trainer, batch_num: int) -> None:
214
+ loss = trainer.losses[self.loss_key]
215
+
216
+ if len(loss) < self.patience:
217
+ return
218
+
219
+ if np.mean(loss[-self.patience:-(self.patience - self.average)]) <= np.mean(loss[-self.average:]):
220
+ for param_group in self.param_groups:
221
+ trainer.set_lr(trainer.lr(param_group) * self.gamma, param_group)
222
+
223
+
224
+ class OneCycleLR(ScheduleLR):
225
+ """One-cycle learning rate scheduler (https://arxiv.org/abs/1708.07120)
226
+
227
+ Args:
228
+ max_lr (float): Upper learning rate boundary in the cycle
229
+ total_steps (int): The total number of steps in the cycle
230
+ pct_start (float, optional): The percentage of the cycle (in number of steps) spent increasing the learning rate. Defaults to 0.3.
231
+ div_factor (float, optional): Determines the initial learning rate via initial_lr = ``max_lr`` / ``div_factor``. Defaults to 25..
232
+ final_div_factor (float, optional): Determines the minimum learning rate via min_lr = ``initial_lr`` / ``final_div_factor``. Defaults to 1e4.
233
+ three_phase (bool, optional): If ``True``, use a third phase of the schedule to annihilate the learning rate according to ``final_div_factor`` instead of modifying the second phase. Defaults to True.
234
+ """
235
+ def __init__(self, max_lr: float, total_steps: int, pct_start: float = 0.3, div_factor: float = 25.,
236
+ final_div_factor: float = 1e4, three_phase: bool = True, **kwargs):
237
+ super().__init__(
238
+ lr_scheduler.OneCycleLR,
239
+ max_lr=max_lr,
240
+ total_steps=total_steps,
241
+ pct_start=pct_start,
242
+ div_factor=div_factor ,
243
+ final_div_factor=final_div_factor,
244
+ three_phase=three_phase,
245
+ **kwargs
246
+ )
@@ -0,0 +1,126 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+
12
+ from reflectorch.data_generation import BATCH_DATA_TYPE
13
+ from reflectorch.ml.basic_trainer import Trainer
14
+ from reflectorch.ml.dataloaders import ReflectivityDataLoader
15
+
16
+ __all__ = [
17
+ 'RealTimeSimTrainer',
18
+ 'DenoisingAETrainer',
19
+ 'VAETrainer',
20
+ 'PointEstimatorTrainer',
21
+ ]
22
+
23
+
24
+ class RealTimeSimTrainer(Trainer):
25
+ """Trainer with functionality to customize the sampled batch of data"""
26
+ loader: ReflectivityDataLoader
27
+
28
+ def get_batch_by_idx(self, batch_num: int):
29
+ """Gets a batch of data with the default batch size"""
30
+ batch_data = self.loader.get_batch(self.batch_size)
31
+ return self._get_batch(batch_data)
32
+
33
+ def get_batch_by_size(self, batch_size: int):
34
+ """Gets a batch of data with a custom batch size"""
35
+ batch_data = self.loader.get_batch(batch_size)
36
+ return self._get_batch(batch_data)
37
+
38
+ def _get_batch(self, batch_data: BATCH_DATA_TYPE):
39
+ """Modify the batch of data sampled from the data loader"""
40
+ raise NotImplementedError
41
+
42
+
43
+ class PointEstimatorTrainer(RealTimeSimTrainer):
44
+ """Trainer for the regression inverse problem with incorporation of prior bounds"""
45
+ add_sigmas_to_context: bool = False
46
+
47
+ def _get_batch(self, batch_data: BATCH_DATA_TYPE):
48
+ scaled_params = batch_data['scaled_params'].to(torch.float32)
49
+ scaled_curves = batch_data['scaled_noisy_curves'].to(torch.float32)
50
+ if self.train_with_q_input:
51
+ q_values = batch_data['q_values'].to(torch.float32)
52
+ scaled_q_values = self.loader.q_generator.scale_q(q_values)
53
+ else:
54
+ scaled_q_values = None
55
+
56
+ num_params = scaled_params.shape[-1] // 3
57
+ assert num_params * 3 == scaled_params.shape[-1]
58
+ scaled_params, scaled_bounds = torch.split(scaled_params, [num_params, 2 * num_params], dim=-1)
59
+
60
+ return scaled_params, scaled_bounds, scaled_curves, scaled_q_values
61
+
62
+ def get_loss_dict(self, batch_data):
63
+ """computes the loss dictionary"""
64
+
65
+ scaled_params, scaled_bounds, scaled_curves, scaled_q_values = batch_data
66
+
67
+ if self.train_with_q_input:
68
+ predicted_params = self.model(scaled_curves, scaled_bounds, scaled_q_values)
69
+ else:
70
+ predicted_params = self.model(scaled_curves, scaled_bounds)
71
+
72
+ loss = self.mse(predicted_params, scaled_params)
73
+ return {'loss': loss}
74
+
75
+ def init(self):
76
+ self.mse = nn.MSELoss()
77
+
78
+
79
+ class DenoisingAETrainer(RealTimeSimTrainer):
80
+ """Trainer which can be used for training a denoising autoencoder model. Overrides _get_batch and get_loss_dict methods """
81
+ def init(self):
82
+ self.criterion = nn.MSELoss()
83
+ self.loader.calc_denoised_curves = True
84
+
85
+ def _get_batch(self, batch_data: BATCH_DATA_TYPE):
86
+ """returns scaled curves with and without noise"""
87
+ scaled_noisy_curves, curves = batch_data['scaled_noisy_curves'], batch_data['curves']
88
+ scaled_curves = self.loader.curves_scaler.scale(curves)
89
+
90
+ scaled_noisy_curves, scaled_curves = scaled_noisy_curves.to(torch.float32), scaled_curves.to(torch.float32)
91
+
92
+ return scaled_noisy_curves, scaled_curves
93
+
94
+ def get_loss_dict(self, batch_data):
95
+ """returns the reconstruction loss of the autoencoder"""
96
+ scaled_noisy_curves, scaled_curves = batch_data
97
+ restored_curves = self.model(scaled_noisy_curves)
98
+ loss = self.criterion(scaled_curves, restored_curves)
99
+ return {'loss': loss}
100
+
101
+
102
+ class VAETrainer(DenoisingAETrainer):
103
+ """Trainer which can be used for training a denoising autoencoder model. Overrides _get_batch and get_loss_dict methods """
104
+ def init(self):
105
+ self.loader.calc_denoised_curves = True
106
+ self.freebits = 0.05
107
+
108
+ def calc_kl(self, z_mu, z_logvar):
109
+ return 0.5*(z_mu**2 + torch.exp(z_logvar) - 1 - z_logvar)
110
+
111
+ def gaussian_log_prob(self, z, mu, logvar):
112
+ return -0.5*(np.log(2*np.pi) + logvar + (z-mu)**2/torch.exp(logvar))
113
+
114
+ def get_loss_dict(self, batch_data):
115
+ """returns the reconstruction loss of the autoencoder"""
116
+ scaled_noisy_curves, scaled_curves = batch_data
117
+ _, (z_mu, z_logvar, restored_curves_mu, restored_curves_logvar) = self.model(scaled_noisy_curves)
118
+
119
+ l_rec = -torch.mean(self.gaussian_log_prob(scaled_curves, restored_curves_mu, restored_curves_logvar), dim=-1)
120
+ l_kl = torch.mean(F.relu(self.calc_kl(z_mu, z_logvar) - self.freebits*np.log(2)) + self.freebits*np.log(2), dim=-1)
121
+ loss = torch.mean(l_rec + l_kl)/np.log(2)
122
+
123
+ l_rec = torch.mean(l_rec)
124
+ l_kl = torch.mean(l_kl)
125
+
126
+ return {'loss': loss}
@@ -0,0 +1,9 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ def is_divisor(num: int, div: int):
9
+ return num and not num % div
@@ -0,0 +1,22 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from reflectorch.models.encoders import *
8
+ from reflectorch.models.networks import *
9
+
10
+ __all__ = [
11
+ "ConvEncoder",
12
+ "ConvDecoder",
13
+ "ConvAutoencoder",
14
+ "ConvVAE",
15
+ "TransformerEncoder",
16
+ "FnoEncoder",
17
+ "SpectralConv1d",
18
+ "ConvResidualNet1D",
19
+ "ResidualMLP",
20
+ "NetworkWithPriorsConvEmb",
21
+ "NetworkWithPriorsFnoEmb",
22
+ ]
@@ -0,0 +1,50 @@
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn.functional import relu
4
+
5
+ class Rowdy(nn.Module):
6
+ """adaptive activation function"""
7
+ def __init__(self, K=9):
8
+ super().__init__()
9
+ self.K = K
10
+ self.alpha = nn.Parameter(torch.cat((torch.ones(1), torch.zeros(K-1))))
11
+ self.alpha.requiresGrad = True
12
+ self.omega = nn.Parameter(torch.ones(K))
13
+ self.omega.requiresGrad = True
14
+
15
+ def forward(self, x):
16
+ rowdy = self.alpha[0]*relu(self.omega[0]*x)
17
+ for k in range(1, self.K):
18
+ rowdy += self.alpha[k]*torch.sin(self.omega[k]*k*x)
19
+ return rowdy
20
+
21
+
22
+ ACTIVATIONS = {
23
+ 'relu': nn.ReLU,
24
+ 'lrelu': nn.LeakyReLU,
25
+ 'gelu': nn.GELU,
26
+ 'selu': nn.SELU,
27
+ 'elu': nn.ELU,
28
+ 'sigmoid': nn.Sigmoid,
29
+ 'tanh': nn.Tanh,
30
+ 'silu': nn.SiLU,
31
+ 'mish': nn.Mish,
32
+ 'rowdy': Rowdy,
33
+ }
34
+
35
+
36
+ def activation_by_name(name):
37
+ """returns an activation function module corresponding to its name
38
+
39
+ Args:
40
+ name (str): string denoting the activation function ('relu', 'lrelu', 'gelu', 'selu', 'elu', 'sigmoid', 'silu', 'mish', 'rowdy')
41
+
42
+ Returns:
43
+ nn.Module: Pytorch activation function module
44
+ """
45
+ if not isinstance(name, str):
46
+ return name
47
+ try:
48
+ return ACTIVATIONS[name.lower()]
49
+ except KeyError:
50
+ raise KeyError(f'Unknown activation function {name}')
@@ -0,0 +1,27 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from reflectorch.models.encoders.conv_encoder import (
8
+ ConvEncoder,
9
+ ConvDecoder,
10
+ ConvAutoencoder,
11
+ ConvVAE,
12
+ )
13
+ from reflectorch.models.encoders.fno import FnoEncoder, SpectralConv1d
14
+ from reflectorch.models.encoders.transformers import TransformerEncoder
15
+ from reflectorch.models.encoders.conv_res_net import ConvResidualNet1D
16
+
17
+
18
+ __all__ = [
19
+ "TransformerEncoder",
20
+ "ConvEncoder",
21
+ "ConvDecoder",
22
+ "ConvAutoencoder",
23
+ "ConvVAE",
24
+ "ConvResidualNet1D",
25
+ "FnoEncoder",
26
+ "SpectralConv1d",
27
+ ]