xradio 0.0.47__py3-none-any.whl → 0.0.49__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. xradio/__init__.py +1 -0
  2. xradio/_utils/dict_helpers.py +69 -2
  3. xradio/_utils/list_and_array.py +3 -1
  4. xradio/_utils/schema.py +3 -1
  5. xradio/image/_util/__init__.py +0 -3
  6. xradio/image/_util/_casacore/common.py +0 -13
  7. xradio/image/_util/_casacore/xds_from_casacore.py +102 -97
  8. xradio/image/_util/_casacore/xds_to_casacore.py +36 -24
  9. xradio/image/_util/_fits/xds_from_fits.py +81 -36
  10. xradio/image/_util/_zarr/zarr_low_level.py +3 -3
  11. xradio/image/_util/casacore.py +7 -5
  12. xradio/image/_util/common.py +13 -26
  13. xradio/image/_util/image_factory.py +143 -191
  14. xradio/image/image.py +10 -59
  15. xradio/measurement_set/__init__.py +11 -6
  16. xradio/measurement_set/_utils/_msv2/_tables/read.py +187 -46
  17. xradio/measurement_set/_utils/_msv2/_tables/table_query.py +22 -0
  18. xradio/measurement_set/_utils/_msv2/conversion.py +347 -299
  19. xradio/measurement_set/_utils/_msv2/create_field_and_source_xds.py +233 -150
  20. xradio/measurement_set/_utils/_msv2/descr.py +1 -1
  21. xradio/measurement_set/_utils/_msv2/msv4_info_dicts.py +20 -13
  22. xradio/measurement_set/_utils/_msv2/msv4_sub_xdss.py +21 -22
  23. xradio/measurement_set/convert_msv2_to_processing_set.py +46 -6
  24. xradio/measurement_set/load_processing_set.py +100 -52
  25. xradio/measurement_set/measurement_set_xdt.py +197 -0
  26. xradio/measurement_set/open_processing_set.py +122 -86
  27. xradio/measurement_set/processing_set_xdt.py +1552 -0
  28. xradio/measurement_set/schema.py +375 -197
  29. xradio/schema/bases.py +5 -1
  30. xradio/schema/check.py +97 -5
  31. xradio/sphinx/schema_table.py +12 -0
  32. {xradio-0.0.47.dist-info → xradio-0.0.49.dist-info}/METADATA +4 -4
  33. {xradio-0.0.47.dist-info → xradio-0.0.49.dist-info}/RECORD +36 -36
  34. {xradio-0.0.47.dist-info → xradio-0.0.49.dist-info}/WHEEL +1 -1
  35. xradio/measurement_set/measurement_set_xds.py +0 -117
  36. xradio/measurement_set/processing_set.py +0 -777
  37. {xradio-0.0.47.dist-info → xradio-0.0.49.dist-info/licenses}/LICENSE.txt +0 -0
  38. {xradio-0.0.47.dist-info → xradio-0.0.49.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1552 @@
1
+ import pandas as pd
2
+ from xradio._utils.list_and_array import to_list
3
+ import numbers
4
+ import numpy as np
5
+ import toolviper.utils.logger as logger
6
+ import xarray as xr
7
+
8
+ PS_DATASET_TYPES = {"processing_set"}
9
+
10
+
11
+ class InvalidAccessorLocation(ValueError):
12
+ """
13
+ Raised by Processing Set accessor functions called on a wrong DataTree node (not processing set).
14
+ """
15
+
16
+ pass
17
+
18
+
19
+ @xr.register_datatree_accessor("xr_ps")
20
+ class ProcessingSetXdt:
21
+ """
22
+ Accessor to Processing Set DataTree nodes. Provides Processing Set specific functionality such
23
+ as producing a summary of the processing set (with information from all its MSv4s), or retrieving
24
+ combined antenna or field_and_source datasets.
25
+ """
26
+
27
+ _xdt: xr.DataTree
28
+
29
+ def __init__(self, datatree: xr.DataTree):
30
+ """
31
+ Initialize the ProcessingSetXdt instance.
32
+
33
+ Parameters
34
+ ----------
35
+ datatree: xarray.DataTree
36
+ The Processing Set DataTree node to construct a ProcessingSetXdt accessor.
37
+ """
38
+
39
+ self._xdt = datatree
40
+ self.meta = {"summary": {}}
41
+
42
+ def summary(self, data_group: str = "base") -> pd.DataFrame:
43
+ """
44
+ Generate and retrieve a summary of the Processing Set.
45
+
46
+ The summary includes information such as the names of the Measurement Sets,
47
+ their intents, polarizations, spectral window names, field names, source names,
48
+ field coordinates, start frequencies, and end frequencies.
49
+
50
+ Parameters
51
+ ----------
52
+ data_group : str, optional
53
+ The data group to summarize. Default is "base".
54
+
55
+ Returns
56
+ -------
57
+ pandas.DataFrame
58
+ A DataFrame containing the summary information of the specified data group.
59
+ """
60
+
61
+ if self._xdt.attrs.get("type") not in PS_DATASET_TYPES:
62
+ raise InvalidAccessorLocation(
63
+ f"{self._xdt.path} is not a processing set node."
64
+ )
65
+
66
+ if data_group in self.meta["summary"]:
67
+ return self.meta["summary"][data_group]
68
+ else:
69
+ self.meta["summary"][data_group] = self._summary(data_group).sort_values(
70
+ by=["name"], ascending=True
71
+ )
72
+ return self.meta["summary"][data_group]
73
+
74
+ def get_max_dims(self) -> dict[str, int]:
75
+ """
76
+ Determine the maximum dimensions across all Measurement Sets in the Processing Set.
77
+
78
+ This method examines each Measurement Set's dimensions and computes the maximum
79
+ size for each dimension across the entire Processing Set.
80
+
81
+ For example, if the Processing Set contains two MSs with dimensions (50, 20, 30) and (10, 30, 40),
82
+ the maximum dimensions will be (50, 30, 40).
83
+
84
+ Returns
85
+ -------
86
+ dict
87
+ A dictionary containing the maximum dimensions of the Processing Set, with dimension names as keys
88
+ and their maximum sizes as values.
89
+ """
90
+
91
+ if self._xdt.attrs.get("type") not in PS_DATASET_TYPES:
92
+ raise InvalidAccessorLocation(
93
+ f"{self._xdt.path} is not a processing set node."
94
+ )
95
+
96
+ if "max_dims" in self.meta:
97
+ return self.meta["max_dims"]
98
+ else:
99
+ max_dims = None
100
+ for ms_xdt in self._xdt.values():
101
+ if max_dims is None:
102
+ max_dims = dict(ms_xdt.sizes)
103
+ else:
104
+ for dim_name, size in ms_xdt.sizes.items():
105
+ if dim_name in max_dims:
106
+ if max_dims[dim_name] < size:
107
+ max_dims[dim_name] = size
108
+ else:
109
+ max_dims[dim_name] = size
110
+ self.meta["max_dims"] = max_dims
111
+ return self.meta["max_dims"]
112
+
113
+ def get_freq_axis(self) -> xr.DataArray:
114
+ """
115
+ Combine the frequency axes of all Measurement Sets in the Processing Set.
116
+
117
+ This method aggregates the frequency information from each Measurement Set to create
118
+ a unified frequency axis for the entire Processing Set.
119
+
120
+ Returns
121
+ -------
122
+ xarray.DataArray
123
+ The combined frequency axis of the Processing Set.
124
+ """
125
+ if self._xdt.attrs.get("type") not in PS_DATASET_TYPES:
126
+ raise InvalidAccessorLocation(
127
+ f"{self._xdt.path} is not a processing set node."
128
+ )
129
+
130
+ if "freq_axis" in self.meta:
131
+ return self.meta["freq_axis"]
132
+ else:
133
+ spw_ids = []
134
+ freq_axis_list = []
135
+ frame = self._xdt[next(iter(self._xdt.children))].frequency.attrs[
136
+ "observer"
137
+ ]
138
+ for ms_xdt in self._xdt.values():
139
+ assert (
140
+ frame == ms_xdt.frequency.attrs["observer"]
141
+ ), "Frequency reference frame not consistent in Processing Set."
142
+ if ms_xdt.frequency.attrs["spectral_window_id"] not in spw_ids:
143
+ spw_ids.append(ms_xdt.frequency.attrs["spectral_window_id"])
144
+ freq_axis_list.append(ms_xdt.frequency)
145
+
146
+ freq_axis = xr.concat(freq_axis_list, dim="frequency").sortby("frequency")
147
+ self.meta["freq_axis"] = freq_axis
148
+ return self.meta["freq_axis"]
149
+
150
+ def _summary(self, data_group: str = "base"):
151
+ summary_data = {
152
+ "name": [],
153
+ "intents": [],
154
+ "shape": [],
155
+ "polarization": [],
156
+ "scan_name": [],
157
+ "spw_name": [],
158
+ "field_name": [],
159
+ "source_name": [],
160
+ "line_name": [],
161
+ "field_coords": [],
162
+ "start_frequency": [],
163
+ "end_frequency": [],
164
+ }
165
+ from astropy.coordinates import SkyCoord
166
+ import astropy.units as u
167
+
168
+ for key, value in self._xdt.items():
169
+ partition_info = value.xr_ms.get_partition_info()
170
+
171
+ summary_data["name"].append(key)
172
+ summary_data["intents"].append(partition_info["intents"])
173
+ summary_data["spw_name"].append(partition_info["spectral_window_name"])
174
+ summary_data["polarization"].append(value.polarization.values)
175
+ summary_data["scan_name"].append(partition_info["scan_name"])
176
+ data_name = value.attrs["data_groups"][data_group]["correlated_data"]
177
+
178
+ if "VISIBILITY" in data_name:
179
+ center_name = "FIELD_PHASE_CENTER"
180
+
181
+ if "SPECTRUM" in data_name:
182
+ center_name = "FIELD_REFERENCE_CENTER"
183
+
184
+ summary_data["shape"].append(value[data_name].shape)
185
+
186
+ summary_data["field_name"].append(partition_info["field_name"])
187
+ summary_data["source_name"].append(partition_info["source_name"])
188
+
189
+ summary_data["line_name"].append(partition_info["line_name"])
190
+
191
+ summary_data["start_frequency"].append(
192
+ to_list(value["frequency"].values)[0]
193
+ )
194
+ summary_data["end_frequency"].append(to_list(value["frequency"].values)[-1])
195
+
196
+ field_and_source_xds = value["field_and_source_xds_" + data_group]
197
+
198
+ if field_and_source_xds.attrs["type"] == "field_and_source_ephemeris":
199
+ summary_data["field_coords"].append("Ephemeris")
200
+ elif field_and_source_xds[center_name]["field_name"].size > 1:
201
+ summary_data["field_coords"].append("Multi-Phase-Center")
202
+ else:
203
+ ra_dec_rad = field_and_source_xds[center_name].values[0, :]
204
+ frame = field_and_source_xds[center_name].attrs["frame"].lower()
205
+
206
+ coord = SkyCoord(
207
+ ra=ra_dec_rad[0] * u.rad, dec=ra_dec_rad[1] * u.rad, frame=frame
208
+ )
209
+
210
+ summary_data["field_coords"].append(
211
+ [
212
+ frame,
213
+ coord.ra.to_string(unit=u.hour, precision=2),
214
+ coord.dec.to_string(unit=u.deg, precision=2),
215
+ ]
216
+ )
217
+
218
+ summary_df = pd.DataFrame(summary_data)
219
+ return summary_df
220
+
221
+ def query(
222
+ self, string_exact_match: bool = True, query: str = None, **kwargs
223
+ ) -> xr.DataTree:
224
+ """
225
+ Select a subset of the Processing Set based on specified criteria.
226
+
227
+ This method allows filtering the Processing Set by matching column names and values
228
+ or by applying a Pandas query string. The selection criteria can target various
229
+ attributes of the Measurement Sets such as intents, polarization, spectral window names, etc.
230
+
231
+ A data group can be selected by name by using the `data_group_name` parameter. This is applied to each Measurement Set in the Processing Set.
232
+
233
+ Note
234
+ ----
235
+ This selection does not modify the actual data within the Measurement Sets. For example, if
236
+ a Measurement Set has `field_name=['field_0','field_10','field_08']` and `ps.query(field_name='field_0')`
237
+ is invoked, the resulting subset will still contain the original list `['field_0','field_10','field_08']`.
238
+ The exception is data group selection, using `data_group_name`, that will select data variables only associated with the specified data group in the Measurement Set.
239
+
240
+ Parameters
241
+ ----------
242
+ string_exact_match : bool, optional
243
+ If `True`, string matching will require exact matches for string and string list columns.
244
+ If `False`, partial matches are allowed. Default is `True`.
245
+ query : str, optional
246
+ A Pandas query string to apply additional filtering. Default is `None`.
247
+ **kwargs : dict
248
+ Keyword arguments representing column names and their corresponding values to filter the Processing Set.
249
+
250
+ Returns
251
+ -------
252
+ xr.DataTree
253
+ A new Processing Set DataTree instance containing only the Measurement Sets that match the selection criteria.
254
+
255
+ Examples
256
+ --------
257
+ >>> # Select all MSs with intents 'OBSERVE_TARGET#ON_SOURCE' and polarization 'RR' or 'LL'
258
+ >>> selected_ps = ps.query(intents='OBSERVE_TARGET#ON_SOURCE', polarization=['RR', 'LL'])
259
+
260
+ >>> # Select all MSs with start_frequency greater than 100 GHz and less than 200 GHz
261
+ >>> selected_ps = ps.query(query='start_frequency > 100e9 AND end_frequency < 200e9')
262
+ """
263
+
264
+ if self._xdt.attrs.get("type") not in PS_DATASET_TYPES:
265
+ raise InvalidAccessorLocation(
266
+ f"{self._xdt.path} is not a processing set node."
267
+ )
268
+
269
+ def select_rows(df, col, sel_vals, string_exact_match):
270
+ def check_selection(row_val):
271
+ row_val = to_list(
272
+ row_val
273
+ ) # make sure that it is a list so that we can iterate over it.
274
+
275
+ for rw in row_val:
276
+ for s in sel_vals:
277
+ if string_exact_match:
278
+ if rw == s:
279
+ return True
280
+ else:
281
+ if s in rw:
282
+ return True
283
+ return False
284
+
285
+ return df[df[col].apply(check_selection)]
286
+
287
+ summary_table = self.summary()
288
+ data_group_name = None
289
+ for key, value in kwargs.items():
290
+
291
+ if "data_group_name" == key:
292
+ data_group_name = value
293
+ else:
294
+ value = to_list(value) # make sure value is a list.
295
+
296
+ if len(value) == 1 and isinstance(value[0], slice):
297
+ summary_table = summary_table[
298
+ summary_table[key].between(value[0].start, value[0].stop)
299
+ ]
300
+ else:
301
+ summary_table = select_rows(
302
+ summary_table, key, value, string_exact_match
303
+ )
304
+
305
+ if query is not None:
306
+ summary_table = summary_table.query(query)
307
+
308
+ sub_ps_xdt = xr.DataTree()
309
+ for key, val in self._xdt.items():
310
+ if key in summary_table["name"].values:
311
+ if data_group_name is not None:
312
+ sub_ps_xdt[key] = val.xr_ms.sel(data_group_name=data_group_name)
313
+ else:
314
+ sub_ps_xdt[key] = val
315
+
316
+ sub_ps_xdt.attrs = self._xdt.attrs
317
+
318
+ return sub_ps_xdt
319
+
320
+ def get_combined_field_and_source_xds(self, data_group: str = "base") -> xr.Dataset:
321
+ """
322
+ Combine all non-ephemeris `field_and_source_xds` datasets from a Processing Set for a data group into a
323
+ single dataset.
324
+
325
+ Parameters
326
+ ----------
327
+ data_group : str, optional
328
+ The data group to process. Default is "base".
329
+
330
+ Returns
331
+ -------
332
+ xarray.Dataset
333
+ combined_field_and_source_xds: Combined dataset for standard (non-ephemeris) fields.
334
+
335
+ Raises
336
+ ------
337
+ ValueError
338
+ If the `field_and_source_xds` attribute is missing or improperly formatted in any Measurement Set.
339
+ """
340
+
341
+ if self._xdt.attrs.get("type") not in PS_DATASET_TYPES:
342
+ raise InvalidAccessorLocation(
343
+ f"{self._xdt.path} is not a processing set node."
344
+ )
345
+
346
+ combined_field_and_source_xds = xr.Dataset()
347
+ for ms_name, ms_xdt in self._xdt.items():
348
+ correlated_data_name = ms_xdt.attrs["data_groups"][data_group][
349
+ "correlated_data"
350
+ ]
351
+
352
+ field_and_source_xds = ms_xdt["field_and_source_xds_" + data_group].ds
353
+
354
+ if not field_and_source_xds.attrs["type"] == "field_and_source_ephemeris":
355
+
356
+ if (
357
+ "line_name" in field_and_source_xds.coords
358
+ ): # Not including line info since it is a function of spw.
359
+ field_and_source_xds = field_and_source_xds.drop_vars(
360
+ ["LINE_REST_FREQUENCY", "LINE_SYSTEMIC_VELOCITY"],
361
+ errors="ignore",
362
+ )
363
+ del field_and_source_xds["line_name"]
364
+ del field_and_source_xds["line_label"]
365
+
366
+ if len(combined_field_and_source_xds.data_vars) == 0:
367
+ combined_field_and_source_xds = field_and_source_xds
368
+ else:
369
+ combined_field_and_source_xds = xr.concat(
370
+ [combined_field_and_source_xds, field_and_source_xds],
371
+ dim="field_name",
372
+ )
373
+
374
+ if (len(combined_field_and_source_xds.data_vars) > 0) and (
375
+ "FIELD_PHASE_CENTER" in combined_field_and_source_xds
376
+ ):
377
+ combined_field_and_source_xds = (
378
+ combined_field_and_source_xds.drop_duplicates("field_name")
379
+ )
380
+
381
+ combined_field_and_source_xds["MEAN_PHASE_CENTER"] = (
382
+ combined_field_and_source_xds["FIELD_PHASE_CENTER"].mean(
383
+ dim=["field_name"]
384
+ )
385
+ )
386
+
387
+ ra1 = (
388
+ combined_field_and_source_xds["FIELD_PHASE_CENTER"]
389
+ .sel(sky_dir_label="ra")
390
+ .values
391
+ )
392
+ dec1 = (
393
+ combined_field_and_source_xds["FIELD_PHASE_CENTER"]
394
+ .sel(sky_dir_label="dec")
395
+ .values
396
+ )
397
+ ra2 = (
398
+ combined_field_and_source_xds["MEAN_PHASE_CENTER"]
399
+ .sel(sky_dir_label="ra")
400
+ .values
401
+ )
402
+ dec2 = (
403
+ combined_field_and_source_xds["MEAN_PHASE_CENTER"]
404
+ .sel(sky_dir_label="dec")
405
+ .values
406
+ )
407
+
408
+ from xradio._utils.coord_math import haversine
409
+
410
+ distance = haversine(ra1, dec1, ra2, dec2)
411
+ min_index = distance.argmin()
412
+
413
+ combined_field_and_source_xds.attrs["center_field_name"] = (
414
+ combined_field_and_source_xds.field_name[min_index].values
415
+ )
416
+
417
+ return combined_field_and_source_xds
418
+
419
+ def get_combined_field_and_source_xds_ephemeris(
420
+ self, data_group: str = "base"
421
+ ) -> xr.Dataset:
422
+ """
423
+ Combine all ephemeris `field_and_source_xds` datasets from a Processing Set for a datagroup into a single dataset.
424
+
425
+ Parameters
426
+ ----------
427
+ data_group : str, optional
428
+ The data group to process. Default is "base".
429
+
430
+ Returns
431
+ -------
432
+ xarray.Dataset
433
+ combined_ephemeris_field_and_source_xds: Combined dataset for ephemeris fields.
434
+
435
+ Raises
436
+ ------
437
+ ValueError
438
+ If the `field_and_source_xds` attribute is missing or improperly formatted in any Measurement Set.
439
+ """
440
+
441
+ if self._xdt.attrs.get("type") not in PS_DATASET_TYPES:
442
+ raise InvalidAccessorLocation(
443
+ f"{self._xdt.path} is not a processing set node."
444
+ )
445
+
446
+ combined_ephemeris_field_and_source_xds = xr.Dataset()
447
+ for ms_name, ms_xdt in self._xdt.items():
448
+
449
+ correlated_data_name = ms_xdt.attrs["data_groups"][data_group][
450
+ "correlated_data"
451
+ ]
452
+
453
+ field_and_source_xds = field_and_source_xds = ms_xdt[
454
+ "field_and_source_xds_" + data_group
455
+ ].ds
456
+
457
+ if field_and_source_xds.attrs["type"] == "field_and_source_ephemeris":
458
+
459
+ if (
460
+ "line_name" in field_and_source_xds.coords
461
+ ): # Not including line info since it is a function of spw.
462
+ field_and_source_xds = field_and_source_xds.drop_vars(
463
+ ["LINE_REST_FREQUENCY", "LINE_SYSTEMIC_VELOCITY"],
464
+ errors="ignore",
465
+ )
466
+ del field_and_source_xds["line_name"]
467
+ del field_and_source_xds["line_label"]
468
+
469
+ from xradio.measurement_set._utils._msv2.msv4_sub_xdss import (
470
+ interpolate_to_time,
471
+ )
472
+
473
+ if "time_ephemeris" in field_and_source_xds:
474
+ field_and_source_xds = interpolate_to_time(
475
+ field_and_source_xds,
476
+ field_and_source_xds.time,
477
+ "field_and_source_xds",
478
+ "time_ephemeris",
479
+ )
480
+ del field_and_source_xds["time_ephemeris"]
481
+ field_and_source_xds = field_and_source_xds.rename(
482
+ {"time_ephemeris": "time"}
483
+ )
484
+
485
+ if "OBSERVER_POSITION" in field_and_source_xds:
486
+ field_and_source_xds = field_and_source_xds.drop_vars(
487
+ ["OBSERVER_POSITION"], errors="ignore"
488
+ )
489
+
490
+ if len(combined_ephemeris_field_and_source_xds.data_vars) == 0:
491
+ combined_ephemeris_field_and_source_xds = field_and_source_xds
492
+ else:
493
+
494
+ combined_ephemeris_field_and_source_xds = xr.concat(
495
+ [combined_ephemeris_field_and_source_xds, field_and_source_xds],
496
+ dim="time",
497
+ )
498
+
499
+ if (len(combined_ephemeris_field_and_source_xds.data_vars) > 0) and (
500
+ "FIELD_PHASE_CENTER" in combined_ephemeris_field_and_source_xds
501
+ ):
502
+
503
+ from xradio._utils.coord_math import wrap_to_pi
504
+
505
+ offset = (
506
+ combined_ephemeris_field_and_source_xds["FIELD_PHASE_CENTER"]
507
+ - combined_ephemeris_field_and_source_xds["SOURCE_LOCATION"]
508
+ )
509
+ combined_ephemeris_field_and_source_xds["FIELD_OFFSET"] = xr.DataArray(
510
+ wrap_to_pi(offset.sel(sky_pos_label=["ra", "dec"])).values,
511
+ dims=["time", "sky_dir_label"],
512
+ )
513
+ combined_ephemeris_field_and_source_xds["FIELD_OFFSET"].attrs = (
514
+ combined_ephemeris_field_and_source_xds["FIELD_PHASE_CENTER"].attrs
515
+ )
516
+ combined_ephemeris_field_and_source_xds["FIELD_OFFSET"].attrs["units"] = (
517
+ combined_ephemeris_field_and_source_xds["FIELD_OFFSET"].attrs["units"][
518
+ :2
519
+ ]
520
+ )
521
+
522
+ ra1 = (
523
+ combined_ephemeris_field_and_source_xds["FIELD_OFFSET"]
524
+ .sel(sky_dir_label="ra")
525
+ .values
526
+ )
527
+ dec1 = (
528
+ combined_ephemeris_field_and_source_xds["FIELD_OFFSET"]
529
+ .sel(sky_dir_label="dec")
530
+ .values
531
+ )
532
+ ra2 = 0.0
533
+ dec2 = 0.0
534
+
535
+ from xradio._utils.coord_math import haversine
536
+
537
+ distance = haversine(ra1, dec1, ra2, dec2)
538
+ min_index = distance.argmin()
539
+
540
+ combined_ephemeris_field_and_source_xds.attrs["center_field_name"] = (
541
+ combined_ephemeris_field_and_source_xds.field_name[min_index].values
542
+ )
543
+
544
+ return combined_ephemeris_field_and_source_xds
545
+
546
+ def plot_phase_centers(
547
+ self, label_all_fields: bool = False, data_group: str = "base"
548
+ ):
549
+ """
550
+ Plot the phase center locations of all fields in the Processing Set.
551
+
552
+ This method is primarily used for visualizing mosaics. It generates scatter plots of
553
+ the phase center coordinates for both standard and ephemeris fields. The central field
554
+ is highlighted in red based on the closest phase center calculation.
555
+
556
+ Parameters
557
+ ----------
558
+ label_all_fields : bool, optional
559
+ If `True`, all fields will be labeled on the plot. Default is `False`.
560
+ data_group : str, optional
561
+ The data group to use for processing. Default is "base".
562
+
563
+ Returns
564
+ -------
565
+ None
566
+
567
+ Raises
568
+ ------
569
+ ValueError
570
+ If the combined datasets are empty or improperly formatted.
571
+ """
572
+ if self._xdt.attrs.get("type") not in PS_DATASET_TYPES:
573
+ raise InvalidAccessorLocation(
574
+ f"{self._xdt.path} is not a processing set node."
575
+ )
576
+
577
+ combined_field_and_source_xds = self.get_combined_field_and_source_xds(
578
+ data_group
579
+ )
580
+ combined_ephemeris_field_and_source_xds = (
581
+ self.get_combined_field_and_source_xds_ephemeris(data_group)
582
+ )
583
+ from matplotlib import pyplot as plt
584
+
585
+ if (len(combined_field_and_source_xds.data_vars) > 0) and (
586
+ "FIELD_PHASE_CENTER" in combined_field_and_source_xds
587
+ ):
588
+ plt.figure()
589
+ plt.title("Field Phase Center Locations")
590
+ plt.scatter(
591
+ combined_field_and_source_xds["FIELD_PHASE_CENTER"].sel(
592
+ sky_dir_label="ra"
593
+ ),
594
+ combined_field_and_source_xds["FIELD_PHASE_CENTER"].sel(
595
+ sky_dir_label="dec"
596
+ ),
597
+ )
598
+
599
+ center_field_name = combined_field_and_source_xds.attrs["center_field_name"]
600
+ center_field = combined_field_and_source_xds.sel(
601
+ field_name=center_field_name
602
+ )
603
+ plt.scatter(
604
+ center_field["FIELD_PHASE_CENTER"].sel(sky_dir_label="ra"),
605
+ center_field["FIELD_PHASE_CENTER"].sel(sky_dir_label="dec"),
606
+ color="red",
607
+ label=center_field_name,
608
+ )
609
+ plt.xlabel("RA (rad)")
610
+ plt.ylabel("DEC (rad)")
611
+ plt.legend()
612
+ plt.show()
613
+
614
+ if (len(combined_ephemeris_field_and_source_xds.data_vars) > 0) and (
615
+ "FIELD_PHASE_CENTER" in combined_ephemeris_field_and_source_xds
616
+ ):
617
+
618
+ plt.figure()
619
+ plt.title(
620
+ "Offset of Field Phase Center from Source Location (Ephemeris Data)"
621
+ )
622
+ plt.scatter(
623
+ combined_ephemeris_field_and_source_xds["FIELD_OFFSET"].sel(
624
+ sky_dir_label="ra"
625
+ ),
626
+ combined_ephemeris_field_and_source_xds["FIELD_OFFSET"].sel(
627
+ sky_dir_label="dec"
628
+ ),
629
+ )
630
+
631
+ center_field_name = combined_ephemeris_field_and_source_xds.attrs[
632
+ "center_field_name"
633
+ ]
634
+
635
+ combined_ephemeris_field_and_source_xds = (
636
+ combined_ephemeris_field_and_source_xds.set_xindex("field_name")
637
+ )
638
+
639
+ center_field = combined_ephemeris_field_and_source_xds.sel(
640
+ field_name=center_field_name
641
+ )
642
+ plt.scatter(
643
+ center_field["FIELD_OFFSET"].sel(sky_dir_label="ra"),
644
+ center_field["FIELD_OFFSET"].sel(sky_dir_label="dec"),
645
+ color="red",
646
+ label=center_field_name,
647
+ )
648
+ plt.xlabel("RA Offset (rad)")
649
+ plt.ylabel("DEC Offset (rad)")
650
+ plt.legend()
651
+ plt.show()
652
+
653
+ def get_combined_antenna_xds(self) -> xr.Dataset:
654
+ """
655
+ Combine the `antenna_xds` datasets from all Measurement Sets into a single dataset.
656
+
657
+ This method concatenates the antenna datasets from each Measurement Set along the 'antenna_name' dimension.
658
+
659
+ Returns
660
+ -------
661
+ xarray.Dataset
662
+ A combined `xarray.Dataset` containing antenna information from all Measurement Sets.
663
+
664
+ Raises
665
+ ------
666
+ ValueError
667
+ If antenna datasets are missing required variables or improperly formatted.
668
+ """
669
+ if self._xdt.attrs.get("type") not in PS_DATASET_TYPES:
670
+ raise InvalidAccessorLocation(
671
+ f"{self._xdt.path} is not a processing set node."
672
+ )
673
+
674
+ combined_antenna_xds = xr.Dataset()
675
+ for cor_name, ms_xdt in self._xdt.items():
676
+ antenna_xds = ms_xdt.antenna_xds.ds
677
+
678
+ if len(combined_antenna_xds.data_vars) == 0:
679
+ combined_antenna_xds = antenna_xds
680
+ else:
681
+ combined_antenna_xds = xr.concat(
682
+ [combined_antenna_xds, antenna_xds],
683
+ dim="antenna_name",
684
+ data_vars="minimal",
685
+ coords="minimal",
686
+ )
687
+
688
+ # ALMA WVR antenna_xds data has a NaN value for the antenna receptor angle.
689
+ if "ANTENNA_RECEPTOR_ANGLE" in combined_antenna_xds.data_vars:
690
+ combined_antenna_xds = combined_antenna_xds.dropna("antenna_name")
691
+
692
+ combined_antenna_xds = combined_antenna_xds.drop_duplicates("antenna_name")
693
+
694
+ return combined_antenna_xds
695
+
696
+ def plot_antenna_positions(self):
697
+ """
698
+ Plot the antenna positions of all antennas in the Processing Set.
699
+
700
+ This method generates three scatter plots displaying the antenna positions in different planes:
701
+ - X vs Y
702
+ - X vs Z
703
+ - Y vs Z
704
+
705
+ Parameters
706
+ ----------
707
+ None
708
+
709
+ Returns
710
+ -------
711
+ None
712
+
713
+ Raises
714
+ ------
715
+ ValueError
716
+ If the combined antenna dataset is empty or missing required coordinates.
717
+ """
718
+ if self._xdt.attrs.get("type") not in PS_DATASET_TYPES:
719
+ raise InvalidAccessorLocation(
720
+ f"{self._xdt.path} is not a processing set node."
721
+ )
722
+
723
+ combined_antenna_xds = self.get_combined_antenna_xds()
724
+ from matplotlib import pyplot as plt
725
+
726
+ plt.figure()
727
+ plt.title("Antenna Positions")
728
+ plt.scatter(
729
+ combined_antenna_xds["ANTENNA_POSITION"].sel(cartesian_pos_label="x"),
730
+ combined_antenna_xds["ANTENNA_POSITION"].sel(cartesian_pos_label="y"),
731
+ )
732
+ plt.xlabel("x (m)")
733
+ plt.ylabel("y (m)")
734
+ plt.show()
735
+
736
+ plt.figure()
737
+ plt.title("Antenna Positions")
738
+ plt.scatter(
739
+ combined_antenna_xds["ANTENNA_POSITION"].sel(cartesian_pos_label="x"),
740
+ combined_antenna_xds["ANTENNA_POSITION"].sel(cartesian_pos_label="z"),
741
+ )
742
+ plt.xlabel("x (m)")
743
+ plt.ylabel("z (m)")
744
+ plt.show()
745
+
746
+ plt.figure()
747
+ plt.title("Antenna Positions")
748
+ plt.scatter(
749
+ combined_antenna_xds["ANTENNA_POSITION"].sel(cartesian_pos_label="y"),
750
+ combined_antenna_xds["ANTENNA_POSITION"].sel(cartesian_pos_label="z"),
751
+ )
752
+ plt.xlabel("y (m)")
753
+ plt.ylabel("z (m)")
754
+ plt.show()
755
+
756
+
757
+ # class ProcessingSet(dict):
758
+ # """
759
+ # A dictionary subclass representing a Processing Set (PS) containing Measurement Sets v4 (MS).
760
+
761
+ # This class extends the built-in `dict` class to provide additional methods for
762
+ # manipulating and selecting subsets of the Processing Set. It includes functionality
763
+ # for summarizing metadata, selecting subsets based on various criteria, and
764
+ # exporting the data to storage formats.
765
+
766
+ # Parameters
767
+ # ----------
768
+ # *args : dict, optional
769
+ # Variable length argument list passed to the base `dict` class.
770
+ # **kwargs : dict, optional
771
+ # Arbitrary keyword arguments passed to the base `dict` class.
772
+ # """
773
+
774
+ # def __init__(self, *args, **kwargs):
775
+ # """
776
+ # Initialize the ProcessingSet instance.
777
+
778
+ # Parameters
779
+ # ----------
780
+ # *args : dict, optional
781
+ # Variable length argument list passed to the base `dict` class.
782
+ # **kwargs : dict, optional
783
+ # Arbitrary keyword arguments passed to the base `dict` class.
784
+ # """
785
+ # super().__init__(*args, **kwargs)
786
+ # self.meta = {"summary": {}}
787
+
788
+ # def summary(self, data_group="base"):
789
+ # """
790
+ # Generate and retrieve a summary of the Processing Set.
791
+
792
+ # The summary includes information such as the names of the Measurement Sets,
793
+ # their intents, polarizations, spectral window names, field names, source names,
794
+ # field coordinates, start frequencies, and end frequencies.
795
+
796
+ # Parameters
797
+ # ----------
798
+ # data_group : str, optional
799
+ # The data group to summarize. Default is "base".
800
+
801
+ # Returns
802
+ # -------
803
+ # pandas.DataFrame
804
+ # A DataFrame containing the summary information of the specified data group.
805
+ # """
806
+
807
+ # if data_group in self.meta["summary"]:
808
+ # return self.meta["summary"][data_group]
809
+ # else:
810
+ # self.meta["summary"][data_group] = self._summary(data_group).sort_values(
811
+ # by=["name"], ascending=True
812
+ # )
813
+ # return self.meta["summary"][data_group]
814
+
815
+ # def get_ps_max_dims(self):
816
+ # """
817
+ # Determine the maximum dimensions across all Measurement Sets in the Processing Set.
818
+
819
+ # This method examines each Measurement Set's dimensions and computes the maximum
820
+ # size for each dimension across the entire Processing Set.
821
+
822
+ # For example, if the Processing Set contains two MSs with dimensions (50, 20, 30) and (10, 30, 40),
823
+ # the maximum dimensions will be (50, 30, 40).
824
+
825
+ # Returns
826
+ # -------
827
+ # dict
828
+ # A dictionary containing the maximum dimensions of the Processing Set, with dimension names as keys
829
+ # and their maximum sizes as values.
830
+ # """
831
+ # if "max_dims" in self.meta:
832
+ # return self.meta["max_dims"]
833
+ # else:
834
+ # self.meta["max_dims"] = self._get_ps_max_dims()
835
+ # return self.meta["max_dims"]
836
+
837
+ # def get_ps_freq_axis(self):
838
+ # """
839
+ # Combine the frequency axes of all Measurement Sets in the Processing Set.
840
+
841
+ # This method aggregates the frequency information from each Measurement Set to create
842
+ # a unified frequency axis for the entire Processing Set.
843
+
844
+ # Returns
845
+ # -------
846
+ # xarray.DataArray
847
+ # The combined frequency axis of the Processing Set.
848
+ # """
849
+ # if "freq_axis" in self.meta:
850
+ # return self.meta["freq_axis"]
851
+ # else:
852
+ # self.meta["freq_axis"] = self._get_ps_freq_axis()
853
+ # return self.meta["freq_axis"]
854
+
855
+ # def _summary(self, data_group="base"):
856
+ # summary_data = {
857
+ # "name": [],
858
+ # "intents": [],
859
+ # "shape": [],
860
+ # "polarization": [],
861
+ # "scan_name": [],
862
+ # "spw_name": [],
863
+ # "field_name": [],
864
+ # "source_name": [],
865
+ # "line_name": [],
866
+ # "field_coords": [],
867
+ # "start_frequency": [],
868
+ # "end_frequency": [],
869
+ # }
870
+ # from astropy.coordinates import SkyCoord
871
+ # import astropy.units as u
872
+
873
+ # for key, value in self.items():
874
+ # summary_data["name"].append(key)
875
+ # summary_data["intents"].append(partition_info["intents"])
876
+ # summary_data["spw_name"].append(
877
+ # partition_info["spectral_window_name"]
878
+ # )
879
+ # summary_data["polarization"].append(value.polarization.values)
880
+ # summary_data["scan_name"].append(partition_info["scan_name"])
881
+ # data_name = value.attrs["data_groups"][data_group]["correlated_data"]
882
+
883
+ # if "VISIBILITY" in data_name:
884
+ # center_name = "FIELD_PHASE_CENTER"
885
+
886
+ # if "SPECTRUM" in data_name:
887
+ # center_name = "FIELD_REFERENCE_CENTER"
888
+
889
+ # summary_data["shape"].append(value[data_name].shape)
890
+
891
+ # summary_data["field_name"].append(
892
+ # partition_info["field_name"]
893
+ # )
894
+ # summary_data["source_name"].append(
895
+ # partition_info["source_name"]
896
+ # )
897
+
898
+ # summary_data["line_name"].append(partition_info["line_name"])
899
+
900
+ # summary_data["start_frequency"].append(
901
+ # to_list(value["frequency"].values)[0]
902
+ # )
903
+ # summary_data["end_frequency"].append(to_list(value["frequency"].values)[-1])
904
+
905
+ # if (
906
+ # value[data_name].attrs["field_and_source_xds"].attrs["type"]
907
+ # == "field_and_source_ephemeris"
908
+ # ):
909
+ # summary_data["field_coords"].append("Ephemeris")
910
+ # # elif (
911
+ # # "time"
912
+ # # in value[data_name].attrs["field_and_source_xds"][center_name].coords
913
+ # # ):
914
+ # elif (
915
+ # value[data_name]
916
+ # .attrs["field_and_source_xds"][center_name]["field_name"]
917
+ # .size
918
+ # > 1
919
+ # ):
920
+ # summary_data["field_coords"].append("Multi-Phase-Center")
921
+ # else:
922
+ # ra_dec_rad = (
923
+ # value[data_name]
924
+ # .attrs["field_and_source_xds"][center_name]
925
+ # .values[0, :]
926
+ # )
927
+ # frame = (
928
+ # value[data_name]
929
+ # .attrs["field_and_source_xds"][center_name]
930
+ # .attrs["frame"]
931
+ # .lower()
932
+ # )
933
+
934
+ # coord = SkyCoord(
935
+ # ra=ra_dec_rad[0] * u.rad, dec=ra_dec_rad[1] * u.rad, frame=frame
936
+ # )
937
+
938
+ # summary_data["field_coords"].append(
939
+ # [
940
+ # frame,
941
+ # coord.ra.to_string(unit=u.hour, precision=2),
942
+ # coord.dec.to_string(unit=u.deg, precision=2),
943
+ # ]
944
+ # )
945
+
946
+ # summary_df = pd.DataFrame(summary_data)
947
+ # return summary_df
948
+
949
+ # def _get_ps_freq_axis(self):
950
+
951
+ # spw_ids = []
952
+ # freq_axis_list = []
953
+ # frame = self.get(0).frequency.attrs["observer"]
954
+ # for ms_xds in self.values():
955
+ # assert (
956
+ # frame == ms_xds.frequency.attrs["observer"]
957
+ # ), "Frequency reference frame not consistent in Processing Set."
958
+ # if ms_xds.frequency.attrs["spectral_window_id"] not in spw_ids:
959
+ # spw_ids.append(ms_xds.frequency.attrs["spectral_window_id"])
960
+ # freq_axis_list.append(ms_xds.frequency)
961
+
962
+ # freq_axis = xr.concat(freq_axis_list, dim="frequency").sortby("frequency")
963
+ # return freq_axis
964
+
965
+ # def _get_ps_max_dims(self):
966
+ # max_dims = None
967
+ # for ms_xds in self.values():
968
+ # if max_dims is None:
969
+ # max_dims = dict(ms_xds.sizes)
970
+ # else:
971
+ # for dim_name, size in ms_xds.sizes.items():
972
+ # if dim_name in max_dims:
973
+ # if max_dims[dim_name] < size:
974
+ # max_dims[dim_name] = size
975
+ # else:
976
+ # max_dims[dim_name] = size
977
+ # return max_dims
978
+
979
+ # def get(self, id):
980
+ # return self[list(self.keys())[id]]
981
+
982
+ # def sel(self, string_exact_match: bool = True, query: str = None, **kwargs):
983
+ # """
984
+ # Select a subset of the Processing Set based on specified criteria.
985
+
986
+ # This method allows filtering the Processing Set by matching column names and values
987
+ # or by applying a Pandas query string. The selection criteria can target various
988
+ # attributes of the Measurement Sets such as intents, polarization, spectral window names, etc.
989
+
990
+ # Note
991
+ # ----
992
+ # This selection does not modify the actual data within the Measurement Sets. For example, if
993
+ # a Measurement Set has `field_name=['field_0','field_10','field_08']` and `ps.query(field_name='field_0')`
994
+ # is invoked, the resulting subset will still contain the original list `['field_0','field_10','field_08']`.
995
+
996
+ # Parameters
997
+ # ----------
998
+ # string_exact_match : bool, optional
999
+ # If `True`, string matching will require exact matches for string and string list columns.
1000
+ # If `False`, partial matches are allowed. Default is `True`.
1001
+ # query : str, optional
1002
+ # A Pandas query string to apply additional filtering. Default is `None`.
1003
+ # **kwargs : dict
1004
+ # Keyword arguments representing column names and their corresponding values to filter the Processing Set.
1005
+
1006
+ # Returns
1007
+ # -------
1008
+ # ProcessingSet
1009
+ # A new `ProcessingSet` instance containing only the Measurement Sets that match the selection criteria.
1010
+
1011
+ # Examples
1012
+ # --------
1013
+ # >>> # Select all MSs with intents 'OBSERVE_TARGET#ON_SOURCE' and polarization 'RR' or 'LL'
1014
+ # >>> selected_ps = ps.query(intents='OBSERVE_TARGET#ON_SOURCE', polarization=['RR', 'LL'])
1015
+
1016
+ # >>> # Select all MSs with start_frequency greater than 100 GHz and less than 200 GHz
1017
+ # >>> selected_ps = ps.query(query='start_frequency > 100e9 AND end_frequency < 200e9')
1018
+ # """
1019
+ # import numpy as np
1020
+
1021
+ # def select_rows(df, col, sel_vals, string_exact_match):
1022
+ # def check_selection(row_val):
1023
+ # row_val = to_list(
1024
+ # row_val
1025
+ # ) # make sure that it is a list so that we can iterate over it.
1026
+
1027
+ # for rw in row_val:
1028
+ # for s in sel_vals:
1029
+ # if string_exact_match:
1030
+ # if rw == s:
1031
+ # return True
1032
+ # else:
1033
+ # if s in rw:
1034
+ # return True
1035
+ # return False
1036
+
1037
+ # return df[df[col].apply(check_selection)]
1038
+
1039
+ # summary_table = self.summary()
1040
+ # for key, value in kwargs.items():
1041
+ # value = to_list(value) # make sure value is a list.
1042
+
1043
+ # if len(value) == 1 and isinstance(value[0], slice):
1044
+ # summary_table = summary_table[
1045
+ # summary_table[key].between(value[0].start, value[0].stop)
1046
+ # ]
1047
+ # else:
1048
+ # summary_table = select_rows(
1049
+ # summary_table, key, value, string_exact_match
1050
+ # )
1051
+
1052
+ # if query is not None:
1053
+ # summary_table = summary_table.query(query)
1054
+
1055
+ # sub_ps = ProcessingSet()
1056
+ # for key, val in self.items():
1057
+ # if key in summary_table["name"].values:
1058
+ # sub_ps[key] = val
1059
+
1060
+ # return sub_ps
1061
+
1062
+ # def ms_sel(self, **kwargs):
1063
+ # """
1064
+ # Select a subset of the Processing Set by applying the `xarray.Dataset.sel` method to each Measurement Set.
1065
+
1066
+ # This method allows for selection based on label-based indexing for each dimension of the datasets.
1067
+
1068
+ # Parameters
1069
+ # ----------
1070
+ # **kwargs : dict
1071
+ # Keyword arguments representing dimension names and the labels to select along those dimensions.
1072
+ # These are passed directly to the `xarray.Dataset.sel <https://docs.xarray.dev/en/latest/generated/xarray.Dataset.sel.html>`__ method.
1073
+
1074
+ # Returns
1075
+ # -------
1076
+ # ProcessingSet
1077
+ # A new `ProcessingSet` instance containing the selected subsets of each Measurement Set.
1078
+ # """
1079
+ # sub_ps = ProcessingSet()
1080
+ # for key, val in self.items():
1081
+ # sub_ps[key] = val.sel(kwargs)
1082
+ # return sub_ps
1083
+
1084
+ # def ms_isel(self, **kwargs):
1085
+ # """
1086
+ # Select a subset of the Processing Set by applying the `isel` method to each Measurement Set.
1087
+
1088
+ # This method allows for selection based on integer-based indexing for each dimension of the datasets.
1089
+
1090
+ # Parameters
1091
+ # ----------
1092
+ # **kwargs : dict
1093
+ # Keyword arguments representing dimension names and the integer indices to select along those dimensions.
1094
+ # These are passed directly to the `xarray.Dataset.isel <https://docs.xarray.dev/en/latest/generated/xarray.Dataset.isel.html>`__ method.
1095
+
1096
+ # Returns
1097
+ # -------
1098
+ # ProcessingSet
1099
+ # A new `ProcessingSet` instance containing the selected subsets of each Measurement Set.
1100
+ # """
1101
+ # sub_ps = ProcessingSet()
1102
+ # for key, val in self.items():
1103
+ # sub_ps[key] = val.isel(kwargs)
1104
+ # return sub_ps
1105
+
1106
+ # def to_store(self, store, **kwargs):
1107
+ # """
1108
+ # Write the Processing Set to a Zarr store.
1109
+
1110
+ # This method serializes each Measurement Set within the Processing Set to a separate Zarr group
1111
+ # within the specified store directory. Note that writing to cloud storage is not supported yet.
1112
+
1113
+ # Parameters
1114
+ # ----------
1115
+ # store : str
1116
+ # The filesystem path to the Zarr store directory where the data will be saved.
1117
+ # **kwargs : dict, optional
1118
+ # Additional keyword arguments to be passed to the `xarray.Dataset.to_zarr` method.
1119
+ # Refer to the `xarray.Dataset.to_zarr <https://docs.xarray.dev/en/latest/generated/xarray.Dataset.to_zarr.html>`__
1120
+ # for available options.
1121
+
1122
+ # Returns
1123
+ # -------
1124
+ # None
1125
+
1126
+ # Raises
1127
+ # ------
1128
+ # OSError
1129
+ # If the specified store path is invalid or not writable.
1130
+
1131
+ # Examples
1132
+ # --------
1133
+ # >>> # Save the Processing Set to a local Zarr store
1134
+ # >>> ps.to_store('/path/to/zarr_store')
1135
+ # """
1136
+ # import os
1137
+
1138
+ # for key, value in self.items():
1139
+ # value.to_store(os.path.join(store, key), **kwargs)
1140
+
1141
+ # def get_combined_field_and_source_xds(self, data_group="base"):
1142
+ # """
1143
+ # Combine all non-ephemeris `field_and_source_xds` datasets from a Processing Set for a datagroup into a single dataset.
1144
+
1145
+ # Parameters
1146
+ # ----------
1147
+ # data_group : str, optional
1148
+ # The data group to process. Default is "base".
1149
+
1150
+ # Returns
1151
+ # -------
1152
+ # xarray.Dataset
1153
+ # combined_field_and_source_xds: Combined dataset for standard fields.
1154
+
1155
+ # Raises
1156
+ # ------
1157
+ # ValueError
1158
+ # If the `field_and_source_xds` attribute is missing or improperly formatted in any Measurement Set.
1159
+ # """
1160
+
1161
+ # combined_field_and_source_xds = xr.Dataset()
1162
+ # for ms_name, ms_xds in self.items():
1163
+ # correlated_data_name = ms_xds.attrs["data_groups"][data_group][
1164
+ # "correlated_data"
1165
+ # ]
1166
+
1167
+ # field_and_source_xds = (
1168
+ # ms_xds[correlated_data_name]
1169
+ # .attrs["field_and_source_xds"]
1170
+ # .copy(deep=True)
1171
+ # )
1172
+
1173
+ # if not field_and_source_xds.attrs["type"] == "field_and_source_ephemeris":
1174
+
1175
+ # if (
1176
+ # "line_name" in field_and_source_xds.coords
1177
+ # ): # Not including line info since it is a function of spw.
1178
+ # field_and_source_xds = field_and_source_xds.drop_vars(
1179
+ # ["LINE_REST_FREQUENCY", "LINE_SYSTEMIC_VELOCITY"],
1180
+ # errors="ignore",
1181
+ # )
1182
+ # del field_and_source_xds["line_name"]
1183
+ # del field_and_source_xds["line_label"]
1184
+
1185
+ # if len(combined_field_and_source_xds.data_vars) == 0:
1186
+ # combined_field_and_source_xds = field_and_source_xds
1187
+ # else:
1188
+ # combined_field_and_source_xds = xr.concat(
1189
+ # [combined_field_and_source_xds, field_and_source_xds],
1190
+ # dim="field_name",
1191
+ # )
1192
+
1193
+ # if (len(combined_field_and_source_xds.data_vars) > 0) and (
1194
+ # "FIELD_PHASE_CENTER" in combined_field_and_source_xds
1195
+ # ):
1196
+ # combined_field_and_source_xds = (
1197
+ # combined_field_and_source_xds.drop_duplicates("field_name")
1198
+ # )
1199
+
1200
+ # combined_field_and_source_xds["MEAN_PHASE_CENTER"] = (
1201
+ # combined_field_and_source_xds["FIELD_PHASE_CENTER"].mean(
1202
+ # dim=["field_name"]
1203
+ # )
1204
+ # )
1205
+
1206
+ # ra1 = (
1207
+ # combined_field_and_source_xds["FIELD_PHASE_CENTER"]
1208
+ # .sel(sky_dir_label="ra")
1209
+ # .values
1210
+ # )
1211
+ # dec1 = (
1212
+ # combined_field_and_source_xds["FIELD_PHASE_CENTER"]
1213
+ # .sel(sky_dir_label="dec")
1214
+ # .values
1215
+ # )
1216
+ # ra2 = (
1217
+ # combined_field_and_source_xds["MEAN_PHASE_CENTER"]
1218
+ # .sel(sky_dir_label="ra")
1219
+ # .values
1220
+ # )
1221
+ # dec2 = (
1222
+ # combined_field_and_source_xds["MEAN_PHASE_CENTER"]
1223
+ # .sel(sky_dir_label="dec")
1224
+ # .values
1225
+ # )
1226
+
1227
+ # from xradio._utils.coord_math import haversine
1228
+
1229
+ # distance = haversine(ra1, dec1, ra2, dec2)
1230
+ # min_index = distance.argmin()
1231
+
1232
+ # combined_field_and_source_xds.attrs["center_field_name"] = (
1233
+ # combined_field_and_source_xds.field_name[min_index].values
1234
+ # )
1235
+
1236
+ # return combined_field_and_source_xds
1237
+
1238
+
1239
+ # def get_combined_field_and_source_xds_ephemeris(self, data_group="base"):
1240
+ # """
1241
+ # Combine all ephemeris `field_and_source_xds` datasets from a Processing Set for a datagroup into a single dataset.
1242
+
1243
+ # Parameters
1244
+ # ----------
1245
+ # data_group : str, optional
1246
+ # The data group to process. Default is "base".
1247
+
1248
+ # Returns
1249
+ # -------
1250
+ # xarray.Dataset
1251
+ # - combined_ephemeris_field_and_source_xds: Combined dataset for ephemeris fields.
1252
+
1253
+ # Raises
1254
+ # ------
1255
+ # ValueError
1256
+ # If the `field_and_source_xds` attribute is missing or improperly formatted in any Measurement Set.
1257
+ # """
1258
+
1259
+ # combined_ephemeris_field_and_source_xds = xr.Dataset()
1260
+ # for ms_name, ms_xds in self.items():
1261
+
1262
+ # correlated_data_name = ms_xds.attrs["data_groups"][data_group][
1263
+ # "correlated_data"
1264
+ # ]
1265
+
1266
+ # field_and_source_xds = (
1267
+ # ms_xds[correlated_data_name]
1268
+ # .attrs["field_and_source_xds"]
1269
+ # .copy(deep=True)
1270
+ # )
1271
+
1272
+ # if field_and_source_xds.attrs["type"] == "field_and_source_ephemeris":
1273
+
1274
+ # if (
1275
+ # "line_name" in field_and_source_xds.coords
1276
+ # ): # Not including line info since it is a function of spw.
1277
+ # field_and_source_xds = field_and_source_xds.drop_vars(
1278
+ # ["LINE_REST_FREQUENCY", "LINE_SYSTEMIC_VELOCITY"],
1279
+ # errors="ignore",
1280
+ # )
1281
+ # del field_and_source_xds["line_name"]
1282
+ # del field_and_source_xds["line_label"]
1283
+
1284
+ # from xradio.measurement_set._utils._msv2.msv4_sub_xdss import (
1285
+ # interpolate_to_time,
1286
+ # )
1287
+
1288
+ # if "time_ephemeris" in field_and_source_xds:
1289
+ # field_and_source_xds = interpolate_to_time(
1290
+ # field_and_source_xds,
1291
+ # field_and_source_xds.time,
1292
+ # "field_and_source_xds",
1293
+ # "time_ephemeris",
1294
+ # )
1295
+ # del field_and_source_xds["time_ephemeris"]
1296
+ # field_and_source_xds = field_and_source_xds.rename(
1297
+ # {"time_ephemeris": "time"}
1298
+ # )
1299
+
1300
+ # if "OBSERVER_POSITION" in field_and_source_xds:
1301
+ # field_and_source_xds = field_and_source_xds.drop_vars(
1302
+ # ["OBSERVER_POSITION"], errors="ignore"
1303
+ # )
1304
+
1305
+ # if len(combined_ephemeris_field_and_source_xds.data_vars) == 0:
1306
+ # combined_ephemeris_field_and_source_xds = field_and_source_xds
1307
+ # else:
1308
+
1309
+ # combined_ephemeris_field_and_source_xds = xr.concat(
1310
+ # [combined_ephemeris_field_and_source_xds, field_and_source_xds],
1311
+ # dim="time",
1312
+ # )
1313
+
1314
+ # if (len(combined_ephemeris_field_and_source_xds.data_vars) > 0) and (
1315
+ # "FIELD_PHASE_CENTER" in combined_ephemeris_field_and_source_xds
1316
+ # ):
1317
+
1318
+ # from xradio._utils.coord_math import wrap_to_pi
1319
+
1320
+ # offset = (
1321
+ # combined_ephemeris_field_and_source_xds["FIELD_PHASE_CENTER"]
1322
+ # - combined_ephemeris_field_and_source_xds["SOURCE_LOCATION"]
1323
+ # )
1324
+ # combined_ephemeris_field_and_source_xds["FIELD_OFFSET"] = xr.DataArray(
1325
+ # wrap_to_pi(offset.sel(sky_pos_label=["ra", "dec"])).values,
1326
+ # dims=["time", "sky_dir_label"],
1327
+ # )
1328
+ # combined_ephemeris_field_and_source_xds["FIELD_OFFSET"].attrs = (
1329
+ # combined_ephemeris_field_and_source_xds["FIELD_PHASE_CENTER"].attrs
1330
+ # )
1331
+ # combined_ephemeris_field_and_source_xds["FIELD_OFFSET"].attrs["units"] = (
1332
+ # combined_ephemeris_field_and_source_xds["FIELD_OFFSET"].attrs["units"][
1333
+ # :2
1334
+ # ]
1335
+ # )
1336
+
1337
+ # ra1 = (
1338
+ # combined_ephemeris_field_and_source_xds["FIELD_OFFSET"]
1339
+ # .sel(sky_dir_label="ra")
1340
+ # .values
1341
+ # )
1342
+ # dec1 = (
1343
+ # combined_ephemeris_field_and_source_xds["FIELD_OFFSET"]
1344
+ # .sel(sky_dir_label="dec")
1345
+ # .values
1346
+ # )
1347
+ # ra2 = 0.0
1348
+ # dec2 = 0.0
1349
+
1350
+ # from xradio._utils.coord_math import haversine
1351
+
1352
+ # distance = haversine(ra1, dec1, ra2, dec2)
1353
+ # min_index = distance.argmin()
1354
+
1355
+ # combined_ephemeris_field_and_source_xds.attrs["center_field_name"] = (
1356
+ # combined_ephemeris_field_and_source_xds.field_name[min_index].values
1357
+ # )
1358
+
1359
+ # return combined_ephemeris_field_and_source_xds
1360
+
1361
+ # def plot_phase_centers(self, label_all_fields=False, data_group="base"):
1362
+ # """
1363
+ # Plot the phase center locations of all fields in the Processing Set.
1364
+
1365
+ # This method is primarily used for visualizing mosaics. It generates scatter plots of
1366
+ # the phase center coordinates for both standard and ephemeris fields. The central field
1367
+ # is highlighted in red based on the closest phase center calculation.
1368
+
1369
+ # Parameters
1370
+ # ----------
1371
+ # label_all_fields : bool, optional
1372
+ # If `True`, all fields will be labeled on the plot. Default is `False`.
1373
+ # data_group : str, optional
1374
+ # The data group to use for processing. Default is "base".
1375
+
1376
+ # Returns
1377
+ # -------
1378
+ # None
1379
+
1380
+ # Raises
1381
+ # ------
1382
+ # ValueError
1383
+ # If the combined datasets are empty or improperly formatted.
1384
+ # """
1385
+ # combined_field_and_source_xds = self.get_combined_field_and_source_xds(
1386
+ # data_group
1387
+ # )
1388
+ # combined_ephemeris_field_and_source_xds = (
1389
+ # self.get_combined_field_and_source_xds_ephemeris(data_group)
1390
+ # )
1391
+ # from matplotlib import pyplot as plt
1392
+
1393
+ # if (len(combined_field_and_source_xds.data_vars) > 0) and (
1394
+ # "FIELD_PHASE_CENTER" in combined_field_and_source_xds
1395
+ # ):
1396
+ # plt.figure()
1397
+ # plt.title("Field Phase Center Locations")
1398
+ # plt.scatter(
1399
+ # combined_field_and_source_xds["FIELD_PHASE_CENTER"].sel(
1400
+ # sky_dir_label="ra"
1401
+ # ),
1402
+ # combined_field_and_source_xds["FIELD_PHASE_CENTER"].sel(
1403
+ # sky_dir_label="dec"
1404
+ # ),
1405
+ # )
1406
+
1407
+ # center_field_name = combined_field_and_source_xds.attrs["center_field_name"]
1408
+ # center_field = combined_field_and_source_xds.sel(
1409
+ # field_name=center_field_name
1410
+ # )
1411
+ # plt.scatter(
1412
+ # center_field["FIELD_PHASE_CENTER"].sel(sky_dir_label="ra"),
1413
+ # center_field["FIELD_PHASE_CENTER"].sel(sky_dir_label="dec"),
1414
+ # color="red",
1415
+ # label=center_field_name,
1416
+ # )
1417
+ # plt.xlabel("RA (rad)")
1418
+ # plt.ylabel("DEC (rad)")
1419
+ # plt.legend()
1420
+ # plt.show()
1421
+
1422
+ # if (len(combined_ephemeris_field_and_source_xds.data_vars) > 0) and (
1423
+ # "FIELD_PHASE_CENTER" in combined_ephemeris_field_and_source_xds
1424
+ # ):
1425
+
1426
+ # plt.figure()
1427
+ # plt.title(
1428
+ # "Offset of Field Phase Center from Source Location (Ephemeris Data)"
1429
+ # )
1430
+ # plt.scatter(
1431
+ # combined_ephemeris_field_and_source_xds["FIELD_OFFSET"].sel(
1432
+ # sky_dir_label="ra"
1433
+ # ),
1434
+ # combined_ephemeris_field_and_source_xds["FIELD_OFFSET"].sel(
1435
+ # sky_dir_label="dec"
1436
+ # ),
1437
+ # )
1438
+
1439
+ # center_field_name = combined_ephemeris_field_and_source_xds.attrs[
1440
+ # "center_field_name"
1441
+ # ]
1442
+
1443
+ # combined_ephemeris_field_and_source_xds = (
1444
+ # combined_ephemeris_field_and_source_xds.set_xindex("field_name")
1445
+ # )
1446
+
1447
+ # center_field = combined_ephemeris_field_and_source_xds.sel(
1448
+ # field_name=center_field_name
1449
+ # )
1450
+ # plt.scatter(
1451
+ # center_field["FIELD_OFFSET"].sel(sky_dir_label="ra"),
1452
+ # center_field["FIELD_OFFSET"].sel(sky_dir_label="dec"),
1453
+ # color="red",
1454
+ # label=center_field_name,
1455
+ # )
1456
+ # plt.xlabel("RA Offset (rad)")
1457
+ # plt.ylabel("DEC Offset (rad)")
1458
+ # plt.legend()
1459
+ # plt.show()
1460
+
1461
+ # def get_combined_antenna_xds(self):
1462
+ # """
1463
+ # Combine the `antenna_xds` datasets from all Measurement Sets into a single dataset.
1464
+
1465
+ # This method concatenates the antenna datasets from each Measurement Set along the 'antenna_name' dimension.
1466
+
1467
+ # Returns
1468
+ # -------
1469
+ # xarray.Dataset
1470
+ # A combined `xarray.Dataset` containing antenna information from all Measurement Sets.
1471
+
1472
+ # Raises
1473
+ # ------
1474
+ # ValueError
1475
+ # If antenna datasets are missing required variables or improperly formatted.
1476
+ # """
1477
+ # combined_antenna_xds = xr.Dataset()
1478
+ # for cor_name, ms_xds in self.items():
1479
+ # antenna_xds = ms_xds.antenna_xds.copy(deep=True)
1480
+
1481
+ # if len(combined_antenna_xds.data_vars) == 0:
1482
+ # combined_antenna_xds = antenna_xds
1483
+ # else:
1484
+ # combined_antenna_xds = xr.concat(
1485
+ # [combined_antenna_xds, antenna_xds],
1486
+ # dim="antenna_name",
1487
+ # data_vars="minimal",
1488
+ # coords="minimal",
1489
+ # )
1490
+
1491
+ # # ALMA WVR antenna_xds data has a NaN value for the antenna receptor angle.
1492
+ # if "ANTENNA_RECEPTOR_ANGLE" in combined_antenna_xds.data_vars:
1493
+ # combined_antenna_xds = combined_antenna_xds.dropna("antenna_name")
1494
+
1495
+ # combined_antenna_xds = combined_antenna_xds.drop_duplicates("antenna_name")
1496
+
1497
+ # return combined_antenna_xds
1498
+
1499
+ # def plot_antenna_positions(self):
1500
+ # """
1501
+ # Plot the antenna positions of all antennas in the Processing Set.
1502
+
1503
+ # This method generates three scatter plots displaying the antenna positions in different planes:
1504
+ # - X vs Y
1505
+ # - X vs Z
1506
+ # - Y vs Z
1507
+
1508
+ # Parameters
1509
+ # ----------
1510
+ # None
1511
+
1512
+ # Returns
1513
+ # -------
1514
+ # None
1515
+
1516
+ # Raises
1517
+ # ------
1518
+ # ValueError
1519
+ # If the combined antenna dataset is empty or missing required coordinates.
1520
+ # """
1521
+ # combined_antenna_xds = self.get_combined_antenna_xds()
1522
+ # from matplotlib import pyplot as plt
1523
+
1524
+ # plt.figure()
1525
+ # plt.title("Antenna Positions")
1526
+ # plt.scatter(
1527
+ # combined_antenna_xds["ANTENNA_POSITION"].sel(cartesian_pos_label="x"),
1528
+ # combined_antenna_xds["ANTENNA_POSITION"].sel(cartesian_pos_label="y"),
1529
+ # )
1530
+ # plt.xlabel("x (m)")
1531
+ # plt.ylabel("y (m)")
1532
+ # plt.show()
1533
+
1534
+ # plt.figure()
1535
+ # plt.title("Antenna Positions")
1536
+ # plt.scatter(
1537
+ # combined_antenna_xds["ANTENNA_POSITION"].sel(cartesian_pos_label="x"),
1538
+ # combined_antenna_xds["ANTENNA_POSITION"].sel(cartesian_pos_label="z"),
1539
+ # )
1540
+ # plt.xlabel("x (m)")
1541
+ # plt.ylabel("z (m)")
1542
+ # plt.show()
1543
+
1544
+ # plt.figure()
1545
+ # plt.title("Antenna Positions")
1546
+ # plt.scatter(
1547
+ # combined_antenna_xds["ANTENNA_POSITION"].sel(cartesian_pos_label="y"),
1548
+ # combined_antenna_xds["ANTENNA_POSITION"].sel(cartesian_pos_label="z"),
1549
+ # )
1550
+ # plt.xlabel("y (m)")
1551
+ # plt.ylabel("z (m)")
1552
+ # plt.show()