xradio 1.0.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xradio/_utils/_casacore/casacore_from_casatools.py +1 -1
- xradio/_utils/dict_helpers.py +38 -7
- xradio/_utils/list_and_array.py +26 -3
- xradio/_utils/schema.py +44 -0
- xradio/_utils/xarray_helpers.py +63 -0
- xradio/_utils/zarr/common.py +4 -2
- xradio/image/__init__.py +4 -2
- xradio/image/_util/_casacore/common.py +2 -1
- xradio/image/_util/_casacore/xds_from_casacore.py +105 -51
- xradio/image/_util/_casacore/xds_to_casacore.py +117 -52
- xradio/image/_util/_fits/xds_from_fits.py +124 -36
- xradio/image/_util/_zarr/common.py +0 -1
- xradio/image/_util/casacore.py +133 -16
- xradio/image/_util/common.py +6 -5
- xradio/image/_util/image_factory.py +466 -27
- xradio/image/image.py +72 -100
- xradio/image/image_xds.py +262 -0
- xradio/image/schema.py +85 -0
- xradio/measurement_set/__init__.py +5 -4
- xradio/measurement_set/_utils/_msv2/_tables/read.py +7 -3
- xradio/measurement_set/_utils/_msv2/conversion.py +6 -9
- xradio/measurement_set/_utils/_msv2/create_field_and_source_xds.py +1 -0
- xradio/measurement_set/_utils/_msv2/msv4_sub_xdss.py +1 -1
- xradio/measurement_set/_utils/_utils/interpolate.py +5 -0
- xradio/measurement_set/_utils/_utils/partition_attrs.py +0 -1
- xradio/measurement_set/convert_msv2_to_processing_set.py +9 -9
- xradio/measurement_set/load_processing_set.py +2 -2
- xradio/measurement_set/measurement_set_xdt.py +83 -93
- xradio/measurement_set/open_processing_set.py +7 -3
- xradio/measurement_set/processing_set_xdt.py +33 -26
- xradio/schema/check.py +70 -19
- xradio/schema/common.py +0 -1
- xradio/testing/__init__.py +0 -0
- xradio/testing/_utils/__template__.py +58 -0
- xradio/testing/measurement_set/__init__.py +58 -0
- xradio/testing/measurement_set/checker.py +131 -0
- xradio/testing/measurement_set/io.py +22 -0
- xradio/testing/measurement_set/msv2_io.py +1854 -0
- {xradio-1.0.1.dist-info → xradio-1.1.0.dist-info}/METADATA +64 -23
- xradio-1.1.0.dist-info/RECORD +75 -0
- {xradio-1.0.1.dist-info → xradio-1.1.0.dist-info}/WHEEL +1 -1
- xradio-1.0.1.dist-info/RECORD +0 -66
- {xradio-1.0.1.dist-info → xradio-1.1.0.dist-info}/licenses/LICENSE.txt +0 -0
- {xradio-1.0.1.dist-info → xradio-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -4,8 +4,10 @@ import numpy as np
|
|
|
4
4
|
import xarray as xr
|
|
5
5
|
from typing import List, Union
|
|
6
6
|
from .common import _c, _compute_world_sph_dims, _l_m_attr_notes
|
|
7
|
+
import toolviper.utils.logger as logger
|
|
7
8
|
from xradio._utils.coord_math import _deg_to_rad
|
|
8
9
|
from xradio._utils.dict_helpers import (
|
|
10
|
+
make_direction_location_dict,
|
|
9
11
|
make_spectral_coord_reference_dict,
|
|
10
12
|
make_quantity,
|
|
11
13
|
make_skycoord_dict,
|
|
@@ -18,6 +20,23 @@ def _input_checks(
|
|
|
18
20
|
image_size: Union[list, np.ndarray],
|
|
19
21
|
cell_size: Union[list, np.ndarray],
|
|
20
22
|
) -> None:
|
|
23
|
+
"""
|
|
24
|
+
Validate input parameters for image creation functions.
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
phase_center : list or np.ndarray
|
|
29
|
+
Image phase center coordinates. Must have exactly 2 elements.
|
|
30
|
+
image_size : list or np.ndarray
|
|
31
|
+
Number of pixels along each axis. Must have exactly 2 elements.
|
|
32
|
+
cell_size : list or np.ndarray
|
|
33
|
+
Size of pixels along each axis. Must have exactly 2 elements.
|
|
34
|
+
|
|
35
|
+
Raises
|
|
36
|
+
------
|
|
37
|
+
ValueError
|
|
38
|
+
If any parameter does not have exactly 2 elements.
|
|
39
|
+
"""
|
|
21
40
|
if len(image_size) != 2:
|
|
22
41
|
raise ValueError("image_size must have exactly two elements")
|
|
23
42
|
if len(phase_center) != 2:
|
|
@@ -27,18 +46,22 @@ def _input_checks(
|
|
|
27
46
|
|
|
28
47
|
|
|
29
48
|
def _make_coords(
|
|
30
|
-
|
|
49
|
+
frequency_coords: Union[list, np.ndarray],
|
|
31
50
|
time_coords: Union[list, np.ndarray],
|
|
32
51
|
) -> dict:
|
|
33
|
-
if not isinstance(
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
52
|
+
if not isinstance(frequency_coords, list) and not isinstance(
|
|
53
|
+
frequency_coords, np.ndarray
|
|
54
|
+
):
|
|
55
|
+
frequency_coords = [frequency_coords]
|
|
56
|
+
frequency_coords = np.array(frequency_coords, dtype=np.float64)
|
|
57
|
+
restfreq = frequency_coords[len(frequency_coords) // 2]
|
|
58
|
+
vel_coords = (1 - frequency_coords / restfreq) * _c.to("m/s").value
|
|
38
59
|
if not isinstance(time_coords, list) and not isinstance(time_coords, np.ndarray):
|
|
39
60
|
time_coords = [time_coords]
|
|
40
61
|
time_coords = np.array(time_coords, dtype=np.float64)
|
|
41
|
-
return dict(
|
|
62
|
+
return dict(
|
|
63
|
+
chan=frequency_coords, vel=vel_coords, time=time_coords, restfreq=restfreq
|
|
64
|
+
)
|
|
42
65
|
|
|
43
66
|
|
|
44
67
|
def _add_common_attrs(
|
|
@@ -72,24 +95,26 @@ def _add_common_attrs(
|
|
|
72
95
|
reference["attrs"].update({"equinox": "j2000.0"})
|
|
73
96
|
xds.attrs = {
|
|
74
97
|
"data_groups": {"base": {}},
|
|
75
|
-
"
|
|
76
|
-
"
|
|
77
|
-
"
|
|
78
|
-
|
|
79
|
-
|
|
98
|
+
"coordinate_system_info": {
|
|
99
|
+
"reference_direction": reference,
|
|
100
|
+
"native_pole_direction": make_direction_location_dict(
|
|
101
|
+
[np.pi, 0.0], "rad", "native_projection"
|
|
102
|
+
),
|
|
103
|
+
"pixel_coordinate_transformation_matrix": [[1.0, 0.0], [0.0, 1.0]],
|
|
80
104
|
"projection": projection,
|
|
81
105
|
"projection_parameters": [0.0, 0.0],
|
|
82
106
|
},
|
|
107
|
+
"type": "image",
|
|
83
108
|
}
|
|
84
109
|
return xds
|
|
85
110
|
|
|
86
111
|
|
|
87
112
|
def _make_common_coords(
|
|
88
113
|
pol_coords: Union[list, np.ndarray],
|
|
89
|
-
|
|
114
|
+
frequency_coords: Union[list, np.ndarray],
|
|
90
115
|
time_coords: Union[list, np.ndarray],
|
|
91
116
|
) -> dict:
|
|
92
|
-
some_coords = _make_coords(
|
|
117
|
+
some_coords = _make_coords(frequency_coords, time_coords)
|
|
93
118
|
return {
|
|
94
119
|
"coords": {
|
|
95
120
|
"time": some_coords["time"],
|
|
@@ -147,7 +172,7 @@ def _make_empty_sky_image(
|
|
|
147
172
|
phase_center: Union[list, np.ndarray],
|
|
148
173
|
image_size: Union[list, np.ndarray],
|
|
149
174
|
cell_size: Union[list, np.ndarray],
|
|
150
|
-
|
|
175
|
+
frequency_coords: Union[list, np.ndarray],
|
|
151
176
|
pol_coords: Union[list, np.ndarray],
|
|
152
177
|
time_coords: Union[list, np.ndarray],
|
|
153
178
|
direction_reference: str,
|
|
@@ -156,7 +181,7 @@ def _make_empty_sky_image(
|
|
|
156
181
|
do_sky_coords: bool,
|
|
157
182
|
) -> xr.Dataset:
|
|
158
183
|
_input_checks(phase_center, image_size, cell_size)
|
|
159
|
-
cc = _make_common_coords(pol_coords,
|
|
184
|
+
cc = _make_common_coords(pol_coords, frequency_coords, time_coords)
|
|
160
185
|
coords = cc["coords"]
|
|
161
186
|
lm_values = _make_lm_values(image_size, cell_size)
|
|
162
187
|
coords.update(lm_values)
|
|
@@ -196,12 +221,8 @@ def _make_uv_values(
|
|
|
196
221
|
) -> dict:
|
|
197
222
|
im_size_wave = 1 / np.array(sky_image_cell_size)
|
|
198
223
|
uv_cell_size = im_size_wave / np.array(image_size)
|
|
199
|
-
u_vals = [
|
|
200
|
-
|
|
201
|
-
]
|
|
202
|
-
v_vals = [
|
|
203
|
-
(i - image_size[1] // 2) * abs(uv_cell_size[1]) for i in range(image_size[1])
|
|
204
|
-
]
|
|
224
|
+
u_vals = [(i - image_size[0] // 2) * uv_cell_size[0] for i in range(image_size[0])]
|
|
225
|
+
v_vals = [(i - image_size[1] // 2) * uv_cell_size[1] for i in range(image_size[1])]
|
|
205
226
|
return {"u": u_vals, "v": v_vals}
|
|
206
227
|
|
|
207
228
|
|
|
@@ -209,7 +230,7 @@ def _make_empty_aperture_image(
|
|
|
209
230
|
phase_center: Union[list, np.ndarray],
|
|
210
231
|
image_size: Union[list, np.ndarray],
|
|
211
232
|
sky_image_cell_size: Union[list, np.ndarray],
|
|
212
|
-
|
|
233
|
+
frequency_coords: Union[list, np.ndarray],
|
|
213
234
|
pol_coords: Union[list, np.ndarray],
|
|
214
235
|
time_coords: Union[list, np.ndarray],
|
|
215
236
|
direction_reference: str,
|
|
@@ -217,7 +238,7 @@ def _make_empty_aperture_image(
|
|
|
217
238
|
spectral_reference: str,
|
|
218
239
|
) -> xr.Dataset:
|
|
219
240
|
_input_checks(phase_center, image_size, sky_image_cell_size)
|
|
220
|
-
cc = _make_common_coords(pol_coords,
|
|
241
|
+
cc = _make_common_coords(pol_coords, frequency_coords, time_coords)
|
|
221
242
|
coords = cc["coords"]
|
|
222
243
|
xds = xr.Dataset(coords=coords)
|
|
223
244
|
xds = _make_uv_coords(xds, image_size, sky_image_cell_size)
|
|
@@ -235,14 +256,29 @@ def _make_empty_aperture_image(
|
|
|
235
256
|
|
|
236
257
|
|
|
237
258
|
def _move_beam_param_dim_coord(xds: xr.Dataset) -> xr.Dataset:
|
|
238
|
-
|
|
259
|
+
"""
|
|
260
|
+
Add beam_params_label coordinate to an xarray Dataset.
|
|
261
|
+
|
|
262
|
+
Parameters
|
|
263
|
+
----------
|
|
264
|
+
xds : xr.Dataset
|
|
265
|
+
Input Dataset to which beam parameters will be added.
|
|
266
|
+
|
|
267
|
+
Returns
|
|
268
|
+
-------
|
|
269
|
+
xr.Dataset
|
|
270
|
+
Dataset with beam_params_label coordinate containing ['major', 'minor', 'pa'].
|
|
271
|
+
"""
|
|
272
|
+
return xds.assign_coords(
|
|
273
|
+
beam_params_label=("beam_params_label", ["major", "minor", "pa"])
|
|
274
|
+
)
|
|
239
275
|
|
|
240
276
|
|
|
241
277
|
def _make_empty_lmuv_image(
|
|
242
278
|
phase_center: Union[list, np.ndarray],
|
|
243
279
|
image_size: Union[list, np.ndarray],
|
|
244
280
|
sky_image_cell_size: Union[list, np.ndarray],
|
|
245
|
-
|
|
281
|
+
frequency_coords: Union[list, np.ndarray],
|
|
246
282
|
pol_coords: Union[list, np.ndarray],
|
|
247
283
|
time_coords: Union[list, np.ndarray],
|
|
248
284
|
direction_reference: str,
|
|
@@ -254,7 +290,7 @@ def _make_empty_lmuv_image(
|
|
|
254
290
|
phase_center,
|
|
255
291
|
image_size,
|
|
256
292
|
sky_image_cell_size,
|
|
257
|
-
|
|
293
|
+
frequency_coords,
|
|
258
294
|
pol_coords,
|
|
259
295
|
time_coords,
|
|
260
296
|
direction_reference,
|
|
@@ -265,3 +301,406 @@ def _make_empty_lmuv_image(
|
|
|
265
301
|
xds = _make_uv_coords(xds, image_size, sky_image_cell_size)
|
|
266
302
|
xds = _move_beam_param_dim_coord(xds)
|
|
267
303
|
return xds
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def detect_store_type(store):
|
|
307
|
+
"""
|
|
308
|
+
Detect the storage format type of an image store.
|
|
309
|
+
|
|
310
|
+
Parameters
|
|
311
|
+
----------
|
|
312
|
+
store : str or dict
|
|
313
|
+
Path to the image store or a dictionary representation.
|
|
314
|
+
|
|
315
|
+
Returns
|
|
316
|
+
-------
|
|
317
|
+
str
|
|
318
|
+
The detected store type: 'fits', 'casa', or 'zarr'.
|
|
319
|
+
|
|
320
|
+
Raises
|
|
321
|
+
------
|
|
322
|
+
ValueError
|
|
323
|
+
If the directory structure is unknown or the path does not exist.
|
|
324
|
+
"""
|
|
325
|
+
import os
|
|
326
|
+
|
|
327
|
+
if isinstance(store, str):
|
|
328
|
+
if os.path.isfile(store):
|
|
329
|
+
store_type = "fits"
|
|
330
|
+
elif os.path.isdir(store):
|
|
331
|
+
if "table.info" in os.listdir(store):
|
|
332
|
+
store_type = "casa"
|
|
333
|
+
elif ".zattrs" in os.listdir(store):
|
|
334
|
+
store_type = "zarr"
|
|
335
|
+
else:
|
|
336
|
+
logger.error("Unknown directory structure.")
|
|
337
|
+
raise ValueError("Unknown directory structure." + str(store))
|
|
338
|
+
else:
|
|
339
|
+
logger.error("Path does not exist.")
|
|
340
|
+
raise ValueError(
|
|
341
|
+
"Path does not exist. The current path: "
|
|
342
|
+
+ str(os.system("pwd"))
|
|
343
|
+
+ " .The current casa directory: "
|
|
344
|
+
+ str(os.system("ls 3c286_Band6_5chans_lsrk_robust_0.5_niter_99_casa"))
|
|
345
|
+
+ ". The current fits directory: "
|
|
346
|
+
+ str(os.system("ls 3c286_Band6_5chans_lsrk_robust_0.5_niter_99_fits"))
|
|
347
|
+
+ " The given store "
|
|
348
|
+
+ str(store)
|
|
349
|
+
)
|
|
350
|
+
else:
|
|
351
|
+
store_type = "zarr"
|
|
352
|
+
|
|
353
|
+
return store_type
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def detect_image_type(store):
|
|
357
|
+
"""
|
|
358
|
+
Detect the image type from the store name or path.
|
|
359
|
+
|
|
360
|
+
Infers the image type based on common naming patterns in the store path.
|
|
361
|
+
|
|
362
|
+
Parameters
|
|
363
|
+
----------
|
|
364
|
+
store : str or other
|
|
365
|
+
Path to the image store. If not a string, returns 'ALL'.
|
|
366
|
+
|
|
367
|
+
Returns
|
|
368
|
+
-------
|
|
369
|
+
str
|
|
370
|
+
The detected image type. Possible values include:
|
|
371
|
+
- 'SKY': Sky image (contains 'image' or 'im')
|
|
372
|
+
- 'POINT_SPREAD_FUNCTION': PSF image (contains 'psf')
|
|
373
|
+
- 'MODEL': Model image (contains 'model')
|
|
374
|
+
- 'RESIDUAL': Residual image (contains 'residual')
|
|
375
|
+
- 'MASK_DECONVOLVE': Mask image (contains 'mask')
|
|
376
|
+
- 'PRIMARY_BEAM': Primary beam image (contains 'pb')
|
|
377
|
+
- 'APERTURE': Aperture image (contains 'aperture')
|
|
378
|
+
- 'VISIBILITY': Visibility image (contains 'visibility')
|
|
379
|
+
- 'VISIBILITY_NORMALIZATION': Visibility normalization (contains 'sumwt')
|
|
380
|
+
- 'UNKNOWN': Could not detect type from name
|
|
381
|
+
- 'ALL': Non-string store type
|
|
382
|
+
"""
|
|
383
|
+
import os
|
|
384
|
+
|
|
385
|
+
if isinstance(store, str):
|
|
386
|
+
if "fits" in store.lower():
|
|
387
|
+
image_type = "SKY"
|
|
388
|
+
elif "image" in store.lower():
|
|
389
|
+
image_type = "SKY"
|
|
390
|
+
elif "sky" in store.lower():
|
|
391
|
+
image_type = "SKY"
|
|
392
|
+
elif "psf" in store.lower():
|
|
393
|
+
image_type = "POINT_SPREAD_FUNCTION"
|
|
394
|
+
elif "model" in store.lower():
|
|
395
|
+
image_type = "MODEL"
|
|
396
|
+
elif "residual" in store.lower():
|
|
397
|
+
image_type = "RESIDUAL"
|
|
398
|
+
elif "dirty" in store.lower():
|
|
399
|
+
image_type = "DIRTY"
|
|
400
|
+
elif "pb" in store.lower():
|
|
401
|
+
image_type = "PRIMARY_BEAM"
|
|
402
|
+
elif "aperture" in store.lower():
|
|
403
|
+
image_type = "APERTURE"
|
|
404
|
+
elif "uv" in store.lower():
|
|
405
|
+
# Must precede the "im" check so uv images aren't misclassified as SKY.
|
|
406
|
+
image_type = "APERTURE"
|
|
407
|
+
elif "visibility" in store.lower():
|
|
408
|
+
image_type = "VISIBILITY"
|
|
409
|
+
elif "sumwt" in store.lower():
|
|
410
|
+
image_type = "VISIBILITY_NORMALIZATION"
|
|
411
|
+
elif "zarr" in store.lower():
|
|
412
|
+
image_type = "ALL"
|
|
413
|
+
elif "im" in store.lower():
|
|
414
|
+
# Must precede the "mask" check so *mask*.im images aren't misclassified.
|
|
415
|
+
image_type = "SKY"
|
|
416
|
+
elif "mask" in store.lower():
|
|
417
|
+
image_type = "MASK_DECONVOLVE"
|
|
418
|
+
else:
|
|
419
|
+
image_type = "UNKNOWN"
|
|
420
|
+
else:
|
|
421
|
+
image_type = "ALL"
|
|
422
|
+
|
|
423
|
+
return image_type
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def create_store_dict(store_to_label):
|
|
427
|
+
"""
|
|
428
|
+
Create a standardized dictionary mapping image types to their store information.
|
|
429
|
+
|
|
430
|
+
Converts various input formats (string, list, or dict) into a consistent
|
|
431
|
+
dictionary format where keys are image types and values contain store metadata.
|
|
432
|
+
|
|
433
|
+
Parameters
|
|
434
|
+
----------
|
|
435
|
+
store_to_label : str, list, or dict
|
|
436
|
+
Input store specification. Can be:
|
|
437
|
+
- str: Single store path
|
|
438
|
+
- list: List of store paths (image types will be auto-detected)
|
|
439
|
+
- dict: Mapping of image types to store paths
|
|
440
|
+
|
|
441
|
+
Returns
|
|
442
|
+
-------
|
|
443
|
+
dict
|
|
444
|
+
Dictionary with image types as keys. Each value is a dict with:
|
|
445
|
+
- 'store_type': str, the format ('casa', 'fits', or 'zarr')
|
|
446
|
+
- 'store': str, the path to the store
|
|
447
|
+
|
|
448
|
+
Raises
|
|
449
|
+
------
|
|
450
|
+
ValueError
|
|
451
|
+
If image type cannot be detected or duplicate image types are found.
|
|
452
|
+
"""
|
|
453
|
+
store_list = None
|
|
454
|
+
if isinstance(store_to_label, str):
|
|
455
|
+
store_list = [store_to_label] # So can iterate over it.
|
|
456
|
+
elif isinstance(store_to_label, list):
|
|
457
|
+
store_list = store_to_label
|
|
458
|
+
|
|
459
|
+
if (store_list is not None) and isinstance(store_list, list):
|
|
460
|
+
store_dict_to_label = {i: v for i, v in enumerate(store_list)}
|
|
461
|
+
else:
|
|
462
|
+
store_dict_to_label = store_to_label
|
|
463
|
+
|
|
464
|
+
store_dict = {}
|
|
465
|
+
for image_type, store in store_dict_to_label.items():
|
|
466
|
+
|
|
467
|
+
if isinstance(image_type, int):
|
|
468
|
+
image_type = detect_image_type(store)
|
|
469
|
+
|
|
470
|
+
image_type = image_type.upper()
|
|
471
|
+
|
|
472
|
+
store_type = detect_store_type(store)
|
|
473
|
+
|
|
474
|
+
if image_type == "UNKNOWN":
|
|
475
|
+
logger.error(f"Could not detect image type for store {store}. ")
|
|
476
|
+
example = "store={'sky': 'path/to/image.fits'}"
|
|
477
|
+
raise ValueError(
|
|
478
|
+
f"Could not detect image type for store {store}. Please label the store with the image type explicitly. For example: {example}"
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
if image_type in store_dict:
|
|
482
|
+
logger.error(f"Duplicate image type {image_type} detected in store list.")
|
|
483
|
+
raise ValueError(
|
|
484
|
+
f"Duplicate image type {image_type} detected in store list. Please ensure each image type is unique. The store dict"
|
|
485
|
+
+ str(store_dict)
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
if store_type == "zarr":
|
|
489
|
+
image_type = "ALL" # Zarr can have multiple data variables.
|
|
490
|
+
|
|
491
|
+
store_dict[image_type] = {"store_type": store_type, "store": store}
|
|
492
|
+
|
|
493
|
+
data_groups = {}
|
|
494
|
+
|
|
495
|
+
for image_type in store_dict.keys():
|
|
496
|
+
if "sky" in image_type.lower():
|
|
497
|
+
if "sky" == image_type.lower():
|
|
498
|
+
data_groups["base"] = {"sky": image_type}
|
|
499
|
+
else:
|
|
500
|
+
data_group_name = image_type.lower().replace("sky_", "")
|
|
501
|
+
data_groups[data_group_name] = {"sky": image_type}
|
|
502
|
+
if "aperture" == image_type.lower():
|
|
503
|
+
data_groups["base"] = {"aperture": image_type}
|
|
504
|
+
|
|
505
|
+
return store_dict, data_groups
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
def create_image_xds_from_store(
|
|
509
|
+
store: Union[list, dict, str],
|
|
510
|
+
access_store_casa: callable,
|
|
511
|
+
casa_kwargs: dict,
|
|
512
|
+
access_store_fits: callable,
|
|
513
|
+
fits_kwargs: dict,
|
|
514
|
+
access_store_zarr: callable,
|
|
515
|
+
zarr_kwargs: dict,
|
|
516
|
+
) -> xr.Dataset:
|
|
517
|
+
"""
|
|
518
|
+
Create an xarray Dataset from one or more image stores.
|
|
519
|
+
|
|
520
|
+
This function reads image data from CASA, FITS, or zarr format stores and
|
|
521
|
+
combines them into a single xarray Dataset with appropriate metadata and
|
|
522
|
+
data variables.
|
|
523
|
+
|
|
524
|
+
Parameters
|
|
525
|
+
----------
|
|
526
|
+
store : str, list, or dict
|
|
527
|
+
Image store specification:
|
|
528
|
+
- str: Single store path
|
|
529
|
+
- list: List of store paths
|
|
530
|
+
- dict: Mapping of image types to store paths
|
|
531
|
+
access_store_casa : callable
|
|
532
|
+
Function to read CASA format images. Should accept a store path and
|
|
533
|
+
keyword arguments, returning an xr.Dataset.
|
|
534
|
+
casa_kwargs : dict
|
|
535
|
+
Keyword arguments to pass to access_store_casa.
|
|
536
|
+
access_store_fits : callable or None
|
|
537
|
+
Function to read FITS format images. Should accept a store path and
|
|
538
|
+
keyword arguments, returning an xr.Dataset. Can be None if FITS support
|
|
539
|
+
is not needed.
|
|
540
|
+
fits_kwargs : dict
|
|
541
|
+
Keyword arguments to pass to access_store_fits.
|
|
542
|
+
access_store_zarr : callable
|
|
543
|
+
Function to read zarr format images. Should accept a store path and
|
|
544
|
+
keyword arguments, returning an xr.Dataset.
|
|
545
|
+
zarr_kwargs : dict
|
|
546
|
+
Keyword arguments to pass to access_store_zarr.
|
|
547
|
+
|
|
548
|
+
Returns
|
|
549
|
+
-------
|
|
550
|
+
xr.Dataset
|
|
551
|
+
An xarray Dataset containing the image data and metadata. The Dataset
|
|
552
|
+
includes:
|
|
553
|
+
- Data variables for each image type (e.g., 'SKY', 'MODEL', 'RESIDUAL')
|
|
554
|
+
- Coordinates shared across all images
|
|
555
|
+
- Attributes including 'type' and 'data_groups'
|
|
556
|
+
|
|
557
|
+
Raises
|
|
558
|
+
------
|
|
559
|
+
ValueError
|
|
560
|
+
If zarr store with multiple data variables is combined with other stores.
|
|
561
|
+
RuntimeError
|
|
562
|
+
If FITS format is requested but access_store_fits is None, or if an
|
|
563
|
+
unrecognized image format is encountered.
|
|
564
|
+
|
|
565
|
+
Notes
|
|
566
|
+
-----
|
|
567
|
+
- Zarr stores can contain multiple data variables and will be returned as-is.
|
|
568
|
+
- For other formats, data from multiple stores is combined into one Dataset.
|
|
569
|
+
- BEAM_FIT_PARAMS from SKY images take precedence over POINT_SPREAD_FUNCTION.
|
|
570
|
+
- Masks are renamed to MASK_<IMAGE_TYPE> for internal masks.
|
|
571
|
+
"""
|
|
572
|
+
store_dict, data_groups = create_store_dict(store)
|
|
573
|
+
if "ALL" in store_dict and len(store_dict) > 1:
|
|
574
|
+
logger.error(
|
|
575
|
+
"When using a zarr store with multiple data variables, no other stores can be specified."
|
|
576
|
+
)
|
|
577
|
+
raise ValueError(
|
|
578
|
+
"When using a zarr store with multiple data variables, no other stores can be specified."
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
if "ALL" in store_dict:
|
|
582
|
+
zarr_store = store if isinstance(store, str) else list(store.values())[0]
|
|
583
|
+
img_xds = access_store_zarr(zarr_store, **zarr_kwargs)
|
|
584
|
+
return img_xds
|
|
585
|
+
|
|
586
|
+
img_xds = xr.Dataset()
|
|
587
|
+
# Loop over all the input CASA and Fits images.
|
|
588
|
+
for image_type, store_description in store_dict.items():
|
|
589
|
+
|
|
590
|
+
store_type = store_description["store_type"]
|
|
591
|
+
store = store_description["store"]
|
|
592
|
+
|
|
593
|
+
fits_kwargs["image_type"] = image_type
|
|
594
|
+
casa_kwargs["image_type"] = image_type
|
|
595
|
+
|
|
596
|
+
if store_type == "casa":
|
|
597
|
+
xds = access_store_casa(store, **casa_kwargs)
|
|
598
|
+
elif store_type == "fits":
|
|
599
|
+
if access_store_fits is None:
|
|
600
|
+
logger.error("FITS not currently supported.")
|
|
601
|
+
raise RuntimeError("FITS not currently supported.")
|
|
602
|
+
xds = access_store_fits(store, **fits_kwargs)
|
|
603
|
+
else:
|
|
604
|
+
logger.error(
|
|
605
|
+
f"Unrecognized image format for path {store}. Supported types are CASA, FITS, and zarr.\n"
|
|
606
|
+
)
|
|
607
|
+
raise RuntimeError(
|
|
608
|
+
f"Unrecognized image format for path {store}. Supported types are CASA, FITS, and zarr.\n"
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
img_xds.attrs = img_xds.attrs | xds.attrs
|
|
612
|
+
img_xds[image_type] = xds[image_type]
|
|
613
|
+
img_xds[image_type].attrs["type"] = image_type.lower()
|
|
614
|
+
|
|
615
|
+
active_data_group_name = None
|
|
616
|
+
# If sky image, handle internal masks and beam fit params.
|
|
617
|
+
if "sky" in image_type.lower():
|
|
618
|
+
for data_group_name, data_group in data_groups.items():
|
|
619
|
+
if data_group["sky"] == image_type:
|
|
620
|
+
active_data_group_name = data_group_name
|
|
621
|
+
|
|
622
|
+
if "BEAM_FIT_PARAMS_" + image_type.upper() in xds:
|
|
623
|
+
img_xds["BEAM_FIT_PARAMS_" + image_type.upper()] = xds[
|
|
624
|
+
"BEAM_FIT_PARAMS_" + image_type.upper()
|
|
625
|
+
]
|
|
626
|
+
data_groups[active_data_group_name]["beam_fit_params_sky"] = (
|
|
627
|
+
"BEAM_FIT_PARAMS_" + image_type.upper()
|
|
628
|
+
)
|
|
629
|
+
expected_flag_name = "FLAG_" + image_type
|
|
630
|
+
|
|
631
|
+
# TODO remove this mask logic and everything that still makes it necessary
|
|
632
|
+
def _add_flag_to_group(
|
|
633
|
+
img_xds: xr.Dataset,
|
|
634
|
+
flag_array: xr.DataArray,
|
|
635
|
+
expected_flag_name: str,
|
|
636
|
+
active_group: dict,
|
|
637
|
+
):
|
|
638
|
+
img_xds[expected_flag_name] = flag_array
|
|
639
|
+
img_xds[expected_flag_name].attrs["type"] = "flag"
|
|
640
|
+
active_group["flag"] = expected_flag_name
|
|
641
|
+
|
|
642
|
+
if expected_flag_name in xds:
|
|
643
|
+
_add_flag_to_group(
|
|
644
|
+
img_xds,
|
|
645
|
+
xds[expected_flag_name],
|
|
646
|
+
expected_flag_name,
|
|
647
|
+
data_groups[active_data_group_name],
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
if "MASK_0" in xds:
|
|
651
|
+
_add_flag_to_group(
|
|
652
|
+
img_xds,
|
|
653
|
+
xds["MASK_0"],
|
|
654
|
+
expected_flag_name,
|
|
655
|
+
data_groups[active_data_group_name],
|
|
656
|
+
)
|
|
657
|
+
"""
|
|
658
|
+
TODO delete old code when certain new function works
|
|
659
|
+
img_xds[expected_flag_name] = xds["MASK_0"]
|
|
660
|
+
data_groups[active_data_group_name]["flag"] = expected_flag_name
|
|
661
|
+
img_xds[expected_flag_name].attrs["type"] = "flag"
|
|
662
|
+
"""
|
|
663
|
+
if "MASK" in xds:
|
|
664
|
+
_add_flag_to_group(
|
|
665
|
+
img_xds,
|
|
666
|
+
xds["MASK"],
|
|
667
|
+
expected_flag_name,
|
|
668
|
+
data_groups[active_data_group_name],
|
|
669
|
+
)
|
|
670
|
+
"""
|
|
671
|
+
TODO delete old code when certain new function works
|
|
672
|
+
img_xds[expected_flag_name] = xds["MASK"]
|
|
673
|
+
data_groups[active_data_group_name]["flag"] = expected_flag_name
|
|
674
|
+
img_xds[expected_flag_name].attrs["type"] = "flag"
|
|
675
|
+
"""
|
|
676
|
+
img_xds[image_type].attrs["type"] = "sky"
|
|
677
|
+
|
|
678
|
+
# If point spread function, handle beam fit params.
|
|
679
|
+
if "point_spread_function" in image_type.lower():
|
|
680
|
+
if "BEAM_FIT_PARAMS_" + image_type.upper() in xds:
|
|
681
|
+
img_xds["BEAM_FIT_PARAMS_" + image_type.upper()] = xds[
|
|
682
|
+
"BEAM_FIT_PARAMS_" + image_type.upper()
|
|
683
|
+
]
|
|
684
|
+
|
|
685
|
+
# Figure out data groups.
|
|
686
|
+
# Each sky image gets its own data group and shares all other images between them.
|
|
687
|
+
if "sky" not in image_type.lower():
|
|
688
|
+
for data_group_name, data_group in data_groups.items():
|
|
689
|
+
data_group[image_type.lower()] = image_type
|
|
690
|
+
|
|
691
|
+
if "point_spread_function" in image_type.lower():
|
|
692
|
+
if "BEAM_FIT_PARAMS_" + image_type.upper() in xds:
|
|
693
|
+
data_group["beam_fit_params_point_spread_function"] = (
|
|
694
|
+
"BEAM_FIT_PARAMS_" + image_type.upper()
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
# if beam_param coord not in image type it is not auto assigned to img_xds
|
|
698
|
+
# but it must be present even if unused
|
|
699
|
+
if "beam_params_label" not in img_xds.dims:
|
|
700
|
+
img_xds.expand_dims(beam_params_label=3)
|
|
701
|
+
|
|
702
|
+
if "beam_params_label" not in img_xds.coords:
|
|
703
|
+
img_xds = _move_beam_param_dim_coord(img_xds)
|
|
704
|
+
img_xds.attrs["type"] = "image_dataset"
|
|
705
|
+
img_xds.attrs["data_groups"] = data_groups
|
|
706
|
+
return img_xds
|