xradio 0.0.47__py3-none-any.whl → 0.0.48__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -127,9 +127,9 @@ def interpolate_to_time(
127
127
  xds = xds.interp(
128
128
  {time_name: interp_time.data}, method=method, assume_sorted=True
129
129
  )
130
- # scan_number sneaks in as a coordinate of the main time axis, drop it
131
- if "scan_number" in xds.coords:
132
- xds = xds.drop_vars("scan_number")
130
+ # scan_name sneaks in as a coordinate of the main time axis, drop it
131
+ if "scan_name" in xds.coords:
132
+ xds = xds.drop_vars("scan_name")
133
133
  points_after = xds[time_name].size
134
134
  logger.debug(
135
135
  f"{message_prefix}: interpolating the time coordinate "
@@ -497,7 +497,7 @@ def prepare_generic_sys_cal_xds(generic_sys_cal_xds: xr.Dataset) -> xr.Dataset:
497
497
  def create_system_calibration_xds(
498
498
  in_file: str,
499
499
  main_xds_frequency: xr.DataArray,
500
- ant_xds_name_ids: xr.DataArray,
500
+ ant_xds: xr.DataArray,
501
501
  sys_cal_interp_time: Union[xr.DataArray, None] = None,
502
502
  ):
503
503
  """
@@ -510,8 +510,8 @@ def create_system_calibration_xds(
510
510
  main_xds_frequency: xr.DataArray
511
511
  frequency array of the main xds (MSv4), containing among other things
512
512
  spectral_window_id and measures metadata
513
- ant_xds_name_ids : xr.Dataset
514
- antenna_name data array from antenna_xds, with name/id information
513
+ ant_xds : xr.Dataset
514
+ The antenna_xds that has information such as names, stations, etc., for coordinates
515
515
  sys_cal_interp_time: Union[xr.DataArray, None] = None,
516
516
  Time axis to interpolate the data vars to (usually main MSv4 time)
517
517
 
@@ -529,7 +529,7 @@ def create_system_calibration_xds(
529
529
  rename_ids=subt_rename_ids["SYSCAL"],
530
530
  taql_where=(
531
531
  f" where (SPECTRAL_WINDOW_ID = {spectral_window_id})"
532
- f" AND (ANTENNA_ID IN [{','.join(map(str, ant_xds_name_ids.antenna_id.values))}])"
532
+ f" AND (ANTENNA_ID IN [{','.join(map(str, ant_xds.antenna_id.values))}])"
533
533
  ),
534
534
  )
535
535
  except ValueError as _exc:
@@ -541,14 +541,14 @@ def create_system_calibration_xds(
541
541
 
542
542
  generic_sys_cal_xds = prepare_generic_sys_cal_xds(generic_sys_cal_xds)
543
543
 
544
- mandatory_dimensions = ["antenna_name", "time_cal", "receptor_label"]
544
+ mandatory_dimensions = ["antenna_name", "time_system_cal", "receptor_label"]
545
545
  if "frequency" not in generic_sys_cal_xds.sizes:
546
546
  dims_all = mandatory_dimensions
547
547
  else:
548
- dims_all = mandatory_dimensions + ["frequency_cal"]
548
+ dims_all = mandatory_dimensions + ["frequency_system_cal"]
549
549
 
550
550
  to_new_data_variables = {
551
- "PHASE_DIFF": ["PHASE_DIFFERENCE", ["antenna_name", "time_cal"]],
551
+ "PHASE_DIFF": ["PHASE_DIFFERENCE", ["antenna_name", "time_system_cal"]],
552
552
  "TCAL": ["TCAL", dims_all],
553
553
  "TCAL_SPECTRUM": ["TCAL", dims_all],
554
554
  "TRX": ["TRX", dims_all],
@@ -564,27 +564,26 @@ def create_system_calibration_xds(
564
564
  }
565
565
 
566
566
  to_new_coords = {
567
- "TIME": ["time_cal", ["time_cal"]],
567
+ "TIME": ["time_system_cal", ["time_system_cal"]],
568
568
  "receptor": ["receptor_label", ["receptor_label"]],
569
- "frequency": ["frequency_cal", ["frequency_cal"]],
569
+ "frequency": ["frequency_system_cal", ["frequency_system_cal"]],
570
570
  }
571
571
 
572
572
  sys_cal_xds = xr.Dataset(attrs={"type": "system_calibration"})
573
- coords = {
574
- "antenna_name": ant_xds_name_ids.sel(
575
- antenna_id=generic_sys_cal_xds["ANTENNA_ID"]
576
- ).data,
577
- "receptor_label": generic_sys_cal_xds.coords["receptor"].data,
573
+ ant_borrowed_coords = {
574
+ "antenna_name": ant_xds.coords["antenna_name"],
575
+ "receptor_label": ant_xds.coords["receptor_label"],
576
+ "polarization_type": ant_xds.coords["polarization_type"],
578
577
  }
579
- sys_cal_xds = sys_cal_xds.assign_coords(coords)
578
+ sys_cal_xds = sys_cal_xds.assign_coords(ant_borrowed_coords)
580
579
  sys_cal_xds = convert_generic_xds_to_xradio_schema(
581
580
  generic_sys_cal_xds, sys_cal_xds, to_new_data_variables, to_new_coords
582
581
  )
583
582
 
584
583
  # Add frequency coord and its measures data, if present
585
- if "frequency_cal" in dims_all:
584
+ if "frequency_system_cal" in dims_all:
586
585
  frequency_coord = {
587
- "frequency_cal": generic_sys_cal_xds.coords["frequency"].data
586
+ "frequency_system_cal": generic_sys_cal_xds.coords["frequency"].data
588
587
  }
589
588
  sys_cal_xds = sys_cal_xds.assign_coords(frequency_coord)
590
589
  frequency_measure = {
@@ -592,10 +591,10 @@ def create_system_calibration_xds(
592
591
  "units": main_xds_frequency.attrs["units"],
593
592
  "observer": main_xds_frequency.attrs["observer"],
594
593
  }
595
- sys_cal_xds.coords["frequency_cal"].attrs.update(frequency_measure)
594
+ sys_cal_xds.coords["frequency_system_cal"].attrs.update(frequency_measure)
596
595
 
597
596
  sys_cal_xds = rename_and_interpolate_to_time(
598
- sys_cal_xds, "time_cal", sys_cal_interp_time, "system_calibration_xds"
597
+ sys_cal_xds, "time_system_cal", sys_cal_interp_time, "system_calibration_xds"
599
598
  )
600
599
 
601
600
  # correct expected types
@@ -37,6 +37,7 @@ def load_processing_set(
37
37
  In memory representation of processing set (data is represented by Dask.arrays).
38
38
  """
39
39
  from xradio._utils.zarr.common import _open_dataset, _get_file_system_and_items
40
+ from xradio.measurement_set import MeasurementSetXds
40
41
 
41
42
  file_system, ms_store_list = _get_file_system_and_items(ps_store)
42
43
 
@@ -71,7 +72,7 @@ def load_processing_set(
71
72
  "field_and_source_xds"
72
73
  ] = field_and_source_xds_dict[data_group_name]
73
74
 
74
- ps[ms_name] = xds
75
+ ps[ms_name] = MeasurementSetXds(xds)
75
76
 
76
77
  return ps
77
78
 
@@ -110,7 +110,7 @@ class ProcessingSet(dict):
110
110
  "intents": [],
111
111
  "shape": [],
112
112
  "polarization": [],
113
- "scan_number": [],
113
+ "scan_name": [],
114
114
  "spw_name": [],
115
115
  "field_name": [],
116
116
  "source_name": [],
@@ -129,9 +129,7 @@ class ProcessingSet(dict):
129
129
  value.attrs["partition_info"]["spectral_window_name"]
130
130
  )
131
131
  summary_data["polarization"].append(value.polarization.values)
132
- summary_data["scan_number"].append(
133
- value.attrs["partition_info"]["scan_number"]
134
- )
132
+ summary_data["scan_name"].append(value.attrs["partition_info"]["scan_name"])
135
133
  data_name = value.attrs["data_groups"][data_group]["correlated_data"]
136
134
 
137
135
  if "VISIBILITY" in data_name:
@@ -156,16 +154,27 @@ class ProcessingSet(dict):
156
154
  )
157
155
  summary_data["end_frequency"].append(to_list(value["frequency"].values)[-1])
158
156
 
159
- if value[data_name].attrs["field_and_source_xds"].is_ephemeris:
157
+ if (
158
+ value[data_name].attrs["field_and_source_xds"].attrs["type"]
159
+ == "field_and_source_ephemeris"
160
+ ):
160
161
  summary_data["field_coords"].append("Ephemeris")
162
+ # elif (
163
+ # "time"
164
+ # in value[data_name].attrs["field_and_source_xds"][center_name].coords
165
+ # ):
161
166
  elif (
162
- "time"
163
- in value[data_name].attrs["field_and_source_xds"][center_name].coords
167
+ value[data_name]
168
+ .attrs["field_and_source_xds"][center_name]["field_name"]
169
+ .size
170
+ > 1
164
171
  ):
165
172
  summary_data["field_coords"].append("Multi-Phase-Center")
166
173
  else:
167
174
  ra_dec_rad = (
168
- value[data_name].attrs["field_and_source_xds"][center_name].values
175
+ value[data_name]
176
+ .attrs["field_and_source_xds"][center_name]
177
+ .values[0, :]
169
178
  )
170
179
  frame = (
171
180
  value[data_name]
@@ -383,10 +392,7 @@ class ProcessingSet(dict):
383
392
 
384
393
  def get_combined_field_and_source_xds(self, data_group="base"):
385
394
  """
386
- Combine the `field_and_source_xds` datasets from all Measurement Sets into a single dataset.
387
-
388
- The combined `xarray.Dataset` will have a new dimension 'field_name', consolidating data from
389
- each Measurement Set. Ephemeris data is handled separately.
395
+ Combine all non-ephemeris `field_and_source_xds` datasets from a Processing Set for a datagroup into a single dataset.
390
396
 
391
397
  Parameters
392
398
  ----------
@@ -395,22 +401,17 @@ class ProcessingSet(dict):
395
401
 
396
402
  Returns
397
403
  -------
398
- tuple of xarray.Dataset
399
- A tuple containing two `xarray.Dataset` objects:
400
- - combined_field_and_source_xds: Combined dataset for standard fields.
401
- - combined_ephemeris_field_and_source_xds: Combined dataset for ephemeris fields.
404
+ xarray.Dataset
405
+ combined_field_and_source_xds: Combined dataset for standard fields.
402
406
 
403
407
  Raises
404
408
  ------
405
409
  ValueError
406
410
  If the `field_and_source_xds` attribute is missing or improperly formatted in any Measurement Set.
407
411
  """
408
- df = self.summary(data_group)
409
412
 
410
413
  combined_field_and_source_xds = xr.Dataset()
411
- combined_ephemeris_field_and_source_xds = xr.Dataset()
412
414
  for ms_name, ms_xds in self.items():
413
-
414
415
  correlated_data_name = ms_xds.attrs["data_groups"][data_group][
415
416
  "correlated_data"
416
417
  ]
@@ -421,75 +422,18 @@ class ProcessingSet(dict):
421
422
  .copy(deep=True)
422
423
  )
423
424
 
424
- if (
425
- "line_name" in field_and_source_xds.coords
426
- ): # Not including line info since it is a function of spw.
427
- field_and_source_xds = field_and_source_xds.drop_vars(
428
- ["LINE_REST_FREQUENCY", "LINE_SYSTEMIC_VELOCITY"], errors="ignore"
429
- )
430
- del field_and_source_xds["line_name"]
431
- del field_and_source_xds["line_label"]
432
-
433
- if "time" in field_and_source_xds.coords:
434
- if "time" not in field_and_source_xds.field_name.dims:
435
- field_names = np.array(
436
- [field_and_source_xds.field_name.values.item()]
437
- * len(field_and_source_xds.time.values)
438
- )
439
- source_names = np.array(
440
- [field_and_source_xds.source_name.values.item()]
441
- * len(field_and_source_xds.time.values)
442
- )
443
- del field_and_source_xds["field_name"]
444
- del field_and_source_xds["source_name"]
445
- field_and_source_xds = field_and_source_xds.assign_coords(
446
- field_name=("time", field_names)
447
- )
448
- field_and_source_xds = field_and_source_xds.assign_coords(
449
- source_name=("time", source_names)
450
- )
451
- field_and_source_xds = field_and_source_xds.swap_dims(
452
- {"time": "field_name"}
453
- )
454
- del field_and_source_xds["time"]
455
- elif "time_ephemeris" in field_and_source_xds.coords:
456
- if "time_ephemeris" not in field_and_source_xds.field_name.dims:
457
- field_names = np.array(
458
- [field_and_source_xds.field_name.values.item()]
459
- * len(field_and_source_xds.time_ephemeris.values)
460
- )
461
- source_names = np.array(
462
- [field_and_source_xds.source_name.values.item()]
463
- * len(field_and_source_xds.time_ephemeris.values)
464
- )
465
- del field_and_source_xds["field_name"]
466
- del field_and_source_xds["source_name"]
467
- field_and_source_xds = field_and_source_xds.assign_coords(
468
- field_name=("time_ephemeris", field_names)
469
- )
470
- field_and_source_xds = field_and_source_xds.assign_coords(
471
- source_name=("time_ephemeris", source_names)
472
- )
473
- field_and_source_xds = field_and_source_xds.swap_dims(
474
- {"time_ephemeris": "field_name"}
475
- )
476
- del field_and_source_xds["time_ephemeris"]
477
- else:
478
- for dv_names in field_and_source_xds.data_vars:
479
- if "field_name" not in field_and_source_xds[dv_names].dims:
480
- field_and_source_xds[dv_names] = field_and_source_xds[
481
- dv_names
482
- ].expand_dims("field_name")
425
+ if not field_and_source_xds.attrs["type"] == "field_and_source_ephemeris":
483
426
 
484
- if field_and_source_xds.is_ephemeris:
485
- if len(combined_ephemeris_field_and_source_xds.data_vars) == 0:
486
- combined_ephemeris_field_and_source_xds = field_and_source_xds
487
- else:
488
- combined_ephemeris_field_and_source_xds = xr.concat(
489
- [combined_ephemeris_field_and_source_xds, field_and_source_xds],
490
- dim="field_name",
427
+ if (
428
+ "line_name" in field_and_source_xds.coords
429
+ ): # Not including line info since it is a function of spw.
430
+ field_and_source_xds = field_and_source_xds.drop_vars(
431
+ ["LINE_REST_FREQUENCY", "LINE_SYSTEMIC_VELOCITY"],
432
+ errors="ignore",
491
433
  )
492
- else:
434
+ del field_and_source_xds["line_name"]
435
+ del field_and_source_xds["line_label"]
436
+
493
437
  if len(combined_field_and_source_xds.data_vars) == 0:
494
438
  combined_field_and_source_xds = field_and_source_xds
495
439
  else:
@@ -541,12 +485,86 @@ class ProcessingSet(dict):
541
485
  combined_field_and_source_xds.field_name[min_index].values
542
486
  )
543
487
 
488
+ return combined_field_and_source_xds
489
+
490
+ def get_combined_field_and_source_xds_ephemeris(self, data_group="base"):
491
+ """
492
+ Combine all ephemeris `field_and_source_xds` datasets from a Processing Set for a datagroup into a single dataset.
493
+
494
+ Parameters
495
+ ----------
496
+ data_group : str, optional
497
+ The data group to process. Default is "base".
498
+
499
+ Returns
500
+ -------
501
+ xarray.Dataset
502
+ - combined_ephemeris_field_and_source_xds: Combined dataset for ephemeris fields.
503
+
504
+ Raises
505
+ ------
506
+ ValueError
507
+ If the `field_and_source_xds` attribute is missing or improperly formatted in any Measurement Set.
508
+ """
509
+
510
+ combined_ephemeris_field_and_source_xds = xr.Dataset()
511
+ for ms_name, ms_xds in self.items():
512
+
513
+ correlated_data_name = ms_xds.attrs["data_groups"][data_group][
514
+ "correlated_data"
515
+ ]
516
+
517
+ field_and_source_xds = (
518
+ ms_xds[correlated_data_name]
519
+ .attrs["field_and_source_xds"]
520
+ .copy(deep=True)
521
+ )
522
+
523
+ if field_and_source_xds.attrs["type"] == "field_and_source_ephemeris":
524
+
525
+ if (
526
+ "line_name" in field_and_source_xds.coords
527
+ ): # Not including line info since it is a function of spw.
528
+ field_and_source_xds = field_and_source_xds.drop_vars(
529
+ ["LINE_REST_FREQUENCY", "LINE_SYSTEMIC_VELOCITY"],
530
+ errors="ignore",
531
+ )
532
+ del field_and_source_xds["line_name"]
533
+ del field_and_source_xds["line_label"]
534
+
535
+ from xradio.measurement_set._utils._msv2.msv4_sub_xdss import (
536
+ interpolate_to_time,
537
+ )
538
+
539
+ if "time_ephemeris" in field_and_source_xds:
540
+ field_and_source_xds = interpolate_to_time(
541
+ field_and_source_xds,
542
+ field_and_source_xds.time,
543
+ "field_and_source_xds",
544
+ "time_ephemeris",
545
+ )
546
+ del field_and_source_xds["time_ephemeris"]
547
+ field_and_source_xds = field_and_source_xds.rename(
548
+ {"time_ephemeris": "time"}
549
+ )
550
+
551
+ if "OBSERVER_POSITION" in field_and_source_xds:
552
+ field_and_source_xds = field_and_source_xds.drop_vars(
553
+ ["OBSERVER_POSITION"], errors="ignore"
554
+ )
555
+
556
+ if len(combined_ephemeris_field_and_source_xds.data_vars) == 0:
557
+ combined_ephemeris_field_and_source_xds = field_and_source_xds
558
+ else:
559
+
560
+ combined_ephemeris_field_and_source_xds = xr.concat(
561
+ [combined_ephemeris_field_and_source_xds, field_and_source_xds],
562
+ dim="time",
563
+ )
564
+
544
565
  if (len(combined_ephemeris_field_and_source_xds.data_vars) > 0) and (
545
566
  "FIELD_PHASE_CENTER" in combined_ephemeris_field_and_source_xds
546
567
  ):
547
- combined_ephemeris_field_and_source_xds = (
548
- combined_ephemeris_field_and_source_xds.drop_duplicates("field_name")
549
- )
550
568
 
551
569
  from xradio._utils.coord_math import wrap_to_pi
552
570
 
@@ -556,7 +574,7 @@ class ProcessingSet(dict):
556
574
  )
557
575
  combined_ephemeris_field_and_source_xds["FIELD_OFFSET"] = xr.DataArray(
558
576
  wrap_to_pi(offset.sel(sky_pos_label=["ra", "dec"])).values,
559
- dims=["field_name", "sky_dir_label"],
577
+ dims=["time", "sky_dir_label"],
560
578
  )
561
579
  combined_ephemeris_field_and_source_xds["FIELD_OFFSET"].attrs = (
562
580
  combined_ephemeris_field_and_source_xds["FIELD_PHASE_CENTER"].attrs
@@ -589,7 +607,7 @@ class ProcessingSet(dict):
589
607
  combined_ephemeris_field_and_source_xds.field_name[min_index].values
590
608
  )
591
609
 
592
- return combined_field_and_source_xds, combined_ephemeris_field_and_source_xds
610
+ return combined_ephemeris_field_and_source_xds
593
611
 
594
612
  def plot_phase_centers(self, label_all_fields=False, data_group="base"):
595
613
  """
@@ -615,8 +633,11 @@ class ProcessingSet(dict):
615
633
  ValueError
616
634
  If the combined datasets are empty or improperly formatted.
617
635
  """
618
- combined_field_and_source_xds, combined_ephemeris_field_and_source_xds = (
619
- self.get_combined_field_and_source_xds(data_group)
636
+ combined_field_and_source_xds = self.get_combined_field_and_source_xds(
637
+ data_group
638
+ )
639
+ combined_ephemeris_field_and_source_xds = (
640
+ self.get_combined_field_and_source_xds_ephemeris(data_group)
620
641
  )
621
642
  from matplotlib import pyplot as plt
622
643
 
@@ -669,6 +690,11 @@ class ProcessingSet(dict):
669
690
  center_field_name = combined_ephemeris_field_and_source_xds.attrs[
670
691
  "center_field_name"
671
692
  ]
693
+
694
+ combined_ephemeris_field_and_source_xds = (
695
+ combined_ephemeris_field_and_source_xds.set_xindex("field_name")
696
+ )
697
+
672
698
  center_field = combined_ephemeris_field_and_source_xds.sel(
673
699
  field_name=center_field_name
674
700
  )