xradio 0.0.48__py3-none-any.whl → 0.0.49__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. xradio/__init__.py +1 -0
  2. xradio/_utils/dict_helpers.py +69 -2
  3. xradio/image/_util/__init__.py +0 -3
  4. xradio/image/_util/_casacore/common.py +0 -13
  5. xradio/image/_util/_casacore/xds_from_casacore.py +102 -97
  6. xradio/image/_util/_casacore/xds_to_casacore.py +36 -24
  7. xradio/image/_util/_fits/xds_from_fits.py +81 -36
  8. xradio/image/_util/_zarr/zarr_low_level.py +3 -3
  9. xradio/image/_util/casacore.py +7 -5
  10. xradio/image/_util/common.py +13 -26
  11. xradio/image/_util/image_factory.py +143 -191
  12. xradio/image/image.py +10 -59
  13. xradio/measurement_set/__init__.py +11 -6
  14. xradio/measurement_set/_utils/_msv2/_tables/read.py +187 -46
  15. xradio/measurement_set/_utils/_msv2/_tables/table_query.py +22 -0
  16. xradio/measurement_set/_utils/_msv2/conversion.py +351 -318
  17. xradio/measurement_set/_utils/_msv2/msv4_info_dicts.py +20 -17
  18. xradio/measurement_set/convert_msv2_to_processing_set.py +46 -6
  19. xradio/measurement_set/load_processing_set.py +100 -53
  20. xradio/measurement_set/measurement_set_xdt.py +197 -0
  21. xradio/measurement_set/open_processing_set.py +122 -86
  22. xradio/measurement_set/processing_set_xdt.py +1552 -0
  23. xradio/measurement_set/schema.py +199 -94
  24. xradio/schema/bases.py +5 -1
  25. xradio/schema/check.py +97 -5
  26. {xradio-0.0.48.dist-info → xradio-0.0.49.dist-info}/METADATA +4 -4
  27. {xradio-0.0.48.dist-info → xradio-0.0.49.dist-info}/RECORD +30 -30
  28. {xradio-0.0.48.dist-info → xradio-0.0.49.dist-info}/WHEEL +1 -1
  29. xradio/measurement_set/measurement_set_xds.py +0 -117
  30. xradio/measurement_set/processing_set.py +0 -803
  31. {xradio-0.0.48.dist-info → xradio-0.0.49.dist-info/licenses}/LICENSE.txt +0 -0
  32. {xradio-0.0.48.dist-info → xradio-0.0.49.dist-info}/top_level.txt +0 -0
@@ -11,12 +11,14 @@ import xarray as xr
11
11
 
12
12
  import toolviper.utils.logger as logger
13
13
  from casacore import tables
14
+
14
15
  from xradio.measurement_set._utils._msv2.msv4_sub_xdss import (
15
16
  create_pointing_xds,
16
17
  create_system_calibration_xds,
17
18
  create_weather_xds,
18
19
  )
19
20
  from .msv4_info_dicts import create_info_dicts
21
+ from xradio.measurement_set.schema import MSV4_SCHEMA_VERSION
20
22
  from xradio.measurement_set._utils._msv2.create_antenna_xds import (
21
23
  create_antenna_xds,
22
24
  create_gain_curve_xds,
@@ -34,11 +36,12 @@ from .msv2_to_msv4_meta import (
34
36
 
35
37
  from .._zarr.encoding import add_encoding
36
38
  from .subtables import subt_rename_ids
37
- from ._tables.table_query import open_table_ro, open_query
39
+ from ._tables.table_query import open_table_ro, open_query, TableManager
38
40
  from ._tables.read import (
39
41
  convert_casacore_time,
40
42
  extract_table_attributes,
41
- read_col_conversion,
43
+ read_col_conversion_numpy,
44
+ read_col_conversion_dask,
42
45
  load_generic_table,
43
46
  )
44
47
  from ._tables.read_main_table import get_baselines, get_baseline_indices, get_utimes_tol
@@ -377,6 +380,10 @@ def calc_used_gb(
377
380
 
378
381
  # TODO: if the didxs are not used in read_col_conversion, remove didxs from here (and convert_and_write_partition)
379
382
  def calc_indx_for_row_split(tb_tool, taql_where):
383
+ # Allow TableManager object to be used
384
+ if isinstance(tb_tool, TableManager):
385
+ tb_tool = tb_tool.get_table()
386
+
380
387
  baselines = get_baselines(tb_tool)
381
388
  col_names = tb_tool.colnames()
382
389
  cshapes = [
@@ -560,10 +567,42 @@ def find_min_max_times(tb_tool: tables.table, taql_where: str) -> tuple:
560
567
 
561
568
 
562
569
  def create_data_variables(
563
- in_file, xds, tb_tool, time_baseline_shape, tidxs, bidxs, didxs, use_table_iter
570
+ in_file,
571
+ xds,
572
+ table_manager,
573
+ time_baseline_shape,
574
+ tidxs,
575
+ bidxs,
576
+ didxs,
577
+ use_table_iter,
578
+ parallel_mode,
579
+ main_chunksize,
564
580
  ):
581
+
582
+ # Get time chunks
583
+ time_chunksize = None
584
+ if parallel_mode == "time":
585
+ try:
586
+ time_chunksize = main_chunksize["time"]
587
+ except KeyError:
588
+ # If time isn't chunked then `read_col_conversion_dask` is slower than `read_col_conversion_numpy`
589
+ logger.warning(
590
+ "'time' isn't specified in `main_chunksize`. Defaulting to `parallel_mode = 'none'`."
591
+ )
592
+ parallel_mode = "none"
593
+
594
+ # Set read_col_conversion from value of `parallel_mode` argument
595
+ # TODO: To make this compatible with multi-node conversion, `read_col_conversion_dask` and TableManager must be pickled.
596
+ # Casacore will make this difficult
597
+ global read_col_conversion
598
+ if parallel_mode == "time":
599
+ read_col_conversion = read_col_conversion_dask
600
+ else:
601
+ read_col_conversion = read_col_conversion_numpy
602
+
565
603
  # Create Data Variables
566
- col_names = tb_tool.colnames()
604
+ with table_manager.get_table() as tb_tool:
605
+ col_names = tb_tool.colnames()
567
606
 
568
607
  main_table_attrs = extract_table_attributes(in_file)
569
608
  main_column_descriptions = main_table_attrs["column_descriptions"]
@@ -577,22 +616,24 @@ def create_data_variables(
577
616
  xds = get_weight(
578
617
  xds,
579
618
  col,
580
- tb_tool,
619
+ table_manager,
581
620
  time_baseline_shape,
582
621
  tidxs,
583
622
  bidxs,
584
623
  use_table_iter,
585
624
  main_column_descriptions,
625
+ time_chunksize,
586
626
  )
587
627
  else:
588
628
  xds[col_to_data_variable_names[col]] = xr.DataArray(
589
629
  read_col_conversion(
590
- tb_tool,
630
+ table_manager,
591
631
  col,
592
632
  time_baseline_shape,
593
633
  tidxs,
594
634
  bidxs,
595
635
  use_table_iter,
636
+ time_chunksize,
596
637
  ),
597
638
  dims=col_dims[col],
598
639
  )
@@ -613,12 +654,13 @@ def create_data_variables(
613
654
  xds = get_weight(
614
655
  xds,
615
656
  "WEIGHT",
616
- tb_tool,
657
+ table_manager,
617
658
  time_baseline_shape,
618
659
  tidxs,
619
660
  bidxs,
620
661
  use_table_iter,
621
662
  main_column_descriptions,
663
+ time_chunksize,
622
664
  )
623
665
 
624
666
 
@@ -651,22 +693,24 @@ def add_missing_data_var_attrs(xds):
651
693
  def get_weight(
652
694
  xds,
653
695
  col,
654
- tb_tool,
696
+ table_manager,
655
697
  time_baseline_shape,
656
698
  tidxs,
657
699
  bidxs,
658
700
  use_table_iter,
659
701
  main_column_descriptions,
702
+ time_chunksize,
660
703
  ):
661
704
  xds[col_to_data_variable_names[col]] = xr.DataArray(
662
705
  np.tile(
663
706
  read_col_conversion(
664
- tb_tool,
707
+ table_manager,
665
708
  col,
666
709
  time_baseline_shape,
667
710
  tidxs,
668
711
  bidxs,
669
712
  use_table_iter,
713
+ time_chunksize,
670
714
  )[:, :, None, :],
671
715
  (1, 1, xds.sizes["frequency"], 1),
672
716
  ),
@@ -931,6 +975,7 @@ def convert_and_write_partition(
931
975
  sys_cal_interpolate: bool = False,
932
976
  compressor: numcodecs.abc.Codec = numcodecs.Zstd(level=2),
933
977
  storage_backend="zarr",
978
+ parallel_mode: str = "none",
934
979
  overwrite: bool = False,
935
980
  ):
936
981
  """_summary_
@@ -967,6 +1012,8 @@ def convert_and_write_partition(
967
1012
  _description_, by default numcodecs.Zstd(level=2)
968
1013
  storage_backend : str, optional
969
1014
  _description_, by default "zarr"
1015
+ parallel_mode : _type_, optional
1016
+ _description_
970
1017
  overwrite : bool, optional
971
1018
  _description_, by default False
972
1019
 
@@ -976,363 +1023,347 @@ def convert_and_write_partition(
976
1023
  _description_
977
1024
  """
978
1025
 
1026
+ ms_xdt = xr.DataTree() # MSv4 as a Data Tree
1027
+
979
1028
  taql_where = create_taql_query_where(partition_info)
1029
+ table_manager = TableManager(in_file, taql_where)
980
1030
  ddi = partition_info["DATA_DESC_ID"][0]
981
1031
  intents = str(partition_info["OBS_MODE"][0])
982
1032
 
983
1033
  start = time.time()
984
- with open_table_ro(in_file) as mtable:
985
- taql_main = f"select * from $mtable {taql_where}"
986
- with open_query(mtable, taql_main) as tb_tool:
987
-
988
- if tb_tool.nrows() == 0:
989
- tb_tool.close()
990
- mtable.close()
991
- return xr.Dataset(), {}, {}
1034
+ with table_manager.get_table() as tb_tool:
1035
+ if tb_tool.nrows() == 0:
1036
+ tb_tool.close()
1037
+ return xr.Dataset(), {}, {}
1038
+
1039
+ logger.debug("Starting a real convert_and_write_partition")
1040
+ (
1041
+ tidxs,
1042
+ bidxs,
1043
+ didxs,
1044
+ baseline_ant1_id,
1045
+ baseline_ant2_id,
1046
+ utime,
1047
+ ) = calc_indx_for_row_split(tb_tool, taql_where)
1048
+ time_baseline_shape = (len(utime), len(baseline_ant1_id))
1049
+ logger.debug("Calc indx for row split " + str(time.time() - start))
1050
+
1051
+ observation_id = check_if_consistent(
1052
+ tb_tool.getcol("OBSERVATION_ID"), "OBSERVATION_ID"
1053
+ )
992
1054
 
993
- logger.debug("Starting a real convert_and_write_partition")
994
- (
995
- tidxs,
996
- bidxs,
997
- didxs,
998
- baseline_ant1_id,
999
- baseline_ant2_id,
1000
- utime,
1001
- ) = calc_indx_for_row_split(tb_tool, taql_where)
1002
- time_baseline_shape = (len(utime), len(baseline_ant1_id))
1003
- logger.debug("Calc indx for row split " + str(time.time() - start))
1004
-
1005
- observation_id = check_if_consistent(
1006
- tb_tool.getcol("OBSERVATION_ID"), "OBSERVATION_ID"
1055
+ def get_observation_info(in_file, observation_id, intents):
1056
+ generic_observation_xds = load_generic_table(
1057
+ in_file,
1058
+ "OBSERVATION",
1059
+ taql_where=f" where (ROWID() IN [{str(observation_id)}])",
1007
1060
  )
1008
1061
 
1009
- def get_observation_info(in_file, observation_id, intents):
1010
- generic_observation_xds = load_generic_table(
1011
- in_file,
1012
- "OBSERVATION",
1013
- taql_where=f" where (ROWID() IN [{str(observation_id)}])",
1014
- )
1015
-
1016
- if intents == "None":
1017
- intents = "obs_" + str(observation_id)
1062
+ if intents == "None":
1063
+ intents = "obs_" + str(observation_id)
1064
+
1065
+ return generic_observation_xds["TELESCOPE_NAME"].values[0], intents
1066
+
1067
+ telescope_name, intents = get_observation_info(in_file, observation_id, intents)
1068
+
1069
+ start = time.time()
1070
+ xds = xr.Dataset(
1071
+ attrs={
1072
+ "schema_version": MSV4_SCHEMA_VERSION,
1073
+ "creator": {
1074
+ "software_name": "xradio",
1075
+ "version": importlib.metadata.version("xradio"),
1076
+ },
1077
+ "creation_date": datetime.datetime.now(
1078
+ datetime.timezone.utc
1079
+ ).isoformat(),
1080
+ "type": "visibility",
1081
+ }
1082
+ )
1018
1083
 
1019
- return generic_observation_xds["TELESCOPE_NAME"].values[0], intents
1084
+ # interval = check_if_consistent(tb_tool.getcol("INTERVAL"), "INTERVAL")
1085
+ interval = tb_tool.getcol("INTERVAL")
1020
1086
 
1021
- telescope_name, intents = get_observation_info(
1022
- in_file, observation_id, intents
1023
- )
1024
-
1025
- start = time.time()
1026
- xds = xr.Dataset(
1027
- attrs={
1028
- "creation_date": datetime.datetime.now(
1029
- datetime.timezone.utc
1030
- ).isoformat(),
1031
- "xradio_version": importlib.metadata.version("xradio"),
1032
- "schema_version": "4.0.-9989",
1033
- "type": "visibility",
1034
- }
1087
+ interval_unique = unique_1d(interval)
1088
+ if len(interval_unique) > 1:
1089
+ logger.debug(
1090
+ "Integration time (interval) not consitent in partition, using median."
1035
1091
  )
1092
+ interval = np.median(interval)
1093
+ else:
1094
+ interval = interval_unique[0]
1095
+
1096
+ scan_id = np.full(time_baseline_shape, -42, dtype=int)
1097
+ scan_id[tidxs, bidxs] = tb_tool.getcol("SCAN_NUMBER")
1098
+ scan_id = np.max(scan_id, axis=1)
1099
+
1100
+ xds = create_coordinates(
1101
+ xds,
1102
+ in_file,
1103
+ ddi,
1104
+ utime,
1105
+ interval,
1106
+ baseline_ant1_id,
1107
+ baseline_ant2_id,
1108
+ scan_id,
1109
+ )
1110
+ logger.debug("Time create coordinates " + str(time.time() - start))
1111
+
1112
+ start = time.time()
1113
+ main_chunksize = parse_chunksize(main_chunksize, "main", xds)
1114
+ create_data_variables(
1115
+ in_file,
1116
+ xds,
1117
+ table_manager,
1118
+ time_baseline_shape,
1119
+ tidxs,
1120
+ bidxs,
1121
+ didxs,
1122
+ use_table_iter,
1123
+ parallel_mode,
1124
+ main_chunksize,
1125
+ )
1036
1126
 
1037
- # interval = check_if_consistent(tb_tool.getcol("INTERVAL"), "INTERVAL")
1038
- interval = tb_tool.getcol("INTERVAL")
1127
+ # Add data_groups
1128
+ xds, is_single_dish = add_data_groups(xds)
1129
+ xds = add_missing_data_var_attrs(xds)
1039
1130
 
1040
- interval_unique = unique_1d(interval)
1041
- if len(interval_unique) > 1:
1042
- logger.debug(
1043
- "Integration time (interval) not consitent in partition, using median."
1131
+ if (
1132
+ "WEIGHT" not in xds.data_vars
1133
+ ): # Some single dish datasets don't have WEIGHT.
1134
+ if is_single_dish:
1135
+ xds["WEIGHT"] = xr.DataArray(
1136
+ np.ones(xds.SPECTRUM.shape, dtype=np.float64),
1137
+ dims=xds.SPECTRUM.dims,
1044
1138
  )
1045
- interval = np.median(interval)
1046
1139
  else:
1047
- interval = interval_unique[0]
1140
+ xds["WEIGHT"] = xr.DataArray(
1141
+ np.ones(xds.VISIBILITY.shape, dtype=np.float64),
1142
+ dims=xds.VISIBILITY.dims,
1143
+ )
1048
1144
 
1049
- scan_id = np.full(time_baseline_shape, -42, dtype=int)
1050
- scan_id[tidxs, bidxs] = tb_tool.getcol("SCAN_NUMBER")
1051
- scan_id = np.max(scan_id, axis=1)
1145
+ logger.debug("Time create data variables " + str(time.time() - start))
1052
1146
 
1053
- xds = create_coordinates(
1054
- xds,
1055
- in_file,
1056
- ddi,
1057
- utime,
1058
- interval,
1059
- baseline_ant1_id,
1060
- baseline_ant2_id,
1061
- scan_id,
1062
- )
1063
- logger.debug("Time create coordinates " + str(time.time() - start))
1147
+ # To constrain the time range to load (in pointing, ephemerides, phase_cal data_vars)
1148
+ time_min_max = find_min_max_times(tb_tool, taql_where)
1064
1149
 
1065
- start = time.time()
1066
- create_data_variables(
1067
- in_file,
1068
- xds,
1069
- tb_tool,
1070
- time_baseline_shape,
1071
- tidxs,
1072
- bidxs,
1073
- didxs,
1074
- use_table_iter,
1150
+ # Create ant_xds
1151
+ start = time.time()
1152
+ feed_id = unique_1d(
1153
+ np.concatenate(
1154
+ [
1155
+ unique_1d(tb_tool.getcol("FEED1")),
1156
+ unique_1d(tb_tool.getcol("FEED2")),
1157
+ ]
1075
1158
  )
1159
+ )
1160
+ antenna_id = unique_1d(
1161
+ np.concatenate(
1162
+ [xds["baseline_antenna1_id"].data, xds["baseline_antenna2_id"].data]
1163
+ )
1164
+ )
1076
1165
 
1077
- # Add data_groups
1078
- xds, is_single_dish = add_data_groups(xds)
1079
- xds = add_missing_data_var_attrs(xds)
1080
-
1081
- if (
1082
- "WEIGHT" not in xds.data_vars
1083
- ): # Some single dish datasets don't have WEIGHT.
1084
- if is_single_dish:
1085
- xds["WEIGHT"] = xr.DataArray(
1086
- np.ones(xds.SPECTRUM.shape, dtype=np.float64),
1087
- dims=xds.SPECTRUM.dims,
1088
- )
1089
- else:
1090
- xds["WEIGHT"] = xr.DataArray(
1091
- np.ones(xds.VISIBILITY.shape, dtype=np.float64),
1092
- dims=xds.VISIBILITY.dims,
1093
- )
1094
-
1095
- logger.debug("Time create data variables " + str(time.time() - start))
1096
-
1097
- # To constrain the time range to load (in pointing, ephemerides, phase_cal data_vars)
1098
- time_min_max = find_min_max_times(tb_tool, taql_where)
1166
+ ant_xds = create_antenna_xds(
1167
+ in_file,
1168
+ xds.frequency.attrs["spectral_window_id"],
1169
+ antenna_id,
1170
+ feed_id,
1171
+ telescope_name,
1172
+ xds.polarization,
1173
+ )
1174
+ logger.debug("Time antenna xds " + str(time.time() - start))
1099
1175
 
1100
- # Create ant_xds
1101
- start = time.time()
1102
- feed_id = unique_1d(
1103
- np.concatenate(
1104
- [
1105
- unique_1d(tb_tool.getcol("FEED1")),
1106
- unique_1d(tb_tool.getcol("FEED2")),
1107
- ]
1108
- )
1109
- )
1110
- antenna_id = unique_1d(
1111
- np.concatenate(
1112
- [xds["baseline_antenna1_id"].data, xds["baseline_antenna2_id"].data]
1113
- )
1114
- )
1176
+ start = time.time()
1177
+ gain_curve_xds = create_gain_curve_xds(
1178
+ in_file, xds.frequency.attrs["spectral_window_id"], ant_xds
1179
+ )
1180
+ logger.debug("Time gain_curve xds " + str(time.time() - start))
1115
1181
 
1116
- ant_xds = create_antenna_xds(
1117
- in_file,
1118
- xds.frequency.attrs["spectral_window_id"],
1119
- antenna_id,
1120
- feed_id,
1121
- telescope_name,
1122
- xds.polarization,
1123
- )
1124
- logger.debug("Time antenna xds " + str(time.time() - start))
1182
+ start = time.time()
1183
+ if phase_cal_interpolate:
1184
+ phase_cal_interp_time = xds.time.values
1185
+ else:
1186
+ phase_cal_interp_time = None
1187
+ phase_calibration_xds = create_phase_calibration_xds(
1188
+ in_file,
1189
+ xds.frequency.attrs["spectral_window_id"],
1190
+ ant_xds,
1191
+ time_min_max,
1192
+ phase_cal_interp_time,
1193
+ )
1194
+ logger.debug("Time phase_calibration xds " + str(time.time() - start))
1125
1195
 
1126
- start = time.time()
1127
- gain_curve_xds = create_gain_curve_xds(
1128
- in_file, xds.frequency.attrs["spectral_window_id"], ant_xds
1129
- )
1130
- logger.debug("Time gain_curve xds " + str(time.time() - start))
1196
+ # Create system_calibration_xds
1197
+ start = time.time()
1198
+ if sys_cal_interpolate:
1199
+ sys_cal_interp_time = xds.time.values
1200
+ else:
1201
+ sys_cal_interp_time = None
1202
+ system_calibration_xds = create_system_calibration_xds(
1203
+ in_file,
1204
+ xds.frequency,
1205
+ ant_xds,
1206
+ sys_cal_interp_time,
1207
+ )
1208
+ logger.debug("Time system_calibation " + str(time.time() - start))
1131
1209
 
1210
+ # Change antenna_ids to antenna_names
1211
+ with_antenna_partitioning = "ANTENNA1" in partition_info
1212
+ xds = antenna_ids_to_names(
1213
+ xds, ant_xds, is_single_dish, with_antenna_partitioning
1214
+ )
1215
+ # but before, keep the name-id arrays, we need them for the pointing and weather xds
1216
+ ant_xds_name_ids = ant_xds["antenna_name"].set_xindex("antenna_id")
1217
+ ant_xds_station_name_ids = ant_xds["station"].set_xindex("antenna_id")
1218
+ # No longer needed after converting to name.
1219
+ ant_xds = ant_xds.drop_vars("antenna_id")
1220
+
1221
+ # Create weather_xds
1222
+ start = time.time()
1223
+ weather_xds = create_weather_xds(in_file, ant_xds_station_name_ids)
1224
+ logger.debug("Time weather " + str(time.time() - start))
1225
+
1226
+ # Create pointing_xds
1227
+ pointing_xds = xr.Dataset()
1228
+ if with_pointing:
1132
1229
  start = time.time()
1133
- if phase_cal_interpolate:
1134
- phase_cal_interp_time = xds.time.values
1230
+ if pointing_interpolate:
1231
+ pointing_interp_time = xds.time
1135
1232
  else:
1136
- phase_cal_interp_time = None
1137
- phase_calibration_xds = create_phase_calibration_xds(
1138
- in_file,
1139
- xds.frequency.attrs["spectral_window_id"],
1140
- ant_xds,
1141
- time_min_max,
1142
- phase_cal_interp_time,
1233
+ pointing_interp_time = None
1234
+ pointing_xds = create_pointing_xds(
1235
+ in_file, ant_xds_name_ids, time_min_max, pointing_interp_time
1143
1236
  )
1144
- logger.debug("Time phase_calibration xds " + str(time.time() - start))
1145
-
1146
- # Create system_calibration_xds
1147
- start = time.time()
1148
- if sys_cal_interpolate:
1149
- sys_cal_interp_time = xds.time.values
1150
- else:
1151
- sys_cal_interp_time = None
1152
- system_calibration_xds = create_system_calibration_xds(
1153
- in_file,
1154
- xds.frequency,
1155
- ant_xds,
1156
- sys_cal_interp_time,
1237
+ pointing_chunksize = parse_chunksize(
1238
+ pointing_chunksize, "pointing", pointing_xds
1157
1239
  )
1158
- logger.debug("Time system_calibation " + str(time.time() - start))
1159
-
1160
- # Change antenna_ids to antenna_names
1161
- with_antenna_partitioning = "ANTENNA1" in partition_info
1162
- xds = antenna_ids_to_names(
1163
- xds, ant_xds, is_single_dish, with_antenna_partitioning
1240
+ add_encoding(pointing_xds, compressor=compressor, chunks=pointing_chunksize)
1241
+ logger.debug(
1242
+ "Time pointing (with add compressor and chunking) "
1243
+ + str(time.time() - start)
1164
1244
  )
1165
- # but before, keep the name-id arrays, we need them for the pointing and weather xds
1166
- ant_xds_name_ids = ant_xds["antenna_name"].set_xindex("antenna_id")
1167
- ant_xds_station_name_ids = ant_xds["station"].set_xindex("antenna_id")
1168
- # No longer needed after converting to name.
1169
- ant_xds = ant_xds.drop_vars("antenna_id")
1170
-
1171
- # Create weather_xds
1172
- start = time.time()
1173
- weather_xds = create_weather_xds(in_file, ant_xds_station_name_ids)
1174
- logger.debug("Time weather " + str(time.time() - start))
1175
-
1176
- # Create pointing_xds
1177
- pointing_xds = xr.Dataset()
1178
- if with_pointing:
1179
- start = time.time()
1180
- if pointing_interpolate:
1181
- pointing_interp_time = xds.time
1182
- else:
1183
- pointing_interp_time = None
1184
- pointing_xds = create_pointing_xds(
1185
- in_file, ant_xds_name_ids, time_min_max, pointing_interp_time
1186
- )
1187
- pointing_chunksize = parse_chunksize(
1188
- pointing_chunksize, "pointing", pointing_xds
1189
- )
1190
- add_encoding(
1191
- pointing_xds, compressor=compressor, chunks=pointing_chunksize
1192
- )
1193
- logger.debug(
1194
- "Time pointing (with add compressor and chunking) "
1195
- + str(time.time() - start)
1196
- )
1197
1245
 
1198
- start = time.time()
1246
+ start = time.time()
1199
1247
 
1200
- # Time and frequency should always be increasing
1201
- if len(xds.frequency) > 1 and xds.frequency[1] - xds.frequency[0] < 0:
1202
- xds = xds.sel(frequency=slice(None, None, -1))
1248
+ # Time and frequency should always be increasing
1249
+ if len(xds.frequency) > 1 and xds.frequency[1] - xds.frequency[0] < 0:
1250
+ xds = xds.sel(frequency=slice(None, None, -1))
1203
1251
 
1204
- if len(xds.time) > 1 and xds.time[1] - xds.time[0] < 0:
1205
- xds = xds.sel(time=slice(None, None, -1))
1252
+ if len(xds.time) > 1 and xds.time[1] - xds.time[0] < 0:
1253
+ xds = xds.sel(time=slice(None, None, -1))
1206
1254
 
1207
- # Create field_and_source_xds (combines field, source and ephemeris data into one super dataset)
1208
- start = time.time()
1209
- if ephemeris_interpolate:
1210
- ephemeris_interp_time = xds.time.values
1211
- else:
1212
- ephemeris_interp_time = None
1213
-
1214
- # if "FIELD_ID" not in partition_scheme:
1215
- # field_id = np.full(time_baseline_shape, -42, dtype=int)
1216
- # field_id[tidxs, bidxs] = tb_tool.getcol("FIELD_ID")
1217
- # field_id = np.max(field_id, axis=1)
1218
- # field_times = utime
1219
- # else:
1220
- # field_id = check_if_consistent(tb_tool.getcol("FIELD_ID"), "FIELD_ID")
1221
- # field_times = None
1222
-
1223
- field_id = np.full(
1224
- time_baseline_shape, -42, dtype=int
1225
- ) # -42 used for missing baselines
1226
- field_id[tidxs, bidxs] = tb_tool.getcol("FIELD_ID")
1227
- field_id = np.max(field_id, axis=1)
1228
- field_times = xds.time.values
1229
-
1230
- # col_unique = unique_1d(col)
1231
- # assert len(col_unique) == 1, col_name + " is not consistent."
1232
- # return col_unique[0]
1233
-
1234
- field_and_source_xds, source_id, _num_lines, field_names = (
1235
- create_field_and_source_xds(
1236
- in_file,
1237
- field_id,
1238
- xds.frequency.attrs["spectral_window_id"],
1239
- field_times,
1240
- is_single_dish,
1241
- time_min_max,
1242
- ephemeris_interpolate,
1243
- )
1255
+ # Create field_and_source_xds (combines field, source and ephemeris data into one super dataset)
1256
+ start = time.time()
1257
+ if ephemeris_interpolate:
1258
+ ephemeris_interp_time = xds.time.values
1259
+ else:
1260
+ ephemeris_interp_time = None
1261
+
1262
+ # if "FIELD_ID" not in partition_scheme:
1263
+ # field_id = np.full(time_baseline_shape, -42, dtype=int)
1264
+ # field_id[tidxs, bidxs] = tb_tool.getcol("FIELD_ID")
1265
+ # field_id = np.max(field_id, axis=1)
1266
+ # field_times = utime
1267
+ # else:
1268
+ # field_id = check_if_consistent(tb_tool.getcol("FIELD_ID"), "FIELD_ID")
1269
+ # field_times = None
1270
+
1271
+ field_id = np.full(
1272
+ time_baseline_shape, -42, dtype=int
1273
+ ) # -42 used for missing baselines
1274
+ field_id[tidxs, bidxs] = tb_tool.getcol("FIELD_ID")
1275
+ field_id = np.max(field_id, axis=1)
1276
+ field_times = xds.time.values
1277
+
1278
+ # col_unique = unique_1d(col)
1279
+ # assert len(col_unique) == 1, col_name + " is not consistent."
1280
+ # return col_unique[0]
1281
+
1282
+ field_and_source_xds, source_id, _num_lines, field_names = (
1283
+ create_field_and_source_xds(
1284
+ in_file,
1285
+ field_id,
1286
+ xds.frequency.attrs["spectral_window_id"],
1287
+ field_times,
1288
+ is_single_dish,
1289
+ time_min_max,
1290
+ ephemeris_interpolate,
1244
1291
  )
1292
+ )
1245
1293
 
1246
- logger.debug("Time field_and_source_xds " + str(time.time() - start))
1247
-
1248
- xds = fix_uvw_frame(xds, field_and_source_xds, is_single_dish)
1249
- xds = xds.assign_coords({"field_name": ("time", field_names)})
1250
-
1251
- partition_info_misc_fields = {
1252
- "scan_name": xds.coords["scan_name"].data,
1253
- "intents": intents,
1254
- "taql_where": taql_where,
1255
- }
1256
- if with_antenna_partitioning:
1257
- partition_info_misc_fields["antenna_name"] = xds.coords[
1258
- "antenna_name"
1259
- ].data[0]
1260
- info_dicts = create_info_dicts(
1261
- in_file, xds, field_and_source_xds, partition_info_misc_fields, tb_tool
1262
- )
1263
- xds.attrs.update(info_dicts)
1294
+ logger.debug("Time field_and_source_xds " + str(time.time() - start))
1295
+
1296
+ xds = fix_uvw_frame(xds, field_and_source_xds, is_single_dish)
1297
+ xds = xds.assign_coords({"field_name": ("time", field_names)})
1298
+
1299
+ partition_info_misc_fields = {
1300
+ "scan_name": xds.coords["scan_name"].data,
1301
+ "intents": intents,
1302
+ "taql_where": taql_where,
1303
+ }
1304
+ if with_antenna_partitioning:
1305
+ partition_info_misc_fields["antenna_name"] = xds.coords[
1306
+ "antenna_name"
1307
+ ].data[0]
1308
+ info_dicts = create_info_dicts(
1309
+ in_file, xds, field_and_source_xds, partition_info_misc_fields, tb_tool
1310
+ )
1311
+ xds.attrs.update(info_dicts)
1264
1312
 
1265
- # xds ready, prepare to write
1266
- start = time.time()
1267
- main_chunksize = parse_chunksize(main_chunksize, "main", xds)
1268
- add_encoding(xds, compressor=compressor, chunks=main_chunksize)
1269
- logger.debug("Time add compressor and chunk " + str(time.time() - start))
1313
+ # xds ready, prepare to write
1314
+ start = time.time()
1315
+ add_encoding(xds, compressor=compressor, chunks=main_chunksize)
1316
+ logger.debug("Time add compressor and chunk " + str(time.time() - start))
1270
1317
 
1271
- file_name = os.path.join(
1272
- out_file,
1273
- pathlib.Path(in_file).name.replace(".ms", "") + "_" + str(ms_v4_id),
1274
- )
1318
+ file_name = os.path.join(
1319
+ out_file,
1320
+ pathlib.Path(in_file).name.replace(".ms", "") + "_" + str(ms_v4_id),
1321
+ )
1275
1322
 
1276
- if overwrite:
1277
- mode = "w"
1278
- else:
1279
- mode = "w-"
1323
+ if overwrite:
1324
+ mode = "w"
1325
+ else:
1326
+ mode = "w-"
1280
1327
 
1281
- if is_single_dish:
1282
- xds.attrs["type"] = "spectrum"
1283
- xds = xds.drop_vars(["UVW"])
1284
- del xds["uvw_label"]
1328
+ if is_single_dish:
1329
+ xds.attrs["type"] = "spectrum"
1330
+ xds = xds.drop_vars("UVW")
1331
+ xds = xds.drop_dims("uvw_label")
1332
+ else:
1333
+ if xds.attrs["processor_info"]["type"] == "RADIOMETER":
1334
+ xds.attrs["type"] = "radiometer"
1285
1335
  else:
1286
- if any("WVR" in s for s in intents):
1287
- xds.attrs["type"] = "wvr"
1288
- else:
1289
- xds.attrs["type"] = "visibility"
1336
+ xds.attrs["type"] = "visibility"
1290
1337
 
1291
- import sys
1338
+ start = time.time()
1339
+ ms_v4_name = pathlib.Path(in_file).name.replace(".ms", "") + "_" + str(ms_v4_id)
1340
+ ms_xdt.ds = xds
1292
1341
 
1293
- start = time.time()
1294
- if storage_backend == "zarr":
1295
- xds.to_zarr(store=os.path.join(file_name, "correlated_xds"), mode=mode)
1296
- ant_xds.to_zarr(store=os.path.join(file_name, "antenna_xds"), mode=mode)
1297
- for group_name in xds.attrs["data_groups"]:
1298
- field_and_source_xds.to_zarr(
1299
- store=os.path.join(
1300
- file_name, f"field_and_source_xds_{group_name}"
1301
- ),
1302
- mode=mode,
1303
- )
1342
+ ms_xdt["/antenna_xds"] = ant_xds
1343
+ for group_name in xds.attrs["data_groups"]:
1344
+ ms_xdt["/" + f"field_and_source_xds_{group_name}"] = field_and_source_xds
1304
1345
 
1305
- if with_pointing and len(pointing_xds.data_vars) > 0:
1306
- pointing_xds.to_zarr(
1307
- store=os.path.join(file_name, "pointing_xds"), mode=mode
1308
- )
1346
+ if with_pointing and len(pointing_xds.data_vars) > 0:
1347
+ ms_xdt["/pointing_xds"] = pointing_xds
1309
1348
 
1310
- if system_calibration_xds:
1311
- system_calibration_xds.to_zarr(
1312
- store=os.path.join(file_name, "system_calibration_xds"),
1313
- mode=mode,
1314
- )
1349
+ if system_calibration_xds:
1350
+ ms_xdt["/system_calibration_xds"] = system_calibration_xds
1315
1351
 
1316
- if gain_curve_xds:
1317
- gain_curve_xds.to_zarr(
1318
- store=os.path.join(file_name, "gain_curve_xds"), mode=mode
1319
- )
1352
+ if gain_curve_xds:
1353
+ ms_xdt["/gain_curve_xds"] = gain_curve_xds
1320
1354
 
1321
- if phase_calibration_xds:
1322
- phase_calibration_xds.to_zarr(
1323
- store=os.path.join(file_name, "phase_calibration_xds"),
1324
- mode=mode,
1325
- )
1355
+ if phase_calibration_xds:
1356
+ ms_xdt["/phase_calibration_xds"] = phase_calibration_xds
1326
1357
 
1327
- if weather_xds:
1328
- weather_xds.to_zarr(
1329
- store=os.path.join(file_name, "weather_xds"), mode=mode
1330
- )
1358
+ if weather_xds:
1359
+ ms_xdt["/weather_xds"] = weather_xds
1331
1360
 
1332
- elif storage_backend == "netcdf":
1333
- # xds.to_netcdf(path=file_name+"/MAIN", mode=mode) #Does not work
1334
- raise
1335
- logger.debug("Write data " + str(time.time() - start))
1361
+ if storage_backend == "zarr":
1362
+ ms_xdt.to_zarr(store=os.path.join(out_file, ms_v4_name))
1363
+ elif storage_backend == "netcdf":
1364
+ # xds.to_netcdf(path=file_name+"/MAIN", mode=mode) #Does not work
1365
+ raise
1366
+ logger.debug("Write data " + str(time.time() - start))
1336
1367
 
1337
1368
  # logger.info("Saved ms_v4 " + file_name + " in " + str(time.time() - start_with) + "s")
1338
1369
 
@@ -1437,6 +1468,8 @@ def add_group_to_data_groups(
1437
1468
  "correlated_data": correlated_data_name,
1438
1469
  "flag": "FLAG",
1439
1470
  "weight": "WEIGHT",
1471
+ "description": f"Data group derived from the data column '{correlated_data_name}' of an MSv2 converted to MSv4",
1472
+ "date": datetime.datetime.now(datetime.timezone.utc).isoformat(),
1440
1473
  }
1441
1474
  if uvw:
1442
1475
  data_groups[what_group]["uvw"] = "UVW"