reflectorch 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- reflectorch/__init__.py +17 -0
- reflectorch/data_generation/__init__.py +128 -0
- reflectorch/data_generation/dataset.py +216 -0
- reflectorch/data_generation/likelihoods.py +80 -0
- reflectorch/data_generation/noise.py +471 -0
- reflectorch/data_generation/priors/__init__.py +60 -0
- reflectorch/data_generation/priors/base.py +55 -0
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
- reflectorch/data_generation/priors/independent_priors.py +195 -0
- reflectorch/data_generation/priors/multilayer_models.py +311 -0
- reflectorch/data_generation/priors/multilayer_structures.py +104 -0
- reflectorch/data_generation/priors/no_constraints.py +206 -0
- reflectorch/data_generation/priors/parametric_models.py +842 -0
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
- reflectorch/data_generation/priors/params.py +252 -0
- reflectorch/data_generation/priors/sampler_strategies.py +370 -0
- reflectorch/data_generation/priors/scaler_mixin.py +65 -0
- reflectorch/data_generation/priors/subprior_sampler.py +371 -0
- reflectorch/data_generation/priors/utils.py +118 -0
- reflectorch/data_generation/process_data.py +41 -0
- reflectorch/data_generation/q_generator.py +280 -0
- reflectorch/data_generation/reflectivity/__init__.py +102 -0
- reflectorch/data_generation/reflectivity/abeles.py +97 -0
- reflectorch/data_generation/reflectivity/kinematical.py +71 -0
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
- reflectorch/data_generation/reflectivity/smearing.py +138 -0
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
- reflectorch/data_generation/scale_curves.py +112 -0
- reflectorch/data_generation/smearing.py +99 -0
- reflectorch/data_generation/utils.py +223 -0
- reflectorch/extensions/__init__.py +0 -0
- reflectorch/extensions/jupyter/__init__.py +11 -0
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -0
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -0
- reflectorch/extensions/matplotlib/losses.py +32 -0
- reflectorch/extensions/refnx/refnx_conversion.py +77 -0
- reflectorch/inference/__init__.py +28 -0
- reflectorch/inference/inference_model.py +848 -0
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +55 -0
- reflectorch/inference/multilayer_fitter.py +171 -0
- reflectorch/inference/multilayer_inference_model.py +193 -0
- reflectorch/inference/plotting.py +524 -0
- reflectorch/inference/preprocess_exp/__init__.py +7 -0
- reflectorch/inference/preprocess_exp/attenuation.py +36 -0
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
- reflectorch/inference/preprocess_exp/footprint.py +81 -0
- reflectorch/inference/preprocess_exp/interpolation.py +19 -0
- reflectorch/inference/preprocess_exp/normalize.py +21 -0
- reflectorch/inference/preprocess_exp/preprocess.py +121 -0
- reflectorch/inference/query_matcher.py +82 -0
- reflectorch/inference/record_time.py +43 -0
- reflectorch/inference/sampler_solution.py +56 -0
- reflectorch/inference/scipy_fitter.py +364 -0
- reflectorch/inference/torch_fitter.py +87 -0
- reflectorch/ml/__init__.py +32 -0
- reflectorch/ml/basic_trainer.py +292 -0
- reflectorch/ml/callbacks.py +81 -0
- reflectorch/ml/dataloaders.py +27 -0
- reflectorch/ml/loggers.py +56 -0
- reflectorch/ml/schedulers.py +356 -0
- reflectorch/ml/trainers.py +201 -0
- reflectorch/ml/utils.py +2 -0
- reflectorch/models/__init__.py +16 -0
- reflectorch/models/activations.py +50 -0
- reflectorch/models/encoders/__init__.py +19 -0
- reflectorch/models/encoders/conv_encoder.py +219 -0
- reflectorch/models/encoders/conv_res_net.py +115 -0
- reflectorch/models/encoders/fno.py +134 -0
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -0
- reflectorch/models/networks/mlp_networks.py +434 -0
- reflectorch/models/networks/residual_net.py +157 -0
- reflectorch/paths.py +29 -0
- reflectorch/runs/__init__.py +31 -0
- reflectorch/runs/config.py +25 -0
- reflectorch/runs/slurm_utils.py +93 -0
- reflectorch/runs/train.py +78 -0
- reflectorch/runs/utils.py +405 -0
- reflectorch/test_config.py +4 -0
- reflectorch/train.py +4 -0
- reflectorch/train_on_cluster.py +4 -0
- reflectorch/utils.py +98 -0
- reflectorch-1.5.1.dist-info/METADATA +151 -0
- reflectorch-1.5.1.dist-info/RECORD +96 -0
- reflectorch-1.5.1.dist-info/WHEEL +5 -0
- reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
- reflectorch-1.5.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,364 @@
|
|
|
1
|
+
|
|
2
|
+
import warnings
|
|
3
|
+
import joblib
|
|
4
|
+
from joblib import Parallel, delayed
|
|
5
|
+
import numpy as np
|
|
6
|
+
from scipy.optimize import minimize, curve_fit
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from reflectorch.data_generation.priors.base import PriorSampler
|
|
10
|
+
from reflectorch.data_generation.reflectivity import abeles_np
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"standard_refl_fit",
|
|
14
|
+
"refl_fit",
|
|
15
|
+
"fit_refl_curve",
|
|
16
|
+
"restore_masked_params",
|
|
17
|
+
"get_fit_with_growth",
|
|
18
|
+
"batch_standard_refl_fit",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def standard_restore_params(fitted_params) -> dict:
|
|
23
|
+
num_layers = (fitted_params.size - 2) // 3
|
|
24
|
+
|
|
25
|
+
return dict(
|
|
26
|
+
thickness=fitted_params[:num_layers],
|
|
27
|
+
roughness=fitted_params[num_layers:2 * num_layers + 1],
|
|
28
|
+
sld=fitted_params[2 * num_layers + 1:],
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def mse_loss(curve1, curve2):
|
|
33
|
+
return np.sum((curve1 - curve2) ** 2)
|
|
34
|
+
|
|
35
|
+
def standard_refl_fit(
|
|
36
|
+
q: np.ndarray, curve: np.ndarray,
|
|
37
|
+
init_params: np.ndarray,
|
|
38
|
+
bounds: np.ndarray = None,
|
|
39
|
+
refl_generator=abeles_np,
|
|
40
|
+
restore_params_func=standard_restore_params,
|
|
41
|
+
scale_curve_func=np.log10,
|
|
42
|
+
**kwargs
|
|
43
|
+
):
|
|
44
|
+
if bounds is not None:
|
|
45
|
+
kwargs['bounds'] = bounds
|
|
46
|
+
init_params = np.clip(init_params, *bounds)
|
|
47
|
+
|
|
48
|
+
res = curve_fit(
|
|
49
|
+
standard_get_scaled_curve_func(
|
|
50
|
+
refl_generator=refl_generator,
|
|
51
|
+
restore_params_func=restore_params_func,
|
|
52
|
+
scale_curve_func=scale_curve_func,
|
|
53
|
+
),
|
|
54
|
+
q, scale_curve_func(curve),
|
|
55
|
+
p0=init_params, **kwargs
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
curve = refl_generator(q, **restore_params_func(res[0]))
|
|
59
|
+
return res[0], curve
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def batch_refl_fit(
|
|
64
|
+
q: np.ndarray,
|
|
65
|
+
curves: np.ndarray,
|
|
66
|
+
init_params: np.ndarray, # (n_curves, n_params)
|
|
67
|
+
prior_sampler: PriorSampler,
|
|
68
|
+
bounds: np.ndarray = None,
|
|
69
|
+
error_bars: np.ndarray = None,
|
|
70
|
+
scale_curve_func=np.log10,
|
|
71
|
+
method: str = 'trf', #'lm', 'trf'
|
|
72
|
+
polishing_max_steps: int = None,
|
|
73
|
+
reflectivity_kwargs: dict = None,
|
|
74
|
+
n_jobs: int = -1,
|
|
75
|
+
verbose: int = 5,
|
|
76
|
+
**kwargs
|
|
77
|
+
):
|
|
78
|
+
"""
|
|
79
|
+
Fit (polished fit) multiple reflectivity curves in parallel using joblib.
|
|
80
|
+
|
|
81
|
+
Parameters
|
|
82
|
+
----------
|
|
83
|
+
q : np.ndarray
|
|
84
|
+
1D array of momentum transfer values (same for all curves).
|
|
85
|
+
curves : np.ndarray
|
|
86
|
+
2D array of reflectivity curves with shape (n_curves, n_q).
|
|
87
|
+
init_params : np.ndarray
|
|
88
|
+
2D array of initial parameter guesses (n_curves, n_params).
|
|
89
|
+
prior_sampler : PriorSampler
|
|
90
|
+
The prior sampler.
|
|
91
|
+
bounds : np.ndarray, optional
|
|
92
|
+
Bounds for the parameters, shape (2, n_params). Shared by all the curves. Default: None.
|
|
93
|
+
error_bars : np.ndarray, optional
|
|
94
|
+
Error bars for the curves, shape (n_curves, n_q). Default: None.
|
|
95
|
+
scale_curve_func : callable, optional
|
|
96
|
+
Function to scale the curves. Default: `np.log10`.
|
|
97
|
+
method : str, optional
|
|
98
|
+
The method to use for the fitting. Default: 'trf'.
|
|
99
|
+
polishing_max_steps : int, optional
|
|
100
|
+
The maximum number of function evaluations for the polishing step. Default: None.
|
|
101
|
+
reflectivity_kwargs : dict, optional
|
|
102
|
+
Keyword arguments for the reflectivity function. Default: None.
|
|
103
|
+
n_jobs : int, optional
|
|
104
|
+
The number of jobs to run in parallel. Default: -1 (all CPUs).
|
|
105
|
+
verbose : int, optional
|
|
106
|
+
The verbosity level for joblib. Default: 5.
|
|
107
|
+
**kwargs : dict
|
|
108
|
+
Extra keyword arguments passed to `scipy.optimize.curve_fit`.
|
|
109
|
+
|
|
110
|
+
Returns
|
|
111
|
+
-------
|
|
112
|
+
params_array : np.ndarray
|
|
113
|
+
Array of fitted parameter values for each curve, shape (n_curves, n_params).
|
|
114
|
+
error_bars_array : np.ndarray
|
|
115
|
+
Array of error bars for the fitted parameter values, shape (n_curves, n_params).
|
|
116
|
+
curves_array : np.ndarray
|
|
117
|
+
Array of fitted reflectivity curves, shape (n_curves, n_q).
|
|
118
|
+
"""
|
|
119
|
+
if bounds is not None:
|
|
120
|
+
if bounds.ndim == 2:
|
|
121
|
+
bounds = bounds[None].repeat(curves.shape[0], 0)
|
|
122
|
+
elif bounds.ndim == 3:
|
|
123
|
+
assert bounds.shape[0] == curves.shape[0], f"Bounds must have the same number of curves as the number of curves, got {bounds.shape[0]} and {curves.shape[0]}"
|
|
124
|
+
else:
|
|
125
|
+
raise ValueError(f"Bounds must be a 2D or 3D array, got {bounds.ndim}D array")
|
|
126
|
+
else:
|
|
127
|
+
bounds = [None] * curves.shape[0]
|
|
128
|
+
|
|
129
|
+
results = Parallel(n_jobs=n_jobs, verbose=verbose)(
|
|
130
|
+
delayed(refl_fit)(
|
|
131
|
+
q=q, curve=curve, init_params=init_params,
|
|
132
|
+
bounds=bound,
|
|
133
|
+
prior_sampler=prior_sampler,
|
|
134
|
+
error_bars=error_bars,
|
|
135
|
+
method=method,
|
|
136
|
+
scale_curve_func=scale_curve_func,
|
|
137
|
+
polishing_max_steps=polishing_max_steps,
|
|
138
|
+
reflectivity_kwargs=reflectivity_kwargs,
|
|
139
|
+
**kwargs
|
|
140
|
+
)
|
|
141
|
+
for curve, init_params, bound in zip(curves, init_params, bounds)
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
params_array, error_bars, curves_array = zip(*results)
|
|
145
|
+
return np.array(params_array), np.array(error_bars), np.array(curves_array)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def refl_fit(
|
|
149
|
+
q: np.ndarray,
|
|
150
|
+
curve: np.ndarray,
|
|
151
|
+
init_params: np.ndarray,
|
|
152
|
+
prior_sampler: PriorSampler,
|
|
153
|
+
bounds: np.ndarray = None,
|
|
154
|
+
error_bars: np.ndarray = None,
|
|
155
|
+
scale_curve_func=np.log10,
|
|
156
|
+
method: str = 'trf', #'lm', 'trf'
|
|
157
|
+
polishing_max_steps: int = None,
|
|
158
|
+
reflectivity_kwargs: dict = None,
|
|
159
|
+
**kwargs
|
|
160
|
+
):
|
|
161
|
+
if bounds is not None:
|
|
162
|
+
if bounds.ndim != 2:
|
|
163
|
+
raise ValueError(f"Bounds must be a 2D array, got {bounds.ndim}D array")
|
|
164
|
+
# introduce a small perturbation for fixed bounds
|
|
165
|
+
epsilon = 1e-6
|
|
166
|
+
adjusted_bounds = bounds.copy()
|
|
167
|
+
|
|
168
|
+
for i in range(bounds.shape[1]):
|
|
169
|
+
if bounds[0, i] == bounds[1, i]:
|
|
170
|
+
adjusted_bounds[0, i] -= epsilon
|
|
171
|
+
adjusted_bounds[1, i] += epsilon
|
|
172
|
+
|
|
173
|
+
init_params = np.clip(init_params, *adjusted_bounds)
|
|
174
|
+
if method != 'lm':
|
|
175
|
+
kwargs['bounds'] = adjusted_bounds
|
|
176
|
+
|
|
177
|
+
reflectivity_kwargs = reflectivity_kwargs or {}
|
|
178
|
+
for key, value in reflectivity_kwargs.items():
|
|
179
|
+
if isinstance(value, float):
|
|
180
|
+
reflectivity_kwargs[key] = torch.tensor([[value]], dtype=torch.float64)
|
|
181
|
+
elif isinstance(value, np.ndarray):
|
|
182
|
+
reflectivity_kwargs[key] = torch.tensor(value, dtype=torch.float32).unsqueeze(0)
|
|
183
|
+
|
|
184
|
+
curve = np.clip(curve, a_min=1e-12, a_max=None)
|
|
185
|
+
|
|
186
|
+
if error_bars is not None and scale_curve_func == np.log10:
|
|
187
|
+
error_bars = np.clip(error_bars, a_min=1e-20, a_max=None)
|
|
188
|
+
scaled_error_bars = error_bars / (curve * np.log(10))
|
|
189
|
+
else:
|
|
190
|
+
scaled_error_bars = None
|
|
191
|
+
|
|
192
|
+
if polishing_max_steps is not None:
|
|
193
|
+
if method == 'lm':
|
|
194
|
+
kwargs['maxfev'] = polishing_max_steps
|
|
195
|
+
else:
|
|
196
|
+
kwargs['max_nfev'] = polishing_max_steps
|
|
197
|
+
|
|
198
|
+
res = curve_fit(
|
|
199
|
+
f=get_scaled_curve_func(
|
|
200
|
+
scale_curve_func=scale_curve_func,
|
|
201
|
+
prior_sampler=prior_sampler,
|
|
202
|
+
reflectivity_kwargs=reflectivity_kwargs,
|
|
203
|
+
),
|
|
204
|
+
xdata=q,
|
|
205
|
+
ydata=scale_curve_func(curve).reshape(-1),
|
|
206
|
+
p0=init_params,
|
|
207
|
+
sigma=scaled_error_bars,
|
|
208
|
+
absolute_sigma=True,
|
|
209
|
+
method=method,
|
|
210
|
+
**kwargs
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
curve = prior_sampler.param_model.reflectivity(torch.tensor(q, dtype=torch.float64),
|
|
214
|
+
torch.tensor(res[0], dtype=torch.float64).unsqueeze(0),
|
|
215
|
+
**reflectivity_kwargs).squeeze().numpy()
|
|
216
|
+
# cov matrix --> variance of the parameter estimate
|
|
217
|
+
if res[1] is not None and np.ndim(res[1]) == 2 and np.all(np.isfinite(res[1])):
|
|
218
|
+
pol_param_errs = np.sqrt(np.diag(res[1]))
|
|
219
|
+
else:
|
|
220
|
+
pol_param_errs = np.full_like(res[1], np.nan)
|
|
221
|
+
return res[0], pol_param_errs, curve
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def get_fit_with_growth(
|
|
225
|
+
q: np.ndarray,
|
|
226
|
+
curve: np.ndarray,
|
|
227
|
+
init_params: np.ndarray,
|
|
228
|
+
bounds: np.ndarray = None,
|
|
229
|
+
init_d_change: float = 0.,
|
|
230
|
+
max_d_change: float = 30.,
|
|
231
|
+
scale_curve_func=np.log10,
|
|
232
|
+
**kwargs
|
|
233
|
+
):
|
|
234
|
+
init_params = np.array(list(init_params) + [init_d_change])
|
|
235
|
+
if bounds is not None:
|
|
236
|
+
bounds = np.concatenate([bounds, np.array([0, max_d_change])[..., None]], -1)
|
|
237
|
+
|
|
238
|
+
params, curve = standard_refl_fit(
|
|
239
|
+
q,
|
|
240
|
+
curve,
|
|
241
|
+
init_params,
|
|
242
|
+
bounds,
|
|
243
|
+
refl_generator=growth_reflectivity,
|
|
244
|
+
restore_params_func=get_restore_params_with_growth_func(q_size=q.size, d_idx=0),
|
|
245
|
+
scale_curve_func=scale_curve_func,
|
|
246
|
+
**kwargs
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
params[0] += params[-1] / 2
|
|
250
|
+
return params, curve
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def fit_refl_curve(q: np.ndarray, curve: np.ndarray,
|
|
254
|
+
init_params: np.ndarray,
|
|
255
|
+
bounds: np.ndarray = None,
|
|
256
|
+
refl_generator=abeles_np,
|
|
257
|
+
restore_params_func=standard_restore_params,
|
|
258
|
+
scale_curve_func=np.log10,
|
|
259
|
+
**kwargs
|
|
260
|
+
) -> np.ndarray:
|
|
261
|
+
fitting_func = get_fitting_func(
|
|
262
|
+
q=q, curve=curve,
|
|
263
|
+
refl_generator=refl_generator,
|
|
264
|
+
restore_params_func=restore_params_func,
|
|
265
|
+
scale_curve_func=scale_curve_func,
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
res = minimize(fitting_func, init_params, bounds=bounds, **kwargs)
|
|
269
|
+
|
|
270
|
+
if not res.success:
|
|
271
|
+
warnings.warn(f"Minimization did not converge.")
|
|
272
|
+
return res.x
|
|
273
|
+
|
|
274
|
+
def standard_get_scaled_curve_func(
|
|
275
|
+
refl_generator=abeles_np,
|
|
276
|
+
restore_params_func=standard_restore_params,
|
|
277
|
+
scale_curve_func=np.log10,
|
|
278
|
+
):
|
|
279
|
+
def scaled_curve_func(q, *fitted_params):
|
|
280
|
+
fitted_params = restore_params_func(np.asarray(fitted_params))
|
|
281
|
+
fitted_curve = refl_generator(q, **fitted_params)
|
|
282
|
+
scaled_curve = scale_curve_func(fitted_curve)
|
|
283
|
+
return scaled_curve
|
|
284
|
+
|
|
285
|
+
return scaled_curve_func
|
|
286
|
+
|
|
287
|
+
def get_scaled_curve_func(
|
|
288
|
+
scale_curve_func=np.log10,
|
|
289
|
+
prior_sampler: PriorSampler = None,
|
|
290
|
+
reflectivity_kwargs: dict = None,
|
|
291
|
+
):
|
|
292
|
+
reflectivity_kwargs = reflectivity_kwargs or {}
|
|
293
|
+
|
|
294
|
+
def scaled_curve_func(q, *fitted_params):
|
|
295
|
+
q_tensor = torch.from_numpy(q).to(torch.float64)
|
|
296
|
+
fitted_params_tensor = torch.tensor(fitted_params, dtype=torch.float64).unsqueeze(0)
|
|
297
|
+
|
|
298
|
+
fitted_curve_tensor = prior_sampler.param_model.reflectivity(q_tensor, fitted_params_tensor, **reflectivity_kwargs)
|
|
299
|
+
fitted_curve = fitted_curve_tensor.squeeze().numpy()
|
|
300
|
+
|
|
301
|
+
scaled_curve = scale_curve_func(fitted_curve)
|
|
302
|
+
|
|
303
|
+
return scaled_curve.reshape(-1)
|
|
304
|
+
|
|
305
|
+
return scaled_curve_func
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def get_fitting_func(
|
|
309
|
+
q: np.ndarray,
|
|
310
|
+
curve: np.ndarray,
|
|
311
|
+
refl_generator=abeles_np,
|
|
312
|
+
restore_params_func=standard_restore_params,
|
|
313
|
+
scale_curve_func=np.log10,
|
|
314
|
+
loss_func=mse_loss,
|
|
315
|
+
):
|
|
316
|
+
scaled_curve = scale_curve_func(curve)
|
|
317
|
+
|
|
318
|
+
def fitting_func(fitted_params):
|
|
319
|
+
fitted_params = restore_params_func(fitted_params)
|
|
320
|
+
fitted_curve = refl_generator(q, **fitted_params)
|
|
321
|
+
loss = loss_func(scale_curve_func(fitted_curve), scaled_curve)
|
|
322
|
+
return loss
|
|
323
|
+
|
|
324
|
+
return fitting_func
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def restore_masked_params(fixed_params, fixed_mask):
|
|
328
|
+
def restore_params(fitted_params) -> dict:
|
|
329
|
+
params = np.empty_like(fixed_mask).astype(fitted_params.dtype)
|
|
330
|
+
params[fixed_mask] = fixed_params
|
|
331
|
+
params[~fixed_mask] = fitted_params
|
|
332
|
+
return standard_restore_params(params)
|
|
333
|
+
|
|
334
|
+
return restore_params
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def base_params2growth(base_params: dict, d_shift: np.ndarray, d_idx: int = 0) -> dict:
|
|
338
|
+
d_init = base_params['thickness'][None]
|
|
339
|
+
q_size = d_shift.size
|
|
340
|
+
d = d_init.repeat(q_size, 0)
|
|
341
|
+
d[:, d_idx] = d[:, d_idx] + d_shift
|
|
342
|
+
|
|
343
|
+
roughness = np.broadcast_to(base_params['roughness'][None], (q_size, base_params['roughness'].size))
|
|
344
|
+
sld = np.broadcast_to(base_params['sld'][None], (q_size, base_params['sld'].size))
|
|
345
|
+
|
|
346
|
+
return {
|
|
347
|
+
'thickness': d,
|
|
348
|
+
'roughness': roughness,
|
|
349
|
+
'sld': sld,
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def get_restore_params_with_growth_func(q_size: int, d_idx: int = 0):
|
|
354
|
+
def restore_params_with_growth(fitted_params) -> dict:
|
|
355
|
+
fitted_params, delta_d = fitted_params[:-1], fitted_params[-1]
|
|
356
|
+
base_params = standard_restore_params(fitted_params)
|
|
357
|
+
d_shift = np.linspace(0, delta_d, q_size)
|
|
358
|
+
return base_params2growth(base_params, d_shift, d_idx)
|
|
359
|
+
|
|
360
|
+
return restore_params_with_growth
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def growth_reflectivity(q: np.ndarray, **kwargs):
|
|
364
|
+
return abeles_np(q[..., None], **kwargs).flatten()
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from tqdm import trange
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn, Tensor
|
|
5
|
+
|
|
6
|
+
from reflectorch.data_generation import LogLikelihood, reflectivity, PriorSampler
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ReflGradientFit(object):
|
|
10
|
+
"""Directly optimizes the thin film parameters using a Pytorch optimizer
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
q (Tensor): the q positions
|
|
14
|
+
exp_curve (Tensor): the experimental reflectivity curve
|
|
15
|
+
prior_sampler (PriorSampler): the prior sampler
|
|
16
|
+
params (Tensor): the initial thin film parameters
|
|
17
|
+
fit_indices (Tensor): the indices of the thin film parameters which are to be fitted
|
|
18
|
+
sigmas (Tensor, optional): error bars of the reflectivity curve, if not provided they are derived from ``rel_err`` and ``abs_err``. Defaults to None.
|
|
19
|
+
optim_cls (Type[torch.optim.Optimizer], optional): the Pytorch optimizer class. Defaults to None.
|
|
20
|
+
lr (float, optional): the learning rate. Defaults to 1e-2.
|
|
21
|
+
rel_err (float, optional): the relative error in the reflectivity curve. Defaults to 0.1.
|
|
22
|
+
abs_err (float, optional): the absolute error in the reflectivity curve. Defaults to 1e-7.
|
|
23
|
+
"""
|
|
24
|
+
def __init__(self,
|
|
25
|
+
q: Tensor,
|
|
26
|
+
exp_curve: Tensor,
|
|
27
|
+
prior_sampler: PriorSampler,
|
|
28
|
+
params: Tensor,
|
|
29
|
+
fit_indices: Tensor,
|
|
30
|
+
sigmas: Tensor = None,
|
|
31
|
+
optim_cls=None,
|
|
32
|
+
lr: float = 1e-2,
|
|
33
|
+
rel_err: float = 0.1,
|
|
34
|
+
abs_err: float = 1e-7,
|
|
35
|
+
):
|
|
36
|
+
self.q = q
|
|
37
|
+
|
|
38
|
+
if sigmas is None:
|
|
39
|
+
sigmas = exp_curve * rel_err + abs_err
|
|
40
|
+
|
|
41
|
+
self.likelihood = LogLikelihood(q, exp_curve, prior_sampler, sigmas)
|
|
42
|
+
|
|
43
|
+
self.num_layers = params.shape[-1] // 3
|
|
44
|
+
self.fit_indices = fit_indices
|
|
45
|
+
self.init_params = params.clone()
|
|
46
|
+
self.params_to_fit = nn.Parameter(self.init_params[fit_indices].clone())
|
|
47
|
+
|
|
48
|
+
optim_cls = optim_cls or torch.optim.Adam
|
|
49
|
+
self.optim = optim_cls([self.params_to_fit], lr)
|
|
50
|
+
|
|
51
|
+
self.losses = []
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def params(self):
|
|
55
|
+
params = self.init_params.clone()
|
|
56
|
+
params[self.fit_indices] = self.params_to_fit
|
|
57
|
+
return params
|
|
58
|
+
|
|
59
|
+
def calc_log_likelihood(self):
|
|
60
|
+
return self.likelihood.calc_log_likelihood(self.refl())
|
|
61
|
+
|
|
62
|
+
def calc_log_prob_loss(self):
|
|
63
|
+
return - self.calc_log_likelihood().mean()
|
|
64
|
+
|
|
65
|
+
def refl(self):
|
|
66
|
+
d, sigma, rho = torch.split(self.params, [self.num_layers, self.num_layers + 1, self.num_layers + 1], -1)
|
|
67
|
+
return reflectivity(self.q, d, sigma, rho)
|
|
68
|
+
|
|
69
|
+
def run(self, num_iterations: int = 500, disable_tqdm: bool = False):
|
|
70
|
+
"""Runs the optimization process
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
num_iterations (int, optional): number of iterations the optimization is run for. Defaults to 500.
|
|
74
|
+
disable_tqdm (bool, optional): whether to disable the prograss bar. Defaults to False.
|
|
75
|
+
"""
|
|
76
|
+
pbar = trange(num_iterations, disable=disable_tqdm)
|
|
77
|
+
|
|
78
|
+
for _ in pbar:
|
|
79
|
+
self.optim.zero_grad()
|
|
80
|
+
loss = self.calc_log_prob_loss()
|
|
81
|
+
loss.backward()
|
|
82
|
+
self.optim.step()
|
|
83
|
+
self.losses.append(loss.item())
|
|
84
|
+
pbar.set_description(f'Loss = {loss.item():.2e}')
|
|
85
|
+
|
|
86
|
+
def clear(self):
|
|
87
|
+
self.losses.clear()
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from reflectorch.ml.basic_trainer import *
|
|
2
|
+
from reflectorch.ml.callbacks import *
|
|
3
|
+
from reflectorch.ml.trainers import *
|
|
4
|
+
from reflectorch.ml.loggers import *
|
|
5
|
+
from reflectorch.ml.schedulers import *
|
|
6
|
+
from reflectorch.ml.dataloaders import *
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
'Trainer',
|
|
10
|
+
'TrainerCallback',
|
|
11
|
+
'DataLoader',
|
|
12
|
+
'PeriodicTrainerCallback',
|
|
13
|
+
'SaveBestModel',
|
|
14
|
+
'LogLosses',
|
|
15
|
+
'Logger',
|
|
16
|
+
'Loggers',
|
|
17
|
+
'PrintLogger',
|
|
18
|
+
'TensorBoardLogger',
|
|
19
|
+
'ScheduleBatchSize',
|
|
20
|
+
'ScheduleLR',
|
|
21
|
+
'StepLR',
|
|
22
|
+
'CyclicLR',
|
|
23
|
+
'LogCyclicLR',
|
|
24
|
+
'ReduceLROnPlateau',
|
|
25
|
+
'OneCycleLR',
|
|
26
|
+
'CosineAnnealingWithWarmup',
|
|
27
|
+
'ReflectivityDataLoader',
|
|
28
|
+
'MultilayerDataLoader',
|
|
29
|
+
'RealTimeSimTrainer',
|
|
30
|
+
'DenoisingAETrainer',
|
|
31
|
+
'PointEstimatorTrainer',
|
|
32
|
+
]
|