xradio 0.0.55__py3-none-any.whl → 0.0.58__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. xradio/__init__.py +2 -2
  2. xradio/_utils/_casacore/casacore_from_casatools.py +1001 -0
  3. xradio/_utils/_casacore/tables.py +6 -1
  4. xradio/_utils/coord_math.py +22 -23
  5. xradio/_utils/dict_helpers.py +76 -11
  6. xradio/_utils/schema.py +5 -2
  7. xradio/_utils/zarr/common.py +1 -73
  8. xradio/image/_util/_casacore/common.py +11 -3
  9. xradio/image/_util/_casacore/xds_from_casacore.py +59 -35
  10. xradio/image/_util/_casacore/xds_to_casacore.py +47 -16
  11. xradio/image/_util/_fits/xds_from_fits.py +172 -77
  12. xradio/image/_util/casacore.py +9 -4
  13. xradio/image/_util/common.py +4 -4
  14. xradio/image/_util/image_factory.py +8 -8
  15. xradio/image/image.py +45 -5
  16. xradio/measurement_set/__init__.py +19 -9
  17. xradio/measurement_set/_utils/__init__.py +1 -3
  18. xradio/measurement_set/_utils/_msv2/__init__.py +0 -0
  19. xradio/measurement_set/_utils/_msv2/_tables/read.py +35 -90
  20. xradio/measurement_set/_utils/_msv2/_tables/read_main_table.py +6 -686
  21. xradio/measurement_set/_utils/_msv2/_tables/table_query.py +13 -3
  22. xradio/measurement_set/_utils/_msv2/conversion.py +129 -145
  23. xradio/measurement_set/_utils/_msv2/create_antenna_xds.py +9 -16
  24. xradio/measurement_set/_utils/_msv2/create_field_and_source_xds.py +125 -221
  25. xradio/measurement_set/_utils/_msv2/msv2_to_msv4_meta.py +1 -2
  26. xradio/measurement_set/_utils/_msv2/msv4_info_dicts.py +13 -8
  27. xradio/measurement_set/_utils/_msv2/msv4_sub_xdss.py +27 -72
  28. xradio/measurement_set/_utils/_msv2/partition_queries.py +5 -262
  29. xradio/measurement_set/_utils/_msv2/subtables.py +0 -107
  30. xradio/measurement_set/_utils/_utils/interpolate.py +60 -0
  31. xradio/measurement_set/_utils/_zarr/encoding.py +2 -7
  32. xradio/measurement_set/convert_msv2_to_processing_set.py +0 -2
  33. xradio/measurement_set/load_processing_set.py +2 -2
  34. xradio/measurement_set/measurement_set_xdt.py +14 -14
  35. xradio/measurement_set/open_processing_set.py +1 -3
  36. xradio/measurement_set/processing_set_xdt.py +41 -835
  37. xradio/measurement_set/schema.py +96 -123
  38. xradio/schema/check.py +91 -97
  39. xradio/schema/dataclass.py +159 -22
  40. xradio/schema/export.py +99 -0
  41. xradio/schema/metamodel.py +51 -16
  42. xradio/schema/typing.py +5 -5
  43. {xradio-0.0.55.dist-info → xradio-0.0.58.dist-info}/METADATA +43 -11
  44. xradio-0.0.58.dist-info/RECORD +65 -0
  45. {xradio-0.0.55.dist-info → xradio-0.0.58.dist-info}/WHEEL +1 -1
  46. xradio/image/_util/fits.py +0 -13
  47. xradio/measurement_set/_utils/_msv2/_tables/load.py +0 -63
  48. xradio/measurement_set/_utils/_msv2/_tables/load_main_table.py +0 -487
  49. xradio/measurement_set/_utils/_msv2/_tables/read_subtables.py +0 -395
  50. xradio/measurement_set/_utils/_msv2/_tables/write.py +0 -320
  51. xradio/measurement_set/_utils/_msv2/_tables/write_exp_api.py +0 -385
  52. xradio/measurement_set/_utils/_msv2/chunks.py +0 -115
  53. xradio/measurement_set/_utils/_msv2/descr.py +0 -165
  54. xradio/measurement_set/_utils/_msv2/msv2_msv3.py +0 -7
  55. xradio/measurement_set/_utils/_msv2/partitions.py +0 -392
  56. xradio/measurement_set/_utils/_utils/cds.py +0 -40
  57. xradio/measurement_set/_utils/_utils/xds_helper.py +0 -404
  58. xradio/measurement_set/_utils/_zarr/read.py +0 -263
  59. xradio/measurement_set/_utils/_zarr/write.py +0 -329
  60. xradio/measurement_set/_utils/msv2.py +0 -106
  61. xradio/measurement_set/_utils/zarr.py +0 -133
  62. xradio-0.0.55.dist-info/RECORD +0 -77
  63. {xradio-0.0.55.dist-info → xradio-0.0.58.dist-info}/licenses/LICENSE.txt +0 -0
  64. {xradio-0.0.55.dist-info → xradio-0.0.58.dist-info}/top_level.txt +0 -0
@@ -1,331 +1,21 @@
1
- import toolviper.utils.logger as logger
2
- from typing import Any, Dict, List, Tuple, Union
1
+ from typing import Tuple
3
2
 
4
- import dask, dask.array
5
3
  import numpy as np
6
- import xarray as xr
7
4
  import pandas as pd
8
5
 
9
- from casacore import tables
6
+ try:
7
+ from casacore import tables
8
+ except ImportError:
9
+ import xradio._utils._casacore.casacore_from_casatools as tables
10
10
 
11
- from .read import (
12
- read_flat_col_chunk,
13
- read_col_chunk,
14
- convert_casacore_time,
15
- extract_table_attributes,
16
- add_units_measures,
17
- )
18
11
 
19
- from .table_query import open_table_ro, open_query
12
+ from .table_query import open_query
20
13
  from xradio._utils.list_and_array import (
21
14
  unique_1d,
22
15
  pairing_function,
23
16
  inverse_pairing_function,
24
17
  )
25
18
 
26
- rename_msv2_cols = {
27
- "ANTENNA1": "antenna1_id",
28
- "ANTENNA2": "antenna2_id",
29
- "FEED1": "feed1_id",
30
- "FEED2": "feed2_id",
31
- # optional cols:
32
- "WEIGHT_SPECTRUM": "WEIGHT",
33
- "CORRECTED_DATA": "VIS_CORRECTED",
34
- "DATA": "VIS",
35
- "MODEL_DATA": "VIS_MODEL",
36
- "FLOAT_DATA": "AUTOCORR",
37
- }
38
-
39
-
40
- def rename_vars(mvars: Dict[str, xr.DataArray]) -> Dict[str, xr.DataArray]:
41
- """
42
- Apply rename rules. Also preserve ordering of data_vars
43
-
44
- Note: not using xr.DataArray.rename because we have optional
45
- column renames and rename complains if some of the names passed
46
- are not present in the dataset
47
-
48
- Parameters
49
- ----------
50
- mvars : Dict[str, xr.DataArray]
51
- dictionary of data_vars to be used to create an xr.Dataset
52
-
53
- Returns
54
- -------
55
- Dict[str, xr.DataArray]
56
- similar dictionary after applying MSv2 => MSv3/ngCASA renaming rules
57
- """
58
- renamed = {
59
- rename_msv2_cols[name] if name in rename_msv2_cols else name: var
60
- for name, var in mvars.items()
61
- }
62
-
63
- return renamed
64
-
65
-
66
- def redim_id_data_vars(mvars: Dict[str, xr.DataArray]) -> Dict[str, xr.DataArray]:
67
- """
68
- Changes:
69
- Several id data variables to drop its baseline dim
70
- The antenna id data vars:
71
- From MS (antenna1_id(time, baseline), antenna2_id(time,baseline)
72
- To cds (baseline_ant1_id(baseline), baseline_ant2_id(baseline)
73
-
74
- Parameters
75
- ----------
76
- mvars : Dict[str, xr.DataArray]
77
- data variables being prepared for a partition xds
78
-
79
- Returns
80
- -------
81
- Dict[str, xr.DataArray]
82
- data variables with the ant id ones modified to cds type
83
- """
84
- # Vars to drop baseline dim
85
- var_names = [
86
- "ARRAY_ID",
87
- "OBSERVATION_ID",
88
- "PROCESSOR_ID",
89
- "SCAN_NUMBER",
90
- "STATE_ID",
91
- ]
92
- for vname in var_names:
93
- if "baseline" in mvars[vname].coords:
94
- mvars[vname] = mvars[vname].sel(baseline=0, drop=True)
95
-
96
- for idx in ["1", "2"]:
97
- new_name = f"baseline_ant{idx}_id"
98
- mvars[new_name] = mvars.pop(f"antenna{idx}_id")
99
- if "time" in mvars[new_name].coords:
100
- mvars[new_name] = mvars[new_name].sel(time=0, drop=True)
101
-
102
- return mvars
103
-
104
-
105
- def get_partition_ids(mtable: tables.table, taql_where: str) -> Dict:
106
- """
107
- Get some of the partition IDs that we have to retrieve from some
108
- of the top level ID/sorting cols of the main table of the MS.
109
-
110
- Parameters
111
- ----------
112
- mtable : tables.table
113
- MS main table
114
- taql_where : str
115
- where part that defines the partition in TaQL
116
-
117
- Returns
118
- -------
119
- Dict
120
- ids of array, observation, and processor
121
- """
122
-
123
- taql_ids = f"select DISTINCT ARRAY_ID, OBSERVATION_ID, PROCESSOR_ID from $mtable {taql_where}"
124
- with open_query(mtable, taql_ids) as query:
125
- # array_id, observation_id, processor_id
126
- array_id = unique_1d(query.getcol("ARRAY_ID"))
127
- obs_id = unique_1d(query.getcol("OBSERVATION_ID"))
128
- proc_id = unique_1d(query.getcol("PROCESSOR_ID"))
129
- check_vars = [
130
- (array_id, "array_id"),
131
- (obs_id, "observation_id"),
132
- (proc_id, "processor_id"),
133
- ]
134
- for var, var_name in check_vars:
135
- if len(var) != 1:
136
- logger.warning(
137
- f"Did not get exactly one {var_name} (got {var} for this partition. TaQL: {taql_where}"
138
- )
139
-
140
- pids = {
141
- "array_id": list(array_id),
142
- "observation_id": list(obs_id),
143
- "processor_id": list(proc_id),
144
- }
145
- return pids
146
-
147
-
148
- def read_expanded_main_table(
149
- infile: str,
150
- ddi: int = 0,
151
- scan_state: Union[Tuple[int, int], None] = None,
152
- ignore_msv2_cols: Union[list, None] = None,
153
- chunks: Tuple[int, ...] = (400, 200, 100, 2),
154
- ) -> Tuple[xr.Dataset, Dict[str, Any], Dict[str, Any]]:
155
- """
156
- Reads one partition from the main table, all columns.
157
- This is the expanded version (time, baseline) dims.
158
-
159
- Chunk tuple: (time, baseline, freq, pol)
160
-
161
- Parameters
162
- ----------
163
- infile : str
164
-
165
- ddi : int (Default value = 0)
166
-
167
- scan_state : Union[Tuple[int, int], None] (Default value = None)
168
-
169
- ignore_msv2_cols: Union[list, None] (Default value = None)
170
-
171
- chunks: Tuple[int, ...] (Default value = (400, 200, 100, 2))
172
-
173
-
174
- Returns
175
- -------
176
- Tuple[xr.Dataset, Dict[str, Any], Dict[str, Any]]
177
- """
178
- if ignore_msv2_cols is None:
179
- ignore_msv2_cols = []
180
-
181
- taql_where = f"where DATA_DESC_ID = {ddi}"
182
- if scan_state:
183
- # get partitions by scan/state
184
- scan, state = scan_state
185
- if type(state) == np.ndarray:
186
- state_ids_or = " OR STATE_ID = ".join(np.char.mod("%d", state))
187
- taql_where += f" AND (STATE_ID = {state_ids_or})"
188
- elif state:
189
- taql_where += f" AND SCAN_NUMBER = {scan} AND STATE_ID = {state}"
190
- elif scan:
191
- # scan can also be None, when partition_scheme='intent'
192
- # but the STATE table is empty!
193
- taql_where += f" AND SCAN_NUMBER = {scan}"
194
-
195
- with open_table_ro(infile) as mtable:
196
- # one partition, select just the specified ddi (+ scan/subscan)
197
- taql_main = f"select * from $mtable {taql_where}"
198
- with open_query(mtable, taql_main) as tb_tool:
199
- if tb_tool.nrows() == 0:
200
- tb_tool.close()
201
- mtable.close()
202
- return xr.Dataset(), {}, {}
203
-
204
- xds, attrs = read_main_table_chunks(
205
- infile, tb_tool, taql_where, ignore_msv2_cols, chunks
206
- )
207
- part_ids = get_partition_ids(tb_tool, taql_where)
208
-
209
- return xds, part_ids, attrs
210
-
211
-
212
- def read_main_table_chunks(
213
- infile: str,
214
- tb_tool: tables.table,
215
- taql_where: str,
216
- ignore_msv2_cols: Union[list, None] = None,
217
- chunks: Tuple[int, ...] = (400, 200, 100, 2),
218
- ) -> Tuple[xr.Dataset, Dict[str, Any]]:
219
- """
220
- Iterates through the time,baseline chunks and reads slices from
221
- all the data columns.
222
-
223
- Parameters
224
- ----------
225
- infile : str
226
-
227
- tb_tool : tables.table
228
-
229
- taql_where : str
230
-
231
- ignore_msv2_cols : Union[list, None] (Default value = None)
232
-
233
- chunks: Tuple[int, ...] (Default value = (400, 200, 100, 2))
234
-
235
-
236
- Returns
237
- -------
238
- Tuple[xr.Dataset, Dict[str, Any]]
239
- """
240
- baselines = get_baselines(tb_tool)
241
-
242
- col_names = tb_tool.colnames()
243
- cshapes = [
244
- np.array(tb_tool.getcell(col, 0)).shape
245
- for col in col_names
246
- if tb_tool.iscelldefined(col, 0)
247
- ]
248
- chan_cnt, pol_cnt = [(cc[0], cc[1]) for cc in cshapes if len(cc) == 2][0]
249
-
250
- unique_times, tol = get_utimes_tol(tb_tool, taql_where)
251
-
252
- tvars = {}
253
- n_baselines = len(baselines)
254
- n_unique_times = len(unique_times)
255
- n_time_chunks = chunks[0]
256
- n_baseline_chunks = chunks[1]
257
- # loop over time chunks
258
- for time_chunk in range(0, n_unique_times, n_time_chunks):
259
- time_start = unique_times[time_chunk] - tol
260
- time_end = (
261
- unique_times[min(n_unique_times, time_chunk + n_time_chunks) - 1] + tol
262
- )
263
-
264
- # chunk time length
265
- ctlen = min(n_unique_times, time_chunk + n_time_chunks) - time_chunk
266
-
267
- bvars = {}
268
- # loop over baseline chunks
269
- for baseline_chunk in range(0, n_baselines, n_baseline_chunks):
270
- cblen = min(n_baselines - baseline_chunk, n_baseline_chunks)
271
-
272
- # read the specified chunk of data
273
- # def read_chunk(infile, ddi, times, blines, chans, pols):
274
- ttql = f"TIME BETWEEN {time_start} and {time_end}"
275
- ant1_start = baselines[baseline_chunk][0]
276
- ant1_end = baselines[cblen + baseline_chunk - 1][0]
277
- atql = f"ANTENNA1 BETWEEN {ant1_start} and {ant1_end}"
278
- ts_taql = f"select * from $mtable {taql_where} AND {ttql} AND {atql}"
279
- with open_query(None, ts_taql) as query_times_ants:
280
- tidxs = (
281
- np.searchsorted(
282
- unique_times, query_times_ants.getcol("TIME", 0, -1)
283
- )
284
- - time_chunk
285
- )
286
- ts_ant1, ts_ant2 = (
287
- query_times_ants.getcol("ANTENNA1", 0, -1),
288
- query_times_ants.getcol("ANTENNA2", 0, -1),
289
- )
290
-
291
- ts_bases = np.column_stack((ts_ant1, ts_ant2))
292
-
293
- bidxs = get_baseline_indices(baselines, ts_bases) - baseline_chunk
294
-
295
- # some antenna 2's will be out of bounds for this chunk, store rows that are in bounds
296
- didxs = np.where(
297
- (bidxs >= 0) & (bidxs < min(chunks[1], n_baselines - baseline_chunk))
298
- )[0]
299
-
300
- delayed_params = (infile, ts_taql, (ctlen, cblen), tidxs, bidxs, didxs)
301
-
302
- read_all_cols_bvars(
303
- tb_tool, chunks, chan_cnt, ignore_msv2_cols, delayed_params, bvars
304
- )
305
-
306
- concat_bvars_update_tvars(bvars, tvars)
307
-
308
- dims = ["time", "baseline", "freq", "pol"]
309
- mvars = concat_tvars_to_mvars(dims, tvars, pol_cnt, chan_cnt)
310
-
311
- mcoords = {
312
- "time": xr.DataArray(convert_casacore_time(unique_times), dims=["time"]),
313
- "baseline": xr.DataArray(np.arange(n_baselines), dims=["baseline"]),
314
- }
315
-
316
- # add xds global attributes
317
- cc_attrs = extract_table_attributes(infile)
318
- attrs = {"other": {"msv2": {"ctds_attrs": cc_attrs, "bad_cols": ignore_msv2_cols}}}
319
- # add per data var attributes
320
- mvars = add_units_measures(mvars, cc_attrs)
321
- mcoords = add_units_measures(mcoords, cc_attrs)
322
-
323
- mvars = rename_vars(mvars)
324
- mvars = redim_id_data_vars(mvars)
325
- xds = xr.Dataset(mvars, coords=mcoords)
326
-
327
- return xds, attrs
328
-
329
19
 
330
20
  def get_utimes_tol(mtable: tables.table, taql_where: str) -> Tuple[np.ndarray, float]:
331
21
  taql_utimes = f"select DISTINCT TIME from $mtable {taql_where}"
@@ -410,373 +100,3 @@ def get_baseline_indices(
410
100
  baseline_indices = unique_baselines_sorted[sorted_indices]
411
101
 
412
102
  return baseline_indices
413
-
414
-
415
- def read_all_cols_bvars(
416
- tb_tool: tables.table,
417
- chunks: Tuple[int, ...],
418
- chan_cnt: int,
419
- ignore_msv2_cols: bool,
420
- delayed_params: Tuple,
421
- bvars: Dict[str, xr.DataArray],
422
- ) -> None:
423
- """
424
- Loops over each column and create delayed dask arrays
425
-
426
- Parameters
427
- ----------
428
- tb_tool : tables.table
429
-
430
- chunks : Tuple[int, ...]
431
-
432
- chan_cnt : int
433
-
434
- ignore_msv2_cols : bool
435
-
436
- delayed_params : Tuple
437
-
438
- bvars : Dict[str, xr.DataArray]
439
-
440
-
441
- Returns
442
- -------
443
-
444
- """
445
-
446
- col_names = tb_tool.colnames()
447
- for col in col_names:
448
- if (col in ignore_msv2_cols + ["TIME"]) or (not tb_tool.iscelldefined(col, 0)):
449
- continue
450
- if col not in bvars:
451
- bvars[col] = []
452
-
453
- cdata = tb_tool.getcol(col, 0, 1)[0]
454
-
455
- if len(cdata.shape) == 0:
456
- infile, ts_taql, (ctlen, cblen), tidxs, bidxs, didxs = delayed_params
457
- cshape = (ctlen, cblen)
458
- delayed_col = infile, ts_taql, col, cshape, tidxs, bidxs, didxs
459
- delayed_array = dask.delayed(read_col_chunk)(*delayed_col, None, None)
460
- bvars[col] += [dask.array.from_delayed(delayed_array, cshape, cdata.dtype)]
461
-
462
- elif col == "UVW":
463
- infile, ts_taql, (ctlen, cblen), tidxs, bidxs, didxs = delayed_params
464
- cshape = (ctlen, cblen, 3)
465
- delayed_3 = infile, ts_taql, col, cshape, tidxs, bidxs, didxs
466
- delayed_array = dask.delayed(read_col_chunk)(*delayed_3, None, None)
467
- bvars[col] += [dask.array.from_delayed(delayed_array, cshape, cdata.dtype)]
468
-
469
- elif len(cdata.shape) == 1:
470
- pol_list = []
471
- dd = 2 if cdata.shape == chan_cnt else 3
472
- for pc in range(0, cdata.shape[0], chunks[dd]):
473
- pols = (pc, min(cdata.shape[0], pc + chunks[dd]) - 1)
474
- infile, ts_taql, (ctlen, cblen), tidxs, bidxs, didxs = delayed_params
475
- cshape = (
476
- ctlen,
477
- cblen,
478
- ) + (pols[1] - pols[0] + 1,)
479
- delayed_cs = infile, ts_taql, col, cshape, tidxs, bidxs, didxs
480
- delayed_array = dask.delayed(read_col_chunk)(*delayed_cs, pols, None)
481
- pol_list += [
482
- dask.array.from_delayed(delayed_array, cshape, cdata.dtype)
483
- ]
484
- bvars[col] += [dask.array.concatenate(pol_list, axis=2)]
485
-
486
- elif len(cdata.shape) == 2:
487
- chan_list = []
488
- for cc in range(0, cdata.shape[0], chunks[2]):
489
- chans = (cc, min(cdata.shape[0], cc + chunks[2]) - 1)
490
- pol_list = []
491
- for pc in range(0, cdata.shape[1], chunks[3]):
492
- pols = (pc, min(cdata.shape[1], pc + chunks[3]) - 1)
493
- (
494
- infile,
495
- ts_taql,
496
- (ctlen, cblen),
497
- tidxs,
498
- bidxs,
499
- didxs,
500
- ) = delayed_params
501
- cshape = (
502
- ctlen,
503
- cblen,
504
- ) + (chans[1] - chans[0] + 1, pols[1] - pols[0] + 1)
505
- delayed_cs = infile, ts_taql, col, cshape, tidxs, bidxs, didxs
506
- delayed_array = dask.delayed(read_col_chunk)(
507
- *delayed_cs, chans, pols
508
- )
509
- pol_list += [
510
- dask.array.from_delayed(delayed_array, cshape, cdata.dtype)
511
- ]
512
- chan_list += [dask.array.concatenate(pol_list, axis=3)]
513
- bvars[col] += [dask.array.concatenate(chan_list, axis=2)]
514
-
515
-
516
- def concat_bvars_update_tvars(
517
- bvars: Dict[str, xr.DataArray], tvars: Dict[str, xr.DataArray]
518
- ) -> None:
519
- """
520
- concats all the dask chunks from each baseline. This is intended to
521
- be called iteratively, for every time chunk iteration, once all the
522
- baseline chunks have been read.
523
-
524
- Parameters
525
- ----------
526
- bvars: Dict[str, xr.DataArray]
527
-
528
- tvars: Dict[str, xr.DataArray]
529
-
530
-
531
- Returns
532
- -------
533
-
534
- """
535
- for kk in bvars.keys():
536
- if len(bvars[kk]) == 0:
537
- continue
538
- if kk not in tvars:
539
- tvars[kk] = []
540
- tvars[kk] += [dask.array.concatenate(bvars[kk], axis=1)]
541
-
542
-
543
- def concat_tvars_to_mvars(
544
- dims: List[str], tvars: Dict[str, xr.DataArray], pol_cnt: int, chan_cnt: int
545
- ) -> Dict[str, xr.DataArray]:
546
- """
547
- Concat into a single dask array all the dask arrays from each time
548
- chunk to make the final arrays of the xds.
549
-
550
- Parameters
551
- ----------
552
- dims : List[str]
553
- dimension names
554
- tvars : Dict[str, xr.DataArray]
555
- variables as lists of dask arrays per time chunk
556
- pol_cnt : int
557
- len of pol axis/dim
558
- chan_cnt : int
559
- len of freq axis/dim (chan indices)
560
-
561
- Returns
562
- -------
563
- Dict[str, xr.DataArray]
564
- variables as concated dask arrays
565
- """
566
-
567
- mvars = {}
568
- for tvr in tvars.keys():
569
- data_var = tvr
570
- if tvr == "UVW":
571
- mvars[data_var] = xr.DataArray(
572
- dask.array.concatenate(tvars[tvr], axis=0),
573
- dims=dims[:2] + ["uvw_coords"],
574
- )
575
- elif len(tvars[tvr][0].shape) == 3 and (tvars[tvr][0].shape[-1] == pol_cnt):
576
- mvars[data_var] = xr.DataArray(
577
- dask.array.concatenate(tvars[tvr], axis=0), dims=dims[:2] + ["pol"]
578
- )
579
- elif len(tvars[tvr][0].shape) == 3 and (tvars[tvr][0].shape[-1] == chan_cnt):
580
- mvars[data_var] = xr.DataArray(
581
- dask.array.concatenate(tvars[tvr], axis=0), dims=dims[:2] + ["freq"]
582
- )
583
- else:
584
- mvars[data_var] = xr.DataArray(
585
- dask.array.concatenate(tvars[tvr], axis=0),
586
- dims=dims[: len(tvars[tvr][0].shape)],
587
- )
588
-
589
- return mvars
590
-
591
-
592
- def read_flat_main_table(
593
- infile: str,
594
- ddi: Union[int, None] = None,
595
- scan_state: Union[Tuple[int, int], None] = None,
596
- rowidxs=None,
597
- ignore_msv2_cols: Union[List, None] = None,
598
- chunks: Tuple[int, ...] = (22000, 512, 2),
599
- ) -> Tuple[xr.Dataset, Dict[str, Any], Dict[str, Any]]:
600
- """
601
- Read main table using flat structure: no baseline dimension. Vis
602
- dimensions are: row, freq, pol
603
- (experimental, perhaps to be deprecated/removed). Works but some
604
- features may be missing and/or flaky.
605
-
606
- Chunk tuple: (row, freq, pol)
607
-
608
- Parameters
609
- ----------
610
- infile : str
611
-
612
- ddi : Union[int, None] (Default value = None)
613
-
614
- scan_state : Union[Tuple[int, int], None] (Default value = None)
615
-
616
- rowidxs : np.ndarray (Default value = None)
617
-
618
- ignore_msv2_cols : Union[List, None] (Default value = None)
619
-
620
- chunks : Tuple[int, ...] (Default value = (22000, 512, 2))
621
-
622
- Returns
623
- -------
624
- Tuple[xr.Dataset, Dict[str, Any], Dict[str, Any]]
625
- """
626
- taql_where = f"where DATA_DESC_ID = {ddi}"
627
- if scan_state:
628
- # TODO: support additional intent/scan/subscan conditions if
629
- # we keep this read_flat functionality
630
- scans, states = scan_state
631
- # get row indices relative to full main table
632
- if type(states) == np.ndarray:
633
- state_ids_or = " OR STATE_ID = ".join(np.char.mod("%d", states))
634
- taql_where += f" AND (STATE_ID = {state_ids_or})"
635
- elif states is not None:
636
- taql_where += f" AND STATE_ID = {states}"
637
- elif scans is not None:
638
- taql_where += f" AND SCAN_NUMBER = {scans}"
639
-
640
- mtable = tables.table(
641
- infile, readonly=True, lockoptions={"option": "usernoread"}, ack=False
642
- )
643
-
644
- # get row indices relative to full main table
645
- if rowidxs is None:
646
- taql_rowid = f"select rowid() as ROWS from $mtable {taql_where}"
647
- with open_query(mtable, taql_rowid) as query_rows:
648
- rowidxs = query_rows.getcol("ROWS")
649
-
650
- nrows = len(rowidxs)
651
- if nrows == 0:
652
- mtable.close()
653
- return xr.Dataset(), {}, {}
654
-
655
- part_ids = get_partition_ids(mtable, taql_where)
656
-
657
- taql_cols = f"select * from $mtable {taql_where}"
658
- with open_query(mtable, taql_cols) as query_cols:
659
- cols = query_cols.colnames()
660
- ignore = [
661
- col
662
- for col in cols
663
- if (not query_cols.iscelldefined(col, 0))
664
- or (query_cols.coldatatype(col) == "record")
665
- ]
666
- cdata = dict(
667
- [
668
- (col, query_cols.getcol(col, 0, 1))
669
- for col in cols
670
- if (col not in ignore)
671
- and not (ignore_msv2_cols and col in ignore_msv2_cols)
672
- ]
673
- )
674
- chan_cnt, pol_cnt = [
675
- (cdata[cc].shape[1], cdata[cc].shape[2])
676
- for cc in cdata
677
- if len(cdata[cc].shape) == 3
678
- ][0]
679
-
680
- mtable.close()
681
-
682
- mvars, mcoords, bvars, xds = {}, {}, {}, xr.Dataset()
683
- # loop over row chunks
684
- for rc in range(0, nrows, chunks[0]):
685
- crlen = min(chunks[0], nrows - rc) # chunk row length
686
- rcidxs = rowidxs[rc : rc + chunks[0]]
687
-
688
- # loop over each column and create delayed dask arrays
689
- for col in cdata.keys():
690
- if col not in bvars:
691
- bvars[col] = []
692
-
693
- if len(cdata[col].shape) == 1:
694
- delayed_array = dask.delayed(read_flat_col_chunk)(
695
- infile, col, (crlen,), rcidxs, None, None
696
- )
697
- bvars[col] += [
698
- dask.array.from_delayed(delayed_array, (crlen,), cdata[col].dtype)
699
- ]
700
-
701
- elif col == "UVW":
702
- delayed_array = dask.delayed(read_flat_col_chunk)(
703
- infile, col, (crlen, 3), rcidxs, None, None
704
- )
705
- bvars[col] += [
706
- dask.array.from_delayed(delayed_array, (crlen, 3), cdata[col].dtype)
707
- ]
708
-
709
- elif len(cdata[col].shape) == 2:
710
- pol_list = []
711
- dd = 1 if cdata[col].shape[1] == chan_cnt else 2
712
- for pc in range(0, cdata[col].shape[1], chunks[dd]):
713
- plen = min(chunks[dd], cdata[col].shape[1] - pc)
714
- delayed_array = dask.delayed(read_flat_col_chunk)(
715
- infile, col, (crlen, plen), rcidxs, None, pc
716
- )
717
- pol_list += [
718
- dask.array.from_delayed(
719
- delayed_array, (crlen, plen), cdata[col].dtype
720
- )
721
- ]
722
- bvars[col] += [dask.array.concatenate(pol_list, axis=1)]
723
-
724
- elif len(cdata[col].shape) == 3:
725
- chan_list = []
726
- for cc in range(0, chan_cnt, chunks[1]):
727
- clen = min(chunks[1], chan_cnt - cc)
728
- pol_list = []
729
- for pc in range(0, cdata[col].shape[2], chunks[2]):
730
- plen = min(chunks[2], cdata[col].shape[2] - pc)
731
- delayed_array = dask.delayed(read_flat_col_chunk)(
732
- infile, col, (crlen, clen, plen), rcidxs, cc, pc
733
- )
734
- pol_list += [
735
- dask.array.from_delayed(
736
- delayed_array, (crlen, clen, plen), cdata[col].dtype
737
- )
738
- ]
739
- chan_list += [dask.array.concatenate(pol_list, axis=2)]
740
- bvars[col] += [dask.array.concatenate(chan_list, axis=1)]
741
-
742
- # now concat all the dask chunks from each time to make the xds
743
- mvars = {}
744
- for kk in bvars.keys():
745
- data_var = kk
746
- if len(bvars[kk]) == 0:
747
- ignore += [kk]
748
- continue
749
- if kk == "UVW":
750
- mvars[data_var] = xr.DataArray(
751
- dask.array.concatenate(bvars[kk], axis=0), dims=["row", "uvw_coords"]
752
- )
753
- elif len(bvars[kk][0].shape) == 2 and (bvars[kk][0].shape[-1] == pol_cnt):
754
- mvars[data_var] = xr.DataArray(
755
- dask.array.concatenate(bvars[kk], axis=0), dims=["row", "pol"]
756
- )
757
- elif len(bvars[kk][0].shape) == 2 and (bvars[kk][0].shape[-1] == chan_cnt):
758
- mvars[data_var] = xr.DataArray(
759
- dask.array.concatenate(bvars[kk], axis=0), dims=["row", "chan"]
760
- )
761
- else:
762
- mvars[data_var] = xr.DataArray(
763
- dask.array.concatenate(bvars[kk], axis=0),
764
- dims=["row", "freq", "pol"][: len(bvars[kk][0].shape)],
765
- )
766
-
767
- mvars["time"] = xr.DataArray(
768
- convert_casacore_time(mvars["TIME"].values), dims=["row"]
769
- ).chunk({"row": chunks[0]})
770
-
771
- # add xds global attributes
772
- cc_attrs = extract_table_attributes(infile)
773
- attrs = {"other": {"msv2": {"ctds_attrs": cc_attrs, "bad_cols": ignore}}}
774
- # add per data var attributes
775
- mvars = add_units_measures(mvars, cc_attrs)
776
- mcoords = add_units_measures(mcoords, cc_attrs)
777
-
778
- mvars = rename_vars(mvars)
779
- mvars = redim_id_data_vars(mvars)
780
- xds = xr.Dataset(mvars, coords=mcoords)
781
-
782
- return xds, part_ids, attrs