ChessAnalysisPipeline 0.0.13__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ChessAnalysisPipeline might be problematic. Click here for more details.

CHAP/utils/fit.py CHANGED
@@ -862,9 +862,13 @@ class Fit:
862
862
  # Third party modules
863
863
  from asteval import Interpreter
864
864
 
865
+ # Local modules
866
+ from CHAP.utils.general import is_num_pair
867
+
865
868
  if centers_range is None:
866
869
  centers_range = (self._x[0], self._x[-1])
867
- elif not is_index_range(centers_range, ge=self._x[0], le=self._x[-1]):
870
+ elif (not is_num_pair(centers_range) or len(centers_range) != 2
871
+ or centers_range[0] >= centers_range[1]):
868
872
  raise ValueError(
869
873
  f'Invalid parameter centers_range ({centers_range})')
870
874
  if self._model is not None:
@@ -1021,9 +1025,8 @@ class Fit:
1021
1025
  {'name': par_name, 'min': min_value})
1022
1026
  elif len(index) == 1:
1023
1027
  parameter = parameters[index[0]]
1024
- _min = parameter.get('min', None)
1025
- if _min is None or _min < min_value:
1026
- parameter['min'] = min_value
1028
+ _min = parameter.get('min', min_value)
1029
+ parameter['min'] = max(_min, min_value)
1027
1030
  else:
1028
1031
  raise ValueError(
1029
1032
  'Invalid parameters value in '
@@ -1045,14 +1048,47 @@ class Fit:
1045
1048
  {'name': par_name, 'min': min_value})
1046
1049
  elif len(index) == 1:
1047
1050
  parameter = parameters[index[0]]
1048
- _min = parameter.get('min', None)
1049
- if _min is None or _min < min_value:
1050
- parameter['min'] = min_value
1051
+ _min = parameter.get('min', min_value)
1052
+ parameter['min'] = max(_min, min_value)
1051
1053
  else:
1052
1054
  raise ValueError(
1053
1055
  'Invalid parameters value in '
1054
1056
  f'background model {name} '
1055
1057
  f'({parameters})')
1058
+ if name == 'gaussian':
1059
+ if parameters is None:
1060
+ parameters = {
1061
+ 'name': 'center',
1062
+ 'value': 0.5 * (
1063
+ centers_range[0] + centers_range[1]),
1064
+ 'min': centers_range[0],
1065
+ 'min': centers_range[1],
1066
+ }
1067
+ else:
1068
+ index = [i for i, par in enumerate(parameters)
1069
+ if par['name'] == 'center']
1070
+ if not len(index):
1071
+ parameters.append({
1072
+ 'name': 'center',
1073
+ 'value': 0.5 * (
1074
+ centers_range[0] + centers_range[1]),
1075
+ 'min': centers_range[0],
1076
+ 'max': centers_range[1],
1077
+ })
1078
+ elif len(index) == 1:
1079
+ parameter = parameters[index[0]]
1080
+ if 'value' not in parameter:
1081
+ parameter['value'] = 0.5 * (
1082
+ centers_range[0]+centers_range[1])
1083
+ _min = parameter.get('min', centers_range[0])
1084
+ parameter['min'] = max(_min, centers_range[0])
1085
+ _max = parameter.get('max', centers_range[1])
1086
+ parameter['max'] = min(_max, centers_range[1])
1087
+ else:
1088
+ raise ValueError(
1089
+ 'Invalid parameters value in '
1090
+ f'background model {name} '
1091
+ f'({parameters})')
1056
1092
  self.add_model(
1057
1093
  name, prefix=prefix, parameters=parameters,
1058
1094
  **model)
@@ -1320,14 +1356,11 @@ class Fit:
1320
1356
  self._reset_par_at_boundary()
1321
1357
 
1322
1358
  # Perform the fit
1323
- fit_kws = None
1324
- # if 'Dfun' in kwargs:
1325
- # fit_kws = {'Dfun': kwargs.pop('Dfun')}
1326
- # self._result = self._model.fit(
1327
- # self._y_norm, self._parameters, x=self._x, fit_kws=fit_kws,
1328
- # **kwargs)
1359
+ fit_kws = {}
1329
1360
  if self._param_constraint:
1330
1361
  fit_kws = {'xtol': 1.e-5, 'ftol': 1.e-5, 'gtol': 1.e-5}
1362
+ # if 'Dfun' in kwargs:
1363
+ # fit_kws['Dfun'] = kwargs.pop('Dfun')
1331
1364
  if self._mask is None:
1332
1365
  self._result = self._model.fit(
1333
1366
  self._y_norm, self._parameters, x=self._x, fit_kws=fit_kws,
@@ -1866,6 +1899,7 @@ class Fit:
1866
1899
  self._result.residual *= self._norm[1]
1867
1900
 
1868
1901
  def _reset_par_at_boundary(self):
1902
+ fraction = 0.02
1869
1903
  for name, par in self._parameters.items():
1870
1904
  if par.vary:
1871
1905
  value = par.value
@@ -1874,26 +1908,26 @@ class Fit:
1874
1908
  if np.isinf(_min):
1875
1909
  if not np.isinf(_max):
1876
1910
  if self._parameter_norms.get(name, False):
1877
- upp = _max-0.1*self._y_range
1911
+ upp = _max - fraction*self._y_range
1878
1912
  elif _max == 0.0:
1879
- upp = _max-0.1
1913
+ upp = _max - fraction
1880
1914
  else:
1881
- upp = _max-0.1*abs(_max)
1915
+ upp = _max - fraction*abs(_max)
1882
1916
  if value >= upp:
1883
1917
  par.set(value=upp)
1884
1918
  else:
1885
1919
  if np.isinf(_max):
1886
1920
  if self._parameter_norms.get(name, False):
1887
- low = _min + 0.1*self._y_range
1921
+ low = _min + fraction*self._y_range
1888
1922
  elif _min == 0.0:
1889
- low = _min+0.1
1923
+ low = _min + fraction
1890
1924
  else:
1891
- low = _min + 0.1*abs(_min)
1925
+ low = _min + fraction*abs(_min)
1892
1926
  if value <= low:
1893
1927
  par.set(value=low)
1894
1928
  else:
1895
- low = 0.9*_min + 0.1*_max
1896
- upp = 0.1*_min + 0.9*_max
1929
+ low = (1.0-fraction)*_min + fraction*_max
1930
+ upp = fraction*_min + (1.0-fraction)*_max
1897
1931
  if value <= low:
1898
1932
  par.set(value=low)
1899
1933
  if value >= upp:
@@ -1917,6 +1951,7 @@ class FitMap(Fit):
1917
1951
  self._max_nfev = None
1918
1952
  self._memfolder = None
1919
1953
  self._new_parameters = None
1954
+ self._num_func_eval = None
1920
1955
  self._out_of_bounds = None
1921
1956
  self._plot = False
1922
1957
  self._print_report = False
@@ -1932,6 +1967,7 @@ class FitMap(Fit):
1932
1967
  # map dimensions
1933
1968
  if isinstance(ymap, (tuple, list, np.ndarray)):
1934
1969
  self._x = np.asarray(x)
1970
+ ymap = np.asarray(ymap)
1935
1971
  elif HAVE_XARRAY and isinstance(ymap, xr.DataArray):
1936
1972
  if x is not None:
1937
1973
  logger.warning('Ignoring superfluous input x ({x})')
@@ -2149,7 +2185,8 @@ class FitMap(Fit):
2149
2185
  @property
2150
2186
  def max_nfev(self):
2151
2187
  """
2152
- Return the maximum number of function evaluations for each fit.
2188
+ Return if the maximum number of function evaluations is reached
2189
+ for each fit.
2153
2190
  """
2154
2191
  return self._max_nfev
2155
2192
 
@@ -2158,7 +2195,7 @@ class FitMap(Fit):
2158
2195
  """
2159
2196
  Return the number of function evaluations for each best fit.
2160
2197
  """
2161
- logger.warning('Undefined property num_func_eval')
2198
+ return self._num_func_eval
2162
2199
 
2163
2200
  @property
2164
2201
  def out_of_bounds(self):
@@ -2466,6 +2503,7 @@ class FitMap(Fit):
2466
2503
  if self._result is not None:
2467
2504
  self._out_of_bounds = None
2468
2505
  self._max_nfev = None
2506
+ self._num_func_eval = None
2469
2507
  self._redchi = None
2470
2508
  self._success = None
2471
2509
  self._best_fit = None
@@ -2508,6 +2546,7 @@ class FitMap(Fit):
2508
2546
  if num_proc == 1:
2509
2547
  self._out_of_bounds_flat = np.zeros(self._map_dim, dtype=bool)
2510
2548
  self._max_nfev_flat = np.zeros(self._map_dim, dtype=bool)
2549
+ self._num_func_eval_flat = np.zeros(self._map_dim, dtype=np.intc)
2511
2550
  self._redchi_flat = np.zeros(self._map_dim, dtype=np.float64)
2512
2551
  self._success_flat = np.zeros(self._map_dim, dtype=bool)
2513
2552
  self._best_fit_flat = np.zeros(
@@ -2537,6 +2576,11 @@ class FitMap(Fit):
2537
2576
  filename_memmap = path.join(self._memfolder, 'max_nfev_memmap')
2538
2577
  self._max_nfev_flat = np.memmap(
2539
2578
  filename_memmap, dtype=bool, shape=(self._map_dim), mode='w+')
2579
+ filename_memmap = path.join(
2580
+ self._memfolder, 'num_func_eval_memmap')
2581
+ self._num_func_eval_flat = np.memmap(
2582
+ filename_memmap, dtype=np.intc, shape=(self._map_dim),
2583
+ mode='w+')
2540
2584
  filename_memmap = path.join(self._memfolder, 'redchi_memmap')
2541
2585
  self._redchi_flat = np.memmap(
2542
2586
  filename_memmap, dtype=np.float64, shape=(self._map_dim),
@@ -2598,29 +2642,32 @@ class FitMap(Fit):
2598
2642
  except AttributeError:
2599
2643
  pass
2600
2644
 
2601
- if num_proc == 1:
2602
- # Perform the remaining fits serially
2603
- for n in range(1, self._map_dim):
2604
- self._fit(n, current_best_values, **kwargs)
2605
- else:
2606
- # Perform the remaining fits in parallel
2607
- num_fit = self._map_dim-1
2608
- if num_proc > num_fit:
2609
- logger.warning(
2610
- f'The requested number of processors ({num_proc}) exceeds '
2611
- f'the number of fits, num_proc reduced to {num_fit}')
2612
- num_proc = num_fit
2613
- num_fit_per_proc = 1
2645
+ if self._map_dim > 1:
2646
+ if num_proc == 1:
2647
+ # Perform the remaining fits serially
2648
+ for n in range(1, self._map_dim):
2649
+ self._fit(n, current_best_values, **kwargs)
2614
2650
  else:
2615
- num_fit_per_proc = round((num_fit)/num_proc)
2616
- if num_proc*num_fit_per_proc < num_fit:
2617
- num_fit_per_proc += 1
2618
- num_fit_batch = min(num_fit_per_proc, 40)
2619
- with Parallel(n_jobs=num_proc) as parallel:
2620
- parallel(
2621
- delayed(self._fit_parallel)
2622
- (current_best_values, num_fit_batch, n_start, **kwargs)
2623
- for n_start in range(1, self._map_dim, num_fit_batch))
2651
+ # Perform the remaining fits in parallel
2652
+ num_fit = self._map_dim-1
2653
+ if num_proc > num_fit:
2654
+ logger.warning(
2655
+ f'The requested number of processors ({num_proc}) '
2656
+ 'exceeds the number of fits, num_proc reduced to '
2657
+ f'{num_fit}')
2658
+ num_proc = num_fit
2659
+ num_fit_per_proc = 1
2660
+ else:
2661
+ num_fit_per_proc = round((num_fit)/num_proc)
2662
+ if num_proc*num_fit_per_proc < num_fit:
2663
+ num_fit_per_proc += 1
2664
+ num_fit_batch = min(num_fit_per_proc, 40)
2665
+ with Parallel(n_jobs=num_proc) as parallel:
2666
+ parallel(
2667
+ delayed(self._fit_parallel)
2668
+ (current_best_values, num_fit_batch, n_start,
2669
+ **kwargs)
2670
+ for n_start in range(1, self._map_dim, num_fit_batch))
2624
2671
 
2625
2672
  # Renormalize the initial parameters for external use
2626
2673
  if self._norm is not None and self._normalized:
@@ -2649,6 +2696,8 @@ class FitMap(Fit):
2649
2696
  self._out_of_bounds_flat, self._map_shape))
2650
2697
  self._max_nfev = np.copy(np.reshape(
2651
2698
  self._max_nfev_flat, self._map_shape))
2699
+ self._num_func_eval = np.copy(np.reshape(
2700
+ self._num_func_eval_flat, self._map_shape))
2652
2701
  self._redchi = np.copy(np.reshape(self._redchi_flat, self._map_shape))
2653
2702
  self._success = np.copy(np.reshape(
2654
2703
  self._success_flat, self._map_shape))
@@ -2662,6 +2711,8 @@ class FitMap(Fit):
2662
2711
  self._out_of_bounds = np.transpose(
2663
2712
  self._out_of_bounds, self._inv_transpose)
2664
2713
  self._max_nfev = np.transpose(self._max_nfev, self._inv_transpose)
2714
+ self._num_func_eval = np.transpose(
2715
+ self._num_func_eval, self._inv_transpose)
2665
2716
  self._redchi = np.transpose(self._redchi, self._inv_transpose)
2666
2717
  self._success = np.transpose(self._success, self._inv_transpose)
2667
2718
  self._best_fit = np.transpose(
@@ -2673,6 +2724,7 @@ class FitMap(Fit):
2673
2724
  self._best_errors, [0] + [i+1 for i in self._inv_transpose])
2674
2725
  del self._out_of_bounds_flat
2675
2726
  del self._max_nfev_flat
2727
+ del self._num_func_eval_flat
2676
2728
  del self._redchi_flat
2677
2729
  del self._success_flat
2678
2730
  del self._best_fit_flat
@@ -2761,6 +2813,7 @@ class FitMap(Fit):
2761
2813
 
2762
2814
  if result.redchi >= self._redchi_cutoff:
2763
2815
  result.success = False
2816
+ self._num_func_eval_flat[n] = result.nfev
2764
2817
  if result.nfev == result.max_nfev:
2765
2818
  if result.redchi < self._redchi_cutoff:
2766
2819
  result.success = True
CHAP/utils/general.py CHANGED
@@ -26,8 +26,6 @@ from sys import float_info
26
26
  import numpy as np
27
27
  try:
28
28
  import matplotlib.pyplot as plt
29
- from matplotlib.widgets import AxesWidget, RadioButtons
30
- from matplotlib import cbook
31
29
  except ImportError:
32
30
  pass
33
31
 
@@ -655,8 +653,8 @@ def _input_int_or_num(
655
653
 
656
654
 
657
655
  def input_int_list(
658
- s=None, ge=None, le=None, split_on_dash=True, remove_duplicates=True,
659
- sort=True, raise_error=False, log=True):
656
+ s=None, num_max=None, ge=None, le=None, split_on_dash=True,
657
+ remove_duplicates=True, sort=True, raise_error=False, log=True):
660
658
  """
661
659
  Prompt the user to input a list of interger and split the entered
662
660
  string on any combination of commas, whitespaces, or dashes (when
@@ -669,13 +667,13 @@ def input_int_list(
669
667
  return None upon an illegal input
670
668
  """
671
669
  return _input_int_or_num_list(
672
- 'int', s, ge, le, split_on_dash, remove_duplicates, sort, raise_error,
673
- log)
670
+ 'int', s, num_max, ge, le, split_on_dash, remove_duplicates, sort,
671
+ raise_error, log)
674
672
 
675
673
 
676
674
  def input_num_list(
677
- s=None, ge=None, le=None, remove_duplicates=True, sort=True,
678
- raise_error=False, log=True):
675
+ s=None, num_max=None, ge=None, le=None, remove_duplicates=True,
676
+ sort=True, raise_error=False, log=True):
679
677
  """
680
678
  Prompt the user to input a list of numbers and split the entered
681
679
  string on any combination of commas or whitespaces.
@@ -687,11 +685,12 @@ def input_num_list(
687
685
  return None upon an illegal input
688
686
  """
689
687
  return _input_int_or_num_list(
690
- 'num', s, ge, le, False, remove_duplicates, sort, raise_error, log)
688
+ 'num', s, num_max, ge, le, False, remove_duplicates, sort, raise_error,
689
+ log)
691
690
 
692
691
 
693
692
  def _input_int_or_num_list(
694
- type_str, s=None, ge=None, le=None, split_on_dash=True,
693
+ type_str, s=None, num_max=None, ge=None, le=None, split_on_dash=True,
695
694
  remove_duplicates=True, sort=True, raise_error=False, log=True):
696
695
  # RV do we want a limit on max dimension?
697
696
  if type_str == 'int':
@@ -707,6 +706,9 @@ def _input_int_or_num_list(
707
706
  else:
708
707
  illegal_value(type_str, 'type_str', '_input_int_or_num_list')
709
708
  return None
709
+ if (num_max is not None
710
+ and not is_int(num_max, gt=0, raise_error=raise_error, log=log)):
711
+ return None
710
712
  v_range = f'{range_string_ge_gt_le_lt(ge=ge, le=le)}'
711
713
  if v_range:
712
714
  v_range = f' (each value in {v_range})'
@@ -721,14 +723,17 @@ def _input_int_or_num_list(
721
723
  except:
722
724
  print('Unexpected error')
723
725
  raise
724
- if (not isinstance(_list, list) or any(
725
- not _is_int_or_num(v, type_str, ge=ge, le=le) for v in _list)):
726
+ if (not isinstance(_list, list)
727
+ or (num_max is not None and len(_list) > num_max)
728
+ or any(
729
+ not _is_int_or_num(v, type_str, ge=ge, le=le) for v in _list)):
730
+ num = '' if num_max is None else f'up to {num_max} '
726
731
  if split_on_dash:
727
- print('Invalid input: enter a valid set of dash/comma/whitespace '
728
- 'separated integers e.g. 1 3,5-8 , 12')
732
+ print(f'Invalid input: enter a valid set of {num}dash/comma/'
733
+ 'whitespace separated numbers e.g. 1 3,5-8 , 12')
729
734
  else:
730
- print('Invalid input: enter a valid set of comma/whitespace '
731
- 'separated integers e.g. 1 3,5 8 , 12')
735
+ print(f'Invalid input: enter a valid set of {num}comma/whitespace '
736
+ 'separated numbers e.g. 1 3,5 8 , 12')
732
737
  _list = _input_int_or_num_list(
733
738
  type_str, s, ge, le, split_on_dash, remove_duplicates, sort,
734
739
  raise_error, log)
@@ -1105,18 +1110,6 @@ def select_mask_1d(
1105
1110
  for v in preselected_index_ranges)):
1106
1111
  raise ValueError('Invalid parameter preselected_index_ranges '
1107
1112
  f'({preselected_index_ranges})')
1108
- if (min_num_index_ranges is not None
1109
- and len(preselected_index_ranges) < min_num_index_ranges):
1110
- raise ValueError(
1111
- 'Invalid parameter preselected_index_ranges '
1112
- f'({preselected_index_ranges}), number of selected index '
1113
- f'ranges must be >= {min_num_index_ranges}')
1114
- if (max_num_index_ranges is not None
1115
- and len(preselected_index_ranges) > max_num_index_ranges):
1116
- raise ValueError(
1117
- 'Invalid parameter preselected_index_ranges '
1118
- f'({preselected_index_ranges}), number of selected index '
1119
- f'ranges must be <= {max_num_index_ranges}')
1120
1113
 
1121
1114
  spans = []
1122
1115
  fig_title = []
@@ -1191,8 +1184,6 @@ def select_mask_1d(
1191
1184
  plt.subplots_adjust(bottom=0.0)
1192
1185
 
1193
1186
  selected_index_ranges = get_selected_index_ranges()
1194
- if not selected_index_ranges:
1195
- selected_index_ranges = None
1196
1187
 
1197
1188
  # Update the mask with the currently selected index ranges
1198
1189
  selected_mask = update_mask(len(x)*[False], selected_index_ranges)
@@ -1290,37 +1281,46 @@ def select_roi_2d(
1290
1281
  fig_title.pop()
1291
1282
  fig_title.append(plt.figtext(*title_pos, title, **title_props))
1292
1283
 
1293
- def change_error_text(error):
1294
- if error_texts:
1295
- error_texts[0].remove()
1296
- error_texts.pop()
1297
- error_texts.append(plt.figtext(*error_pos, error, **error_props))
1284
+ def change_subfig_title(error):
1285
+ if subfig_title:
1286
+ subfig_title[0].remove()
1287
+ subfig_title.pop()
1288
+ subfig_title.append(plt.figtext(*error_pos, error, **error_props))
1289
+
1290
+ def clear_selection():
1291
+ rects[0].set_visible(False)
1292
+ rects.pop()
1293
+ rects.append(
1294
+ RectangleSelector(
1295
+ ax, on_rect_select, props=rect_props,
1296
+ useblit=True, interactive=interactive, drag_from_anywhere=True,
1297
+ ignore_event_outside=False))
1298
1298
 
1299
1299
  def on_rect_select(eclick, erelease):
1300
1300
  """Callback function for the RectangleSelector widget."""
1301
- change_error_text(
1302
- f'Selected ROI: {tuple(int(v) for v in rects[0].extents)}')
1301
+ if (not int(rects[0].extents[1]) - int(rects[0].extents[0])
1302
+ or not int(rects[0].extents[3]) - int(rects[0].extents[2])):
1303
+ clear_selection()
1304
+ change_subfig_title(
1305
+ f'Selected ROI too small, try again')
1306
+ else:
1307
+ change_subfig_title(
1308
+ f'Selected ROI: {tuple(int(v) for v in rects[0].extents)}')
1303
1309
  plt.draw()
1304
1310
 
1305
1311
  def reset(event):
1306
1312
  """Callback function for the "Reset" button."""
1307
- if error_texts:
1308
- error_texts[0].remove()
1309
- error_texts.pop()
1310
- rects[0].set_visible(False)
1311
- rects.pop()
1312
- rects.append(
1313
- RectangleSelector(
1314
- ax, on_rect_select, props=rect_props, useblit=True,
1315
- interactive=interactive, drag_from_anywhere=True,
1316
- ignore_event_outside=False))
1313
+ if subfig_title:
1314
+ subfig_title[0].remove()
1315
+ subfig_title.pop()
1316
+ clear_selection()
1317
1317
  plt.draw()
1318
1318
 
1319
1319
  def confirm(event):
1320
1320
  """Callback function for the "Confirm" button."""
1321
- if error_texts:
1322
- error_texts[0].remove()
1323
- error_texts.pop()
1321
+ if subfig_title:
1322
+ subfig_title[0].remove()
1323
+ subfig_title.pop()
1324
1324
  roi = tuple(int(v) for v in rects[0].extents)
1325
1325
  if roi[1]-roi[0] < 1 or roi[3]-roi[2] < 1:
1326
1326
  roi = None
@@ -1328,7 +1328,7 @@ def select_roi_2d(
1328
1328
  plt.close()
1329
1329
 
1330
1330
  fig_title = []
1331
- error_texts = []
1331
+ subfig_title = []
1332
1332
 
1333
1333
  # Check inputs
1334
1334
  a = np.asarray(a)
@@ -1377,7 +1377,7 @@ def select_roi_2d(
1377
1377
 
1378
1378
  change_fig_title(title)
1379
1379
  if preselected_roi is not None:
1380
- change_error_text(
1380
+ change_subfig_title(
1381
1381
  f'Preselected ROI: {tuple(int(v) for v in preselected_roi)}')
1382
1382
  fig.subplots_adjust(bottom=0.2)
1383
1383
 
@@ -1436,7 +1436,7 @@ def select_image_indices(
1436
1436
  :param preselected_indices: Preselected image indices,
1437
1437
  defaults to `None`.
1438
1438
  :type preselected_indices: tuple(int), list(int), optional
1439
- :param axis_index_offset: Offset in axes index range and
1439
+ :param axis_index_offset: Offset in axis index range and
1440
1440
  preselected indices, defaults to `0`.
1441
1441
  :type axis_index_offset: int, optional
1442
1442
  :param min_range: The minimal range spanned by the selected
@@ -1535,7 +1535,7 @@ def select_image_indices(
1535
1535
  try:
1536
1536
  index = int(expression)
1537
1537
  if (index < axis_index_offset
1538
- or index >= axis_index_offset+a.shape[axis]):
1538
+ or index > axis_index_offset+a.shape[axis]):
1539
1539
  raise ValueError
1540
1540
  except ValueError:
1541
1541
  change_error_text(
@@ -1872,3 +1872,126 @@ def quick_plot(
1872
1872
  if save_fig:
1873
1873
  plt.savefig(path)
1874
1874
  plt.show(block=block)
1875
+
1876
+
1877
+ def nxcopy(
1878
+ nxobject, exclude_nxpaths=None, nxpath_prefix=None,
1879
+ nxpathabs_prefix=None, nxpath_copy_abspath=None):
1880
+ """
1881
+ Function that returns a copy of a nexus object, optionally exluding
1882
+ certain child items.
1883
+
1884
+ :param nxobject: The input nexus object to "copy".
1885
+ :type nxobject: nexusformat.nexus.NXobject
1886
+ :param exlude_nxpaths: A list of relative paths to child nexus
1887
+ objects that should be excluded from the returned "copy",
1888
+ defaults to `[]`.
1889
+ :type exclude_nxpaths: str, list[str], optional
1890
+ :param nxpath_prefix: For use in recursive calls from inside this
1891
+ function only.
1892
+ :type nxpath_prefix: str
1893
+ :param nxpathabs_prefix: For use in recursive calls from inside this
1894
+ function only.
1895
+ :type nxpathabs_prefix: str
1896
+ :param nxpath_copy_abspath: For use in recursive calls from inside this
1897
+ function only.
1898
+ :type nxpath_copy_abspath: str
1899
+ :return: Copy of the input `nxobject` with some children optionally
1900
+ exluded.
1901
+ :rtype: nexusformat.nexus.NXobject
1902
+ """
1903
+ # Third party modules
1904
+ from nexusformat.nexus import (
1905
+ NXentry,
1906
+ NXfield,
1907
+ NXgroup,
1908
+ NXlink,
1909
+ NXlinkgroup,
1910
+ NXroot,
1911
+ )
1912
+
1913
+
1914
+ if isinstance(nxobject, NXlinkgroup):
1915
+ # The top level nxobject is a linked group
1916
+ # Create a group with the same name as the top level's target
1917
+ nxobject_copy = nxobject[nxobject.nxtarget].__class__(
1918
+ name=nxobject.nxname)
1919
+ elif isinstance(nxobject, (NXlink, NXfield)):
1920
+ # The top level nxobject is a (linked) field: return a copy
1921
+ nxobject_copy = NXfield(
1922
+ value=nxobject.nxdata, name=nxobject.nxname,
1923
+ attrs=nxobject.attrs)
1924
+ return nxobject_copy
1925
+ else:
1926
+ # Create a group with the same type/name as the nxobject
1927
+ nxobject_copy = nxobject.__class__(name=nxobject.nxname)
1928
+
1929
+ # Copy attributes
1930
+ if isinstance(nxobject, NXroot):
1931
+ if 'default' in nxobject.attrs:
1932
+ nxobject_copy.attrs['default'] = nxobject.default
1933
+ else:
1934
+ for k, v in nxobject.attrs.items():
1935
+ nxobject_copy.attrs[k] = v
1936
+
1937
+ # Setup paths
1938
+ if exclude_nxpaths is None:
1939
+ exclude_nxpaths = []
1940
+ elif isinstance(exclude_nxpaths, str):
1941
+ exclude_nxpaths = [exclude_nxpaths]
1942
+ for exclude_nxpath in exclude_nxpaths:
1943
+ if exclude_nxpath[0] == '/':
1944
+ raise ValueError(
1945
+ f'Invalid parameter in exclude_nxpaths ({exclude_nxpaths}), '
1946
+ 'excluded paths should be relative')
1947
+ if nxpath_prefix is None:
1948
+ nxpath_prefix = ''
1949
+ if nxpathabs_prefix is None:
1950
+ if isinstance(nxobject, NXentry):
1951
+ nxpathabs_prefix = nxobject.nxpath
1952
+ else:
1953
+ nxpathabs_prefix = nxobject.nxpath.removesuffix(nxobject.nxname)
1954
+ if nxpath_copy_abspath is None:
1955
+ nxpath_copy_abspath = ''
1956
+
1957
+ # Loop over all nxobject's children
1958
+ for k, v in nxobject.items():
1959
+ nxpath = os_path.join(nxpath_prefix, k)
1960
+ nxpathabs = os_path.join(nxpathabs_prefix, nxpath)
1961
+ if nxpath in exclude_nxpaths:
1962
+ if 'default' in nxobject_copy.attrs and nxobject_copy.default == k:
1963
+ nxobject_copy.attrs.pop('default')
1964
+ continue
1965
+ if isinstance(v, NXlinkgroup):
1966
+ if nxpathabs == v.nxpath and not any(
1967
+ v.nxtarget.startswith(os_path.join(nxpathabs_prefix, p))
1968
+ for p in exclude_nxpaths):
1969
+ nxobject_copy[k] = NXlink(v.nxtarget)
1970
+ else:
1971
+ nxobject_copy[k] = nxcopy(
1972
+ v, exclude_nxpaths=exclude_nxpaths,
1973
+ nxpath_prefix=nxpath, nxpathabs_prefix=nxpathabs_prefix,
1974
+ nxpath_copy_abspath=os_path.join(nxpath_copy_abspath, k))
1975
+ elif isinstance(v, NXlink):
1976
+ if nxpathabs == v.nxpath and not any(
1977
+ v.nxtarget.startswith(os_path.join(nxpathabs_prefix, p))
1978
+ for p in exclude_nxpaths):
1979
+ nxobject_copy[k] = v
1980
+ else:
1981
+ nxobject_copy[k] = v.nxdata
1982
+ for kk, vv in v.attrs.items():
1983
+ nxobject_copy[k].attrs[kk] = vv
1984
+ nxobject_copy[k].attrs.pop('target', None)
1985
+ elif isinstance(v, NXgroup):
1986
+ nxobject_copy[k] = nxcopy(
1987
+ v, exclude_nxpaths=exclude_nxpaths,
1988
+ nxpath_prefix=nxpath, nxpathabs_prefix=nxpathabs_prefix,
1989
+ nxpath_copy_abspath=os_path.join(nxpath_copy_abspath, k))
1990
+ else:
1991
+ nxobject_copy[k] = v.nxdata
1992
+ for kk, vv in v.attrs.items():
1993
+ nxobject_copy[k].attrs[kk] = vv
1994
+ if nxpathabs != os_path.join(nxpath_copy_abspath, k):
1995
+ nxobject_copy[k].attrs.pop('target', None)
1996
+
1997
+ return nxobject_copy