ChessAnalysisPipeline 0.0.13__py3-none-any.whl → 0.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ChessAnalysisPipeline might be problematic. Click here for more details.

CHAP/utils/general.py CHANGED
@@ -26,13 +26,57 @@ from sys import float_info
26
26
  import numpy as np
27
27
  try:
28
28
  import matplotlib.pyplot as plt
29
- from matplotlib.widgets import AxesWidget, RadioButtons
30
- from matplotlib import cbook
31
29
  except ImportError:
32
30
  pass
33
31
 
34
32
  logger = getLogger(__name__)
35
33
 
34
+ tiny = np.finfo(np.float64).resolution
35
+
36
+ def gformat(val, length=11):
37
+ """
38
+ Format a number with '%g'-like format, while:
39
+ - the length of the output string will be of the requested length
40
+ - positive numbers will have a leading blank
41
+ - the precision will be as high as possible
42
+ - trailing zeros will not be trimmed
43
+ """
44
+ # Code taken from lmfit library
45
+ if val is None or isinstance(val, bool):
46
+ return f'{repr(val):>{length}s}'
47
+ try:
48
+ expon = int(np.log10(abs(val)))
49
+ except (OverflowError, ValueError):
50
+ expon = 0
51
+ except TypeError:
52
+ return f'{repr(val):>{length}s}'
53
+
54
+ length = max(length, 7)
55
+ form = 'e'
56
+ prec = length - 7
57
+ if abs(expon) > 99:
58
+ prec -= 1
59
+ elif ((expon > 0 and expon < (prec+4)) or
60
+ (expon <= 0 and -expon < (prec-1))):
61
+ form = 'f'
62
+ prec += 4
63
+ if expon > 0:
64
+ prec -= expon
65
+ return f'{val:{length}.{prec}{form}}'
66
+
67
+
68
+ def getfloat_attr(obj, attr, length=11):
69
+ """Format an attribute of an object for printing."""
70
+ # Code taken from lmfit library
71
+ val = getattr(obj, attr, None)
72
+ if val is None:
73
+ return 'unknown'
74
+ if isinstance(val, int):
75
+ return f'{val}'
76
+ if isinstance(val, float):
77
+ return gformat(val, length=length).strip()
78
+ return repr(val)
79
+
36
80
 
37
81
  def depth_list(_list):
38
82
  """Return the depth of a list."""
@@ -88,6 +132,13 @@ def illegal_combination(
88
132
  raise ValueError(error_msg)
89
133
 
90
134
 
135
+ def not_zero(value):
136
+ """Return value with a minimal absolute size of tiny,
137
+ preserving the sign.
138
+ """
139
+ return float(np.copysign(max(tiny, abs(value)), value))
140
+
141
+
91
142
  def test_ge_gt_le_lt(
92
143
  ge, gt, le, lt, func, location=None, raise_error=False, log=True):
93
144
  """
@@ -449,8 +500,8 @@ def index_nearest_down(a, value):
449
500
  return index
450
501
 
451
502
 
452
- def index_nearest_upp(a, value):
453
- """Return index of nearest array value, rounded upp."""
503
+ def index_nearest_up(a, value):
504
+ """Return index of nearest array value, rounded up."""
454
505
  a = np.asarray(a)
455
506
  if a.ndim > 1:
456
507
  raise ValueError(
@@ -479,14 +530,14 @@ def get_consecutive_int_range(a):
479
530
 
480
531
 
481
532
  def round_to_n(x, n=1):
482
- """Round to a specific number of decimals."""
533
+ """Round to a specific number of sig figs ."""
483
534
  if x == 0.0:
484
535
  return 0
485
536
  return type(x)(round(x, n-1-int(np.floor(np.log10(abs(x))))))
486
537
 
487
538
 
488
539
  def round_up_to_n(x, n=1):
489
- """Round up to a specific number of decimals."""
540
+ """Round up to a specific number of sig figs."""
490
541
  x_round = round_to_n(x, n)
491
542
  if abs(x/x_round) > 1.0:
492
543
  x_round += np.sign(x) * 10**(np.floor(np.log10(abs(x)))+1-n)
@@ -494,7 +545,7 @@ def round_up_to_n(x, n=1):
494
545
 
495
546
 
496
547
  def trunc_to_n(x, n=1):
497
- """Truncate to a specific number of decimals."""
548
+ """Truncate to a specific number of sig figs."""
498
549
  x_round = round_to_n(x, n)
499
550
  if abs(x_round/x) > 1.0:
500
551
  x_round -= np.sign(x) * 10**(np.floor(np.log10(abs(x)))+1-n)
@@ -512,7 +563,9 @@ def almost_equal(a, b, sig_figs):
512
563
  f'b: {b}, {type(b)})')
513
564
 
514
565
 
515
- def string_to_list(s, split_on_dash=True, remove_duplicates=True, sort=True):
566
+ def string_to_list(
567
+ s, split_on_dash=True, remove_duplicates=True, sort=True,
568
+ raise_error=False):
516
569
  """
517
570
  Return a list of numbers by splitting/expanding a string on any
518
571
  combination of commas, whitespaces, or dashes (when
@@ -526,8 +579,11 @@ def string_to_list(s, split_on_dash=True, remove_duplicates=True, sort=True):
526
579
  return []
527
580
  try:
528
581
  list1 = re_split(r'\s+,\s+|\s+,|,\s+|\s+|,', s.strip())
529
- except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
530
- return None
582
+ except (ValueError, TypeError, SyntaxError, MemoryError,
583
+ RecursionError) as e:
584
+ if not raise_error:
585
+ return None
586
+ raise e
531
587
  if split_on_dash:
532
588
  try:
533
589
  l_of_i = []
@@ -542,8 +598,10 @@ def string_to_list(s, split_on_dash=True, remove_duplicates=True, sort=True):
542
598
  else:
543
599
  raise ValueError
544
600
  except (ValueError, TypeError, SyntaxError, MemoryError,
545
- RecursionError):
546
- return None
601
+ RecursionError) as e:
602
+ if not raise_error:
603
+ return None
604
+ raise e
547
605
  else:
548
606
  l_of_i = [literal_eval(x) for x in list1]
549
607
  if remove_duplicates:
@@ -655,8 +713,8 @@ def _input_int_or_num(
655
713
 
656
714
 
657
715
  def input_int_list(
658
- s=None, ge=None, le=None, split_on_dash=True, remove_duplicates=True,
659
- sort=True, raise_error=False, log=True):
716
+ s=None, num_max=None, ge=None, le=None, split_on_dash=True,
717
+ remove_duplicates=True, sort=True, raise_error=False, log=True):
660
718
  """
661
719
  Prompt the user to input a list of interger and split the entered
662
720
  string on any combination of commas, whitespaces, or dashes (when
@@ -669,13 +727,13 @@ def input_int_list(
669
727
  return None upon an illegal input
670
728
  """
671
729
  return _input_int_or_num_list(
672
- 'int', s, ge, le, split_on_dash, remove_duplicates, sort, raise_error,
673
- log)
730
+ 'int', s, num_max, ge, le, split_on_dash, remove_duplicates, sort,
731
+ raise_error, log)
674
732
 
675
733
 
676
734
  def input_num_list(
677
- s=None, ge=None, le=None, remove_duplicates=True, sort=True,
678
- raise_error=False, log=True):
735
+ s=None, num_max=None, ge=None, le=None, remove_duplicates=True,
736
+ sort=True, raise_error=False, log=True):
679
737
  """
680
738
  Prompt the user to input a list of numbers and split the entered
681
739
  string on any combination of commas or whitespaces.
@@ -687,11 +745,12 @@ def input_num_list(
687
745
  return None upon an illegal input
688
746
  """
689
747
  return _input_int_or_num_list(
690
- 'num', s, ge, le, False, remove_duplicates, sort, raise_error, log)
748
+ 'num', s, num_max, ge, le, False, remove_duplicates, sort, raise_error,
749
+ log)
691
750
 
692
751
 
693
752
  def _input_int_or_num_list(
694
- type_str, s=None, ge=None, le=None, split_on_dash=True,
753
+ type_str, s=None, num_max=None, ge=None, le=None, split_on_dash=True,
695
754
  remove_duplicates=True, sort=True, raise_error=False, log=True):
696
755
  # RV do we want a limit on max dimension?
697
756
  if type_str == 'int':
@@ -707,6 +766,9 @@ def _input_int_or_num_list(
707
766
  else:
708
767
  illegal_value(type_str, 'type_str', '_input_int_or_num_list')
709
768
  return None
769
+ if (num_max is not None
770
+ and not is_int(num_max, gt=0, raise_error=raise_error, log=log)):
771
+ return None
710
772
  v_range = f'{range_string_ge_gt_le_lt(ge=ge, le=le)}'
711
773
  if v_range:
712
774
  v_range = f' (each value in {v_range})'
@@ -721,14 +783,17 @@ def _input_int_or_num_list(
721
783
  except:
722
784
  print('Unexpected error')
723
785
  raise
724
- if (not isinstance(_list, list) or any(
725
- not _is_int_or_num(v, type_str, ge=ge, le=le) for v in _list)):
786
+ if (not isinstance(_list, list)
787
+ or (num_max is not None and len(_list) > num_max)
788
+ or any(
789
+ not _is_int_or_num(v, type_str, ge=ge, le=le) for v in _list)):
790
+ num = '' if num_max is None else f'up to {num_max} '
726
791
  if split_on_dash:
727
- print('Invalid input: enter a valid set of dash/comma/whitespace '
728
- 'separated integers e.g. 1 3,5-8 , 12')
792
+ print(f'Invalid input: enter a valid set of {num}dash/comma/'
793
+ 'whitespace separated numbers e.g. 1 3,5-8 , 12')
729
794
  else:
730
- print('Invalid input: enter a valid set of comma/whitespace '
731
- 'separated integers e.g. 1 3,5 8 , 12')
795
+ print(f'Invalid input: enter a valid set of {num}comma/whitespace '
796
+ 'separated numbers e.g. 1 3,5 8 , 12')
732
797
  _list = _input_int_or_num_list(
733
798
  type_str, s, ge, le, split_on_dash, remove_duplicates, sort,
734
799
  raise_error, log)
@@ -897,11 +962,255 @@ def file_exists_and_readable(f):
897
962
  return f
898
963
 
899
964
 
965
+ def rolling_average(
966
+ y, x=None, dtype=None, start=0, end=None, width=None,
967
+ stride=None, num=None, average=True, mode='valid',
968
+ use_convolve=None):
969
+ """
970
+ Returns the rolling sum or average of an array over the last
971
+ dimension.
972
+ """
973
+ y = np.asarray(y)
974
+ y_shape = y.shape
975
+ if y.ndim == 1:
976
+ y = np.expand_dims(y, 0)
977
+ else:
978
+ y = y.reshape((np.prod(y.shape[0:-1]), y.shape[-1]))
979
+ if x is not None:
980
+ x = np.asarray(x)
981
+ if x.ndim != 1:
982
+ raise ValueError('Parameter "x" must be a 1D array-like')
983
+ if x.size != y.shape[1]:
984
+ raise ValueError(f'Dimensions of "x" and "y[1]" do not '
985
+ f'match ({x.size} vs {y.shape[1]})')
986
+ if dtype is None:
987
+ if average:
988
+ dtype = y.dtype
989
+ else:
990
+ dtype = np.float32
991
+ if width is None and stride is None and num is None:
992
+ raise ValueError('Invalid input parameters, specify at least one of '
993
+ '"width", "stride" or "num"')
994
+ if width is not None and not is_int(width, ge=1):
995
+ raise ValueError(f'Invalid "width" parameter ({width})')
996
+ if stride is not None and not is_int(stride, ge=1):
997
+ raise ValueError(f'Invalid "stride" parameter ({stride})')
998
+ if num is not None and not is_int(num, ge=1):
999
+ raise ValueError(f'Invalid "num" parameter ({num})')
1000
+ if not isinstance(average, bool):
1001
+ raise ValueError(f'Invalid "average" parameter ({average})')
1002
+ if mode not in ('valid', 'full'):
1003
+ raise ValueError(f'Invalid "mode" parameter ({mode})')
1004
+ size = y.shape[1]
1005
+ if size < 2:
1006
+ raise ValueError(f'Invalid y[1] dimension ({size})')
1007
+ if not is_int(start, ge=0, lt=size):
1008
+ raise ValueError(f'Invalid "start" parameter ({start})')
1009
+ if end is None:
1010
+ end = size
1011
+ elif not is_int(end, gt=start, le=size):
1012
+ raise ValueError(f'Invalid "end" parameter ({end})')
1013
+ if use_convolve is None:
1014
+ if len(y_shape) ==1:
1015
+ use_convolve = True
1016
+ else:
1017
+ use_convolve = False
1018
+ if use_convolve and (start or end < size):
1019
+ y = np.take(y, range(start, end), axis=1)
1020
+ if x is not None:
1021
+ x = x[start:end]
1022
+ size = y.shape[1]
1023
+ else:
1024
+ size = end-start
1025
+
1026
+ if stride is None:
1027
+ if width is None:
1028
+ width = max(1, int(size/num))
1029
+ stride = width
1030
+ else:
1031
+ width = min(width, size)
1032
+ if num is None:
1033
+ stride = width
1034
+ else:
1035
+ stride = max(1, int((size-width) / (num-1)))
1036
+ else:
1037
+ stride = min(stride, size-stride)
1038
+ if width is None:
1039
+ width = stride
1040
+
1041
+ if mode == 'valid':
1042
+ num = 1 + max(0, int((size-width) / stride))
1043
+ else:
1044
+ num = int(size/stride)
1045
+ if num*stride < size:
1046
+ num += 1
1047
+
1048
+ if use_convolve:
1049
+ n_start = 0
1050
+ n_end = width
1051
+ weight = np.empty((num))
1052
+ for n in range(num):
1053
+ n_num = n_end-n_start
1054
+ weight[n] = n_num
1055
+ n_start += stride
1056
+ n_end = min(size, n_end+stride)
1057
+
1058
+ window = np.ones((width))
1059
+ if x is not None:
1060
+ if mode == 'valid':
1061
+ rx = np.convolve(x, window)[width-1:1-width:stride]
1062
+ else:
1063
+ rx = np.convolve(x, window)[width-1::stride]
1064
+ rx /= weight
1065
+
1066
+ ry = []
1067
+ if mode == 'valid':
1068
+ for i in range(y.shape[0]):
1069
+ ry.append(np.convolve(y[i], window)[width-1:1-width:stride])
1070
+ else:
1071
+ for i in range(y.shape[0]):
1072
+ ry.append(np.convolve(y[i], window)[width-1::stride])
1073
+ ry = np.reshape(ry, (*y_shape[0:-1], num))
1074
+ if len(y_shape) == 1:
1075
+ ry = np.squeeze(ry)
1076
+ if average:
1077
+ ry = (np.asarray(ry).astype(np.float32)/weight).astype(dtype)
1078
+ elif mode != 'valid':
1079
+ weight = np.where(weight < width, width/weight, 1.0)
1080
+ ry = (np.asarray(ry).astype(np.float32)*weight).astype(dtype)
1081
+ else:
1082
+ ry = np.zeros((num, y.shape[0]), dtype=y.dtype)
1083
+ if x is not None:
1084
+ rx = np.zeros(num, dtype=x.dtype)
1085
+ n_start = start
1086
+ n_end = n_start+width
1087
+ for n in range(num):
1088
+ y_sum = np.sum(y[:,n_start:n_end], 1)
1089
+ n_num = n_end-n_start
1090
+ if n_num < width:
1091
+ y_sum *= width/n_num
1092
+ ry[n] = y_sum
1093
+ if x is not None:
1094
+ rx[n] = np.sum(x[n_start:n_end])/n_num
1095
+ n_start += stride
1096
+ n_end = min(start+size, n_end+stride)
1097
+ ry = np.reshape(ry.T, (*y_shape[0:-1], num))
1098
+ if len(y_shape) == 1:
1099
+ ry = np.squeeze(ry)
1100
+ if average:
1101
+ ry = (ry.astype(np.float32)/width).astype(dtype)
1102
+
1103
+ if x is None:
1104
+ return ry
1105
+ return ry, rx
1106
+
1107
+
1108
+ def baseline_arPLS(
1109
+ y, mask=None, w=None, tol=1.e-8, lam=1.e6, max_iter=20,
1110
+ full_output=False):
1111
+ """Returns the smoothed baseline estimate of a spectrum.
1112
+
1113
+ Based on S.-J. Baek, A. Park, Y.-J Ahn, and J. Choo,
1114
+ "Baseline correction using asymmetrically reweighted penalized
1115
+ least squares smoothing", Analyst, 2015,140, 250-257
1116
+
1117
+ :param y: The spectrum.
1118
+ :type y: array-like
1119
+ :param mask: A mask to apply to the spectrum before baseline
1120
+ construction, default to `None`.
1121
+ :type mask: array-like, optional
1122
+ :param w: The weights (allows restart for additional ieterations),
1123
+ defaults to None.
1124
+ :type w: numpy.array, optional
1125
+ :param tol: The convergence tolerence, defaults to `1.e-8`.
1126
+ :type tol: float, optional
1127
+ :param lam: The &lambda (smoothness) parameter (the balance
1128
+ between the residual of the data and the baseline and the
1129
+ smoothness of the baseline). The suggested range is between
1130
+ 100 and 10^8, defaults to `10^6`.
1131
+ :type lam: float, optional
1132
+ :param max_iter: The maximum number of iterations,
1133
+ defaults to `20`.
1134
+ :type max_iter: int, optional
1135
+ :param full_output: Whether or not to also output the baseline
1136
+ corrected spectrum, the number of iterations and error in the
1137
+ returned result, defaults to `False`.
1138
+ :type full_output: bool, optional
1139
+ :return: The smoothed baseline, with optionally the baseline
1140
+ corrected spectrum, the weights, the number of iterations and
1141
+ the error in the returned result.
1142
+ :rtype: numpy.array [, numpy.array, int, float]
1143
+ """
1144
+ # With credit to: Daniel Casas-Orozco
1145
+ # https://stackoverflow.com/questions/29156532/python-baseline-correction-library
1146
+ # System modules
1147
+ from sys import float_info
1148
+
1149
+ # Third party modules
1150
+ from scipy.sparse import (
1151
+ spdiags,
1152
+ linalg,
1153
+ )
1154
+
1155
+ if not is_num(tol, gt=0):
1156
+ raise ValueError(f'Invalid tol parameter ({tol})')
1157
+ if not is_num(lam, gt=0):
1158
+ raise ValueError(f'Invalid lam parameter ({lam})')
1159
+ if not is_int(max_iter, gt=0):
1160
+ raise ValueError(f'Invalid max_iter parameter ({max_iter})')
1161
+ if not isinstance(full_output, bool):
1162
+ raise ValueError(f'Invalid full_output parameter ({max_iter})')
1163
+ y = np.asarray(y)
1164
+ if mask is not None:
1165
+ mask = mask.astype(bool)
1166
+ y_org = y
1167
+ y = y[mask]
1168
+ num = y.size
1169
+
1170
+ diag = np.ones((num-2))
1171
+ D = spdiags([diag, -2*diag, diag], [0, -1, -2], num, num-2)
1172
+
1173
+ H = lam * D.dot(D.T)
1174
+
1175
+ if w is None:
1176
+ w = np.ones(num)
1177
+ W = spdiags(w, 0, num, num)
1178
+
1179
+ error = 1
1180
+ num_iter = 0
1181
+
1182
+ exp_max = int(np.log(float_info.max))
1183
+ while error > tol and num_iter < max_iter:
1184
+ z = linalg.spsolve(W + H, W * y)
1185
+ d = y - z
1186
+ dn = d[d < 0]
1187
+
1188
+ m = np.mean(dn)
1189
+ s = np.std(dn)
1190
+
1191
+ w_new = 1.0 / (1.0 + np.exp(
1192
+ np.clip(2.0 * (d - (2.0*s - m))/s, None, exp_max)))
1193
+ error = np.linalg.norm(w_new - w) / np.linalg.norm(w)
1194
+ num_iter += 1
1195
+ w = w_new
1196
+ W.setdiag(w)
1197
+
1198
+ if mask is not None:
1199
+ zz = np.zeros(y_org.size)
1200
+ zz[mask] = z
1201
+ z = zz
1202
+ if full_output:
1203
+ d = y_org - z
1204
+ if full_output:
1205
+ return z, d, w, num_iter, float(error)
1206
+ return z
1207
+
1208
+
900
1209
  def select_mask_1d(
901
1210
  y, x=None, label=None, ref_data=[], preselected_index_ranges=None,
902
1211
  preselected_mask=None, title=None, xlabel=None, ylabel=None,
903
1212
  min_num_index_ranges=None, max_num_index_ranges=None,
904
- interactive=True):
1213
+ interactive=True, filename=None):
905
1214
  """Display a lineplot and have the user select a mask.
906
1215
 
907
1216
  :param y: One-dimensional data array for which a mask will be
@@ -940,18 +1249,21 @@ def select_mask_1d(
940
1249
  ranges, defaults to `None`.
941
1250
  :type max_num_index_ranges: int, optional
942
1251
  :param interactive: Show the plot and allow user interactions with
943
- the matplotlib figure, defults to `True`.
1252
+ the matplotlib figure, defaults to `True`.
944
1253
  :type interactive: bool, optional
945
- :return: A Matplotlib figure, a boolean mask array and the list of
946
- selected index ranges.
947
- :rtype: matplotlib.figure.Figure, numpy.ndarray,
948
- list[tuple(int, int)]
1254
+ :param filename: Save a .png of the plot to filename, defaults to
1255
+ `None`, in which case the plot is not saved.
1256
+ :type filename: str, optional
1257
+ :return: A boolean mask array and the list of selected index
1258
+ ranges.
1259
+ :rtype: numpy.ndarray, list[tuple(int, int)]
949
1260
  """
950
1261
  # Third party modules
951
- from matplotlib.patches import Patch
952
- from matplotlib.widgets import Button, SpanSelector
1262
+ if interactive or filename is not None:
1263
+ from matplotlib.patches import Patch
1264
+ from matplotlib.widgets import Button, SpanSelector
953
1265
 
954
- # local modules
1266
+ # Local modules
955
1267
  from CHAP.utils.general import index_nearest
956
1268
 
957
1269
  def change_fig_title(title):
@@ -1105,23 +1417,26 @@ def select_mask_1d(
1105
1417
  for v in preselected_index_ranges)):
1106
1418
  raise ValueError('Invalid parameter preselected_index_ranges '
1107
1419
  f'({preselected_index_ranges})')
1108
- if (min_num_index_ranges is not None
1109
- and len(preselected_index_ranges) < min_num_index_ranges):
1110
- raise ValueError(
1111
- 'Invalid parameter preselected_index_ranges '
1112
- f'({preselected_index_ranges}), number of selected index '
1113
- f'ranges must be >= {min_num_index_ranges}')
1114
- if (max_num_index_ranges is not None
1115
- and len(preselected_index_ranges) > max_num_index_ranges):
1116
- raise ValueError(
1117
- 'Invalid parameter preselected_index_ranges '
1118
- f'({preselected_index_ranges}), number of selected index '
1119
- f'ranges must be <= {max_num_index_ranges}')
1420
+
1421
+ # Setup the preselected mask and index ranges if provided
1422
+ if preselected_mask is not None:
1423
+ preselected_index_ranges = update_index_ranges(
1424
+ update_mask(
1425
+ np.copy(np.asarray(preselected_mask, dtype=bool)),
1426
+ preselected_index_ranges))
1427
+
1428
+ if not interactive and filename is None:
1429
+
1430
+ # Update the mask with the preselected index ranges
1431
+ selected_mask = update_mask(len(x)*[False], preselected_index_ranges)
1432
+
1433
+ return selected_mask, preselected_index_ranges
1120
1434
 
1121
1435
  spans = []
1122
1436
  fig_title = []
1123
1437
  error_texts = []
1124
1438
 
1439
+ # Setup the Matplotlib figure
1125
1440
  title_pos = (0.5, 0.95)
1126
1441
  title_props = {'fontsize': 'xx-large', 'horizontalalignment': 'center',
1127
1442
  'verticalalignment': 'bottom'}
@@ -1145,12 +1460,7 @@ def select_mask_1d(
1145
1460
  ax.set_xlim(x[0], x[-1])
1146
1461
  fig.subplots_adjust(bottom=0.0, top=0.85)
1147
1462
 
1148
- # Setup the preselected mask and index ranges if provided
1149
- if preselected_mask is not None:
1150
- preselected_index_ranges = update_index_ranges(
1151
- update_mask(
1152
- np.copy(np.asarray(preselected_mask, dtype=bool)),
1153
- preselected_index_ranges))
1463
+ # Add the preselected index ranges
1154
1464
  for min_, max_ in preselected_index_ranges:
1155
1465
  add_span(None, xrange_init=(x[min_], x[min(max_, num_data-1)]))
1156
1466
 
@@ -1165,7 +1475,8 @@ def select_mask_1d(
1165
1475
  fig.subplots_adjust(bottom=0.2)
1166
1476
 
1167
1477
  # Setup "Add span" button
1168
- add_span_btn = Button(plt.axes([0.15, 0.05, 0.15, 0.075]), 'Add span')
1478
+ add_span_btn = Button(
1479
+ plt.axes([0.15, 0.05, 0.15, 0.075]), 'Add span')
1169
1480
  add_span_cid = add_span_btn.on_clicked(add_span)
1170
1481
 
1171
1482
  # Setup "Reset" button
@@ -1191,21 +1502,22 @@ def select_mask_1d(
1191
1502
  plt.subplots_adjust(bottom=0.0)
1192
1503
 
1193
1504
  selected_index_ranges = get_selected_index_ranges()
1194
- if not selected_index_ranges:
1195
- selected_index_ranges = None
1196
1505
 
1197
1506
  # Update the mask with the currently selected index ranges
1198
1507
  selected_mask = update_mask(len(x)*[False], selected_index_ranges)
1199
1508
 
1200
- fig_title[0].set_in_layout(True)
1201
- fig.tight_layout(rect=(0, 0, 1, 0.95))
1509
+ if filename is not None:
1510
+ fig_title[0].set_in_layout(True)
1511
+ fig.tight_layout(rect=(0, 0, 1, 0.95))
1512
+ fig.savefig(filename)
1513
+ plt.close()
1202
1514
 
1203
- return fig, selected_mask, selected_index_ranges
1515
+ return selected_mask, selected_index_ranges
1204
1516
 
1205
1517
 
1206
1518
  def select_roi_1d(
1207
1519
  y, x=None, preselected_roi=None, title=None, xlabel=None, ylabel=None,
1208
- interactive=True):
1520
+ interactive=True, filename=None):
1209
1521
  """Display a 2D plot and have the user select a single region
1210
1522
  of interest.
1211
1523
 
@@ -1226,8 +1538,11 @@ def select_roi_1d(
1226
1538
  defaults to `None`.
1227
1539
  :type ylabel: str, optional
1228
1540
  :param interactive: Show the plot and allow user interactions with
1229
- the matplotlib figure, defults to `True`.
1541
+ the matplotlib figure, defaults to `True`.
1230
1542
  :type interactive: bool, optional
1543
+ :param filename: Save a .png of the plot to filename, defaults to
1544
+ `None`, in which case the plot is not saved.
1545
+ :type filename: str, optional
1231
1546
  :return: The selected region of interest as array indices and a
1232
1547
  matplotlib figure.
1233
1548
  :rtype: matplotlib.figure.Figure, tuple(int, int)
@@ -1242,16 +1557,17 @@ def select_roi_1d(
1242
1557
  f'({preselected_roi})')
1243
1558
  preselected_roi = [preselected_roi]
1244
1559
 
1245
- fig, mask, roi = select_mask_1d(
1560
+ mask, roi = select_mask_1d(
1246
1561
  y, x=x, preselected_index_ranges=preselected_roi, title=title,
1247
1562
  xlabel=xlabel, ylabel=ylabel, min_num_index_ranges=1,
1248
- max_num_index_ranges=1, interactive=interactive)
1563
+ max_num_index_ranges=1, interactive=interactive, filename=filename)
1249
1564
 
1250
- return fig, tuple(roi[0])
1565
+ return tuple(roi[0])
1251
1566
 
1252
1567
  def select_roi_2d(
1253
1568
  a, preselected_roi=None, title=None, title_a=None,
1254
- row_label='row index', column_label='column index', interactive=True):
1569
+ row_label='row index', column_label='column index', interactive=True,
1570
+ filename=None):
1255
1571
  """Display a 2D image and have the user select a single rectangular
1256
1572
  region of interest.
1257
1573
 
@@ -1274,12 +1590,15 @@ def select_roi_2d(
1274
1590
  :param interactive: Show the plot and allow user interactions with
1275
1591
  the matplotlib figure, defaults to `True`.
1276
1592
  :type interactive: bool, optional
1277
- :return: The selected region of interest as array indices and a
1278
- matplotlib figure.
1279
- :rtype: matplotlib.figure.Figure, tuple(int, int, int, int)
1593
+ :param filename: Save a .png of the plot to filename, defaults to
1594
+ `None`, in which case the plot is not saved.
1595
+ :type filename: str, optional
1596
+ :return: The selected region of interest as array indices.
1597
+ :rtype: tuple(int, int, int, int)
1280
1598
  """
1281
1599
  # Third party modules
1282
- from matplotlib.widgets import Button, RectangleSelector
1600
+ if interactive or filename is not None:
1601
+ from matplotlib.widgets import Button, RectangleSelector
1283
1602
 
1284
1603
  # Local modules
1285
1604
  from CHAP.utils.general import index_nearest
@@ -1290,46 +1609,52 @@ def select_roi_2d(
1290
1609
  fig_title.pop()
1291
1610
  fig_title.append(plt.figtext(*title_pos, title, **title_props))
1292
1611
 
1293
- def change_error_text(error):
1294
- if error_texts:
1295
- error_texts[0].remove()
1296
- error_texts.pop()
1297
- error_texts.append(plt.figtext(*error_pos, error, **error_props))
1612
+ def change_subfig_title(error):
1613
+ if subfig_title:
1614
+ subfig_title[0].remove()
1615
+ subfig_title.pop()
1616
+ subfig_title.append(plt.figtext(*error_pos, error, **error_props))
1617
+
1618
+ def clear_selection():
1619
+ rects[0].set_visible(False)
1620
+ rects.pop()
1621
+ rects.append(
1622
+ RectangleSelector(
1623
+ ax, on_rect_select, props=rect_props,
1624
+ useblit=True, interactive=interactive, drag_from_anywhere=True,
1625
+ ignore_event_outside=False))
1298
1626
 
1299
1627
  def on_rect_select(eclick, erelease):
1300
1628
  """Callback function for the RectangleSelector widget."""
1301
- change_error_text(
1302
- f'Selected ROI: {tuple(int(v) for v in rects[0].extents)}')
1629
+ if (not int(rects[0].extents[1]) - int(rects[0].extents[0])
1630
+ or not int(rects[0].extents[3]) - int(rects[0].extents[2])):
1631
+ clear_selection()
1632
+ change_subfig_title(
1633
+ f'Selected ROI too small, try again')
1634
+ else:
1635
+ change_subfig_title(
1636
+ f'Selected ROI: {tuple(int(v) for v in rects[0].extents)}')
1303
1637
  plt.draw()
1304
1638
 
1305
1639
  def reset(event):
1306
1640
  """Callback function for the "Reset" button."""
1307
- if error_texts:
1308
- error_texts[0].remove()
1309
- error_texts.pop()
1310
- rects[0].set_visible(False)
1311
- rects.pop()
1312
- rects.append(
1313
- RectangleSelector(
1314
- ax, on_rect_select, props=rect_props, useblit=True,
1315
- interactive=interactive, drag_from_anywhere=True,
1316
- ignore_event_outside=False))
1641
+ if subfig_title:
1642
+ subfig_title[0].remove()
1643
+ subfig_title.pop()
1644
+ clear_selection()
1317
1645
  plt.draw()
1318
1646
 
1319
1647
  def confirm(event):
1320
1648
  """Callback function for the "Confirm" button."""
1321
- if error_texts:
1322
- error_texts[0].remove()
1323
- error_texts.pop()
1649
+ if subfig_title:
1650
+ subfig_title[0].remove()
1651
+ subfig_title.pop()
1324
1652
  roi = tuple(int(v) for v in rects[0].extents)
1325
1653
  if roi[1]-roi[0] < 1 or roi[3]-roi[2] < 1:
1326
1654
  roi = None
1327
1655
  change_fig_title(f'Selected ROI: {roi}')
1328
1656
  plt.close()
1329
1657
 
1330
- fig_title = []
1331
- error_texts = []
1332
-
1333
1658
  # Check inputs
1334
1659
  a = np.asarray(a)
1335
1660
  if a.ndim != 2:
@@ -1342,6 +1667,12 @@ def select_roi_2d(
1342
1667
  if title is None:
1343
1668
  title = 'Click and drag to select or adjust a region of interest (ROI)'
1344
1669
 
1670
+ if not interactive and filename is None:
1671
+ return preselected_roi
1672
+
1673
+ fig_title = []
1674
+ subfig_title = []
1675
+
1345
1676
  title_pos = (0.5, 0.95)
1346
1677
  title_props = {'fontsize': 'xx-large', 'horizontalalignment': 'center',
1347
1678
  'verticalalignment': 'bottom'}
@@ -1370,14 +1701,15 @@ def select_roi_2d(
1370
1701
 
1371
1702
  if not interactive:
1372
1703
 
1373
- change_fig_title(
1374
- f'Selected ROI: {tuple(int(v) for v in preselected_roi)}')
1704
+ if preselected_roi is not None:
1705
+ change_fig_title(
1706
+ f'Selected ROI: {tuple(int(v) for v in preselected_roi)}')
1375
1707
 
1376
1708
  else:
1377
1709
 
1378
1710
  change_fig_title(title)
1379
1711
  if preselected_roi is not None:
1380
- change_error_text(
1712
+ change_subfig_title(
1381
1713
  f'Preselected ROI: {tuple(int(v) for v in preselected_roi)}')
1382
1714
  fig.subplots_adjust(bottom=0.2)
1383
1715
 
@@ -1400,20 +1732,26 @@ def select_roi_2d(
1400
1732
  reset_btn.ax.remove()
1401
1733
  confirm_btn.ax.remove()
1402
1734
 
1403
- fig_title[0].set_in_layout(True)
1404
- fig.tight_layout(rect=(0, 0, 1, 0.95))
1735
+ if filename is not None:
1736
+ if fig_title:
1737
+ fig_title[0].set_in_layout(True)
1738
+ fig.tight_layout(rect=(0, 0, 1, 0.95))
1739
+ else:
1740
+ fig.tight_layout(rect=(0, 0, 1, 1))
1405
1741
 
1406
- # Remove the handles before returning the figure
1407
- if interactive:
1408
- rects[0]._center_handle.set_visible(False)
1409
- rects[0]._corner_handles.set_visible(False)
1410
- rects[0]._edge_handles.set_visible(False)
1742
+ # Remove the handles
1743
+ if interactive:
1744
+ rects[0]._center_handle.set_visible(False)
1745
+ rects[0]._corner_handles.set_visible(False)
1746
+ rects[0]._edge_handles.set_visible(False)
1747
+ fig.savefig(filename)
1748
+ plt.close()
1411
1749
 
1412
1750
  roi = tuple(int(v) for v in rects[0].extents)
1413
1751
  if roi[1]-roi[0] < 1 or roi[3]-roi[2] < 1:
1414
1752
  roi = None
1415
1753
 
1416
- return fig, roi
1754
+ return roi
1417
1755
 
1418
1756
 
1419
1757
  def select_image_indices(
@@ -1436,7 +1774,7 @@ def select_image_indices(
1436
1774
  :param preselected_indices: Preselected image indices,
1437
1775
  defaults to `None`.
1438
1776
  :type preselected_indices: tuple(int), list(int), optional
1439
- :param axis_index_offset: Offset in axes index range and
1777
+ :param axis_index_offset: Offset in axis index range and
1440
1778
  preselected indices, defaults to `0`.
1441
1779
  :type axis_index_offset: int, optional
1442
1780
  :param min_range: The minimal range spanned by the selected
@@ -1535,7 +1873,7 @@ def select_image_indices(
1535
1873
  try:
1536
1874
  index = int(expression)
1537
1875
  if (index < axis_index_offset
1538
- or index >= axis_index_offset+a.shape[axis]):
1876
+ or index > axis_index_offset+a.shape[axis]):
1539
1877
  raise ValueError
1540
1878
  except ValueError:
1541
1879
  change_error_text(
@@ -1872,3 +2210,128 @@ def quick_plot(
1872
2210
  if save_fig:
1873
2211
  plt.savefig(path)
1874
2212
  plt.show(block=block)
2213
+
2214
+
2215
+ def nxcopy(
2216
+ nxobject, exclude_nxpaths=None, nxpath_prefix=None,
2217
+ nxpathabs_prefix=None, nxpath_copy_abspath=None):
2218
+ """
2219
+ Function that returns a copy of a nexus object, optionally exluding
2220
+ certain child items.
2221
+
2222
+ :param nxobject: The input nexus object to "copy".
2223
+ :type nxobject: nexusformat.nexus.NXobject
2224
+ :param exlude_nxpaths: A list of relative paths to child nexus
2225
+ objects that should be excluded from the returned "copy",
2226
+ defaults to `[]`.
2227
+ :type exclude_nxpaths: str, list[str], optional
2228
+ :param nxpath_prefix: For use in recursive calls from inside this
2229
+ function only.
2230
+ :type nxpath_prefix: str
2231
+ :param nxpathabs_prefix: For use in recursive calls from inside this
2232
+ function only.
2233
+ :type nxpathabs_prefix: str
2234
+ :param nxpath_copy_abspath: For use in recursive calls from inside this
2235
+ function only.
2236
+ :type nxpath_copy_abspath: str
2237
+ :return: Copy of the input `nxobject` with some children optionally
2238
+ exluded.
2239
+ :rtype: nexusformat.nexus.NXobject
2240
+ """
2241
+ # Third party modules
2242
+ from nexusformat.nexus import (
2243
+ NXentry,
2244
+ NXfield,
2245
+ NXgroup,
2246
+ NXlink,
2247
+ NXlinkgroup,
2248
+ NXroot,
2249
+ )
2250
+
2251
+
2252
+ if isinstance(nxobject, NXlinkgroup):
2253
+ # The top level nxobject is a linked group
2254
+ # Create a group with the same name as the top level's target
2255
+ nxobject_copy = nxobject[nxobject.nxtarget].__class__(
2256
+ name=nxobject.nxname)
2257
+ elif isinstance(nxobject, (NXlink, NXfield)):
2258
+ # The top level nxobject is a (linked) field: return a copy
2259
+ attrs = nxobject.attrs
2260
+ attrs.pop('target', None)
2261
+ nxobject_copy = NXfield(
2262
+ value=nxobject.nxdata, name=nxobject.nxname,
2263
+ attrs=attrs)
2264
+ return nxobject_copy
2265
+ else:
2266
+ # Create a group with the same type/name as the nxobject
2267
+ nxobject_copy = nxobject.__class__(name=nxobject.nxname)
2268
+
2269
+ # Copy attributes
2270
+ if isinstance(nxobject, NXroot):
2271
+ if 'default' in nxobject.attrs:
2272
+ nxobject_copy.attrs['default'] = nxobject.default
2273
+ else:
2274
+ for k, v in nxobject.attrs.items():
2275
+ nxobject_copy.attrs[k] = v
2276
+
2277
+ # Setup paths
2278
+ if exclude_nxpaths is None:
2279
+ exclude_nxpaths = []
2280
+ elif isinstance(exclude_nxpaths, str):
2281
+ exclude_nxpaths = [exclude_nxpaths]
2282
+ for exclude_nxpath in exclude_nxpaths:
2283
+ if exclude_nxpath[0] == '/':
2284
+ raise ValueError(
2285
+ f'Invalid parameter in exclude_nxpaths ({exclude_nxpaths}), '
2286
+ 'excluded paths should be relative')
2287
+ if nxpath_prefix is None:
2288
+ nxpath_prefix = ''
2289
+ if nxpathabs_prefix is None:
2290
+ if isinstance(nxobject, NXentry):
2291
+ nxpathabs_prefix = nxobject.nxpath
2292
+ else:
2293
+ nxpathabs_prefix = nxobject.nxpath.removesuffix(nxobject.nxname)
2294
+ if nxpath_copy_abspath is None:
2295
+ nxpath_copy_abspath = ''
2296
+
2297
+ # Loop over all nxobject's children
2298
+ for k, v in nxobject.items():
2299
+ nxpath = os_path.join(nxpath_prefix, k)
2300
+ nxpathabs = os_path.join(nxpathabs_prefix, nxpath)
2301
+ if nxpath in exclude_nxpaths:
2302
+ if 'default' in nxobject_copy.attrs and nxobject_copy.default == k:
2303
+ nxobject_copy.attrs.pop('default')
2304
+ continue
2305
+ if isinstance(v, NXlinkgroup):
2306
+ if nxpathabs == v.nxpath and not any(
2307
+ v.nxtarget.startswith(os_path.join(nxpathabs_prefix, p))
2308
+ for p in exclude_nxpaths):
2309
+ nxobject_copy[k] = NXlink(v.nxtarget)
2310
+ else:
2311
+ nxobject_copy[k] = nxcopy(
2312
+ v, exclude_nxpaths=exclude_nxpaths,
2313
+ nxpath_prefix=nxpath, nxpathabs_prefix=nxpathabs_prefix,
2314
+ nxpath_copy_abspath=os_path.join(nxpath_copy_abspath, k))
2315
+ elif isinstance(v, NXlink):
2316
+ if nxpathabs == v.nxpath and not any(
2317
+ v.nxtarget.startswith(os_path.join(nxpathabs_prefix, p))
2318
+ for p in exclude_nxpaths):
2319
+ nxobject_copy[k] = v
2320
+ else:
2321
+ nxobject_copy[k] = v.nxdata
2322
+ for kk, vv in v.attrs.items():
2323
+ nxobject_copy[k].attrs[kk] = vv
2324
+ nxobject_copy[k].attrs.pop('target', None)
2325
+ elif isinstance(v, NXgroup):
2326
+ nxobject_copy[k] = nxcopy(
2327
+ v, exclude_nxpaths=exclude_nxpaths,
2328
+ nxpath_prefix=nxpath, nxpathabs_prefix=nxpathabs_prefix,
2329
+ nxpath_copy_abspath=os_path.join(nxpath_copy_abspath, k))
2330
+ else:
2331
+ nxobject_copy[k] = v.nxdata
2332
+ for kk, vv in v.attrs.items():
2333
+ nxobject_copy[k].attrs[kk] = vv
2334
+ if nxpathabs != os_path.join(nxpath_copy_abspath, k):
2335
+ nxobject_copy[k].attrs.pop('target', None)
2336
+
2337
+ return nxobject_copy