pymodaq_data 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pymodaq_data/data.py ADDED
@@ -0,0 +1,2901 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created the 28/10/2022
4
+
5
+ @author: Sebastien Weber
6
+ """
7
+ from __future__ import annotations
8
+
9
+ from abc import ABCMeta, abstractmethod, abstractproperty
10
+ import numbers
11
+ import numpy as np
12
+ from typing import List, Tuple, Union, Any, Callable
13
+ from typing import Iterable as IterableType
14
+ from collections.abc import Iterable
15
+ from collections import OrderedDict
16
+ import logging
17
+
18
+ import warnings
19
+ from time import time
20
+ import copy
21
+ import pint
22
+ from multipledispatch import dispatch
23
+
24
+ from pymodaq_utils.enums import BaseEnum, enum_checker
25
+ from pymodaq_utils.warnings import deprecation_msg
26
+ from pymodaq_utils.utils import find_objects_in_list_from_attr_name_val
27
+ from pymodaq_utils.logger import set_logger, get_module_name
28
+ from pymodaq_data.slicing import SpecialSlicersData
29
+ from pymodaq_utils import math_utils as mutils
30
+ from pymodaq_utils.config import Config
31
+ from pymodaq_data.plotting.plotter.plotter import PlotterFactory
32
+
33
+ from pymodaq_data import Q_, ureg, Unit
34
+
35
+ config = Config()
36
+ plotter_factory = PlotterFactory()
37
+ logger = set_logger(get_module_name(__file__))
38
+
39
+
40
+ def check_units(units: str):
41
+ try:
42
+ Unit(units)
43
+ return units
44
+ except pint.errors.UndefinedUnitError:
45
+ logger.warning(f'The unit "{units}" is not defined in the pint registry, switching to'
46
+ f'dimensionless')
47
+ return ''
48
+
49
+
50
+ def squeeze(data_array: np.ndarray, do_squeeze=True, squeeze_indexes: Tuple[int]=None) -> np.ndarray:
51
+ """ Squeeze numpy arrays return at least 1D arrays except if do_squeeze is False"""
52
+ if do_squeeze:
53
+ return np.atleast_1d(np.squeeze(data_array, axis=squeeze_indexes))
54
+ else:
55
+ return np.atleast_1d(data_array)
56
+
57
+
58
+ class DataIndexWarning(Warning):
59
+ pass
60
+
61
+
62
+ class DataTypeWarning(Warning):
63
+ pass
64
+
65
+
66
+ class DataDimWarning(Warning):
67
+ pass
68
+
69
+
70
+ class DataSizeWarning(Warning):
71
+ pass
72
+
73
+
74
+ WARNINGS = [DataIndexWarning, DataTypeWarning, DataDimWarning, DataSizeWarning]
75
+
76
+ if logging.getLevelName(logger.level) == 'DEBUG':
77
+ for warning in WARNINGS:
78
+ warnings.filterwarnings('default', category=warning)
79
+ else:
80
+ for warning in WARNINGS:
81
+ warnings.filterwarnings('ignore', category=warning)
82
+
83
+
84
+ class DataShapeError(Exception):
85
+ pass
86
+
87
+
88
+ class DataLengthError(Exception):
89
+ pass
90
+
91
+
92
+ class DataDimError(Exception):
93
+ pass
94
+
95
+
96
+ class DataUnitError(Exception):
97
+ pass
98
+
99
+
100
+ class DwaType(BaseEnum):
101
+ DataWithAxes = 0
102
+ DataRaw = 1
103
+ DataActuator = 2
104
+ DataFromPlugins = 3
105
+ DataCalculated = 4
106
+
107
+
108
+ class DataDim(BaseEnum):
109
+ """Enum for dimensionality representation of data"""
110
+ Data0D = 0
111
+ Data1D = 1
112
+ Data2D = 2
113
+ DataND = 3
114
+
115
+ def __le__(self, other_dim: 'DataDim'):
116
+ other_dim = enum_checker(DataDim, other_dim)
117
+ return self.value.__le__(other_dim.value)
118
+
119
+ def __lt__(self, other_dim: 'DataDim'):
120
+ other_dim = enum_checker(DataDim, other_dim)
121
+ return self.value.__lt__(other_dim.value)
122
+
123
+ def __ge__(self, other_dim: 'DataDim'):
124
+ other_dim = enum_checker(DataDim, other_dim)
125
+ return self.value.__ge__(other_dim.value)
126
+
127
+ def __gt__(self, other_dim: 'DataDim'):
128
+ other_dim = enum_checker(DataDim, other_dim)
129
+ return self.value.__gt__(other_dim.value)
130
+
131
+ @property
132
+ def dim_index(self):
133
+ return self.value
134
+
135
+ @staticmethod
136
+ def from_data_array(data_array: np.ndarray):
137
+ if len(data_array.shape) == 1 and data_array.size == 1:
138
+ return DataDim['Data0D']
139
+ elif len(data_array.shape) == 1 and data_array.size > 1:
140
+ return DataDim['Data1D']
141
+ elif len(data_array.shape) == 2:
142
+ return DataDim['Data2D']
143
+ else:
144
+ return DataDim['DataND']
145
+
146
+
147
+ class DataSource(BaseEnum):
148
+ """Enum for source of data"""
149
+ raw = 0
150
+ calculated = 1
151
+
152
+
153
+ class DataDistribution(BaseEnum):
154
+ """Enum for distribution of data"""
155
+ uniform = 0
156
+ spread = 1
157
+
158
+
159
+ class Axis:
160
+ """Object holding info and data about physical axis of some data
161
+
162
+ In case the axis's data is linear, store the info as a scale and offset else store the data
163
+
164
+ Parameters
165
+ ----------
166
+ label: str
167
+ The label of the axis, for instance 'time' for a temporal axis
168
+ units: str
169
+ The units of the data in the object, for instance 's' for seconds
170
+ data: ndarray
171
+ A 1D ndarray holding the data of the axis
172
+ index: int
173
+ an integer representing the index of the Data object this axis is related to
174
+ scaling: float
175
+ The scaling to apply to a linspace version in order to obtain the proper scaling
176
+ offset: float
177
+ The offset to apply to a linspace/scaled version in order to obtain the proper axis
178
+ size: int
179
+ The size of the axis array (to be specified if data is None)
180
+ spread_order: int
181
+ An integer needed in the case where data has a spread DataDistribution. It refers to the index along the data's
182
+ spread_index dimension
183
+
184
+ Examples
185
+ --------
186
+ >>> axis = Axis('myaxis', units='seconds', data=np.array([1,2,3,4,5]), index=0)
187
+ """
188
+
189
+ base_type = 'Axis'
190
+
191
+ def __init__(self, label: str = '', units: str = '', data: np.ndarray = None, index: int = 0,
192
+ scaling=None, offset=None, size=None, spread_order: int = 0):
193
+ super().__init__()
194
+
195
+ self.iaxis: Axis = SpecialSlicersData(self, False)
196
+
197
+ self._size = size
198
+ self._data = None
199
+ self._index = None
200
+ self._label = None
201
+ self._units = None
202
+ self._scaling = scaling
203
+ self._offset = offset
204
+
205
+ self.units = units
206
+ self.label = label
207
+ self.data = data
208
+ self.index = index
209
+ self.spread_order = spread_order
210
+ if (scaling is None or offset is None or size is None) and data is not None:
211
+ self.get_scale_offset_from_data(data)
212
+
213
+ def copy(self):
214
+ return copy.copy(self)
215
+
216
+ def as_dwa(self) -> DataWithAxes:
217
+ dwa = DataRaw(self.label, data=[self.get_data()],
218
+ labels=[f'{self.label}_{self.units}'])
219
+ dwa.create_missing_axes()
220
+ return dwa
221
+
222
+ @property
223
+ def label(self) -> str:
224
+ """str: get/set the label of this axis"""
225
+ return self._label
226
+
227
+ @label.setter
228
+ def label(self, lab: str):
229
+ if not isinstance(lab, str):
230
+ raise TypeError('label for the Axis class should be a string')
231
+ self._label = lab
232
+
233
+ @property
234
+ def units(self) -> str:
235
+ """str: get/set the units for this axis"""
236
+ return self._units
237
+
238
+ @units.setter
239
+ def units(self, units: str):
240
+ if not isinstance(units, str):
241
+ raise TypeError('units for the Axis class should be a string')
242
+ units = check_units(units)
243
+ self._units = units
244
+
245
+ @property
246
+ def index(self) -> int:
247
+ """int: get/set the index this axis corresponds to in a DataWithAxis object"""
248
+ return self._index
249
+
250
+ @index.setter
251
+ def index(self, ind: int):
252
+ self._check_index_valid(ind)
253
+ self._index = ind
254
+
255
+ @property
256
+ def data(self):
257
+ """np.ndarray: get/set the data of Axis"""
258
+ return self._data
259
+
260
+ @data.setter
261
+ def data(self, data: np.ndarray):
262
+ if data is not None:
263
+ self._check_data_valid(data)
264
+ self.get_scale_offset_from_data(data)
265
+ self._size = data.size
266
+ elif self.size is None:
267
+ self._size = 0
268
+ self._data = data
269
+
270
+ def get_data(self) -> np.ndarray:
271
+ """Convenience method to obtain the axis data (usually None because scaling and offset are used)"""
272
+ return self._data if self._data is not None else self._linear_data(self.size)
273
+
274
+ def get_data_at(self, indexes: Union[int, IterableType, slice]) -> np.ndarray:
275
+ """ Get data at specified indexes
276
+
277
+ Parameters
278
+ ----------
279
+ indexes:
280
+ """
281
+ if not (isinstance(indexes, np.ndarray) or isinstance(indexes, slice) or
282
+ isinstance(indexes, int)):
283
+ indexes = np.array(indexes)
284
+ return self.get_data()[indexes]
285
+
286
+ def get_scale_offset_from_data(self, data: np.ndarray = None):
287
+ """Get the scaling and offset from the axis's data
288
+
289
+ If data is not None, extract the scaling and offset
290
+
291
+ Parameters
292
+ ----------
293
+ data: ndarray
294
+ """
295
+ if data is None and self._data is not None:
296
+ data = self._data
297
+
298
+ if self.is_axis_linear(data):
299
+ if len(data) == 1:
300
+ self._scaling = 1
301
+ else:
302
+ self._scaling = np.mean(np.diff(data))
303
+ self._offset = data[0]
304
+ self._data = None
305
+
306
+ def is_axis_linear(self, data=None):
307
+ if data is None:
308
+ data = self.get_data()
309
+ if data is not None:
310
+ return np.allclose(np.diff(data), np.mean(np.diff(data)))
311
+ else:
312
+ return False
313
+
314
+ @property
315
+ def scaling(self):
316
+ return self._scaling
317
+
318
+ @scaling.setter
319
+ def scaling(self, _scaling: float):
320
+ self._scaling = _scaling
321
+
322
+ @property
323
+ def offset(self):
324
+ return self._offset
325
+
326
+ @offset.setter
327
+ def offset(self, _offset: float):
328
+ self._offset = _offset
329
+
330
+ @property
331
+ def size(self) -> int:
332
+ """int: get/set the size/length of the 1D ndarray"""
333
+ return self._size
334
+
335
+ @size.setter
336
+ def size(self, _size: int):
337
+ if self._data is None:
338
+ self._size = _size
339
+
340
+ @staticmethod
341
+ def _check_index_valid(index: int):
342
+ if not isinstance(index, int):
343
+ raise TypeError('index for the Axis class should be a positive integer')
344
+ elif index < 0:
345
+ raise ValueError('index for the Axis class should be a positive integer')
346
+
347
+ @staticmethod
348
+ def _check_data_valid(data):
349
+ if not isinstance(data, np.ndarray):
350
+ raise TypeError(f'data for the Axis class should be a 1D numpy array')
351
+ elif len(data.shape) != 1:
352
+ raise ValueError(f'data for the Axis class should be a 1D numpy array')
353
+
354
+ def _linear_data(self, nsteps: int):
355
+ """create axis data with a linear version using scaling and offset"""
356
+ return self._offset + self._scaling * np.linspace(0, nsteps-1, nsteps)
357
+
358
+ def create_linear_data(self, nsteps:int):
359
+ """replace the axis data with a linear version using scaling and offset"""
360
+ self.data = self._linear_data(nsteps)
361
+
362
+ @staticmethod
363
+ def create_simple_linear_data(nsteps: int):
364
+ return np.linspace(0, nsteps-1, nsteps)
365
+
366
+ def __len__(self):
367
+ return self.size
368
+
369
+ def _compute_slices(self, slices, *ignored, **ignored_also):
370
+ return slices
371
+
372
+ def _slicer(self, _slice, *ignored, **ignored_also):
373
+ ax: Axis = copy.deepcopy(self)
374
+ if isinstance(_slice, int):
375
+ ax.data = np.array([ax.get_data()[_slice]])
376
+ return ax
377
+ elif _slice is Ellipsis:
378
+ return ax
379
+ elif isinstance(_slice, slice):
380
+ if ax._data is not None:
381
+ ax.data = ax._data.__getitem__(_slice)
382
+ return ax
383
+ else:
384
+ start = _slice.start if _slice.start is not None else 0
385
+ stop = _slice.stop if _slice.stop is not None else self.size
386
+
387
+ ax._offset = ax.offset + start * ax.scaling
388
+ ax._size = stop - start
389
+ return ax
390
+
391
+ def __getitem__(self, item):
392
+ if hasattr(self, item):
393
+ # for when axis was a dict
394
+ deprecation_msg('attributes from an Axis object should not be fetched using __getitem__')
395
+ return getattr(self, item)
396
+
397
+ def __repr__(self):
398
+ return f'{self.__class__.__name__}: <label: {self.label}> - <units: {self.units}> - <index: {self.index}>'
399
+
400
+ def __mul__(self, scale: numbers.Real):
401
+ if isinstance(scale, numbers.Real):
402
+ ax = copy.deepcopy(self)
403
+ if self.data is not None:
404
+ ax.data *= scale
405
+ else:
406
+ ax._offset *= scale
407
+ ax._scaling *= scale
408
+ return ax
409
+
410
+ def __add__(self, offset: numbers.Real):
411
+ if isinstance(offset, numbers.Real):
412
+ ax = copy.deepcopy(self)
413
+ if self.data is not None:
414
+ ax.data += offset
415
+ else:
416
+ ax._offset += offset
417
+ return ax
418
+
419
+ def __eq__(self, other: Axis):
420
+ if isinstance(other, Axis):
421
+ eq = self.label == other.label
422
+ eq = eq and (Unit(self.units).is_compatible_with(other.units))
423
+ eq = eq and (self.index == other.index)
424
+ if self.data is not None and other.data is not None:
425
+ eq = eq and (np.allclose(Q_(self.data, self.units),
426
+ Q_(other.data, other.units)))
427
+ else:
428
+ eq = eq and (np.allclose(Q_(self.offset, self.units),
429
+ Q_(other.offset, other.units)))
430
+ eq = eq and (np.allclose(Q_(self.scaling, self.units),
431
+ Q_(other.scaling, other.units)))
432
+
433
+ return eq
434
+ else:
435
+ return False
436
+
437
+ def mean(self):
438
+ if self._data is not None:
439
+ return np.mean(self._data)
440
+ else:
441
+ return self.offset + self.size / 2 * self.scaling
442
+
443
+ def min(self):
444
+ if self._data is not None:
445
+ return np.min(self._data)
446
+ else:
447
+ return self.offset + (self.size * self.scaling if self.scaling < 0 else 0)
448
+
449
+ def max(self):
450
+ if self._data is not None:
451
+ return np.max(self._data)
452
+ else:
453
+ return self.offset + (self.size * self.scaling if self.scaling > 0 else 0)
454
+
455
+ def find_index(self, threshold: float) -> int:
456
+ """find the index of the threshold value within the axis"""
457
+ if threshold < self.min():
458
+ return 0
459
+ elif threshold > self.max():
460
+ return len(self) - 1
461
+ elif self._data is not None:
462
+ return mutils.find_index(self._data, threshold)[0][0]
463
+ else:
464
+ return int((threshold - self.offset) / self.scaling)
465
+
466
+ def find_indexes(self, thresholds: IterableType[float]) -> IterableType[int]:
467
+ if isinstance(thresholds, numbers.Number):
468
+ thresholds = [thresholds]
469
+ return [self.find_index(threshold) for threshold in thresholds]
470
+
471
+
472
+ class NavAxis(Axis):
473
+ def __init__(self, *args, **kwargs):
474
+ super().__init__(*args, **kwargs)
475
+ deprecation_msg('NavAxis should not be used anymore, please use Axis object with correct index.'
476
+ 'The navigation index should be specified in the Data object')
477
+
478
+
479
+ class DataLowLevel:
480
+ """Abstract object for all Data Object
481
+
482
+ Parameters
483
+ ----------
484
+ name: str
485
+ the identifier of the data
486
+
487
+ Attributes
488
+ ----------
489
+ name: str
490
+ timestamp: float
491
+ Time in seconds since epoch. See method time.time()
492
+ """
493
+
494
+ def __init__(self, name: str):
495
+ self._timestamp = time()
496
+ self._name = name
497
+
498
+ @property
499
+ def name(self):
500
+ """Get/Set the identifier of the data"""
501
+ return self._name
502
+
503
+ @name.setter
504
+ def name(self, other_name: str):
505
+ self._name = other_name
506
+
507
+ @property
508
+ def timestamp(self):
509
+ """Get/Set the timestamp of when the object has been created"""
510
+ return self._timestamp
511
+
512
+ @timestamp.setter
513
+ def timestamp(self, timestamp: float):
514
+ """The timestamp of when the object has been created"""
515
+ self._timestamp = timestamp
516
+
517
+
518
+ class DataBase(DataLowLevel):
519
+ """Base object to store homogeneous data and metadata generated by pymodaq's objects.
520
+
521
+ To be inherited for real data
522
+
523
+ Parameters
524
+ ----------
525
+ name: str
526
+ the identifier of these data
527
+ source: DataSource or str
528
+ Enum specifying if data are raw or processed (for instance from roi)
529
+ dim: DataDim or str
530
+ The identifier of the data type
531
+ distribution: DataDistribution or str
532
+ The distribution type of the data: uniform if distributed on a regular grid or spread if on
533
+ specific unordered points
534
+ data: list of ndarray
535
+ The data the object is storing
536
+ labels: list of str
537
+ The labels of the data nd-arrays
538
+ origin: str
539
+ An identifier of the element where the data originated, for instance the DAQ_Viewer's name.
540
+ Used when appending DataToExport in DAQ_Scan to disintricate from which origin data comes
541
+ from when scanning multiple detectors.
542
+ units: str
543
+ A unit string identifier as specified in the UnitRegistry of the pint module
544
+
545
+ kwargs: named parameters
546
+ All other parameters are stored dynamically using the name/value pair. The name of these
547
+ extra parameters are added into the extra_attributes attribute
548
+
549
+ Attributes
550
+ ----------
551
+ name: str
552
+ the identifier of these data
553
+ source: DataSource or str
554
+ Enum specifying if data are raw or processed (for instance from roi)
555
+ dim: DataDim or str
556
+ The identifier of the data type
557
+ distribution: DataDistribution or str
558
+ The distribution type of the data: uniform if distributed on a regular grid or spread if on specific
559
+ unordered points
560
+ data: list of ndarray
561
+ The data the object is storing
562
+ labels: list of str
563
+ The labels of the data nd-arrays
564
+ origin: str
565
+ An identifier of the element where the data originated, for instance the DAQ_Viewer's name. Used when appending
566
+ DataToExport in DAQ_Scan to disintricate from which origin data comes from when scanning multiple detectors.
567
+ shape: Tuple[int]
568
+ The shape of the underlying data
569
+ size: int
570
+ The size of the ndarrays stored in the object
571
+ length: int
572
+ The number of ndarrays stored in the object
573
+ extra_attributes: List[str]
574
+ list of string giving identifiers of the attributes added dynamically at the initialization (for instance
575
+ to save extra metadata using the DataSaverLoader
576
+
577
+ See Also
578
+ --------
579
+ DataWithAxes, DataFromPlugins, DataRaw, DataSaverLoader
580
+
581
+ Examples
582
+ --------
583
+ >>> import numpy as np
584
+ >>> from pymodaq.utils.data import DataBase, DataSource, DataDim, DataDistribution
585
+ >>> data = DataBase('mydata', source=DataSource['raw'], dim=DataDim['Data1D'], \
586
+ distribution=DataDistribution['uniform'], data=[np.array([1.,2.,3.]), np.array([4.,5.,6.])],\
587
+ labels=['channel1', 'channel2'], origin='docutils code')
588
+ >>> data.dim
589
+ <DataDim.Data1D: 1>
590
+ >>> data.source
591
+ <DataSource.raw: 0>
592
+ >>> data.shape
593
+ (3,)
594
+ >>> data.length
595
+ 2
596
+ >>> data.size
597
+ 3
598
+ """
599
+
600
+ base_type = 'Data'
601
+
602
+ def __init__(self, name: str,
603
+ source: DataSource = None, dim: DataDim = None,
604
+ distribution: DataDistribution = DataDistribution['uniform'],
605
+ data: List[np.ndarray] = None,
606
+ labels: List[str] = None, origin: str = '',
607
+ units: str = '',
608
+ **kwargs):
609
+
610
+ super().__init__(name=name)
611
+ self._iter_index = 0
612
+ self._shape = None
613
+ self._size = None
614
+ self._data = None
615
+ self._length = None
616
+ self._labels = None
617
+ self._dim = dim
618
+ self._units = check_units(units)
619
+ self._errors = None
620
+ self.origin = origin
621
+
622
+ source = enum_checker(DataSource, source)
623
+ self._source = source
624
+
625
+ distribution = enum_checker(DataDistribution, distribution)
626
+ self._distribution = distribution
627
+
628
+ self.data = data # dim consistency is actually checked within the setter method
629
+
630
+ self._check_labels(labels)
631
+ self.extra_attributes = []
632
+ self.add_extra_attribute(**kwargs)
633
+
634
+ @property
635
+ def units(self):
636
+ return self._units
637
+
638
+ @units.setter
639
+ def units(self, units: str):
640
+ units = check_units(units)
641
+ self.units_as(units, inplace=True)
642
+
643
+ def units_as(self, units: str, inplace=True) -> 'DataBase':
644
+ """ Set the object units to the new one (if possible)
645
+
646
+ Parameters
647
+ ----------
648
+ units: str
649
+ The new unit to convert the data to
650
+ inplace: bool
651
+ default True.
652
+ If True replace the data's arrays by array in the new units
653
+ If False, return a new data object
654
+ """
655
+ arrays = []
656
+ try:
657
+ for ind_array in range(len(self)):
658
+ arrays.append(self.quantities[ind_array].m_as(units))
659
+
660
+ except pint.errors.DimensionalityError as e:
661
+ raise DataUnitError(
662
+ f'Cannot change the Data units to {units} \n'
663
+ f'{e}')
664
+
665
+ if inplace:
666
+ self.data = arrays
667
+ self._units = units
668
+ return self
669
+ else:
670
+ new_data = copy.deepcopy(self)
671
+ new_data.data = arrays
672
+ new_data._units = units
673
+ return new_data
674
+
675
+ def as_dte(self, name: str = 'mydte') -> DataToExport:
676
+ """Convenience method to wrap the DataWithAxes object into a DataToExport"""
677
+ return DataToExport(name, data=[self])
678
+
679
+ def add_extra_attribute(self, **kwargs):
680
+ for key in kwargs:
681
+ if key not in self.extra_attributes:
682
+ self.extra_attributes.append(key)
683
+ setattr(self, key, kwargs[key])
684
+
685
+ def get_full_name(self) -> str:
686
+ """Get the data ful name including the origin attribute into the returned value
687
+
688
+ Returns
689
+ -------
690
+ str: the name of the ataWithAxes data constructed as : origin/name
691
+
692
+ Examples
693
+ --------
694
+ d0 = DataBase(name='datafromdet0', origin='det0')
695
+ """
696
+ return f'{self.origin}/{self.name}'
697
+
698
+ def __repr__(self):
699
+ return (f'{self.__class__.__name__} <{self.name}> '
700
+ f'<u: {self.units}> '
701
+ f'<{self.dim}> <{self.source}> <{self.shape}>')
702
+
703
+ def __len__(self):
704
+ return self.length
705
+
706
+ def __iter__(self):
707
+ self._iter_index = 0
708
+ return self
709
+
710
+ def __next__(self):
711
+ if self._iter_index < len(self):
712
+ self._iter_index += 1
713
+ return self.data[self._iter_index-1]
714
+ else:
715
+ raise StopIteration
716
+
717
+ def __getitem__(self, item) -> np.ndarray:
718
+ if (isinstance(item, int) and item < len(self)) or isinstance(item, slice):
719
+ return self.data[item]
720
+ else:
721
+ raise IndexError(f'The index should be an integer lower than the data length')
722
+
723
+ def __setitem__(self, key, value):
724
+ if isinstance(key, int) and key < len(self) and isinstance(value, np.ndarray) and value.shape == self.shape:
725
+ self.data[key] = value
726
+ else:
727
+ raise IndexError(f'The index should be an positive integer lower than the data length')
728
+
729
+ def __add__(self, other: object):
730
+ if isinstance(other, DataBase) and len(other) == len(self):
731
+ new_data = copy.deepcopy(self)
732
+ for ind_array in range(len(new_data)):
733
+ if self[ind_array].shape != other[ind_array].shape:
734
+ raise ValueError('The shapes of arrays stored into the data are not consistent')
735
+ try:
736
+ new_data[ind_array] = (Q_(self[ind_array], self.units) +
737
+ Q_(other[ind_array], other.units)).m_as(self.units)
738
+ except pint.errors.DimensionalityError as e:
739
+ raise DataUnitError(
740
+ f'Cannot sum Data objects not having the same dimension: {e}')
741
+ return new_data
742
+ else:
743
+ raise TypeError(f'Could not add a {other.__class__.__name__} or a {self.__class__.__name__} '
744
+ f'of a different length')
745
+
746
+ def __sub__(self, other: object):
747
+ return self.__add__(other * -1)
748
+ #
749
+ # if isinstance(other, DataBase) and len(other) == len(self):
750
+ # new_data = copy.deepcopy(self)
751
+ # for ind_array in range(len(new_data)):
752
+ # new_data[ind_array] = self[ind_array] - other[ind_array]
753
+ # return new_data
754
+ # elif isinstance(other, numbers.Number) and self.length == 1 and self.size == 1:
755
+ # new_data = copy.deepcopy(self)
756
+ # new_data = new_data - DataActuator(data=other)
757
+ # return new_data
758
+ # else:
759
+ # raise TypeError(f'Could not substract a {other.__class__.__name__} or a {self.__class__.__name__} '
760
+ # f'of a different length')
761
+
762
+ def __mul__(self, other):
763
+ if (isinstance(other, numbers.Number) or
764
+ (isinstance(other, np.ndarray) and other.shape == self._shape)):
765
+ new_data = copy.deepcopy(self)
766
+ for ind_array in range(len(new_data)):
767
+ new_data[ind_array] = self[ind_array] * other
768
+ return new_data
769
+ elif isinstance(other, DataBase) and other.shape == self._shape:
770
+ new_data = copy.deepcopy(self)
771
+ new_unit = str((Q_(self[0], self.units) *
772
+ Q_(other[0], other.units)).to_base_units().units)
773
+ for ind_array in range(len(new_data)):
774
+ new_data[ind_array] = \
775
+ ((Q_(self[ind_array], self.units) * Q_(other[ind_array], other.units))
776
+ .to_base_units()).magnitude
777
+ new_data._units = new_unit
778
+ return new_data
779
+ else:
780
+ raise TypeError(f'Could not multiply a {other.__class__.__name__} and a {self.__class__.__name__} '
781
+ f'of a different length')
782
+
783
+ def __truediv__(self, other):
784
+ if isinstance(other, numbers.Number):
785
+ return self * (1 / other)
786
+ else:
787
+ raise TypeError(f'Could not divide a {other.__class__.__name__} and a {self.__class__.__name__} '
788
+ f'of a different length')
789
+
790
+ def _comparison_common(self, other, operator='__eq__'):
791
+ if isinstance(other, DataBase):
792
+ if not (self.name == other.name and
793
+ len(self) == len(other) and
794
+ Unit(self.units).is_compatible_with(other.units)):
795
+ return False
796
+ if self.dim != other.dim:
797
+ return False
798
+ eq = True
799
+ for ind in range(len(self)):
800
+ if self[ind].shape != other[ind].shape:
801
+ eq = False
802
+ break
803
+ eq = eq and np.all(getattr(self.quantities[ind], operator)(other.quantities[ind]))
804
+ # extra attributes are not relevant as they may contain module specific data...
805
+ # eq = eq and (self.extra_attributes == other.extra_attributes)
806
+ # for attribute in self.extra_attributes:
807
+ # eq = eq and (getattr(self, attribute) == getattr(other, attribute))
808
+ return eq
809
+ elif isinstance(other, numbers.Number):
810
+ return np.all(getattr(self[0], operator)(other))
811
+ else:
812
+ raise TypeError()
813
+
814
+ def __eq__(self, other):
815
+ return self._comparison_common(other, '__eq__')
816
+
817
+ def __le__(self, other):
818
+ return self._comparison_common(other, '__le__')
819
+
820
+ def __lt__(self, other):
821
+ return self._comparison_common(other, '__lt__')
822
+
823
+ def __ge__(self, other):
824
+ return self._comparison_common(other, '__ge__')
825
+
826
+ def __gt__(self, other):
827
+ return self._comparison_common(other, '__gt__')
828
+
829
+ def deepcopy(self):
830
+ return copy.deepcopy(self)
831
+
832
+ def average(self, other: 'DataBase', weight: int) -> 'DataBase':
833
+ """ Compute the weighted average between self and other DataBase
834
+
835
+ Parameters
836
+ ----------
837
+ other_data: DataBase
838
+ weight: int
839
+ The weight the 'other' holds with respect to self
840
+ Returns
841
+ -------
842
+ DataBase: the averaged DataBase object
843
+ """
844
+ if isinstance(other, DataBase) and len(other) == len(self) and isinstance(weight, numbers.Number):
845
+ return (other * weight + self) / (weight + 1)
846
+ else:
847
+ raise TypeError(f'Could not average a {other.__class__.__name__} or a {self.__class__.__name__} '
848
+ f'of a different length')
849
+
850
+ def abs(self):
851
+ """ Take the absolute value of itself"""
852
+ new_data = copy.copy(self)
853
+ new_data.data = [np.abs(dat) for dat in new_data]
854
+ return new_data
855
+
856
+ def angle(self):
857
+ """ Take the phase value of itself"""
858
+ new_data = copy.copy(self)
859
+ new_data.data = [np.angle(dat) for dat in new_data]
860
+ return new_data
861
+
862
+ def real(self):
863
+ """ Take the real part of itself"""
864
+ new_data = copy.copy(self)
865
+ new_data.data = [np.real(dat) for dat in new_data]
866
+ return new_data
867
+
868
+ def imag(self):
869
+ """ Take the imaginary part of itself"""
870
+ new_data = copy.copy(self)
871
+ new_data.data = [np.imag(dat) for dat in new_data]
872
+ return new_data
873
+
874
+ def flipud(self):
875
+ """Reverse the order of elements along axis 0 (up/down)"""
876
+ new_data = copy.copy(self)
877
+ new_data.data = [np.flipud(dat) for dat in new_data]
878
+ return new_data
879
+
880
+ def fliplr(self):
881
+ """Reverse the order of elements along axis 1 (left/right)"""
882
+ new_data = copy.copy(self)
883
+ new_data.data = [np.fliplr(dat) for dat in new_data]
884
+ return new_data
885
+
886
+ def append(self, data: DataWithAxes):
887
+ """Append data content if the underlying arrays have the same shape and compatible units"""
888
+ for dat in data:
889
+ if dat.shape != self.shape:
890
+ raise DataShapeError('Cannot append those ndarrays, they don\'t have the same shape'
891
+ ' as self')
892
+ self.data += [Q_(data_array, data.units).m_as(self.units) for data_array in data.data]
893
+ self.labels.extend(data.labels)
894
+
895
+ def pop(self, index: int) -> DataBase:
896
+ """ Returns a copy of self but with data taken at the specified index"""
897
+ dwa = self.deepcopy()
898
+ dwa.data = [dwa.data[index]]
899
+ dwa.labels = [dwa.labels[index]]
900
+ return dwa
901
+
902
+ @property
903
+ def shape(self):
904
+ """The shape of the nd-arrays"""
905
+ return self._shape
906
+
907
+ def stack_as_array(self, axis=0, dtype=None) -> np.ndarray:
908
+ """ Stack all data arrays in a single numpy array
909
+
910
+ Parameters
911
+ ----------
912
+ axis: int
913
+ The new stack axis index, default 0
914
+ dtype: str or np.dtype
915
+ the dtype of the stacked array
916
+
917
+ Returns
918
+ -------
919
+ np.ndarray
920
+
921
+ See Also
922
+ --------
923
+ :meth:`np.stack`
924
+ """
925
+
926
+ return np.stack(self.data, axis=axis, dtype=dtype)
927
+
928
+ @property
929
+ def size(self):
930
+ """The size of the nd-arrays"""
931
+ return self._size
932
+
933
+ @property
934
+ def dim(self):
935
+ """DataDim: the enum representing the dimensionality of the stored data"""
936
+ return self._dim
937
+
938
+ def set_dim(self, dim: Union[DataDim, str]):
939
+ """Addhoc modification of dim independantly of the real data shape,
940
+ should be used with extra care"""
941
+ self._dim = enum_checker(DataDim, dim)
942
+
943
+ @property
944
+ def source(self):
945
+ """DataSource: the enum representing the source of the data"""
946
+ return self._source
947
+
948
+ @source.setter
949
+ def source(self, source_type: Union[str, DataSource]):
950
+ """DataSource: the enum representing the source of the data"""
951
+ source_type = enum_checker(DataSource, source_type)
952
+ self._source = source_type
953
+
954
+ @property
955
+ def distribution(self):
956
+ """DataDistribution: the enum representing the distribution of the stored data"""
957
+ return self._distribution
958
+
959
+ @property
960
+ def length(self):
961
+ """The length of data. This is the length of the list containing the nd-arrays"""
962
+ return self._length
963
+
964
+ @property
965
+ def labels(self):
966
+ return self._labels
967
+
968
+ @labels.setter
969
+ def labels(self, labels: List['str']):
970
+ self._check_labels(labels)
971
+
972
+ def _check_labels(self, labels: List['str']):
973
+ if labels is None:
974
+ labels = []
975
+ else:
976
+ labels = labels[:]
977
+ while len(labels) < self.length:
978
+ labels.append(f'CH{len(labels):02d}')
979
+ self._labels = labels
980
+
981
+ def get_data_index(self, index: int = 0) -> np.ndarray:
982
+ """Get the data by its index in the list, same as self[index]"""
983
+ return self.data[index]
984
+
985
+ @staticmethod
986
+ def _check_data_type(data: List[np.ndarray]) -> List[np.ndarray]:
987
+ """make sure data is a list of nd-arrays"""
988
+ is_valid = True
989
+ if data is None:
990
+ is_valid = False
991
+ if not isinstance(data, list):
992
+ # try to transform the data to regular type
993
+ if isinstance(data, np.ndarray):
994
+ warnings.warn(DataTypeWarning(f'Your data should be a list of numpy arrays not just a single numpy'
995
+ f' array, wrapping them with a list'))
996
+ data = [data]
997
+ elif isinstance(data, numbers.Number):
998
+ warnings.warn(DataTypeWarning(f'Your data should be a list of numpy arrays not just a single numpy'
999
+ f' array, wrapping them with a list'))
1000
+ data = [np.array([data])]
1001
+ else:
1002
+ is_valid = False
1003
+ if isinstance(data, list):
1004
+ if len(data) == 0:
1005
+ is_valid = False
1006
+ elif not isinstance(data[0], np.ndarray):
1007
+ is_valid = False
1008
+ elif len(data[0].shape) == 0:
1009
+ is_valid = False
1010
+ if not is_valid:
1011
+ raise TypeError(f'Data should be an non-empty list of non-empty numpy arrays')
1012
+ return data
1013
+
1014
+ def check_shape_from_data(self, data: List[np.ndarray]):
1015
+ self._shape = data[0].shape
1016
+
1017
+ @staticmethod
1018
+ def _get_dim_from_data(data: List[np.ndarray]) -> DataDim:
1019
+ shape = data[0].shape
1020
+ size = data[0].size
1021
+ if len(shape) == 1 and size == 1:
1022
+ dim = DataDim['Data0D']
1023
+ elif len(shape) == 1 and size > 1:
1024
+ dim = DataDim['Data1D']
1025
+ elif len(shape) == 2:
1026
+ dim = DataDim['Data2D']
1027
+ else:
1028
+ dim = DataDim['DataND']
1029
+ return dim
1030
+
1031
+ def get_dim_from_data(self, data: List[np.ndarray]):
1032
+ """Get the dimensionality DataDim from data"""
1033
+ self.check_shape_from_data(data)
1034
+ self._size = data[0].size
1035
+ self._length = len(data)
1036
+ if len(self._shape) == 1 and self._size == 1:
1037
+ dim = DataDim['Data0D']
1038
+ elif len(self._shape) == 1 and self._size > 1:
1039
+ dim = DataDim['Data1D']
1040
+ elif len(self._shape) == 2:
1041
+ dim = DataDim['Data2D']
1042
+ else:
1043
+ dim = DataDim['DataND']
1044
+ return dim
1045
+
1046
+ def _check_shape_dim_consistency(self, data: List[np.ndarray]):
1047
+ """Process the dim from data or make sure data and DataDim are coherent"""
1048
+ dim = self.get_dim_from_data(data)
1049
+ if self._dim is None:
1050
+ self._dim = dim
1051
+ else:
1052
+ self._dim = enum_checker(DataDim, self._dim)
1053
+ if self._dim != dim:
1054
+ warnings.warn(
1055
+ DataDimWarning('The specified dimensionality is not coherent with the data '
1056
+ 'shape, replacing it'))
1057
+ self._dim = dim
1058
+
1059
+ def _check_same_shape(self, data: List[np.ndarray]):
1060
+ """Check that all nd-arrays have the same shape"""
1061
+ for dat in data:
1062
+ if dat.shape != self.shape:
1063
+ raise DataShapeError('The shape of the ndarrays in data is not the same')
1064
+
1065
+ @property
1066
+ def quantities(self) -> list[Q_]:
1067
+ """ Get the arrays as pint quantities (with units)"""
1068
+ return [Q_(array, self.units) for array in self.data]
1069
+
1070
+ @property
1071
+ def data(self) -> List[np.ndarray]:
1072
+ """List[np.ndarray]: get/set (and check) the data the object is storing"""
1073
+ return self._data
1074
+
1075
+ @data.setter
1076
+ def data(self, data: List[np.ndarray]):
1077
+ data = self._check_data_type(data)
1078
+ self._check_shape_dim_consistency(data)
1079
+ self._check_same_shape(data)
1080
+ self._data = data
1081
+
1082
+ def to_dict(self):
1083
+ """ Get the data arrays into dictionary whose keys are the labels"""
1084
+ data_dict = OrderedDict([])
1085
+ for ind in range(len(self)):
1086
+ data_dict[self.labels[ind]] = self[ind]
1087
+ return data_dict
1088
+
1089
+ def to_dB(self) -> DataBase:
1090
+ """ Get a new data object in decibels
1091
+
1092
+ new in 4.3.0
1093
+ """
1094
+ new_data = copy.deepcopy(self)
1095
+ for ind_array in range(len(new_data)):
1096
+ new_data[ind_array] = 10 * np.log10(self[ind_array] / self[ind_array].max())
1097
+ new_data._units = 'dB'
1098
+ return new_data
1099
+
1100
+
1101
+ class AxesManagerBase:
1102
+ def __init__(self, data_shape: Tuple[int], axes: List[Axis], nav_indexes=None, sig_indexes=None, **kwargs):
1103
+ self._data_shape = data_shape[:] # initial shape needed for self._check_axis
1104
+ self._axes = axes[:]
1105
+ self._nav_indexes = nav_indexes
1106
+ self._sig_indexes = sig_indexes if sig_indexes is not None else self.compute_sig_indexes()
1107
+
1108
+ self._check_axis(self._axes)
1109
+ self._manage_named_axes(self._axes, **kwargs)
1110
+
1111
+ @property
1112
+ def axes(self):
1113
+ return self._axes
1114
+
1115
+ @axes.setter
1116
+ def axes(self, axes: List[Axis]):
1117
+ self._axes = axes[:]
1118
+ self._check_axis(self._axes)
1119
+
1120
+ @abstractmethod
1121
+ def _check_axis(self, axes):
1122
+ ...
1123
+
1124
+ @abstractmethod
1125
+ def get_sorted_index(self, axis_index: int = 0, spread_index=0) -> Tuple[np.ndarray, Tuple[slice]]:
1126
+ """ Get the index to sort the specified axis
1127
+
1128
+ Parameters
1129
+ ----------
1130
+ axis_index: int
1131
+ The index along which one should sort the data
1132
+ spread_index: int
1133
+ for spread data only, specifies which spread axis to use
1134
+
1135
+ Returns
1136
+ -------
1137
+ np.ndarray: the sorted index from the specified axis
1138
+ tuple of slice:
1139
+ used to slice the underlying data
1140
+ """
1141
+ ...
1142
+
1143
+ @abstractmethod
1144
+ def get_axis_from_index_spread(self, index: int, spread_order: int) -> Axis:
1145
+ """in spread mode, different nav axes have the same index (but not
1146
+ the same spread_order integer value)
1147
+
1148
+ """
1149
+ ...
1150
+
1151
+ def compute_sig_indexes(self):
1152
+ _shape = list(self._data_shape)
1153
+ indexes = list(np.arange(len(self._data_shape)))
1154
+ for index in self.nav_indexes:
1155
+ if index in indexes:
1156
+ indexes.pop(indexes.index(index))
1157
+ return tuple(indexes)
1158
+
1159
+ def _has_get_axis_from_index(self, index: int):
1160
+ """Check if the axis referred by a given data dimensionality index is present
1161
+
1162
+ Returns
1163
+ -------
1164
+ bool: True if the axis has been found else False
1165
+ Axis or None: return the axis instance if has the axis else None
1166
+ """
1167
+ if index > len(self._data_shape) or index < 0:
1168
+ raise IndexError('The specified index does not correspond to any data dimension')
1169
+ for axis in self.axes:
1170
+ if axis.index == index:
1171
+ return True, axis
1172
+ return False, None
1173
+
1174
+ def _manage_named_axes(self, axes, x_axis=None, y_axis=None, nav_x_axis=None, nav_y_axis=None):
1175
+ """This method make sur old style Data is still compatible, especially when using x_axis or y_axis parameters"""
1176
+ modified = False
1177
+ if x_axis is not None:
1178
+ modified = True
1179
+ index = 0
1180
+ if len(self._data_shape) == 1 and not self._has_get_axis_from_index(0)[0]:
1181
+ # in case of Data1D the x_axis corresponds to the first data dim
1182
+ index = 0
1183
+ elif len(self._data_shape) == 2 and not self._has_get_axis_from_index(1)[0]:
1184
+ # in case of Data2D the x_axis corresponds to the second data dim (columns)
1185
+ index = 1
1186
+ axes.append(Axis(x_axis.label, x_axis.units, x_axis.data, index=index))
1187
+
1188
+ if y_axis is not None:
1189
+
1190
+ if len(self._data_shape) == 2 and not self._has_get_axis_from_index(0)[0]:
1191
+ modified = True
1192
+ # in case of Data2D the y_axis corresponds to the first data dim (lines)
1193
+ axes.append(Axis(y_axis.label, y_axis.units, y_axis.data, index=0))
1194
+
1195
+ if nav_x_axis is not None:
1196
+ if len(self.nav_indexes) > 0:
1197
+ modified = True
1198
+ # in case of DataND the y_axis corresponds to the first data dim (lines)
1199
+ axes.append(Axis(nav_x_axis.label, nav_x_axis.units, nav_x_axis.data, index=self._nav_indexes[0]))
1200
+
1201
+ if nav_y_axis is not None:
1202
+ if len(self.nav_indexes) > 1:
1203
+ modified = True
1204
+ # in case of Data2D the y_axis corresponds to the first data dim (lines)
1205
+ axes.append(Axis(nav_y_axis.label, nav_y_axis.units, nav_y_axis.data, index=self._nav_indexes[1]))
1206
+
1207
+ if modified:
1208
+ self._check_axis(axes)
1209
+
1210
+ @property
1211
+ def shape(self) -> Tuple[int]:
1212
+ # self._data_shape = self.compute_shape_from_axes()
1213
+ return self._data_shape
1214
+
1215
+ @abstractmethod
1216
+ def compute_shape_from_axes(self):
1217
+ ...
1218
+
1219
+ @property
1220
+ def sig_shape(self) -> tuple:
1221
+ return tuple([self.shape[ind] for ind in self.sig_indexes])
1222
+
1223
+ @property
1224
+ def nav_shape(self) -> tuple:
1225
+ return tuple([self.shape[ind] for ind in self.nav_indexes])
1226
+
1227
+ def append_axis(self, axis: Axis):
1228
+ self._axes.append(axis)
1229
+ self._check_axis([axis])
1230
+
1231
+ @property
1232
+ def nav_indexes(self) -> IterableType[int]:
1233
+ return self._nav_indexes
1234
+
1235
+ @nav_indexes.setter
1236
+ def nav_indexes(self, nav_indexes: IterableType[int]):
1237
+ if isinstance(nav_indexes, Iterable):
1238
+ nav_indexes = tuple(nav_indexes)
1239
+ valid = True
1240
+ for index in nav_indexes:
1241
+ if index not in self.get_axes_index():
1242
+ logger.warning('Could not set the corresponding nav_index into the data object, not enough'
1243
+ ' Axis declared')
1244
+ valid = False
1245
+ break
1246
+ if valid:
1247
+ self._nav_indexes = nav_indexes
1248
+ else:
1249
+ logger.warning('Could not set the corresponding sig_indexes into the data object, should be an iterable')
1250
+ self.sig_indexes = self.compute_sig_indexes()
1251
+ self.shape
1252
+
1253
+ @property
1254
+ def sig_indexes(self) -> IterableType[int]:
1255
+ return self._sig_indexes
1256
+
1257
+ @sig_indexes.setter
1258
+ def sig_indexes(self, sig_indexes: IterableType[int]):
1259
+ if isinstance(sig_indexes, Iterable):
1260
+ sig_indexes = tuple(sig_indexes)
1261
+ valid = True
1262
+ for index in sig_indexes:
1263
+ if index in self._nav_indexes:
1264
+ logger.warning('Could not set the corresponding sig_index into the axis manager object, '
1265
+ 'the axis is already affected to the navigation axis')
1266
+ valid = False
1267
+ break
1268
+ if index not in self.get_axes_index():
1269
+ logger.warning('Could not set the corresponding nav_index into the data object, not enough'
1270
+ ' Axis declared')
1271
+ valid = False
1272
+ break
1273
+ if valid:
1274
+ self._sig_indexes = sig_indexes
1275
+ else:
1276
+ logger.warning('Could not set the corresponding sig_indexes into the data object, should be an iterable')
1277
+
1278
+ @property
1279
+ def nav_axes(self) -> List[int]:
1280
+ deprecation_msg('nav_axes parameter should not be used anymore, use nav_indexes')
1281
+ return self._nav_indexes
1282
+
1283
+ @nav_axes.setter
1284
+ def nav_axes(self, nav_indexes: List[int]):
1285
+ deprecation_msg('nav_axes parameter should not be used anymore, use nav_indexes')
1286
+ self.nav_indexes = nav_indexes
1287
+
1288
+ def is_axis_signal(self, axis: Axis) -> bool:
1289
+ """Check if an axis is considered signal or navigation"""
1290
+ return axis.index in self._nav_indexes
1291
+
1292
+ def is_axis_navigation(self, axis: Axis) -> bool:
1293
+ """Check if an axis is considered signal or navigation"""
1294
+ return axis.index not in self._nav_indexes
1295
+
1296
+ @abstractmethod
1297
+ def get_shape_from_index(self, index: int) -> int:
1298
+ """Get the data shape at the given index"""
1299
+ ...
1300
+
1301
+ def get_axes_index(self) -> List[int]:
1302
+ """Get the index list from the axis objects"""
1303
+ return [axis.index for axis in self._axes]
1304
+
1305
+ @abstractmethod
1306
+ def get_axis_from_index(self, index: int, create: bool = False) -> List[Axis]:
1307
+ ...
1308
+
1309
+ def get_axis_from_index_spread(self, index: int, spread_order: int) -> Axis:
1310
+ """Only valid for Spread data"""
1311
+ ...
1312
+
1313
+ def get_nav_axes(self) -> List[Axis]:
1314
+ """Get the navigation axes corresponding to the data
1315
+
1316
+ Use get_axis_from_index for all index in self.nav_indexes, but in spread distribution, one index may
1317
+ correspond to multiple nav axes, see Spread data distribution
1318
+
1319
+
1320
+ """
1321
+ return list(mutils.flatten([copy.copy(self.get_axis_from_index(index, create=True))
1322
+ for index in self.nav_indexes]))
1323
+
1324
+ def get_signal_axes(self):
1325
+ if self.sig_indexes is None:
1326
+ self._sig_indexes = tuple([int(axis.index) for axis in self.axes if axis.index not in self.nav_indexes])
1327
+ axes = []
1328
+ for index in self._sig_indexes:
1329
+ axes_tmp = copy.copy(self.get_axis_from_index(index, create=True))
1330
+ for ax in axes_tmp:
1331
+ if ax.size > 1:
1332
+ axes.append(ax)
1333
+ return axes
1334
+
1335
+ def is_axis_signal(self, axis: Axis) -> bool:
1336
+ """Check if an axis is considered signal or navigation"""
1337
+ return axis.index in self._nav_indexes
1338
+
1339
+ def is_axis_navigation(self, axis: Axis) -> bool:
1340
+ """Check if an axis is considered signal or navigation"""
1341
+ return axis.index not in self._nav_indexes
1342
+
1343
+ def __repr__(self):
1344
+ return self._get_dimension_str()
1345
+
1346
+ @abstractmethod
1347
+ def _get_dimension_str(self):
1348
+ ...
1349
+
1350
+
1351
+ class AxesManagerUniform(AxesManagerBase):
1352
+ def __init__(self, *args, **kwargs):
1353
+ super().__init__(*args, **kwargs)
1354
+
1355
+ def compute_shape_from_axes(self):
1356
+ if len(self.axes) != 0:
1357
+ shape = []
1358
+ for ind in range(len(self.axes)):
1359
+ shape.append(len(self.get_axis_from_index(ind, create=True)[0]))
1360
+ else:
1361
+ shape = self._data_shape
1362
+ return tuple(shape)
1363
+
1364
+ def get_shape_from_index(self, index: int) -> int:
1365
+ """Get the data shape at the given index"""
1366
+ if index > len(self._data_shape) or index < 0:
1367
+ raise IndexError('The specified index does not correspond to any data dimension')
1368
+ return self._data_shape[index]
1369
+
1370
+ def _check_axis(self, axes: List[Axis]):
1371
+ """Check all axis to make sure of their type and make sure their data are properly referring to the data index
1372
+
1373
+ See Also
1374
+ --------
1375
+ :py:meth:`Axis.create_linear_data`
1376
+ """
1377
+ for ind, axis in enumerate(axes):
1378
+ if not isinstance(axis, Axis):
1379
+ raise TypeError(f'An axis of {self.__class__.__name__} should be an Axis object')
1380
+ if self.get_shape_from_index(axis.index) != axis.size:
1381
+ warnings.warn(DataSizeWarning('The size of the axis is not coherent with the shape of the data. '
1382
+ 'Replacing it with a linspaced version: np.array([0, 1, 2, ...])'))
1383
+ axis.size = self.get_shape_from_index(axis.index)
1384
+ axis.scaling = 1
1385
+ axis.offset = 0
1386
+ axes[ind] = axis
1387
+ self._axes = axes
1388
+
1389
+ def get_axis_from_index(self, index: int, create: bool = False) -> List[Axis]:
1390
+ """Get the axis referred by a given data dimensionality index
1391
+
1392
+ If the axis is absent, create a linear one to fit the data shape if parameter create is True
1393
+
1394
+ Parameters
1395
+ ----------
1396
+ index: int
1397
+ The index referring to the data ndarray shape
1398
+ create: bool
1399
+ If True and the axis referred by index has not been found in axes, create one
1400
+
1401
+ Returns
1402
+ -------
1403
+ List[Axis] or None: return the list of axis instance if Data has the axis (or it has been created) else None
1404
+
1405
+ See Also
1406
+ --------
1407
+ :py:meth:`Axis.create_linear_data`
1408
+ """
1409
+ index = int(index)
1410
+ has_axis, axis = self._has_get_axis_from_index(index)
1411
+ if not has_axis:
1412
+ if create:
1413
+ warnings.warn(DataIndexWarning(f'The axis requested with index {index} is not present, '
1414
+ f'creating a linear one...'))
1415
+ axis = Axis(index=index, offset=0, scaling=1)
1416
+ axis.size = self.get_shape_from_index(index)
1417
+ else:
1418
+ warnings.warn(DataIndexWarning(f'The axis requested with index {index} is not present, returning None'))
1419
+ return [axis]
1420
+
1421
+ def get_axis_from_index_spread(self, index: int, spread_order: int) -> Axis:
1422
+ """in spread mode, different nav axes have the same index (but not
1423
+ the same spread_order integer value)
1424
+
1425
+ """
1426
+ return None
1427
+
1428
+ def get_sorted_index(self, axis_index: int = 0, spread_index=0) -> Tuple[np.ndarray, Tuple[slice]]:
1429
+ """ Get the index to sort the specified axis
1430
+
1431
+ Parameters
1432
+ ----------
1433
+ axis_index: int
1434
+ The index along which one should sort the data
1435
+ spread_index: int
1436
+ for spread data only, specifies which spread axis to use
1437
+
1438
+ Returns
1439
+ -------
1440
+ np.ndarray: the sorted index from the specified axis
1441
+ tuple of slice:
1442
+ used to slice the underlying data
1443
+ """
1444
+
1445
+ axes = self.get_axis_from_index(axis_index)
1446
+ if axes[0] is not None:
1447
+ sorted_index = np.argsort(axes[0].get_data())
1448
+ axes[0].data = axes[0].get_data()[sorted_index]
1449
+ slices = []
1450
+ for ind in range(len(self.shape)):
1451
+ if ind == axis_index:
1452
+ slices.append(sorted_index)
1453
+ else:
1454
+ slices.append(Ellipsis)
1455
+ slices = tuple(slices)
1456
+ return sorted_index, slices
1457
+ else:
1458
+ return None, None
1459
+
1460
+ def _get_dimension_str(self):
1461
+ string = "("
1462
+ for nav_index in self.nav_indexes:
1463
+ string += str(self._data_shape[nav_index]) + ", "
1464
+ string = string.rstrip(", ")
1465
+ string += "|"
1466
+ for sig_index in self.sig_indexes:
1467
+ string += str(self._data_shape[sig_index]) + ", "
1468
+ string = string.rstrip(", ")
1469
+ string += ")"
1470
+ return string
1471
+
1472
+
1473
+ class AxesManagerSpread(AxesManagerBase):
1474
+ """For this particular data category, some explanation is needed, see example below:
1475
+
1476
+ Examples
1477
+ --------
1478
+ One take images data (20x30) as a function of 2 parameters, say xaxis and yaxis non-linearly spaced on a regular
1479
+ grid.
1480
+
1481
+ data.shape = (150, 20, 30)
1482
+ data.nav_indexes = (0,)
1483
+
1484
+ The first dimension (150) corresponds to the navigation (there are 150 non uniform data points taken)
1485
+ The second and third could correspond to signal data, here an image of size (20x30)
1486
+ so:
1487
+ * nav_indexes is (0, )
1488
+ * sig_indexes are (1, 2)
1489
+
1490
+ xaxis = Axis(name=xaxis, index=0, data...) length 150
1491
+ yaxis = Axis(name=yaxis, index=0, data...) length 150
1492
+
1493
+ In fact from such a data shape the number of navigation axes in unknown . In our example, they are 2. To somehow
1494
+ keep track of some ordering in these navigation axes, one adds an attribute to the Axis object: the spread_order
1495
+ xaxis = Axis(name=xaxis, index=0, spread_order=0, data...) length 150
1496
+ yaxis = Axis(name=yaxis, index=0, spread_order=1, data...) length 150
1497
+ """
1498
+
1499
+ def __init__(self, *args, **kwargs):
1500
+ super().__init__(*args, **kwargs)
1501
+
1502
+ def _check_axis(self, axes: List[Axis]):
1503
+ """Check all axis to make sure of their type and make sure their data are properly referring to the data index
1504
+
1505
+ """
1506
+ for axis in axes:
1507
+ if not isinstance(axis, Axis):
1508
+ raise TypeError(f'An axis of {self.__class__.__name__} should be an Axis object')
1509
+ elif len(self.nav_indexes) != 1:
1510
+ raise ValueError('Spread data should have only one specified index in self.nav_indexes')
1511
+ elif axis.index in self.nav_indexes:
1512
+ if axis.size != 1 and (axis.size != self._data_shape[self.nav_indexes[0]]):
1513
+ raise DataLengthError('all navigation axes should have the same size')
1514
+
1515
+ def compute_shape_from_axes(self):
1516
+ """Get data shape from axes
1517
+
1518
+ First get the nav length from one of the navigation axes
1519
+ Then check for signal axes
1520
+ """
1521
+ if len(self.axes) != 0:
1522
+
1523
+ axes = sorted(self.axes, key=lambda axis: axis.index)
1524
+
1525
+ shape = []
1526
+ for axis in axes:
1527
+ if axis.index in self.nav_indexes:
1528
+ shape.append(axis.size)
1529
+ break
1530
+ for axis in axes:
1531
+ if axis.index not in self.nav_indexes:
1532
+ shape.append(axis.size)
1533
+ else:
1534
+ shape = self._data_shape
1535
+ return tuple(shape)
1536
+
1537
+ def get_shape_from_index(self, index: int) -> int:
1538
+ """Get the data shape at the given index"""
1539
+ if index > len(self._data_shape) or index < 0:
1540
+ raise IndexError('The specified index does not correspond to any data dimension')
1541
+ return self._data_shape[index]
1542
+
1543
+ def get_axis_from_index(self, index: int, create: bool = False) -> List[Axis]:
1544
+ """in spread mode, different nav axes have the same index (but not
1545
+ the same spread_order integer value) so may return multiple axis
1546
+
1547
+ No possible "linear" creation in this mode except if the index is a signal index
1548
+
1549
+ """
1550
+ if index in self.nav_indexes:
1551
+ axes = []
1552
+ for axis in self.axes:
1553
+ if axis.index == index:
1554
+ axes.append(axis)
1555
+ return axes
1556
+ else:
1557
+ index = int(index)
1558
+ try:
1559
+ has_axis, axis = self._has_get_axis_from_index(index)
1560
+ except IndexError:
1561
+ axis = [None]
1562
+ has_axis = False
1563
+ return axis
1564
+
1565
+ if not has_axis and index in self.sig_indexes:
1566
+ if create:
1567
+ warnings.warn(DataIndexWarning(f'The axis requested with index {index} is not present, '
1568
+ f'creating a linear one...'))
1569
+ axis = Axis(index=index, offset=0, scaling=1)
1570
+ axis.size = self.get_shape_from_index(index)
1571
+ else:
1572
+ warnings.warn(DataIndexWarning(f'The axis requested with index {index} is not present, returning None'))
1573
+
1574
+ return [axis]
1575
+
1576
+ def get_axis_from_index_spread(self, index: int, spread_order: int) -> Axis:
1577
+ """in spread mode, different nav axes have the same index (but not
1578
+ the same spread_order integer value)
1579
+
1580
+ """
1581
+ for axis in self.axes:
1582
+ if axis.index == index and axis.spread_order == spread_order:
1583
+ return axis
1584
+
1585
+ def get_sorted_index(self, axis_index: int = 0, spread_index=0) -> Tuple[np.ndarray, Tuple[slice]]:
1586
+ """ Get the index to sort the specified axis
1587
+
1588
+ Parameters
1589
+ ----------
1590
+ axis_index: int
1591
+ The index along which one should sort the data
1592
+ spread_index: int
1593
+ for spread data only, specifies which spread axis to use
1594
+
1595
+ Returns
1596
+ -------
1597
+ np.ndarray: the sorted index from the specified axis
1598
+ tuple of slice:
1599
+ used to slice the underlying data
1600
+ """
1601
+
1602
+ if axis_index in self.nav_indexes:
1603
+ axis = self.get_axis_from_index_spread(axis_index, spread_index)
1604
+ else:
1605
+ axis = self.get_axis_from_index(axis_index)[0]
1606
+
1607
+ if axis is not None:
1608
+ sorted_index = np.argsort(axis.get_data())
1609
+ slices = []
1610
+ for ind in range(len(self.shape)):
1611
+ if ind == axis_index:
1612
+ slices.append(sorted_index)
1613
+ else:
1614
+ if slices[-1] is Ellipsis: # only one ellipsis
1615
+ slices.append(Ellipsis)
1616
+ slices = tuple(slices)
1617
+
1618
+ for nav_index in self.nav_indexes:
1619
+ for axis in self.get_axis_from_index(nav_index):
1620
+ axis.data = axis.get_data()[sorted_index]
1621
+
1622
+ return sorted_index, slices
1623
+ else:
1624
+ return None, None
1625
+
1626
+ def get_axis_from_index_spread(self, index: int, spread_order: int) -> Axis:
1627
+ for axis in self.axes:
1628
+ if axis.index == index and axis.spread_order == spread_order:
1629
+ return axis
1630
+
1631
+ def _get_dimension_str(self):
1632
+ try:
1633
+ string = "("
1634
+ for nav_index in self.nav_indexes:
1635
+ string += str(self._data_shape[nav_index]) + ", "
1636
+ break
1637
+ string = string.rstrip(", ")
1638
+ string += "|"
1639
+ for sig_index in self.sig_indexes:
1640
+ string += str(self._data_shape[sig_index]) + ", "
1641
+ string = string.rstrip(", ")
1642
+ string += ")"
1643
+ except Exception as e:
1644
+ string = f'({self._data_shape})'
1645
+ finally:
1646
+ return string
1647
+
1648
+
1649
+ class DataWithAxes(DataBase):
1650
+ """Data object with Axis objects corresponding to underlying data nd-arrays
1651
+
1652
+ Parameters
1653
+ ----------
1654
+ axes: list of Axis
1655
+ the list of Axis object for proper plotting, calibration ...
1656
+ nav_indexes: tuple of int
1657
+ highlight which Axis in axes is Signal or Navigation axis depending on the content:
1658
+ For instance, nav_indexes = (2,), means that the axis with index 2 in a at least 3D ndarray data is the first
1659
+ navigation axis
1660
+ For instance, nav_indexes = (3,2), means that the axis with index 3 in a at least 4D ndarray data is the first
1661
+ navigation axis while the axis with index 2 is the second navigation Axis. Axes with index 0 and 1 are signal
1662
+ axes of 2D ndarray data
1663
+ errors: list of ndarray.
1664
+ The list should match the length of the data attribute while the ndarrays
1665
+ should match the data ndarray
1666
+ """
1667
+
1668
+ def __init__(self, *args, axes: List[Axis] = [],
1669
+ nav_indexes: Tuple[int] = (),
1670
+ errors: Iterable[np.ndarray] = None,
1671
+ **kwargs):
1672
+
1673
+ if 'nav_axes' in kwargs:
1674
+ deprecation_msg('nav_axes parameter should not be used anymore, use nav_indexes')
1675
+ nav_indexes = kwargs.pop('nav_axes')
1676
+
1677
+ x_axis = kwargs.pop('x_axis') if 'x_axis' in kwargs else None
1678
+ y_axis = kwargs.pop('y_axis') if 'y_axis' in kwargs else None
1679
+
1680
+ nav_x_axis = kwargs.pop('nav_x_axis') if 'nav_x_axis' in kwargs else None
1681
+ nav_y_axis = kwargs.pop('nav_y_axis') if 'nav_y_axis' in kwargs else None
1682
+
1683
+ super().__init__(*args, **kwargs)
1684
+
1685
+ self._axes = axes
1686
+
1687
+ other_kwargs = dict(x_axis=x_axis, y_axis=y_axis, nav_x_axis=nav_x_axis, nav_y_axis=nav_y_axis)
1688
+
1689
+ self.set_axes_manager(self.shape, axes=axes, nav_indexes=nav_indexes, **other_kwargs)
1690
+
1691
+ self.inav: Iterable[DataWithAxes] = SpecialSlicersData(self, True)
1692
+ self.isig: Iterable[DataWithAxes] = SpecialSlicersData(self, False)
1693
+
1694
+ self.get_dim_from_data_axes() # in DataBase, dim is processed from the shape of data, but if axes are provided
1695
+ #then use get_dim_from axes
1696
+ self._check_errors(errors)
1697
+
1698
+ def _check_errors(self, errors: Iterable[np.ndarray]):
1699
+ """ Make sure the errors object is adapted to the len/shape of the dwa object
1700
+
1701
+ new in 4.2.0
1702
+ """
1703
+ check = False
1704
+ if errors is None:
1705
+ self._errors = None
1706
+ return
1707
+ if isinstance(errors, (tuple, list)) and len(errors) == len(self):
1708
+ if np.all([isinstance(error, np.ndarray) for error in errors]):
1709
+ if np.all([error_array.shape == self.shape for error_array in errors]):
1710
+ check = True
1711
+ else:
1712
+ logger.warning(f'All error objects should have the same shape as the data'
1713
+ f'objects')
1714
+ else:
1715
+ logger.warning(f'All error objects should be np.ndarray')
1716
+
1717
+ if not check:
1718
+ logger.warning('the errors field is incompatible with the structure of the data')
1719
+ self._errors = None
1720
+ else:
1721
+ self._errors = errors
1722
+
1723
+ @property
1724
+ def errors(self):
1725
+ """ Get/Set the errors bar values as a list of np.ndarray
1726
+
1727
+ new in 4.2.0
1728
+ """
1729
+ return self._errors
1730
+
1731
+ @errors.setter
1732
+ def errors(self, errors: Iterable[np.ndarray]):
1733
+ self._check_errors(errors)
1734
+
1735
+ def get_error(self, index):
1736
+ """ Get a particular error ndarray at the given index in the list
1737
+
1738
+ new in 4.2.0
1739
+ """
1740
+ if self._errors is not None: #because to the initial check we know it is a list of ndarrays
1741
+ return self._errors[index]
1742
+ else:
1743
+ return np.array([0]) # this could be added to any numpy array of any shape
1744
+
1745
+ def errors_as_dwa(self):
1746
+ """ Get a dwa from self replacing the data content with the error attribute (if not None)
1747
+
1748
+ New in 4.2.0
1749
+ """
1750
+ if self.errors is not None:
1751
+ dwa = self.deepcopy_with_new_data(self.errors)
1752
+ dwa.name = f'{self.name}_errors'
1753
+ dwa.errors = None
1754
+ return dwa
1755
+ else:
1756
+ raise ValueError(f'Cannot create a dwa from a None, should be a list of ndarray')
1757
+
1758
+ def plot(self, plotter_backend: str = config('plotting', 'backend'), *args, viewer=None,
1759
+ **kwargs):
1760
+ """ Call a plotter factory and its plot method over the actual data"""
1761
+ return plotter_factory.get(plotter_backend).plot(self, *args, viewer=viewer, **kwargs)
1762
+
1763
+ def set_axes_manager(self, data_shape, axes, nav_indexes, **kwargs):
1764
+ if self.distribution.name == 'uniform' or len(nav_indexes) == 0:
1765
+ self._distribution = DataDistribution['uniform']
1766
+ self.axes_manager = AxesManagerUniform(data_shape=data_shape, axes=axes,
1767
+ nav_indexes=nav_indexes,
1768
+ **kwargs)
1769
+ elif self.distribution.name == 'spread':
1770
+ self.axes_manager = AxesManagerSpread(data_shape=data_shape, axes=axes,
1771
+ nav_indexes=nav_indexes,
1772
+ **kwargs)
1773
+ else:
1774
+ raise ValueError(f'Such a data distribution ({data.distribution}) has no AxesManager')
1775
+
1776
+ def __eq__(self, other):
1777
+ is_equal = super().__eq__(other)
1778
+ if not is_equal:
1779
+ return is_equal
1780
+ if isinstance(other, DataWithAxes):
1781
+ for ind in list(self.nav_indexes) + list(self.sig_indexes):
1782
+ axes_self = self.get_axis_from_index(ind)
1783
+ axes_other = other.get_axis_from_index(ind)
1784
+ if len(axes_other) != len(axes_self):
1785
+ return False
1786
+ for ind_ax in range(len(axes_self)):
1787
+ if axes_self[ind_ax] != axes_other[ind_ax]:
1788
+ return False
1789
+ if self.errors is None:
1790
+ is_equal = is_equal and other.errors is None
1791
+ else:
1792
+ for ind_error in range(len(self.errors)):
1793
+ if not np.allclose(self.errors[ind_error], other.errors[ind_error]):
1794
+ return False
1795
+ return is_equal
1796
+
1797
+ def __repr__(self):
1798
+ return (f'<{self.__class__.__name__}: {self.name} '
1799
+ f'<u: {self.units}> '
1800
+ f'<len:{self.length}> {self._am}>')
1801
+
1802
+ def sort_data(self, axis_index: int = 0, spread_index=0, inplace=False) -> DataWithAxes:
1803
+ """ Sort data along a given axis, default is 0
1804
+
1805
+ Parameters
1806
+ ----------
1807
+ axis_index: int
1808
+ The index along which one should sort the data
1809
+ spread_index: int
1810
+ for spread data only, specifies which spread axis to use
1811
+ inplace: bool
1812
+ modify in place or not the data (and its axes)
1813
+
1814
+ Returns
1815
+ -------
1816
+ DataWithAxes
1817
+ """
1818
+ if inplace:
1819
+ data = self
1820
+ else:
1821
+ data = self.deepcopy()
1822
+ sorted_index, slices = data._am.get_sorted_index(axis_index, spread_index)
1823
+ if sorted_index is not None:
1824
+ for ind in range(len(data)):
1825
+ data.data[ind] = data.data[ind][slices]
1826
+ return data
1827
+
1828
+ def transpose(self):
1829
+ """replace the data by their transposed version
1830
+
1831
+ Valid only for 2D data
1832
+ """
1833
+ if self.dim == 'Data2D':
1834
+ self.data[:] = [data.T for data in self.data]
1835
+ for axis in self.axes:
1836
+ axis.index = 0 if axis.index == 1 else 1
1837
+
1838
+ def crop_at_along(self, coordinates_tuple: Tuple):
1839
+ slices = []
1840
+ for coordinates in coordinates_tuple:
1841
+ axis = self.get_axis_from_index(0)[0]
1842
+ indexes = axis.find_indexes(coordinates)
1843
+ slices.append(slice(indexes))
1844
+
1845
+ return self._slicer(slices, False)
1846
+
1847
+ def mean(self, axis: int = 0) -> DataWithAxes:
1848
+ """Process the mean of the data on the specified axis and returns the new data
1849
+
1850
+ Parameters
1851
+ ----------
1852
+ axis: int
1853
+
1854
+ Returns
1855
+ -------
1856
+ DataWithAxes
1857
+ """
1858
+ dat_mean = []
1859
+ for dat in self.data:
1860
+ mean = np.mean(dat, axis=axis)
1861
+ if isinstance(mean, numbers.Number):
1862
+ mean = np.array([mean])
1863
+ dat_mean.append(mean)
1864
+ return self.deepcopy_with_new_data(dat_mean, remove_axes_index=axis)
1865
+
1866
+ def sum(self, axis: int = 0) -> DataWithAxes:
1867
+ """Process the sum of the data on the specified axis and returns the new data
1868
+
1869
+ Parameters
1870
+ ----------
1871
+ axis: int
1872
+
1873
+ Returns
1874
+ -------
1875
+ DataWithAxes
1876
+ """
1877
+ dat_sum = []
1878
+ for dat in self.data:
1879
+ dat_sum.append(np.sum(dat, axis=axis))
1880
+ return self.deepcopy_with_new_data(dat_sum, remove_axes_index=axis)
1881
+
1882
+ def interp(self, new_axis_data: Union[Axis, np.ndarray], **kwargs) -> DataWithAxes:
1883
+ """Performs linear interpolation for 1D data only.
1884
+
1885
+ For more complex ones, see :py:meth:`scipy.interpolate`
1886
+
1887
+ Parameters
1888
+ ----------
1889
+ new_axis_data: Union[Axis, np.ndarray]
1890
+ The coordinates over which to do the interpolation
1891
+ kwargs: dict
1892
+ extra named parameters to be passed to the :py:meth:`~numpy.interp` method
1893
+
1894
+ Returns
1895
+ -------
1896
+ DataWithAxes
1897
+
1898
+ See Also
1899
+ --------
1900
+ :py:meth:`~numpy.interp`
1901
+ :py:meth:`~scipy.interpolate`
1902
+ """
1903
+ if self.dim != DataDim['Data1D']:
1904
+ raise ValueError('For basic interpolation, only 1D data are supported')
1905
+
1906
+ data_interpolated = []
1907
+ axis_obj = self.get_axis_from_index(0)[0]
1908
+ if isinstance(new_axis_data, np.ndarray):
1909
+ new_axis_data = Axis(axis_obj.label, axis_obj.units, data=new_axis_data)
1910
+
1911
+ for dat in self.data:
1912
+ data_interpolated.append(np.interp(new_axis_data.get_data(), axis_obj.get_data(), dat,
1913
+ **kwargs))
1914
+ new_data = DataCalculated(f'{self.name}_interp', data=data_interpolated,
1915
+ axes=[new_axis_data],
1916
+ labels=self.labels)
1917
+ return new_data
1918
+
1919
+ def ft(self, axis: int = 0) -> DataWithAxes:
1920
+ """Process the Fourier Transform of the data on the specified axis and returns the new data
1921
+
1922
+ Parameters
1923
+ ----------
1924
+ axis: int
1925
+
1926
+ Returns
1927
+ -------
1928
+ DataWithAxes
1929
+
1930
+ See Also
1931
+ --------
1932
+ :py:meth:`~pymodaq.utils.math_utils.ft`, :py:meth:`~numpy.fft.fft`
1933
+ """
1934
+ dat_ft = []
1935
+ axis_obj = self.get_axis_from_index(axis)[0].copy()
1936
+ omega_grid, time_grid = mutils.ftAxis_time(len(axis_obj),
1937
+ np.abs(axis_obj.max() - axis_obj.min()))
1938
+ for dat in self.data:
1939
+ dat_ft.append(mutils.ft(dat, dim=axis))
1940
+ new_data = self.deepcopy_with_new_data(dat_ft)
1941
+ axis_obj = new_data.get_axis_from_index(axis)[0]
1942
+ axis_obj.data = omega_grid
1943
+ axis_obj.label = f'ft({axis_obj.label})'
1944
+ axis_obj.units = f'rad/{axis_obj.units}'
1945
+ return new_data
1946
+
1947
+ def ift(self, axis: int = 0) -> DataWithAxes:
1948
+ """Process the inverse Fourier Transform of the data on the specified axis and returns the
1949
+ new data
1950
+
1951
+ Parameters
1952
+ ----------
1953
+ axis: int
1954
+
1955
+ Returns
1956
+ -------
1957
+ DataWithAxes
1958
+
1959
+ See Also
1960
+ --------
1961
+ :py:meth:`~pymodaq.utils.math_utils.ift`, :py:meth:`~numpy.fft.ifft`
1962
+ """
1963
+ dat_ift = []
1964
+ axis_obj = self.get_axis_from_index(axis)[0].copy()
1965
+ omega_grid, time_grid = mutils.ftAxis_time(len(axis_obj),
1966
+ np.abs(axis_obj.max() - axis_obj.min()))
1967
+ for dat in self.data:
1968
+ dat_ift.append(mutils.ift(dat, dim=axis))
1969
+ new_data = self.deepcopy_with_new_data(dat_ift)
1970
+ axis_obj = new_data.get_axis_from_index(axis)[0]
1971
+ axis_obj.data = omega_grid
1972
+ axis_obj.label = f'ift({axis_obj.label})'
1973
+ axis_obj.units = str(Unit(f'rad/({axis_obj.units})'))
1974
+ return new_data
1975
+
1976
+ def fit(self, function: Callable, initial_guess: IterableType, data_index: int = None,
1977
+ axis_index: int = 0, **kwargs) -> DataCalculated:
1978
+ """ Apply 1D curve fitting using the scipy optimization package
1979
+
1980
+ Parameters
1981
+ ----------
1982
+ function: Callable
1983
+ a callable to be used for the fit
1984
+ initial_guess: Iterable
1985
+ The initial parameters for the fit
1986
+ data_index: int
1987
+ The index of the data over which to do the fit, if None apply the fit to all
1988
+ axis_index: int
1989
+ the axis index to use for the fit (if multiple) but there should be only one
1990
+ kwargs: dict
1991
+ extra named parameters applied to the curve_fit scipy method
1992
+
1993
+ Returns
1994
+ -------
1995
+ DataCalculated containing the evaluation of the fit on the specified axis
1996
+
1997
+ See Also
1998
+ --------
1999
+ :py:meth:`~scipy.optimize.curve_fit`
2000
+ """
2001
+ import scipy.optimize as opt
2002
+ if self.dim != DataDim['Data1D']:
2003
+ raise ValueError('Integrated fitting only works for 1D data')
2004
+ axis = self.get_axis_from_index(axis_index)[0].copy()
2005
+ axis_array = axis.get_data()
2006
+ if data_index is None:
2007
+ datalist_to_fit = self.data
2008
+ labels = [f'{label}_fit' for label in self.labels]
2009
+ else:
2010
+ datalist_to_fit = [self.data[data_index]]
2011
+ labels = [f'{self.labels[data_index]}_fit']
2012
+
2013
+ datalist_fitted = []
2014
+ fit_coeffs = []
2015
+ for data_array in datalist_to_fit:
2016
+ popt, pcov = opt.curve_fit(function, axis_array, data_array, p0=initial_guess, **kwargs)
2017
+ datalist_fitted.append(function(axis_array, *popt))
2018
+ fit_coeffs.append(popt)
2019
+
2020
+ return DataCalculated(f'{self.name}_fit', data=datalist_fitted,
2021
+ labels=labels,
2022
+ axes=[axis], fit_coeffs=fit_coeffs)
2023
+
2024
+ def find_peaks(self, height=None, threshold=None, **kwargs) -> DataToExport:
2025
+ """ Apply the scipy find_peaks method to 1D data
2026
+
2027
+ Parameters
2028
+ ----------
2029
+ height: number or ndarray or sequence, optional
2030
+ threshold: number or ndarray or sequence, optional
2031
+ kwargs: dict
2032
+ extra named parameters applied to the find_peaks scipy method
2033
+
2034
+ Returns
2035
+ -------
2036
+ DataCalculated
2037
+
2038
+ See Also
2039
+ --------
2040
+ :py:meth:`~scipy.optimize.find_peaks`
2041
+ """
2042
+ if self.dim != DataDim['Data1D']:
2043
+ raise ValueError('Finding peaks only works for 1D data')
2044
+ from scipy.signal import find_peaks
2045
+ peaks_indices = []
2046
+ dte = DataToExport('peaks')
2047
+ for ind in range(len(self)):
2048
+ peaks, properties = find_peaks(self[ind], height, threshold, **kwargs)
2049
+ peaks_indices.append(peaks)
2050
+
2051
+ dte.append(DataCalculated(f'{self.labels[ind]}',
2052
+ data=[self[ind][peaks_indices[-1]],
2053
+ peaks_indices[-1]
2054
+ ],
2055
+ labels=['peak value', 'peak indexes'],
2056
+ axes=[Axis('peak position', self.axes[0].units,
2057
+ data=self.axes[0].get_data_at(peaks_indices[-1]))])
2058
+ )
2059
+ return dte
2060
+
2061
+ def get_dim_from_data_axes(self) -> DataDim:
2062
+ """Get the dimensionality DataDim from data taking into account nav indexes
2063
+ """
2064
+ if len(self.axes) != len(self.shape):
2065
+ self._dim = self.get_dim_from_data(self.data)
2066
+ else:
2067
+ if len(self.nav_indexes) > 0:
2068
+ self._dim = DataDim['DataND']
2069
+ else:
2070
+ if len(self.axes) == 0:
2071
+ self._dim = DataDim['Data0D']
2072
+ elif len(self.axes) == 1:
2073
+ self._dim = DataDim['Data1D']
2074
+ elif len(self.axes) == 2:
2075
+ self._dim = DataDim['Data2D']
2076
+ if len(self.nav_indexes) > 0:
2077
+ self._dim = DataDim['DataND']
2078
+ return self._dim
2079
+
2080
+ @property
2081
+ def n_axes(self):
2082
+ """Get the number of axes (even if not specified)"""
2083
+ return len(self.axes)
2084
+
2085
+ @property
2086
+ def axes(self):
2087
+ """convenience property to fetch attribute from axis_manager"""
2088
+ return self._am.axes
2089
+
2090
+ @axes.setter
2091
+ def axes(self, axes: List[Axis]):
2092
+ """convenience property to set attribute from axis_manager"""
2093
+ self.set_axes_manager(self.shape, axes=axes, nav_indexes=self.nav_indexes)
2094
+
2095
+ def axes_limits(self, axes_indexes: List[int] = None) -> List[Tuple[float, float]]:
2096
+ """Get the limits of specified axes (all if axes_indexes is None)"""
2097
+ if axes_indexes is None:
2098
+ return [(axis.min(), axis.max()) for axis in self.axes]
2099
+ else:
2100
+ return [(axis.min(), axis.max()) for axis in self.axes if axis.index in axes_indexes]
2101
+
2102
+ @property
2103
+ def sig_indexes(self):
2104
+ """convenience property to fetch attribute from axis_manager"""
2105
+ return self._am.sig_indexes
2106
+
2107
+ @property
2108
+ def nav_indexes(self):
2109
+ """convenience property to fetch attribute from axis_manager"""
2110
+ return self._am.nav_indexes
2111
+
2112
+ @nav_indexes.setter
2113
+ def nav_indexes(self, indexes: List[int]):
2114
+ """create new axis manager with new navigation indexes"""
2115
+ self.set_axes_manager(self.shape, axes=self.axes, nav_indexes=indexes)
2116
+ self.get_dim_from_data_axes()
2117
+
2118
+ def get_nav_axes(self) -> List[Axis]:
2119
+ return self._am.get_nav_axes()
2120
+
2121
+ def get_sig_index(self) -> List[Axis]:
2122
+ return self._am.get_signal_axes()
2123
+
2124
+ def get_nav_axes_with_data(self) -> List[Axis]:
2125
+ """Get the data's navigation axes making sure there is data in the data field"""
2126
+ axes = self.get_nav_axes()
2127
+ for axis in axes:
2128
+ if axis.get_data() is None:
2129
+ axis.create_linear_data(self.shape[axis.index])
2130
+ return axes
2131
+
2132
+ def get_axis_indexes(self) -> List[int]:
2133
+ """Get all present different axis indexes"""
2134
+ return sorted(list(set([axis.index for axis in self.axes])))
2135
+
2136
+ def get_axis_from_index(self, index, create=False):
2137
+ return self._am.get_axis_from_index(index, create)
2138
+
2139
+ def get_axis_from_index_spread(self, index: int, spread: int):
2140
+ return self._am.get_axis_from_index_spread(index, spread)
2141
+
2142
+ def get_axis_from_label(self, label: str) -> Axis:
2143
+ """Get the axis referred by a given label
2144
+
2145
+ Parameters
2146
+ ----------
2147
+ label: str
2148
+ The label of the axis
2149
+
2150
+ Returns
2151
+ -------
2152
+ Axis or None: return the axis instance if it has the right label else None
2153
+ """
2154
+ for axis in self.axes:
2155
+ if axis.label == label:
2156
+ return axis
2157
+
2158
+ def create_missing_axes(self):
2159
+ """Check if given the data shape, some axes are missing to properly define the data
2160
+ (especially for plotting)"""
2161
+ axes = self.axes[:]
2162
+ for index in self.nav_indexes + self.sig_indexes:
2163
+ if (len(self.get_axis_from_index(index)) != 0 and
2164
+ self.get_axis_from_index(index)[0] is None):
2165
+ axes_tmp = self.get_axis_from_index(index, create=True)
2166
+ for ax in axes_tmp:
2167
+ if ax.size > 1:
2168
+ axes.append(ax)
2169
+ self.axes = axes
2170
+
2171
+ def _compute_slices(self, slices, is_navigation=True):
2172
+ """Compute the total slice to apply to the data
2173
+
2174
+ Filling in Ellipsis when no slicing should be done
2175
+ """
2176
+ if isinstance(slices, numbers.Number) or isinstance(slices, slice):
2177
+ slices = [slices]
2178
+ if is_navigation:
2179
+ indexes = self._am.nav_indexes
2180
+ else:
2181
+ indexes = self._am.sig_indexes
2182
+ total_slices = []
2183
+ slices = list(slices)
2184
+ for ind in range(len(self.shape)):
2185
+ if ind in indexes:
2186
+ total_slices.append(slices.pop(0))
2187
+ elif len(total_slices) == 0:
2188
+ total_slices.append(Ellipsis)
2189
+ elif not (Ellipsis in total_slices and total_slices[-1] is Ellipsis):
2190
+ total_slices.append(slice(None))
2191
+ total_slices = tuple(total_slices)
2192
+ return total_slices
2193
+
2194
+ def check_squeeze(self, total_slices: List[slice], is_navigation: bool):
2195
+
2196
+ do_squeeze = True
2197
+ if 1 in self.data[0][total_slices].shape:
2198
+ if not is_navigation and self.data[0][total_slices].shape.index(1) in self.nav_indexes:
2199
+ do_squeeze = False
2200
+ elif is_navigation and self.data[0][total_slices].shape.index(1) in self.sig_indexes:
2201
+ do_squeeze = False
2202
+ return do_squeeze
2203
+
2204
+ def _slicer(self, slices, is_navigation=True):
2205
+ """Apply a given slice to the data either navigation or signal dimension
2206
+
2207
+ Parameters
2208
+ ----------
2209
+ slices: tuple of slice or int
2210
+ the slices to apply to the data
2211
+ is_navigation: bool
2212
+ if True apply the slices to the navigation dimension else to the signal ones
2213
+
2214
+ Returns
2215
+ -------
2216
+ DataWithAxes
2217
+ Object of the same type as the initial data, derived from DataWithAxes. But with lower
2218
+ data size due to the slicing and with eventually less axes.
2219
+ """
2220
+
2221
+ if isinstance(slices, numbers.Number) or isinstance(slices, slice):
2222
+ slices = [slices]
2223
+ total_slices = self._compute_slices(slices, is_navigation)
2224
+
2225
+ do_squeeze = self.check_squeeze(total_slices, is_navigation)
2226
+ new_arrays_data = [squeeze(dat[total_slices], do_squeeze) for dat in self.data]
2227
+ tmp_axes = self._am.get_signal_axes() if is_navigation else self._am.get_nav_axes()
2228
+ axes_to_append = [copy.deepcopy(axis) for axis in tmp_axes]
2229
+
2230
+ # axes_to_append are the axes to append to the new produced data
2231
+ # (basically the ones to keep)
2232
+
2233
+ indexes_to_get = self.nav_indexes if is_navigation else self.sig_indexes
2234
+ # indexes_to_get are the indexes of the axes where the slice should be applied
2235
+
2236
+ _indexes = list(self.nav_indexes)
2237
+ _indexes.extend(self.sig_indexes)
2238
+ lower_indexes = dict(zip(_indexes, [0 for _ in range(len(_indexes))]))
2239
+ # lower_indexes will store for each *axis index* how much the index should be reduced
2240
+ # because one axis has
2241
+ # been removed
2242
+
2243
+ axes = []
2244
+ nav_indexes = [] if is_navigation else list(self._am.nav_indexes)
2245
+ for ind_slice, _slice in enumerate(slices):
2246
+ if ind_slice < len(indexes_to_get):
2247
+ ax = self._am.get_axis_from_index(indexes_to_get[ind_slice])
2248
+ if len(ax) != 0 and ax[0] is not None:
2249
+ for ind in range(len(ax)):
2250
+ ax[ind] = ax[ind].iaxis[_slice]
2251
+
2252
+ if not(ax[0] is None or ax[0].size <= 1): # means the slice kept part of the axis
2253
+ if is_navigation:
2254
+ nav_indexes.append(self._am.nav_indexes[ind_slice])
2255
+ axes.extend(ax)
2256
+ else:
2257
+ for axis in axes_to_append: # means we removed one of the axes (and data dim),
2258
+ # hence axis index above current index should be lowered by 1
2259
+ if axis.index > indexes_to_get[ind_slice]:
2260
+ lower_indexes[axis.index] += 1
2261
+ for index in indexes_to_get[ind_slice+1:]:
2262
+ lower_indexes[index] += 1
2263
+
2264
+ axes.extend(axes_to_append)
2265
+ for axis in axes:
2266
+ axis.index -= lower_indexes[axis.index]
2267
+ for ind in range(len(nav_indexes)):
2268
+ nav_indexes[ind] -= lower_indexes[nav_indexes[ind]]
2269
+
2270
+ if len(nav_indexes) != 0:
2271
+ distribution = self.distribution
2272
+ else:
2273
+ distribution = DataDistribution['uniform']
2274
+
2275
+ data = DataWithAxes(self.name, data=new_arrays_data, nav_indexes=tuple(nav_indexes),
2276
+ axes=axes,
2277
+ source='calculated', origin=self.origin,
2278
+ labels=self.labels[:],
2279
+ distribution=distribution)
2280
+ return data
2281
+
2282
+ def deepcopy_with_new_data(self, data: List[np.ndarray] = None,
2283
+ remove_axes_index: Union[int, List[int]] = None,
2284
+ source: DataSource = 'calculated',
2285
+ keep_dim=False) -> DataWithAxes:
2286
+ """deepcopy without copying the initial data (saving memory)
2287
+
2288
+ The new data, may have some axes stripped as specified in remove_axes_index
2289
+
2290
+ Parameters
2291
+ ----------
2292
+ data: list of numpy ndarray
2293
+ The new data
2294
+ remove_axes_index: tuple of int
2295
+ indexes of the axis to be removed
2296
+ source: DataSource
2297
+ keep_dim: bool
2298
+ if False (the default) will calculate the new dim based on the data shape
2299
+ else keep the same (be aware it could lead to issues)
2300
+
2301
+ Returns
2302
+ -------
2303
+ DataWithAxes
2304
+ """
2305
+ try:
2306
+ old_data = self.data
2307
+ self._data = None
2308
+ new_data = self.deepcopy()
2309
+ new_data._data = data
2310
+ new_data.get_dim_from_data(data)
2311
+
2312
+ if source is not None:
2313
+ source = enum_checker(DataSource, source)
2314
+ new_data._source = source
2315
+
2316
+ if remove_axes_index is not None:
2317
+ if not isinstance(remove_axes_index, Iterable):
2318
+ remove_axes_index = [remove_axes_index]
2319
+
2320
+ lower_indexes = dict(zip(new_data.get_axis_indexes(),
2321
+ [0 for _ in range(len(new_data.get_axis_indexes()))]))
2322
+ # lower_indexes will store for each *axis index* how much the index should be reduced because one axis has
2323
+ # been removed
2324
+
2325
+ nav_indexes = list(new_data.nav_indexes)
2326
+ sig_indexes = list(new_data.sig_indexes)
2327
+ for index in remove_axes_index:
2328
+ for axis in new_data.get_axis_from_index(index):
2329
+ if axis is not None:
2330
+ new_data.axes.remove(axis)
2331
+
2332
+ if index in new_data.nav_indexes:
2333
+ nav_indexes.pop(nav_indexes.index(index))
2334
+ if index in new_data.sig_indexes:
2335
+ sig_indexes.pop(sig_indexes.index(index))
2336
+
2337
+ # for ind, nav_ind in enumerate(nav_indexes):
2338
+ # if nav_ind > index and nav_ind not in remove_axes_index:
2339
+ # nav_indexes[ind] -= 1
2340
+
2341
+ # for ind, sig_ind in enumerate(sig_indexes):
2342
+ # if sig_ind > index:
2343
+ # sig_indexes[ind] -= 1
2344
+ for axis in new_data.axes:
2345
+ if axis.index > index and axis.index not in remove_axes_index:
2346
+ lower_indexes[axis.index] += 1
2347
+
2348
+ for axis in new_data.axes:
2349
+ axis.index -= lower_indexes[axis.index]
2350
+ for ind in range(len(nav_indexes)):
2351
+ nav_indexes[ind] -= lower_indexes[nav_indexes[ind]]
2352
+
2353
+ new_data.nav_indexes = tuple(nav_indexes)
2354
+ # new_data._am.sig_indexes = tuple(sig_indexes)
2355
+
2356
+ new_data._shape = data[0].shape
2357
+ if not keep_dim:
2358
+ new_data._dim = self._get_dim_from_data(data)
2359
+ return new_data
2360
+
2361
+ except Exception as e:
2362
+ pass
2363
+ finally:
2364
+ self._data = old_data
2365
+
2366
+ @property
2367
+ def _am(self) -> AxesManagerBase:
2368
+ return self.axes_manager
2369
+
2370
+ def get_data_dimension(self) -> str:
2371
+ return str(self._am)
2372
+
2373
+ def get_data_as_dwa(self, index: int = 0) -> DataWithAxes:
2374
+ """ Get the underlying data selected from the list at index, returned as a DataWithAxes"""
2375
+ return self.deepcopy_with_new_data([self[index]])
2376
+
2377
+
2378
+ class DataRaw(DataWithAxes):
2379
+ """Specialized DataWithAxes set with source as 'raw'. To be used for raw data"""
2380
+ def __init__(self, *args, **kwargs):
2381
+ if 'source' in kwargs:
2382
+ kwargs.pop('source')
2383
+ super().__init__(*args, source=DataSource['raw'], **kwargs)
2384
+
2385
+
2386
+ class DataCalculated(DataWithAxes):
2387
+ """Specialized DataWithAxes set with source as 'calculated'. To be used for processed/calculated data"""
2388
+ def __init__(self, *args, axes=[], **kwargs):
2389
+ if 'source' in kwargs:
2390
+ kwargs.pop('source')
2391
+ super().__init__(*args, source=DataSource['calculated'], axes=axes, **kwargs)
2392
+
2393
+
2394
+ class DataFromRoi(DataCalculated):
2395
+ """Specialized DataWithAxes set with source as 'calculated'.To be used for processed data from region of interest"""
2396
+ def __init__(self, *args, axes=[], **kwargs):
2397
+ super().__init__(*args, axes=axes, **kwargs)
2398
+
2399
+
2400
+ class DataToExport(DataLowLevel):
2401
+ """Object to store all raw and calculated DataWithAxes data for later exporting, saving, sending signal...
2402
+
2403
+ Includes methods to retrieve data from dim, source...
2404
+ Stored data have a unique identifier their name. If some data is appended with an existing name, it will replace
2405
+ the existing data. So if you want to append data that has the same name
2406
+
2407
+ Parameters
2408
+ ----------
2409
+ name: str
2410
+ The identifier of the exporting object
2411
+ data: list of DataWithAxes
2412
+ All the raw and calculated data to be exported
2413
+
2414
+ Attributes
2415
+ ----------
2416
+ name
2417
+ timestamp
2418
+ data
2419
+ """
2420
+
2421
+ def __init__(self, name: str, data: List[DataWithAxes] = [], **kwargs):
2422
+ """
2423
+
2424
+ Parameters
2425
+ ----------
2426
+ name
2427
+ data
2428
+ """
2429
+ super().__init__(name)
2430
+ if not isinstance(data, list):
2431
+ raise TypeError('Data stored in a DataToExport object should be as a list of objects'
2432
+ ' inherited from DataWithAxis')
2433
+ self._data = []
2434
+
2435
+ self.data = data
2436
+ for key in kwargs:
2437
+ setattr(self, key, kwargs[key])
2438
+
2439
+ def plot(self, plotter_backend: str = config('plotting', 'backend'), *args, **kwargs):
2440
+ """ Call a plotter factory and its plot method over the actual data"""
2441
+ return plotter_factory.get(plotter_backend).plot(self, *args, **kwargs)
2442
+
2443
+ def affect_name_to_origin_if_none(self):
2444
+ """Affect self.name to all DataWithAxes children's attribute origin if this origin is not defined"""
2445
+ for dat in self.data:
2446
+ if dat.origin is None or dat.origin == '':
2447
+ dat.origin = self.name
2448
+
2449
+ def __sub__(self, other: object):
2450
+ if isinstance(other, DataToExport) and len(other) == len(self):
2451
+ new_data = copy.deepcopy(self)
2452
+ for ind_dfp in range(len(self)):
2453
+ new_data[ind_dfp] = self[ind_dfp] - other[ind_dfp]
2454
+ return new_data
2455
+ else:
2456
+ raise TypeError(f'Could not substract a {other.__class__.__name__} or a {self.__class__.__name__} '
2457
+ f'of a different length')
2458
+
2459
+ def __add__(self, other: object):
2460
+ if isinstance(other, DataToExport) and len(other) == len(self):
2461
+ new_data = copy.deepcopy(self)
2462
+ for ind_dfp in range(len(self)):
2463
+ new_data[ind_dfp] = self[ind_dfp] + other[ind_dfp]
2464
+ return new_data
2465
+ else:
2466
+ raise TypeError(f'Could not add a {other.__class__.__name__} or a {self.__class__.__name__} '
2467
+ f'of a different length')
2468
+
2469
+ def __mul__(self, other: object):
2470
+ if isinstance(other, numbers.Number):
2471
+ new_data = copy.deepcopy(self)
2472
+ for ind_dfp in range(len(self)):
2473
+ new_data[ind_dfp] = self[ind_dfp] * other
2474
+ return new_data
2475
+ else:
2476
+ raise TypeError(f'Could not multiply a {other.__class__.__name__} with a {self.__class__.__name__} '
2477
+ f'of a different length')
2478
+
2479
+ def __truediv__(self, other: object):
2480
+ if isinstance(other, numbers.Number):
2481
+ return self * (1 / other)
2482
+ else:
2483
+ raise TypeError(f'Could not divide a {other.__class__.__name__} with a {self.__class__.__name__} '
2484
+ f'of a different length')
2485
+
2486
+ def average(self, other: DataToExport, weight: int) -> DataToExport:
2487
+ """ Compute the weighted average between self and other DataToExport and attributes it to self
2488
+
2489
+ Parameters
2490
+ ----------
2491
+ other: DataToExport
2492
+ weight: int
2493
+ The weight the 'other_data' holds with respect to self
2494
+
2495
+ """
2496
+ if isinstance(other, DataToExport) and len(other) == len(self):
2497
+ new_data = copy.copy(self)
2498
+ for ind_dfp in range(len(self)):
2499
+ new_data[ind_dfp] = self[ind_dfp].average(other[ind_dfp], weight)
2500
+ return new_data
2501
+ else:
2502
+ raise TypeError(f'Could not average a {other.__class__.__name__} with a {self.__class__.__name__} '
2503
+ f'of a different length')
2504
+
2505
+ def merge_as_dwa(self, dim: Union[str, DataDim], name: str = None) -> DataRaw:
2506
+ """ attempt to merge filtered dwa into one
2507
+
2508
+ Only possible if all filtered dwa and underlying data have same shape
2509
+
2510
+ Parameters
2511
+ ----------
2512
+ dim: DataDim or str
2513
+ will only try to merge dwa having this dimensionality
2514
+ name: str
2515
+ The new name of the returned dwa
2516
+ """
2517
+ dim = enum_checker(DataDim, dim)
2518
+
2519
+ filtered_data = self.get_data_from_dim(dim)
2520
+ if len(filtered_data) != 0:
2521
+ dwa = filtered_data[0].deepcopy()
2522
+ for dwa_tmp in filtered_data[1:]:
2523
+ if dwa_tmp.shape == dwa.shape and dwa_tmp.distribution == dwa.distribution:
2524
+ dwa.append(dwa_tmp)
2525
+ if name is None:
2526
+ name = self.name
2527
+ dwa.name = name
2528
+ return dwa
2529
+
2530
+ def __repr__(self):
2531
+ repr = f'{self.__class__.__name__}: {self.name} <len:{len(self)}>\n'
2532
+ for dwa in self:
2533
+ repr += f' * {str(dwa)}\n'
2534
+ return repr
2535
+
2536
+ def __len__(self):
2537
+ return len(self.data)
2538
+
2539
+ def __iter__(self):
2540
+ self._iter_index = 0
2541
+ return self
2542
+
2543
+ def __next__(self) -> DataWithAxes:
2544
+ if self._iter_index < len(self):
2545
+ self._iter_index += 1
2546
+ return self.data[self._iter_index-1]
2547
+ else:
2548
+ raise StopIteration
2549
+
2550
+ def __getitem__(self, item) -> Union[DataWithAxes, DataToExport]:
2551
+ if isinstance(item, int) and 0 <= item < len(self):
2552
+ return self.data[item]
2553
+ elif isinstance(item, slice):
2554
+ return DataToExport(self.name, data=[self[ind] for ind in list(range(len(self))[item])])
2555
+ else:
2556
+ raise IndexError(f'The index should be a positive integer lower than the data length')
2557
+
2558
+ def __setitem__(self, key, value: DataWithAxes):
2559
+ if isinstance(key, int) and 0 <= key < len(self) and isinstance(value, DataWithAxes):
2560
+ self.data[key] = value
2561
+ else:
2562
+ raise IndexError(f'The index should be a positive integer lower than the data length')
2563
+
2564
+ def get_names(self, dim: DataDim = None) -> List[str]:
2565
+ """Get the names of the stored DataWithAxes, eventually filtered by dim
2566
+
2567
+ Parameters
2568
+ ----------
2569
+ dim: DataDim or str
2570
+
2571
+ Returns
2572
+ -------
2573
+ list of str: the names of the (filtered) DataWithAxes data
2574
+ """
2575
+ if dim is None:
2576
+ return [data.name for data in self.data]
2577
+ else:
2578
+ return [data.name for data in self.get_data_from_dim(dim).data]
2579
+
2580
+ def get_full_names(self, dim: DataDim = None):
2581
+ """Get the ful names including the origin attribute into the returned value, eventually filtered by dim
2582
+
2583
+ Parameters
2584
+ ----------
2585
+ dim: DataDim or str
2586
+
2587
+ Returns
2588
+ -------
2589
+ list of str: the names of the (filtered) DataWithAxes data constructed as : origin/name
2590
+
2591
+ Examples
2592
+ --------
2593
+ d0 = DataWithAxes(name='datafromdet0', origin='det0')
2594
+ """
2595
+ if dim is None:
2596
+ return [data.get_full_name() for data in self.data]
2597
+ else:
2598
+ return [data.get_full_name() for data in self.get_data_from_dim(dim).data]
2599
+
2600
+ def get_origins(self, dim: DataDim = None):
2601
+ """Get the origins of the underlying data into the returned value, eventually filtered by dim
2602
+
2603
+ Parameters
2604
+ ----------
2605
+ dim: DataDim or str
2606
+
2607
+ Returns
2608
+ -------
2609
+ list of str: the origins of the (filtered) DataWithAxes data
2610
+
2611
+ Examples
2612
+ --------
2613
+ d0 = DataWithAxes(name='datafromdet0', origin='det0')
2614
+ """
2615
+ if dim is None:
2616
+ return list({dwa.origin for dwa in self.data})
2617
+ else:
2618
+ return list({dwa.origin for dwa in self.get_data_from_dim(dim).data})
2619
+
2620
+
2621
+ def get_data_from_full_name(self, full_name: str, deepcopy=False) -> DataWithAxes:
2622
+ """Get the DataWithAxes with matching full name"""
2623
+ if deepcopy:
2624
+ data = self.get_data_from_name_origin(full_name.split('/')[1], full_name.split('/')[0]).deepcopy()
2625
+ else:
2626
+ data = self.get_data_from_name_origin(full_name.split('/')[1], full_name.split('/')[0])
2627
+ return data
2628
+
2629
+ def get_data_from_full_names(self, full_names: List[str], deepcopy=False) -> DataToExport:
2630
+ data = [self.get_data_from_full_name(full_name, deepcopy) for full_name in full_names]
2631
+ return DataToExport(name=self.name, data=data)
2632
+
2633
+ def get_dim_presents(self) -> List[str]:
2634
+ dims = []
2635
+ for dim in DataDim.names():
2636
+ if len(self.get_data_from_dim(dim)) != 0:
2637
+ dims.append(dim)
2638
+
2639
+ return dims
2640
+
2641
+ def get_data_from_source(self, source: DataSource, deepcopy=False) -> DataToExport:
2642
+ """Get the data matching the given DataSource
2643
+
2644
+ Returns
2645
+ -------
2646
+ DataToExport: filtered with data matching the dimensionality
2647
+ """
2648
+ source = enum_checker(DataSource, source)
2649
+ return self.get_data_from_attribute('source', source, deepcopy=deepcopy)
2650
+
2651
+ def get_data_from_missing_attribute(self, attribute: str, deepcopy=False) -> DataToExport:
2652
+ """ Get the data matching a given attribute value
2653
+
2654
+ Parameters
2655
+ ----------
2656
+ attribute: str
2657
+ a string of a possible attribute
2658
+ deepcopy: bool
2659
+ if True the returned DataToExport will contain deepcopies of the DataWithAxes
2660
+ Returns
2661
+ -------
2662
+ DataToExport: filtered with data missing the given attribute
2663
+ """
2664
+ if deepcopy:
2665
+ return DataToExport(self.name, data=[dwa.deepcopy() for dwa in self if not hasattr(dwa, attribute)])
2666
+ else:
2667
+ return DataToExport(self.name, data=[dwa for dwa in self if not hasattr(dwa, attribute)])
2668
+
2669
+ def get_data_from_attribute(self, attribute: str, attribute_value: Any, deepcopy=False) -> DataToExport:
2670
+ """Get the data matching a given attribute value
2671
+
2672
+ Returns
2673
+ -------
2674
+ DataToExport: filtered with data matching the attribute presence and value
2675
+ """
2676
+ selection = find_objects_in_list_from_attr_name_val(self.data, attribute, attribute_value,
2677
+ return_first=False)
2678
+ selection.sort(key=lambda elt: elt[0].name)
2679
+ if deepcopy:
2680
+ data = [sel[0].deepcopy() for sel in selection]
2681
+ else:
2682
+ data = [sel[0] for sel in selection]
2683
+ return DataToExport(name=self.name, data=data)
2684
+
2685
+ def get_data_from_dim(self, dim: DataDim, deepcopy=False) -> DataToExport:
2686
+ """Get the data matching the given DataDim
2687
+
2688
+ Returns
2689
+ -------
2690
+ DataToExport: filtered with data matching the dimensionality
2691
+ """
2692
+ dim = enum_checker(DataDim, dim)
2693
+ return self.get_data_from_attribute('dim', dim, deepcopy=deepcopy)
2694
+
2695
+ def get_data_from_dims(self, dims: List[DataDim], deepcopy=False) -> DataToExport:
2696
+ """Get the data matching the given DataDim
2697
+
2698
+ Returns
2699
+ -------
2700
+ DataToExport: filtered with data matching the dimensionality
2701
+ """
2702
+ data = DataToExport(name=self.name)
2703
+ for dim in dims:
2704
+ data.append(self.get_data_from_dim(dim, deepcopy=deepcopy))
2705
+ return data
2706
+
2707
+ def get_data_from_sig_axes(self, Naxes: int, deepcopy: bool = False) -> DataToExport:
2708
+ """Get the data matching the given number of signal axes
2709
+
2710
+ Parameters
2711
+ ----------
2712
+ Naxes: int
2713
+ Number of signal axes in the DataWithAxes objects
2714
+
2715
+ Returns
2716
+ -------
2717
+ DataToExport: filtered with data matching the number of signal axes
2718
+ """
2719
+ data = DataToExport(name=self.name)
2720
+ for _data in self:
2721
+ if len(_data.sig_indexes) == Naxes:
2722
+ if deepcopy:
2723
+ data.append(_data.deepcopy())
2724
+ else:
2725
+ data.append(_data)
2726
+ return data
2727
+
2728
+ def get_data_from_Naxes(self, Naxes: int, deepcopy: bool = False) -> DataToExport:
2729
+ """Get the data matching the given number of axes
2730
+
2731
+ Parameters
2732
+ ----------
2733
+ Naxes: int
2734
+ Number of axes in the DataWithAxes objects
2735
+
2736
+ Returns
2737
+ -------
2738
+ DataToExport: filtered with data matching the number of axes
2739
+ """
2740
+ data = DataToExport(name=self.name)
2741
+ for _data in self:
2742
+ if len(_data.shape) == Naxes:
2743
+ if deepcopy:
2744
+ data.append(_data.deepcopy())
2745
+ else:
2746
+ data.append(_data)
2747
+ return data
2748
+
2749
+ def get_data_with_naxes_lower_than(self, n_axes=2, deepcopy: bool = False) -> DataToExport:
2750
+ """Get the data with n axes lower than the given number
2751
+
2752
+ Parameters
2753
+ ----------
2754
+ Naxes: int
2755
+ Number of axes in the DataWithAxes objects
2756
+
2757
+ Returns
2758
+ -------
2759
+ DataToExport: filtered with data matching the number of axes
2760
+ """
2761
+ data = DataToExport(name=self.name)
2762
+ for _data in self:
2763
+ if _data.n_axes <= n_axes:
2764
+ if deepcopy:
2765
+ data.append(_data.deepcopy())
2766
+ else:
2767
+ data.append(_data)
2768
+ return data
2769
+
2770
+ def get_data_from_name(self, name: str) -> DataWithAxes:
2771
+ """Get the data matching the given name"""
2772
+ data, _ = find_objects_in_list_from_attr_name_val(self.data, 'name', name, return_first=True)
2773
+ return data
2774
+
2775
+ def get_data_from_names(self, names: List[str]) -> DataToExport:
2776
+ return DataToExport(self.name, data=[dwa for dwa in self if dwa.name in names])
2777
+
2778
+ def get_data_from_name_origin(self, name: str, origin: str = '') -> DataWithAxes:
2779
+ """Get the data matching the given name and the given origin"""
2780
+ if origin == '':
2781
+ data, _ = find_objects_in_list_from_attr_name_val(self.data, 'name', name, return_first=True)
2782
+ else:
2783
+ selection = find_objects_in_list_from_attr_name_val(self.data, 'name', name, return_first=False)
2784
+ selection = [sel[0] for sel in selection]
2785
+ data, _ = find_objects_in_list_from_attr_name_val(selection, 'origin', origin)
2786
+ return data
2787
+
2788
+ def index(self, data: DataWithAxes):
2789
+ return self.data.index(data)
2790
+
2791
+ def index_from_name_origin(self, name: str, origin: str = '') -> List[DataWithAxes]:
2792
+ """Get the index of a given DataWithAxes within the list of data"""
2793
+ """Get the data matching the given name and the given origin"""
2794
+ if origin == '':
2795
+ _, index = find_objects_in_list_from_attr_name_val(self.data, 'name', name, return_first=True)
2796
+ else:
2797
+ selection = find_objects_in_list_from_attr_name_val(self.data, 'name', name, return_first=False)
2798
+ data_selection = [sel[0] for sel in selection]
2799
+ index_selection = [sel[1] for sel in selection]
2800
+ _, index = find_objects_in_list_from_attr_name_val(data_selection, 'origin', origin)
2801
+ index = index_selection[index]
2802
+ return index
2803
+
2804
+ def pop(self, index: int) -> DataWithAxes:
2805
+ """return and remove the DataWithAxes referred by its index
2806
+
2807
+ Parameters
2808
+ ----------
2809
+ index: int
2810
+ index as returned by self.index_from_name_origin
2811
+
2812
+ See Also
2813
+ --------
2814
+ index_from_name_origin
2815
+ """
2816
+ return self.data.pop(index)
2817
+
2818
+ def remove(self, dwa: DataWithAxes):
2819
+ return self.pop(self.data.index(dwa))
2820
+
2821
+ @property
2822
+ def data(self) -> List[DataWithAxes]:
2823
+ """List[DataWithAxes]: get the data contained in the object"""
2824
+ return self._data
2825
+
2826
+ @data.setter
2827
+ def data(self, new_data: List[DataWithAxes]):
2828
+ for dat in new_data:
2829
+ self._check_data_type(dat)
2830
+ self._data[:] = [dat for dat in new_data] # shallow copyto make sure that if the original
2831
+ # list is changed, the change will not be applied in here
2832
+
2833
+ self.affect_name_to_origin_if_none()
2834
+
2835
+ @staticmethod
2836
+ def _check_data_type(data: DataWithAxes):
2837
+ """Make sure data is a DataWithAxes object or inherited"""
2838
+ if not isinstance(data, DataWithAxes):
2839
+ raise TypeError('Data stored in a DataToExport object should be objects inherited from DataWithAxis')
2840
+
2841
+ def deepcopy(self):
2842
+ return DataToExport('Copy', data=[data.deepcopy() for data in self])
2843
+
2844
+ @dispatch(list)
2845
+ def append(self, data_list: List[DataWithAxes]):
2846
+ for dwa in data_list:
2847
+ self.append(dwa)
2848
+
2849
+ @dispatch(DataWithAxes)
2850
+ def append(self, dwa: DataWithAxes):
2851
+ """Append/replace DataWithAxes object to the data attribute
2852
+
2853
+ Make sure only one DataWithAxes object with a given name is in the list except if they don't have the same
2854
+ origin identifier
2855
+ """
2856
+ dwa = dwa.deepcopy()
2857
+ self._check_data_type(dwa)
2858
+ obj = self.get_data_from_name_origin(dwa.name, dwa.origin)
2859
+ if obj is not None:
2860
+ self._data.pop(self.data.index(obj))
2861
+ self._data.append(dwa)
2862
+
2863
+ @dispatch(object)
2864
+ def append(self, dte: DataToExport):
2865
+ if isinstance(dte, DataToExport):
2866
+ self.append(dte.data)
2867
+
2868
+
2869
+ if __name__ == '__main__':
2870
+ d = DataRaw('hjk', units='m', data=[np.array([0, 1, 2])])
2871
+ dm = DataRaw('hjk', units='mm', data=[np.array([0, 1, 2])])
2872
+ d + d
2873
+ d - d
2874
+
2875
+ d1 = DataFromRoi(name=f'Hlineout_', data=[np.zeros((24,))],
2876
+ x_axis=Axis(data=np.zeros((24,)), units='myunits', label='mylabel1'))
2877
+ d2 = DataFromRoi(name=f'Hlineout_', data=[np.zeros((12,))],
2878
+ x_axis=Axis(data=np.zeros((12,)),
2879
+ units='myunits2',
2880
+ label='mylabel2'))
2881
+
2882
+ Nsig = 200
2883
+ Nnav = 10
2884
+ x = np.linspace(-Nsig/2, Nsig/2-1, Nsig)
2885
+
2886
+ dat = np.zeros((Nnav, Nsig))
2887
+ for ind in range(Nnav):
2888
+ dat[ind] = mutils.gauss1D(x, 50 * (ind -Nnav / 2), 25 / np.sqrt(2))
2889
+
2890
+ data = DataRaw('mydata', data=[dat], nav_indexes=(0,),
2891
+ axes=[Axis('nav', data=np.linspace(0, Nnav-1, Nnav), index=0),
2892
+ Axis('sig', data=x, index=1)])
2893
+
2894
+ data + data
2895
+
2896
+ data2 = copy.copy(data)
2897
+
2898
+ data3 = data.deepcopy_with_new_data([np.sum(dat, 1)], remove_axes_index=(1,))
2899
+
2900
+ print('done')
2901
+