pymodaq 4.1.5__py3-none-any.whl → 4.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pymodaq might be problematic. Click here for more details.

Files changed (79) hide show
  1. pymodaq/__init__.py +41 -4
  2. pymodaq/control_modules/daq_move.py +32 -73
  3. pymodaq/control_modules/daq_viewer.py +73 -98
  4. pymodaq/control_modules/daq_viewer_ui.py +2 -1
  5. pymodaq/control_modules/move_utility_classes.py +17 -7
  6. pymodaq/control_modules/utils.py +153 -5
  7. pymodaq/control_modules/viewer_utility_classes.py +31 -20
  8. pymodaq/dashboard.py +23 -5
  9. pymodaq/examples/tcp_client.py +97 -0
  10. pymodaq/extensions/__init__.py +4 -0
  11. pymodaq/extensions/bayesian/__init__.py +2 -0
  12. pymodaq/extensions/bayesian/bayesian_optimisation.py +673 -0
  13. pymodaq/extensions/bayesian/utils.py +403 -0
  14. pymodaq/extensions/daq_scan.py +4 -4
  15. pymodaq/extensions/daq_scan_ui.py +2 -1
  16. pymodaq/extensions/pid/pid_controller.py +12 -7
  17. pymodaq/extensions/pid/utils.py +9 -26
  18. pymodaq/extensions/utils.py +3 -0
  19. pymodaq/post_treatment/load_and_plot.py +42 -19
  20. pymodaq/resources/VERSION +1 -1
  21. pymodaq/resources/config_template.toml +9 -24
  22. pymodaq/resources/setup_plugin.py +1 -1
  23. pymodaq/utils/config.py +103 -5
  24. pymodaq/utils/daq_utils.py +35 -134
  25. pymodaq/utils/data.py +614 -95
  26. pymodaq/utils/enums.py +17 -1
  27. pymodaq/utils/factory.py +2 -2
  28. pymodaq/utils/gui_utils/custom_app.py +5 -2
  29. pymodaq/utils/gui_utils/dock.py +33 -4
  30. pymodaq/utils/gui_utils/utils.py +14 -1
  31. pymodaq/utils/h5modules/backends.py +9 -1
  32. pymodaq/utils/h5modules/data_saving.py +254 -57
  33. pymodaq/utils/h5modules/saving.py +1 -0
  34. pymodaq/utils/leco/daq_move_LECODirector.py +172 -0
  35. pymodaq/utils/leco/daq_xDviewer_LECODirector.py +170 -0
  36. pymodaq/utils/leco/desktop.ini +2 -0
  37. pymodaq/utils/leco/director_utils.py +58 -0
  38. pymodaq/utils/leco/leco_director.py +88 -0
  39. pymodaq/utils/leco/pymodaq_listener.py +279 -0
  40. pymodaq/utils/leco/utils.py +41 -0
  41. pymodaq/utils/managers/action_manager.py +20 -6
  42. pymodaq/utils/managers/parameter_manager.py +6 -4
  43. pymodaq/utils/managers/roi_manager.py +63 -54
  44. pymodaq/utils/math_utils.py +1 -1
  45. pymodaq/utils/plotting/data_viewers/__init__.py +3 -1
  46. pymodaq/utils/plotting/data_viewers/base.py +286 -0
  47. pymodaq/utils/plotting/data_viewers/viewer.py +29 -202
  48. pymodaq/utils/plotting/data_viewers/viewer0D.py +94 -47
  49. pymodaq/utils/plotting/data_viewers/viewer1D.py +341 -174
  50. pymodaq/utils/plotting/data_viewers/viewer1Dbasic.py +1 -1
  51. pymodaq/utils/plotting/data_viewers/viewer2D.py +271 -181
  52. pymodaq/utils/plotting/data_viewers/viewerND.py +26 -22
  53. pymodaq/utils/plotting/items/crosshair.py +3 -3
  54. pymodaq/utils/plotting/items/image.py +2 -1
  55. pymodaq/utils/plotting/plotter/plotter.py +94 -0
  56. pymodaq/utils/plotting/plotter/plotters/__init__.py +0 -0
  57. pymodaq/utils/plotting/plotter/plotters/matplotlib_plotters.py +134 -0
  58. pymodaq/utils/plotting/plotter/plotters/qt_plotters.py +78 -0
  59. pymodaq/utils/plotting/utils/axes_viewer.py +1 -1
  60. pymodaq/utils/plotting/utils/filter.py +194 -147
  61. pymodaq/utils/plotting/utils/lineout.py +13 -11
  62. pymodaq/utils/plotting/utils/plot_utils.py +89 -12
  63. pymodaq/utils/scanner/__init__.py +0 -3
  64. pymodaq/utils/scanner/scan_config.py +1 -9
  65. pymodaq/utils/scanner/scan_factory.py +10 -36
  66. pymodaq/utils/scanner/scanner.py +3 -2
  67. pymodaq/utils/scanner/scanners/_1d_scanners.py +7 -5
  68. pymodaq/utils/scanner/scanners/_2d_scanners.py +36 -49
  69. pymodaq/utils/scanner/scanners/sequential.py +10 -4
  70. pymodaq/utils/scanner/scanners/tabular.py +10 -5
  71. pymodaq/utils/slicing.py +1 -1
  72. pymodaq/utils/tcp_ip/serializer.py +38 -5
  73. pymodaq/utils/tcp_ip/tcp_server_client.py +25 -17
  74. {pymodaq-4.1.5.dist-info → pymodaq-4.2.0.dist-info}/METADATA +4 -2
  75. {pymodaq-4.1.5.dist-info → pymodaq-4.2.0.dist-info}/RECORD +78 -63
  76. pymodaq/resources/config_scan_template.toml +0 -42
  77. {pymodaq-4.1.5.dist-info → pymodaq-4.2.0.dist-info}/WHEEL +0 -0
  78. {pymodaq-4.1.5.dist-info → pymodaq-4.2.0.dist-info}/entry_points.txt +0 -0
  79. {pymodaq-4.1.5.dist-info → pymodaq-4.2.0.dist-info}/licenses/LICENSE +0 -0
pymodaq/utils/data.py CHANGED
@@ -9,9 +9,10 @@ from __future__ import annotations
9
9
  from abc import ABCMeta, abstractmethod, abstractproperty
10
10
  import numbers
11
11
  import numpy as np
12
- from typing import List, Tuple, Union, Any
12
+ from typing import List, Tuple, Union, Any, Callable
13
13
  from typing import Iterable as IterableType
14
14
  from collections.abc import Iterable
15
+ from collections import OrderedDict
15
16
  import logging
16
17
 
17
18
  import warnings
@@ -25,10 +26,22 @@ from pymodaq.utils.daq_utils import find_objects_in_list_from_attr_name_val
25
26
  from pymodaq.utils.logger import set_logger, get_module_name
26
27
  from pymodaq.utils.slicing import SpecialSlicersData
27
28
  from pymodaq.utils import math_utils as mutils
29
+ from pymodaq.utils.config import Config
30
+ from pymodaq.utils.plotting.plotter.plotter import PlotterFactory
28
31
 
32
+ config = Config()
33
+ plotter_factory = PlotterFactory()
29
34
  logger = set_logger(get_module_name(__file__))
30
35
 
31
36
 
37
+ def squeeze(data_array: np.ndarray, do_squeeze=True, squeeze_indexes: Tuple[int]=None) -> np.ndarray:
38
+ """ Squeeze numpy arrays return at least 1D arrays except if do_squeeze is False"""
39
+ if do_squeeze:
40
+ return np.atleast_1d(np.squeeze(data_array, axis=squeeze_indexes))
41
+ else:
42
+ return np.atleast_1d(data_array)
43
+
44
+
32
45
  class DataIndexWarning(Warning):
33
46
  pass
34
47
 
@@ -83,9 +96,11 @@ class DataDim(BaseEnum):
83
96
  DataND = 3
84
97
 
85
98
  def __le__(self, other_dim: 'DataDim'):
99
+ other_dim = enum_checker(DataDim, other_dim)
86
100
  return self.value.__le__(other_dim.value)
87
101
 
88
102
  def __lt__(self, other_dim: 'DataDim'):
103
+ other_dim = enum_checker(DataDim, other_dim)
89
104
  return self.value.__lt__(other_dim.value)
90
105
 
91
106
  def __ge__(self, other_dim: 'DataDim'):
@@ -93,12 +108,24 @@ class DataDim(BaseEnum):
93
108
  return self.value.__ge__(other_dim.value)
94
109
 
95
110
  def __gt__(self, other_dim: 'DataDim'):
111
+ other_dim = enum_checker(DataDim, other_dim)
96
112
  return self.value.__gt__(other_dim.value)
97
113
 
98
114
  @property
99
115
  def dim_index(self):
100
116
  return self.value
101
117
 
118
+ @staticmethod
119
+ def from_data_array(data_array: np.ndarray):
120
+ if len(data_array.shape) == 1 and data_array.size == 1:
121
+ return DataDim['Data0D']
122
+ elif len(data_array.shape) == 1 and data_array.size > 1:
123
+ return DataDim['Data1D']
124
+ elif len(data_array.shape) == 2:
125
+ return DataDim['Data2D']
126
+ else:
127
+ return DataDim['DataND']
128
+
102
129
 
103
130
  class DataSource(BaseEnum):
104
131
  """Enum for source of data"""
@@ -131,6 +158,8 @@ class Axis:
131
158
  The scaling to apply to a linspace version in order to obtain the proper scaling
132
159
  offset: float
133
160
  The offset to apply to a linspace/scaled version in order to obtain the proper axis
161
+ size: int
162
+ The size of the axis array (to be specified if data is None)
134
163
  spread_order: int
135
164
  An integer needed in the case where data has a spread DataDistribution. It refers to the index along the data's
136
165
  spread_index dimension
@@ -140,13 +169,13 @@ class Axis:
140
169
  >>> axis = Axis('myaxis', units='seconds', data=np.array([1,2,3,4,5]), index=0)
141
170
  """
142
171
 
143
- def __init__(self, label: str = '', units: str = '', data: np.ndarray = None, index: int = 0, scaling=None,
144
- offset=None, spread_order: int = 0):
172
+ def __init__(self, label: str = '', units: str = '', data: np.ndarray = None, index: int = 0,
173
+ scaling=None, offset=None, size=None, spread_order: int = 0):
145
174
  super().__init__()
146
175
 
147
176
  self.iaxis: Axis = SpecialSlicersData(self, False)
148
177
 
149
- self._size = None
178
+ self._size = size
150
179
  self._data = None
151
180
  self._index = None
152
181
  self._label = None
@@ -159,12 +188,18 @@ class Axis:
159
188
  self.data = data
160
189
  self.index = index
161
190
  self.spread_order = spread_order
162
- if (scaling is None or offset is None) and data is not None:
191
+ if (scaling is None or offset is None or size is None) and data is not None:
163
192
  self.get_scale_offset_from_data(data)
164
193
 
165
194
  def copy(self):
166
195
  return copy.copy(self)
167
196
 
197
+ def as_dwa(self) -> DataWithAxes:
198
+ dwa = DataRaw(self.label, data=[self.get_data()],
199
+ labels=[f'{self.label}_{self.units}'])
200
+ dwa.create_missing_axes()
201
+ return dwa
202
+
168
203
  @property
169
204
  def label(self) -> str:
170
205
  """str: get/set the label of this axis"""
@@ -208,7 +243,7 @@ class Axis:
208
243
  self._check_data_valid(data)
209
244
  self.get_scale_offset_from_data(data)
210
245
  self._size = data.size
211
- else:
246
+ elif self.size is None:
212
247
  self._size = 0
213
248
  self._data = data
214
249
 
@@ -216,6 +251,18 @@ class Axis:
216
251
  """Convenience method to obtain the axis data (usually None because scaling and offset are used)"""
217
252
  return self._data if self._data is not None else self._linear_data(self.size)
218
253
 
254
+ def get_data_at(self, indexes: Union[int, IterableType, slice]) -> np.ndarray:
255
+ """ Get data at specified indexes
256
+
257
+ Parameters
258
+ ----------
259
+ indexes:
260
+ """
261
+ if not (isinstance(indexes, np.ndarray) or isinstance(indexes, slice) or
262
+ isinstance(indexes, int)):
263
+ indexes = np.array(indexes)
264
+ return self.get_data()[indexes]
265
+
219
266
  def get_scale_offset_from_data(self, data: np.ndarray = None):
220
267
  """Get the scaling and offset from the axis's data
221
268
 
@@ -349,17 +396,20 @@ class Axis:
349
396
  ax._offset += offset
350
397
  return ax
351
398
 
352
- def __eq__(self, other):
353
- eq = self.label == other.label
354
- eq = eq and (self.units == other.units)
355
- eq = eq and (self.index == other.index)
356
- if self.data is not None and other.data is not None:
357
- eq = eq and (np.allclose(self.data, other.data))
358
- else:
359
- eq = eq and self.offset == other.offset
360
- eq = eq and self.scaling == other.scaling
399
+ def __eq__(self, other: Axis):
400
+ if isinstance(other, Axis):
401
+ eq = self.label == other.label
402
+ eq = eq and (self.units == other.units)
403
+ eq = eq and (self.index == other.index)
404
+ if self.data is not None and other.data is not None:
405
+ eq = eq and (np.allclose(self.data, other.data))
406
+ else:
407
+ eq = eq and self.offset == other.offset
408
+ eq = eq and self.scaling == other.scaling
361
409
 
362
- return eq
410
+ return eq
411
+ else:
412
+ return False
363
413
 
364
414
  def mean(self):
365
415
  if self._data is not None:
@@ -391,6 +441,8 @@ class Axis:
391
441
  return int((threshold - self.offset) / self.scaling)
392
442
 
393
443
  def find_indexes(self, thresholds: IterableType[float]) -> IterableType[int]:
444
+ if isinstance(thresholds, numbers.Number):
445
+ thresholds = [thresholds]
394
446
  return [self.find_index(threshold) for threshold in thresholds]
395
447
 
396
448
 
@@ -441,7 +493,9 @@ class DataLowLevel:
441
493
 
442
494
 
443
495
  class DataBase(DataLowLevel):
444
- """Base object to store homogeneous data and metadata generated by pymodaq's objects. To be inherited for real data
496
+ """Base object to store homogeneous data and metadata generated by pymodaq's objects.
497
+
498
+ To be inherited for real data
445
499
 
446
500
  Parameters
447
501
  ----------
@@ -517,8 +571,10 @@ class DataBase(DataLowLevel):
517
571
  """
518
572
 
519
573
  def __init__(self, name: str, source: DataSource = None, dim: DataDim = None,
520
- distribution: DataDistribution = DataDistribution['uniform'], data: List[np.ndarray] = None,
521
- labels: List[str] = [], origin: str = '', **kwargs):
574
+ distribution: DataDistribution = DataDistribution['uniform'],
575
+ data: List[np.ndarray] = None,
576
+ labels: List[str] = None, origin: str = '',
577
+ **kwargs):
522
578
 
523
579
  super().__init__(name=name)
524
580
  self._iter_index = 0
@@ -528,6 +584,7 @@ class DataBase(DataLowLevel):
528
584
  self._length = None
529
585
  self._labels = None
530
586
  self._dim = dim
587
+ self._errors = None
531
588
  self.origin = origin
532
589
 
533
590
  source = enum_checker(DataSource, source)
@@ -542,6 +599,10 @@ class DataBase(DataLowLevel):
542
599
  self.extra_attributes = []
543
600
  self.add_extra_attribute(**kwargs)
544
601
 
602
+ def as_dte(self, name: str = 'mydte') -> DataToExport:
603
+ """Convenience method to wrap the DataWithAxes object into a DataToExport"""
604
+ return DataToExport(name, data=[self])
605
+
545
606
  def add_extra_attribute(self, **kwargs):
546
607
  for key in kwargs:
547
608
  if key not in self.extra_attributes:
@@ -579,7 +640,7 @@ class DataBase(DataLowLevel):
579
640
  raise StopIteration
580
641
 
581
642
  def __getitem__(self, item) -> np.ndarray:
582
- if isinstance(item, int) and item < len(self):
643
+ if (isinstance(item, int) and item < len(self)) or isinstance(item, slice):
583
644
  return self.data[item]
584
645
  else:
585
646
  raise IndexError(f'The index should be an integer lower than the data length')
@@ -641,6 +702,8 @@ class DataBase(DataLowLevel):
641
702
  if isinstance(other, DataBase):
642
703
  if not(self.name == other.name and len(self) == len(other)):
643
704
  return False
705
+ if self.dim != other.dim:
706
+ return False
644
707
  eq = True
645
708
  for ind in range(len(self)):
646
709
  if self[ind].shape != other[ind].shape:
@@ -672,6 +735,9 @@ class DataBase(DataLowLevel):
672
735
  def __gt__(self, other):
673
736
  return self._comparison_common(other, '__gt__')
674
737
 
738
+ def deepcopy(self):
739
+ return copy.deepcopy(self)
740
+
675
741
  def average(self, other: 'DataBase', weight: int) -> 'DataBase':
676
742
  """ Compute the weighted average between self and other DataBase
677
743
 
@@ -696,6 +762,18 @@ class DataBase(DataLowLevel):
696
762
  new_data.data = [np.abs(dat) for dat in new_data]
697
763
  return new_data
698
764
 
765
+ def real(self):
766
+ """ Take the real part of itself"""
767
+ new_data = copy.copy(self)
768
+ new_data.data = [np.real(dat) for dat in new_data]
769
+ return new_data
770
+
771
+ def imag(self):
772
+ """ Take the imaginary part of itself"""
773
+ new_data = copy.copy(self)
774
+ new_data.data = [np.imag(dat) for dat in new_data]
775
+ return new_data
776
+
699
777
  def flipud(self):
700
778
  """Reverse the order of elements along axis 0 (up/down)"""
701
779
  new_data = copy.copy(self)
@@ -712,14 +790,42 @@ class DataBase(DataLowLevel):
712
790
  for dat in data:
713
791
  if dat.shape != self.shape:
714
792
  raise DataShapeError('Cannot append those ndarrays, they don\'t have the same shape as self')
715
- self.data = self.data + data.data
793
+ self.data += data.data
716
794
  self.labels.extend(data.labels)
717
795
 
796
+ def pop(self, index: int) -> DataBase:
797
+ """ Returns a copy of self but with data taken at the specified index"""
798
+ dwa = self.deepcopy()
799
+ dwa.data = [dwa.data[index]]
800
+ dwa.labels = [dwa.labels[index]]
801
+ return dwa
802
+
718
803
  @property
719
804
  def shape(self):
720
805
  """The shape of the nd-arrays"""
721
806
  return self._shape
722
807
 
808
+ def stack_as_array(self, axis=0, dtype=None) -> np.ndarray:
809
+ """ Stack all data arrays in a single numpy array
810
+
811
+ Parameters
812
+ ----------
813
+ axis: int
814
+ The new stack axis index, default 0
815
+ dtype: str or np.dtype
816
+ the dtype of the stacked array
817
+
818
+ Returns
819
+ -------
820
+ np.ndarray
821
+
822
+ See Also
823
+ --------
824
+ :meth:`np.stack`
825
+ """
826
+
827
+ return np.stack(self.data, axis=axis, dtype=dtype)
828
+
723
829
  @property
724
830
  def size(self):
725
831
  """The size of the nd-arrays"""
@@ -772,8 +878,8 @@ class DataBase(DataLowLevel):
772
878
  labels.append(f'CH{len(labels):02d}')
773
879
  self._labels = labels
774
880
 
775
- def get_data_index(self, index: int = 0):
776
- """Get the data by its index in the list"""
881
+ def get_data_index(self, index: int = 0) -> np.ndarray:
882
+ """Get the data by its index in the list, same as self[index]"""
777
883
  return self.data[index]
778
884
 
779
885
  @staticmethod
@@ -797,7 +903,7 @@ class DataBase(DataLowLevel):
797
903
  if isinstance(data, list):
798
904
  if len(data) == 0:
799
905
  is_valid = False
800
- if not isinstance(data[0], np.ndarray):
906
+ elif not isinstance(data[0], np.ndarray):
801
907
  is_valid = False
802
908
  elif len(data[0].shape) == 0:
803
909
  is_valid = False
@@ -863,10 +969,17 @@ class DataBase(DataLowLevel):
863
969
  @data.setter
864
970
  def data(self, data: List[np.ndarray]):
865
971
  data = self._check_data_type(data)
972
+ #data = [squeeze(data_array) for data_array in data]
866
973
  self._check_shape_dim_consistency(data)
867
974
  self._check_same_shape(data)
868
975
  self._data = data
869
976
 
977
+ def to_dict(self):
978
+ data_dict = OrderedDict([])
979
+ for ind in range(len(self)):
980
+ data_dict[self.labels[ind]] = self[ind]
981
+ return data_dict
982
+
870
983
 
871
984
  class AxesManagerBase:
872
985
  def __init__(self, data_shape: Tuple[int], axes: List[Axis], nav_indexes=None, sig_indexes=None, **kwargs):
@@ -1076,6 +1189,10 @@ class AxesManagerBase:
1076
1189
  def get_axis_from_index(self, index: int, create: bool = False) -> List[Axis]:
1077
1190
  ...
1078
1191
 
1192
+ def get_axis_from_index_spread(self, index: int, spread_order: int) -> Axis:
1193
+ """Only valid for Spread data"""
1194
+ ...
1195
+
1079
1196
  def get_nav_axes(self) -> List[Axis]:
1080
1197
  """Get the navigation axes corresponding to the data
1081
1198
 
@@ -1090,8 +1207,13 @@ class AxesManagerBase:
1090
1207
  def get_signal_axes(self):
1091
1208
  if self.sig_indexes is None:
1092
1209
  self._sig_indexes = tuple([int(axis.index) for axis in self.axes if axis.index not in self.nav_indexes])
1093
- return list(mutils.flatten([copy.copy(self.get_axis_from_index(index, create=True))
1094
- for index in self.sig_indexes]))
1210
+ axes = []
1211
+ for index in self._sig_indexes:
1212
+ axes_tmp = copy.copy(self.get_axis_from_index(index, create=True))
1213
+ for ax in axes_tmp:
1214
+ if ax.size > 1:
1215
+ axes.append(ax)
1216
+ return axes
1095
1217
 
1096
1218
  def is_axis_signal(self, axis: Axis) -> bool:
1097
1219
  """Check if an axis is considered signal or navigation"""
@@ -1270,7 +1392,7 @@ class AxesManagerSpread(AxesManagerBase):
1270
1392
  elif len(self.nav_indexes) != 1:
1271
1393
  raise ValueError('Spread data should have only one specified index in self.nav_indexes')
1272
1394
  elif axis.index in self.nav_indexes:
1273
- if axis.size != self._data_shape[self.nav_indexes[0]]:
1395
+ if axis.size != 1 and (axis.size != self._data_shape[self.nav_indexes[0]]):
1274
1396
  raise DataLengthError('all navigation axes should have the same size')
1275
1397
 
1276
1398
  def compute_shape_from_axes(self):
@@ -1384,6 +1506,11 @@ class AxesManagerSpread(AxesManagerBase):
1384
1506
  else:
1385
1507
  return None, None
1386
1508
 
1509
+ def get_axis_from_index_spread(self, index: int, spread_order: int) -> Axis:
1510
+ for axis in self.axes:
1511
+ if axis.index == index and axis.spread_order == spread_order:
1512
+ return axis
1513
+
1387
1514
  def _get_dimension_str(self):
1388
1515
  try:
1389
1516
  string = "("
@@ -1416,9 +1543,15 @@ class DataWithAxes(DataBase):
1416
1543
  For instance, nav_indexes = (3,2), means that the axis with index 3 in a at least 4D ndarray data is the first
1417
1544
  navigation axis while the axis with index 2 is the second navigation Axis. Axes with index 0 and 1 are signal
1418
1545
  axes of 2D ndarray data
1546
+ errors: list of ndarray.
1547
+ The list should match the length of the data attribute while the ndarrays
1548
+ should match the data ndarray
1419
1549
  """
1420
1550
 
1421
- def __init__(self, *args, axes: List[Axis] = [], nav_indexes: Tuple[int] = (), **kwargs):
1551
+ def __init__(self, *args, axes: List[Axis] = [],
1552
+ nav_indexes: Tuple[int] = (),
1553
+ errors: Iterable[np.ndarray] = None,
1554
+ **kwargs):
1422
1555
 
1423
1556
  if 'nav_axes' in kwargs:
1424
1557
  deprecation_msg('nav_axes parameter should not be used anymore, use nav_indexes')
@@ -1443,18 +1576,105 @@ class DataWithAxes(DataBase):
1443
1576
 
1444
1577
  self.get_dim_from_data_axes() # in DataBase, dim is processed from the shape of data, but if axes are provided
1445
1578
  #then use get_dim_from axes
1579
+ self._check_errors(errors)
1580
+
1581
+ def _check_errors(self, errors: Iterable[np.ndarray]):
1582
+ """ Make sure the errors object is adapted to the len/shape of the dwa object
1583
+
1584
+ new in 4.2.0
1585
+ """
1586
+ check = False
1587
+ if errors is None:
1588
+ self._errors = None
1589
+ return
1590
+ if isinstance(errors, (tuple, list)) and len(errors) == len(self):
1591
+ if np.all([isinstance(error, np.ndarray) for error in errors]):
1592
+ if np.all([error_array.shape == self.shape for error_array in errors]):
1593
+ check = True
1594
+ else:
1595
+ logger.warning(f'All error objects should have the same shape as the data'
1596
+ f'objects')
1597
+ else:
1598
+ logger.warning(f'All error objects should be np.ndarray')
1599
+
1600
+ if not check:
1601
+ logger.warning('the errors field is incompatible with the structure of the data')
1602
+ self._errors = None
1603
+ else:
1604
+ self._errors = errors
1605
+
1606
+ @property
1607
+ def errors(self):
1608
+ """ Get/Set the errors bar values as a list of np.ndarray
1609
+
1610
+ new in 4.2.0
1611
+ """
1612
+ return self._errors
1613
+
1614
+ @errors.setter
1615
+ def errors(self, errors: Iterable[np.ndarray]):
1616
+ self._check_errors(errors)
1617
+
1618
+ def get_error(self, index):
1619
+ """ Get a particular error ndarray at the given index in the list
1620
+
1621
+ new in 4.2.0
1622
+ """
1623
+ if self._errors is not None: #because to the initial check we know it is a list of ndarrays
1624
+ return self._errors[index]
1625
+ else:
1626
+ return np.array([0]) # this could be added to any numpy array of any shape
1627
+
1628
+ def errors_as_dwa(self):
1629
+ """ Get a dwa from self replacing the data content with the error attribute (if not None)
1630
+
1631
+ New in 4.2.0
1632
+ """
1633
+ if self.errors is not None:
1634
+ dwa = self.deepcopy_with_new_data(self.errors)
1635
+ dwa.name = f'{self.name}_errors'
1636
+ dwa.errors = None
1637
+ return dwa
1638
+ else:
1639
+ raise ValueError(f'Cannot create a dwa from a None, should be a list of ndarray')
1640
+
1641
+ def plot(self, plotter_backend: str = config('plotting', 'backend'), *args, viewer=None,
1642
+ **kwargs):
1643
+ """ Call a plotter factory and its plot method over the actual data"""
1644
+ return plotter_factory.get(plotter_backend).plot(self, *args, viewer=viewer, **kwargs)
1446
1645
 
1447
1646
  def set_axes_manager(self, data_shape, axes, nav_indexes, **kwargs):
1448
1647
  if self.distribution.name == 'uniform' or len(nav_indexes) == 0:
1449
1648
  self._distribution = DataDistribution['uniform']
1450
- self.axes_manager = AxesManagerUniform(data_shape=data_shape, axes=axes, nav_indexes=nav_indexes,
1649
+ self.axes_manager = AxesManagerUniform(data_shape=data_shape, axes=axes,
1650
+ nav_indexes=nav_indexes,
1451
1651
  **kwargs)
1452
1652
  elif self.distribution.name == 'spread':
1453
- self.axes_manager = AxesManagerSpread(data_shape=data_shape, axes=axes, nav_indexes=nav_indexes,
1653
+ self.axes_manager = AxesManagerSpread(data_shape=data_shape, axes=axes,
1654
+ nav_indexes=nav_indexes,
1454
1655
  **kwargs)
1455
1656
  else:
1456
1657
  raise ValueError(f'Such a data distribution ({data.distribution}) has no AxesManager')
1457
1658
 
1659
+ def __eq__(self, other):
1660
+ is_equal = super().__eq__(other)
1661
+ if isinstance(other, DataWithAxes):
1662
+ for ind in list(self.nav_indexes) + list(self.sig_indexes):
1663
+ axes_self = self.get_axis_from_index(ind)
1664
+ axes_other = other.get_axis_from_index(ind)
1665
+ if len(axes_other) != len(axes_self):
1666
+ return False
1667
+ for ind_ax in range(len(axes_self)):
1668
+ if axes_self[ind_ax] != axes_other[ind_ax]:
1669
+ return False
1670
+ if self.errors is None:
1671
+ is_equal = is_equal and other.errors is None
1672
+ else:
1673
+ for ind_error in range(len(self.errors)):
1674
+ if not np.allclose(self.errors[ind_error], other.errors[ind_error]):
1675
+ return False
1676
+ return is_equal
1677
+
1458
1678
  def __repr__(self):
1459
1679
  return f'<{self.__class__.__name__}: {self.name} <len:{self.length}> {self._am}>'
1460
1680
 
@@ -1494,6 +1714,15 @@ class DataWithAxes(DataBase):
1494
1714
  for axis in self.axes:
1495
1715
  axis.index = 0 if axis.index == 1 else 1
1496
1716
 
1717
+ def crop_at_along(self, coordinates_tuple: Tuple):
1718
+ slices = []
1719
+ for coordinates in coordinates_tuple:
1720
+ axis = self.get_axis_from_index(0)[0]
1721
+ indexes = axis.find_indexes(coordinates)
1722
+ slices.append(slice(indexes))
1723
+
1724
+ return self._slicer(slices, False)
1725
+
1497
1726
  def mean(self, axis: int = 0) -> DataWithAxes:
1498
1727
  """Process the mean of the data on the specified axis and returns the new data
1499
1728
 
@@ -1507,7 +1736,10 @@ class DataWithAxes(DataBase):
1507
1736
  """
1508
1737
  dat_mean = []
1509
1738
  for dat in self.data:
1510
- dat_mean.append(np.mean(dat, axis=axis))
1739
+ mean = np.mean(dat, axis=axis)
1740
+ if isinstance(mean, numbers.Number):
1741
+ mean = np.array([mean])
1742
+ dat_mean.append(mean)
1511
1743
  return self.deepcopy_with_new_data(dat_mean, remove_axes_index=axis)
1512
1744
 
1513
1745
  def sum(self, axis: int = 0) -> DataWithAxes:
@@ -1525,7 +1757,44 @@ class DataWithAxes(DataBase):
1525
1757
  for dat in self.data:
1526
1758
  dat_sum.append(np.sum(dat, axis=axis))
1527
1759
  return self.deepcopy_with_new_data(dat_sum, remove_axes_index=axis)
1528
-
1760
+
1761
+ def interp(self, new_axis_data: Union[Axis, np.ndarray], **kwargs) -> DataWithAxes:
1762
+ """Performs linear interpolation for 1D data only.
1763
+
1764
+ For more complex ones, see :py:meth:`scipy.interpolate`
1765
+
1766
+ Parameters
1767
+ ----------
1768
+ new_axis_data: Union[Axis, np.ndarray]
1769
+ The coordinates over which to do the interpolation
1770
+ kwargs: dict
1771
+ extra named parameters to be passed to the :py:meth:`~numpy.interp` method
1772
+
1773
+ Returns
1774
+ -------
1775
+ DataWithAxes
1776
+
1777
+ See Also
1778
+ --------
1779
+ :py:meth:`~numpy.interp`
1780
+ :py:meth:`~scipy.interpolate`
1781
+ """
1782
+ if self.dim != DataDim['Data1D']:
1783
+ raise ValueError('For basic interpolation, only 1D data are supported')
1784
+
1785
+ data_interpolated = []
1786
+ axis_obj = self.get_axis_from_index(0)[0]
1787
+ if isinstance(new_axis_data, np.ndarray):
1788
+ new_axis_data = Axis(axis_obj.label, axis_obj.units, data=new_axis_data)
1789
+
1790
+ for dat in self.data:
1791
+ data_interpolated.append(np.interp(new_axis_data.get_data(), axis_obj.get_data(), dat,
1792
+ **kwargs))
1793
+ new_data = DataCalculated(f'{self.name}_interp', data=data_interpolated,
1794
+ axes=[new_axis_data],
1795
+ labels=self.labels)
1796
+ return new_data
1797
+
1529
1798
  def ft(self, axis: int = 0) -> DataWithAxes:
1530
1799
  """Process the Fourier Transform of the data on the specified axis and returns the new data
1531
1800
 
@@ -1536,14 +1805,27 @@ class DataWithAxes(DataBase):
1536
1805
  Returns
1537
1806
  -------
1538
1807
  DataWithAxes
1808
+
1809
+ See Also
1810
+ --------
1811
+ :py:meth:`~pymodaq.utils.math_utils.ft`, :py:meth:`~numpy.fft.fft`
1539
1812
  """
1540
1813
  dat_ft = []
1814
+ axis_obj = self.get_axis_from_index(axis)[0]
1815
+ omega_grid, time_grid = mutils.ftAxis_time(len(axis_obj),
1816
+ np.abs(axis_obj.max() - axis_obj.min()))
1541
1817
  for dat in self.data:
1542
1818
  dat_ft.append(mutils.ft(dat, dim=axis))
1543
- return self.deepcopy_with_new_data(dat_ft)
1819
+ new_data = self.deepcopy_with_new_data(dat_ft)
1820
+ axis_obj = new_data.get_axis_from_index(axis)[0]
1821
+ axis_obj.data = omega_grid
1822
+ axis_obj.label = f'ft({axis_obj.label})'
1823
+ axis_obj.units = f'2pi/{axis_obj.units}'
1824
+ return new_data
1544
1825
 
1545
1826
  def ift(self, axis: int = 0) -> DataWithAxes:
1546
- """Process the inverse Fourier Transform of the data on the specified axis and returns the new data
1827
+ """Process the inverse Fourier Transform of the data on the specified axis and returns the
1828
+ new data
1547
1829
 
1548
1830
  Parameters
1549
1831
  ----------
@@ -1552,12 +1834,108 @@ class DataWithAxes(DataBase):
1552
1834
  Returns
1553
1835
  -------
1554
1836
  DataWithAxes
1837
+
1838
+ See Also
1839
+ --------
1840
+ :py:meth:`~pymodaq.utils.math_utils.ift`, :py:meth:`~numpy.fft.ifft`
1555
1841
  """
1556
1842
  dat_ift = []
1843
+ axis_obj = self.get_axis_from_index(axis)[0]
1844
+ omega_grid, time_grid = mutils.ftAxis_time(len(axis_obj),
1845
+ np.abs(axis_obj.max() - axis_obj.min()))
1557
1846
  for dat in self.data:
1558
1847
  dat_ift.append(mutils.ift(dat, dim=axis))
1559
- return self.deepcopy_with_new_data(dat_ift)
1560
-
1848
+ new_data = self.deepcopy_with_new_data(dat_ift)
1849
+ axis_obj.data = omega_grid
1850
+ axis_obj.label = f'ift({axis_obj.label})'
1851
+ axis_obj.units = f'2pi/{axis_obj.units}'
1852
+ return new_data
1853
+
1854
+ def fit(self, function: Callable, initial_guess: IterableType, data_index: int = None,
1855
+ axis_index: int = 0, **kwargs) -> DataCalculated:
1856
+ """ Apply 1D curve fitting using the scipy optimization package
1857
+
1858
+ Parameters
1859
+ ----------
1860
+ function: Callable
1861
+ a callable to be used for the fit
1862
+ initial_guess: Iterable
1863
+ The initial parameters for the fit
1864
+ data_index: int
1865
+ The index of the data over which to do the fit, if None apply the fit to all
1866
+ axis_index: int
1867
+ the axis index to use for the fit (if multiple) but there should be only one
1868
+ kwargs: dict
1869
+ extra named parameters applied to the curve_fit scipy method
1870
+
1871
+ Returns
1872
+ -------
1873
+ DataCalculated containing the evaluation of the fit on the specified axis
1874
+
1875
+ See Also
1876
+ --------
1877
+ :py:meth:`~scipy.optimize.curve_fit`
1878
+ """
1879
+ import scipy.optimize as opt
1880
+ if self.dim != DataDim['Data1D']:
1881
+ raise ValueError('Integrated fitting only works for 1D data')
1882
+ axis = self.get_axis_from_index(axis_index)[0].copy()
1883
+ axis_array = axis.get_data()
1884
+ if data_index is None:
1885
+ datalist_to_fit = self.data
1886
+ labels = [f'{label}_fit' for label in self.labels]
1887
+ else:
1888
+ datalist_to_fit = [self.data[data_index]]
1889
+ labels = [f'{self.labels[data_index]}_fit']
1890
+
1891
+ datalist_fitted = []
1892
+ fit_coeffs = []
1893
+ for data_array in datalist_to_fit:
1894
+ popt, pcov = opt.curve_fit(function, axis_array, data_array, p0=initial_guess, **kwargs)
1895
+ datalist_fitted.append(function(axis_array, *popt))
1896
+ fit_coeffs.append(popt)
1897
+
1898
+ return DataCalculated(f'{self.name}_fit', data=datalist_fitted,
1899
+ labels=labels,
1900
+ axes=[axis], fit_coeffs=fit_coeffs)
1901
+
1902
+ def find_peaks(self, height=None, threshold=None, **kwargs) -> DataToExport:
1903
+ """ Apply the scipy find_peaks method to 1D data
1904
+
1905
+ Parameters
1906
+ ----------
1907
+ height: number or ndarray or sequence, optional
1908
+ threshold: number or ndarray or sequence, optional
1909
+ kwargs: dict
1910
+ extra named parameters applied to the find_peaks scipy method
1911
+
1912
+ Returns
1913
+ -------
1914
+ DataCalculated
1915
+
1916
+ See Also
1917
+ --------
1918
+ :py:meth:`~scipy.optimize.find_peaks`
1919
+ """
1920
+ if self.dim != DataDim['Data1D']:
1921
+ raise ValueError('Finding peaks only works for 1D data')
1922
+ from scipy.signal import find_peaks
1923
+ peaks_indices = []
1924
+ dte = DataToExport('peaks')
1925
+ for ind in range(len(self)):
1926
+ peaks, properties = find_peaks(self[ind], height, threshold, **kwargs)
1927
+ peaks_indices.append(peaks)
1928
+
1929
+ dte.append(DataCalculated(f'{self.labels[ind]}',
1930
+ data=[self[ind][peaks_indices[-1]],
1931
+ peaks_indices[-1]
1932
+ ],
1933
+ labels=['peak value', 'peak indexes'],
1934
+ axes=[Axis('peak position', self.axes[0].units,
1935
+ data=self.axes[0].get_data_at(peaks_indices[-1]))])
1936
+ )
1937
+ return dte
1938
+
1561
1939
  def get_dim_from_data_axes(self) -> DataDim:
1562
1940
  """Get the dimensionality DataDim from data taking into account nav indexes
1563
1941
  """
@@ -1618,6 +1996,9 @@ class DataWithAxes(DataBase):
1618
1996
  def get_nav_axes(self) -> List[Axis]:
1619
1997
  return self._am.get_nav_axes()
1620
1998
 
1999
+ def get_sig_index(self) -> List[Axis]:
2000
+ return self._am.get_signal_axes()
2001
+
1621
2002
  def get_nav_axes_with_data(self) -> List[Axis]:
1622
2003
  """Get the data's navigation axes making sure there is data in the data field"""
1623
2004
  axes = self.get_nav_axes()
@@ -1633,12 +2014,36 @@ class DataWithAxes(DataBase):
1633
2014
  def get_axis_from_index(self, index, create=False):
1634
2015
  return self._am.get_axis_from_index(index, create)
1635
2016
 
2017
+ def get_axis_from_index_spread(self, index: int, spread: int):
2018
+ return self._am.get_axis_from_index_spread(index, spread)
2019
+
2020
+ def get_axis_from_label(self, label: str) -> Axis:
2021
+ """Get the axis referred by a given label
2022
+
2023
+ Parameters
2024
+ ----------
2025
+ label: str
2026
+ The label of the axis
2027
+
2028
+ Returns
2029
+ -------
2030
+ Axis or None: return the axis instance if it has the right label else None
2031
+ """
2032
+ for axis in self.axes:
2033
+ if axis.label == label:
2034
+ return axis
2035
+
1636
2036
  def create_missing_axes(self):
1637
- """Check if given the data shape, some axes are missing to properly define the data (especially for plotting)"""
2037
+ """Check if given the data shape, some axes are missing to properly define the data
2038
+ (especially for plotting)"""
1638
2039
  axes = self.axes[:]
1639
2040
  for index in self.nav_indexes + self.sig_indexes:
1640
- if len(self.get_axis_from_index(index)) != 0 and self.get_axis_from_index(index)[0] is None:
1641
- axes.extend(self.get_axis_from_index(index, create=True))
2041
+ if (len(self.get_axis_from_index(index)) != 0 and
2042
+ self.get_axis_from_index(index)[0] is None):
2043
+ axes_tmp = self.get_axis_from_index(index, create=True)
2044
+ for ax in axes_tmp:
2045
+ if ax.size > 1:
2046
+ axes.append(ax)
1642
2047
  self.axes = axes
1643
2048
 
1644
2049
  def _compute_slices(self, slices, is_navigation=True):
@@ -1657,11 +2062,23 @@ class DataWithAxes(DataBase):
1657
2062
  for ind in range(len(self.shape)):
1658
2063
  if ind in indexes:
1659
2064
  total_slices.append(slices.pop(0))
1660
- elif len(total_slices) == 0 or total_slices[-1] != Ellipsis:
2065
+ elif len(total_slices) == 0:
1661
2066
  total_slices.append(Ellipsis)
2067
+ elif not (Ellipsis in total_slices and total_slices[-1] is Ellipsis):
2068
+ total_slices.append(slice(None))
1662
2069
  total_slices = tuple(total_slices)
1663
2070
  return total_slices
1664
2071
 
2072
+ def check_squeeze(self, total_slices: List[slice], is_navigation: bool):
2073
+
2074
+ do_squeeze = True
2075
+ if 1 in self.data[0][total_slices].shape:
2076
+ if not is_navigation and self.data[0][total_slices].shape.index(1) in self.nav_indexes:
2077
+ do_squeeze = False
2078
+ elif is_navigation and self.data[0][total_slices].shape.index(1) in self.sig_indexes:
2079
+ do_squeeze = False
2080
+ return do_squeeze
2081
+
1665
2082
  def _slicer(self, slices, is_navigation=True):
1666
2083
  """Apply a given slice to the data either navigation or signal dimension
1667
2084
 
@@ -1675,18 +2092,21 @@ class DataWithAxes(DataBase):
1675
2092
  Returns
1676
2093
  -------
1677
2094
  DataWithAxes
1678
- Object of the same type as the initial data, derived from DataWithAxes. But with lower data size due to the
1679
- slicing and with eventually less axes.
2095
+ Object of the same type as the initial data, derived from DataWithAxes. But with lower
2096
+ data size due to the slicing and with eventually less axes.
1680
2097
  """
1681
2098
 
1682
2099
  if isinstance(slices, numbers.Number) or isinstance(slices, slice):
1683
2100
  slices = [slices]
1684
2101
  total_slices = self._compute_slices(slices, is_navigation)
1685
- new_arrays_data = [np.atleast_1d(np.squeeze(dat[total_slices])) for dat in self.data]
2102
+
2103
+ do_squeeze = self.check_squeeze(total_slices, is_navigation)
2104
+ new_arrays_data = [squeeze(dat[total_slices], do_squeeze) for dat in self.data]
1686
2105
  tmp_axes = self._am.get_signal_axes() if is_navigation else self._am.get_nav_axes()
1687
2106
  axes_to_append = [copy.deepcopy(axis) for axis in tmp_axes]
1688
2107
 
1689
- # axes_to_append are the axes to append to the new produced data (basically the ones to keep)
2108
+ # axes_to_append are the axes to append to the new produced data
2109
+ # (basically the ones to keep)
1690
2110
 
1691
2111
  indexes_to_get = self.nav_indexes if is_navigation else self.sig_indexes
1692
2112
  # indexes_to_get are the indexes of the axes where the slice should be applied
@@ -1694,42 +2114,51 @@ class DataWithAxes(DataBase):
1694
2114
  _indexes = list(self.nav_indexes)
1695
2115
  _indexes.extend(self.sig_indexes)
1696
2116
  lower_indexes = dict(zip(_indexes, [0 for _ in range(len(_indexes))]))
1697
- # lower_indexes will store for each *axis index* how much the index should be reduced because one axis has
2117
+ # lower_indexes will store for each *axis index* how much the index should be reduced
2118
+ # because one axis has
1698
2119
  # been removed
1699
2120
 
1700
2121
  axes = []
1701
2122
  nav_indexes = [] if is_navigation else list(self._am.nav_indexes)
1702
2123
  for ind_slice, _slice in enumerate(slices):
1703
- ax = self._am.get_axis_from_index(indexes_to_get[ind_slice])
1704
- if len(ax) != 0 and ax[0] is not None:
1705
- for ind in range(len(ax)):
1706
- ax[ind] = ax[ind].iaxis[_slice]
1707
-
1708
- if not(ax[0] is None or ax[0].size <= 1): # means the slice kept part of the axis
1709
- if is_navigation:
1710
- nav_indexes.append(self._am.nav_indexes[ind_slice])
1711
- axes.extend(ax)
1712
- else:
1713
- for axis in axes_to_append: # means we removed one of the axes (and data dim),
1714
- # hence axis index above current index should be lowered by 1
1715
- if axis.index > indexes_to_get[ind_slice]:
1716
- lower_indexes[axis.index] += 1
1717
- for index in indexes_to_get[ind_slice+1:]:
1718
- lower_indexes[index] += 1
2124
+ if ind_slice < len(indexes_to_get):
2125
+ ax = self._am.get_axis_from_index(indexes_to_get[ind_slice])
2126
+ if len(ax) != 0 and ax[0] is not None:
2127
+ for ind in range(len(ax)):
2128
+ ax[ind] = ax[ind].iaxis[_slice]
2129
+
2130
+ if not(ax[0] is None or ax[0].size <= 1): # means the slice kept part of the axis
2131
+ if is_navigation:
2132
+ nav_indexes.append(self._am.nav_indexes[ind_slice])
2133
+ axes.extend(ax)
2134
+ else:
2135
+ for axis in axes_to_append: # means we removed one of the axes (and data dim),
2136
+ # hence axis index above current index should be lowered by 1
2137
+ if axis.index > indexes_to_get[ind_slice]:
2138
+ lower_indexes[axis.index] += 1
2139
+ for index in indexes_to_get[ind_slice+1:]:
2140
+ lower_indexes[index] += 1
1719
2141
 
1720
2142
  axes.extend(axes_to_append)
1721
2143
  for axis in axes:
1722
2144
  axis.index -= lower_indexes[axis.index]
1723
2145
  for ind in range(len(nav_indexes)):
1724
2146
  nav_indexes[ind] -= lower_indexes[nav_indexes[ind]]
1725
- data = DataWithAxes(self.name, data=new_arrays_data, nav_indexes=tuple(nav_indexes), axes=axes,
2147
+
2148
+ if len(nav_indexes) != 0:
2149
+ distribution = self.distribution
2150
+ else:
2151
+ distribution = DataDistribution['uniform']
2152
+
2153
+ data = DataWithAxes(self.name, data=new_arrays_data, nav_indexes=tuple(nav_indexes),
2154
+ axes=axes,
1726
2155
  source='calculated', origin=self.origin,
1727
2156
  labels=self.labels[:],
1728
- distribution=self.distribution if len(nav_indexes) != 0 else DataDistribution['uniform'])
2157
+ distribution=distribution)
1729
2158
  return data
1730
2159
 
1731
2160
  def deepcopy_with_new_data(self, data: List[np.ndarray] = None,
1732
- remove_axes_index: List[int] = None,
2161
+ remove_axes_index: Union[int, List[int]] = None,
1733
2162
  source: DataSource = 'calculated',
1734
2163
  keep_dim=False) -> DataWithAxes:
1735
2164
  """deepcopy without copying the initial data (saving memory)
@@ -1762,7 +2191,6 @@ class DataWithAxes(DataBase):
1762
2191
  source = enum_checker(DataSource, source)
1763
2192
  new_data._source = source
1764
2193
 
1765
-
1766
2194
  if remove_axes_index is not None:
1767
2195
  if not isinstance(remove_axes_index, Iterable):
1768
2196
  remove_axes_index = [remove_axes_index]
@@ -1776,7 +2204,8 @@ class DataWithAxes(DataBase):
1776
2204
  sig_indexes = list(new_data.sig_indexes)
1777
2205
  for index in remove_axes_index:
1778
2206
  for axis in new_data.get_axis_from_index(index):
1779
- new_data.axes.remove(axis)
2207
+ if axis is not None:
2208
+ new_data.axes.remove(axis)
1780
2209
 
1781
2210
  if index in new_data.nav_indexes:
1782
2211
  nav_indexes.pop(nav_indexes.index(index))
@@ -1812,8 +2241,7 @@ class DataWithAxes(DataBase):
1812
2241
  finally:
1813
2242
  self._data = old_data
1814
2243
 
1815
- def deepcopy(self):
1816
- return copy.deepcopy(self)
2244
+
1817
2245
 
1818
2246
  @property
1819
2247
  def _am(self) -> AxesManagerBase:
@@ -1822,6 +2250,10 @@ class DataWithAxes(DataBase):
1822
2250
  def get_data_dimension(self) -> str:
1823
2251
  return str(self._am)
1824
2252
 
2253
+ def get_data_as_dwa(self, index: int = 0) -> DataWithAxes:
2254
+ """ Get the underlying data selected from the list at index, returned as a DataWithAxes"""
2255
+ return self.deepcopy_with_new_data([self[index]])
2256
+
1825
2257
 
1826
2258
  class DataRaw(DataWithAxes):
1827
2259
  """Specialized DataWithAxes set with source as 'raw'. To be used for raw data"""
@@ -1832,7 +2264,8 @@ class DataRaw(DataWithAxes):
1832
2264
 
1833
2265
 
1834
2266
  class DataActuator(DataRaw):
1835
- """Specialized DataWithAxes set with source as 'raw'. To be used for raw data generated by actuator plugins"""
2267
+ """Specialized DataWithAxes set with source as 'raw'.
2268
+ To be used for raw data generated by actuator plugins"""
1836
2269
  def __init__(self, *args, **kwargs):
1837
2270
  if len(args) == 0 and 'name' not in kwargs:
1838
2271
  args = ['actuator']
@@ -1844,44 +2277,67 @@ class DataActuator(DataRaw):
1844
2277
 
1845
2278
  def __repr__(self):
1846
2279
  if self.dim.name == 'Data0D':
1847
- return f'{self.__class__.__name__} <{self.data[0][0]}>'
2280
+ return f'<{self.__class__.__name__} ({self.data[0][0]})>'
1848
2281
  else:
1849
- return f'{self.__class__.__name__} <{self.shape}>>'
2282
+ return f'<{self.__class__.__name__} ({self.shape})>'
1850
2283
 
1851
- def value(self):
1852
- """Returns the underlying float value if this data holds only a float otherwise returns a mean of the
1853
- underlying data"""
2284
+ def value(self) -> float:
2285
+ """Returns the underlying float value (of the first elt in the data list) if this data
2286
+ holds only a float otherwise returns a mean of the underlying data"""
1854
2287
  if self.length == 1 and self.size == 1:
1855
2288
  return float(self.data[0][0])
1856
2289
  else:
1857
2290
  return float(np.mean(self.data))
1858
2291
 
2292
+ def values(self) -> List[float]:
2293
+ """Returns the underlying float value (for each data array in the data list) if this data
2294
+ holds only a float otherwise returns a mean of the underlying data"""
2295
+ if self.length == 1 and self.size == 1:
2296
+ return [float(data_array[0]) for data_array in self.data]
2297
+ else:
2298
+ return [float(np.mean(data_array)) for data_array in self.data]
2299
+
1859
2300
 
1860
2301
  class DataFromPlugins(DataRaw):
1861
2302
  """Specialized DataWithAxes set with source as 'raw'. To be used for raw data generated by Detector plugins
1862
2303
 
1863
- It introduces by default to extra attributes, plot and save. Their presence can be checked in the
2304
+ It introduces by default to extra attributes, do_plot and do_save. Their presence can be checked in the
1864
2305
  extra_attributes list.
1865
2306
 
1866
2307
  Parameters
1867
2308
  ----------
1868
- plot: bool
2309
+ do_plot: bool
1869
2310
  If True the underlying data will be plotted in the DAQViewer
1870
- save: bool
2311
+ do_save: bool
1871
2312
  If True the underlying data will be saved
1872
2313
 
1873
2314
  Attributes
1874
2315
  ----------
1875
- plot: bool
2316
+ do_plot: bool
1876
2317
  If True the underlying data will be plotted in the DAQViewer
1877
- save: bool
2318
+ do_save: bool
1878
2319
  If True the underlying data will be saved
1879
2320
  """
1880
2321
  def __init__(self, *args, **kwargs):
1881
- if 'plot' not in kwargs:
1882
- kwargs['plot'] = True
1883
- if 'save' not in kwargs:
1884
- kwargs['save'] = True
2322
+
2323
+ ##### for backcompatibility
2324
+ if 'plot' in kwargs:
2325
+ deprecation_msg("'plot' should not be used anymore as extra_attribute, "
2326
+ "please use 'do_plot'")
2327
+ do_plot = kwargs.pop('plot')
2328
+ kwargs['do_plot'] = do_plot
2329
+
2330
+ if 'save' in kwargs:
2331
+ deprecation_msg("'save' should not be used anymore as extra_attribute, "
2332
+ "please use 'do_save'")
2333
+ do_save = kwargs.pop('save')
2334
+ kwargs['do_save'] = do_save
2335
+ #######
2336
+
2337
+ if 'do_plot' not in kwargs:
2338
+ kwargs['do_plot'] = True
2339
+ if 'do_save' not in kwargs:
2340
+ kwargs['do_save'] = True
1885
2341
  super().__init__(*args, **kwargs)
1886
2342
 
1887
2343
 
@@ -1938,6 +2394,10 @@ class DataToExport(DataLowLevel):
1938
2394
  for key in kwargs:
1939
2395
  setattr(self, key, kwargs[key])
1940
2396
 
2397
+ def plot(self, plotter_backend: str = config('plotting', 'backend'), *args, **kwargs):
2398
+ """ Call a plotter factory and its plot method over the actual data"""
2399
+ return plotter_factory.get(plotter_backend).plot(self, *args, **kwargs)
2400
+
1941
2401
  def affect_name_to_origin_if_none(self):
1942
2402
  """Affect self.name to all DataWithAxes children's attribute origin if this origin is not defined"""
1943
2403
  for dat in self.data:
@@ -2000,20 +2460,30 @@ class DataToExport(DataLowLevel):
2000
2460
  raise TypeError(f'Could not average a {other.__class__.__name__} with a {self.__class__.__name__} '
2001
2461
  f'of a different length')
2002
2462
 
2003
- def merge_as_dwa(self, dim: DataDim, name: str = None) -> DataRaw:
2004
- """ attempt to merge all dwa into one
2463
+ def merge_as_dwa(self, dim: Union[str, DataDim], name: str = None) -> DataRaw:
2464
+ """ attempt to merge filtered dwa into one
2005
2465
 
2006
- Only possible if all dwa and underlying data have same shape
2466
+ Only possible if all filtered dwa and underlying data have same shape
2467
+
2468
+ Parameters
2469
+ ----------
2470
+ dim: DataDim or str
2471
+ will only try to merge dwa having this dimensionality
2472
+ name: str
2473
+ The new name of the returned dwa
2007
2474
  """
2008
2475
  dim = enum_checker(DataDim, dim)
2009
- if name is None:
2010
- name = self.name
2476
+
2011
2477
  filtered_data = self.get_data_from_dim(dim)
2012
- ndarrays = []
2013
- for dwa in filtered_data:
2014
- ndarrays.extend(dwa.data)
2015
- dwa = DataRaw(name, dim=dim, data=ndarrays)
2016
- return dwa
2478
+ if len(filtered_data) != 0:
2479
+ dwa = filtered_data[0].deepcopy()
2480
+ for dwa_tmp in filtered_data[1:]:
2481
+ if dwa_tmp.shape == dwa.shape and dwa_tmp.distribution == dwa.distribution:
2482
+ dwa.append(dwa_tmp)
2483
+ if name is None:
2484
+ name = self.name
2485
+ dwa.name = name
2486
+ return dwa
2017
2487
 
2018
2488
  def __repr__(self):
2019
2489
  repr = f'{self.__class__.__name__}: {self.name} <len:{len(self)}>\n'
@@ -2085,6 +2555,27 @@ class DataToExport(DataLowLevel):
2085
2555
  else:
2086
2556
  return [data.get_full_name() for data in self.get_data_from_dim(dim).data]
2087
2557
 
2558
+ def get_origins(self, dim: DataDim = None):
2559
+ """Get the origins of the underlying data into the returned value, eventually filtered by dim
2560
+
2561
+ Parameters
2562
+ ----------
2563
+ dim: DataDim or str
2564
+
2565
+ Returns
2566
+ -------
2567
+ list of str: the origins of the (filtered) DataWithAxes data
2568
+
2569
+ Examples
2570
+ --------
2571
+ d0 = DataWithAxes(name='datafromdet0', origin='det0')
2572
+ """
2573
+ if dim is None:
2574
+ return list({dwa.origin for dwa in self.data})
2575
+ else:
2576
+ return list({dwa.origin for dwa in self.get_data_from_dim(dim).data})
2577
+
2578
+
2088
2579
  def get_data_from_full_name(self, full_name: str, deepcopy=False) -> DataWithAxes:
2089
2580
  """Get the DataWithAxes with matching full name"""
2090
2581
  if deepcopy:
@@ -2294,8 +2785,8 @@ class DataToExport(DataLowLevel):
2294
2785
  def data(self, new_data: List[DataWithAxes]):
2295
2786
  for dat in new_data:
2296
2787
  self._check_data_type(dat)
2297
- self._data[:] = [dat for dat in new_data] # shallow copyto make sure that if the original list
2298
- # is changed, the change will not be applied in here
2788
+ self._data[:] = [dat for dat in new_data] # shallow copyto make sure that if the original
2789
+ # list is changed, the change will not be applied in here
2299
2790
 
2300
2791
  self.affect_name_to_origin_if_none()
2301
2792
 
@@ -2339,6 +2830,34 @@ class DataScan(DataToExport):
2339
2830
  super().__init__(name, data, **kwargs)
2340
2831
 
2341
2832
 
2833
+ class DataToActuators(DataToExport):
2834
+ """ Particular case of a DataToExport adding one named parameter to indicate what kind of change
2835
+ should be applied to the actuators, absolute or relative
2836
+
2837
+ Attributes
2838
+ ----------
2839
+ mode: str
2840
+ Adds an attribute called mode holding a string describing the type of change:
2841
+ relative or absolute
2842
+
2843
+ Parameters
2844
+ ---------
2845
+ mode: str
2846
+ either 'rel' or 'abs' for a relative or absolute change of the actuator's values
2847
+ """
2848
+
2849
+ def __init__(self, *args, mode='rel', **kwargs):
2850
+ if mode not in ['rel', 'abs']:
2851
+ warnings.warn('Incorrect mode for the actuators, switching to default relative mode: rel')
2852
+ mode = 'rel'
2853
+ kwargs.update({'mode': mode})
2854
+ super().__init__(*args, **kwargs)
2855
+
2856
+ def __repr__(self):
2857
+ return f'{super().__repr__()}: {self.mode}'
2858
+
2859
+
2860
+
2342
2861
  if __name__ == '__main__':
2343
2862
 
2344
2863