xradio 1.0.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xradio/_utils/_casacore/casacore_from_casatools.py +1 -1
- xradio/_utils/dict_helpers.py +38 -7
- xradio/_utils/list_and_array.py +26 -3
- xradio/_utils/schema.py +44 -0
- xradio/_utils/xarray_helpers.py +63 -0
- xradio/_utils/zarr/common.py +4 -2
- xradio/image/__init__.py +4 -2
- xradio/image/_util/_casacore/common.py +2 -1
- xradio/image/_util/_casacore/xds_from_casacore.py +105 -51
- xradio/image/_util/_casacore/xds_to_casacore.py +117 -52
- xradio/image/_util/_fits/xds_from_fits.py +124 -36
- xradio/image/_util/_zarr/common.py +0 -1
- xradio/image/_util/casacore.py +133 -16
- xradio/image/_util/common.py +6 -5
- xradio/image/_util/image_factory.py +466 -27
- xradio/image/image.py +72 -100
- xradio/image/image_xds.py +262 -0
- xradio/image/schema.py +85 -0
- xradio/measurement_set/__init__.py +5 -4
- xradio/measurement_set/_utils/_msv2/_tables/read.py +7 -3
- xradio/measurement_set/_utils/_msv2/conversion.py +6 -9
- xradio/measurement_set/_utils/_msv2/create_field_and_source_xds.py +1 -0
- xradio/measurement_set/_utils/_msv2/msv4_sub_xdss.py +1 -1
- xradio/measurement_set/_utils/_utils/interpolate.py +5 -0
- xradio/measurement_set/_utils/_utils/partition_attrs.py +0 -1
- xradio/measurement_set/convert_msv2_to_processing_set.py +9 -9
- xradio/measurement_set/load_processing_set.py +2 -2
- xradio/measurement_set/measurement_set_xdt.py +83 -93
- xradio/measurement_set/open_processing_set.py +7 -3
- xradio/measurement_set/processing_set_xdt.py +33 -26
- xradio/schema/check.py +70 -19
- xradio/schema/common.py +0 -1
- xradio/testing/__init__.py +0 -0
- xradio/testing/_utils/__template__.py +58 -0
- xradio/testing/measurement_set/__init__.py +58 -0
- xradio/testing/measurement_set/checker.py +131 -0
- xradio/testing/measurement_set/io.py +22 -0
- xradio/testing/measurement_set/msv2_io.py +1854 -0
- {xradio-1.0.1.dist-info → xradio-1.1.0.dist-info}/METADATA +64 -23
- xradio-1.1.0.dist-info/RECORD +75 -0
- {xradio-1.0.1.dist-info → xradio-1.1.0.dist-info}/WHEEL +1 -1
- xradio-1.0.1.dist-info/RECORD +0 -66
- {xradio-1.0.1.dist-info → xradio-1.1.0.dist-info}/licenses/LICENSE.txt +0 -0
- {xradio-1.0.1.dist-info → xradio-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -4,7 +4,7 @@ convert, and retrieve information from Processing Set and Measurement Sets nodes
|
|
|
4
4
|
Processing Set DataTree
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
import
|
|
7
|
+
import warnings
|
|
8
8
|
|
|
9
9
|
from .processing_set_xdt import ProcessingSetXdt
|
|
10
10
|
from .open_processing_set import open_processing_set
|
|
@@ -27,9 +27,10 @@ try:
|
|
|
27
27
|
estimate_conversion_memory_and_cores,
|
|
28
28
|
)
|
|
29
29
|
except ModuleNotFoundError as exc:
|
|
30
|
-
|
|
31
|
-
"Could not import the function to convert from MSv2 to MSv4. "
|
|
32
|
-
f"That functionality will not be available. Details: {exc}"
|
|
30
|
+
warnings.warn(
|
|
31
|
+
f"Could not import the function to convert from MSv2 to MSv4. "
|
|
32
|
+
f"That functionality will not be available. Details: {exc}",
|
|
33
|
+
UserWarning,
|
|
33
34
|
)
|
|
34
35
|
else:
|
|
35
36
|
__all__.extend(
|
|
@@ -93,9 +93,12 @@ def convert_mjd_time(rawtimes: np.ndarray) -> np.ndarray:
|
|
|
93
93
|
np.ndarray
|
|
94
94
|
times converted to pandas reference and datetime type
|
|
95
95
|
"""
|
|
96
|
+
|
|
97
|
+
print("^^^^^^^", rawtimes, MJD_DIF_UNIX, SECS_IN_DAY)
|
|
96
98
|
times_reref = pd.to_datetime(
|
|
97
99
|
(rawtimes - MJD_DIF_UNIX) * SECS_IN_DAY, unit="s"
|
|
98
100
|
).values
|
|
101
|
+
print("^^^^^^^", times_reref)
|
|
99
102
|
|
|
100
103
|
return times_reref
|
|
101
104
|
|
|
@@ -160,7 +163,7 @@ def make_taql_where_between_min_max(
|
|
|
160
163
|
if min_max_range is None:
|
|
161
164
|
taql = None
|
|
162
165
|
else:
|
|
163
|
-
|
|
166
|
+
min_val, max_val = min_max_range
|
|
164
167
|
taql = f"where {colname} >= {min_val} AND {colname} <= {max_val}"
|
|
165
168
|
|
|
166
169
|
return taql
|
|
@@ -226,7 +229,7 @@ def find_projected_min_max_array(
|
|
|
226
229
|
"""Does the min/max checks and search for find_projected_min_max_table()"""
|
|
227
230
|
|
|
228
231
|
sorted_array = np.sort(array)
|
|
229
|
-
|
|
232
|
+
range_min, range_max = min_max
|
|
230
233
|
if len(sorted_array) < 2:
|
|
231
234
|
tol = np.finfo(sorted_array.dtype).eps * 4
|
|
232
235
|
else:
|
|
@@ -891,7 +894,8 @@ def raw_col_data_to_coords_vars(
|
|
|
891
894
|
|
|
892
895
|
if col in timecols:
|
|
893
896
|
if col == "MJD":
|
|
894
|
-
data = convert_mjd_time(data).astype("float64") / 1e9
|
|
897
|
+
# data = convert_mjd_time(data).astype("float64") / 1e9
|
|
898
|
+
data = convert_mjd_time(data).astype("datetime64[ns]").view("int64") / 1e9
|
|
895
899
|
else:
|
|
896
900
|
try:
|
|
897
901
|
data = convert_casacore_time(data, False)
|
|
@@ -1017,7 +1017,7 @@ def convert_and_write_partition(
|
|
|
1017
1017
|
add_reshaping_indices: bool = False,
|
|
1018
1018
|
storage_backend="zarr",
|
|
1019
1019
|
parallel_mode: str = "none",
|
|
1020
|
-
|
|
1020
|
+
persistence_mode: str = "w-",
|
|
1021
1021
|
):
|
|
1022
1022
|
"""_summary_
|
|
1023
1023
|
|
|
@@ -1057,8 +1057,8 @@ def convert_and_write_partition(
|
|
|
1057
1057
|
_description_, by default "zarr"
|
|
1058
1058
|
parallel_mode : _type_, optional
|
|
1059
1059
|
_description_
|
|
1060
|
-
|
|
1061
|
-
_description_, by default
|
|
1060
|
+
persistence_mode: str = "w-",
|
|
1061
|
+
_description_, by default "w-"
|
|
1062
1062
|
|
|
1063
1063
|
Returns
|
|
1064
1064
|
-------
|
|
@@ -1368,11 +1368,6 @@ def convert_and_write_partition(
|
|
|
1368
1368
|
pathlib.Path(in_file).name.replace(".ms", "") + "_" + str(ms_v4_id),
|
|
1369
1369
|
)
|
|
1370
1370
|
|
|
1371
|
-
if overwrite:
|
|
1372
|
-
mode = "w"
|
|
1373
|
-
else:
|
|
1374
|
-
mode = "w-"
|
|
1375
|
-
|
|
1376
1371
|
if is_single_dish:
|
|
1377
1372
|
xds.attrs["type"] = "spectrum"
|
|
1378
1373
|
xds = xds.drop_vars("UVW")
|
|
@@ -1416,7 +1411,9 @@ def convert_and_write_partition(
|
|
|
1416
1411
|
ms_xdt["/phased_array_xds"] = phased_array_xds
|
|
1417
1412
|
|
|
1418
1413
|
if storage_backend == "zarr":
|
|
1419
|
-
ms_xdt.to_zarr(
|
|
1414
|
+
ms_xdt.to_zarr(
|
|
1415
|
+
store=os.path.join(out_file, ms_v4_name), mode=persistence_mode
|
|
1416
|
+
)
|
|
1420
1417
|
elif storage_backend == "netcdf":
|
|
1421
1418
|
# xds.to_netcdf(path=file_name+"/MAIN", mode=mode) #Does not work
|
|
1422
1419
|
raise
|
|
@@ -285,6 +285,7 @@ def extract_ephemeris_info(
|
|
|
285
285
|
# Metadata has to be fixed manually. Alternatively, issues like
|
|
286
286
|
# UNIT/QuantumUnits issue could be handled in convert_generic_xds_to_xradio_schema,
|
|
287
287
|
# but for now preferring not to pollute that function.
|
|
288
|
+
|
|
288
289
|
time_ephemeris_dim = ["time_ephemeris"]
|
|
289
290
|
to_new_data_variables = {
|
|
290
291
|
# mandatory: SOURCE_RADIAL_VELOCITY
|
|
@@ -27,7 +27,6 @@ from xradio.measurement_set._utils._msv2._tables.read import (
|
|
|
27
27
|
table_has_column,
|
|
28
28
|
)
|
|
29
29
|
|
|
30
|
-
|
|
31
30
|
standard_time_coord_attrs = make_time_measure_attrs(time_format="unix")
|
|
32
31
|
|
|
33
32
|
|
|
@@ -81,6 +80,7 @@ def rename_and_interpolate_to_time(
|
|
|
81
80
|
)
|
|
82
81
|
|
|
83
82
|
# rename the time_* axis to time.
|
|
83
|
+
|
|
84
84
|
time_coord = {"time": (time_initial_name, interp_time.data)}
|
|
85
85
|
renamed_time_xds = interpolated_xds.assign_coords(time_coord)
|
|
86
86
|
renamed_time_xds.coords["time"].attrs.update(standard_time_coord_attrs)
|
|
@@ -41,6 +41,10 @@ def interpolate_to_time(
|
|
|
41
41
|
method = "linear"
|
|
42
42
|
else:
|
|
43
43
|
method = "nearest"
|
|
44
|
+
|
|
45
|
+
# print("xds before interp:",xds.NORTH_POLE_ANGULAR_DISTANCE.values, xds[time_name].values)
|
|
46
|
+
# print("interp_time data:",interp_time,interp_time.data)
|
|
47
|
+
# print("method:",method)
|
|
44
48
|
xds = xds.interp(
|
|
45
49
|
{time_name: interp_time.data}, method=method, assume_sorted=True
|
|
46
50
|
)
|
|
@@ -56,5 +60,6 @@ def interpolate_to_time(
|
|
|
56
60
|
f"{message_prefix}: interpolating the time coordinate "
|
|
57
61
|
f"from {points_before} to {points_after} points"
|
|
58
62
|
)
|
|
63
|
+
# print("xds after interp:",xds.NORTH_POLE_ANGULAR_DISTANCE.values, xds[time_name].values)
|
|
59
64
|
|
|
60
65
|
return xds
|
|
@@ -68,7 +68,7 @@ def convert_msv2_to_processing_set(
|
|
|
68
68
|
add_reshaping_indices: bool = False,
|
|
69
69
|
storage_backend: Literal["zarr", "netcdf"] = "zarr",
|
|
70
70
|
parallel_mode: Literal["none", "partition", "time"] = "none",
|
|
71
|
-
|
|
71
|
+
persistence_mode: str = "w-",
|
|
72
72
|
):
|
|
73
73
|
"""Convert a Measurement Set v2 into a Processing Set of Measurement Set v4.
|
|
74
74
|
|
|
@@ -110,8 +110,11 @@ def convert_msv2_to_processing_set(
|
|
|
110
110
|
Choose whether to use Dask to execute conversion in parallel, by default "none" and conversion occurs serially.
|
|
111
111
|
The option "partition", parallelises the conversion over partitions specified by `partition_scheme`. The option "time" can only be used for phased array interferometers where there are no partitions
|
|
112
112
|
in the MS v2; instead the MS v2 is parallelised along the time dimension and can be controlled by `main_chunksize`.
|
|
113
|
-
|
|
114
|
-
|
|
113
|
+
persistence_mode : str, optional
|
|
114
|
+
“w” means create (overwrite if exists);
|
|
115
|
+
“w-” means create (fail if exists);
|
|
116
|
+
“a” means override all existing variables including dimension coordinates (create if does not exist); Use this mode if you want to add to an existing Processing Set.
|
|
117
|
+
The default is "w-".
|
|
115
118
|
"""
|
|
116
119
|
|
|
117
120
|
# Create empty data tree
|
|
@@ -122,10 +125,7 @@ def convert_msv2_to_processing_set(
|
|
|
122
125
|
if not str(out_file).endswith("ps.zarr"):
|
|
123
126
|
out_file += ".ps.zarr"
|
|
124
127
|
|
|
125
|
-
|
|
126
|
-
ps_dt.to_zarr(store=out_file, mode="w")
|
|
127
|
-
else:
|
|
128
|
-
ps_dt.to_zarr(store=out_file, mode="w-")
|
|
128
|
+
ps_dt.to_zarr(store=out_file, mode=persistence_mode)
|
|
129
129
|
|
|
130
130
|
# Check `parallel_mode` is valid
|
|
131
131
|
try:
|
|
@@ -192,7 +192,7 @@ def convert_msv2_to_processing_set(
|
|
|
192
192
|
add_reshaping_indices=add_reshaping_indices,
|
|
193
193
|
compressor=compressor,
|
|
194
194
|
parallel_mode=parallel_mode,
|
|
195
|
-
|
|
195
|
+
persistence_mode=persistence_mode,
|
|
196
196
|
)
|
|
197
197
|
)
|
|
198
198
|
else:
|
|
@@ -214,7 +214,7 @@ def convert_msv2_to_processing_set(
|
|
|
214
214
|
add_reshaping_indices=add_reshaping_indices,
|
|
215
215
|
compressor=compressor,
|
|
216
216
|
parallel_mode=parallel_mode,
|
|
217
|
-
|
|
217
|
+
persistence_mode=persistence_mode,
|
|
218
218
|
)
|
|
219
219
|
end_time = time.time()
|
|
220
220
|
logger.debug(
|
|
@@ -1,8 +1,6 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from typing import Dict, Union
|
|
3
|
-
import dask
|
|
4
3
|
import xarray as xr
|
|
5
|
-
import s3fs
|
|
6
4
|
|
|
7
5
|
|
|
8
6
|
def load_processing_set(
|
|
@@ -48,6 +46,8 @@ def load_processing_set(
|
|
|
48
46
|
In memory representation of processing set using xr.DataTree.
|
|
49
47
|
"""
|
|
50
48
|
from xradio._utils.zarr.common import _get_file_system_and_items
|
|
49
|
+
import dask
|
|
50
|
+
import s3fs
|
|
51
51
|
|
|
52
52
|
file_system, ms_store_list = _get_file_system_and_items(ps_store)
|
|
53
53
|
|
|
@@ -5,7 +5,8 @@ from typing import Any, Union
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import xarray as xr
|
|
7
7
|
|
|
8
|
-
from xradio._utils.list_and_array import
|
|
8
|
+
from xradio._utils.list_and_array import to_python_type
|
|
9
|
+
from xradio._utils.xarray_helpers import get_data_group_name, create_new_data_group
|
|
9
10
|
|
|
10
11
|
MS_DATASET_TYPES = {"visibility", "spectrum", "radiometer"}
|
|
11
12
|
|
|
@@ -151,11 +152,7 @@ class MeasurementSetXdt:
|
|
|
151
152
|
if self._xdt.attrs.get("type") not in MS_DATASET_TYPES:
|
|
152
153
|
raise InvalidAccessorLocation(f"{self._xdt.path} is not a MSv4 node.")
|
|
153
154
|
|
|
154
|
-
|
|
155
|
-
if "base" in self._xdt.attrs["data_groups"]:
|
|
156
|
-
data_group_name = "base"
|
|
157
|
-
else:
|
|
158
|
-
data_group_name = list(self._xdt.attrs["data_groups"].keys())[0]
|
|
155
|
+
data_group_name = get_data_group_name(self._xdt, data_group_name)
|
|
159
156
|
|
|
160
157
|
field_and_source_xds_name = self._xdt.attrs["data_groups"][data_group_name][
|
|
161
158
|
"field_and_source"
|
|
@@ -188,16 +185,12 @@ class MeasurementSetXdt:
|
|
|
188
185
|
f"{self._xdt.path} is not a MSv4 node (type {self._xdt.attrs.get('type')}."
|
|
189
186
|
)
|
|
190
187
|
|
|
191
|
-
|
|
192
|
-
if "base" in self._xdt.attrs["data_groups"]:
|
|
193
|
-
data_group_name = "base"
|
|
194
|
-
else:
|
|
195
|
-
data_group_name = list(self._xdt.attrs["data_groups"].keys())[0]
|
|
188
|
+
data_group_name = get_data_group_name(self._xdt, data_group_name)
|
|
196
189
|
|
|
197
190
|
field_and_source_xds = self._xdt.xr_ms.get_field_and_source_xds(data_group_name)
|
|
198
191
|
|
|
199
192
|
if "line_name" in field_and_source_xds.coords:
|
|
200
|
-
line_name =
|
|
193
|
+
line_name = to_python_type(
|
|
201
194
|
np.unique(np.ravel(field_and_source_xds.line_name.values))
|
|
202
195
|
)
|
|
203
196
|
else:
|
|
@@ -218,10 +211,14 @@ class MeasurementSetXdt:
|
|
|
218
211
|
partition_info = {
|
|
219
212
|
"spectral_window_name": self._xdt.frequency.attrs["spectral_window_name"],
|
|
220
213
|
"spectral_window_intents": spw_intent,
|
|
221
|
-
"field_name":
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
"
|
|
214
|
+
"field_name": to_python_type(
|
|
215
|
+
np.unique(field_and_source_xds.field_name.values)
|
|
216
|
+
),
|
|
217
|
+
"polarization_setup": to_python_type(self._xdt.polarization.values),
|
|
218
|
+
"scan_name": to_python_type(np.unique(self._xdt.scan_name.values)),
|
|
219
|
+
"source_name": to_python_type(
|
|
220
|
+
np.unique(field_and_source_xds.source_name.values)
|
|
221
|
+
),
|
|
225
222
|
"scan_intents": scan_intents,
|
|
226
223
|
"line_name": line_name,
|
|
227
224
|
"data_group_name": data_group_name,
|
|
@@ -232,13 +229,7 @@ class MeasurementSetXdt:
|
|
|
232
229
|
def add_data_group(
|
|
233
230
|
self,
|
|
234
231
|
new_data_group_name: str,
|
|
235
|
-
|
|
236
|
-
weight: str = None,
|
|
237
|
-
flag: str = None,
|
|
238
|
-
uvw: str = None,
|
|
239
|
-
field_and_source_xds: str = None,
|
|
240
|
-
date_time: str = None,
|
|
241
|
-
description: str = None,
|
|
232
|
+
new_data_group: dict = {},
|
|
242
233
|
data_group_dv_shared_with: str = None,
|
|
243
234
|
) -> xr.DataTree:
|
|
244
235
|
"""Adds a data group to the MSv4 DataTree, grouping the given data, weight, flag, etc. variables
|
|
@@ -248,20 +239,8 @@ class MeasurementSetXdt:
|
|
|
248
239
|
----------
|
|
249
240
|
new_data_group_name : str
|
|
250
241
|
_description_
|
|
251
|
-
|
|
252
|
-
_description_
|
|
253
|
-
weights : str, optional
|
|
254
|
-
_description_, by default None
|
|
255
|
-
flag : str, optional
|
|
256
|
-
_description_, by default None
|
|
257
|
-
uvw : str, optional
|
|
258
|
-
_description_, by default None
|
|
259
|
-
field_and_source_xds : str, optional
|
|
260
|
-
_description_, by default None
|
|
261
|
-
date_time : str, optional
|
|
262
|
-
_description_, by default None
|
|
263
|
-
description : str, optional
|
|
264
|
-
_description_, by default None
|
|
242
|
+
new_data_group : dict
|
|
243
|
+
_description_
|
|
265
244
|
data_group_dv_shared_with : str, optional
|
|
266
245
|
_description_, by default "base"
|
|
267
246
|
|
|
@@ -271,63 +250,74 @@ class MeasurementSetXdt:
|
|
|
271
250
|
MSv4 DataTree with the new group added
|
|
272
251
|
"""
|
|
273
252
|
|
|
274
|
-
if
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
)
|
|
286
|
-
|
|
287
|
-
if weight is None:
|
|
288
|
-
weight = default_data_group["weight"]
|
|
289
|
-
new_data_group["weight"] = weight
|
|
290
|
-
assert (
|
|
291
|
-
weight in self._xdt.ds.data_vars
|
|
292
|
-
), f"Data variable {weight} not found in dataset."
|
|
293
|
-
|
|
294
|
-
if flag is None:
|
|
295
|
-
flag = default_data_group["flag"]
|
|
296
|
-
new_data_group["flag"] = flag
|
|
297
|
-
assert (
|
|
298
|
-
flag in self._xdt.ds.data_vars
|
|
299
|
-
), f"Data variable {flag} not found in dataset."
|
|
300
|
-
|
|
301
|
-
if self._xdt.attrs["type"] == "visibility":
|
|
302
|
-
if uvw is None:
|
|
303
|
-
uvw = default_data_group["uvw"]
|
|
304
|
-
new_data_group["uvw"] = uvw
|
|
305
|
-
assert (
|
|
306
|
-
uvw in self._xdt.ds.data_vars
|
|
307
|
-
), f"Data variable {uvw} not found in dataset."
|
|
308
|
-
|
|
309
|
-
if field_and_source_xds is None:
|
|
310
|
-
field_and_source_xds = default_data_group["field_and_source"]
|
|
311
|
-
new_data_group["field_and_source"] = field_and_source_xds
|
|
312
|
-
assert (
|
|
313
|
-
field_and_source_xds in self._xdt.children
|
|
314
|
-
), f"Data variable {field_and_source_xds} not found in dataset."
|
|
315
|
-
|
|
316
|
-
if date_time is None:
|
|
317
|
-
date_time = datetime.datetime.now(datetime.timezone.utc).isoformat()
|
|
318
|
-
new_data_group["date"] = date_time
|
|
319
|
-
|
|
320
|
-
if description is None:
|
|
321
|
-
description = ""
|
|
322
|
-
new_data_group["description"] = description
|
|
253
|
+
if self._xdt.attrs.get("type") not in MS_DATASET_TYPES:
|
|
254
|
+
raise InvalidAccessorLocation(
|
|
255
|
+
f"{self._xdt.path} is not a MSv4 node (type {self._xdt.attrs.get('type')}."
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
new_data_group_name, new_data_group = create_new_data_group(
|
|
259
|
+
self._xdt,
|
|
260
|
+
"msv4",
|
|
261
|
+
new_data_group_name,
|
|
262
|
+
new_data_group,
|
|
263
|
+
data_group_dv_shared_with=data_group_dv_shared_with,
|
|
264
|
+
)
|
|
323
265
|
|
|
324
266
|
self._xdt.attrs["data_groups"][new_data_group_name] = new_data_group
|
|
325
|
-
|
|
326
267
|
return self._xdt
|
|
327
268
|
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
269
|
+
# data_group_dv_shared_with = get_data_group_name(
|
|
270
|
+
# self._xdt, data_group_dv_shared_with
|
|
271
|
+
# )
|
|
272
|
+
|
|
273
|
+
# default_data_group = self._xdt.attrs["data_groups"][data_group_dv_shared_with]
|
|
274
|
+
|
|
275
|
+
# new_data_group = {}
|
|
276
|
+
|
|
277
|
+
# if correlated_data is None:
|
|
278
|
+
# correlated_data = default_data_group["correlated_data"]
|
|
279
|
+
# new_data_group["correlated_data"] = correlated_data
|
|
280
|
+
# assert (
|
|
281
|
+
# correlated_data in self._xdt.ds.data_vars
|
|
282
|
+
# ), f"Data variable {correlated_data} not found in dataset."
|
|
283
|
+
|
|
284
|
+
# if weight is None:
|
|
285
|
+
# weight = default_data_group["weight"]
|
|
286
|
+
# new_data_group["weight"] = weight
|
|
287
|
+
# assert (
|
|
288
|
+
# weight in self._xdt.ds.data_vars
|
|
289
|
+
# ), f"Data variable {weight} not found in dataset."
|
|
290
|
+
|
|
291
|
+
# if flag is None:
|
|
292
|
+
# flag = default_data_group["flag"]
|
|
293
|
+
# new_data_group["flag"] = flag
|
|
294
|
+
# assert (
|
|
295
|
+
# flag in self._xdt.ds.data_vars
|
|
296
|
+
# ), f"Data variable {flag} not found in dataset."
|
|
297
|
+
|
|
298
|
+
# if self._xdt.attrs["type"] == "visibility":
|
|
299
|
+
# if uvw is None:
|
|
300
|
+
# uvw = default_data_group["uvw"]
|
|
301
|
+
# new_data_group["uvw"] = uvw
|
|
302
|
+
# assert (
|
|
303
|
+
# uvw in self._xdt.ds.data_vars
|
|
304
|
+
# ), f"Data variable {uvw} not found in dataset."
|
|
305
|
+
|
|
306
|
+
# if field_and_source_xds is None:
|
|
307
|
+
# field_and_source_xds = default_data_group["field_and_source"]
|
|
308
|
+
# new_data_group["field_and_source"] = field_and_source_xds
|
|
309
|
+
# assert (
|
|
310
|
+
# field_and_source_xds in self._xdt.children
|
|
311
|
+
# ), f"Data variable {field_and_source_xds} not found in dataset."
|
|
312
|
+
|
|
313
|
+
# if date_time is None:
|
|
314
|
+
# date_time = datetime.datetime.now(datetime.timezone.utc).isoformat()
|
|
315
|
+
# new_data_group["date"] = date_time
|
|
316
|
+
|
|
317
|
+
# if description is None:
|
|
318
|
+
# description = ""
|
|
319
|
+
# new_data_group["description"] = description
|
|
320
|
+
|
|
321
|
+
# self._xdt.attrs["data_groups"][new_data_group_name] = new_data_group
|
|
322
|
+
|
|
323
|
+
# return self._xdt
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
from xradio._utils.zarr.common import _get_file_system_and_items
|
|
2
|
-
import s3fs
|
|
3
2
|
import xarray as xr
|
|
4
3
|
|
|
5
4
|
|
|
@@ -25,12 +24,17 @@ def open_processing_set(
|
|
|
25
24
|
"""
|
|
26
25
|
|
|
27
26
|
file_system, ms_store_list = _get_file_system_and_items(ps_store)
|
|
27
|
+
import s3fs
|
|
28
28
|
|
|
29
29
|
if isinstance(file_system, s3fs.core.S3FileSystem):
|
|
30
30
|
mapping = s3fs.S3Map(root=ps_store, s3=file_system, check=False)
|
|
31
|
-
ps_xdt = xr.open_datatree(
|
|
31
|
+
ps_xdt = xr.open_datatree(
|
|
32
|
+
mapping, engine="zarr", chunks={}, chunked_array_type="dask"
|
|
33
|
+
)
|
|
32
34
|
else:
|
|
33
|
-
ps_xdt = xr.open_datatree(
|
|
35
|
+
ps_xdt = xr.open_datatree(
|
|
36
|
+
ps_store, engine="zarr", chunks={}, chunked_array_type="dask"
|
|
37
|
+
)
|
|
34
38
|
|
|
35
39
|
# Future work is to add ASDM backend
|
|
36
40
|
|
|
@@ -2,6 +2,7 @@ import pandas as pd
|
|
|
2
2
|
from xradio._utils.list_and_array import to_list
|
|
3
3
|
import numpy as np
|
|
4
4
|
import xarray as xr
|
|
5
|
+
from xradio.measurement_set.measurement_set_xdt import get_data_group_name
|
|
5
6
|
|
|
6
7
|
PS_DATASET_TYPES = {"processing_set"}
|
|
7
8
|
|
|
@@ -38,7 +39,7 @@ class ProcessingSetXdt:
|
|
|
38
39
|
self.meta = {"summary": {}}
|
|
39
40
|
|
|
40
41
|
def summary(
|
|
41
|
-
self,
|
|
42
|
+
self, data_group_name: str | None = None, first_columns: list[str] = None
|
|
42
43
|
) -> pd.DataFrame:
|
|
43
44
|
"""
|
|
44
45
|
Generate and retrieve a summary of the Processing Set as a data frame.
|
|
@@ -53,7 +54,7 @@ class ProcessingSetXdt:
|
|
|
53
54
|
|
|
54
55
|
Parameters
|
|
55
56
|
----------
|
|
56
|
-
|
|
57
|
+
data_group_name : str, optional
|
|
57
58
|
The data group to summarize. By default the "base" group
|
|
58
59
|
is used (if found), or otherwise the first group found.
|
|
59
60
|
first_columns : list[str], optional
|
|
@@ -73,31 +74,33 @@ class ProcessingSetXdt:
|
|
|
73
74
|
A DataFrame containing the summary information of the specified data group.
|
|
74
75
|
"""
|
|
75
76
|
|
|
76
|
-
def find_data_group_base_or_first(
|
|
77
|
+
def find_data_group_base_or_first(
|
|
78
|
+
data_group_name: str, xdt: xr.DataTree
|
|
79
|
+
) -> str:
|
|
77
80
|
first_msv4 = next(iter(xdt.values()))
|
|
78
81
|
first_data_groups = first_msv4.attrs["data_groups"]
|
|
79
|
-
if
|
|
80
|
-
|
|
82
|
+
if data_group_name is None:
|
|
83
|
+
data_group_name = (
|
|
81
84
|
"base"
|
|
82
85
|
if "base" in first_data_groups
|
|
83
86
|
else next(iter(first_data_groups))
|
|
84
87
|
)
|
|
85
|
-
return
|
|
88
|
+
return data_group_name
|
|
86
89
|
|
|
87
90
|
if self._xdt.attrs.get("type") not in PS_DATASET_TYPES:
|
|
88
91
|
raise InvalidAccessorLocation(
|
|
89
92
|
f"{self._xdt.path} is not a processing set node."
|
|
90
93
|
)
|
|
91
94
|
|
|
92
|
-
|
|
95
|
+
data_group_name = find_data_group_base_or_first(data_group_name, self._xdt)
|
|
93
96
|
|
|
94
|
-
if
|
|
95
|
-
summary = self.meta["summary"][
|
|
97
|
+
if data_group_name in self.meta["summary"]:
|
|
98
|
+
summary = self.meta["summary"][data_group_name]
|
|
96
99
|
else:
|
|
97
|
-
self.meta["summary"][
|
|
98
|
-
|
|
99
|
-
)
|
|
100
|
-
summary = self.meta["summary"][
|
|
100
|
+
self.meta["summary"][data_group_name] = self._summary(
|
|
101
|
+
data_group_name
|
|
102
|
+
).sort_values(by=["name"], ascending=True)
|
|
103
|
+
summary = self.meta["summary"][data_group_name]
|
|
101
104
|
|
|
102
105
|
if first_columns:
|
|
103
106
|
found_columns = [col for col in first_columns if col in summary.columns]
|
|
@@ -187,7 +190,7 @@ class ProcessingSetXdt:
|
|
|
187
190
|
self.meta["freq_axis"] = freq_axis
|
|
188
191
|
return self.meta["freq_axis"]
|
|
189
192
|
|
|
190
|
-
def _summary(self,
|
|
193
|
+
def _summary(self, data_group_name: str = None):
|
|
191
194
|
summary_data = {
|
|
192
195
|
"name": [],
|
|
193
196
|
"scan_intents": [],
|
|
@@ -225,7 +228,7 @@ class ProcessingSetXdt:
|
|
|
225
228
|
)
|
|
226
229
|
summary_data["polarization"].append(value.polarization.values)
|
|
227
230
|
summary_data["scan_name"].append(partition_info["scan_name"])
|
|
228
|
-
data_name = value.attrs["data_groups"][
|
|
231
|
+
data_name = value.attrs["data_groups"][data_group_name]["correlated_data"]
|
|
229
232
|
|
|
230
233
|
if "VISIBILITY" in data_name:
|
|
231
234
|
center_name = "FIELD_PHASE_CENTER_DIRECTION"
|
|
@@ -253,7 +256,7 @@ class ProcessingSetXdt:
|
|
|
253
256
|
)
|
|
254
257
|
summary_data["end_frequency"].append(to_list(value["frequency"].values)[-1])
|
|
255
258
|
|
|
256
|
-
field_and_source_xds = value.xr_ms.get_field_and_source_xds(
|
|
259
|
+
field_and_source_xds = value.xr_ms.get_field_and_source_xds(data_group_name)
|
|
257
260
|
|
|
258
261
|
if field_and_source_xds.attrs["type"] == "field_and_source_ephemeris":
|
|
259
262
|
summary_data["field_coords"].append("Ephemeris")
|
|
@@ -377,14 +380,16 @@ class ProcessingSetXdt:
|
|
|
377
380
|
|
|
378
381
|
return sub_ps_xdt
|
|
379
382
|
|
|
380
|
-
def get_combined_field_and_source_xds(
|
|
383
|
+
def get_combined_field_and_source_xds(
|
|
384
|
+
self, data_group_name: str = "base"
|
|
385
|
+
) -> xr.Dataset:
|
|
381
386
|
"""
|
|
382
387
|
Combine all non-ephemeris `field_and_source_xds` datasets from a Processing Set for a data group into a
|
|
383
388
|
single dataset.
|
|
384
389
|
|
|
385
390
|
Parameters
|
|
386
391
|
----------
|
|
387
|
-
|
|
392
|
+
data_group_name : str, optional
|
|
388
393
|
The data group to process. Default is "base".
|
|
389
394
|
|
|
390
395
|
Returns
|
|
@@ -405,7 +410,9 @@ class ProcessingSetXdt:
|
|
|
405
410
|
|
|
406
411
|
combined_field_and_source_xds = xr.Dataset()
|
|
407
412
|
for ms_name, ms_xdt in self._xdt.items():
|
|
408
|
-
field_and_source_xds = ms_xdt.xr_ms.get_field_and_source_xds(
|
|
413
|
+
field_and_source_xds = ms_xdt.xr_ms.get_field_and_source_xds(
|
|
414
|
+
data_group_name
|
|
415
|
+
)
|
|
409
416
|
|
|
410
417
|
if not field_and_source_xds.attrs["type"] == "field_and_source_ephemeris":
|
|
411
418
|
|
|
@@ -474,14 +481,14 @@ class ProcessingSetXdt:
|
|
|
474
481
|
return combined_field_and_source_xds
|
|
475
482
|
|
|
476
483
|
def get_combined_field_and_source_xds_ephemeris(
|
|
477
|
-
self,
|
|
484
|
+
self, data_group_name: str = "base"
|
|
478
485
|
) -> xr.Dataset:
|
|
479
486
|
"""
|
|
480
487
|
Combine all ephemeris `field_and_source_xds` datasets from a Processing Set for a datagroup into a single dataset.
|
|
481
488
|
|
|
482
489
|
Parameters
|
|
483
490
|
----------
|
|
484
|
-
|
|
491
|
+
data_group_name : str, optional
|
|
485
492
|
The data group to process. Default is "base".
|
|
486
493
|
|
|
487
494
|
Returns
|
|
@@ -503,7 +510,7 @@ class ProcessingSetXdt:
|
|
|
503
510
|
combined_ephemeris_field_and_source_xds = xr.Dataset()
|
|
504
511
|
for ms_name, ms_xdt in self._xdt.items():
|
|
505
512
|
field_and_source_xds = field_and_source_xds = (
|
|
506
|
-
ms_xdt.xr_ms.get_field_and_source_xds(
|
|
513
|
+
ms_xdt.xr_ms.get_field_and_source_xds(data_group_name)
|
|
507
514
|
)
|
|
508
515
|
|
|
509
516
|
if field_and_source_xds.attrs["type"] == "field_and_source_ephemeris":
|
|
@@ -599,7 +606,7 @@ class ProcessingSetXdt:
|
|
|
599
606
|
return combined_ephemeris_field_and_source_xds
|
|
600
607
|
|
|
601
608
|
def plot_phase_centers(
|
|
602
|
-
self, label_all_fields: bool = False,
|
|
609
|
+
self, label_all_fields: bool = False, data_group_name: str = "base"
|
|
603
610
|
):
|
|
604
611
|
"""
|
|
605
612
|
Plot the phase center locations of all fields in the Processing Set.
|
|
@@ -612,7 +619,7 @@ class ProcessingSetXdt:
|
|
|
612
619
|
----------
|
|
613
620
|
label_all_fields : bool, optional
|
|
614
621
|
If `True`, all fields will be labeled on the plot. Default is `False`.
|
|
615
|
-
|
|
622
|
+
data_group_name : str, optional
|
|
616
623
|
The data group to use for processing. Default is "base".
|
|
617
624
|
|
|
618
625
|
Returns
|
|
@@ -641,10 +648,10 @@ class ProcessingSetXdt:
|
|
|
641
648
|
)
|
|
642
649
|
|
|
643
650
|
combined_field_and_source_xds = self.get_combined_field_and_source_xds(
|
|
644
|
-
|
|
651
|
+
data_group_name
|
|
645
652
|
)
|
|
646
653
|
combined_ephemeris_field_and_source_xds = (
|
|
647
|
-
self.get_combined_field_and_source_xds_ephemeris(
|
|
654
|
+
self.get_combined_field_and_source_xds_ephemeris(data_group_name)
|
|
648
655
|
)
|
|
649
656
|
from matplotlib import pyplot as plt
|
|
650
657
|
|