xradio 0.0.31__py3-none-any.whl → 0.0.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. xradio/_utils/list_and_array.py +5 -3
  2. xradio/vis/__init__.py +3 -5
  3. xradio/vis/_processing_set.py +3 -3
  4. xradio/vis/_vis_utils/_ms/_tables/load_main_table.py +4 -4
  5. xradio/vis/_vis_utils/_ms/_tables/read.py +57 -41
  6. xradio/vis/_vis_utils/_ms/_tables/read_main_table.py +17 -18
  7. xradio/vis/_vis_utils/_ms/_tables/read_subtables.py +5 -5
  8. xradio/vis/_vis_utils/_ms/_tables/write.py +2 -4
  9. xradio/vis/_vis_utils/_ms/_tables/write_exp_api.py +19 -13
  10. xradio/vis/_vis_utils/_ms/chunks.py +5 -72
  11. xradio/vis/_vis_utils/_ms/conversion.py +238 -55
  12. xradio/vis/_vis_utils/_ms/{_tables/create_field_and_source_xds.py → create_field_and_source_xds.py} +114 -85
  13. xradio/vis/_vis_utils/_ms/descr.py +8 -8
  14. xradio/vis/_vis_utils/_ms/msv4_sub_xdss.py +249 -77
  15. xradio/vis/_vis_utils/_ms/partition_queries.py +19 -185
  16. xradio/vis/_vis_utils/_ms/partitions.py +18 -22
  17. xradio/vis/_vis_utils/_ms/subtables.py +2 -2
  18. xradio/vis/_vis_utils/_utils/partition_attrs.py +2 -2
  19. xradio/vis/_vis_utils/_utils/xds_helper.py +12 -12
  20. xradio/vis/_vis_utils/ms.py +1 -43
  21. xradio/vis/_vis_utils/zarr.py +0 -1
  22. xradio/vis/convert_msv2_to_processing_set.py +8 -1
  23. xradio/vis/load_processing_set.py +0 -3
  24. xradio/vis/read_processing_set.py +2 -2
  25. {xradio-0.0.31.dist-info → xradio-0.0.34.dist-info}/METADATA +1 -1
  26. {xradio-0.0.31.dist-info → xradio-0.0.34.dist-info}/RECORD +29 -31
  27. {xradio-0.0.31.dist-info → xradio-0.0.34.dist-info}/WHEEL +1 -1
  28. xradio/vis/_vis_utils/ms_column_descriptions_dicts.py +0 -1360
  29. xradio/vis/vis_io.py +0 -146
  30. {xradio-0.0.31.dist-info → xradio-0.0.34.dist-info}/LICENSE.txt +0 -0
  31. {xradio-0.0.31.dist-info → xradio-0.0.34.dist-info}/top_level.txt +0 -0
@@ -25,9 +25,9 @@ def check_if_consistent(array: np.ndarray, array_name: str) -> np.ndarray:
25
25
 
26
26
  Parameters
27
27
  ----------
28
- col : _type_
28
+ array : _type_
29
29
  _description_
30
- col_name : _type_
30
+ array_name : _type_
31
31
  _description_
32
32
 
33
33
  Returns
@@ -68,7 +68,9 @@ def unique_1d(array: np.ndarray) -> np.ndarray:
68
68
  if array.ndim == 0:
69
69
  return np.array([array.item()])
70
70
 
71
- return np.sort(pd.unique(array))
71
+ return np.sort(
72
+ pd.unique(array)
73
+ ) # Don't remove the sort! It will cause errors that are very difficult to detect. Specifically create_field_info_and_check_ephemeris has a TaQL query that requires this.
72
74
 
73
75
 
74
76
  def pairing_function(antenna_pairs: np.ndarray) -> np.ndarray:
xradio/vis/__init__.py CHANGED
@@ -2,14 +2,12 @@ from .read_processing_set import read_processing_set
2
2
  from .load_processing_set import load_processing_set
3
3
  from .convert_msv2_to_processing_set import convert_msv2_to_processing_set
4
4
 
5
- from .vis_io import read_vis, load_vis_block, write_vis
6
-
7
5
  from .schema import VisibilityXds
8
6
 
9
7
  __all__ = [
10
- "read_vis",
11
- "load_vis_block",
12
- "write_vis",
8
+ "read_processing_set",
9
+ "load_processing_set",
10
+ "convert_msv2_to_processing_set",
13
11
  "VisibilityXds",
14
12
  "PointingXds",
15
13
  "AntennaXds",
@@ -33,7 +33,7 @@ class processing_set(dict):
33
33
  "obs_mode": [],
34
34
  "shape": [],
35
35
  "polarization": [],
36
- "spw_id": [],
36
+ "spw_name": [],
37
37
  # "field_id": [],
38
38
  "field_name": [],
39
39
  # "source_id": [],
@@ -48,8 +48,8 @@ class processing_set(dict):
48
48
  for key, value in self.items():
49
49
  summary_data["name"].append(key)
50
50
  summary_data["obs_mode"].append(value.attrs["partition_info"]["obs_mode"])
51
- summary_data["spw_id"].append(
52
- value.attrs["partition_info"]["spectral_window_id"]
51
+ summary_data["spw_name"].append(
52
+ value.attrs["partition_info"]["spectral_window_name"]
53
53
  )
54
54
  summary_data["polarization"].append(value.polarization.values)
55
55
 
@@ -187,7 +187,7 @@ def load_ddi_cols_chunk(
187
187
  cell_shape = cdata.shape
188
188
  if len(cell_shape) == 0:
189
189
  col_dims = dims[:2]
190
- mvars[col.lower()] = xr.DataArray(
190
+ mvars[col] = xr.DataArray(
191
191
  load_col_chunk(
192
192
  tb_tool, col, (ctlen, cblen), tidxs, bidxs, didxs, None, None
193
193
  ),
@@ -196,7 +196,7 @@ def load_ddi_cols_chunk(
196
196
 
197
197
  elif col == "UVW":
198
198
  col_dims = dims[:2] + ["uvw_coords"]
199
- mvars[col.lower()] = xr.DataArray(
199
+ mvars[col] = xr.DataArray(
200
200
  load_col_chunk(
201
201
  tb_tool, col, (ctlen, cblen, 3), tidxs, bidxs, didxs, None, None
202
202
  ),
@@ -206,7 +206,7 @@ def load_ddi_cols_chunk(
206
206
  elif len(cell_shape) == 1:
207
207
  pols, col_dims = get_col_1d_pols(cell_shape, dims, chan_cnt, pol_cnt, chunk)
208
208
  cshape = (ctlen, cblen) + (pols[1] - pols[0] + 1,)
209
- mvars[col.lower()] = xr.DataArray(
209
+ mvars[col] = xr.DataArray(
210
210
  load_col_chunk(tb_tool, col, cshape, tidxs, bidxs, didxs, pols, None),
211
211
  dims=col_dims,
212
212
  )
@@ -215,7 +215,7 @@ def load_ddi_cols_chunk(
215
215
  chans, pols = get_col_2d_chans_pols(cell_shape, chan_cnt, pol_cnt, chunk)
216
216
  cshape = (ctlen, cblen) + (chans[1] - chans[0] + 1, pols[1] - pols[0] + 1)
217
217
  col_dims = dims
218
- mvars[col.lower()] = xr.DataArray(
218
+ mvars[col] = xr.DataArray(
219
219
  load_col_chunk(tb_tool, col, cshape, tidxs, bidxs, didxs, chans, pols),
220
220
  dims=col_dims,
221
221
  )
@@ -279,7 +279,10 @@ def add_units_measures(
279
279
  col_descrs = cc_attrs["column_descriptions"]
280
280
  # TODO: Should probably loop the other way around, over mvars
281
281
  for col in col_descrs:
282
- var_name = col.lower()
282
+ if col == "TIME":
283
+ var_name = "time"
284
+ else:
285
+ var_name = col
283
286
  if var_name in mvars and "keywords" in col_descrs[col]:
284
287
  if "QuantumUnits" in col_descrs[col]["keywords"]:
285
288
  cc_units = col_descrs[col]["keywords"]["QuantumUnits"]
@@ -364,12 +367,12 @@ def make_freq_attrs(spw_xds: xr.Dataset, spw_id: int) -> Dict[str, Any]:
364
367
  ctds_cols = spw_xds.attrs["other"]["msv2"]["ctds_attrs"]["column_descriptions"]
365
368
  cfreq = ctds_cols["CHAN_FREQ"]
366
369
 
367
- cf_attrs = spw_xds.chan_freq.attrs
370
+ cf_attrs = spw_xds.data_vars["CHAN_FREQ"].attrs
368
371
  if "MEASINFO" in cfreq["keywords"] and "VarRefCol" in cfreq["keywords"]["MEASINFO"]:
369
372
  fattrs = cfreq["keywords"]["MEASINFO"]
370
373
  var_ref_col = fattrs["VarRefCol"]
371
374
  # This should point to the SPW/MEAS_FREQ_REF col
372
- meas_freq_ref_idx = spw_xds.data_vars[var_ref_col.lower()].values[spw_id]
375
+ meas_freq_ref_idx = spw_xds.data_vars[var_ref_col].values[spw_id]
373
376
 
374
377
  if "TabRefCodes" not in fattrs or "TabRefTypes" not in fattrs:
375
378
  # Datasets like vla/ic2233_1.ms say "VarRefCol" but "TabRefTypes" is missing
@@ -384,8 +387,8 @@ def make_freq_attrs(spw_xds: xr.Dataset, spw_id: int) -> Dict[str, Any]:
384
387
  }
385
388
 
386
389
  # Also set the 'VarRefCol' for CHAN_FREQ and REF_FREQUENCEY
387
- spw_xds.data_vars["chan_freq"].attrs.update(cf_attrs)
388
- spw_xds.data_vars["ref_frequency"].attrs.update(cf_attrs)
390
+ spw_xds.data_vars["CHAN_FREQ"].attrs.update(cf_attrs)
391
+ spw_xds.data_vars["REF_FREQUENCY"].attrs.update(cf_attrs)
389
392
 
390
393
  return cf_attrs
391
394
 
@@ -440,18 +443,18 @@ def redimension_ms_subtable(xds: xr.Dataset, subt_name: str) -> xr.Dataset:
440
443
  (one dimension for every columns)
441
444
  """
442
445
  subt_key_cols = {
443
- "DOPPLER": ["doppler_id", "source_id"],
446
+ "DOPPLER": ["DOPPLER_ID", "SOURCE_ID"],
444
447
  "FREQ_OFFSET": [
445
- "antenna1",
446
- "antenna2",
447
- "feed_id",
448
- "spectral_window_id",
449
- "time",
448
+ "ANTENNA1",
449
+ "ANTENNA2",
450
+ "FEED_ID",
451
+ "SPECTRAL_WINDOW_ID",
452
+ "TIME",
450
453
  ],
451
- "POINTING": ["time", "antenna_id"],
452
- "SOURCE": ["source_id", "time", "spectral_window_id"],
453
- "SYSCAL": ["antenna_id", "feed_id", "spectral_window_id", "time"],
454
- "WEATHER": ["antenna_id", "time"],
454
+ "POINTING": ["TIME", "ANTENNA_ID"],
455
+ "SOURCE": ["SOURCE_ID", "TIME", "SPECTRAL_WINDOW_ID"],
456
+ "SYSCAL": ["ANTENNA_ID", "FEED_ID", "SPECTRAL_WINDOW_ID", "TIME"],
457
+ "WEATHER": ["ANTENNA_ID", "TIME"],
455
458
  # added tables (MSv3 but not preent in MSv2). Build it from "EPHEMi_... tables
456
459
  # Not clear what to do about 'time' var/dim: , "time"],
457
460
  "EPHEMERIDES": ["ephemeris_row_id", "ephemeris_id"],
@@ -476,10 +479,13 @@ def redimension_ms_subtable(xds: xr.Dataset, subt_name: str) -> xr.Dataset:
476
479
  # we need to reset to the original type.
477
480
  for var in rxds.data_vars:
478
481
  if rxds[var].dtype != xds[var].dtype:
479
- rxds[var] = rxds[var].astype(xds[var].dtype)
482
+ # beware of gaps/empty==nan values when redimensioning
483
+ with np.errstate(invalid="ignore"):
484
+ rxds[var] = rxds[var].astype(xds[var].dtype)
480
485
  except Exception as exc:
481
486
  logger.warning(
482
- f"Cannot expand rows to {key_dims}, possibly duplicate values in those coordinates. Exception: {exc}"
487
+ f"Cannot expand rows in table {subt_name} to {key_dims}, possibly duplicate values in those coordinates. "
488
+ f"Exception: {exc}"
483
489
  )
484
490
  rxds = xds.copy()
485
491
 
@@ -500,9 +506,9 @@ def add_ephemeris_vars(tname: str, xds: xr.Dataset) -> xr.Dataset:
500
506
  ephem_id = 0
501
507
 
502
508
  xds["ephemeris_id"] = np.uint32(ephem_id) * xr.ones_like(
503
- xds["mjd"], dtype=np.uint32
509
+ xds["MJD"], dtype=np.uint32
504
510
  )
505
- xds = xds.rename({"mjd": "time"})
511
+ xds = xds.rename({"MJD": "time"})
506
512
  xds["ephemeris_row_id"] = (
507
513
  xr.zeros_like(xds["time"], dtype=np.uint32) + xds["row"].values
508
514
  )
@@ -529,7 +535,7 @@ def is_nested_ms(attrs: Dict) -> bool:
529
535
  )
530
536
 
531
537
 
532
- def read_generic_table(
538
+ def load_generic_table(
533
539
  inpath: str,
534
540
  tname: str,
535
541
  timecols: Union[List[str], None] = None,
@@ -574,7 +580,7 @@ def read_generic_table(
574
580
  infile = str(infile.expanduser())
575
581
  if not os.path.isdir(infile):
576
582
  raise ValueError(
577
- f"invalid input filename to read_generic_table: {infile} table {tname}"
583
+ f"invalid input filename to load_generic_table: {infile} table {tname}"
578
584
  )
579
585
 
580
586
  cc_attrs = extract_table_attributes(infile)
@@ -632,7 +638,14 @@ def read_generic_table(
632
638
  )
633
639
  )
634
640
 
635
- if tname in ["DOPPLER", "FREQ_OFFSET", "POINTING", "SOURCE", "SYSCAL", "WEATHER"]:
641
+ if tname in [
642
+ "DOPPLER",
643
+ "FREQ_OFFSET",
644
+ "POINTING",
645
+ "SOURCE",
646
+ "SYSCAL",
647
+ "WEATHER",
648
+ ]:
636
649
  xds = redimension_ms_subtable(xds, tname)
637
650
 
638
651
  if is_ephem_subtable(tname):
@@ -788,9 +801,9 @@ def load_generic_cols(
788
801
  inpath, tb_tool, col, data, timecols
789
802
  )
790
803
  if array_type == "coord":
791
- mcoords[col.lower()] = array_data
804
+ mcoords[col] = array_data
792
805
  elif array_type == "data_var":
793
- mvars[col.lower()] = array_data
806
+ mvars[col] = array_data
794
807
 
795
808
  return mcoords, mvars
796
809
 
@@ -849,9 +862,9 @@ def load_fixed_size_cols(
849
862
  inpath, tb_tool, col, data, timecols
850
863
  )
851
864
  if array_type == "coord":
852
- mcoords[col.lower()] = array_data
865
+ mcoords[col] = array_data
853
866
  elif array_type == "data_var":
854
- mvars[col.lower()] = array_data
867
+ mvars[col] = array_data
855
868
 
856
869
  return mcoords, mvars
857
870
 
@@ -924,7 +937,6 @@ def raw_col_data_to_coords_vars(
924
937
  # Almost sure that when TIME is present (in a standard MS subt) it
925
938
  # is part of the key. But what about non-std subtables, ASDM subts?
926
939
  subts_with_time_key = (
927
- "FEED",
928
940
  "FLAG_CMD",
929
941
  "FREQ_OFFSET",
930
942
  "HISTORY",
@@ -1176,6 +1188,7 @@ def read_col_conversion(
1176
1188
  cshape: Tuple[int],
1177
1189
  tidxs: np.ndarray,
1178
1190
  bidxs: np.ndarray,
1191
+ use_table_iter: bool,
1179
1192
  ) -> np.ndarray:
1180
1193
  """
1181
1194
  Function to perform delayed reads from table columns when converting
@@ -1233,23 +1246,26 @@ def read_col_conversion(
1233
1246
  data = np.full(cshape + extra_dimensions, np.nan, dtype=col_dtype)
1234
1247
 
1235
1248
  # Use built-in casacore table iterator to populate the data column by unique times.
1236
- start_row = 0
1237
- for ts in tb_tool.iter("TIME", sort=False):
1238
- num_rows = ts.nrows()
1249
+ if use_table_iter:
1250
+ start_row = 0
1251
+ for ts in tb_tool.iter("TIME", sort=False):
1252
+ num_rows = ts.nrows()
1239
1253
 
1240
- # Create small temporary array to store the partial column
1241
- tmp_arr = np.full((num_rows,) + extra_dimensions, np.nan, dtype=col_dtype)
1254
+ # Create small temporary array to store the partial column
1255
+ tmp_arr = np.full((num_rows,) + extra_dimensions, np.nan, dtype=col_dtype)
1242
1256
 
1243
- # Note we don't use `getcol()` because it's less safe. See:
1244
- # https://github.com/casacore/python-casacore/issues/130#issuecomment-463202373
1245
- ts.getcolnp(col, tmp_arr)
1257
+ # Note we don't use `getcol()` because it's less safe. See:
1258
+ # https://github.com/casacore/python-casacore/issues/130#issuecomment-463202373
1259
+ ts.getcolnp(col, tmp_arr)
1246
1260
 
1247
- # Get the slice of rows contained in `tmp_arr`.
1248
- # Used to get the relevant integer indexes from `tidxs` and `bidxs`
1249
- tmp_slice = slice(start_row, start_row + num_rows)
1261
+ # Get the slice of rows contained in `tmp_arr`.
1262
+ # Used to get the relevant integer indexes from `tidxs` and `bidxs`
1263
+ tmp_slice = slice(start_row, start_row + num_rows)
1250
1264
 
1251
- # Copy `tmp_arr` into correct elements of `tmp_arr`
1252
- data[tidxs[tmp_slice], bidxs[tmp_slice]] = tmp_arr
1253
- start_row += num_rows
1265
+ # Copy `tmp_arr` into correct elements of `tmp_arr`
1266
+ data[tidxs[tmp_slice], bidxs[tmp_slice]] = tmp_arr
1267
+ start_row += num_rows
1268
+ else:
1269
+ data[tidxs, bidxs] = tb_tool.getcol(col)
1254
1270
 
1255
1271
  return data
@@ -24,16 +24,16 @@ from xradio._utils.list_and_array import (
24
24
  )
25
25
 
26
26
  rename_msv2_cols = {
27
- "antenna1": "antenna1_id",
28
- "antenna2": "antenna2_id",
29
- "feed1": "feed1_id",
30
- "feed2": "feed2_id",
27
+ "ANTENNA1": "antenna1_id",
28
+ "ANTENNA2": "antenna2_id",
29
+ "FEED1": "feed1_id",
30
+ "FEED2": "feed2_id",
31
31
  # optional cols:
32
- "weight_spectrum": "weight",
33
- "corrected_data": "vis_corrected",
34
- "data": "vis",
35
- "model_data": "vis_model",
36
- "float_data": "autocorr",
32
+ "WEIGHT_SPECTRUM": "WEIGHT",
33
+ "CORRECTED_DATA": "VIS_CORRECTED",
34
+ "DATA": "VIS",
35
+ "MODEL_DATA": "VIS_MODEL",
36
+ "FLOAT_DATA": "AUTOCORR",
37
37
  }
38
38
 
39
39
 
@@ -83,11 +83,11 @@ def redim_id_data_vars(mvars: Dict[str, xr.DataArray]) -> Dict[str, xr.DataArray
83
83
  """
84
84
  # Vars to drop baseline dim
85
85
  var_names = [
86
- "array_id",
87
- "observation_id",
88
- "processor_id",
89
- "scan_number",
90
- "state_id",
86
+ "ARRAY_ID",
87
+ "OBSERVATION_ID",
88
+ "PROCESSOR_ID",
89
+ "SCAN_NUMBER",
90
+ "STATE_ID",
91
91
  ]
92
92
  for vname in var_names:
93
93
  if "baseline" in mvars[vname].coords:
@@ -566,7 +566,7 @@ def concat_tvars_to_mvars(
566
566
 
567
567
  mvars = {}
568
568
  for tvr in tvars.keys():
569
- data_var = tvr.lower()
569
+ data_var = tvr
570
570
  if tvr == "UVW":
571
571
  mvars[data_var] = xr.DataArray(
572
572
  dask.array.concatenate(tvars[tvr], axis=0),
@@ -742,8 +742,7 @@ def read_flat_main_table(
742
742
  # now concat all the dask chunks from each time to make the xds
743
743
  mvars = {}
744
744
  for kk in bvars.keys():
745
- # from uppercase MS col names to lowercase xds var names:
746
- data_var = kk.lower()
745
+ data_var = kk
747
746
  if len(bvars[kk]) == 0:
748
747
  ignore += [kk]
749
748
  continue
@@ -766,7 +765,7 @@ def read_flat_main_table(
766
765
  )
767
766
 
768
767
  mvars["time"] = xr.DataArray(
769
- convert_casacore_time(mvars["time"].values), dims=["row"]
768
+ convert_casacore_time(mvars["TIME"].values), dims=["row"]
770
769
  ).chunk({"row": chunks[0]})
771
770
 
772
771
  # add xds global attributes
@@ -16,7 +16,7 @@ from .read import (
16
16
  extract_table_attributes,
17
17
  add_units_measures,
18
18
  table_exists,
19
- read_generic_table,
19
+ load_generic_table,
20
20
  )
21
21
  from .write import revert_time
22
22
  from xradio._utils.list_and_array import unique_1d
@@ -49,7 +49,7 @@ def read_ephemerides(
49
49
  logger.debug(f"Reading ephemerides info from: FIELD / {sdir.name}")
50
50
  # One "EPHEM_*.tab" (each with a difference ephemeris_id) to concatenate
51
51
  ephem.append(
52
- read_generic_table(infile, str(Path(*sdir.parts[-2:])), timecols=["MJD"])
52
+ load_generic_table(infile, str(Path(*sdir.parts[-2:])), timecols=["MJD"])
53
53
  )
54
54
 
55
55
  if ephem:
@@ -339,7 +339,7 @@ def read_delayed_pointing_chunks(
339
339
  ):
340
340
  continue
341
341
  if col not in bvars:
342
- bvars[col.lower()] = []
342
+ bvars[col] = []
343
343
 
344
344
  cdata = tb_tool.getcol(col, 0, 1)[0]
345
345
  if isinstance(cdata, str):
@@ -356,7 +356,7 @@ def read_delayed_pointing_chunks(
356
356
  None,
357
357
  None,
358
358
  )
359
- bvars[col.lower()] += [
359
+ bvars[col] += [
360
360
  dask.array.from_delayed(
361
361
  delayed_array, (ctlen, cblen), cdata.dtype
362
362
  )
@@ -390,6 +390,6 @@ def read_delayed_pointing_chunks(
390
390
  )
391
391
  ]
392
392
  d1_list += [dask.array.concatenate(d2_list, axis=3)]
393
- bvars[col.lower()] += [dask.array.concatenate(d1_list, axis=2)]
393
+ bvars[col] += [dask.array.concatenate(d1_list, axis=2)]
394
394
 
395
395
  return bvars
@@ -6,8 +6,6 @@ import xarray as xr
6
6
 
7
7
  from casacore import tables
8
8
 
9
- from ..msv2_msv3 import ignore_msv2_cols
10
-
11
9
 
12
10
  def revert_time(datetimes: np.ndarray) -> np.ndarray:
13
11
  """
@@ -72,9 +70,9 @@ def create_table(
72
70
 
73
71
  if cols is None:
74
72
  if ctds_attrs and "column_descriptions" in ctds_attrs:
75
- cols = {col: col.lower() for col in ctds_attrs["column_descriptions"]}
73
+ cols = {col: col for col in ctds_attrs["column_descriptions"]}
76
74
  else:
77
- cols = {var.upper(): var for var in xds.data_vars}
75
+ cols = {var: var for var in xds.data_vars}
78
76
  # Would add all xds data vars regardless of description availability
79
77
  # +
80
78
  # list(xds.data_vars) +
@@ -15,16 +15,16 @@ from casacore import tables
15
15
  # TODO: this should be consolidated with the equivalent in read_main_table,
16
16
  # if we keep this mapping
17
17
  rename_to_msv2_cols = {
18
- "antenna1_id": "antenna1",
19
- "antenna2_id": "antenna2",
20
- "feed1_id": "feed1",
21
- "feed2_id": "feed2",
18
+ "antenna1_id": "ANTENNA1",
19
+ "antenna2_id": "ANTENNA2",
20
+ "feed1_id": "FEED1",
21
+ "feed2_id": "FEED2",
22
22
  # optional cols:
23
- # "weight": "weight_spectrum",
24
- "vis_corrected": "corrected_data",
25
- "vis": "data",
26
- "vis_model": "model_data",
27
- "autocorr": "float_data",
23
+ # "WEIGHT": "WEIGHT_SPECTRUM",
24
+ "VIS_CORRECTED": "CORRECTED_DATA",
25
+ "VIS": "DATA",
26
+ "VIS_MODEL": "MODEL_DATA",
27
+ "AUTOCORR": "FLOAT_DATA",
28
28
  }
29
29
  # cols added in xds not in MSv2
30
30
  cols_not_in_msv2 = ["baseline_ant1_id", "baseline_ant2_id"]
@@ -214,9 +214,15 @@ def write_ms(
214
214
  continue
215
215
 
216
216
  col_chunk_size = np.prod([kk[0] for kk in txds[col].chunks])
217
- col_rows = (
218
- int(np.ceil(max_chunk_size / col_chunk_size)) * txds[col].chunks[0][0]
219
- )
217
+ if max_chunk_size <= 0:
218
+ max_chunk_size = 19200
219
+ if col_chunk_size <= 0:
220
+ col_rows = max_chunk_size
221
+ else:
222
+ col_rows = (
223
+ int(np.ceil(max_chunk_size / col_chunk_size))
224
+ * txds[col].chunks[0][0]
225
+ )
220
226
  for rr in range(0, txds[col].row.shape[0], col_rows):
221
227
  txda = txds[col].isel(row=slice(rr, rr + col_rows))
222
228
  delayed_writes += [
@@ -338,7 +344,7 @@ def write_ms_serial(
338
344
  print(f"Exception writing subtable {subtable}: {exc}")
339
345
 
340
346
  part_key0 = next(iter(mxds.partitions))
341
- vis_data_shape = mxds.partitions[part_key0].vis.shape
347
+ vis_data_shape = mxds.partitions[part_key0].VIS.shape
342
348
  rows_chunk_size = calc_optimal_ms_chunk_shape(
343
349
  memory_available_in_bytes, vis_data_shape, 16, "DATA"
344
350
  )
@@ -2,18 +2,11 @@ from pathlib import Path
2
2
  from typing import Dict, Tuple
3
3
 
4
4
  import xarray as xr
5
- import numpy as np
6
5
 
7
6
 
8
- from .msv2_msv3 import ignore_msv2_cols
9
- from .partition_queries import make_partition_ids_by_ddi_scan
10
7
  from .subtables import subt_rename_ids
11
- from ._tables.load_main_table import load_expanded_main_table_chunk
12
- from ._tables.read import read_generic_table, make_freq_attrs
8
+ from ._tables.read import load_generic_table
13
9
  from ._tables.read_subtables import read_delayed_pointing_table
14
- from .._utils.partition_attrs import add_partition_attrs
15
- from .._utils.xds_helper import make_coords
16
- from xradio._utils.list_and_array import unique_1d
17
10
 
18
11
 
19
12
  def read_spw_ddi_ant_pol(inpath: str) -> Tuple[xr.Dataset]:
@@ -30,81 +23,21 @@ def read_spw_ddi_ant_pol(inpath: str) -> Tuple[xr.Dataset]:
30
23
  Tuple[xr.Dataset]
31
24
  tuple with antenna, ddi, spw, and polarization setup subtables info
32
25
  """
33
- spw_xds = read_generic_table(
26
+ spw_xds = load_generic_table(
34
27
  inpath,
35
28
  "SPECTRAL_WINDOW",
36
29
  rename_ids=subt_rename_ids["SPECTRAL_WINDOW"],
37
30
  )
38
- ddi_xds = read_generic_table(inpath, "DATA_DESCRIPTION")
39
- ant_xds = read_generic_table(
31
+ ddi_xds = load_generic_table(inpath, "DATA_DESCRIPTION")
32
+ ant_xds = load_generic_table(
40
33
  inpath, "ANTENNA", rename_ids=subt_rename_ids["ANTENNA"]
41
34
  )
42
- pol_xds = read_generic_table(
35
+ pol_xds = load_generic_table(
43
36
  inpath, "POLARIZATION", rename_ids=subt_rename_ids["POLARIZATION"]
44
37
  )
45
38
  return ant_xds, ddi_xds, spw_xds, pol_xds
46
39
 
47
40
 
48
- def load_main_chunk(
49
- infile: str, chunk: Dict[str, slice]
50
- ) -> Dict[Tuple[int, int], xr.Dataset]:
51
- """
52
- Loads a chunk of visibility data. For every DDI, a separate
53
- dataset is produced.
54
- This is very loosely equivalent to the
55
- partitions.read_*_partitions functions, but in a load (not lazy)
56
- fashion and with an implicit single partition wrt. anything but
57
- DDIs.
58
- Metainfo (sub)tables) are not loaded, and the result is one or more
59
- Xarray datasets. It produces one dataset per DDI found within the
60
- chunk slice of time/baseline.
61
-
62
- Parameters
63
- ----------
64
- infile : str
65
- MS path (main table)
66
- chunk : Dict[str, slice]
67
- specification of chunk to load
68
-
69
- Returns
70
- -------
71
- Dict[Tuple[int, int], xr.Dataset]
72
- dictionary of chunk datasets (keys are spw and pol_setup IDs)
73
- """
74
-
75
- chunk_dims = ["time", "baseline", "freq", "pol"]
76
- if not all(key in chunk_dims for key in chunk):
77
- raise ValueError(f"chunks dict has unknown keys. Accepted ones: {chunk_dims}")
78
-
79
- ant_xds, ddi_xds, spw_xds, pol_xds = read_spw_ddi_ant_pol(infile)
80
-
81
- # TODO: constrain this better/ properly
82
- data_desc_id, scan_number, state_id = make_partition_ids_by_ddi_scan(infile, False)
83
-
84
- all_xdss = {}
85
- data_desc_id = unique_1d(data_desc_id)
86
- for ddi in data_desc_id:
87
- xds, part_ids, attrs = load_expanded_main_table_chunk(
88
- infile, ddi, chunk, ignore_msv2_cols=ignore_msv2_cols
89
- )
90
-
91
- coords = make_coords(xds, ddi, (ant_xds, ddi_xds, spw_xds, pol_xds))
92
- xds = xds.assign_coords(coords)
93
- xds = add_partition_attrs(xds, ddi, ddi_xds, part_ids, other_attrs={})
94
-
95
- # freq dim needs to pull its units/measure info from the SPW subtable
96
- spw_id = xds.attrs["partition_ids"]["spw_id"]
97
- xds.freq.attrs.update(make_freq_attrs(spw_xds, spw_id))
98
- pol_setup_id = ddi_xds.polarization_id.values[ddi]
99
-
100
- chunk_ddi_key = (spw_id, pol_setup_id)
101
- all_xdss[chunk_ddi_key] = xds
102
-
103
- chunk_xdss = finalize_chunks(infile, all_xdss, chunk)
104
-
105
- return chunk_xdss
106
-
107
-
108
41
  def finalize_chunks(
109
42
  infile: str, chunks: Dict[str, xr.Dataset], chunk_spec: Dict[str, slice]
110
43
  ) -> Dict[Tuple[int, int], xr.Dataset]: