ChessAnalysisPipeline 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ChessAnalysisPipeline might be problematic. Click here for more details.

CHAP/utils/general.py CHANGED
@@ -26,8 +26,6 @@ from sys import float_info
26
26
  import numpy as np
27
27
  try:
28
28
  import matplotlib.pyplot as plt
29
- from matplotlib.widgets import AxesWidget, RadioButtons
30
- from matplotlib import cbook
31
29
  except ImportError:
32
30
  pass
33
31
 
@@ -655,8 +653,8 @@ def _input_int_or_num(
655
653
 
656
654
 
657
655
  def input_int_list(
658
- s=None, ge=None, le=None, split_on_dash=True, remove_duplicates=True,
659
- sort=True, raise_error=False, log=True):
656
+ s=None, num_max=None, ge=None, le=None, split_on_dash=True,
657
+ remove_duplicates=True, sort=True, raise_error=False, log=True):
660
658
  """
661
659
  Prompt the user to input a list of interger and split the entered
662
660
  string on any combination of commas, whitespaces, or dashes (when
@@ -669,13 +667,13 @@ def input_int_list(
669
667
  return None upon an illegal input
670
668
  """
671
669
  return _input_int_or_num_list(
672
- 'int', s, ge, le, split_on_dash, remove_duplicates, sort, raise_error,
673
- log)
670
+ 'int', s, num_max, ge, le, split_on_dash, remove_duplicates, sort,
671
+ raise_error, log)
674
672
 
675
673
 
676
674
  def input_num_list(
677
- s=None, ge=None, le=None, remove_duplicates=True, sort=True,
678
- raise_error=False, log=True):
675
+ s=None, num_max=None, ge=None, le=None, remove_duplicates=True,
676
+ sort=True, raise_error=False, log=True):
679
677
  """
680
678
  Prompt the user to input a list of numbers and split the entered
681
679
  string on any combination of commas or whitespaces.
@@ -687,11 +685,12 @@ def input_num_list(
687
685
  return None upon an illegal input
688
686
  """
689
687
  return _input_int_or_num_list(
690
- 'num', s, ge, le, False, remove_duplicates, sort, raise_error, log)
688
+ 'num', s, num_max, ge, le, False, remove_duplicates, sort, raise_error,
689
+ log)
691
690
 
692
691
 
693
692
  def _input_int_or_num_list(
694
- type_str, s=None, ge=None, le=None, split_on_dash=True,
693
+ type_str, s=None, num_max=None, ge=None, le=None, split_on_dash=True,
695
694
  remove_duplicates=True, sort=True, raise_error=False, log=True):
696
695
  # RV do we want a limit on max dimension?
697
696
  if type_str == 'int':
@@ -707,6 +706,9 @@ def _input_int_or_num_list(
707
706
  else:
708
707
  illegal_value(type_str, 'type_str', '_input_int_or_num_list')
709
708
  return None
709
+ if (num_max is not None
710
+ and not is_int(num_max, gt=0, raise_error=raise_error, log=log)):
711
+ return None
710
712
  v_range = f'{range_string_ge_gt_le_lt(ge=ge, le=le)}'
711
713
  if v_range:
712
714
  v_range = f' (each value in {v_range})'
@@ -721,14 +723,17 @@ def _input_int_or_num_list(
721
723
  except:
722
724
  print('Unexpected error')
723
725
  raise
724
- if (not isinstance(_list, list) or any(
725
- not _is_int_or_num(v, type_str, ge=ge, le=le) for v in _list)):
726
+ if (not isinstance(_list, list)
727
+ or (num_max is not None and len(_list) > num_max)
728
+ or any(
729
+ not _is_int_or_num(v, type_str, ge=ge, le=le) for v in _list)):
730
+ num = '' if num_max is None else f'up to {num_max} '
726
731
  if split_on_dash:
727
- print('Invalid input: enter a valid set of dash/comma/whitespace '
728
- 'separated integers e.g. 1 3,5-8 , 12')
732
+ print(f'Invalid input: enter a valid set of {num}dash/comma/'
733
+ 'whitespace separated numbers e.g. 1 3,5-8 , 12')
729
734
  else:
730
- print('Invalid input: enter a valid set of comma/whitespace '
731
- 'separated integers e.g. 1 3,5 8 , 12')
735
+ print(f'Invalid input: enter a valid set of {num}comma/whitespace '
736
+ 'separated numbers e.g. 1 3,5 8 , 12')
732
737
  _list = _input_int_or_num_list(
733
738
  type_str, s, ge, le, split_on_dash, remove_duplicates, sort,
734
739
  raise_error, log)
@@ -1105,18 +1110,6 @@ def select_mask_1d(
1105
1110
  for v in preselected_index_ranges)):
1106
1111
  raise ValueError('Invalid parameter preselected_index_ranges '
1107
1112
  f'({preselected_index_ranges})')
1108
- if (min_num_index_ranges is not None
1109
- and len(preselected_index_ranges) < min_num_index_ranges):
1110
- raise ValueError(
1111
- 'Invalid parameter preselected_index_ranges '
1112
- f'({preselected_index_ranges}), number of selected index '
1113
- f'ranges must be >= {min_num_index_ranges}')
1114
- if (max_num_index_ranges is not None
1115
- and len(preselected_index_ranges) > max_num_index_ranges):
1116
- raise ValueError(
1117
- 'Invalid parameter preselected_index_ranges '
1118
- f'({preselected_index_ranges}), number of selected index '
1119
- f'ranges must be <= {max_num_index_ranges}')
1120
1113
 
1121
1114
  spans = []
1122
1115
  fig_title = []
@@ -1191,8 +1184,6 @@ def select_mask_1d(
1191
1184
  plt.subplots_adjust(bottom=0.0)
1192
1185
 
1193
1186
  selected_index_ranges = get_selected_index_ranges()
1194
- if not selected_index_ranges:
1195
- selected_index_ranges = None
1196
1187
 
1197
1188
  # Update the mask with the currently selected index ranges
1198
1189
  selected_mask = update_mask(len(x)*[False], selected_index_ranges)
@@ -1290,37 +1281,46 @@ def select_roi_2d(
1290
1281
  fig_title.pop()
1291
1282
  fig_title.append(plt.figtext(*title_pos, title, **title_props))
1292
1283
 
1293
- def change_error_text(error):
1294
- if error_texts:
1295
- error_texts[0].remove()
1296
- error_texts.pop()
1297
- error_texts.append(plt.figtext(*error_pos, error, **error_props))
1284
+ def change_subfig_title(error):
1285
+ if subfig_title:
1286
+ subfig_title[0].remove()
1287
+ subfig_title.pop()
1288
+ subfig_title.append(plt.figtext(*error_pos, error, **error_props))
1289
+
1290
+ def clear_selection():
1291
+ rects[0].set_visible(False)
1292
+ rects.pop()
1293
+ rects.append(
1294
+ RectangleSelector(
1295
+ ax, on_rect_select, props=rect_props,
1296
+ useblit=True, interactive=interactive, drag_from_anywhere=True,
1297
+ ignore_event_outside=False))
1298
1298
 
1299
1299
  def on_rect_select(eclick, erelease):
1300
1300
  """Callback function for the RectangleSelector widget."""
1301
- change_error_text(
1302
- f'Selected ROI: {tuple(int(v) for v in rects[0].extents)}')
1301
+ if (not int(rects[0].extents[1]) - int(rects[0].extents[0])
1302
+ or not int(rects[0].extents[3]) - int(rects[0].extents[2])):
1303
+ clear_selection()
1304
+ change_subfig_title(
1305
+ f'Selected ROI too small, try again')
1306
+ else:
1307
+ change_subfig_title(
1308
+ f'Selected ROI: {tuple(int(v) for v in rects[0].extents)}')
1303
1309
  plt.draw()
1304
1310
 
1305
1311
  def reset(event):
1306
1312
  """Callback function for the "Reset" button."""
1307
- if error_texts:
1308
- error_texts[0].remove()
1309
- error_texts.pop()
1310
- rects[0].set_visible(False)
1311
- rects.pop()
1312
- rects.append(
1313
- RectangleSelector(
1314
- ax, on_rect_select, props=rect_props, useblit=True,
1315
- interactive=interactive, drag_from_anywhere=True,
1316
- ignore_event_outside=False))
1313
+ if subfig_title:
1314
+ subfig_title[0].remove()
1315
+ subfig_title.pop()
1316
+ clear_selection()
1317
1317
  plt.draw()
1318
1318
 
1319
1319
  def confirm(event):
1320
1320
  """Callback function for the "Confirm" button."""
1321
- if error_texts:
1322
- error_texts[0].remove()
1323
- error_texts.pop()
1321
+ if subfig_title:
1322
+ subfig_title[0].remove()
1323
+ subfig_title.pop()
1324
1324
  roi = tuple(int(v) for v in rects[0].extents)
1325
1325
  if roi[1]-roi[0] < 1 or roi[3]-roi[2] < 1:
1326
1326
  roi = None
@@ -1328,7 +1328,7 @@ def select_roi_2d(
1328
1328
  plt.close()
1329
1329
 
1330
1330
  fig_title = []
1331
- error_texts = []
1331
+ subfig_title = []
1332
1332
 
1333
1333
  # Check inputs
1334
1334
  a = np.asarray(a)
@@ -1377,7 +1377,7 @@ def select_roi_2d(
1377
1377
 
1378
1378
  change_fig_title(title)
1379
1379
  if preselected_roi is not None:
1380
- change_error_text(
1380
+ change_subfig_title(
1381
1381
  f'Preselected ROI: {tuple(int(v) for v in preselected_roi)}')
1382
1382
  fig.subplots_adjust(bottom=0.2)
1383
1383
 
@@ -1417,8 +1417,8 @@ def select_roi_2d(
1417
1417
 
1418
1418
 
1419
1419
  def select_image_indices(
1420
- a, axis, b=None, preselected_indices=None, min_range=None,
1421
- min_num_indices=2, max_num_indices=2, title=None,
1420
+ a, axis, b=None, preselected_indices=None, axis_index_offset=0,
1421
+ min_range=None, min_num_indices=2, max_num_indices=2, title=None,
1422
1422
  title_a=None, title_b=None, row_label='row index',
1423
1423
  column_label='column index', interactive=True):
1424
1424
  """Display a 2D image and have the user select a set of image
@@ -1435,7 +1435,10 @@ def select_image_indices(
1435
1435
  :type b: numpy.ndarray, optional
1436
1436
  :param preselected_indices: Preselected image indices,
1437
1437
  defaults to `None`.
1438
- :type preselected_roi: tuple(int), list(int), optional
1438
+ :type preselected_indices: tuple(int), list(int), optional
1439
+ :param axis_index_offset: Offset in axis index range and
1440
+ preselected indices, defaults to `0`.
1441
+ :type axis_index_offset: int, optional
1439
1442
  :param min_range: The minimal range spanned by the selected
1440
1443
  indices, defaults to `None`
1441
1444
  :type min_range: int, optional
@@ -1531,12 +1534,14 @@ def select_image_indices(
1531
1534
  error_texts.pop()
1532
1535
  try:
1533
1536
  index = int(expression)
1534
- if not 0 <= index <= a.shape[axis]:
1537
+ if (index < axis_index_offset
1538
+ or index > axis_index_offset+a.shape[axis]):
1535
1539
  raise ValueError
1536
1540
  except ValueError:
1537
1541
  change_error_text(
1538
1542
  f'Invalid {row_column} index ({expression}), enter an integer '
1539
- f'between 0 and {a.shape[axis]-1}')
1543
+ f'between {axis_index_offset} and '
1544
+ f'{axis_index_offset+a.shape[axis]-1}')
1540
1545
  else:
1541
1546
  try:
1542
1547
  add_index(index)
@@ -1585,9 +1590,13 @@ def select_image_indices(
1585
1590
  row_column = 'row'
1586
1591
  else:
1587
1592
  row_column = 'column'
1593
+ if not is_int(axis_index_offset, ge=0, log=False):
1594
+ raise ValueError(
1595
+ 'Invalid parameter axis_index_offset ({axis_index_offset})')
1588
1596
  if preselected_indices is not None:
1589
1597
  if not is_int_series(
1590
- preselected_indices, ge=0, le=a.shape[axis], log=False):
1598
+ preselected_indices, ge=axis_index_offset,
1599
+ le=axis_index_offset+a.shape[axis], log=False):
1591
1600
  if interactive:
1592
1601
  logger.warning(
1593
1602
  'Invalid parameter preselected_indices '
@@ -1607,7 +1616,7 @@ def select_image_indices(
1607
1616
  if a.shape[0] != b.shape[0]:
1608
1617
  raise ValueError(f'Inconsistent image shapes({a.shape} vs '
1609
1618
  f'{b.shape})')
1610
-
1619
+
1611
1620
  indices = []
1612
1621
  lines = []
1613
1622
  fig_title = []
@@ -1627,10 +1636,11 @@ def select_image_indices(
1627
1636
  fig, axs = plt.subplots(1, 2, figsize=(11, 8.5))
1628
1637
  else:
1629
1638
  fig, axs = plt.subplots(2, 1, figsize=(11, 8.5))
1630
- axs[0].imshow(a)
1639
+ extent = (0, a.shape[1], axis_index_offset+a.shape[0], axis_index_offset)
1640
+ axs[0].imshow(a, extent=extent)
1631
1641
  axs[0].set_title(title_a, fontsize='xx-large')
1632
1642
  if b is not None:
1633
- axs[1].imshow(b)
1643
+ axs[1].imshow(b, extent=extent)
1634
1644
  axs[1].set_title(title_b, fontsize='xx-large')
1635
1645
  if a.shape[0]+b.shape[0] > max(a.shape[1], b.shape[1]):
1636
1646
  axs[0].set_xlabel(column_label, fontsize='x-large')
@@ -1641,8 +1651,8 @@ def select_image_indices(
1641
1651
  axs[1].set_xlabel(column_label, fontsize='x-large')
1642
1652
  axs[1].set_ylabel(row_label, fontsize='x-large')
1643
1653
  for ax in axs:
1644
- ax.set_xlim(0, a.shape[1])
1645
- ax.set_ylim(a.shape[0], 0)
1654
+ ax.set_xlim(extent[0], extent[1])
1655
+ ax.set_ylim(extent[2], extent[3])
1646
1656
  fig.subplots_adjust(bottom=0.0, top=0.85)
1647
1657
 
1648
1658
  # Setup the preselected indices if provided
@@ -1862,3 +1872,126 @@ def quick_plot(
1862
1872
  if save_fig:
1863
1873
  plt.savefig(path)
1864
1874
  plt.show(block=block)
1875
+
1876
+
1877
+ def nxcopy(
1878
+ nxobject, exclude_nxpaths=None, nxpath_prefix=None,
1879
+ nxpathabs_prefix=None, nxpath_copy_abspath=None):
1880
+ """
1881
+ Function that returns a copy of a nexus object, optionally exluding
1882
+ certain child items.
1883
+
1884
+ :param nxobject: The input nexus object to "copy".
1885
+ :type nxobject: nexusformat.nexus.NXobject
1886
+ :param exlude_nxpaths: A list of relative paths to child nexus
1887
+ objects that should be excluded from the returned "copy",
1888
+ defaults to `[]`.
1889
+ :type exclude_nxpaths: str, list[str], optional
1890
+ :param nxpath_prefix: For use in recursive calls from inside this
1891
+ function only.
1892
+ :type nxpath_prefix: str
1893
+ :param nxpathabs_prefix: For use in recursive calls from inside this
1894
+ function only.
1895
+ :type nxpathabs_prefix: str
1896
+ :param nxpath_copy_abspath: For use in recursive calls from inside this
1897
+ function only.
1898
+ :type nxpath_copy_abspath: str
1899
+ :return: Copy of the input `nxobject` with some children optionally
1900
+ exluded.
1901
+ :rtype: nexusformat.nexus.NXobject
1902
+ """
1903
+ # Third party modules
1904
+ from nexusformat.nexus import (
1905
+ NXentry,
1906
+ NXfield,
1907
+ NXgroup,
1908
+ NXlink,
1909
+ NXlinkgroup,
1910
+ NXroot,
1911
+ )
1912
+
1913
+
1914
+ if isinstance(nxobject, NXlinkgroup):
1915
+ # The top level nxobject is a linked group
1916
+ # Create a group with the same name as the top level's target
1917
+ nxobject_copy = nxobject[nxobject.nxtarget].__class__(
1918
+ name=nxobject.nxname)
1919
+ elif isinstance(nxobject, (NXlink, NXfield)):
1920
+ # The top level nxobject is a (linked) field: return a copy
1921
+ nxobject_copy = NXfield(
1922
+ value=nxobject.nxdata, name=nxobject.nxname,
1923
+ attrs=nxobject.attrs)
1924
+ return nxobject_copy
1925
+ else:
1926
+ # Create a group with the same type/name as the nxobject
1927
+ nxobject_copy = nxobject.__class__(name=nxobject.nxname)
1928
+
1929
+ # Copy attributes
1930
+ if isinstance(nxobject, NXroot):
1931
+ if 'default' in nxobject.attrs:
1932
+ nxobject_copy.attrs['default'] = nxobject.default
1933
+ else:
1934
+ for k, v in nxobject.attrs.items():
1935
+ nxobject_copy.attrs[k] = v
1936
+
1937
+ # Setup paths
1938
+ if exclude_nxpaths is None:
1939
+ exclude_nxpaths = []
1940
+ elif isinstance(exclude_nxpaths, str):
1941
+ exclude_nxpaths = [exclude_nxpaths]
1942
+ for exclude_nxpath in exclude_nxpaths:
1943
+ if exclude_nxpath[0] == '/':
1944
+ raise ValueError(
1945
+ f'Invalid parameter in exclude_nxpaths ({exclude_nxpaths}), '
1946
+ 'excluded paths should be relative')
1947
+ if nxpath_prefix is None:
1948
+ nxpath_prefix = ''
1949
+ if nxpathabs_prefix is None:
1950
+ if isinstance(nxobject, NXentry):
1951
+ nxpathabs_prefix = nxobject.nxpath
1952
+ else:
1953
+ nxpathabs_prefix = nxobject.nxpath.removesuffix(nxobject.nxname)
1954
+ if nxpath_copy_abspath is None:
1955
+ nxpath_copy_abspath = ''
1956
+
1957
+ # Loop over all nxobject's children
1958
+ for k, v in nxobject.items():
1959
+ nxpath = os_path.join(nxpath_prefix, k)
1960
+ nxpathabs = os_path.join(nxpathabs_prefix, nxpath)
1961
+ if nxpath in exclude_nxpaths:
1962
+ if 'default' in nxobject_copy.attrs and nxobject_copy.default == k:
1963
+ nxobject_copy.attrs.pop('default')
1964
+ continue
1965
+ if isinstance(v, NXlinkgroup):
1966
+ if nxpathabs == v.nxpath and not any(
1967
+ v.nxtarget.startswith(os_path.join(nxpathabs_prefix, p))
1968
+ for p in exclude_nxpaths):
1969
+ nxobject_copy[k] = NXlink(v.nxtarget)
1970
+ else:
1971
+ nxobject_copy[k] = nxcopy(
1972
+ v, exclude_nxpaths=exclude_nxpaths,
1973
+ nxpath_prefix=nxpath, nxpathabs_prefix=nxpathabs_prefix,
1974
+ nxpath_copy_abspath=os_path.join(nxpath_copy_abspath, k))
1975
+ elif isinstance(v, NXlink):
1976
+ if nxpathabs == v.nxpath and not any(
1977
+ v.nxtarget.startswith(os_path.join(nxpathabs_prefix, p))
1978
+ for p in exclude_nxpaths):
1979
+ nxobject_copy[k] = v
1980
+ else:
1981
+ nxobject_copy[k] = v.nxdata
1982
+ for kk, vv in v.attrs.items():
1983
+ nxobject_copy[k].attrs[kk] = vv
1984
+ nxobject_copy[k].attrs.pop('target', None)
1985
+ elif isinstance(v, NXgroup):
1986
+ nxobject_copy[k] = nxcopy(
1987
+ v, exclude_nxpaths=exclude_nxpaths,
1988
+ nxpath_prefix=nxpath, nxpathabs_prefix=nxpathabs_prefix,
1989
+ nxpath_copy_abspath=os_path.join(nxpath_copy_abspath, k))
1990
+ else:
1991
+ nxobject_copy[k] = v.nxdata
1992
+ for kk, vv in v.attrs.items():
1993
+ nxobject_copy[k].attrs[kk] = vv
1994
+ if nxpathabs != os_path.join(nxpath_copy_abspath, k):
1995
+ nxobject_copy[k].attrs.pop('target', None)
1996
+
1997
+ return nxobject_copy