reflectorch 1.2.1__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (41) hide show
  1. reflectorch/data_generation/__init__.py +4 -0
  2. reflectorch/data_generation/dataset.py +27 -7
  3. reflectorch/data_generation/noise.py +115 -9
  4. reflectorch/data_generation/priors/parametric_models.py +91 -16
  5. reflectorch/data_generation/priors/parametric_subpriors.py +28 -7
  6. reflectorch/data_generation/priors/sampler_strategies.py +67 -3
  7. reflectorch/data_generation/q_generator.py +97 -43
  8. reflectorch/data_generation/reflectivity/__init__.py +53 -11
  9. reflectorch/data_generation/reflectivity/kinematical.py +4 -5
  10. reflectorch/data_generation/reflectivity/smearing.py +25 -10
  11. reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
  12. reflectorch/data_generation/smearing.py +42 -11
  13. reflectorch/data_generation/utils.py +93 -18
  14. reflectorch/extensions/refnx/refnx_conversion.py +77 -0
  15. reflectorch/inference/inference_model.py +795 -159
  16. reflectorch/inference/loading_data.py +37 -0
  17. reflectorch/inference/plotting.py +517 -0
  18. reflectorch/inference/preprocess_exp/interpolation.py +5 -2
  19. reflectorch/inference/scipy_fitter.py +98 -7
  20. reflectorch/ml/__init__.py +2 -0
  21. reflectorch/ml/basic_trainer.py +18 -6
  22. reflectorch/ml/callbacks.py +5 -4
  23. reflectorch/ml/loggers.py +25 -0
  24. reflectorch/ml/schedulers.py +116 -0
  25. reflectorch/ml/trainers.py +131 -23
  26. reflectorch/models/__init__.py +2 -1
  27. reflectorch/models/encoders/__init__.py +2 -2
  28. reflectorch/models/encoders/conv_encoder.py +54 -40
  29. reflectorch/models/encoders/fno.py +23 -16
  30. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  31. reflectorch/models/networks/__init__.py +2 -0
  32. reflectorch/models/networks/mlp_networks.py +331 -153
  33. reflectorch/models/networks/residual_net.py +31 -5
  34. reflectorch/runs/train.py +0 -1
  35. reflectorch/runs/utils.py +48 -11
  36. reflectorch/utils.py +30 -0
  37. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/METADATA +20 -17
  38. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/RECORD +41 -36
  39. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/WHEEL +1 -1
  40. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info/licenses}/LICENSE.txt +0 -0
  41. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/top_level.txt +0 -0
@@ -2,11 +2,14 @@ import warnings
2
2
 
3
3
  import numpy as np
4
4
  from scipy.optimize import minimize, curve_fit
5
+ import torch
5
6
 
7
+ from reflectorch.data_generation.priors.base import PriorSampler
6
8
  from reflectorch.data_generation.reflectivity import abeles_np
7
9
 
8
10
  __all__ = [
9
11
  "standard_refl_fit",
12
+ "refl_fit",
10
13
  "fit_refl_curve",
11
14
  "restore_masked_params",
12
15
  "get_fit_with_growth",
@@ -26,7 +29,6 @@ def standard_restore_params(fitted_params) -> dict:
26
29
  def mse_loss(curve1, curve2):
27
30
  return np.sum((curve1 - curve2) ** 2)
28
31
 
29
-
30
32
  def standard_refl_fit(
31
33
  q: np.ndarray, curve: np.ndarray,
32
34
  init_params: np.ndarray,
@@ -41,7 +43,7 @@ def standard_refl_fit(
41
43
  init_params = np.clip(init_params, *bounds)
42
44
 
43
45
  res = curve_fit(
44
- get_scaled_curve_func(
46
+ standard_get_scaled_curve_func(
45
47
  refl_generator=refl_generator,
46
48
  restore_params_func=restore_params_func,
47
49
  scale_curve_func=scale_curve_func,
@@ -53,9 +55,73 @@ def standard_refl_fit(
53
55
  curve = refl_generator(q, **restore_params_func(res[0]))
54
56
  return res[0], curve
55
57
 
58
+ def refl_fit(
59
+ q: np.ndarray,
60
+ curve: np.ndarray,
61
+ init_params: np.ndarray,
62
+ prior_sampler: PriorSampler,
63
+ bounds: np.ndarray = None,
64
+ error_bars: np.ndarray = None,
65
+ scale_curve_func=np.log10,
66
+ method: str = 'trf', #'lm', 'trf'
67
+ polishing_max_nfev: int = None,
68
+ reflectivity_kwargs: dict = None,
69
+ **kwargs
70
+ ):
71
+ if bounds is not None:
72
+ # introduce a small perturbation for fixed bounds
73
+ epsilon = 1e-6
74
+ adjusted_bounds = bounds.copy()
75
+
76
+ for i in range(bounds.shape[1]):
77
+ if bounds[0, i] == bounds[1, i]:
78
+ adjusted_bounds[0, i] -= epsilon
79
+ adjusted_bounds[1, i] += epsilon
80
+
81
+ init_params = np.clip(init_params, *adjusted_bounds)
82
+ if method != 'lm':
83
+ kwargs['bounds'] = adjusted_bounds
84
+
85
+ reflectivity_kwargs = reflectivity_kwargs or {}
86
+ for key, value in reflectivity_kwargs.items():
87
+ if isinstance(value, float):
88
+ reflectivity_kwargs[key] = torch.tensor([[value]], dtype=torch.float64)
89
+ elif isinstance(value, np.ndarray):
90
+ reflectivity_kwargs[key] = torch.tensor(value, dtype=torch.float32).unsqueeze(0)
91
+
92
+ curve = np.clip(curve, a_min=1e-12, a_max=None)
93
+
94
+ if error_bars is not None and scale_curve_func == np.log10:
95
+ error_bars = np.clip(error_bars, a_min=1e-20, a_max=None)
96
+ scaled_error_bars = error_bars / (curve * np.log(10))
97
+ else:
98
+ scaled_error_bars = None
99
+
100
+ res = curve_fit(
101
+ f=get_scaled_curve_func(
102
+ scale_curve_func=scale_curve_func,
103
+ prior_sampler=prior_sampler,
104
+ reflectivity_kwargs=reflectivity_kwargs,
105
+ ),
106
+ xdata=q,
107
+ ydata=scale_curve_func(curve).reshape(-1),
108
+ p0=init_params,
109
+ sigma=scaled_error_bars,
110
+ absolute_sigma=True,
111
+ method=method,
112
+ max_nfev=polishing_max_nfev,
113
+ **kwargs
114
+ )
115
+
116
+ curve = prior_sampler.param_model.reflectivity(torch.tensor(q, dtype=torch.float64),
117
+ torch.tensor(res[0], dtype=torch.float64).unsqueeze(0),
118
+ **reflectivity_kwargs).squeeze().numpy()
119
+ return res[0], curve
120
+
56
121
 
57
122
  def get_fit_with_growth(
58
- q: np.ndarray, curve: np.ndarray,
123
+ q: np.ndarray,
124
+ curve: np.ndarray,
59
125
  init_params: np.ndarray,
60
126
  bounds: np.ndarray = None,
61
127
  init_d_change: float = 0.,
@@ -68,10 +134,16 @@ def get_fit_with_growth(
68
134
  bounds = np.concatenate([bounds, np.array([0, max_d_change])[..., None]], -1)
69
135
 
70
136
  params, curve = standard_refl_fit(
71
- q, curve, init_params, bounds, refl_generator=growth_reflectivity,
137
+ q,
138
+ curve,
139
+ init_params,
140
+ bounds,
141
+ refl_generator=growth_reflectivity,
72
142
  restore_params_func=get_restore_params_with_growth_func(q_size=q.size, d_idx=0),
73
- scale_curve_func=scale_curve_func, **kwargs
143
+ scale_curve_func=scale_curve_func,
144
+ **kwargs
74
145
  )
146
+
75
147
  params[0] += params[-1] / 2
76
148
  return params, curve
77
149
 
@@ -97,8 +169,7 @@ def fit_refl_curve(q: np.ndarray, curve: np.ndarray,
97
169
  warnings.warn(f"Minimization did not converge.")
98
170
  return res.x
99
171
 
100
-
101
- def get_scaled_curve_func(
172
+ def standard_get_scaled_curve_func(
102
173
  refl_generator=abeles_np,
103
174
  restore_params_func=standard_restore_params,
104
175
  scale_curve_func=np.log10,
@@ -111,6 +182,26 @@ def get_scaled_curve_func(
111
182
 
112
183
  return scaled_curve_func
113
184
 
185
+ def get_scaled_curve_func(
186
+ scale_curve_func=np.log10,
187
+ prior_sampler: PriorSampler = None,
188
+ reflectivity_kwargs: dict = None,
189
+ ):
190
+ reflectivity_kwargs = reflectivity_kwargs or {}
191
+
192
+ def scaled_curve_func(q, *fitted_params):
193
+ q_tensor = torch.from_numpy(q).to(torch.float64)
194
+ fitted_params_tensor = torch.tensor(fitted_params, dtype=torch.float64).unsqueeze(0)
195
+
196
+ fitted_curve_tensor = prior_sampler.param_model.reflectivity(q_tensor, fitted_params_tensor, **reflectivity_kwargs)
197
+ fitted_curve = fitted_curve_tensor.squeeze().numpy()
198
+
199
+ scaled_curve = scale_curve_func(fitted_curve)
200
+
201
+ return scaled_curve.reshape(-1)
202
+
203
+ return scaled_curve_func
204
+
114
205
 
115
206
  def get_fitting_func(
116
207
  q: np.ndarray,
@@ -15,6 +15,7 @@ __all__ = [
15
15
  'Logger',
16
16
  'Loggers',
17
17
  'PrintLogger',
18
+ 'TensorBoardLogger',
18
19
  'ScheduleBatchSize',
19
20
  'ScheduleLR',
20
21
  'StepLR',
@@ -22,6 +23,7 @@ __all__ = [
22
23
  'LogCyclicLR',
23
24
  'ReduceLROnPlateau',
24
25
  'OneCycleLR',
26
+ 'CosineAnnealingWithWarmup',
25
27
  'ReflectivityDataLoader',
26
28
  'MultilayerDataLoader',
27
29
  'RealTimeSimTrainer',
@@ -1,7 +1,8 @@
1
1
  from typing import Optional, Tuple, Iterable, Any, Union, Type
2
2
  from collections import defaultdict
3
3
 
4
- from tqdm.notebook import trange
4
+ from tqdm import tqdm as standard_tqdm
5
+ from tqdm.notebook import tqdm as notebook_tqdm
5
6
  import numpy as np
6
7
 
7
8
  import torch
@@ -31,7 +32,6 @@ class Trainer(object):
31
32
  logger (Union[Logger, Tuple[Logger, ...], Loggers], optional): logger. Defaults to None.
32
33
  optim_cls (Type[torch.optim.Optimizer], optional): Pytorch optimizer. Defaults to torch.optim.Adam.
33
34
  optim_kwargs (dict, optional): optimizer arguments. Defaults to None.
34
- train_with_q_input (bool, optional): if ``True`` the q values are also used as input. Defaults to False.
35
35
  """
36
36
 
37
37
  TOTAL_LOSS_KEY: str = 'total_loss'
@@ -42,7 +42,6 @@ class Trainer(object):
42
42
  lr: float,
43
43
  batch_size: int,
44
44
  clip_grad_norm_max: Optional[int] = None,
45
- train_with_q_input: bool = False,
46
45
  logger: Union[Logger, Tuple[Logger, ...], Loggers] = None,
47
46
  optim_cls: Type[torch.optim.Optimizer] = torch.optim.Adam,
48
47
  optim_kwargs: dict = None,
@@ -53,7 +52,6 @@ class Trainer(object):
53
52
  self.loader = loader
54
53
  self.batch_size = batch_size
55
54
  self.clip_grad_norm_max = clip_grad_norm_max
56
- self.train_with_q_input = train_with_q_input
57
55
 
58
56
  self.optim = self.configure_optimizer(optim_cls, lr=lr, **(optim_kwargs or {}))
59
57
  self.lrs = []
@@ -78,7 +76,8 @@ class Trainer(object):
78
76
  num_batches: int,
79
77
  callbacks: Union[Tuple['TrainerCallback', ...], 'TrainerCallback'] = (),
80
78
  disable_tqdm: bool = False,
81
- update_tqdm_freq: int = 10,
79
+ use_notebook_tqdm: bool = False,
80
+ update_tqdm_freq: int = 1,
82
81
  grad_accumulation_steps: int = 1,
83
82
  ):
84
83
  """starts the training process
@@ -87,6 +86,7 @@ class Trainer(object):
87
86
  num_batches (int): total number of training iterations
88
87
  callbacks (Union[Tuple['TrainerCallback'], 'TrainerCallback']): the trainer callbacks. Defaults to ().
89
88
  disable_tqdm (bool, optional): if ``True``, the progress bar is disabled. Defaults to False.
89
+ use_notebook_tqdm (bool, optional): should be set to ``True`` when used in a Jupyter Notebook. Defaults to False.
90
90
  update_tqdm_freq (int, optional): frequency for updating the progress bar. Defaults to 10.
91
91
  grad_accumulation_steps (int, optional): number of gradient accumulation steps. Defaults to 1.
92
92
  """
@@ -96,7 +96,8 @@ class Trainer(object):
96
96
 
97
97
  callbacks = _StackedTrainerCallbacks(list(callbacks) + [self.loader])
98
98
 
99
- pbar = trange(num_batches, disable=disable_tqdm)
99
+ tqdm_class = notebook_tqdm if use_notebook_tqdm else standard_tqdm
100
+ pbar = tqdm_class(range(num_batches), disable=disable_tqdm)
100
101
 
101
102
  callbacks.start_training(self)
102
103
 
@@ -121,6 +122,7 @@ class Trainer(object):
121
122
 
122
123
  if self.clip_grad_norm_max is not None:
123
124
  torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip_grad_norm_max)
125
+
124
126
  self.optim.step()
125
127
 
126
128
  avr_loss_dict = {k: np.mean(v) for k, v in avr_loss_dict.items()}
@@ -141,6 +143,16 @@ class Trainer(object):
141
143
  last_loss = np.mean(self.losses[self.TOTAL_LOSS_KEY][-10:])
142
144
  pbar.set_description(f'Loss = {last_loss:.2e}')
143
145
 
146
+ postfix = {}
147
+ for key in self.losses.keys():
148
+ if key != self.TOTAL_LOSS_KEY:
149
+ last_value = self.losses[key][-1]
150
+ postfix[key] = f'{last_value:.4f}'
151
+
152
+ postfix['lr'] = f'{self.lr():.2e}'
153
+
154
+ pbar.set_postfix(postfix)
155
+
144
156
  def get_batch_by_idx(self, batch_num: int) -> Any:
145
157
  raise NotImplementedError
146
158
 
@@ -74,7 +74,8 @@ class LogLosses(TrainerCallback):
74
74
  trainer (Trainer): the trainer object
75
75
  batch_num (int): the index of the current iteration / batch
76
76
  """
77
- try:
78
- trainer.log('train/total_loss', trainer.losses[trainer.TOTAL_LOSS_KEY][-1])
79
- except IndexError:
80
- pass
77
+ for loss_name, loss_values in trainer.losses.items():
78
+ try:
79
+ trainer.log(f'train/{loss_name}', loss_values[-1])
80
+ except IndexError:
81
+ continue
reflectorch/ml/loggers.py CHANGED
@@ -1,7 +1,10 @@
1
+ from torch.utils.tensorboard import SummaryWriter
2
+
1
3
  __all__ = [
2
4
  'Logger',
3
5
  'Loggers',
4
6
  'PrintLogger',
7
+ 'TensorBoardLogger',
5
8
  ]
6
9
 
7
10
 
@@ -29,3 +32,25 @@ class PrintLogger(Logger):
29
32
  """Logger which prints to the console"""
30
33
  def log(self, name: str, data):
31
34
  print(name, ': ', data)
35
+
36
+ class TensorBoardLogger(Logger):
37
+ def __init__(self, log_dir: str):
38
+ """
39
+ Args:
40
+ log_dir (str): Directory where TensorBoard logs will be written
41
+ """
42
+ super().__init__()
43
+ self.writer = SummaryWriter(log_dir=log_dir)
44
+ self.step = 1
45
+
46
+ def log(self, name: str, data):
47
+ """Log scalar data to TensorBoard
48
+
49
+ Args:
50
+ name (str): Name/tag for the data
51
+ data: Scalar value to log
52
+ """
53
+ if hasattr(data, 'item'):
54
+ data = data.item()
55
+ self.writer.add_scalar(name, data, self.step)
56
+ self.step += 1
@@ -1,3 +1,4 @@
1
+ import math
1
2
  from torch.optim import lr_scheduler
2
3
 
3
4
  import numpy as np
@@ -12,6 +13,7 @@ __all__ = [
12
13
  'LogCyclicLR',
13
14
  'ReduceLROnPlateau',
14
15
  'OneCycleLR',
16
+ 'CosineAnnealingWithWarmup',
15
17
  ]
16
18
 
17
19
 
@@ -69,6 +71,31 @@ class ScheduleLR(TrainerCallback):
69
71
  """
70
72
  self.lr_scheduler.step()
71
73
 
74
+ def simulate_and_plot(self, total_steps: int, initial_lr: float, log_scale: bool = False):
75
+ import torch
76
+ import matplotlib.pyplot as plt
77
+
78
+ dummy_optim = torch.optim.Adam([torch.zeros(1)], lr=initial_lr)
79
+ scheduler = self.lr_scheduler_cls(dummy_optim, **self.kwargs)
80
+
81
+ lrs = []
82
+ for step in range(total_steps):
83
+ lrs.append(dummy_optim.param_groups[0]['lr'])
84
+ scheduler.step()
85
+
86
+ plt.figure(figsize=(10, 6))
87
+ plt.plot(lrs, label='Learning Rate')
88
+ plt.xlabel('Steps')
89
+ plt.ylabel('Learning Rate')
90
+ plt.title('Learning Rate Schedule')
91
+
92
+ if log_scale:
93
+ plt.yscale('log')
94
+
95
+ plt.grid(True, which="both", linestyle='--', linewidth=0.5)
96
+ plt.legend()
97
+ plt.show()
98
+
72
99
 
73
100
  class StepLR(ScheduleLR):
74
101
  """Learning rate scheduler which decays the learning rate of each parameter group by gamma every ``step_size`` epochs.
@@ -176,6 +203,21 @@ class LogCyclicLR(TrainerCallback):
176
203
  for param_group in self.param_groups:
177
204
  trainer.set_lr(lr, param_group)
178
205
 
206
+ def simulate_and_plot(self, total_steps: int, log_scale: bool = True):
207
+ import matplotlib.pyplot as plt
208
+ lrs = [self.get_lr(batch_num) for batch_num in range(total_steps)]
209
+
210
+ plt.figure(figsize=(10, 6))
211
+ plt.plot(lrs, label='Learning Rate')
212
+ plt.xlabel('Steps')
213
+ plt.ylabel('Learning Rate')
214
+ plt.title('Learning Rate Schedule')
215
+ if log_scale:
216
+ plt.yscale('log')
217
+ plt.grid(True, which='both', linestyle='--', linewidth=0.5)
218
+ plt.legend()
219
+ plt.show()
220
+
179
221
 
180
222
  class ReduceLROnPlateau(TrainerCallback):
181
223
  """Learning rate scheduler which reduces the learning rate when the loss stops decreasing
@@ -238,3 +280,77 @@ class OneCycleLR(ScheduleLR):
238
280
  three_phase=three_phase,
239
281
  **kwargs
240
282
  )
283
+
284
+ class CosineAnnealingWithWarmup(TrainerCallback):
285
+ """
286
+ Cosine annealing scheduler with a warm-up stage.
287
+
288
+ Args:
289
+ max_lr (float): The maximum learning rate after the warm-up phase.
290
+ min_lr (float): The minimum learning rate after the warm-up phase.
291
+ warmup_iters (int): The number of iterations for the warm-up phase.
292
+ total_iters (int): The total number of iterations for the scheduler (including warm-up).
293
+ """
294
+ def __init__(self, max_lr=None, min_lr=1.0e-6, warmup_iters=100, total_iters=100000):
295
+ self.max_lr = max_lr
296
+ self.min_lr = min_lr
297
+ self.warmup_iters = warmup_iters
298
+ self.total_iters = total_iters
299
+
300
+ def get_lr(self, step):
301
+ """
302
+ Compute the learning rate for a given iteration.
303
+
304
+ Args:
305
+ step (int): The current iteration.
306
+
307
+ Returns:
308
+ float: The learning rate for the current iteration.
309
+ """
310
+ if step < self.warmup_iters:
311
+ # Warm-up stage: Linear increase from 0 to max_lr
312
+ return self.max_lr * step / self.warmup_iters
313
+ elif step < self.total_iters:
314
+ # Cosine annealing stage
315
+ t = (step - self.warmup_iters) / (self.total_iters - self.warmup_iters)
316
+ return self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1 + math.cos(math.pi * t))
317
+ else:
318
+ # Beyond total iterations: Return min_lr
319
+ return self.min_lr
320
+
321
+ def start_training(self, trainer: Trainer) -> None:
322
+ self.max_lr = trainer.lr()
323
+
324
+ def end_batch(self, trainer: Trainer, batch_num: int):
325
+ """
326
+ Updates the learning rate at the end of each batch.
327
+
328
+ Args:
329
+ trainer (Trainer): The trainer object.
330
+ batch_num (int): The current batch number.
331
+ """
332
+ lr = self.get_lr(batch_num)
333
+ trainer.set_lr(lr)
334
+
335
+ def simulate_and_plot(self, total_steps: int = None, log_scale: bool = False):
336
+ """
337
+ Simulates and plots the learning rate evolution.
338
+
339
+ Args:
340
+ total_batches (int, optional): Total number of batches to simulate. If None, uses self.total_iters.
341
+ """
342
+
343
+ total_steps = total_steps or self.total_iters
344
+ lrs = [self.get_lr(step) for step in range(total_steps)]
345
+
346
+ import matplotlib.pyplot as plt
347
+ plt.figure(figsize=(10, 6))
348
+ plt.plot(lrs, label='Learning Rate')
349
+ plt.xlabel('Steps')
350
+ plt.ylabel('Learning Rate')
351
+ plt.title('Learning Rate Scheduler')
352
+ if log_scale:
353
+ plt.yscale('log')
354
+ plt.grid(True, which='both', linestyle='--', linewidth=0.5)
355
+ plt.legend()
356
+ plt.show()
@@ -2,6 +2,8 @@ import numpy as np
2
2
  import torch
3
3
  import torch.nn.functional as F
4
4
  from torch import nn
5
+ from dataclasses import dataclass
6
+ from typing import Optional
5
7
 
6
8
  from reflectorch.data_generation import BATCH_DATA_TYPE
7
9
  from reflectorch.ml.basic_trainer import Trainer
@@ -14,6 +16,18 @@ __all__ = [
14
16
  ]
15
17
 
16
18
 
19
+ @dataclass
20
+ class BasicBatchData:
21
+ scaled_curves: torch.Tensor
22
+ scaled_bounds: torch.Tensor
23
+ scaled_params: torch.Tensor = None
24
+ scaled_sigmas: Optional[torch.Tensor] = None
25
+ scaled_q_values: Optional[torch.Tensor] = None
26
+ scaled_denoised_curves: Optional[torch.Tensor] = None
27
+ key_padding_mask: Optional[torch.Tensor] = None
28
+ scaled_conditioning_params: Optional[torch.Tensor] = None
29
+ unscaled_q_values: Optional[torch.Tensor] = None
30
+
17
31
  class RealTimeSimTrainer(Trainer):
18
32
  """Trainer with functionality to customize the sampled batch of data"""
19
33
  loader: ReflectivityDataLoader
@@ -34,46 +48,140 @@ class RealTimeSimTrainer(Trainer):
34
48
 
35
49
 
36
50
  class PointEstimatorTrainer(RealTimeSimTrainer):
37
- """Trainer for the regression inverse problem with incorporation of prior bounds"""
38
- add_sigmas_to_context: bool = False
39
-
40
- def _get_batch(self, batch_data: BATCH_DATA_TYPE):
41
- scaled_params = batch_data['scaled_params'].to(torch.float32)
42
- scaled_curves = batch_data['scaled_noisy_curves'].to(torch.float32)
43
- if self.train_with_q_input:
44
- q_values = batch_data['q_values'].to(torch.float32)
45
- scaled_q_values = self.loader.q_generator.scale_q(q_values)
51
+ """Point estimator trainer for the inverse problem."""
52
+
53
+ def init(self):
54
+ if getattr(self, 'use_l1_loss', False):
55
+ self.criterion = nn.L1Loss(reduction='none')
46
56
  else:
47
- scaled_q_values = None
57
+ self.criterion = nn.MSELoss(reduction='none')
58
+ self.use_curve_reconstruction_loss = getattr(self, 'use_curve_reconstruction_loss', False)
59
+ self.rescale_loss_interval_width = getattr(self, 'rescale_loss_interval_width', False)
60
+ if self.use_curve_reconstruction_loss:
61
+ self.loader.calc_denoised_curves = True
62
+
63
+ self.train_with_q_input = getattr(self, 'train_with_q_input', False)
64
+ self.train_with_sigmas = getattr(self, 'train_with_sigmas', False)
65
+ self.condition_on_q_resolutions = getattr(self, 'condition_on_q_resolutions', False)
66
+
67
+ def _get_batch(self, batch_data: BATCH_DATA_TYPE) -> BasicBatchData:
68
+ def get_scaled_or_none(key, scaler=None):
69
+ value = batch_data.get(key)
70
+ if value is None:
71
+ return None
72
+ scale_func = scaler or (lambda x: x)
73
+ return scale_func(value).to(torch.float32)
48
74
 
75
+ scaled_params = batch_data['scaled_params'].to(torch.float32)
76
+ scaled_curves = batch_data['scaled_noisy_curves'].to(torch.float32)
77
+ scaled_denoised_curves = get_scaled_or_none('curves', self.loader.curves_scaler.scale)
78
+ scaled_q_values = get_scaled_or_none('q_values', self.loader.q_generator.scale_q) if self.train_with_q_input else None
79
+ key_padding_mask = batch_data.get('key_padding_mask', None)
80
+
81
+ scaled_q_resolutions = get_scaled_or_none('q_resolutions', self.loader.smearing.scale_resolutions) if self.condition_on_q_resolutions else None
82
+ conditioning_params = []
83
+ if scaled_q_resolutions is not None:
84
+ conditioning_params.append(scaled_q_resolutions)
85
+ scaled_conditioning_params = torch.cat(conditioning_params, dim=-1) if len(conditioning_params) > 0 else None
86
+
49
87
  num_params = scaled_params.shape[-1] // 3
50
88
  assert num_params * 3 == scaled_params.shape[-1]
51
89
  scaled_params, scaled_bounds = torch.split(scaled_params, [num_params, 2 * num_params], dim=-1)
52
90
 
53
- return scaled_params, scaled_bounds, scaled_curves, scaled_q_values
91
+ return BasicBatchData(
92
+ scaled_params=scaled_params,
93
+ scaled_bounds=scaled_bounds,
94
+ scaled_curves=scaled_curves,
95
+ scaled_q_values=scaled_q_values,
96
+ scaled_denoised_curves=scaled_denoised_curves,
97
+ scaled_conditioning_params=scaled_conditioning_params,
98
+ unscaled_q_values=batch_data['q_values'],
99
+ key_padding_mask=key_padding_mask,
100
+ )
101
+
102
+ def get_loss_dict(self, batch_data: BasicBatchData):
103
+ """Returns the regression loss"""
104
+ scaled_params=batch_data.scaled_params
105
+ scaled_curves=batch_data.scaled_curves
106
+ scaled_bounds=batch_data.scaled_bounds
107
+ scaled_q_values=batch_data.scaled_q_values
108
+ key_padding_mask=batch_data.key_padding_mask
109
+ scaled_conditioning_params=batch_data.scaled_conditioning_params
110
+ unscaled_q_values=batch_data.unscaled_q_values
111
+
112
+ predicted_params = self.model(
113
+ curves = scaled_curves,
114
+ bounds = scaled_bounds,
115
+ q_values = scaled_q_values,
116
+ conditioning_params = scaled_conditioning_params,
117
+ key_padding_mask = key_padding_mask,
118
+ unscaled_q_values = unscaled_q_values,
119
+ )
120
+
121
+ if not self.rescale_loss_interval_width:
122
+ loss = self.criterion(predicted_params, scaled_params).mean()
123
+ else:
124
+ n_params = scaled_params.shape[-1]
125
+ b_min = scaled_bounds[..., :n_params]
126
+ b_max = scaled_bounds[..., n_params:]
127
+ interval_width = b_max - b_min
54
128
 
55
- def get_loss_dict(self, batch_data):
56
- """computes the loss dictionary"""
129
+ base_loss = self.criterion(predicted_params, scaled_params)
130
+ if isinstance(self.criterion, torch.nn.MSELoss):
131
+ width_factors = (interval_width / 2) ** 2
132
+ elif isinstance(self.criterion, torch.nn.L1Loss):
133
+ width_factors = interval_width / 2
57
134
 
58
- scaled_params, scaled_bounds, scaled_curves, scaled_q_values = batch_data
135
+ loss = (width_factors * base_loss).mean()
59
136
 
60
- if self.train_with_q_input:
61
- predicted_params = self.model(scaled_curves, scaled_bounds, scaled_q_values)
62
- else:
63
- predicted_params = self.model(scaled_curves, scaled_bounds)
64
-
65
- loss = self.mse(predicted_params, scaled_params)
66
137
  return {'loss': loss}
67
138
 
68
- def init(self):
69
- self.mse = nn.MSELoss()
139
+
140
+ # class PointEstimatorTrainer(RealTimeSimTrainer):
141
+ # """Trainer for the regression inverse problem with incorporation of prior bounds"""
142
+ # add_sigmas_to_context: bool = False
143
+
144
+ # def _get_batch(self, batch_data: BATCH_DATA_TYPE):
145
+ # scaled_params = batch_data['scaled_params'].to(torch.float32)
146
+ # scaled_curves = batch_data['scaled_noisy_curves'].to(torch.float32)
147
+ # if self.train_with_q_input:
148
+ # q_values = batch_data['q_values'].to(torch.float32)
149
+ # scaled_q_values = self.loader.q_generator.scale_q(q_values)
150
+ # else:
151
+ # scaled_q_values = None
152
+
153
+ # num_params = scaled_params.shape[-1] // 3
154
+ # assert num_params * 3 == scaled_params.shape[-1]
155
+ # scaled_params, scaled_bounds = torch.split(scaled_params, [num_params, 2 * num_params], dim=-1)
156
+
157
+ # return scaled_params, scaled_bounds, scaled_curves, scaled_q_values
158
+
159
+ # def get_loss_dict(self, batch_data):
160
+ # """computes the loss dictionary"""
161
+
162
+ # scaled_params, scaled_bounds, scaled_curves, scaled_q_values = batch_data
163
+
164
+ # if self.train_with_q_input:
165
+ # predicted_params = self.model(scaled_curves, scaled_bounds, scaled_q_values)
166
+ # else:
167
+ # predicted_params = self.model(scaled_curves, scaled_bounds)
168
+
169
+ # loss = self.mse(predicted_params, scaled_params)
170
+ # return {'loss': loss}
171
+
172
+ # def init(self):
173
+ # self.mse = nn.MSELoss()
70
174
 
71
175
 
72
176
  class DenoisingAETrainer(RealTimeSimTrainer):
73
177
  """Trainer which can be used for training a denoising autoencoder model. Overrides _get_batch and get_loss_dict methods """
74
178
  def init(self):
75
- self.criterion = nn.MSELoss()
76
179
  self.loader.calc_denoised_curves = True
180
+
181
+ if getattr(self, 'use_l1_loss', False):
182
+ self.criterion = nn.L1Loss()
183
+ else:
184
+ self.criterion = nn.MSELoss()
77
185
 
78
186
  def _get_batch(self, batch_data: BATCH_DATA_TYPE):
79
187
  """returns scaled curves with and without noise"""
@@ -5,11 +5,12 @@ __all__ = [
5
5
  "ConvEncoder",
6
6
  "ConvDecoder",
7
7
  "ConvAutoencoder",
8
- "ConvVAE",
9
8
  "FnoEncoder",
9
+ "IntegralConvEmbedding",
10
10
  "SpectralConv1d",
11
11
  "ConvResidualNet1D",
12
12
  "ResidualMLP",
13
+ "NetworkWithPriors",
13
14
  "NetworkWithPriorsConvEmb",
14
15
  "NetworkWithPriorsFnoEmb",
15
16
  ]
@@ -2,9 +2,9 @@ from reflectorch.models.encoders.conv_encoder import (
2
2
  ConvEncoder,
3
3
  ConvDecoder,
4
4
  ConvAutoencoder,
5
- ConvVAE,
6
5
  )
7
6
  from reflectorch.models.encoders.fno import FnoEncoder, SpectralConv1d
7
+ from reflectorch.models.encoders.integral_kernel_embedding import IntegralConvEmbedding
8
8
  from reflectorch.models.encoders.conv_res_net import ConvResidualNet1D
9
9
 
10
10
 
@@ -12,8 +12,8 @@ __all__ = [
12
12
  "ConvEncoder",
13
13
  "ConvDecoder",
14
14
  "ConvAutoencoder",
15
- "ConvVAE",
16
15
  "ConvResidualNet1D",
17
16
  "FnoEncoder",
18
17
  "SpectralConv1d",
18
+ "IntegralConvEmbedding",
19
19
  ]