reflectorch 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- reflectorch/__init__.py +17 -0
- reflectorch/data_generation/__init__.py +128 -0
- reflectorch/data_generation/dataset.py +216 -0
- reflectorch/data_generation/likelihoods.py +80 -0
- reflectorch/data_generation/noise.py +471 -0
- reflectorch/data_generation/priors/__init__.py +60 -0
- reflectorch/data_generation/priors/base.py +55 -0
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
- reflectorch/data_generation/priors/independent_priors.py +195 -0
- reflectorch/data_generation/priors/multilayer_models.py +311 -0
- reflectorch/data_generation/priors/multilayer_structures.py +104 -0
- reflectorch/data_generation/priors/no_constraints.py +206 -0
- reflectorch/data_generation/priors/parametric_models.py +842 -0
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
- reflectorch/data_generation/priors/params.py +252 -0
- reflectorch/data_generation/priors/sampler_strategies.py +370 -0
- reflectorch/data_generation/priors/scaler_mixin.py +65 -0
- reflectorch/data_generation/priors/subprior_sampler.py +371 -0
- reflectorch/data_generation/priors/utils.py +118 -0
- reflectorch/data_generation/process_data.py +41 -0
- reflectorch/data_generation/q_generator.py +280 -0
- reflectorch/data_generation/reflectivity/__init__.py +102 -0
- reflectorch/data_generation/reflectivity/abeles.py +97 -0
- reflectorch/data_generation/reflectivity/kinematical.py +71 -0
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
- reflectorch/data_generation/reflectivity/smearing.py +138 -0
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
- reflectorch/data_generation/scale_curves.py +112 -0
- reflectorch/data_generation/smearing.py +99 -0
- reflectorch/data_generation/utils.py +223 -0
- reflectorch/extensions/__init__.py +0 -0
- reflectorch/extensions/jupyter/__init__.py +11 -0
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -0
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -0
- reflectorch/extensions/matplotlib/losses.py +32 -0
- reflectorch/extensions/refnx/refnx_conversion.py +77 -0
- reflectorch/inference/__init__.py +28 -0
- reflectorch/inference/inference_model.py +848 -0
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +55 -0
- reflectorch/inference/multilayer_fitter.py +171 -0
- reflectorch/inference/multilayer_inference_model.py +193 -0
- reflectorch/inference/plotting.py +524 -0
- reflectorch/inference/preprocess_exp/__init__.py +7 -0
- reflectorch/inference/preprocess_exp/attenuation.py +36 -0
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
- reflectorch/inference/preprocess_exp/footprint.py +81 -0
- reflectorch/inference/preprocess_exp/interpolation.py +19 -0
- reflectorch/inference/preprocess_exp/normalize.py +21 -0
- reflectorch/inference/preprocess_exp/preprocess.py +121 -0
- reflectorch/inference/query_matcher.py +82 -0
- reflectorch/inference/record_time.py +43 -0
- reflectorch/inference/sampler_solution.py +56 -0
- reflectorch/inference/scipy_fitter.py +364 -0
- reflectorch/inference/torch_fitter.py +87 -0
- reflectorch/ml/__init__.py +32 -0
- reflectorch/ml/basic_trainer.py +292 -0
- reflectorch/ml/callbacks.py +81 -0
- reflectorch/ml/dataloaders.py +27 -0
- reflectorch/ml/loggers.py +56 -0
- reflectorch/ml/schedulers.py +356 -0
- reflectorch/ml/trainers.py +201 -0
- reflectorch/ml/utils.py +2 -0
- reflectorch/models/__init__.py +16 -0
- reflectorch/models/activations.py +50 -0
- reflectorch/models/encoders/__init__.py +19 -0
- reflectorch/models/encoders/conv_encoder.py +219 -0
- reflectorch/models/encoders/conv_res_net.py +115 -0
- reflectorch/models/encoders/fno.py +134 -0
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -0
- reflectorch/models/networks/mlp_networks.py +434 -0
- reflectorch/models/networks/residual_net.py +157 -0
- reflectorch/paths.py +29 -0
- reflectorch/runs/__init__.py +31 -0
- reflectorch/runs/config.py +25 -0
- reflectorch/runs/slurm_utils.py +93 -0
- reflectorch/runs/train.py +78 -0
- reflectorch/runs/utils.py +405 -0
- reflectorch/test_config.py +4 -0
- reflectorch/train.py +4 -0
- reflectorch/train_on_cluster.py +4 -0
- reflectorch/utils.py +98 -0
- reflectorch-1.5.1.dist-info/METADATA +151 -0
- reflectorch-1.5.1.dist-info/RECORD +96 -0
- reflectorch-1.5.1.dist-info/WHEEL +5 -0
- reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
- reflectorch-1.5.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,524 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import matplotlib.pyplot as plt
|
|
3
|
+
import matplotlib.ticker as mticker
|
|
4
|
+
from matplotlib.lines import Line2D
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def print_prediction_results(prediction_dict, param_names=None, width=10, precision=3, header=True, print_err=False):
|
|
8
|
+
|
|
9
|
+
if param_names is None:
|
|
10
|
+
param_names = prediction_dict.get("param_names", [])
|
|
11
|
+
|
|
12
|
+
pred = np.asarray(prediction_dict.get("predicted_params_array", []), dtype=float)
|
|
13
|
+
pol = prediction_dict.get("polished_params_array", None)
|
|
14
|
+
pol = np.asarray(pol, dtype=float) if pol is not None else None
|
|
15
|
+
pol_err = prediction_dict.get('polished_params_error_array', None) if print_err else None
|
|
16
|
+
pol_err = np.asarray(pol_err, dtype=float) if pol_err is not None else None
|
|
17
|
+
|
|
18
|
+
name_w = max(14, max((len(str(n)) for n in param_names), default=14))
|
|
19
|
+
|
|
20
|
+
num_fmt = f"{{:>{width}.{precision}f}}"
|
|
21
|
+
|
|
22
|
+
if header:
|
|
23
|
+
hdr = f"{'Parameter'.ljust(name_w)} {'Predicted'.rjust(width)}"
|
|
24
|
+
if pol is not None:
|
|
25
|
+
hdr += f" {'Polished'.rjust(width)}"
|
|
26
|
+
if pol_err is not None:
|
|
27
|
+
hdr += f" {'Polished err'.rjust(width)}"
|
|
28
|
+
print(hdr)
|
|
29
|
+
print("-" * len(hdr))
|
|
30
|
+
|
|
31
|
+
for i, name in enumerate(param_names):
|
|
32
|
+
pred_val = pred[i] if i < pred.size else float("nan")
|
|
33
|
+
row = f"{str(name).ljust(name_w)} {num_fmt.format(pred_val)}"
|
|
34
|
+
if pol is not None:
|
|
35
|
+
pol_val = pol[i] if i < pol.size else float("nan")
|
|
36
|
+
row += f" {num_fmt.format(pol_val)}"
|
|
37
|
+
if pol_err is not None:
|
|
38
|
+
pol_err_val = pol_err[i] if i < pol_err.size else float('nan')
|
|
39
|
+
row += f' {num_fmt.format(pol_err_val)}'
|
|
40
|
+
print(row)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def plot_prediction_results(
|
|
44
|
+
prediction_dict: dict,
|
|
45
|
+
q_exp: np.ndarray,
|
|
46
|
+
curve_exp: np.ndarray,
|
|
47
|
+
sigmas_exp: np.ndarray = None,
|
|
48
|
+
logx=False,
|
|
49
|
+
):
|
|
50
|
+
q_pred = prediction_dict['q_plot_pred']
|
|
51
|
+
r_pred = prediction_dict['predicted_curve']
|
|
52
|
+
r_pol = prediction_dict.get('polished_curve', None)
|
|
53
|
+
|
|
54
|
+
q_pol = None
|
|
55
|
+
if r_pol is not None:
|
|
56
|
+
if len(r_pol) == len(q_pred):
|
|
57
|
+
q_pol = q_pred
|
|
58
|
+
elif len(r_pol) == len(q_exp):
|
|
59
|
+
q_pol = q_exp
|
|
60
|
+
|
|
61
|
+
z_sld = prediction_dict.get('predicted_sld_xaxis', None)
|
|
62
|
+
sld_pred_c = prediction_dict.get('predicted_sld_profile', None)
|
|
63
|
+
sld_pol_c = prediction_dict.get('sld_profile_polished', None)
|
|
64
|
+
|
|
65
|
+
plot_sld = (z_sld is not None) and (sld_pred_c is not None or sld_pol_c is not None)
|
|
66
|
+
|
|
67
|
+
sld_is_complex = np.iscomplexobj(sld_pred_c)
|
|
68
|
+
|
|
69
|
+
sld_pred_label = 'pred. SLD (Re)' if sld_is_complex else 'pred. SLD'
|
|
70
|
+
sld_pol_label = 'polished SLD (Re)' if sld_is_complex else 'polished SLD'
|
|
71
|
+
|
|
72
|
+
fig, axes = plot_reflectivity(
|
|
73
|
+
q_exp=q_exp, r_exp=curve_exp, yerr=sigmas_exp,
|
|
74
|
+
q_pred=q_pred, r_pred=r_pred,
|
|
75
|
+
q_pol=q_pol, r_pol=r_pol,
|
|
76
|
+
z_sld=z_sld,
|
|
77
|
+
sld_pred=sld_pred_c.real if sld_pred_c is not None else None,
|
|
78
|
+
sld_pol=sld_pol_c.real if sld_pol_c is not None else None,
|
|
79
|
+
sld_pred_label=sld_pred_label,
|
|
80
|
+
sld_pol_label=sld_pol_label,
|
|
81
|
+
plot_sld_profile=plot_sld,
|
|
82
|
+
logx=logx,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
if sld_is_complex and plot_sld:
|
|
86
|
+
ax_r, ax_s = axes
|
|
87
|
+
ax_s.plot(z_sld, sld_pred_c.imag, color='darkgreen', lw=2.0, ls='-', zorder=4, label='pred. SLD (Im)')
|
|
88
|
+
if sld_pol_c is not None:
|
|
89
|
+
ax_s.plot(z_sld, sld_pol_c.imag, color='cyan', lw=2.0, ls='--', zorder=5, label='polished SLD (Im)')
|
|
90
|
+
ax_s.legend(fontsize=14, frameon=True)
|
|
91
|
+
|
|
92
|
+
return fig, axes
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def plot_reflectivity(
|
|
96
|
+
*,
|
|
97
|
+
q_exp=None,
|
|
98
|
+
r_exp=None,
|
|
99
|
+
yerr=None,
|
|
100
|
+
xerr=None,
|
|
101
|
+
q_pred=None,
|
|
102
|
+
r_pred=None,
|
|
103
|
+
q_pol=None,
|
|
104
|
+
r_pol=None,
|
|
105
|
+
z_sld=None,
|
|
106
|
+
sld_pred=None,
|
|
107
|
+
sld_pol=None,
|
|
108
|
+
plot_sld_profile=False,
|
|
109
|
+
figsize=None,
|
|
110
|
+
logx=False,
|
|
111
|
+
logy=True,
|
|
112
|
+
x_ticks_log=None,
|
|
113
|
+
y_ticks_log=(10.0 ** -np.arange(0, 12, 2)),
|
|
114
|
+
q_label=r'q [$\mathrm{\AA^{-1}}$]',
|
|
115
|
+
r_label='R(q)',
|
|
116
|
+
z_label=r'z [$\mathrm{\AA}$]',
|
|
117
|
+
sld_label=r'SLD [$10^{-6}\ \mathrm{\AA^{-2}}$]',
|
|
118
|
+
xlim=None,
|
|
119
|
+
axis_label_size=20,
|
|
120
|
+
tick_label_size=15,
|
|
121
|
+
legend_fontsize=14,
|
|
122
|
+
exp_style='auto',
|
|
123
|
+
exp_color='blue',
|
|
124
|
+
exp_facecolor='none',
|
|
125
|
+
exp_marker='o',
|
|
126
|
+
exp_ms=3,
|
|
127
|
+
exp_alpha=1.0,
|
|
128
|
+
exp_errcolor='purple',
|
|
129
|
+
exp_elinewidth=1.0,
|
|
130
|
+
exp_capsize=1.0,
|
|
131
|
+
exp_capthick=1.0,
|
|
132
|
+
exp_zorder=2,
|
|
133
|
+
pred_color='red',
|
|
134
|
+
pred_lw=2.0,
|
|
135
|
+
pred_ls='-',
|
|
136
|
+
pred_alpha=1.0,
|
|
137
|
+
pred_zorder=3,
|
|
138
|
+
pol_color='orange',
|
|
139
|
+
pol_lw=2.0,
|
|
140
|
+
pol_ls='--',
|
|
141
|
+
pol_alpha=1.0,
|
|
142
|
+
pol_zorder=4,
|
|
143
|
+
sld_pred_color='red',
|
|
144
|
+
sld_pred_lw=2.0,
|
|
145
|
+
sld_pred_ls='-',
|
|
146
|
+
sld_pol_color='orange',
|
|
147
|
+
sld_pol_lw=2.0,
|
|
148
|
+
sld_pol_ls='--',
|
|
149
|
+
exp_label='exp. data',
|
|
150
|
+
pred_label='prediction',
|
|
151
|
+
pol_label='polished prediction',
|
|
152
|
+
sld_pred_label='pred. SLD',
|
|
153
|
+
sld_pol_label='polished SLD',
|
|
154
|
+
legend=True,
|
|
155
|
+
legend_kwargs=None
|
|
156
|
+
):
|
|
157
|
+
|
|
158
|
+
def _np(a):
|
|
159
|
+
return None if a is None else np.asarray(a)
|
|
160
|
+
|
|
161
|
+
def _mask(x, y):
|
|
162
|
+
m = np.isfinite(x) & np.isfinite(y)
|
|
163
|
+
if logx: m &= (x > 0.0)
|
|
164
|
+
if logy: m &= (y > 0.0)
|
|
165
|
+
return m
|
|
166
|
+
|
|
167
|
+
def _slice_sym_err(err, mask):
|
|
168
|
+
if err is None:
|
|
169
|
+
return None
|
|
170
|
+
if np.isscalar(err):
|
|
171
|
+
return err
|
|
172
|
+
e = np.asarray(err)
|
|
173
|
+
if e.ndim != 1:
|
|
174
|
+
raise ValueError("Errors must be scalar or 1-D array.")
|
|
175
|
+
return e[mask]
|
|
176
|
+
|
|
177
|
+
q_exp, r_exp, yerr, xerr = _np(q_exp), _np(r_exp), _np(yerr), _np(xerr)
|
|
178
|
+
q_pred, r_pred = _np(q_pred), _np(r_pred)
|
|
179
|
+
q_pol, r_pol = _np(q_pol), _np(r_pol)
|
|
180
|
+
z_sld, sld_pred, sld_pol = _np(z_sld), _np(sld_pred), _np(sld_pol)
|
|
181
|
+
|
|
182
|
+
# Figure & axes
|
|
183
|
+
if figsize is None:
|
|
184
|
+
figsize = (12, 6) if plot_sld_profile else (6, 6)
|
|
185
|
+
if plot_sld_profile:
|
|
186
|
+
fig, (ax_r, ax_s) = plt.subplots(1, 2, figsize=figsize)
|
|
187
|
+
else:
|
|
188
|
+
fig, ax_r = plt.subplots(1, 1, figsize=figsize)
|
|
189
|
+
ax_s = None
|
|
190
|
+
|
|
191
|
+
# Apply x-limits (right-only or both)
|
|
192
|
+
if xlim is not None:
|
|
193
|
+
if np.isscalar(xlim):
|
|
194
|
+
cur_left, _ = ax_r.get_xlim()
|
|
195
|
+
if logx and cur_left <= 0:
|
|
196
|
+
cur_left = 1e-12
|
|
197
|
+
ax_r.set_xlim(left=cur_left, right=float(xlim))
|
|
198
|
+
else:
|
|
199
|
+
xmin, xmax = xlim
|
|
200
|
+
if logx and xmin is not None and xmin <= 0:
|
|
201
|
+
raise ValueError("For log-x, xmin must be > 0.")
|
|
202
|
+
ax_r.set_xlim(left=xmin, right=xmax)
|
|
203
|
+
|
|
204
|
+
# Axis scales / labels / ticks
|
|
205
|
+
if logx: ax_r.set_xscale('log')
|
|
206
|
+
if logy: ax_r.set_yscale('log')
|
|
207
|
+
|
|
208
|
+
ax_r.set_xlabel(q_label, fontsize=axis_label_size)
|
|
209
|
+
ax_r.set_ylabel(r_label, fontsize=axis_label_size)
|
|
210
|
+
ax_r.tick_params(axis='both', which='major', labelsize=tick_label_size)
|
|
211
|
+
ax_r.tick_params(axis='both', which='minor', labelsize=tick_label_size)
|
|
212
|
+
if logx and x_ticks_log is not None:
|
|
213
|
+
ax_r.xaxis.set_major_locator(mticker.FixedLocator(x_ticks_log))
|
|
214
|
+
if logy and y_ticks_log is not None:
|
|
215
|
+
ax_r.yaxis.set_major_locator(mticker.FixedLocator(y_ticks_log))
|
|
216
|
+
|
|
217
|
+
handles = []
|
|
218
|
+
|
|
219
|
+
# Experimental plot
|
|
220
|
+
exp_handle = None
|
|
221
|
+
if q_exp is not None and r_exp is not None:
|
|
222
|
+
m = _mask(q_exp, r_exp)
|
|
223
|
+
style = exp_style if exp_style != 'auto' else ('errorbar' if yerr is not None else 'scatter')
|
|
224
|
+
|
|
225
|
+
if style == 'errorbar' and (yerr is not None):
|
|
226
|
+
yerr_m = _slice_sym_err(yerr, m)
|
|
227
|
+
xerr_m = _slice_sym_err(xerr, m)
|
|
228
|
+
ax_r.errorbar(
|
|
229
|
+
q_exp[m], r_exp[m], yerr=yerr_m, xerr=xerr_m,
|
|
230
|
+
color=exp_color, ecolor=exp_errcolor,
|
|
231
|
+
elinewidth=exp_elinewidth, capsize=exp_capsize,
|
|
232
|
+
capthick=(exp_elinewidth if exp_capthick is None else exp_capthick),
|
|
233
|
+
marker=exp_marker, linestyle='none', markersize=exp_ms,
|
|
234
|
+
markerfacecolor=exp_facecolor, markeredgecolor=exp_color,
|
|
235
|
+
alpha=exp_alpha, zorder=exp_zorder, label=None
|
|
236
|
+
)
|
|
237
|
+
exp_handle = Line2D([], [], color=exp_color, marker=exp_marker,
|
|
238
|
+
linestyle='none', markersize=exp_ms,
|
|
239
|
+
markerfacecolor=exp_facecolor, markeredgecolor=exp_color,
|
|
240
|
+
alpha=exp_alpha, label=exp_label)
|
|
241
|
+
elif style == 'scatter':
|
|
242
|
+
ax_r.scatter(
|
|
243
|
+
q_exp[m], r_exp[m],
|
|
244
|
+
s=exp_ms**2, marker=exp_marker,
|
|
245
|
+
facecolors=exp_facecolor, edgecolors=exp_color,
|
|
246
|
+
alpha=exp_alpha, zorder=exp_zorder, label=None
|
|
247
|
+
)
|
|
248
|
+
exp_handle = Line2D([], [], color=exp_color, marker=exp_marker,
|
|
249
|
+
linestyle='none', markersize=exp_ms,
|
|
250
|
+
markerfacecolor=exp_facecolor, markeredgecolor=exp_color,
|
|
251
|
+
alpha=exp_alpha, label=exp_label)
|
|
252
|
+
else: # 'line'
|
|
253
|
+
ln = ax_r.plot(
|
|
254
|
+
q_exp[m], r_exp[m], color=exp_color, lw=1.0, ls='-',
|
|
255
|
+
alpha=exp_alpha, zorder=exp_zorder, label=exp_label
|
|
256
|
+
)[0]
|
|
257
|
+
exp_handle = ln
|
|
258
|
+
|
|
259
|
+
if exp_handle is not None:
|
|
260
|
+
handles.append(exp_handle)
|
|
261
|
+
|
|
262
|
+
# Predicted line
|
|
263
|
+
pred_handle = None
|
|
264
|
+
if q_pred is not None and r_pred is not None:
|
|
265
|
+
m = _mask(q_pred, r_pred)
|
|
266
|
+
pred_handle = ax_r.plot(
|
|
267
|
+
q_pred[m], r_pred[m],
|
|
268
|
+
color=pred_color, lw=pred_lw, ls=pred_ls,
|
|
269
|
+
alpha=pred_alpha, zorder=pred_zorder, label=pred_label
|
|
270
|
+
)[0]
|
|
271
|
+
handles.append(pred_handle)
|
|
272
|
+
|
|
273
|
+
# Polished line
|
|
274
|
+
pol_handle = None
|
|
275
|
+
if q_pol is not None and r_pol is not None:
|
|
276
|
+
m = _mask(q_pol, r_pol)
|
|
277
|
+
pol_handle = ax_r.plot(
|
|
278
|
+
q_pol[m], r_pol[m],
|
|
279
|
+
color=pol_color, lw=pol_lw, ls=pol_ls,
|
|
280
|
+
alpha=pol_alpha, zorder=pol_zorder, label=pol_label
|
|
281
|
+
)[0]
|
|
282
|
+
handles.append(pol_handle)
|
|
283
|
+
|
|
284
|
+
if legend and handles:
|
|
285
|
+
lk = {} if legend_kwargs is None else dict(legend_kwargs)
|
|
286
|
+
ax_r.legend(handles=handles,
|
|
287
|
+
labels=[h.get_label() for h in handles],
|
|
288
|
+
fontsize=legend_fontsize, loc='best', **lk)
|
|
289
|
+
|
|
290
|
+
# SLD panel (optional)
|
|
291
|
+
if ax_s is not None:
|
|
292
|
+
ax_s.set_xlabel(z_label, fontsize=axis_label_size)
|
|
293
|
+
ax_s.set_ylabel(sld_label, fontsize=axis_label_size)
|
|
294
|
+
ax_s.tick_params(axis='both', which='major', labelsize=tick_label_size)
|
|
295
|
+
ax_s.tick_params(axis='both', which='minor', labelsize=tick_label_size)
|
|
296
|
+
|
|
297
|
+
if z_sld is not None and sld_pred is not None:
|
|
298
|
+
ax_s.plot(z_sld, sld_pred,
|
|
299
|
+
color=sld_pred_color, lw=sld_pred_lw, ls=sld_pred_ls,
|
|
300
|
+
label=sld_pred_label)
|
|
301
|
+
if z_sld is not None and sld_pol is not None:
|
|
302
|
+
ax_s.plot(z_sld, sld_pol,
|
|
303
|
+
color=sld_pol_color, lw=sld_pol_lw, ls=sld_pol_ls,
|
|
304
|
+
label=sld_pol_label)
|
|
305
|
+
|
|
306
|
+
if legend:
|
|
307
|
+
ax_s.legend(fontsize=legend_fontsize, loc='best', **(legend_kwargs or {}))
|
|
308
|
+
|
|
309
|
+
plt.tight_layout()
|
|
310
|
+
return (fig, (ax_r, ax_s)) if ax_s is not None else (fig, ax_r)
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def plot_reflectivity_multi(
|
|
314
|
+
*,
|
|
315
|
+
rq_series,
|
|
316
|
+
sld_series=None,
|
|
317
|
+
plot_sld_profile=False,
|
|
318
|
+
figsize=None,
|
|
319
|
+
logx=False,
|
|
320
|
+
logy=True,
|
|
321
|
+
xlim=None,
|
|
322
|
+
x_ticks_log=None,
|
|
323
|
+
y_ticks_log=(10.0 ** -np.arange(0, 12, 2)),
|
|
324
|
+
q_label=r'q [$\mathrm{\AA^{-1}}$]',
|
|
325
|
+
r_label='R(q)',
|
|
326
|
+
z_label=r'z [$\mathrm{\AA}$]',
|
|
327
|
+
sld_label=r'SLD [$10^{-6}\ \mathrm{\AA^{-2}}$]',
|
|
328
|
+
axis_label_size=20,
|
|
329
|
+
tick_label_size=15,
|
|
330
|
+
legend=True,
|
|
331
|
+
legend_fontsize=12,
|
|
332
|
+
legend_kwargs=None,
|
|
333
|
+
):
|
|
334
|
+
"""
|
|
335
|
+
Plot multiple R(q) series (and optional SLD lines) with per-series styling.
|
|
336
|
+
|
|
337
|
+
rq_series: list of dicts, each with:
|
|
338
|
+
required:
|
|
339
|
+
- x: 1D array
|
|
340
|
+
- y: 1D array
|
|
341
|
+
optional (per series):
|
|
342
|
+
- kind: 'errorbar' | 'scatter' | 'line' (default 'line')
|
|
343
|
+
- label: str
|
|
344
|
+
- color: str
|
|
345
|
+
- alpha: float (0..1)
|
|
346
|
+
- zorder: int
|
|
347
|
+
# scatter / marker for errorbar:
|
|
348
|
+
- marker: str (default 'o')
|
|
349
|
+
- ms: float (marker size in pt; for scatter internally converted to s=ms**2)
|
|
350
|
+
- facecolor: str (scatter marker face)
|
|
351
|
+
# errorbar only:
|
|
352
|
+
- yerr: scalar or 1D array
|
|
353
|
+
- xerr: scalar or 1D array
|
|
354
|
+
- ecolor: str
|
|
355
|
+
- elinewidth: float
|
|
356
|
+
- capsize: float
|
|
357
|
+
- capthick: float
|
|
358
|
+
# line only:
|
|
359
|
+
- lw: float
|
|
360
|
+
- ls: str (e.g. '-', '--', ':')
|
|
361
|
+
|
|
362
|
+
sld_series: list of dicts (only lines), each with:
|
|
363
|
+
- x: 1D array (z-axis)
|
|
364
|
+
- y: 1D array (SLD)
|
|
365
|
+
- label: str (optional)
|
|
366
|
+
- color: str (optional)
|
|
367
|
+
- lw: float (optional)
|
|
368
|
+
- ls: str (optional)
|
|
369
|
+
- alpha: float (optional)
|
|
370
|
+
- zorder: int (optional)
|
|
371
|
+
"""
|
|
372
|
+
|
|
373
|
+
def _np(a): return None if a is None else np.asarray(a)
|
|
374
|
+
|
|
375
|
+
def _mask(x, y):
|
|
376
|
+
m = np.isfinite(x) & np.isfinite(y)
|
|
377
|
+
if logx: m &= (x > 0.0)
|
|
378
|
+
if logy: m &= (y > 0.0)
|
|
379
|
+
return m
|
|
380
|
+
|
|
381
|
+
# Figure & axes
|
|
382
|
+
if figsize is None:
|
|
383
|
+
figsize = (12, 6) if plot_sld_profile else (6, 6)
|
|
384
|
+
if plot_sld_profile:
|
|
385
|
+
fig, (ax_r, ax_s) = plt.subplots(1, 2, figsize=figsize)
|
|
386
|
+
else:
|
|
387
|
+
fig, ax_r = plt.subplots(1, 1, figsize=figsize)
|
|
388
|
+
ax_s = None
|
|
389
|
+
|
|
390
|
+
# Axis scales / labels / ticks
|
|
391
|
+
if logx: ax_r.set_xscale('log')
|
|
392
|
+
if logy: ax_r.set_yscale('log')
|
|
393
|
+
|
|
394
|
+
ax_r.set_xlabel(q_label, fontsize=axis_label_size)
|
|
395
|
+
ax_r.set_ylabel(r_label, fontsize=axis_label_size)
|
|
396
|
+
ax_r.tick_params(axis='both', which='major', labelsize=tick_label_size)
|
|
397
|
+
ax_r.tick_params(axis='both', which='minor', labelsize=tick_label_size)
|
|
398
|
+
if logx and x_ticks_log is not None:
|
|
399
|
+
ax_r.xaxis.set_major_locator(mticker.FixedLocator(x_ticks_log))
|
|
400
|
+
if logy and y_ticks_log is not None:
|
|
401
|
+
ax_r.yaxis.set_major_locator(mticker.FixedLocator(y_ticks_log))
|
|
402
|
+
|
|
403
|
+
# Apply x-limits (right-only or both)
|
|
404
|
+
if xlim is not None:
|
|
405
|
+
if np.isscalar(xlim):
|
|
406
|
+
left, _ = ax_r.get_xlim()
|
|
407
|
+
if logx and left <= 0:
|
|
408
|
+
left = 1e-12
|
|
409
|
+
ax_r.set_xlim(left=left, right=float(xlim))
|
|
410
|
+
else:
|
|
411
|
+
xmin, xmax = xlim
|
|
412
|
+
if logx and xmin is not None and xmin <= 0:
|
|
413
|
+
raise ValueError("For log-x, xmin must be > 0.")
|
|
414
|
+
ax_r.set_xlim(left=xmin, right=xmax)
|
|
415
|
+
|
|
416
|
+
# Plot all R(q) series in the given order (order = legend order)
|
|
417
|
+
handles = []
|
|
418
|
+
for s in rq_series:
|
|
419
|
+
kind = s.get('kind', 'line')
|
|
420
|
+
x = _np(s.get('x'))
|
|
421
|
+
y = _np(s.get('y'))
|
|
422
|
+
if x is None or y is None:
|
|
423
|
+
continue
|
|
424
|
+
|
|
425
|
+
label = s.get('label', None)
|
|
426
|
+
color = s.get('color', None)
|
|
427
|
+
alpha = s.get('alpha', 1.0)
|
|
428
|
+
zord = s.get('zorder', None)
|
|
429
|
+
ms = s.get('ms', 5.0)
|
|
430
|
+
marker = s.get('marker', 'o')
|
|
431
|
+
|
|
432
|
+
m = _mask(x, y)
|
|
433
|
+
|
|
434
|
+
if kind == 'errorbar':
|
|
435
|
+
yerr = s.get('yerr', None)
|
|
436
|
+
xerr = s.get('xerr', None)
|
|
437
|
+
ecolor = s.get('ecolor', color)
|
|
438
|
+
elinewidth = s.get('elinewidth', 1.0)
|
|
439
|
+
capsize = s.get('capsize', 0.0)
|
|
440
|
+
capthick = s.get('capthick', elinewidth)
|
|
441
|
+
|
|
442
|
+
# Symmetric error input: scalar or 1D
|
|
443
|
+
def _slice_sym(err):
|
|
444
|
+
if err is None: return None
|
|
445
|
+
if np.isscalar(err): return err
|
|
446
|
+
arr = np.asarray(err)
|
|
447
|
+
if arr.ndim != 1:
|
|
448
|
+
raise ValueError("For symmetric error bars, provide scalar or 1-D array.")
|
|
449
|
+
return arr[m]
|
|
450
|
+
|
|
451
|
+
yerr_m = _slice_sym(yerr)
|
|
452
|
+
xerr_m = _slice_sym(xerr)
|
|
453
|
+
|
|
454
|
+
ax_r.errorbar(
|
|
455
|
+
x[m], y[m], yerr=yerr_m, xerr=xerr_m,
|
|
456
|
+
color=color, ecolor=ecolor,
|
|
457
|
+
elinewidth=elinewidth, capsize=capsize, capthick=capthick,
|
|
458
|
+
marker=marker, linestyle='none', markersize=ms,
|
|
459
|
+
markerfacecolor=s.get('facecolor', 'none'),
|
|
460
|
+
markeredgecolor=color,
|
|
461
|
+
alpha=alpha, zorder=zord, label=None
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
h = Line2D([], [], color=color, marker=marker, linestyle='none',
|
|
465
|
+
markersize=ms, markerfacecolor=s.get('facecolor','none'),
|
|
466
|
+
markeredgecolor=color, alpha=alpha, label=label)
|
|
467
|
+
if label is not None:
|
|
468
|
+
handles.append(h)
|
|
469
|
+
|
|
470
|
+
elif kind == 'scatter':
|
|
471
|
+
facecolor = s.get('facecolor', 'none')
|
|
472
|
+
sc = ax_r.scatter(
|
|
473
|
+
x[m], y[m], s=ms**2, marker=marker,
|
|
474
|
+
facecolors=facecolor, edgecolors=color,
|
|
475
|
+
alpha=alpha, zorder=zord, label=None
|
|
476
|
+
)
|
|
477
|
+
h = Line2D([], [], color=color, marker=marker, linestyle='none',
|
|
478
|
+
markersize=ms, markerfacecolor=facecolor,
|
|
479
|
+
markeredgecolor=color, alpha=alpha, label=label)
|
|
480
|
+
if label is not None:
|
|
481
|
+
handles.append(h)
|
|
482
|
+
|
|
483
|
+
else: # 'line'
|
|
484
|
+
lw = s.get('lw', 2.0)
|
|
485
|
+
ls = s.get('ls', '-')
|
|
486
|
+
line = ax_r.plot(
|
|
487
|
+
x[m], y[m],
|
|
488
|
+
color=color, lw=lw, ls=ls,
|
|
489
|
+
alpha=alpha, zorder=zord, label=label
|
|
490
|
+
)[0]
|
|
491
|
+
if label is not None:
|
|
492
|
+
handles.append(line)
|
|
493
|
+
|
|
494
|
+
if legend and handles:
|
|
495
|
+
lk = {} if legend_kwargs is None else dict(legend_kwargs)
|
|
496
|
+
ax_r.legend(handles=handles,
|
|
497
|
+
labels=[h.get_label() for h in handles],
|
|
498
|
+
fontsize=legend_fontsize, loc='best', **lk)
|
|
499
|
+
|
|
500
|
+
# Optional SLD panel
|
|
501
|
+
if plot_sld_profile:
|
|
502
|
+
ax_s.set_xlabel(z_label, fontsize=axis_label_size)
|
|
503
|
+
ax_s.set_ylabel(sld_label, fontsize=axis_label_size)
|
|
504
|
+
ax_s.tick_params(axis='both', which='major', labelsize=tick_label_size)
|
|
505
|
+
ax_s.tick_params(axis='both', which='minor', labelsize=tick_label_size)
|
|
506
|
+
|
|
507
|
+
if sld_series:
|
|
508
|
+
for s in sld_series:
|
|
509
|
+
zx = _np(s.get('x')); zy = _np(s.get('y'))
|
|
510
|
+
if zx is None or zy is None:
|
|
511
|
+
continue
|
|
512
|
+
label = s.get('label', None)
|
|
513
|
+
color = s.get('color', None)
|
|
514
|
+
lw = s.get('lw', 2.0)
|
|
515
|
+
ls = s.get('ls', '-')
|
|
516
|
+
alpha = s.get('alpha', 1.0)
|
|
517
|
+
zord = s.get('zorder', None)
|
|
518
|
+
ax_s.plot(zx, zy, color=color, lw=lw, ls=ls, alpha=alpha, zorder=zord, label=label)
|
|
519
|
+
|
|
520
|
+
if legend:
|
|
521
|
+
ax_s.legend(fontsize=legend_fontsize, loc='best', **(legend_kwargs or {}))
|
|
522
|
+
|
|
523
|
+
plt.tight_layout()
|
|
524
|
+
return (fig, (ax_r, ax_s)) if plot_sld_profile else (fig, ax_r)
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
from reflectorch.inference.preprocess_exp.preprocess import (
|
|
2
|
+
standard_preprocessing,
|
|
3
|
+
StandardPreprocessing,
|
|
4
|
+
)
|
|
5
|
+
from reflectorch.inference.preprocess_exp.interpolation import interp_reflectivity
|
|
6
|
+
from reflectorch.inference.preprocess_exp.attenuation import apply_attenuation_correction
|
|
7
|
+
from reflectorch.inference.preprocess_exp.footprint import apply_footprint_correction
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def apply_attenuation_correction(
|
|
5
|
+
intensity: np.ndarray,
|
|
6
|
+
attenuation: np.ndarray,
|
|
7
|
+
scattering_angle: np.ndarray = None,
|
|
8
|
+
correct_discontinuities: bool = True
|
|
9
|
+
) -> np.ndarray:
|
|
10
|
+
"""Applies attenuation correction to experimental reflectivity curves
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
intensity (np.ndarray): intensities of an experimental reflectivity curve
|
|
14
|
+
attenuation (np.ndarray): attenuation factors for each measured point
|
|
15
|
+
scattering_angle (np.ndarray, optional): scattering angles of the measured points. Defaults to None.
|
|
16
|
+
correct_discontinuities (bool, optional): whether to correct discontinuities in the measured curves. Defaults to True.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
np.ndarray: the corrected reflectivity curve
|
|
20
|
+
"""
|
|
21
|
+
intensity = intensity / attenuation
|
|
22
|
+
if correct_discontinuities:
|
|
23
|
+
if scattering_angle is None:
|
|
24
|
+
raise ValueError("correct_discontinuities options requires scattering_angle, but scattering_angle is None.")
|
|
25
|
+
intensity = apply_discontinuities_correction(intensity, scattering_angle)
|
|
26
|
+
return intensity
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def apply_discontinuities_correction(intensity: np.ndarray, scattering_angle: np.ndarray) -> np.ndarray:
|
|
30
|
+
intensity = intensity.copy()
|
|
31
|
+
diff_angle = np.diff(scattering_angle)
|
|
32
|
+
for i in range(len(diff_angle)):
|
|
33
|
+
if diff_angle[i] == 0:
|
|
34
|
+
factor = intensity[i] / intensity[i + 1]
|
|
35
|
+
intensity[(i + 1):] *= factor
|
|
36
|
+
return intensity
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from reflectorch.utils import angle_to_q
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def cut_curve(q: np.ndarray, curve: np.ndarray, max_q: float, max_angle: float, wavelength: float):
|
|
7
|
+
"""Cuts an experimental reflectivity curve at a maximum q position
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
q (np.ndarray): the array of q points
|
|
11
|
+
curve (np.ndarray): the experimental reflectivity curve
|
|
12
|
+
max_q (float): the maximum q value at which the curve is cut
|
|
13
|
+
max_angle (float): the maximum scattering angle at which the curve is cut; only used if max_q is not provided
|
|
14
|
+
wavelength (float): the wavelength of the beam
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
tuple: the q array after cutting, the reflectivity curve after cutting, and the ratio between the maximum q after cutting and before cutting
|
|
18
|
+
"""
|
|
19
|
+
if max_angle is None and max_q is None:
|
|
20
|
+
q_ratio = 1.
|
|
21
|
+
else:
|
|
22
|
+
if max_q is None:
|
|
23
|
+
max_q = angle_to_q(max_angle, wavelength)
|
|
24
|
+
|
|
25
|
+
q_ratio = max_q / q.max()
|
|
26
|
+
|
|
27
|
+
if q_ratio < 1.:
|
|
28
|
+
idx = np.argmax(q > max_q)
|
|
29
|
+
q = q[:idx] / q_ratio
|
|
30
|
+
curve = curve[:idx]
|
|
31
|
+
return q, curve, q_ratio
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
try:
|
|
2
|
+
from typing import Literal
|
|
3
|
+
except ImportError:
|
|
4
|
+
from typing_extensions import Literal
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from scipy.special import erf
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"apply_footprint_correction",
|
|
11
|
+
"remove_footprint_correction",
|
|
12
|
+
"BEAM_SHAPE",
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
BEAM_SHAPE = Literal["gauss", "box"]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def apply_footprint_correction(
|
|
20
|
+
intensity: np.ndarray,
|
|
21
|
+
scattering_angle: np.ndarray,
|
|
22
|
+
beam_width: float,
|
|
23
|
+
sample_length: float,
|
|
24
|
+
beam_shape: BEAM_SHAPE = "gauss",
|
|
25
|
+
) -> np.ndarray:
|
|
26
|
+
"""Applies footprint correction to an experimental reflectivity curve
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
intensity (np.ndarray): reflectivity curve
|
|
30
|
+
scattering_angle (np.ndarray): array of scattering angles
|
|
31
|
+
beam_width (float): the beam width
|
|
32
|
+
sample_length (float): the sample length
|
|
33
|
+
beam_shape (BEAM_SHAPE, optional): the shape of the beam, either "gauss" or "box". Defaults to "gauss".
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
np.ndarray: the footprint corrected reflectivity curve
|
|
37
|
+
"""
|
|
38
|
+
factors = _get_factors_by_beam_shape(
|
|
39
|
+
scattering_angle, beam_width, sample_length, beam_shape
|
|
40
|
+
)
|
|
41
|
+
return intensity.copy() * factors
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def remove_footprint_correction(
|
|
45
|
+
intensity: np.ndarray,
|
|
46
|
+
scattering_angle: np.ndarray,
|
|
47
|
+
beam_width: float,
|
|
48
|
+
sample_length: float,
|
|
49
|
+
beam_shape: BEAM_SHAPE = "gauss",
|
|
50
|
+
):
|
|
51
|
+
factors = _get_factors_by_beam_shape(
|
|
52
|
+
scattering_angle, beam_width, sample_length, beam_shape
|
|
53
|
+
)
|
|
54
|
+
return intensity.copy() / factors
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _get_factors_by_beam_shape(
|
|
58
|
+
scattering_angle: np.ndarray, beam_width: float, sample_length: float, beam_shape: BEAM_SHAPE
|
|
59
|
+
):
|
|
60
|
+
if beam_shape == "gauss":
|
|
61
|
+
return gaussian_factors(scattering_angle, beam_width, sample_length)
|
|
62
|
+
elif beam_shape == "box":
|
|
63
|
+
return box_factors(scattering_angle, beam_width, sample_length)
|
|
64
|
+
else:
|
|
65
|
+
raise ValueError("invalid beam shape")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def box_factors(scattering_angle, beam_width, sample_length):
|
|
69
|
+
max_angle = 2 * np.arcsin(beam_width / sample_length) / np.pi * 180
|
|
70
|
+
ratios = beam_footprint_ratio(scattering_angle, beam_width, sample_length)
|
|
71
|
+
ones = np.ones_like(scattering_angle)
|
|
72
|
+
return np.where(scattering_angle < max_angle, ones * ratios, ones)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def gaussian_factors(scattering_angle, beam_width, sample_length):
|
|
76
|
+
ratio = beam_footprint_ratio(scattering_angle, beam_width, sample_length)
|
|
77
|
+
return 1 / erf(np.sqrt(np.log(2)) / ratio)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def beam_footprint_ratio(scattering_angle, beam_width, sample_length):
|
|
81
|
+
return beam_width / sample_length / np.sin(scattering_angle / 2 * np.pi / 180)
|