xarpes 0.2.3__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xarpes/__init__.py +3 -5
- xarpes/constants.py +12 -0
- xarpes/distributions.py +500 -239
- xarpes/functions.py +263 -61
- xarpes/plotting.py +46 -29
- xarpes/spectral.py +2067 -0
- xarpes-0.3.0.dist-info/METADATA +160 -0
- xarpes-0.3.0.dist-info/RECORD +11 -0
- {xarpes-0.2.3.dist-info → xarpes-0.3.0.dist-info}/WHEEL +1 -1
- xarpes-0.3.0.dist-info/entry_points.txt +3 -0
- {xarpes-0.2.3.dist-info → xarpes-0.3.0.dist-info/licenses}/LICENSE +0 -0
- xarpes/.ipynb_checkpoints/__init__-checkpoint.py +0 -8
- xarpes/band_map.py +0 -306
- xarpes-0.2.3.dist-info/METADATA +0 -121
- xarpes-0.2.3.dist-info/RECORD +0 -10
xarpes/functions.py
CHANGED
|
@@ -1,90 +1,251 @@
|
|
|
1
|
-
# Copyright (C)
|
|
1
|
+
# Copyright (C) 2025 xARPES Developers
|
|
2
2
|
# This program is free software under the terms of the GNU GPLv3 license.
|
|
3
3
|
|
|
4
4
|
"""Separate functions mostly used in conjunction with various classes."""
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
|
+
from .constants import fwhm_to_std, sigma_extend
|
|
7
8
|
|
|
8
|
-
def
|
|
9
|
+
def resolve_param_name(params, label, pname):
|
|
10
|
+
"""
|
|
11
|
+
Try to find the lmfit param key corresponding to this component `label`
|
|
12
|
+
and bare parameter name `pname` (e.g., 'amplitude', 'peak', 'broadening').
|
|
13
|
+
Works with common token separators.
|
|
14
|
+
"""
|
|
15
|
+
import re
|
|
16
|
+
names = list(params.keys())
|
|
17
|
+
# Fast exact candidates
|
|
18
|
+
candidates = (
|
|
19
|
+
f"{pname}_{label}", f"{label}_{pname}",
|
|
20
|
+
f"{pname}:{label}", f"{label}:{pname}",
|
|
21
|
+
f"{label}.{pname}", f"{label}|{pname}",
|
|
22
|
+
f"{label}-{pname}", f"{pname}-{label}",
|
|
23
|
+
)
|
|
24
|
+
for c in candidates:
|
|
25
|
+
if c in params:
|
|
26
|
+
return c
|
|
27
|
+
|
|
28
|
+
# Regex fallback: label and pname as tokens in any order
|
|
29
|
+
esc_l = re.escape(str(label))
|
|
30
|
+
esc_p = re.escape(str(pname))
|
|
31
|
+
tok = r"[.:/_\-]" # common separators
|
|
32
|
+
pat = re.compile(rf"(^|{tok}){esc_l}({tok}|$).*({tok}){esc_p}({tok}|$)")
|
|
33
|
+
for n in names:
|
|
34
|
+
if pat.search(n):
|
|
35
|
+
return n
|
|
36
|
+
|
|
37
|
+
# Last resort: unique tail match on pname that also contains the label somewhere
|
|
38
|
+
tails = [n for n in names if n.endswith(pname) and str(label) in n]
|
|
39
|
+
if len(tails) == 1:
|
|
40
|
+
return tails[0]
|
|
41
|
+
|
|
42
|
+
# Give up
|
|
43
|
+
return None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def build_distributions(distributions, parameters):
|
|
47
|
+
r"""TBD
|
|
48
|
+
"""
|
|
49
|
+
for dist in distributions:
|
|
50
|
+
if dist.class_name == 'Constant':
|
|
51
|
+
dist.offset = parameters['offset_' + dist.label].value
|
|
52
|
+
elif dist.class_name == 'Linear':
|
|
53
|
+
dist.offset = parameters['offset_' + dist.label].value
|
|
54
|
+
dist.slope = parameters['slope_' + dist.label].value
|
|
55
|
+
elif dist.class_name == 'SpectralLinear':
|
|
56
|
+
dist.amplitude = parameters['amplitude_' + dist.label].value
|
|
57
|
+
dist.peak = parameters['peak_' + dist.label].value
|
|
58
|
+
dist.broadening = parameters['broadening_' + dist.label].value
|
|
59
|
+
elif dist.class_name == 'SpectralQuadratic':
|
|
60
|
+
dist.amplitude = parameters['amplitude_' + dist.label].value
|
|
61
|
+
dist.peak = parameters['peak_' + dist.label].value
|
|
62
|
+
dist.broadening = parameters['broadening_' + dist.label].value
|
|
63
|
+
return distributions
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def construct_parameters(distribution_list, matrix_args=None):
|
|
67
|
+
r"""TBD
|
|
68
|
+
"""
|
|
69
|
+
from lmfit import Parameters
|
|
70
|
+
|
|
71
|
+
parameters = Parameters()
|
|
72
|
+
|
|
73
|
+
for dist in distribution_list:
|
|
74
|
+
if dist.class_name == 'Constant':
|
|
75
|
+
parameters.add(name='offset_' + dist.label, value=dist.offset)
|
|
76
|
+
elif dist.class_name == 'Linear':
|
|
77
|
+
parameters.add(name='offset_' + dist.label, value=dist.offset)
|
|
78
|
+
parameters.add(name='slope_' + dist.label, value=dist.slope)
|
|
79
|
+
elif dist.class_name == 'SpectralLinear':
|
|
80
|
+
parameters.add(name='amplitude_' + dist.label,
|
|
81
|
+
value=dist.amplitude, min=0)
|
|
82
|
+
parameters.add(name='peak_' + dist.label, value=dist.peak)
|
|
83
|
+
parameters.add(name='broadening_' + dist.label,
|
|
84
|
+
value=dist.broadening, min=0)
|
|
85
|
+
elif dist.class_name == 'SpectralQuadratic':
|
|
86
|
+
parameters.add(name='amplitude_' + dist.label,
|
|
87
|
+
value=dist.amplitude, min=0)
|
|
88
|
+
parameters.add(name='peak_' + dist.label, value=dist.peak)
|
|
89
|
+
parameters.add(name='broadening_' + dist.label,
|
|
90
|
+
value=dist.broadening, min=0)
|
|
91
|
+
|
|
92
|
+
if matrix_args is not None:
|
|
93
|
+
element_names = list()
|
|
94
|
+
for key, value in matrix_args.items():
|
|
95
|
+
parameters.add(name=key, value=value)
|
|
96
|
+
element_names.append(key)
|
|
97
|
+
return parameters, element_names
|
|
98
|
+
else:
|
|
99
|
+
return parameters
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def residual(parameters, xdata, ydata, angle_resolution, new_distributions,
|
|
103
|
+
kinetic_energy, hnuminphi, matrix_element=None,
|
|
104
|
+
element_names=None):
|
|
105
|
+
r"""
|
|
106
|
+
"""
|
|
107
|
+
from scipy.ndimage import gaussian_filter
|
|
108
|
+
from xarpes.distributions import Dispersion
|
|
109
|
+
|
|
110
|
+
if matrix_element is not None:
|
|
111
|
+
matrix_parameters = {}
|
|
112
|
+
for name in element_names:
|
|
113
|
+
if name in parameters:
|
|
114
|
+
matrix_parameters[name] = parameters[name].value
|
|
115
|
+
|
|
116
|
+
new_distributions = build_distributions(new_distributions, parameters)
|
|
117
|
+
|
|
118
|
+
extend, step, numb = extend_function(xdata, angle_resolution)
|
|
119
|
+
|
|
120
|
+
model = np.zeros_like(extend)
|
|
121
|
+
|
|
122
|
+
for dist in new_distributions:
|
|
123
|
+
if getattr(dist, 'class_name', type(dist).__name__) == 'SpectralQuadratic':
|
|
124
|
+
part = dist.evaluate(extend, kinetic_energy, hnuminphi)
|
|
125
|
+
else:
|
|
126
|
+
part = dist.evaluate(extend)
|
|
127
|
+
|
|
128
|
+
if (matrix_element is not None) and isinstance(dist, Dispersion):
|
|
129
|
+
part *= matrix_element(extend, **matrix_parameters)
|
|
130
|
+
|
|
131
|
+
model += part
|
|
132
|
+
|
|
133
|
+
model = gaussian_filter(model, sigma=step)[numb:-numb if numb else None]
|
|
134
|
+
return model - ydata
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def extend_function(abscissa_range, abscissa_resolution):
|
|
138
|
+
r"""TBD
|
|
139
|
+
"""
|
|
140
|
+
step_size = np.abs(abscissa_range[1] - abscissa_range[0])
|
|
141
|
+
step = abscissa_resolution / (step_size * fwhm_to_std)
|
|
142
|
+
numb = int(sigma_extend * step)
|
|
143
|
+
extend = np.linspace(abscissa_range[0] - numb * step_size,
|
|
144
|
+
abscissa_range[-1] + numb * step_size,
|
|
145
|
+
len(abscissa_range) + 2 * numb)
|
|
146
|
+
return extend, step, numb
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def error_function(p, xdata, ydata, function, resolution, yerr, extra_args):
|
|
9
150
|
r"""The error function used inside the fit_leastsq function.
|
|
10
151
|
|
|
11
152
|
Parameters
|
|
12
153
|
----------
|
|
13
154
|
p : ndarray
|
|
14
|
-
Array of parameters during the optimization
|
|
155
|
+
Array of parameters during the optimization.
|
|
15
156
|
xdata : ndarray
|
|
16
|
-
|
|
157
|
+
Abscissa values the function is evaluated on.
|
|
17
158
|
ydata : ndarray
|
|
18
|
-
|
|
19
|
-
function :
|
|
20
|
-
Function or class with
|
|
21
|
-
|
|
22
|
-
|
|
159
|
+
Measured values to compare to.
|
|
160
|
+
function : callable
|
|
161
|
+
Function or class with __call__ method to evaluate.
|
|
162
|
+
resolution : float or None
|
|
163
|
+
Convolution resolution (sigma), if applicable.
|
|
164
|
+
yerr : ndarray
|
|
165
|
+
Standard deviations of ydata.
|
|
166
|
+
extra_args : tuple
|
|
167
|
+
Additional arguments passed to function.
|
|
23
168
|
|
|
24
169
|
Returns
|
|
25
170
|
-------
|
|
26
|
-
residual :
|
|
27
|
-
|
|
171
|
+
residual : ndarray
|
|
172
|
+
Normalized residuals between model and ydata.
|
|
28
173
|
"""
|
|
29
|
-
|
|
174
|
+
from scipy.ndimage import gaussian_filter
|
|
175
|
+
|
|
176
|
+
if resolution:
|
|
177
|
+
extend, step, numb = extend_function(xdata, resolution)
|
|
178
|
+
model = gaussian_filter(function(extend, *p, *extra_args),
|
|
179
|
+
sigma=step)
|
|
180
|
+
model = model[numb:-numb if numb else None]
|
|
181
|
+
else:
|
|
182
|
+
model = function(xdata, *p, *extra_args)
|
|
183
|
+
|
|
184
|
+
residual = (model - ydata) / yerr
|
|
30
185
|
return residual
|
|
31
186
|
|
|
32
187
|
|
|
33
|
-
def fit_leastsq(p0, xdata, ydata, function,
|
|
34
|
-
|
|
188
|
+
def fit_leastsq(p0, xdata, ydata, function, resolution=None,
|
|
189
|
+
yerr=None, *extra_args):
|
|
190
|
+
r"""Wrapper around scipy.optimize.leastsq.
|
|
35
191
|
|
|
36
192
|
Parameters
|
|
37
193
|
----------
|
|
38
194
|
p0 : ndarray
|
|
39
|
-
Initial guess for parameters to be optimized
|
|
195
|
+
Initial guess for parameters to be optimized.
|
|
40
196
|
xdata : ndarray
|
|
41
|
-
|
|
197
|
+
Abscissa values the function is evaluated on.
|
|
42
198
|
ydata : ndarray
|
|
43
|
-
|
|
44
|
-
function :
|
|
45
|
-
Function or class with
|
|
46
|
-
|
|
47
|
-
|
|
199
|
+
Measured values to compare to.
|
|
200
|
+
function : callable
|
|
201
|
+
Function or class with __call__ method to evaluate.
|
|
202
|
+
resolution : float or None, optional
|
|
203
|
+
Convolution resolution (sigma), if applicable.
|
|
204
|
+
yerr : ndarray or None, optional
|
|
205
|
+
Standard deviations of ydata. Defaults to ones if None.
|
|
206
|
+
extra_args : tuple
|
|
207
|
+
Additional arguments passed to the function.
|
|
48
208
|
|
|
49
209
|
Returns
|
|
50
210
|
-------
|
|
51
211
|
pfit_leastsq : ndarray
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
212
|
+
Optimized parameters.
|
|
213
|
+
pcov : ndarray or float
|
|
214
|
+
Scaled covariance matrix of the optimized parameters.
|
|
215
|
+
If the covariance could not be estimated, returns np.inf.
|
|
55
216
|
"""
|
|
56
217
|
from scipy.optimize import leastsq
|
|
57
|
-
|
|
218
|
+
|
|
219
|
+
if yerr is None:
|
|
220
|
+
yerr = np.ones_like(ydata)
|
|
221
|
+
|
|
58
222
|
pfit, pcov, infodict, errmsg, success = leastsq(
|
|
59
|
-
error_function,
|
|
60
|
-
|
|
223
|
+
error_function,
|
|
224
|
+
p0,
|
|
225
|
+
args=(xdata, ydata, function, resolution, yerr, extra_args),
|
|
226
|
+
full_output=1
|
|
227
|
+
)
|
|
61
228
|
|
|
62
229
|
if (len(ydata) > len(p0)) and pcov is not None:
|
|
63
|
-
s_sq = (
|
|
64
|
-
|
|
65
|
-
|
|
230
|
+
s_sq = (
|
|
231
|
+
error_function(pfit, xdata, ydata, function, resolution,
|
|
232
|
+
yerr, extra_args) ** 2
|
|
233
|
+
).sum() / (len(ydata) - len(p0))
|
|
234
|
+
pcov *= s_sq
|
|
66
235
|
else:
|
|
67
236
|
pcov = np.inf
|
|
68
237
|
|
|
69
|
-
|
|
70
|
-
for i in range(len(pfit)):
|
|
71
|
-
try:
|
|
72
|
-
error.append(np.absolute(pcov[i][i]) ** 0.5)
|
|
73
|
-
except:
|
|
74
|
-
error.append(0.00)
|
|
75
|
-
pfit_leastsq = pfit
|
|
76
|
-
perr_leastsq = np.array(error)
|
|
238
|
+
return pfit, pcov
|
|
77
239
|
|
|
78
|
-
return pfit_leastsq, perr_leastsq
|
|
79
|
-
|
|
80
240
|
|
|
81
241
|
def download_examples():
|
|
82
|
-
"""Downloads the examples folder from the xARPES
|
|
83
|
-
already exist. Prints executed steps and a
|
|
84
|
-
|
|
242
|
+
"""Downloads the examples folder from the main xARPES repository only if it
|
|
243
|
+
does not already exist in the current directory. Prints executed steps and a
|
|
244
|
+
final cleanup/failure message.
|
|
245
|
+
|
|
85
246
|
Returns
|
|
86
247
|
-------
|
|
87
|
-
0
|
|
248
|
+
0 or 1 : int
|
|
88
249
|
Returns 0 if the execution succeeds, 1 if it fails.
|
|
89
250
|
"""
|
|
90
251
|
import requests
|
|
@@ -92,22 +253,25 @@ def download_examples():
|
|
|
92
253
|
import os
|
|
93
254
|
import shutil
|
|
94
255
|
import io
|
|
95
|
-
|
|
96
|
-
|
|
256
|
+
import jupytext
|
|
257
|
+
|
|
258
|
+
# Main xARPES repo (examples now live in /examples here)
|
|
259
|
+
repo_url = 'https://github.com/xARPES/xARPES'
|
|
97
260
|
output_dir = '.' # Directory from which the function is called
|
|
98
|
-
|
|
99
|
-
#
|
|
261
|
+
|
|
262
|
+
# Target 'examples' directory in the user's current location
|
|
100
263
|
final_examples_path = os.path.join(output_dir, 'examples')
|
|
101
264
|
if os.path.exists(final_examples_path):
|
|
102
|
-
print("Warning: 'examples' folder already exists.
|
|
103
|
-
|
|
265
|
+
print("Warning: 'examples' folder already exists. "
|
|
266
|
+
"No download will be performed.")
|
|
267
|
+
return 1 # Exit the function if 'examples' directory exists
|
|
104
268
|
|
|
105
269
|
# Proceed with download if 'examples' directory does not exist
|
|
106
|
-
repo_parts = repo_url.replace(
|
|
107
|
-
zip_url = f
|
|
270
|
+
repo_parts = repo_url.replace('https://github.com/', '').rstrip('/')
|
|
271
|
+
zip_url = f'https://github.com/{repo_parts}/archive/refs/heads/main.zip'
|
|
108
272
|
|
|
109
273
|
# Make the HTTP request to download the zip file
|
|
110
|
-
print(f
|
|
274
|
+
print(f'Downloading {zip_url}')
|
|
111
275
|
response = requests.get(zip_url)
|
|
112
276
|
if response.status_code == 200:
|
|
113
277
|
zip_file_bytes = io.BytesIO(response.content)
|
|
@@ -115,21 +279,59 @@ def download_examples():
|
|
|
115
279
|
with zipfile.ZipFile(zip_file_bytes, 'r') as zip_ref:
|
|
116
280
|
zip_ref.extractall(output_dir)
|
|
117
281
|
|
|
118
|
-
# Path to the extracted main folder
|
|
119
|
-
main_folder_path = os.path.join(
|
|
282
|
+
# Path to the extracted main folder (e.g. xARPES-main)
|
|
283
|
+
main_folder_path = os.path.join(
|
|
284
|
+
output_dir,
|
|
285
|
+
repo_parts.split('/')[-1] + '-main'
|
|
286
|
+
)
|
|
120
287
|
examples_path = os.path.join(main_folder_path, 'examples')
|
|
121
288
|
|
|
122
|
-
# Move the 'examples' directory to the target location
|
|
289
|
+
# Move the 'examples' directory to the target location
|
|
123
290
|
if os.path.exists(examples_path):
|
|
124
291
|
shutil.move(examples_path, final_examples_path)
|
|
125
292
|
print(f"'examples' subdirectory moved to {final_examples_path}")
|
|
126
|
-
|
|
127
|
-
|
|
293
|
+
|
|
294
|
+
# Convert all .Rmd files in the examples directory to .ipynb
|
|
295
|
+
# and delete the .Rmd files
|
|
296
|
+
for dirpath, dirnames, filenames in os.walk(final_examples_path):
|
|
297
|
+
for filename in filenames:
|
|
298
|
+
if filename.endswith('.Rmd'):
|
|
299
|
+
full_path = os.path.join(dirpath, filename)
|
|
300
|
+
jupytext.write(
|
|
301
|
+
jupytext.read(full_path),
|
|
302
|
+
full_path.replace('.Rmd', '.ipynb')
|
|
303
|
+
)
|
|
304
|
+
os.remove(full_path) # Deletes .Rmd file afterwards
|
|
305
|
+
print(f'Converted and deleted {full_path}')
|
|
128
306
|
|
|
129
307
|
# Remove the rest of the extracted content
|
|
130
308
|
shutil.rmtree(main_folder_path)
|
|
131
|
-
print(f
|
|
309
|
+
print(f'Cleaned up temporary files in {main_folder_path}')
|
|
132
310
|
return 0
|
|
133
311
|
else:
|
|
134
|
-
print(
|
|
135
|
-
|
|
312
|
+
print('Failed to download the repository. Status code: '
|
|
313
|
+
f'{response.status_code}')
|
|
314
|
+
return 1
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def set_script_dir():
|
|
318
|
+
r"""This function sets the directory such that the xARPES code can be
|
|
319
|
+
executed either inside IPython environments or as .py scripts from
|
|
320
|
+
arbitrary locations.
|
|
321
|
+
"""
|
|
322
|
+
import os
|
|
323
|
+
import inspect
|
|
324
|
+
try:
|
|
325
|
+
# This block checks if the script is running in an IPython environment
|
|
326
|
+
cfg = get_ipython().config
|
|
327
|
+
script_dir = os.getcwd()
|
|
328
|
+
except NameError:
|
|
329
|
+
# If not in IPython, get the caller's file location
|
|
330
|
+
frame = inspect.stack()[1]
|
|
331
|
+
module = inspect.getmodule(frame[0])
|
|
332
|
+
script_dir = os.path.dirname(os.path.abspath(module.__file__))
|
|
333
|
+
except:
|
|
334
|
+
# If __file__ isn't defined, fall back to current working directory
|
|
335
|
+
script_dir = os.getcwd()
|
|
336
|
+
|
|
337
|
+
return script_dir
|
xarpes/plotting.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C)
|
|
1
|
+
# Copyright (C) 2025 xARPES Developers
|
|
2
2
|
# This program is free software under the terms of the GNU GPLv3 license.
|
|
3
3
|
|
|
4
4
|
# get_ax_fig_plt and add_fig_kwargs originate from pymatgen/util/plotting.py.
|
|
@@ -18,15 +18,17 @@ import matplotlib as mpl
|
|
|
18
18
|
def plot_settings(name='default'):
|
|
19
19
|
mpl.rc('xtick', labelsize=10, direction='in')
|
|
20
20
|
mpl.rc('ytick', labelsize=10, direction='in')
|
|
21
|
+
plt.rcParams['legend.frameon'] = False
|
|
21
22
|
lw = dict(default=2.0, large=4.0)[name]
|
|
22
|
-
mpl.rcParams
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
23
|
+
mpl.rcParams.update({
|
|
24
|
+
'lines.linewidth': lw,
|
|
25
|
+
'lines.markersize': 3,
|
|
26
|
+
'xtick.major.size': 4,
|
|
27
|
+
'xtick.minor.size': 2,
|
|
28
|
+
'xtick.major.width': 0.8,
|
|
29
|
+
'font.size': 16,
|
|
30
|
+
'axes.ymargin': 0.15,
|
|
31
|
+
})
|
|
30
32
|
|
|
31
33
|
def get_ax_fig_plt(ax=None, **kwargs):
|
|
32
34
|
r"""Helper function used in plot functions supporting an optional `Axes`
|
|
@@ -59,12 +61,14 @@ def get_ax_fig_plt(ax=None, **kwargs):
|
|
|
59
61
|
|
|
60
62
|
return ax, fig, plt
|
|
61
63
|
|
|
64
|
+
|
|
62
65
|
def add_fig_kwargs(func):
|
|
63
66
|
"""Decorator that adds keyword arguments for functions returning matplotlib
|
|
64
67
|
figures.
|
|
65
68
|
|
|
66
|
-
The function should return either a matplotlib figure or
|
|
67
|
-
some sort of
|
|
69
|
+
The function should return either a matplotlib figure or a tuple, where the
|
|
70
|
+
first element is a matplotlib figure, or None to signal some sort of
|
|
71
|
+
error/unexpected event.
|
|
68
72
|
"""
|
|
69
73
|
@wraps(func)
|
|
70
74
|
def wrapper(*args, **kwargs):
|
|
@@ -76,14 +80,26 @@ def add_fig_kwargs(func):
|
|
|
76
80
|
tight_layout = kwargs.pop('tight_layout', False)
|
|
77
81
|
ax_grid = kwargs.pop('ax_grid', None)
|
|
78
82
|
ax_annotate = kwargs.pop('ax_annotate', None)
|
|
79
|
-
fig_close = kwargs.pop('fig_close',
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
+
fig_close = kwargs.pop('fig_close', True)
|
|
84
|
+
|
|
85
|
+
import string
|
|
86
|
+
|
|
87
|
+
# Call the original function
|
|
88
|
+
result = func(*args, **kwargs)
|
|
89
|
+
|
|
90
|
+
# Determine if result is a figure or tuple with first element as figure
|
|
91
|
+
if isinstance(result, tuple):
|
|
92
|
+
fig = result[0]
|
|
93
|
+
rest = result[1:]
|
|
94
|
+
else:
|
|
95
|
+
fig = result
|
|
96
|
+
rest = None
|
|
97
|
+
|
|
98
|
+
# Return immediately if no figure is returned
|
|
83
99
|
if fig is None:
|
|
84
|
-
return
|
|
100
|
+
return result
|
|
85
101
|
|
|
86
|
-
# Operate on matplotlib figure
|
|
102
|
+
# Operate on the matplotlib figure
|
|
87
103
|
if title is not None:
|
|
88
104
|
fig.suptitle(title)
|
|
89
105
|
|
|
@@ -96,9 +112,10 @@ def add_fig_kwargs(func):
|
|
|
96
112
|
ax.grid(bool(ax_grid))
|
|
97
113
|
|
|
98
114
|
if ax_annotate:
|
|
99
|
-
tags = ascii_letters
|
|
115
|
+
tags = string.ascii_letters
|
|
100
116
|
if len(fig.axes) > len(tags):
|
|
101
|
-
tags = (1 + len(ascii_letters) // len(fig.axes)) *
|
|
117
|
+
tags = (1 + len(string.ascii_letters) // len(fig.axes)) * \
|
|
118
|
+
string.ascii_letters
|
|
102
119
|
for ax, tag in zip(fig.axes, tags):
|
|
103
120
|
ax.annotate(f'({tag})', xy=(0.05, 0.95),
|
|
104
121
|
xycoords='axes fraction')
|
|
@@ -107,10 +124,8 @@ def add_fig_kwargs(func):
|
|
|
107
124
|
try:
|
|
108
125
|
fig.tight_layout()
|
|
109
126
|
except Exception as exc:
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
print('Ignoring Exception raised by fig.tight_layout\n',
|
|
113
|
-
str(exc))
|
|
127
|
+
print('Ignoring Exception raised by fig.tight_layout ' +
|
|
128
|
+
'\n', str(exc))
|
|
114
129
|
|
|
115
130
|
if savefig:
|
|
116
131
|
fig.savefig(savefig)
|
|
@@ -120,12 +135,16 @@ def add_fig_kwargs(func):
|
|
|
120
135
|
if fig_close:
|
|
121
136
|
plt.close(fig=fig)
|
|
122
137
|
|
|
123
|
-
return
|
|
138
|
+
# Reassemble the tuple if necessary and return
|
|
139
|
+
if rest is not None:
|
|
140
|
+
return (fig, *rest)
|
|
141
|
+
else:
|
|
142
|
+
return fig
|
|
124
143
|
|
|
125
144
|
# Add docstring to the decorated method.
|
|
126
145
|
doc_str = """\n\n
|
|
127
146
|
|
|
128
|
-
|
|
147
|
+
Notes
|
|
129
148
|
-----
|
|
130
149
|
|
|
131
150
|
Keyword arguments controlling the display of the figure:
|
|
@@ -142,18 +161,16 @@ def add_fig_kwargs(func):
|
|
|
142
161
|
ax_grid True (False) to add (remove) grid from all axes in
|
|
143
162
|
fig.
|
|
144
163
|
Default: None i.e. fig is left unchanged.
|
|
145
|
-
ax_annotate Add labels to
|
|
164
|
+
ax_annotate Add labels to subplots e.g. (a), (b).
|
|
146
165
|
Default: False
|
|
147
|
-
fig_close Close figure. Default:
|
|
166
|
+
fig_close Close figure. Default: True.
|
|
148
167
|
================ ====================================================
|
|
149
168
|
|
|
150
169
|
"""
|
|
151
170
|
|
|
152
171
|
if wrapper.__doc__ is not None:
|
|
153
|
-
# Add s at the end of the docstring.
|
|
154
172
|
wrapper.__doc__ += f'\n{doc_str}'
|
|
155
173
|
else:
|
|
156
|
-
# Use s
|
|
157
174
|
wrapper.__doc__ = doc_str
|
|
158
175
|
|
|
159
176
|
return wrapper
|