xradio 1.0.2__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. xradio/_utils/_casacore/casacore_from_casatools.py +75 -9
  2. xradio/_utils/dict_helpers.py +38 -7
  3. xradio/_utils/list_and_array.py +26 -3
  4. xradio/_utils/schema.py +44 -0
  5. xradio/_utils/xarray_helpers.py +63 -0
  6. xradio/_utils/zarr/common.py +4 -2
  7. xradio/image/__init__.py +4 -2
  8. xradio/image/_util/_casacore/common.py +2 -1
  9. xradio/image/_util/_casacore/xds_from_casacore.py +144 -92
  10. xradio/image/_util/_casacore/xds_to_casacore.py +118 -53
  11. xradio/image/_util/_fits/xds_from_fits.py +125 -37
  12. xradio/image/_util/_zarr/common.py +0 -1
  13. xradio/image/_util/casacore.py +183 -25
  14. xradio/image/_util/common.py +10 -8
  15. xradio/image/_util/image_factory.py +469 -27
  16. xradio/image/image.py +72 -100
  17. xradio/image/image_xds.py +262 -0
  18. xradio/image/schema.py +85 -0
  19. xradio/measurement_set/__init__.py +5 -4
  20. xradio/measurement_set/_utils/_msv2/_tables/read.py +4 -3
  21. xradio/measurement_set/_utils/_msv2/conversion.py +6 -9
  22. xradio/measurement_set/_utils/_msv2/create_field_and_source_xds.py +1 -0
  23. xradio/measurement_set/_utils/_msv2/msv4_sub_xdss.py +1 -1
  24. xradio/measurement_set/_utils/_utils/interpolate.py +5 -0
  25. xradio/measurement_set/_utils/_utils/partition_attrs.py +0 -1
  26. xradio/measurement_set/convert_msv2_to_processing_set.py +9 -9
  27. xradio/measurement_set/load_processing_set.py +2 -2
  28. xradio/measurement_set/measurement_set_xdt.py +83 -93
  29. xradio/measurement_set/open_processing_set.py +1 -1
  30. xradio/measurement_set/processing_set_xdt.py +33 -26
  31. xradio/schema/check.py +70 -19
  32. xradio/schema/common.py +0 -1
  33. xradio/testing/__init__.py +0 -0
  34. xradio/testing/_utils/__template__.py +58 -0
  35. xradio/testing/measurement_set/__init__.py +58 -0
  36. xradio/testing/measurement_set/checker.py +131 -0
  37. xradio/testing/measurement_set/io.py +22 -0
  38. xradio/testing/measurement_set/msv2_io.py +1854 -0
  39. {xradio-1.0.2.dist-info → xradio-1.1.1.dist-info}/METADATA +65 -23
  40. xradio-1.1.1.dist-info/RECORD +75 -0
  41. {xradio-1.0.2.dist-info → xradio-1.1.1.dist-info}/WHEEL +1 -1
  42. xradio-1.0.2.dist-info/RECORD +0 -66
  43. {xradio-1.0.2.dist-info → xradio-1.1.1.dist-info}/licenses/LICENSE.txt +0 -0
  44. {xradio-1.0.2.dist-info → xradio-1.1.1.dist-info}/top_level.txt +0 -0
@@ -4,8 +4,10 @@ import numpy as np
4
4
  import xarray as xr
5
5
  from typing import List, Union
6
6
  from .common import _c, _compute_world_sph_dims, _l_m_attr_notes
7
+ import toolviper.utils.logger as logger
7
8
  from xradio._utils.coord_math import _deg_to_rad
8
9
  from xradio._utils.dict_helpers import (
10
+ make_direction_location_dict,
9
11
  make_spectral_coord_reference_dict,
10
12
  make_quantity,
11
13
  make_skycoord_dict,
@@ -18,6 +20,23 @@ def _input_checks(
18
20
  image_size: Union[list, np.ndarray],
19
21
  cell_size: Union[list, np.ndarray],
20
22
  ) -> None:
23
+ """
24
+ Validate input parameters for image creation functions.
25
+
26
+ Parameters
27
+ ----------
28
+ phase_center : list or np.ndarray
29
+ Image phase center coordinates. Must have exactly 2 elements.
30
+ image_size : list or np.ndarray
31
+ Number of pixels along each axis. Must have exactly 2 elements.
32
+ cell_size : list or np.ndarray
33
+ Size of pixels along each axis. Must have exactly 2 elements.
34
+
35
+ Raises
36
+ ------
37
+ ValueError
38
+ If any parameter does not have exactly 2 elements.
39
+ """
21
40
  if len(image_size) != 2:
22
41
  raise ValueError("image_size must have exactly two elements")
23
42
  if len(phase_center) != 2:
@@ -27,18 +46,22 @@ def _input_checks(
27
46
 
28
47
 
29
48
  def _make_coords(
30
- chan_coords: Union[list, np.ndarray],
49
+ frequency_coords: Union[list, np.ndarray],
31
50
  time_coords: Union[list, np.ndarray],
32
51
  ) -> dict:
33
- if not isinstance(chan_coords, list) and not isinstance(chan_coords, np.ndarray):
34
- chan_coords = [chan_coords]
35
- chan_coords = np.array(chan_coords, dtype=np.float64)
36
- restfreq = chan_coords[len(chan_coords) // 2]
37
- vel_coords = (1 - chan_coords / restfreq) * _c.to("m/s").value
52
+ if not isinstance(frequency_coords, list) and not isinstance(
53
+ frequency_coords, np.ndarray
54
+ ):
55
+ frequency_coords = [frequency_coords]
56
+ frequency_coords = np.array(frequency_coords, dtype=np.float64)
57
+ restfreq = frequency_coords[len(frequency_coords) // 2]
58
+ vel_coords = (1 - frequency_coords / restfreq) * _c.to("m/s").value
38
59
  if not isinstance(time_coords, list) and not isinstance(time_coords, np.ndarray):
39
60
  time_coords = [time_coords]
40
61
  time_coords = np.array(time_coords, dtype=np.float64)
41
- return dict(chan=chan_coords, vel=vel_coords, time=time_coords, restfreq=restfreq)
62
+ return dict(
63
+ chan=frequency_coords, vel=vel_coords, time=time_coords, restfreq=restfreq
64
+ )
42
65
 
43
66
 
44
67
  def _add_common_attrs(
@@ -72,24 +95,26 @@ def _add_common_attrs(
72
95
  reference["attrs"].update({"equinox": "j2000.0"})
73
96
  xds.attrs = {
74
97
  "data_groups": {"base": {}},
75
- "direction": {
76
- "reference": reference,
77
- "lonpole": make_quantity(np.pi, "rad", ["l", "m"]),
78
- "latpole": make_quantity(0.0, "rad", ["l", "m"]),
79
- "pc": [[1.0, 0.0], [0.0, 1.0]],
98
+ "coordinate_system_info": {
99
+ "reference_direction": reference,
100
+ "native_pole_direction": make_direction_location_dict(
101
+ [np.pi, 0.0], "rad", "native_projection"
102
+ ),
103
+ "pixel_coordinate_transformation_matrix": [[1.0, 0.0], [0.0, 1.0]],
80
104
  "projection": projection,
81
105
  "projection_parameters": [0.0, 0.0],
82
106
  },
107
+ "type": "image",
83
108
  }
84
109
  return xds
85
110
 
86
111
 
87
112
  def _make_common_coords(
88
113
  pol_coords: Union[list, np.ndarray],
89
- chan_coords: Union[list, np.ndarray],
114
+ frequency_coords: Union[list, np.ndarray],
90
115
  time_coords: Union[list, np.ndarray],
91
116
  ) -> dict:
92
- some_coords = _make_coords(chan_coords, time_coords)
117
+ some_coords = _make_coords(frequency_coords, time_coords)
93
118
  return {
94
119
  "coords": {
95
120
  "time": some_coords["time"],
@@ -147,7 +172,7 @@ def _make_empty_sky_image(
147
172
  phase_center: Union[list, np.ndarray],
148
173
  image_size: Union[list, np.ndarray],
149
174
  cell_size: Union[list, np.ndarray],
150
- chan_coords: Union[list, np.ndarray],
175
+ frequency_coords: Union[list, np.ndarray],
151
176
  pol_coords: Union[list, np.ndarray],
152
177
  time_coords: Union[list, np.ndarray],
153
178
  direction_reference: str,
@@ -156,7 +181,7 @@ def _make_empty_sky_image(
156
181
  do_sky_coords: bool,
157
182
  ) -> xr.Dataset:
158
183
  _input_checks(phase_center, image_size, cell_size)
159
- cc = _make_common_coords(pol_coords, chan_coords, time_coords)
184
+ cc = _make_common_coords(pol_coords, frequency_coords, time_coords)
160
185
  coords = cc["coords"]
161
186
  lm_values = _make_lm_values(image_size, cell_size)
162
187
  coords.update(lm_values)
@@ -196,12 +221,8 @@ def _make_uv_values(
196
221
  ) -> dict:
197
222
  im_size_wave = 1 / np.array(sky_image_cell_size)
198
223
  uv_cell_size = im_size_wave / np.array(image_size)
199
- u_vals = [
200
- (i - image_size[0] // 2) * abs(uv_cell_size[0]) for i in range(image_size[0])
201
- ]
202
- v_vals = [
203
- (i - image_size[1] // 2) * abs(uv_cell_size[1]) for i in range(image_size[1])
204
- ]
224
+ u_vals = [(i - image_size[0] // 2) * uv_cell_size[0] for i in range(image_size[0])]
225
+ v_vals = [(i - image_size[1] // 2) * uv_cell_size[1] for i in range(image_size[1])]
205
226
  return {"u": u_vals, "v": v_vals}
206
227
 
207
228
 
@@ -209,7 +230,7 @@ def _make_empty_aperture_image(
209
230
  phase_center: Union[list, np.ndarray],
210
231
  image_size: Union[list, np.ndarray],
211
232
  sky_image_cell_size: Union[list, np.ndarray],
212
- chan_coords: Union[list, np.ndarray],
233
+ frequency_coords: Union[list, np.ndarray],
213
234
  pol_coords: Union[list, np.ndarray],
214
235
  time_coords: Union[list, np.ndarray],
215
236
  direction_reference: str,
@@ -217,7 +238,7 @@ def _make_empty_aperture_image(
217
238
  spectral_reference: str,
218
239
  ) -> xr.Dataset:
219
240
  _input_checks(phase_center, image_size, sky_image_cell_size)
220
- cc = _make_common_coords(pol_coords, chan_coords, time_coords)
241
+ cc = _make_common_coords(pol_coords, frequency_coords, time_coords)
221
242
  coords = cc["coords"]
222
243
  xds = xr.Dataset(coords=coords)
223
244
  xds = _make_uv_coords(xds, image_size, sky_image_cell_size)
@@ -235,14 +256,29 @@ def _make_empty_aperture_image(
235
256
 
236
257
 
237
258
  def _move_beam_param_dim_coord(xds: xr.Dataset) -> xr.Dataset:
238
- return xds.assign_coords(beam_param=("beam_param", ["major", "minor", "pa"]))
259
+ """
260
+ Add beam_params_label coordinate to an xarray Dataset.
261
+
262
+ Parameters
263
+ ----------
264
+ xds : xr.Dataset
265
+ Input Dataset to which beam parameters will be added.
266
+
267
+ Returns
268
+ -------
269
+ xr.Dataset
270
+ Dataset with beam_params_label coordinate containing ['major', 'minor', 'pa'].
271
+ """
272
+ return xds.assign_coords(
273
+ beam_params_label=("beam_params_label", ["major", "minor", "pa"])
274
+ )
239
275
 
240
276
 
241
277
  def _make_empty_lmuv_image(
242
278
  phase_center: Union[list, np.ndarray],
243
279
  image_size: Union[list, np.ndarray],
244
280
  sky_image_cell_size: Union[list, np.ndarray],
245
- chan_coords: Union[list, np.ndarray],
281
+ frequency_coords: Union[list, np.ndarray],
246
282
  pol_coords: Union[list, np.ndarray],
247
283
  time_coords: Union[list, np.ndarray],
248
284
  direction_reference: str,
@@ -254,7 +290,7 @@ def _make_empty_lmuv_image(
254
290
  phase_center,
255
291
  image_size,
256
292
  sky_image_cell_size,
257
- chan_coords,
293
+ frequency_coords,
258
294
  pol_coords,
259
295
  time_coords,
260
296
  direction_reference,
@@ -265,3 +301,409 @@ def _make_empty_lmuv_image(
265
301
  xds = _make_uv_coords(xds, image_size, sky_image_cell_size)
266
302
  xds = _move_beam_param_dim_coord(xds)
267
303
  return xds
304
+
305
+
306
+ def detect_store_type(store):
307
+ """
308
+ Detect the storage format type of an image store.
309
+
310
+ Parameters
311
+ ----------
312
+ store : str or dict
313
+ Path to the image store or a dictionary representation.
314
+
315
+ Returns
316
+ -------
317
+ str
318
+ The detected store type: 'fits', 'casa', or 'zarr'.
319
+
320
+ Raises
321
+ ------
322
+ ValueError
323
+ If the directory structure is unknown or the path does not exist.
324
+ """
325
+ import os
326
+
327
+ if isinstance(store, str):
328
+ if os.path.isfile(store):
329
+ store_type = "fits"
330
+ elif os.path.isdir(store):
331
+ if "table.info" in os.listdir(store):
332
+ store_type = "casa"
333
+ elif ".zattrs" in os.listdir(store):
334
+ store_type = "zarr"
335
+ else:
336
+ logger.error("Unknown directory structure.")
337
+ raise ValueError("Unknown directory structure." + str(store))
338
+ else:
339
+ logger.error("Path does not exist.")
340
+ raise ValueError(
341
+ "Path does not exist. The current path: "
342
+ + str(os.system("pwd"))
343
+ + " .The current casa directory: "
344
+ + str(os.system("ls 3c286_Band6_5chans_lsrk_robust_0.5_niter_99_casa"))
345
+ + ". The current fits directory: "
346
+ + str(os.system("ls 3c286_Band6_5chans_lsrk_robust_0.5_niter_99_fits"))
347
+ + " The given store "
348
+ + str(store)
349
+ )
350
+ else:
351
+ store_type = "zarr"
352
+
353
+ return store_type
354
+
355
+
356
+ def detect_image_type(store):
357
+ """
358
+ Detect the image type from the store name or path.
359
+
360
+ Infers the image type based on common naming patterns in the store path.
361
+
362
+ Parameters
363
+ ----------
364
+ store : str or other
365
+ Path to the image store. If not a string, returns 'ALL'.
366
+
367
+ Returns
368
+ -------
369
+ str
370
+ The detected image type. Possible values include:
371
+ - 'SKY': Sky image (contains 'image' or 'im')
372
+ - 'POINT_SPREAD_FUNCTION': PSF image (contains 'psf')
373
+ - 'MODEL': Model image (contains 'model')
374
+ - 'RESIDUAL': Residual image (contains 'residual')
375
+ - 'MASK_DECONVOLVE': Mask image (contains 'mask')
376
+ - 'PRIMARY_BEAM': Primary beam image (contains 'pb')
377
+ - 'APERTURE': Aperture image (contains 'aperture')
378
+ - 'VISIBILITY': Visibility image (contains 'visibility')
379
+ - 'VISIBILITY_NORMALIZATION': Visibility normalization (contains 'sumwt')
380
+ - 'UNKNOWN': Could not detect type from name
381
+ - 'ALL': Non-string store type
382
+ """
383
+ import os
384
+
385
+ if isinstance(store, str):
386
+ if "fits" in store.lower():
387
+ image_type = "SKY"
388
+ elif "image" in store.lower():
389
+ image_type = "SKY"
390
+ elif "sky" in store.lower():
391
+ image_type = "SKY"
392
+ elif "psf" in store.lower():
393
+ image_type = "POINT_SPREAD_FUNCTION"
394
+ elif "model" in store.lower():
395
+ image_type = "MODEL"
396
+ elif "residual" in store.lower():
397
+ image_type = "RESIDUAL"
398
+ elif "dirty" in store.lower():
399
+ image_type = "DIRTY"
400
+ elif "pb" in store.lower():
401
+ image_type = "PRIMARY_BEAM"
402
+ elif "aperture" in store.lower():
403
+ image_type = "APERTURE"
404
+ elif "uv" in store.lower():
405
+ # Must precede the "im" check so uv images aren't misclassified as SKY.
406
+ image_type = "APERTURE"
407
+ elif "visibility" in store.lower():
408
+ image_type = "VISIBILITY"
409
+ elif "sumwt" in store.lower():
410
+ image_type = "VISIBILITY_NORMALIZATION"
411
+ elif "zarr" in store.lower():
412
+ image_type = "ALL"
413
+ elif "im" in store.lower():
414
+ # Must precede the "mask" check so *mask*.im images aren't misclassified.
415
+ image_type = "SKY"
416
+ elif "mask" in store.lower():
417
+ image_type = "MASK_DECONVOLVE"
418
+ else:
419
+ image_type = "UNKNOWN"
420
+ else:
421
+ image_type = "ALL"
422
+
423
+ return image_type
424
+
425
+
426
+ def create_store_dict(store_to_label):
427
+ """
428
+ Create a standardized dictionary mapping image types to their store information.
429
+
430
+ Converts various input formats (string, list, or dict) into a consistent
431
+ dictionary format where keys are image types and values contain store metadata.
432
+
433
+ Parameters
434
+ ----------
435
+ store_to_label : str, list, or dict
436
+ Input store specification. Can be:
437
+ - str: Single store path
438
+ - list: List of store paths (image types will be auto-detected)
439
+ - dict: Mapping of image types to store paths
440
+
441
+ Returns
442
+ -------
443
+ dict
444
+ Dictionary with image types as keys. Each value is a dict with:
445
+ - 'store_type': str, the format ('casa', 'fits', or 'zarr')
446
+ - 'store': str, the path to the store
447
+
448
+ Raises
449
+ ------
450
+ ValueError
451
+ If image type cannot be detected or duplicate image types are found.
452
+ """
453
+ store_list = None
454
+ if isinstance(store_to_label, str):
455
+ store_list = [store_to_label] # So can iterate over it.
456
+ elif isinstance(store_to_label, list):
457
+ store_list = store_to_label
458
+
459
+ if (store_list is not None) and isinstance(store_list, list):
460
+ store_dict_to_label = {i: v for i, v in enumerate(store_list)}
461
+ else:
462
+ store_dict_to_label = store_to_label
463
+
464
+ store_dict = {}
465
+ for image_type, store in store_dict_to_label.items():
466
+
467
+ if isinstance(image_type, int):
468
+ image_type = detect_image_type(store)
469
+
470
+ image_type = image_type.upper()
471
+
472
+ store_type = detect_store_type(store)
473
+
474
+ if image_type == "UNKNOWN":
475
+ logger.error(f"Could not detect image type for store {store}. ")
476
+ example = "store={'sky': 'path/to/image.fits'}"
477
+ raise ValueError(
478
+ f"Could not detect image type for store {store}. Please label the store with the image type explicitly. For example: {example}"
479
+ )
480
+
481
+ if image_type in store_dict:
482
+ logger.error(f"Duplicate image type {image_type} detected in store list.")
483
+ raise ValueError(
484
+ f"Duplicate image type {image_type} detected in store list. Please ensure each image type is unique. The store dict"
485
+ + str(store_dict)
486
+ )
487
+
488
+ if store_type == "zarr":
489
+ image_type = "ALL" # Zarr can have multiple data variables.
490
+
491
+ store_dict[image_type] = {"store_type": store_type, "store": store}
492
+
493
+ data_groups = {}
494
+
495
+ for image_type in store_dict.keys():
496
+ if "sky" in image_type.lower():
497
+ if "sky" == image_type.lower():
498
+ data_groups["base"] = {"sky": image_type}
499
+ else:
500
+ data_group_name = image_type.lower().replace("sky_", "")
501
+ data_groups[data_group_name] = {"sky": image_type}
502
+ if "aperture" == image_type.lower():
503
+ data_groups["base"] = {"aperture": image_type}
504
+
505
+ return store_dict, data_groups
506
+
507
+
508
+ def create_image_xds_from_store(
509
+ store: Union[list, dict, str],
510
+ access_store_casa: callable,
511
+ casa_kwargs: dict,
512
+ access_store_fits: callable,
513
+ fits_kwargs: dict,
514
+ access_store_zarr: callable,
515
+ zarr_kwargs: dict,
516
+ ) -> xr.Dataset:
517
+ """
518
+ Create an xarray Dataset from one or more image stores.
519
+
520
+ This function reads image data from CASA, FITS, or zarr format stores and
521
+ combines them into a single xarray Dataset with appropriate metadata and
522
+ data variables.
523
+
524
+ Parameters
525
+ ----------
526
+ store : str, list, or dict
527
+ Image store specification:
528
+ - str: Single store path
529
+ - list: List of store paths
530
+ - dict: Mapping of image types to store paths
531
+ access_store_casa : callable
532
+ Function to read CASA format images. Should accept a store path and
533
+ keyword arguments, returning an xr.Dataset.
534
+ casa_kwargs : dict
535
+ Keyword arguments to pass to access_store_casa.
536
+ access_store_fits : callable or None
537
+ Function to read FITS format images. Should accept a store path and
538
+ keyword arguments, returning an xr.Dataset. Can be None if FITS support
539
+ is not needed.
540
+ fits_kwargs : dict
541
+ Keyword arguments to pass to access_store_fits.
542
+ access_store_zarr : callable
543
+ Function to read zarr format images. Should accept a store path and
544
+ keyword arguments, returning an xr.Dataset.
545
+ zarr_kwargs : dict
546
+ Keyword arguments to pass to access_store_zarr.
547
+
548
+ Returns
549
+ -------
550
+ xr.Dataset
551
+ An xarray Dataset containing the image data and metadata. The Dataset
552
+ includes:
553
+ - Data variables for each image type (e.g., 'SKY', 'MODEL', 'RESIDUAL')
554
+ - Coordinates shared across all images
555
+ - Attributes including 'type' and 'data_groups'
556
+
557
+ Raises
558
+ ------
559
+ ValueError
560
+ If zarr store with multiple data variables is combined with other stores.
561
+ RuntimeError
562
+ If FITS format is requested but access_store_fits is None, or if an
563
+ unrecognized image format is encountered.
564
+
565
+ Notes
566
+ -----
567
+ - Zarr stores can contain multiple data variables and will be returned as-is.
568
+ - For other formats, data from multiple stores is combined into one Dataset.
569
+ - BEAM_FIT_PARAMS from SKY images take precedence over POINT_SPREAD_FUNCTION.
570
+ - Masks are renamed to MASK_<IMAGE_TYPE> for internal masks.
571
+ """
572
+ store_dict, data_groups = create_store_dict(store)
573
+ if "ALL" in store_dict and len(store_dict) > 1:
574
+ logger.error(
575
+ "When using a zarr store with multiple data variables, no other stores can be specified."
576
+ )
577
+ raise ValueError(
578
+ "When using a zarr store with multiple data variables, no other stores can be specified."
579
+ )
580
+
581
+ if "ALL" in store_dict:
582
+ zarr_store = store if isinstance(store, str) else list(store.values())[0]
583
+ img_xds = access_store_zarr(zarr_store, **zarr_kwargs)
584
+ return img_xds
585
+
586
+ img_xds = xr.Dataset()
587
+ # Loop over all the input CASA and Fits images.
588
+ for image_type, store_description in store_dict.items():
589
+
590
+ store_type = store_description["store_type"]
591
+ store = store_description["store"]
592
+
593
+ fits_kwargs["image_type"] = image_type
594
+ casa_kwargs["image_type"] = image_type
595
+
596
+ if store_type == "casa":
597
+ xds = access_store_casa(store, **casa_kwargs)
598
+ elif store_type == "fits":
599
+ if access_store_fits is None:
600
+ logger.error("FITS not currently supported.")
601
+ raise RuntimeError("FITS not currently supported.")
602
+ xds = access_store_fits(store, **fits_kwargs)
603
+ else:
604
+ logger.error(
605
+ f"Unrecognized image format for path {store}. Supported types are CASA, FITS, and zarr.\n"
606
+ )
607
+ raise RuntimeError(
608
+ f"Unrecognized image format for path {store}. Supported types are CASA, FITS, and zarr.\n"
609
+ )
610
+
611
+ img_xds.attrs = img_xds.attrs | xds.attrs
612
+ img_xds[image_type] = xds[image_type]
613
+ img_xds[image_type].attrs["type"] = image_type.lower()
614
+
615
+ active_data_group_name = None
616
+ # If sky image, handle internal masks and beam fit params.
617
+ if "sky" in image_type.lower():
618
+ for data_group_name, data_group in data_groups.items():
619
+ if data_group["sky"] == image_type:
620
+ active_data_group_name = data_group_name
621
+
622
+ if "BEAM_FIT_PARAMS_" + image_type.upper() in xds:
623
+ img_xds["BEAM_FIT_PARAMS_" + image_type.upper()] = xds[
624
+ "BEAM_FIT_PARAMS_" + image_type.upper()
625
+ ]
626
+ data_groups[active_data_group_name]["beam_fit_params_sky"] = (
627
+ "BEAM_FIT_PARAMS_" + image_type.upper()
628
+ )
629
+ expected_flag_name = "FLAG_" + image_type
630
+
631
+ # TODO remove this mask logic and everything that still makes it necessary
632
+ def _add_flag_to_group(
633
+ img_xds: xr.Dataset,
634
+ flag_array: xr.DataArray,
635
+ expected_flag_name: str,
636
+ active_group: dict,
637
+ ):
638
+ img_xds[expected_flag_name] = flag_array
639
+ img_xds[expected_flag_name].attrs["type"] = "flag"
640
+ active_group["flag"] = expected_flag_name
641
+
642
+ if expected_flag_name in xds:
643
+ _add_flag_to_group(
644
+ img_xds,
645
+ xds[expected_flag_name],
646
+ expected_flag_name,
647
+ data_groups[active_data_group_name],
648
+ )
649
+
650
+ if "MASK_0" in xds:
651
+ _add_flag_to_group(
652
+ img_xds,
653
+ xds["MASK_0"],
654
+ expected_flag_name,
655
+ data_groups[active_data_group_name],
656
+ )
657
+ """
658
+ TODO delete old code when certain new function works
659
+ img_xds[expected_flag_name] = xds["MASK_0"]
660
+ data_groups[active_data_group_name]["flag"] = expected_flag_name
661
+ img_xds[expected_flag_name].attrs["type"] = "flag"
662
+ """
663
+ if "MASK" in xds:
664
+ _add_flag_to_group(
665
+ img_xds,
666
+ xds["MASK"],
667
+ expected_flag_name,
668
+ data_groups[active_data_group_name],
669
+ )
670
+ """
671
+ TODO delete old code when certain new function works
672
+ img_xds[expected_flag_name] = xds["MASK"]
673
+ data_groups[active_data_group_name]["flag"] = expected_flag_name
674
+ img_xds[expected_flag_name].attrs["type"] = "flag"
675
+ """
676
+ img_xds[image_type].attrs["type"] = "sky"
677
+
678
+ # If point spread function, handle beam fit params.
679
+ if "point_spread_function" in image_type.lower():
680
+ if "BEAM_FIT_PARAMS_" + image_type.upper() in xds:
681
+ img_xds["BEAM_FIT_PARAMS_" + image_type.upper()] = xds[
682
+ "BEAM_FIT_PARAMS_" + image_type.upper()
683
+ ]
684
+
685
+ # Figure out data groups.
686
+ # Each sky image gets its own data group and shares all other images between them.
687
+ if "sky" not in image_type.lower():
688
+ for data_group_name, data_group in data_groups.items():
689
+ data_group[image_type.lower()] = image_type
690
+
691
+ if "point_spread_function" in image_type.lower():
692
+ if "BEAM_FIT_PARAMS_" + image_type.upper() in xds:
693
+ data_group["beam_fit_params_point_spread_function"] = (
694
+ "BEAM_FIT_PARAMS_" + image_type.upper()
695
+ )
696
+ if (
697
+ "visibility_normalization" not in image_type.lower()
698
+ or len(img_xds.data_vars) > 1
699
+ ):
700
+ # if beam_param coord not in image type it is not auto assigned to img_xds
701
+ # but it must be present even if unused
702
+ if "beam_params_label" not in img_xds.dims:
703
+ img_xds.expand_dims(beam_params_label=3)
704
+
705
+ if "beam_params_label" not in img_xds.coords:
706
+ img_xds = _move_beam_param_dim_coord(img_xds)
707
+ img_xds.attrs["type"] = "image_dataset"
708
+ img_xds.attrs["data_groups"] = data_groups
709
+ return img_xds