xradio 0.0.48__py3-none-any.whl → 0.0.49__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. xradio/__init__.py +1 -0
  2. xradio/_utils/dict_helpers.py +69 -2
  3. xradio/image/_util/__init__.py +0 -3
  4. xradio/image/_util/_casacore/common.py +0 -13
  5. xradio/image/_util/_casacore/xds_from_casacore.py +102 -97
  6. xradio/image/_util/_casacore/xds_to_casacore.py +36 -24
  7. xradio/image/_util/_fits/xds_from_fits.py +81 -36
  8. xradio/image/_util/_zarr/zarr_low_level.py +3 -3
  9. xradio/image/_util/casacore.py +7 -5
  10. xradio/image/_util/common.py +13 -26
  11. xradio/image/_util/image_factory.py +143 -191
  12. xradio/image/image.py +10 -59
  13. xradio/measurement_set/__init__.py +11 -6
  14. xradio/measurement_set/_utils/_msv2/_tables/read.py +187 -46
  15. xradio/measurement_set/_utils/_msv2/_tables/table_query.py +22 -0
  16. xradio/measurement_set/_utils/_msv2/conversion.py +351 -318
  17. xradio/measurement_set/_utils/_msv2/msv4_info_dicts.py +20 -17
  18. xradio/measurement_set/convert_msv2_to_processing_set.py +46 -6
  19. xradio/measurement_set/load_processing_set.py +100 -53
  20. xradio/measurement_set/measurement_set_xdt.py +197 -0
  21. xradio/measurement_set/open_processing_set.py +122 -86
  22. xradio/measurement_set/processing_set_xdt.py +1552 -0
  23. xradio/measurement_set/schema.py +199 -94
  24. xradio/schema/bases.py +5 -1
  25. xradio/schema/check.py +97 -5
  26. {xradio-0.0.48.dist-info → xradio-0.0.49.dist-info}/METADATA +4 -4
  27. {xradio-0.0.48.dist-info → xradio-0.0.49.dist-info}/RECORD +30 -30
  28. {xradio-0.0.48.dist-info → xradio-0.0.49.dist-info}/WHEEL +1 -1
  29. xradio/measurement_set/measurement_set_xds.py +0 -117
  30. xradio/measurement_set/processing_set.py +0 -803
  31. {xradio-0.0.48.dist-info → xradio-0.0.49.dist-info/licenses}/LICENSE.txt +0 -0
  32. {xradio-0.0.48.dist-info → xradio-0.0.49.dist-info}/top_level.txt +0 -0
@@ -49,28 +49,31 @@ def create_info_dicts(
49
49
  line_name = []
50
50
 
51
51
  info_dicts = {}
52
- info_dicts["partition_info"] = {
53
- # "spectral_window_id": xds.frequency.attrs["spectral_window_id"],
54
- "spectral_window_name": xds.frequency.attrs["spectral_window_name"],
55
- # "field_id": to_list(unique_1d(field_id)),
56
- "field_name": to_list(np.unique(field_and_source_xds.field_name.values)),
57
- "polarization_setup": to_list(xds.polarization.values),
58
- "scan_name": to_list(np.unique(partition_info_misc_fields["scan_name"])),
59
- "source_name": to_list(np.unique(field_and_source_xds.source_name.values)),
60
- # "source_id": to_list(unique_1d(source_id)),
61
- "intents": partition_info_misc_fields["intents"].split(","),
62
- "taql": partition_info_misc_fields["taql_where"],
63
- "line_name": line_name,
64
- }
65
- if "antenna_name" in partition_info_misc_fields:
66
- info_dicts["partition_info"]["antenna_name"] = partition_info_misc_fields[
67
- "antenna_name"
68
- ]
52
+ # info_dicts["partition_info"] = {
53
+ # # "spectral_window_id": xds.frequency.attrs["spectral_window_id"],
54
+ # "spectral_window_name": xds.frequency.attrs["spectral_window_name"],
55
+ # # "field_id": to_list(unique_1d(field_id)),
56
+ # "field_name": to_list(np.unique(field_and_source_xds.field_name.values)),
57
+ # "polarization_setup": to_list(xds.polarization.values),
58
+ # "scan_name": to_list(np.unique(partition_info_misc_fields["scan_name"])),
59
+ # "source_name": to_list(np.unique(field_and_source_xds.source_name.values)),
60
+ # # "source_id": to_list(unique_1d(source_id)),
61
+ # "intents": partition_info_misc_fields["intents"].split(","),
62
+ # "taql": partition_info_misc_fields["taql_where"],
63
+ # "line_name": line_name,
64
+ # }
65
+ # if "antenna_name" in partition_info_misc_fields:
66
+ # info_dicts["partition_info"]["antenna_name"] = partition_info_misc_fields[
67
+ # "antenna_name"
68
+ # ]
69
69
 
70
70
  observation_id = check_if_consistent(
71
71
  tb_tool.getcol("OBSERVATION_ID"), "OBSERVATION_ID"
72
72
  )
73
73
  info_dicts["observation_info"] = create_observation_info(in_file, observation_id)
74
+ info_dicts["observation_info"]["intents"] = partition_info_misc_fields[
75
+ "intents"
76
+ ].split(",")
74
77
 
75
78
  processor_id = check_if_consistent(tb_tool.getcol("PROCESSOR_ID"), "PROCESSOR_ID")
76
79
  info_dicts["processor_info"] = create_processor_info(in_file, processor_id)
@@ -18,6 +18,7 @@ def estimate_conversion_memory_and_cores(
18
18
  """
19
19
  Given an MSv2 and a partition_scheme to use when converting it to MSv4,
20
20
  estimates:
21
+
21
22
  - memory (in the sense of the amount expected to be enough to convert)
22
23
  - cores (in the sense of the recommended/optimal number of cores to use to convert)
23
24
 
@@ -36,7 +37,7 @@ def estimate_conversion_memory_and_cores(
36
37
  Partition scheme as used in the function convert_msv2_to_processing_set()
37
38
 
38
39
  Returns
39
- ----------
40
+ -------
40
41
  tuple
41
42
  estimated maximum memory required for one partition,
42
43
  maximum number of cores it makes sense to use (number of partitions),
@@ -62,7 +63,7 @@ def convert_msv2_to_processing_set(
62
63
  use_table_iter: bool = False,
63
64
  compressor: numcodecs.abc.Codec = numcodecs.Zstd(level=2),
64
65
  storage_backend: str = "zarr",
65
- parallel: bool = False,
66
+ parallel_mode: str = "none",
66
67
  overwrite: bool = False,
67
68
  ):
68
69
  """Convert a Measurement Set v2 into a Processing Set of Measurement Set v4.
@@ -99,14 +100,45 @@ def convert_msv2_to_processing_set(
99
100
  The Blosc compressor to use when saving the converted data to disk using Zarr, by default numcodecs.Zstd(level=2).
100
101
  storage_backend : {"zarr", "netcdf"}, optional
101
102
  The on-disk format to use. "netcdf" is not yet implemented.
102
- parallel : bool, optional
103
- Makes use of Dask to execute conversion in parallel, by default False.
103
+ parallel_mode : {"none", "partition", "time"}, optional
104
+ Choose whether to use Dask to execute conversion in parallel, by default "none" and conversion occurs serially.
105
+ The option "partition", parallelises the conversion over partitions specified by `partition_scheme`. The option "time" can only be used for phased array interferometers where there are no partitions
106
+ in the MS v2; instead the MS v2 is parallelised along the time dimension and can be controlled by `main_chunksize`.
104
107
  overwrite : bool, optional
105
108
  Whether to overwrite an existing processing set, by default False.
106
109
  """
107
110
 
111
+ # Create empty data tree
112
+ import xarray as xr
113
+
114
+ ps_dt = xr.DataTree()
115
+
116
+ if not str(out_file).endswith("ps.zarr"):
117
+ out_file += ".ps.zarr"
118
+
119
+ print("Output file: ", out_file)
120
+
121
+ if overwrite:
122
+ ps_dt.to_zarr(store=out_file, mode="w")
123
+ else:
124
+ ps_dt.to_zarr(store=out_file, mode="w-")
125
+
126
+ # Check `parallel_mode` is valid
127
+ try:
128
+ assert parallel_mode in ["none", "partition", "time"]
129
+ except AssertionError:
130
+ logger.warning(
131
+ f"`parallel_mode` {parallel_mode} not recognosed. Defauling to 'none'."
132
+ )
133
+ parallel_mode = "none"
134
+
108
135
  partitions = create_partitions(in_file, partition_scheme=partition_scheme)
109
136
  logger.info("Number of partitions: " + str(len(partitions)))
137
+ if parallel_mode == "time":
138
+ assert (
139
+ len(partitions) == 1
140
+ ), "MS v2 contains more than one partition. `parallel_mode = 'time'` not valid."
141
+
110
142
  delayed_list = []
111
143
 
112
144
  for ms_v4_id, partition_info in enumerate(partitions):
@@ -132,7 +164,7 @@ def convert_msv2_to_processing_set(
132
164
 
133
165
  # prepend '0' to ms_v4_id as needed
134
166
  ms_v4_id = f"{ms_v4_id:0>{len(str(len(partitions) - 1))}}"
135
- if parallel:
167
+ if parallel_mode == "partition":
136
168
  delayed_list.append(
137
169
  dask.delayed(convert_and_write_partition)(
138
170
  in_file,
@@ -149,6 +181,7 @@ def convert_msv2_to_processing_set(
149
181
  phase_cal_interpolate=phase_cal_interpolate,
150
182
  sys_cal_interpolate=sys_cal_interpolate,
151
183
  compressor=compressor,
184
+ parallel_mode=parallel_mode,
152
185
  overwrite=overwrite,
153
186
  )
154
187
  )
@@ -168,8 +201,15 @@ def convert_msv2_to_processing_set(
168
201
  phase_cal_interpolate=phase_cal_interpolate,
169
202
  sys_cal_interpolate=sys_cal_interpolate,
170
203
  compressor=compressor,
204
+ parallel_mode=parallel_mode,
171
205
  overwrite=overwrite,
172
206
  )
173
207
 
174
- if parallel:
208
+ if parallel_mode == "partition":
175
209
  dask.compute(delayed_list)
210
+
211
+ import zarr
212
+
213
+ root_group = zarr.open(out_file, mode="r+") # Open in read/write mode
214
+ root_group.attrs["type"] = "processing_set" # Replace
215
+ zarr.convenience.consolidate_metadata(root_group.store)
@@ -1,80 +1,115 @@
1
1
  import os
2
- from xradio.measurement_set import ProcessingSet
3
2
  from typing import Dict, Union
3
+ import dask
4
+ import xarray as xr
5
+ import s3fs
4
6
 
5
7
 
6
8
  def load_processing_set(
7
9
  ps_store: str,
8
- sel_parms: dict,
9
- data_variables: Union[list, None] = None,
10
+ sel_parms: dict = None,
11
+ data_group_name: str = None,
12
+ include_variables: Union[list, None] = None,
13
+ drop_variables: Union[list, None] = None,
10
14
  load_sub_datasets: bool = True,
11
- ) -> ProcessingSet:
15
+ ) -> xr.DataTree:
12
16
  """Loads a processing set into memory.
13
17
 
14
18
  Parameters
15
19
  ----------
16
20
  ps_store : str
17
21
  String of the path and name of the processing set. For example '/users/user_1/uid___A002_Xf07bba_Xbe5c_target.lsrk.vis.zarr' for a file stored on a local file system, or 's3://viper-test-data/Antennae_North.cal.lsrk.split.vis.zarr/' for a file in AWS object storage.
18
- sel_parms : dict
19
- A dictionary where the keys are the names of the ms_xds's and the values are slice_dicts.
22
+ sel_parms : dict, optional
23
+ A dictionary where the keys are the names of the ms_xdt's (measurement set xarray data trees) and the values are slice_dicts.
20
24
  slice_dicts: A dictionary where the keys are the dimension names and the values are slices.
25
+
21
26
  For example::
22
27
 
23
28
  {
29
+
24
30
  'ms_v4_name_1': {'frequency': slice(0, 160, None),'time':slice(0,100)},
25
31
  ...
26
32
  'ms_v4_name_n': {'frequency': slice(0, 160, None),'time':slice(0,100)},
27
33
  }
28
34
 
29
- data_variables : Union[list, None], optional
35
+ By default None, which loads all ms_xdts.
36
+ data_group_name : str, optional
37
+ The name of the data group to select. By default None, which loads all data groups.
38
+ include_variables : Union[list, None], optional
30
39
  The list of data variables to load into memory for example ['VISIBILITY', 'WEIGHT, 'FLAGS']. By default None which will load all data variables into memory.
40
+ drop_variables : Union[list, None], optional
41
+ The list of data variables to drop from memory for example ['VISIBILITY', 'WEIGHT, 'FLAGS']. By default None which will not drop any data variables from memory.
31
42
  load_sub_datasets : bool, optional
32
43
  If true sub-datasets (for example weather_xds, antenna_xds, pointing_xds, system_calibration_xds ...) will be loaded into memory, by default True.
33
44
 
34
45
  Returns
35
46
  -------
36
- ProcessingSet
37
- In memory representation of processing set (data is represented by Dask.arrays).
47
+ xarray.DataTree
48
+ In memory representation of processing set using xr.DataTree.
38
49
  """
39
- from xradio._utils.zarr.common import _open_dataset, _get_file_system_and_items
40
- from xradio.measurement_set import MeasurementSetXds
50
+ from xradio._utils.zarr.common import _get_file_system_and_items
41
51
 
42
52
  file_system, ms_store_list = _get_file_system_and_items(ps_store)
43
53
 
44
- ps = ProcessingSet()
45
- for ms_name, ms_xds_isel in sel_parms.items():
46
- ms_store = os.path.join(ps_store, ms_name)
47
- correlated_store = os.path.join(ms_store, "correlated_xds")
48
-
49
- xds = _open_dataset(
50
- correlated_store,
51
- file_system,
52
- ms_xds_isel,
53
- data_variables,
54
- load=True,
55
- )
56
- data_groups = xds.attrs["data_groups"]
57
-
58
- if load_sub_datasets:
59
- from xradio.measurement_set.open_processing_set import _open_sub_xds
60
-
61
- sub_xds_dict, field_and_source_xds_dict = _open_sub_xds(
62
- ms_store, file_system=file_system, load=True, data_groups=data_groups
54
+ with dask.config.set(
55
+ scheduler="synchronous"
56
+ ): # serial scheduler, critical so that this can be used within delayed functions.
57
+ ps_xdt = xr.DataTree()
58
+
59
+ if sel_parms:
60
+ for ms_name, ms_xds_isel in sel_parms.items():
61
+ ms_store = os.path.join(ps_store, ms_name)
62
+
63
+ if isinstance(file_system, s3fs.core.S3FileSystem):
64
+ ms_store = s3fs.S3Map(root=ps_store, s3=file_system, check=False)
65
+
66
+ if ms_xds_isel:
67
+ ms_xdt = (
68
+ xr.open_datatree(
69
+ ms_store, engine="zarr", drop_variables=drop_variables
70
+ )
71
+ .isel(ms_xds_isel)
72
+ .xr_ms.sel(data_group_name=data_group_name)
73
+ )
74
+ else:
75
+ ms_xdt = xr.open_datatree(
76
+ ms_store, engine="zarr", drop_variables=drop_variables
77
+ ).xr_ms.sel(data_group_name=data_group_name)
78
+
79
+ if include_variables is not None:
80
+ for data_vars in ms_xdt.ds.data_vars:
81
+ if data_vars not in include_variables:
82
+ ms_xdt.ds = ms_xdt.ds.drop_vars(data_vars)
83
+
84
+ ps_xdt[ms_name] = ms_xdt
85
+
86
+ ps_xdt.attrs["type"] = "processing_set"
87
+ else:
88
+ ps_xdt = xr.open_datatree(
89
+ ps_store, engine="zarr", drop_variables=drop_variables
63
90
  )
64
91
 
65
- xds.attrs = {
66
- **xds.attrs,
67
- **sub_xds_dict,
68
- }
69
- for data_group_name, data_group_vals in data_groups.items():
92
+ if (include_variables is not None) or data_group_name:
93
+ for ms_name, ms_xdt in ps_xdt.items():
94
+
95
+ ms_xdt = ms_xdt.xr_ms.sel(data_group_name=data_group_name)
96
+
97
+ if include_variables is not None:
98
+ for data_vars in ms_xdt.ds.data_vars:
99
+ if data_vars not in include_variables:
100
+ ms_xdt.ds = ms_xdt.ds.drop_vars(data_vars)
101
+ ps_xdt[ms_name] = ms_xdt
70
102
 
71
- xds[data_group_vals["correlated_data"]].attrs[
72
- "field_and_source_xds"
73
- ] = field_and_source_xds_dict[data_group_name]
103
+ if not load_sub_datasets:
104
+ for ms_xdt in ps_xdt.children.values():
105
+ ms_xdt_names = list(ms_xdt.keys())
106
+ for sub_xds_name in ms_xdt_names:
107
+ if "xds" in sub_xds_name:
108
+ del ms_xdt[sub_xds_name]
74
109
 
75
- ps[ms_name] = MeasurementSetXds(xds)
110
+ ps_xdt = ps_xdt.load()
76
111
 
77
- return ps
112
+ return ps_xdt
78
113
 
79
114
 
80
115
  class ProcessingSetIterator:
@@ -82,8 +117,10 @@ class ProcessingSetIterator:
82
117
  self,
83
118
  sel_parms: dict,
84
119
  input_data_store: str,
85
- input_data: Union[Dict, ProcessingSet, None] = None,
86
- data_variables: list = None,
120
+ input_data: Union[Dict, xr.DataTree, None] = None,
121
+ data_group_name: str = None,
122
+ include_variables: Union[list, None] = None,
123
+ drop_variables: Union[list, None] = None,
87
124
  load_sub_datasets: bool = True,
88
125
  ):
89
126
  """An iterator that will go through a processing set one MS v4 at a time.
@@ -102,10 +139,16 @@ class ProcessingSetIterator:
102
139
  }
103
140
  input_data_store : str
104
141
  String of the path and name of the processing set. For example '/users/user_1/uid___A002_Xf07bba_Xbe5c_target.lsrk.vis.zarr'.
105
- input_data : Union[Dict, processing_set, None], optional
142
+ input_data : Union[Dict, xr.DataTree, None], optional
106
143
  If the processing set is in memory already it can be supplied here. By default None which will make the iterator load data using the supplied input_data_store.
107
- data_variables : list, optional
144
+ data_group_name : str, optional
145
+ The name of the data group to select. By default None, which loads all data groups.
146
+ data_group_name : str, optional
147
+ The name of the data group to select. By default None, which loads all data groups.
148
+ include_variables : Union[list, None], optional
108
149
  The list of data variables to load into memory for example ['VISIBILITY', 'WEIGHT, 'FLAGS']. By default None which will load all data variables into memory.
150
+ drop_variables : Union[list, None], optional
151
+ The list of data variables to drop from memory for example ['VISIBILITY', 'WEIGHT, 'FLAGS']. By default None which will not drop any data variables from memory.
109
152
  load_sub_datasets : bool, optional
110
153
  If true sub-datasets (for example weather_xds, antenna_xds, pointing_xds, system_calibration_xds ...) will be loaded into memory, by default True.
111
154
  """
@@ -114,7 +157,9 @@ class ProcessingSetIterator:
114
157
  self.input_data_store = input_data_store
115
158
  self.sel_parms = sel_parms
116
159
  self.xds_name_iter = iter(sel_parms.keys())
117
- self.data_variables = data_variables
160
+ self.data_group_name = data_group_name
161
+ self.include_variables = include_variables
162
+ self.drop_variables = drop_variables
118
163
  self.load_sub_datasets = load_sub_datasets
119
164
 
120
165
  def __iter__(self):
@@ -122,20 +167,22 @@ class ProcessingSetIterator:
122
167
 
123
168
  def __next__(self):
124
169
  try:
125
- xds_name = next(self.xds_name_iter)
170
+ sub_xds_name = next(self.xds_name_iter)
126
171
  except Exception as e:
127
172
  raise StopIteration
128
173
 
129
174
  if self.input_data is None:
130
- slice_description = self.sel_parms[xds_name]
131
- ps = load_processing_set(
175
+ slice_description = self.sel_parms[sub_xds_name]
176
+ ps_xdt = load_processing_set(
132
177
  ps_store=self.input_data_store,
133
- sel_parms={xds_name: slice_description},
134
- data_variables=self.data_variables,
178
+ sel_parms={sub_xds_name: slice_description},
179
+ data_group_name=self.data_group_name,
180
+ include_variables=self.include_variables,
181
+ drop_variables=self.drop_variables,
135
182
  load_sub_datasets=self.load_sub_datasets,
136
183
  )
137
- xds = ps.get(0)
184
+ sub_xdt = ps_xdt.get(0)
138
185
  else:
139
- xds = self.input_data[xds_name] # In memory
186
+ sub_xdt = self.input_data[sub_xds_name] # In memory
140
187
 
141
- return xds
188
+ return sub_xdt
@@ -0,0 +1,197 @@
1
+ import pandas as pd
2
+ from xradio._utils.list_and_array import to_list
3
+ import xarray as xr
4
+ import numpy as np
5
+ import numbers
6
+ import os
7
+ from collections.abc import Mapping, Iterable
8
+ from typing import Any, Union
9
+
10
+ MS_DATASET_TYPES = {"visibility", "spectrum", "radiometer"}
11
+
12
+
13
+ class InvalidAccessorLocation(ValueError):
14
+ """
15
+ Raised by MeasurementSetXdt accessor functions called on a wrong DataTree node (not MSv4).
16
+ """
17
+
18
+ pass
19
+
20
+
21
+ @xr.register_datatree_accessor("xr_ms")
22
+ class MeasurementSetXdt:
23
+ """Accessor to the Measurement Set DataTree node. Provides MSv4 specific functionality
24
+ such as:
25
+
26
+ - get_partition_info(): produce an info dict with a general MSv4 description including
27
+ intents, SPW name, field and source names, etc.
28
+ - get_field_and_source_xds() to retrieve the field_and_source_xds for a given data
29
+ group.
30
+ - sel(): select data by dimension labels, for example by data group and polaritzation
31
+
32
+ """
33
+
34
+ _xdt: xr.DataTree
35
+
36
+ def __init__(self, datatree: xr.DataTree):
37
+ """
38
+ Initialize the MeasurementSetXdt instance.
39
+
40
+ Parameters
41
+ ----------
42
+ datatree: xarray.DataTree
43
+ The MSv4 DataTree node to construct a MeasurementSetXdt accessor.
44
+ """
45
+
46
+ self._xdt = datatree
47
+ self.meta = {"summary": {}}
48
+
49
+ def sel(
50
+ self,
51
+ indexers: Union[Mapping[Any, Any], None] = None,
52
+ method: Union[str, None] = None,
53
+ tolerance: Union[int, float, Iterable[Union[int, float]], None] = None,
54
+ drop: bool = False,
55
+ **indexers_kwargs: Any,
56
+ ) -> xr.DataTree:
57
+ """
58
+ Select data along dimension(s) by label. Alternative to `xarray.Dataset.sel <https://xarray.pydata.org/en/stable/generated/xarray.Dataset.sel.html>`__ so that a data group can be selected by name by using the `data_group_name` parameter.
59
+ For more information on data groups see `Data Groups <https://xradio.readthedocs.io/en/latest/measurement_set_overview.html#Data-Groups>`__ section. See `xarray.Dataset.sel <https://xarray.pydata.org/en/stable/generated/xarray.Dataset.sel.html>`__ for parameter descriptions.
60
+
61
+ Returns
62
+ -------
63
+ xarray.DataTree
64
+ xarray DataTree with MeasurementSetXdt accessors
65
+
66
+ Examples
67
+ --------
68
+ >>> # Select data group 'corrected' and polarization 'XX'.
69
+ >>> selected_ms_xdt = ms_xdt.xr_ms.sel(data_group_name='corrected', polarization='XX')
70
+
71
+ >>> # Select data group 'corrected' and polarization 'XX' using a dict.
72
+ >>> selected_ms_xdt = ms_xdt.xr_ms.sel({'data_group_name':'corrected', 'polarization':'XX')
73
+ """
74
+
75
+ if self._xdt.attrs.get("type") not in MS_DATASET_TYPES:
76
+ raise InvalidAccessorLocation(f"{self._xdt.path} is not a MSv4 node.")
77
+
78
+ assert self._xdt.attrs["type"] in [
79
+ "visibility",
80
+ "spectrum",
81
+ "radiometer",
82
+ ], "The type of the xdt must be 'visibility', 'spectrum' or 'radiometer'."
83
+
84
+ if "data_group_name" in indexers_kwargs:
85
+ data_group_name = indexers_kwargs["data_group_name"]
86
+ del indexers_kwargs["data_group_name"]
87
+ elif (indexers is not None) and ("data_group_name" in indexers):
88
+ data_group_name = indexers["data_group_name"]
89
+ del indexers["data_group_name"]
90
+ else:
91
+ data_group_name = None
92
+
93
+ if data_group_name is not None:
94
+ sel_data_group_set = set(
95
+ self._xdt.attrs["data_groups"][data_group_name].values()
96
+ )
97
+
98
+ data_variables_to_drop = []
99
+ for dg in self._xdt.attrs["data_groups"].values():
100
+ temp_set = set(dg.values()) - sel_data_group_set
101
+ data_variables_to_drop.extend(list(temp_set))
102
+
103
+ data_variables_to_drop = list(set(data_variables_to_drop))
104
+
105
+ sel_ms_xdt = self._xdt
106
+
107
+ sel_corr_xds = self._xdt.ds.sel(
108
+ indexers, method, tolerance, drop, **indexers_kwargs
109
+ ).drop_vars(data_variables_to_drop)
110
+
111
+ sel_ms_xdt.ds = sel_corr_xds
112
+
113
+ sel_ms_xdt.attrs["data_groups"] = {
114
+ data_group_name: self._xdt.attrs["data_groups"][data_group_name]
115
+ }
116
+
117
+ return sel_ms_xdt
118
+ else:
119
+ return self._xdt.sel(indexers, method, tolerance, drop, **indexers_kwargs)
120
+
121
+ def get_field_and_source_xds(self, data_group_name: str = None) -> xr.Dataset:
122
+ """Get the field_and_source_xds associated with data group `data_group_name`.
123
+
124
+ Parameters
125
+ ----------
126
+ data_group_name : str, optional
127
+ The data group to process. Default is "base" or if not found to first data group.
128
+
129
+ Returns
130
+ -------
131
+ xarray.Dataset
132
+ field_and_source_xds associated with the data group.
133
+ """
134
+ if self._xdt.attrs.get("type") not in MS_DATASET_TYPES:
135
+ raise InvalidAccessorLocation(f"{self._xdt.path} is not a MSv4 node.")
136
+
137
+ if data_group_name is None:
138
+ if "base" in self._xdt.attrs["data_groups"].keys():
139
+ data_group_name = "base"
140
+ else:
141
+ data_group_name = list(self._xdt.attrs["data_groups"].keys())[0]
142
+
143
+ return self._xdt[f"field_and_source_xds_{data_group_name}"].ds
144
+
145
+ def get_partition_info(self, data_group_name: str = None) -> dict:
146
+ """
147
+ Generate a partition info dict for an MSv4, with general MSv4 description including
148
+ information such as field and source names, SPW name, scan name, the intents string,
149
+ etc.
150
+
151
+ The information is gathered from various coordinates, secondary datasets, and info
152
+ dicts of the MSv4. For example, the SPW name comes from the attributes of the
153
+ frequency coordinate, whereas field and source related information such as field and
154
+ source names come from the field_and_source_xds (base) dataset of the MSv4.
155
+
156
+ Parameters
157
+ ----------
158
+ data_group_name : str, optional
159
+ The data group to process. Default is "base" or if not found to first data group.
160
+
161
+ Returns
162
+ -------
163
+ dict
164
+ Partition info dict for the MSv4
165
+ """
166
+ if self._xdt.attrs.get("type") not in MS_DATASET_TYPES:
167
+ raise InvalidAccessorLocation(
168
+ f"{self._xdt.path} is not a MSv4 node (type {self._xdt.attrs.get('type')}."
169
+ )
170
+
171
+ if data_group_name is None:
172
+ if "base" in self._xdt.attrs["data_groups"].keys():
173
+ data_group_name = "base"
174
+ else:
175
+ data_group_name = list(self._xdt.attrs["data_groups"].keys())[0]
176
+
177
+ field_and_source_xds = self._xdt.xr_ms.get_field_and_source_xds(data_group_name)
178
+
179
+ if "line_name" in field_and_source_xds.coords:
180
+ line_name = to_list(
181
+ np.unique(np.ravel(field_and_source_xds.line_name.values))
182
+ )
183
+ else:
184
+ line_name = []
185
+
186
+ partition_info = {
187
+ "spectral_window_name": self._xdt.frequency.attrs["spectral_window_name"],
188
+ "field_name": to_list(np.unique(field_and_source_xds.field_name.values)),
189
+ "polarization_setup": to_list(self._xdt.polarization.values),
190
+ "scan_name": to_list(np.unique(self._xdt.scan_name.values)),
191
+ "source_name": to_list(np.unique(field_and_source_xds.source_name.values)),
192
+ "intents": self._xdt.observation_info["intents"],
193
+ "line_name": line_name,
194
+ "data_group_name": data_group_name,
195
+ }
196
+
197
+ return partition_info