xradio 0.0.47__py3-none-any.whl → 0.0.49__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. xradio/__init__.py +1 -0
  2. xradio/_utils/dict_helpers.py +69 -2
  3. xradio/_utils/list_and_array.py +3 -1
  4. xradio/_utils/schema.py +3 -1
  5. xradio/image/_util/__init__.py +0 -3
  6. xradio/image/_util/_casacore/common.py +0 -13
  7. xradio/image/_util/_casacore/xds_from_casacore.py +102 -97
  8. xradio/image/_util/_casacore/xds_to_casacore.py +36 -24
  9. xradio/image/_util/_fits/xds_from_fits.py +81 -36
  10. xradio/image/_util/_zarr/zarr_low_level.py +3 -3
  11. xradio/image/_util/casacore.py +7 -5
  12. xradio/image/_util/common.py +13 -26
  13. xradio/image/_util/image_factory.py +143 -191
  14. xradio/image/image.py +10 -59
  15. xradio/measurement_set/__init__.py +11 -6
  16. xradio/measurement_set/_utils/_msv2/_tables/read.py +187 -46
  17. xradio/measurement_set/_utils/_msv2/_tables/table_query.py +22 -0
  18. xradio/measurement_set/_utils/_msv2/conversion.py +347 -299
  19. xradio/measurement_set/_utils/_msv2/create_field_and_source_xds.py +233 -150
  20. xradio/measurement_set/_utils/_msv2/descr.py +1 -1
  21. xradio/measurement_set/_utils/_msv2/msv4_info_dicts.py +20 -13
  22. xradio/measurement_set/_utils/_msv2/msv4_sub_xdss.py +21 -22
  23. xradio/measurement_set/convert_msv2_to_processing_set.py +46 -6
  24. xradio/measurement_set/load_processing_set.py +100 -52
  25. xradio/measurement_set/measurement_set_xdt.py +197 -0
  26. xradio/measurement_set/open_processing_set.py +122 -86
  27. xradio/measurement_set/processing_set_xdt.py +1552 -0
  28. xradio/measurement_set/schema.py +375 -197
  29. xradio/schema/bases.py +5 -1
  30. xradio/schema/check.py +97 -5
  31. xradio/sphinx/schema_table.py +12 -0
  32. {xradio-0.0.47.dist-info → xradio-0.0.49.dist-info}/METADATA +4 -4
  33. {xradio-0.0.47.dist-info → xradio-0.0.49.dist-info}/RECORD +36 -36
  34. {xradio-0.0.47.dist-info → xradio-0.0.49.dist-info}/WHEEL +1 -1
  35. xradio/measurement_set/measurement_set_xds.py +0 -117
  36. xradio/measurement_set/processing_set.py +0 -777
  37. {xradio-0.0.47.dist-info → xradio-0.0.49.dist-info/licenses}/LICENSE.txt +0 -0
  38. {xradio-0.0.47.dist-info → xradio-0.0.49.dist-info}/top_level.txt +0 -0
@@ -127,9 +127,9 @@ def interpolate_to_time(
127
127
  xds = xds.interp(
128
128
  {time_name: interp_time.data}, method=method, assume_sorted=True
129
129
  )
130
- # scan_number sneaks in as a coordinate of the main time axis, drop it
131
- if "scan_number" in xds.coords:
132
- xds = xds.drop_vars("scan_number")
130
+ # scan_name sneaks in as a coordinate of the main time axis, drop it
131
+ if "scan_name" in xds.coords:
132
+ xds = xds.drop_vars("scan_name")
133
133
  points_after = xds[time_name].size
134
134
  logger.debug(
135
135
  f"{message_prefix}: interpolating the time coordinate "
@@ -497,7 +497,7 @@ def prepare_generic_sys_cal_xds(generic_sys_cal_xds: xr.Dataset) -> xr.Dataset:
497
497
  def create_system_calibration_xds(
498
498
  in_file: str,
499
499
  main_xds_frequency: xr.DataArray,
500
- ant_xds_name_ids: xr.DataArray,
500
+ ant_xds: xr.DataArray,
501
501
  sys_cal_interp_time: Union[xr.DataArray, None] = None,
502
502
  ):
503
503
  """
@@ -510,8 +510,8 @@ def create_system_calibration_xds(
510
510
  main_xds_frequency: xr.DataArray
511
511
  frequency array of the main xds (MSv4), containing among other things
512
512
  spectral_window_id and measures metadata
513
- ant_xds_name_ids : xr.Dataset
514
- antenna_name data array from antenna_xds, with name/id information
513
+ ant_xds : xr.Dataset
514
+ The antenna_xds that has information such as names, stations, etc., for coordinates
515
515
  sys_cal_interp_time: Union[xr.DataArray, None] = None,
516
516
  Time axis to interpolate the data vars to (usually main MSv4 time)
517
517
 
@@ -529,7 +529,7 @@ def create_system_calibration_xds(
529
529
  rename_ids=subt_rename_ids["SYSCAL"],
530
530
  taql_where=(
531
531
  f" where (SPECTRAL_WINDOW_ID = {spectral_window_id})"
532
- f" AND (ANTENNA_ID IN [{','.join(map(str, ant_xds_name_ids.antenna_id.values))}])"
532
+ f" AND (ANTENNA_ID IN [{','.join(map(str, ant_xds.antenna_id.values))}])"
533
533
  ),
534
534
  )
535
535
  except ValueError as _exc:
@@ -541,14 +541,14 @@ def create_system_calibration_xds(
541
541
 
542
542
  generic_sys_cal_xds = prepare_generic_sys_cal_xds(generic_sys_cal_xds)
543
543
 
544
- mandatory_dimensions = ["antenna_name", "time_cal", "receptor_label"]
544
+ mandatory_dimensions = ["antenna_name", "time_system_cal", "receptor_label"]
545
545
  if "frequency" not in generic_sys_cal_xds.sizes:
546
546
  dims_all = mandatory_dimensions
547
547
  else:
548
- dims_all = mandatory_dimensions + ["frequency_cal"]
548
+ dims_all = mandatory_dimensions + ["frequency_system_cal"]
549
549
 
550
550
  to_new_data_variables = {
551
- "PHASE_DIFF": ["PHASE_DIFFERENCE", ["antenna_name", "time_cal"]],
551
+ "PHASE_DIFF": ["PHASE_DIFFERENCE", ["antenna_name", "time_system_cal"]],
552
552
  "TCAL": ["TCAL", dims_all],
553
553
  "TCAL_SPECTRUM": ["TCAL", dims_all],
554
554
  "TRX": ["TRX", dims_all],
@@ -564,27 +564,26 @@ def create_system_calibration_xds(
564
564
  }
565
565
 
566
566
  to_new_coords = {
567
- "TIME": ["time_cal", ["time_cal"]],
567
+ "TIME": ["time_system_cal", ["time_system_cal"]],
568
568
  "receptor": ["receptor_label", ["receptor_label"]],
569
- "frequency": ["frequency_cal", ["frequency_cal"]],
569
+ "frequency": ["frequency_system_cal", ["frequency_system_cal"]],
570
570
  }
571
571
 
572
572
  sys_cal_xds = xr.Dataset(attrs={"type": "system_calibration"})
573
- coords = {
574
- "antenna_name": ant_xds_name_ids.sel(
575
- antenna_id=generic_sys_cal_xds["ANTENNA_ID"]
576
- ).data,
577
- "receptor_label": generic_sys_cal_xds.coords["receptor"].data,
573
+ ant_borrowed_coords = {
574
+ "antenna_name": ant_xds.coords["antenna_name"],
575
+ "receptor_label": ant_xds.coords["receptor_label"],
576
+ "polarization_type": ant_xds.coords["polarization_type"],
578
577
  }
579
- sys_cal_xds = sys_cal_xds.assign_coords(coords)
578
+ sys_cal_xds = sys_cal_xds.assign_coords(ant_borrowed_coords)
580
579
  sys_cal_xds = convert_generic_xds_to_xradio_schema(
581
580
  generic_sys_cal_xds, sys_cal_xds, to_new_data_variables, to_new_coords
582
581
  )
583
582
 
584
583
  # Add frequency coord and its measures data, if present
585
- if "frequency_cal" in dims_all:
584
+ if "frequency_system_cal" in dims_all:
586
585
  frequency_coord = {
587
- "frequency_cal": generic_sys_cal_xds.coords["frequency"].data
586
+ "frequency_system_cal": generic_sys_cal_xds.coords["frequency"].data
588
587
  }
589
588
  sys_cal_xds = sys_cal_xds.assign_coords(frequency_coord)
590
589
  frequency_measure = {
@@ -592,10 +591,10 @@ def create_system_calibration_xds(
592
591
  "units": main_xds_frequency.attrs["units"],
593
592
  "observer": main_xds_frequency.attrs["observer"],
594
593
  }
595
- sys_cal_xds.coords["frequency_cal"].attrs.update(frequency_measure)
594
+ sys_cal_xds.coords["frequency_system_cal"].attrs.update(frequency_measure)
596
595
 
597
596
  sys_cal_xds = rename_and_interpolate_to_time(
598
- sys_cal_xds, "time_cal", sys_cal_interp_time, "system_calibration_xds"
597
+ sys_cal_xds, "time_system_cal", sys_cal_interp_time, "system_calibration_xds"
599
598
  )
600
599
 
601
600
  # correct expected types
@@ -18,6 +18,7 @@ def estimate_conversion_memory_and_cores(
18
18
  """
19
19
  Given an MSv2 and a partition_scheme to use when converting it to MSv4,
20
20
  estimates:
21
+
21
22
  - memory (in the sense of the amount expected to be enough to convert)
22
23
  - cores (in the sense of the recommended/optimal number of cores to use to convert)
23
24
 
@@ -36,7 +37,7 @@ def estimate_conversion_memory_and_cores(
36
37
  Partition scheme as used in the function convert_msv2_to_processing_set()
37
38
 
38
39
  Returns
39
- ----------
40
+ -------
40
41
  tuple
41
42
  estimated maximum memory required for one partition,
42
43
  maximum number of cores it makes sense to use (number of partitions),
@@ -62,7 +63,7 @@ def convert_msv2_to_processing_set(
62
63
  use_table_iter: bool = False,
63
64
  compressor: numcodecs.abc.Codec = numcodecs.Zstd(level=2),
64
65
  storage_backend: str = "zarr",
65
- parallel: bool = False,
66
+ parallel_mode: str = "none",
66
67
  overwrite: bool = False,
67
68
  ):
68
69
  """Convert a Measurement Set v2 into a Processing Set of Measurement Set v4.
@@ -99,14 +100,45 @@ def convert_msv2_to_processing_set(
99
100
  The Blosc compressor to use when saving the converted data to disk using Zarr, by default numcodecs.Zstd(level=2).
100
101
  storage_backend : {"zarr", "netcdf"}, optional
101
102
  The on-disk format to use. "netcdf" is not yet implemented.
102
- parallel : bool, optional
103
- Makes use of Dask to execute conversion in parallel, by default False.
103
+ parallel_mode : {"none", "partition", "time"}, optional
104
+ Choose whether to use Dask to execute conversion in parallel, by default "none" and conversion occurs serially.
105
+ The option "partition", parallelises the conversion over partitions specified by `partition_scheme`. The option "time" can only be used for phased array interferometers where there are no partitions
106
+ in the MS v2; instead the MS v2 is parallelised along the time dimension and can be controlled by `main_chunksize`.
104
107
  overwrite : bool, optional
105
108
  Whether to overwrite an existing processing set, by default False.
106
109
  """
107
110
 
111
+ # Create empty data tree
112
+ import xarray as xr
113
+
114
+ ps_dt = xr.DataTree()
115
+
116
+ if not str(out_file).endswith("ps.zarr"):
117
+ out_file += ".ps.zarr"
118
+
119
+ print("Output file: ", out_file)
120
+
121
+ if overwrite:
122
+ ps_dt.to_zarr(store=out_file, mode="w")
123
+ else:
124
+ ps_dt.to_zarr(store=out_file, mode="w-")
125
+
126
+ # Check `parallel_mode` is valid
127
+ try:
128
+ assert parallel_mode in ["none", "partition", "time"]
129
+ except AssertionError:
130
+ logger.warning(
131
+ f"`parallel_mode` {parallel_mode} not recognosed. Defauling to 'none'."
132
+ )
133
+ parallel_mode = "none"
134
+
108
135
  partitions = create_partitions(in_file, partition_scheme=partition_scheme)
109
136
  logger.info("Number of partitions: " + str(len(partitions)))
137
+ if parallel_mode == "time":
138
+ assert (
139
+ len(partitions) == 1
140
+ ), "MS v2 contains more than one partition. `parallel_mode = 'time'` not valid."
141
+
110
142
  delayed_list = []
111
143
 
112
144
  for ms_v4_id, partition_info in enumerate(partitions):
@@ -132,7 +164,7 @@ def convert_msv2_to_processing_set(
132
164
 
133
165
  # prepend '0' to ms_v4_id as needed
134
166
  ms_v4_id = f"{ms_v4_id:0>{len(str(len(partitions) - 1))}}"
135
- if parallel:
167
+ if parallel_mode == "partition":
136
168
  delayed_list.append(
137
169
  dask.delayed(convert_and_write_partition)(
138
170
  in_file,
@@ -149,6 +181,7 @@ def convert_msv2_to_processing_set(
149
181
  phase_cal_interpolate=phase_cal_interpolate,
150
182
  sys_cal_interpolate=sys_cal_interpolate,
151
183
  compressor=compressor,
184
+ parallel_mode=parallel_mode,
152
185
  overwrite=overwrite,
153
186
  )
154
187
  )
@@ -168,8 +201,15 @@ def convert_msv2_to_processing_set(
168
201
  phase_cal_interpolate=phase_cal_interpolate,
169
202
  sys_cal_interpolate=sys_cal_interpolate,
170
203
  compressor=compressor,
204
+ parallel_mode=parallel_mode,
171
205
  overwrite=overwrite,
172
206
  )
173
207
 
174
- if parallel:
208
+ if parallel_mode == "partition":
175
209
  dask.compute(delayed_list)
210
+
211
+ import zarr
212
+
213
+ root_group = zarr.open(out_file, mode="r+") # Open in read/write mode
214
+ root_group.attrs["type"] = "processing_set" # Replace
215
+ zarr.convenience.consolidate_metadata(root_group.store)
@@ -1,79 +1,115 @@
1
1
  import os
2
- from xradio.measurement_set import ProcessingSet
3
2
  from typing import Dict, Union
3
+ import dask
4
+ import xarray as xr
5
+ import s3fs
4
6
 
5
7
 
6
8
  def load_processing_set(
7
9
  ps_store: str,
8
- sel_parms: dict,
9
- data_variables: Union[list, None] = None,
10
+ sel_parms: dict = None,
11
+ data_group_name: str = None,
12
+ include_variables: Union[list, None] = None,
13
+ drop_variables: Union[list, None] = None,
10
14
  load_sub_datasets: bool = True,
11
- ) -> ProcessingSet:
15
+ ) -> xr.DataTree:
12
16
  """Loads a processing set into memory.
13
17
 
14
18
  Parameters
15
19
  ----------
16
20
  ps_store : str
17
21
  String of the path and name of the processing set. For example '/users/user_1/uid___A002_Xf07bba_Xbe5c_target.lsrk.vis.zarr' for a file stored on a local file system, or 's3://viper-test-data/Antennae_North.cal.lsrk.split.vis.zarr/' for a file in AWS object storage.
18
- sel_parms : dict
19
- A dictionary where the keys are the names of the ms_xds's and the values are slice_dicts.
22
+ sel_parms : dict, optional
23
+ A dictionary where the keys are the names of the ms_xdt's (measurement set xarray data trees) and the values are slice_dicts.
20
24
  slice_dicts: A dictionary where the keys are the dimension names and the values are slices.
25
+
21
26
  For example::
22
27
 
23
28
  {
29
+
24
30
  'ms_v4_name_1': {'frequency': slice(0, 160, None),'time':slice(0,100)},
25
31
  ...
26
32
  'ms_v4_name_n': {'frequency': slice(0, 160, None),'time':slice(0,100)},
27
33
  }
28
34
 
29
- data_variables : Union[list, None], optional
35
+ By default None, which loads all ms_xdts.
36
+ data_group_name : str, optional
37
+ The name of the data group to select. By default None, which loads all data groups.
38
+ include_variables : Union[list, None], optional
30
39
  The list of data variables to load into memory for example ['VISIBILITY', 'WEIGHT, 'FLAGS']. By default None which will load all data variables into memory.
40
+ drop_variables : Union[list, None], optional
41
+ The list of data variables to drop from memory for example ['VISIBILITY', 'WEIGHT, 'FLAGS']. By default None which will not drop any data variables from memory.
31
42
  load_sub_datasets : bool, optional
32
43
  If true sub-datasets (for example weather_xds, antenna_xds, pointing_xds, system_calibration_xds ...) will be loaded into memory, by default True.
33
44
 
34
45
  Returns
35
46
  -------
36
- ProcessingSet
37
- In memory representation of processing set (data is represented by Dask.arrays).
47
+ xarray.DataTree
48
+ In memory representation of processing set using xr.DataTree.
38
49
  """
39
- from xradio._utils.zarr.common import _open_dataset, _get_file_system_and_items
50
+ from xradio._utils.zarr.common import _get_file_system_and_items
40
51
 
41
52
  file_system, ms_store_list = _get_file_system_and_items(ps_store)
42
53
 
43
- ps = ProcessingSet()
44
- for ms_name, ms_xds_isel in sel_parms.items():
45
- ms_store = os.path.join(ps_store, ms_name)
46
- correlated_store = os.path.join(ms_store, "correlated_xds")
47
-
48
- xds = _open_dataset(
49
- correlated_store,
50
- file_system,
51
- ms_xds_isel,
52
- data_variables,
53
- load=True,
54
- )
55
- data_groups = xds.attrs["data_groups"]
56
-
57
- if load_sub_datasets:
58
- from xradio.measurement_set.open_processing_set import _open_sub_xds
59
-
60
- sub_xds_dict, field_and_source_xds_dict = _open_sub_xds(
61
- ms_store, file_system=file_system, load=True, data_groups=data_groups
54
+ with dask.config.set(
55
+ scheduler="synchronous"
56
+ ): # serial scheduler, critical so that this can be used within delayed functions.
57
+ ps_xdt = xr.DataTree()
58
+
59
+ if sel_parms:
60
+ for ms_name, ms_xds_isel in sel_parms.items():
61
+ ms_store = os.path.join(ps_store, ms_name)
62
+
63
+ if isinstance(file_system, s3fs.core.S3FileSystem):
64
+ ms_store = s3fs.S3Map(root=ps_store, s3=file_system, check=False)
65
+
66
+ if ms_xds_isel:
67
+ ms_xdt = (
68
+ xr.open_datatree(
69
+ ms_store, engine="zarr", drop_variables=drop_variables
70
+ )
71
+ .isel(ms_xds_isel)
72
+ .xr_ms.sel(data_group_name=data_group_name)
73
+ )
74
+ else:
75
+ ms_xdt = xr.open_datatree(
76
+ ms_store, engine="zarr", drop_variables=drop_variables
77
+ ).xr_ms.sel(data_group_name=data_group_name)
78
+
79
+ if include_variables is not None:
80
+ for data_vars in ms_xdt.ds.data_vars:
81
+ if data_vars not in include_variables:
82
+ ms_xdt.ds = ms_xdt.ds.drop_vars(data_vars)
83
+
84
+ ps_xdt[ms_name] = ms_xdt
85
+
86
+ ps_xdt.attrs["type"] = "processing_set"
87
+ else:
88
+ ps_xdt = xr.open_datatree(
89
+ ps_store, engine="zarr", drop_variables=drop_variables
62
90
  )
63
91
 
64
- xds.attrs = {
65
- **xds.attrs,
66
- **sub_xds_dict,
67
- }
68
- for data_group_name, data_group_vals in data_groups.items():
92
+ if (include_variables is not None) or data_group_name:
93
+ for ms_name, ms_xdt in ps_xdt.items():
94
+
95
+ ms_xdt = ms_xdt.xr_ms.sel(data_group_name=data_group_name)
96
+
97
+ if include_variables is not None:
98
+ for data_vars in ms_xdt.ds.data_vars:
99
+ if data_vars not in include_variables:
100
+ ms_xdt.ds = ms_xdt.ds.drop_vars(data_vars)
101
+ ps_xdt[ms_name] = ms_xdt
69
102
 
70
- xds[data_group_vals["correlated_data"]].attrs[
71
- "field_and_source_xds"
72
- ] = field_and_source_xds_dict[data_group_name]
103
+ if not load_sub_datasets:
104
+ for ms_xdt in ps_xdt.children.values():
105
+ ms_xdt_names = list(ms_xdt.keys())
106
+ for sub_xds_name in ms_xdt_names:
107
+ if "xds" in sub_xds_name:
108
+ del ms_xdt[sub_xds_name]
73
109
 
74
- ps[ms_name] = xds
110
+ ps_xdt = ps_xdt.load()
75
111
 
76
- return ps
112
+ return ps_xdt
77
113
 
78
114
 
79
115
  class ProcessingSetIterator:
@@ -81,8 +117,10 @@ class ProcessingSetIterator:
81
117
  self,
82
118
  sel_parms: dict,
83
119
  input_data_store: str,
84
- input_data: Union[Dict, ProcessingSet, None] = None,
85
- data_variables: list = None,
120
+ input_data: Union[Dict, xr.DataTree, None] = None,
121
+ data_group_name: str = None,
122
+ include_variables: Union[list, None] = None,
123
+ drop_variables: Union[list, None] = None,
86
124
  load_sub_datasets: bool = True,
87
125
  ):
88
126
  """An iterator that will go through a processing set one MS v4 at a time.
@@ -101,10 +139,16 @@ class ProcessingSetIterator:
101
139
  }
102
140
  input_data_store : str
103
141
  String of the path and name of the processing set. For example '/users/user_1/uid___A002_Xf07bba_Xbe5c_target.lsrk.vis.zarr'.
104
- input_data : Union[Dict, processing_set, None], optional
142
+ input_data : Union[Dict, xr.DataTree, None], optional
105
143
  If the processing set is in memory already it can be supplied here. By default None which will make the iterator load data using the supplied input_data_store.
106
- data_variables : list, optional
144
+ data_group_name : str, optional
145
+ The name of the data group to select. By default None, which loads all data groups.
146
+ data_group_name : str, optional
147
+ The name of the data group to select. By default None, which loads all data groups.
148
+ include_variables : Union[list, None], optional
107
149
  The list of data variables to load into memory for example ['VISIBILITY', 'WEIGHT, 'FLAGS']. By default None which will load all data variables into memory.
150
+ drop_variables : Union[list, None], optional
151
+ The list of data variables to drop from memory for example ['VISIBILITY', 'WEIGHT, 'FLAGS']. By default None which will not drop any data variables from memory.
108
152
  load_sub_datasets : bool, optional
109
153
  If true sub-datasets (for example weather_xds, antenna_xds, pointing_xds, system_calibration_xds ...) will be loaded into memory, by default True.
110
154
  """
@@ -113,7 +157,9 @@ class ProcessingSetIterator:
113
157
  self.input_data_store = input_data_store
114
158
  self.sel_parms = sel_parms
115
159
  self.xds_name_iter = iter(sel_parms.keys())
116
- self.data_variables = data_variables
160
+ self.data_group_name = data_group_name
161
+ self.include_variables = include_variables
162
+ self.drop_variables = drop_variables
117
163
  self.load_sub_datasets = load_sub_datasets
118
164
 
119
165
  def __iter__(self):
@@ -121,20 +167,22 @@ class ProcessingSetIterator:
121
167
 
122
168
  def __next__(self):
123
169
  try:
124
- xds_name = next(self.xds_name_iter)
170
+ sub_xds_name = next(self.xds_name_iter)
125
171
  except Exception as e:
126
172
  raise StopIteration
127
173
 
128
174
  if self.input_data is None:
129
- slice_description = self.sel_parms[xds_name]
130
- ps = load_processing_set(
175
+ slice_description = self.sel_parms[sub_xds_name]
176
+ ps_xdt = load_processing_set(
131
177
  ps_store=self.input_data_store,
132
- sel_parms={xds_name: slice_description},
133
- data_variables=self.data_variables,
178
+ sel_parms={sub_xds_name: slice_description},
179
+ data_group_name=self.data_group_name,
180
+ include_variables=self.include_variables,
181
+ drop_variables=self.drop_variables,
134
182
  load_sub_datasets=self.load_sub_datasets,
135
183
  )
136
- xds = ps.get(0)
184
+ sub_xdt = ps_xdt.get(0)
137
185
  else:
138
- xds = self.input_data[xds_name] # In memory
186
+ sub_xdt = self.input_data[sub_xds_name] # In memory
139
187
 
140
- return xds
188
+ return sub_xdt
@@ -0,0 +1,197 @@
1
+ import pandas as pd
2
+ from xradio._utils.list_and_array import to_list
3
+ import xarray as xr
4
+ import numpy as np
5
+ import numbers
6
+ import os
7
+ from collections.abc import Mapping, Iterable
8
+ from typing import Any, Union
9
+
10
+ MS_DATASET_TYPES = {"visibility", "spectrum", "radiometer"}
11
+
12
+
13
+ class InvalidAccessorLocation(ValueError):
14
+ """
15
+ Raised by MeasurementSetXdt accessor functions called on a wrong DataTree node (not MSv4).
16
+ """
17
+
18
+ pass
19
+
20
+
21
+ @xr.register_datatree_accessor("xr_ms")
22
+ class MeasurementSetXdt:
23
+ """Accessor to the Measurement Set DataTree node. Provides MSv4 specific functionality
24
+ such as:
25
+
26
+ - get_partition_info(): produce an info dict with a general MSv4 description including
27
+ intents, SPW name, field and source names, etc.
28
+ - get_field_and_source_xds() to retrieve the field_and_source_xds for a given data
29
+ group.
30
+ - sel(): select data by dimension labels, for example by data group and polaritzation
31
+
32
+ """
33
+
34
+ _xdt: xr.DataTree
35
+
36
+ def __init__(self, datatree: xr.DataTree):
37
+ """
38
+ Initialize the MeasurementSetXdt instance.
39
+
40
+ Parameters
41
+ ----------
42
+ datatree: xarray.DataTree
43
+ The MSv4 DataTree node to construct a MeasurementSetXdt accessor.
44
+ """
45
+
46
+ self._xdt = datatree
47
+ self.meta = {"summary": {}}
48
+
49
+ def sel(
50
+ self,
51
+ indexers: Union[Mapping[Any, Any], None] = None,
52
+ method: Union[str, None] = None,
53
+ tolerance: Union[int, float, Iterable[Union[int, float]], None] = None,
54
+ drop: bool = False,
55
+ **indexers_kwargs: Any,
56
+ ) -> xr.DataTree:
57
+ """
58
+ Select data along dimension(s) by label. Alternative to `xarray.Dataset.sel <https://xarray.pydata.org/en/stable/generated/xarray.Dataset.sel.html>`__ so that a data group can be selected by name by using the `data_group_name` parameter.
59
+ For more information on data groups see `Data Groups <https://xradio.readthedocs.io/en/latest/measurement_set_overview.html#Data-Groups>`__ section. See `xarray.Dataset.sel <https://xarray.pydata.org/en/stable/generated/xarray.Dataset.sel.html>`__ for parameter descriptions.
60
+
61
+ Returns
62
+ -------
63
+ xarray.DataTree
64
+ xarray DataTree with MeasurementSetXdt accessors
65
+
66
+ Examples
67
+ --------
68
+ >>> # Select data group 'corrected' and polarization 'XX'.
69
+ >>> selected_ms_xdt = ms_xdt.xr_ms.sel(data_group_name='corrected', polarization='XX')
70
+
71
+ >>> # Select data group 'corrected' and polarization 'XX' using a dict.
72
+ >>> selected_ms_xdt = ms_xdt.xr_ms.sel({'data_group_name':'corrected', 'polarization':'XX')
73
+ """
74
+
75
+ if self._xdt.attrs.get("type") not in MS_DATASET_TYPES:
76
+ raise InvalidAccessorLocation(f"{self._xdt.path} is not a MSv4 node.")
77
+
78
+ assert self._xdt.attrs["type"] in [
79
+ "visibility",
80
+ "spectrum",
81
+ "radiometer",
82
+ ], "The type of the xdt must be 'visibility', 'spectrum' or 'radiometer'."
83
+
84
+ if "data_group_name" in indexers_kwargs:
85
+ data_group_name = indexers_kwargs["data_group_name"]
86
+ del indexers_kwargs["data_group_name"]
87
+ elif (indexers is not None) and ("data_group_name" in indexers):
88
+ data_group_name = indexers["data_group_name"]
89
+ del indexers["data_group_name"]
90
+ else:
91
+ data_group_name = None
92
+
93
+ if data_group_name is not None:
94
+ sel_data_group_set = set(
95
+ self._xdt.attrs["data_groups"][data_group_name].values()
96
+ )
97
+
98
+ data_variables_to_drop = []
99
+ for dg in self._xdt.attrs["data_groups"].values():
100
+ temp_set = set(dg.values()) - sel_data_group_set
101
+ data_variables_to_drop.extend(list(temp_set))
102
+
103
+ data_variables_to_drop = list(set(data_variables_to_drop))
104
+
105
+ sel_ms_xdt = self._xdt
106
+
107
+ sel_corr_xds = self._xdt.ds.sel(
108
+ indexers, method, tolerance, drop, **indexers_kwargs
109
+ ).drop_vars(data_variables_to_drop)
110
+
111
+ sel_ms_xdt.ds = sel_corr_xds
112
+
113
+ sel_ms_xdt.attrs["data_groups"] = {
114
+ data_group_name: self._xdt.attrs["data_groups"][data_group_name]
115
+ }
116
+
117
+ return sel_ms_xdt
118
+ else:
119
+ return self._xdt.sel(indexers, method, tolerance, drop, **indexers_kwargs)
120
+
121
+ def get_field_and_source_xds(self, data_group_name: str = None) -> xr.Dataset:
122
+ """Get the field_and_source_xds associated with data group `data_group_name`.
123
+
124
+ Parameters
125
+ ----------
126
+ data_group_name : str, optional
127
+ The data group to process. Default is "base" or if not found to first data group.
128
+
129
+ Returns
130
+ -------
131
+ xarray.Dataset
132
+ field_and_source_xds associated with the data group.
133
+ """
134
+ if self._xdt.attrs.get("type") not in MS_DATASET_TYPES:
135
+ raise InvalidAccessorLocation(f"{self._xdt.path} is not a MSv4 node.")
136
+
137
+ if data_group_name is None:
138
+ if "base" in self._xdt.attrs["data_groups"].keys():
139
+ data_group_name = "base"
140
+ else:
141
+ data_group_name = list(self._xdt.attrs["data_groups"].keys())[0]
142
+
143
+ return self._xdt[f"field_and_source_xds_{data_group_name}"].ds
144
+
145
+ def get_partition_info(self, data_group_name: str = None) -> dict:
146
+ """
147
+ Generate a partition info dict for an MSv4, with general MSv4 description including
148
+ information such as field and source names, SPW name, scan name, the intents string,
149
+ etc.
150
+
151
+ The information is gathered from various coordinates, secondary datasets, and info
152
+ dicts of the MSv4. For example, the SPW name comes from the attributes of the
153
+ frequency coordinate, whereas field and source related information such as field and
154
+ source names come from the field_and_source_xds (base) dataset of the MSv4.
155
+
156
+ Parameters
157
+ ----------
158
+ data_group_name : str, optional
159
+ The data group to process. Default is "base" or if not found to first data group.
160
+
161
+ Returns
162
+ -------
163
+ dict
164
+ Partition info dict for the MSv4
165
+ """
166
+ if self._xdt.attrs.get("type") not in MS_DATASET_TYPES:
167
+ raise InvalidAccessorLocation(
168
+ f"{self._xdt.path} is not a MSv4 node (type {self._xdt.attrs.get('type')}."
169
+ )
170
+
171
+ if data_group_name is None:
172
+ if "base" in self._xdt.attrs["data_groups"].keys():
173
+ data_group_name = "base"
174
+ else:
175
+ data_group_name = list(self._xdt.attrs["data_groups"].keys())[0]
176
+
177
+ field_and_source_xds = self._xdt.xr_ms.get_field_and_source_xds(data_group_name)
178
+
179
+ if "line_name" in field_and_source_xds.coords:
180
+ line_name = to_list(
181
+ np.unique(np.ravel(field_and_source_xds.line_name.values))
182
+ )
183
+ else:
184
+ line_name = []
185
+
186
+ partition_info = {
187
+ "spectral_window_name": self._xdt.frequency.attrs["spectral_window_name"],
188
+ "field_name": to_list(np.unique(field_and_source_xds.field_name.values)),
189
+ "polarization_setup": to_list(self._xdt.polarization.values),
190
+ "scan_name": to_list(np.unique(self._xdt.scan_name.values)),
191
+ "source_name": to_list(np.unique(field_and_source_xds.source_name.values)),
192
+ "intents": self._xdt.observation_info["intents"],
193
+ "line_name": line_name,
194
+ "data_group_name": data_group_name,
195
+ }
196
+
197
+ return partition_info