reflectorch 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. reflectorch/__init__.py +17 -0
  2. reflectorch/data_generation/__init__.py +128 -0
  3. reflectorch/data_generation/dataset.py +216 -0
  4. reflectorch/data_generation/likelihoods.py +80 -0
  5. reflectorch/data_generation/noise.py +471 -0
  6. reflectorch/data_generation/priors/__init__.py +60 -0
  7. reflectorch/data_generation/priors/base.py +55 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
  9. reflectorch/data_generation/priors/independent_priors.py +195 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -0
  12. reflectorch/data_generation/priors/no_constraints.py +206 -0
  13. reflectorch/data_generation/priors/parametric_models.py +842 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
  15. reflectorch/data_generation/priors/params.py +252 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +370 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -0
  19. reflectorch/data_generation/priors/utils.py +118 -0
  20. reflectorch/data_generation/process_data.py +41 -0
  21. reflectorch/data_generation/q_generator.py +280 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +71 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -0
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
  29. reflectorch/data_generation/scale_curves.py +112 -0
  30. reflectorch/data_generation/smearing.py +99 -0
  31. reflectorch/data_generation/utils.py +223 -0
  32. reflectorch/extensions/__init__.py +0 -0
  33. reflectorch/extensions/jupyter/__init__.py +11 -0
  34. reflectorch/extensions/jupyter/api.py +85 -0
  35. reflectorch/extensions/jupyter/callbacks.py +34 -0
  36. reflectorch/extensions/jupyter/components.py +758 -0
  37. reflectorch/extensions/jupyter/custom_select.py +268 -0
  38. reflectorch/extensions/jupyter/log_widget.py +241 -0
  39. reflectorch/extensions/jupyter/model_selection.py +495 -0
  40. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  41. reflectorch/extensions/jupyter/widget.py +625 -0
  42. reflectorch/extensions/matplotlib/__init__.py +5 -0
  43. reflectorch/extensions/matplotlib/losses.py +32 -0
  44. reflectorch/extensions/refnx/refnx_conversion.py +77 -0
  45. reflectorch/inference/__init__.py +28 -0
  46. reflectorch/inference/inference_model.py +848 -0
  47. reflectorch/inference/input_interface.py +239 -0
  48. reflectorch/inference/loading_data.py +55 -0
  49. reflectorch/inference/multilayer_fitter.py +171 -0
  50. reflectorch/inference/multilayer_inference_model.py +193 -0
  51. reflectorch/inference/plotting.py +524 -0
  52. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  53. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  54. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  55. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  56. reflectorch/inference/preprocess_exp/interpolation.py +19 -0
  57. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  58. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  59. reflectorch/inference/query_matcher.py +82 -0
  60. reflectorch/inference/record_time.py +43 -0
  61. reflectorch/inference/sampler_solution.py +56 -0
  62. reflectorch/inference/scipy_fitter.py +364 -0
  63. reflectorch/inference/torch_fitter.py +87 -0
  64. reflectorch/ml/__init__.py +32 -0
  65. reflectorch/ml/basic_trainer.py +292 -0
  66. reflectorch/ml/callbacks.py +81 -0
  67. reflectorch/ml/dataloaders.py +27 -0
  68. reflectorch/ml/loggers.py +56 -0
  69. reflectorch/ml/schedulers.py +356 -0
  70. reflectorch/ml/trainers.py +201 -0
  71. reflectorch/ml/utils.py +2 -0
  72. reflectorch/models/__init__.py +16 -0
  73. reflectorch/models/activations.py +50 -0
  74. reflectorch/models/encoders/__init__.py +19 -0
  75. reflectorch/models/encoders/conv_encoder.py +219 -0
  76. reflectorch/models/encoders/conv_res_net.py +115 -0
  77. reflectorch/models/encoders/fno.py +134 -0
  78. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  79. reflectorch/models/networks/__init__.py +14 -0
  80. reflectorch/models/networks/mlp_networks.py +434 -0
  81. reflectorch/models/networks/residual_net.py +157 -0
  82. reflectorch/paths.py +29 -0
  83. reflectorch/runs/__init__.py +31 -0
  84. reflectorch/runs/config.py +25 -0
  85. reflectorch/runs/slurm_utils.py +93 -0
  86. reflectorch/runs/train.py +78 -0
  87. reflectorch/runs/utils.py +405 -0
  88. reflectorch/test_config.py +4 -0
  89. reflectorch/train.py +4 -0
  90. reflectorch/train_on_cluster.py +4 -0
  91. reflectorch/utils.py +98 -0
  92. reflectorch-1.5.1.dist-info/METADATA +151 -0
  93. reflectorch-1.5.1.dist-info/RECORD +96 -0
  94. reflectorch-1.5.1.dist-info/WHEEL +5 -0
  95. reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
  96. reflectorch-1.5.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,356 @@
1
+ import math
2
+ from torch.optim import lr_scheduler
3
+
4
+ import numpy as np
5
+
6
+ from reflectorch.ml.basic_trainer import Trainer, TrainerCallback, PeriodicTrainerCallback
7
+
8
+ __all__ = [
9
+ 'ScheduleBatchSize',
10
+ 'ScheduleLR',
11
+ 'StepLR',
12
+ 'CyclicLR',
13
+ 'LogCyclicLR',
14
+ 'ReduceLROnPlateau',
15
+ 'OneCycleLR',
16
+ 'CosineAnnealingWithWarmup',
17
+ ]
18
+
19
+
20
+ class ScheduleBatchSize(PeriodicTrainerCallback):
21
+ """Batch size scheduler
22
+
23
+ Args:
24
+ step (int): number of iterations after which the batch size is modified.
25
+ gamma (int, optional): quantity which is added to or multiplied with the current batch size. Defaults to 2.
26
+ last_epoch (int, optional): the last training iteration for which the batch size is modified. Defaults to -1.
27
+ mode (str, optional): ``'add'`` for addition or ``'multiply'`` for multiplication. Defaults to 'add'.
28
+ """
29
+ def __init__(self, step: int, gamma: int = 2, last_epoch: int = -1, mode: str = 'add'):
30
+ super().__init__(step, last_epoch)
31
+
32
+ assert mode in ('add', 'multiply')
33
+
34
+ self.gamma = gamma
35
+ self.mode = mode
36
+
37
+ def _end_batch(self, trainer: Trainer, batch_num: int) -> None:
38
+ if self.mode == 'add':
39
+ trainer.batch_size += self.gamma
40
+ elif self.mode == 'multiply':
41
+ trainer.batch_size *= self.gamma
42
+
43
+
44
+ class ScheduleLR(TrainerCallback):
45
+ """Base class for learning rate schedulers
46
+
47
+ Args:
48
+ lr_scheduler_cls: class of the learning rate scheduler
49
+ """
50
+
51
+ def __init__(self, lr_scheduler_cls, **kwargs):
52
+ self.lr_scheduler_cls = lr_scheduler_cls
53
+ self.kwargs = kwargs
54
+ self.lr_scheduler = None
55
+
56
+ def start_training(self, trainer: Trainer) -> None:
57
+ """initializes a learning rate scheduler based on its class and keyword arguments at the start of training
58
+
59
+ Args:
60
+ trainer (Trainer): the trainer object
61
+ """
62
+ self.lr_scheduler = self.lr_scheduler_cls(trainer.optim, **self.kwargs)
63
+ trainer.callback_params['lrs'] = []
64
+
65
+ def end_batch(self, trainer: Trainer, batch_num: int) -> None:
66
+ """modifies the learning rate at the end of each iteration
67
+
68
+ Args:
69
+ trainer (Trainer): the trainer object
70
+ batch_num (int): index of the current iteration
71
+ """
72
+ self.lr_scheduler.step()
73
+
74
+ def simulate_and_plot(self, total_steps: int, initial_lr: float, log_scale: bool = False):
75
+ import torch
76
+ import matplotlib.pyplot as plt
77
+
78
+ dummy_optim = torch.optim.Adam([torch.zeros(1)], lr=initial_lr)
79
+ scheduler = self.lr_scheduler_cls(dummy_optim, **self.kwargs)
80
+
81
+ lrs = []
82
+ for step in range(total_steps):
83
+ lrs.append(dummy_optim.param_groups[0]['lr'])
84
+ scheduler.step()
85
+
86
+ plt.figure(figsize=(10, 6))
87
+ plt.plot(lrs, label='Learning Rate')
88
+ plt.xlabel('Steps')
89
+ plt.ylabel('Learning Rate')
90
+ plt.title('Learning Rate Schedule')
91
+
92
+ if log_scale:
93
+ plt.yscale('log')
94
+
95
+ plt.grid(True, which="both", linestyle='--', linewidth=0.5)
96
+ plt.legend()
97
+ plt.show()
98
+
99
+
100
+ class StepLR(ScheduleLR):
101
+ """Learning rate scheduler which decays the learning rate of each parameter group by gamma every ``step_size`` epochs.
102
+
103
+ Args:
104
+ step_size (int): Period of learning rate decay
105
+ gamma (float): Multiplicative factor of learning rate decay
106
+ last_epoch (int, optional): The index of last iteration. Defaults to -1.
107
+ """
108
+ def __init__(self, step_size: int, gamma: float, last_epoch: int = -1, **kwargs):
109
+
110
+
111
+ super().__init__(lr_scheduler.StepLR, step_size=step_size, gamma=gamma, last_epoch=last_epoch, **kwargs)
112
+
113
+ def start_training(self, trainer: Trainer) -> None:
114
+ trainer.optim.param_groups[0]['initial_lr'] = trainer.lr()
115
+ super().start_training(trainer)
116
+
117
+
118
+ class CyclicLR(ScheduleLR):
119
+ """Cyclic learning rate scheduler
120
+
121
+ Args:
122
+ base_lr (float): Initial learning rate which is the lower boundary in the cycle
123
+ max_lr (float): Upper learning rate boundary in the cycle
124
+ step_size_up (int, optional): Number of training iterations in the increasing half of a cycle. Defaults to 2000.
125
+ cycle_momentum (bool, optional): If True, momentum is cycled inversely to learning rate between ``base_momentum`` and ``max_momentum``. Defaults to False.
126
+ gamma (float, optional): Constant in ``‘exp_range’`` mode scaling function: gamma^(cycle iterations). Defaults to 1.
127
+ mode (str, optional): One of: ``'triangular'`` (a basic triangular cycle without amplitude scaling),
128
+ ``'triangular2'`` (a basic triangular cycle that scales initial amplitude by half each cycle), ``'exp_range'``
129
+ (a cycle that scales initial amplitude by gamma^iterations at each cycle iteration). Defaults to 'triangular'.
130
+ """
131
+ def __init__(self, base_lr, max_lr, step_size_up: int = 2000,
132
+ cycle_momentum: bool = False, gamma: float = 1., mode: str = 'triangular',
133
+ **kwargs):
134
+ super().__init__(
135
+ lr_scheduler.CyclicLR,
136
+ base_lr=base_lr,
137
+ max_lr=max_lr,
138
+ step_size_up=step_size_up,
139
+ cycle_momentum=cycle_momentum,
140
+ gamma=gamma,
141
+ mode=mode,
142
+ **kwargs
143
+ )
144
+
145
+
146
+ class LogCyclicLR(TrainerCallback):
147
+ """Cyclic learning rate scheduler on a logarithmic scale
148
+
149
+ Args:
150
+ base_lr (float): Lower learning rate boundary in the cycle
151
+ max_lr (float): Upper learning rate boundary in the cycle
152
+ period (int, optional): Number of training iterations in the cycle. Defaults to 2000.
153
+ gamma (float, optional): Constant for scaling the amplitude as ``gamma`` ^ ``iterations``. Defaults to 1.
154
+ start_period (int, optional): Number of starting iterations with the default learning rate.
155
+ log (bool, optional): If ``True``, the cycle is in the logarithmic domain.
156
+ param_groups (tupe, optional): Parameter groups of the optimizer.
157
+ """
158
+ def __init__(self,
159
+ base_lr,
160
+ max_lr,
161
+ period: int = 2000,
162
+ gamma: float = None,
163
+ log: bool = True,
164
+ param_groups: tuple = (0,),
165
+ start_period: int = 25,
166
+ ):
167
+ self.base_lr = base_lr
168
+ self.max_lr = max_lr
169
+ self.period = period
170
+ self.gamma = gamma
171
+ self.param_groups = param_groups
172
+ self.log = log
173
+ self.start_period = start_period
174
+ self._axis = None
175
+ self._period = None
176
+
177
+ def get_lr(self, batch_num: int):
178
+ return self._get_lr(batch_num)
179
+
180
+ def _get_lr(self, batch_num):
181
+ num_period, t = batch_num // self.period, batch_num % self.period
182
+
183
+ if self._period != num_period:
184
+ self._period = num_period
185
+ if self.gamma and (num_period >= self.start_period):
186
+ amp = (self.max_lr - self.base_lr) * (self.gamma ** (num_period - self.start_period))
187
+ max_lr = self.base_lr + amp
188
+ else:
189
+ max_lr = self.max_lr
190
+
191
+ if self.log:
192
+ self._axis = np.logspace(np.log10(self.base_lr), np.log10(max_lr), self.period // 2)
193
+ else:
194
+ self._axis = np.linspace(self.base_lr, max_lr, self.period // 2)
195
+ if t < self.period // 2:
196
+ lr = self._axis[t]
197
+ else:
198
+ lr = self._axis[self.period - t - 1]
199
+ return lr
200
+
201
+ def end_batch(self, trainer: Trainer, batch_num: int):
202
+ lr = self.get_lr(batch_num)
203
+ for param_group in self.param_groups:
204
+ trainer.set_lr(lr, param_group)
205
+
206
+ def simulate_and_plot(self, total_steps: int, log_scale: bool = True):
207
+ import matplotlib.pyplot as plt
208
+ lrs = [self.get_lr(batch_num) for batch_num in range(total_steps)]
209
+
210
+ plt.figure(figsize=(10, 6))
211
+ plt.plot(lrs, label='Learning Rate')
212
+ plt.xlabel('Steps')
213
+ plt.ylabel('Learning Rate')
214
+ plt.title('Learning Rate Schedule')
215
+ if log_scale:
216
+ plt.yscale('log')
217
+ plt.grid(True, which='both', linestyle='--', linewidth=0.5)
218
+ plt.legend()
219
+ plt.show()
220
+
221
+
222
+ class ReduceLROnPlateau(TrainerCallback):
223
+ """Learning rate scheduler which reduces the learning rate when the loss stops decreasing
224
+
225
+ Args:
226
+ gamma (float, optional): Multiplicative factor of learning rate decay. Defaults to 0.5.
227
+ patience (int, optional): The number of allowed iterations with no improvement after which the learning rate will be reduced. Defaults to 500.
228
+ average (int, optional): Size of the window over which the average loss is computed. Defaults to 50.
229
+ loss_key (str, optional): Defaults to 'total_loss'.
230
+ param_groups (tuple, optional): Defaults to (0,).
231
+
232
+ """
233
+ def __init__(
234
+ self,
235
+ gamma: float = 0.5,
236
+ patience: int = 500,
237
+ average: int = 50,
238
+ loss_key: str = 'total_loss',
239
+ param_groups: tuple = (0,),
240
+ ):
241
+ """
242
+ """
243
+ self.patience = patience
244
+ self.average = average
245
+ self.gamma = gamma
246
+ self.loss_key = loss_key
247
+ self.param_groups = param_groups
248
+
249
+ def end_batch(self, trainer: Trainer, batch_num: int) -> None:
250
+ loss = trainer.losses[self.loss_key]
251
+
252
+ if len(loss) < self.patience:
253
+ return
254
+
255
+ if np.mean(loss[-self.patience:-(self.patience - self.average)]) <= np.mean(loss[-self.average:]):
256
+ for param_group in self.param_groups:
257
+ trainer.set_lr(trainer.lr(param_group) * self.gamma, param_group)
258
+
259
+
260
+ class OneCycleLR(ScheduleLR):
261
+ """One-cycle learning rate scheduler (https://arxiv.org/abs/1708.07120)
262
+
263
+ Args:
264
+ max_lr (float): Upper learning rate boundary in the cycle
265
+ total_steps (int): The total number of steps in the cycle
266
+ pct_start (float, optional): The percentage of the cycle (in number of steps) spent increasing the learning rate. Defaults to 0.3.
267
+ div_factor (float, optional): Determines the initial learning rate via initial_lr = ``max_lr`` / ``div_factor``. Defaults to 25..
268
+ final_div_factor (float, optional): Determines the minimum learning rate via min_lr = ``initial_lr`` / ``final_div_factor``. Defaults to 1e4.
269
+ three_phase (bool, optional): If ``True``, use a third phase of the schedule to annihilate the learning rate according to ``final_div_factor`` instead of modifying the second phase. Defaults to True.
270
+ """
271
+ def __init__(self, max_lr: float, total_steps: int, pct_start: float = 0.3, div_factor: float = 25.,
272
+ final_div_factor: float = 1e4, three_phase: bool = True, **kwargs):
273
+ super().__init__(
274
+ lr_scheduler.OneCycleLR,
275
+ max_lr=max_lr,
276
+ total_steps=total_steps,
277
+ pct_start=pct_start,
278
+ div_factor=div_factor ,
279
+ final_div_factor=final_div_factor,
280
+ three_phase=three_phase,
281
+ **kwargs
282
+ )
283
+
284
+ class CosineAnnealingWithWarmup(TrainerCallback):
285
+ """
286
+ Cosine annealing scheduler with a warm-up stage.
287
+
288
+ Args:
289
+ max_lr (float): The maximum learning rate after the warm-up phase.
290
+ min_lr (float): The minimum learning rate after the warm-up phase.
291
+ warmup_iters (int): The number of iterations for the warm-up phase.
292
+ total_iters (int): The total number of iterations for the scheduler (including warm-up).
293
+ """
294
+ def __init__(self, max_lr=None, min_lr=1.0e-6, warmup_iters=100, total_iters=100000):
295
+ self.max_lr = max_lr
296
+ self.min_lr = min_lr
297
+ self.warmup_iters = warmup_iters
298
+ self.total_iters = total_iters
299
+
300
+ def get_lr(self, step):
301
+ """
302
+ Compute the learning rate for a given iteration.
303
+
304
+ Args:
305
+ step (int): The current iteration.
306
+
307
+ Returns:
308
+ float: The learning rate for the current iteration.
309
+ """
310
+ if step < self.warmup_iters:
311
+ # Warm-up stage: Linear increase from 0 to max_lr
312
+ return self.max_lr * step / self.warmup_iters
313
+ elif step < self.total_iters:
314
+ # Cosine annealing stage
315
+ t = (step - self.warmup_iters) / (self.total_iters - self.warmup_iters)
316
+ return self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1 + math.cos(math.pi * t))
317
+ else:
318
+ # Beyond total iterations: Return min_lr
319
+ return self.min_lr
320
+
321
+ def start_training(self, trainer: Trainer) -> None:
322
+ self.max_lr = trainer.lr()
323
+
324
+ def end_batch(self, trainer: Trainer, batch_num: int):
325
+ """
326
+ Updates the learning rate at the end of each batch.
327
+
328
+ Args:
329
+ trainer (Trainer): The trainer object.
330
+ batch_num (int): The current batch number.
331
+ """
332
+ lr = self.get_lr(batch_num)
333
+ trainer.set_lr(lr)
334
+
335
+ def simulate_and_plot(self, total_steps: int = None, log_scale: bool = False):
336
+ """
337
+ Simulates and plots the learning rate evolution.
338
+
339
+ Args:
340
+ total_batches (int, optional): Total number of batches to simulate. If None, uses self.total_iters.
341
+ """
342
+
343
+ total_steps = total_steps or self.total_iters
344
+ lrs = [self.get_lr(step) for step in range(total_steps)]
345
+
346
+ import matplotlib.pyplot as plt
347
+ plt.figure(figsize=(10, 6))
348
+ plt.plot(lrs, label='Learning Rate')
349
+ plt.xlabel('Steps')
350
+ plt.ylabel('Learning Rate')
351
+ plt.title('Learning Rate Scheduler')
352
+ if log_scale:
353
+ plt.yscale('log')
354
+ plt.grid(True, which='both', linestyle='--', linewidth=0.5)
355
+ plt.legend()
356
+ plt.show()
@@ -0,0 +1,201 @@
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ from dataclasses import dataclass
6
+ from typing import Optional
7
+
8
+ from reflectorch.data_generation import BATCH_DATA_TYPE
9
+ from reflectorch.ml.basic_trainer import Trainer
10
+ from reflectorch.ml.dataloaders import ReflectivityDataLoader
11
+
12
+ __all__ = [
13
+ 'RealTimeSimTrainer',
14
+ 'DenoisingAETrainer',
15
+ 'PointEstimatorTrainer',
16
+ ]
17
+
18
+
19
+ @dataclass
20
+ class BasicBatchData:
21
+ scaled_curves: torch.Tensor
22
+ scaled_bounds: torch.Tensor
23
+ scaled_params: torch.Tensor = None
24
+ scaled_sigmas: Optional[torch.Tensor] = None
25
+ scaled_q_values: Optional[torch.Tensor] = None
26
+ scaled_denoised_curves: Optional[torch.Tensor] = None
27
+ key_padding_mask: Optional[torch.Tensor] = None
28
+ scaled_conditioning_params: Optional[torch.Tensor] = None
29
+ unscaled_q_values: Optional[torch.Tensor] = None
30
+
31
+ class RealTimeSimTrainer(Trainer):
32
+ """Trainer with functionality to customize the sampled batch of data"""
33
+ loader: ReflectivityDataLoader
34
+
35
+ def get_batch_by_idx(self, batch_num: int):
36
+ """Gets a batch of data with the default batch size"""
37
+ batch_data = self.loader.get_batch(self.batch_size)
38
+ return self._get_batch(batch_data)
39
+
40
+ def get_batch_by_size(self, batch_size: int):
41
+ """Gets a batch of data with a custom batch size"""
42
+ batch_data = self.loader.get_batch(batch_size)
43
+ return self._get_batch(batch_data)
44
+
45
+ def _get_batch(self, batch_data: BATCH_DATA_TYPE):
46
+ """Modify the batch of data sampled from the data loader"""
47
+ raise NotImplementedError
48
+
49
+
50
+ class PointEstimatorTrainer(RealTimeSimTrainer):
51
+ """Point estimator trainer for the inverse problem."""
52
+
53
+ def init(self):
54
+ if getattr(self, 'use_l1_loss', False):
55
+ self.criterion = nn.L1Loss(reduction='none')
56
+ else:
57
+ self.criterion = nn.MSELoss(reduction='none')
58
+ self.use_curve_reconstruction_loss = getattr(self, 'use_curve_reconstruction_loss', False)
59
+ self.rescale_loss_interval_width = getattr(self, 'rescale_loss_interval_width', False)
60
+ if self.use_curve_reconstruction_loss:
61
+ self.loader.calc_denoised_curves = True
62
+
63
+ self.train_with_q_input = getattr(self, 'train_with_q_input', False)
64
+ self.train_with_sigmas = getattr(self, 'train_with_sigmas', False)
65
+ self.condition_on_q_resolutions = getattr(self, 'condition_on_q_resolutions', False)
66
+
67
+ def _get_batch(self, batch_data: BATCH_DATA_TYPE) -> BasicBatchData:
68
+ def get_scaled_or_none(key, scaler=None):
69
+ value = batch_data.get(key)
70
+ if value is None:
71
+ return None
72
+ scale_func = scaler or (lambda x: x)
73
+ return scale_func(value).to(torch.float32)
74
+
75
+ scaled_params = batch_data['scaled_params'].to(torch.float32)
76
+ scaled_curves = batch_data['scaled_noisy_curves'].to(torch.float32)
77
+ scaled_denoised_curves = get_scaled_or_none('curves', self.loader.curves_scaler.scale)
78
+ scaled_q_values = get_scaled_or_none('q_values', self.loader.q_generator.scale_q) if self.train_with_q_input else None
79
+ key_padding_mask = batch_data.get('key_padding_mask', None)
80
+
81
+ scaled_q_resolutions = get_scaled_or_none('q_resolutions', self.loader.smearing.scale_resolutions) if self.condition_on_q_resolutions else None
82
+ conditioning_params = []
83
+ if scaled_q_resolutions is not None:
84
+ conditioning_params.append(scaled_q_resolutions)
85
+ scaled_conditioning_params = torch.cat(conditioning_params, dim=-1) if len(conditioning_params) > 0 else None
86
+
87
+ num_params = scaled_params.shape[-1] // 3
88
+ assert num_params * 3 == scaled_params.shape[-1]
89
+ scaled_params, scaled_bounds = torch.split(scaled_params, [num_params, 2 * num_params], dim=-1)
90
+
91
+ return BasicBatchData(
92
+ scaled_params=scaled_params,
93
+ scaled_bounds=scaled_bounds,
94
+ scaled_curves=scaled_curves,
95
+ scaled_q_values=scaled_q_values,
96
+ scaled_denoised_curves=scaled_denoised_curves,
97
+ scaled_conditioning_params=scaled_conditioning_params,
98
+ unscaled_q_values=batch_data['q_values'],
99
+ key_padding_mask=key_padding_mask,
100
+ )
101
+
102
+ def get_loss_dict(self, batch_data: BasicBatchData):
103
+ """Returns the regression loss"""
104
+ scaled_params=batch_data.scaled_params
105
+ scaled_curves=batch_data.scaled_curves
106
+ scaled_bounds=batch_data.scaled_bounds
107
+ scaled_q_values=batch_data.scaled_q_values
108
+ key_padding_mask=batch_data.key_padding_mask
109
+ scaled_conditioning_params=batch_data.scaled_conditioning_params
110
+ unscaled_q_values=batch_data.unscaled_q_values
111
+
112
+ predicted_params = self.model(
113
+ curves = scaled_curves,
114
+ bounds = scaled_bounds,
115
+ q_values = scaled_q_values,
116
+ conditioning_params = scaled_conditioning_params,
117
+ key_padding_mask = key_padding_mask,
118
+ unscaled_q_values = unscaled_q_values,
119
+ )
120
+
121
+ if not self.rescale_loss_interval_width:
122
+ loss = self.criterion(predicted_params, scaled_params).mean()
123
+ else:
124
+ n_params = scaled_params.shape[-1]
125
+ b_min = scaled_bounds[..., :n_params]
126
+ b_max = scaled_bounds[..., n_params:]
127
+ interval_width = b_max - b_min
128
+
129
+ base_loss = self.criterion(predicted_params, scaled_params)
130
+ if isinstance(self.criterion, torch.nn.MSELoss):
131
+ width_factors = (interval_width / 2) ** 2
132
+ elif isinstance(self.criterion, torch.nn.L1Loss):
133
+ width_factors = interval_width / 2
134
+
135
+ loss = (width_factors * base_loss).mean()
136
+
137
+ return {'loss': loss}
138
+
139
+
140
+ # class PointEstimatorTrainer(RealTimeSimTrainer):
141
+ # """Trainer for the regression inverse problem with incorporation of prior bounds"""
142
+ # add_sigmas_to_context: bool = False
143
+
144
+ # def _get_batch(self, batch_data: BATCH_DATA_TYPE):
145
+ # scaled_params = batch_data['scaled_params'].to(torch.float32)
146
+ # scaled_curves = batch_data['scaled_noisy_curves'].to(torch.float32)
147
+ # if self.train_with_q_input:
148
+ # q_values = batch_data['q_values'].to(torch.float32)
149
+ # scaled_q_values = self.loader.q_generator.scale_q(q_values)
150
+ # else:
151
+ # scaled_q_values = None
152
+
153
+ # num_params = scaled_params.shape[-1] // 3
154
+ # assert num_params * 3 == scaled_params.shape[-1]
155
+ # scaled_params, scaled_bounds = torch.split(scaled_params, [num_params, 2 * num_params], dim=-1)
156
+
157
+ # return scaled_params, scaled_bounds, scaled_curves, scaled_q_values
158
+
159
+ # def get_loss_dict(self, batch_data):
160
+ # """computes the loss dictionary"""
161
+
162
+ # scaled_params, scaled_bounds, scaled_curves, scaled_q_values = batch_data
163
+
164
+ # if self.train_with_q_input:
165
+ # predicted_params = self.model(scaled_curves, scaled_bounds, scaled_q_values)
166
+ # else:
167
+ # predicted_params = self.model(scaled_curves, scaled_bounds)
168
+
169
+ # loss = self.mse(predicted_params, scaled_params)
170
+ # return {'loss': loss}
171
+
172
+ # def init(self):
173
+ # self.mse = nn.MSELoss()
174
+
175
+
176
+ class DenoisingAETrainer(RealTimeSimTrainer):
177
+ """Trainer which can be used for training a denoising autoencoder model. Overrides _get_batch and get_loss_dict methods """
178
+ def init(self):
179
+ self.loader.calc_denoised_curves = True
180
+
181
+ if getattr(self, 'use_l1_loss', False):
182
+ self.criterion = nn.L1Loss()
183
+ else:
184
+ self.criterion = nn.MSELoss()
185
+
186
+ def _get_batch(self, batch_data: BATCH_DATA_TYPE):
187
+ """returns scaled curves with and without noise"""
188
+ scaled_noisy_curves, curves = batch_data['scaled_noisy_curves'], batch_data['curves']
189
+ scaled_curves = self.loader.curves_scaler.scale(curves)
190
+
191
+ scaled_noisy_curves, scaled_curves = scaled_noisy_curves.to(torch.float32), scaled_curves.to(torch.float32)
192
+
193
+ return scaled_noisy_curves, scaled_curves
194
+
195
+ def get_loss_dict(self, batch_data):
196
+ """returns the reconstruction loss of the autoencoder"""
197
+ scaled_noisy_curves, scaled_curves = batch_data
198
+ restored_curves = self.model(scaled_noisy_curves)
199
+ loss = self.criterion(scaled_curves, restored_curves)
200
+ return {'loss': loss}
201
+
@@ -0,0 +1,2 @@
1
+ def is_divisor(num: int, div: int):
2
+ return num and not num % div
@@ -0,0 +1,16 @@
1
+ from reflectorch.models.encoders import *
2
+ from reflectorch.models.networks import *
3
+
4
+ __all__ = [
5
+ "ConvEncoder",
6
+ "ConvDecoder",
7
+ "ConvAutoencoder",
8
+ "FnoEncoder",
9
+ "IntegralConvEmbedding",
10
+ "SpectralConv1d",
11
+ "ConvResidualNet1D",
12
+ "ResidualMLP",
13
+ "NetworkWithPriors",
14
+ "NetworkWithPriorsConvEmb",
15
+ "NetworkWithPriorsFnoEmb",
16
+ ]
@@ -0,0 +1,50 @@
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn.functional import relu
4
+
5
+ class Rowdy(nn.Module):
6
+ """adaptive activation function"""
7
+ def __init__(self, K=9):
8
+ super().__init__()
9
+ self.K = K
10
+ self.alpha = nn.Parameter(torch.cat((torch.ones(1), torch.zeros(K-1))))
11
+ self.alpha.requiresGrad = True
12
+ self.omega = nn.Parameter(torch.ones(K))
13
+ self.omega.requiresGrad = True
14
+
15
+ def forward(self, x):
16
+ rowdy = self.alpha[0]*relu(self.omega[0]*x)
17
+ for k in range(1, self.K):
18
+ rowdy += self.alpha[k]*torch.sin(self.omega[k]*k*x)
19
+ return rowdy
20
+
21
+
22
+ ACTIVATIONS = {
23
+ 'relu': nn.ReLU,
24
+ 'lrelu': nn.LeakyReLU,
25
+ 'gelu': nn.GELU,
26
+ 'selu': nn.SELU,
27
+ 'elu': nn.ELU,
28
+ 'sigmoid': nn.Sigmoid,
29
+ 'tanh': nn.Tanh,
30
+ 'silu': nn.SiLU,
31
+ 'mish': nn.Mish,
32
+ 'rowdy': Rowdy,
33
+ }
34
+
35
+
36
+ def activation_by_name(name):
37
+ """returns an activation function module corresponding to its name
38
+
39
+ Args:
40
+ name (str): string denoting the activation function ('relu', 'lrelu', 'gelu', 'selu', 'elu', 'sigmoid', 'silu', 'mish', 'rowdy')
41
+
42
+ Returns:
43
+ nn.Module: Pytorch activation function module
44
+ """
45
+ if not isinstance(name, str):
46
+ return name
47
+ try:
48
+ return ACTIVATIONS[name.lower()]
49
+ except KeyError:
50
+ raise KeyError(f'Unknown activation function {name}')
@@ -0,0 +1,19 @@
1
+ from reflectorch.models.encoders.conv_encoder import (
2
+ ConvEncoder,
3
+ ConvDecoder,
4
+ ConvAutoencoder,
5
+ )
6
+ from reflectorch.models.encoders.fno import FnoEncoder, SpectralConv1d
7
+ from reflectorch.models.encoders.integral_kernel_embedding import IntegralConvEmbedding
8
+ from reflectorch.models.encoders.conv_res_net import ConvResidualNet1D
9
+
10
+
11
+ __all__ = [
12
+ "ConvEncoder",
13
+ "ConvDecoder",
14
+ "ConvAutoencoder",
15
+ "ConvResidualNet1D",
16
+ "FnoEncoder",
17
+ "SpectralConv1d",
18
+ "IntegralConvEmbedding",
19
+ ]