xarpes 0.2.3__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xarpes/functions.py CHANGED
@@ -1,90 +1,251 @@
1
- # Copyright (C) 2024 xARPES Developers
1
+ # Copyright (C) 2025 xARPES Developers
2
2
  # This program is free software under the terms of the GNU GPLv3 license.
3
3
 
4
4
  """Separate functions mostly used in conjunction with various classes."""
5
5
 
6
6
  import numpy as np
7
+ from .constants import fwhm_to_std, sigma_extend
7
8
 
8
- def error_function(p, xdata, ydata, function, extra_args):
9
+ def resolve_param_name(params, label, pname):
10
+ """
11
+ Try to find the lmfit param key corresponding to this component `label`
12
+ and bare parameter name `pname` (e.g., 'amplitude', 'peak', 'broadening').
13
+ Works with common token separators.
14
+ """
15
+ import re
16
+ names = list(params.keys())
17
+ # Fast exact candidates
18
+ candidates = (
19
+ f"{pname}_{label}", f"{label}_{pname}",
20
+ f"{pname}:{label}", f"{label}:{pname}",
21
+ f"{label}.{pname}", f"{label}|{pname}",
22
+ f"{label}-{pname}", f"{pname}-{label}",
23
+ )
24
+ for c in candidates:
25
+ if c in params:
26
+ return c
27
+
28
+ # Regex fallback: label and pname as tokens in any order
29
+ esc_l = re.escape(str(label))
30
+ esc_p = re.escape(str(pname))
31
+ tok = r"[.:/_\-]" # common separators
32
+ pat = re.compile(rf"(^|{tok}){esc_l}({tok}|$).*({tok}){esc_p}({tok}|$)")
33
+ for n in names:
34
+ if pat.search(n):
35
+ return n
36
+
37
+ # Last resort: unique tail match on pname that also contains the label somewhere
38
+ tails = [n for n in names if n.endswith(pname) and str(label) in n]
39
+ if len(tails) == 1:
40
+ return tails[0]
41
+
42
+ # Give up
43
+ return None
44
+
45
+
46
+ def build_distributions(distributions, parameters):
47
+ r"""TBD
48
+ """
49
+ for dist in distributions:
50
+ if dist.class_name == 'Constant':
51
+ dist.offset = parameters['offset_' + dist.label].value
52
+ elif dist.class_name == 'Linear':
53
+ dist.offset = parameters['offset_' + dist.label].value
54
+ dist.slope = parameters['slope_' + dist.label].value
55
+ elif dist.class_name == 'SpectralLinear':
56
+ dist.amplitude = parameters['amplitude_' + dist.label].value
57
+ dist.peak = parameters['peak_' + dist.label].value
58
+ dist.broadening = parameters['broadening_' + dist.label].value
59
+ elif dist.class_name == 'SpectralQuadratic':
60
+ dist.amplitude = parameters['amplitude_' + dist.label].value
61
+ dist.peak = parameters['peak_' + dist.label].value
62
+ dist.broadening = parameters['broadening_' + dist.label].value
63
+ return distributions
64
+
65
+
66
+ def construct_parameters(distribution_list, matrix_args=None):
67
+ r"""TBD
68
+ """
69
+ from lmfit import Parameters
70
+
71
+ parameters = Parameters()
72
+
73
+ for dist in distribution_list:
74
+ if dist.class_name == 'Constant':
75
+ parameters.add(name='offset_' + dist.label, value=dist.offset)
76
+ elif dist.class_name == 'Linear':
77
+ parameters.add(name='offset_' + dist.label, value=dist.offset)
78
+ parameters.add(name='slope_' + dist.label, value=dist.slope)
79
+ elif dist.class_name == 'SpectralLinear':
80
+ parameters.add(name='amplitude_' + dist.label,
81
+ value=dist.amplitude, min=0)
82
+ parameters.add(name='peak_' + dist.label, value=dist.peak)
83
+ parameters.add(name='broadening_' + dist.label,
84
+ value=dist.broadening, min=0)
85
+ elif dist.class_name == 'SpectralQuadratic':
86
+ parameters.add(name='amplitude_' + dist.label,
87
+ value=dist.amplitude, min=0)
88
+ parameters.add(name='peak_' + dist.label, value=dist.peak)
89
+ parameters.add(name='broadening_' + dist.label,
90
+ value=dist.broadening, min=0)
91
+
92
+ if matrix_args is not None:
93
+ element_names = list()
94
+ for key, value in matrix_args.items():
95
+ parameters.add(name=key, value=value)
96
+ element_names.append(key)
97
+ return parameters, element_names
98
+ else:
99
+ return parameters
100
+
101
+
102
+ def residual(parameters, xdata, ydata, angle_resolution, new_distributions,
103
+ kinetic_energy, hnuminphi, matrix_element=None,
104
+ element_names=None):
105
+ r"""
106
+ """
107
+ from scipy.ndimage import gaussian_filter
108
+ from xarpes.distributions import Dispersion
109
+
110
+ if matrix_element is not None:
111
+ matrix_parameters = {}
112
+ for name in element_names:
113
+ if name in parameters:
114
+ matrix_parameters[name] = parameters[name].value
115
+
116
+ new_distributions = build_distributions(new_distributions, parameters)
117
+
118
+ extend, step, numb = extend_function(xdata, angle_resolution)
119
+
120
+ model = np.zeros_like(extend)
121
+
122
+ for dist in new_distributions:
123
+ if getattr(dist, 'class_name', type(dist).__name__) == 'SpectralQuadratic':
124
+ part = dist.evaluate(extend, kinetic_energy, hnuminphi)
125
+ else:
126
+ part = dist.evaluate(extend)
127
+
128
+ if (matrix_element is not None) and isinstance(dist, Dispersion):
129
+ part *= matrix_element(extend, **matrix_parameters)
130
+
131
+ model += part
132
+
133
+ model = gaussian_filter(model, sigma=step)[numb:-numb if numb else None]
134
+ return model - ydata
135
+
136
+
137
+ def extend_function(abscissa_range, abscissa_resolution):
138
+ r"""TBD
139
+ """
140
+ step_size = np.abs(abscissa_range[1] - abscissa_range[0])
141
+ step = abscissa_resolution / (step_size * fwhm_to_std)
142
+ numb = int(sigma_extend * step)
143
+ extend = np.linspace(abscissa_range[0] - numb * step_size,
144
+ abscissa_range[-1] + numb * step_size,
145
+ len(abscissa_range) + 2 * numb)
146
+ return extend, step, numb
147
+
148
+
149
+ def error_function(p, xdata, ydata, function, resolution, yerr, extra_args):
9
150
  r"""The error function used inside the fit_leastsq function.
10
151
 
11
152
  Parameters
12
153
  ----------
13
154
  p : ndarray
14
- Array of parameters during the optimization
155
+ Array of parameters during the optimization.
15
156
  xdata : ndarray
16
- Array of abscissa values the function is evaluated on
157
+ Abscissa values the function is evaluated on.
17
158
  ydata : ndarray
18
- Outcomes on ordinate the evaluated function is compared to
19
- function : function
20
- Function or class with call method to be evaluated
21
- extra_args :
22
- Arguments provided to function that should not be optimized
159
+ Measured values to compare to.
160
+ function : callable
161
+ Function or class with __call__ method to evaluate.
162
+ resolution : float or None
163
+ Convolution resolution (sigma), if applicable.
164
+ yerr : ndarray
165
+ Standard deviations of ydata.
166
+ extra_args : tuple
167
+ Additional arguments passed to function.
23
168
 
24
169
  Returns
25
170
  -------
26
- residual :
27
- Residual between evaluated function and ydata
171
+ residual : ndarray
172
+ Normalized residuals between model and ydata.
28
173
  """
29
- residual = function(xdata, *p, extra_args) - ydata
174
+ from scipy.ndimage import gaussian_filter
175
+
176
+ if resolution:
177
+ extend, step, numb = extend_function(xdata, resolution)
178
+ model = gaussian_filter(function(extend, *p, *extra_args),
179
+ sigma=step)
180
+ model = model[numb:-numb if numb else None]
181
+ else:
182
+ model = function(xdata, *p, *extra_args)
183
+
184
+ residual = (model - ydata) / yerr
30
185
  return residual
31
186
 
32
187
 
33
- def fit_leastsq(p0, xdata, ydata, function, extra_args):
34
- r"""Wrapper arround scipy.optimize.leastsq.
188
+ def fit_leastsq(p0, xdata, ydata, function, resolution=None,
189
+ yerr=None, *extra_args):
190
+ r"""Wrapper around scipy.optimize.leastsq.
35
191
 
36
192
  Parameters
37
193
  ----------
38
194
  p0 : ndarray
39
- Initial guess for parameters to be optimized
195
+ Initial guess for parameters to be optimized.
40
196
  xdata : ndarray
41
- Array of abscissa values the function is evaluated on
197
+ Abscissa values the function is evaluated on.
42
198
  ydata : ndarray
43
- Outcomes on ordinate the evaluated function is compared to
44
- function : function
45
- Function or class with call method to be evaluated
46
- extra_args :
47
- Arguments provided to function that should not be optimized
199
+ Measured values to compare to.
200
+ function : callable
201
+ Function or class with __call__ method to evaluate.
202
+ resolution : float or None, optional
203
+ Convolution resolution (sigma), if applicable.
204
+ yerr : ndarray or None, optional
205
+ Standard deviations of ydata. Defaults to ones if None.
206
+ extra_args : tuple
207
+ Additional arguments passed to the function.
48
208
 
49
209
  Returns
50
210
  -------
51
211
  pfit_leastsq : ndarray
52
- Array containing the optimized parameters
53
- perr_leastsq : ndarray
54
- Covariance matrix of the optimized parameters
212
+ Optimized parameters.
213
+ pcov : ndarray or float
214
+ Scaled covariance matrix of the optimized parameters.
215
+ If the covariance could not be estimated, returns np.inf.
55
216
  """
56
217
  from scipy.optimize import leastsq
57
-
218
+
219
+ if yerr is None:
220
+ yerr = np.ones_like(ydata)
221
+
58
222
  pfit, pcov, infodict, errmsg, success = leastsq(
59
- error_function, p0, args=(xdata, ydata, function, extra_args),
60
- full_output=1)
223
+ error_function,
224
+ p0,
225
+ args=(xdata, ydata, function, resolution, yerr, extra_args),
226
+ full_output=1
227
+ )
61
228
 
62
229
  if (len(ydata) > len(p0)) and pcov is not None:
63
- s_sq = (error_function(pfit, xdata, ydata, function,
64
- extra_args) ** 2).sum() / (len(ydata) - len(p0))
65
- pcov = pcov * s_sq
230
+ s_sq = (
231
+ error_function(pfit, xdata, ydata, function, resolution,
232
+ yerr, extra_args) ** 2
233
+ ).sum() / (len(ydata) - len(p0))
234
+ pcov *= s_sq
66
235
  else:
67
236
  pcov = np.inf
68
237
 
69
- error = []
70
- for i in range(len(pfit)):
71
- try:
72
- error.append(np.absolute(pcov[i][i]) ** 0.5)
73
- except:
74
- error.append(0.00)
75
- pfit_leastsq = pfit
76
- perr_leastsq = np.array(error)
238
+ return pfit, pcov
77
239
 
78
- return pfit_leastsq, perr_leastsq
79
-
80
240
 
81
241
  def download_examples():
82
- """Downloads the examples folder from the xARPES code only if it does not
83
- already exist. Prints executed steps and a final cleanup/failure message.
84
-
242
+ """Downloads the examples folder from the main xARPES repository only if it
243
+ does not already exist in the current directory. Prints executed steps and a
244
+ final cleanup/failure message.
245
+
85
246
  Returns
86
247
  -------
87
- 0, 1 : int
248
+ 0 or 1 : int
88
249
  Returns 0 if the execution succeeds, 1 if it fails.
89
250
  """
90
251
  import requests
@@ -92,22 +253,25 @@ def download_examples():
92
253
  import os
93
254
  import shutil
94
255
  import io
95
-
96
- repo_url = 'https://github.com/xARPES/xARPES_examples'
256
+ import jupytext
257
+
258
+ # Main xARPES repo (examples now live in /examples here)
259
+ repo_url = 'https://github.com/xARPES/xARPES'
97
260
  output_dir = '.' # Directory from which the function is called
98
-
99
- # Check if 'examples' directory already exists
261
+
262
+ # Target 'examples' directory in the user's current location
100
263
  final_examples_path = os.path.join(output_dir, 'examples')
101
264
  if os.path.exists(final_examples_path):
102
- print("Warning: 'examples' folder already exists. No download will be performed.")
103
- return 1 # Exit the function if 'examples' directory exists
265
+ print("Warning: 'examples' folder already exists. "
266
+ "No download will be performed.")
267
+ return 1 # Exit the function if 'examples' directory exists
104
268
 
105
269
  # Proceed with download if 'examples' directory does not exist
106
- repo_parts = repo_url.replace("https://github.com/", "").rstrip('/')
107
- zip_url = f"https://github.com/{repo_parts}/archive/refs/heads/main.zip"
270
+ repo_parts = repo_url.replace('https://github.com/', '').rstrip('/')
271
+ zip_url = f'https://github.com/{repo_parts}/archive/refs/heads/main.zip'
108
272
 
109
273
  # Make the HTTP request to download the zip file
110
- print(f"Downloading {zip_url}")
274
+ print(f'Downloading {zip_url}')
111
275
  response = requests.get(zip_url)
112
276
  if response.status_code == 200:
113
277
  zip_file_bytes = io.BytesIO(response.content)
@@ -115,21 +279,59 @@ def download_examples():
115
279
  with zipfile.ZipFile(zip_file_bytes, 'r') as zip_ref:
116
280
  zip_ref.extractall(output_dir)
117
281
 
118
- # Path to the extracted main folder
119
- main_folder_path = os.path.join(output_dir, repo_parts.split('/')[-1] + '-main')
282
+ # Path to the extracted main folder (e.g. xARPES-main)
283
+ main_folder_path = os.path.join(
284
+ output_dir,
285
+ repo_parts.split('/')[-1] + '-main'
286
+ )
120
287
  examples_path = os.path.join(main_folder_path, 'examples')
121
288
 
122
- # Move the 'examples' directory to the target location if it was extracted
289
+ # Move the 'examples' directory to the target location
123
290
  if os.path.exists(examples_path):
124
291
  shutil.move(examples_path, final_examples_path)
125
292
  print(f"'examples' subdirectory moved to {final_examples_path}")
126
- else:
127
- print("'examples' subdirectory not found in the repository.")
293
+
294
+ # Convert all .Rmd files in the examples directory to .ipynb
295
+ # and delete the .Rmd files
296
+ for dirpath, dirnames, filenames in os.walk(final_examples_path):
297
+ for filename in filenames:
298
+ if filename.endswith('.Rmd'):
299
+ full_path = os.path.join(dirpath, filename)
300
+ jupytext.write(
301
+ jupytext.read(full_path),
302
+ full_path.replace('.Rmd', '.ipynb')
303
+ )
304
+ os.remove(full_path) # Deletes .Rmd file afterwards
305
+ print(f'Converted and deleted {full_path}')
128
306
 
129
307
  # Remove the rest of the extracted content
130
308
  shutil.rmtree(main_folder_path)
131
- print(f"Cleaned up temporary files in {main_folder_path}")
309
+ print(f'Cleaned up temporary files in {main_folder_path}')
132
310
  return 0
133
311
  else:
134
- print(f"Failed to download the repository. Status code: {response.status_code}")
135
- return 1
312
+ print('Failed to download the repository. Status code: '
313
+ f'{response.status_code}')
314
+ return 1
315
+
316
+
317
+ def set_script_dir():
318
+ r"""This function sets the directory such that the xARPES code can be
319
+ executed either inside IPython environments or as .py scripts from
320
+ arbitrary locations.
321
+ """
322
+ import os
323
+ import inspect
324
+ try:
325
+ # This block checks if the script is running in an IPython environment
326
+ cfg = get_ipython().config
327
+ script_dir = os.getcwd()
328
+ except NameError:
329
+ # If not in IPython, get the caller's file location
330
+ frame = inspect.stack()[1]
331
+ module = inspect.getmodule(frame[0])
332
+ script_dir = os.path.dirname(os.path.abspath(module.__file__))
333
+ except:
334
+ # If __file__ isn't defined, fall back to current working directory
335
+ script_dir = os.getcwd()
336
+
337
+ return script_dir
xarpes/plotting.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2024 xARPES Developers
1
+ # Copyright (C) 2025 xARPES Developers
2
2
  # This program is free software under the terms of the GNU GPLv3 license.
3
3
 
4
4
  # get_ax_fig_plt and add_fig_kwargs originate from pymatgen/util/plotting.py.
@@ -18,15 +18,17 @@ import matplotlib as mpl
18
18
  def plot_settings(name='default'):
19
19
  mpl.rc('xtick', labelsize=10, direction='in')
20
20
  mpl.rc('ytick', labelsize=10, direction='in')
21
+ plt.rcParams['legend.frameon'] = False
21
22
  lw = dict(default=2.0, large=4.0)[name]
22
- mpl.rcParams['lines.linewidth'] = lw
23
- mpl.rcParams['lines.markersize'] = 3
24
- mpl.rcParams['xtick.major.size'] = 4
25
- mpl.rcParams['xtick.minor.size'] = 2
26
- mpl.rcParams['xtick.major.width'] = 0.8
27
- mpl.rcParams.update({'font.size': 16})
28
- mpl.use('Qt5Agg') # Backend for showing plots in terminal
29
-
23
+ mpl.rcParams.update({
24
+ 'lines.linewidth': lw,
25
+ 'lines.markersize': 3,
26
+ 'xtick.major.size': 4,
27
+ 'xtick.minor.size': 2,
28
+ 'xtick.major.width': 0.8,
29
+ 'font.size': 16,
30
+ 'axes.ymargin': 0.15,
31
+ })
30
32
 
31
33
  def get_ax_fig_plt(ax=None, **kwargs):
32
34
  r"""Helper function used in plot functions supporting an optional `Axes`
@@ -59,12 +61,14 @@ def get_ax_fig_plt(ax=None, **kwargs):
59
61
 
60
62
  return ax, fig, plt
61
63
 
64
+
62
65
  def add_fig_kwargs(func):
63
66
  """Decorator that adds keyword arguments for functions returning matplotlib
64
67
  figures.
65
68
 
66
- The function should return either a matplotlib figure or None to signal
67
- some sort of error/unexpected event.
69
+ The function should return either a matplotlib figure or a tuple, where the
70
+ first element is a matplotlib figure, or None to signal some sort of
71
+ error/unexpected event.
68
72
  """
69
73
  @wraps(func)
70
74
  def wrapper(*args, **kwargs):
@@ -76,14 +80,26 @@ def add_fig_kwargs(func):
76
80
  tight_layout = kwargs.pop('tight_layout', False)
77
81
  ax_grid = kwargs.pop('ax_grid', None)
78
82
  ax_annotate = kwargs.pop('ax_annotate', None)
79
- fig_close = kwargs.pop('fig_close', False)
80
-
81
- # Call func and return immediately if None is returned.
82
- fig = func(*args, **kwargs)
83
+ fig_close = kwargs.pop('fig_close', True)
84
+
85
+ import string
86
+
87
+ # Call the original function
88
+ result = func(*args, **kwargs)
89
+
90
+ # Determine if result is a figure or tuple with first element as figure
91
+ if isinstance(result, tuple):
92
+ fig = result[0]
93
+ rest = result[1:]
94
+ else:
95
+ fig = result
96
+ rest = None
97
+
98
+ # Return immediately if no figure is returned
83
99
  if fig is None:
84
- return fig
100
+ return result
85
101
 
86
- # Operate on matplotlib figure.
102
+ # Operate on the matplotlib figure
87
103
  if title is not None:
88
104
  fig.suptitle(title)
89
105
 
@@ -96,9 +112,10 @@ def add_fig_kwargs(func):
96
112
  ax.grid(bool(ax_grid))
97
113
 
98
114
  if ax_annotate:
99
- tags = ascii_letters
115
+ tags = string.ascii_letters
100
116
  if len(fig.axes) > len(tags):
101
- tags = (1 + len(ascii_letters) // len(fig.axes)) * ascii_letters
117
+ tags = (1 + len(string.ascii_letters) // len(fig.axes)) * \
118
+ string.ascii_letters
102
119
  for ax, tag in zip(fig.axes, tags):
103
120
  ax.annotate(f'({tag})', xy=(0.05, 0.95),
104
121
  xycoords='axes fraction')
@@ -107,10 +124,8 @@ def add_fig_kwargs(func):
107
124
  try:
108
125
  fig.tight_layout()
109
126
  except Exception as exc:
110
- # For some unknown reason, this problem shows up only on travis.
111
- # https://stackoverflow.com/questions/22708888/valueerror-when-using-matplotlib-tight-layout
112
- print('Ignoring Exception raised by fig.tight_layout\n',
113
- str(exc))
127
+ print('Ignoring Exception raised by fig.tight_layout ' +
128
+ '\n', str(exc))
114
129
 
115
130
  if savefig:
116
131
  fig.savefig(savefig)
@@ -120,12 +135,16 @@ def add_fig_kwargs(func):
120
135
  if fig_close:
121
136
  plt.close(fig=fig)
122
137
 
123
- return fig
138
+ # Reassemble the tuple if necessary and return
139
+ if rest is not None:
140
+ return (fig, *rest)
141
+ else:
142
+ return fig
124
143
 
125
144
  # Add docstring to the decorated method.
126
145
  doc_str = """\n\n
127
146
 
128
- notes
147
+ Notes
129
148
  -----
130
149
 
131
150
  Keyword arguments controlling the display of the figure:
@@ -142,18 +161,16 @@ def add_fig_kwargs(func):
142
161
  ax_grid True (False) to add (remove) grid from all axes in
143
162
  fig.
144
163
  Default: None i.e. fig is left unchanged.
145
- ax_annotate Add labels to subplots e.g. (a), (b).
164
+ ax_annotate Add labels to subplots e.g. (a), (b).
146
165
  Default: False
147
- fig_close Close figure. Default: False.
166
+ fig_close Close figure. Default: True.
148
167
  ================ ====================================================
149
168
 
150
169
  """
151
170
 
152
171
  if wrapper.__doc__ is not None:
153
- # Add s at the end of the docstring.
154
172
  wrapper.__doc__ += f'\n{doc_str}'
155
173
  else:
156
- # Use s
157
174
  wrapper.__doc__ = doc_str
158
175
 
159
176
  return wrapper