reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -126
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -246
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -222
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -851
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +37 -0
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +524 -98
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -16
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -248
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -191
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -14
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -17
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -428
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -401
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +98 -68
  91. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
  94. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  95. reflectorch-1.3.0.dist-info/RECORD +0 -86
  96. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,98 +1,524 @@
1
- from matplotlib import pyplot as plt
2
- import numpy as np
3
-
4
-
5
- def plot_prediction_results(
6
- prediction_dict: dict,
7
- q_exp: np.ndarray = None,
8
- curve_exp: np.ndarray = None,
9
- sigmas_exp: np.ndarray = None,
10
- q_model: np.ndarray = None,
11
- ):
12
- """
13
- Plot the experimental curve (with optional error bars), the predicted
14
- and polished curves, and also the predicted/polished SLD profiles.
15
-
16
- Args:
17
- prediction_dict (dict): Dictionary containing 'predicted_curve',
18
- 'predicted_sld_profile', 'predicted_sld_xaxis',
19
- and optionally 'polished_curve', 'sld_profile_polished'.
20
- q_exp (ndarray, optional): Experimental q-values.
21
- curve_exp (ndarray, optional): Experimental reflectivity curve.
22
- sigmas_exp (ndarray, optional): Error bars of the experimental reflectivity.
23
- q_model (ndarray, optional): The q-values on which prediction_dict's reflectivity
24
- was computed (e.g. from EasyInferenceModel.interpolate_data_to_model_q).
25
-
26
- Example usage:
27
- prediction_dict = model.predict(...)
28
- plot_prediction_results(
29
- prediction_dict,
30
- q_exp=q_exp,
31
- curve_exp=curve_exp,
32
- sigmas_exp=sigmas_exp,
33
- q_model=q_model
34
- )
35
- """
36
-
37
- fig, ax = plt.subplots(1, 2, figsize=(12, 6))
38
-
39
- # --- Left plot: Reflectivity curves ---
40
- ax[0].set_yscale('log')
41
- ax[0].set_xlabel('q [$Å^{-1}$]', fontsize=20)
42
- ax[0].set_ylabel('R(q)', fontsize=20)
43
- ax[0].tick_params(axis='both', which='major', labelsize=15)
44
- ax[0].tick_params(axis='both', which='minor', labelsize=15)
45
-
46
- # Optionally set major y ticks (log scale)
47
- y_tick_locations = [10 ** (-2 * i) for i in range(6)]
48
- ax[0].yaxis.set_major_locator(plt.FixedLocator(y_tick_locations))
49
-
50
- # Plot experimental data with error bars (if provided)
51
- if q_exp is not None and curve_exp is not None:
52
- el = ax[0].errorbar(
53
- q_exp, curve_exp, yerr=sigmas_exp,
54
- xerr=None, c='b', ecolor='purple', elinewidth=1,
55
- marker='o', linestyle='none', markersize=3,
56
- label='exp. curve', zorder=1
57
- )
58
- # Change the color of error bar lines (optional)
59
- elines = el.get_children()
60
- if len(elines) > 1:
61
- elines[1].set_color('purple')
62
-
63
- # Plot predicted curve
64
- if 'predicted_curve' in prediction_dict and q_model is not None:
65
- ax[0].plot(q_model, prediction_dict['predicted_curve'], c='red', lw=2, label='pred. curve')
66
-
67
- # Plot polished curve (if present)
68
- if 'polished_curve' in prediction_dict and q_model is not None:
69
- ax[0].plot(q_model, prediction_dict['polished_curve'], c='orange', ls='--', lw=2, label='polished pred. curve')
70
-
71
- ax[0].legend(fontsize=12)
72
-
73
- # --- Right plot: SLD profiles ---
74
- ax[1].set_xlabel('z [$Å$]', fontsize=20)
75
- ax[1].set_ylabel('SLD [$10^{-6} Å^{-2}$]', fontsize=20)
76
- ax[1].tick_params(axis='both', which='major', labelsize=15)
77
- ax[1].tick_params(axis='both', which='minor', labelsize=15)
78
-
79
- # Predicted SLD
80
- if 'predicted_sld_xaxis' in prediction_dict and 'predicted_sld_profile' in prediction_dict:
81
- ax[1].plot(
82
- prediction_dict['predicted_sld_xaxis'],
83
- prediction_dict['predicted_sld_profile'],
84
- c='red', label='pred. sld'
85
- )
86
-
87
- # Polished SLD
88
- if 'sld_profile_polished' in prediction_dict and 'predicted_sld_xaxis' in prediction_dict:
89
- ax[1].plot(
90
- prediction_dict['predicted_sld_xaxis'],
91
- prediction_dict['sld_profile_polished'],
92
- c='orange', ls='--', label='polished sld'
93
- )
94
-
95
- ax[1].legend(fontsize=12)
96
-
97
- plt.tight_layout()
98
- plt.show()
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib.ticker as mticker
4
+ from matplotlib.lines import Line2D
5
+
6
+
7
+ def print_prediction_results(prediction_dict, param_names=None, width=10, precision=3, header=True):
8
+
9
+ if param_names is None:
10
+ param_names = prediction_dict.get("param_names", [])
11
+
12
+ pred = np.asarray(prediction_dict.get("predicted_params_array", []), dtype=float)
13
+ pol = prediction_dict.get("polished_params_array", None)
14
+ pol = np.asarray(pol, dtype=float) if pol is not None else None
15
+ pol_err = prediction_dict.get('polished_params_error_array', None)
16
+ pol_err = np.asarray(pol_err, dtype=float) if pol_err is not None else None
17
+
18
+ name_w = max(14, max((len(str(n)) for n in param_names), default=14))
19
+
20
+ num_fmt = f"{{:>{width}.{precision}f}}"
21
+
22
+ if header:
23
+ hdr = f"{'Parameter'.ljust(name_w)} {'Predicted'.rjust(width)}"
24
+ if pol is not None:
25
+ hdr += f" {'Polished'.rjust(width)}"
26
+ if pol_err is not None:
27
+ hdr += f" {'Polished err'.rjust(width)}"
28
+ print(hdr)
29
+ print("-" * len(hdr))
30
+
31
+ for i, name in enumerate(param_names):
32
+ pred_val = pred[i] if i < pred.size else float("nan")
33
+ row = f"{str(name).ljust(name_w)} {num_fmt.format(pred_val)}"
34
+ if pol is not None:
35
+ pol_val = pol[i] if i < pol.size else float("nan")
36
+ row += f" {num_fmt.format(pol_val)}"
37
+ if pol_err is not None:
38
+ pol_err_val = pol_err[i] if i < pol_err.size else float('nan')
39
+ row += f' {num_fmt.format(pol_err_val)}'
40
+ print(row)
41
+
42
+
43
+ def plot_prediction_results(
44
+ prediction_dict: dict,
45
+ q_exp: np.ndarray,
46
+ curve_exp: np.ndarray,
47
+ sigmas_exp: np.ndarray = None,
48
+ logx=False,
49
+ ):
50
+ q_pred = prediction_dict['q_plot_pred']
51
+ r_pred = prediction_dict['predicted_curve']
52
+ r_pol = prediction_dict.get('polished_curve', None)
53
+
54
+ q_pol = None
55
+ if r_pol is not None:
56
+ if len(r_pol) == len(q_pred):
57
+ q_pol = q_pred
58
+ elif len(r_pol) == len(q_exp):
59
+ q_pol = q_exp
60
+
61
+ z_sld = prediction_dict.get('predicted_sld_xaxis', None)
62
+ sld_pred_c = prediction_dict.get('predicted_sld_profile', None)
63
+ sld_pol_c = prediction_dict.get('sld_profile_polished', None)
64
+
65
+ plot_sld = (z_sld is not None) and (sld_pred_c is not None or sld_pol_c is not None)
66
+
67
+ sld_is_complex = np.iscomplexobj(sld_pred_c)
68
+
69
+ sld_pred_label = 'pred. SLD (Re)' if sld_is_complex else 'pred. SLD'
70
+ sld_pol_label = 'polished SLD (Re)' if sld_is_complex else 'polished SLD'
71
+
72
+ fig, axes = plot_reflectivity(
73
+ q_exp=q_exp, r_exp=curve_exp, yerr=sigmas_exp,
74
+ q_pred=q_pred, r_pred=r_pred,
75
+ q_pol=q_pol, r_pol=r_pol,
76
+ z_sld=z_sld,
77
+ sld_pred=sld_pred_c.real if sld_pred_c is not None else None,
78
+ sld_pol=sld_pol_c.real if sld_pol_c is not None else None,
79
+ sld_pred_label=sld_pred_label,
80
+ sld_pol_label=sld_pol_label,
81
+ plot_sld_profile=plot_sld,
82
+ logx=logx,
83
+ )
84
+
85
+ if sld_is_complex and plot_sld:
86
+ ax_r, ax_s = axes
87
+ ax_s.plot(z_sld, sld_pred_c.imag, color='darkgreen', lw=2.0, ls='-', zorder=4, label='pred. SLD (Im)')
88
+ if sld_pol_c is not None:
89
+ ax_s.plot(z_sld, sld_pol_c.imag, color='cyan', lw=2.0, ls='--', zorder=5, label='polished SLD (Im)')
90
+ ax_s.legend(fontsize=14, frameon=True)
91
+
92
+ return fig, axes
93
+
94
+
95
+ def plot_reflectivity(
96
+ *,
97
+ q_exp=None,
98
+ r_exp=None,
99
+ yerr=None,
100
+ xerr=None,
101
+ q_pred=None,
102
+ r_pred=None,
103
+ q_pol=None,
104
+ r_pol=None,
105
+ z_sld=None,
106
+ sld_pred=None,
107
+ sld_pol=None,
108
+ plot_sld_profile=False,
109
+ figsize=None,
110
+ logx=False,
111
+ logy=True,
112
+ x_ticks_log=None,
113
+ y_ticks_log=(10.0 ** -np.arange(0, 12, 2)),
114
+ q_label=r'q [$\mathrm{\AA^{-1}}$]',
115
+ r_label='R(q)',
116
+ z_label=r'z [$\mathrm{\AA}$]',
117
+ sld_label=r'SLD [$10^{-6}\ \mathrm{\AA^{-2}}$]',
118
+ xlim=None,
119
+ axis_label_size=20,
120
+ tick_label_size=15,
121
+ legend_fontsize=14,
122
+ exp_style='auto',
123
+ exp_color='blue',
124
+ exp_facecolor='none',
125
+ exp_marker='o',
126
+ exp_ms=3,
127
+ exp_alpha=1.0,
128
+ exp_errcolor='purple',
129
+ exp_elinewidth=1.0,
130
+ exp_capsize=1.0,
131
+ exp_capthick=1.0,
132
+ exp_zorder=2,
133
+ pred_color='red',
134
+ pred_lw=2.0,
135
+ pred_ls='-',
136
+ pred_alpha=1.0,
137
+ pred_zorder=3,
138
+ pol_color='orange',
139
+ pol_lw=2.0,
140
+ pol_ls='--',
141
+ pol_alpha=1.0,
142
+ pol_zorder=4,
143
+ sld_pred_color='red',
144
+ sld_pred_lw=2.0,
145
+ sld_pred_ls='-',
146
+ sld_pol_color='orange',
147
+ sld_pol_lw=2.0,
148
+ sld_pol_ls='--',
149
+ exp_label='exp. data',
150
+ pred_label='prediction',
151
+ pol_label='polished prediction',
152
+ sld_pred_label='pred. SLD',
153
+ sld_pol_label='polished SLD',
154
+ legend=True,
155
+ legend_kwargs=None
156
+ ):
157
+
158
+ def _np(a):
159
+ return None if a is None else np.asarray(a)
160
+
161
+ def _mask(x, y):
162
+ m = np.isfinite(x) & np.isfinite(y)
163
+ if logx: m &= (x > 0.0)
164
+ if logy: m &= (y > 0.0)
165
+ return m
166
+
167
+ def _slice_sym_err(err, mask):
168
+ if err is None:
169
+ return None
170
+ if np.isscalar(err):
171
+ return err
172
+ e = np.asarray(err)
173
+ if e.ndim != 1:
174
+ raise ValueError("Errors must be scalar or 1-D array.")
175
+ return e[mask]
176
+
177
+ q_exp, r_exp, yerr, xerr = _np(q_exp), _np(r_exp), _np(yerr), _np(xerr)
178
+ q_pred, r_pred = _np(q_pred), _np(r_pred)
179
+ q_pol, r_pol = _np(q_pol), _np(r_pol)
180
+ z_sld, sld_pred, sld_pol = _np(z_sld), _np(sld_pred), _np(sld_pol)
181
+
182
+ # Figure & axes
183
+ if figsize is None:
184
+ figsize = (12, 6) if plot_sld_profile else (6, 6)
185
+ if plot_sld_profile:
186
+ fig, (ax_r, ax_s) = plt.subplots(1, 2, figsize=figsize)
187
+ else:
188
+ fig, ax_r = plt.subplots(1, 1, figsize=figsize)
189
+ ax_s = None
190
+
191
+ # Apply x-limits (right-only or both)
192
+ if xlim is not None:
193
+ if np.isscalar(xlim):
194
+ cur_left, _ = ax_r.get_xlim()
195
+ if logx and cur_left <= 0:
196
+ cur_left = 1e-12
197
+ ax_r.set_xlim(left=cur_left, right=float(xlim))
198
+ else:
199
+ xmin, xmax = xlim
200
+ if logx and xmin is not None and xmin <= 0:
201
+ raise ValueError("For log-x, xmin must be > 0.")
202
+ ax_r.set_xlim(left=xmin, right=xmax)
203
+
204
+ # Axis scales / labels / ticks
205
+ if logx: ax_r.set_xscale('log')
206
+ if logy: ax_r.set_yscale('log')
207
+
208
+ ax_r.set_xlabel(q_label, fontsize=axis_label_size)
209
+ ax_r.set_ylabel(r_label, fontsize=axis_label_size)
210
+ ax_r.tick_params(axis='both', which='major', labelsize=tick_label_size)
211
+ ax_r.tick_params(axis='both', which='minor', labelsize=tick_label_size)
212
+ if logx and x_ticks_log is not None:
213
+ ax_r.xaxis.set_major_locator(mticker.FixedLocator(x_ticks_log))
214
+ if logy and y_ticks_log is not None:
215
+ ax_r.yaxis.set_major_locator(mticker.FixedLocator(y_ticks_log))
216
+
217
+ handles = []
218
+
219
+ # Experimental plot
220
+ exp_handle = None
221
+ if q_exp is not None and r_exp is not None:
222
+ m = _mask(q_exp, r_exp)
223
+ style = exp_style if exp_style != 'auto' else ('errorbar' if yerr is not None else 'scatter')
224
+
225
+ if style == 'errorbar' and (yerr is not None):
226
+ yerr_m = _slice_sym_err(yerr, m)
227
+ xerr_m = _slice_sym_err(xerr, m)
228
+ ax_r.errorbar(
229
+ q_exp[m], r_exp[m], yerr=yerr_m, xerr=xerr_m,
230
+ color=exp_color, ecolor=exp_errcolor,
231
+ elinewidth=exp_elinewidth, capsize=exp_capsize,
232
+ capthick=(exp_elinewidth if exp_capthick is None else exp_capthick),
233
+ marker=exp_marker, linestyle='none', markersize=exp_ms,
234
+ markerfacecolor=exp_facecolor, markeredgecolor=exp_color,
235
+ alpha=exp_alpha, zorder=exp_zorder, label=None
236
+ )
237
+ exp_handle = Line2D([], [], color=exp_color, marker=exp_marker,
238
+ linestyle='none', markersize=exp_ms,
239
+ markerfacecolor=exp_facecolor, markeredgecolor=exp_color,
240
+ alpha=exp_alpha, label=exp_label)
241
+ elif style == 'scatter':
242
+ ax_r.scatter(
243
+ q_exp[m], r_exp[m],
244
+ s=exp_ms**2, marker=exp_marker,
245
+ facecolors=exp_facecolor, edgecolors=exp_color,
246
+ alpha=exp_alpha, zorder=exp_zorder, label=None
247
+ )
248
+ exp_handle = Line2D([], [], color=exp_color, marker=exp_marker,
249
+ linestyle='none', markersize=exp_ms,
250
+ markerfacecolor=exp_facecolor, markeredgecolor=exp_color,
251
+ alpha=exp_alpha, label=exp_label)
252
+ else: # 'line'
253
+ ln = ax_r.plot(
254
+ q_exp[m], r_exp[m], color=exp_color, lw=1.0, ls='-',
255
+ alpha=exp_alpha, zorder=exp_zorder, label=exp_label
256
+ )[0]
257
+ exp_handle = ln
258
+
259
+ if exp_handle is not None:
260
+ handles.append(exp_handle)
261
+
262
+ # Predicted line
263
+ pred_handle = None
264
+ if q_pred is not None and r_pred is not None:
265
+ m = _mask(q_pred, r_pred)
266
+ pred_handle = ax_r.plot(
267
+ q_pred[m], r_pred[m],
268
+ color=pred_color, lw=pred_lw, ls=pred_ls,
269
+ alpha=pred_alpha, zorder=pred_zorder, label=pred_label
270
+ )[0]
271
+ handles.append(pred_handle)
272
+
273
+ # Polished line
274
+ pol_handle = None
275
+ if q_pol is not None and r_pol is not None:
276
+ m = _mask(q_pol, r_pol)
277
+ pol_handle = ax_r.plot(
278
+ q_pol[m], r_pol[m],
279
+ color=pol_color, lw=pol_lw, ls=pol_ls,
280
+ alpha=pol_alpha, zorder=pol_zorder, label=pol_label
281
+ )[0]
282
+ handles.append(pol_handle)
283
+
284
+ if legend and handles:
285
+ lk = {} if legend_kwargs is None else dict(legend_kwargs)
286
+ ax_r.legend(handles=handles,
287
+ labels=[h.get_label() for h in handles],
288
+ fontsize=legend_fontsize, loc='best', **lk)
289
+
290
+ # SLD panel (optional)
291
+ if ax_s is not None:
292
+ ax_s.set_xlabel(z_label, fontsize=axis_label_size)
293
+ ax_s.set_ylabel(sld_label, fontsize=axis_label_size)
294
+ ax_s.tick_params(axis='both', which='major', labelsize=tick_label_size)
295
+ ax_s.tick_params(axis='both', which='minor', labelsize=tick_label_size)
296
+
297
+ if z_sld is not None and sld_pred is not None:
298
+ ax_s.plot(z_sld, sld_pred,
299
+ color=sld_pred_color, lw=sld_pred_lw, ls=sld_pred_ls,
300
+ label=sld_pred_label)
301
+ if z_sld is not None and sld_pol is not None:
302
+ ax_s.plot(z_sld, sld_pol,
303
+ color=sld_pol_color, lw=sld_pol_lw, ls=sld_pol_ls,
304
+ label=sld_pol_label)
305
+
306
+ if legend:
307
+ ax_s.legend(fontsize=legend_fontsize, loc='best', **(legend_kwargs or {}))
308
+
309
+ plt.tight_layout()
310
+ return (fig, (ax_r, ax_s)) if ax_s is not None else (fig, ax_r)
311
+
312
+
313
+ def plot_reflectivity_multi(
314
+ *,
315
+ rq_series,
316
+ sld_series=None,
317
+ plot_sld_profile=False,
318
+ figsize=None,
319
+ logx=False,
320
+ logy=True,
321
+ xlim=None,
322
+ x_ticks_log=None,
323
+ y_ticks_log=(10.0 ** -np.arange(0, 12, 2)),
324
+ q_label=r'q [$\mathrm{\AA^{-1}}$]',
325
+ r_label='R(q)',
326
+ z_label=r'z [$\mathrm{\AA}$]',
327
+ sld_label=r'SLD [$10^{-6}\ \mathrm{\AA^{-2}}$]',
328
+ axis_label_size=20,
329
+ tick_label_size=15,
330
+ legend=True,
331
+ legend_fontsize=12,
332
+ legend_kwargs=None,
333
+ ):
334
+ """
335
+ Plot multiple R(q) series (and optional SLD lines) with per-series styling.
336
+
337
+ rq_series: list of dicts, each with:
338
+ required:
339
+ - x: 1D array
340
+ - y: 1D array
341
+ optional (per series):
342
+ - kind: 'errorbar' | 'scatter' | 'line' (default 'line')
343
+ - label: str
344
+ - color: str
345
+ - alpha: float (0..1)
346
+ - zorder: int
347
+ # scatter / marker for errorbar:
348
+ - marker: str (default 'o')
349
+ - ms: float (marker size in pt; for scatter internally converted to s=ms**2)
350
+ - facecolor: str (scatter marker face)
351
+ # errorbar only:
352
+ - yerr: scalar or 1D array
353
+ - xerr: scalar or 1D array
354
+ - ecolor: str
355
+ - elinewidth: float
356
+ - capsize: float
357
+ - capthick: float
358
+ # line only:
359
+ - lw: float
360
+ - ls: str (e.g. '-', '--', ':')
361
+
362
+ sld_series: list of dicts (only lines), each with:
363
+ - x: 1D array (z-axis)
364
+ - y: 1D array (SLD)
365
+ - label: str (optional)
366
+ - color: str (optional)
367
+ - lw: float (optional)
368
+ - ls: str (optional)
369
+ - alpha: float (optional)
370
+ - zorder: int (optional)
371
+ """
372
+
373
+ def _np(a): return None if a is None else np.asarray(a)
374
+
375
+ def _mask(x, y):
376
+ m = np.isfinite(x) & np.isfinite(y)
377
+ if logx: m &= (x > 0.0)
378
+ if logy: m &= (y > 0.0)
379
+ return m
380
+
381
+ # Figure & axes
382
+ if figsize is None:
383
+ figsize = (12, 6) if plot_sld_profile else (6, 6)
384
+ if plot_sld_profile:
385
+ fig, (ax_r, ax_s) = plt.subplots(1, 2, figsize=figsize)
386
+ else:
387
+ fig, ax_r = plt.subplots(1, 1, figsize=figsize)
388
+ ax_s = None
389
+
390
+ # Axis scales / labels / ticks
391
+ if logx: ax_r.set_xscale('log')
392
+ if logy: ax_r.set_yscale('log')
393
+
394
+ ax_r.set_xlabel(q_label, fontsize=axis_label_size)
395
+ ax_r.set_ylabel(r_label, fontsize=axis_label_size)
396
+ ax_r.tick_params(axis='both', which='major', labelsize=tick_label_size)
397
+ ax_r.tick_params(axis='both', which='minor', labelsize=tick_label_size)
398
+ if logx and x_ticks_log is not None:
399
+ ax_r.xaxis.set_major_locator(mticker.FixedLocator(x_ticks_log))
400
+ if logy and y_ticks_log is not None:
401
+ ax_r.yaxis.set_major_locator(mticker.FixedLocator(y_ticks_log))
402
+
403
+ # Apply x-limits (right-only or both)
404
+ if xlim is not None:
405
+ if np.isscalar(xlim):
406
+ left, _ = ax_r.get_xlim()
407
+ if logx and left <= 0:
408
+ left = 1e-12
409
+ ax_r.set_xlim(left=left, right=float(xlim))
410
+ else:
411
+ xmin, xmax = xlim
412
+ if logx and xmin is not None and xmin <= 0:
413
+ raise ValueError("For log-x, xmin must be > 0.")
414
+ ax_r.set_xlim(left=xmin, right=xmax)
415
+
416
+ # Plot all R(q) series in the given order (order = legend order)
417
+ handles = []
418
+ for s in rq_series:
419
+ kind = s.get('kind', 'line')
420
+ x = _np(s.get('x'))
421
+ y = _np(s.get('y'))
422
+ if x is None or y is None:
423
+ continue
424
+
425
+ label = s.get('label', None)
426
+ color = s.get('color', None)
427
+ alpha = s.get('alpha', 1.0)
428
+ zord = s.get('zorder', None)
429
+ ms = s.get('ms', 5.0)
430
+ marker = s.get('marker', 'o')
431
+
432
+ m = _mask(x, y)
433
+
434
+ if kind == 'errorbar':
435
+ yerr = s.get('yerr', None)
436
+ xerr = s.get('xerr', None)
437
+ ecolor = s.get('ecolor', color)
438
+ elinewidth = s.get('elinewidth', 1.0)
439
+ capsize = s.get('capsize', 0.0)
440
+ capthick = s.get('capthick', elinewidth)
441
+
442
+ # Symmetric error input: scalar or 1D
443
+ def _slice_sym(err):
444
+ if err is None: return None
445
+ if np.isscalar(err): return err
446
+ arr = np.asarray(err)
447
+ if arr.ndim != 1:
448
+ raise ValueError("For symmetric error bars, provide scalar or 1-D array.")
449
+ return arr[m]
450
+
451
+ yerr_m = _slice_sym(yerr)
452
+ xerr_m = _slice_sym(xerr)
453
+
454
+ ax_r.errorbar(
455
+ x[m], y[m], yerr=yerr_m, xerr=xerr_m,
456
+ color=color, ecolor=ecolor,
457
+ elinewidth=elinewidth, capsize=capsize, capthick=capthick,
458
+ marker=marker, linestyle='none', markersize=ms,
459
+ markerfacecolor=s.get('facecolor', 'none'),
460
+ markeredgecolor=color,
461
+ alpha=alpha, zorder=zord, label=None
462
+ )
463
+
464
+ h = Line2D([], [], color=color, marker=marker, linestyle='none',
465
+ markersize=ms, markerfacecolor=s.get('facecolor','none'),
466
+ markeredgecolor=color, alpha=alpha, label=label)
467
+ if label is not None:
468
+ handles.append(h)
469
+
470
+ elif kind == 'scatter':
471
+ facecolor = s.get('facecolor', 'none')
472
+ sc = ax_r.scatter(
473
+ x[m], y[m], s=ms**2, marker=marker,
474
+ facecolors=facecolor, edgecolors=color,
475
+ alpha=alpha, zorder=zord, label=None
476
+ )
477
+ h = Line2D([], [], color=color, marker=marker, linestyle='none',
478
+ markersize=ms, markerfacecolor=facecolor,
479
+ markeredgecolor=color, alpha=alpha, label=label)
480
+ if label is not None:
481
+ handles.append(h)
482
+
483
+ else: # 'line'
484
+ lw = s.get('lw', 2.0)
485
+ ls = s.get('ls', '-')
486
+ line = ax_r.plot(
487
+ x[m], y[m],
488
+ color=color, lw=lw, ls=ls,
489
+ alpha=alpha, zorder=zord, label=label
490
+ )[0]
491
+ if label is not None:
492
+ handles.append(line)
493
+
494
+ if legend and handles:
495
+ lk = {} if legend_kwargs is None else dict(legend_kwargs)
496
+ ax_r.legend(handles=handles,
497
+ labels=[h.get_label() for h in handles],
498
+ fontsize=legend_fontsize, loc='best', **lk)
499
+
500
+ # Optional SLD panel
501
+ if plot_sld_profile:
502
+ ax_s.set_xlabel(z_label, fontsize=axis_label_size)
503
+ ax_s.set_ylabel(sld_label, fontsize=axis_label_size)
504
+ ax_s.tick_params(axis='both', which='major', labelsize=tick_label_size)
505
+ ax_s.tick_params(axis='both', which='minor', labelsize=tick_label_size)
506
+
507
+ if sld_series:
508
+ for s in sld_series:
509
+ zx = _np(s.get('x')); zy = _np(s.get('y'))
510
+ if zx is None or zy is None:
511
+ continue
512
+ label = s.get('label', None)
513
+ color = s.get('color', None)
514
+ lw = s.get('lw', 2.0)
515
+ ls = s.get('ls', '-')
516
+ alpha = s.get('alpha', 1.0)
517
+ zord = s.get('zorder', None)
518
+ ax_s.plot(zx, zy, color=color, lw=lw, ls=ls, alpha=alpha, zorder=zord, label=label)
519
+
520
+ if legend:
521
+ ax_s.legend(fontsize=legend_fontsize, loc='best', **(legend_kwargs or {}))
522
+
523
+ plt.tight_layout()
524
+ return (fig, (ax_r, ax_s)) if plot_sld_profile else (fig, ax_r)
@@ -1,7 +1,7 @@
1
- from reflectorch.inference.preprocess_exp.preprocess import (
2
- standard_preprocessing,
3
- StandardPreprocessing,
4
- )
5
- from reflectorch.inference.preprocess_exp.interpolation import interp_reflectivity
6
- from reflectorch.inference.preprocess_exp.attenuation import apply_attenuation_correction
1
+ from reflectorch.inference.preprocess_exp.preprocess import (
2
+ standard_preprocessing,
3
+ StandardPreprocessing,
4
+ )
5
+ from reflectorch.inference.preprocess_exp.interpolation import interp_reflectivity
6
+ from reflectorch.inference.preprocess_exp.attenuation import apply_attenuation_correction
7
7
  from reflectorch.inference.preprocess_exp.footprint import apply_footprint_correction