xarray-kat 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. xarray_kat/__init__.py +2 -0
  2. xarray_kat/array.py +308 -0
  3. xarray_kat/async_loop.py +88 -0
  4. xarray_kat/datatree_factory.py +511 -0
  5. xarray_kat/entrypoint.py +155 -0
  6. xarray_kat/errors.py +6 -0
  7. xarray_kat/jwt.py +23 -0
  8. xarray_kat/katdal_types.py +70 -0
  9. xarray_kat/meerkat_chunk_manager.py +259 -0
  10. xarray_kat/multiton.py +108 -0
  11. xarray_kat/py.typed +0 -0
  12. xarray_kat/stores/base_store.py +113 -0
  13. xarray_kat/stores/flag_store.py +63 -0
  14. xarray_kat/stores/http_store.py +48 -0
  15. xarray_kat/stores/vfw_store.py +156 -0
  16. xarray_kat/stores/vis_weight_flag_store_factory.py +228 -0
  17. xarray_kat/stores/visibility_stores.py +154 -0
  18. xarray_kat/stores/weight_store.py +102 -0
  19. xarray_kat/third_party/vendored/katdal/__init__.py +0 -0
  20. xarray_kat/third_party/vendored/katdal/applycal_minimal.py +744 -0
  21. xarray_kat/third_party/vendored/katdal/categorical.py +812 -0
  22. xarray_kat/third_party/vendored/katdal/dataset.py +1461 -0
  23. xarray_kat/third_party/vendored/katdal/datasources_minimal.py +543 -0
  24. xarray_kat/third_party/vendored/katdal/flags.py +54 -0
  25. xarray_kat/third_party/vendored/katdal/sensordata.py +1011 -0
  26. xarray_kat/third_party/vendored/katdal/spectral_window.py +222 -0
  27. xarray_kat/third_party/vendored/katdal/van_vleck.py +128 -0
  28. xarray_kat/third_party/vendored/katdal/vis_flags_weights_minimal.py +190 -0
  29. xarray_kat/third_party/vendored/katdal/visdatav4_minimal.py +917 -0
  30. xarray_kat/utils/__init__.py +140 -0
  31. xarray_kat/utils/chunk_selection.py +114 -0
  32. xarray_kat/utils/serialisation.py +85 -0
  33. xarray_kat/xkat_types.py +62 -0
  34. xarray_kat-0.0.1.dist-info/METADATA +89 -0
  35. xarray_kat-0.0.1.dist-info/RECORD +37 -0
  36. xarray_kat-0.0.1.dist-info/WHEEL +4 -0
  37. xarray_kat-0.0.1.dist-info/entry_points.txt +6 -0
xarray_kat/__init__.py ADDED
@@ -0,0 +1,2 @@
1
+ def hello() -> str:
2
+ return "Hello from xarray-kat!"
xarray_kat/array.py ADDED
@@ -0,0 +1,308 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from numbers import Integral
5
+ from typing import TYPE_CHECKING, Tuple
6
+
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+ import tensorstore as ts
10
+ from xarray.backends import BackendArray
11
+ from xarray.core.indexing import (
12
+ ExplicitlyIndexedNDArrayMixin,
13
+ IndexingSupport,
14
+ OuterIndexer,
15
+ VectorizedIndexer,
16
+ expanded_indexer,
17
+ explicit_indexing_adapter,
18
+ )
19
+
20
+ if TYPE_CHECKING:
21
+ from xarray_kat.multiton import Multiton
22
+
23
+
24
+ class AbstractMeerkatArchiveArray(ABC, BackendArray):
25
+ """Require subclasses to implement ``dims`` and ``chunks`` properties.
26
+ Note that xarray's internal API expects ``BackendArray``
27
+ to provide ``shape`` and ``dtype`` attributes."""
28
+
29
+ @property
30
+ @abstractmethod
31
+ def dims(self) -> Tuple[str, ...]:
32
+ raise NotImplementedError
33
+
34
+ @property
35
+ @abstractmethod
36
+ def chunks(self) -> Tuple[int, ...]:
37
+ raise NotImplementedError
38
+
39
+
40
+ class DelayedTensorStore(ExplicitlyIndexedNDArrayMixin):
41
+ """A wrapper for TensorStores that only produces new
42
+ DelayedTensorStores when indexed"""
43
+
44
+ __slots__ = ("array",)
45
+
46
+ array: ts.TensorStore
47
+
48
+ def __init__(self, array):
49
+ self.array = array
50
+
51
+ @property
52
+ def dtype(self) -> npt.DTypeLike:
53
+ return self.array.dtype.numpy_dtype
54
+
55
+ def get_duck_array(self):
56
+ return self.array
57
+
58
+ async def async_get_duck_array(self):
59
+ return self.array
60
+
61
+ def _oindex_get(self, indexer: OuterIndexer):
62
+ return DelayedTensorStore(self.array.oindex[indexer.tuple])
63
+
64
+ def _vindex_get(self, indexer: VectorizedIndexer):
65
+ return DelayedTensorStore(self.array.vindex[indexer.tuple])
66
+
67
+ def __getitem__(self, key):
68
+ return DelayedTensorStore(self.array[key.tuple])
69
+
70
+
71
+ class CorrProductMixin:
72
+ """Mixin containing methods for reasoning about
73
+ ``(time, frequency, corrprod)`` shaped MeerKAT archive data.
74
+
75
+ Implements ``dims``, ``chunks`` and ``shape`` properties
76
+ of ``AbstractMeerkatArchiveArray``.
77
+
78
+ The ``meerkat_key`` method produces an index
79
+ that, when applied to a ``(time, frequency, corrprod)`` array
80
+ produces a ``(time, frequency, baseline_id, polarization)`` array.
81
+ This can then be transposed into canonical MSv4 ording.
82
+ """
83
+
84
+ __slots__ = ("_cp_argsort", "_msv4_shape", "_msv4_dims", "_msv4_chunks")
85
+
86
+ _cp_argsort: npt.NDArray
87
+ _msv4_shape: Tuple[int, int, int, int]
88
+ _msv4_dims: Tuple[str, str, str, str]
89
+ _msv4_chunks: Tuple[int, int, int, int]
90
+
91
+ def __init__(
92
+ self,
93
+ meerkat_shape: Tuple[int, int, int],
94
+ meerkat_dims: Tuple[str, str, str],
95
+ meerkat_chunks: Tuple[int, int, int],
96
+ cp_argsort: npt.NDArray,
97
+ npol: int,
98
+ ):
99
+ """Constructs a CorrProductMixin
100
+
101
+ Args:
102
+ meerkat_shape: The shape of the meerkat array.
103
+ Should be associated with the ``(time, frequency, corrprod)`` dimensions.
104
+ meerkat_dims: The dimensions of the meerkat array.
105
+ Should be ``(time, frequency, corrprod)``.
106
+ meerkat_chunks: The chunking of the meerkat array.
107
+ Should be associated with the ``(time, frequency, corrprod)`` dimensions.
108
+ cp_argsort: An array sorting the ``corrprod`` dimension into a
109
+ canonical ``(baseline_id, polarization)`` ordering.
110
+ npol: Number of polarizations.
111
+ """
112
+ if meerkat_dims != ("time", "frequency", "corrprod"):
113
+ raise ValueError(f"{meerkat_dims} should be (time, frequency, corrprod)")
114
+
115
+ try:
116
+ ntime, nfreq, ncorrprod = meerkat_shape
117
+ except ValueError:
118
+ raise ValueError(f"{meerkat_shape} should be (time, frequency, corrprod)")
119
+ if len(cp_argsort) != ncorrprod:
120
+ raise ValueError(f"{len(cp_argsort)} does not match corrprods {ncorrprod}")
121
+ self._cp_argsort = cp_argsort
122
+ nbl, rem = divmod(len(cp_argsort), npol)
123
+ self._msv4_shape = (ntime, nbl, nfreq, npol)
124
+ self._msv4_dims = (meerkat_dims[0], "baseline_id", meerkat_dims[1], "polarization")
125
+ self._msv4_chunks = (meerkat_chunks[0], nbl, meerkat_chunks[1], npol)
126
+ if rem != 0:
127
+ raise ValueError(
128
+ f"Number of polarizations {npol} must divide "
129
+ f"the correlation product index {len(cp_argsort)} exactly."
130
+ )
131
+
132
+ @property
133
+ def dims(self) -> Tuple[str, ...]:
134
+ return self._msv4_dims
135
+
136
+ @property
137
+ def chunks(self) -> Tuple[int, ...]:
138
+ return self._msv4_chunks
139
+
140
+ @property
141
+ def shape(self) -> Tuple[int, ...]:
142
+ return self._msv4_shape
143
+
144
+ @property
145
+ def ndim(self) -> int:
146
+ return len(self._msv4_shape)
147
+
148
+ def _normalize_key_axis(
149
+ self,
150
+ key: Tuple[slice | npt.NDArray | Integral, ...],
151
+ axis: int,
152
+ ) -> npt.NDArray:
153
+ """Normalises ``key[axis]`` into an numpy array"""
154
+ if isinstance(key_item := key[axis], slice):
155
+ return np.arange(self.shape[axis])[key_item]
156
+ elif isinstance(key_item, Integral):
157
+ return np.array([key_item])
158
+ elif isinstance(key_item, np.ndarray):
159
+ return key_item
160
+ else:
161
+ raise NotImplementedError(f"key_item type {type(key_item)}")
162
+
163
+ def meerkat_key(self, msv4_key: Tuple) -> Tuple:
164
+ """Translates an MSv4 key into a MeerKAT key.
165
+
166
+ MSv4 arrays have ``(time, baseline_id, frequency, polarization)``
167
+ dimensions. This method translates keys referencing the above
168
+ dimensions into keys which operate on MeerKAT archive data with
169
+ ``(time, frequency, corrprod)`` dimensions.
170
+ """
171
+ msv4_key = expanded_indexer(msv4_key, self.ndim)
172
+ assert isinstance(msv4_key, tuple) and len(msv4_key) == 4
173
+ time_selection = msv4_key[0]
174
+ bl_selection = self._normalize_key_axis(msv4_key, 1)
175
+ frequency_selection = msv4_key[2]
176
+ pol_selection = self._normalize_key_axis(msv4_key, 3)
177
+
178
+ bl_grid, pol_grid = np.meshgrid(bl_selection, pol_selection, indexing="ij")
179
+ # cp_selection has shape (nbl, npol). When used in an index,
180
+ # it has the effect of splitting the corrprod dimension
181
+ # into baseline and polarization
182
+ npol = self.shape[3]
183
+ cp_selection = self._cp_argsort[bl_grid * npol + pol_grid]
184
+ return (time_selection, frequency_selection, cp_selection)
185
+
186
+ @property
187
+ def transpose_axes(self) -> Tuple[int, int, int, int]:
188
+ """Transpose (time, frequency, baseline_id, polarization) to
189
+ (time, baseline_id, frequency, polarization)"""
190
+ return (0, 2, 1, 3)
191
+
192
+
193
+ class DelayedCorrProductArray(CorrProductMixin, AbstractMeerkatArchiveArray):
194
+ """Wraps a ``(time, frequency, corrprod)``` TensorStore.
195
+
196
+ Most data in the MeerKAT archive has dimension
197
+ ``(time, frequency, corrprod)``.
198
+ This class reshapes the underlying data into the
199
+ ``(time, baseline_id, frequency, polarization)`` form.
200
+ """
201
+
202
+ __slots__ = "_store"
203
+
204
+ _store: Multiton[ts.TensorStore]
205
+
206
+ def __init__(
207
+ self, store: Multiton[ts.TensorStore], cp_argsort: npt.NDArray, npol: int
208
+ ):
209
+ CorrProductMixin.__init__(
210
+ self,
211
+ store.instance.shape,
212
+ store.instance.domain.labels,
213
+ store.instance.chunk_layout.read_chunk.shape,
214
+ cp_argsort,
215
+ npol,
216
+ )
217
+ self._store = store
218
+
219
+ @property
220
+ def dtype(self) -> npt.DTypeLike:
221
+ return self._store.instance.dtype.numpy_dtype
222
+
223
+ def __getitem__(self, key) -> DelayedTensorStore:
224
+ return explicit_indexing_adapter(
225
+ key, self.shape, IndexingSupport.OUTER, self._getitem
226
+ )
227
+
228
+ def _getitem(self, key) -> DelayedTensorStore:
229
+ return DelayedTensorStore(
230
+ self._store.instance[self.meerkat_key(key)].transpose(self.transpose_axes)
231
+ )
232
+
233
+
234
+ class ImmediateCorrProductArray(DelayedCorrProductArray):
235
+ def __init__(
236
+ self, store: Multiton[ts.TensorStore], cp_argsort: npt.NDArray, npol: int
237
+ ):
238
+ super().__init__(store, cp_argsort, npol)
239
+
240
+ def _getitem(self, key) -> npt.NDArray:
241
+ return super()._getitem(key).get_duck_array().read().result()
242
+
243
+
244
+ class DelayedBackendArray(
245
+ CorrProductMixin, ExplicitlyIndexedNDArrayMixin, AbstractMeerkatArchiveArray
246
+ ):
247
+ """Wraps a ``(time, frequency, corrprod)``` TensorStore.
248
+
249
+ Most data in the MeerKAT archive has dimension
250
+ ``(time, frequency, corrprod)``.
251
+ This class reshapes the underlying data into the
252
+ ``(time, baseline_id, frequency, polarization)`` form.
253
+ """
254
+
255
+ __slots__ = ("array",)
256
+
257
+ array: Multiton[ts.TensorStore]
258
+
259
+ def __init__(
260
+ self, store: Multiton[ts.Tensorstore], cp_argsort: npt.NDArray, npol: int
261
+ ):
262
+ super().__init__(
263
+ store.instance.shape,
264
+ store.instance.domain.labels,
265
+ store.instance.chunk_layout.read_chunk.shape,
266
+ cp_argsort,
267
+ npol,
268
+ )
269
+ self.array = store
270
+
271
+ @property
272
+ def dtype(self) -> npt.DTypeLike:
273
+ return self.array.instance.dtype.numpy_dtype
274
+
275
+ def get_duck_array(self):
276
+ return self.array.instance
277
+
278
+ async def async_get_duck_array(self):
279
+ return self.array.instance
280
+
281
+ def _oindex_get(self, indexer):
282
+ key = self.meerkat_key(indexer.tuple)
283
+ store = self.array.instance.oindex[key].transpose(self.transpose_axes)
284
+ return DelayedTensorStore(store)
285
+
286
+ def _vindex_get(self, indexer):
287
+ raise NotImplementedError("vindex")
288
+
289
+ def __getitem__(self, key):
290
+ mkey = self.meerkat_key(key.tuple)
291
+ store = self.array.instance[mkey].transpose(self.transpose_axes)
292
+ return DelayedTensorStore(store)
293
+
294
+
295
+ class ImmediateBackendArray(DelayedBackendArray):
296
+ def __init__(
297
+ self, store: Multiton[ts.Tensorstore], cp_argsort: npt.NDArray, npol: int
298
+ ):
299
+ super().__init__(store, cp_argsort, npol)
300
+
301
+ def _oindex_get(self, indexer):
302
+ return super()._oindex_get(indexer).get_duck_array().read().result()
303
+
304
+ def _vindex_get(self, indexer):
305
+ return super()._vindex_get(indexer).get_duck_array().read().result()
306
+
307
+ def __getitem__(self, key):
308
+ return super().__getitem__(key).get_duck_array().read().result()
@@ -0,0 +1,88 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ import threading
6
+ import weakref
7
+ from typing import Any, Dict
8
+
9
+ log = logging.getLogger(__name__)
10
+
11
+
12
+ class Singleton(type):
13
+ _instances: Dict[type, Any] = {}
14
+ _instance_lock = threading.Lock()
15
+
16
+ def __call__(cls, *args, **kwargs):
17
+ if cls not in cls._instances:
18
+ with cls._instance_lock:
19
+ if cls not in cls._instances:
20
+ cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
21
+
22
+ return cls._instances[cls]
23
+
24
+
25
+ def _run_loop_in_thread(
26
+ loop: asyncio.AbstractEventLoop, running: threading.Event
27
+ ) -> None:
28
+ asyncio.set_event_loop(loop)
29
+ running.set()
30
+
31
+ try:
32
+ loop.run_forever()
33
+ finally:
34
+ log.debug("Loop stops")
35
+ running.clear()
36
+ log.debug("Shutting down async generators")
37
+ loop.run_until_complete(loop.shutdown_asyncgens())
38
+ log.debug("Shutting down default executors")
39
+ loop.run_until_complete(loop.shutdown_default_executor())
40
+
41
+ log.debug("Closing the loop")
42
+ loop.close()
43
+ log.debug("Done")
44
+
45
+
46
+ class AsyncLoopSingleton(metaclass=Singleton):
47
+ _loop: asyncio.AbstractEventLoop | None
48
+ _thread: threading.Thread | None
49
+ _lock: threading.Lock
50
+ _running: threading.Event
51
+
52
+ def __init__(self):
53
+ self._loop = None
54
+ self._thread = None
55
+ self._lock = threading.Lock()
56
+ self._running = threading.Event()
57
+ weakref.finalize(self, self.close)
58
+ self.start()
59
+
60
+ @property
61
+ def instance(self):
62
+ return self._loop
63
+
64
+ def start(self) -> None:
65
+ with self._lock:
66
+ if self._thread and self._thread.is_alive():
67
+ return
68
+
69
+ self._loop = asyncio.new_event_loop()
70
+ self._thread = threading.Thread(
71
+ target=_run_loop_in_thread,
72
+ args=(self._loop, self._running),
73
+ daemon=True,
74
+ name="AsyncLoopThread",
75
+ )
76
+ self._thread.start()
77
+
78
+ def close(self) -> None:
79
+ with self._lock:
80
+ if not self._thread or not self._loop:
81
+ return
82
+
83
+ if self._loop.is_running():
84
+ self._loop.call_soon_threadsafe(self._loop.stop)
85
+
86
+ self._thread.join()
87
+ self._thread = None
88
+ self._loop = None