xradio 0.0.34__py3-none-any.whl → 0.0.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,17 +4,22 @@ from .._zarr.encoding import add_encoding
4
4
  from typing import Dict, Union
5
5
  import graphviper.utils.logger as logger
6
6
  import os
7
+ import pathlib
7
8
 
8
9
  import numpy as np
9
10
  import xarray as xr
10
11
 
11
12
  from casacore import tables
12
- from .msv4_sub_xdss import create_ant_xds, create_pointing_xds, create_weather_xds
13
+ from xradio.vis._vis_utils._ms.msv4_sub_xdss import (
14
+ create_pointing_xds,
15
+ create_weather_xds,
16
+ )
17
+ from xradio.vis._vis_utils._ms.create_antenna_xds import create_antenna_xds
13
18
  from xradio.vis._vis_utils._ms.create_field_and_source_xds import (
14
19
  create_field_and_source_xds,
15
20
  )
21
+ from xradio._utils.schema import column_description_casacore_to_msv4_measure
16
22
  from .msv2_to_msv4_meta import (
17
- column_description_casacore_to_msv4_measure,
18
23
  create_attribute_metadata,
19
24
  col_to_data_variable_names,
20
25
  col_dims,
@@ -75,7 +80,7 @@ def check_chunksize(chunksize: dict, xds_type: str) -> None:
75
80
  allowed_dims = [
76
81
  "time",
77
82
  "baseline_id",
78
- "antenna_id",
83
+ "antenna_name",
79
84
  "frequency",
80
85
  "polarization",
81
86
  ]
@@ -133,7 +138,7 @@ def mem_chunksize_to_dict_main(chunksize: float, xds: xr.Dataset) -> Dict[str, i
133
138
  It presently relies on the logic of mem_chunksize_to_dict_main_balanced() to find a
134
139
  balanced list of dimension sizes for the chunks
135
140
 
136
- Assumes these relevant dims: (time, antenna_id/baseline_id, frequency,
141
+ Assumes these relevant dims: (time, antenna_name/baseline_id, frequency,
137
142
  polarization).
138
143
  """
139
144
 
@@ -144,11 +149,11 @@ def mem_chunksize_to_dict_main(chunksize: float, xds: xr.Dataset) -> Dict[str, i
144
149
  "Cannot calculate chunk sizes when memory bound ({chunksize}) does not even allow all polarizations in one chunk"
145
150
  )
146
151
 
147
- baseline_or_antenna_id = find_baseline_or_antenna_var(xds)
148
- total_size = calc_used_gb(xds.sizes, baseline_or_antenna_id, sizeof_vis)
152
+ baseline_or_antenna_name = find_baseline_or_antenna_var(xds)
153
+ total_size = calc_used_gb(xds.sizes, baseline_or_antenna_name, sizeof_vis)
149
154
 
150
155
  ratio = chunksize / total_size
151
- chunked_dims = ["time", baseline_or_antenna_id, "frequency", "polarization"]
156
+ chunked_dims = ["time", baseline_or_antenna_name, "frequency", "polarization"]
152
157
  if ratio >= 1:
153
158
  result = {dim: xds.sizes[dim] for dim in chunked_dims}
154
159
  logger.debug(
@@ -157,14 +162,17 @@ def mem_chunksize_to_dict_main(chunksize: float, xds: xr.Dataset) -> Dict[str, i
157
162
  else:
158
163
  xds_dim_sizes = {k: xds.sizes[k] for k in chunked_dims}
159
164
  result = mem_chunksize_to_dict_main_balanced(
160
- chunksize, xds_dim_sizes, baseline_or_antenna_id, sizeof_vis
165
+ chunksize, xds_dim_sizes, baseline_or_antenna_name, sizeof_vis
161
166
  )
162
167
 
163
168
  return result
164
169
 
165
170
 
166
171
  def mem_chunksize_to_dict_main_balanced(
167
- chunksize: float, xds_dim_sizes: dict, baseline_or_antenna_id: str, sizeof_vis: int
172
+ chunksize: float,
173
+ xds_dim_sizes: dict,
174
+ baseline_or_antenna_name: str,
175
+ sizeof_vis: int,
168
176
  ) -> Dict[str, int]:
169
177
  """
170
178
  Assumes the ratio is <1 and all pols can fit in memory (from
@@ -231,7 +239,7 @@ def mem_chunksize_to_dict_main_balanced(
231
239
  dim_chunksizes[idx] += int_delta
232
240
  used = np.prod(dim_chunksizes) * sizeof_vis / GiBYTES_TO_BYTES
233
241
 
234
- chunked_dim_names = ["time", baseline_or_antenna_id, "frequency", "polarization"]
242
+ chunked_dim_names = ["time", baseline_or_antenna_name, "frequency", "polarization"]
235
243
  dim_chunksizes_int = [int(v) for v in dim_chunksizes]
236
244
  result = dict(zip(chunked_dim_names, dim_chunksizes_int))
237
245
 
@@ -309,11 +317,11 @@ def mem_chunksize_to_dict_pointing(chunksize: float, xds: xr.Dataset) -> Dict[st
309
317
 
310
318
  def find_baseline_or_antenna_var(xds: xr.Dataset) -> str:
311
319
  if "baseline_id" in xds.coords:
312
- baseline_or_antenna_id = "baseline_id"
313
- elif "antenna_id" in xds.coords:
314
- baseline_or_antenna_id = "antenna_id"
320
+ baseline_or_antenna_name = "baseline_id"
321
+ elif "antenna_name" in xds.coords:
322
+ baseline_or_antenna_name = "antenna_name"
315
323
 
316
- return baseline_or_antenna_id
324
+ return baseline_or_antenna_name
317
325
 
318
326
 
319
327
  def itemsize_vis_spec(xds: xr.Dataset) -> int:
@@ -347,11 +355,11 @@ def itemsize_pointing_spec(xds: xr.Dataset) -> int:
347
355
 
348
356
 
349
357
  def calc_used_gb(
350
- chunksizes: dict, baseline_or_antenna_id: str, sizeof_vis: int
358
+ chunksizes: dict, baseline_or_antenna_name: str, sizeof_vis: int
351
359
  ) -> float:
352
360
  return (
353
361
  chunksizes["time"]
354
- * chunksizes[baseline_or_antenna_id]
362
+ * chunksizes[baseline_or_antenna_name]
355
363
  * chunksizes["frequency"]
356
364
  * chunksizes["polarization"]
357
365
  * sizeof_vis
@@ -400,7 +408,7 @@ def calc_indx_for_row_split(tb_tool, taql_where):
400
408
 
401
409
 
402
410
  def create_coordinates(
403
- xds, in_file, ddi, utime, interval, baseline_ant1_id, baseline_ant2_id
411
+ xds, in_file, ddi, utime, interval, baseline_ant1_id, baseline_ant2_id, scan_id
404
412
  ):
405
413
  coords = {
406
414
  "time": utime,
@@ -408,6 +416,7 @@ def create_coordinates(
408
416
  "baseline_antenna2_id": ("baseline_id", baseline_ant2_id),
409
417
  "uvw_label": ["u", "v", "w"],
410
418
  "baseline_id": np.arange(len(baseline_ant1_id)),
419
+ "scan_number": ("time", scan_id),
411
420
  }
412
421
 
413
422
  ddi_xds = load_generic_table(in_file, "DATA_DESCRIPTION").sel(row=ddi)
@@ -446,12 +455,12 @@ def create_coordinates(
446
455
  )
447
456
  xds.frequency.attrs.update(msv4_measure)
448
457
 
449
- if (spectral_window_xds.NAME.values.item() is None) or (
450
- spectral_window_xds.NAME.values.item() == "none"
451
- ):
458
+ spw_name = spectral_window_xds.NAME.values.item()
459
+ if (spw_name is None) or (spw_name == "none") or (spw_name == ""):
452
460
  spw_name = "spw_" + str(spectral_window_id)
453
461
  else:
454
- spw_name = spectral_window_xds.NAME.values.item()
462
+ # spw_name = spectral_window_xds.NAME.values.item()
463
+ spw_name = spw_name + "_" + str(spectral_window_id)
455
464
 
456
465
  xds.frequency.attrs["spectral_window_name"] = spw_name
457
466
  msv4_measure = column_description_casacore_to_msv4_measure(
@@ -588,7 +597,7 @@ def create_data_variables(
588
597
  logger.debug(
589
598
  "Time to read column " + str(col) + " : " + str(time.time() - start)
590
599
  )
591
- except:
600
+ except Exception as e:
592
601
  logger.debug("Could not load column", col)
593
602
 
594
603
  if ("WEIGHT_SPECTRUM" == col) and (
@@ -606,6 +615,18 @@ def create_data_variables(
606
615
  )
607
616
 
608
617
 
618
+ def add_missing_data_var_attrs(xds):
619
+ """Adds in attributes expected metadata that cannot be found
620
+ in the input MSv2. For now specifically for missing
621
+ single-dish/SPECTRUM metadata"""
622
+ data_var_names = ["SPECTRUM", "SPECTRUM_CORRECTED"]
623
+ for var_name in data_var_names:
624
+ if var_name in xds.data_vars:
625
+ xds.data_vars[var_name].attrs["units"] = ["Jy"]
626
+
627
+ return xds
628
+
629
+
609
630
  def get_weight(
610
631
  xds,
611
632
  col,
@@ -662,7 +683,7 @@ def create_taql_query(partition_info):
662
683
  def convert_and_write_partition(
663
684
  in_file: str,
664
685
  out_file: str,
665
- ms_v4_id: int,
686
+ ms_v4_id: Union[int, str],
666
687
  partition_info: Dict,
667
688
  use_table_iter: bool,
668
689
  partition_scheme: str = "ddi_intent_field",
@@ -671,6 +692,7 @@ def convert_and_write_partition(
671
692
  pointing_chunksize: Union[Dict, float, None] = None,
672
693
  pointing_interpolate: bool = False,
673
694
  ephemeris_interpolate: bool = False,
695
+ phase_cal_interpolate: bool = False,
674
696
  compressor: numcodecs.abc.Codec = numcodecs.Zstd(level=2),
675
697
  storage_backend="zarr",
676
698
  overwrite: bool = False,
@@ -715,7 +737,6 @@ def convert_and_write_partition(
715
737
  """
716
738
 
717
739
  taql_where = create_taql_query(partition_info)
718
- # print("taql_where", taql_where)
719
740
  ddi = partition_info["DATA_DESC_ID"][0]
720
741
  obs_mode = str(partition_info["OBS_MODE"][0])
721
742
 
@@ -775,8 +796,19 @@ def convert_and_write_partition(
775
796
  else:
776
797
  interval = interval_unique[0]
777
798
 
799
+ scan_id = np.full(time_baseline_shape, -42, dtype=int)
800
+ scan_id[tidxs, bidxs] = tb_tool.getcol("SCAN_NUMBER")
801
+ scan_id = np.max(scan_id, axis=1)
802
+
778
803
  xds = create_coordinates(
779
- xds, in_file, ddi, utime, interval, baseline_ant1_id, baseline_ant2_id
804
+ xds,
805
+ in_file,
806
+ ddi,
807
+ utime,
808
+ interval,
809
+ baseline_ant1_id,
810
+ baseline_ant2_id,
811
+ scan_id,
780
812
  )
781
813
  logger.debug("Time create coordinates " + str(time.time() - start))
782
814
 
@@ -792,9 +824,11 @@ def convert_and_write_partition(
792
824
  use_table_iter,
793
825
  )
794
826
 
795
- # Add data_groups and field_info
827
+ # Add data_groups
796
828
  xds, is_single_dish = add_data_groups(xds)
797
829
 
830
+ xds = add_missing_data_var_attrs(xds)
831
+
798
832
  if (
799
833
  "WEIGHT" not in xds.data_vars
800
834
  ): # Some single dish datasets don't have WEIGHT.
@@ -811,6 +845,9 @@ def convert_and_write_partition(
811
845
 
812
846
  logger.debug("Time create data variables " + str(time.time() - start))
813
847
 
848
+ # To constrain the time range to load (in pointing, ephemerides, phase_cal data_vars)
849
+ time_min_max = find_min_max_times(tb_tool, taql_where)
850
+
814
851
  # Create ant_xds
815
852
  start = time.time()
816
853
  feed_id = unique_1d(
@@ -826,17 +863,27 @@ def convert_and_write_partition(
826
863
  [xds["baseline_antenna1_id"].data, xds["baseline_antenna2_id"].data]
827
864
  )
828
865
  )
866
+ if phase_cal_interpolate:
867
+ phase_cal_interp_time = xds.time.values
868
+ else:
869
+ phase_cal_interp_time = None
829
870
 
830
- ant_xds = create_ant_xds(
871
+ ant_xds = create_antenna_xds(
831
872
  in_file,
832
873
  xds.frequency.attrs["spectral_window_id"],
833
874
  antenna_id,
834
875
  feed_id,
835
876
  telescope_name,
877
+ time_min_max,
878
+ phase_cal_interp_time,
836
879
  )
837
880
 
838
881
  # Change antenna_ids to antenna_names
839
- xds = antenna_ids_to_names(xds, ant_xds)
882
+ xds = antenna_ids_to_names(xds, ant_xds, is_single_dish)
883
+ ant_xds_name_ids = ant_xds["antenna_name"].set_xindex("antenna_id")
884
+ ant_xds = ant_xds.drop_vars(
885
+ "antenna_id"
886
+ ) # No longer needed after converting to name.
840
887
 
841
888
  logger.debug("Time ant xds " + str(time.time() - start))
842
889
 
@@ -845,9 +892,7 @@ def convert_and_write_partition(
845
892
  weather_xds = create_weather_xds(in_file)
846
893
  logger.debug("Time weather " + str(time.time() - start))
847
894
 
848
- # To constrain the time range to load (in pointing, ephemerides data_vars)
849
- time_min_max = find_min_max_times(tb_tool, taql_where)
850
-
895
+ # Create pointing_xds
851
896
  if with_pointing:
852
897
  start = time.time()
853
898
  if pointing_interpolate:
@@ -855,7 +900,7 @@ def convert_and_write_partition(
855
900
  else:
856
901
  pointing_interp_time = None
857
902
  pointing_xds = create_pointing_xds(
858
- in_file, time_min_max, pointing_interp_time
903
+ in_file, ant_xds_name_ids, time_min_max, pointing_interp_time
859
904
  )
860
905
  pointing_chunksize = parse_chunksize(
861
906
  pointing_chunksize, "pointing", pointing_xds
@@ -869,6 +914,7 @@ def convert_and_write_partition(
869
914
  )
870
915
 
871
916
  start = time.time()
917
+ xds.attrs["type"] = "visibility"
872
918
 
873
919
  # Time and frequency should always be increasing
874
920
  if len(xds.frequency) > 1 and xds.frequency[1] - xds.frequency[0] < 0:
@@ -884,10 +930,6 @@ def convert_and_write_partition(
884
930
  else:
885
931
  ephemeris_interp_time = None
886
932
 
887
- scan_id = np.full(time_baseline_shape, -42, dtype=int)
888
- scan_id[tidxs, bidxs] = tb_tool.getcol("SCAN_NUMBER")
889
- scan_id = np.max(scan_id, axis=1)
890
-
891
933
  if "FIELD_ID" not in partition_scheme:
892
934
  field_id = np.full(time_baseline_shape, -42, dtype=int)
893
935
  field_id[tidxs, bidxs] = tb_tool.getcol("FIELD_ID")
@@ -901,7 +943,7 @@ def convert_and_write_partition(
901
943
  # assert len(col_unique) == 1, col_name + " is not consistent."
902
944
  # return col_unique[0]
903
945
 
904
- field_and_source_xds, source_id = create_field_and_source_xds(
946
+ field_and_source_xds, source_id, num_lines = create_field_and_source_xds(
905
947
  in_file,
906
948
  field_id,
907
949
  xds.frequency.attrs["spectral_window_id"],
@@ -914,7 +956,6 @@ def convert_and_write_partition(
914
956
 
915
957
  # Fix UVW frame
916
958
  # From CASA fixvis docs: clean and the im tool ignore the reference frame claimed by the UVW column (it is often mislabelled as ITRF when it is really FK5 (J2000)) and instead assume the (u, v, w)s are in the same frame as the phase tracking center. calcuvw does not yet force the UVW column and field centers to use the same reference frame! Blank = use the phase tracking frame of vis.
917
- # print('##################',field_and_source_xds)
918
959
  if is_single_dish:
919
960
  xds.UVW.attrs["frame"] = field_and_source_xds[
920
961
  "FIELD_REFERENCE_CENTER"
@@ -935,11 +976,20 @@ def convert_and_write_partition(
935
976
 
936
977
  file_name = os.path.join(
937
978
  out_file,
938
- out_file.replace(".vis.zarr", "").replace(".zarr", "").split("/")[-1]
979
+ pathlib.Path(out_file)
980
+ .name.replace(".vis.zarr", "")
981
+ .replace(".zarr", "")
939
982
  + "_"
940
983
  + str(ms_v4_id),
941
984
  )
942
985
 
986
+ if "line_name" in field_and_source_xds.coords:
987
+ line_name = to_list(
988
+ unique_1d(np.ravel(field_and_source_xds.line_name.values))
989
+ )
990
+ else:
991
+ line_name = []
992
+
943
993
  xds.attrs["partition_info"] = {
944
994
  # "spectral_window_id": xds.frequency.attrs["spectral_window_id"],
945
995
  "spectral_window_name": xds.frequency.attrs["spectral_window_name"],
@@ -948,11 +998,14 @@ def convert_and_write_partition(
948
998
  np.unique(field_and_source_xds.field_name.values)
949
999
  ),
950
1000
  # "source_id": to_list(unique_1d(source_id)),
1001
+ "line_name": line_name,
1002
+ "scan_number": to_list(np.unique(scan_id)),
951
1003
  "source_name": to_list(
952
1004
  np.unique(field_and_source_xds.source_name.values)
953
1005
  ),
954
1006
  "polarization_setup": to_list(xds.polarization.values),
955
- "obs_mode": obs_mode,
1007
+ "num_lines": num_lines,
1008
+ "obs_mode": obs_mode.split(","),
956
1009
  "taql": taql_where,
957
1010
  }
958
1011
 
@@ -968,7 +1021,7 @@ def convert_and_write_partition(
968
1021
  mode=mode,
969
1022
  )
970
1023
 
971
- if with_pointing:
1024
+ if with_pointing and len(pointing_xds.data_vars) > 1:
972
1025
  pointing_xds.to_zarr(
973
1026
  store=os.path.join(file_name, "POINTING"), mode=mode
974
1027
  )
@@ -986,70 +1039,43 @@ def convert_and_write_partition(
986
1039
  # logger.info("Saved ms_v4 " + file_name + " in " + str(time.time() - start_with) + "s")
987
1040
 
988
1041
 
989
- def antenna_ids_to_names(xds, ant_xds):
1042
+ def antenna_ids_to_names(
1043
+ xds: xr.Dataset, ant_xds: xr.Dataset, is_single_dish: bool
1044
+ ) -> xr.Dataset:
1045
+ ant_xds = ant_xds.set_xindex(
1046
+ "antenna_id"
1047
+ ) # Allows for non-dimension coordinate selection.
990
1048
 
991
- if ant_xds.attrs["overall_telescope_name"] in ["ALMA", "VLA", "NOEMA", "EVLA"]:
992
- moving_antennas = True
993
- else:
994
- moving_antennas = False
995
-
996
- if moving_antennas:
997
- if "baseline_antenna1_id" in xds: # Interferometer
998
-
999
- baseline_ant1_name = np.core.defchararray.add(
1000
- ant_xds["name"].sel(antenna_id=xds["baseline_antenna1_id"]).values, "_"
1001
- )
1002
- baseline_ant1_name = np.core.defchararray.add(
1003
- baseline_ant1_name,
1004
- ant_xds["station"].sel(antenna_id=xds["baseline_antenna1_id"]).values,
1005
- )
1006
- baseline_ant2_name = np.core.defchararray.add(
1007
- ant_xds["name"].sel(antenna_id=xds["baseline_antenna2_id"]).values, "_"
1008
- )
1009
- baseline_ant2_name = np.core.defchararray.add(
1010
- baseline_ant2_name,
1011
- ant_xds["station"].sel(antenna_id=xds["baseline_antenna2_id"]).values,
1012
- )
1013
-
1014
- xds["baseline_antenna1_id"] = xr.DataArray(
1015
- baseline_ant1_name, dims="baseline_id"
1016
- )
1017
- xds["baseline_antenna2_id"] = xr.DataArray(
1018
- baseline_ant2_name, dims="baseline_id"
1019
- )
1020
- xds = xds.rename(
1021
- {
1022
- "baseline_antenna1_id": "baseline_antenna1_name",
1023
- "baseline_antenna2_id": "baseline_antenna2_name",
1024
- }
1025
- )
1026
- else: # Single Dish
1027
- antenna_name = np.core.defchararray.add(
1028
- ant_xds["name"].sel(antenna_id=xds["antenna_id"]).values, "_"
1029
- )
1030
- antenna_name = np.core.defchararray.add(
1031
- antenna_name,
1032
- ant_xds["station"].sel(antenna_id=xds["antenna_id"]).values,
1033
- )
1034
- xds["antenna_id"] = xr.DataArray(antenna_name, dims="baseline_id")
1035
- xds = xds.rename({"antenna_id": "antenna_name"})
1049
+ if not is_single_dish: # Interferometer
1050
+ xds["baseline_antenna1_id"].data = ant_xds["antenna_name"].sel(
1051
+ antenna_id=xds["baseline_antenna1_id"].data
1052
+ )
1053
+ xds["baseline_antenna2_id"].data = ant_xds["antenna_name"].sel(
1054
+ antenna_id=xds["baseline_antenna2_id"].data
1055
+ )
1056
+ xds = xds.rename(
1057
+ {
1058
+ "baseline_antenna1_id": "baseline_antenna1_name",
1059
+ "baseline_antenna2_id": "baseline_antenna2_name",
1060
+ }
1061
+ )
1036
1062
  else:
1037
- if "baseline_antenna1_id" in xds: # Interferometer
1038
- xds["baseline_antenna1_id"] = ant_xds["name"].sel(
1039
- antenna_id=xds["baseline_antenna1_id"]
1040
- )
1041
- xds["baseline_antenna2_id"] = ant_xds["name"].sel(
1042
- antenna_id=xds["baseline_antenna2_id"]
1043
- )
1044
- xds = xds.rename(
1045
- {
1046
- "baseline_antenna1_id": "baseline_antenna1_name",
1047
- "baseline_antenna2_id": "baseline_antenna2_name",
1048
- }
1049
- )
1050
- else: # Single Dish
1051
- xds["antenna_id"] = ant_xds["name"].sel(antenna_id=xds["antenna_id"])
1052
- xds = xds.rename({"antenna_id": "antenna_name"})
1063
+ xds["baseline_id"] = ant_xds["antenna_name"].sel(antenna_id=xds["baseline_id"])
1064
+ unwanted_coords_from_ant_xds = [
1065
+ "antenna_id",
1066
+ "antenna_name",
1067
+ "mount",
1068
+ "station",
1069
+ ]
1070
+ for unwanted_coord in unwanted_coords_from_ant_xds:
1071
+ xds = xds.drop_vars(unwanted_coord)
1072
+ xds = xds.rename({"baseline_id": "antenna_name"})
1073
+
1074
+ # drop more vars that seem unwanted in main_sd_xds, but there shouuld be a better way
1075
+ # of not creating them in the first place
1076
+ unwanted_coords_sd = ["baseline_antenna1_id", "baseline_antenna2_id"]
1077
+ for unwanted_coord in unwanted_coords_sd:
1078
+ xds = xds.drop_vars(unwanted_coord)
1053
1079
 
1054
1080
  return xds
1055
1081