xradio 0.0.33__py3-none-any.whl → 0.0.36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,17 +4,22 @@ from .._zarr.encoding import add_encoding
4
4
  from typing import Dict, Union
5
5
  import graphviper.utils.logger as logger
6
6
  import os
7
+ import pathlib
7
8
 
8
9
  import numpy as np
9
10
  import xarray as xr
10
11
 
11
12
  from casacore import tables
12
- from .msv4_sub_xdss import create_ant_xds, create_pointing_xds, create_weather_xds
13
+ from xradio.vis._vis_utils._ms.msv4_sub_xdss import (
14
+ create_pointing_xds,
15
+ create_weather_xds,
16
+ )
17
+ from xradio.vis._vis_utils._ms.create_antenna_xds import create_antenna_xds
13
18
  from xradio.vis._vis_utils._ms.create_field_and_source_xds import (
14
19
  create_field_and_source_xds,
15
20
  )
21
+ from xradio._utils.schema import column_description_casacore_to_msv4_measure
16
22
  from .msv2_to_msv4_meta import (
17
- column_description_casacore_to_msv4_measure,
18
23
  create_attribute_metadata,
19
24
  col_to_data_variable_names,
20
25
  col_dims,
@@ -400,7 +405,7 @@ def calc_indx_for_row_split(tb_tool, taql_where):
400
405
 
401
406
 
402
407
  def create_coordinates(
403
- xds, in_file, ddi, utime, interval, baseline_ant1_id, baseline_ant2_id
408
+ xds, in_file, ddi, utime, interval, baseline_ant1_id, baseline_ant2_id, scan_id
404
409
  ):
405
410
  coords = {
406
411
  "time": utime,
@@ -408,6 +413,7 @@ def create_coordinates(
408
413
  "baseline_antenna2_id": ("baseline_id", baseline_ant2_id),
409
414
  "uvw_label": ["u", "v", "w"],
410
415
  "baseline_id": np.arange(len(baseline_ant1_id)),
416
+ "scan_number": ("time", scan_id),
411
417
  }
412
418
 
413
419
  ddi_xds = load_generic_table(in_file, "DATA_DESCRIPTION").sel(row=ddi)
@@ -446,12 +452,12 @@ def create_coordinates(
446
452
  )
447
453
  xds.frequency.attrs.update(msv4_measure)
448
454
 
449
- if (spectral_window_xds.NAME.values.item() is None) or (
450
- spectral_window_xds.NAME.values.item() == "none"
451
- ):
455
+ spw_name = spectral_window_xds.NAME.values.item()
456
+ if (spw_name is None) or (spw_name == "none") or (spw_name == ""):
452
457
  spw_name = "spw_" + str(spectral_window_id)
453
458
  else:
454
- spw_name = spectral_window_xds.NAME.values.item()
459
+ # spw_name = spectral_window_xds.NAME.values.item()
460
+ spw_name = spw_name + "_" + str(spectral_window_id)
455
461
 
456
462
  xds.frequency.attrs["spectral_window_name"] = spw_name
457
463
  msv4_measure = column_description_casacore_to_msv4_measure(
@@ -558,21 +564,16 @@ def create_data_variables(
558
564
  try:
559
565
  start = time.time()
560
566
  if col == "WEIGHT":
561
- xds[col_to_data_variable_names[col]] = xr.DataArray(
562
- np.tile(
563
- read_col_conversion(
564
- tb_tool,
565
- col,
566
- time_baseline_shape,
567
- tidxs,
568
- bidxs,
569
- use_table_iter,
570
- )[:, :, None, :],
571
- (1, 1, xds.sizes["frequency"], 1),
572
- ),
573
- dims=col_dims[col],
567
+ xds = get_weight(
568
+ xds,
569
+ col,
570
+ tb_tool,
571
+ time_baseline_shape,
572
+ tidxs,
573
+ bidxs,
574
+ use_table_iter,
575
+ main_column_descriptions,
574
576
  )
575
-
576
577
  else:
577
578
  xds[col_to_data_variable_names[col]] = xr.DataArray(
578
579
  read_col_conversion(
@@ -585,20 +586,73 @@ def create_data_variables(
585
586
  ),
586
587
  dims=col_dims[col],
587
588
  )
588
- logger.debug(
589
- "Time to read column "
590
- + str(col)
591
- + " : "
592
- + str(time.time() - start)
589
+
590
+ xds[col_to_data_variable_names[col]].attrs.update(
591
+ create_attribute_metadata(col, main_column_descriptions)
592
+ )
593
+
594
+ logger.debug(
595
+ "Time to read column " + str(col) + " : " + str(time.time() - start)
596
+ )
597
+ except Exception as e:
598
+ logger.debug("Could not load column", col)
599
+
600
+ if ("WEIGHT_SPECTRUM" == col) and (
601
+ "WEIGHT" in col_names
602
+ ): # Bogus WEIGHT_SPECTRUM column, need to use WEIGHT.
603
+ xds = get_weight(
604
+ xds,
605
+ "WEIGHT",
606
+ tb_tool,
607
+ time_baseline_shape,
608
+ tidxs,
609
+ bidxs,
610
+ use_table_iter,
611
+ main_column_descriptions,
593
612
  )
594
- except:
595
- # logger.debug("Could not load column",col)
596
- # print("Could not load column", col)
597
- continue
598
613
 
599
- xds[col_to_data_variable_names[col]].attrs.update(
600
- create_attribute_metadata(col, main_column_descriptions)
601
- )
614
+
615
+ def add_missing_data_var_attrs(xds):
616
+ """Adds in attributes expected metadata that cannot be found
617
+ in the input MSv2. For now specifically for missing
618
+ single-dish/SPECTRUM metadata"""
619
+ data_var_names = ["SPECTRUM", "SPECTRUM_CORRECTED"]
620
+ for var_name in data_var_names:
621
+ if var_name in xds.data_vars:
622
+ xds.data_vars[var_name].attrs["units"] = ["Jy"]
623
+
624
+ return xds
625
+
626
+
627
+ def get_weight(
628
+ xds,
629
+ col,
630
+ tb_tool,
631
+ time_baseline_shape,
632
+ tidxs,
633
+ bidxs,
634
+ use_table_iter,
635
+ main_column_descriptions,
636
+ ):
637
+ xds[col_to_data_variable_names[col]] = xr.DataArray(
638
+ np.tile(
639
+ read_col_conversion(
640
+ tb_tool,
641
+ col,
642
+ time_baseline_shape,
643
+ tidxs,
644
+ bidxs,
645
+ use_table_iter,
646
+ )[:, :, None, :],
647
+ (1, 1, xds.sizes["frequency"], 1),
648
+ ),
649
+ dims=col_dims[col],
650
+ )
651
+
652
+ xds[col_to_data_variable_names[col]].attrs.update(
653
+ create_attribute_metadata(col, main_column_descriptions)
654
+ )
655
+ return xds
602
656
 
603
657
 
604
658
  def create_taql_query(partition_info):
@@ -626,7 +680,7 @@ def create_taql_query(partition_info):
626
680
  def convert_and_write_partition(
627
681
  in_file: str,
628
682
  out_file: str,
629
- ms_v4_id: int,
683
+ ms_v4_id: Union[int, str],
630
684
  partition_info: Dict,
631
685
  use_table_iter: bool,
632
686
  partition_scheme: str = "ddi_intent_field",
@@ -635,6 +689,7 @@ def convert_and_write_partition(
635
689
  pointing_chunksize: Union[Dict, float, None] = None,
636
690
  pointing_interpolate: bool = False,
637
691
  ephemeris_interpolate: bool = False,
692
+ phase_cal_interpolate: bool = False,
638
693
  compressor: numcodecs.abc.Codec = numcodecs.Zstd(level=2),
639
694
  storage_backend="zarr",
640
695
  overwrite: bool = False,
@@ -739,8 +794,19 @@ def convert_and_write_partition(
739
794
  else:
740
795
  interval = interval_unique[0]
741
796
 
797
+ scan_id = np.full(time_baseline_shape, -42, dtype=int)
798
+ scan_id[tidxs, bidxs] = tb_tool.getcol("SCAN_NUMBER")
799
+ scan_id = np.max(scan_id, axis=1)
800
+
742
801
  xds = create_coordinates(
743
- xds, in_file, ddi, utime, interval, baseline_ant1_id, baseline_ant2_id
802
+ xds,
803
+ in_file,
804
+ ddi,
805
+ utime,
806
+ interval,
807
+ baseline_ant1_id,
808
+ baseline_ant2_id,
809
+ scan_id,
744
810
  )
745
811
  logger.debug("Time create coordinates " + str(time.time() - start))
746
812
 
@@ -756,16 +822,30 @@ def convert_and_write_partition(
756
822
  use_table_iter,
757
823
  )
758
824
 
825
+ # Add data_groups
826
+ xds, is_single_dish = add_data_groups(xds)
827
+
828
+ xds = add_missing_data_var_attrs(xds)
829
+
759
830
  if (
760
831
  "WEIGHT" not in xds.data_vars
761
832
  ): # Some single dish datasets don't have WEIGHT.
762
- xds["WEIGHT"] = xr.DataArray(
763
- np.ones(xds.SPECTRUM.shape, dtype=np.float64),
764
- dims=xds.SPECTRUM.dims,
765
- )
833
+ if is_single_dish:
834
+ xds["WEIGHT"] = xr.DataArray(
835
+ np.ones(xds.SPECTRUM.shape, dtype=np.float64),
836
+ dims=xds.SPECTRUM.dims,
837
+ )
838
+ else:
839
+ xds["WEIGHT"] = xr.DataArray(
840
+ np.ones(xds.VISIBILITY.shape, dtype=np.float64),
841
+ dims=xds.VISIBILITY.dims,
842
+ )
766
843
 
767
844
  logger.debug("Time create data variables " + str(time.time() - start))
768
845
 
846
+ # To constrain the time range to load (in pointing, ephemerides, phase_cal data_vars)
847
+ time_min_max = find_min_max_times(tb_tool, taql_where)
848
+
769
849
  # Create ant_xds
770
850
  start = time.time()
771
851
  feed_id = unique_1d(
@@ -781,17 +861,26 @@ def convert_and_write_partition(
781
861
  [xds["baseline_antenna1_id"].data, xds["baseline_antenna2_id"].data]
782
862
  )
783
863
  )
864
+ if phase_cal_interpolate:
865
+ phase_cal_interp_time = xds.time.values
866
+ else:
867
+ phase_cal_interp_time = None
784
868
 
785
- ant_xds = create_ant_xds(
869
+ ant_xds = create_antenna_xds(
786
870
  in_file,
787
871
  xds.frequency.attrs["spectral_window_id"],
788
872
  antenna_id,
789
873
  feed_id,
790
874
  telescope_name,
875
+ time_min_max,
876
+ phase_cal_interp_time,
791
877
  )
792
878
 
793
879
  # Change antenna_ids to antenna_names
794
880
  xds = antenna_ids_to_names(xds, ant_xds)
881
+ ant_xds = ant_xds.drop_vars(
882
+ "antenna_id"
883
+ ) # No longer needed after converting to name.
795
884
 
796
885
  logger.debug("Time ant xds " + str(time.time() - start))
797
886
 
@@ -800,9 +889,7 @@ def convert_and_write_partition(
800
889
  weather_xds = create_weather_xds(in_file)
801
890
  logger.debug("Time weather " + str(time.time() - start))
802
891
 
803
- # To constrain the time range to load (in pointing, ephemerides data_vars)
804
- time_min_max = find_min_max_times(tb_tool, taql_where)
805
-
892
+ # Create pointing_xds
806
893
  if with_pointing:
807
894
  start = time.time()
808
895
  if pointing_interpolate:
@@ -824,6 +911,7 @@ def convert_and_write_partition(
824
911
  )
825
912
 
826
913
  start = time.time()
914
+ xds.attrs["type"] = "visibility"
827
915
 
828
916
  # Time and frequency should always be increasing
829
917
  if len(xds.frequency) > 1 and xds.frequency[1] - xds.frequency[0] < 0:
@@ -832,9 +920,6 @@ def convert_and_write_partition(
832
920
  if len(xds.time) > 1 and xds.time[1] - xds.time[0] < 0:
833
921
  xds = xds.sel(time=slice(None, None, -1))
834
922
 
835
- # Add data_groups and field_info
836
- xds, is_single_dish = add_data_groups(xds)
837
-
838
923
  # Create field_and_source_xds (combines field, source and ephemeris data into one super dataset)
839
924
  start = time.time()
840
925
  if ephemeris_interpolate:
@@ -842,10 +927,6 @@ def convert_and_write_partition(
842
927
  else:
843
928
  ephemeris_interp_time = None
844
929
 
845
- scan_id = np.full(time_baseline_shape, -42, dtype=int)
846
- scan_id[tidxs, bidxs] = tb_tool.getcol("SCAN_NUMBER")
847
- scan_id = np.max(scan_id, axis=1)
848
-
849
930
  if "FIELD_ID" not in partition_scheme:
850
931
  field_id = np.full(time_baseline_shape, -42, dtype=int)
851
932
  field_id[tidxs, bidxs] = tb_tool.getcol("FIELD_ID")
@@ -859,7 +940,7 @@ def convert_and_write_partition(
859
940
  # assert len(col_unique) == 1, col_name + " is not consistent."
860
941
  # return col_unique[0]
861
942
 
862
- field_and_source_xds, source_id = create_field_and_source_xds(
943
+ field_and_source_xds, source_id, num_lines = create_field_and_source_xds(
863
944
  in_file,
864
945
  field_id,
865
946
  xds.frequency.attrs["spectral_window_id"],
@@ -893,11 +974,20 @@ def convert_and_write_partition(
893
974
 
894
975
  file_name = os.path.join(
895
976
  out_file,
896
- out_file.replace(".vis.zarr", "").replace(".zarr", "").split("/")[-1]
977
+ pathlib.Path(out_file)
978
+ .name.replace(".vis.zarr", "")
979
+ .replace(".zarr", "")
897
980
  + "_"
898
981
  + str(ms_v4_id),
899
982
  )
900
983
 
984
+ if "line_name" in field_and_source_xds.coords:
985
+ line_name = to_list(
986
+ unique_1d(np.ravel(field_and_source_xds.line_name.values))
987
+ )
988
+ else:
989
+ line_name = []
990
+
901
991
  xds.attrs["partition_info"] = {
902
992
  # "spectral_window_id": xds.frequency.attrs["spectral_window_id"],
903
993
  "spectral_window_name": xds.frequency.attrs["spectral_window_name"],
@@ -906,11 +996,14 @@ def convert_and_write_partition(
906
996
  np.unique(field_and_source_xds.field_name.values)
907
997
  ),
908
998
  # "source_id": to_list(unique_1d(source_id)),
999
+ "line_name": line_name,
1000
+ "scan_number": to_list(np.unique(scan_id)),
909
1001
  "source_name": to_list(
910
1002
  np.unique(field_and_source_xds.source_name.values)
911
1003
  ),
912
1004
  "polarization_setup": to_list(xds.polarization.values),
913
- "obs_mode": obs_mode,
1005
+ "num_lines": num_lines,
1006
+ "obs_mode": obs_mode.split(","),
914
1007
  "taql": taql_where,
915
1008
  }
916
1009
 
@@ -926,7 +1019,7 @@ def convert_and_write_partition(
926
1019
  mode=mode,
927
1020
  )
928
1021
 
929
- if with_pointing:
1022
+ if with_pointing and len(pointing_xds.data_vars) > 1:
930
1023
  pointing_xds.to_zarr(
931
1024
  store=os.path.join(file_name, "POINTING"), mode=mode
932
1025
  )
@@ -945,69 +1038,26 @@ def convert_and_write_partition(
945
1038
 
946
1039
 
947
1040
  def antenna_ids_to_names(xds, ant_xds):
1041
+ ant_xds = ant_xds.set_xindex(
1042
+ "antenna_id"
1043
+ ) # Allows for non-dimension coordinate selection.
948
1044
 
949
- if ant_xds.attrs["overall_telescope_name"] in ["ALMA", "VLA", "NOEMA", "EVLA"]:
950
- moving_antennas = True
951
- else:
952
- moving_antennas = False
953
-
954
- if moving_antennas:
955
- if "baseline_antenna1_id" in xds: # Interferometer
956
-
957
- baseline_ant1_name = np.core.defchararray.add(
958
- ant_xds["name"].sel(antenna_id=xds["baseline_antenna1_id"]).values, "_"
959
- )
960
- baseline_ant1_name = np.core.defchararray.add(
961
- baseline_ant1_name,
962
- ant_xds["station"].sel(antenna_id=xds["baseline_antenna1_id"]).values,
963
- )
964
- baseline_ant2_name = np.core.defchararray.add(
965
- ant_xds["name"].sel(antenna_id=xds["baseline_antenna2_id"]).values, "_"
966
- )
967
- baseline_ant2_name = np.core.defchararray.add(
968
- baseline_ant2_name,
969
- ant_xds["station"].sel(antenna_id=xds["baseline_antenna2_id"]).values,
970
- )
971
-
972
- xds["baseline_antenna1_id"] = xr.DataArray(
973
- baseline_ant1_name, dims="baseline_id"
974
- )
975
- xds["baseline_antenna2_id"] = xr.DataArray(
976
- baseline_ant2_name, dims="baseline_id"
977
- )
978
- xds = xds.rename(
979
- {
980
- "baseline_antenna1_id": "baseline_antenna1_name",
981
- "baseline_antenna2_id": "baseline_antenna2_name",
982
- }
983
- )
984
- else: # Single Dish
985
- antenna_name = np.core.defchararray.add(
986
- ant_xds["name"].sel(antenna_id=xds["antenna_id"]).values, "_"
987
- )
988
- antenna_name = np.core.defchararray.add(
989
- antenna_name,
990
- ant_xds["station"].sel(antenna_id=xds["antenna_id"]).values,
991
- )
992
- xds["antenna_id"] = xr.DataArray(antenna_name, dims="baseline_id")
993
- xds = xds.rename({"antenna_id": "antenna_name"})
994
- else:
995
- if "baseline_antenna1_id" in xds: # Interferometer
996
- xds["baseline_antenna1_id"] = ant_xds["name"].sel(
997
- antenna_id=xds["baseline_antenna1_id"]
998
- )
999
- xds["baseline_antenna2_id"] = ant_xds["name"].sel(
1000
- antenna_id=xds["baseline_antenna2_id"]
1001
- )
1002
- xds = xds.rename(
1003
- {
1004
- "baseline_antenna1_id": "baseline_antenna1_name",
1005
- "baseline_antenna2_id": "baseline_antenna2_name",
1006
- }
1007
- )
1008
- else: # Single Dish
1009
- xds["antenna_id"] = ant_xds["name"].sel(antenna_id=xds["antenna_id"])
1010
- xds = xds.rename({"antenna_id": "antenna_name"})
1045
+ if "baseline_antenna1_id" in xds: # Interferometer
1046
+ xds["baseline_antenna1_id"] = ant_xds["antenna_name"].sel(
1047
+ antenna_id=xds["baseline_antenna1_id"]
1048
+ )
1049
+ xds["baseline_antenna2_id"] = ant_xds["antenna_name"].sel(
1050
+ antenna_id=xds["baseline_antenna2_id"]
1051
+ )
1052
+ xds = xds.rename(
1053
+ {
1054
+ "baseline_antenna1_id": "baseline_antenna1_name",
1055
+ "baseline_antenna2_id": "baseline_antenna2_name",
1056
+ }
1057
+ )
1058
+ else: # Single Dish
1059
+ xds["antenna_id"] = ant_xds["antenna_name"].sel(antenna_id=xds["antenna_id"])
1060
+ xds = xds.rename({"antenna_id": "antenna_name"})
1011
1061
 
1012
1062
  return xds
1013
1063