xradio 0.0.47__py3-none-any.whl → 0.0.49__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xradio/__init__.py +1 -0
- xradio/_utils/dict_helpers.py +69 -2
- xradio/_utils/list_and_array.py +3 -1
- xradio/_utils/schema.py +3 -1
- xradio/image/_util/__init__.py +0 -3
- xradio/image/_util/_casacore/common.py +0 -13
- xradio/image/_util/_casacore/xds_from_casacore.py +102 -97
- xradio/image/_util/_casacore/xds_to_casacore.py +36 -24
- xradio/image/_util/_fits/xds_from_fits.py +81 -36
- xradio/image/_util/_zarr/zarr_low_level.py +3 -3
- xradio/image/_util/casacore.py +7 -5
- xradio/image/_util/common.py +13 -26
- xradio/image/_util/image_factory.py +143 -191
- xradio/image/image.py +10 -59
- xradio/measurement_set/__init__.py +11 -6
- xradio/measurement_set/_utils/_msv2/_tables/read.py +187 -46
- xradio/measurement_set/_utils/_msv2/_tables/table_query.py +22 -0
- xradio/measurement_set/_utils/_msv2/conversion.py +347 -299
- xradio/measurement_set/_utils/_msv2/create_field_and_source_xds.py +233 -150
- xradio/measurement_set/_utils/_msv2/descr.py +1 -1
- xradio/measurement_set/_utils/_msv2/msv4_info_dicts.py +20 -13
- xradio/measurement_set/_utils/_msv2/msv4_sub_xdss.py +21 -22
- xradio/measurement_set/convert_msv2_to_processing_set.py +46 -6
- xradio/measurement_set/load_processing_set.py +100 -52
- xradio/measurement_set/measurement_set_xdt.py +197 -0
- xradio/measurement_set/open_processing_set.py +122 -86
- xradio/measurement_set/processing_set_xdt.py +1552 -0
- xradio/measurement_set/schema.py +375 -197
- xradio/schema/bases.py +5 -1
- xradio/schema/check.py +97 -5
- xradio/sphinx/schema_table.py +12 -0
- {xradio-0.0.47.dist-info → xradio-0.0.49.dist-info}/METADATA +4 -4
- {xradio-0.0.47.dist-info → xradio-0.0.49.dist-info}/RECORD +36 -36
- {xradio-0.0.47.dist-info → xradio-0.0.49.dist-info}/WHEEL +1 -1
- xradio/measurement_set/measurement_set_xds.py +0 -117
- xradio/measurement_set/processing_set.py +0 -777
- {xradio-0.0.47.dist-info → xradio-0.0.49.dist-info/licenses}/LICENSE.txt +0 -0
- {xradio-0.0.47.dist-info → xradio-0.0.49.dist-info}/top_level.txt +0 -0
|
@@ -127,9 +127,9 @@ def interpolate_to_time(
|
|
|
127
127
|
xds = xds.interp(
|
|
128
128
|
{time_name: interp_time.data}, method=method, assume_sorted=True
|
|
129
129
|
)
|
|
130
|
-
#
|
|
131
|
-
if "
|
|
132
|
-
xds = xds.drop_vars("
|
|
130
|
+
# scan_name sneaks in as a coordinate of the main time axis, drop it
|
|
131
|
+
if "scan_name" in xds.coords:
|
|
132
|
+
xds = xds.drop_vars("scan_name")
|
|
133
133
|
points_after = xds[time_name].size
|
|
134
134
|
logger.debug(
|
|
135
135
|
f"{message_prefix}: interpolating the time coordinate "
|
|
@@ -497,7 +497,7 @@ def prepare_generic_sys_cal_xds(generic_sys_cal_xds: xr.Dataset) -> xr.Dataset:
|
|
|
497
497
|
def create_system_calibration_xds(
|
|
498
498
|
in_file: str,
|
|
499
499
|
main_xds_frequency: xr.DataArray,
|
|
500
|
-
|
|
500
|
+
ant_xds: xr.DataArray,
|
|
501
501
|
sys_cal_interp_time: Union[xr.DataArray, None] = None,
|
|
502
502
|
):
|
|
503
503
|
"""
|
|
@@ -510,8 +510,8 @@ def create_system_calibration_xds(
|
|
|
510
510
|
main_xds_frequency: xr.DataArray
|
|
511
511
|
frequency array of the main xds (MSv4), containing among other things
|
|
512
512
|
spectral_window_id and measures metadata
|
|
513
|
-
|
|
514
|
-
|
|
513
|
+
ant_xds : xr.Dataset
|
|
514
|
+
The antenna_xds that has information such as names, stations, etc., for coordinates
|
|
515
515
|
sys_cal_interp_time: Union[xr.DataArray, None] = None,
|
|
516
516
|
Time axis to interpolate the data vars to (usually main MSv4 time)
|
|
517
517
|
|
|
@@ -529,7 +529,7 @@ def create_system_calibration_xds(
|
|
|
529
529
|
rename_ids=subt_rename_ids["SYSCAL"],
|
|
530
530
|
taql_where=(
|
|
531
531
|
f" where (SPECTRAL_WINDOW_ID = {spectral_window_id})"
|
|
532
|
-
f" AND (ANTENNA_ID IN [{','.join(map(str,
|
|
532
|
+
f" AND (ANTENNA_ID IN [{','.join(map(str, ant_xds.antenna_id.values))}])"
|
|
533
533
|
),
|
|
534
534
|
)
|
|
535
535
|
except ValueError as _exc:
|
|
@@ -541,14 +541,14 @@ def create_system_calibration_xds(
|
|
|
541
541
|
|
|
542
542
|
generic_sys_cal_xds = prepare_generic_sys_cal_xds(generic_sys_cal_xds)
|
|
543
543
|
|
|
544
|
-
mandatory_dimensions = ["antenna_name", "
|
|
544
|
+
mandatory_dimensions = ["antenna_name", "time_system_cal", "receptor_label"]
|
|
545
545
|
if "frequency" not in generic_sys_cal_xds.sizes:
|
|
546
546
|
dims_all = mandatory_dimensions
|
|
547
547
|
else:
|
|
548
|
-
dims_all = mandatory_dimensions + ["
|
|
548
|
+
dims_all = mandatory_dimensions + ["frequency_system_cal"]
|
|
549
549
|
|
|
550
550
|
to_new_data_variables = {
|
|
551
|
-
"PHASE_DIFF": ["PHASE_DIFFERENCE", ["antenna_name", "
|
|
551
|
+
"PHASE_DIFF": ["PHASE_DIFFERENCE", ["antenna_name", "time_system_cal"]],
|
|
552
552
|
"TCAL": ["TCAL", dims_all],
|
|
553
553
|
"TCAL_SPECTRUM": ["TCAL", dims_all],
|
|
554
554
|
"TRX": ["TRX", dims_all],
|
|
@@ -564,27 +564,26 @@ def create_system_calibration_xds(
|
|
|
564
564
|
}
|
|
565
565
|
|
|
566
566
|
to_new_coords = {
|
|
567
|
-
"TIME": ["
|
|
567
|
+
"TIME": ["time_system_cal", ["time_system_cal"]],
|
|
568
568
|
"receptor": ["receptor_label", ["receptor_label"]],
|
|
569
|
-
"frequency": ["
|
|
569
|
+
"frequency": ["frequency_system_cal", ["frequency_system_cal"]],
|
|
570
570
|
}
|
|
571
571
|
|
|
572
572
|
sys_cal_xds = xr.Dataset(attrs={"type": "system_calibration"})
|
|
573
|
-
|
|
574
|
-
"antenna_name":
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
"receptor_label": generic_sys_cal_xds.coords["receptor"].data,
|
|
573
|
+
ant_borrowed_coords = {
|
|
574
|
+
"antenna_name": ant_xds.coords["antenna_name"],
|
|
575
|
+
"receptor_label": ant_xds.coords["receptor_label"],
|
|
576
|
+
"polarization_type": ant_xds.coords["polarization_type"],
|
|
578
577
|
}
|
|
579
|
-
sys_cal_xds = sys_cal_xds.assign_coords(
|
|
578
|
+
sys_cal_xds = sys_cal_xds.assign_coords(ant_borrowed_coords)
|
|
580
579
|
sys_cal_xds = convert_generic_xds_to_xradio_schema(
|
|
581
580
|
generic_sys_cal_xds, sys_cal_xds, to_new_data_variables, to_new_coords
|
|
582
581
|
)
|
|
583
582
|
|
|
584
583
|
# Add frequency coord and its measures data, if present
|
|
585
|
-
if "
|
|
584
|
+
if "frequency_system_cal" in dims_all:
|
|
586
585
|
frequency_coord = {
|
|
587
|
-
"
|
|
586
|
+
"frequency_system_cal": generic_sys_cal_xds.coords["frequency"].data
|
|
588
587
|
}
|
|
589
588
|
sys_cal_xds = sys_cal_xds.assign_coords(frequency_coord)
|
|
590
589
|
frequency_measure = {
|
|
@@ -592,10 +591,10 @@ def create_system_calibration_xds(
|
|
|
592
591
|
"units": main_xds_frequency.attrs["units"],
|
|
593
592
|
"observer": main_xds_frequency.attrs["observer"],
|
|
594
593
|
}
|
|
595
|
-
sys_cal_xds.coords["
|
|
594
|
+
sys_cal_xds.coords["frequency_system_cal"].attrs.update(frequency_measure)
|
|
596
595
|
|
|
597
596
|
sys_cal_xds = rename_and_interpolate_to_time(
|
|
598
|
-
sys_cal_xds, "
|
|
597
|
+
sys_cal_xds, "time_system_cal", sys_cal_interp_time, "system_calibration_xds"
|
|
599
598
|
)
|
|
600
599
|
|
|
601
600
|
# correct expected types
|
|
@@ -18,6 +18,7 @@ def estimate_conversion_memory_and_cores(
|
|
|
18
18
|
"""
|
|
19
19
|
Given an MSv2 and a partition_scheme to use when converting it to MSv4,
|
|
20
20
|
estimates:
|
|
21
|
+
|
|
21
22
|
- memory (in the sense of the amount expected to be enough to convert)
|
|
22
23
|
- cores (in the sense of the recommended/optimal number of cores to use to convert)
|
|
23
24
|
|
|
@@ -36,7 +37,7 @@ def estimate_conversion_memory_and_cores(
|
|
|
36
37
|
Partition scheme as used in the function convert_msv2_to_processing_set()
|
|
37
38
|
|
|
38
39
|
Returns
|
|
39
|
-
|
|
40
|
+
-------
|
|
40
41
|
tuple
|
|
41
42
|
estimated maximum memory required for one partition,
|
|
42
43
|
maximum number of cores it makes sense to use (number of partitions),
|
|
@@ -62,7 +63,7 @@ def convert_msv2_to_processing_set(
|
|
|
62
63
|
use_table_iter: bool = False,
|
|
63
64
|
compressor: numcodecs.abc.Codec = numcodecs.Zstd(level=2),
|
|
64
65
|
storage_backend: str = "zarr",
|
|
65
|
-
|
|
66
|
+
parallel_mode: str = "none",
|
|
66
67
|
overwrite: bool = False,
|
|
67
68
|
):
|
|
68
69
|
"""Convert a Measurement Set v2 into a Processing Set of Measurement Set v4.
|
|
@@ -99,14 +100,45 @@ def convert_msv2_to_processing_set(
|
|
|
99
100
|
The Blosc compressor to use when saving the converted data to disk using Zarr, by default numcodecs.Zstd(level=2).
|
|
100
101
|
storage_backend : {"zarr", "netcdf"}, optional
|
|
101
102
|
The on-disk format to use. "netcdf" is not yet implemented.
|
|
102
|
-
|
|
103
|
-
|
|
103
|
+
parallel_mode : {"none", "partition", "time"}, optional
|
|
104
|
+
Choose whether to use Dask to execute conversion in parallel, by default "none" and conversion occurs serially.
|
|
105
|
+
The option "partition", parallelises the conversion over partitions specified by `partition_scheme`. The option "time" can only be used for phased array interferometers where there are no partitions
|
|
106
|
+
in the MS v2; instead the MS v2 is parallelised along the time dimension and can be controlled by `main_chunksize`.
|
|
104
107
|
overwrite : bool, optional
|
|
105
108
|
Whether to overwrite an existing processing set, by default False.
|
|
106
109
|
"""
|
|
107
110
|
|
|
111
|
+
# Create empty data tree
|
|
112
|
+
import xarray as xr
|
|
113
|
+
|
|
114
|
+
ps_dt = xr.DataTree()
|
|
115
|
+
|
|
116
|
+
if not str(out_file).endswith("ps.zarr"):
|
|
117
|
+
out_file += ".ps.zarr"
|
|
118
|
+
|
|
119
|
+
print("Output file: ", out_file)
|
|
120
|
+
|
|
121
|
+
if overwrite:
|
|
122
|
+
ps_dt.to_zarr(store=out_file, mode="w")
|
|
123
|
+
else:
|
|
124
|
+
ps_dt.to_zarr(store=out_file, mode="w-")
|
|
125
|
+
|
|
126
|
+
# Check `parallel_mode` is valid
|
|
127
|
+
try:
|
|
128
|
+
assert parallel_mode in ["none", "partition", "time"]
|
|
129
|
+
except AssertionError:
|
|
130
|
+
logger.warning(
|
|
131
|
+
f"`parallel_mode` {parallel_mode} not recognosed. Defauling to 'none'."
|
|
132
|
+
)
|
|
133
|
+
parallel_mode = "none"
|
|
134
|
+
|
|
108
135
|
partitions = create_partitions(in_file, partition_scheme=partition_scheme)
|
|
109
136
|
logger.info("Number of partitions: " + str(len(partitions)))
|
|
137
|
+
if parallel_mode == "time":
|
|
138
|
+
assert (
|
|
139
|
+
len(partitions) == 1
|
|
140
|
+
), "MS v2 contains more than one partition. `parallel_mode = 'time'` not valid."
|
|
141
|
+
|
|
110
142
|
delayed_list = []
|
|
111
143
|
|
|
112
144
|
for ms_v4_id, partition_info in enumerate(partitions):
|
|
@@ -132,7 +164,7 @@ def convert_msv2_to_processing_set(
|
|
|
132
164
|
|
|
133
165
|
# prepend '0' to ms_v4_id as needed
|
|
134
166
|
ms_v4_id = f"{ms_v4_id:0>{len(str(len(partitions) - 1))}}"
|
|
135
|
-
if
|
|
167
|
+
if parallel_mode == "partition":
|
|
136
168
|
delayed_list.append(
|
|
137
169
|
dask.delayed(convert_and_write_partition)(
|
|
138
170
|
in_file,
|
|
@@ -149,6 +181,7 @@ def convert_msv2_to_processing_set(
|
|
|
149
181
|
phase_cal_interpolate=phase_cal_interpolate,
|
|
150
182
|
sys_cal_interpolate=sys_cal_interpolate,
|
|
151
183
|
compressor=compressor,
|
|
184
|
+
parallel_mode=parallel_mode,
|
|
152
185
|
overwrite=overwrite,
|
|
153
186
|
)
|
|
154
187
|
)
|
|
@@ -168,8 +201,15 @@ def convert_msv2_to_processing_set(
|
|
|
168
201
|
phase_cal_interpolate=phase_cal_interpolate,
|
|
169
202
|
sys_cal_interpolate=sys_cal_interpolate,
|
|
170
203
|
compressor=compressor,
|
|
204
|
+
parallel_mode=parallel_mode,
|
|
171
205
|
overwrite=overwrite,
|
|
172
206
|
)
|
|
173
207
|
|
|
174
|
-
if
|
|
208
|
+
if parallel_mode == "partition":
|
|
175
209
|
dask.compute(delayed_list)
|
|
210
|
+
|
|
211
|
+
import zarr
|
|
212
|
+
|
|
213
|
+
root_group = zarr.open(out_file, mode="r+") # Open in read/write mode
|
|
214
|
+
root_group.attrs["type"] = "processing_set" # Replace
|
|
215
|
+
zarr.convenience.consolidate_metadata(root_group.store)
|
|
@@ -1,79 +1,115 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from xradio.measurement_set import ProcessingSet
|
|
3
2
|
from typing import Dict, Union
|
|
3
|
+
import dask
|
|
4
|
+
import xarray as xr
|
|
5
|
+
import s3fs
|
|
4
6
|
|
|
5
7
|
|
|
6
8
|
def load_processing_set(
|
|
7
9
|
ps_store: str,
|
|
8
|
-
sel_parms: dict,
|
|
9
|
-
|
|
10
|
+
sel_parms: dict = None,
|
|
11
|
+
data_group_name: str = None,
|
|
12
|
+
include_variables: Union[list, None] = None,
|
|
13
|
+
drop_variables: Union[list, None] = None,
|
|
10
14
|
load_sub_datasets: bool = True,
|
|
11
|
-
) ->
|
|
15
|
+
) -> xr.DataTree:
|
|
12
16
|
"""Loads a processing set into memory.
|
|
13
17
|
|
|
14
18
|
Parameters
|
|
15
19
|
----------
|
|
16
20
|
ps_store : str
|
|
17
21
|
String of the path and name of the processing set. For example '/users/user_1/uid___A002_Xf07bba_Xbe5c_target.lsrk.vis.zarr' for a file stored on a local file system, or 's3://viper-test-data/Antennae_North.cal.lsrk.split.vis.zarr/' for a file in AWS object storage.
|
|
18
|
-
sel_parms : dict
|
|
19
|
-
A dictionary where the keys are the names of the
|
|
22
|
+
sel_parms : dict, optional
|
|
23
|
+
A dictionary where the keys are the names of the ms_xdt's (measurement set xarray data trees) and the values are slice_dicts.
|
|
20
24
|
slice_dicts: A dictionary where the keys are the dimension names and the values are slices.
|
|
25
|
+
|
|
21
26
|
For example::
|
|
22
27
|
|
|
23
28
|
{
|
|
29
|
+
|
|
24
30
|
'ms_v4_name_1': {'frequency': slice(0, 160, None),'time':slice(0,100)},
|
|
25
31
|
...
|
|
26
32
|
'ms_v4_name_n': {'frequency': slice(0, 160, None),'time':slice(0,100)},
|
|
27
33
|
}
|
|
28
34
|
|
|
29
|
-
|
|
35
|
+
By default None, which loads all ms_xdts.
|
|
36
|
+
data_group_name : str, optional
|
|
37
|
+
The name of the data group to select. By default None, which loads all data groups.
|
|
38
|
+
include_variables : Union[list, None], optional
|
|
30
39
|
The list of data variables to load into memory for example ['VISIBILITY', 'WEIGHT, 'FLAGS']. By default None which will load all data variables into memory.
|
|
40
|
+
drop_variables : Union[list, None], optional
|
|
41
|
+
The list of data variables to drop from memory for example ['VISIBILITY', 'WEIGHT, 'FLAGS']. By default None which will not drop any data variables from memory.
|
|
31
42
|
load_sub_datasets : bool, optional
|
|
32
43
|
If true sub-datasets (for example weather_xds, antenna_xds, pointing_xds, system_calibration_xds ...) will be loaded into memory, by default True.
|
|
33
44
|
|
|
34
45
|
Returns
|
|
35
46
|
-------
|
|
36
|
-
|
|
37
|
-
In memory representation of processing set
|
|
47
|
+
xarray.DataTree
|
|
48
|
+
In memory representation of processing set using xr.DataTree.
|
|
38
49
|
"""
|
|
39
|
-
from xradio._utils.zarr.common import
|
|
50
|
+
from xradio._utils.zarr.common import _get_file_system_and_items
|
|
40
51
|
|
|
41
52
|
file_system, ms_store_list = _get_file_system_and_items(ps_store)
|
|
42
53
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
54
|
+
with dask.config.set(
|
|
55
|
+
scheduler="synchronous"
|
|
56
|
+
): # serial scheduler, critical so that this can be used within delayed functions.
|
|
57
|
+
ps_xdt = xr.DataTree()
|
|
58
|
+
|
|
59
|
+
if sel_parms:
|
|
60
|
+
for ms_name, ms_xds_isel in sel_parms.items():
|
|
61
|
+
ms_store = os.path.join(ps_store, ms_name)
|
|
62
|
+
|
|
63
|
+
if isinstance(file_system, s3fs.core.S3FileSystem):
|
|
64
|
+
ms_store = s3fs.S3Map(root=ps_store, s3=file_system, check=False)
|
|
65
|
+
|
|
66
|
+
if ms_xds_isel:
|
|
67
|
+
ms_xdt = (
|
|
68
|
+
xr.open_datatree(
|
|
69
|
+
ms_store, engine="zarr", drop_variables=drop_variables
|
|
70
|
+
)
|
|
71
|
+
.isel(ms_xds_isel)
|
|
72
|
+
.xr_ms.sel(data_group_name=data_group_name)
|
|
73
|
+
)
|
|
74
|
+
else:
|
|
75
|
+
ms_xdt = xr.open_datatree(
|
|
76
|
+
ms_store, engine="zarr", drop_variables=drop_variables
|
|
77
|
+
).xr_ms.sel(data_group_name=data_group_name)
|
|
78
|
+
|
|
79
|
+
if include_variables is not None:
|
|
80
|
+
for data_vars in ms_xdt.ds.data_vars:
|
|
81
|
+
if data_vars not in include_variables:
|
|
82
|
+
ms_xdt.ds = ms_xdt.ds.drop_vars(data_vars)
|
|
83
|
+
|
|
84
|
+
ps_xdt[ms_name] = ms_xdt
|
|
85
|
+
|
|
86
|
+
ps_xdt.attrs["type"] = "processing_set"
|
|
87
|
+
else:
|
|
88
|
+
ps_xdt = xr.open_datatree(
|
|
89
|
+
ps_store, engine="zarr", drop_variables=drop_variables
|
|
62
90
|
)
|
|
63
91
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
92
|
+
if (include_variables is not None) or data_group_name:
|
|
93
|
+
for ms_name, ms_xdt in ps_xdt.items():
|
|
94
|
+
|
|
95
|
+
ms_xdt = ms_xdt.xr_ms.sel(data_group_name=data_group_name)
|
|
96
|
+
|
|
97
|
+
if include_variables is not None:
|
|
98
|
+
for data_vars in ms_xdt.ds.data_vars:
|
|
99
|
+
if data_vars not in include_variables:
|
|
100
|
+
ms_xdt.ds = ms_xdt.ds.drop_vars(data_vars)
|
|
101
|
+
ps_xdt[ms_name] = ms_xdt
|
|
69
102
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
103
|
+
if not load_sub_datasets:
|
|
104
|
+
for ms_xdt in ps_xdt.children.values():
|
|
105
|
+
ms_xdt_names = list(ms_xdt.keys())
|
|
106
|
+
for sub_xds_name in ms_xdt_names:
|
|
107
|
+
if "xds" in sub_xds_name:
|
|
108
|
+
del ms_xdt[sub_xds_name]
|
|
73
109
|
|
|
74
|
-
|
|
110
|
+
ps_xdt = ps_xdt.load()
|
|
75
111
|
|
|
76
|
-
return
|
|
112
|
+
return ps_xdt
|
|
77
113
|
|
|
78
114
|
|
|
79
115
|
class ProcessingSetIterator:
|
|
@@ -81,8 +117,10 @@ class ProcessingSetIterator:
|
|
|
81
117
|
self,
|
|
82
118
|
sel_parms: dict,
|
|
83
119
|
input_data_store: str,
|
|
84
|
-
input_data: Union[Dict,
|
|
85
|
-
|
|
120
|
+
input_data: Union[Dict, xr.DataTree, None] = None,
|
|
121
|
+
data_group_name: str = None,
|
|
122
|
+
include_variables: Union[list, None] = None,
|
|
123
|
+
drop_variables: Union[list, None] = None,
|
|
86
124
|
load_sub_datasets: bool = True,
|
|
87
125
|
):
|
|
88
126
|
"""An iterator that will go through a processing set one MS v4 at a time.
|
|
@@ -101,10 +139,16 @@ class ProcessingSetIterator:
|
|
|
101
139
|
}
|
|
102
140
|
input_data_store : str
|
|
103
141
|
String of the path and name of the processing set. For example '/users/user_1/uid___A002_Xf07bba_Xbe5c_target.lsrk.vis.zarr'.
|
|
104
|
-
input_data : Union[Dict,
|
|
142
|
+
input_data : Union[Dict, xr.DataTree, None], optional
|
|
105
143
|
If the processing set is in memory already it can be supplied here. By default None which will make the iterator load data using the supplied input_data_store.
|
|
106
|
-
|
|
144
|
+
data_group_name : str, optional
|
|
145
|
+
The name of the data group to select. By default None, which loads all data groups.
|
|
146
|
+
data_group_name : str, optional
|
|
147
|
+
The name of the data group to select. By default None, which loads all data groups.
|
|
148
|
+
include_variables : Union[list, None], optional
|
|
107
149
|
The list of data variables to load into memory for example ['VISIBILITY', 'WEIGHT, 'FLAGS']. By default None which will load all data variables into memory.
|
|
150
|
+
drop_variables : Union[list, None], optional
|
|
151
|
+
The list of data variables to drop from memory for example ['VISIBILITY', 'WEIGHT, 'FLAGS']. By default None which will not drop any data variables from memory.
|
|
108
152
|
load_sub_datasets : bool, optional
|
|
109
153
|
If true sub-datasets (for example weather_xds, antenna_xds, pointing_xds, system_calibration_xds ...) will be loaded into memory, by default True.
|
|
110
154
|
"""
|
|
@@ -113,7 +157,9 @@ class ProcessingSetIterator:
|
|
|
113
157
|
self.input_data_store = input_data_store
|
|
114
158
|
self.sel_parms = sel_parms
|
|
115
159
|
self.xds_name_iter = iter(sel_parms.keys())
|
|
116
|
-
self.
|
|
160
|
+
self.data_group_name = data_group_name
|
|
161
|
+
self.include_variables = include_variables
|
|
162
|
+
self.drop_variables = drop_variables
|
|
117
163
|
self.load_sub_datasets = load_sub_datasets
|
|
118
164
|
|
|
119
165
|
def __iter__(self):
|
|
@@ -121,20 +167,22 @@ class ProcessingSetIterator:
|
|
|
121
167
|
|
|
122
168
|
def __next__(self):
|
|
123
169
|
try:
|
|
124
|
-
|
|
170
|
+
sub_xds_name = next(self.xds_name_iter)
|
|
125
171
|
except Exception as e:
|
|
126
172
|
raise StopIteration
|
|
127
173
|
|
|
128
174
|
if self.input_data is None:
|
|
129
|
-
slice_description = self.sel_parms[
|
|
130
|
-
|
|
175
|
+
slice_description = self.sel_parms[sub_xds_name]
|
|
176
|
+
ps_xdt = load_processing_set(
|
|
131
177
|
ps_store=self.input_data_store,
|
|
132
|
-
sel_parms={
|
|
133
|
-
|
|
178
|
+
sel_parms={sub_xds_name: slice_description},
|
|
179
|
+
data_group_name=self.data_group_name,
|
|
180
|
+
include_variables=self.include_variables,
|
|
181
|
+
drop_variables=self.drop_variables,
|
|
134
182
|
load_sub_datasets=self.load_sub_datasets,
|
|
135
183
|
)
|
|
136
|
-
|
|
184
|
+
sub_xdt = ps_xdt.get(0)
|
|
137
185
|
else:
|
|
138
|
-
|
|
186
|
+
sub_xdt = self.input_data[sub_xds_name] # In memory
|
|
139
187
|
|
|
140
|
-
return
|
|
188
|
+
return sub_xdt
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
from xradio._utils.list_and_array import to_list
|
|
3
|
+
import xarray as xr
|
|
4
|
+
import numpy as np
|
|
5
|
+
import numbers
|
|
6
|
+
import os
|
|
7
|
+
from collections.abc import Mapping, Iterable
|
|
8
|
+
from typing import Any, Union
|
|
9
|
+
|
|
10
|
+
MS_DATASET_TYPES = {"visibility", "spectrum", "radiometer"}
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class InvalidAccessorLocation(ValueError):
|
|
14
|
+
"""
|
|
15
|
+
Raised by MeasurementSetXdt accessor functions called on a wrong DataTree node (not MSv4).
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@xr.register_datatree_accessor("xr_ms")
|
|
22
|
+
class MeasurementSetXdt:
|
|
23
|
+
"""Accessor to the Measurement Set DataTree node. Provides MSv4 specific functionality
|
|
24
|
+
such as:
|
|
25
|
+
|
|
26
|
+
- get_partition_info(): produce an info dict with a general MSv4 description including
|
|
27
|
+
intents, SPW name, field and source names, etc.
|
|
28
|
+
- get_field_and_source_xds() to retrieve the field_and_source_xds for a given data
|
|
29
|
+
group.
|
|
30
|
+
- sel(): select data by dimension labels, for example by data group and polaritzation
|
|
31
|
+
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
_xdt: xr.DataTree
|
|
35
|
+
|
|
36
|
+
def __init__(self, datatree: xr.DataTree):
|
|
37
|
+
"""
|
|
38
|
+
Initialize the MeasurementSetXdt instance.
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
datatree: xarray.DataTree
|
|
43
|
+
The MSv4 DataTree node to construct a MeasurementSetXdt accessor.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
self._xdt = datatree
|
|
47
|
+
self.meta = {"summary": {}}
|
|
48
|
+
|
|
49
|
+
def sel(
|
|
50
|
+
self,
|
|
51
|
+
indexers: Union[Mapping[Any, Any], None] = None,
|
|
52
|
+
method: Union[str, None] = None,
|
|
53
|
+
tolerance: Union[int, float, Iterable[Union[int, float]], None] = None,
|
|
54
|
+
drop: bool = False,
|
|
55
|
+
**indexers_kwargs: Any,
|
|
56
|
+
) -> xr.DataTree:
|
|
57
|
+
"""
|
|
58
|
+
Select data along dimension(s) by label. Alternative to `xarray.Dataset.sel <https://xarray.pydata.org/en/stable/generated/xarray.Dataset.sel.html>`__ so that a data group can be selected by name by using the `data_group_name` parameter.
|
|
59
|
+
For more information on data groups see `Data Groups <https://xradio.readthedocs.io/en/latest/measurement_set_overview.html#Data-Groups>`__ section. See `xarray.Dataset.sel <https://xarray.pydata.org/en/stable/generated/xarray.Dataset.sel.html>`__ for parameter descriptions.
|
|
60
|
+
|
|
61
|
+
Returns
|
|
62
|
+
-------
|
|
63
|
+
xarray.DataTree
|
|
64
|
+
xarray DataTree with MeasurementSetXdt accessors
|
|
65
|
+
|
|
66
|
+
Examples
|
|
67
|
+
--------
|
|
68
|
+
>>> # Select data group 'corrected' and polarization 'XX'.
|
|
69
|
+
>>> selected_ms_xdt = ms_xdt.xr_ms.sel(data_group_name='corrected', polarization='XX')
|
|
70
|
+
|
|
71
|
+
>>> # Select data group 'corrected' and polarization 'XX' using a dict.
|
|
72
|
+
>>> selected_ms_xdt = ms_xdt.xr_ms.sel({'data_group_name':'corrected', 'polarization':'XX')
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
if self._xdt.attrs.get("type") not in MS_DATASET_TYPES:
|
|
76
|
+
raise InvalidAccessorLocation(f"{self._xdt.path} is not a MSv4 node.")
|
|
77
|
+
|
|
78
|
+
assert self._xdt.attrs["type"] in [
|
|
79
|
+
"visibility",
|
|
80
|
+
"spectrum",
|
|
81
|
+
"radiometer",
|
|
82
|
+
], "The type of the xdt must be 'visibility', 'spectrum' or 'radiometer'."
|
|
83
|
+
|
|
84
|
+
if "data_group_name" in indexers_kwargs:
|
|
85
|
+
data_group_name = indexers_kwargs["data_group_name"]
|
|
86
|
+
del indexers_kwargs["data_group_name"]
|
|
87
|
+
elif (indexers is not None) and ("data_group_name" in indexers):
|
|
88
|
+
data_group_name = indexers["data_group_name"]
|
|
89
|
+
del indexers["data_group_name"]
|
|
90
|
+
else:
|
|
91
|
+
data_group_name = None
|
|
92
|
+
|
|
93
|
+
if data_group_name is not None:
|
|
94
|
+
sel_data_group_set = set(
|
|
95
|
+
self._xdt.attrs["data_groups"][data_group_name].values()
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
data_variables_to_drop = []
|
|
99
|
+
for dg in self._xdt.attrs["data_groups"].values():
|
|
100
|
+
temp_set = set(dg.values()) - sel_data_group_set
|
|
101
|
+
data_variables_to_drop.extend(list(temp_set))
|
|
102
|
+
|
|
103
|
+
data_variables_to_drop = list(set(data_variables_to_drop))
|
|
104
|
+
|
|
105
|
+
sel_ms_xdt = self._xdt
|
|
106
|
+
|
|
107
|
+
sel_corr_xds = self._xdt.ds.sel(
|
|
108
|
+
indexers, method, tolerance, drop, **indexers_kwargs
|
|
109
|
+
).drop_vars(data_variables_to_drop)
|
|
110
|
+
|
|
111
|
+
sel_ms_xdt.ds = sel_corr_xds
|
|
112
|
+
|
|
113
|
+
sel_ms_xdt.attrs["data_groups"] = {
|
|
114
|
+
data_group_name: self._xdt.attrs["data_groups"][data_group_name]
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
return sel_ms_xdt
|
|
118
|
+
else:
|
|
119
|
+
return self._xdt.sel(indexers, method, tolerance, drop, **indexers_kwargs)
|
|
120
|
+
|
|
121
|
+
def get_field_and_source_xds(self, data_group_name: str = None) -> xr.Dataset:
|
|
122
|
+
"""Get the field_and_source_xds associated with data group `data_group_name`.
|
|
123
|
+
|
|
124
|
+
Parameters
|
|
125
|
+
----------
|
|
126
|
+
data_group_name : str, optional
|
|
127
|
+
The data group to process. Default is "base" or if not found to first data group.
|
|
128
|
+
|
|
129
|
+
Returns
|
|
130
|
+
-------
|
|
131
|
+
xarray.Dataset
|
|
132
|
+
field_and_source_xds associated with the data group.
|
|
133
|
+
"""
|
|
134
|
+
if self._xdt.attrs.get("type") not in MS_DATASET_TYPES:
|
|
135
|
+
raise InvalidAccessorLocation(f"{self._xdt.path} is not a MSv4 node.")
|
|
136
|
+
|
|
137
|
+
if data_group_name is None:
|
|
138
|
+
if "base" in self._xdt.attrs["data_groups"].keys():
|
|
139
|
+
data_group_name = "base"
|
|
140
|
+
else:
|
|
141
|
+
data_group_name = list(self._xdt.attrs["data_groups"].keys())[0]
|
|
142
|
+
|
|
143
|
+
return self._xdt[f"field_and_source_xds_{data_group_name}"].ds
|
|
144
|
+
|
|
145
|
+
def get_partition_info(self, data_group_name: str = None) -> dict:
|
|
146
|
+
"""
|
|
147
|
+
Generate a partition info dict for an MSv4, with general MSv4 description including
|
|
148
|
+
information such as field and source names, SPW name, scan name, the intents string,
|
|
149
|
+
etc.
|
|
150
|
+
|
|
151
|
+
The information is gathered from various coordinates, secondary datasets, and info
|
|
152
|
+
dicts of the MSv4. For example, the SPW name comes from the attributes of the
|
|
153
|
+
frequency coordinate, whereas field and source related information such as field and
|
|
154
|
+
source names come from the field_and_source_xds (base) dataset of the MSv4.
|
|
155
|
+
|
|
156
|
+
Parameters
|
|
157
|
+
----------
|
|
158
|
+
data_group_name : str, optional
|
|
159
|
+
The data group to process. Default is "base" or if not found to first data group.
|
|
160
|
+
|
|
161
|
+
Returns
|
|
162
|
+
-------
|
|
163
|
+
dict
|
|
164
|
+
Partition info dict for the MSv4
|
|
165
|
+
"""
|
|
166
|
+
if self._xdt.attrs.get("type") not in MS_DATASET_TYPES:
|
|
167
|
+
raise InvalidAccessorLocation(
|
|
168
|
+
f"{self._xdt.path} is not a MSv4 node (type {self._xdt.attrs.get('type')}."
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
if data_group_name is None:
|
|
172
|
+
if "base" in self._xdt.attrs["data_groups"].keys():
|
|
173
|
+
data_group_name = "base"
|
|
174
|
+
else:
|
|
175
|
+
data_group_name = list(self._xdt.attrs["data_groups"].keys())[0]
|
|
176
|
+
|
|
177
|
+
field_and_source_xds = self._xdt.xr_ms.get_field_and_source_xds(data_group_name)
|
|
178
|
+
|
|
179
|
+
if "line_name" in field_and_source_xds.coords:
|
|
180
|
+
line_name = to_list(
|
|
181
|
+
np.unique(np.ravel(field_and_source_xds.line_name.values))
|
|
182
|
+
)
|
|
183
|
+
else:
|
|
184
|
+
line_name = []
|
|
185
|
+
|
|
186
|
+
partition_info = {
|
|
187
|
+
"spectral_window_name": self._xdt.frequency.attrs["spectral_window_name"],
|
|
188
|
+
"field_name": to_list(np.unique(field_and_source_xds.field_name.values)),
|
|
189
|
+
"polarization_setup": to_list(self._xdt.polarization.values),
|
|
190
|
+
"scan_name": to_list(np.unique(self._xdt.scan_name.values)),
|
|
191
|
+
"source_name": to_list(np.unique(field_and_source_xds.source_name.values)),
|
|
192
|
+
"intents": self._xdt.observation_info["intents"],
|
|
193
|
+
"line_name": line_name,
|
|
194
|
+
"data_group_name": data_group_name,
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
return partition_info
|