reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +17 -17
- reflectorch/data_generation/__init__.py +128 -126
- reflectorch/data_generation/dataset.py +210 -210
- reflectorch/data_generation/likelihoods.py +80 -80
- reflectorch/data_generation/noise.py +470 -470
- reflectorch/data_generation/priors/__init__.py +60 -60
- reflectorch/data_generation/priors/base.py +55 -55
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
- reflectorch/data_generation/priors/independent_priors.py +195 -195
- reflectorch/data_generation/priors/multilayer_models.py +311 -311
- reflectorch/data_generation/priors/multilayer_structures.py +104 -104
- reflectorch/data_generation/priors/no_constraints.py +206 -206
- reflectorch/data_generation/priors/parametric_models.py +841 -841
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
- reflectorch/data_generation/priors/params.py +252 -252
- reflectorch/data_generation/priors/sampler_strategies.py +369 -369
- reflectorch/data_generation/priors/scaler_mixin.py +65 -65
- reflectorch/data_generation/priors/subprior_sampler.py +371 -371
- reflectorch/data_generation/priors/utils.py +118 -118
- reflectorch/data_generation/process_data.py +41 -41
- reflectorch/data_generation/q_generator.py +280 -246
- reflectorch/data_generation/reflectivity/__init__.py +102 -102
- reflectorch/data_generation/reflectivity/abeles.py +97 -97
- reflectorch/data_generation/reflectivity/kinematical.py +70 -70
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
- reflectorch/data_generation/reflectivity/smearing.py +138 -138
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
- reflectorch/data_generation/scale_curves.py +112 -112
- reflectorch/data_generation/smearing.py +98 -98
- reflectorch/data_generation/utils.py +223 -222
- reflectorch/extensions/jupyter/__init__.py +11 -6
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -34
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -5
- reflectorch/extensions/matplotlib/losses.py +32 -32
- reflectorch/extensions/refnx/refnx_conversion.py +76 -76
- reflectorch/inference/__init__.py +28 -24
- reflectorch/inference/inference_model.py +847 -851
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +37 -0
- reflectorch/inference/multilayer_fitter.py +171 -171
- reflectorch/inference/multilayer_inference_model.py +193 -193
- reflectorch/inference/plotting.py +524 -98
- reflectorch/inference/preprocess_exp/__init__.py +6 -6
- reflectorch/inference/preprocess_exp/attenuation.py +36 -36
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
- reflectorch/inference/preprocess_exp/footprint.py +81 -81
- reflectorch/inference/preprocess_exp/interpolation.py +19 -16
- reflectorch/inference/preprocess_exp/normalize.py +21 -21
- reflectorch/inference/preprocess_exp/preprocess.py +121 -121
- reflectorch/inference/query_matcher.py +81 -81
- reflectorch/inference/record_time.py +43 -43
- reflectorch/inference/sampler_solution.py +56 -56
- reflectorch/inference/scipy_fitter.py +272 -248
- reflectorch/inference/torch_fitter.py +87 -87
- reflectorch/ml/__init__.py +32 -32
- reflectorch/ml/basic_trainer.py +292 -292
- reflectorch/ml/callbacks.py +80 -80
- reflectorch/ml/dataloaders.py +26 -26
- reflectorch/ml/loggers.py +55 -55
- reflectorch/ml/schedulers.py +355 -355
- reflectorch/ml/trainers.py +200 -191
- reflectorch/ml/utils.py +2 -2
- reflectorch/models/__init__.py +15 -14
- reflectorch/models/activations.py +50 -50
- reflectorch/models/encoders/__init__.py +19 -17
- reflectorch/models/encoders/conv_encoder.py +218 -218
- reflectorch/models/encoders/conv_res_net.py +115 -115
- reflectorch/models/encoders/fno.py +133 -133
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -14
- reflectorch/models/networks/mlp_networks.py +434 -428
- reflectorch/models/networks/residual_net.py +156 -156
- reflectorch/paths.py +29 -27
- reflectorch/runs/__init__.py +31 -31
- reflectorch/runs/config.py +25 -25
- reflectorch/runs/slurm_utils.py +93 -93
- reflectorch/runs/train.py +78 -78
- reflectorch/runs/utils.py +404 -401
- reflectorch/test_config.py +4 -4
- reflectorch/train.py +4 -4
- reflectorch/train_on_cluster.py +4 -4
- reflectorch/utils.py +98 -68
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
- reflectorch-1.5.0.dist-info/RECORD +96 -0
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
- reflectorch-1.3.0.dist-info/RECORD +0 -86
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,98 +1,524 @@
|
|
|
1
|
-
|
|
2
|
-
import
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
""
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
1
|
+
import numpy as np
|
|
2
|
+
import matplotlib.pyplot as plt
|
|
3
|
+
import matplotlib.ticker as mticker
|
|
4
|
+
from matplotlib.lines import Line2D
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def print_prediction_results(prediction_dict, param_names=None, width=10, precision=3, header=True):
|
|
8
|
+
|
|
9
|
+
if param_names is None:
|
|
10
|
+
param_names = prediction_dict.get("param_names", [])
|
|
11
|
+
|
|
12
|
+
pred = np.asarray(prediction_dict.get("predicted_params_array", []), dtype=float)
|
|
13
|
+
pol = prediction_dict.get("polished_params_array", None)
|
|
14
|
+
pol = np.asarray(pol, dtype=float) if pol is not None else None
|
|
15
|
+
pol_err = prediction_dict.get('polished_params_error_array', None)
|
|
16
|
+
pol_err = np.asarray(pol_err, dtype=float) if pol_err is not None else None
|
|
17
|
+
|
|
18
|
+
name_w = max(14, max((len(str(n)) for n in param_names), default=14))
|
|
19
|
+
|
|
20
|
+
num_fmt = f"{{:>{width}.{precision}f}}"
|
|
21
|
+
|
|
22
|
+
if header:
|
|
23
|
+
hdr = f"{'Parameter'.ljust(name_w)} {'Predicted'.rjust(width)}"
|
|
24
|
+
if pol is not None:
|
|
25
|
+
hdr += f" {'Polished'.rjust(width)}"
|
|
26
|
+
if pol_err is not None:
|
|
27
|
+
hdr += f" {'Polished err'.rjust(width)}"
|
|
28
|
+
print(hdr)
|
|
29
|
+
print("-" * len(hdr))
|
|
30
|
+
|
|
31
|
+
for i, name in enumerate(param_names):
|
|
32
|
+
pred_val = pred[i] if i < pred.size else float("nan")
|
|
33
|
+
row = f"{str(name).ljust(name_w)} {num_fmt.format(pred_val)}"
|
|
34
|
+
if pol is not None:
|
|
35
|
+
pol_val = pol[i] if i < pol.size else float("nan")
|
|
36
|
+
row += f" {num_fmt.format(pol_val)}"
|
|
37
|
+
if pol_err is not None:
|
|
38
|
+
pol_err_val = pol_err[i] if i < pol_err.size else float('nan')
|
|
39
|
+
row += f' {num_fmt.format(pol_err_val)}'
|
|
40
|
+
print(row)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def plot_prediction_results(
|
|
44
|
+
prediction_dict: dict,
|
|
45
|
+
q_exp: np.ndarray,
|
|
46
|
+
curve_exp: np.ndarray,
|
|
47
|
+
sigmas_exp: np.ndarray = None,
|
|
48
|
+
logx=False,
|
|
49
|
+
):
|
|
50
|
+
q_pred = prediction_dict['q_plot_pred']
|
|
51
|
+
r_pred = prediction_dict['predicted_curve']
|
|
52
|
+
r_pol = prediction_dict.get('polished_curve', None)
|
|
53
|
+
|
|
54
|
+
q_pol = None
|
|
55
|
+
if r_pol is not None:
|
|
56
|
+
if len(r_pol) == len(q_pred):
|
|
57
|
+
q_pol = q_pred
|
|
58
|
+
elif len(r_pol) == len(q_exp):
|
|
59
|
+
q_pol = q_exp
|
|
60
|
+
|
|
61
|
+
z_sld = prediction_dict.get('predicted_sld_xaxis', None)
|
|
62
|
+
sld_pred_c = prediction_dict.get('predicted_sld_profile', None)
|
|
63
|
+
sld_pol_c = prediction_dict.get('sld_profile_polished', None)
|
|
64
|
+
|
|
65
|
+
plot_sld = (z_sld is not None) and (sld_pred_c is not None or sld_pol_c is not None)
|
|
66
|
+
|
|
67
|
+
sld_is_complex = np.iscomplexobj(sld_pred_c)
|
|
68
|
+
|
|
69
|
+
sld_pred_label = 'pred. SLD (Re)' if sld_is_complex else 'pred. SLD'
|
|
70
|
+
sld_pol_label = 'polished SLD (Re)' if sld_is_complex else 'polished SLD'
|
|
71
|
+
|
|
72
|
+
fig, axes = plot_reflectivity(
|
|
73
|
+
q_exp=q_exp, r_exp=curve_exp, yerr=sigmas_exp,
|
|
74
|
+
q_pred=q_pred, r_pred=r_pred,
|
|
75
|
+
q_pol=q_pol, r_pol=r_pol,
|
|
76
|
+
z_sld=z_sld,
|
|
77
|
+
sld_pred=sld_pred_c.real if sld_pred_c is not None else None,
|
|
78
|
+
sld_pol=sld_pol_c.real if sld_pol_c is not None else None,
|
|
79
|
+
sld_pred_label=sld_pred_label,
|
|
80
|
+
sld_pol_label=sld_pol_label,
|
|
81
|
+
plot_sld_profile=plot_sld,
|
|
82
|
+
logx=logx,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
if sld_is_complex and plot_sld:
|
|
86
|
+
ax_r, ax_s = axes
|
|
87
|
+
ax_s.plot(z_sld, sld_pred_c.imag, color='darkgreen', lw=2.0, ls='-', zorder=4, label='pred. SLD (Im)')
|
|
88
|
+
if sld_pol_c is not None:
|
|
89
|
+
ax_s.plot(z_sld, sld_pol_c.imag, color='cyan', lw=2.0, ls='--', zorder=5, label='polished SLD (Im)')
|
|
90
|
+
ax_s.legend(fontsize=14, frameon=True)
|
|
91
|
+
|
|
92
|
+
return fig, axes
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def plot_reflectivity(
|
|
96
|
+
*,
|
|
97
|
+
q_exp=None,
|
|
98
|
+
r_exp=None,
|
|
99
|
+
yerr=None,
|
|
100
|
+
xerr=None,
|
|
101
|
+
q_pred=None,
|
|
102
|
+
r_pred=None,
|
|
103
|
+
q_pol=None,
|
|
104
|
+
r_pol=None,
|
|
105
|
+
z_sld=None,
|
|
106
|
+
sld_pred=None,
|
|
107
|
+
sld_pol=None,
|
|
108
|
+
plot_sld_profile=False,
|
|
109
|
+
figsize=None,
|
|
110
|
+
logx=False,
|
|
111
|
+
logy=True,
|
|
112
|
+
x_ticks_log=None,
|
|
113
|
+
y_ticks_log=(10.0 ** -np.arange(0, 12, 2)),
|
|
114
|
+
q_label=r'q [$\mathrm{\AA^{-1}}$]',
|
|
115
|
+
r_label='R(q)',
|
|
116
|
+
z_label=r'z [$\mathrm{\AA}$]',
|
|
117
|
+
sld_label=r'SLD [$10^{-6}\ \mathrm{\AA^{-2}}$]',
|
|
118
|
+
xlim=None,
|
|
119
|
+
axis_label_size=20,
|
|
120
|
+
tick_label_size=15,
|
|
121
|
+
legend_fontsize=14,
|
|
122
|
+
exp_style='auto',
|
|
123
|
+
exp_color='blue',
|
|
124
|
+
exp_facecolor='none',
|
|
125
|
+
exp_marker='o',
|
|
126
|
+
exp_ms=3,
|
|
127
|
+
exp_alpha=1.0,
|
|
128
|
+
exp_errcolor='purple',
|
|
129
|
+
exp_elinewidth=1.0,
|
|
130
|
+
exp_capsize=1.0,
|
|
131
|
+
exp_capthick=1.0,
|
|
132
|
+
exp_zorder=2,
|
|
133
|
+
pred_color='red',
|
|
134
|
+
pred_lw=2.0,
|
|
135
|
+
pred_ls='-',
|
|
136
|
+
pred_alpha=1.0,
|
|
137
|
+
pred_zorder=3,
|
|
138
|
+
pol_color='orange',
|
|
139
|
+
pol_lw=2.0,
|
|
140
|
+
pol_ls='--',
|
|
141
|
+
pol_alpha=1.0,
|
|
142
|
+
pol_zorder=4,
|
|
143
|
+
sld_pred_color='red',
|
|
144
|
+
sld_pred_lw=2.0,
|
|
145
|
+
sld_pred_ls='-',
|
|
146
|
+
sld_pol_color='orange',
|
|
147
|
+
sld_pol_lw=2.0,
|
|
148
|
+
sld_pol_ls='--',
|
|
149
|
+
exp_label='exp. data',
|
|
150
|
+
pred_label='prediction',
|
|
151
|
+
pol_label='polished prediction',
|
|
152
|
+
sld_pred_label='pred. SLD',
|
|
153
|
+
sld_pol_label='polished SLD',
|
|
154
|
+
legend=True,
|
|
155
|
+
legend_kwargs=None
|
|
156
|
+
):
|
|
157
|
+
|
|
158
|
+
def _np(a):
|
|
159
|
+
return None if a is None else np.asarray(a)
|
|
160
|
+
|
|
161
|
+
def _mask(x, y):
|
|
162
|
+
m = np.isfinite(x) & np.isfinite(y)
|
|
163
|
+
if logx: m &= (x > 0.0)
|
|
164
|
+
if logy: m &= (y > 0.0)
|
|
165
|
+
return m
|
|
166
|
+
|
|
167
|
+
def _slice_sym_err(err, mask):
|
|
168
|
+
if err is None:
|
|
169
|
+
return None
|
|
170
|
+
if np.isscalar(err):
|
|
171
|
+
return err
|
|
172
|
+
e = np.asarray(err)
|
|
173
|
+
if e.ndim != 1:
|
|
174
|
+
raise ValueError("Errors must be scalar or 1-D array.")
|
|
175
|
+
return e[mask]
|
|
176
|
+
|
|
177
|
+
q_exp, r_exp, yerr, xerr = _np(q_exp), _np(r_exp), _np(yerr), _np(xerr)
|
|
178
|
+
q_pred, r_pred = _np(q_pred), _np(r_pred)
|
|
179
|
+
q_pol, r_pol = _np(q_pol), _np(r_pol)
|
|
180
|
+
z_sld, sld_pred, sld_pol = _np(z_sld), _np(sld_pred), _np(sld_pol)
|
|
181
|
+
|
|
182
|
+
# Figure & axes
|
|
183
|
+
if figsize is None:
|
|
184
|
+
figsize = (12, 6) if plot_sld_profile else (6, 6)
|
|
185
|
+
if plot_sld_profile:
|
|
186
|
+
fig, (ax_r, ax_s) = plt.subplots(1, 2, figsize=figsize)
|
|
187
|
+
else:
|
|
188
|
+
fig, ax_r = plt.subplots(1, 1, figsize=figsize)
|
|
189
|
+
ax_s = None
|
|
190
|
+
|
|
191
|
+
# Apply x-limits (right-only or both)
|
|
192
|
+
if xlim is not None:
|
|
193
|
+
if np.isscalar(xlim):
|
|
194
|
+
cur_left, _ = ax_r.get_xlim()
|
|
195
|
+
if logx and cur_left <= 0:
|
|
196
|
+
cur_left = 1e-12
|
|
197
|
+
ax_r.set_xlim(left=cur_left, right=float(xlim))
|
|
198
|
+
else:
|
|
199
|
+
xmin, xmax = xlim
|
|
200
|
+
if logx and xmin is not None and xmin <= 0:
|
|
201
|
+
raise ValueError("For log-x, xmin must be > 0.")
|
|
202
|
+
ax_r.set_xlim(left=xmin, right=xmax)
|
|
203
|
+
|
|
204
|
+
# Axis scales / labels / ticks
|
|
205
|
+
if logx: ax_r.set_xscale('log')
|
|
206
|
+
if logy: ax_r.set_yscale('log')
|
|
207
|
+
|
|
208
|
+
ax_r.set_xlabel(q_label, fontsize=axis_label_size)
|
|
209
|
+
ax_r.set_ylabel(r_label, fontsize=axis_label_size)
|
|
210
|
+
ax_r.tick_params(axis='both', which='major', labelsize=tick_label_size)
|
|
211
|
+
ax_r.tick_params(axis='both', which='minor', labelsize=tick_label_size)
|
|
212
|
+
if logx and x_ticks_log is not None:
|
|
213
|
+
ax_r.xaxis.set_major_locator(mticker.FixedLocator(x_ticks_log))
|
|
214
|
+
if logy and y_ticks_log is not None:
|
|
215
|
+
ax_r.yaxis.set_major_locator(mticker.FixedLocator(y_ticks_log))
|
|
216
|
+
|
|
217
|
+
handles = []
|
|
218
|
+
|
|
219
|
+
# Experimental plot
|
|
220
|
+
exp_handle = None
|
|
221
|
+
if q_exp is not None and r_exp is not None:
|
|
222
|
+
m = _mask(q_exp, r_exp)
|
|
223
|
+
style = exp_style if exp_style != 'auto' else ('errorbar' if yerr is not None else 'scatter')
|
|
224
|
+
|
|
225
|
+
if style == 'errorbar' and (yerr is not None):
|
|
226
|
+
yerr_m = _slice_sym_err(yerr, m)
|
|
227
|
+
xerr_m = _slice_sym_err(xerr, m)
|
|
228
|
+
ax_r.errorbar(
|
|
229
|
+
q_exp[m], r_exp[m], yerr=yerr_m, xerr=xerr_m,
|
|
230
|
+
color=exp_color, ecolor=exp_errcolor,
|
|
231
|
+
elinewidth=exp_elinewidth, capsize=exp_capsize,
|
|
232
|
+
capthick=(exp_elinewidth if exp_capthick is None else exp_capthick),
|
|
233
|
+
marker=exp_marker, linestyle='none', markersize=exp_ms,
|
|
234
|
+
markerfacecolor=exp_facecolor, markeredgecolor=exp_color,
|
|
235
|
+
alpha=exp_alpha, zorder=exp_zorder, label=None
|
|
236
|
+
)
|
|
237
|
+
exp_handle = Line2D([], [], color=exp_color, marker=exp_marker,
|
|
238
|
+
linestyle='none', markersize=exp_ms,
|
|
239
|
+
markerfacecolor=exp_facecolor, markeredgecolor=exp_color,
|
|
240
|
+
alpha=exp_alpha, label=exp_label)
|
|
241
|
+
elif style == 'scatter':
|
|
242
|
+
ax_r.scatter(
|
|
243
|
+
q_exp[m], r_exp[m],
|
|
244
|
+
s=exp_ms**2, marker=exp_marker,
|
|
245
|
+
facecolors=exp_facecolor, edgecolors=exp_color,
|
|
246
|
+
alpha=exp_alpha, zorder=exp_zorder, label=None
|
|
247
|
+
)
|
|
248
|
+
exp_handle = Line2D([], [], color=exp_color, marker=exp_marker,
|
|
249
|
+
linestyle='none', markersize=exp_ms,
|
|
250
|
+
markerfacecolor=exp_facecolor, markeredgecolor=exp_color,
|
|
251
|
+
alpha=exp_alpha, label=exp_label)
|
|
252
|
+
else: # 'line'
|
|
253
|
+
ln = ax_r.plot(
|
|
254
|
+
q_exp[m], r_exp[m], color=exp_color, lw=1.0, ls='-',
|
|
255
|
+
alpha=exp_alpha, zorder=exp_zorder, label=exp_label
|
|
256
|
+
)[0]
|
|
257
|
+
exp_handle = ln
|
|
258
|
+
|
|
259
|
+
if exp_handle is not None:
|
|
260
|
+
handles.append(exp_handle)
|
|
261
|
+
|
|
262
|
+
# Predicted line
|
|
263
|
+
pred_handle = None
|
|
264
|
+
if q_pred is not None and r_pred is not None:
|
|
265
|
+
m = _mask(q_pred, r_pred)
|
|
266
|
+
pred_handle = ax_r.plot(
|
|
267
|
+
q_pred[m], r_pred[m],
|
|
268
|
+
color=pred_color, lw=pred_lw, ls=pred_ls,
|
|
269
|
+
alpha=pred_alpha, zorder=pred_zorder, label=pred_label
|
|
270
|
+
)[0]
|
|
271
|
+
handles.append(pred_handle)
|
|
272
|
+
|
|
273
|
+
# Polished line
|
|
274
|
+
pol_handle = None
|
|
275
|
+
if q_pol is not None and r_pol is not None:
|
|
276
|
+
m = _mask(q_pol, r_pol)
|
|
277
|
+
pol_handle = ax_r.plot(
|
|
278
|
+
q_pol[m], r_pol[m],
|
|
279
|
+
color=pol_color, lw=pol_lw, ls=pol_ls,
|
|
280
|
+
alpha=pol_alpha, zorder=pol_zorder, label=pol_label
|
|
281
|
+
)[0]
|
|
282
|
+
handles.append(pol_handle)
|
|
283
|
+
|
|
284
|
+
if legend and handles:
|
|
285
|
+
lk = {} if legend_kwargs is None else dict(legend_kwargs)
|
|
286
|
+
ax_r.legend(handles=handles,
|
|
287
|
+
labels=[h.get_label() for h in handles],
|
|
288
|
+
fontsize=legend_fontsize, loc='best', **lk)
|
|
289
|
+
|
|
290
|
+
# SLD panel (optional)
|
|
291
|
+
if ax_s is not None:
|
|
292
|
+
ax_s.set_xlabel(z_label, fontsize=axis_label_size)
|
|
293
|
+
ax_s.set_ylabel(sld_label, fontsize=axis_label_size)
|
|
294
|
+
ax_s.tick_params(axis='both', which='major', labelsize=tick_label_size)
|
|
295
|
+
ax_s.tick_params(axis='both', which='minor', labelsize=tick_label_size)
|
|
296
|
+
|
|
297
|
+
if z_sld is not None and sld_pred is not None:
|
|
298
|
+
ax_s.plot(z_sld, sld_pred,
|
|
299
|
+
color=sld_pred_color, lw=sld_pred_lw, ls=sld_pred_ls,
|
|
300
|
+
label=sld_pred_label)
|
|
301
|
+
if z_sld is not None and sld_pol is not None:
|
|
302
|
+
ax_s.plot(z_sld, sld_pol,
|
|
303
|
+
color=sld_pol_color, lw=sld_pol_lw, ls=sld_pol_ls,
|
|
304
|
+
label=sld_pol_label)
|
|
305
|
+
|
|
306
|
+
if legend:
|
|
307
|
+
ax_s.legend(fontsize=legend_fontsize, loc='best', **(legend_kwargs or {}))
|
|
308
|
+
|
|
309
|
+
plt.tight_layout()
|
|
310
|
+
return (fig, (ax_r, ax_s)) if ax_s is not None else (fig, ax_r)
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def plot_reflectivity_multi(
|
|
314
|
+
*,
|
|
315
|
+
rq_series,
|
|
316
|
+
sld_series=None,
|
|
317
|
+
plot_sld_profile=False,
|
|
318
|
+
figsize=None,
|
|
319
|
+
logx=False,
|
|
320
|
+
logy=True,
|
|
321
|
+
xlim=None,
|
|
322
|
+
x_ticks_log=None,
|
|
323
|
+
y_ticks_log=(10.0 ** -np.arange(0, 12, 2)),
|
|
324
|
+
q_label=r'q [$\mathrm{\AA^{-1}}$]',
|
|
325
|
+
r_label='R(q)',
|
|
326
|
+
z_label=r'z [$\mathrm{\AA}$]',
|
|
327
|
+
sld_label=r'SLD [$10^{-6}\ \mathrm{\AA^{-2}}$]',
|
|
328
|
+
axis_label_size=20,
|
|
329
|
+
tick_label_size=15,
|
|
330
|
+
legend=True,
|
|
331
|
+
legend_fontsize=12,
|
|
332
|
+
legend_kwargs=None,
|
|
333
|
+
):
|
|
334
|
+
"""
|
|
335
|
+
Plot multiple R(q) series (and optional SLD lines) with per-series styling.
|
|
336
|
+
|
|
337
|
+
rq_series: list of dicts, each with:
|
|
338
|
+
required:
|
|
339
|
+
- x: 1D array
|
|
340
|
+
- y: 1D array
|
|
341
|
+
optional (per series):
|
|
342
|
+
- kind: 'errorbar' | 'scatter' | 'line' (default 'line')
|
|
343
|
+
- label: str
|
|
344
|
+
- color: str
|
|
345
|
+
- alpha: float (0..1)
|
|
346
|
+
- zorder: int
|
|
347
|
+
# scatter / marker for errorbar:
|
|
348
|
+
- marker: str (default 'o')
|
|
349
|
+
- ms: float (marker size in pt; for scatter internally converted to s=ms**2)
|
|
350
|
+
- facecolor: str (scatter marker face)
|
|
351
|
+
# errorbar only:
|
|
352
|
+
- yerr: scalar or 1D array
|
|
353
|
+
- xerr: scalar or 1D array
|
|
354
|
+
- ecolor: str
|
|
355
|
+
- elinewidth: float
|
|
356
|
+
- capsize: float
|
|
357
|
+
- capthick: float
|
|
358
|
+
# line only:
|
|
359
|
+
- lw: float
|
|
360
|
+
- ls: str (e.g. '-', '--', ':')
|
|
361
|
+
|
|
362
|
+
sld_series: list of dicts (only lines), each with:
|
|
363
|
+
- x: 1D array (z-axis)
|
|
364
|
+
- y: 1D array (SLD)
|
|
365
|
+
- label: str (optional)
|
|
366
|
+
- color: str (optional)
|
|
367
|
+
- lw: float (optional)
|
|
368
|
+
- ls: str (optional)
|
|
369
|
+
- alpha: float (optional)
|
|
370
|
+
- zorder: int (optional)
|
|
371
|
+
"""
|
|
372
|
+
|
|
373
|
+
def _np(a): return None if a is None else np.asarray(a)
|
|
374
|
+
|
|
375
|
+
def _mask(x, y):
|
|
376
|
+
m = np.isfinite(x) & np.isfinite(y)
|
|
377
|
+
if logx: m &= (x > 0.0)
|
|
378
|
+
if logy: m &= (y > 0.0)
|
|
379
|
+
return m
|
|
380
|
+
|
|
381
|
+
# Figure & axes
|
|
382
|
+
if figsize is None:
|
|
383
|
+
figsize = (12, 6) if plot_sld_profile else (6, 6)
|
|
384
|
+
if plot_sld_profile:
|
|
385
|
+
fig, (ax_r, ax_s) = plt.subplots(1, 2, figsize=figsize)
|
|
386
|
+
else:
|
|
387
|
+
fig, ax_r = plt.subplots(1, 1, figsize=figsize)
|
|
388
|
+
ax_s = None
|
|
389
|
+
|
|
390
|
+
# Axis scales / labels / ticks
|
|
391
|
+
if logx: ax_r.set_xscale('log')
|
|
392
|
+
if logy: ax_r.set_yscale('log')
|
|
393
|
+
|
|
394
|
+
ax_r.set_xlabel(q_label, fontsize=axis_label_size)
|
|
395
|
+
ax_r.set_ylabel(r_label, fontsize=axis_label_size)
|
|
396
|
+
ax_r.tick_params(axis='both', which='major', labelsize=tick_label_size)
|
|
397
|
+
ax_r.tick_params(axis='both', which='minor', labelsize=tick_label_size)
|
|
398
|
+
if logx and x_ticks_log is not None:
|
|
399
|
+
ax_r.xaxis.set_major_locator(mticker.FixedLocator(x_ticks_log))
|
|
400
|
+
if logy and y_ticks_log is not None:
|
|
401
|
+
ax_r.yaxis.set_major_locator(mticker.FixedLocator(y_ticks_log))
|
|
402
|
+
|
|
403
|
+
# Apply x-limits (right-only or both)
|
|
404
|
+
if xlim is not None:
|
|
405
|
+
if np.isscalar(xlim):
|
|
406
|
+
left, _ = ax_r.get_xlim()
|
|
407
|
+
if logx and left <= 0:
|
|
408
|
+
left = 1e-12
|
|
409
|
+
ax_r.set_xlim(left=left, right=float(xlim))
|
|
410
|
+
else:
|
|
411
|
+
xmin, xmax = xlim
|
|
412
|
+
if logx and xmin is not None and xmin <= 0:
|
|
413
|
+
raise ValueError("For log-x, xmin must be > 0.")
|
|
414
|
+
ax_r.set_xlim(left=xmin, right=xmax)
|
|
415
|
+
|
|
416
|
+
# Plot all R(q) series in the given order (order = legend order)
|
|
417
|
+
handles = []
|
|
418
|
+
for s in rq_series:
|
|
419
|
+
kind = s.get('kind', 'line')
|
|
420
|
+
x = _np(s.get('x'))
|
|
421
|
+
y = _np(s.get('y'))
|
|
422
|
+
if x is None or y is None:
|
|
423
|
+
continue
|
|
424
|
+
|
|
425
|
+
label = s.get('label', None)
|
|
426
|
+
color = s.get('color', None)
|
|
427
|
+
alpha = s.get('alpha', 1.0)
|
|
428
|
+
zord = s.get('zorder', None)
|
|
429
|
+
ms = s.get('ms', 5.0)
|
|
430
|
+
marker = s.get('marker', 'o')
|
|
431
|
+
|
|
432
|
+
m = _mask(x, y)
|
|
433
|
+
|
|
434
|
+
if kind == 'errorbar':
|
|
435
|
+
yerr = s.get('yerr', None)
|
|
436
|
+
xerr = s.get('xerr', None)
|
|
437
|
+
ecolor = s.get('ecolor', color)
|
|
438
|
+
elinewidth = s.get('elinewidth', 1.0)
|
|
439
|
+
capsize = s.get('capsize', 0.0)
|
|
440
|
+
capthick = s.get('capthick', elinewidth)
|
|
441
|
+
|
|
442
|
+
# Symmetric error input: scalar or 1D
|
|
443
|
+
def _slice_sym(err):
|
|
444
|
+
if err is None: return None
|
|
445
|
+
if np.isscalar(err): return err
|
|
446
|
+
arr = np.asarray(err)
|
|
447
|
+
if arr.ndim != 1:
|
|
448
|
+
raise ValueError("For symmetric error bars, provide scalar or 1-D array.")
|
|
449
|
+
return arr[m]
|
|
450
|
+
|
|
451
|
+
yerr_m = _slice_sym(yerr)
|
|
452
|
+
xerr_m = _slice_sym(xerr)
|
|
453
|
+
|
|
454
|
+
ax_r.errorbar(
|
|
455
|
+
x[m], y[m], yerr=yerr_m, xerr=xerr_m,
|
|
456
|
+
color=color, ecolor=ecolor,
|
|
457
|
+
elinewidth=elinewidth, capsize=capsize, capthick=capthick,
|
|
458
|
+
marker=marker, linestyle='none', markersize=ms,
|
|
459
|
+
markerfacecolor=s.get('facecolor', 'none'),
|
|
460
|
+
markeredgecolor=color,
|
|
461
|
+
alpha=alpha, zorder=zord, label=None
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
h = Line2D([], [], color=color, marker=marker, linestyle='none',
|
|
465
|
+
markersize=ms, markerfacecolor=s.get('facecolor','none'),
|
|
466
|
+
markeredgecolor=color, alpha=alpha, label=label)
|
|
467
|
+
if label is not None:
|
|
468
|
+
handles.append(h)
|
|
469
|
+
|
|
470
|
+
elif kind == 'scatter':
|
|
471
|
+
facecolor = s.get('facecolor', 'none')
|
|
472
|
+
sc = ax_r.scatter(
|
|
473
|
+
x[m], y[m], s=ms**2, marker=marker,
|
|
474
|
+
facecolors=facecolor, edgecolors=color,
|
|
475
|
+
alpha=alpha, zorder=zord, label=None
|
|
476
|
+
)
|
|
477
|
+
h = Line2D([], [], color=color, marker=marker, linestyle='none',
|
|
478
|
+
markersize=ms, markerfacecolor=facecolor,
|
|
479
|
+
markeredgecolor=color, alpha=alpha, label=label)
|
|
480
|
+
if label is not None:
|
|
481
|
+
handles.append(h)
|
|
482
|
+
|
|
483
|
+
else: # 'line'
|
|
484
|
+
lw = s.get('lw', 2.0)
|
|
485
|
+
ls = s.get('ls', '-')
|
|
486
|
+
line = ax_r.plot(
|
|
487
|
+
x[m], y[m],
|
|
488
|
+
color=color, lw=lw, ls=ls,
|
|
489
|
+
alpha=alpha, zorder=zord, label=label
|
|
490
|
+
)[0]
|
|
491
|
+
if label is not None:
|
|
492
|
+
handles.append(line)
|
|
493
|
+
|
|
494
|
+
if legend and handles:
|
|
495
|
+
lk = {} if legend_kwargs is None else dict(legend_kwargs)
|
|
496
|
+
ax_r.legend(handles=handles,
|
|
497
|
+
labels=[h.get_label() for h in handles],
|
|
498
|
+
fontsize=legend_fontsize, loc='best', **lk)
|
|
499
|
+
|
|
500
|
+
# Optional SLD panel
|
|
501
|
+
if plot_sld_profile:
|
|
502
|
+
ax_s.set_xlabel(z_label, fontsize=axis_label_size)
|
|
503
|
+
ax_s.set_ylabel(sld_label, fontsize=axis_label_size)
|
|
504
|
+
ax_s.tick_params(axis='both', which='major', labelsize=tick_label_size)
|
|
505
|
+
ax_s.tick_params(axis='both', which='minor', labelsize=tick_label_size)
|
|
506
|
+
|
|
507
|
+
if sld_series:
|
|
508
|
+
for s in sld_series:
|
|
509
|
+
zx = _np(s.get('x')); zy = _np(s.get('y'))
|
|
510
|
+
if zx is None or zy is None:
|
|
511
|
+
continue
|
|
512
|
+
label = s.get('label', None)
|
|
513
|
+
color = s.get('color', None)
|
|
514
|
+
lw = s.get('lw', 2.0)
|
|
515
|
+
ls = s.get('ls', '-')
|
|
516
|
+
alpha = s.get('alpha', 1.0)
|
|
517
|
+
zord = s.get('zorder', None)
|
|
518
|
+
ax_s.plot(zx, zy, color=color, lw=lw, ls=ls, alpha=alpha, zorder=zord, label=label)
|
|
519
|
+
|
|
520
|
+
if legend:
|
|
521
|
+
ax_s.legend(fontsize=legend_fontsize, loc='best', **(legend_kwargs or {}))
|
|
522
|
+
|
|
523
|
+
plt.tight_layout()
|
|
524
|
+
return (fig, (ax_r, ax_s)) if plot_sld_profile else (fig, ax_r)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
from reflectorch.inference.preprocess_exp.preprocess import (
|
|
2
|
-
standard_preprocessing,
|
|
3
|
-
StandardPreprocessing,
|
|
4
|
-
)
|
|
5
|
-
from reflectorch.inference.preprocess_exp.interpolation import interp_reflectivity
|
|
6
|
-
from reflectorch.inference.preprocess_exp.attenuation import apply_attenuation_correction
|
|
1
|
+
from reflectorch.inference.preprocess_exp.preprocess import (
|
|
2
|
+
standard_preprocessing,
|
|
3
|
+
StandardPreprocessing,
|
|
4
|
+
)
|
|
5
|
+
from reflectorch.inference.preprocess_exp.interpolation import interp_reflectivity
|
|
6
|
+
from reflectorch.inference.preprocess_exp.attenuation import apply_attenuation_correction
|
|
7
7
|
from reflectorch.inference.preprocess_exp.footprint import apply_footprint_correction
|