xradio 0.0.55__py3-none-any.whl → 0.0.58__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. xradio/__init__.py +2 -2
  2. xradio/_utils/_casacore/casacore_from_casatools.py +1001 -0
  3. xradio/_utils/_casacore/tables.py +6 -1
  4. xradio/_utils/coord_math.py +22 -23
  5. xradio/_utils/dict_helpers.py +76 -11
  6. xradio/_utils/schema.py +5 -2
  7. xradio/_utils/zarr/common.py +1 -73
  8. xradio/image/_util/_casacore/common.py +11 -3
  9. xradio/image/_util/_casacore/xds_from_casacore.py +59 -35
  10. xradio/image/_util/_casacore/xds_to_casacore.py +47 -16
  11. xradio/image/_util/_fits/xds_from_fits.py +172 -77
  12. xradio/image/_util/casacore.py +9 -4
  13. xradio/image/_util/common.py +4 -4
  14. xradio/image/_util/image_factory.py +8 -8
  15. xradio/image/image.py +45 -5
  16. xradio/measurement_set/__init__.py +19 -9
  17. xradio/measurement_set/_utils/__init__.py +1 -3
  18. xradio/measurement_set/_utils/_msv2/__init__.py +0 -0
  19. xradio/measurement_set/_utils/_msv2/_tables/read.py +35 -90
  20. xradio/measurement_set/_utils/_msv2/_tables/read_main_table.py +6 -686
  21. xradio/measurement_set/_utils/_msv2/_tables/table_query.py +13 -3
  22. xradio/measurement_set/_utils/_msv2/conversion.py +129 -145
  23. xradio/measurement_set/_utils/_msv2/create_antenna_xds.py +9 -16
  24. xradio/measurement_set/_utils/_msv2/create_field_and_source_xds.py +125 -221
  25. xradio/measurement_set/_utils/_msv2/msv2_to_msv4_meta.py +1 -2
  26. xradio/measurement_set/_utils/_msv2/msv4_info_dicts.py +13 -8
  27. xradio/measurement_set/_utils/_msv2/msv4_sub_xdss.py +27 -72
  28. xradio/measurement_set/_utils/_msv2/partition_queries.py +5 -262
  29. xradio/measurement_set/_utils/_msv2/subtables.py +0 -107
  30. xradio/measurement_set/_utils/_utils/interpolate.py +60 -0
  31. xradio/measurement_set/_utils/_zarr/encoding.py +2 -7
  32. xradio/measurement_set/convert_msv2_to_processing_set.py +0 -2
  33. xradio/measurement_set/load_processing_set.py +2 -2
  34. xradio/measurement_set/measurement_set_xdt.py +14 -14
  35. xradio/measurement_set/open_processing_set.py +1 -3
  36. xradio/measurement_set/processing_set_xdt.py +41 -835
  37. xradio/measurement_set/schema.py +96 -123
  38. xradio/schema/check.py +91 -97
  39. xradio/schema/dataclass.py +159 -22
  40. xradio/schema/export.py +99 -0
  41. xradio/schema/metamodel.py +51 -16
  42. xradio/schema/typing.py +5 -5
  43. {xradio-0.0.55.dist-info → xradio-0.0.58.dist-info}/METADATA +43 -11
  44. xradio-0.0.58.dist-info/RECORD +65 -0
  45. {xradio-0.0.55.dist-info → xradio-0.0.58.dist-info}/WHEEL +1 -1
  46. xradio/image/_util/fits.py +0 -13
  47. xradio/measurement_set/_utils/_msv2/_tables/load.py +0 -63
  48. xradio/measurement_set/_utils/_msv2/_tables/load_main_table.py +0 -487
  49. xradio/measurement_set/_utils/_msv2/_tables/read_subtables.py +0 -395
  50. xradio/measurement_set/_utils/_msv2/_tables/write.py +0 -320
  51. xradio/measurement_set/_utils/_msv2/_tables/write_exp_api.py +0 -385
  52. xradio/measurement_set/_utils/_msv2/chunks.py +0 -115
  53. xradio/measurement_set/_utils/_msv2/descr.py +0 -165
  54. xradio/measurement_set/_utils/_msv2/msv2_msv3.py +0 -7
  55. xradio/measurement_set/_utils/_msv2/partitions.py +0 -392
  56. xradio/measurement_set/_utils/_utils/cds.py +0 -40
  57. xradio/measurement_set/_utils/_utils/xds_helper.py +0 -404
  58. xradio/measurement_set/_utils/_zarr/read.py +0 -263
  59. xradio/measurement_set/_utils/_zarr/write.py +0 -329
  60. xradio/measurement_set/_utils/msv2.py +0 -106
  61. xradio/measurement_set/_utils/zarr.py +0 -133
  62. xradio-0.0.55.dist-info/RECORD +0 -77
  63. {xradio-0.0.55.dist-info → xradio-0.0.58.dist-info}/licenses/LICENSE.txt +0 -0
  64. {xradio-0.0.55.dist-info → xradio-0.0.58.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,11 @@ import dask.array as da
5
5
  import numpy as np
6
6
  import xarray as xr
7
7
  from astropy.coordinates import Angle
8
- from casacore import tables
8
+
9
+ try:
10
+ from casacore import tables
11
+ except ImportError:
12
+ import xradio._utils._casacore.casacore_from_casatools as tables
9
13
 
10
14
  from .common import _active_mask, _create_new_image, _object_name, _pointing_center
11
15
  from ..common import _aperture_or_sky, _compute_sky_reference_pixel, _doppler_types
@@ -28,7 +32,10 @@ def _compute_direction_dict(xds: xr.Dataset) -> dict:
28
32
  direction["system"] = "J2000"
29
33
  direction["projection"] = xds_dir["projection"]
30
34
  direction["projection_parameters"] = xds_dir["projection_parameters"]
31
- direction["units"] = xds_dir["reference"]["attrs"]["units"]
35
+ direction["units"] = [
36
+ xds_dir["reference"]["attrs"]["units"],
37
+ xds_dir["reference"]["attrs"]["units"],
38
+ ]
32
39
  direction["crval"] = np.array(xds_dir["reference"]["data"])
33
40
  direction["cdelt"] = np.array((xds.l[1] - xds.l[0], xds.m[1] - xds.m[0]))
34
41
  direction["crpix"] = _compute_sky_reference_pixel(xds)
@@ -39,7 +46,7 @@ def _compute_direction_dict(xds: xr.Dataset) -> dict:
39
46
  m = "lonpole" if s == "longpole" else s
40
47
  # lonpole, latpole are numerical values in degrees in casa images
41
48
  direction[s] = float(
42
- Angle(str(xds_dir[m]["data"]) + xds_dir[m]["attrs"]["units"][0]).deg
49
+ Angle(str(xds_dir[m]["data"]) + xds_dir[m]["attrs"]["units"]).deg
43
50
  )
44
51
  return direction
45
52
 
@@ -83,20 +90,22 @@ def _compute_spectral_dict(xds: xr.Dataset) -> dict:
83
90
  spec["restfreq"] = xds.frequency.attrs["rest_frequency"]["data"]
84
91
  # spec["restfreqs"] = copy.deepcopy(xds.frequency.attrs["restfreqs"]["value"])
85
92
  spec["restfreqs"] = np.array([spec["restfreq"]])
86
- spec["system"] = xds.frequency.attrs["reference_value"]["attrs"]["observer"].upper()
87
- u = xds.frequency.attrs["reference_value"]["attrs"]["units"]
88
- spec["unit"] = u if isinstance(u, str) else u[0]
93
+ spec["system"] = xds.frequency.attrs["reference_frequency"]["attrs"][
94
+ "observer"
95
+ ].upper()
96
+ u = xds.frequency.attrs["reference_frequency"]["attrs"]["units"]
97
+ spec["unit"] = u
89
98
  spec["velType"] = _doppler_types.index(xds.velocity.attrs["doppler_type"])
90
99
  u = xds.velocity.attrs["units"]
91
100
  spec["version"] = 2
92
101
  # vel unit is a list[str] in the xds but needs to be a str in the casa image
93
- spec["velUnit"] = xds.velocity.attrs["units"][0]
102
+ spec["velUnit"] = xds.velocity.attrs["units"]
94
103
  # wave unit is a list[str] in the xds but needs to be a str in the casa image
95
- spec["waveUnit"] = xds.frequency.attrs["wave_unit"][0]
104
+ spec["waveUnit"] = xds.frequency.attrs["wave_units"]
96
105
  wcs = {}
97
106
  wcs["ctype"] = "FREQ"
98
107
  wcs["pc"] = 1.0
99
- wcs["crval"] = float(xds.frequency.attrs["reference_value"]["data"])
108
+ wcs["crval"] = float(xds.frequency.attrs["reference_frequency"]["data"])
100
109
  wcs["cdelt"] = float(xds.frequency.values[1] - xds.frequency.values[0])
101
110
  wcs["crpix"] = float((wcs["crval"] - xds.frequency.values[0]) / wcs["cdelt"])
102
111
  spec["wcs"] = wcs
@@ -110,26 +119,46 @@ def _coord_dict_from_xds(xds: xr.Dataset) -> dict:
110
119
  tel = xds[sky_ap].attrs["telescope"]
111
120
  if "name" in tel:
112
121
  coord["telescope"] = xds[sky_ap].attrs["telescope"]["name"]
113
- if "location" in tel:
114
- xds_telloc = tel["location"]
122
+
123
+ if "direction" in tel:
124
+ xds_telloc = tel["direction"]
115
125
  telloc = {}
116
126
  telloc["refer"] = xds_telloc["attrs"]["frame"]
117
127
  if telloc["refer"] == "GRS80":
118
128
  telloc["refer"] = "ITRF"
119
- for i in range(3):
129
+ for i in range(2):
120
130
  telloc[f"m{i}"] = {
121
- "unit": xds_telloc["attrs"]["units"][i],
131
+ "unit": xds_telloc["attrs"]["units"],
122
132
  "value": xds_telloc["data"][i],
123
133
  }
134
+ telloc[f"m{2}"] = {
135
+ "unit": tel["distance"]["attrs"]["units"],
136
+ "value": tel["distance"]["data"][0],
137
+ }
138
+
124
139
  telloc["type"] = "position"
125
140
  coord["telescopeposition"] = telloc
141
+
142
+ # if "location" in tel:
143
+ # xds_telloc = tel["location"]
144
+ # telloc = {}
145
+ # telloc["refer"] = xds_telloc["attrs"]["frame"]
146
+ # if telloc["refer"] == "GRS80":
147
+ # telloc["refer"] = "ITRF"
148
+ # for i in range(3):
149
+ # telloc[f"m{i}"] = {
150
+ # "unit": xds_telloc["attrs"]["units"],
151
+ # "value": xds_telloc["data"][i],
152
+ # }
153
+ # telloc["type"] = "position"
154
+ # coord["telescopeposition"] = telloc
126
155
  if "observer" in xds[sky_ap].attrs:
127
156
  coord["observer"] = xds[sky_ap].attrs["observer"]
128
157
  obsdate = {}
129
158
  obsdate["refer"] = xds.coords["time"].attrs["scale"]
130
159
  obsdate["type"] = "epoch"
131
160
  obsdate["m0"] = {}
132
- obsdate["m0"]["unit"] = xds.coords["time"].attrs["units"][0]
161
+ obsdate["m0"]["unit"] = xds.coords["time"].attrs["units"]
133
162
  obsdate["m0"]["value"] = float(xds.coords["time"].values[0])
134
163
  coord["obsdate"] = obsdate
135
164
  if _pointing_center in xds[sky_ap].attrs:
@@ -197,7 +226,7 @@ def _history_from_xds(xds: xr.Dataset, image: str) -> None:
197
226
  def _imageinfo_dict_from_xds(xds: xr.Dataset) -> dict:
198
227
  ii = {}
199
228
  ap_sky = _aperture_or_sky(xds)
200
- ii["image_type"] = (
229
+ ii["imagetype"] = (
201
230
  xds[ap_sky].attrs["image_type"] if "image_type" in xds[ap_sky].attrs else ""
202
231
  )
203
232
  ii["objectname"] = (
@@ -208,6 +237,7 @@ def _imageinfo_dict_from_xds(xds: xr.Dataset) -> dict:
208
237
  pp = {}
209
238
  pp["nChannels"] = xds.sizes["frequency"]
210
239
  pp["nStokes"] = xds.sizes["polarization"]
240
+
211
241
  bu = xds.BEAM.attrs["units"]
212
242
  chan = 0
213
243
  polarization = 0
@@ -225,6 +255,7 @@ def _imageinfo_dict_from_xds(xds: xr.Dataset) -> dict:
225
255
  chan = 0
226
256
  polarization += 1
227
257
  ii["perplanebeams"] = pp
258
+
228
259
  """
229
260
  elif "beam" in xds.attrs and xds.attrs["beam"]:
230
261
  # do nothing if xds.attrs['beam'] is None
@@ -232,7 +263,7 @@ def _imageinfo_dict_from_xds(xds: xr.Dataset) -> dict:
232
263
  for k in ["major", "minor", "pa"]:
233
264
  # print("*** ", k, ii["restoringbeam"][k])
234
265
  del ii["restoringbeam"][k]["dims"]
235
- ii["restoringbeam"][k]["unit"] = ii["restoringbeam"][k]["attrs"]["units"][0]
266
+ ii["restoringbeam"][k]["units"] = ii["restoringbeam"][k]["attrs"]["units"]
236
267
  del ii["restoringbeam"][k]["attrs"]
237
268
  ii["restoringbeam"][k]["value"] = ii["restoringbeam"][k]["data"]
238
269
  del ii["restoringbeam"][k]["data"]
@@ -1,8 +1,26 @@
1
- import astropy as ap
1
+ import copy
2
+ import re
3
+ from typing import Union
4
+
5
+ import dask
6
+ import dask.array as da
7
+ import numpy as np
8
+ import psutil
9
+ import xarray as xr
2
10
  from astropy import units as u
3
11
  from astropy.io import fits
4
12
  from astropy.time import Time
5
- from ..common import (
13
+
14
+ from xradio._utils.coord_math import _deg_to_rad
15
+ from xradio._utils.dict_helpers import (
16
+ make_quantity,
17
+ make_spectral_coord_reference_dict,
18
+ make_skycoord_dict,
19
+ make_time_measure_dict,
20
+ )
21
+
22
+ from xradio.measurement_set._utils._utils.stokes_types import stokes_types
23
+ from xradio.image._util.common import (
6
24
  _compute_linear_world_values,
7
25
  _compute_velocity_values,
8
26
  _compute_world_sph_dims,
@@ -15,37 +33,34 @@ from ..common import (
15
33
  _image_type,
16
34
  _l_m_attr_notes,
17
35
  )
18
- from xradio._utils.coord_math import _deg_to_rad
19
- from xradio._utils.dict_helpers import (
20
- make_frequency_reference_dict,
21
- make_quantity,
22
- make_skycoord_dict,
23
- make_time_measure_dict,
24
- )
25
- import copy
26
- import dask
27
- import dask.array as da
28
- import numpy as np
29
- import re
30
- from typing import Union
31
- import xarray as xr
32
36
 
33
37
 
34
38
  def _fits_image_to_xds(
35
- img_full_path: str, chunks: dict, verbose: bool, do_sky_coords: bool
39
+ img_full_path: str,
40
+ chunks: dict,
41
+ verbose: bool,
42
+ do_sky_coords: bool,
43
+ compute_mask: bool,
36
44
  ) -> dict:
37
45
  """
46
+ compute_mask : bool, optional
47
+ If True (default), compute and attach valid data masks to the xds.
48
+ If False, skip mask generation for performance. It is solely the responsibility
49
+ of the user to ensure downstream apps can handle NaN values; do not
50
+ ask package developers to add this non-standard behavior.
51
+
38
52
  TODO: complete documentation
39
53
  Create an xds without any pixel data from metadata from the specified FITS image
40
54
  """
41
- # memmap = True allows only part of data to be loaded into memory
42
55
  # may also need to pass mode='denywrite'
43
56
  # https://stackoverflow.com/questions/35759713/astropy-io-fits-read-row-from-large-fits-file-with-mutliple-hdus
44
- hdulist = fits.open(img_full_path, memmap=True)
45
- attrs, helpers, header = _fits_header_to_xds_attrs(hdulist)
46
- hdulist.close()
47
- # avoid keeping reference to mem-mapped fits file
48
- del hdulist
57
+ try:
58
+ hdulist = fits.open(img_full_path, memmap=True)
59
+ attrs, helpers, header = _fits_header_to_xds_attrs(hdulist, compute_mask)
60
+ finally:
61
+ hdulist.close()
62
+ # avoid keeping reference to mem-mapped fits file
63
+ del hdulist
49
64
  xds = _create_coords(helpers, header, do_sky_coords)
50
65
  sphr_dims = helpers["sphr_dims"]
51
66
  ary = _read_image_array(img_full_path, chunks, helpers, verbose)
@@ -83,10 +98,10 @@ def _add_freq_attrs(xds: xr.Dataset, helpers: dict) -> xr.Dataset:
83
98
  meta["rest_frequency"] = make_quantity(helpers["restfreq"], "Hz")
84
99
  meta["rest_frequencies"] = [meta["rest_frequency"]]
85
100
  meta["type"] = "frequency"
86
- meta["wave_unit"] = ["mm"]
101
+ meta["wave_units"] = "mm"
87
102
  freq_axis = helpers["freq_axis"]
88
- meta["reference_value"] = make_frequency_reference_dict(
89
- helpers["crval"][freq_axis], ["Hz"], helpers["specsys"]
103
+ meta["reference_frequency"] = make_spectral_coord_reference_dict(
104
+ helpers["crval"][freq_axis], "Hz", helpers["specsys"]
90
105
  )
91
106
  # meta["cdelt"] = helpers["cdelt"][freq_axis]
92
107
  if not meta:
@@ -99,7 +114,7 @@ def _add_freq_attrs(xds: xr.Dataset, helpers: dict) -> xr.Dataset:
99
114
 
100
115
  def _add_vel_attrs(xds: xr.Dataset, helpers: dict) -> xr.Dataset:
101
116
  vel_coord = xds.coords["velocity"]
102
- meta = {"units": ["m/s"]}
117
+ meta = {"units": "m/s"}
103
118
  if helpers["has_freq"]:
104
119
  meta["doppler_type"] = helpers.get("doppler", "RADIO")
105
120
  else:
@@ -156,9 +171,7 @@ def _xds_direction_attrs_from_header(helpers: dict, header) -> dict:
156
171
  helpers["ref_sys"] = ref_sys
157
172
  helpers["ref_eqx"] = ref_eqx
158
173
  # fits does not support conversion frames
159
- direction["reference"] = make_skycoord_dict(
160
- [0.0, 0.0], units=["rad", "rad"], frame=ref_sys
161
- )
174
+ direction["reference"] = make_skycoord_dict([0.0, 0.0], units="rad", frame=ref_sys)
162
175
  dir_axes = helpers["dir_axes"]
163
176
  ddata = []
164
177
  dunits = []
@@ -236,16 +249,43 @@ def _get_telescope_metadata(helpers: dict, header) -> dict:
236
249
  r = np.sqrt(np.sum(xyz * xyz))
237
250
  lat = np.arcsin(z / r)
238
251
  long = np.arctan2(y, x)
239
- tel["location"] = {
252
+ tel["direction"] = {
240
253
  "attrs": {
241
254
  "coordinate_system": "geocentric",
242
255
  # I haven't seen a FITS keyword for reference frame of telescope posiiton
243
256
  "frame": "ITRF",
244
257
  "origin_object_name": "earth",
245
258
  "type": "location",
246
- "units": ["rad", "rad", "m"],
259
+ "units": "rad",
260
+ },
261
+ "data": np.array([long, lat]),
262
+ "dims": ["ellipsoid_dir_label"],
263
+ "coords": {
264
+ "ellipsoid_dir_label": {
265
+ "dims": ["ellipsoid_dir_label"],
266
+ "data": ["lon", "lat"],
267
+ }
268
+ },
269
+ }
270
+ tel["distance"] = {
271
+ "attrs": {
272
+ "coordinate_system": "geocentric",
273
+ # I haven't seen a FITS keyword for reference frame of telescope posiiton
274
+ "frame": "ITRF",
275
+ "origin_object_name": "earth",
276
+ "type": "location",
277
+ "units": "m",
278
+ },
279
+ "data": np.array([r]),
280
+ "dims": ["ellipsoid_dis_label"],
281
+ "coords": {
282
+ "ellipsoid_dis_label": {
283
+ "dims": ["ellipsoid_dis_label"],
284
+ "data": [
285
+ "dist",
286
+ ],
287
+ }
247
288
  },
248
- "data": np.array([long, lat, r]),
249
289
  }
250
290
  return tel
251
291
 
@@ -263,9 +303,7 @@ def _compute_pointing_center(helpers: dict, header) -> dict:
263
303
  pc_lat = float(header[f"CRVAL{t_axes[1]}"]) * unit[1]
264
304
  pc_long = pc_long.to(u.rad).value
265
305
  pc_lat = pc_lat.to(u.rad).value
266
- return make_skycoord_dict(
267
- [pc_long, pc_lat], units=["rad", "rad"], frame=helpers["ref_sys"]
268
- )
306
+ return make_skycoord_dict([pc_long, pc_lat], units="rad", frame=helpers["ref_sys"])
269
307
 
270
308
 
271
309
  def _user_attrs_from_header(header) -> dict:
@@ -364,12 +402,41 @@ def _create_dim_map(helpers: dict, header) -> dict:
364
402
  return dim_map
365
403
 
366
404
 
367
- def _fits_header_to_xds_attrs(hdulist: fits.hdu.hdulist.HDUList) -> tuple:
405
+ def _fits_header_to_xds_attrs(
406
+ hdulist: fits.hdu.hdulist.HDUList, compute_mask: bool
407
+ ) -> tuple:
408
+ # First: Guard for unsupported compressed images
409
+ for i, hdu in enumerate(hdulist):
410
+ if isinstance(hdu, fits.CompImageHDU):
411
+ raise RuntimeError(
412
+ f"HDU {i}, name={hdu.name} is a CompImageHDU, which is not supported "
413
+ "for memory-mapping. "
414
+ "Cannot memory-map compressed FITS image (CompImageHDU). "
415
+ "Workaround: decompress the FITS using tools like `funpack`, `cfitsio`, "
416
+ "or Astropy's `.scale()`/`.copy()` workflows"
417
+ )
368
418
  primary = None
419
+ # FIXME beams is set but never actually used in this function. What's up with that?
369
420
  beams = None
370
421
  for hdu in hdulist:
371
422
  if hdu.name == "PRIMARY":
372
423
  primary = hdu
424
+ # Memory map support check
425
+ # avoid possibly non-existent hdu.scale_type attribute check and check header instead
426
+ header = hdu.header
427
+ scale = hdu.header.get("BSCALE", 1.0)
428
+ zero = hdu.header.get("BZERO", 0.0)
429
+ if not (scale == 1.0 and zero == 0.0):
430
+ raise RuntimeError(
431
+ "Cannot memory-map scaled FITS data (BSCALE/BZERO set). "
432
+ f"BZERO={zero}, BSCALE={scale}. "
433
+ "Workaround: remove scaling with Astropy's"
434
+ " `HDU.data = HDU.data * BSCALE + BZERO` and save a new file"
435
+ )
436
+ # NOTE: check for primary.data size being too large removed, since
437
+ # data is read in chunks, so no danger of exhausting memory
438
+ # NOTE: sanity-check for ndarray type has been removed to avoid
439
+ # forcing eager memory load of possibly very large data array.
373
440
  elif hdu.name == "BEAMS":
374
441
  beams = hdu
375
442
  else:
@@ -399,13 +466,57 @@ def _fits_header_to_xds_attrs(hdulist: fits.hdu.hdulist.HDUList) -> tuple:
399
466
  raise RuntimeError("Could not find both direction axes")
400
467
  if dir_axes is not None:
401
468
  attrs["direction"] = _xds_direction_attrs_from_header(helpers, header)
402
- # FIXME read fits data in chunks in case all data too large to hold in memory
403
- helpers["has_mask"] = da.any(da.isnan(primary.data)).compute()
469
+ helpers["has_mask"] = False
470
+ if compute_mask:
471
+ # 🧠 Why the primary.data reference here is Safe (does not cause
472
+ # an eager read of entire data array)
473
+ # primary.data is a memory-mapped array (because fits.open(..., memmap=True)
474
+ # is used upstream)
475
+ # da.from_array(...) wraps this without reading it immediately
476
+ # The actual read occurs inside:
477
+ # .map_blocks(...).any().compute()
478
+ # ...and that triggers blockwise loading via Dask → safe and parallel
479
+ # 💡 Gotcha
480
+ # What would be dangerous:
481
+ # arr = np.isnan(primary.data).any()
482
+ # That would pull the whole array into memory. But we're not doing that.
483
+ data_dask = da.from_array(primary.data, chunks="auto")
484
+ # The following code black has corner case exposure, although the guard should
485
+ # eliminate it. But there is a cleaner, dask-y way that should work that we implement
486
+ # next, with cautions
487
+ # def chunk_has_nan(block):
488
+ # if not isinstance(block, np.ndarray) or block.size == 0:
489
+ # return False
490
+ # return np.isnan(block).any()
491
+ # helpers["has_mask"] = data_dask.map_blocks(chunk_has_nan, dtype=bool).any().compute()
492
+ # ✅ Option: np.isnan(data_dask).any().compute()
493
+ # 🔒 Pros:
494
+ # Cleaner and shorter (no custom function)
495
+ # Handles all chunk shapes robustly — no risk of empty inputs
496
+ # Uses Dask’s own optimized blockwise operations under the hood
497
+ # ⚠️ Cons:
498
+ # Might trigger more eager computation if Dask can't optimize well:
499
+ # If chunks are misaligned or small, Dask might combine many or materialize more blocks than needed
500
+ # Especially on large images, it could bump memory pressure slightly
501
+ # But since we already call .compute(), we will load some block data no matter
502
+ # what — this just changes how much and how smartly.
503
+ # ✅ Verdict for compute_mask
504
+ # Because this is explicitly for computing a global has-NaN flag (not building the
505
+ # dataset), recommend:
506
+ # helpers["has_mask"] = np.isnan(data_dask).any().compute()
507
+ # It's concise, robust to shape edge cases, and still parallelized.
508
+ # We can always revisit it later if perf becomes a concern — and even then,
509
+ # it's likely a matter of tuning chunks= manually rather than the expression itself.
510
+ #
511
+ # This compute will normally be done in parallel
512
+ helpers["has_mask"] = np.isnan(data_dask).any().compute()
404
513
  beam = _beam_attr_from_header(helpers, header)
405
514
  if beam != "mb":
406
515
  helpers["beam"] = beam
407
516
  if "BITPIX" in header:
408
517
  v = abs(header["BITPIX"])
518
+ if v == 16:
519
+ helpers["dtype"] = "int16"
409
520
  if v == 32:
410
521
  helpers["dtype"] = "float32"
411
522
  elif v == 64:
@@ -487,8 +598,8 @@ def _create_coords(
487
598
  cdelt=pick(helpers["cdelt"]),
488
599
  cunit=pick(helpers["cunit"]),
489
600
  )
601
+ helpers["cunit"] = my_ret["units"]
490
602
  for j, i in enumerate(dir_axes):
491
- helpers["cunit"][i] = my_ret["unit"][j]
492
603
  helpers["crval"][i] = my_ret["ref_val"][j]
493
604
  helpers["cdelt"][i] = my_ret["inc"][j]
494
605
  coords[my_ret["axis_name"][0]] = (["l", "m"], my_ret["value"][0])
@@ -507,32 +618,7 @@ def _get_time_values(helpers):
507
618
 
508
619
 
509
620
  def _get_pol_values(helpers):
510
- # as mapped in casacore Stokes.h
511
- stokes_map = [
512
- "Undefined",
513
- "I",
514
- "Q",
515
- "U",
516
- "V",
517
- "RR",
518
- "RL",
519
- "LR",
520
- "LL",
521
- "XX",
522
- "XY",
523
- "YX",
524
- "YY",
525
- "RX",
526
- "RY",
527
- "LX",
528
- "LY",
529
- "XR",
530
- "XL",
531
- "YR",
532
- "YL",
533
- "PP",
534
- "PQ",
535
- ]
621
+
536
622
  idx = helpers["ctype"].index("STOKES")
537
623
  if idx >= 0:
538
624
  vals = []
@@ -542,7 +628,13 @@ def _get_pol_values(helpers):
542
628
  stokes_start_idx = crval - cdelt * crpix
543
629
  for i in range(helpers["shape"][idx]):
544
630
  stokes_idx = (stokes_start_idx + i) * cdelt
545
- vals.append(stokes_map[stokes_idx])
631
+ if 0 <= stokes_idx < len(stokes_types):
632
+ # stokes_types provides the index-label mapping from casacore Stokes.h
633
+ vals.append(stokes_types[stokes_idx])
634
+ else:
635
+ raise RuntimeError(
636
+ "Can't find the Stokes type using the FITS header index"
637
+ )
546
638
  return vals
547
639
  else:
548
640
  return ["I"]
@@ -582,9 +674,9 @@ def _get_freq_values(helpers: dict) -> list:
582
674
  freq, vel = _freq_from_vel(
583
675
  crval, cdelt, crpix, cunit, "Z", helpers["shape"][v_idx], restfreq
584
676
  )
585
- helpers["velocity"] = vel["value"] * u.Unit(vel["unit"])
586
- helpers["crval"][v_idx] = (freq["crval"] * u.Unit(freq["unit"])).to(u.Hz).value
587
- helpers["cdelt"][v_idx] = (freq["cdelt"] * u.Unit(freq["unit"])).to(u.Hz).value
677
+ helpers["velocity"] = vel["value"] * u.Unit(vel["units"])
678
+ helpers["crval"][v_idx] = (freq["crval"] * u.Unit(freq["units"])).to(u.Hz).value
679
+ helpers["cdelt"][v_idx] = (freq["cdelt"] * u.Unit(freq["units"])).to(u.Hz).value
588
680
  return list(freq["value"])
589
681
  else:
590
682
  return [1420e6]
@@ -603,6 +695,9 @@ def _get_velocity_values(helpers: dict) -> list:
603
695
  return v
604
696
 
605
697
 
698
+ # FIXME change namee, even if there is only a single beam, we make a
699
+ # multi beam array using it. If we have a beam, it will always be
700
+ # "mutltibeam" is name is redundant and confusing
606
701
  def _do_multibeam(xds: xr.Dataset, imname: str) -> xr.Dataset:
607
702
  """Only run if we are sure there are multiple beams"""
608
703
  hdulist = fits.open(imname)
@@ -837,12 +932,12 @@ def _get_transpose_list(helpers: dict) -> tuple:
837
932
 
838
933
  def _read_image_chunk(img_full_path, shapes: tuple, starts: tuple) -> np.ndarray:
839
934
  hdulist = fits.open(img_full_path, memmap=True)
840
- s = []
841
- for start, length in zip(starts, shapes):
842
- s.append(slice(start, start + length))
843
- t = tuple(s)
844
- z = hdulist[0].data[t]
935
+ hdu = hdulist[0]
936
+ # Chunk slice
937
+ slices = tuple(
938
+ slice(start, start + length) for start, length in zip(starts, shapes)
939
+ )
940
+ chunk = hdu.data[slices]
845
941
  hdulist.close()
846
- # delete to avoid having a reference to a mem-mapped hdulist
847
942
  del hdulist
848
- return z
943
+ return chunk
@@ -10,7 +10,11 @@ from typing import Union
10
10
 
11
11
  import xarray as xr
12
12
 
13
- from casacore import tables
13
+ try:
14
+ from casacore import tables
15
+ except ImportError:
16
+ import xradio._utils._casacore.casacore_from_casatools as tables
17
+
14
18
  from ._casacore.common import _open_image_ro
15
19
  from ._casacore.xds_from_casacore import (
16
20
  _add_mask,
@@ -42,8 +46,8 @@ def _load_casa_image_block(infile: str, block_des: dict, do_sky_coords) -> xr.Da
42
46
  cshape = casa_image.shape()
43
47
  ret = _casa_image_to_xds_coords(image_full_path, False, do_sky_coords)
44
48
  xds = ret["xds"].isel(block_des)
45
- nchan = ret["xds"].dims["frequency"]
46
- npol = ret["xds"].dims["polarization"]
49
+ nchan = ret["xds"].sizes["frequency"]
50
+ npol = ret["xds"].sizes["polarization"]
47
51
  starts, shapes, slices = _get_starts_shapes_slices(block_des, coords, cshape)
48
52
  dimorder = _get_xds_dim_order(ret["sphr_dims"])
49
53
  transpose_list, new_axes = _get_transpose_list(coords)
@@ -101,7 +105,7 @@ def _read_casa_image(
101
105
  xds = _add_mask(xds, m.upper(), ary, dimorder)
102
106
  xds.attrs = _casa_image_to_xds_attrs(img_full_path)
103
107
  beam = _get_beam(
104
- img_full_path, xds.dims["frequency"], xds.dims["polarization"], True
108
+ img_full_path, xds.sizes["frequency"], xds.sizes["polarization"], True
105
109
  )
106
110
  if beam is not None:
107
111
  xds["BEAM"] = beam
@@ -129,6 +133,7 @@ def _xds_to_casa_image(xds: xr.Dataset, imagename: str) -> None:
129
133
  lockoptions={"option": "permanentwait"},
130
134
  ack=False,
131
135
  )
136
+
132
137
  tb.putkeyword("coords", coord)
133
138
  tb.putkeyword("imageinfo", ii)
134
139
  if units:
@@ -110,7 +110,7 @@ def _numpy_arrayize_dv(xds: xr.Dataset) -> xr.Dataset:
110
110
  def _default_freq_info() -> dict:
111
111
  return {
112
112
  "rest_frequency": make_quantity(1420405751.7860003, "Hz"),
113
- "type": "frequency",
113
+ "type": "spectral_coord",
114
114
  "frame": "lsrk",
115
115
  "units": "Hz",
116
116
  "waveUnit": "mm",
@@ -141,7 +141,7 @@ def _freq_from_vel(
141
141
  vel = vel * u.Unit(cunit)
142
142
  v_dict = {
143
143
  "value": vel.value,
144
- "unit": cunit,
144
+ "units": cunit,
145
145
  "crval": crval,
146
146
  "cdelt": cdelt,
147
147
  "crpix": crpix,
@@ -154,7 +154,7 @@ def _freq_from_vel(
154
154
  fcdelt = -restfreq / _c / (crval * vel.unit / _c + 1) ** 2 * cdelt * vel.unit
155
155
  f_dict = {
156
156
  "value": freq.value,
157
- "unit": "Hz",
157
+ "units": "Hz",
158
158
  "crval": fcrval.to(u.Hz).value,
159
159
  "cdelt": fcdelt.to(u.Hz).value,
160
160
  "crpix": crpix,
@@ -180,7 +180,7 @@ def _compute_world_sph_dims(
180
180
  "axis_name": [None, None],
181
181
  "ref_val": [None, None],
182
182
  "inc": [None, None],
183
- "unit": ["rad", "rad"],
183
+ "units": "rad",
184
184
  "value": [None, None],
185
185
  }
186
186
  for i in range(2):
@@ -6,7 +6,7 @@ from typing import List, Union
6
6
  from .common import _c, _compute_world_sph_dims, _l_m_attr_notes
7
7
  from xradio._utils.coord_math import _deg_to_rad
8
8
  from xradio._utils.dict_helpers import (
9
- make_frequency_reference_dict,
9
+ make_spectral_coord_reference_dict,
10
10
  make_quantity,
11
11
  make_skycoord_dict,
12
12
  make_time_coord_attrs,
@@ -50,24 +50,24 @@ def _add_common_attrs(
50
50
  cell_size: Union[List[float], np.ndarray],
51
51
  projection: str,
52
52
  ) -> xr.Dataset:
53
- xds.time.attrs = make_time_coord_attrs(units=["d"], scale="utc", time_format="mjd")
53
+ xds.time.attrs = make_time_coord_attrs(units="d", scale="utc", time_format="mjd")
54
54
  freq_vals = np.array(xds.frequency)
55
55
  xds.frequency.attrs = {
56
56
  "observer": spectral_reference.lower(),
57
- "reference_value": make_frequency_reference_dict(
57
+ "reference_frequency": make_spectral_coord_reference_dict(
58
58
  value=freq_vals[len(freq_vals) // 2].item(),
59
- units=["Hz"],
59
+ units="Hz",
60
60
  observer=spectral_reference.lower(),
61
61
  ),
62
62
  "rest_frequencies": make_quantity(restfreq, "Hz"),
63
63
  "rest_frequency": make_quantity(restfreq, "Hz"),
64
- "type": "frequency",
65
- "units": ["Hz"],
66
- "wave_unit": ["mm"],
64
+ "type": "spectral_coord",
65
+ "units": "Hz",
66
+ "wave_units": "mm",
67
67
  }
68
68
  xds.velocity.attrs = {"doppler_type": "radio", "type": "doppler", "units": "m/s"}
69
69
  reference = make_skycoord_dict(
70
- data=phase_center, units=["rad", "rad"], frame=direction_reference
70
+ data=phase_center, units="rad", frame=direction_reference
71
71
  )
72
72
  reference["attrs"].update({"equinox": "j2000.0"})
73
73
  xds.attrs = {