ChessAnalysisPipeline 0.0.14__py3-none-any.whl → 0.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ChessAnalysisPipeline might be problematic. Click here for more details.

CHAP/utils/general.py CHANGED
@@ -31,6 +31,52 @@ except ImportError:
31
31
 
32
32
  logger = getLogger(__name__)
33
33
 
34
+ tiny = np.finfo(np.float64).resolution
35
+
36
+ def gformat(val, length=11):
37
+ """
38
+ Format a number with '%g'-like format, while:
39
+ - the length of the output string will be of the requested length
40
+ - positive numbers will have a leading blank
41
+ - the precision will be as high as possible
42
+ - trailing zeros will not be trimmed
43
+ """
44
+ # Code taken from lmfit library
45
+ if val is None or isinstance(val, bool):
46
+ return f'{repr(val):>{length}s}'
47
+ try:
48
+ expon = int(np.log10(abs(val)))
49
+ except (OverflowError, ValueError):
50
+ expon = 0
51
+ except TypeError:
52
+ return f'{repr(val):>{length}s}'
53
+
54
+ length = max(length, 7)
55
+ form = 'e'
56
+ prec = length - 7
57
+ if abs(expon) > 99:
58
+ prec -= 1
59
+ elif ((expon > 0 and expon < (prec+4)) or
60
+ (expon <= 0 and -expon < (prec-1))):
61
+ form = 'f'
62
+ prec += 4
63
+ if expon > 0:
64
+ prec -= expon
65
+ return f'{val:{length}.{prec}{form}}'
66
+
67
+
68
+ def getfloat_attr(obj, attr, length=11):
69
+ """Format an attribute of an object for printing."""
70
+ # Code taken from lmfit library
71
+ val = getattr(obj, attr, None)
72
+ if val is None:
73
+ return 'unknown'
74
+ if isinstance(val, int):
75
+ return f'{val}'
76
+ if isinstance(val, float):
77
+ return gformat(val, length=length).strip()
78
+ return repr(val)
79
+
34
80
 
35
81
  def depth_list(_list):
36
82
  """Return the depth of a list."""
@@ -86,6 +132,13 @@ def illegal_combination(
86
132
  raise ValueError(error_msg)
87
133
 
88
134
 
135
+ def not_zero(value):
136
+ """Return value with a minimal absolute size of tiny,
137
+ preserving the sign.
138
+ """
139
+ return float(np.copysign(max(tiny, abs(value)), value))
140
+
141
+
89
142
  def test_ge_gt_le_lt(
90
143
  ge, gt, le, lt, func, location=None, raise_error=False, log=True):
91
144
  """
@@ -447,8 +500,8 @@ def index_nearest_down(a, value):
447
500
  return index
448
501
 
449
502
 
450
- def index_nearest_upp(a, value):
451
- """Return index of nearest array value, rounded upp."""
503
+ def index_nearest_up(a, value):
504
+ """Return index of nearest array value, rounded up."""
452
505
  a = np.asarray(a)
453
506
  if a.ndim > 1:
454
507
  raise ValueError(
@@ -477,14 +530,14 @@ def get_consecutive_int_range(a):
477
530
 
478
531
 
479
532
  def round_to_n(x, n=1):
480
- """Round to a specific number of decimals."""
533
+ """Round to a specific number of sig figs ."""
481
534
  if x == 0.0:
482
535
  return 0
483
536
  return type(x)(round(x, n-1-int(np.floor(np.log10(abs(x))))))
484
537
 
485
538
 
486
539
  def round_up_to_n(x, n=1):
487
- """Round up to a specific number of decimals."""
540
+ """Round up to a specific number of sig figs."""
488
541
  x_round = round_to_n(x, n)
489
542
  if abs(x/x_round) > 1.0:
490
543
  x_round += np.sign(x) * 10**(np.floor(np.log10(abs(x)))+1-n)
@@ -492,7 +545,7 @@ def round_up_to_n(x, n=1):
492
545
 
493
546
 
494
547
  def trunc_to_n(x, n=1):
495
- """Truncate to a specific number of decimals."""
548
+ """Truncate to a specific number of sig figs."""
496
549
  x_round = round_to_n(x, n)
497
550
  if abs(x_round/x) > 1.0:
498
551
  x_round -= np.sign(x) * 10**(np.floor(np.log10(abs(x)))+1-n)
@@ -510,7 +563,9 @@ def almost_equal(a, b, sig_figs):
510
563
  f'b: {b}, {type(b)})')
511
564
 
512
565
 
513
- def string_to_list(s, split_on_dash=True, remove_duplicates=True, sort=True):
566
+ def string_to_list(
567
+ s, split_on_dash=True, remove_duplicates=True, sort=True,
568
+ raise_error=False):
514
569
  """
515
570
  Return a list of numbers by splitting/expanding a string on any
516
571
  combination of commas, whitespaces, or dashes (when
@@ -524,8 +579,11 @@ def string_to_list(s, split_on_dash=True, remove_duplicates=True, sort=True):
524
579
  return []
525
580
  try:
526
581
  list1 = re_split(r'\s+,\s+|\s+,|,\s+|\s+|,', s.strip())
527
- except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
528
- return None
582
+ except (ValueError, TypeError, SyntaxError, MemoryError,
583
+ RecursionError) as e:
584
+ if not raise_error:
585
+ return None
586
+ raise e
529
587
  if split_on_dash:
530
588
  try:
531
589
  l_of_i = []
@@ -540,8 +598,10 @@ def string_to_list(s, split_on_dash=True, remove_duplicates=True, sort=True):
540
598
  else:
541
599
  raise ValueError
542
600
  except (ValueError, TypeError, SyntaxError, MemoryError,
543
- RecursionError):
544
- return None
601
+ RecursionError) as e:
602
+ if not raise_error:
603
+ return None
604
+ raise e
545
605
  else:
546
606
  l_of_i = [literal_eval(x) for x in list1]
547
607
  if remove_duplicates:
@@ -902,11 +962,255 @@ def file_exists_and_readable(f):
902
962
  return f
903
963
 
904
964
 
965
+ def rolling_average(
966
+ y, x=None, dtype=None, start=0, end=None, width=None,
967
+ stride=None, num=None, average=True, mode='valid',
968
+ use_convolve=None):
969
+ """
970
+ Returns the rolling sum or average of an array over the last
971
+ dimension.
972
+ """
973
+ y = np.asarray(y)
974
+ y_shape = y.shape
975
+ if y.ndim == 1:
976
+ y = np.expand_dims(y, 0)
977
+ else:
978
+ y = y.reshape((np.prod(y.shape[0:-1]), y.shape[-1]))
979
+ if x is not None:
980
+ x = np.asarray(x)
981
+ if x.ndim != 1:
982
+ raise ValueError('Parameter "x" must be a 1D array-like')
983
+ if x.size != y.shape[1]:
984
+ raise ValueError(f'Dimensions of "x" and "y[1]" do not '
985
+ f'match ({x.size} vs {y.shape[1]})')
986
+ if dtype is None:
987
+ if average:
988
+ dtype = y.dtype
989
+ else:
990
+ dtype = np.float32
991
+ if width is None and stride is None and num is None:
992
+ raise ValueError('Invalid input parameters, specify at least one of '
993
+ '"width", "stride" or "num"')
994
+ if width is not None and not is_int(width, ge=1):
995
+ raise ValueError(f'Invalid "width" parameter ({width})')
996
+ if stride is not None and not is_int(stride, ge=1):
997
+ raise ValueError(f'Invalid "stride" parameter ({stride})')
998
+ if num is not None and not is_int(num, ge=1):
999
+ raise ValueError(f'Invalid "num" parameter ({num})')
1000
+ if not isinstance(average, bool):
1001
+ raise ValueError(f'Invalid "average" parameter ({average})')
1002
+ if mode not in ('valid', 'full'):
1003
+ raise ValueError(f'Invalid "mode" parameter ({mode})')
1004
+ size = y.shape[1]
1005
+ if size < 2:
1006
+ raise ValueError(f'Invalid y[1] dimension ({size})')
1007
+ if not is_int(start, ge=0, lt=size):
1008
+ raise ValueError(f'Invalid "start" parameter ({start})')
1009
+ if end is None:
1010
+ end = size
1011
+ elif not is_int(end, gt=start, le=size):
1012
+ raise ValueError(f'Invalid "end" parameter ({end})')
1013
+ if use_convolve is None:
1014
+ if len(y_shape) ==1:
1015
+ use_convolve = True
1016
+ else:
1017
+ use_convolve = False
1018
+ if use_convolve and (start or end < size):
1019
+ y = np.take(y, range(start, end), axis=1)
1020
+ if x is not None:
1021
+ x = x[start:end]
1022
+ size = y.shape[1]
1023
+ else:
1024
+ size = end-start
1025
+
1026
+ if stride is None:
1027
+ if width is None:
1028
+ width = max(1, int(size/num))
1029
+ stride = width
1030
+ else:
1031
+ width = min(width, size)
1032
+ if num is None:
1033
+ stride = width
1034
+ else:
1035
+ stride = max(1, int((size-width) / (num-1)))
1036
+ else:
1037
+ stride = min(stride, size-stride)
1038
+ if width is None:
1039
+ width = stride
1040
+
1041
+ if mode == 'valid':
1042
+ num = 1 + max(0, int((size-width) / stride))
1043
+ else:
1044
+ num = int(size/stride)
1045
+ if num*stride < size:
1046
+ num += 1
1047
+
1048
+ if use_convolve:
1049
+ n_start = 0
1050
+ n_end = width
1051
+ weight = np.empty((num))
1052
+ for n in range(num):
1053
+ n_num = n_end-n_start
1054
+ weight[n] = n_num
1055
+ n_start += stride
1056
+ n_end = min(size, n_end+stride)
1057
+
1058
+ window = np.ones((width))
1059
+ if x is not None:
1060
+ if mode == 'valid':
1061
+ rx = np.convolve(x, window)[width-1:1-width:stride]
1062
+ else:
1063
+ rx = np.convolve(x, window)[width-1::stride]
1064
+ rx /= weight
1065
+
1066
+ ry = []
1067
+ if mode == 'valid':
1068
+ for i in range(y.shape[0]):
1069
+ ry.append(np.convolve(y[i], window)[width-1:1-width:stride])
1070
+ else:
1071
+ for i in range(y.shape[0]):
1072
+ ry.append(np.convolve(y[i], window)[width-1::stride])
1073
+ ry = np.reshape(ry, (*y_shape[0:-1], num))
1074
+ if len(y_shape) == 1:
1075
+ ry = np.squeeze(ry)
1076
+ if average:
1077
+ ry = (np.asarray(ry).astype(np.float32)/weight).astype(dtype)
1078
+ elif mode != 'valid':
1079
+ weight = np.where(weight < width, width/weight, 1.0)
1080
+ ry = (np.asarray(ry).astype(np.float32)*weight).astype(dtype)
1081
+ else:
1082
+ ry = np.zeros((num, y.shape[0]), dtype=y.dtype)
1083
+ if x is not None:
1084
+ rx = np.zeros(num, dtype=x.dtype)
1085
+ n_start = start
1086
+ n_end = n_start+width
1087
+ for n in range(num):
1088
+ y_sum = np.sum(y[:,n_start:n_end], 1)
1089
+ n_num = n_end-n_start
1090
+ if n_num < width:
1091
+ y_sum *= width/n_num
1092
+ ry[n] = y_sum
1093
+ if x is not None:
1094
+ rx[n] = np.sum(x[n_start:n_end])/n_num
1095
+ n_start += stride
1096
+ n_end = min(start+size, n_end+stride)
1097
+ ry = np.reshape(ry.T, (*y_shape[0:-1], num))
1098
+ if len(y_shape) == 1:
1099
+ ry = np.squeeze(ry)
1100
+ if average:
1101
+ ry = (ry.astype(np.float32)/width).astype(dtype)
1102
+
1103
+ if x is None:
1104
+ return ry
1105
+ return ry, rx
1106
+
1107
+
1108
+ def baseline_arPLS(
1109
+ y, mask=None, w=None, tol=1.e-8, lam=1.e6, max_iter=20,
1110
+ full_output=False):
1111
+ """Returns the smoothed baseline estimate of a spectrum.
1112
+
1113
+ Based on S.-J. Baek, A. Park, Y.-J Ahn, and J. Choo,
1114
+ "Baseline correction using asymmetrically reweighted penalized
1115
+ least squares smoothing", Analyst, 2015,140, 250-257
1116
+
1117
+ :param y: The spectrum.
1118
+ :type y: array-like
1119
+ :param mask: A mask to apply to the spectrum before baseline
1120
+ construction, default to `None`.
1121
+ :type mask: array-like, optional
1122
+ :param w: The weights (allows restart for additional ieterations),
1123
+ defaults to None.
1124
+ :type w: numpy.array, optional
1125
+ :param tol: The convergence tolerence, defaults to `1.e-8`.
1126
+ :type tol: float, optional
1127
+ :param lam: The &lambda (smoothness) parameter (the balance
1128
+ between the residual of the data and the baseline and the
1129
+ smoothness of the baseline). The suggested range is between
1130
+ 100 and 10^8, defaults to `10^6`.
1131
+ :type lam: float, optional
1132
+ :param max_iter: The maximum number of iterations,
1133
+ defaults to `20`.
1134
+ :type max_iter: int, optional
1135
+ :param full_output: Whether or not to also output the baseline
1136
+ corrected spectrum, the number of iterations and error in the
1137
+ returned result, defaults to `False`.
1138
+ :type full_output: bool, optional
1139
+ :return: The smoothed baseline, with optionally the baseline
1140
+ corrected spectrum, the weights, the number of iterations and
1141
+ the error in the returned result.
1142
+ :rtype: numpy.array [, numpy.array, int, float]
1143
+ """
1144
+ # With credit to: Daniel Casas-Orozco
1145
+ # https://stackoverflow.com/questions/29156532/python-baseline-correction-library
1146
+ # System modules
1147
+ from sys import float_info
1148
+
1149
+ # Third party modules
1150
+ from scipy.sparse import (
1151
+ spdiags,
1152
+ linalg,
1153
+ )
1154
+
1155
+ if not is_num(tol, gt=0):
1156
+ raise ValueError(f'Invalid tol parameter ({tol})')
1157
+ if not is_num(lam, gt=0):
1158
+ raise ValueError(f'Invalid lam parameter ({lam})')
1159
+ if not is_int(max_iter, gt=0):
1160
+ raise ValueError(f'Invalid max_iter parameter ({max_iter})')
1161
+ if not isinstance(full_output, bool):
1162
+ raise ValueError(f'Invalid full_output parameter ({max_iter})')
1163
+ y = np.asarray(y)
1164
+ if mask is not None:
1165
+ mask = mask.astype(bool)
1166
+ y_org = y
1167
+ y = y[mask]
1168
+ num = y.size
1169
+
1170
+ diag = np.ones((num-2))
1171
+ D = spdiags([diag, -2*diag, diag], [0, -1, -2], num, num-2)
1172
+
1173
+ H = lam * D.dot(D.T)
1174
+
1175
+ if w is None:
1176
+ w = np.ones(num)
1177
+ W = spdiags(w, 0, num, num)
1178
+
1179
+ error = 1
1180
+ num_iter = 0
1181
+
1182
+ exp_max = int(np.log(float_info.max))
1183
+ while error > tol and num_iter < max_iter:
1184
+ z = linalg.spsolve(W + H, W * y)
1185
+ d = y - z
1186
+ dn = d[d < 0]
1187
+
1188
+ m = np.mean(dn)
1189
+ s = np.std(dn)
1190
+
1191
+ w_new = 1.0 / (1.0 + np.exp(
1192
+ np.clip(2.0 * (d - (2.0*s - m))/s, None, exp_max)))
1193
+ error = np.linalg.norm(w_new - w) / np.linalg.norm(w)
1194
+ num_iter += 1
1195
+ w = w_new
1196
+ W.setdiag(w)
1197
+
1198
+ if mask is not None:
1199
+ zz = np.zeros(y_org.size)
1200
+ zz[mask] = z
1201
+ z = zz
1202
+ if full_output:
1203
+ d = y_org - z
1204
+ if full_output:
1205
+ return z, d, w, num_iter, float(error)
1206
+ return z
1207
+
1208
+
905
1209
  def select_mask_1d(
906
1210
  y, x=None, label=None, ref_data=[], preselected_index_ranges=None,
907
1211
  preselected_mask=None, title=None, xlabel=None, ylabel=None,
908
1212
  min_num_index_ranges=None, max_num_index_ranges=None,
909
- interactive=True):
1213
+ interactive=True, filename=None):
910
1214
  """Display a lineplot and have the user select a mask.
911
1215
 
912
1216
  :param y: One-dimensional data array for which a mask will be
@@ -945,18 +1249,21 @@ def select_mask_1d(
945
1249
  ranges, defaults to `None`.
946
1250
  :type max_num_index_ranges: int, optional
947
1251
  :param interactive: Show the plot and allow user interactions with
948
- the matplotlib figure, defults to `True`.
1252
+ the matplotlib figure, defaults to `True`.
949
1253
  :type interactive: bool, optional
950
- :return: A Matplotlib figure, a boolean mask array and the list of
951
- selected index ranges.
952
- :rtype: matplotlib.figure.Figure, numpy.ndarray,
953
- list[tuple(int, int)]
1254
+ :param filename: Save a .png of the plot to filename, defaults to
1255
+ `None`, in which case the plot is not saved.
1256
+ :type filename: str, optional
1257
+ :return: A boolean mask array and the list of selected index
1258
+ ranges.
1259
+ :rtype: numpy.ndarray, list[tuple(int, int)]
954
1260
  """
955
1261
  # Third party modules
956
- from matplotlib.patches import Patch
957
- from matplotlib.widgets import Button, SpanSelector
1262
+ if interactive or filename is not None:
1263
+ from matplotlib.patches import Patch
1264
+ from matplotlib.widgets import Button, SpanSelector
958
1265
 
959
- # local modules
1266
+ # Local modules
960
1267
  from CHAP.utils.general import index_nearest
961
1268
 
962
1269
  def change_fig_title(title):
@@ -1111,10 +1418,25 @@ def select_mask_1d(
1111
1418
  raise ValueError('Invalid parameter preselected_index_ranges '
1112
1419
  f'({preselected_index_ranges})')
1113
1420
 
1421
+ # Setup the preselected mask and index ranges if provided
1422
+ if preselected_mask is not None:
1423
+ preselected_index_ranges = update_index_ranges(
1424
+ update_mask(
1425
+ np.copy(np.asarray(preselected_mask, dtype=bool)),
1426
+ preselected_index_ranges))
1427
+
1428
+ if not interactive and filename is None:
1429
+
1430
+ # Update the mask with the preselected index ranges
1431
+ selected_mask = update_mask(len(x)*[False], preselected_index_ranges)
1432
+
1433
+ return selected_mask, preselected_index_ranges
1434
+
1114
1435
  spans = []
1115
1436
  fig_title = []
1116
1437
  error_texts = []
1117
1438
 
1439
+ # Setup the Matplotlib figure
1118
1440
  title_pos = (0.5, 0.95)
1119
1441
  title_props = {'fontsize': 'xx-large', 'horizontalalignment': 'center',
1120
1442
  'verticalalignment': 'bottom'}
@@ -1138,12 +1460,7 @@ def select_mask_1d(
1138
1460
  ax.set_xlim(x[0], x[-1])
1139
1461
  fig.subplots_adjust(bottom=0.0, top=0.85)
1140
1462
 
1141
- # Setup the preselected mask and index ranges if provided
1142
- if preselected_mask is not None:
1143
- preselected_index_ranges = update_index_ranges(
1144
- update_mask(
1145
- np.copy(np.asarray(preselected_mask, dtype=bool)),
1146
- preselected_index_ranges))
1463
+ # Add the preselected index ranges
1147
1464
  for min_, max_ in preselected_index_ranges:
1148
1465
  add_span(None, xrange_init=(x[min_], x[min(max_, num_data-1)]))
1149
1466
 
@@ -1158,7 +1475,8 @@ def select_mask_1d(
1158
1475
  fig.subplots_adjust(bottom=0.2)
1159
1476
 
1160
1477
  # Setup "Add span" button
1161
- add_span_btn = Button(plt.axes([0.15, 0.05, 0.15, 0.075]), 'Add span')
1478
+ add_span_btn = Button(
1479
+ plt.axes([0.15, 0.05, 0.15, 0.075]), 'Add span')
1162
1480
  add_span_cid = add_span_btn.on_clicked(add_span)
1163
1481
 
1164
1482
  # Setup "Reset" button
@@ -1188,15 +1506,18 @@ def select_mask_1d(
1188
1506
  # Update the mask with the currently selected index ranges
1189
1507
  selected_mask = update_mask(len(x)*[False], selected_index_ranges)
1190
1508
 
1191
- fig_title[0].set_in_layout(True)
1192
- fig.tight_layout(rect=(0, 0, 1, 0.95))
1509
+ if filename is not None:
1510
+ fig_title[0].set_in_layout(True)
1511
+ fig.tight_layout(rect=(0, 0, 1, 0.95))
1512
+ fig.savefig(filename)
1513
+ plt.close()
1193
1514
 
1194
- return fig, selected_mask, selected_index_ranges
1515
+ return selected_mask, selected_index_ranges
1195
1516
 
1196
1517
 
1197
1518
  def select_roi_1d(
1198
1519
  y, x=None, preselected_roi=None, title=None, xlabel=None, ylabel=None,
1199
- interactive=True):
1520
+ interactive=True, filename=None):
1200
1521
  """Display a 2D plot and have the user select a single region
1201
1522
  of interest.
1202
1523
 
@@ -1217,8 +1538,11 @@ def select_roi_1d(
1217
1538
  defaults to `None`.
1218
1539
  :type ylabel: str, optional
1219
1540
  :param interactive: Show the plot and allow user interactions with
1220
- the matplotlib figure, defults to `True`.
1541
+ the matplotlib figure, defaults to `True`.
1221
1542
  :type interactive: bool, optional
1543
+ :param filename: Save a .png of the plot to filename, defaults to
1544
+ `None`, in which case the plot is not saved.
1545
+ :type filename: str, optional
1222
1546
  :return: The selected region of interest as array indices and a
1223
1547
  matplotlib figure.
1224
1548
  :rtype: matplotlib.figure.Figure, tuple(int, int)
@@ -1233,16 +1557,17 @@ def select_roi_1d(
1233
1557
  f'({preselected_roi})')
1234
1558
  preselected_roi = [preselected_roi]
1235
1559
 
1236
- fig, mask, roi = select_mask_1d(
1560
+ mask, roi = select_mask_1d(
1237
1561
  y, x=x, preselected_index_ranges=preselected_roi, title=title,
1238
1562
  xlabel=xlabel, ylabel=ylabel, min_num_index_ranges=1,
1239
- max_num_index_ranges=1, interactive=interactive)
1563
+ max_num_index_ranges=1, interactive=interactive, filename=filename)
1240
1564
 
1241
- return fig, tuple(roi[0])
1565
+ return tuple(roi[0])
1242
1566
 
1243
1567
  def select_roi_2d(
1244
1568
  a, preselected_roi=None, title=None, title_a=None,
1245
- row_label='row index', column_label='column index', interactive=True):
1569
+ row_label='row index', column_label='column index', interactive=True,
1570
+ filename=None):
1246
1571
  """Display a 2D image and have the user select a single rectangular
1247
1572
  region of interest.
1248
1573
 
@@ -1265,12 +1590,15 @@ def select_roi_2d(
1265
1590
  :param interactive: Show the plot and allow user interactions with
1266
1591
  the matplotlib figure, defaults to `True`.
1267
1592
  :type interactive: bool, optional
1268
- :return: The selected region of interest as array indices and a
1269
- matplotlib figure.
1270
- :rtype: matplotlib.figure.Figure, tuple(int, int, int, int)
1593
+ :param filename: Save a .png of the plot to filename, defaults to
1594
+ `None`, in which case the plot is not saved.
1595
+ :type filename: str, optional
1596
+ :return: The selected region of interest as array indices.
1597
+ :rtype: tuple(int, int, int, int)
1271
1598
  """
1272
1599
  # Third party modules
1273
- from matplotlib.widgets import Button, RectangleSelector
1600
+ if interactive or filename is not None:
1601
+ from matplotlib.widgets import Button, RectangleSelector
1274
1602
 
1275
1603
  # Local modules
1276
1604
  from CHAP.utils.general import index_nearest
@@ -1327,9 +1655,6 @@ def select_roi_2d(
1327
1655
  change_fig_title(f'Selected ROI: {roi}')
1328
1656
  plt.close()
1329
1657
 
1330
- fig_title = []
1331
- subfig_title = []
1332
-
1333
1658
  # Check inputs
1334
1659
  a = np.asarray(a)
1335
1660
  if a.ndim != 2:
@@ -1342,6 +1667,12 @@ def select_roi_2d(
1342
1667
  if title is None:
1343
1668
  title = 'Click and drag to select or adjust a region of interest (ROI)'
1344
1669
 
1670
+ if not interactive and filename is None:
1671
+ return preselected_roi
1672
+
1673
+ fig_title = []
1674
+ subfig_title = []
1675
+
1345
1676
  title_pos = (0.5, 0.95)
1346
1677
  title_props = {'fontsize': 'xx-large', 'horizontalalignment': 'center',
1347
1678
  'verticalalignment': 'bottom'}
@@ -1370,8 +1701,9 @@ def select_roi_2d(
1370
1701
 
1371
1702
  if not interactive:
1372
1703
 
1373
- change_fig_title(
1374
- f'Selected ROI: {tuple(int(v) for v in preselected_roi)}')
1704
+ if preselected_roi is not None:
1705
+ change_fig_title(
1706
+ f'Selected ROI: {tuple(int(v) for v in preselected_roi)}')
1375
1707
 
1376
1708
  else:
1377
1709
 
@@ -1400,20 +1732,26 @@ def select_roi_2d(
1400
1732
  reset_btn.ax.remove()
1401
1733
  confirm_btn.ax.remove()
1402
1734
 
1403
- fig_title[0].set_in_layout(True)
1404
- fig.tight_layout(rect=(0, 0, 1, 0.95))
1735
+ if filename is not None:
1736
+ if fig_title:
1737
+ fig_title[0].set_in_layout(True)
1738
+ fig.tight_layout(rect=(0, 0, 1, 0.95))
1739
+ else:
1740
+ fig.tight_layout(rect=(0, 0, 1, 1))
1405
1741
 
1406
- # Remove the handles before returning the figure
1407
- if interactive:
1408
- rects[0]._center_handle.set_visible(False)
1409
- rects[0]._corner_handles.set_visible(False)
1410
- rects[0]._edge_handles.set_visible(False)
1742
+ # Remove the handles
1743
+ if interactive:
1744
+ rects[0]._center_handle.set_visible(False)
1745
+ rects[0]._corner_handles.set_visible(False)
1746
+ rects[0]._edge_handles.set_visible(False)
1747
+ fig.savefig(filename)
1748
+ plt.close()
1411
1749
 
1412
1750
  roi = tuple(int(v) for v in rects[0].extents)
1413
1751
  if roi[1]-roi[0] < 1 or roi[3]-roi[2] < 1:
1414
1752
  roi = None
1415
1753
 
1416
- return fig, roi
1754
+ return roi
1417
1755
 
1418
1756
 
1419
1757
  def select_image_indices(
@@ -1918,9 +2256,11 @@ def nxcopy(
1918
2256
  name=nxobject.nxname)
1919
2257
  elif isinstance(nxobject, (NXlink, NXfield)):
1920
2258
  # The top level nxobject is a (linked) field: return a copy
2259
+ attrs = nxobject.attrs
2260
+ attrs.pop('target', None)
1921
2261
  nxobject_copy = NXfield(
1922
2262
  value=nxobject.nxdata, name=nxobject.nxname,
1923
- attrs=nxobject.attrs)
2263
+ attrs=attrs)
1924
2264
  return nxobject_copy
1925
2265
  else:
1926
2266
  # Create a group with the same type/name as the nxobject