reflectorch 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. reflectorch/__init__.py +17 -0
  2. reflectorch/data_generation/__init__.py +128 -0
  3. reflectorch/data_generation/dataset.py +216 -0
  4. reflectorch/data_generation/likelihoods.py +80 -0
  5. reflectorch/data_generation/noise.py +471 -0
  6. reflectorch/data_generation/priors/__init__.py +60 -0
  7. reflectorch/data_generation/priors/base.py +55 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
  9. reflectorch/data_generation/priors/independent_priors.py +195 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -0
  12. reflectorch/data_generation/priors/no_constraints.py +206 -0
  13. reflectorch/data_generation/priors/parametric_models.py +842 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
  15. reflectorch/data_generation/priors/params.py +252 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +370 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -0
  19. reflectorch/data_generation/priors/utils.py +118 -0
  20. reflectorch/data_generation/process_data.py +41 -0
  21. reflectorch/data_generation/q_generator.py +280 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +71 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -0
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
  29. reflectorch/data_generation/scale_curves.py +112 -0
  30. reflectorch/data_generation/smearing.py +99 -0
  31. reflectorch/data_generation/utils.py +223 -0
  32. reflectorch/extensions/__init__.py +0 -0
  33. reflectorch/extensions/jupyter/__init__.py +11 -0
  34. reflectorch/extensions/jupyter/api.py +85 -0
  35. reflectorch/extensions/jupyter/callbacks.py +34 -0
  36. reflectorch/extensions/jupyter/components.py +758 -0
  37. reflectorch/extensions/jupyter/custom_select.py +268 -0
  38. reflectorch/extensions/jupyter/log_widget.py +241 -0
  39. reflectorch/extensions/jupyter/model_selection.py +495 -0
  40. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  41. reflectorch/extensions/jupyter/widget.py +625 -0
  42. reflectorch/extensions/matplotlib/__init__.py +5 -0
  43. reflectorch/extensions/matplotlib/losses.py +32 -0
  44. reflectorch/extensions/refnx/refnx_conversion.py +77 -0
  45. reflectorch/inference/__init__.py +28 -0
  46. reflectorch/inference/inference_model.py +848 -0
  47. reflectorch/inference/input_interface.py +239 -0
  48. reflectorch/inference/loading_data.py +55 -0
  49. reflectorch/inference/multilayer_fitter.py +171 -0
  50. reflectorch/inference/multilayer_inference_model.py +193 -0
  51. reflectorch/inference/plotting.py +524 -0
  52. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  53. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  54. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  55. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  56. reflectorch/inference/preprocess_exp/interpolation.py +19 -0
  57. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  58. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  59. reflectorch/inference/query_matcher.py +82 -0
  60. reflectorch/inference/record_time.py +43 -0
  61. reflectorch/inference/sampler_solution.py +56 -0
  62. reflectorch/inference/scipy_fitter.py +364 -0
  63. reflectorch/inference/torch_fitter.py +87 -0
  64. reflectorch/ml/__init__.py +32 -0
  65. reflectorch/ml/basic_trainer.py +292 -0
  66. reflectorch/ml/callbacks.py +81 -0
  67. reflectorch/ml/dataloaders.py +27 -0
  68. reflectorch/ml/loggers.py +56 -0
  69. reflectorch/ml/schedulers.py +356 -0
  70. reflectorch/ml/trainers.py +201 -0
  71. reflectorch/ml/utils.py +2 -0
  72. reflectorch/models/__init__.py +16 -0
  73. reflectorch/models/activations.py +50 -0
  74. reflectorch/models/encoders/__init__.py +19 -0
  75. reflectorch/models/encoders/conv_encoder.py +219 -0
  76. reflectorch/models/encoders/conv_res_net.py +115 -0
  77. reflectorch/models/encoders/fno.py +134 -0
  78. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  79. reflectorch/models/networks/__init__.py +14 -0
  80. reflectorch/models/networks/mlp_networks.py +434 -0
  81. reflectorch/models/networks/residual_net.py +157 -0
  82. reflectorch/paths.py +29 -0
  83. reflectorch/runs/__init__.py +31 -0
  84. reflectorch/runs/config.py +25 -0
  85. reflectorch/runs/slurm_utils.py +93 -0
  86. reflectorch/runs/train.py +78 -0
  87. reflectorch/runs/utils.py +405 -0
  88. reflectorch/test_config.py +4 -0
  89. reflectorch/train.py +4 -0
  90. reflectorch/train_on_cluster.py +4 -0
  91. reflectorch/utils.py +98 -0
  92. reflectorch-1.5.1.dist-info/METADATA +151 -0
  93. reflectorch-1.5.1.dist-info/RECORD +96 -0
  94. reflectorch-1.5.1.dist-info/WHEEL +5 -0
  95. reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
  96. reflectorch-1.5.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,524 @@
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib.ticker as mticker
4
+ from matplotlib.lines import Line2D
5
+
6
+
7
+ def print_prediction_results(prediction_dict, param_names=None, width=10, precision=3, header=True, print_err=False):
8
+
9
+ if param_names is None:
10
+ param_names = prediction_dict.get("param_names", [])
11
+
12
+ pred = np.asarray(prediction_dict.get("predicted_params_array", []), dtype=float)
13
+ pol = prediction_dict.get("polished_params_array", None)
14
+ pol = np.asarray(pol, dtype=float) if pol is not None else None
15
+ pol_err = prediction_dict.get('polished_params_error_array', None) if print_err else None
16
+ pol_err = np.asarray(pol_err, dtype=float) if pol_err is not None else None
17
+
18
+ name_w = max(14, max((len(str(n)) for n in param_names), default=14))
19
+
20
+ num_fmt = f"{{:>{width}.{precision}f}}"
21
+
22
+ if header:
23
+ hdr = f"{'Parameter'.ljust(name_w)} {'Predicted'.rjust(width)}"
24
+ if pol is not None:
25
+ hdr += f" {'Polished'.rjust(width)}"
26
+ if pol_err is not None:
27
+ hdr += f" {'Polished err'.rjust(width)}"
28
+ print(hdr)
29
+ print("-" * len(hdr))
30
+
31
+ for i, name in enumerate(param_names):
32
+ pred_val = pred[i] if i < pred.size else float("nan")
33
+ row = f"{str(name).ljust(name_w)} {num_fmt.format(pred_val)}"
34
+ if pol is not None:
35
+ pol_val = pol[i] if i < pol.size else float("nan")
36
+ row += f" {num_fmt.format(pol_val)}"
37
+ if pol_err is not None:
38
+ pol_err_val = pol_err[i] if i < pol_err.size else float('nan')
39
+ row += f' {num_fmt.format(pol_err_val)}'
40
+ print(row)
41
+
42
+
43
+ def plot_prediction_results(
44
+ prediction_dict: dict,
45
+ q_exp: np.ndarray,
46
+ curve_exp: np.ndarray,
47
+ sigmas_exp: np.ndarray = None,
48
+ logx=False,
49
+ ):
50
+ q_pred = prediction_dict['q_plot_pred']
51
+ r_pred = prediction_dict['predicted_curve']
52
+ r_pol = prediction_dict.get('polished_curve', None)
53
+
54
+ q_pol = None
55
+ if r_pol is not None:
56
+ if len(r_pol) == len(q_pred):
57
+ q_pol = q_pred
58
+ elif len(r_pol) == len(q_exp):
59
+ q_pol = q_exp
60
+
61
+ z_sld = prediction_dict.get('predicted_sld_xaxis', None)
62
+ sld_pred_c = prediction_dict.get('predicted_sld_profile', None)
63
+ sld_pol_c = prediction_dict.get('sld_profile_polished', None)
64
+
65
+ plot_sld = (z_sld is not None) and (sld_pred_c is not None or sld_pol_c is not None)
66
+
67
+ sld_is_complex = np.iscomplexobj(sld_pred_c)
68
+
69
+ sld_pred_label = 'pred. SLD (Re)' if sld_is_complex else 'pred. SLD'
70
+ sld_pol_label = 'polished SLD (Re)' if sld_is_complex else 'polished SLD'
71
+
72
+ fig, axes = plot_reflectivity(
73
+ q_exp=q_exp, r_exp=curve_exp, yerr=sigmas_exp,
74
+ q_pred=q_pred, r_pred=r_pred,
75
+ q_pol=q_pol, r_pol=r_pol,
76
+ z_sld=z_sld,
77
+ sld_pred=sld_pred_c.real if sld_pred_c is not None else None,
78
+ sld_pol=sld_pol_c.real if sld_pol_c is not None else None,
79
+ sld_pred_label=sld_pred_label,
80
+ sld_pol_label=sld_pol_label,
81
+ plot_sld_profile=plot_sld,
82
+ logx=logx,
83
+ )
84
+
85
+ if sld_is_complex and plot_sld:
86
+ ax_r, ax_s = axes
87
+ ax_s.plot(z_sld, sld_pred_c.imag, color='darkgreen', lw=2.0, ls='-', zorder=4, label='pred. SLD (Im)')
88
+ if sld_pol_c is not None:
89
+ ax_s.plot(z_sld, sld_pol_c.imag, color='cyan', lw=2.0, ls='--', zorder=5, label='polished SLD (Im)')
90
+ ax_s.legend(fontsize=14, frameon=True)
91
+
92
+ return fig, axes
93
+
94
+
95
+ def plot_reflectivity(
96
+ *,
97
+ q_exp=None,
98
+ r_exp=None,
99
+ yerr=None,
100
+ xerr=None,
101
+ q_pred=None,
102
+ r_pred=None,
103
+ q_pol=None,
104
+ r_pol=None,
105
+ z_sld=None,
106
+ sld_pred=None,
107
+ sld_pol=None,
108
+ plot_sld_profile=False,
109
+ figsize=None,
110
+ logx=False,
111
+ logy=True,
112
+ x_ticks_log=None,
113
+ y_ticks_log=(10.0 ** -np.arange(0, 12, 2)),
114
+ q_label=r'q [$\mathrm{\AA^{-1}}$]',
115
+ r_label='R(q)',
116
+ z_label=r'z [$\mathrm{\AA}$]',
117
+ sld_label=r'SLD [$10^{-6}\ \mathrm{\AA^{-2}}$]',
118
+ xlim=None,
119
+ axis_label_size=20,
120
+ tick_label_size=15,
121
+ legend_fontsize=14,
122
+ exp_style='auto',
123
+ exp_color='blue',
124
+ exp_facecolor='none',
125
+ exp_marker='o',
126
+ exp_ms=3,
127
+ exp_alpha=1.0,
128
+ exp_errcolor='purple',
129
+ exp_elinewidth=1.0,
130
+ exp_capsize=1.0,
131
+ exp_capthick=1.0,
132
+ exp_zorder=2,
133
+ pred_color='red',
134
+ pred_lw=2.0,
135
+ pred_ls='-',
136
+ pred_alpha=1.0,
137
+ pred_zorder=3,
138
+ pol_color='orange',
139
+ pol_lw=2.0,
140
+ pol_ls='--',
141
+ pol_alpha=1.0,
142
+ pol_zorder=4,
143
+ sld_pred_color='red',
144
+ sld_pred_lw=2.0,
145
+ sld_pred_ls='-',
146
+ sld_pol_color='orange',
147
+ sld_pol_lw=2.0,
148
+ sld_pol_ls='--',
149
+ exp_label='exp. data',
150
+ pred_label='prediction',
151
+ pol_label='polished prediction',
152
+ sld_pred_label='pred. SLD',
153
+ sld_pol_label='polished SLD',
154
+ legend=True,
155
+ legend_kwargs=None
156
+ ):
157
+
158
+ def _np(a):
159
+ return None if a is None else np.asarray(a)
160
+
161
+ def _mask(x, y):
162
+ m = np.isfinite(x) & np.isfinite(y)
163
+ if logx: m &= (x > 0.0)
164
+ if logy: m &= (y > 0.0)
165
+ return m
166
+
167
+ def _slice_sym_err(err, mask):
168
+ if err is None:
169
+ return None
170
+ if np.isscalar(err):
171
+ return err
172
+ e = np.asarray(err)
173
+ if e.ndim != 1:
174
+ raise ValueError("Errors must be scalar or 1-D array.")
175
+ return e[mask]
176
+
177
+ q_exp, r_exp, yerr, xerr = _np(q_exp), _np(r_exp), _np(yerr), _np(xerr)
178
+ q_pred, r_pred = _np(q_pred), _np(r_pred)
179
+ q_pol, r_pol = _np(q_pol), _np(r_pol)
180
+ z_sld, sld_pred, sld_pol = _np(z_sld), _np(sld_pred), _np(sld_pol)
181
+
182
+ # Figure & axes
183
+ if figsize is None:
184
+ figsize = (12, 6) if plot_sld_profile else (6, 6)
185
+ if plot_sld_profile:
186
+ fig, (ax_r, ax_s) = plt.subplots(1, 2, figsize=figsize)
187
+ else:
188
+ fig, ax_r = plt.subplots(1, 1, figsize=figsize)
189
+ ax_s = None
190
+
191
+ # Apply x-limits (right-only or both)
192
+ if xlim is not None:
193
+ if np.isscalar(xlim):
194
+ cur_left, _ = ax_r.get_xlim()
195
+ if logx and cur_left <= 0:
196
+ cur_left = 1e-12
197
+ ax_r.set_xlim(left=cur_left, right=float(xlim))
198
+ else:
199
+ xmin, xmax = xlim
200
+ if logx and xmin is not None and xmin <= 0:
201
+ raise ValueError("For log-x, xmin must be > 0.")
202
+ ax_r.set_xlim(left=xmin, right=xmax)
203
+
204
+ # Axis scales / labels / ticks
205
+ if logx: ax_r.set_xscale('log')
206
+ if logy: ax_r.set_yscale('log')
207
+
208
+ ax_r.set_xlabel(q_label, fontsize=axis_label_size)
209
+ ax_r.set_ylabel(r_label, fontsize=axis_label_size)
210
+ ax_r.tick_params(axis='both', which='major', labelsize=tick_label_size)
211
+ ax_r.tick_params(axis='both', which='minor', labelsize=tick_label_size)
212
+ if logx and x_ticks_log is not None:
213
+ ax_r.xaxis.set_major_locator(mticker.FixedLocator(x_ticks_log))
214
+ if logy and y_ticks_log is not None:
215
+ ax_r.yaxis.set_major_locator(mticker.FixedLocator(y_ticks_log))
216
+
217
+ handles = []
218
+
219
+ # Experimental plot
220
+ exp_handle = None
221
+ if q_exp is not None and r_exp is not None:
222
+ m = _mask(q_exp, r_exp)
223
+ style = exp_style if exp_style != 'auto' else ('errorbar' if yerr is not None else 'scatter')
224
+
225
+ if style == 'errorbar' and (yerr is not None):
226
+ yerr_m = _slice_sym_err(yerr, m)
227
+ xerr_m = _slice_sym_err(xerr, m)
228
+ ax_r.errorbar(
229
+ q_exp[m], r_exp[m], yerr=yerr_m, xerr=xerr_m,
230
+ color=exp_color, ecolor=exp_errcolor,
231
+ elinewidth=exp_elinewidth, capsize=exp_capsize,
232
+ capthick=(exp_elinewidth if exp_capthick is None else exp_capthick),
233
+ marker=exp_marker, linestyle='none', markersize=exp_ms,
234
+ markerfacecolor=exp_facecolor, markeredgecolor=exp_color,
235
+ alpha=exp_alpha, zorder=exp_zorder, label=None
236
+ )
237
+ exp_handle = Line2D([], [], color=exp_color, marker=exp_marker,
238
+ linestyle='none', markersize=exp_ms,
239
+ markerfacecolor=exp_facecolor, markeredgecolor=exp_color,
240
+ alpha=exp_alpha, label=exp_label)
241
+ elif style == 'scatter':
242
+ ax_r.scatter(
243
+ q_exp[m], r_exp[m],
244
+ s=exp_ms**2, marker=exp_marker,
245
+ facecolors=exp_facecolor, edgecolors=exp_color,
246
+ alpha=exp_alpha, zorder=exp_zorder, label=None
247
+ )
248
+ exp_handle = Line2D([], [], color=exp_color, marker=exp_marker,
249
+ linestyle='none', markersize=exp_ms,
250
+ markerfacecolor=exp_facecolor, markeredgecolor=exp_color,
251
+ alpha=exp_alpha, label=exp_label)
252
+ else: # 'line'
253
+ ln = ax_r.plot(
254
+ q_exp[m], r_exp[m], color=exp_color, lw=1.0, ls='-',
255
+ alpha=exp_alpha, zorder=exp_zorder, label=exp_label
256
+ )[0]
257
+ exp_handle = ln
258
+
259
+ if exp_handle is not None:
260
+ handles.append(exp_handle)
261
+
262
+ # Predicted line
263
+ pred_handle = None
264
+ if q_pred is not None and r_pred is not None:
265
+ m = _mask(q_pred, r_pred)
266
+ pred_handle = ax_r.plot(
267
+ q_pred[m], r_pred[m],
268
+ color=pred_color, lw=pred_lw, ls=pred_ls,
269
+ alpha=pred_alpha, zorder=pred_zorder, label=pred_label
270
+ )[0]
271
+ handles.append(pred_handle)
272
+
273
+ # Polished line
274
+ pol_handle = None
275
+ if q_pol is not None and r_pol is not None:
276
+ m = _mask(q_pol, r_pol)
277
+ pol_handle = ax_r.plot(
278
+ q_pol[m], r_pol[m],
279
+ color=pol_color, lw=pol_lw, ls=pol_ls,
280
+ alpha=pol_alpha, zorder=pol_zorder, label=pol_label
281
+ )[0]
282
+ handles.append(pol_handle)
283
+
284
+ if legend and handles:
285
+ lk = {} if legend_kwargs is None else dict(legend_kwargs)
286
+ ax_r.legend(handles=handles,
287
+ labels=[h.get_label() for h in handles],
288
+ fontsize=legend_fontsize, loc='best', **lk)
289
+
290
+ # SLD panel (optional)
291
+ if ax_s is not None:
292
+ ax_s.set_xlabel(z_label, fontsize=axis_label_size)
293
+ ax_s.set_ylabel(sld_label, fontsize=axis_label_size)
294
+ ax_s.tick_params(axis='both', which='major', labelsize=tick_label_size)
295
+ ax_s.tick_params(axis='both', which='minor', labelsize=tick_label_size)
296
+
297
+ if z_sld is not None and sld_pred is not None:
298
+ ax_s.plot(z_sld, sld_pred,
299
+ color=sld_pred_color, lw=sld_pred_lw, ls=sld_pred_ls,
300
+ label=sld_pred_label)
301
+ if z_sld is not None and sld_pol is not None:
302
+ ax_s.plot(z_sld, sld_pol,
303
+ color=sld_pol_color, lw=sld_pol_lw, ls=sld_pol_ls,
304
+ label=sld_pol_label)
305
+
306
+ if legend:
307
+ ax_s.legend(fontsize=legend_fontsize, loc='best', **(legend_kwargs or {}))
308
+
309
+ plt.tight_layout()
310
+ return (fig, (ax_r, ax_s)) if ax_s is not None else (fig, ax_r)
311
+
312
+
313
+ def plot_reflectivity_multi(
314
+ *,
315
+ rq_series,
316
+ sld_series=None,
317
+ plot_sld_profile=False,
318
+ figsize=None,
319
+ logx=False,
320
+ logy=True,
321
+ xlim=None,
322
+ x_ticks_log=None,
323
+ y_ticks_log=(10.0 ** -np.arange(0, 12, 2)),
324
+ q_label=r'q [$\mathrm{\AA^{-1}}$]',
325
+ r_label='R(q)',
326
+ z_label=r'z [$\mathrm{\AA}$]',
327
+ sld_label=r'SLD [$10^{-6}\ \mathrm{\AA^{-2}}$]',
328
+ axis_label_size=20,
329
+ tick_label_size=15,
330
+ legend=True,
331
+ legend_fontsize=12,
332
+ legend_kwargs=None,
333
+ ):
334
+ """
335
+ Plot multiple R(q) series (and optional SLD lines) with per-series styling.
336
+
337
+ rq_series: list of dicts, each with:
338
+ required:
339
+ - x: 1D array
340
+ - y: 1D array
341
+ optional (per series):
342
+ - kind: 'errorbar' | 'scatter' | 'line' (default 'line')
343
+ - label: str
344
+ - color: str
345
+ - alpha: float (0..1)
346
+ - zorder: int
347
+ # scatter / marker for errorbar:
348
+ - marker: str (default 'o')
349
+ - ms: float (marker size in pt; for scatter internally converted to s=ms**2)
350
+ - facecolor: str (scatter marker face)
351
+ # errorbar only:
352
+ - yerr: scalar or 1D array
353
+ - xerr: scalar or 1D array
354
+ - ecolor: str
355
+ - elinewidth: float
356
+ - capsize: float
357
+ - capthick: float
358
+ # line only:
359
+ - lw: float
360
+ - ls: str (e.g. '-', '--', ':')
361
+
362
+ sld_series: list of dicts (only lines), each with:
363
+ - x: 1D array (z-axis)
364
+ - y: 1D array (SLD)
365
+ - label: str (optional)
366
+ - color: str (optional)
367
+ - lw: float (optional)
368
+ - ls: str (optional)
369
+ - alpha: float (optional)
370
+ - zorder: int (optional)
371
+ """
372
+
373
+ def _np(a): return None if a is None else np.asarray(a)
374
+
375
+ def _mask(x, y):
376
+ m = np.isfinite(x) & np.isfinite(y)
377
+ if logx: m &= (x > 0.0)
378
+ if logy: m &= (y > 0.0)
379
+ return m
380
+
381
+ # Figure & axes
382
+ if figsize is None:
383
+ figsize = (12, 6) if plot_sld_profile else (6, 6)
384
+ if plot_sld_profile:
385
+ fig, (ax_r, ax_s) = plt.subplots(1, 2, figsize=figsize)
386
+ else:
387
+ fig, ax_r = plt.subplots(1, 1, figsize=figsize)
388
+ ax_s = None
389
+
390
+ # Axis scales / labels / ticks
391
+ if logx: ax_r.set_xscale('log')
392
+ if logy: ax_r.set_yscale('log')
393
+
394
+ ax_r.set_xlabel(q_label, fontsize=axis_label_size)
395
+ ax_r.set_ylabel(r_label, fontsize=axis_label_size)
396
+ ax_r.tick_params(axis='both', which='major', labelsize=tick_label_size)
397
+ ax_r.tick_params(axis='both', which='minor', labelsize=tick_label_size)
398
+ if logx and x_ticks_log is not None:
399
+ ax_r.xaxis.set_major_locator(mticker.FixedLocator(x_ticks_log))
400
+ if logy and y_ticks_log is not None:
401
+ ax_r.yaxis.set_major_locator(mticker.FixedLocator(y_ticks_log))
402
+
403
+ # Apply x-limits (right-only or both)
404
+ if xlim is not None:
405
+ if np.isscalar(xlim):
406
+ left, _ = ax_r.get_xlim()
407
+ if logx and left <= 0:
408
+ left = 1e-12
409
+ ax_r.set_xlim(left=left, right=float(xlim))
410
+ else:
411
+ xmin, xmax = xlim
412
+ if logx and xmin is not None and xmin <= 0:
413
+ raise ValueError("For log-x, xmin must be > 0.")
414
+ ax_r.set_xlim(left=xmin, right=xmax)
415
+
416
+ # Plot all R(q) series in the given order (order = legend order)
417
+ handles = []
418
+ for s in rq_series:
419
+ kind = s.get('kind', 'line')
420
+ x = _np(s.get('x'))
421
+ y = _np(s.get('y'))
422
+ if x is None or y is None:
423
+ continue
424
+
425
+ label = s.get('label', None)
426
+ color = s.get('color', None)
427
+ alpha = s.get('alpha', 1.0)
428
+ zord = s.get('zorder', None)
429
+ ms = s.get('ms', 5.0)
430
+ marker = s.get('marker', 'o')
431
+
432
+ m = _mask(x, y)
433
+
434
+ if kind == 'errorbar':
435
+ yerr = s.get('yerr', None)
436
+ xerr = s.get('xerr', None)
437
+ ecolor = s.get('ecolor', color)
438
+ elinewidth = s.get('elinewidth', 1.0)
439
+ capsize = s.get('capsize', 0.0)
440
+ capthick = s.get('capthick', elinewidth)
441
+
442
+ # Symmetric error input: scalar or 1D
443
+ def _slice_sym(err):
444
+ if err is None: return None
445
+ if np.isscalar(err): return err
446
+ arr = np.asarray(err)
447
+ if arr.ndim != 1:
448
+ raise ValueError("For symmetric error bars, provide scalar or 1-D array.")
449
+ return arr[m]
450
+
451
+ yerr_m = _slice_sym(yerr)
452
+ xerr_m = _slice_sym(xerr)
453
+
454
+ ax_r.errorbar(
455
+ x[m], y[m], yerr=yerr_m, xerr=xerr_m,
456
+ color=color, ecolor=ecolor,
457
+ elinewidth=elinewidth, capsize=capsize, capthick=capthick,
458
+ marker=marker, linestyle='none', markersize=ms,
459
+ markerfacecolor=s.get('facecolor', 'none'),
460
+ markeredgecolor=color,
461
+ alpha=alpha, zorder=zord, label=None
462
+ )
463
+
464
+ h = Line2D([], [], color=color, marker=marker, linestyle='none',
465
+ markersize=ms, markerfacecolor=s.get('facecolor','none'),
466
+ markeredgecolor=color, alpha=alpha, label=label)
467
+ if label is not None:
468
+ handles.append(h)
469
+
470
+ elif kind == 'scatter':
471
+ facecolor = s.get('facecolor', 'none')
472
+ sc = ax_r.scatter(
473
+ x[m], y[m], s=ms**2, marker=marker,
474
+ facecolors=facecolor, edgecolors=color,
475
+ alpha=alpha, zorder=zord, label=None
476
+ )
477
+ h = Line2D([], [], color=color, marker=marker, linestyle='none',
478
+ markersize=ms, markerfacecolor=facecolor,
479
+ markeredgecolor=color, alpha=alpha, label=label)
480
+ if label is not None:
481
+ handles.append(h)
482
+
483
+ else: # 'line'
484
+ lw = s.get('lw', 2.0)
485
+ ls = s.get('ls', '-')
486
+ line = ax_r.plot(
487
+ x[m], y[m],
488
+ color=color, lw=lw, ls=ls,
489
+ alpha=alpha, zorder=zord, label=label
490
+ )[0]
491
+ if label is not None:
492
+ handles.append(line)
493
+
494
+ if legend and handles:
495
+ lk = {} if legend_kwargs is None else dict(legend_kwargs)
496
+ ax_r.legend(handles=handles,
497
+ labels=[h.get_label() for h in handles],
498
+ fontsize=legend_fontsize, loc='best', **lk)
499
+
500
+ # Optional SLD panel
501
+ if plot_sld_profile:
502
+ ax_s.set_xlabel(z_label, fontsize=axis_label_size)
503
+ ax_s.set_ylabel(sld_label, fontsize=axis_label_size)
504
+ ax_s.tick_params(axis='both', which='major', labelsize=tick_label_size)
505
+ ax_s.tick_params(axis='both', which='minor', labelsize=tick_label_size)
506
+
507
+ if sld_series:
508
+ for s in sld_series:
509
+ zx = _np(s.get('x')); zy = _np(s.get('y'))
510
+ if zx is None or zy is None:
511
+ continue
512
+ label = s.get('label', None)
513
+ color = s.get('color', None)
514
+ lw = s.get('lw', 2.0)
515
+ ls = s.get('ls', '-')
516
+ alpha = s.get('alpha', 1.0)
517
+ zord = s.get('zorder', None)
518
+ ax_s.plot(zx, zy, color=color, lw=lw, ls=ls, alpha=alpha, zorder=zord, label=label)
519
+
520
+ if legend:
521
+ ax_s.legend(fontsize=legend_fontsize, loc='best', **(legend_kwargs or {}))
522
+
523
+ plt.tight_layout()
524
+ return (fig, (ax_r, ax_s)) if plot_sld_profile else (fig, ax_r)
@@ -0,0 +1,7 @@
1
+ from reflectorch.inference.preprocess_exp.preprocess import (
2
+ standard_preprocessing,
3
+ StandardPreprocessing,
4
+ )
5
+ from reflectorch.inference.preprocess_exp.interpolation import interp_reflectivity
6
+ from reflectorch.inference.preprocess_exp.attenuation import apply_attenuation_correction
7
+ from reflectorch.inference.preprocess_exp.footprint import apply_footprint_correction
@@ -0,0 +1,36 @@
1
+ import numpy as np
2
+
3
+
4
+ def apply_attenuation_correction(
5
+ intensity: np.ndarray,
6
+ attenuation: np.ndarray,
7
+ scattering_angle: np.ndarray = None,
8
+ correct_discontinuities: bool = True
9
+ ) -> np.ndarray:
10
+ """Applies attenuation correction to experimental reflectivity curves
11
+
12
+ Args:
13
+ intensity (np.ndarray): intensities of an experimental reflectivity curve
14
+ attenuation (np.ndarray): attenuation factors for each measured point
15
+ scattering_angle (np.ndarray, optional): scattering angles of the measured points. Defaults to None.
16
+ correct_discontinuities (bool, optional): whether to correct discontinuities in the measured curves. Defaults to True.
17
+
18
+ Returns:
19
+ np.ndarray: the corrected reflectivity curve
20
+ """
21
+ intensity = intensity / attenuation
22
+ if correct_discontinuities:
23
+ if scattering_angle is None:
24
+ raise ValueError("correct_discontinuities options requires scattering_angle, but scattering_angle is None.")
25
+ intensity = apply_discontinuities_correction(intensity, scattering_angle)
26
+ return intensity
27
+
28
+
29
+ def apply_discontinuities_correction(intensity: np.ndarray, scattering_angle: np.ndarray) -> np.ndarray:
30
+ intensity = intensity.copy()
31
+ diff_angle = np.diff(scattering_angle)
32
+ for i in range(len(diff_angle)):
33
+ if diff_angle[i] == 0:
34
+ factor = intensity[i] / intensity[i + 1]
35
+ intensity[(i + 1):] *= factor
36
+ return intensity
@@ -0,0 +1,31 @@
1
+ import numpy as np
2
+
3
+ from reflectorch.utils import angle_to_q
4
+
5
+
6
+ def cut_curve(q: np.ndarray, curve: np.ndarray, max_q: float, max_angle: float, wavelength: float):
7
+ """Cuts an experimental reflectivity curve at a maximum q position
8
+
9
+ Args:
10
+ q (np.ndarray): the array of q points
11
+ curve (np.ndarray): the experimental reflectivity curve
12
+ max_q (float): the maximum q value at which the curve is cut
13
+ max_angle (float): the maximum scattering angle at which the curve is cut; only used if max_q is not provided
14
+ wavelength (float): the wavelength of the beam
15
+
16
+ Returns:
17
+ tuple: the q array after cutting, the reflectivity curve after cutting, and the ratio between the maximum q after cutting and before cutting
18
+ """
19
+ if max_angle is None and max_q is None:
20
+ q_ratio = 1.
21
+ else:
22
+ if max_q is None:
23
+ max_q = angle_to_q(max_angle, wavelength)
24
+
25
+ q_ratio = max_q / q.max()
26
+
27
+ if q_ratio < 1.:
28
+ idx = np.argmax(q > max_q)
29
+ q = q[:idx] / q_ratio
30
+ curve = curve[:idx]
31
+ return q, curve, q_ratio
@@ -0,0 +1,81 @@
1
+ try:
2
+ from typing import Literal
3
+ except ImportError:
4
+ from typing_extensions import Literal
5
+
6
+ import numpy as np
7
+ from scipy.special import erf
8
+
9
+ __all__ = [
10
+ "apply_footprint_correction",
11
+ "remove_footprint_correction",
12
+ "BEAM_SHAPE",
13
+ ]
14
+
15
+
16
+ BEAM_SHAPE = Literal["gauss", "box"]
17
+
18
+
19
+ def apply_footprint_correction(
20
+ intensity: np.ndarray,
21
+ scattering_angle: np.ndarray,
22
+ beam_width: float,
23
+ sample_length: float,
24
+ beam_shape: BEAM_SHAPE = "gauss",
25
+ ) -> np.ndarray:
26
+ """Applies footprint correction to an experimental reflectivity curve
27
+
28
+ Args:
29
+ intensity (np.ndarray): reflectivity curve
30
+ scattering_angle (np.ndarray): array of scattering angles
31
+ beam_width (float): the beam width
32
+ sample_length (float): the sample length
33
+ beam_shape (BEAM_SHAPE, optional): the shape of the beam, either "gauss" or "box". Defaults to "gauss".
34
+
35
+ Returns:
36
+ np.ndarray: the footprint corrected reflectivity curve
37
+ """
38
+ factors = _get_factors_by_beam_shape(
39
+ scattering_angle, beam_width, sample_length, beam_shape
40
+ )
41
+ return intensity.copy() * factors
42
+
43
+
44
+ def remove_footprint_correction(
45
+ intensity: np.ndarray,
46
+ scattering_angle: np.ndarray,
47
+ beam_width: float,
48
+ sample_length: float,
49
+ beam_shape: BEAM_SHAPE = "gauss",
50
+ ):
51
+ factors = _get_factors_by_beam_shape(
52
+ scattering_angle, beam_width, sample_length, beam_shape
53
+ )
54
+ return intensity.copy() / factors
55
+
56
+
57
+ def _get_factors_by_beam_shape(
58
+ scattering_angle: np.ndarray, beam_width: float, sample_length: float, beam_shape: BEAM_SHAPE
59
+ ):
60
+ if beam_shape == "gauss":
61
+ return gaussian_factors(scattering_angle, beam_width, sample_length)
62
+ elif beam_shape == "box":
63
+ return box_factors(scattering_angle, beam_width, sample_length)
64
+ else:
65
+ raise ValueError("invalid beam shape")
66
+
67
+
68
+ def box_factors(scattering_angle, beam_width, sample_length):
69
+ max_angle = 2 * np.arcsin(beam_width / sample_length) / np.pi * 180
70
+ ratios = beam_footprint_ratio(scattering_angle, beam_width, sample_length)
71
+ ones = np.ones_like(scattering_angle)
72
+ return np.where(scattering_angle < max_angle, ones * ratios, ones)
73
+
74
+
75
+ def gaussian_factors(scattering_angle, beam_width, sample_length):
76
+ ratio = beam_footprint_ratio(scattering_angle, beam_width, sample_length)
77
+ return 1 / erf(np.sqrt(np.log(2)) / ratio)
78
+
79
+
80
+ def beam_footprint_ratio(scattering_angle, beam_width, sample_length):
81
+ return beam_width / sample_length / np.sin(scattering_angle / 2 * np.pi / 180)