ChessAnalysisPipeline 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ChessAnalysisPipeline might be problematic. Click here for more details.

Files changed (38) hide show
  1. CHAP/__init__.py +1 -1
  2. CHAP/common/__init__.py +13 -0
  3. CHAP/common/models/integration.py +29 -26
  4. CHAP/common/models/map.py +395 -224
  5. CHAP/common/processor.py +1725 -93
  6. CHAP/common/reader.py +265 -28
  7. CHAP/common/writer.py +191 -18
  8. CHAP/edd/__init__.py +9 -2
  9. CHAP/edd/models.py +886 -665
  10. CHAP/edd/processor.py +2592 -936
  11. CHAP/edd/reader.py +889 -0
  12. CHAP/edd/utils.py +846 -292
  13. CHAP/foxden/__init__.py +6 -0
  14. CHAP/foxden/processor.py +42 -0
  15. CHAP/foxden/writer.py +65 -0
  16. CHAP/giwaxs/__init__.py +8 -0
  17. CHAP/giwaxs/models.py +100 -0
  18. CHAP/giwaxs/processor.py +520 -0
  19. CHAP/giwaxs/reader.py +5 -0
  20. CHAP/giwaxs/writer.py +5 -0
  21. CHAP/pipeline.py +48 -10
  22. CHAP/runner.py +161 -72
  23. CHAP/tomo/models.py +31 -29
  24. CHAP/tomo/processor.py +169 -118
  25. CHAP/utils/__init__.py +1 -0
  26. CHAP/utils/fit.py +1292 -1315
  27. CHAP/utils/general.py +411 -53
  28. CHAP/utils/models.py +594 -0
  29. CHAP/utils/parfile.py +10 -2
  30. ChessAnalysisPipeline-0.0.16.dist-info/LICENSE +60 -0
  31. {ChessAnalysisPipeline-0.0.14.dist-info → ChessAnalysisPipeline-0.0.16.dist-info}/METADATA +1 -1
  32. ChessAnalysisPipeline-0.0.16.dist-info/RECORD +62 -0
  33. {ChessAnalysisPipeline-0.0.14.dist-info → ChessAnalysisPipeline-0.0.16.dist-info}/WHEEL +1 -1
  34. CHAP/utils/scanparsers.py +0 -1431
  35. ChessAnalysisPipeline-0.0.14.dist-info/LICENSE +0 -21
  36. ChessAnalysisPipeline-0.0.14.dist-info/RECORD +0 -54
  37. {ChessAnalysisPipeline-0.0.14.dist-info → ChessAnalysisPipeline-0.0.16.dist-info}/entry_points.txt +0 -0
  38. {ChessAnalysisPipeline-0.0.14.dist-info → ChessAnalysisPipeline-0.0.16.dist-info}/top_level.txt +0 -0
CHAP/utils/general.py CHANGED
@@ -31,6 +31,52 @@ except ImportError:
31
31
 
32
32
  logger = getLogger(__name__)
33
33
 
34
+ tiny = np.finfo(np.float64).resolution
35
+
36
+ def gformat(val, length=11):
37
+ """
38
+ Format a number with '%g'-like format, while:
39
+ - the length of the output string will be of the requested length
40
+ - positive numbers will have a leading blank
41
+ - the precision will be as high as possible
42
+ - trailing zeros will not be trimmed
43
+ """
44
+ # Code taken from lmfit library
45
+ if val is None or isinstance(val, bool):
46
+ return f'{repr(val):>{length}s}'
47
+ try:
48
+ expon = int(np.log10(abs(val)))
49
+ except (OverflowError, ValueError):
50
+ expon = 0
51
+ except TypeError:
52
+ return f'{repr(val):>{length}s}'
53
+
54
+ length = max(length, 7)
55
+ form = 'e'
56
+ prec = length - 7
57
+ if abs(expon) > 99:
58
+ prec -= 1
59
+ elif ((expon > 0 and expon < (prec+4)) or
60
+ (expon <= 0 and -expon < (prec-1))):
61
+ form = 'f'
62
+ prec += 4
63
+ if expon > 0:
64
+ prec -= expon
65
+ return f'{val:{length}.{prec}{form}}'
66
+
67
+
68
+ def getfloat_attr(obj, attr, length=11):
69
+ """Format an attribute of an object for printing."""
70
+ # Code taken from lmfit library
71
+ val = getattr(obj, attr, None)
72
+ if val is None:
73
+ return 'unknown'
74
+ if isinstance(val, int):
75
+ return f'{val}'
76
+ if isinstance(val, float):
77
+ return gformat(val, length=length).strip()
78
+ return repr(val)
79
+
34
80
 
35
81
  def depth_list(_list):
36
82
  """Return the depth of a list."""
@@ -86,6 +132,13 @@ def illegal_combination(
86
132
  raise ValueError(error_msg)
87
133
 
88
134
 
135
+ def not_zero(value):
136
+ """Return value with a minimal absolute size of tiny,
137
+ preserving the sign.
138
+ """
139
+ return float(np.copysign(max(tiny, abs(value)), value))
140
+
141
+
89
142
  def test_ge_gt_le_lt(
90
143
  ge, gt, le, lt, func, location=None, raise_error=False, log=True):
91
144
  """
@@ -447,8 +500,8 @@ def index_nearest_down(a, value):
447
500
  return index
448
501
 
449
502
 
450
- def index_nearest_upp(a, value):
451
- """Return index of nearest array value, rounded upp."""
503
+ def index_nearest_up(a, value):
504
+ """Return index of nearest array value, rounded up."""
452
505
  a = np.asarray(a)
453
506
  if a.ndim > 1:
454
507
  raise ValueError(
@@ -477,14 +530,14 @@ def get_consecutive_int_range(a):
477
530
 
478
531
 
479
532
  def round_to_n(x, n=1):
480
- """Round to a specific number of decimals."""
533
+ """Round to a specific number of sig figs ."""
481
534
  if x == 0.0:
482
535
  return 0
483
536
  return type(x)(round(x, n-1-int(np.floor(np.log10(abs(x))))))
484
537
 
485
538
 
486
539
  def round_up_to_n(x, n=1):
487
- """Round up to a specific number of decimals."""
540
+ """Round up to a specific number of sig figs."""
488
541
  x_round = round_to_n(x, n)
489
542
  if abs(x/x_round) > 1.0:
490
543
  x_round += np.sign(x) * 10**(np.floor(np.log10(abs(x)))+1-n)
@@ -492,7 +545,7 @@ def round_up_to_n(x, n=1):
492
545
 
493
546
 
494
547
  def trunc_to_n(x, n=1):
495
- """Truncate to a specific number of decimals."""
548
+ """Truncate to a specific number of sig figs."""
496
549
  x_round = round_to_n(x, n)
497
550
  if abs(x_round/x) > 1.0:
498
551
  x_round -= np.sign(x) * 10**(np.floor(np.log10(abs(x)))+1-n)
@@ -510,7 +563,9 @@ def almost_equal(a, b, sig_figs):
510
563
  f'b: {b}, {type(b)})')
511
564
 
512
565
 
513
- def string_to_list(s, split_on_dash=True, remove_duplicates=True, sort=True):
566
+ def string_to_list(
567
+ s, split_on_dash=True, remove_duplicates=True, sort=True,
568
+ raise_error=False):
514
569
  """
515
570
  Return a list of numbers by splitting/expanding a string on any
516
571
  combination of commas, whitespaces, or dashes (when
@@ -524,8 +579,11 @@ def string_to_list(s, split_on_dash=True, remove_duplicates=True, sort=True):
524
579
  return []
525
580
  try:
526
581
  list1 = re_split(r'\s+,\s+|\s+,|,\s+|\s+|,', s.strip())
527
- except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
528
- return None
582
+ except (ValueError, TypeError, SyntaxError, MemoryError,
583
+ RecursionError) as e:
584
+ if not raise_error:
585
+ return None
586
+ raise e
529
587
  if split_on_dash:
530
588
  try:
531
589
  l_of_i = []
@@ -540,8 +598,10 @@ def string_to_list(s, split_on_dash=True, remove_duplicates=True, sort=True):
540
598
  else:
541
599
  raise ValueError
542
600
  except (ValueError, TypeError, SyntaxError, MemoryError,
543
- RecursionError):
544
- return None
601
+ RecursionError) as e:
602
+ if not raise_error:
603
+ return None
604
+ raise e
545
605
  else:
546
606
  l_of_i = [literal_eval(x) for x in list1]
547
607
  if remove_duplicates:
@@ -551,6 +611,24 @@ def string_to_list(s, split_on_dash=True, remove_duplicates=True, sort=True):
551
611
  return l_of_i
552
612
 
553
613
 
614
+ def list_to_string(a):
615
+ """Return a list of pairs of integers marking consecutive ranges
616
+ of integers in string notation."""
617
+ int_ranges = get_consecutive_int_range(a)
618
+ if not len(int_ranges):
619
+ return ''
620
+ if int_ranges[0][0] == int_ranges[0][1]:
621
+ s = f'{int_ranges[0][0]}'
622
+ else:
623
+ s = f'{int_ranges[0][0]}-{int_ranges[0][1]}'
624
+ for int_range in int_ranges[1:]:
625
+ if int_range[0] == int_range[1]:
626
+ s += f', {int_range[0]}'
627
+ else:
628
+ s += f', {int_range[0]}-{int_range[1]}'
629
+ return s
630
+
631
+
554
632
  def get_trailing_int(string):
555
633
  """Get the trailing integer in a string."""
556
634
  index_regex = re_compile(r'\d+$')
@@ -902,11 +980,255 @@ def file_exists_and_readable(f):
902
980
  return f
903
981
 
904
982
 
983
+ def rolling_average(
984
+ y, x=None, dtype=None, start=0, end=None, width=None,
985
+ stride=None, num=None, average=True, mode='valid',
986
+ use_convolve=None):
987
+ """
988
+ Returns the rolling sum or average of an array over the last
989
+ dimension.
990
+ """
991
+ y = np.asarray(y)
992
+ y_shape = y.shape
993
+ if y.ndim == 1:
994
+ y = np.expand_dims(y, 0)
995
+ else:
996
+ y = y.reshape((np.prod(y.shape[0:-1]), y.shape[-1]))
997
+ if x is not None:
998
+ x = np.asarray(x)
999
+ if x.ndim != 1:
1000
+ raise ValueError('Parameter "x" must be a 1D array-like')
1001
+ if x.size != y.shape[1]:
1002
+ raise ValueError(f'Dimensions of "x" and "y[1]" do not '
1003
+ f'match ({x.size} vs {y.shape[1]})')
1004
+ if dtype is None:
1005
+ if average:
1006
+ dtype = y.dtype
1007
+ else:
1008
+ dtype = np.float32
1009
+ if width is None and stride is None and num is None:
1010
+ raise ValueError('Invalid input parameters, specify at least one of '
1011
+ '"width", "stride" or "num"')
1012
+ if width is not None and not is_int(width, ge=1):
1013
+ raise ValueError(f'Invalid "width" parameter ({width})')
1014
+ if stride is not None and not is_int(stride, ge=1):
1015
+ raise ValueError(f'Invalid "stride" parameter ({stride})')
1016
+ if num is not None and not is_int(num, ge=1):
1017
+ raise ValueError(f'Invalid "num" parameter ({num})')
1018
+ if not isinstance(average, bool):
1019
+ raise ValueError(f'Invalid "average" parameter ({average})')
1020
+ if mode not in ('valid', 'full'):
1021
+ raise ValueError(f'Invalid "mode" parameter ({mode})')
1022
+ size = y.shape[1]
1023
+ if size < 2:
1024
+ raise ValueError(f'Invalid y[1] dimension ({size})')
1025
+ if not is_int(start, ge=0, lt=size):
1026
+ raise ValueError(f'Invalid "start" parameter ({start})')
1027
+ if end is None:
1028
+ end = size
1029
+ elif not is_int(end, gt=start, le=size):
1030
+ raise ValueError(f'Invalid "end" parameter ({end})')
1031
+ if use_convolve is None:
1032
+ if len(y_shape) ==1:
1033
+ use_convolve = True
1034
+ else:
1035
+ use_convolve = False
1036
+ if use_convolve and (start or end < size):
1037
+ y = np.take(y, range(start, end), axis=1)
1038
+ if x is not None:
1039
+ x = x[start:end]
1040
+ size = y.shape[1]
1041
+ else:
1042
+ size = end-start
1043
+
1044
+ if stride is None:
1045
+ if width is None:
1046
+ width = max(1, int(size/num))
1047
+ stride = width
1048
+ else:
1049
+ width = min(width, size)
1050
+ if num is None:
1051
+ stride = width
1052
+ else:
1053
+ stride = max(1, int((size-width) / (num-1)))
1054
+ else:
1055
+ stride = min(stride, size-stride)
1056
+ if width is None:
1057
+ width = stride
1058
+
1059
+ if mode == 'valid':
1060
+ num = 1 + max(0, int((size-width) / stride))
1061
+ else:
1062
+ num = int(size/stride)
1063
+ if num*stride < size:
1064
+ num += 1
1065
+
1066
+ if use_convolve:
1067
+ n_start = 0
1068
+ n_end = width
1069
+ weight = np.empty((num))
1070
+ for n in range(num):
1071
+ n_num = n_end-n_start
1072
+ weight[n] = n_num
1073
+ n_start += stride
1074
+ n_end = min(size, n_end+stride)
1075
+
1076
+ window = np.ones((width))
1077
+ if x is not None:
1078
+ if mode == 'valid':
1079
+ rx = np.convolve(x, window)[width-1:1-width:stride]
1080
+ else:
1081
+ rx = np.convolve(x, window)[width-1::stride]
1082
+ rx /= weight
1083
+
1084
+ ry = []
1085
+ if mode == 'valid':
1086
+ for i in range(y.shape[0]):
1087
+ ry.append(np.convolve(y[i], window)[width-1:1-width:stride])
1088
+ else:
1089
+ for i in range(y.shape[0]):
1090
+ ry.append(np.convolve(y[i], window)[width-1::stride])
1091
+ ry = np.reshape(ry, (*y_shape[0:-1], num))
1092
+ if len(y_shape) == 1:
1093
+ ry = np.squeeze(ry)
1094
+ if average:
1095
+ ry = (np.asarray(ry).astype(np.float32)/weight).astype(dtype)
1096
+ elif mode != 'valid':
1097
+ weight = np.where(weight < width, width/weight, 1.0)
1098
+ ry = (np.asarray(ry).astype(np.float32)*weight).astype(dtype)
1099
+ else:
1100
+ ry = np.zeros((num, y.shape[0]), dtype=y.dtype)
1101
+ if x is not None:
1102
+ rx = np.zeros(num, dtype=x.dtype)
1103
+ n_start = start
1104
+ n_end = n_start+width
1105
+ for n in range(num):
1106
+ y_sum = np.sum(y[:,n_start:n_end], 1)
1107
+ n_num = n_end-n_start
1108
+ if n_num < width:
1109
+ y_sum *= width/n_num
1110
+ ry[n] = y_sum
1111
+ if x is not None:
1112
+ rx[n] = np.sum(x[n_start:n_end])/n_num
1113
+ n_start += stride
1114
+ n_end = min(start+size, n_end+stride)
1115
+ ry = np.reshape(ry.T, (*y_shape[0:-1], num))
1116
+ if len(y_shape) == 1:
1117
+ ry = np.squeeze(ry)
1118
+ if average:
1119
+ ry = (ry.astype(np.float32)/width).astype(dtype)
1120
+
1121
+ if x is None:
1122
+ return ry
1123
+ return ry, rx
1124
+
1125
+
1126
+ def baseline_arPLS(
1127
+ y, mask=None, w=None, tol=1.e-8, lam=1.e6, max_iter=20,
1128
+ full_output=False):
1129
+ """Returns the smoothed baseline estimate of a spectrum.
1130
+
1131
+ Based on S.-J. Baek, A. Park, Y.-J Ahn, and J. Choo,
1132
+ "Baseline correction using asymmetrically reweighted penalized
1133
+ least squares smoothing", Analyst, 2015,140, 250-257
1134
+
1135
+ :param y: The spectrum.
1136
+ :type y: array-like
1137
+ :param mask: A mask to apply to the spectrum before baseline
1138
+ construction, default to `None`.
1139
+ :type mask: array-like, optional
1140
+ :param w: The weights (allows restart for additional ieterations),
1141
+ defaults to None.
1142
+ :type w: numpy.array, optional
1143
+ :param tol: The convergence tolerence, defaults to `1.e-8`.
1144
+ :type tol: float, optional
1145
+ :param lam: The &lambda (smoothness) parameter (the balance
1146
+ between the residual of the data and the baseline and the
1147
+ smoothness of the baseline). The suggested range is between
1148
+ 100 and 10^8, defaults to `10^6`.
1149
+ :type lam: float, optional
1150
+ :param max_iter: The maximum number of iterations,
1151
+ defaults to `20`.
1152
+ :type max_iter: int, optional
1153
+ :param full_output: Whether or not to also output the baseline
1154
+ corrected spectrum, the number of iterations and error in the
1155
+ returned result, defaults to `False`.
1156
+ :type full_output: bool, optional
1157
+ :return: The smoothed baseline, with optionally the baseline
1158
+ corrected spectrum, the weights, the number of iterations and
1159
+ the error in the returned result.
1160
+ :rtype: numpy.array [, numpy.array, int, float]
1161
+ """
1162
+ # With credit to: Daniel Casas-Orozco
1163
+ # https://stackoverflow.com/questions/29156532/python-baseline-correction-library
1164
+ # System modules
1165
+ from sys import float_info
1166
+
1167
+ # Third party modules
1168
+ from scipy.sparse import (
1169
+ spdiags,
1170
+ linalg,
1171
+ )
1172
+
1173
+ if not is_num(tol, gt=0):
1174
+ raise ValueError(f'Invalid tol parameter ({tol})')
1175
+ if not is_num(lam, gt=0):
1176
+ raise ValueError(f'Invalid lam parameter ({lam})')
1177
+ if not is_int(max_iter, gt=0):
1178
+ raise ValueError(f'Invalid max_iter parameter ({max_iter})')
1179
+ if not isinstance(full_output, bool):
1180
+ raise ValueError(f'Invalid full_output parameter ({max_iter})')
1181
+ y = np.asarray(y)
1182
+ if mask is not None:
1183
+ mask = mask.astype(bool)
1184
+ y_org = y
1185
+ y = y[mask]
1186
+ num = y.size
1187
+
1188
+ diag = np.ones((num-2))
1189
+ D = spdiags([diag, -2*diag, diag], [0, -1, -2], num, num-2)
1190
+
1191
+ H = lam * D.dot(D.T)
1192
+
1193
+ if w is None:
1194
+ w = np.ones(num)
1195
+ W = spdiags(w, 0, num, num)
1196
+
1197
+ error = 1
1198
+ num_iter = 0
1199
+
1200
+ exp_max = int(np.log(float_info.max))
1201
+ while error > tol and num_iter < max_iter:
1202
+ z = linalg.spsolve(W + H, W * y)
1203
+ d = y - z
1204
+ dn = d[d < 0]
1205
+
1206
+ m = np.mean(dn)
1207
+ s = np.std(dn)
1208
+
1209
+ w_new = 1.0 / (1.0 + np.exp(
1210
+ np.clip(2.0 * (d - (2.0*s - m))/s, None, exp_max)))
1211
+ error = np.linalg.norm(w_new - w) / np.linalg.norm(w)
1212
+ num_iter += 1
1213
+ w = w_new
1214
+ W.setdiag(w)
1215
+
1216
+ if mask is not None:
1217
+ zz = np.zeros(y_org.size)
1218
+ zz[mask] = z
1219
+ z = zz
1220
+ if full_output:
1221
+ d = y_org - z
1222
+ if full_output:
1223
+ return z, d, w, num_iter, float(error)
1224
+ return z
1225
+
1226
+
905
1227
  def select_mask_1d(
906
1228
  y, x=None, label=None, ref_data=[], preselected_index_ranges=None,
907
1229
  preselected_mask=None, title=None, xlabel=None, ylabel=None,
908
1230
  min_num_index_ranges=None, max_num_index_ranges=None,
909
- interactive=True):
1231
+ interactive=True, filename=None):
910
1232
  """Display a lineplot and have the user select a mask.
911
1233
 
912
1234
  :param y: One-dimensional data array for which a mask will be
@@ -945,18 +1267,21 @@ def select_mask_1d(
945
1267
  ranges, defaults to `None`.
946
1268
  :type max_num_index_ranges: int, optional
947
1269
  :param interactive: Show the plot and allow user interactions with
948
- the matplotlib figure, defults to `True`.
1270
+ the matplotlib figure, defaults to `True`.
949
1271
  :type interactive: bool, optional
950
- :return: A Matplotlib figure, a boolean mask array and the list of
951
- selected index ranges.
952
- :rtype: matplotlib.figure.Figure, numpy.ndarray,
953
- list[tuple(int, int)]
1272
+ :param filename: Save a .png of the plot to filename, defaults to
1273
+ `None`, in which case the plot is not saved.
1274
+ :type filename: str, optional
1275
+ :return: A boolean mask array and the list of selected index
1276
+ ranges.
1277
+ :rtype: numpy.ndarray, list[tuple(int, int)]
954
1278
  """
955
1279
  # Third party modules
956
- from matplotlib.patches import Patch
957
- from matplotlib.widgets import Button, SpanSelector
1280
+ if interactive or filename is not None:
1281
+ from matplotlib.patches import Patch
1282
+ from matplotlib.widgets import Button, SpanSelector
958
1283
 
959
- # local modules
1284
+ # Local modules
960
1285
  from CHAP.utils.general import index_nearest
961
1286
 
962
1287
  def change_fig_title(title):
@@ -1111,10 +1436,25 @@ def select_mask_1d(
1111
1436
  raise ValueError('Invalid parameter preselected_index_ranges '
1112
1437
  f'({preselected_index_ranges})')
1113
1438
 
1439
+ # Setup the preselected mask and index ranges if provided
1440
+ if preselected_mask is not None:
1441
+ preselected_index_ranges = update_index_ranges(
1442
+ update_mask(
1443
+ np.copy(np.asarray(preselected_mask, dtype=bool)),
1444
+ preselected_index_ranges))
1445
+
1446
+ if not interactive and filename is None:
1447
+
1448
+ # Update the mask with the preselected index ranges
1449
+ selected_mask = update_mask(len(x)*[False], preselected_index_ranges)
1450
+
1451
+ return selected_mask, preselected_index_ranges
1452
+
1114
1453
  spans = []
1115
1454
  fig_title = []
1116
1455
  error_texts = []
1117
1456
 
1457
+ # Setup the Matplotlib figure
1118
1458
  title_pos = (0.5, 0.95)
1119
1459
  title_props = {'fontsize': 'xx-large', 'horizontalalignment': 'center',
1120
1460
  'verticalalignment': 'bottom'}
@@ -1138,12 +1478,7 @@ def select_mask_1d(
1138
1478
  ax.set_xlim(x[0], x[-1])
1139
1479
  fig.subplots_adjust(bottom=0.0, top=0.85)
1140
1480
 
1141
- # Setup the preselected mask and index ranges if provided
1142
- if preselected_mask is not None:
1143
- preselected_index_ranges = update_index_ranges(
1144
- update_mask(
1145
- np.copy(np.asarray(preselected_mask, dtype=bool)),
1146
- preselected_index_ranges))
1481
+ # Add the preselected index ranges
1147
1482
  for min_, max_ in preselected_index_ranges:
1148
1483
  add_span(None, xrange_init=(x[min_], x[min(max_, num_data-1)]))
1149
1484
 
@@ -1158,7 +1493,8 @@ def select_mask_1d(
1158
1493
  fig.subplots_adjust(bottom=0.2)
1159
1494
 
1160
1495
  # Setup "Add span" button
1161
- add_span_btn = Button(plt.axes([0.15, 0.05, 0.15, 0.075]), 'Add span')
1496
+ add_span_btn = Button(
1497
+ plt.axes([0.15, 0.05, 0.15, 0.075]), 'Add span')
1162
1498
  add_span_cid = add_span_btn.on_clicked(add_span)
1163
1499
 
1164
1500
  # Setup "Reset" button
@@ -1188,15 +1524,18 @@ def select_mask_1d(
1188
1524
  # Update the mask with the currently selected index ranges
1189
1525
  selected_mask = update_mask(len(x)*[False], selected_index_ranges)
1190
1526
 
1191
- fig_title[0].set_in_layout(True)
1192
- fig.tight_layout(rect=(0, 0, 1, 0.95))
1527
+ if filename is not None:
1528
+ fig_title[0].set_in_layout(True)
1529
+ fig.tight_layout(rect=(0, 0, 1, 0.95))
1530
+ fig.savefig(filename)
1531
+ plt.close()
1193
1532
 
1194
- return fig, selected_mask, selected_index_ranges
1533
+ return selected_mask, selected_index_ranges
1195
1534
 
1196
1535
 
1197
1536
  def select_roi_1d(
1198
1537
  y, x=None, preselected_roi=None, title=None, xlabel=None, ylabel=None,
1199
- interactive=True):
1538
+ interactive=True, filename=None):
1200
1539
  """Display a 2D plot and have the user select a single region
1201
1540
  of interest.
1202
1541
 
@@ -1217,8 +1556,11 @@ def select_roi_1d(
1217
1556
  defaults to `None`.
1218
1557
  :type ylabel: str, optional
1219
1558
  :param interactive: Show the plot and allow user interactions with
1220
- the matplotlib figure, defults to `True`.
1559
+ the matplotlib figure, defaults to `True`.
1221
1560
  :type interactive: bool, optional
1561
+ :param filename: Save a .png of the plot to filename, defaults to
1562
+ `None`, in which case the plot is not saved.
1563
+ :type filename: str, optional
1222
1564
  :return: The selected region of interest as array indices and a
1223
1565
  matplotlib figure.
1224
1566
  :rtype: matplotlib.figure.Figure, tuple(int, int)
@@ -1233,16 +1575,17 @@ def select_roi_1d(
1233
1575
  f'({preselected_roi})')
1234
1576
  preselected_roi = [preselected_roi]
1235
1577
 
1236
- fig, mask, roi = select_mask_1d(
1578
+ mask, roi = select_mask_1d(
1237
1579
  y, x=x, preselected_index_ranges=preselected_roi, title=title,
1238
1580
  xlabel=xlabel, ylabel=ylabel, min_num_index_ranges=1,
1239
- max_num_index_ranges=1, interactive=interactive)
1581
+ max_num_index_ranges=1, interactive=interactive, filename=filename)
1240
1582
 
1241
- return fig, tuple(roi[0])
1583
+ return tuple(roi[0])
1242
1584
 
1243
1585
  def select_roi_2d(
1244
1586
  a, preselected_roi=None, title=None, title_a=None,
1245
- row_label='row index', column_label='column index', interactive=True):
1587
+ row_label='row index', column_label='column index', interactive=True,
1588
+ filename=None):
1246
1589
  """Display a 2D image and have the user select a single rectangular
1247
1590
  region of interest.
1248
1591
 
@@ -1265,12 +1608,15 @@ def select_roi_2d(
1265
1608
  :param interactive: Show the plot and allow user interactions with
1266
1609
  the matplotlib figure, defaults to `True`.
1267
1610
  :type interactive: bool, optional
1268
- :return: The selected region of interest as array indices and a
1269
- matplotlib figure.
1270
- :rtype: matplotlib.figure.Figure, tuple(int, int, int, int)
1611
+ :param filename: Save a .png of the plot to filename, defaults to
1612
+ `None`, in which case the plot is not saved.
1613
+ :type filename: str, optional
1614
+ :return: The selected region of interest as array indices.
1615
+ :rtype: tuple(int, int, int, int)
1271
1616
  """
1272
1617
  # Third party modules
1273
- from matplotlib.widgets import Button, RectangleSelector
1618
+ if interactive or filename is not None:
1619
+ from matplotlib.widgets import Button, RectangleSelector
1274
1620
 
1275
1621
  # Local modules
1276
1622
  from CHAP.utils.general import index_nearest
@@ -1327,9 +1673,6 @@ def select_roi_2d(
1327
1673
  change_fig_title(f'Selected ROI: {roi}')
1328
1674
  plt.close()
1329
1675
 
1330
- fig_title = []
1331
- subfig_title = []
1332
-
1333
1676
  # Check inputs
1334
1677
  a = np.asarray(a)
1335
1678
  if a.ndim != 2:
@@ -1342,6 +1685,12 @@ def select_roi_2d(
1342
1685
  if title is None:
1343
1686
  title = 'Click and drag to select or adjust a region of interest (ROI)'
1344
1687
 
1688
+ if not interactive and filename is None:
1689
+ return preselected_roi
1690
+
1691
+ fig_title = []
1692
+ subfig_title = []
1693
+
1345
1694
  title_pos = (0.5, 0.95)
1346
1695
  title_props = {'fontsize': 'xx-large', 'horizontalalignment': 'center',
1347
1696
  'verticalalignment': 'bottom'}
@@ -1370,8 +1719,9 @@ def select_roi_2d(
1370
1719
 
1371
1720
  if not interactive:
1372
1721
 
1373
- change_fig_title(
1374
- f'Selected ROI: {tuple(int(v) for v in preselected_roi)}')
1722
+ if preselected_roi is not None:
1723
+ change_fig_title(
1724
+ f'Selected ROI: {tuple(int(v) for v in preselected_roi)}')
1375
1725
 
1376
1726
  else:
1377
1727
 
@@ -1400,20 +1750,26 @@ def select_roi_2d(
1400
1750
  reset_btn.ax.remove()
1401
1751
  confirm_btn.ax.remove()
1402
1752
 
1403
- fig_title[0].set_in_layout(True)
1404
- fig.tight_layout(rect=(0, 0, 1, 0.95))
1753
+ if filename is not None:
1754
+ if fig_title:
1755
+ fig_title[0].set_in_layout(True)
1756
+ fig.tight_layout(rect=(0, 0, 1, 0.95))
1757
+ else:
1758
+ fig.tight_layout(rect=(0, 0, 1, 1))
1405
1759
 
1406
- # Remove the handles before returning the figure
1407
- if interactive:
1408
- rects[0]._center_handle.set_visible(False)
1409
- rects[0]._corner_handles.set_visible(False)
1410
- rects[0]._edge_handles.set_visible(False)
1760
+ # Remove the handles
1761
+ if interactive:
1762
+ rects[0]._center_handle.set_visible(False)
1763
+ rects[0]._corner_handles.set_visible(False)
1764
+ rects[0]._edge_handles.set_visible(False)
1765
+ fig.savefig(filename)
1766
+ plt.close()
1411
1767
 
1412
1768
  roi = tuple(int(v) for v in rects[0].extents)
1413
1769
  if roi[1]-roi[0] < 1 or roi[3]-roi[2] < 1:
1414
1770
  roi = None
1415
1771
 
1416
- return fig, roi
1772
+ return roi
1417
1773
 
1418
1774
 
1419
1775
  def select_image_indices(
@@ -1918,9 +2274,11 @@ def nxcopy(
1918
2274
  name=nxobject.nxname)
1919
2275
  elif isinstance(nxobject, (NXlink, NXfield)):
1920
2276
  # The top level nxobject is a (linked) field: return a copy
2277
+ attrs = nxobject.attrs
2278
+ attrs.pop('target', None)
1921
2279
  nxobject_copy = NXfield(
1922
2280
  value=nxobject.nxdata, name=nxobject.nxname,
1923
- attrs=nxobject.attrs)
2281
+ attrs=attrs)
1924
2282
  return nxobject_copy
1925
2283
  else:
1926
2284
  # Create a group with the same type/name as the nxobject