xradio 0.0.56__py3-none-any.whl → 0.0.58__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. xradio/__init__.py +2 -2
  2. xradio/_utils/_casacore/casacore_from_casatools.py +12 -2
  3. xradio/_utils/_casacore/tables.py +1 -0
  4. xradio/_utils/coord_math.py +22 -23
  5. xradio/_utils/dict_helpers.py +76 -11
  6. xradio/_utils/schema.py +5 -2
  7. xradio/_utils/zarr/common.py +1 -73
  8. xradio/image/_util/_casacore/xds_from_casacore.py +49 -33
  9. xradio/image/_util/_casacore/xds_to_casacore.py +41 -14
  10. xradio/image/_util/_fits/xds_from_fits.py +146 -35
  11. xradio/image/_util/casacore.py +4 -3
  12. xradio/image/_util/common.py +4 -4
  13. xradio/image/_util/image_factory.py +8 -8
  14. xradio/image/image.py +45 -5
  15. xradio/measurement_set/__init__.py +19 -9
  16. xradio/measurement_set/_utils/__init__.py +1 -3
  17. xradio/measurement_set/_utils/_msv2/__init__.py +0 -0
  18. xradio/measurement_set/_utils/_msv2/_tables/read.py +17 -76
  19. xradio/measurement_set/_utils/_msv2/_tables/read_main_table.py +2 -685
  20. xradio/measurement_set/_utils/_msv2/conversion.py +123 -145
  21. xradio/measurement_set/_utils/_msv2/create_antenna_xds.py +9 -16
  22. xradio/measurement_set/_utils/_msv2/create_field_and_source_xds.py +125 -221
  23. xradio/measurement_set/_utils/_msv2/msv2_to_msv4_meta.py +1 -2
  24. xradio/measurement_set/_utils/_msv2/msv4_info_dicts.py +8 -7
  25. xradio/measurement_set/_utils/_msv2/msv4_sub_xdss.py +27 -72
  26. xradio/measurement_set/_utils/_msv2/partition_queries.py +1 -261
  27. xradio/measurement_set/_utils/_msv2/subtables.py +0 -107
  28. xradio/measurement_set/_utils/_utils/interpolate.py +60 -0
  29. xradio/measurement_set/_utils/_zarr/encoding.py +2 -7
  30. xradio/measurement_set/convert_msv2_to_processing_set.py +0 -2
  31. xradio/measurement_set/load_processing_set.py +2 -2
  32. xradio/measurement_set/measurement_set_xdt.py +14 -14
  33. xradio/measurement_set/open_processing_set.py +1 -3
  34. xradio/measurement_set/processing_set_xdt.py +41 -835
  35. xradio/measurement_set/schema.py +95 -122
  36. xradio/schema/check.py +91 -97
  37. xradio/schema/dataclass.py +159 -22
  38. xradio/schema/export.py +99 -0
  39. xradio/schema/metamodel.py +51 -16
  40. xradio/schema/typing.py +5 -5
  41. {xradio-0.0.56.dist-info → xradio-0.0.58.dist-info}/METADATA +2 -1
  42. xradio-0.0.58.dist-info/RECORD +65 -0
  43. {xradio-0.0.56.dist-info → xradio-0.0.58.dist-info}/WHEEL +1 -1
  44. xradio/image/_util/fits.py +0 -13
  45. xradio/measurement_set/_utils/_msv2/_tables/load.py +0 -66
  46. xradio/measurement_set/_utils/_msv2/_tables/load_main_table.py +0 -490
  47. xradio/measurement_set/_utils/_msv2/_tables/read_subtables.py +0 -398
  48. xradio/measurement_set/_utils/_msv2/_tables/write.py +0 -323
  49. xradio/measurement_set/_utils/_msv2/_tables/write_exp_api.py +0 -388
  50. xradio/measurement_set/_utils/_msv2/chunks.py +0 -115
  51. xradio/measurement_set/_utils/_msv2/descr.py +0 -165
  52. xradio/measurement_set/_utils/_msv2/msv2_msv3.py +0 -7
  53. xradio/measurement_set/_utils/_msv2/partitions.py +0 -392
  54. xradio/measurement_set/_utils/_utils/cds.py +0 -40
  55. xradio/measurement_set/_utils/_utils/xds_helper.py +0 -404
  56. xradio/measurement_set/_utils/_zarr/read.py +0 -263
  57. xradio/measurement_set/_utils/_zarr/write.py +0 -329
  58. xradio/measurement_set/_utils/msv2.py +0 -106
  59. xradio/measurement_set/_utils/zarr.py +0 -133
  60. xradio-0.0.56.dist-info/RECORD +0 -78
  61. {xradio-0.0.56.dist-info → xradio-0.0.58.dist-info}/licenses/LICENSE.txt +0 -0
  62. {xradio-0.0.56.dist-info → xradio-0.0.58.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,6 @@
1
- import toolviper.utils.logger as logger
2
- from typing import Any, Dict, List, Tuple, Union
1
+ from typing import Tuple
3
2
 
4
- import dask, dask.array
5
3
  import numpy as np
6
- import xarray as xr
7
4
  import pandas as pd
8
5
 
9
6
  try:
@@ -11,324 +8,14 @@ try:
11
8
  except ImportError:
12
9
  import xradio._utils._casacore.casacore_from_casatools as tables
13
10
 
14
- from .read import (
15
- read_flat_col_chunk,
16
- read_col_chunk,
17
- convert_casacore_time,
18
- extract_table_attributes,
19
- add_units_measures,
20
- )
21
11
 
22
- from .table_query import open_table_ro, open_query
12
+ from .table_query import open_query
23
13
  from xradio._utils.list_and_array import (
24
14
  unique_1d,
25
15
  pairing_function,
26
16
  inverse_pairing_function,
27
17
  )
28
18
 
29
- rename_msv2_cols = {
30
- "ANTENNA1": "antenna1_id",
31
- "ANTENNA2": "antenna2_id",
32
- "FEED1": "feed1_id",
33
- "FEED2": "feed2_id",
34
- # optional cols:
35
- "WEIGHT_SPECTRUM": "WEIGHT",
36
- "CORRECTED_DATA": "VIS_CORRECTED",
37
- "DATA": "VIS",
38
- "MODEL_DATA": "VIS_MODEL",
39
- "FLOAT_DATA": "AUTOCORR",
40
- }
41
-
42
-
43
- def rename_vars(mvars: Dict[str, xr.DataArray]) -> Dict[str, xr.DataArray]:
44
- """
45
- Apply rename rules. Also preserve ordering of data_vars
46
-
47
- Note: not using xr.DataArray.rename because we have optional
48
- column renames and rename complains if some of the names passed
49
- are not present in the dataset
50
-
51
- Parameters
52
- ----------
53
- mvars : Dict[str, xr.DataArray]
54
- dictionary of data_vars to be used to create an xr.Dataset
55
-
56
- Returns
57
- -------
58
- Dict[str, xr.DataArray]
59
- similar dictionary after applying MSv2 => MSv3/ngCASA renaming rules
60
- """
61
- renamed = {
62
- rename_msv2_cols[name] if name in rename_msv2_cols else name: var
63
- for name, var in mvars.items()
64
- }
65
-
66
- return renamed
67
-
68
-
69
- def redim_id_data_vars(mvars: Dict[str, xr.DataArray]) -> Dict[str, xr.DataArray]:
70
- """
71
- Changes:
72
- Several id data variables to drop its baseline dim
73
- The antenna id data vars:
74
- From MS (antenna1_id(time, baseline), antenna2_id(time,baseline)
75
- To cds (baseline_ant1_id(baseline), baseline_ant2_id(baseline)
76
-
77
- Parameters
78
- ----------
79
- mvars : Dict[str, xr.DataArray]
80
- data variables being prepared for a partition xds
81
-
82
- Returns
83
- -------
84
- Dict[str, xr.DataArray]
85
- data variables with the ant id ones modified to cds type
86
- """
87
- # Vars to drop baseline dim
88
- var_names = [
89
- "ARRAY_ID",
90
- "OBSERVATION_ID",
91
- "PROCESSOR_ID",
92
- "SCAN_NUMBER",
93
- "STATE_ID",
94
- ]
95
- for vname in var_names:
96
- if "baseline" in mvars[vname].coords:
97
- mvars[vname] = mvars[vname].sel(baseline=0, drop=True)
98
-
99
- for idx in ["1", "2"]:
100
- new_name = f"baseline_ant{idx}_id"
101
- mvars[new_name] = mvars.pop(f"antenna{idx}_id")
102
- if "time" in mvars[new_name].coords:
103
- mvars[new_name] = mvars[new_name].sel(time=0, drop=True)
104
-
105
- return mvars
106
-
107
-
108
- def get_partition_ids(mtable: tables.table, taql_where: str) -> Dict:
109
- """
110
- Get some of the partition IDs that we have to retrieve from some
111
- of the top level ID/sorting cols of the main table of the MS.
112
-
113
- Parameters
114
- ----------
115
- mtable : tables.table
116
- MS main table
117
- taql_where : str
118
- where part that defines the partition in TaQL
119
-
120
- Returns
121
- -------
122
- Dict
123
- ids of array, observation, and processor
124
- """
125
-
126
- taql_ids = f"select DISTINCT ARRAY_ID, OBSERVATION_ID, PROCESSOR_ID from $mtable {taql_where}"
127
- with open_query(mtable, taql_ids) as query:
128
- # array_id, observation_id, processor_id
129
- array_id = unique_1d(query.getcol("ARRAY_ID"))
130
- obs_id = unique_1d(query.getcol("OBSERVATION_ID"))
131
- proc_id = unique_1d(query.getcol("PROCESSOR_ID"))
132
- check_vars = [
133
- (array_id, "array_id"),
134
- (obs_id, "observation_id"),
135
- (proc_id, "processor_id"),
136
- ]
137
- for var, var_name in check_vars:
138
- if len(var) != 1:
139
- logger.warning(
140
- f"Did not get exactly one {var_name} (got {var} for this partition. TaQL: {taql_where}"
141
- )
142
-
143
- pids = {
144
- "array_id": list(array_id),
145
- "observation_id": list(obs_id),
146
- "processor_id": list(proc_id),
147
- }
148
- return pids
149
-
150
-
151
- def read_expanded_main_table(
152
- infile: str,
153
- ddi: int = 0,
154
- scan_state: Union[Tuple[int, int], None] = None,
155
- ignore_msv2_cols: Union[list, None] = None,
156
- chunks: Tuple[int, ...] = (400, 200, 100, 2),
157
- ) -> Tuple[xr.Dataset, Dict[str, Any], Dict[str, Any]]:
158
- """
159
- Reads one partition from the main table, all columns.
160
- This is the expanded version (time, baseline) dims.
161
-
162
- Chunk tuple: (time, baseline, freq, pol)
163
-
164
- Parameters
165
- ----------
166
- infile : str
167
-
168
- ddi : int (Default value = 0)
169
-
170
- scan_state : Union[Tuple[int, int], None] (Default value = None)
171
-
172
- ignore_msv2_cols: Union[list, None] (Default value = None)
173
-
174
- chunks: Tuple[int, ...] (Default value = (400, 200, 100, 2))
175
-
176
-
177
- Returns
178
- -------
179
- Tuple[xr.Dataset, Dict[str, Any], Dict[str, Any]]
180
- """
181
- if ignore_msv2_cols is None:
182
- ignore_msv2_cols = []
183
-
184
- taql_where = f"where DATA_DESC_ID = {ddi}"
185
- if scan_state:
186
- # get partitions by scan/state
187
- scan, state = scan_state
188
- if type(state) == np.ndarray:
189
- state_ids_or = " OR STATE_ID = ".join(np.char.mod("%d", state))
190
- taql_where += f" AND (STATE_ID = {state_ids_or})"
191
- elif state:
192
- taql_where += f" AND SCAN_NUMBER = {scan} AND STATE_ID = {state}"
193
- elif scan:
194
- # scan can also be None, when partition_scheme='intent'
195
- # but the STATE table is empty!
196
- taql_where += f" AND SCAN_NUMBER = {scan}"
197
-
198
- with open_table_ro(infile) as mtable:
199
- # one partition, select just the specified ddi (+ scan/subscan)
200
- taql_main = f"select * from $mtable {taql_where}"
201
- with open_query(mtable, taql_main) as tb_tool:
202
- if tb_tool.nrows() == 0:
203
- tb_tool.close()
204
- mtable.close()
205
- return xr.Dataset(), {}, {}
206
-
207
- xds, attrs = read_main_table_chunks(
208
- infile, tb_tool, taql_where, ignore_msv2_cols, chunks
209
- )
210
- part_ids = get_partition_ids(tb_tool, taql_where)
211
-
212
- return xds, part_ids, attrs
213
-
214
-
215
- def read_main_table_chunks(
216
- infile: str,
217
- tb_tool: tables.table,
218
- taql_where: str,
219
- ignore_msv2_cols: Union[list, None] = None,
220
- chunks: Tuple[int, ...] = (400, 200, 100, 2),
221
- ) -> Tuple[xr.Dataset, Dict[str, Any]]:
222
- """
223
- Iterates through the time,baseline chunks and reads slices from
224
- all the data columns.
225
-
226
- Parameters
227
- ----------
228
- infile : str
229
-
230
- tb_tool : tables.table
231
-
232
- taql_where : str
233
-
234
- ignore_msv2_cols : Union[list, None] (Default value = None)
235
-
236
- chunks: Tuple[int, ...] (Default value = (400, 200, 100, 2))
237
-
238
-
239
- Returns
240
- -------
241
- Tuple[xr.Dataset, Dict[str, Any]]
242
- """
243
- baselines = get_baselines(tb_tool)
244
-
245
- col_names = tb_tool.colnames()
246
- cshapes = [
247
- np.array(tb_tool.getcell(col, 0)).shape
248
- for col in col_names
249
- if tb_tool.iscelldefined(col, 0)
250
- ]
251
- chan_cnt, pol_cnt = [(cc[0], cc[1]) for cc in cshapes if len(cc) == 2][0]
252
-
253
- unique_times, tol = get_utimes_tol(tb_tool, taql_where)
254
-
255
- tvars = {}
256
- n_baselines = len(baselines)
257
- n_unique_times = len(unique_times)
258
- n_time_chunks = chunks[0]
259
- n_baseline_chunks = chunks[1]
260
- # loop over time chunks
261
- for time_chunk in range(0, n_unique_times, n_time_chunks):
262
- time_start = unique_times[time_chunk] - tol
263
- time_end = (
264
- unique_times[min(n_unique_times, time_chunk + n_time_chunks) - 1] + tol
265
- )
266
-
267
- # chunk time length
268
- ctlen = min(n_unique_times, time_chunk + n_time_chunks) - time_chunk
269
-
270
- bvars = {}
271
- # loop over baseline chunks
272
- for baseline_chunk in range(0, n_baselines, n_baseline_chunks):
273
- cblen = min(n_baselines - baseline_chunk, n_baseline_chunks)
274
-
275
- # read the specified chunk of data
276
- # def read_chunk(infile, ddi, times, blines, chans, pols):
277
- ttql = f"TIME BETWEEN {time_start} and {time_end}"
278
- ant1_start = baselines[baseline_chunk][0]
279
- ant1_end = baselines[cblen + baseline_chunk - 1][0]
280
- atql = f"ANTENNA1 BETWEEN {ant1_start} and {ant1_end}"
281
- ts_taql = f"select * from $mtable {taql_where} AND {ttql} AND {atql}"
282
- with open_query(None, ts_taql) as query_times_ants:
283
- tidxs = (
284
- np.searchsorted(
285
- unique_times, query_times_ants.getcol("TIME", 0, -1)
286
- )
287
- - time_chunk
288
- )
289
- ts_ant1, ts_ant2 = (
290
- query_times_ants.getcol("ANTENNA1", 0, -1),
291
- query_times_ants.getcol("ANTENNA2", 0, -1),
292
- )
293
-
294
- ts_bases = np.column_stack((ts_ant1, ts_ant2))
295
-
296
- bidxs = get_baseline_indices(baselines, ts_bases) - baseline_chunk
297
-
298
- # some antenna 2's will be out of bounds for this chunk, store rows that are in bounds
299
- didxs = np.where(
300
- (bidxs >= 0) & (bidxs < min(chunks[1], n_baselines - baseline_chunk))
301
- )[0]
302
-
303
- delayed_params = (infile, ts_taql, (ctlen, cblen), tidxs, bidxs, didxs)
304
-
305
- read_all_cols_bvars(
306
- tb_tool, chunks, chan_cnt, ignore_msv2_cols, delayed_params, bvars
307
- )
308
-
309
- concat_bvars_update_tvars(bvars, tvars)
310
-
311
- dims = ["time", "baseline", "freq", "pol"]
312
- mvars = concat_tvars_to_mvars(dims, tvars, pol_cnt, chan_cnt)
313
-
314
- mcoords = {
315
- "time": xr.DataArray(convert_casacore_time(unique_times), dims=["time"]),
316
- "baseline": xr.DataArray(np.arange(n_baselines), dims=["baseline"]),
317
- }
318
-
319
- # add xds global attributes
320
- cc_attrs = extract_table_attributes(infile)
321
- attrs = {"other": {"msv2": {"ctds_attrs": cc_attrs, "bad_cols": ignore_msv2_cols}}}
322
- # add per data var attributes
323
- mvars = add_units_measures(mvars, cc_attrs)
324
- mcoords = add_units_measures(mcoords, cc_attrs)
325
-
326
- mvars = rename_vars(mvars)
327
- mvars = redim_id_data_vars(mvars)
328
- xds = xr.Dataset(mvars, coords=mcoords)
329
-
330
- return xds, attrs
331
-
332
19
 
333
20
  def get_utimes_tol(mtable: tables.table, taql_where: str) -> Tuple[np.ndarray, float]:
334
21
  taql_utimes = f"select DISTINCT TIME from $mtable {taql_where}"
@@ -413,373 +100,3 @@ def get_baseline_indices(
413
100
  baseline_indices = unique_baselines_sorted[sorted_indices]
414
101
 
415
102
  return baseline_indices
416
-
417
-
418
- def read_all_cols_bvars(
419
- tb_tool: tables.table,
420
- chunks: Tuple[int, ...],
421
- chan_cnt: int,
422
- ignore_msv2_cols: bool,
423
- delayed_params: Tuple,
424
- bvars: Dict[str, xr.DataArray],
425
- ) -> None:
426
- """
427
- Loops over each column and create delayed dask arrays
428
-
429
- Parameters
430
- ----------
431
- tb_tool : tables.table
432
-
433
- chunks : Tuple[int, ...]
434
-
435
- chan_cnt : int
436
-
437
- ignore_msv2_cols : bool
438
-
439
- delayed_params : Tuple
440
-
441
- bvars : Dict[str, xr.DataArray]
442
-
443
-
444
- Returns
445
- -------
446
-
447
- """
448
-
449
- col_names = tb_tool.colnames()
450
- for col in col_names:
451
- if (col in ignore_msv2_cols + ["TIME"]) or (not tb_tool.iscelldefined(col, 0)):
452
- continue
453
- if col not in bvars:
454
- bvars[col] = []
455
-
456
- cdata = tb_tool.getcol(col, 0, 1)[0]
457
-
458
- if len(cdata.shape) == 0:
459
- infile, ts_taql, (ctlen, cblen), tidxs, bidxs, didxs = delayed_params
460
- cshape = (ctlen, cblen)
461
- delayed_col = infile, ts_taql, col, cshape, tidxs, bidxs, didxs
462
- delayed_array = dask.delayed(read_col_chunk)(*delayed_col, None, None)
463
- bvars[col] += [dask.array.from_delayed(delayed_array, cshape, cdata.dtype)]
464
-
465
- elif col == "UVW":
466
- infile, ts_taql, (ctlen, cblen), tidxs, bidxs, didxs = delayed_params
467
- cshape = (ctlen, cblen, 3)
468
- delayed_3 = infile, ts_taql, col, cshape, tidxs, bidxs, didxs
469
- delayed_array = dask.delayed(read_col_chunk)(*delayed_3, None, None)
470
- bvars[col] += [dask.array.from_delayed(delayed_array, cshape, cdata.dtype)]
471
-
472
- elif len(cdata.shape) == 1:
473
- pol_list = []
474
- dd = 2 if cdata.shape == chan_cnt else 3
475
- for pc in range(0, cdata.shape[0], chunks[dd]):
476
- pols = (pc, min(cdata.shape[0], pc + chunks[dd]) - 1)
477
- infile, ts_taql, (ctlen, cblen), tidxs, bidxs, didxs = delayed_params
478
- cshape = (
479
- ctlen,
480
- cblen,
481
- ) + (pols[1] - pols[0] + 1,)
482
- delayed_cs = infile, ts_taql, col, cshape, tidxs, bidxs, didxs
483
- delayed_array = dask.delayed(read_col_chunk)(*delayed_cs, pols, None)
484
- pol_list += [
485
- dask.array.from_delayed(delayed_array, cshape, cdata.dtype)
486
- ]
487
- bvars[col] += [dask.array.concatenate(pol_list, axis=2)]
488
-
489
- elif len(cdata.shape) == 2:
490
- chan_list = []
491
- for cc in range(0, cdata.shape[0], chunks[2]):
492
- chans = (cc, min(cdata.shape[0], cc + chunks[2]) - 1)
493
- pol_list = []
494
- for pc in range(0, cdata.shape[1], chunks[3]):
495
- pols = (pc, min(cdata.shape[1], pc + chunks[3]) - 1)
496
- (
497
- infile,
498
- ts_taql,
499
- (ctlen, cblen),
500
- tidxs,
501
- bidxs,
502
- didxs,
503
- ) = delayed_params
504
- cshape = (
505
- ctlen,
506
- cblen,
507
- ) + (chans[1] - chans[0] + 1, pols[1] - pols[0] + 1)
508
- delayed_cs = infile, ts_taql, col, cshape, tidxs, bidxs, didxs
509
- delayed_array = dask.delayed(read_col_chunk)(
510
- *delayed_cs, chans, pols
511
- )
512
- pol_list += [
513
- dask.array.from_delayed(delayed_array, cshape, cdata.dtype)
514
- ]
515
- chan_list += [dask.array.concatenate(pol_list, axis=3)]
516
- bvars[col] += [dask.array.concatenate(chan_list, axis=2)]
517
-
518
-
519
- def concat_bvars_update_tvars(
520
- bvars: Dict[str, xr.DataArray], tvars: Dict[str, xr.DataArray]
521
- ) -> None:
522
- """
523
- concats all the dask chunks from each baseline. This is intended to
524
- be called iteratively, for every time chunk iteration, once all the
525
- baseline chunks have been read.
526
-
527
- Parameters
528
- ----------
529
- bvars: Dict[str, xr.DataArray]
530
-
531
- tvars: Dict[str, xr.DataArray]
532
-
533
-
534
- Returns
535
- -------
536
-
537
- """
538
- for kk in bvars.keys():
539
- if len(bvars[kk]) == 0:
540
- continue
541
- if kk not in tvars:
542
- tvars[kk] = []
543
- tvars[kk] += [dask.array.concatenate(bvars[kk], axis=1)]
544
-
545
-
546
- def concat_tvars_to_mvars(
547
- dims: List[str], tvars: Dict[str, xr.DataArray], pol_cnt: int, chan_cnt: int
548
- ) -> Dict[str, xr.DataArray]:
549
- """
550
- Concat into a single dask array all the dask arrays from each time
551
- chunk to make the final arrays of the xds.
552
-
553
- Parameters
554
- ----------
555
- dims : List[str]
556
- dimension names
557
- tvars : Dict[str, xr.DataArray]
558
- variables as lists of dask arrays per time chunk
559
- pol_cnt : int
560
- len of pol axis/dim
561
- chan_cnt : int
562
- len of freq axis/dim (chan indices)
563
-
564
- Returns
565
- -------
566
- Dict[str, xr.DataArray]
567
- variables as concated dask arrays
568
- """
569
-
570
- mvars = {}
571
- for tvr in tvars.keys():
572
- data_var = tvr
573
- if tvr == "UVW":
574
- mvars[data_var] = xr.DataArray(
575
- dask.array.concatenate(tvars[tvr], axis=0),
576
- dims=dims[:2] + ["uvw_coords"],
577
- )
578
- elif len(tvars[tvr][0].shape) == 3 and (tvars[tvr][0].shape[-1] == pol_cnt):
579
- mvars[data_var] = xr.DataArray(
580
- dask.array.concatenate(tvars[tvr], axis=0), dims=dims[:2] + ["pol"]
581
- )
582
- elif len(tvars[tvr][0].shape) == 3 and (tvars[tvr][0].shape[-1] == chan_cnt):
583
- mvars[data_var] = xr.DataArray(
584
- dask.array.concatenate(tvars[tvr], axis=0), dims=dims[:2] + ["freq"]
585
- )
586
- else:
587
- mvars[data_var] = xr.DataArray(
588
- dask.array.concatenate(tvars[tvr], axis=0),
589
- dims=dims[: len(tvars[tvr][0].shape)],
590
- )
591
-
592
- return mvars
593
-
594
-
595
- def read_flat_main_table(
596
- infile: str,
597
- ddi: Union[int, None] = None,
598
- scan_state: Union[Tuple[int, int], None] = None,
599
- rowidxs=None,
600
- ignore_msv2_cols: Union[List, None] = None,
601
- chunks: Tuple[int, ...] = (22000, 512, 2),
602
- ) -> Tuple[xr.Dataset, Dict[str, Any], Dict[str, Any]]:
603
- """
604
- Read main table using flat structure: no baseline dimension. Vis
605
- dimensions are: row, freq, pol
606
- (experimental, perhaps to be deprecated/removed). Works but some
607
- features may be missing and/or flaky.
608
-
609
- Chunk tuple: (row, freq, pol)
610
-
611
- Parameters
612
- ----------
613
- infile : str
614
-
615
- ddi : Union[int, None] (Default value = None)
616
-
617
- scan_state : Union[Tuple[int, int], None] (Default value = None)
618
-
619
- rowidxs : np.ndarray (Default value = None)
620
-
621
- ignore_msv2_cols : Union[List, None] (Default value = None)
622
-
623
- chunks : Tuple[int, ...] (Default value = (22000, 512, 2))
624
-
625
- Returns
626
- -------
627
- Tuple[xr.Dataset, Dict[str, Any], Dict[str, Any]]
628
- """
629
- taql_where = f"where DATA_DESC_ID = {ddi}"
630
- if scan_state:
631
- # TODO: support additional intent/scan/subscan conditions if
632
- # we keep this read_flat functionality
633
- scans, states = scan_state
634
- # get row indices relative to full main table
635
- if type(states) == np.ndarray:
636
- state_ids_or = " OR STATE_ID = ".join(np.char.mod("%d", states))
637
- taql_where += f" AND (STATE_ID = {state_ids_or})"
638
- elif states is not None:
639
- taql_where += f" AND STATE_ID = {states}"
640
- elif scans is not None:
641
- taql_where += f" AND SCAN_NUMBER = {scans}"
642
-
643
- mtable = tables.table(
644
- infile, readonly=True, lockoptions={"option": "usernoread"}, ack=False
645
- )
646
-
647
- # get row indices relative to full main table
648
- if rowidxs is None:
649
- taql_rowid = f"select rowid() as ROWS from $mtable {taql_where}"
650
- with open_query(mtable, taql_rowid) as query_rows:
651
- rowidxs = query_rows.getcol("ROWS")
652
-
653
- nrows = len(rowidxs)
654
- if nrows == 0:
655
- mtable.close()
656
- return xr.Dataset(), {}, {}
657
-
658
- part_ids = get_partition_ids(mtable, taql_where)
659
-
660
- taql_cols = f"select * from $mtable {taql_where}"
661
- with open_query(mtable, taql_cols) as query_cols:
662
- cols = query_cols.colnames()
663
- ignore = [
664
- col
665
- for col in cols
666
- if (not query_cols.iscelldefined(col, 0))
667
- or (query_cols.coldatatype(col) == "record")
668
- ]
669
- cdata = dict(
670
- [
671
- (col, query_cols.getcol(col, 0, 1))
672
- for col in cols
673
- if (col not in ignore)
674
- and not (ignore_msv2_cols and col in ignore_msv2_cols)
675
- ]
676
- )
677
- chan_cnt, pol_cnt = [
678
- (cdata[cc].shape[1], cdata[cc].shape[2])
679
- for cc in cdata
680
- if len(cdata[cc].shape) == 3
681
- ][0]
682
-
683
- mtable.close()
684
-
685
- mvars, mcoords, bvars, xds = {}, {}, {}, xr.Dataset()
686
- # loop over row chunks
687
- for rc in range(0, nrows, chunks[0]):
688
- crlen = min(chunks[0], nrows - rc) # chunk row length
689
- rcidxs = rowidxs[rc : rc + chunks[0]]
690
-
691
- # loop over each column and create delayed dask arrays
692
- for col in cdata.keys():
693
- if col not in bvars:
694
- bvars[col] = []
695
-
696
- if len(cdata[col].shape) == 1:
697
- delayed_array = dask.delayed(read_flat_col_chunk)(
698
- infile, col, (crlen,), rcidxs, None, None
699
- )
700
- bvars[col] += [
701
- dask.array.from_delayed(delayed_array, (crlen,), cdata[col].dtype)
702
- ]
703
-
704
- elif col == "UVW":
705
- delayed_array = dask.delayed(read_flat_col_chunk)(
706
- infile, col, (crlen, 3), rcidxs, None, None
707
- )
708
- bvars[col] += [
709
- dask.array.from_delayed(delayed_array, (crlen, 3), cdata[col].dtype)
710
- ]
711
-
712
- elif len(cdata[col].shape) == 2:
713
- pol_list = []
714
- dd = 1 if cdata[col].shape[1] == chan_cnt else 2
715
- for pc in range(0, cdata[col].shape[1], chunks[dd]):
716
- plen = min(chunks[dd], cdata[col].shape[1] - pc)
717
- delayed_array = dask.delayed(read_flat_col_chunk)(
718
- infile, col, (crlen, plen), rcidxs, None, pc
719
- )
720
- pol_list += [
721
- dask.array.from_delayed(
722
- delayed_array, (crlen, plen), cdata[col].dtype
723
- )
724
- ]
725
- bvars[col] += [dask.array.concatenate(pol_list, axis=1)]
726
-
727
- elif len(cdata[col].shape) == 3:
728
- chan_list = []
729
- for cc in range(0, chan_cnt, chunks[1]):
730
- clen = min(chunks[1], chan_cnt - cc)
731
- pol_list = []
732
- for pc in range(0, cdata[col].shape[2], chunks[2]):
733
- plen = min(chunks[2], cdata[col].shape[2] - pc)
734
- delayed_array = dask.delayed(read_flat_col_chunk)(
735
- infile, col, (crlen, clen, plen), rcidxs, cc, pc
736
- )
737
- pol_list += [
738
- dask.array.from_delayed(
739
- delayed_array, (crlen, clen, plen), cdata[col].dtype
740
- )
741
- ]
742
- chan_list += [dask.array.concatenate(pol_list, axis=2)]
743
- bvars[col] += [dask.array.concatenate(chan_list, axis=1)]
744
-
745
- # now concat all the dask chunks from each time to make the xds
746
- mvars = {}
747
- for kk in bvars.keys():
748
- data_var = kk
749
- if len(bvars[kk]) == 0:
750
- ignore += [kk]
751
- continue
752
- if kk == "UVW":
753
- mvars[data_var] = xr.DataArray(
754
- dask.array.concatenate(bvars[kk], axis=0), dims=["row", "uvw_coords"]
755
- )
756
- elif len(bvars[kk][0].shape) == 2 and (bvars[kk][0].shape[-1] == pol_cnt):
757
- mvars[data_var] = xr.DataArray(
758
- dask.array.concatenate(bvars[kk], axis=0), dims=["row", "pol"]
759
- )
760
- elif len(bvars[kk][0].shape) == 2 and (bvars[kk][0].shape[-1] == chan_cnt):
761
- mvars[data_var] = xr.DataArray(
762
- dask.array.concatenate(bvars[kk], axis=0), dims=["row", "chan"]
763
- )
764
- else:
765
- mvars[data_var] = xr.DataArray(
766
- dask.array.concatenate(bvars[kk], axis=0),
767
- dims=["row", "freq", "pol"][: len(bvars[kk][0].shape)],
768
- )
769
-
770
- mvars["time"] = xr.DataArray(
771
- convert_casacore_time(mvars["TIME"].values), dims=["row"]
772
- ).chunk({"row": chunks[0]})
773
-
774
- # add xds global attributes
775
- cc_attrs = extract_table_attributes(infile)
776
- attrs = {"other": {"msv2": {"ctds_attrs": cc_attrs, "bad_cols": ignore}}}
777
- # add per data var attributes
778
- mvars = add_units_measures(mvars, cc_attrs)
779
- mcoords = add_units_measures(mcoords, cc_attrs)
780
-
781
- mvars = rename_vars(mvars)
782
- mvars = redim_id_data_vars(mvars)
783
- xds = xr.Dataset(mvars, coords=mcoords)
784
-
785
- return xds, part_ids, attrs