xarpes 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xarpes/mdcs.py ADDED
@@ -0,0 +1,1035 @@
1
+ # Copyright (C) 2025 xARPES Developers
2
+ # This program is free software under the terms of the GNU GPLv3 license.
3
+
4
+ # get_ax_fig_plt and add_fig_kwargs originate from pymatgen/util/plotting.py.
5
+ # Copyright (C) 2011-2024 Shyue Ping Ong and the pymatgen Development Team
6
+ # Pymatgen is released under the MIT License.
7
+
8
+ # See also abipy/tools/plotting.py.
9
+ # Copyright (C) 2021 Matteo Giantomassi and the AbiPy Group
10
+ # AbiPy is free software under the terms of the GNU GPLv2 license.
11
+
12
+ """File containing the MDCs class."""
13
+
14
+ import numpy as np
15
+ from .plotting import get_ax_fig_plt, add_fig_kwargs
16
+ from .functions import extend_function
17
+ from .constants import KILO
18
+
19
+ class MDCs:
20
+ r"""
21
+ Container for momentum distribution curves (MDCs) and their fits.
22
+
23
+ This class stores the MDC intensity maps, angular and energy grids, and
24
+ the aggregated fit results produced by :meth:`fit_selection`.
25
+
26
+ Parameters
27
+ ----------
28
+ intensities : ndarray
29
+ MDC intensity data. Typically a 2D array with shape
30
+ ``(n_energy, n_angle)`` or a 1D array for a single curve.
31
+ angles : ndarray
32
+ Angular grid corresponding to the MDCs [degrees].
33
+ angle_resolution : float
34
+ Angular step size or effective angular resolution [degrees].
35
+ enel : ndarray or float
36
+ Electron binding energies of the MDC slices [eV].
37
+ Can be a scalar for a single MDC.
38
+ hnuminPhi : float
39
+ Photon energy minus work function, used to convert ``enel`` to
40
+ kinetic energy [eV].
41
+
42
+ Attributes
43
+ ----------
44
+ intensities : ndarray
45
+ MDC intensity data (same object as passed to the constructor).
46
+ angles : ndarray
47
+ Angular grid [degrees].
48
+ angle_resolution : float
49
+ Angular step size or resolution [degrees].
50
+ enel : ndarray or float
51
+ Electron binding energies [eV], as given at construction.
52
+ ekin : ndarray or float
53
+ Kinetic energies [eV], computed as ``enel + hnuminPhi``.
54
+ hnuminPhi : float
55
+ Photon energy minus work function [eV].
56
+ ekin_range : ndarray
57
+ Kinetic-energy values of the slices that were actually fitted.
58
+ Set by :meth:`fit_selection`.
59
+ individual_properties : dict
60
+ Nested mapping of fitted parameters and their uncertainties for each
61
+ component and each energy slice. Populated by :meth:`fit_selection`.
62
+
63
+ Notes
64
+ -----
65
+ After calling :meth:`fit_selection`, :attr:`individual_properties` has the
66
+ structure::
67
+
68
+ {
69
+ label: {
70
+ class_name: {
71
+ 'label': label,
72
+ '_class': class_name,
73
+ param: [values per energy slice],
74
+ param_sigma: [1σ per slice or None],
75
+ ...
76
+ }
77
+ }
78
+ }
79
+
80
+ where ``param`` is typically one of ``'offset'``, ``'slope'``,
81
+ ``'amplitude'``, ``'peak'``, ``'broadening'``, and ``param_sigma`` stores
82
+ the corresponding uncertainty for each slice.
83
+
84
+ """
85
+
86
+ def __init__(self, intensities, angles, angle_resolution, enel, hnuminPhi):
87
+ # Core input data (read-only)
88
+ self._intensities = intensities
89
+ self._angles = angles
90
+ self._angle_resolution = angle_resolution
91
+ self._enel = enel
92
+ self._hnuminPhi = hnuminPhi
93
+
94
+ # Derived attributes (populated by fit_selection)
95
+ self._ekin_range = None
96
+ self._individual_properties = None # combined values + sigmas
97
+
98
+ # -------------------- Immutable physics inputs --------------------
99
+
100
+ @property
101
+ def angles(self):
102
+ """Angular axis for the MDCs."""
103
+ return self._angles
104
+
105
+ @property
106
+ def angle_resolution(self):
107
+ """Angular step size (float)."""
108
+ return self._angle_resolution
109
+
110
+ @property
111
+ def enel(self):
112
+ """Photoelectron binding energies (array-like). Read-only."""
113
+ return self._enel
114
+
115
+ @enel.setter
116
+ def enel(self, _):
117
+ raise AttributeError("`enel` is read-only; set it via the constructor.")
118
+
119
+ @property
120
+ def hnuminPhi(self):
121
+ """Work-function/photon-energy offset. Read-only."""
122
+ return self._hnuminPhi
123
+
124
+ @hnuminPhi.setter
125
+ def hnuminPhi(self, _):
126
+ raise AttributeError("`hnuminPhi` is read-only; set it via the constructor.")
127
+
128
+ @property
129
+ def ekin(self):
130
+ """Kinetic energy array: enel + hnuminPhi (computed on the fly)."""
131
+ return self._enel + self._hnuminPhi
132
+
133
+ @ekin.setter
134
+ def ekin(self, _):
135
+ raise AttributeError("`ekin` is derived and read-only.")
136
+
137
+ # -------------------- Data arrays --------------------
138
+
139
+ @property
140
+ def intensities(self):
141
+ """2D or 3D intensity map (energy × angle)."""
142
+ return self._intensities
143
+
144
+ @intensities.setter
145
+ def intensities(self, x):
146
+ self._intensities = x
147
+
148
+ # -------------------- Results populated by fit_selection --------------------
149
+
150
+ @property
151
+ def ekin_range(self):
152
+ """Kinetic-energy slices that were fitted."""
153
+ if self._ekin_range is None:
154
+ raise AttributeError("`ekin_range` not yet set. Run `.fit_selection()` first.")
155
+ return self._ekin_range
156
+
157
+ @property
158
+ def individual_properties(self):
159
+ """
160
+ Aggregated fitted parameter values and uncertainties per component.
161
+
162
+ Returns
163
+ -------
164
+ dict
165
+ Nested mapping::
166
+
167
+ {
168
+ label: {
169
+ class_name: {
170
+ 'label': label,
171
+ '_class': class_name,
172
+ <param>: [values per slice],
173
+ <param>_sigma: [1σ per slice or None],
174
+ ...
175
+ }
176
+ }
177
+ }
178
+ """
179
+ if self._individual_properties is None:
180
+ raise AttributeError(
181
+ "`individual_properties` not yet set. Run `.fit_selection()` first."
182
+ )
183
+ return self._individual_properties
184
+
185
+ def energy_check(self, energy_value):
186
+ r"""
187
+ """
188
+ if np.isscalar(self.ekin):
189
+ if energy_value is not None:
190
+ raise ValueError("This dataset contains only one " \
191
+ "momentum-distribution curve; do not provide energy_value.")
192
+ else:
193
+ kinergy = self.ekin
194
+ counts = self.intensities
195
+ else:
196
+ if energy_value is None:
197
+ raise ValueError("This dataset contains multiple " \
198
+ "momentum-distribution curves. Please provide an energy_value "
199
+ "for which to plot the MDCs.")
200
+ else:
201
+ energy_index = np.abs(self.enel - energy_value).argmin()
202
+ kinergy = self.ekin[energy_index]
203
+ counts = self.intensities[energy_index, :]
204
+
205
+ if not (self.enel.min() <= energy_value <= self.enel.max()):
206
+ raise ValueError(
207
+ f"Selected energy_value={energy_value:.3f} "
208
+ f"is outside the available energy range "
209
+ f"({self.enel.min():.3f} – {self.enel.max():.3f}) "
210
+ "of the MDC collection."
211
+ )
212
+
213
+ return counts, kinergy
214
+
215
+
216
+ def plot(self, energy_value=None, energy_range=None, ax=None, **kwargs):
217
+ """
218
+ Interactive or static plot with optional slider and full wrapper
219
+ support. Behavior consistent with Jupyter and CLI based on show /
220
+ fig_close.
221
+ """
222
+ import matplotlib.pyplot as plt
223
+ from matplotlib.widgets import Slider
224
+ import string
225
+ import sys
226
+ import warnings
227
+
228
+ # Wrapper kwargs
229
+ title = kwargs.pop("title", None)
230
+ savefig = kwargs.pop("savefig", None)
231
+ show = kwargs.pop("show", True)
232
+ fig_close = kwargs.pop("fig_close", False)
233
+ tight_layout = kwargs.pop("tight_layout", False)
234
+ ax_grid = kwargs.pop("ax_grid", None)
235
+ ax_annotate = kwargs.pop("ax_annotate", False)
236
+ size_kwargs = kwargs.pop("size_kwargs", None)
237
+
238
+ if energy_value is not None and energy_range is not None:
239
+ raise ValueError(
240
+ "Provide at most energy_value or energy_range, not both.")
241
+
242
+ ax, fig, plt = get_ax_fig_plt(ax=ax)
243
+
244
+ angles = self.angles
245
+ energies = self.enel
246
+
247
+ if np.isscalar(energies):
248
+ if energy_value is not None or energy_range is not None:
249
+ raise ValueError(
250
+ "This dataset contains only one momentum-distribution "
251
+ "curve; do not provide energy_value or energy_range."
252
+ )
253
+
254
+ intensities = self.intensities
255
+ ax.scatter(angles, intensities, label="Data")
256
+ ax.set_title(f"Energy slice: {energies * KILO:.3f} meV")
257
+
258
+ # --- y-only autoscale, preserve x ---
259
+ x0, x1 = ax.get_xlim() # keep current x-range
260
+ ax.relim(visible_only=True) # recompute data limits
261
+ ax.autoscale_view(scalex=False, scaley=True)
262
+ ax.set_xlim(x0, x1) # restore x (belt-and-suspenders)
263
+
264
+ else:
265
+ if (energy_value is not None) and (energy_range is not None):
266
+ raise ValueError("Provide either energy_value or energy_range, not both.")
267
+
268
+ emin, emax = energies.min(), energies.max()
269
+
270
+ # ---- Single-slice path (no slider) ----
271
+ if energy_value is not None:
272
+ if energy_value < emin or energy_value > emax:
273
+ raise ValueError(
274
+ f"Requested energy_value {energy_value:.3f} eV is "
275
+ f"outside the available energy range "
276
+ f"[{emin:.3f}, {emax:.3f}] eV."
277
+ )
278
+ idx = int(np.abs(energies - energy_value).argmin())
279
+ intensities = self.intensities[idx]
280
+ ax.scatter(angles, intensities, label="Data")
281
+ ax.set_title(f"Energy slice: {energies[idx] * KILO:.3f} meV")
282
+
283
+ # --- y-only autoscale, preserve x ---
284
+ x0, x1 = ax.get_xlim() # keep current x-range
285
+ ax.relim(visible_only=True) # recompute data limits
286
+ ax.autoscale_view(scalex=False, scaley=True)
287
+ ax.set_xlim(x0, x1) # restore x (belt-and-suspenders)
288
+
289
+ # ---- Multi-slice path (slider) ----
290
+ else:
291
+ if energy_range is not None:
292
+ e_min, e_max = energy_range
293
+ mask = (energies >= e_min) & (energies <= e_max)
294
+ else:
295
+ mask = np.ones_like(energies, dtype=bool)
296
+
297
+ indices = np.where(mask)[0]
298
+ if len(indices) == 0:
299
+ raise ValueError("No energies found in the specified selection.")
300
+
301
+ intensities = self.intensities[indices]
302
+
303
+ fig.subplots_adjust(bottom=0.25)
304
+ idx = 0
305
+ scatter = ax.scatter(angles, intensities[idx], label="Data")
306
+ ax.set_title(f"Energy slice: "
307
+ f"{energies[indices[idx]] * KILO:.3f} meV")
308
+
309
+ # Suppress single-point slider warning (when len(indices) == 1)
310
+ warnings.filterwarnings(
311
+ "ignore",
312
+ message="Attempting to set identical left == right",
313
+ category=UserWarning
314
+ )
315
+
316
+ slider_ax = fig.add_axes([0.2, 0.08, 0.6, 0.04])
317
+ slider = Slider(
318
+ slider_ax, "Index", 0, len(indices) - 1,
319
+ valinit=idx, valstep=1
320
+ )
321
+
322
+ def update(val):
323
+ i = int(slider.val)
324
+ yi = intensities[i]
325
+
326
+ scatter.set_offsets(np.c_[angles, yi])
327
+
328
+ x0, x1 = ax.get_xlim()
329
+
330
+ yv = np.asarray(yi, dtype=float).ravel()
331
+ mask = np.isfinite(yv)
332
+ if mask.any():
333
+ y_min = float(yv[mask].min())
334
+ y_max = float(yv[mask].max())
335
+ span = y_max - y_min
336
+ frac = plt.rcParams['axes.ymargin']
337
+
338
+ if span <= 0 or not np.isfinite(span):
339
+ scale = max(abs(y_max), 1.0)
340
+ pad = frac * scale
341
+ else:
342
+ pad = frac * span
343
+
344
+ ax.set_ylim(y_min - pad, y_max + pad)
345
+
346
+ # Keep x unchanged
347
+ ax.set_xlim(x0, x1)
348
+
349
+ # Update title and redraw
350
+ ax.set_title(f"Energy slice: "
351
+ f"{energies[indices[i]] * KILO:.3f} meV")
352
+ fig.canvas.draw_idle()
353
+
354
+ slider.on_changed(update)
355
+ self._slider = slider
356
+ self._line = scatter
357
+
358
+ ax.set_xlabel("Angle (°)")
359
+ ax.set_ylabel("Counts (-)")
360
+ ax.legend()
361
+ self._fig = fig
362
+
363
+ if size_kwargs:
364
+ fig.set_size_inches(size_kwargs.pop("w"),
365
+ size_kwargs.pop("h"), **size_kwargs)
366
+ if title:
367
+ fig.suptitle(title)
368
+ if tight_layout:
369
+ fig.tight_layout()
370
+ if savefig:
371
+ fig.savefig(savefig)
372
+ if ax_grid is not None:
373
+ for axis in fig.axes:
374
+ axis.grid(bool(ax_grid))
375
+ if ax_annotate:
376
+ tags = string.ascii_lowercase
377
+ for i, axis in enumerate(fig.axes):
378
+ axis.annotate(f"({tags[i]})", xy=(0.05, 0.95),
379
+ xycoords="axes fraction")
380
+
381
+ is_interactive = hasattr(sys, 'ps1') or 'ipykernel' in sys.modules
382
+ is_cli = not is_interactive
383
+
384
+ if show:
385
+ if is_cli:
386
+ plt.show()
387
+ if fig_close:
388
+ plt.close(fig)
389
+
390
+ if not show and (fig_close or is_cli):
391
+ return None
392
+ return fig
393
+
394
+
395
+ @add_fig_kwargs
396
+ def visualize_guess(self, distributions, energy_value=None,
397
+ matrix_element=None, matrix_args=None,
398
+ ax=None, **kwargs):
399
+ r"""
400
+ """
401
+
402
+ counts, kinergy = self.energy_check(energy_value)
403
+
404
+ ax, fig, plt = get_ax_fig_plt(ax=ax)
405
+
406
+ ax.set_xlabel('Angle ($\\degree$)')
407
+ ax.set_ylabel('Counts (-)')
408
+ ax.set_title(f"Energy slice: "
409
+ f"{(kinergy - self.hnuminPhi) * KILO:.3f} meV")
410
+ ax.scatter(self.angles, counts, label='Data')
411
+
412
+ final_result = self._merge_and_plot(ax=ax,
413
+ distributions=distributions, kinetic_energy=kinergy,
414
+ matrix_element=matrix_element,
415
+ matrix_args=dict(matrix_args) if matrix_args else None,
416
+ plot_individual=True,
417
+ )
418
+
419
+ residual = counts - final_result
420
+ ax.scatter(self.angles, residual, label='Residual')
421
+ ax.legend()
422
+
423
+ return fig
424
+
425
+
426
+ def fit_selection(self, distributions, energy_value=None, energy_range=None,
427
+ matrix_element=None, matrix_args=None, ax=None, **kwargs):
428
+ r"""
429
+ """
430
+ import matplotlib.pyplot as plt
431
+ from matplotlib.widgets import Slider
432
+ from copy import deepcopy
433
+ import string
434
+ import sys
435
+ import warnings
436
+ from lmfit import Minimizer
437
+ from scipy.ndimage import gaussian_filter
438
+ from .functions import construct_parameters, build_distributions, \
439
+ residual, resolve_param_name
440
+
441
+ # Wrapper kwargs
442
+ title = kwargs.pop("title", None)
443
+ savefig = kwargs.pop("savefig", None)
444
+ show = kwargs.pop("show", True)
445
+ fig_close = kwargs.pop("fig_close", False)
446
+ tight_layout = kwargs.pop("tight_layout", False)
447
+ ax_grid = kwargs.pop("ax_grid", None)
448
+ ax_annotate = kwargs.pop("ax_annotate", False)
449
+ size_kwargs = kwargs.pop("size_kwargs", None)
450
+
451
+ ax, fig, plt = get_ax_fig_plt(ax=ax)
452
+
453
+ energies = self.enel
454
+ new_distributions = deepcopy(distributions)
455
+
456
+ if energy_value is not None and energy_range is not None:
457
+ raise ValueError(
458
+ "Provide at most energy_value or energy_range, not both.")
459
+
460
+ if np.isscalar(energies):
461
+ if energy_value is not None or energy_range is not None:
462
+ raise ValueError(
463
+ "This dataset contains only one momentum-distribution "
464
+ "curve; do not provide energy_value or energy_range."
465
+ )
466
+ kinergies = np.atleast_1d(self.ekin)
467
+ intensities = np.atleast_2d(self.intensities)
468
+
469
+ else:
470
+ if energy_value is not None:
471
+ if (energy_value < energies.min() or energy_value > energies.max()):
472
+ raise ValueError( f"Requested energy_value {energy_value:.3f} eV is "
473
+ f"outside the available energy range "
474
+ f"[{energies.min():.3f}, {energies.max():.3f}] eV." )
475
+ idx = np.abs(energies - energy_value).argmin()
476
+ indices = np.atleast_1d(idx)
477
+ kinergies = self.ekin[indices]
478
+ intensities = self.intensities[indices, :]
479
+
480
+ elif energy_range is not None:
481
+ e_min, e_max = energy_range
482
+ indices = np.where((energies >= e_min) & (energies <= e_max))[0]
483
+ if len(indices) == 0:
484
+ raise ValueError("No energies found in the specified energy_range.")
485
+ kinergies = self.ekin[indices]
486
+ intensities = self.intensities[indices, :]
487
+
488
+ else: # Without specifying a range, all MDCs are plotted
489
+ kinergies = self.ekin
490
+ intensities = self.intensities
491
+
492
+ # Final shape guard
493
+ kinergies = np.atleast_1d(kinergies)
494
+ intensities = np.atleast_2d(intensities)
495
+
496
+ all_final_results = []
497
+ all_residuals = []
498
+ all_individual_results = [] # List of (n_individuals, n_angles)
499
+
500
+ aggregated_properties = {}
501
+
502
+ # map class_name -> parameter names to extract
503
+ param_spec = {
504
+ 'Constant': ('offset',),
505
+ 'Linear': ('offset', 'slope'),
506
+ 'SpectralLinear': ('amplitude', 'peak', 'broadening'),
507
+ 'SpectralQuadratic': ('amplitude', 'peak', 'broadening'),
508
+ }
509
+
510
+ order = np.argsort(kinergies)[::-1]
511
+ for idx in order:
512
+ kinergy = kinergies[idx]
513
+ intensity = intensities[idx]
514
+ if matrix_element is not None:
515
+ parameters, element_names = construct_parameters(
516
+ new_distributions, matrix_args)
517
+ new_distributions = build_distributions(new_distributions, parameters)
518
+ mini = Minimizer(
519
+ residual, parameters,
520
+ fcn_args=(self.angles, intensity, self.angle_resolution,
521
+ new_distributions, kinergy, self.hnuminPhi,
522
+ matrix_element, element_names)
523
+ )
524
+ else:
525
+ parameters = construct_parameters(new_distributions)
526
+ new_distributions = build_distributions(new_distributions, parameters)
527
+ mini = Minimizer(
528
+ residual, parameters,
529
+ fcn_args=(self.angles, intensity, self.angle_resolution,
530
+ new_distributions, kinergy, self.hnuminPhi)
531
+ )
532
+
533
+ outcome = mini.minimize('least_squares')
534
+
535
+ pcov = outcome.covar
536
+
537
+ var_names = getattr(outcome, 'var_names', None)
538
+ if not var_names:
539
+ var_names = [n for n, p in outcome.params.items() if p.vary]
540
+ var_idx = {n: i for i, n in enumerate(var_names)}
541
+
542
+ param_sigma_full = {}
543
+ for name, par in outcome.params.items():
544
+ sigma = None
545
+ if pcov is not None and name in var_idx:
546
+ d = pcov[var_idx[name], var_idx[name]]
547
+ if np.isfinite(d) and d >= 0:
548
+ sigma = float(np.sqrt(d))
549
+ if sigma is None:
550
+ s = getattr(par, 'stderr', None)
551
+ sigma = float(s) if s is not None else None
552
+ param_sigma_full[name] = sigma
553
+
554
+ # Rebuild the *fitted* distributions from optimized params
555
+ fitted_distributions = build_distributions(new_distributions, outcome.params)
556
+
557
+ # If using a matrix element, extract slice-specific args from the fit
558
+ if matrix_element is not None:
559
+ new_matrix_args = {key: outcome.params[key].value for key in matrix_args}
560
+ else:
561
+ new_matrix_args = None
562
+
563
+ # individual curves (smoothed, cropped) and final sum (no plotting here)
564
+ extend, step, numb = extend_function(self.angles, self.angle_resolution)
565
+
566
+ total_result_ext = np.zeros_like(extend)
567
+ indiv_rows = [] # (n_individuals, n_angles)
568
+ individual_labels = []
569
+
570
+ for dist in fitted_distributions:
571
+ # evaluate each component on the extended grid
572
+ if getattr(dist, 'class_name', None) == 'SpectralQuadratic':
573
+ if (getattr(dist, 'center_angle', None) is not None) and (
574
+ kinergy is None or self.hnuminPhi is None
575
+ ):
576
+ raise ValueError(
577
+ 'Spectral quadratic function is defined in terms '
578
+ 'of a center angle. Please provide a kinetic energy '
579
+ 'and hnuminPhi.'
580
+ )
581
+ extended_result = dist.evaluate(extend, kinergy, self.hnuminPhi)
582
+ else:
583
+ extended_result = dist.evaluate(extend)
584
+
585
+ if matrix_element is not None and hasattr(dist, 'index'):
586
+ args = new_matrix_args or {}
587
+ extended_result *= matrix_element(extend, **args)
588
+
589
+ total_result_ext += extended_result
590
+
591
+ # smoothed & cropped individual
592
+ individual_curve = gaussian_filter(extended_result, sigma=step)[
593
+ numb:-numb if numb else None
594
+ ]
595
+ indiv_rows.append(np.asarray(individual_curve))
596
+
597
+ # label
598
+ label = getattr(dist, 'label', str(dist))
599
+ individual_labels.append(label)
600
+
601
+ # ---- collect parameters for this distribution
602
+ # (Aggregated over slices)
603
+ cls = getattr(dist, 'class_name', None)
604
+ wanted = param_spec.get(cls, ())
605
+
606
+ # ensure dicts exist
607
+ label_bucket = aggregated_properties.setdefault(label, {})
608
+ class_bucket = label_bucket.setdefault(
609
+ cls, {'label': label, '_class': cls}
610
+ )
611
+
612
+ # store center_wavevector (scalar) for SpectralQuadratic
613
+ if (
614
+ cls == 'SpectralQuadratic'
615
+ and hasattr(dist, 'center_wavevector')
616
+ ):
617
+ class_bucket.setdefault(
618
+ 'center_wavevector', dist.center_wavevector
619
+ )
620
+
621
+ # ensure keys for both values and sigmas
622
+ for pname in wanted:
623
+ class_bucket.setdefault(pname, [])
624
+ class_bucket.setdefault(f"{pname}_sigma", [])
625
+
626
+ # append values and sigmas in the order of slices
627
+ for pname in wanted:
628
+ param_key = resolve_param_name(outcome.params, label, pname)
629
+
630
+ if param_key is not None and param_key in outcome.params:
631
+ class_bucket[pname].append(outcome.params[param_key].value)
632
+ class_bucket[f"{pname}_sigma"].append(param_sigma_full.get(param_key, None))
633
+ else:
634
+ # Not fitted in this slice → keep the value if present on the dist, sigma=None
635
+ class_bucket[pname].append(getattr(dist, pname, None))
636
+ class_bucket[f"{pname}_sigma"].append(None)
637
+
638
+ # final (sum) curve, smoothed & cropped
639
+ final_result_i = gaussian_filter(total_result_ext, sigma=step)[
640
+ numb:-numb if numb else None]
641
+ final_result_i = np.asarray(final_result_i)
642
+
643
+ # Residual for this slice
644
+ residual_i = np.asarray(intensity) - final_result_i
645
+
646
+ # Store per-slice results
647
+ all_final_results.append(final_result_i)
648
+ all_residuals.append(residual_i)
649
+ all_individual_results.append(np.vstack(indiv_rows))
650
+
651
+ # --- after the reversed-order loop, restore original (ascending) order ---
652
+ inverse_order = np.argsort(np.argsort(kinergies)[::-1])
653
+
654
+ # Reorder per-slice arrays/lists computed in the loop
655
+ all_final_results[:] = [all_final_results[i] for i in inverse_order]
656
+ all_residuals[:] = [all_residuals[i] for i in inverse_order]
657
+ all_individual_results[:] = [all_individual_results[i] for i in inverse_order]
658
+
659
+ # Reorder all per-slice lists in aggregated_properties
660
+ for label_dict in aggregated_properties.values():
661
+ for cls_dict in label_dict.values():
662
+ for key, val in cls_dict.items():
663
+ if isinstance(val, list) and len(val) == len(kinergies):
664
+ cls_dict[key] = [val[i] for i in inverse_order]
665
+
666
+ self._ekin_range = kinergies
667
+ self._individual_properties = aggregated_properties
668
+
669
+ if np.isscalar(energies):
670
+ # One slice only: plot MDC, Fit, Residual, and Individuals
671
+ ydata = np.asarray(intensities).squeeze()
672
+ yfit = np.asarray(all_final_results[0]).squeeze()
673
+ yres = np.asarray(all_residuals[0]).squeeze()
674
+ yind = np.asarray(all_individual_results[0])
675
+
676
+ ax.scatter(self.angles, ydata, label="Data")
677
+ # plot individuals with their labels
678
+ for j, lab in enumerate(individual_labels or []):
679
+ ax.plot(self.angles, yind[j], label=str(lab))
680
+ ax.plot(self.angles, yfit, label="Fit")
681
+ ax.scatter(self.angles, yres, label="Residual")
682
+
683
+ ax.set_title(f"Energy slice: {energies * KILO:.3f} meV")
684
+ ax.relim() # recompute data limits from all artists
685
+ ax.autoscale_view() # apply autoscaling + axes.ymargin padding
686
+
687
+ else:
688
+ if energy_value is not None:
689
+ _idx = int(np.abs(energies - energy_value).argmin())
690
+ energies_sel = np.atleast_1d(energies[_idx])
691
+ elif energy_range is not None:
692
+ e_min, e_max = energy_range
693
+ energies_sel = energies[(energies >= e_min)
694
+ & (energies <= e_max)]
695
+ else:
696
+ energies_sel = energies
697
+
698
+ # Number of slices must match
699
+ n_slices = len(all_final_results)
700
+ assert intensities.shape[0] == n_slices == len(all_residuals) \
701
+ == len(all_individual_results), (f"Mismatch: data \
702
+ {intensities.shape[0]}, fits {len(all_final_results)}, "
703
+ f"residuals {len(all_residuals)}, \
704
+ individuals {len(all_individual_results)}."
705
+ )
706
+ n_individuals = all_individual_results[0].shape[0] \
707
+ if n_slices else 0
708
+
709
+ fig.subplots_adjust(bottom=0.25)
710
+ idx = 0
711
+
712
+ # Initial draw (MDC + Individuals + Fit + Residual) at slice 0
713
+ scatter = ax.scatter(self.angles, intensities[idx], label="Data")
714
+
715
+ individual_lines = []
716
+ if n_individuals:
717
+ for j in range(n_individuals):
718
+ if individual_labels and j < len(individual_labels):
719
+ label = str(individual_labels[j])
720
+ else:
721
+ label = f"Comp {j}"
722
+
723
+ yvals = all_individual_results[idx][j]
724
+ line, = ax.plot(self.angles, yvals, label=label)
725
+ individual_lines.append(line)
726
+
727
+ result_line, = ax.plot(self.angles, all_final_results[idx],
728
+ label="Fit")
729
+ resid_scatter = ax.scatter(self.angles, all_residuals[idx],
730
+ label="Residual")
731
+
732
+ # Title + limits (use only the currently shown slice)
733
+ ax.set_title(f"Energy slice: {energies_sel[idx] * KILO:.3f} meV")
734
+ ax.relim() # recompute data limits from all artists
735
+ ax.autoscale_view() # apply autoscaling + axes.ymargin padding
736
+
737
+ # Suppress warning when a single MDC is plotted
738
+ warnings.filterwarnings(
739
+ "ignore",
740
+ message="Attempting to set identical left == right",
741
+ category=UserWarning
742
+ )
743
+
744
+ # Slider over slice index (0..n_slices-1)
745
+ slider_ax = fig.add_axes([0.2, 0.08, 0.6, 0.04])
746
+ slider = Slider(
747
+ slider_ax, "Index", 0, n_slices - 1,
748
+ valinit=idx, valstep=1
749
+ )
750
+
751
+ def update(val):
752
+ i = int(slider.val)
753
+ # Update MDC points
754
+ scatter.set_offsets(np.c_[self.angles, intensities[i]])
755
+
756
+ # Update individuals
757
+ if n_individuals:
758
+ Yi = all_individual_results[i] # (n_individuals, n_angles)
759
+ for j, ln in enumerate(individual_lines):
760
+ ln.set_ydata(Yi[j])
761
+
762
+ # Update fit and residual
763
+ result_line.set_ydata(all_final_results[i])
764
+ resid_scatter.set_offsets(np.c_[self.angles, all_residuals[i]])
765
+
766
+ ax.relim()
767
+ ax.autoscale_view()
768
+
769
+ # Update title and redraw
770
+ ax.set_title(f"Energy slice: "
771
+ f"{energies_sel[i] * KILO:.3f} meV")
772
+ fig.canvas.draw_idle()
773
+
774
+ slider.on_changed(update)
775
+ self._slider = slider
776
+ self._line = scatter
777
+ self._individual_lines = individual_lines
778
+ self._result_line = result_line
779
+ self._resid_scatter = resid_scatter
780
+
781
+ ax.set_xlabel("Angle (°)")
782
+ ax.set_ylabel("Counts (-)")
783
+ ax.legend()
784
+ self._fig = fig
785
+
786
+ if size_kwargs:
787
+ fig.set_size_inches(size_kwargs.pop("w"),
788
+ size_kwargs.pop("h"), **size_kwargs)
789
+ if title:
790
+ fig.suptitle(title)
791
+ if tight_layout:
792
+ fig.tight_layout()
793
+ if savefig:
794
+ fig.savefig(savefig)
795
+ if ax_grid is not None:
796
+ for axis in fig.axes:
797
+ axis.grid(bool(ax_grid))
798
+ if ax_annotate:
799
+ tags = string.ascii_lowercase
800
+ for i, axis in enumerate(fig.axes):
801
+ axis.annotate(f"({tags[i]})", xy=(0.05, 0.95),
802
+ xycoords="axes fraction")
803
+
804
+ is_interactive = hasattr(sys, 'ps1') or 'ipykernel' in sys.modules
805
+ is_cli = not is_interactive
806
+
807
+ if show:
808
+ if is_cli:
809
+ plt.show()
810
+ if fig_close:
811
+ plt.close(fig)
812
+
813
+ if not show and (fig_close or is_cli):
814
+ return None
815
+ return fig
816
+
817
+
818
+ @add_fig_kwargs
819
+ def fit(self, distributions, energy_value=None, matrix_element=None,
820
+ matrix_args=None, ax=None, **kwargs):
821
+ r"""
822
+ """
823
+ from copy import deepcopy
824
+ from lmfit import Minimizer
825
+ from .functions import construct_parameters, build_distributions, \
826
+ residual
827
+
828
+ counts, kinergy = self.energy_check(energy_value)
829
+
830
+ ax, fig, plt = get_ax_fig_plt(ax=ax)
831
+
832
+ ax.set_xlabel('Angle ($\\degree$)')
833
+ ax.set_ylabel('Counts (-)')
834
+ ax.set_title(f"Energy slice: "
835
+ f"{(kinergy - self.hnuminPhi) * KILO:.3f} meV")
836
+
837
+ ax.scatter(self.angles, counts, label='Data')
838
+
839
+ new_distributions = deepcopy(distributions)
840
+
841
+ if matrix_element is not None:
842
+ parameters, element_names = construct_parameters(distributions,
843
+ matrix_args)
844
+ new_distributions = build_distributions(new_distributions, \
845
+ parameters)
846
+ mini = Minimizer(
847
+ residual, parameters,
848
+ fcn_args=(self.angles, counts, self.angle_resolution,
849
+ new_distributions, kinergy, self.hnuminPhi,
850
+ matrix_element, element_names))
851
+ else:
852
+ parameters = construct_parameters(distributions)
853
+ new_distributions = build_distributions(new_distributions,
854
+ parameters)
855
+ mini = Minimizer(residual, parameters,
856
+ fcn_args=(self.angles, counts, self.angle_resolution,
857
+ new_distributions, kinergy, self.hnuminPhi))
858
+
859
+ outcome = mini.minimize('least_squares')
860
+ pcov = outcome.covar
861
+
862
+ # If matrix params were fitted, pass the fitted values to plotting
863
+ if matrix_element is not None:
864
+ new_matrix_args = {key: outcome.params[key].value for key in
865
+ matrix_args}
866
+ else:
867
+ new_matrix_args = None
868
+
869
+ final_result = self._merge_and_plot(ax=ax,
870
+ distributions=new_distributions, kinetic_energy=kinergy,
871
+ matrix_element=matrix_element, matrix_args=new_matrix_args,
872
+ plot_individual=True)
873
+
874
+ residual_vals = counts - final_result
875
+ ax.scatter(self.angles, residual_vals, label='Residual')
876
+ ax.legend()
877
+ if matrix_element is not None:
878
+ return fig, new_distributions, pcov, new_matrix_args
879
+ else:
880
+ return fig, new_distributions, pcov
881
+
882
+
883
+ def _merge_and_plot(self, ax, distributions, kinetic_energy,
884
+ matrix_element=None, matrix_args=None,
885
+ plot_individual=True):
886
+ r"""
887
+ Evaluate distributions on the extended grid, apply optional matrix
888
+ element, smooth, plot individuals and the summed curve.
889
+
890
+ Returns
891
+ -------
892
+ final_result : np.ndarray
893
+ Smoothed, cropped total distribution aligned with self.angles.
894
+ """
895
+ from scipy.ndimage import gaussian_filter
896
+
897
+ # Build extended grid
898
+ extend, step, numb = extend_function(self.angles, self.angle_resolution)
899
+ total_result = np.zeros_like(extend)
900
+
901
+ for dist in distributions:
902
+ # Special handling for SpectralQuadratic
903
+ if getattr(dist, 'class_name', None) == 'SpectralQuadratic':
904
+ if (getattr(dist, 'center_angle', None) is not None) and (
905
+ kinetic_energy is None or self.hnuminPhi is None
906
+ ):
907
+ raise ValueError(
908
+ 'Spectral quadratic function is defined in terms '
909
+ 'of a center angle. Please provide a kinetic energy '
910
+ 'and hnuminPhi.'
911
+ )
912
+ extended_result = dist.evaluate(extend, kinetic_energy, \
913
+ self.hnuminPhi)
914
+ else:
915
+ extended_result = dist.evaluate(extend)
916
+
917
+ # Optional matrix element (only for components that advertise an index)
918
+ if matrix_element is not None and hasattr(dist, 'index'):
919
+ args = matrix_args or {}
920
+ extended_result *= matrix_element(extend, **args)
921
+
922
+ total_result += extended_result
923
+
924
+ if plot_individual and ax:
925
+ individual = gaussian_filter(extended_result, sigma=step)\
926
+ [numb:-numb if numb else None]
927
+ ax.plot(self.angles, individual, label=getattr(dist, \
928
+ 'label', str(dist)))
929
+
930
+ # Smoothed, cropped total curve aligned to self.angles
931
+ final_result = gaussian_filter(total_result, sigma=step)[numb:-numb \
932
+ if numb else None]
933
+ if ax:
934
+ ax.plot(self.angles, final_result, label='Distribution sum')
935
+
936
+ return final_result
937
+
938
+
939
+ def expose_parameters(self, select_label, fermi_wavevector=None,
940
+ fermi_velocity=None, bare_mass=None, side=None):
941
+ r"""
942
+ Select and return fitted parameters for a given component label, plus a
943
+ flat export dictionary containing values **and** 1σ uncertainties.
944
+
945
+ Parameters
946
+ ----------
947
+ select_label : str
948
+ Label to look for among the fitted distributions.
949
+ fermi_wavevector : float, optional
950
+ Optional Fermi wave vector to include.
951
+ fermi_velocity : float, optional
952
+ Optional Fermi velocity to include.
953
+ bare_mass : float, optional
954
+ Optional bare mass to include (used for SpectralQuadratic
955
+ dispersions).
956
+ side : {'left','right'}, optional
957
+ Optional side selector for SpectralQuadratic dispersions.
958
+
959
+ Returns
960
+ -------
961
+ ekin_range : np.ndarray
962
+ Kinetic-energy grid corresponding to the selected label.
963
+ hnuminPhi : float
964
+ Photoelectron work-function offset.
965
+ label : str
966
+ Label of the selected distribution.
967
+ selected_properties : dict or list of dict
968
+ Nested dictionary (or list thereof) containing <param> and
969
+ <param>_sigma arrays. For SpectralQuadratic components, a
970
+ scalar `center_wavevector` is also present.
971
+ exported_parameters : dict
972
+ Flat dictionary of parameters and their uncertainties, plus
973
+ optional Fermi quantities and `side`. For SpectralQuadratic
974
+ components, `center_wavevector` is included and taken directly
975
+ from the fitted distribution.
976
+ """
977
+
978
+ if self._ekin_range is None:
979
+ raise AttributeError(
980
+ "ekin_range not yet set. Run `.fit_selection()` first."
981
+ )
982
+
983
+ store = getattr(self, "_individual_properties", None)
984
+ if not store or select_label not in store:
985
+ all_labels = (sorted(store.keys())
986
+ if isinstance(store, dict) else [])
987
+ raise ValueError(
988
+ f"Label '{select_label}' not found in available labels: "
989
+ f"{all_labels}"
990
+ )
991
+
992
+ # Convert lists → numpy arrays within the selected label’s classes.
993
+ # Keep scalar center_wavevector as a scalar.
994
+ per_class_dicts = []
995
+ for cls, bucket in store[select_label].items():
996
+ dct = {}
997
+ for k, v in bucket.items():
998
+ if k in ("label", "_class"):
999
+ dct[k] = v
1000
+ elif k == "center_wavevector":
1001
+ # keep scalar as-is, do not wrap in np.asarray
1002
+ dct[k] = v
1003
+ else:
1004
+ dct[k] = np.asarray(v)
1005
+ per_class_dicts.append(dct)
1006
+
1007
+ selected_properties = (
1008
+ per_class_dicts[0] if len(per_class_dicts) == 1 else per_class_dicts
1009
+ )
1010
+
1011
+ # Flat export dict: simple keys, includes optional extras
1012
+ exported_parameters = {
1013
+ "fermi_wavevector": fermi_wavevector,
1014
+ "fermi_velocity": fermi_velocity,
1015
+ "bare_mass": bare_mass,
1016
+ "side": side,
1017
+ }
1018
+
1019
+ # Collect parameters without prefixing by class. This will also include
1020
+ # center_wavevector from the fitted SpectralQuadratic class, and since
1021
+ # there is no function argument with that name, it cannot be overridden.
1022
+ if isinstance(selected_properties, dict):
1023
+ for key, val in selected_properties.items():
1024
+ if key not in ("label", "_class"):
1025
+ exported_parameters[key] = val
1026
+ else:
1027
+ # If multiple classes, merge sequentially
1028
+ # (last overwrites same-name keys).
1029
+ for cls_bucket in selected_properties:
1030
+ for key, val in cls_bucket.items():
1031
+ if key not in ("label", "_class"):
1032
+ exported_parameters[key] = val
1033
+
1034
+ return (self._ekin_range, self.hnuminPhi, select_label,
1035
+ selected_properties, exported_parameters)