xradio 0.0.27__py3-none-any.whl → 0.0.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. xradio/__init__.py +5 -4
  2. xradio/_utils/array.py +90 -0
  3. xradio/_utils/zarr/common.py +48 -3
  4. xradio/image/_util/_fits/xds_from_fits.py +10 -5
  5. xradio/image/_util/_zarr/zarr_low_level.py +27 -24
  6. xradio/image/_util/common.py +4 -1
  7. xradio/image/_util/zarr.py +4 -1
  8. xradio/schema/__init__.py +24 -6
  9. xradio/schema/bases.py +440 -2
  10. xradio/schema/check.py +96 -55
  11. xradio/schema/dataclass.py +123 -27
  12. xradio/schema/metamodel.py +21 -4
  13. xradio/schema/typing.py +33 -18
  14. xradio/vis/__init__.py +5 -2
  15. xradio/vis/_processing_set.py +30 -9
  16. xradio/vis/_vis_utils/_ms/_tables/create_field_and_source_xds.py +710 -0
  17. xradio/vis/_vis_utils/_ms/_tables/load.py +23 -10
  18. xradio/vis/_vis_utils/_ms/_tables/load_main_table.py +145 -64
  19. xradio/vis/_vis_utils/_ms/_tables/read.py +782 -156
  20. xradio/vis/_vis_utils/_ms/_tables/read_main_table.py +176 -45
  21. xradio/vis/_vis_utils/_ms/_tables/read_subtables.py +79 -28
  22. xradio/vis/_vis_utils/_ms/_tables/write.py +102 -45
  23. xradio/vis/_vis_utils/_ms/_tables/write_exp_api.py +127 -65
  24. xradio/vis/_vis_utils/_ms/chunks.py +58 -21
  25. xradio/vis/_vis_utils/_ms/conversion.py +536 -67
  26. xradio/vis/_vis_utils/_ms/descr.py +52 -20
  27. xradio/vis/_vis_utils/_ms/msv2_to_msv4_meta.py +70 -35
  28. xradio/vis/_vis_utils/_ms/msv4_infos.py +0 -59
  29. xradio/vis/_vis_utils/_ms/msv4_sub_xdss.py +76 -9
  30. xradio/vis/_vis_utils/_ms/optimised_functions.py +0 -46
  31. xradio/vis/_vis_utils/_ms/partition_queries.py +308 -119
  32. xradio/vis/_vis_utils/_ms/partitions.py +82 -25
  33. xradio/vis/_vis_utils/_ms/subtables.py +32 -14
  34. xradio/vis/_vis_utils/_utils/partition_attrs.py +30 -11
  35. xradio/vis/_vis_utils/_utils/xds_helper.py +136 -45
  36. xradio/vis/_vis_utils/_zarr/read.py +60 -22
  37. xradio/vis/_vis_utils/_zarr/write.py +83 -9
  38. xradio/vis/_vis_utils/ms.py +48 -29
  39. xradio/vis/_vis_utils/zarr.py +44 -20
  40. xradio/vis/convert_msv2_to_processing_set.py +106 -32
  41. xradio/vis/load_processing_set.py +38 -61
  42. xradio/vis/read_processing_set.py +62 -96
  43. xradio/vis/schema.py +687 -0
  44. xradio/vis/vis_io.py +75 -43
  45. {xradio-0.0.27.dist-info → xradio-0.0.29.dist-info}/LICENSE.txt +6 -1
  46. {xradio-0.0.27.dist-info → xradio-0.0.29.dist-info}/METADATA +10 -5
  47. xradio-0.0.29.dist-info/RECORD +73 -0
  48. {xradio-0.0.27.dist-info → xradio-0.0.29.dist-info}/WHEEL +1 -1
  49. xradio/vis/model.py +0 -497
  50. xradio-0.0.27.dist-info/RECORD +0 -71
  51. {xradio-0.0.27.dist-info → xradio-0.0.29.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,26 @@
1
1
  import numcodecs
2
+ import math
2
3
  import time
3
4
  from .._zarr.encoding import add_encoding
4
5
  from typing import Dict, Union
5
6
  import graphviper.utils.logger as logger
7
+ import os
6
8
 
7
9
  import numpy as np
8
10
  import xarray as xr
9
11
 
10
- from .msv4_infos import create_field_info
12
+ from casacore import tables
11
13
  from .msv4_sub_xdss import create_ant_xds, create_pointing_xds, create_weather_xds
14
+ from xradio.vis._vis_utils._ms._tables.create_field_and_source_xds import (
15
+ create_field_and_source_xds,
16
+ )
12
17
  from .msv2_to_msv4_meta import (
13
18
  column_description_casacore_to_msv4_measure,
14
19
  create_attribute_metadata,
15
20
  col_to_data_variable_names,
16
21
  col_dims,
17
22
  )
18
- from .partition_queries import create_taql_query_and_file_name
23
+
19
24
  from .subtables import subt_rename_ids
20
25
  from ._tables.table_query import open_table_ro, open_query
21
26
  from ._tables.read import (
@@ -26,28 +31,333 @@ from ._tables.read import (
26
31
  )
27
32
  from ._tables.read_main_table import get_baselines, get_baseline_indices, get_utimes_tol
28
33
  from .._utils.stokes_types import stokes_types
29
- from xradio.vis._vis_utils._ms.optimised_functions import unique_1d
34
+ from xradio._utils.array import check_if_consistent, unique_1d
30
35
 
31
36
 
32
- def check_if_consistent(col, col_name):
33
- """_summary_
37
+ def parse_chunksize(
38
+ chunksize: Union[Dict, float, None], xds_type: str, xds: xr.Dataset
39
+ ) -> Dict[str, int]:
40
+ """
41
+ Parameters
42
+ ----------
43
+ chunksize : Union[Dict, float, None]
44
+ Desired maximum size of the chunks, either as a dict of per-dimension sizes or as
45
+ an amount of memory
46
+ xds_type : str
47
+ whether to use chunking logic for main or pointing datasets
48
+ xds : xr.Dataset
49
+ dataset to calculate best chunking
50
+
51
+ Returns
52
+ -------
53
+ Dict[str, int]
54
+ dictionary of chunk sizes (as dim->size)
55
+ """
56
+ if isinstance(chunksize, dict):
57
+ check_chunksize(chunksize, xds_type)
58
+ elif isinstance(chunksize, float):
59
+ chunksize = mem_chunksize_to_dict(chunksize, xds_type, xds)
60
+ elif chunksize is not None:
61
+ raise ValueError(
62
+ f"Chunk size expected as a dict or a float, got: "
63
+ f" {chunksize} (of type {type(chunksize)}"
64
+ )
65
+
66
+ return chunksize
67
+
68
+
69
+ def check_chunksize(chunksize: dict, xds_type: str) -> None:
70
+ """
71
+ Rudimentary check of the chunksize parameters to catch obvious errors early before
72
+ more work is done.
73
+ """
74
+ # perphaps start using some TypeDict or/and validator like pydantic?
75
+ if xds_type == "main":
76
+ allowed_dims = [
77
+ "time",
78
+ "baseline_id",
79
+ "antenna_id",
80
+ "frequency",
81
+ "polarization",
82
+ ]
83
+ elif xds_type == "pointing":
84
+ allowed_dims = ["time", "antenna"]
85
+
86
+ msg = ""
87
+ for dim in chunksize.keys():
88
+ if dim not in allowed_dims:
89
+ msg += f"dimension {dim} not allowed in {xds_type} dataset:\n"
90
+ if msg:
91
+ raise ValueError(f"Wrong keys found in chunksize: {msg}")
92
+
93
+
94
+ def mem_chunksize_to_dict(
95
+ chunksize: float, xds_type: str, xds: xr.Dataset
96
+ ) -> Dict[str, int]:
97
+ """
98
+ Given a desired 'chunksize' as amount of memory in GB, calculate best chunk sizes
99
+ for every dimension of an xds.
34
100
 
35
101
  Parameters
36
102
  ----------
37
- col : _type_
38
- _description_
39
- col_name : _type_
40
- _description_
103
+ chunksize : float
104
+ Desired maximum size of the chunks
105
+ xds_type : str
106
+ whether to use chunking logic for main or pointing datasets
107
+ xds : xr.Dataset
108
+ dataset to auto-calculate chunking of its dimensions
41
109
 
42
110
  Returns
43
111
  -------
44
- _type_
45
- _description_
112
+ Dict[str, int]
113
+ dictionary of chunk sizes (as dim->size)
114
+ """
115
+
116
+ if xds_type == "pointing":
117
+ sizes = mem_chunksize_to_dict_pointing(chunksize, xds)
118
+ elif xds_type == "main":
119
+ sizes = mem_chunksize_to_dict_main(chunksize, xds)
120
+ else:
121
+ raise RuntimeError(f"Unexpected type: {xds_type=}")
122
+
123
+ return sizes
124
+
125
+
126
+ GiBYTES_TO_BYTES = 1024 * 1024 * 1024
127
+
128
+
129
+ def mem_chunksize_to_dict_main(chunksize: float, xds: xr.Dataset) -> Dict[str, int]:
130
+ """
131
+ Checks the assumption that all polarizations can be held in memory, at least for one
132
+ data point (one time, one freq, one channel).
133
+
134
+ It presently relies on the logic of mem_chunksize_to_dict_main_balanced() to find a
135
+ balanced list of dimension sizes for the chunks
136
+
137
+ Assumes these relevant dims: (time, antenna_id/baseline_id, frequency,
138
+ polarization).
139
+ """
140
+
141
+ sizeof_vis = itemsize_vis_spec(xds)
142
+ size_all_pols = sizeof_vis * xds.sizes["polarization"]
143
+ if size_all_pols / GiBYTES_TO_BYTES > chunksize:
144
+ raise RuntimeError(
145
+ "Cannot calculate chunk sizes when memory bound ({chunksize}) does not even allow all polarizations in one chunk"
146
+ )
147
+
148
+ baseline_or_antenna_id = find_baseline_or_antenna_var(xds)
149
+ total_size = calc_used_gb(xds.sizes, baseline_or_antenna_id, sizeof_vis)
150
+
151
+ ratio = chunksize / total_size
152
+ chunked_dims = ["time", baseline_or_antenna_id, "frequency", "polarization"]
153
+ if ratio >= 1:
154
+ result = {dim: xds.sizes[dim] for dim in chunked_dims}
155
+ logger.debug(
156
+ f"{chunksize=} GiB is enough to fully hold {total_size=} GiB (for {xds.sizes=}) in memory in one chunk"
157
+ )
158
+ else:
159
+ xds_dim_sizes = {k: xds.sizes[k] for k in chunked_dims}
160
+ result = mem_chunksize_to_dict_main_balanced(
161
+ chunksize, xds_dim_sizes, baseline_or_antenna_id, sizeof_vis
162
+ )
163
+
164
+ return result
165
+
166
+
167
+ def mem_chunksize_to_dict_main_balanced(
168
+ chunksize: float, xds_dim_sizes: dict, baseline_or_antenna_id: str, sizeof_vis: int
169
+ ) -> Dict[str, int]:
170
+ """
171
+ Assumes the ratio is <1 and all pols can fit in memory (from
172
+ mem_chunksize_to_dict_main()).
173
+
174
+ What is kept balanced is the fraction of the total size of every dimension included in a
175
+ chunk. For example, time: 10, baseline: 100, freq: 1000, if we can afford about 33% in
176
+ one chunk, the chunksize will be ~ time: 3, baseline: 33, freq: 333.
177
+ The polarization axis is excluded from the calculations.
178
+ Because this can leave a leftover (below or above the desired chunksize limit) and
179
+ adjustment is done to get the final memory use below but as close as possible to
180
+ 'chunksize'. This adjustment alters the balance.
181
+
182
+ Parameters
183
+ ----------
184
+ chunksize : float
185
+ Desired maximum size of the chunks
186
+ xds_dim_sizes : dict
187
+ Dataset dimension sizes as dim_name->size
188
+ sizeof_vis : int
189
+ Size in bytes of a data point (one visibility / spectrum value)
190
+
191
+ Returns
192
+ -------
193
+ Dict[str, int]
194
+ dictionary of chunk sizes (as dim->size)
195
+ """
196
+
197
+ dim_names = [name for name in xds_dim_sizes.keys()]
198
+ dim_sizes = [size for size in xds_dim_sizes.values()]
199
+ # Fix fourth dim (polarization) to all (not free to auto-calculate)
200
+ free_dims_mask = np.array([True, True, True, False])
201
+
202
+ total_size = np.prod(dim_sizes) * sizeof_vis / GiBYTES_TO_BYTES
203
+ ratio = chunksize / total_size
204
+
205
+ dim_chunksizes = np.array(dim_sizes, dtype="int64")
206
+ factor = ratio ** (1 / np.sum(free_dims_mask))
207
+ dim_chunksizes[free_dims_mask] = np.maximum(
208
+ dim_chunksizes[free_dims_mask] * factor, 1
209
+ )
210
+ used = np.prod(dim_chunksizes) * sizeof_vis / GiBYTES_TO_BYTES
211
+
212
+ logger.debug(
213
+ f"Auto-calculating main chunk sizes. First order approximation {dim_chunksizes=}, used total: {used} GiB (with {chunksize=} GiB)"
214
+ )
215
+
216
+ # Iterate through the dims, starting from the dims with lower chunk size
217
+ # (=bigger impact of a +1)
218
+ # Note the use of math.floor, this iteration can either increase or decrease sizes,
219
+ # if increasing sizes we want to keep mem use below the upper limit, floor(2.3) = +2
220
+ # if decreasing sizes we want to take mem use below the upper limit, floor(-2.3) = -3
221
+ indices = np.argsort(dim_chunksizes[free_dims_mask])
222
+ for idx in indices:
223
+ left = chunksize - used
224
+ other_dims_mask = np.ones(free_dims_mask.shape, dtype=bool)
225
+ other_dims_mask[idx] = False
226
+ delta = np.divide(
227
+ left,
228
+ np.prod(dim_chunksizes[other_dims_mask]) * sizeof_vis / GiBYTES_TO_BYTES,
229
+ )
230
+ int_delta = np.floor(delta)
231
+ if abs(int_delta) > 0 and int_delta + dim_chunksizes[idx] > 0:
232
+ dim_chunksizes[idx] += int_delta
233
+ used = np.prod(dim_chunksizes) * sizeof_vis / GiBYTES_TO_BYTES
234
+
235
+ chunked_dim_names = ["time", baseline_or_antenna_id, "frequency", "polarization"]
236
+ dim_chunksizes_int = [int(v) for v in dim_chunksizes]
237
+ result = dict(zip(chunked_dim_names, dim_chunksizes_int))
238
+
239
+ logger.debug(
240
+ f"Auto-calculated main chunk sizes with {chunksize=}, {total_size=} GiB (for {dim_sizes=}): {result=} which uses {used} GiB."
241
+ )
242
+
243
+ return result
244
+
245
+
246
+ def mem_chunksize_to_dict_pointing(chunksize: float, xds: xr.Dataset) -> Dict[str, int]:
247
+ """
248
+ Equivalent to mem_chunksize_to_dict_main adapted to pointing xdss.
249
+ Assumes these relevant dims: (time, antenna, direction).
250
+ """
251
+
252
+ if not xds.sizes:
253
+ return {}
254
+
255
+ sizeof_pointing = itemsize_pointing_spec(xds)
256
+ chunked_dim_names = [name for name in xds.sizes.keys()]
257
+ dim_sizes = [size for size in xds.sizes.values()]
258
+ total_size = np.prod(dim_sizes) * sizeof_pointing / GiBYTES_TO_BYTES
259
+
260
+ # Fix third dim (direction) to all
261
+ free_dims_mask = np.array([True, True, False])
262
+
263
+ ratio = chunksize / total_size
264
+ if ratio >= 1:
265
+ logger.debug(
266
+ f"Pointing chunsize: {chunksize=} GiB is enough to fully hold {total_size=} GiB (for {xds.sizes=}) in memory in one chunk"
267
+ )
268
+ dim_chunksizes = dim_sizes
269
+ else:
270
+ # balanced
271
+ dim_chunksizes = np.array(dim_sizes, dtype="int")
272
+ factor = ratio ** (1 / np.sum(free_dims_mask))
273
+ dim_chunksizes[free_dims_mask] = np.maximum(
274
+ dim_chunksizes[free_dims_mask] * factor, 1
275
+ )
276
+ used = np.prod(dim_chunksizes) * sizeof_pointing / GiBYTES_TO_BYTES
277
+
278
+ logger.debug(
279
+ f"Auto-calculating pointing chunk sizes. First order approximation: {dim_chunksizes=}, used total: {used=} GiB (with {chunksize=} GiB"
280
+ )
281
+
282
+ indices = np.argsort(dim_chunksizes[free_dims_mask])
283
+ # refine dim_chunksizes
284
+ for idx in indices:
285
+ left = chunksize - used
286
+ other_dims_mask = np.ones(free_dims_mask.shape, dtype=bool)
287
+ other_dims_mask[idx] = False
288
+ delta = np.divide(
289
+ left,
290
+ np.prod(dim_chunksizes[other_dims_mask])
291
+ * sizeof_pointing
292
+ / GiBYTES_TO_BYTES,
293
+ )
294
+ int_delta = np.floor(delta)
295
+ if abs(int_delta) > 0 and int_delta + dim_chunksizes[idx] > 0:
296
+ dim_chunksizes[idx] += int_delta
297
+
298
+ used = np.prod(dim_chunksizes) * sizeof_pointing / GiBYTES_TO_BYTES
299
+
300
+ dim_chunksizes_int = [int(v) for v in dim_chunksizes]
301
+ result = dict(zip(chunked_dim_names, dim_chunksizes_int))
302
+
303
+ if ratio < 1:
304
+ logger.debug(
305
+ f"Auto-calculated pointing chunk sizes with {chunksize=}, {total_size=} GiB (for {xds.sizes=}): {result=} which uses {used} GiB."
306
+ )
307
+
308
+ return result
309
+
310
+
311
+ def find_baseline_or_antenna_var(xds: xr.Dataset) -> str:
312
+ if "baseline_id" in xds.coords:
313
+ baseline_or_antenna_id = "baseline_id"
314
+ elif "antenna_id" in xds.coords:
315
+ baseline_or_antenna_id = "antenna_id"
316
+
317
+ return baseline_or_antenna_id
318
+
319
+
320
+ def itemsize_vis_spec(xds: xr.Dataset) -> int:
46
321
  """
322
+ Size in bytes of one visibility (or spectrum) value.
323
+ """
324
+ names = ["SPECTRUM", "VISIBILITY"]
325
+ itemsize = 8
326
+ for var in names:
327
+ if var in xds.data_vars:
328
+ var_name = var
329
+ itemsize = np.dtype(xds.data_vars[var_name].dtype).itemsize
330
+ break
331
+
332
+ return itemsize
333
+
47
334
 
48
- col_unique = unique_1d(col)
49
- assert len(col_unique) == 1, col_name + " is not consistent."
50
- return col_unique[0]
335
+ def itemsize_pointing_spec(xds: xr.Dataset) -> int:
336
+ """
337
+ Size in bytes of one pointing (or spectrum) value.
338
+ """
339
+ pnames = ["BEAM_POINTING"]
340
+ itemsize = 8
341
+ for var in pnames:
342
+ if var in xds.data_vars:
343
+ var_name = var
344
+ itemsize = np.dtype(xds.data_vars[var_name].dtype).itemsize
345
+ break
346
+
347
+ return itemsize
348
+
349
+
350
+ def calc_used_gb(
351
+ chunksizes: dict, baseline_or_antenna_id: str, sizeof_vis: int
352
+ ) -> float:
353
+ return (
354
+ chunksizes["time"]
355
+ * chunksizes[baseline_or_antenna_id]
356
+ * chunksizes["frequency"]
357
+ * chunksizes["polarization"]
358
+ * sizeof_vis
359
+ / GiBYTES_TO_BYTES
360
+ )
51
361
 
52
362
 
53
363
  # TODO: if the didxs are not used in read_col_conversion, remove didxs from here (and convert_and_write_partition)
@@ -103,15 +413,15 @@ def create_coordinates(
103
413
 
104
414
  ddi_xds = read_generic_table(in_file, "DATA_DESCRIPTION").sel(row=ddi)
105
415
  pol_setup_id = ddi_xds.polarization_id.values
106
- spw_id = ddi_xds.spectral_window_id.values
416
+ spectral_window_id = int(ddi_xds.spectral_window_id.values)
107
417
 
108
- spw_xds = read_generic_table(
418
+ spectral_window_xds = read_generic_table(
109
419
  in_file,
110
420
  "SPECTRAL_WINDOW",
111
421
  rename_ids=subt_rename_ids["SPECTRAL_WINDOW"],
112
- ).sel(spectral_window_id=spw_id)
113
- coords["frequency"] = spw_xds["chan_freq"].data[
114
- ~(np.isnan(spw_xds["chan_freq"].data))
422
+ ).sel(spectral_window_id=spectral_window_id)
423
+ coords["frequency"] = spectral_window_xds["chan_freq"].data[
424
+ ~(np.isnan(spectral_window_xds["chan_freq"].data))
115
425
  ]
116
426
 
117
427
  pol_xds = read_generic_table(
@@ -127,25 +437,27 @@ def create_coordinates(
127
437
  xds = xds.assign_coords(coords)
128
438
 
129
439
  ###### Create Frequency Coordinate ######
130
- freq_column_description = spw_xds.attrs["other"]["msv2"]["ctds_attrs"][
440
+ freq_column_description = spectral_window_xds.attrs["other"]["msv2"]["ctds_attrs"][
131
441
  "column_descriptions"
132
442
  ]
133
443
 
134
444
  msv4_measure = column_description_casacore_to_msv4_measure(
135
- freq_column_description["CHAN_FREQ"], ref_code=spw_xds["meas_freq_ref"].data
445
+ freq_column_description["CHAN_FREQ"],
446
+ ref_code=spectral_window_xds["meas_freq_ref"].data,
136
447
  )
137
448
  xds.frequency.attrs.update(msv4_measure)
138
449
 
139
- xds.frequency.attrs["spectral_window_name"] = str(spw_xds.name.values)
450
+ xds.frequency.attrs["spectral_window_name"] = str(spectral_window_xds.name.values)
140
451
  msv4_measure = column_description_casacore_to_msv4_measure(
141
- freq_column_description["REF_FREQUENCY"], ref_code=spw_xds["meas_freq_ref"].data
452
+ freq_column_description["REF_FREQUENCY"],
453
+ ref_code=spectral_window_xds["meas_freq_ref"].data,
142
454
  )
143
455
  xds.frequency.attrs["reference_frequency"] = {
144
- "dims": "",
145
- "data": float(spw_xds.ref_frequency.values),
456
+ "dims": [],
457
+ "data": float(spectral_window_xds.ref_frequency.values),
146
458
  "attrs": msv4_measure,
147
459
  }
148
- xds.frequency.attrs["spw_id"] = spw_id
460
+ xds.frequency.attrs["spectral_window_id"] = spectral_window_id
149
461
 
150
462
  # xds.frequency.attrs["effective_channel_width"] = "EFFECTIVE_CHANNEL_WIDTH"
151
463
  # Add if doppler table is present
@@ -153,20 +465,23 @@ def create_coordinates(
153
465
  # xds.frequency.attrs["doppler_type"] =
154
466
 
155
467
  unique_chan_width = unique_1d(
156
- spw_xds.chan_width.data[np.logical_not(np.isnan(spw_xds.chan_width.data))]
468
+ spectral_window_xds.chan_width.data[
469
+ np.logical_not(np.isnan(spectral_window_xds.chan_width.data))
470
+ ]
157
471
  )
158
- # assert len(unique_chan_width) == 1, "Channel width varies for spw."
159
- # xds.frequency.attrs["channel_width"] = spw_xds.chan_width.data[
160
- # ~(np.isnan(spw_xds.chan_width.data))
472
+ # assert len(unique_chan_width) == 1, "Channel width varies for spectral_window."
473
+ # xds.frequency.attrs["channel_width"] = spectral_window_xds.chan_width.data[
474
+ # ~(np.isnan(spectral_window_xds.chan_width.data))
161
475
  # ] # unique_chan_width[0]
162
476
  msv4_measure = column_description_casacore_to_msv4_measure(
163
- freq_column_description["CHAN_WIDTH"], ref_code=spw_xds["meas_freq_ref"].data
477
+ freq_column_description["CHAN_WIDTH"],
478
+ ref_code=spectral_window_xds["meas_freq_ref"].data,
164
479
  )
165
480
  if not msv4_measure:
166
481
  msv4_measure["type"] = "quantity"
167
482
  msv4_measure["units"] = ["Hz"]
168
483
  xds.frequency.attrs["channel_width"] = {
169
- "dims": "",
484
+ "dims": [],
170
485
  "data": np.abs(unique_chan_width[0]),
171
486
  "attrs": msv4_measure,
172
487
  }
@@ -186,7 +501,7 @@ def create_coordinates(
186
501
  msv4_measure["type"] = "quantity"
187
502
  msv4_measure["units"] = ["s"]
188
503
  xds.time.attrs["integration_time"] = {
189
- "dims": "",
504
+ "dims": [],
190
505
  "data": interval,
191
506
  "attrs": msv4_measure,
192
507
  }
@@ -194,6 +509,34 @@ def create_coordinates(
194
509
  return xds
195
510
 
196
511
 
512
+ def find_min_max_times(tb_tool: tables.table, taql_where: str) -> tuple:
513
+ """
514
+ Find the min/max times in an MSv4, for constraining pointing.
515
+
516
+ To avoid numerical comparison issues (leaving out some times at the edges),
517
+ it substracts/adds a tolerance from/to the min and max values. The tolerance
518
+ is a fraction of the difference between times / interval of the MS (see
519
+ get_utimes_tol()).
520
+
521
+ Parameters
522
+ ----------
523
+ tb_tool : tables.table
524
+ table (query) opened with an MSv4 query
525
+
526
+ taql_where : str
527
+ TaQL where that defines the partition of this MSv4
528
+
529
+ Returns
530
+ -------
531
+ tuple
532
+ min/max times (raw time values from the Msv2 table)
533
+ """
534
+ utimes, tol = get_utimes_tol(tb_tool, taql_where)
535
+ time_min = utimes.min() - tol
536
+ time_max = utimes.max() + tol
537
+ return (time_min, time_max)
538
+
539
+
197
540
  def create_data_variables(
198
541
  in_file, xds, tb_tool, time_baseline_shape, tidxs, bidxs, didxs
199
542
  ):
@@ -242,6 +585,7 @@ def create_data_variables(
242
585
  )
243
586
  except:
244
587
  # logger.debug("Could not load column",col)
588
+ # print("Could not load column", col)
245
589
  continue
246
590
 
247
591
  xds[col_to_data_variable_names[col]].attrs.update(
@@ -249,15 +593,38 @@ def create_data_variables(
249
593
  )
250
594
 
251
595
 
596
+ def create_taql_query(partition_info):
597
+ main_par_table_cols = [
598
+ "DATA_DESC_ID",
599
+ "STATE_ID",
600
+ "FIELD_ID",
601
+ "SCAN_NUMBER",
602
+ "STATE_ID",
603
+ ]
604
+
605
+ taql_where = "WHERE "
606
+ for col_name in main_par_table_cols:
607
+ if col_name in partition_info:
608
+ taql_where = (
609
+ taql_where
610
+ + f"({col_name} IN [{','.join(map(str, partition_info[col_name]))}]) AND"
611
+ )
612
+ taql_where = taql_where[:-3]
613
+
614
+ return taql_where
615
+
616
+
252
617
  def convert_and_write_partition(
253
618
  in_file: str,
254
619
  out_file: str,
255
- intent: str,
256
- ddi: int = 0,
257
- state_ids=None,
258
- field_id: int = None,
259
- ignore_msv2_cols: Union[list, None] = None,
260
- main_chunksize: Union[Dict, None] = None,
620
+ ms_v4_id: int,
621
+ partition_info: Dict,
622
+ partition_scheme: str = "ddi_intent_field",
623
+ main_chunksize: Union[Dict, float, None] = None,
624
+ with_pointing: bool = True,
625
+ pointing_chunksize: Union[Dict, float, None] = None,
626
+ pointing_interpolate: bool = False,
627
+ ephemeris_interpolate: bool = False,
261
628
  compressor: numcodecs.abc.Codec = numcodecs.Zstd(level=2),
262
629
  storage_backend="zarr",
263
630
  overwrite: bool = False,
@@ -278,9 +645,15 @@ def convert_and_write_partition(
278
645
  _description_, by default None
279
646
  field_id : int, optional
280
647
  _description_, by default None
281
- ignore_msv2_cols : Union[list, None], optional
648
+ main_chunksize : Union[Dict, float, None], optional
282
649
  _description_, by default None
283
- main_chunksize : Union[Dict, None], optional
650
+ with_pointing: bool, optional
651
+ _description_, by default True
652
+ pointing_chunksize : Union[Dict, float, None], optional
653
+ _description_, by default None
654
+ pointing_interpolate : bool, optional
655
+ _description_, by default None
656
+ ephemeris_interpolate : bool, optional
284
657
  _description_, by default None
285
658
  compressor : numcodecs.abc.Codec, optional
286
659
  _description_, by default numcodecs.Zstd(level=2)
@@ -294,17 +667,16 @@ def convert_and_write_partition(
294
667
  _type_
295
668
  _description_
296
669
  """
297
- if ignore_msv2_cols is None:
298
- ignore_msv2_cols = []
299
670
 
300
- taql_where, file_name = create_taql_query_and_file_name(
301
- out_file, intent, state_ids, field_id, ddi
302
- )
671
+ taql_where = create_taql_query(partition_info)
672
+ ddi = partition_info["DATA_DESC_ID"][0]
673
+ intent = str(partition_info["INTENT"][0])
303
674
 
304
675
  start = time.time()
305
676
  with open_table_ro(in_file) as mtable:
306
677
  taql_main = f"select * from $mtable {taql_where}"
307
678
  with open_query(mtable, taql_main) as tb_tool:
679
+
308
680
  if tb_tool.nrows() == 0:
309
681
  tb_tool.close()
310
682
  mtable.close()
@@ -329,7 +701,7 @@ def convert_and_write_partition(
329
701
 
330
702
  interval_unique = unique_1d(interval)
331
703
  if len(interval_unique) > 1:
332
- print(
704
+ logger.debug(
333
705
  "Integration time (interval) not consitent in partition, using median."
334
706
  )
335
707
  interval = np.median(interval)
@@ -347,12 +719,6 @@ def convert_and_write_partition(
347
719
  )
348
720
  logger.debug("Time create data variables " + str(time.time() - start))
349
721
 
350
- # Create field_info
351
- start = time.time()
352
- field_id = check_if_consistent(tb_tool.getcol("FIELD_ID"), "FIELD_ID")
353
- field_info = create_field_info(in_file, field_id)
354
- logger.debug("Time field info " + str(time.time() - start))
355
-
356
722
  # Create ant_xds
357
723
  start = time.time()
358
724
  ant_xds = create_ant_xds(in_file)
@@ -363,14 +729,30 @@ def convert_and_write_partition(
363
729
  weather_xds = create_weather_xds(in_file)
364
730
  logger.debug("Time weather " + str(time.time() - start))
365
731
 
366
- start = time.time()
367
- pointing_xds = create_pointing_xds(in_file)
368
- logger.debug("Time pointing " + str(time.time() - start))
732
+ # To constrain the time range to load (in pointing, ephemerides data_vars)
733
+ time_min_max = find_min_max_times(tb_tool, taql_where)
734
+
735
+ if with_pointing:
736
+ start = time.time()
737
+ if pointing_interpolate:
738
+ pointing_interp_time = xds.time
739
+ else:
740
+ pointing_interp_time = None
741
+ pointing_xds = create_pointing_xds(
742
+ in_file, time_min_max, pointing_interp_time
743
+ )
744
+ pointing_chunksize = parse_chunksize(
745
+ pointing_chunksize, "pointing", pointing_xds
746
+ )
747
+ add_encoding(
748
+ pointing_xds, compressor=compressor, chunks=pointing_chunksize
749
+ )
750
+ logger.debug(
751
+ "Time pointing (with add compressor and chunking) "
752
+ + str(time.time() - start)
753
+ )
369
754
 
370
755
  start = time.time()
371
- # Fix UVW frame
372
- # From CASA fixvis docs: clean and the im tool ignore the reference frame claimed by the UVW column (it is often mislabelled as ITRF when it is really FK5 (J2000)) and instead assume the (u, v, w)s are in the same frame as the phase tracking center. calcuvw does not yet force the UVW column and field centers to use the same reference frame! Blank = use the phase tracking frame of vis.
373
- xds.UVW.attrs["frame"] = field_info["phase_direction"]["attrs"]["frame"]
374
756
 
375
757
  xds.attrs["intent"] = intent
376
758
  xds.attrs["ddi"] = ddi
@@ -391,7 +773,6 @@ def convert_and_write_partition(
391
773
  "weight": "WEIGHT",
392
774
  "uvw": "UVW",
393
775
  }
394
- xds.VISIBILITY.attrs["field_info"] = field_info
395
776
 
396
777
  if "VISIBILITY_CORRECTED" in xds:
397
778
  xds.attrs["data_groups"]["corrected"] = {
@@ -400,8 +781,8 @@ def convert_and_write_partition(
400
781
  "weight": "WEIGHT",
401
782
  "uvw": "UVW",
402
783
  }
403
- xds.VISIBILITY_CORRECTED.attrs["field_info"] = field_info
404
784
 
785
+ is_single_dish = False
405
786
  if "SPECTRUM" in xds:
406
787
  xds.attrs["data_groups"]["base"] = {
407
788
  "spectrum": "SPECTRUM",
@@ -409,7 +790,7 @@ def convert_and_write_partition(
409
790
  "weight": "WEIGHT",
410
791
  "uvw": "UVW",
411
792
  }
412
- xds.SPECTRUM.attrs["field_info"] = field_info
793
+ is_single_dish = True
413
794
 
414
795
  if "SPECTRUM_CORRECTED" in xds:
415
796
  xds.attrs["data_groups"]["corrected"] = {
@@ -418,23 +799,111 @@ def convert_and_write_partition(
418
799
  "weight": "WEIGHT",
419
800
  "uvw": "UVW",
420
801
  }
421
- xds.SPECTRUM_CORRECTED.attrs["field_info"] = field_info
802
+ is_single_dish = True
803
+
804
+ # Create field_and_source_xds (combines field, source and ephemeris data into one super dataset)
805
+ start = time.time()
806
+ if ephemeris_interpolate:
807
+ ephemeris_interp_time = xds.time
808
+ else:
809
+ ephemeris_interp_time = None
810
+
811
+ scan_id = np.full(time_baseline_shape, -42, dtype=int)
812
+ scan_id[tidxs, bidxs] = tb_tool.getcol("SCAN_NUMBER")
813
+ scan_id = np.max(scan_id, axis=1)
814
+
815
+ if (
816
+ partition_scheme == "ddi_intent_source"
817
+ or partition_scheme == "ddi_intent_scan"
818
+ ):
819
+ field_id = np.full(time_baseline_shape, -42, dtype=int)
820
+ field_id[tidxs, bidxs] = tb_tool.getcol("FIELD_ID")
821
+ field_id = np.max(field_id, axis=1)
822
+ field_times = utime
823
+ else:
824
+ field_id = check_if_consistent(tb_tool.getcol("FIELD_ID"), "FIELD_ID")
825
+ field_times = None
826
+
827
+ # col_unique = unique_1d(col)
828
+ # assert len(col_unique) == 1, col_name + " is not consistent."
829
+ # return col_unique[0]
830
+
831
+ field_and_source_xds = create_field_and_source_xds(
832
+ in_file,
833
+ field_id,
834
+ xds.frequency.attrs["spectral_window_id"],
835
+ field_times,
836
+ is_single_dish,
837
+ time_min_max,
838
+ ephemeris_interp_time,
839
+ )
840
+ logger.debug("Time field_and_source_xds " + str(time.time() - start))
841
+
842
+ # Fix UVW frame
843
+ # From CASA fixvis docs: clean and the im tool ignore the reference frame claimed by the UVW column (it is often mislabelled as ITRF when it is really FK5 (J2000)) and instead assume the (u, v, w)s are in the same frame as the phase tracking center. calcuvw does not yet force the UVW column and field centers to use the same reference frame! Blank = use the phase tracking frame of vis.
844
+ # print('##################',field_and_source_xds)
845
+ if is_single_dish:
846
+ xds.UVW.attrs["frame"] = field_and_source_xds[
847
+ "FIELD_REFERENCE_CENTER"
848
+ ].attrs["frame"]
849
+ else:
850
+ xds.UVW.attrs["frame"] = field_and_source_xds[
851
+ "FIELD_PHASE_CENTER"
852
+ ].attrs["frame"]
422
853
 
423
854
  if overwrite:
424
855
  mode = "w"
425
856
  else:
426
857
  mode = "w-"
427
858
 
859
+ main_chunksize = parse_chunksize(main_chunksize, "main", xds)
428
860
  add_encoding(xds, compressor=compressor, chunks=main_chunksize)
429
861
  logger.debug("Time add compressor and chunk " + str(time.time() - start))
430
862
 
863
+ file_name = os.path.join(
864
+ out_file,
865
+ out_file.replace(".vis.zarr", "").replace(".zarr", "").split("/")[-1]
866
+ + "_"
867
+ + str(ms_v4_id),
868
+ )
869
+
870
+ if isinstance(field_id, np.ndarray):
871
+ field_id = "OTF"
872
+
873
+ xds.attrs["partition_info"] = {
874
+ "spectral_window_id": xds.frequency.attrs["spectral_window_id"],
875
+ "spectral_window_name": xds.frequency.attrs["spectral_window_name"],
876
+ "field_id": field_id,
877
+ "field_name": field_and_source_xds.attrs["field_name"],
878
+ "source_id": field_and_source_xds.attrs["source_id"],
879
+ "source_name": field_and_source_xds.attrs["source_name"],
880
+ "polarization_setup": list(xds.polarization.values),
881
+ "intent": intent,
882
+ "taql": taql_where,
883
+ }
884
+
885
+ # print(xds)
886
+
431
887
  start = time.time()
432
888
  if storage_backend == "zarr":
433
- xds.to_zarr(store=file_name + "/MAIN", mode=mode)
434
- ant_xds.to_zarr(store=file_name + "/ANTENNA", mode=mode)
435
- pointing_xds.to_zarr(store=file_name + "/POINTING", mode=mode)
889
+ xds.to_zarr(store=os.path.join(file_name, "MAIN"), mode=mode)
890
+ ant_xds.to_zarr(store=os.path.join(file_name, "ANTENNA"), mode=mode)
891
+ for group_name in xds.attrs["data_groups"]:
892
+ field_and_source_xds.to_zarr(
893
+ store=os.path.join(
894
+ file_name, f"FIELD_AND_SOURCE_{group_name.upper()}"
895
+ ),
896
+ mode=mode,
897
+ )
898
+
899
+ if with_pointing:
900
+ pointing_xds.to_zarr(store=file_name + "/POINTING", mode=mode)
901
+
436
902
  if weather_xds:
437
- weather_xds.to_zarr(store=file_name + "/WEATHER", mode=mode)
903
+ weather_xds.to_zarr(
904
+ store=os.path.join(file_name, "WEATHER"), mode=mode
905
+ )
906
+
438
907
  elif storage_backend == "netcdf":
439
908
  # xds.to_netcdf(path=file_name+"/MAIN", mode=mode) #Does not work
440
909
  raise