reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +17 -17
- reflectorch/data_generation/__init__.py +128 -126
- reflectorch/data_generation/dataset.py +210 -210
- reflectorch/data_generation/likelihoods.py +80 -80
- reflectorch/data_generation/noise.py +470 -470
- reflectorch/data_generation/priors/__init__.py +60 -60
- reflectorch/data_generation/priors/base.py +55 -55
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
- reflectorch/data_generation/priors/independent_priors.py +195 -195
- reflectorch/data_generation/priors/multilayer_models.py +311 -311
- reflectorch/data_generation/priors/multilayer_structures.py +104 -104
- reflectorch/data_generation/priors/no_constraints.py +206 -206
- reflectorch/data_generation/priors/parametric_models.py +841 -841
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
- reflectorch/data_generation/priors/params.py +252 -252
- reflectorch/data_generation/priors/sampler_strategies.py +369 -369
- reflectorch/data_generation/priors/scaler_mixin.py +65 -65
- reflectorch/data_generation/priors/subprior_sampler.py +371 -371
- reflectorch/data_generation/priors/utils.py +118 -118
- reflectorch/data_generation/process_data.py +41 -41
- reflectorch/data_generation/q_generator.py +280 -246
- reflectorch/data_generation/reflectivity/__init__.py +102 -102
- reflectorch/data_generation/reflectivity/abeles.py +97 -97
- reflectorch/data_generation/reflectivity/kinematical.py +70 -70
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
- reflectorch/data_generation/reflectivity/smearing.py +138 -138
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
- reflectorch/data_generation/scale_curves.py +112 -112
- reflectorch/data_generation/smearing.py +98 -98
- reflectorch/data_generation/utils.py +223 -222
- reflectorch/extensions/jupyter/__init__.py +11 -6
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -34
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -5
- reflectorch/extensions/matplotlib/losses.py +32 -32
- reflectorch/extensions/refnx/refnx_conversion.py +76 -76
- reflectorch/inference/__init__.py +28 -24
- reflectorch/inference/inference_model.py +847 -851
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +37 -0
- reflectorch/inference/multilayer_fitter.py +171 -171
- reflectorch/inference/multilayer_inference_model.py +193 -193
- reflectorch/inference/plotting.py +524 -98
- reflectorch/inference/preprocess_exp/__init__.py +6 -6
- reflectorch/inference/preprocess_exp/attenuation.py +36 -36
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
- reflectorch/inference/preprocess_exp/footprint.py +81 -81
- reflectorch/inference/preprocess_exp/interpolation.py +19 -16
- reflectorch/inference/preprocess_exp/normalize.py +21 -21
- reflectorch/inference/preprocess_exp/preprocess.py +121 -121
- reflectorch/inference/query_matcher.py +81 -81
- reflectorch/inference/record_time.py +43 -43
- reflectorch/inference/sampler_solution.py +56 -56
- reflectorch/inference/scipy_fitter.py +272 -248
- reflectorch/inference/torch_fitter.py +87 -87
- reflectorch/ml/__init__.py +32 -32
- reflectorch/ml/basic_trainer.py +292 -292
- reflectorch/ml/callbacks.py +80 -80
- reflectorch/ml/dataloaders.py +26 -26
- reflectorch/ml/loggers.py +55 -55
- reflectorch/ml/schedulers.py +355 -355
- reflectorch/ml/trainers.py +200 -191
- reflectorch/ml/utils.py +2 -2
- reflectorch/models/__init__.py +15 -14
- reflectorch/models/activations.py +50 -50
- reflectorch/models/encoders/__init__.py +19 -17
- reflectorch/models/encoders/conv_encoder.py +218 -218
- reflectorch/models/encoders/conv_res_net.py +115 -115
- reflectorch/models/encoders/fno.py +133 -133
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -14
- reflectorch/models/networks/mlp_networks.py +434 -428
- reflectorch/models/networks/residual_net.py +156 -156
- reflectorch/paths.py +29 -27
- reflectorch/runs/__init__.py +31 -31
- reflectorch/runs/config.py +25 -25
- reflectorch/runs/slurm_utils.py +93 -93
- reflectorch/runs/train.py +78 -78
- reflectorch/runs/utils.py +404 -401
- reflectorch/test_config.py +4 -4
- reflectorch/train.py +4 -4
- reflectorch/train_on_cluster.py +4 -4
- reflectorch/utils.py +98 -68
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
- reflectorch-1.5.0.dist-info/RECORD +96 -0
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
- reflectorch-1.3.0.dist-info/RECORD +0 -86
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,248 +1,272 @@
|
|
|
1
|
-
import warnings
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
from scipy.optimize import minimize, curve_fit
|
|
5
|
-
import torch
|
|
6
|
-
|
|
7
|
-
from reflectorch.data_generation.priors.base import PriorSampler
|
|
8
|
-
from reflectorch.data_generation.reflectivity import abeles_np
|
|
9
|
-
|
|
10
|
-
__all__ = [
|
|
11
|
-
"standard_refl_fit",
|
|
12
|
-
"refl_fit",
|
|
13
|
-
"fit_refl_curve",
|
|
14
|
-
"restore_masked_params",
|
|
15
|
-
"get_fit_with_growth",
|
|
16
|
-
]
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def standard_restore_params(fitted_params) -> dict:
|
|
20
|
-
num_layers = (fitted_params.size - 2) // 3
|
|
21
|
-
|
|
22
|
-
return dict(
|
|
23
|
-
thickness=fitted_params[:num_layers],
|
|
24
|
-
roughness=fitted_params[num_layers:2 * num_layers + 1],
|
|
25
|
-
sld=fitted_params[2 * num_layers + 1:],
|
|
26
|
-
)
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def mse_loss(curve1, curve2):
|
|
30
|
-
return np.sum((curve1 - curve2) ** 2)
|
|
31
|
-
|
|
32
|
-
def standard_refl_fit(
|
|
33
|
-
q: np.ndarray, curve: np.ndarray,
|
|
34
|
-
init_params: np.ndarray,
|
|
35
|
-
bounds: np.ndarray = None,
|
|
36
|
-
refl_generator=abeles_np,
|
|
37
|
-
restore_params_func=standard_restore_params,
|
|
38
|
-
scale_curve_func=np.log10,
|
|
39
|
-
**kwargs
|
|
40
|
-
):
|
|
41
|
-
if bounds is not None:
|
|
42
|
-
kwargs['bounds'] = bounds
|
|
43
|
-
init_params = np.clip(init_params, *bounds)
|
|
44
|
-
|
|
45
|
-
res = curve_fit(
|
|
46
|
-
standard_get_scaled_curve_func(
|
|
47
|
-
refl_generator=refl_generator,
|
|
48
|
-
restore_params_func=restore_params_func,
|
|
49
|
-
scale_curve_func=scale_curve_func,
|
|
50
|
-
),
|
|
51
|
-
q, scale_curve_func(curve),
|
|
52
|
-
p0=init_params, **kwargs
|
|
53
|
-
)
|
|
54
|
-
|
|
55
|
-
curve = refl_generator(q, **restore_params_func(res[0]))
|
|
56
|
-
return res[0], curve
|
|
57
|
-
|
|
58
|
-
def refl_fit(
|
|
59
|
-
q: np.ndarray,
|
|
60
|
-
curve: np.ndarray,
|
|
61
|
-
init_params: np.ndarray,
|
|
62
|
-
prior_sampler: PriorSampler,
|
|
63
|
-
bounds: np.ndarray = None,
|
|
64
|
-
error_bars: np.ndarray = None,
|
|
65
|
-
scale_curve_func=np.log10,
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
q
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
scale_curve_func=
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
)
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
):
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
def
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
1
|
+
import warnings
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from scipy.optimize import minimize, curve_fit
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from reflectorch.data_generation.priors.base import PriorSampler
|
|
8
|
+
from reflectorch.data_generation.reflectivity import abeles_np
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"standard_refl_fit",
|
|
12
|
+
"refl_fit",
|
|
13
|
+
"fit_refl_curve",
|
|
14
|
+
"restore_masked_params",
|
|
15
|
+
"get_fit_with_growth",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def standard_restore_params(fitted_params) -> dict:
|
|
20
|
+
num_layers = (fitted_params.size - 2) // 3
|
|
21
|
+
|
|
22
|
+
return dict(
|
|
23
|
+
thickness=fitted_params[:num_layers],
|
|
24
|
+
roughness=fitted_params[num_layers:2 * num_layers + 1],
|
|
25
|
+
sld=fitted_params[2 * num_layers + 1:],
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def mse_loss(curve1, curve2):
|
|
30
|
+
return np.sum((curve1 - curve2) ** 2)
|
|
31
|
+
|
|
32
|
+
def standard_refl_fit(
|
|
33
|
+
q: np.ndarray, curve: np.ndarray,
|
|
34
|
+
init_params: np.ndarray,
|
|
35
|
+
bounds: np.ndarray = None,
|
|
36
|
+
refl_generator=abeles_np,
|
|
37
|
+
restore_params_func=standard_restore_params,
|
|
38
|
+
scale_curve_func=np.log10,
|
|
39
|
+
**kwargs
|
|
40
|
+
):
|
|
41
|
+
if bounds is not None:
|
|
42
|
+
kwargs['bounds'] = bounds
|
|
43
|
+
init_params = np.clip(init_params, *bounds)
|
|
44
|
+
|
|
45
|
+
res = curve_fit(
|
|
46
|
+
standard_get_scaled_curve_func(
|
|
47
|
+
refl_generator=refl_generator,
|
|
48
|
+
restore_params_func=restore_params_func,
|
|
49
|
+
scale_curve_func=scale_curve_func,
|
|
50
|
+
),
|
|
51
|
+
q, scale_curve_func(curve),
|
|
52
|
+
p0=init_params, **kwargs
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
curve = refl_generator(q, **restore_params_func(res[0]))
|
|
56
|
+
return res[0], curve
|
|
57
|
+
|
|
58
|
+
def refl_fit(
|
|
59
|
+
q: np.ndarray,
|
|
60
|
+
curve: np.ndarray,
|
|
61
|
+
init_params: np.ndarray,
|
|
62
|
+
prior_sampler: PriorSampler,
|
|
63
|
+
bounds: np.ndarray = None,
|
|
64
|
+
error_bars: np.ndarray = None,
|
|
65
|
+
scale_curve_func=np.log10,
|
|
66
|
+
method: str = 'trf', #'lm', 'trf'
|
|
67
|
+
polishing_max_steps: int = None,
|
|
68
|
+
reflectivity_kwargs: dict = None,
|
|
69
|
+
**kwargs
|
|
70
|
+
):
|
|
71
|
+
if bounds is not None:
|
|
72
|
+
# introduce a small perturbation for fixed bounds
|
|
73
|
+
epsilon = 1e-6
|
|
74
|
+
adjusted_bounds = bounds.copy()
|
|
75
|
+
|
|
76
|
+
for i in range(bounds.shape[1]):
|
|
77
|
+
if bounds[0, i] == bounds[1, i]:
|
|
78
|
+
adjusted_bounds[0, i] -= epsilon
|
|
79
|
+
adjusted_bounds[1, i] += epsilon
|
|
80
|
+
|
|
81
|
+
init_params = np.clip(init_params, *adjusted_bounds)
|
|
82
|
+
if method != 'lm':
|
|
83
|
+
kwargs['bounds'] = adjusted_bounds
|
|
84
|
+
|
|
85
|
+
reflectivity_kwargs = reflectivity_kwargs or {}
|
|
86
|
+
for key, value in reflectivity_kwargs.items():
|
|
87
|
+
if isinstance(value, float):
|
|
88
|
+
reflectivity_kwargs[key] = torch.tensor([[value]], dtype=torch.float64)
|
|
89
|
+
elif isinstance(value, np.ndarray):
|
|
90
|
+
reflectivity_kwargs[key] = torch.tensor(value, dtype=torch.float32).unsqueeze(0)
|
|
91
|
+
|
|
92
|
+
curve = np.clip(curve, a_min=1e-12, a_max=None)
|
|
93
|
+
|
|
94
|
+
if error_bars is not None and scale_curve_func == np.log10:
|
|
95
|
+
error_bars = np.clip(error_bars, a_min=1e-20, a_max=None)
|
|
96
|
+
scaled_error_bars = error_bars / (curve * np.log(10))
|
|
97
|
+
else:
|
|
98
|
+
scaled_error_bars = None
|
|
99
|
+
|
|
100
|
+
if polishing_max_steps is not None:
|
|
101
|
+
if method == 'lm':
|
|
102
|
+
kwargs['maxfev'] = polishing_max_steps
|
|
103
|
+
else:
|
|
104
|
+
kwargs['max_nfev'] = polishing_max_steps
|
|
105
|
+
|
|
106
|
+
res = curve_fit(
|
|
107
|
+
f=get_scaled_curve_func(
|
|
108
|
+
scale_curve_func=scale_curve_func,
|
|
109
|
+
prior_sampler=prior_sampler,
|
|
110
|
+
reflectivity_kwargs=reflectivity_kwargs,
|
|
111
|
+
),
|
|
112
|
+
xdata=q,
|
|
113
|
+
ydata=scale_curve_func(curve).reshape(-1),
|
|
114
|
+
p0=init_params,
|
|
115
|
+
sigma=scaled_error_bars,
|
|
116
|
+
absolute_sigma=True,
|
|
117
|
+
method=method,
|
|
118
|
+
**kwargs
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
curve = prior_sampler.param_model.reflectivity(torch.tensor(q, dtype=torch.float64),
|
|
122
|
+
torch.tensor(res[0], dtype=torch.float64).unsqueeze(0),
|
|
123
|
+
**reflectivity_kwargs).squeeze().numpy()
|
|
124
|
+
# cov matrix --> variance of the parameter estimate
|
|
125
|
+
if res[1] is not None and np.ndim(res[1]) == 2 and np.all(np.isfinite(res[1])):
|
|
126
|
+
pol_param_errs = np.sqrt(np.diag(res[1]))
|
|
127
|
+
else:
|
|
128
|
+
pol_param_errs = np.full_like(res[1], np.nan)
|
|
129
|
+
return res[0], pol_param_errs, curve
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def get_fit_with_growth(
|
|
133
|
+
q: np.ndarray,
|
|
134
|
+
curve: np.ndarray,
|
|
135
|
+
init_params: np.ndarray,
|
|
136
|
+
bounds: np.ndarray = None,
|
|
137
|
+
init_d_change: float = 0.,
|
|
138
|
+
max_d_change: float = 30.,
|
|
139
|
+
scale_curve_func=np.log10,
|
|
140
|
+
**kwargs
|
|
141
|
+
):
|
|
142
|
+
init_params = np.array(list(init_params) + [init_d_change])
|
|
143
|
+
if bounds is not None:
|
|
144
|
+
bounds = np.concatenate([bounds, np.array([0, max_d_change])[..., None]], -1)
|
|
145
|
+
|
|
146
|
+
params, curve = standard_refl_fit(
|
|
147
|
+
q,
|
|
148
|
+
curve,
|
|
149
|
+
init_params,
|
|
150
|
+
bounds,
|
|
151
|
+
refl_generator=growth_reflectivity,
|
|
152
|
+
restore_params_func=get_restore_params_with_growth_func(q_size=q.size, d_idx=0),
|
|
153
|
+
scale_curve_func=scale_curve_func,
|
|
154
|
+
**kwargs
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
params[0] += params[-1] / 2
|
|
158
|
+
return params, curve
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def fit_refl_curve(q: np.ndarray, curve: np.ndarray,
|
|
162
|
+
init_params: np.ndarray,
|
|
163
|
+
bounds: np.ndarray = None,
|
|
164
|
+
refl_generator=abeles_np,
|
|
165
|
+
restore_params_func=standard_restore_params,
|
|
166
|
+
scale_curve_func=np.log10,
|
|
167
|
+
**kwargs
|
|
168
|
+
) -> np.ndarray:
|
|
169
|
+
fitting_func = get_fitting_func(
|
|
170
|
+
q=q, curve=curve,
|
|
171
|
+
refl_generator=refl_generator,
|
|
172
|
+
restore_params_func=restore_params_func,
|
|
173
|
+
scale_curve_func=scale_curve_func,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
res = minimize(fitting_func, init_params, bounds=bounds, **kwargs)
|
|
177
|
+
|
|
178
|
+
if not res.success:
|
|
179
|
+
warnings.warn(f"Minimization did not converge.")
|
|
180
|
+
return res.x
|
|
181
|
+
|
|
182
|
+
def standard_get_scaled_curve_func(
|
|
183
|
+
refl_generator=abeles_np,
|
|
184
|
+
restore_params_func=standard_restore_params,
|
|
185
|
+
scale_curve_func=np.log10,
|
|
186
|
+
):
|
|
187
|
+
def scaled_curve_func(q, *fitted_params):
|
|
188
|
+
fitted_params = restore_params_func(np.asarray(fitted_params))
|
|
189
|
+
fitted_curve = refl_generator(q, **fitted_params)
|
|
190
|
+
scaled_curve = scale_curve_func(fitted_curve)
|
|
191
|
+
return scaled_curve
|
|
192
|
+
|
|
193
|
+
return scaled_curve_func
|
|
194
|
+
|
|
195
|
+
def get_scaled_curve_func(
|
|
196
|
+
scale_curve_func=np.log10,
|
|
197
|
+
prior_sampler: PriorSampler = None,
|
|
198
|
+
reflectivity_kwargs: dict = None,
|
|
199
|
+
):
|
|
200
|
+
reflectivity_kwargs = reflectivity_kwargs or {}
|
|
201
|
+
|
|
202
|
+
def scaled_curve_func(q, *fitted_params):
|
|
203
|
+
q_tensor = torch.from_numpy(q).to(torch.float64)
|
|
204
|
+
fitted_params_tensor = torch.tensor(fitted_params, dtype=torch.float64).unsqueeze(0)
|
|
205
|
+
|
|
206
|
+
fitted_curve_tensor = prior_sampler.param_model.reflectivity(q_tensor, fitted_params_tensor, **reflectivity_kwargs)
|
|
207
|
+
fitted_curve = fitted_curve_tensor.squeeze().numpy()
|
|
208
|
+
|
|
209
|
+
scaled_curve = scale_curve_func(fitted_curve)
|
|
210
|
+
|
|
211
|
+
return scaled_curve.reshape(-1)
|
|
212
|
+
|
|
213
|
+
return scaled_curve_func
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def get_fitting_func(
|
|
217
|
+
q: np.ndarray,
|
|
218
|
+
curve: np.ndarray,
|
|
219
|
+
refl_generator=abeles_np,
|
|
220
|
+
restore_params_func=standard_restore_params,
|
|
221
|
+
scale_curve_func=np.log10,
|
|
222
|
+
loss_func=mse_loss,
|
|
223
|
+
):
|
|
224
|
+
scaled_curve = scale_curve_func(curve)
|
|
225
|
+
|
|
226
|
+
def fitting_func(fitted_params):
|
|
227
|
+
fitted_params = restore_params_func(fitted_params)
|
|
228
|
+
fitted_curve = refl_generator(q, **fitted_params)
|
|
229
|
+
loss = loss_func(scale_curve_func(fitted_curve), scaled_curve)
|
|
230
|
+
return loss
|
|
231
|
+
|
|
232
|
+
return fitting_func
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def restore_masked_params(fixed_params, fixed_mask):
|
|
236
|
+
def restore_params(fitted_params) -> dict:
|
|
237
|
+
params = np.empty_like(fixed_mask).astype(fitted_params.dtype)
|
|
238
|
+
params[fixed_mask] = fixed_params
|
|
239
|
+
params[~fixed_mask] = fitted_params
|
|
240
|
+
return standard_restore_params(params)
|
|
241
|
+
|
|
242
|
+
return restore_params
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def base_params2growth(base_params: dict, d_shift: np.ndarray, d_idx: int = 0) -> dict:
|
|
246
|
+
d_init = base_params['thickness'][None]
|
|
247
|
+
q_size = d_shift.size
|
|
248
|
+
d = d_init.repeat(q_size, 0)
|
|
249
|
+
d[:, d_idx] = d[:, d_idx] + d_shift
|
|
250
|
+
|
|
251
|
+
roughness = np.broadcast_to(base_params['roughness'][None], (q_size, base_params['roughness'].size))
|
|
252
|
+
sld = np.broadcast_to(base_params['sld'][None], (q_size, base_params['sld'].size))
|
|
253
|
+
|
|
254
|
+
return {
|
|
255
|
+
'thickness': d,
|
|
256
|
+
'roughness': roughness,
|
|
257
|
+
'sld': sld,
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def get_restore_params_with_growth_func(q_size: int, d_idx: int = 0):
|
|
262
|
+
def restore_params_with_growth(fitted_params) -> dict:
|
|
263
|
+
fitted_params, delta_d = fitted_params[:-1], fitted_params[-1]
|
|
264
|
+
base_params = standard_restore_params(fitted_params)
|
|
265
|
+
d_shift = np.linspace(0, delta_d, q_size)
|
|
266
|
+
return base_params2growth(base_params, d_shift, d_idx)
|
|
267
|
+
|
|
268
|
+
return restore_params_with_growth
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def growth_reflectivity(q: np.ndarray, **kwargs):
|
|
272
|
+
return abeles_np(q[..., None], **kwargs).flatten()
|