xarray-ms 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xarray_ms/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ __all__ = ["xds_from_zarr", "xdt_from_zarr", "xds_to_zarr", "xdt_to_zarr"]
2
+
3
+ from xarray_ms.core import xds_from_zarr, xds_to_zarr, xdt_from_zarr, xdt_to_zarr
@@ -0,0 +1,36 @@
1
+ from typing import Mapping
2
+
3
+ from xarray import Dataset, Variable
4
+
5
+ from xarray_ms.backend.msv2.structure import MSv2StructureFactory
6
+
7
+
8
+ class AntennaDatasetFactory:
9
+ _structure_factory: MSv2StructureFactory
10
+
11
+ def __init__(self, structure_factory: MSv2StructureFactory):
12
+ self._structure_factory = structure_factory
13
+
14
+ def get_dataset(self) -> Mapping[str, Variable]:
15
+ ants = self._structure_factory()._ant
16
+
17
+ import pyarrow.compute as pac
18
+
19
+ ant_pos = pac.list_flatten(ants["POSITION"]).to_numpy().reshape(-1, 3)
20
+
21
+ return Dataset(
22
+ data_vars={
23
+ "ANTENNA_POSITION": Variable(
24
+ ("antenna_name", "cartesian_pos_label/ellipsoid_pos_label"), ant_pos
25
+ )
26
+ },
27
+ coords={
28
+ "antenna_name": Variable("antenna_name", ants["NAME"].to_numpy()),
29
+ "station": Variable(
30
+ "antenna_name", ants["STATION"].to_numpy(), {"coordinates": "station"}
31
+ ),
32
+ "mount": Variable(
33
+ "antenna_name", ants["MOUNT"].to_numpy(), {"coordinates": "mount"}
34
+ ),
35
+ },
36
+ )
@@ -0,0 +1,75 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any, Callable, Tuple
4
+
5
+ import numpy as np
6
+ from xarray.backends import BackendArray
7
+ from xarray.core.indexing import IndexingSupport, explicit_indexing_adapter
8
+
9
+ if TYPE_CHECKING:
10
+ import numpy.typing as npt
11
+
12
+ from xarray_ms.backend.msv2.structure import MSv2StructureFactory, PartitionKeyT
13
+ from xarray_ms.backend.msv2.table_factory import TableFactory
14
+
15
+
16
+ def slice_length(s, max_len):
17
+ start, stop, step = s.indices(max_len)
18
+ if step != 1:
19
+ raise NotImplementedError(f"Slicing with steps {s} other than 1 not supported")
20
+ return stop - start
21
+
22
+
23
+ class MSv2Array(BackendArray):
24
+ """Backend array containing functionality for reading an MSv2 column"""
25
+
26
+ _table_factory: TableFactory
27
+ _structure_factory: MSv2StructureFactory
28
+ _partition: PartitionKeyT
29
+ _column: str
30
+ _shape: Tuple[int, ...]
31
+ _dtype: npt.DTypeLike
32
+ _default: Any | None
33
+ _transform: Callable[[npt.NDArray], npt.NDArray] | None
34
+
35
+ def __init__(
36
+ self,
37
+ table_factory: TableFactory,
38
+ structure_factory: MSv2StructureFactory,
39
+ partition: PartitionKeyT,
40
+ column: str,
41
+ shape: Tuple[int, ...],
42
+ dtype: npt.DTypeLike,
43
+ default: Any | None = None,
44
+ transform: Callable[[npt.NDArray], npt.NDArray] | None = None,
45
+ ):
46
+ self._table_factory = table_factory
47
+ self._structure_factory = structure_factory
48
+ self._partition = partition
49
+ self._column = column
50
+ self._default = default
51
+ self._transform = transform
52
+ self.shape = shape
53
+ self.dtype = np.dtype(dtype)
54
+
55
+ assert len(shape) >= 2, "(time, baseline) required"
56
+
57
+ def __getitem__(self, key) -> npt.NDArray:
58
+ return explicit_indexing_adapter(
59
+ key, self.shape, IndexingSupport.OUTER, self._getitem
60
+ )
61
+
62
+ def _getitem(self, key) -> npt.NDArray:
63
+ assert len(key) == len(self.shape)
64
+ expected_shape = tuple(slice_length(k, s) for k, s in zip(key, self.shape))
65
+ # Map the (time, baseline) coordinates onto row indices
66
+ rows = self._structure_factory()[self._partition].row_map[key[:2]]
67
+ xkey = (rows.ravel(),) + key[2:]
68
+ row_shape = (rows.size,) + expected_shape[2:]
69
+ result = np.full(row_shape, self._default, dtype=self.dtype)
70
+ self._table_factory().getcol(self._column, xkey, result)
71
+ result = result.reshape(rows.shape + expected_shape[2:])
72
+ return self._transform(result) if self._transform else result
73
+
74
+ def set_transform(self, transform: Callable[[npt.NDArray], npt.NDArray]):
75
+ self._transform = transform
@@ -0,0 +1,199 @@
1
+ """
2
+ TODO(sjperkins): This implementation is incomplete and
3
+ will need refactoring.
4
+
5
+ In particular, the reference columns and codes for more complex Measures
6
+ are not yet handled.
7
+
8
+ This logic can be found here:
9
+
10
+ casacore/measures/TableMeasRefDesc/TableMeasRefDesc.cc
11
+ TableMeasRefDesc::TableMeasRefDesc
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import sys
17
+ from datetime import datetime
18
+ from typing import TYPE_CHECKING, Any, Dict, List
19
+
20
+ if TYPE_CHECKING:
21
+ if sys.version_info >= (3, 11):
22
+ from typing import Self
23
+ else:
24
+ from typing_extensions import Self
25
+
26
+
27
+ import numpy as np
28
+ import numpy.testing as npt
29
+ from xarray import Variable
30
+ from xarray.coding.variables import (
31
+ VariableCoder,
32
+ unpack_for_decoding,
33
+ unpack_for_encoding,
34
+ )
35
+
36
+ from xarray_ms.backend.msv2.array import MSv2Array
37
+ from xarray_ms.errors import MissingMeasuresInfo, MissingQuantumUnits
38
+
39
+ if TYPE_CHECKING:
40
+ from xarray.coding.variables import T_Name
41
+
42
+ from xarray_ms.casa_types import ColumnDesc
43
+
44
+
45
+ class CasaCoder(VariableCoder):
46
+ """Base class for CASA Measures Coders"""
47
+
48
+ _column: str
49
+ _column_descs: Dict[str, ColumnDesc]
50
+
51
+ def __init__(self, column: str, column_descs: Dict[str, ColumnDesc]):
52
+ assert column in column_descs
53
+ self._column = column
54
+ self._column_descs = column_descs
55
+
56
+ @property
57
+ def column(self) -> str:
58
+ """Returns the column"""
59
+ return self._column
60
+
61
+ @property
62
+ def column_descs(self) -> Dict[str, ColumnDesc]:
63
+ """Returns the column descriptors"""
64
+ return self._column_descs
65
+
66
+ @property
67
+ def column_desc(self) -> ColumnDesc:
68
+ """Returns the column descriptor"""
69
+ try:
70
+ return self._column_descs[self._column]
71
+ except KeyError:
72
+ raise KeyError(f"No Column Descriptor exist for {self.column}")
73
+
74
+ @property
75
+ def measinfo(self) -> Dict[str, Any]:
76
+ """Returns the MEASINFO keyword in the column descriptor"""
77
+ kw = self.column_desc.keywords
78
+ try:
79
+ return kw["MEASINFO"]
80
+ except KeyError:
81
+ raise MissingMeasuresInfo(f"No MEASINFO found for {self.column}")
82
+
83
+ @property
84
+ def quantum_units(self) -> List[str]:
85
+ """REturns the QuantumUnits keyword in the column descriptor"""
86
+ kw = self.column_desc.keywords
87
+ try:
88
+ return kw["QuantumUnits"]
89
+ except KeyError:
90
+ raise MissingQuantumUnits(f"No QuantumUnits found for {self.column}")
91
+
92
+
93
+ class QuantityCoder(CasaCoder):
94
+ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
95
+ dims, data, attrs, encoding = unpack_for_encoding(variable)
96
+ attrs.pop("type", None)
97
+ attrs.pop("units", None)
98
+ return Variable(dims, data, attrs, encoding, fastpath=True)
99
+
100
+ def decode(self, variable: Variable, name: T_Name = None) -> Variable:
101
+ dims, data, attrs, encoding = unpack_for_decoding(variable)
102
+ attrs["type"] = "quantity"
103
+ attrs["units"] = self.quantum_units
104
+ return Variable(dims, data, attrs, encoding, fastpath=True)
105
+
106
+
107
+ class TimeCoder(CasaCoder):
108
+ """Dispatches encoding functionality to sub-classes"""
109
+
110
+ @classmethod
111
+ def from_time_coder(cls, time_coder: TimeCoder) -> Self:
112
+ return cls(time_coder._column, time_coder._column_descs)
113
+
114
+ def dispatched_coder(self) -> TimeCoder:
115
+ measures = self.measinfo
116
+ assert measures["type"] == "epoch"
117
+ ref = measures["Ref"].upper()
118
+ if ref == "UTC":
119
+ cls = UTCCoder
120
+ else:
121
+ raise NotImplementedError(measures["Ref"])
122
+
123
+ return cls.from_time_coder(self)
124
+
125
+ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
126
+ return self.dispatched_coder().encode(variable, name)
127
+
128
+ def decode(self, variable: Variable, name: T_Name = None) -> Variable:
129
+ return self.dispatched_coder().decode(variable, name)
130
+
131
+
132
+ class UTCCoder(TimeCoder):
133
+ """Encode MJD UTC"""
134
+
135
+ MJD_EPOCH: datetime = datetime(1858, 11, 17)
136
+ UTC_EPOCH: datetime = datetime(1970, 1, 1)
137
+ MJD_OFFSET_SECONDS: float = (UTC_EPOCH - MJD_EPOCH).total_seconds()
138
+
139
+ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
140
+ """Convert UTC in seconds to Modified Julian Date"""
141
+ dims, data, attrs, encoding = unpack_for_encoding(variable)
142
+ attrs.pop("type", None)
143
+ attrs.pop("units", None)
144
+ attrs.pop("scale", None)
145
+ attrs.pop("format", None)
146
+
147
+ if isinstance(data, MSv2Array):
148
+ data.set_transform(UTCCoder.encode_array)
149
+ elif isinstance(data, np.ndarray):
150
+ data = UTCCoder.encode_array(data)
151
+ else:
152
+ raise TypeError(f"Unknown data type {type(data)}")
153
+
154
+ return Variable(dims, data, attrs, encoding, fastpath=True)
155
+
156
+ def decode(self, variable: Variable, name: T_Name = None) -> Variable:
157
+ """Convert Modified Julian Date in seconds to UTC in seconds"""
158
+ dims, data, attrs, encoding = unpack_for_decoding(variable)
159
+ attrs["type"] = "time"
160
+ attrs["units"] = self.quantum_units
161
+ attrs["scale"] = "utc"
162
+ attrs["format"] = "unix"
163
+
164
+ if isinstance(data, MSv2Array):
165
+ data.set_transform(UTCCoder.decode_array)
166
+ elif isinstance(data, np.ndarray):
167
+ data = UTCCoder.decode_array(data)
168
+ else:
169
+ raise TypeError(f"Unknown data type {type(data)}")
170
+
171
+ return Variable(dims, data, attrs, encoding, fastpath=True)
172
+
173
+ @staticmethod
174
+ def encode_array(data: npt.NDArray) -> npt.NDArray:
175
+ return data + UTCCoder.MJD_OFFSET_SECONDS
176
+
177
+ @staticmethod
178
+ def decode_array(data: npt.NDArray) -> npt.NDArray:
179
+ return data - UTCCoder.MJD_OFFSET_SECONDS
180
+
181
+
182
+ class SpectralCoordCoder(CasaCoder):
183
+ """Encode Measures Spectral Coordinates"""
184
+
185
+ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
186
+ dims, data, attrs, encoding = unpack_for_encoding(variable)
187
+ attrs.pop("type", None)
188
+ return Variable(dims, data, attrs, encoding, fastpath=True)
189
+
190
+ def decode(self, variable: Variable, name: T_Name = None) -> Variable:
191
+ dims, data, attrs, encoding = unpack_for_decoding(variable)
192
+ measures = self.measinfo
193
+ assert measures["type"] == "frequency"
194
+ attrs["type"] = "spectral_coord"
195
+ # TODO(sjperkins): topo is hard-coded here and will almost
196
+ # certainly need extra work to support other frames
197
+ attrs["frame"] = "topo"
198
+ attrs["units"] = self.quantum_units
199
+ return Variable(dims, data, attrs, encoding, fastpath=True)
@@ -0,0 +1,354 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import warnings
5
+ from datetime import datetime, timezone
6
+ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple
7
+ from uuid import uuid4
8
+
9
+ import xarray
10
+ from arcae.lib.arrow_tables import Table
11
+ from xarray.backends import BackendEntrypoint
12
+ from xarray.backends.common import AbstractWritableDataStore, _normalize_path
13
+ from xarray.backends.store import StoreBackendEntrypoint
14
+ from xarray.core.dataset import Dataset
15
+ from xarray.core.datatree import DataTree
16
+ from xarray.core.utils import try_read_magic_number_from_file_or_path
17
+
18
+ from xarray_ms.backend.msv2.antenna_dataset_factory import AntennaDatasetFactory
19
+ from xarray_ms.backend.msv2.main_dataset_factory import MainDatasetFactory
20
+ from xarray_ms.backend.msv2.structure import (
21
+ DEFAULT_PARTITION_COLUMNS,
22
+ MSv2Structure,
23
+ MSv2StructureFactory,
24
+ )
25
+ from xarray_ms.backend.msv2.table_factory import TableFactory
26
+ from xarray_ms.errors import InvalidPartitionKey
27
+ from xarray_ms.utils import format_docstring
28
+
29
+ if TYPE_CHECKING:
30
+ from io import BufferedIOBase
31
+
32
+ from xarray.backends.common import AbstractDataStore
33
+
34
+ from xarray_ms.backend.msv2.structure import DEFAULT_PARTITION_COLUMNS, PartitionKeyT
35
+
36
+
37
+ def promote_chunks(
38
+ structure: MSv2Structure, chunks: Dict | None
39
+ ) -> Dict[PartitionKeyT, Dict[str, int]] | None:
40
+ """Promotes a chunks dictionary into a
41
+ :code:`{partition_key: chunks}` dictionary.
42
+ """
43
+ if chunks is None:
44
+ return None
45
+
46
+ # Base case, no chunking
47
+ return_chunks: Dict[PartitionKeyT, Dict[str, int]] = {k: {} for k in structure.keys()}
48
+
49
+ if all(isinstance(k, str) for k in chunks.keys()):
50
+ # All keys are strings, try promote them to partition keys
51
+ # keys, may resolve to multiple partition keys
52
+ try:
53
+ crkeys = list(map(structure.resolve_key, chunks.keys()))
54
+ except InvalidPartitionKey:
55
+ # Apply a chunk dictionary to all partitions in the structure
56
+ return {k: chunks for k in structure.keys()}
57
+ else:
58
+ return_chunks.update((k, v) for rk, v in zip(crkeys, chunks.values()) for k in rk)
59
+ else:
60
+ for k, v in chunks.items():
61
+ rkeys = structure.resolve_key(k)
62
+ return_chunks.update((k, v) for k in rkeys)
63
+
64
+ return return_chunks
65
+
66
+
67
+ def initialise_default_args(
68
+ ms: str,
69
+ ninstances: int,
70
+ auto_corrs: bool,
71
+ epoch: str | None,
72
+ table_factory: TableFactory | None,
73
+ partition_columns: List[str] | None,
74
+ structure_factory: MSv2StructureFactory | None,
75
+ ) -> Tuple[str, TableFactory, List[str], MSv2StructureFactory]:
76
+ """
77
+ Ensures consistency when initialising default arguments from multiple locations
78
+ """
79
+ if not os.path.exists(ms):
80
+ raise ValueError(f"MS {ms} does not exist")
81
+
82
+ table_factory = table_factory or TableFactory(
83
+ Table.from_filename,
84
+ ms,
85
+ ninstances=ninstances,
86
+ readonly=True,
87
+ lockoptions="nolock",
88
+ )
89
+ epoch = epoch or uuid4().hex[:8]
90
+ partition_columns = partition_columns or DEFAULT_PARTITION_COLUMNS
91
+ structure_factory = structure_factory or MSv2StructureFactory(
92
+ table_factory, partition_columns, auto_corrs=auto_corrs
93
+ )
94
+ return epoch, table_factory, partition_columns, structure_factory
95
+
96
+
97
+ class MSv2Store(AbstractWritableDataStore):
98
+ """Store for reading and writing MSv2 data"""
99
+
100
+ __slots__ = (
101
+ "_table_factory",
102
+ "_structure_factory",
103
+ "_partition_columns",
104
+ "_partition_key",
105
+ "_auto_corrs",
106
+ "_ninstances",
107
+ "_epoch",
108
+ )
109
+
110
+ _table_factory: TableFactory
111
+ _structure_factory: MSv2StructureFactory
112
+ _partition_columns: List[str]
113
+ _partition: PartitionKeyT
114
+ _autocorrs: bool
115
+ _ninstances: int
116
+ _epoch: str
117
+
118
+ def __init__(
119
+ self,
120
+ table_factory: TableFactory,
121
+ structure_factory: MSv2StructureFactory,
122
+ partition_columns: List[str],
123
+ partition_key: PartitionKeyT,
124
+ auto_corrs: bool,
125
+ ninstances: int,
126
+ epoch: str,
127
+ ):
128
+ self._table_factory = table_factory
129
+ self._structure_factory = structure_factory
130
+ self._partition_columns = partition_columns
131
+ self._partition_key = partition_key
132
+ self._auto_corrs = auto_corrs
133
+ self._ninstances = ninstances
134
+ self._epoch = epoch
135
+
136
+ @classmethod
137
+ def open(
138
+ cls,
139
+ ms: str,
140
+ drop_variables=None,
141
+ partition_columns: List[str] | None = None,
142
+ partition_key: PartitionKeyT | None = None,
143
+ auto_corrs: bool = True,
144
+ ninstances: int = 1,
145
+ epoch: str | None = None,
146
+ structure_factory: MSv2StructureFactory | None = None,
147
+ ):
148
+ if not isinstance(ms, str):
149
+ raise ValueError("Measurement Sets paths must be strings")
150
+
151
+ epoch, table_factory, partition_columns, structure_factory = (
152
+ initialise_default_args(
153
+ ms,
154
+ ninstances,
155
+ auto_corrs,
156
+ epoch,
157
+ None,
158
+ partition_columns,
159
+ structure_factory,
160
+ )
161
+ )
162
+
163
+ structure = structure_factory()
164
+
165
+ if partition_key is None:
166
+ partition_key = next(iter(structure.keys()))
167
+ warnings.warn(
168
+ f"No partition_key was supplied. Selected first partition {partition_key}"
169
+ )
170
+ elif partition_key not in structure:
171
+ raise ValueError(f"{partition_key} not in {list(structure.keys())}")
172
+
173
+ return cls(
174
+ table_factory,
175
+ structure_factory,
176
+ partition_columns=partition_columns,
177
+ partition_key=partition_key,
178
+ auto_corrs=auto_corrs,
179
+ ninstances=ninstances,
180
+ epoch=epoch,
181
+ )
182
+
183
+ def close(self, **kwargs):
184
+ pass
185
+
186
+ def get_variables(self):
187
+ return MainDatasetFactory(
188
+ self._partition_key, self._table_factory, self._structure_factory
189
+ ).get_variables()
190
+
191
+ def get_attrs(self):
192
+ try:
193
+ ddid = next(iter(v for k, v in self._partition_key if k == "DATA_DESC_ID"))
194
+ except StopIteration:
195
+ raise KeyError("DATA_DESC_ID not found in partition")
196
+
197
+ antenna_factory = AntennaDatasetFactory(self._structure_factory)
198
+ ds = antenna_factory.get_dataset()
199
+
200
+ return {
201
+ "antenna_xds": ds,
202
+ "version": "0.0.1",
203
+ "creation_date": datetime.now(timezone.utc).isoformat(),
204
+ "data_description_id": ddid,
205
+ }
206
+
207
+ def get_dimensions(self):
208
+ return None
209
+
210
+ def get_encoding(self):
211
+ return {}
212
+
213
+
214
+ class MSv2PartitionEntryPoint(BackendEntrypoint):
215
+ open_dataset_parameters = [
216
+ "filename_or_obj",
217
+ "partition_columns" "partition_key",
218
+ "auto_corrs",
219
+ "ninstances",
220
+ "epoch",
221
+ "structure_factory",
222
+ ]
223
+ description = "Opens v2 CASA Measurement Sets in Xarray"
224
+ url = "https://link_to/your_backend/documentation"
225
+
226
+ def guess_can_open(
227
+ self, filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore
228
+ ) -> bool:
229
+ """Return true if this is a CASA table"""
230
+ if not isinstance(filename_or_obj, (str, os.PathLike)):
231
+ return False
232
+
233
+ # CASA Tables are directories containing a table.dat file
234
+ table_path = os.path.join(_normalize_path(filename_or_obj), "table.dat")
235
+ if not os.path.exists(table_path):
236
+ return False
237
+
238
+ # Check the magic number
239
+ if magic := try_read_magic_number_from_file_or_path(table_path, count=4):
240
+ return magic == b"\xbe\xbe\xbe\xbe"
241
+
242
+ return False
243
+
244
+ @format_docstring(DEFAULT_PARTITION_COLUMNS=DEFAULT_PARTITION_COLUMNS)
245
+ def open_dataset(
246
+ self,
247
+ filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
248
+ *,
249
+ drop_variables: str | Iterable[str] | None = None,
250
+ partition_columns: List[str] | None = None,
251
+ partition_key: PartitionKeyT | None = None,
252
+ auto_corrs: bool = True,
253
+ ninstances: int = 8,
254
+ epoch: str | None = None,
255
+ structure_factory: MSv2StructureFactory | None = None,
256
+ ) -> Dataset:
257
+ """Create a :class:`~xarray.Dataset` presenting an MSv4 view
258
+ over a partition of a MSv2 CASA Measurement Set
259
+
260
+ Args:
261
+ filename_or_obj: The path to the MSv2 CASA Measurement Set file.
262
+ drop_variables: Variables to drop from the dataset.
263
+ partition_columns: The columns to use for partitioning the Measurement set.
264
+ Defaults to :code:`{DEFAULT_PARTITION_COLUMNS}`.
265
+ partition_key: A key corresponding to an individual partition.
266
+ For example :code:`(('DATA_DESC_ID', 0), ('FIELD_ID', 0))`.
267
+ If :code:`None`, the first partition will be opened.
268
+ auto_corrs: Include/Exclude auto-correlations.
269
+ ninstances: The number of Measurement Set instances to open for parallel I/O.
270
+ epoch: A unique string identifying the creation of this Dataset.
271
+ This should not normally need to be set by the user
272
+ structure_factory: A factory for creating MSv2Structure objects.
273
+ This should not normally need to be set by the user
274
+
275
+ Returns:
276
+ A :class:`~xarray.Dataset` referring to the unique
277
+ partition specified by :code:`partition_columns` and :code:`partition_key`.
278
+ """
279
+ filename_or_obj = _normalize_path(filename_or_obj)
280
+ store = MSv2Store.open(
281
+ filename_or_obj,
282
+ drop_variables=drop_variables,
283
+ partition_columns=partition_columns,
284
+ partition_key=partition_key,
285
+ auto_corrs=auto_corrs,
286
+ ninstances=ninstances,
287
+ epoch=epoch,
288
+ structure_factory=structure_factory,
289
+ )
290
+ store_entrypoint = StoreBackendEntrypoint()
291
+ return store_entrypoint.open_dataset(store)
292
+
293
+ @format_docstring(DEFAULT_PARTITION_COLUMNS=DEFAULT_PARTITION_COLUMNS)
294
+ def open_datatree(
295
+ self,
296
+ filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
297
+ *,
298
+ drop_variables: str | Iterable[str] | None = None,
299
+ partition_columns: List[str] | None = None,
300
+ auto_corrs: bool = True,
301
+ ninstances: int = 8,
302
+ epoch: str | None = None,
303
+ **kwargs,
304
+ ) -> DataTree:
305
+ """Create a :class:`~xarray.core.datatree.DataTree` presenting an MSv4 view
306
+ over multiple partitions of a MSv2 CASA Measurement Set.
307
+
308
+ Args:
309
+ filename_or_obj: The path to the MSv2 CASA Measurement Set file.
310
+ drop_variables: Variables to drop from the dataset.
311
+ partition_columns: The columns to use for partitioning the Measurement set.
312
+ Defaults to :code:`{DEFAULT_PARTITION_COLUMNS}`.
313
+ auto_corrs: Include/Exclude auto-correlations.
314
+ ninstances: The number of Measurement Set instances to open for parallel I/O.
315
+ epoch: A unique string identifying the creation of this Dataset.
316
+ This should not normally need to be set by the user
317
+
318
+ Returns:
319
+ An xarray :class:`~xarray.core.datatree.DataTree`
320
+ """
321
+ if isinstance(filename_or_obj, os.PathLike):
322
+ ms = str(filename_or_obj)
323
+ elif isinstance(filename_or_obj, str):
324
+ ms = filename_or_obj
325
+ else:
326
+ raise ValueError("Measurement Set paths must be strings")
327
+
328
+ epoch, _, partition_columns, structure_factory = initialise_default_args(
329
+ ms, ninstances, auto_corrs, epoch, None, partition_columns, None
330
+ )
331
+
332
+ structure = structure_factory()
333
+ datasets = {}
334
+ chunks = kwargs.pop("chunks", None)
335
+ pchunks = promote_chunks(structure, chunks)
336
+
337
+ for partition_key in structure:
338
+ ds = xarray.open_dataset(
339
+ ms,
340
+ drop_variables=drop_variables,
341
+ partition_columns=partition_columns,
342
+ partition_key=partition_key,
343
+ auto_corrs=auto_corrs,
344
+ ninstances=ninstances,
345
+ epoch=epoch,
346
+ structure_factory=structure_factory,
347
+ chunks=None if pchunks is None else pchunks[partition_key],
348
+ **kwargs,
349
+ )
350
+
351
+ key = ",".join(f"{k}={v}" for k, v in sorted(partition_key))
352
+ datasets[key] = ds
353
+
354
+ return DataTree.from_dict(datasets)