reflectorch 1.2.1__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (41) hide show
  1. reflectorch/data_generation/__init__.py +4 -0
  2. reflectorch/data_generation/dataset.py +27 -7
  3. reflectorch/data_generation/noise.py +115 -9
  4. reflectorch/data_generation/priors/parametric_models.py +91 -16
  5. reflectorch/data_generation/priors/parametric_subpriors.py +28 -7
  6. reflectorch/data_generation/priors/sampler_strategies.py +67 -3
  7. reflectorch/data_generation/q_generator.py +97 -43
  8. reflectorch/data_generation/reflectivity/__init__.py +53 -11
  9. reflectorch/data_generation/reflectivity/kinematical.py +4 -5
  10. reflectorch/data_generation/reflectivity/smearing.py +25 -10
  11. reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
  12. reflectorch/data_generation/smearing.py +42 -11
  13. reflectorch/data_generation/utils.py +93 -18
  14. reflectorch/extensions/refnx/refnx_conversion.py +77 -0
  15. reflectorch/inference/inference_model.py +795 -159
  16. reflectorch/inference/loading_data.py +37 -0
  17. reflectorch/inference/plotting.py +517 -0
  18. reflectorch/inference/preprocess_exp/interpolation.py +5 -2
  19. reflectorch/inference/scipy_fitter.py +98 -7
  20. reflectorch/ml/__init__.py +2 -0
  21. reflectorch/ml/basic_trainer.py +18 -6
  22. reflectorch/ml/callbacks.py +5 -4
  23. reflectorch/ml/loggers.py +25 -0
  24. reflectorch/ml/schedulers.py +116 -0
  25. reflectorch/ml/trainers.py +131 -23
  26. reflectorch/models/__init__.py +2 -1
  27. reflectorch/models/encoders/__init__.py +2 -2
  28. reflectorch/models/encoders/conv_encoder.py +54 -40
  29. reflectorch/models/encoders/fno.py +23 -16
  30. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  31. reflectorch/models/networks/__init__.py +2 -0
  32. reflectorch/models/networks/mlp_networks.py +331 -153
  33. reflectorch/models/networks/residual_net.py +31 -5
  34. reflectorch/runs/train.py +0 -1
  35. reflectorch/runs/utils.py +48 -11
  36. reflectorch/utils.py +30 -0
  37. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/METADATA +20 -17
  38. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/RECORD +41 -36
  39. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/WHEEL +1 -1
  40. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info/licenses}/LICENSE.txt +0 -0
  41. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,37 @@
1
+ from pathlib import Path
2
+ import numpy as np
3
+
4
+ def load_mft_data(filepath):
5
+ """
6
+ Load q, reflectivity, reflectivity error, and q-resolution from an .mft file.
7
+
8
+ Parameters:
9
+ filepath (str or Path): Path to the .mft file
10
+
11
+ Returns:
12
+ q, refl, refl_err, q_res : np.ndarray
13
+ """
14
+ filepath = Path(filepath)
15
+
16
+ with filepath.open('r', encoding='utf-8') as f:
17
+ lines = f.readlines()
18
+
19
+ start_idx = next(
20
+ i for i, line in enumerate(lines)
21
+ if line.strip().startswith('q') and 'q_res' in line
22
+ ) + 1
23
+
24
+ data = []
25
+ for line in lines[start_idx:]:
26
+ parts = line.strip().split()
27
+ if len(parts) == 4:
28
+ try:
29
+ data.append([float(p.replace('E', 'e')) for p in parts])
30
+ except ValueError:
31
+ continue
32
+
33
+ if not data:
34
+ raise ValueError(f"No valid data found in {filepath}")
35
+
36
+ data_array = np.array(data)
37
+ return data_array.T
@@ -0,0 +1,517 @@
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib.ticker as mticker
4
+ from matplotlib.lines import Line2D
5
+
6
+
7
+ def print_prediction_results(prediction_dict, param_names=None, width=10, precision=3, header=True):
8
+
9
+ if param_names is None:
10
+ param_names = prediction_dict.get("param_names", [])
11
+
12
+ pred = np.asarray(prediction_dict.get("predicted_params_array", []), dtype=float)
13
+ pol = prediction_dict.get("polished_params_array", None)
14
+ pol = np.asarray(pol, dtype=float) if pol is not None else None
15
+
16
+ name_w = max(14, max((len(str(n)) for n in param_names), default=14))
17
+
18
+ num_fmt = f"{{:>{width}.{precision}f}}"
19
+
20
+ if header:
21
+ hdr = f"{'Parameter'.ljust(name_w)} {'Predicted'.rjust(width)}"
22
+ if pol is not None:
23
+ hdr += f" {'Polished'.rjust(width)}"
24
+ print(hdr)
25
+ print("-" * len(hdr))
26
+
27
+ for i, name in enumerate(param_names):
28
+ pred_val = pred[i] if i < pred.size else float("nan")
29
+ row = f"{str(name).ljust(name_w)} {num_fmt.format(pred_val)}"
30
+ if pol is not None:
31
+ pol_val = pol[i] if i < pol.size else float("nan")
32
+ row += f" {num_fmt.format(pol_val)}"
33
+ print(row)
34
+
35
+
36
+ def plot_prediction_results(
37
+ prediction_dict: dict,
38
+ q_exp: np.ndarray,
39
+ curve_exp: np.ndarray,
40
+ sigmas_exp: np.ndarray = None,
41
+ logx=False,
42
+ ):
43
+ q_pred = prediction_dict['q_plot_pred']
44
+ r_pred = prediction_dict['predicted_curve']
45
+ r_pol = prediction_dict.get('polished_curve', None)
46
+
47
+ q_pol = None
48
+ if r_pol is not None:
49
+ if len(r_pol) == len(q_pred):
50
+ q_pol = q_pred
51
+ elif len(r_pol) == len(q_exp):
52
+ q_pol = q_exp
53
+
54
+ z_sld = prediction_dict.get('predicted_sld_xaxis', None)
55
+ sld_pred_c = prediction_dict.get('predicted_sld_profile', None)
56
+ sld_pol_c = prediction_dict.get('sld_profile_polished', None)
57
+
58
+ plot_sld = (z_sld is not None) and (sld_pred_c is not None or sld_pol_c is not None)
59
+
60
+ sld_is_complex = np.iscomplexobj(sld_pred_c)
61
+
62
+ sld_pred_label = 'pred. SLD (Re)' if sld_is_complex else 'pred. SLD'
63
+ sld_pol_label = 'polished SLD (Re)' if sld_is_complex else 'polished SLD'
64
+
65
+ fig, axes = plot_reflectivity(
66
+ q_exp=q_exp, r_exp=curve_exp, yerr=sigmas_exp,
67
+ q_pred=q_pred, r_pred=r_pred,
68
+ q_pol=q_pol, r_pol=r_pol,
69
+ z_sld=z_sld,
70
+ sld_pred=sld_pred_c.real if sld_pred_c is not None else None,
71
+ sld_pol=sld_pol_c.real if sld_pol_c is not None else None,
72
+ sld_pred_label=sld_pred_label,
73
+ sld_pol_label=sld_pol_label,
74
+ plot_sld_profile=plot_sld,
75
+ logx=logx,
76
+ )
77
+
78
+ if sld_is_complex and plot_sld:
79
+ ax_r, ax_s = axes
80
+ ax_s.plot(z_sld, sld_pred_c.imag, color='darkgreen', lw=2.0, ls='-', zorder=4, label='pred. SLD (Im)')
81
+ if sld_pol_c is not None:
82
+ ax_s.plot(z_sld, sld_pol_c.imag, color='cyan', lw=2.0, ls='--', zorder=5, label='polished SLD (Im)')
83
+ ax_s.legend(fontsize=14, frameon=True)
84
+
85
+ return fig, axes
86
+
87
+
88
+ def plot_reflectivity(
89
+ *,
90
+ q_exp=None,
91
+ r_exp=None,
92
+ yerr=None,
93
+ xerr=None,
94
+ q_pred=None,
95
+ r_pred=None,
96
+ q_pol=None,
97
+ r_pol=None,
98
+ z_sld=None,
99
+ sld_pred=None,
100
+ sld_pol=None,
101
+ plot_sld_profile=False,
102
+ figsize=None,
103
+ logx=False,
104
+ logy=True,
105
+ x_ticks_log=None,
106
+ y_ticks_log=(10.0 ** -np.arange(0, 12, 2)),
107
+ q_label=r'q [$\mathrm{\AA^{-1}}$]',
108
+ r_label='R(q)',
109
+ z_label=r'z [$\mathrm{\AA}$]',
110
+ sld_label=r'SLD [$10^{-6}\ \mathrm{\AA^{-2}}$]',
111
+ xlim=None,
112
+ axis_label_size=20,
113
+ tick_label_size=15,
114
+ legend_fontsize=14,
115
+ exp_style='auto',
116
+ exp_color='blue',
117
+ exp_facecolor='none',
118
+ exp_marker='o',
119
+ exp_ms=3,
120
+ exp_alpha=1.0,
121
+ exp_errcolor='purple',
122
+ exp_elinewidth=1.0,
123
+ exp_capsize=1.0,
124
+ exp_capthick=1.0,
125
+ exp_zorder=2,
126
+ pred_color='red',
127
+ pred_lw=2.0,
128
+ pred_ls='-',
129
+ pred_alpha=1.0,
130
+ pred_zorder=3,
131
+ pol_color='orange',
132
+ pol_lw=2.0,
133
+ pol_ls='--',
134
+ pol_alpha=1.0,
135
+ pol_zorder=4,
136
+ sld_pred_color='red',
137
+ sld_pred_lw=2.0,
138
+ sld_pred_ls='-',
139
+ sld_pol_color='orange',
140
+ sld_pol_lw=2.0,
141
+ sld_pol_ls='--',
142
+ exp_label='exp. data',
143
+ pred_label='prediction',
144
+ pol_label='polished prediction',
145
+ sld_pred_label='pred. SLD',
146
+ sld_pol_label='polished SLD',
147
+ legend=True,
148
+ legend_kwargs=None
149
+ ):
150
+
151
+ def _np(a):
152
+ return None if a is None else np.asarray(a)
153
+
154
+ def _mask(x, y):
155
+ m = np.isfinite(x) & np.isfinite(y)
156
+ if logx: m &= (x > 0.0)
157
+ if logy: m &= (y > 0.0)
158
+ return m
159
+
160
+ def _slice_sym_err(err, mask):
161
+ if err is None:
162
+ return None
163
+ if np.isscalar(err):
164
+ return err
165
+ e = np.asarray(err)
166
+ if e.ndim != 1:
167
+ raise ValueError("Errors must be scalar or 1-D array.")
168
+ return e[mask]
169
+
170
+ q_exp, r_exp, yerr, xerr = _np(q_exp), _np(r_exp), _np(yerr), _np(xerr)
171
+ q_pred, r_pred = _np(q_pred), _np(r_pred)
172
+ q_pol, r_pol = _np(q_pol), _np(r_pol)
173
+ z_sld, sld_pred, sld_pol = _np(z_sld), _np(sld_pred), _np(sld_pol)
174
+
175
+ # Figure & axes
176
+ if figsize is None:
177
+ figsize = (12, 6) if plot_sld_profile else (6, 6)
178
+ if plot_sld_profile:
179
+ fig, (ax_r, ax_s) = plt.subplots(1, 2, figsize=figsize)
180
+ else:
181
+ fig, ax_r = plt.subplots(1, 1, figsize=figsize)
182
+ ax_s = None
183
+
184
+ # Apply x-limits (right-only or both)
185
+ if xlim is not None:
186
+ if np.isscalar(xlim):
187
+ cur_left, _ = ax_r.get_xlim()
188
+ if logx and cur_left <= 0:
189
+ cur_left = 1e-12
190
+ ax_r.set_xlim(left=cur_left, right=float(xlim))
191
+ else:
192
+ xmin, xmax = xlim
193
+ if logx and xmin is not None and xmin <= 0:
194
+ raise ValueError("For log-x, xmin must be > 0.")
195
+ ax_r.set_xlim(left=xmin, right=xmax)
196
+
197
+ # Axis scales / labels / ticks
198
+ if logx: ax_r.set_xscale('log')
199
+ if logy: ax_r.set_yscale('log')
200
+
201
+ ax_r.set_xlabel(q_label, fontsize=axis_label_size)
202
+ ax_r.set_ylabel(r_label, fontsize=axis_label_size)
203
+ ax_r.tick_params(axis='both', which='major', labelsize=tick_label_size)
204
+ ax_r.tick_params(axis='both', which='minor', labelsize=tick_label_size)
205
+ if logx and x_ticks_log is not None:
206
+ ax_r.xaxis.set_major_locator(mticker.FixedLocator(x_ticks_log))
207
+ if logy and y_ticks_log is not None:
208
+ ax_r.yaxis.set_major_locator(mticker.FixedLocator(y_ticks_log))
209
+
210
+ handles = []
211
+
212
+ # Experimental plot
213
+ exp_handle = None
214
+ if q_exp is not None and r_exp is not None:
215
+ m = _mask(q_exp, r_exp)
216
+ style = exp_style if exp_style != 'auto' else ('errorbar' if yerr is not None else 'scatter')
217
+
218
+ if style == 'errorbar' and (yerr is not None):
219
+ yerr_m = _slice_sym_err(yerr, m)
220
+ xerr_m = _slice_sym_err(xerr, m)
221
+ ax_r.errorbar(
222
+ q_exp[m], r_exp[m], yerr=yerr_m, xerr=xerr_m,
223
+ color=exp_color, ecolor=exp_errcolor,
224
+ elinewidth=exp_elinewidth, capsize=exp_capsize,
225
+ capthick=(exp_elinewidth if exp_capthick is None else exp_capthick),
226
+ marker=exp_marker, linestyle='none', markersize=exp_ms,
227
+ markerfacecolor=exp_facecolor, markeredgecolor=exp_color,
228
+ alpha=exp_alpha, zorder=exp_zorder, label=None
229
+ )
230
+ exp_handle = Line2D([], [], color=exp_color, marker=exp_marker,
231
+ linestyle='none', markersize=exp_ms,
232
+ markerfacecolor=exp_facecolor, markeredgecolor=exp_color,
233
+ alpha=exp_alpha, label=exp_label)
234
+ elif style == 'scatter':
235
+ ax_r.scatter(
236
+ q_exp[m], r_exp[m],
237
+ s=exp_ms**2, marker=exp_marker,
238
+ facecolors=exp_facecolor, edgecolors=exp_color,
239
+ alpha=exp_alpha, zorder=exp_zorder, label=None
240
+ )
241
+ exp_handle = Line2D([], [], color=exp_color, marker=exp_marker,
242
+ linestyle='none', markersize=exp_ms,
243
+ markerfacecolor=exp_facecolor, markeredgecolor=exp_color,
244
+ alpha=exp_alpha, label=exp_label)
245
+ else: # 'line'
246
+ ln = ax_r.plot(
247
+ q_exp[m], r_exp[m], color=exp_color, lw=1.0, ls='-',
248
+ alpha=exp_alpha, zorder=exp_zorder, label=exp_label
249
+ )[0]
250
+ exp_handle = ln
251
+
252
+ if exp_handle is not None:
253
+ handles.append(exp_handle)
254
+
255
+ # Predicted line
256
+ pred_handle = None
257
+ if q_pred is not None and r_pred is not None:
258
+ m = _mask(q_pred, r_pred)
259
+ pred_handle = ax_r.plot(
260
+ q_pred[m], r_pred[m],
261
+ color=pred_color, lw=pred_lw, ls=pred_ls,
262
+ alpha=pred_alpha, zorder=pred_zorder, label=pred_label
263
+ )[0]
264
+ handles.append(pred_handle)
265
+
266
+ # Polished line
267
+ pol_handle = None
268
+ if q_pol is not None and r_pol is not None:
269
+ m = _mask(q_pol, r_pol)
270
+ pol_handle = ax_r.plot(
271
+ q_pol[m], r_pol[m],
272
+ color=pol_color, lw=pol_lw, ls=pol_ls,
273
+ alpha=pol_alpha, zorder=pol_zorder, label=pol_label
274
+ )[0]
275
+ handles.append(pol_handle)
276
+
277
+ if legend and handles:
278
+ lk = {} if legend_kwargs is None else dict(legend_kwargs)
279
+ ax_r.legend(handles=handles,
280
+ labels=[h.get_label() for h in handles],
281
+ fontsize=legend_fontsize, loc='best', **lk)
282
+
283
+ # SLD panel (optional)
284
+ if ax_s is not None:
285
+ ax_s.set_xlabel(z_label, fontsize=axis_label_size)
286
+ ax_s.set_ylabel(sld_label, fontsize=axis_label_size)
287
+ ax_s.tick_params(axis='both', which='major', labelsize=tick_label_size)
288
+ ax_s.tick_params(axis='both', which='minor', labelsize=tick_label_size)
289
+
290
+ if z_sld is not None and sld_pred is not None:
291
+ ax_s.plot(z_sld, sld_pred,
292
+ color=sld_pred_color, lw=sld_pred_lw, ls=sld_pred_ls,
293
+ label=sld_pred_label)
294
+ if z_sld is not None and sld_pol is not None:
295
+ ax_s.plot(z_sld, sld_pol,
296
+ color=sld_pol_color, lw=sld_pol_lw, ls=sld_pol_ls,
297
+ label=sld_pol_label)
298
+
299
+ if legend:
300
+ ax_s.legend(fontsize=legend_fontsize, loc='best', **(legend_kwargs or {}))
301
+
302
+ plt.tight_layout()
303
+ return (fig, (ax_r, ax_s)) if ax_s is not None else (fig, ax_r)
304
+
305
+
306
+ def plot_reflectivity_multi(
307
+ *,
308
+ rq_series,
309
+ sld_series=None,
310
+ plot_sld_profile=False,
311
+ figsize=None,
312
+ logx=False,
313
+ logy=True,
314
+ xlim=None,
315
+ x_ticks_log=None,
316
+ y_ticks_log=(10.0 ** -np.arange(0, 12, 2)),
317
+ q_label=r'q [$\mathrm{\AA^{-1}}$]',
318
+ r_label='R(q)',
319
+ z_label=r'z [$\mathrm{\AA}$]',
320
+ sld_label=r'SLD [$10^{-6}\ \mathrm{\AA^{-2}}$]',
321
+ axis_label_size=20,
322
+ tick_label_size=15,
323
+ legend=True,
324
+ legend_fontsize=12,
325
+ legend_kwargs=None,
326
+ ):
327
+ """
328
+ Plot multiple R(q) series (and optional SLD lines) with per-series styling.
329
+
330
+ rq_series: list of dicts, each with:
331
+ required:
332
+ - x: 1D array
333
+ - y: 1D array
334
+ optional (per series):
335
+ - kind: 'errorbar' | 'scatter' | 'line' (default 'line')
336
+ - label: str
337
+ - color: str
338
+ - alpha: float (0..1)
339
+ - zorder: int
340
+ # scatter / marker for errorbar:
341
+ - marker: str (default 'o')
342
+ - ms: float (marker size in pt; for scatter internally converted to s=ms**2)
343
+ - facecolor: str (scatter marker face)
344
+ # errorbar only:
345
+ - yerr: scalar or 1D array
346
+ - xerr: scalar or 1D array
347
+ - ecolor: str
348
+ - elinewidth: float
349
+ - capsize: float
350
+ - capthick: float
351
+ # line only:
352
+ - lw: float
353
+ - ls: str (e.g. '-', '--', ':')
354
+
355
+ sld_series: list of dicts (only lines), each with:
356
+ - x: 1D array (z-axis)
357
+ - y: 1D array (SLD)
358
+ - label: str (optional)
359
+ - color: str (optional)
360
+ - lw: float (optional)
361
+ - ls: str (optional)
362
+ - alpha: float (optional)
363
+ - zorder: int (optional)
364
+ """
365
+
366
+ def _np(a): return None if a is None else np.asarray(a)
367
+
368
+ def _mask(x, y):
369
+ m = np.isfinite(x) & np.isfinite(y)
370
+ if logx: m &= (x > 0.0)
371
+ if logy: m &= (y > 0.0)
372
+ return m
373
+
374
+ # Figure & axes
375
+ if figsize is None:
376
+ figsize = (12, 6) if plot_sld_profile else (6, 6)
377
+ if plot_sld_profile:
378
+ fig, (ax_r, ax_s) = plt.subplots(1, 2, figsize=figsize)
379
+ else:
380
+ fig, ax_r = plt.subplots(1, 1, figsize=figsize)
381
+ ax_s = None
382
+
383
+ # Axis scales / labels / ticks
384
+ if logx: ax_r.set_xscale('log')
385
+ if logy: ax_r.set_yscale('log')
386
+
387
+ ax_r.set_xlabel(q_label, fontsize=axis_label_size)
388
+ ax_r.set_ylabel(r_label, fontsize=axis_label_size)
389
+ ax_r.tick_params(axis='both', which='major', labelsize=tick_label_size)
390
+ ax_r.tick_params(axis='both', which='minor', labelsize=tick_label_size)
391
+ if logx and x_ticks_log is not None:
392
+ ax_r.xaxis.set_major_locator(mticker.FixedLocator(x_ticks_log))
393
+ if logy and y_ticks_log is not None:
394
+ ax_r.yaxis.set_major_locator(mticker.FixedLocator(y_ticks_log))
395
+
396
+ # Apply x-limits (right-only or both)
397
+ if xlim is not None:
398
+ if np.isscalar(xlim):
399
+ left, _ = ax_r.get_xlim()
400
+ if logx and left <= 0:
401
+ left = 1e-12
402
+ ax_r.set_xlim(left=left, right=float(xlim))
403
+ else:
404
+ xmin, xmax = xlim
405
+ if logx and xmin is not None and xmin <= 0:
406
+ raise ValueError("For log-x, xmin must be > 0.")
407
+ ax_r.set_xlim(left=xmin, right=xmax)
408
+
409
+ # Plot all R(q) series in the given order (order = legend order)
410
+ handles = []
411
+ for s in rq_series:
412
+ kind = s.get('kind', 'line')
413
+ x = _np(s.get('x'))
414
+ y = _np(s.get('y'))
415
+ if x is None or y is None:
416
+ continue
417
+
418
+ label = s.get('label', None)
419
+ color = s.get('color', None)
420
+ alpha = s.get('alpha', 1.0)
421
+ zord = s.get('zorder', None)
422
+ ms = s.get('ms', 5.0)
423
+ marker = s.get('marker', 'o')
424
+
425
+ m = _mask(x, y)
426
+
427
+ if kind == 'errorbar':
428
+ yerr = s.get('yerr', None)
429
+ xerr = s.get('xerr', None)
430
+ ecolor = s.get('ecolor', color)
431
+ elinewidth = s.get('elinewidth', 1.0)
432
+ capsize = s.get('capsize', 0.0)
433
+ capthick = s.get('capthick', elinewidth)
434
+
435
+ # Symmetric error input: scalar or 1D
436
+ def _slice_sym(err):
437
+ if err is None: return None
438
+ if np.isscalar(err): return err
439
+ arr = np.asarray(err)
440
+ if arr.ndim != 1:
441
+ raise ValueError("For symmetric error bars, provide scalar or 1-D array.")
442
+ return arr[m]
443
+
444
+ yerr_m = _slice_sym(yerr)
445
+ xerr_m = _slice_sym(xerr)
446
+
447
+ ax_r.errorbar(
448
+ x[m], y[m], yerr=yerr_m, xerr=xerr_m,
449
+ color=color, ecolor=ecolor,
450
+ elinewidth=elinewidth, capsize=capsize, capthick=capthick,
451
+ marker=marker, linestyle='none', markersize=ms,
452
+ markerfacecolor=s.get('facecolor', 'none'),
453
+ markeredgecolor=color,
454
+ alpha=alpha, zorder=zord, label=None
455
+ )
456
+
457
+ h = Line2D([], [], color=color, marker=marker, linestyle='none',
458
+ markersize=ms, markerfacecolor=s.get('facecolor','none'),
459
+ markeredgecolor=color, alpha=alpha, label=label)
460
+ if label is not None:
461
+ handles.append(h)
462
+
463
+ elif kind == 'scatter':
464
+ facecolor = s.get('facecolor', 'none')
465
+ sc = ax_r.scatter(
466
+ x[m], y[m], s=ms**2, marker=marker,
467
+ facecolors=facecolor, edgecolors=color,
468
+ alpha=alpha, zorder=zord, label=None
469
+ )
470
+ h = Line2D([], [], color=color, marker=marker, linestyle='none',
471
+ markersize=ms, markerfacecolor=facecolor,
472
+ markeredgecolor=color, alpha=alpha, label=label)
473
+ if label is not None:
474
+ handles.append(h)
475
+
476
+ else: # 'line'
477
+ lw = s.get('lw', 2.0)
478
+ ls = s.get('ls', '-')
479
+ line = ax_r.plot(
480
+ x[m], y[m],
481
+ color=color, lw=lw, ls=ls,
482
+ alpha=alpha, zorder=zord, label=label
483
+ )[0]
484
+ if label is not None:
485
+ handles.append(line)
486
+
487
+ if legend and handles:
488
+ lk = {} if legend_kwargs is None else dict(legend_kwargs)
489
+ ax_r.legend(handles=handles,
490
+ labels=[h.get_label() for h in handles],
491
+ fontsize=legend_fontsize, loc='best', **lk)
492
+
493
+ # Optional SLD panel
494
+ if plot_sld_profile:
495
+ ax_s.set_xlabel(z_label, fontsize=axis_label_size)
496
+ ax_s.set_ylabel(sld_label, fontsize=axis_label_size)
497
+ ax_s.tick_params(axis='both', which='major', labelsize=tick_label_size)
498
+ ax_s.tick_params(axis='both', which='minor', labelsize=tick_label_size)
499
+
500
+ if sld_series:
501
+ for s in sld_series:
502
+ zx = _np(s.get('x')); zy = _np(s.get('y'))
503
+ if zx is None or zy is None:
504
+ continue
505
+ label = s.get('label', None)
506
+ color = s.get('color', None)
507
+ lw = s.get('lw', 2.0)
508
+ ls = s.get('ls', '-')
509
+ alpha = s.get('alpha', 1.0)
510
+ zord = s.get('zorder', None)
511
+ ax_s.plot(zx, zy, color=color, lw=lw, ls=ls, alpha=alpha, zorder=zord, label=label)
512
+
513
+ if legend:
514
+ ax_s.legend(fontsize=legend_fontsize, loc='best', **(legend_kwargs or {}))
515
+
516
+ plt.tight_layout()
517
+ return (fig, (ax_r, ax_s)) if plot_sld_profile else (fig, ax_r)
@@ -1,7 +1,7 @@
1
1
  import numpy as np
2
2
 
3
3
 
4
- def interp_reflectivity(q_interp, q, reflectivity, min_value: float = 1e-10):
4
+ def interp_reflectivity(q_interp, q, reflectivity, min_value: float = 1e-10, logspace = False):
5
5
  """Interpolate data on a base 10 logarithmic scale
6
6
 
7
7
  Args:
@@ -13,4 +13,7 @@ def interp_reflectivity(q_interp, q, reflectivity, min_value: float = 1e-10):
13
13
  Returns:
14
14
  array-like: interpolated reflectivity curve
15
15
  """
16
- return 10 ** np.interp(q_interp, q, np.log10(np.clip(reflectivity, min_value, None)))
16
+ if not(logspace):
17
+ return 10 ** np.interp(q_interp, q, np.log10(np.clip(reflectivity, min_value, None)))
18
+ else:
19
+ return 10 ** np.interp(np.log10(q_interp), np.log10(q), np.log10(np.clip(reflectivity, min_value, None)))