xradio 1.0.2__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. xradio/_utils/_casacore/casacore_from_casatools.py +1 -1
  2. xradio/_utils/dict_helpers.py +38 -7
  3. xradio/_utils/list_and_array.py +26 -3
  4. xradio/_utils/schema.py +44 -0
  5. xradio/_utils/xarray_helpers.py +63 -0
  6. xradio/_utils/zarr/common.py +4 -2
  7. xradio/image/__init__.py +4 -2
  8. xradio/image/_util/_casacore/common.py +2 -1
  9. xradio/image/_util/_casacore/xds_from_casacore.py +105 -51
  10. xradio/image/_util/_casacore/xds_to_casacore.py +117 -52
  11. xradio/image/_util/_fits/xds_from_fits.py +124 -36
  12. xradio/image/_util/_zarr/common.py +0 -1
  13. xradio/image/_util/casacore.py +133 -16
  14. xradio/image/_util/common.py +6 -5
  15. xradio/image/_util/image_factory.py +466 -27
  16. xradio/image/image.py +72 -100
  17. xradio/image/image_xds.py +262 -0
  18. xradio/image/schema.py +85 -0
  19. xradio/measurement_set/__init__.py +5 -4
  20. xradio/measurement_set/_utils/_msv2/_tables/read.py +7 -3
  21. xradio/measurement_set/_utils/_msv2/conversion.py +6 -9
  22. xradio/measurement_set/_utils/_msv2/create_field_and_source_xds.py +1 -0
  23. xradio/measurement_set/_utils/_msv2/msv4_sub_xdss.py +1 -1
  24. xradio/measurement_set/_utils/_utils/interpolate.py +5 -0
  25. xradio/measurement_set/_utils/_utils/partition_attrs.py +0 -1
  26. xradio/measurement_set/convert_msv2_to_processing_set.py +9 -9
  27. xradio/measurement_set/load_processing_set.py +2 -2
  28. xradio/measurement_set/measurement_set_xdt.py +83 -93
  29. xradio/measurement_set/open_processing_set.py +1 -1
  30. xradio/measurement_set/processing_set_xdt.py +33 -26
  31. xradio/schema/check.py +70 -19
  32. xradio/schema/common.py +0 -1
  33. xradio/testing/__init__.py +0 -0
  34. xradio/testing/_utils/__template__.py +58 -0
  35. xradio/testing/measurement_set/__init__.py +58 -0
  36. xradio/testing/measurement_set/checker.py +131 -0
  37. xradio/testing/measurement_set/io.py +22 -0
  38. xradio/testing/measurement_set/msv2_io.py +1854 -0
  39. {xradio-1.0.2.dist-info → xradio-1.1.0.dist-info}/METADATA +64 -23
  40. xradio-1.1.0.dist-info/RECORD +75 -0
  41. {xradio-1.0.2.dist-info → xradio-1.1.0.dist-info}/WHEEL +1 -1
  42. xradio-1.0.2.dist-info/RECORD +0 -66
  43. {xradio-1.0.2.dist-info → xradio-1.1.0.dist-info}/licenses/LICENSE.txt +0 -0
  44. {xradio-1.0.2.dist-info → xradio-1.1.0.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,7 @@ convert, and retrieve information from Processing Set and Measurement Sets nodes
4
4
  Processing Set DataTree
5
5
  """
6
6
 
7
- import toolviper.utils.logger as _logger
7
+ import warnings
8
8
 
9
9
  from .processing_set_xdt import ProcessingSetXdt
10
10
  from .open_processing_set import open_processing_set
@@ -27,9 +27,10 @@ try:
27
27
  estimate_conversion_memory_and_cores,
28
28
  )
29
29
  except ModuleNotFoundError as exc:
30
- _logger.warning(
31
- "Could not import the function to convert from MSv2 to MSv4. "
32
- f"That functionality will not be available. Details: {exc}"
30
+ warnings.warn(
31
+ f"Could not import the function to convert from MSv2 to MSv4. "
32
+ f"That functionality will not be available. Details: {exc}",
33
+ UserWarning,
33
34
  )
34
35
  else:
35
36
  __all__.extend(
@@ -93,9 +93,12 @@ def convert_mjd_time(rawtimes: np.ndarray) -> np.ndarray:
93
93
  np.ndarray
94
94
  times converted to pandas reference and datetime type
95
95
  """
96
+
97
+ print("^^^^^^^", rawtimes, MJD_DIF_UNIX, SECS_IN_DAY)
96
98
  times_reref = pd.to_datetime(
97
99
  (rawtimes - MJD_DIF_UNIX) * SECS_IN_DAY, unit="s"
98
100
  ).values
101
+ print("^^^^^^^", times_reref)
99
102
 
100
103
  return times_reref
101
104
 
@@ -160,7 +163,7 @@ def make_taql_where_between_min_max(
160
163
  if min_max_range is None:
161
164
  taql = None
162
165
  else:
163
- (min_val, max_val) = min_max_range
166
+ min_val, max_val = min_max_range
164
167
  taql = f"where {colname} >= {min_val} AND {colname} <= {max_val}"
165
168
 
166
169
  return taql
@@ -226,7 +229,7 @@ def find_projected_min_max_array(
226
229
  """Does the min/max checks and search for find_projected_min_max_table()"""
227
230
 
228
231
  sorted_array = np.sort(array)
229
- (range_min, range_max) = min_max
232
+ range_min, range_max = min_max
230
233
  if len(sorted_array) < 2:
231
234
  tol = np.finfo(sorted_array.dtype).eps * 4
232
235
  else:
@@ -891,7 +894,8 @@ def raw_col_data_to_coords_vars(
891
894
 
892
895
  if col in timecols:
893
896
  if col == "MJD":
894
- data = convert_mjd_time(data).astype("float64") / 1e9
897
+ # data = convert_mjd_time(data).astype("float64") / 1e9
898
+ data = convert_mjd_time(data).astype("datetime64[ns]").view("int64") / 1e9
895
899
  else:
896
900
  try:
897
901
  data = convert_casacore_time(data, False)
@@ -1017,7 +1017,7 @@ def convert_and_write_partition(
1017
1017
  add_reshaping_indices: bool = False,
1018
1018
  storage_backend="zarr",
1019
1019
  parallel_mode: str = "none",
1020
- overwrite: bool = False,
1020
+ persistence_mode: str = "w-",
1021
1021
  ):
1022
1022
  """_summary_
1023
1023
 
@@ -1057,8 +1057,8 @@ def convert_and_write_partition(
1057
1057
  _description_, by default "zarr"
1058
1058
  parallel_mode : _type_, optional
1059
1059
  _description_
1060
- overwrite : bool, optional
1061
- _description_, by default False
1060
+ persistence_mode: str = "w-",
1061
+ _description_, by default "w-"
1062
1062
 
1063
1063
  Returns
1064
1064
  -------
@@ -1368,11 +1368,6 @@ def convert_and_write_partition(
1368
1368
  pathlib.Path(in_file).name.replace(".ms", "") + "_" + str(ms_v4_id),
1369
1369
  )
1370
1370
 
1371
- if overwrite:
1372
- mode = "w"
1373
- else:
1374
- mode = "w-"
1375
-
1376
1371
  if is_single_dish:
1377
1372
  xds.attrs["type"] = "spectrum"
1378
1373
  xds = xds.drop_vars("UVW")
@@ -1416,7 +1411,9 @@ def convert_and_write_partition(
1416
1411
  ms_xdt["/phased_array_xds"] = phased_array_xds
1417
1412
 
1418
1413
  if storage_backend == "zarr":
1419
- ms_xdt.to_zarr(store=os.path.join(out_file, ms_v4_name), mode=mode)
1414
+ ms_xdt.to_zarr(
1415
+ store=os.path.join(out_file, ms_v4_name), mode=persistence_mode
1416
+ )
1420
1417
  elif storage_backend == "netcdf":
1421
1418
  # xds.to_netcdf(path=file_name+"/MAIN", mode=mode) #Does not work
1422
1419
  raise
@@ -285,6 +285,7 @@ def extract_ephemeris_info(
285
285
  # Metadata has to be fixed manually. Alternatively, issues like
286
286
  # UNIT/QuantumUnits issue could be handled in convert_generic_xds_to_xradio_schema,
287
287
  # but for now preferring not to pollute that function.
288
+
288
289
  time_ephemeris_dim = ["time_ephemeris"]
289
290
  to_new_data_variables = {
290
291
  # mandatory: SOURCE_RADIAL_VELOCITY
@@ -27,7 +27,6 @@ from xradio.measurement_set._utils._msv2._tables.read import (
27
27
  table_has_column,
28
28
  )
29
29
 
30
-
31
30
  standard_time_coord_attrs = make_time_measure_attrs(time_format="unix")
32
31
 
33
32
 
@@ -81,6 +80,7 @@ def rename_and_interpolate_to_time(
81
80
  )
82
81
 
83
82
  # rename the time_* axis to time.
83
+
84
84
  time_coord = {"time": (time_initial_name, interp_time.data)}
85
85
  renamed_time_xds = interpolated_xds.assign_coords(time_coord)
86
86
  renamed_time_xds.coords["time"].attrs.update(standard_time_coord_attrs)
@@ -41,6 +41,10 @@ def interpolate_to_time(
41
41
  method = "linear"
42
42
  else:
43
43
  method = "nearest"
44
+
45
+ # print("xds before interp:",xds.NORTH_POLE_ANGULAR_DISTANCE.values, xds[time_name].values)
46
+ # print("interp_time data:",interp_time,interp_time.data)
47
+ # print("method:",method)
44
48
  xds = xds.interp(
45
49
  {time_name: interp_time.data}, method=method, assume_sorted=True
46
50
  )
@@ -56,5 +60,6 @@ def interpolate_to_time(
56
60
  f"{message_prefix}: interpolating the time coordinate "
57
61
  f"from {points_before} to {points_after} points"
58
62
  )
63
+ # print("xds after interp:",xds.NORTH_POLE_ANGULAR_DISTANCE.values, xds[time_name].values)
59
64
 
60
65
  return xds
@@ -2,7 +2,6 @@ from typing import Dict, TypedDict, Union
2
2
 
3
3
  import xarray as xr
4
4
 
5
-
6
5
  PartitionIds = TypedDict(
7
6
  "PartitionIds",
8
7
  {
@@ -68,7 +68,7 @@ def convert_msv2_to_processing_set(
68
68
  add_reshaping_indices: bool = False,
69
69
  storage_backend: Literal["zarr", "netcdf"] = "zarr",
70
70
  parallel_mode: Literal["none", "partition", "time"] = "none",
71
- overwrite: bool = False,
71
+ persistence_mode: str = "w-",
72
72
  ):
73
73
  """Convert a Measurement Set v2 into a Processing Set of Measurement Set v4.
74
74
 
@@ -110,8 +110,11 @@ def convert_msv2_to_processing_set(
110
110
  Choose whether to use Dask to execute conversion in parallel, by default "none" and conversion occurs serially.
111
111
  The option "partition", parallelises the conversion over partitions specified by `partition_scheme`. The option "time" can only be used for phased array interferometers where there are no partitions
112
112
  in the MS v2; instead the MS v2 is parallelised along the time dimension and can be controlled by `main_chunksize`.
113
- overwrite : bool, optional
114
- Whether to overwrite an existing processing set, by default False.
113
+ persistence_mode : str, optional
114
+ “w” means create (overwrite if exists);
115
+ “w-” means create (fail if exists);
116
+ “a” means override all existing variables including dimension coordinates (create if does not exist); Use this mode if you want to add to an existing Processing Set.
117
+ The default is "w-".
115
118
  """
116
119
 
117
120
  # Create empty data tree
@@ -122,10 +125,7 @@ def convert_msv2_to_processing_set(
122
125
  if not str(out_file).endswith("ps.zarr"):
123
126
  out_file += ".ps.zarr"
124
127
 
125
- if overwrite:
126
- ps_dt.to_zarr(store=out_file, mode="w")
127
- else:
128
- ps_dt.to_zarr(store=out_file, mode="w-")
128
+ ps_dt.to_zarr(store=out_file, mode=persistence_mode)
129
129
 
130
130
  # Check `parallel_mode` is valid
131
131
  try:
@@ -192,7 +192,7 @@ def convert_msv2_to_processing_set(
192
192
  add_reshaping_indices=add_reshaping_indices,
193
193
  compressor=compressor,
194
194
  parallel_mode=parallel_mode,
195
- overwrite=overwrite,
195
+ persistence_mode=persistence_mode,
196
196
  )
197
197
  )
198
198
  else:
@@ -214,7 +214,7 @@ def convert_msv2_to_processing_set(
214
214
  add_reshaping_indices=add_reshaping_indices,
215
215
  compressor=compressor,
216
216
  parallel_mode=parallel_mode,
217
- overwrite=overwrite,
217
+ persistence_mode=persistence_mode,
218
218
  )
219
219
  end_time = time.time()
220
220
  logger.debug(
@@ -1,8 +1,6 @@
1
1
  import os
2
2
  from typing import Dict, Union
3
- import dask
4
3
  import xarray as xr
5
- import s3fs
6
4
 
7
5
 
8
6
  def load_processing_set(
@@ -48,6 +46,8 @@ def load_processing_set(
48
46
  In memory representation of processing set using xr.DataTree.
49
47
  """
50
48
  from xradio._utils.zarr.common import _get_file_system_and_items
49
+ import dask
50
+ import s3fs
51
51
 
52
52
  file_system, ms_store_list = _get_file_system_and_items(ps_store)
53
53
 
@@ -5,7 +5,8 @@ from typing import Any, Union
5
5
  import numpy as np
6
6
  import xarray as xr
7
7
 
8
- from xradio._utils.list_and_array import to_list
8
+ from xradio._utils.list_and_array import to_python_type
9
+ from xradio._utils.xarray_helpers import get_data_group_name, create_new_data_group
9
10
 
10
11
  MS_DATASET_TYPES = {"visibility", "spectrum", "radiometer"}
11
12
 
@@ -151,11 +152,7 @@ class MeasurementSetXdt:
151
152
  if self._xdt.attrs.get("type") not in MS_DATASET_TYPES:
152
153
  raise InvalidAccessorLocation(f"{self._xdt.path} is not a MSv4 node.")
153
154
 
154
- if data_group_name is None:
155
- if "base" in self._xdt.attrs["data_groups"]:
156
- data_group_name = "base"
157
- else:
158
- data_group_name = list(self._xdt.attrs["data_groups"].keys())[0]
155
+ data_group_name = get_data_group_name(self._xdt, data_group_name)
159
156
 
160
157
  field_and_source_xds_name = self._xdt.attrs["data_groups"][data_group_name][
161
158
  "field_and_source"
@@ -188,16 +185,12 @@ class MeasurementSetXdt:
188
185
  f"{self._xdt.path} is not a MSv4 node (type {self._xdt.attrs.get('type')}."
189
186
  )
190
187
 
191
- if data_group_name is None:
192
- if "base" in self._xdt.attrs["data_groups"]:
193
- data_group_name = "base"
194
- else:
195
- data_group_name = list(self._xdt.attrs["data_groups"].keys())[0]
188
+ data_group_name = get_data_group_name(self._xdt, data_group_name)
196
189
 
197
190
  field_and_source_xds = self._xdt.xr_ms.get_field_and_source_xds(data_group_name)
198
191
 
199
192
  if "line_name" in field_and_source_xds.coords:
200
- line_name = to_list(
193
+ line_name = to_python_type(
201
194
  np.unique(np.ravel(field_and_source_xds.line_name.values))
202
195
  )
203
196
  else:
@@ -218,10 +211,14 @@ class MeasurementSetXdt:
218
211
  partition_info = {
219
212
  "spectral_window_name": self._xdt.frequency.attrs["spectral_window_name"],
220
213
  "spectral_window_intents": spw_intent,
221
- "field_name": to_list(np.unique(field_and_source_xds.field_name.values)),
222
- "polarization_setup": to_list(self._xdt.polarization.values),
223
- "scan_name": to_list(np.unique(self._xdt.scan_name.values)),
224
- "source_name": to_list(np.unique(field_and_source_xds.source_name.values)),
214
+ "field_name": to_python_type(
215
+ np.unique(field_and_source_xds.field_name.values)
216
+ ),
217
+ "polarization_setup": to_python_type(self._xdt.polarization.values),
218
+ "scan_name": to_python_type(np.unique(self._xdt.scan_name.values)),
219
+ "source_name": to_python_type(
220
+ np.unique(field_and_source_xds.source_name.values)
221
+ ),
225
222
  "scan_intents": scan_intents,
226
223
  "line_name": line_name,
227
224
  "data_group_name": data_group_name,
@@ -232,13 +229,7 @@ class MeasurementSetXdt:
232
229
  def add_data_group(
233
230
  self,
234
231
  new_data_group_name: str,
235
- correlated_data: str = None,
236
- weight: str = None,
237
- flag: str = None,
238
- uvw: str = None,
239
- field_and_source_xds: str = None,
240
- date_time: str = None,
241
- description: str = None,
232
+ new_data_group: dict = {},
242
233
  data_group_dv_shared_with: str = None,
243
234
  ) -> xr.DataTree:
244
235
  """Adds a data group to the MSv4 DataTree, grouping the given data, weight, flag, etc. variables
@@ -248,20 +239,8 @@ class MeasurementSetXdt:
248
239
  ----------
249
240
  new_data_group_name : str
250
241
  _description_
251
- correlated_data : str, optional
252
- _description_, by default None
253
- weights : str, optional
254
- _description_, by default None
255
- flag : str, optional
256
- _description_, by default None
257
- uvw : str, optional
258
- _description_, by default None
259
- field_and_source_xds : str, optional
260
- _description_, by default None
261
- date_time : str, optional
262
- _description_, by default None
263
- description : str, optional
264
- _description_, by default None
242
+ new_data_group : dict
243
+ _description_
265
244
  data_group_dv_shared_with : str, optional
266
245
  _description_, by default "base"
267
246
 
@@ -271,63 +250,74 @@ class MeasurementSetXdt:
271
250
  MSv4 DataTree with the new group added
272
251
  """
273
252
 
274
- if data_group_dv_shared_with is None:
275
- data_group_dv_shared_with = self._xdt.xr_ms._get_default_data_group_name()
276
- default_data_group = self._xdt.attrs["data_groups"][data_group_dv_shared_with]
277
-
278
- new_data_group = {}
279
-
280
- if correlated_data is None:
281
- correlated_data = default_data_group["correlated_data"]
282
- new_data_group["correlated_data"] = correlated_data
283
- assert (
284
- correlated_data in self._xdt.ds.data_vars
285
- ), f"Data variable {correlated_data} not found in dataset."
286
-
287
- if weight is None:
288
- weight = default_data_group["weight"]
289
- new_data_group["weight"] = weight
290
- assert (
291
- weight in self._xdt.ds.data_vars
292
- ), f"Data variable {weight} not found in dataset."
293
-
294
- if flag is None:
295
- flag = default_data_group["flag"]
296
- new_data_group["flag"] = flag
297
- assert (
298
- flag in self._xdt.ds.data_vars
299
- ), f"Data variable {flag} not found in dataset."
300
-
301
- if self._xdt.attrs["type"] == "visibility":
302
- if uvw is None:
303
- uvw = default_data_group["uvw"]
304
- new_data_group["uvw"] = uvw
305
- assert (
306
- uvw in self._xdt.ds.data_vars
307
- ), f"Data variable {uvw} not found in dataset."
308
-
309
- if field_and_source_xds is None:
310
- field_and_source_xds = default_data_group["field_and_source"]
311
- new_data_group["field_and_source"] = field_and_source_xds
312
- assert (
313
- field_and_source_xds in self._xdt.children
314
- ), f"Data variable {field_and_source_xds} not found in dataset."
315
-
316
- if date_time is None:
317
- date_time = datetime.datetime.now(datetime.timezone.utc).isoformat()
318
- new_data_group["date"] = date_time
319
-
320
- if description is None:
321
- description = ""
322
- new_data_group["description"] = description
253
+ if self._xdt.attrs.get("type") not in MS_DATASET_TYPES:
254
+ raise InvalidAccessorLocation(
255
+ f"{self._xdt.path} is not a MSv4 node (type {self._xdt.attrs.get('type')}."
256
+ )
257
+
258
+ new_data_group_name, new_data_group = create_new_data_group(
259
+ self._xdt,
260
+ "msv4",
261
+ new_data_group_name,
262
+ new_data_group,
263
+ data_group_dv_shared_with=data_group_dv_shared_with,
264
+ )
323
265
 
324
266
  self._xdt.attrs["data_groups"][new_data_group_name] = new_data_group
325
-
326
267
  return self._xdt
327
268
 
328
- def _get_default_data_group_name(self):
329
- if "base" in self._xdt.attrs["data_groups"]:
330
- data_group_name = "base"
331
- else:
332
- data_group_name = list(self._xdt.attrs["data_groups"].keys())[0]
333
- return data_group_name
269
+ # data_group_dv_shared_with = get_data_group_name(
270
+ # self._xdt, data_group_dv_shared_with
271
+ # )
272
+
273
+ # default_data_group = self._xdt.attrs["data_groups"][data_group_dv_shared_with]
274
+
275
+ # new_data_group = {}
276
+
277
+ # if correlated_data is None:
278
+ # correlated_data = default_data_group["correlated_data"]
279
+ # new_data_group["correlated_data"] = correlated_data
280
+ # assert (
281
+ # correlated_data in self._xdt.ds.data_vars
282
+ # ), f"Data variable {correlated_data} not found in dataset."
283
+
284
+ # if weight is None:
285
+ # weight = default_data_group["weight"]
286
+ # new_data_group["weight"] = weight
287
+ # assert (
288
+ # weight in self._xdt.ds.data_vars
289
+ # ), f"Data variable {weight} not found in dataset."
290
+
291
+ # if flag is None:
292
+ # flag = default_data_group["flag"]
293
+ # new_data_group["flag"] = flag
294
+ # assert (
295
+ # flag in self._xdt.ds.data_vars
296
+ # ), f"Data variable {flag} not found in dataset."
297
+
298
+ # if self._xdt.attrs["type"] == "visibility":
299
+ # if uvw is None:
300
+ # uvw = default_data_group["uvw"]
301
+ # new_data_group["uvw"] = uvw
302
+ # assert (
303
+ # uvw in self._xdt.ds.data_vars
304
+ # ), f"Data variable {uvw} not found in dataset."
305
+
306
+ # if field_and_source_xds is None:
307
+ # field_and_source_xds = default_data_group["field_and_source"]
308
+ # new_data_group["field_and_source"] = field_and_source_xds
309
+ # assert (
310
+ # field_and_source_xds in self._xdt.children
311
+ # ), f"Data variable {field_and_source_xds} not found in dataset."
312
+
313
+ # if date_time is None:
314
+ # date_time = datetime.datetime.now(datetime.timezone.utc).isoformat()
315
+ # new_data_group["date"] = date_time
316
+
317
+ # if description is None:
318
+ # description = ""
319
+ # new_data_group["description"] = description
320
+
321
+ # self._xdt.attrs["data_groups"][new_data_group_name] = new_data_group
322
+
323
+ # return self._xdt
@@ -1,5 +1,4 @@
1
1
  from xradio._utils.zarr.common import _get_file_system_and_items
2
- import s3fs
3
2
  import xarray as xr
4
3
 
5
4
 
@@ -25,6 +24,7 @@ def open_processing_set(
25
24
  """
26
25
 
27
26
  file_system, ms_store_list = _get_file_system_and_items(ps_store)
27
+ import s3fs
28
28
 
29
29
  if isinstance(file_system, s3fs.core.S3FileSystem):
30
30
  mapping = s3fs.S3Map(root=ps_store, s3=file_system, check=False)
@@ -2,6 +2,7 @@ import pandas as pd
2
2
  from xradio._utils.list_and_array import to_list
3
3
  import numpy as np
4
4
  import xarray as xr
5
+ from xradio.measurement_set.measurement_set_xdt import get_data_group_name
5
6
 
6
7
  PS_DATASET_TYPES = {"processing_set"}
7
8
 
@@ -38,7 +39,7 @@ class ProcessingSetXdt:
38
39
  self.meta = {"summary": {}}
39
40
 
40
41
  def summary(
41
- self, data_group: str | None = None, first_columns: list[str] = None
42
+ self, data_group_name: str | None = None, first_columns: list[str] = None
42
43
  ) -> pd.DataFrame:
43
44
  """
44
45
  Generate and retrieve a summary of the Processing Set as a data frame.
@@ -53,7 +54,7 @@ class ProcessingSetXdt:
53
54
 
54
55
  Parameters
55
56
  ----------
56
- data_group : str, optional
57
+ data_group_name : str, optional
57
58
  The data group to summarize. By default the "base" group
58
59
  is used (if found), or otherwise the first group found.
59
60
  first_columns : list[str], optional
@@ -73,31 +74,33 @@ class ProcessingSetXdt:
73
74
  A DataFrame containing the summary information of the specified data group.
74
75
  """
75
76
 
76
- def find_data_group_base_or_first(data_group: str, xdt: xr.DataTree) -> str:
77
+ def find_data_group_base_or_first(
78
+ data_group_name: str, xdt: xr.DataTree
79
+ ) -> str:
77
80
  first_msv4 = next(iter(xdt.values()))
78
81
  first_data_groups = first_msv4.attrs["data_groups"]
79
- if data_group is None:
80
- data_group = (
82
+ if data_group_name is None:
83
+ data_group_name = (
81
84
  "base"
82
85
  if "base" in first_data_groups
83
86
  else next(iter(first_data_groups))
84
87
  )
85
- return data_group
88
+ return data_group_name
86
89
 
87
90
  if self._xdt.attrs.get("type") not in PS_DATASET_TYPES:
88
91
  raise InvalidAccessorLocation(
89
92
  f"{self._xdt.path} is not a processing set node."
90
93
  )
91
94
 
92
- data_group = find_data_group_base_or_first(data_group, self._xdt)
95
+ data_group_name = find_data_group_base_or_first(data_group_name, self._xdt)
93
96
 
94
- if data_group in self.meta["summary"]:
95
- summary = self.meta["summary"][data_group]
97
+ if data_group_name in self.meta["summary"]:
98
+ summary = self.meta["summary"][data_group_name]
96
99
  else:
97
- self.meta["summary"][data_group] = self._summary(data_group).sort_values(
98
- by=["name"], ascending=True
99
- )
100
- summary = self.meta["summary"][data_group]
100
+ self.meta["summary"][data_group_name] = self._summary(
101
+ data_group_name
102
+ ).sort_values(by=["name"], ascending=True)
103
+ summary = self.meta["summary"][data_group_name]
101
104
 
102
105
  if first_columns:
103
106
  found_columns = [col for col in first_columns if col in summary.columns]
@@ -187,7 +190,7 @@ class ProcessingSetXdt:
187
190
  self.meta["freq_axis"] = freq_axis
188
191
  return self.meta["freq_axis"]
189
192
 
190
- def _summary(self, data_group: str = None):
193
+ def _summary(self, data_group_name: str = None):
191
194
  summary_data = {
192
195
  "name": [],
193
196
  "scan_intents": [],
@@ -225,7 +228,7 @@ class ProcessingSetXdt:
225
228
  )
226
229
  summary_data["polarization"].append(value.polarization.values)
227
230
  summary_data["scan_name"].append(partition_info["scan_name"])
228
- data_name = value.attrs["data_groups"][data_group]["correlated_data"]
231
+ data_name = value.attrs["data_groups"][data_group_name]["correlated_data"]
229
232
 
230
233
  if "VISIBILITY" in data_name:
231
234
  center_name = "FIELD_PHASE_CENTER_DIRECTION"
@@ -253,7 +256,7 @@ class ProcessingSetXdt:
253
256
  )
254
257
  summary_data["end_frequency"].append(to_list(value["frequency"].values)[-1])
255
258
 
256
- field_and_source_xds = value.xr_ms.get_field_and_source_xds(data_group)
259
+ field_and_source_xds = value.xr_ms.get_field_and_source_xds(data_group_name)
257
260
 
258
261
  if field_and_source_xds.attrs["type"] == "field_and_source_ephemeris":
259
262
  summary_data["field_coords"].append("Ephemeris")
@@ -377,14 +380,16 @@ class ProcessingSetXdt:
377
380
 
378
381
  return sub_ps_xdt
379
382
 
380
- def get_combined_field_and_source_xds(self, data_group: str = "base") -> xr.Dataset:
383
+ def get_combined_field_and_source_xds(
384
+ self, data_group_name: str = "base"
385
+ ) -> xr.Dataset:
381
386
  """
382
387
  Combine all non-ephemeris `field_and_source_xds` datasets from a Processing Set for a data group into a
383
388
  single dataset.
384
389
 
385
390
  Parameters
386
391
  ----------
387
- data_group : str, optional
392
+ data_group_name : str, optional
388
393
  The data group to process. Default is "base".
389
394
 
390
395
  Returns
@@ -405,7 +410,9 @@ class ProcessingSetXdt:
405
410
 
406
411
  combined_field_and_source_xds = xr.Dataset()
407
412
  for ms_name, ms_xdt in self._xdt.items():
408
- field_and_source_xds = ms_xdt.xr_ms.get_field_and_source_xds(data_group)
413
+ field_and_source_xds = ms_xdt.xr_ms.get_field_and_source_xds(
414
+ data_group_name
415
+ )
409
416
 
410
417
  if not field_and_source_xds.attrs["type"] == "field_and_source_ephemeris":
411
418
 
@@ -474,14 +481,14 @@ class ProcessingSetXdt:
474
481
  return combined_field_and_source_xds
475
482
 
476
483
  def get_combined_field_and_source_xds_ephemeris(
477
- self, data_group: str = "base"
484
+ self, data_group_name: str = "base"
478
485
  ) -> xr.Dataset:
479
486
  """
480
487
  Combine all ephemeris `field_and_source_xds` datasets from a Processing Set for a datagroup into a single dataset.
481
488
 
482
489
  Parameters
483
490
  ----------
484
- data_group : str, optional
491
+ data_group_name : str, optional
485
492
  The data group to process. Default is "base".
486
493
 
487
494
  Returns
@@ -503,7 +510,7 @@ class ProcessingSetXdt:
503
510
  combined_ephemeris_field_and_source_xds = xr.Dataset()
504
511
  for ms_name, ms_xdt in self._xdt.items():
505
512
  field_and_source_xds = field_and_source_xds = (
506
- ms_xdt.xr_ms.get_field_and_source_xds(data_group)
513
+ ms_xdt.xr_ms.get_field_and_source_xds(data_group_name)
507
514
  )
508
515
 
509
516
  if field_and_source_xds.attrs["type"] == "field_and_source_ephemeris":
@@ -599,7 +606,7 @@ class ProcessingSetXdt:
599
606
  return combined_ephemeris_field_and_source_xds
600
607
 
601
608
  def plot_phase_centers(
602
- self, label_all_fields: bool = False, data_group: str = "base"
609
+ self, label_all_fields: bool = False, data_group_name: str = "base"
603
610
  ):
604
611
  """
605
612
  Plot the phase center locations of all fields in the Processing Set.
@@ -612,7 +619,7 @@ class ProcessingSetXdt:
612
619
  ----------
613
620
  label_all_fields : bool, optional
614
621
  If `True`, all fields will be labeled on the plot. Default is `False`.
615
- data_group : str, optional
622
+ data_group_name : str, optional
616
623
  The data group to use for processing. Default is "base".
617
624
 
618
625
  Returns
@@ -641,10 +648,10 @@ class ProcessingSetXdt:
641
648
  )
642
649
 
643
650
  combined_field_and_source_xds = self.get_combined_field_and_source_xds(
644
- data_group
651
+ data_group_name
645
652
  )
646
653
  combined_ephemeris_field_and_source_xds = (
647
- self.get_combined_field_and_source_xds_ephemeris(data_group)
654
+ self.get_combined_field_and_source_xds_ephemeris(data_group_name)
648
655
  )
649
656
  from matplotlib import pyplot as plt
650
657