reflectorch 1.2.1__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/data_generation/__init__.py +4 -0
- reflectorch/data_generation/dataset.py +27 -7
- reflectorch/data_generation/noise.py +115 -9
- reflectorch/data_generation/priors/parametric_models.py +91 -16
- reflectorch/data_generation/priors/parametric_subpriors.py +28 -7
- reflectorch/data_generation/priors/sampler_strategies.py +67 -3
- reflectorch/data_generation/q_generator.py +97 -43
- reflectorch/data_generation/reflectivity/__init__.py +53 -11
- reflectorch/data_generation/reflectivity/kinematical.py +4 -5
- reflectorch/data_generation/reflectivity/smearing.py +25 -10
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
- reflectorch/data_generation/smearing.py +42 -11
- reflectorch/data_generation/utils.py +93 -18
- reflectorch/extensions/refnx/refnx_conversion.py +77 -0
- reflectorch/inference/inference_model.py +795 -159
- reflectorch/inference/loading_data.py +37 -0
- reflectorch/inference/plotting.py +517 -0
- reflectorch/inference/preprocess_exp/interpolation.py +5 -2
- reflectorch/inference/scipy_fitter.py +98 -7
- reflectorch/ml/__init__.py +2 -0
- reflectorch/ml/basic_trainer.py +18 -6
- reflectorch/ml/callbacks.py +5 -4
- reflectorch/ml/loggers.py +25 -0
- reflectorch/ml/schedulers.py +116 -0
- reflectorch/ml/trainers.py +131 -23
- reflectorch/models/__init__.py +2 -1
- reflectorch/models/encoders/__init__.py +2 -2
- reflectorch/models/encoders/conv_encoder.py +54 -40
- reflectorch/models/encoders/fno.py +23 -16
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +2 -0
- reflectorch/models/networks/mlp_networks.py +331 -153
- reflectorch/models/networks/residual_net.py +31 -5
- reflectorch/runs/train.py +0 -1
- reflectorch/runs/utils.py +48 -11
- reflectorch/utils.py +30 -0
- {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/METADATA +20 -17
- {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/RECORD +41 -36
- {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info/licenses}/LICENSE.txt +0 -0
- {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/top_level.txt +0 -0
|
@@ -2,11 +2,14 @@ import warnings
|
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
from scipy.optimize import minimize, curve_fit
|
|
5
|
+
import torch
|
|
5
6
|
|
|
7
|
+
from reflectorch.data_generation.priors.base import PriorSampler
|
|
6
8
|
from reflectorch.data_generation.reflectivity import abeles_np
|
|
7
9
|
|
|
8
10
|
__all__ = [
|
|
9
11
|
"standard_refl_fit",
|
|
12
|
+
"refl_fit",
|
|
10
13
|
"fit_refl_curve",
|
|
11
14
|
"restore_masked_params",
|
|
12
15
|
"get_fit_with_growth",
|
|
@@ -26,7 +29,6 @@ def standard_restore_params(fitted_params) -> dict:
|
|
|
26
29
|
def mse_loss(curve1, curve2):
|
|
27
30
|
return np.sum((curve1 - curve2) ** 2)
|
|
28
31
|
|
|
29
|
-
|
|
30
32
|
def standard_refl_fit(
|
|
31
33
|
q: np.ndarray, curve: np.ndarray,
|
|
32
34
|
init_params: np.ndarray,
|
|
@@ -41,7 +43,7 @@ def standard_refl_fit(
|
|
|
41
43
|
init_params = np.clip(init_params, *bounds)
|
|
42
44
|
|
|
43
45
|
res = curve_fit(
|
|
44
|
-
|
|
46
|
+
standard_get_scaled_curve_func(
|
|
45
47
|
refl_generator=refl_generator,
|
|
46
48
|
restore_params_func=restore_params_func,
|
|
47
49
|
scale_curve_func=scale_curve_func,
|
|
@@ -53,9 +55,73 @@ def standard_refl_fit(
|
|
|
53
55
|
curve = refl_generator(q, **restore_params_func(res[0]))
|
|
54
56
|
return res[0], curve
|
|
55
57
|
|
|
58
|
+
def refl_fit(
|
|
59
|
+
q: np.ndarray,
|
|
60
|
+
curve: np.ndarray,
|
|
61
|
+
init_params: np.ndarray,
|
|
62
|
+
prior_sampler: PriorSampler,
|
|
63
|
+
bounds: np.ndarray = None,
|
|
64
|
+
error_bars: np.ndarray = None,
|
|
65
|
+
scale_curve_func=np.log10,
|
|
66
|
+
method: str = 'trf', #'lm', 'trf'
|
|
67
|
+
polishing_max_nfev: int = None,
|
|
68
|
+
reflectivity_kwargs: dict = None,
|
|
69
|
+
**kwargs
|
|
70
|
+
):
|
|
71
|
+
if bounds is not None:
|
|
72
|
+
# introduce a small perturbation for fixed bounds
|
|
73
|
+
epsilon = 1e-6
|
|
74
|
+
adjusted_bounds = bounds.copy()
|
|
75
|
+
|
|
76
|
+
for i in range(bounds.shape[1]):
|
|
77
|
+
if bounds[0, i] == bounds[1, i]:
|
|
78
|
+
adjusted_bounds[0, i] -= epsilon
|
|
79
|
+
adjusted_bounds[1, i] += epsilon
|
|
80
|
+
|
|
81
|
+
init_params = np.clip(init_params, *adjusted_bounds)
|
|
82
|
+
if method != 'lm':
|
|
83
|
+
kwargs['bounds'] = adjusted_bounds
|
|
84
|
+
|
|
85
|
+
reflectivity_kwargs = reflectivity_kwargs or {}
|
|
86
|
+
for key, value in reflectivity_kwargs.items():
|
|
87
|
+
if isinstance(value, float):
|
|
88
|
+
reflectivity_kwargs[key] = torch.tensor([[value]], dtype=torch.float64)
|
|
89
|
+
elif isinstance(value, np.ndarray):
|
|
90
|
+
reflectivity_kwargs[key] = torch.tensor(value, dtype=torch.float32).unsqueeze(0)
|
|
91
|
+
|
|
92
|
+
curve = np.clip(curve, a_min=1e-12, a_max=None)
|
|
93
|
+
|
|
94
|
+
if error_bars is not None and scale_curve_func == np.log10:
|
|
95
|
+
error_bars = np.clip(error_bars, a_min=1e-20, a_max=None)
|
|
96
|
+
scaled_error_bars = error_bars / (curve * np.log(10))
|
|
97
|
+
else:
|
|
98
|
+
scaled_error_bars = None
|
|
99
|
+
|
|
100
|
+
res = curve_fit(
|
|
101
|
+
f=get_scaled_curve_func(
|
|
102
|
+
scale_curve_func=scale_curve_func,
|
|
103
|
+
prior_sampler=prior_sampler,
|
|
104
|
+
reflectivity_kwargs=reflectivity_kwargs,
|
|
105
|
+
),
|
|
106
|
+
xdata=q,
|
|
107
|
+
ydata=scale_curve_func(curve).reshape(-1),
|
|
108
|
+
p0=init_params,
|
|
109
|
+
sigma=scaled_error_bars,
|
|
110
|
+
absolute_sigma=True,
|
|
111
|
+
method=method,
|
|
112
|
+
max_nfev=polishing_max_nfev,
|
|
113
|
+
**kwargs
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
curve = prior_sampler.param_model.reflectivity(torch.tensor(q, dtype=torch.float64),
|
|
117
|
+
torch.tensor(res[0], dtype=torch.float64).unsqueeze(0),
|
|
118
|
+
**reflectivity_kwargs).squeeze().numpy()
|
|
119
|
+
return res[0], curve
|
|
120
|
+
|
|
56
121
|
|
|
57
122
|
def get_fit_with_growth(
|
|
58
|
-
q: np.ndarray,
|
|
123
|
+
q: np.ndarray,
|
|
124
|
+
curve: np.ndarray,
|
|
59
125
|
init_params: np.ndarray,
|
|
60
126
|
bounds: np.ndarray = None,
|
|
61
127
|
init_d_change: float = 0.,
|
|
@@ -68,10 +134,16 @@ def get_fit_with_growth(
|
|
|
68
134
|
bounds = np.concatenate([bounds, np.array([0, max_d_change])[..., None]], -1)
|
|
69
135
|
|
|
70
136
|
params, curve = standard_refl_fit(
|
|
71
|
-
q,
|
|
137
|
+
q,
|
|
138
|
+
curve,
|
|
139
|
+
init_params,
|
|
140
|
+
bounds,
|
|
141
|
+
refl_generator=growth_reflectivity,
|
|
72
142
|
restore_params_func=get_restore_params_with_growth_func(q_size=q.size, d_idx=0),
|
|
73
|
-
scale_curve_func=scale_curve_func,
|
|
143
|
+
scale_curve_func=scale_curve_func,
|
|
144
|
+
**kwargs
|
|
74
145
|
)
|
|
146
|
+
|
|
75
147
|
params[0] += params[-1] / 2
|
|
76
148
|
return params, curve
|
|
77
149
|
|
|
@@ -97,8 +169,7 @@ def fit_refl_curve(q: np.ndarray, curve: np.ndarray,
|
|
|
97
169
|
warnings.warn(f"Minimization did not converge.")
|
|
98
170
|
return res.x
|
|
99
171
|
|
|
100
|
-
|
|
101
|
-
def get_scaled_curve_func(
|
|
172
|
+
def standard_get_scaled_curve_func(
|
|
102
173
|
refl_generator=abeles_np,
|
|
103
174
|
restore_params_func=standard_restore_params,
|
|
104
175
|
scale_curve_func=np.log10,
|
|
@@ -111,6 +182,26 @@ def get_scaled_curve_func(
|
|
|
111
182
|
|
|
112
183
|
return scaled_curve_func
|
|
113
184
|
|
|
185
|
+
def get_scaled_curve_func(
|
|
186
|
+
scale_curve_func=np.log10,
|
|
187
|
+
prior_sampler: PriorSampler = None,
|
|
188
|
+
reflectivity_kwargs: dict = None,
|
|
189
|
+
):
|
|
190
|
+
reflectivity_kwargs = reflectivity_kwargs or {}
|
|
191
|
+
|
|
192
|
+
def scaled_curve_func(q, *fitted_params):
|
|
193
|
+
q_tensor = torch.from_numpy(q).to(torch.float64)
|
|
194
|
+
fitted_params_tensor = torch.tensor(fitted_params, dtype=torch.float64).unsqueeze(0)
|
|
195
|
+
|
|
196
|
+
fitted_curve_tensor = prior_sampler.param_model.reflectivity(q_tensor, fitted_params_tensor, **reflectivity_kwargs)
|
|
197
|
+
fitted_curve = fitted_curve_tensor.squeeze().numpy()
|
|
198
|
+
|
|
199
|
+
scaled_curve = scale_curve_func(fitted_curve)
|
|
200
|
+
|
|
201
|
+
return scaled_curve.reshape(-1)
|
|
202
|
+
|
|
203
|
+
return scaled_curve_func
|
|
204
|
+
|
|
114
205
|
|
|
115
206
|
def get_fitting_func(
|
|
116
207
|
q: np.ndarray,
|
reflectorch/ml/__init__.py
CHANGED
|
@@ -15,6 +15,7 @@ __all__ = [
|
|
|
15
15
|
'Logger',
|
|
16
16
|
'Loggers',
|
|
17
17
|
'PrintLogger',
|
|
18
|
+
'TensorBoardLogger',
|
|
18
19
|
'ScheduleBatchSize',
|
|
19
20
|
'ScheduleLR',
|
|
20
21
|
'StepLR',
|
|
@@ -22,6 +23,7 @@ __all__ = [
|
|
|
22
23
|
'LogCyclicLR',
|
|
23
24
|
'ReduceLROnPlateau',
|
|
24
25
|
'OneCycleLR',
|
|
26
|
+
'CosineAnnealingWithWarmup',
|
|
25
27
|
'ReflectivityDataLoader',
|
|
26
28
|
'MultilayerDataLoader',
|
|
27
29
|
'RealTimeSimTrainer',
|
reflectorch/ml/basic_trainer.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from typing import Optional, Tuple, Iterable, Any, Union, Type
|
|
2
2
|
from collections import defaultdict
|
|
3
3
|
|
|
4
|
-
from tqdm
|
|
4
|
+
from tqdm import tqdm as standard_tqdm
|
|
5
|
+
from tqdm.notebook import tqdm as notebook_tqdm
|
|
5
6
|
import numpy as np
|
|
6
7
|
|
|
7
8
|
import torch
|
|
@@ -31,7 +32,6 @@ class Trainer(object):
|
|
|
31
32
|
logger (Union[Logger, Tuple[Logger, ...], Loggers], optional): logger. Defaults to None.
|
|
32
33
|
optim_cls (Type[torch.optim.Optimizer], optional): Pytorch optimizer. Defaults to torch.optim.Adam.
|
|
33
34
|
optim_kwargs (dict, optional): optimizer arguments. Defaults to None.
|
|
34
|
-
train_with_q_input (bool, optional): if ``True`` the q values are also used as input. Defaults to False.
|
|
35
35
|
"""
|
|
36
36
|
|
|
37
37
|
TOTAL_LOSS_KEY: str = 'total_loss'
|
|
@@ -42,7 +42,6 @@ class Trainer(object):
|
|
|
42
42
|
lr: float,
|
|
43
43
|
batch_size: int,
|
|
44
44
|
clip_grad_norm_max: Optional[int] = None,
|
|
45
|
-
train_with_q_input: bool = False,
|
|
46
45
|
logger: Union[Logger, Tuple[Logger, ...], Loggers] = None,
|
|
47
46
|
optim_cls: Type[torch.optim.Optimizer] = torch.optim.Adam,
|
|
48
47
|
optim_kwargs: dict = None,
|
|
@@ -53,7 +52,6 @@ class Trainer(object):
|
|
|
53
52
|
self.loader = loader
|
|
54
53
|
self.batch_size = batch_size
|
|
55
54
|
self.clip_grad_norm_max = clip_grad_norm_max
|
|
56
|
-
self.train_with_q_input = train_with_q_input
|
|
57
55
|
|
|
58
56
|
self.optim = self.configure_optimizer(optim_cls, lr=lr, **(optim_kwargs or {}))
|
|
59
57
|
self.lrs = []
|
|
@@ -78,7 +76,8 @@ class Trainer(object):
|
|
|
78
76
|
num_batches: int,
|
|
79
77
|
callbacks: Union[Tuple['TrainerCallback', ...], 'TrainerCallback'] = (),
|
|
80
78
|
disable_tqdm: bool = False,
|
|
81
|
-
|
|
79
|
+
use_notebook_tqdm: bool = False,
|
|
80
|
+
update_tqdm_freq: int = 1,
|
|
82
81
|
grad_accumulation_steps: int = 1,
|
|
83
82
|
):
|
|
84
83
|
"""starts the training process
|
|
@@ -87,6 +86,7 @@ class Trainer(object):
|
|
|
87
86
|
num_batches (int): total number of training iterations
|
|
88
87
|
callbacks (Union[Tuple['TrainerCallback'], 'TrainerCallback']): the trainer callbacks. Defaults to ().
|
|
89
88
|
disable_tqdm (bool, optional): if ``True``, the progress bar is disabled. Defaults to False.
|
|
89
|
+
use_notebook_tqdm (bool, optional): should be set to ``True`` when used in a Jupyter Notebook. Defaults to False.
|
|
90
90
|
update_tqdm_freq (int, optional): frequency for updating the progress bar. Defaults to 10.
|
|
91
91
|
grad_accumulation_steps (int, optional): number of gradient accumulation steps. Defaults to 1.
|
|
92
92
|
"""
|
|
@@ -96,7 +96,8 @@ class Trainer(object):
|
|
|
96
96
|
|
|
97
97
|
callbacks = _StackedTrainerCallbacks(list(callbacks) + [self.loader])
|
|
98
98
|
|
|
99
|
-
|
|
99
|
+
tqdm_class = notebook_tqdm if use_notebook_tqdm else standard_tqdm
|
|
100
|
+
pbar = tqdm_class(range(num_batches), disable=disable_tqdm)
|
|
100
101
|
|
|
101
102
|
callbacks.start_training(self)
|
|
102
103
|
|
|
@@ -121,6 +122,7 @@ class Trainer(object):
|
|
|
121
122
|
|
|
122
123
|
if self.clip_grad_norm_max is not None:
|
|
123
124
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip_grad_norm_max)
|
|
125
|
+
|
|
124
126
|
self.optim.step()
|
|
125
127
|
|
|
126
128
|
avr_loss_dict = {k: np.mean(v) for k, v in avr_loss_dict.items()}
|
|
@@ -141,6 +143,16 @@ class Trainer(object):
|
|
|
141
143
|
last_loss = np.mean(self.losses[self.TOTAL_LOSS_KEY][-10:])
|
|
142
144
|
pbar.set_description(f'Loss = {last_loss:.2e}')
|
|
143
145
|
|
|
146
|
+
postfix = {}
|
|
147
|
+
for key in self.losses.keys():
|
|
148
|
+
if key != self.TOTAL_LOSS_KEY:
|
|
149
|
+
last_value = self.losses[key][-1]
|
|
150
|
+
postfix[key] = f'{last_value:.4f}'
|
|
151
|
+
|
|
152
|
+
postfix['lr'] = f'{self.lr():.2e}'
|
|
153
|
+
|
|
154
|
+
pbar.set_postfix(postfix)
|
|
155
|
+
|
|
144
156
|
def get_batch_by_idx(self, batch_num: int) -> Any:
|
|
145
157
|
raise NotImplementedError
|
|
146
158
|
|
reflectorch/ml/callbacks.py
CHANGED
|
@@ -74,7 +74,8 @@ class LogLosses(TrainerCallback):
|
|
|
74
74
|
trainer (Trainer): the trainer object
|
|
75
75
|
batch_num (int): the index of the current iteration / batch
|
|
76
76
|
"""
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
77
|
+
for loss_name, loss_values in trainer.losses.items():
|
|
78
|
+
try:
|
|
79
|
+
trainer.log(f'train/{loss_name}', loss_values[-1])
|
|
80
|
+
except IndexError:
|
|
81
|
+
continue
|
reflectorch/ml/loggers.py
CHANGED
|
@@ -1,7 +1,10 @@
|
|
|
1
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
2
|
+
|
|
1
3
|
__all__ = [
|
|
2
4
|
'Logger',
|
|
3
5
|
'Loggers',
|
|
4
6
|
'PrintLogger',
|
|
7
|
+
'TensorBoardLogger',
|
|
5
8
|
]
|
|
6
9
|
|
|
7
10
|
|
|
@@ -29,3 +32,25 @@ class PrintLogger(Logger):
|
|
|
29
32
|
"""Logger which prints to the console"""
|
|
30
33
|
def log(self, name: str, data):
|
|
31
34
|
print(name, ': ', data)
|
|
35
|
+
|
|
36
|
+
class TensorBoardLogger(Logger):
|
|
37
|
+
def __init__(self, log_dir: str):
|
|
38
|
+
"""
|
|
39
|
+
Args:
|
|
40
|
+
log_dir (str): Directory where TensorBoard logs will be written
|
|
41
|
+
"""
|
|
42
|
+
super().__init__()
|
|
43
|
+
self.writer = SummaryWriter(log_dir=log_dir)
|
|
44
|
+
self.step = 1
|
|
45
|
+
|
|
46
|
+
def log(self, name: str, data):
|
|
47
|
+
"""Log scalar data to TensorBoard
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
name (str): Name/tag for the data
|
|
51
|
+
data: Scalar value to log
|
|
52
|
+
"""
|
|
53
|
+
if hasattr(data, 'item'):
|
|
54
|
+
data = data.item()
|
|
55
|
+
self.writer.add_scalar(name, data, self.step)
|
|
56
|
+
self.step += 1
|
reflectorch/ml/schedulers.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import math
|
|
1
2
|
from torch.optim import lr_scheduler
|
|
2
3
|
|
|
3
4
|
import numpy as np
|
|
@@ -12,6 +13,7 @@ __all__ = [
|
|
|
12
13
|
'LogCyclicLR',
|
|
13
14
|
'ReduceLROnPlateau',
|
|
14
15
|
'OneCycleLR',
|
|
16
|
+
'CosineAnnealingWithWarmup',
|
|
15
17
|
]
|
|
16
18
|
|
|
17
19
|
|
|
@@ -69,6 +71,31 @@ class ScheduleLR(TrainerCallback):
|
|
|
69
71
|
"""
|
|
70
72
|
self.lr_scheduler.step()
|
|
71
73
|
|
|
74
|
+
def simulate_and_plot(self, total_steps: int, initial_lr: float, log_scale: bool = False):
|
|
75
|
+
import torch
|
|
76
|
+
import matplotlib.pyplot as plt
|
|
77
|
+
|
|
78
|
+
dummy_optim = torch.optim.Adam([torch.zeros(1)], lr=initial_lr)
|
|
79
|
+
scheduler = self.lr_scheduler_cls(dummy_optim, **self.kwargs)
|
|
80
|
+
|
|
81
|
+
lrs = []
|
|
82
|
+
for step in range(total_steps):
|
|
83
|
+
lrs.append(dummy_optim.param_groups[0]['lr'])
|
|
84
|
+
scheduler.step()
|
|
85
|
+
|
|
86
|
+
plt.figure(figsize=(10, 6))
|
|
87
|
+
plt.plot(lrs, label='Learning Rate')
|
|
88
|
+
plt.xlabel('Steps')
|
|
89
|
+
plt.ylabel('Learning Rate')
|
|
90
|
+
plt.title('Learning Rate Schedule')
|
|
91
|
+
|
|
92
|
+
if log_scale:
|
|
93
|
+
plt.yscale('log')
|
|
94
|
+
|
|
95
|
+
plt.grid(True, which="both", linestyle='--', linewidth=0.5)
|
|
96
|
+
plt.legend()
|
|
97
|
+
plt.show()
|
|
98
|
+
|
|
72
99
|
|
|
73
100
|
class StepLR(ScheduleLR):
|
|
74
101
|
"""Learning rate scheduler which decays the learning rate of each parameter group by gamma every ``step_size`` epochs.
|
|
@@ -176,6 +203,21 @@ class LogCyclicLR(TrainerCallback):
|
|
|
176
203
|
for param_group in self.param_groups:
|
|
177
204
|
trainer.set_lr(lr, param_group)
|
|
178
205
|
|
|
206
|
+
def simulate_and_plot(self, total_steps: int, log_scale: bool = True):
|
|
207
|
+
import matplotlib.pyplot as plt
|
|
208
|
+
lrs = [self.get_lr(batch_num) for batch_num in range(total_steps)]
|
|
209
|
+
|
|
210
|
+
plt.figure(figsize=(10, 6))
|
|
211
|
+
plt.plot(lrs, label='Learning Rate')
|
|
212
|
+
plt.xlabel('Steps')
|
|
213
|
+
plt.ylabel('Learning Rate')
|
|
214
|
+
plt.title('Learning Rate Schedule')
|
|
215
|
+
if log_scale:
|
|
216
|
+
plt.yscale('log')
|
|
217
|
+
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
|
|
218
|
+
plt.legend()
|
|
219
|
+
plt.show()
|
|
220
|
+
|
|
179
221
|
|
|
180
222
|
class ReduceLROnPlateau(TrainerCallback):
|
|
181
223
|
"""Learning rate scheduler which reduces the learning rate when the loss stops decreasing
|
|
@@ -238,3 +280,77 @@ class OneCycleLR(ScheduleLR):
|
|
|
238
280
|
three_phase=three_phase,
|
|
239
281
|
**kwargs
|
|
240
282
|
)
|
|
283
|
+
|
|
284
|
+
class CosineAnnealingWithWarmup(TrainerCallback):
|
|
285
|
+
"""
|
|
286
|
+
Cosine annealing scheduler with a warm-up stage.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
max_lr (float): The maximum learning rate after the warm-up phase.
|
|
290
|
+
min_lr (float): The minimum learning rate after the warm-up phase.
|
|
291
|
+
warmup_iters (int): The number of iterations for the warm-up phase.
|
|
292
|
+
total_iters (int): The total number of iterations for the scheduler (including warm-up).
|
|
293
|
+
"""
|
|
294
|
+
def __init__(self, max_lr=None, min_lr=1.0e-6, warmup_iters=100, total_iters=100000):
|
|
295
|
+
self.max_lr = max_lr
|
|
296
|
+
self.min_lr = min_lr
|
|
297
|
+
self.warmup_iters = warmup_iters
|
|
298
|
+
self.total_iters = total_iters
|
|
299
|
+
|
|
300
|
+
def get_lr(self, step):
|
|
301
|
+
"""
|
|
302
|
+
Compute the learning rate for a given iteration.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
step (int): The current iteration.
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
float: The learning rate for the current iteration.
|
|
309
|
+
"""
|
|
310
|
+
if step < self.warmup_iters:
|
|
311
|
+
# Warm-up stage: Linear increase from 0 to max_lr
|
|
312
|
+
return self.max_lr * step / self.warmup_iters
|
|
313
|
+
elif step < self.total_iters:
|
|
314
|
+
# Cosine annealing stage
|
|
315
|
+
t = (step - self.warmup_iters) / (self.total_iters - self.warmup_iters)
|
|
316
|
+
return self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1 + math.cos(math.pi * t))
|
|
317
|
+
else:
|
|
318
|
+
# Beyond total iterations: Return min_lr
|
|
319
|
+
return self.min_lr
|
|
320
|
+
|
|
321
|
+
def start_training(self, trainer: Trainer) -> None:
|
|
322
|
+
self.max_lr = trainer.lr()
|
|
323
|
+
|
|
324
|
+
def end_batch(self, trainer: Trainer, batch_num: int):
|
|
325
|
+
"""
|
|
326
|
+
Updates the learning rate at the end of each batch.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
trainer (Trainer): The trainer object.
|
|
330
|
+
batch_num (int): The current batch number.
|
|
331
|
+
"""
|
|
332
|
+
lr = self.get_lr(batch_num)
|
|
333
|
+
trainer.set_lr(lr)
|
|
334
|
+
|
|
335
|
+
def simulate_and_plot(self, total_steps: int = None, log_scale: bool = False):
|
|
336
|
+
"""
|
|
337
|
+
Simulates and plots the learning rate evolution.
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
total_batches (int, optional): Total number of batches to simulate. If None, uses self.total_iters.
|
|
341
|
+
"""
|
|
342
|
+
|
|
343
|
+
total_steps = total_steps or self.total_iters
|
|
344
|
+
lrs = [self.get_lr(step) for step in range(total_steps)]
|
|
345
|
+
|
|
346
|
+
import matplotlib.pyplot as plt
|
|
347
|
+
plt.figure(figsize=(10, 6))
|
|
348
|
+
plt.plot(lrs, label='Learning Rate')
|
|
349
|
+
plt.xlabel('Steps')
|
|
350
|
+
plt.ylabel('Learning Rate')
|
|
351
|
+
plt.title('Learning Rate Scheduler')
|
|
352
|
+
if log_scale:
|
|
353
|
+
plt.yscale('log')
|
|
354
|
+
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
|
|
355
|
+
plt.legend()
|
|
356
|
+
plt.show()
|
reflectorch/ml/trainers.py
CHANGED
|
@@ -2,6 +2,8 @@ import numpy as np
|
|
|
2
2
|
import torch
|
|
3
3
|
import torch.nn.functional as F
|
|
4
4
|
from torch import nn
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Optional
|
|
5
7
|
|
|
6
8
|
from reflectorch.data_generation import BATCH_DATA_TYPE
|
|
7
9
|
from reflectorch.ml.basic_trainer import Trainer
|
|
@@ -14,6 +16,18 @@ __all__ = [
|
|
|
14
16
|
]
|
|
15
17
|
|
|
16
18
|
|
|
19
|
+
@dataclass
|
|
20
|
+
class BasicBatchData:
|
|
21
|
+
scaled_curves: torch.Tensor
|
|
22
|
+
scaled_bounds: torch.Tensor
|
|
23
|
+
scaled_params: torch.Tensor = None
|
|
24
|
+
scaled_sigmas: Optional[torch.Tensor] = None
|
|
25
|
+
scaled_q_values: Optional[torch.Tensor] = None
|
|
26
|
+
scaled_denoised_curves: Optional[torch.Tensor] = None
|
|
27
|
+
key_padding_mask: Optional[torch.Tensor] = None
|
|
28
|
+
scaled_conditioning_params: Optional[torch.Tensor] = None
|
|
29
|
+
unscaled_q_values: Optional[torch.Tensor] = None
|
|
30
|
+
|
|
17
31
|
class RealTimeSimTrainer(Trainer):
|
|
18
32
|
"""Trainer with functionality to customize the sampled batch of data"""
|
|
19
33
|
loader: ReflectivityDataLoader
|
|
@@ -34,46 +48,140 @@ class RealTimeSimTrainer(Trainer):
|
|
|
34
48
|
|
|
35
49
|
|
|
36
50
|
class PointEstimatorTrainer(RealTimeSimTrainer):
|
|
37
|
-
"""
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
scaled_curves = batch_data['scaled_noisy_curves'].to(torch.float32)
|
|
43
|
-
if self.train_with_q_input:
|
|
44
|
-
q_values = batch_data['q_values'].to(torch.float32)
|
|
45
|
-
scaled_q_values = self.loader.q_generator.scale_q(q_values)
|
|
51
|
+
"""Point estimator trainer for the inverse problem."""
|
|
52
|
+
|
|
53
|
+
def init(self):
|
|
54
|
+
if getattr(self, 'use_l1_loss', False):
|
|
55
|
+
self.criterion = nn.L1Loss(reduction='none')
|
|
46
56
|
else:
|
|
47
|
-
|
|
57
|
+
self.criterion = nn.MSELoss(reduction='none')
|
|
58
|
+
self.use_curve_reconstruction_loss = getattr(self, 'use_curve_reconstruction_loss', False)
|
|
59
|
+
self.rescale_loss_interval_width = getattr(self, 'rescale_loss_interval_width', False)
|
|
60
|
+
if self.use_curve_reconstruction_loss:
|
|
61
|
+
self.loader.calc_denoised_curves = True
|
|
62
|
+
|
|
63
|
+
self.train_with_q_input = getattr(self, 'train_with_q_input', False)
|
|
64
|
+
self.train_with_sigmas = getattr(self, 'train_with_sigmas', False)
|
|
65
|
+
self.condition_on_q_resolutions = getattr(self, 'condition_on_q_resolutions', False)
|
|
66
|
+
|
|
67
|
+
def _get_batch(self, batch_data: BATCH_DATA_TYPE) -> BasicBatchData:
|
|
68
|
+
def get_scaled_or_none(key, scaler=None):
|
|
69
|
+
value = batch_data.get(key)
|
|
70
|
+
if value is None:
|
|
71
|
+
return None
|
|
72
|
+
scale_func = scaler or (lambda x: x)
|
|
73
|
+
return scale_func(value).to(torch.float32)
|
|
48
74
|
|
|
75
|
+
scaled_params = batch_data['scaled_params'].to(torch.float32)
|
|
76
|
+
scaled_curves = batch_data['scaled_noisy_curves'].to(torch.float32)
|
|
77
|
+
scaled_denoised_curves = get_scaled_or_none('curves', self.loader.curves_scaler.scale)
|
|
78
|
+
scaled_q_values = get_scaled_or_none('q_values', self.loader.q_generator.scale_q) if self.train_with_q_input else None
|
|
79
|
+
key_padding_mask = batch_data.get('key_padding_mask', None)
|
|
80
|
+
|
|
81
|
+
scaled_q_resolutions = get_scaled_or_none('q_resolutions', self.loader.smearing.scale_resolutions) if self.condition_on_q_resolutions else None
|
|
82
|
+
conditioning_params = []
|
|
83
|
+
if scaled_q_resolutions is not None:
|
|
84
|
+
conditioning_params.append(scaled_q_resolutions)
|
|
85
|
+
scaled_conditioning_params = torch.cat(conditioning_params, dim=-1) if len(conditioning_params) > 0 else None
|
|
86
|
+
|
|
49
87
|
num_params = scaled_params.shape[-1] // 3
|
|
50
88
|
assert num_params * 3 == scaled_params.shape[-1]
|
|
51
89
|
scaled_params, scaled_bounds = torch.split(scaled_params, [num_params, 2 * num_params], dim=-1)
|
|
52
90
|
|
|
53
|
-
return
|
|
91
|
+
return BasicBatchData(
|
|
92
|
+
scaled_params=scaled_params,
|
|
93
|
+
scaled_bounds=scaled_bounds,
|
|
94
|
+
scaled_curves=scaled_curves,
|
|
95
|
+
scaled_q_values=scaled_q_values,
|
|
96
|
+
scaled_denoised_curves=scaled_denoised_curves,
|
|
97
|
+
scaled_conditioning_params=scaled_conditioning_params,
|
|
98
|
+
unscaled_q_values=batch_data['q_values'],
|
|
99
|
+
key_padding_mask=key_padding_mask,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def get_loss_dict(self, batch_data: BasicBatchData):
|
|
103
|
+
"""Returns the regression loss"""
|
|
104
|
+
scaled_params=batch_data.scaled_params
|
|
105
|
+
scaled_curves=batch_data.scaled_curves
|
|
106
|
+
scaled_bounds=batch_data.scaled_bounds
|
|
107
|
+
scaled_q_values=batch_data.scaled_q_values
|
|
108
|
+
key_padding_mask=batch_data.key_padding_mask
|
|
109
|
+
scaled_conditioning_params=batch_data.scaled_conditioning_params
|
|
110
|
+
unscaled_q_values=batch_data.unscaled_q_values
|
|
111
|
+
|
|
112
|
+
predicted_params = self.model(
|
|
113
|
+
curves = scaled_curves,
|
|
114
|
+
bounds = scaled_bounds,
|
|
115
|
+
q_values = scaled_q_values,
|
|
116
|
+
conditioning_params = scaled_conditioning_params,
|
|
117
|
+
key_padding_mask = key_padding_mask,
|
|
118
|
+
unscaled_q_values = unscaled_q_values,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
if not self.rescale_loss_interval_width:
|
|
122
|
+
loss = self.criterion(predicted_params, scaled_params).mean()
|
|
123
|
+
else:
|
|
124
|
+
n_params = scaled_params.shape[-1]
|
|
125
|
+
b_min = scaled_bounds[..., :n_params]
|
|
126
|
+
b_max = scaled_bounds[..., n_params:]
|
|
127
|
+
interval_width = b_max - b_min
|
|
54
128
|
|
|
55
|
-
|
|
56
|
-
|
|
129
|
+
base_loss = self.criterion(predicted_params, scaled_params)
|
|
130
|
+
if isinstance(self.criterion, torch.nn.MSELoss):
|
|
131
|
+
width_factors = (interval_width / 2) ** 2
|
|
132
|
+
elif isinstance(self.criterion, torch.nn.L1Loss):
|
|
133
|
+
width_factors = interval_width / 2
|
|
57
134
|
|
|
58
|
-
|
|
135
|
+
loss = (width_factors * base_loss).mean()
|
|
59
136
|
|
|
60
|
-
if self.train_with_q_input:
|
|
61
|
-
predicted_params = self.model(scaled_curves, scaled_bounds, scaled_q_values)
|
|
62
|
-
else:
|
|
63
|
-
predicted_params = self.model(scaled_curves, scaled_bounds)
|
|
64
|
-
|
|
65
|
-
loss = self.mse(predicted_params, scaled_params)
|
|
66
137
|
return {'loss': loss}
|
|
67
138
|
|
|
68
|
-
|
|
69
|
-
|
|
139
|
+
|
|
140
|
+
# class PointEstimatorTrainer(RealTimeSimTrainer):
|
|
141
|
+
# """Trainer for the regression inverse problem with incorporation of prior bounds"""
|
|
142
|
+
# add_sigmas_to_context: bool = False
|
|
143
|
+
|
|
144
|
+
# def _get_batch(self, batch_data: BATCH_DATA_TYPE):
|
|
145
|
+
# scaled_params = batch_data['scaled_params'].to(torch.float32)
|
|
146
|
+
# scaled_curves = batch_data['scaled_noisy_curves'].to(torch.float32)
|
|
147
|
+
# if self.train_with_q_input:
|
|
148
|
+
# q_values = batch_data['q_values'].to(torch.float32)
|
|
149
|
+
# scaled_q_values = self.loader.q_generator.scale_q(q_values)
|
|
150
|
+
# else:
|
|
151
|
+
# scaled_q_values = None
|
|
152
|
+
|
|
153
|
+
# num_params = scaled_params.shape[-1] // 3
|
|
154
|
+
# assert num_params * 3 == scaled_params.shape[-1]
|
|
155
|
+
# scaled_params, scaled_bounds = torch.split(scaled_params, [num_params, 2 * num_params], dim=-1)
|
|
156
|
+
|
|
157
|
+
# return scaled_params, scaled_bounds, scaled_curves, scaled_q_values
|
|
158
|
+
|
|
159
|
+
# def get_loss_dict(self, batch_data):
|
|
160
|
+
# """computes the loss dictionary"""
|
|
161
|
+
|
|
162
|
+
# scaled_params, scaled_bounds, scaled_curves, scaled_q_values = batch_data
|
|
163
|
+
|
|
164
|
+
# if self.train_with_q_input:
|
|
165
|
+
# predicted_params = self.model(scaled_curves, scaled_bounds, scaled_q_values)
|
|
166
|
+
# else:
|
|
167
|
+
# predicted_params = self.model(scaled_curves, scaled_bounds)
|
|
168
|
+
|
|
169
|
+
# loss = self.mse(predicted_params, scaled_params)
|
|
170
|
+
# return {'loss': loss}
|
|
171
|
+
|
|
172
|
+
# def init(self):
|
|
173
|
+
# self.mse = nn.MSELoss()
|
|
70
174
|
|
|
71
175
|
|
|
72
176
|
class DenoisingAETrainer(RealTimeSimTrainer):
|
|
73
177
|
"""Trainer which can be used for training a denoising autoencoder model. Overrides _get_batch and get_loss_dict methods """
|
|
74
178
|
def init(self):
|
|
75
|
-
self.criterion = nn.MSELoss()
|
|
76
179
|
self.loader.calc_denoised_curves = True
|
|
180
|
+
|
|
181
|
+
if getattr(self, 'use_l1_loss', False):
|
|
182
|
+
self.criterion = nn.L1Loss()
|
|
183
|
+
else:
|
|
184
|
+
self.criterion = nn.MSELoss()
|
|
77
185
|
|
|
78
186
|
def _get_batch(self, batch_data: BATCH_DATA_TYPE):
|
|
79
187
|
"""returns scaled curves with and without noise"""
|
reflectorch/models/__init__.py
CHANGED
|
@@ -5,11 +5,12 @@ __all__ = [
|
|
|
5
5
|
"ConvEncoder",
|
|
6
6
|
"ConvDecoder",
|
|
7
7
|
"ConvAutoencoder",
|
|
8
|
-
"ConvVAE",
|
|
9
8
|
"FnoEncoder",
|
|
9
|
+
"IntegralConvEmbedding",
|
|
10
10
|
"SpectralConv1d",
|
|
11
11
|
"ConvResidualNet1D",
|
|
12
12
|
"ResidualMLP",
|
|
13
|
+
"NetworkWithPriors",
|
|
13
14
|
"NetworkWithPriorsConvEmb",
|
|
14
15
|
"NetworkWithPriorsFnoEmb",
|
|
15
16
|
]
|
|
@@ -2,9 +2,9 @@ from reflectorch.models.encoders.conv_encoder import (
|
|
|
2
2
|
ConvEncoder,
|
|
3
3
|
ConvDecoder,
|
|
4
4
|
ConvAutoencoder,
|
|
5
|
-
ConvVAE,
|
|
6
5
|
)
|
|
7
6
|
from reflectorch.models.encoders.fno import FnoEncoder, SpectralConv1d
|
|
7
|
+
from reflectorch.models.encoders.integral_kernel_embedding import IntegralConvEmbedding
|
|
8
8
|
from reflectorch.models.encoders.conv_res_net import ConvResidualNet1D
|
|
9
9
|
|
|
10
10
|
|
|
@@ -12,8 +12,8 @@ __all__ = [
|
|
|
12
12
|
"ConvEncoder",
|
|
13
13
|
"ConvDecoder",
|
|
14
14
|
"ConvAutoencoder",
|
|
15
|
-
"ConvVAE",
|
|
16
15
|
"ConvResidualNet1D",
|
|
17
16
|
"FnoEncoder",
|
|
18
17
|
"SpectralConv1d",
|
|
18
|
+
"IntegralConvEmbedding",
|
|
19
19
|
]
|