pymodaq_data 5.2.0a1__tar.gz → 5.2.0a3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/PKG-INFO +6 -3
  2. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/pyproject.toml +4 -2
  3. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/src/pymodaq_data/__init__.py +1 -1
  4. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/src/pymodaq_data/data.py +334 -46
  5. pymodaq_data-5.2.0a3/src/pymodaq_data/h5modules/__init__.py +123 -0
  6. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/src/pymodaq_data/h5modules/backends.py +158 -13
  7. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/src/pymodaq_data/h5modules/browsing.py +11 -7
  8. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/src/pymodaq_data/h5modules/data_saving.py +94 -34
  9. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/src/pymodaq_data/h5modules/exporter.py +1 -1
  10. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/src/pymodaq_data/h5modules/exporters/hyperspy.py +2 -2
  11. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/src/pymodaq_data/h5modules/saving.py +96 -14
  12. pymodaq_data-5.2.0a3/src/pymodaq_data/h5modules/swmr.py +108 -0
  13. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/src/pymodaq_data/numpy_func.py +2 -2
  14. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/src/pymodaq_data/plotting/plotter/plotters/matplotlib_plotters.py +4 -3
  15. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/src/pymodaq_data/plotting/utils.py +3 -3
  16. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/src/pymodaq_data/resources/config_template.toml +11 -9
  17. pymodaq_data-5.2.0a1/src/pymodaq_data/h5modules/__init__.py +0 -6
  18. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/.gitignore +0 -0
  19. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/LICENSE +0 -0
  20. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/README.rst +0 -0
  21. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/src/pymodaq_data/config.py +0 -0
  22. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/src/pymodaq_data/h5modules/exporters/__init__.py +0 -0
  23. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/src/pymodaq_data/h5modules/exporters/base.py +0 -0
  24. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/src/pymodaq_data/h5modules/exporters/flimj.py +0 -0
  25. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/src/pymodaq_data/h5modules/utils.py +0 -0
  26. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/src/pymodaq_data/icon.ico +0 -0
  27. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/src/pymodaq_data/plotting/__init__.py +0 -0
  28. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/src/pymodaq_data/plotting/plotter/plotter.py +0 -0
  29. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/src/pymodaq_data/plotting/plotter/plotters/__init__.py +0 -0
  30. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/src/pymodaq_data/post_treatment/__init__.py +0 -0
  31. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/src/pymodaq_data/post_treatment/process_to_scalar.py +0 -0
  32. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/src/pymodaq_data/resources/__init__.py +0 -0
  33. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/src/pymodaq_data/slicing.py +0 -0
  34. {pymodaq_data-5.2.0a1 → pymodaq_data-5.2.0a3}/src/pymodaq_data/splash.png +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pymodaq_data
3
- Version: 5.2.0a1
3
+ Version: 5.2.0a3
4
4
  Summary: Modular Data Acquisition with Python
5
5
  Project-URL: Homepage, http://pymodaq.cnrs.fr
6
6
  Project-URL: Source, https://github.com/PyMoDAQ/PyMoDAQ
@@ -43,16 +43,19 @@ Classifier: Topic :: Scientific/Engineering :: Visualization
43
43
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
44
44
  Classifier: Topic :: Software Development :: User Interfaces
45
45
  Requires-Python: >=3.8
46
+ Requires-Dist: h5py
46
47
  Requires-Dist: pymodaq-utils>=0.0.8
47
48
  Requires-Dist: scipy
48
- Requires-Dist: tables>=3.10
49
49
  Provides-Extra: dev
50
50
  Requires-Dist: flake8; extra == 'dev'
51
- Requires-Dist: h5py; extra == 'dev'
51
+ Requires-Dist: h5py>=2.5.0; extra == 'dev'
52
52
  Requires-Dist: hatch; extra == 'dev'
53
53
  Requires-Dist: pytest; extra == 'dev'
54
54
  Requires-Dist: pytest-cov; extra == 'dev'
55
55
  Requires-Dist: pytest-xdist; extra == 'dev'
56
+ Requires-Dist: tables>=3.10; extra == 'dev'
57
+ Provides-Extra: xarray
58
+ Requires-Dist: xarray>=2024.10; extra == 'xarray'
56
59
  Description-Content-Type: text/x-rst
57
60
 
58
61
  .. ############################################################
@@ -33,14 +33,16 @@ classifiers = [
33
33
  dependencies = [
34
34
  "pymodaq_utils>=0.0.8",
35
35
  "scipy",
36
- "tables>=3.10",
36
+ "h5py",
37
37
  ]
38
38
 
39
39
  [project.optional-dependencies]
40
+ xarray = ["xarray>=2024.10"]
40
41
  dev = [
41
42
  "hatch",
42
43
  "flake8",
43
- "h5py",
44
+ "h5py>=2.5.0",
45
+ "tables>=3.10",
44
46
  "pytest",
45
47
  "pytest-cov",
46
48
  "pytest-xdist",
@@ -13,7 +13,7 @@ try:
13
13
  try:
14
14
  logger = set_logger('pymodaq_data', add_handler=True, base_logger=True)
15
15
  except Exception:
16
- print("Couldn't create the local folder to store logs , presets...")
16
+ print("Couldn't create the local folder to store logs , experiments...")
17
17
 
18
18
  import pymodaq_data.config
19
19
  logger.info('************************')
@@ -7,8 +7,8 @@ Created the 28/10/2022
7
7
  from __future__ import annotations
8
8
 
9
9
  from abc import ABCMeta, abstractmethod
10
- import numbers
11
10
  from copy import deepcopy
11
+ import numbers
12
12
 
13
13
  import numpy as np
14
14
  from numpy.lib.mixins import NDArrayOperatorsMixin
@@ -184,19 +184,18 @@ class DataDistribution(BaseEnum):
184
184
 
185
185
  def _compute_slices_from_axis(axis: Axis, _slice, *ignored, is_index=True, **ignored_also):
186
186
  if not is_index:
187
- if isinstance(_slice, numbers.Number) or isinstance(_slice, Q_):
187
+ if isinstance(_slice, (numbers.Number, Q_)):
188
188
  if not is_index:
189
189
  _slice = axis.find_index(_slice)
190
190
  elif _slice is Ellipsis:
191
191
  return _slice
192
- elif isinstance(_slice, slice):
193
- if not (_slice.start is None and
194
- _slice.stop is None and _slice.step is None):
195
- start = axis.find_index(
196
- _slice.start if _slice.start is not None else axis.get_data()[0])
197
- stop = axis.find_index(
198
- _slice.stop if _slice.stop is not None else axis.get_data()[-1])
199
- _slice = slice(start, stop)
192
+ elif isinstance(_slice, slice) and not (_slice.start is None and
193
+ _slice.stop is None and _slice.step is None):
194
+ start = axis.find_index(
195
+ _slice.start if _slice.start is not None else axis.get_data()[0])
196
+ stop = axis.find_index(
197
+ _slice.stop if _slice.stop is not None else axis.get_data()[-1])
198
+ _slice = slice(start, stop)
200
199
  return _slice
201
200
 
202
201
 
@@ -437,8 +436,7 @@ class Axis(SerializableBase):
437
436
  ----------
438
437
  indexes:
439
438
  """
440
- if not (isinstance(indexes, np.ndarray) or isinstance(indexes, slice) or
441
- isinstance(indexes, int)):
439
+ if not (isinstance(indexes, (np.ndarray, slice, int))):
442
440
  indexes = np.array(indexes)
443
441
  return self.get_data()[indexes]
444
442
 
@@ -910,6 +908,16 @@ class DataBase(DataLowLevel, NDArrayOperatorsMixin):
910
908
  else:
911
909
  return [float(np.mean(data_array)) for data_array in self.data]
912
910
 
911
+ def equal_to(self, other: 'DataBase', epsilon: Union[float, Q_])-> bool:
912
+ """ Check if two data object are equal within epsilon """
913
+ if isinstance(epsilon, numbers.Number):
914
+ epsilon = Q_(epsilon, self.units)
915
+ try:
916
+ return bool(np.all([np.abs(self.quantities[ind] - other.quantities[ind])
917
+ <= epsilon for ind in range(len(self))]))
918
+ except pint.errors.DimensionalityError as e:
919
+ return False
920
+
913
921
  def as_dte(self, name: str = 'mydte') -> DataToExport:
914
922
  """Convenience method to wrap the DataWithAxes object into a DataToExport"""
915
923
  return DataToExport(name, data=[self])
@@ -918,11 +926,11 @@ class DataBase(DataLowLevel, NDArrayOperatorsMixin):
918
926
  """ Convenience method to split each ndarray into a DataWithAxes object """
919
927
  return DataToExport(name, data=[type(self)(self.labels[ind],
920
928
  source=self.source,
921
- dim = self.dim,
929
+ dim=self.dim,
922
930
  data=[array],
923
- labels = [self.labels[ind]],
924
- axes = deepcopy(self.axes),
925
- units = self.units,
931
+ labels=[self.labels[ind]],
932
+ axes=deepcopy(self.axes),
933
+ units=self.units,
926
934
  ) for ind, array in enumerate(self)])
927
935
 
928
936
  def add_extra_attribute(self, **kwargs):
@@ -1009,7 +1017,7 @@ class DataBase(DataLowLevel, NDArrayOperatorsMixin):
1009
1017
 
1010
1018
  def _comparison_common(self, other, operator='__eq__'):
1011
1019
  if isinstance(other, DataBase):
1012
- if not (# no more checking for name equality but take care ot the pop/remove methods
1020
+ if not ( # no more checking for name equality but take care ot the pop/remove methods
1013
1021
  len(self) == len(other) and
1014
1022
  Unit(self.units).is_compatible_with(other.units)):
1015
1023
  return False
@@ -1021,7 +1029,7 @@ class DataBase(DataLowLevel, NDArrayOperatorsMixin):
1021
1029
  eq = False
1022
1030
  break
1023
1031
  if operator == '__eq__':
1024
- eq = eq and np.allclose(self.quantities[ind], other.quantities[ind])
1032
+ eq = eq and np.allclose(self.quantities[ind], other.quantities[ind], equal_nan=True)
1025
1033
  else:
1026
1034
  eq = eq and np.all(getattr(self.quantities[ind], operator)(other.quantities[ind]))
1027
1035
  # extra attributes are not relevant as they may contain module specific data...
@@ -1211,38 +1219,37 @@ class DataBase(DataLowLevel, NDArrayOperatorsMixin):
1211
1219
  return self.data[index]
1212
1220
 
1213
1221
  def _check_data_type(self, data: List[Union[np.ndarray, Q_]]) -> List[np.ndarray]:
1214
- """make sure data is a list of nd-arrays"""
1215
- is_valid = True
1222
+ """Make sure data is a list of non-empty nd-arrays."""
1216
1223
  if data is None:
1217
- is_valid = False
1224
+ raise TypeError('Data should be a non-empty list of non-empty numpy arrays')
1225
+
1226
+ # Convert single Q_, ndarray, or Number to a list
1218
1227
  if not isinstance(data, list):
1219
- # try to transform the data to regular type
1220
1228
  if isinstance(data, Q_):
1221
1229
  self.force_units(str(data.units))
1222
1230
  data = [data.magnitude]
1223
1231
  elif isinstance(data, np.ndarray):
1224
- warnings.warn(DataTypeWarning(f'Your data should be a list of numpy arrays not just a single numpy'
1225
- f' array, wrapping them with a list'))
1232
+ warnings.warn(DataTypeWarning('Your data should be a list of numpy arrays, not just a single numpy array. Wrapping it in a list.'))
1226
1233
  data = [data]
1227
1234
  elif isinstance(data, numbers.Number):
1228
- warnings.warn(DataTypeWarning(f'Your data should be a list of numpy arrays not just a single numpy'
1229
- f' array, wrapping them with a list'))
1235
+ warnings.warn(DataTypeWarning('Your data should be a list of numpy arrays, not just a single number. Wrapping it in a list.'))
1230
1236
  data = [np.array([data])]
1231
1237
  else:
1232
- is_valid = False
1233
- if isinstance(data, list):
1234
- if len(data) == 0:
1235
- is_valid = False
1236
- elif not (isinstance(data[0], np.ndarray) or
1237
- isinstance(data[0], Q_)):
1238
- is_valid = False
1239
- elif len(data[0].shape) == 0:
1240
- is_valid = False
1241
- if not is_valid:
1242
- raise TypeError(f'Data should be an non-empty list of non-empty numpy arrays')
1238
+ raise TypeError('Data should be a non-empty list of non-empty numpy arrays')
1239
+
1240
+ # Validate the list
1241
+ if not data or not all(isinstance(item, (np.ndarray, Q_)) for item in data):
1242
+ raise TypeError('Data should be a non-empty list of non-empty numpy arrays')
1243
+
1244
+ # Check for non-empty arrays
1245
+ if any(len(item.shape) == 0 for item in data):
1246
+ raise TypeError('Data should be a non-empty list of non-empty numpy arrays')
1247
+
1248
+ # Convert Q_ to magnitude
1243
1249
  if isinstance(data[0], Q_):
1244
1250
  self.force_units(str(data[0].units))
1245
1251
  data = [array.magnitude for array in data]
1252
+
1246
1253
  return data
1247
1254
 
1248
1255
  def check_shape_from_data(self, data: List[np.ndarray]):
@@ -2265,7 +2272,7 @@ class DataWithAxes(DataBase, SerializableBase):
2265
2272
  dat_sum.append(np.atleast_1d(np.sum(dat, axis=axis)))
2266
2273
  return self.deepcopy_with_new_data(dat_sum, remove_axes_index=axis)
2267
2274
 
2268
- def interp(self, new_axis_data: Union[Axis, np.ndarray], **kwargs) -> DataWithAxes:
2275
+ def interp(self, new_axis_data: Union[Axis, np.ndarray], **kwargs) -> DataWithAxes:
2269
2276
  """Performs linear interpolation for 1D data only.
2270
2277
 
2271
2278
  For more complex ones, see :py:meth:`scipy.interpolate`
@@ -2477,11 +2484,11 @@ class DataWithAxes(DataBase, SerializableBase):
2477
2484
 
2478
2485
  dte.append(DataCalculated(f'{self.labels[ind]}',
2479
2486
  data=[self[ind][peaks_indices[-1]],
2480
- peaks_indices[-1]
2487
+ peaks_indices[-1],
2481
2488
  ],
2482
2489
  labels=['peak value', 'peak indexes'],
2483
2490
  axes=[Axis('peak position', self.axes[0].units,
2484
- data=self.axes[0].get_data_at(peaks_indices[-1]))])
2491
+ data=self.axes[0].get_data_at(peaks_indices[-1]))]),
2485
2492
  )
2486
2493
  return dte
2487
2494
 
@@ -2632,7 +2639,7 @@ class DataWithAxes(DataBase, SerializableBase):
2632
2639
  list(slice): a version as index of the input argument
2633
2640
  """
2634
2641
  _slices_as_index = []
2635
- if isinstance(slices, numbers.Number) or isinstance(slices, Q_) or isinstance(slices, slice):
2642
+ if isinstance(slices, (numbers.Number, Q_, slice)):
2636
2643
  slices = [slices]
2637
2644
  if is_navigation:
2638
2645
  indexes = self._am.nav_indexes
@@ -2694,7 +2701,7 @@ class DataWithAxes(DataBase, SerializableBase):
2694
2701
  Object of the same type as the initial data, derived from DataWithAxes. But with lower
2695
2702
  data size due to the slicing and with eventually less axes.
2696
2703
  """
2697
- if isinstance(slices, numbers.Number) or isinstance(slices, slice):
2704
+ if isinstance(slices, (numbers.Number, slice)):
2698
2705
  slices = [slices]
2699
2706
 
2700
2707
  total_slices, slices = self._compute_slices(slices, is_navigation, is_index=is_index)
@@ -2851,6 +2858,218 @@ class DataWithAxes(DataBase, SerializableBase):
2851
2858
  """ Get the underlying data selected from the list at index, returned as a DataWithAxes"""
2852
2859
  return self.deepcopy_with_new_data([self[index]])
2853
2860
 
2861
+ def to_xarray(self):
2862
+ """Convert this DataWithAxes to an xarray.Dataset.
2863
+
2864
+ Each array in self.data becomes a data variable (keyed by its label).
2865
+ Each Axis becomes a coordinate on the corresponding dimension.
2866
+ Error arrays (if present) are stored as ``<label>_error`` data variables.
2867
+
2868
+ Returns
2869
+ -------
2870
+ xr.Dataset
2871
+
2872
+ Raises
2873
+ ------
2874
+ ImportError
2875
+ If xarray is not installed.
2876
+ """
2877
+ try:
2878
+ import xarray as xr
2879
+ except ImportError:
2880
+ raise ImportError(
2881
+ "xarray is required for to_xarray(). "
2882
+ "Install it with: pip install 'pymodaq_data[xarray]'",
2883
+ )
2884
+
2885
+ ndim = len(self.shape)
2886
+
2887
+ # --- build dim names (one per shape dimension) ---
2888
+ dim_names = []
2889
+ seen_dim_names = {}
2890
+ for i in range(ndim):
2891
+ axes_at_i = self.get_axis_from_index(i)
2892
+ if axes_at_i and axes_at_i[0] is not None and axes_at_i[0].label:
2893
+ base = axes_at_i[0].label
2894
+ else:
2895
+ base = f'dim_{i}'
2896
+ # deduplicate
2897
+ if base in seen_dim_names:
2898
+ seen_dim_names[base] += 1
2899
+ name = f'{base}_{seen_dim_names[base]}'
2900
+ else:
2901
+ seen_dim_names[base] = 0
2902
+ name = base
2903
+ dim_names.append(name)
2904
+
2905
+ # --- build coordinates ---
2906
+ coords = {}
2907
+ spread_dim_names = []
2908
+ for axis in self.axes:
2909
+ dim_name = dim_names[axis.index]
2910
+ axis_data = axis.get_data()
2911
+ if axis_data is None:
2912
+ continue
2913
+ if axis.spread_order > 0:
2914
+ coord_name = f'{axis.label}_{axis.spread_order}' if axis.label else f'{dim_name}_{axis.spread_order}'
2915
+ coord_attrs = {
2916
+ 'units': axis.units,
2917
+ 'pymodaq_label': axis.label,
2918
+ 'spread_order': axis.spread_order,
2919
+ }
2920
+ if dim_name not in spread_dim_names:
2921
+ spread_dim_names.append(dim_name)
2922
+ else:
2923
+ coord_name = axis.label if axis.label else dim_name
2924
+ coord_attrs = {
2925
+ 'units': axis.units,
2926
+ 'pymodaq_label': axis.label,
2927
+ }
2928
+ coords[coord_name] = xr.Variable(dim_name, axis_data, attrs=coord_attrs)
2929
+
2930
+ # --- build data variables ---
2931
+ data_vars = {}
2932
+ label_list = list(self.labels)
2933
+ error_var_names = []
2934
+ for i, (array, label) in enumerate(zip(self.data, label_list)):
2935
+ var_name = label if label else f'data_{i}'
2936
+ data_vars[var_name] = xr.Variable(dim_names, array)
2937
+ if self.errors is not None:
2938
+ err_name = f'{var_name}_error'
2939
+ data_vars[err_name] = xr.Variable(dim_names, self.errors[i])
2940
+ error_var_names.append(err_name)
2941
+
2942
+ # --- dataset attrs ---
2943
+ attrs = {
2944
+ 'pymodaq_name': self.name,
2945
+ 'pymodaq_origin': self.origin,
2946
+ 'pymodaq_source': self.source.name,
2947
+ 'pymodaq_distribution': self.distribution.name,
2948
+ 'pymodaq_units': self.units,
2949
+ 'pymodaq_nav_indexes': list(self.nav_indexes),
2950
+ 'pymodaq_labels': label_list,
2951
+ }
2952
+ if error_var_names:
2953
+ attrs['pymodaq_error_vars'] = error_var_names
2954
+ if spread_dim_names:
2955
+ attrs['pymodaq_spread_dim_names'] = spread_dim_names
2956
+
2957
+ return xr.Dataset(data_vars, coords=coords, attrs=attrs)
2958
+
2959
+ @classmethod
2960
+ def from_xarray(cls, ds) -> 'DataWithAxes':
2961
+ """Construct a DataWithAxes from an xarray Dataset (or DataArray).
2962
+
2963
+ Parameters
2964
+ ----------
2965
+ ds : xr.Dataset or xr.DataArray
2966
+
2967
+ Returns
2968
+ -------
2969
+ DataWithAxes
2970
+
2971
+ Raises
2972
+ ------
2973
+ ImportError
2974
+ If xarray is not installed.
2975
+ """
2976
+ try:
2977
+ import xarray as xr
2978
+ except ImportError:
2979
+ raise ImportError(
2980
+ "xarray is required for from_xarray(). "
2981
+ "Install it with: pip install 'pymodaq_data[xarray]'",
2982
+ )
2983
+
2984
+ if isinstance(ds, xr.DataArray):
2985
+ var_name = ds.name if ds.name else 'data'
2986
+ ds = ds.to_dataset(name=var_name)
2987
+
2988
+ attrs = ds.attrs
2989
+ name = attrs.get('pymodaq_name', 'from_xarray')
2990
+ origin = attrs.get('pymodaq_origin', '')
2991
+ source_str = attrs.get('pymodaq_source', 'raw')
2992
+ distribution_str = attrs.get('pymodaq_distribution', 'uniform')
2993
+ units = attrs.get('pymodaq_units', '')
2994
+ nav_indexes = tuple(attrs.get('pymodaq_nav_indexes', []))
2995
+ stored_labels = list(attrs.get('pymodaq_labels', []))
2996
+ error_var_names = list(attrs.get('pymodaq_error_vars', []))
2997
+
2998
+ # separate error vars from regular data vars
2999
+ regular_vars = {k: v for k, v in ds.data_vars.items() if k not in error_var_names}
3000
+ error_vars = {k: v for k, v in ds.data_vars.items() if k in error_var_names}
3001
+
3002
+ # reconstruct data arrays and labels
3003
+ data_arrays = []
3004
+ labels = []
3005
+ for i, (var_name, var) in enumerate(regular_vars.items()):
3006
+ data_arrays.append(var.values)
3007
+ label = stored_labels[i] if i < len(stored_labels) else var_name
3008
+ labels.append(label)
3009
+
3010
+ # reconstruct error arrays (matched by position in error_var_names)
3011
+ errors = None
3012
+ if error_vars:
3013
+ errors = [error_vars[err_name].values for err_name in error_var_names]
3014
+
3015
+ # reconstruct Axis objects from dims and coords
3016
+ dim_names = list(ds.dims)
3017
+ axes = []
3018
+ for i, dim_name in enumerate(dim_names):
3019
+ # find coords that live on this dimension
3020
+ dim_coords = [
3021
+ (cname, cvar) for cname, cvar in ds.coords.items()
3022
+ if list(cvar.dims) == [dim_name]
3023
+ ]
3024
+ if dim_coords:
3025
+ # primary coord: spread_order == 0 (or missing) comes first
3026
+ primary = None
3027
+ secondary = []
3028
+ for cname, cvar in dim_coords:
3029
+ spread_order = int(cvar.attrs.get('spread_order', 0))
3030
+ axis_label = cvar.attrs.get('pymodaq_label', cname)
3031
+ axis_units = cvar.attrs.get('units', '')
3032
+ axis = Axis(
3033
+ label=axis_label,
3034
+ units=axis_units,
3035
+ data=cvar.values,
3036
+ index=i,
3037
+ spread_order=spread_order,
3038
+ )
3039
+ if spread_order == 0:
3040
+ primary = axis
3041
+ else:
3042
+ secondary.append(axis)
3043
+ if primary is not None:
3044
+ axes.append(primary)
3045
+ axes.extend(secondary)
3046
+ else:
3047
+ # no coord: create size-only axis
3048
+ dim_size = ds.sizes[dim_name]
3049
+ axes.append(Axis(label=dim_name, index=i, size=dim_size))
3050
+
3051
+ try:
3052
+ source = DataSource[source_str]
3053
+ except KeyError:
3054
+ source = DataSource.raw
3055
+ try:
3056
+ distribution = DataDistribution[distribution_str]
3057
+ except KeyError:
3058
+ distribution = DataDistribution.uniform
3059
+
3060
+ return cls(
3061
+ name=name,
3062
+ source=source,
3063
+ distribution=distribution,
3064
+ data=data_arrays,
3065
+ labels=labels,
3066
+ units=units,
3067
+ axes=axes,
3068
+ nav_indexes=nav_indexes,
3069
+ origin=origin,
3070
+ errors=errors,
3071
+ )
3072
+
2854
3073
 
2855
3074
  @ser_factory.register_decorator()
2856
3075
  class DataRaw(DataWithAxes):
@@ -2878,7 +3097,7 @@ class DataRaw(DataWithAxes):
2878
3097
  axes=axes,
2879
3098
  nav_indexes=nav_indexes,
2880
3099
  errors=errors,
2881
- **kwargs
3100
+ **kwargs,
2882
3101
  )
2883
3102
 
2884
3103
 
@@ -3252,9 +3471,9 @@ class DataToExport(DataLowLevel, SerializableBase):
3252
3471
  def get_data_from_full_name(self, full_name: str, deepcopy=False) -> DataWithAxes:
3253
3472
  """Get the DataWithAxes with matching full name"""
3254
3473
  if deepcopy:
3255
- data = self.get_data_from_name_origin(full_name.split('/')[1], full_name.split('/')[0]).deepcopy()
3474
+ data = self.get_data_from_name_origin('/'.join(full_name.split('/')[1:]), full_name.split('/')[0]).deepcopy()
3256
3475
  else:
3257
- data = self.get_data_from_name_origin(full_name.split('/')[1], full_name.split('/')[0])
3476
+ data = self.get_data_from_name_origin('/'.join(full_name.split('/')[1:]), full_name.split('/')[0])
3258
3477
  return data
3259
3478
 
3260
3479
  def get_data_from_full_names(self, full_names: List[str], deepcopy=False) -> DataToExport:
@@ -3530,6 +3749,75 @@ class DataToExport(DataLowLevel, SerializableBase):
3530
3749
  if isinstance(dte, DataToExport):
3531
3750
  self.append(dte.data)
3532
3751
 
3752
+ def to_xarray(self):
3753
+ """Convert this DataToExport to an xarray.DataTree.
3754
+
3755
+ The root node carries ``pymodaq_name`` in its attrs. Each DataWithAxes
3756
+ becomes a child node whose dataset is produced by
3757
+ ``DataWithAxes.to_xarray()``.
3758
+
3759
+ Returns
3760
+ -------
3761
+ xr.DataTree
3762
+
3763
+ Raises
3764
+ ------
3765
+ ImportError
3766
+ If xarray is not installed.
3767
+ """
3768
+ try:
3769
+ import xarray as xr
3770
+ except ImportError:
3771
+ raise ImportError(
3772
+ "xarray is required for to_xarray(). "
3773
+ "Install it with: pip install 'pymodaq_data[xarray]'",
3774
+ )
3775
+
3776
+ children = {dwa.name: xr.DataTree(dataset=dwa.to_xarray()) for dwa in self}
3777
+ root_ds = xr.Dataset(attrs={'pymodaq_name': self.name})
3778
+ return xr.DataTree(dataset=root_ds, children=children)
3779
+
3780
+ @classmethod
3781
+ def from_xarray(cls, dt, name: str = None) -> 'DataToExport':
3782
+ """Construct a DataToExport from an xarray.DataTree or a dict of Datasets.
3783
+
3784
+ Parameters
3785
+ ----------
3786
+ dt : xr.DataTree or dict[str, xr.Dataset]
3787
+ name : str, optional
3788
+ Override the name; if None, read from ``dt.attrs['pymodaq_name']``.
3789
+
3790
+ Returns
3791
+ -------
3792
+ DataToExport
3793
+
3794
+ Raises
3795
+ ------
3796
+ ImportError
3797
+ If xarray is not installed.
3798
+ """
3799
+ try:
3800
+ import xarray as xr
3801
+ except ImportError:
3802
+ raise ImportError(
3803
+ "xarray is required for from_xarray(). "
3804
+ "Install it with: pip install 'pymodaq_data[xarray]'",
3805
+ )
3806
+
3807
+ if isinstance(dt, xr.DataTree):
3808
+ root_attrs = dt.dataset.attrs if dt.dataset is not None else {}
3809
+ dte_name = name or root_attrs.get('pymodaq_name', 'from_xarray')
3810
+ data_list = [
3811
+ DataWithAxes.from_xarray(child.dataset)
3812
+ for child in dt.children.values()
3813
+ ]
3814
+ else:
3815
+ # dict[str, xr.Dataset] fallback
3816
+ dte_name = name or 'from_xarray'
3817
+ data_list = [DataWithAxes.from_xarray(ds) for ds in dt.values()]
3818
+
3819
+ return cls(name=dte_name, data=data_list)
3820
+
3533
3821
 
3534
3822
  if __name__ == '__main__':
3535
3823
  d = DataRaw('hjk', units='m', data=[np.array([0, 1, 2])])
@@ -3550,7 +3838,7 @@ if __name__ == '__main__':
3550
3838
 
3551
3839
  dat = np.zeros((Nnav, Nsig))
3552
3840
  for ind in range(Nnav):
3553
- dat[ind] = mutils.gauss1D(x, 50 * (ind -Nnav / 2), 25 / np.sqrt(2))
3841
+ dat[ind] = mutils.gauss1D(x, 50 * (ind -Nnav / 2), 25 / np.sqrt(2))
3554
3842
 
3555
3843
  data = DataRaw('mydata', data=[dat], nav_indexes=(0,),
3556
3844
  axes=[Axis('nav', data=np.linspace(0, Nnav-1, Nnav), index=0),