xarray-ms 0.2.4__tar.gz → 0.2.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xarray_ms-0.2.4 → xarray_ms-0.2.6}/PKG-INFO +1 -1
- {xarray_ms-0.2.4 → xarray_ms-0.2.6}/pyproject.toml +2 -2
- xarray_ms-0.2.6/xarray_ms/backend/msv2/array.py +170 -0
- {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/backend/msv2/factories/antenna.py +4 -3
- {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/backend/msv2/factories/correlated.py +84 -36
- xarray_ms-0.2.6/xarray_ms/backend/msv2/imputation.py +95 -0
- {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/backend/msv2/structure.py +45 -22
- {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/errors.py +9 -0
- {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/testing/simulator.py +15 -16
- xarray_ms-0.2.4/xarray_ms/backend/msv2/array.py +0 -87
- {xarray_ms-0.2.4 → xarray_ms-0.2.6}/LICENSE +0 -0
- {xarray_ms-0.2.4 → xarray_ms-0.2.6}/README.rst +0 -0
- {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/__init__.py +0 -0
- {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/backend/msv2/encoders.py +0 -0
- {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/backend/msv2/entrypoint.py +0 -0
- {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/backend/msv2/entrypoint_utils.py +0 -0
- {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/backend/msv2/factories/__init__.py +0 -0
- {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/backend/msv2/partition.py +0 -0
- {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/casa_types.py +0 -0
- {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/msv4_types.py +0 -0
- {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/multiton.py +0 -0
- {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/query.py +0 -0
- {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/testing/__init__.py +0 -0
- {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/testing/utils.py +0 -0
- {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "xarray-ms"
|
|
3
|
-
version = "0.2.
|
|
3
|
+
version = "0.2.6"
|
|
4
4
|
description = "xarray MSv4 views over MSv2 Measurement Sets"
|
|
5
5
|
authors = ["Simon Perkins <simon.perkins@gmail.com>"]
|
|
6
6
|
readme = "README.rst"
|
|
@@ -58,7 +58,7 @@ build-backend = "poetry.core.masonry.api"
|
|
|
58
58
|
# github_url = "https://github.com/<user or organization>/<project>/"
|
|
59
59
|
|
|
60
60
|
[tool.tbump.version]
|
|
61
|
-
current = "0.2.
|
|
61
|
+
current = "0.2.6"
|
|
62
62
|
|
|
63
63
|
# Example of a semver regexp.
|
|
64
64
|
# Make sure this matches current_version before
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Callable, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from xarray.backends import BackendArray
|
|
7
|
+
from xarray.core.indexing import IndexingSupport, explicit_indexing_adapter
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
import numpy.typing as npt
|
|
11
|
+
|
|
12
|
+
from xarray_ms.backend.msv2.structure import MSv2StructureFactory, PartitionKeyT
|
|
13
|
+
from xarray_ms.multiton import Multiton
|
|
14
|
+
|
|
15
|
+
TransformerT = Callable[[npt.NDArray], npt.NDArray]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def slice_length(s: npt.NDArray | slice, max_len) -> int:
|
|
19
|
+
if isinstance(s, np.ndarray):
|
|
20
|
+
if s.ndim != 1:
|
|
21
|
+
raise NotImplementedError("Slicing with non-1D numpy arrays")
|
|
22
|
+
return len(s)
|
|
23
|
+
|
|
24
|
+
start, stop, step = s.indices(max_len)
|
|
25
|
+
if step != 1:
|
|
26
|
+
raise NotImplementedError(f"Slicing with steps {s} other than 1 not supported")
|
|
27
|
+
return stop - start
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class MSv2Array(BackendArray):
|
|
31
|
+
"""Base MSv2Array backend array class,
|
|
32
|
+
containing required shape and data type"""
|
|
33
|
+
|
|
34
|
+
__slots__ = ("shape", "dtype")
|
|
35
|
+
|
|
36
|
+
shape: Tuple[int, ...]
|
|
37
|
+
dtype: npt.DTypeLike
|
|
38
|
+
|
|
39
|
+
def __init__(self, shape: Tuple[int, ...], dtype: npt.DTypeLike):
|
|
40
|
+
self.shape = shape
|
|
41
|
+
self.dtype = dtype
|
|
42
|
+
|
|
43
|
+
def __getitem__(self, key) -> npt.NDArray:
|
|
44
|
+
raise NotImplementedError
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def transform(self) -> TransformerT | None:
|
|
48
|
+
raise NotImplementedError
|
|
49
|
+
|
|
50
|
+
@transform.setter
|
|
51
|
+
def transform(self, value: TransformerT) -> None:
|
|
52
|
+
raise NotImplementedError
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class MainMSv2Array(MSv2Array):
|
|
56
|
+
"""Backend array containing functionality for reading an MSv2 column
|
|
57
|
+
from the MAIN table. Columns are assumed to have ("time", "baseline_id")
|
|
58
|
+
as the first dimensions. These are mapped onto the "row" dimension
|
|
59
|
+
via the partition row map"""
|
|
60
|
+
|
|
61
|
+
__slots__ = (
|
|
62
|
+
"_table_factory",
|
|
63
|
+
"_structure_factory",
|
|
64
|
+
"_partition",
|
|
65
|
+
"_column",
|
|
66
|
+
"_default",
|
|
67
|
+
"_transform",
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
_table_factory: Multiton
|
|
71
|
+
_structure_factory: MSv2StructureFactory
|
|
72
|
+
_partition: PartitionKeyT
|
|
73
|
+
_column: str
|
|
74
|
+
_default: Any | None
|
|
75
|
+
_transform: TransformerT | None
|
|
76
|
+
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
table_factory: Multiton,
|
|
80
|
+
structure_factory: MSv2StructureFactory,
|
|
81
|
+
partition: PartitionKeyT,
|
|
82
|
+
column: str,
|
|
83
|
+
shape: Tuple[int, ...],
|
|
84
|
+
dtype: npt.DTypeLike,
|
|
85
|
+
default: Any | None = None,
|
|
86
|
+
transform: TransformerT | None = None,
|
|
87
|
+
):
|
|
88
|
+
super().__init__(shape, dtype)
|
|
89
|
+
self._table_factory = table_factory
|
|
90
|
+
self._structure_factory = structure_factory
|
|
91
|
+
self._partition = partition
|
|
92
|
+
self._column = column
|
|
93
|
+
self._default = default
|
|
94
|
+
self._transform = transform
|
|
95
|
+
|
|
96
|
+
assert len(shape) >= 2, "(time, baseline_ids) required"
|
|
97
|
+
|
|
98
|
+
def __getitem__(self, key) -> npt.NDArray:
|
|
99
|
+
return explicit_indexing_adapter(
|
|
100
|
+
key, self.shape, IndexingSupport.OUTER, self._getitem
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
def _getitem(self, key) -> npt.NDArray:
|
|
104
|
+
assert len(key) == len(self.shape)
|
|
105
|
+
expected_shape = tuple(slice_length(k, s) for k, s in zip(key, self.shape))
|
|
106
|
+
# Map the (time, baseline_id) coordinates onto row indices
|
|
107
|
+
rows = self._structure_factory.instance[self._partition].row_map[key[:2]]
|
|
108
|
+
row_key = (rows.ravel(),) + key[2:]
|
|
109
|
+
row_shape = (rows.size,) + expected_shape[2:]
|
|
110
|
+
result = np.full(row_shape, self._default, dtype=self.dtype)
|
|
111
|
+
self._table_factory.instance.getcol(self._column, row_key, result)
|
|
112
|
+
result = result.reshape(rows.shape + expected_shape[2:])
|
|
113
|
+
return self._transform(result) if self._transform else result
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def transform(self) -> TransformerT | None:
|
|
117
|
+
return self._transform
|
|
118
|
+
|
|
119
|
+
@transform.setter
|
|
120
|
+
def transform(self, value: TransformerT) -> None:
|
|
121
|
+
self._transform = value
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class BroadcastMSv2Array(MSv2Array):
|
|
125
|
+
"""Broadcasts a MAIN table MSv2 Column up to an
|
|
126
|
+
MSv4 column. This can be inefficient for example,
|
|
127
|
+
if multiple frequency chunks are read for the same
|
|
128
|
+
("time", "baseline_id") range as the same
|
|
129
|
+
low resolution data can be read multiple times.
|
|
130
|
+
|
|
131
|
+
However, this should be no worse than reading the
|
|
132
|
+
data for a full resolution column.
|
|
133
|
+
|
|
134
|
+
This is primarily useful for falling back to the
|
|
135
|
+
WEIGHT column when WEIGHT_SPECTRUM is missing, or
|
|
136
|
+
FLAG_ROW if FLAG is missing.
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
__slots__ = ("_low_res_array", "_low_res_index")
|
|
140
|
+
|
|
141
|
+
_low_res_array: MSv2Array
|
|
142
|
+
_low_res_index: Tuple[slice | None, ...]
|
|
143
|
+
shape: Tuple[int, ...]
|
|
144
|
+
|
|
145
|
+
def __init__(
|
|
146
|
+
self,
|
|
147
|
+
low_res_array: MSv2Array,
|
|
148
|
+
low_res_index: Tuple[slice | None, ...],
|
|
149
|
+
high_res_shape: Tuple[int, ...],
|
|
150
|
+
):
|
|
151
|
+
self._low_res_array = low_res_array
|
|
152
|
+
self._low_res_index = low_res_index
|
|
153
|
+
self.shape = high_res_shape
|
|
154
|
+
|
|
155
|
+
@property
|
|
156
|
+
def dtype(self):
|
|
157
|
+
return self._low_res_array.dtype
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def transform(self) -> TransformerT | None:
|
|
161
|
+
return self._low_res_array.transform
|
|
162
|
+
|
|
163
|
+
@transform.setter
|
|
164
|
+
def transform(self, value: TransformerT) -> None:
|
|
165
|
+
self._low_res_array.transform = value
|
|
166
|
+
|
|
167
|
+
def __getitem__(self, key) -> npt.NDArray:
|
|
168
|
+
low_res_data = self._low_res_array.__getitem__(key)
|
|
169
|
+
low_res_data = low_res_data[self._low_res_index]
|
|
170
|
+
return np.broadcast_to(low_res_data, self.shape)
|
|
@@ -3,6 +3,7 @@ from typing import Dict, Mapping
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
from xarray import Dataset, Variable
|
|
5
5
|
|
|
6
|
+
from xarray_ms.backend.msv2.imputation import maybe_impute_observation_table
|
|
6
7
|
from xarray_ms.backend.msv2.structure import MSv2StructureFactory, PartitionKeyT
|
|
7
8
|
from xarray_ms.errors import InvalidMeasurementSet
|
|
8
9
|
from xarray_ms.multiton import Multiton
|
|
@@ -26,13 +27,13 @@ class AntennaDatasetFactory:
|
|
|
26
27
|
self._subtable_factories = subtable_factories
|
|
27
28
|
|
|
28
29
|
def get_dataset(self) -> Mapping[str, Variable]:
|
|
29
|
-
|
|
30
|
-
partition = structure[self._partition_key]
|
|
30
|
+
partition = self._structure_factory.instance[self._partition_key]
|
|
31
31
|
ants = self._subtable_factories["ANTENNA"].instance
|
|
32
32
|
feeds = self._subtable_factories["FEED"].instance
|
|
33
33
|
obs = self._subtable_factories["OBSERVATION"].instance
|
|
34
34
|
|
|
35
|
-
|
|
35
|
+
obs = maybe_impute_observation_table(obs, [partition.obs_id])
|
|
36
|
+
telescope_name = obs["TELESCOPE_NAME"][0].as_py()
|
|
36
37
|
|
|
37
38
|
import pyarrow.compute as pac
|
|
38
39
|
|
|
@@ -8,12 +8,20 @@ from xarray.coding.variables import unpack_for_decoding
|
|
|
8
8
|
from xarray.core.indexing import LazilyIndexedArray
|
|
9
9
|
from xarray.core.utils import FrozenDict
|
|
10
10
|
|
|
11
|
-
from xarray_ms.backend.msv2.array import
|
|
11
|
+
from xarray_ms.backend.msv2.array import (
|
|
12
|
+
BroadcastMSv2Array,
|
|
13
|
+
MainMSv2Array,
|
|
14
|
+
MSv2Array,
|
|
15
|
+
)
|
|
12
16
|
from xarray_ms.backend.msv2.encoders import (
|
|
13
17
|
CasaCoder,
|
|
14
18
|
QuantityCoder,
|
|
15
19
|
TimeCoder,
|
|
16
20
|
)
|
|
21
|
+
from xarray_ms.backend.msv2.imputation import (
|
|
22
|
+
maybe_impute_field_table,
|
|
23
|
+
maybe_impute_observation_table,
|
|
24
|
+
)
|
|
17
25
|
from xarray_ms.backend.msv2.structure import MSv2StructureFactory, PartitionKeyT
|
|
18
26
|
from xarray_ms.casa_types import ColumnDesc, FrequencyMeasures, Polarisations
|
|
19
27
|
from xarray_ms.errors import IrregularGridWarning
|
|
@@ -26,6 +34,7 @@ class MSv2ColumnSchema:
|
|
|
26
34
|
dims: Tuple[str, ...]
|
|
27
35
|
default: Any = None
|
|
28
36
|
coder: Type[CasaCoder] | None = None
|
|
37
|
+
low_res_dims: Tuple[str, ...] | None = None
|
|
29
38
|
|
|
30
39
|
|
|
31
40
|
MSV4_to_MSV2_COLUMN_SCHEMAS = {
|
|
@@ -34,10 +43,20 @@ MSV4_to_MSV2_COLUMN_SCHEMAS = {
|
|
|
34
43
|
"TIME_CENTROID": MSv2ColumnSchema("TIME_CENTROID", (), np.nan, TimeCoder),
|
|
35
44
|
"EFFECTIVE_INTEGRATION_TIME": MSv2ColumnSchema("EXPOSURE", (), np.nan, QuantityCoder),
|
|
36
45
|
"UVW": MSv2ColumnSchema("UVW", ("uvw_label",), np.nan, None),
|
|
46
|
+
"FLAG_ROW": MSv2ColumnSchema(
|
|
47
|
+
"FLAG_ROW", ("frequency", "polarization"), 1, None, low_res_dims=()
|
|
48
|
+
),
|
|
37
49
|
"FLAG": MSv2ColumnSchema("FLAG", ("frequency", "polarization"), 1, None),
|
|
38
50
|
"VISIBILITY": MSv2ColumnSchema(
|
|
39
51
|
"DATA", ("frequency", "polarization"), np.nan + np.nan * 1j, None
|
|
40
52
|
),
|
|
53
|
+
"WEIGHT_ROW": MSv2ColumnSchema(
|
|
54
|
+
"WEIGHT",
|
|
55
|
+
("frequency", "polarization"),
|
|
56
|
+
np.nan,
|
|
57
|
+
None,
|
|
58
|
+
low_res_dims=("polarization",),
|
|
59
|
+
),
|
|
41
60
|
"WEIGHT": MSv2ColumnSchema(
|
|
42
61
|
"WEIGHT_SPECTRUM", ("frequency", "polarization"), np.nan, None
|
|
43
62
|
),
|
|
@@ -78,11 +97,8 @@ class CorrelatedDatasetFactory:
|
|
|
78
97
|
c: ColumnDesc.from_descriptor(c, ms_table_desc) for c in ms.columns()
|
|
79
98
|
}
|
|
80
99
|
|
|
81
|
-
def _variable_from_column(self, column: str) -> Variable:
|
|
100
|
+
def _variable_from_column(self, column: str, dim_sizes: Dict[str, int]) -> Variable:
|
|
82
101
|
"""Derive an xarray Variable from the MSv2 column descriptor and schemas"""
|
|
83
|
-
structure = self._structure_factory.instance
|
|
84
|
-
partition = structure[self._partition_key]
|
|
85
|
-
|
|
86
102
|
try:
|
|
87
103
|
schema = MSV4_to_MSV2_COLUMN_SCHEMAS[column]
|
|
88
104
|
except KeyError:
|
|
@@ -93,20 +109,6 @@ class CorrelatedDatasetFactory:
|
|
|
93
109
|
except KeyError:
|
|
94
110
|
raise KeyError(f"No Column Descriptor exist for {schema.name}")
|
|
95
111
|
|
|
96
|
-
spw = self._subtable_factories["SPECTRAL_WINDOW"].instance
|
|
97
|
-
pol = self._subtable_factories["POLARIZATION"].instance
|
|
98
|
-
|
|
99
|
-
chan_freq = spw["CHAN_FREQ"][partition.spw_id].as_py()
|
|
100
|
-
corr_type = pol["CORR_TYPE"][partition.pol_id].as_py()
|
|
101
|
-
|
|
102
|
-
dim_sizes = {
|
|
103
|
-
"time": len(partition.time),
|
|
104
|
-
"baseline_id": partition.nbl,
|
|
105
|
-
"frequency": len(chan_freq),
|
|
106
|
-
"polarization": len(corr_type),
|
|
107
|
-
**FIXED_DIMENSION_SIZES,
|
|
108
|
-
}
|
|
109
|
-
|
|
110
112
|
dims = ("time", "baseline_id") + schema.dims
|
|
111
113
|
|
|
112
114
|
try:
|
|
@@ -116,7 +118,20 @@ class CorrelatedDatasetFactory:
|
|
|
116
118
|
|
|
117
119
|
default = column_desc.dtype.type(schema.default)
|
|
118
120
|
|
|
119
|
-
|
|
121
|
+
high_res_shape = shape
|
|
122
|
+
low_res_index: Tuple[slice | None, ...] = tuple(slice(None) for _ in shape)
|
|
123
|
+
|
|
124
|
+
if schema.low_res_dims:
|
|
125
|
+
low_res_dims = ("time", "baseline_id") + schema.low_res_dims
|
|
126
|
+
high_res_shape = shape
|
|
127
|
+
try:
|
|
128
|
+
shape_map = {d: dim_sizes[d] for d in low_res_dims}
|
|
129
|
+
except KeyError as e:
|
|
130
|
+
raise KeyError(f"No dimension size found for {e.args[0]}")
|
|
131
|
+
low_res_index = tuple(slice(None) if d in shape_map else None for d in dims)
|
|
132
|
+
shape = tuple(shape_map.values())
|
|
133
|
+
|
|
134
|
+
array: MSv2Array = MainMSv2Array(
|
|
120
135
|
self._ms_factory,
|
|
121
136
|
self._structure_factory,
|
|
122
137
|
self._partition_key,
|
|
@@ -126,7 +141,10 @@ class CorrelatedDatasetFactory:
|
|
|
126
141
|
default,
|
|
127
142
|
)
|
|
128
143
|
|
|
129
|
-
|
|
144
|
+
if schema.low_res_dims:
|
|
145
|
+
array = BroadcastMSv2Array(array, low_res_index, high_res_shape)
|
|
146
|
+
|
|
147
|
+
var = Variable(dims, array, fastpath=True)
|
|
130
148
|
|
|
131
149
|
# Apply any measures encoding
|
|
132
150
|
if schema.coder:
|
|
@@ -144,8 +162,7 @@ class CorrelatedDatasetFactory:
|
|
|
144
162
|
structure = self._structure_factory.instance
|
|
145
163
|
partition = structure[self._partition_key]
|
|
146
164
|
ant1, ant2 = partition.antenna_pairs
|
|
147
|
-
nbl
|
|
148
|
-
assert (nbl,) == ant1.shape
|
|
165
|
+
assert (partition.nbl,) == ant1.shape
|
|
149
166
|
|
|
150
167
|
antenna = self._subtable_factories["ANTENNA"].instance
|
|
151
168
|
ant_names = antenna["NAME"].to_numpy()
|
|
@@ -156,6 +173,7 @@ class CorrelatedDatasetFactory:
|
|
|
156
173
|
pol_id = partition.pol_id
|
|
157
174
|
spw = self._subtable_factories["SPECTRAL_WINDOW"].instance
|
|
158
175
|
pol = self._subtable_factories["POLARIZATION"].instance
|
|
176
|
+
field = self._subtable_factories["FIELD"].instance
|
|
159
177
|
|
|
160
178
|
chan_freq = spw["CHAN_FREQ"][spw_id].as_py()
|
|
161
179
|
uchan_width = np.unique(spw["CHAN_WIDTH"][spw_id].as_py())
|
|
@@ -167,6 +185,14 @@ class CorrelatedDatasetFactory:
|
|
|
167
185
|
|
|
168
186
|
corr_type = Polarisations.from_values(pol["CORR_TYPE"][pol_id].as_py()).to_str()
|
|
169
187
|
|
|
188
|
+
dim_sizes = {
|
|
189
|
+
"time": len(partition.time),
|
|
190
|
+
"baseline_id": partition.nbl,
|
|
191
|
+
"frequency": len(chan_freq),
|
|
192
|
+
"polarization": len(corr_type),
|
|
193
|
+
**FIXED_DIMENSION_SIZES,
|
|
194
|
+
}
|
|
195
|
+
|
|
170
196
|
row_map = partition.row_map
|
|
171
197
|
missing = np.count_nonzero(row_map == -1)
|
|
172
198
|
if missing > 0:
|
|
@@ -181,17 +207,28 @@ class CorrelatedDatasetFactory:
|
|
|
181
207
|
)
|
|
182
208
|
|
|
183
209
|
data_vars = [
|
|
184
|
-
(n, self._variable_from_column(n))
|
|
210
|
+
(n, self._variable_from_column(n, dim_sizes))
|
|
185
211
|
for n in (
|
|
186
212
|
"TIME_CENTROID",
|
|
187
213
|
"EFFECTIVE_INTEGRATION_TIME",
|
|
188
214
|
"UVW",
|
|
189
215
|
"VISIBILITY",
|
|
190
|
-
"FLAG",
|
|
191
|
-
"WEIGHT",
|
|
192
216
|
)
|
|
193
217
|
]
|
|
194
218
|
|
|
219
|
+
if "FLAG" in self._main_column_descs:
|
|
220
|
+
data_vars.append(("FLAG", self._variable_from_column("FLAG", dim_sizes)))
|
|
221
|
+
else:
|
|
222
|
+
data_vars.append(("FLAG", self._variable_from_column("FLAG_ROW", dim_sizes)))
|
|
223
|
+
|
|
224
|
+
if "WEIGHT_SPECTRUM" in self._main_column_descs:
|
|
225
|
+
data_vars.append(("WEIGHT", self._variable_from_column("WEIGHT", dim_sizes)))
|
|
226
|
+
else:
|
|
227
|
+
data_vars.append(("WEIGHT", self._variable_from_column("WEIGHT_ROW", dim_sizes)))
|
|
228
|
+
|
|
229
|
+
field = maybe_impute_field_table(field, partition.field_ids)
|
|
230
|
+
field_names = field.take(partition.field_ids)["NAME"].to_numpy()
|
|
231
|
+
|
|
195
232
|
# Add coordinates indexing coordinates
|
|
196
233
|
coordinates = [
|
|
197
234
|
(
|
|
@@ -208,6 +245,12 @@ class CorrelatedDatasetFactory:
|
|
|
208
245
|
),
|
|
209
246
|
("polarization", (("polarization",), corr_type, None)),
|
|
210
247
|
("uvw_label", (("uvw_label",), ["u", "v", "w"], None)),
|
|
248
|
+
("field_name", ("time", field_names, {"coordinates": "field_name"})),
|
|
249
|
+
("scan_number", ("time", partition.scan_numbers, {"coordinates": "scan_number"})),
|
|
250
|
+
(
|
|
251
|
+
"sub_scan_number",
|
|
252
|
+
("time", partition.sub_scan_numbers, {"coordinates": "sub_scan_number"}),
|
|
253
|
+
),
|
|
211
254
|
]
|
|
212
255
|
|
|
213
256
|
e = {"preferred_chunks": self._preferred_chunks} if self._preferred_chunks else None
|
|
@@ -217,8 +260,14 @@ class CorrelatedDatasetFactory:
|
|
|
217
260
|
time_coder = TimeCoder("TIME", self._main_column_descs)
|
|
218
261
|
|
|
219
262
|
if partition.interval.size == 1:
|
|
263
|
+
# Single unique value
|
|
220
264
|
time_attrs = {"integration_time": partition.interval.item()}
|
|
265
|
+
elif np.allclose(partition.interval[:, None], partition.interval[None, :]):
|
|
266
|
+
# Tolerate some jitter in the unique values
|
|
267
|
+
time_attrs = {"integration_time": np.mean(partition.interval)}
|
|
221
268
|
else:
|
|
269
|
+
# There are multiple unique interval values,
|
|
270
|
+
# a regular grid isn't possible
|
|
222
271
|
warnings.warn(
|
|
223
272
|
f"Missing/Multiple intervals {partition.interval} "
|
|
224
273
|
f"found in partition {self._partition_key}. "
|
|
@@ -230,8 +279,11 @@ class CorrelatedDatasetFactory:
|
|
|
230
279
|
time_attrs = {"integration_time": np.nan}
|
|
231
280
|
data_vars.extend(
|
|
232
281
|
[
|
|
233
|
-
("TIME", self._variable_from_column("TIME")),
|
|
234
|
-
(
|
|
282
|
+
("TIME", self._variable_from_column("TIME", dim_sizes)),
|
|
283
|
+
(
|
|
284
|
+
"INTEGRATION_TIME",
|
|
285
|
+
self._variable_from_column("INTEGRATION_TIME", dim_sizes),
|
|
286
|
+
),
|
|
235
287
|
]
|
|
236
288
|
)
|
|
237
289
|
|
|
@@ -272,19 +324,15 @@ class CorrelatedDatasetFactory:
|
|
|
272
324
|
return FrozenDict(sorted(data_vars + coordinates))
|
|
273
325
|
|
|
274
326
|
def _observation_info(self) -> Dict[str, Any]:
|
|
275
|
-
|
|
276
|
-
partition = structure[self._partition_key]
|
|
327
|
+
partition = self._structure_factory.instance[self._partition_key]
|
|
277
328
|
obs = self._subtable_factories["OBSERVATION"].instance
|
|
278
|
-
|
|
279
|
-
project = obs["PROJECT"][partition.obs_id].as_py()
|
|
280
|
-
# TODO: A Measures conversions is needed here
|
|
281
|
-
release_date = obs["RELEASE_DATE"][partition.obs_id].as_py() # noqa: F841
|
|
329
|
+
obs = maybe_impute_observation_table(obs, [partition.obs_id])
|
|
282
330
|
|
|
283
331
|
return dict(
|
|
284
332
|
sorted(
|
|
285
333
|
{
|
|
286
|
-
"observer":
|
|
287
|
-
"project":
|
|
334
|
+
"observer": obs["OBSERVER"][partition.obs_id].as_py(),
|
|
335
|
+
"project": obs["PROJECT"][partition.obs_id].as_py(),
|
|
288
336
|
}.items()
|
|
289
337
|
)
|
|
290
338
|
)
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import numpy.typing as npt
|
|
8
|
+
|
|
9
|
+
from xarray_ms.errors import ImputedMetadataWarning
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
import pyarrow as pa
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _maybe_return_table_or_max_id(
|
|
16
|
+
table: pa.Table, table_name: str, ids: npt.NDArray[np.int32], id_column_name: str
|
|
17
|
+
) -> pa.Table | int:
|
|
18
|
+
"""Returns the existing table if a row entry exists,
|
|
19
|
+
else returns the maximum id"""
|
|
20
|
+
max_id = np.max(ids)
|
|
21
|
+
|
|
22
|
+
if max_id < len(table):
|
|
23
|
+
return table
|
|
24
|
+
|
|
25
|
+
warnings.warn(
|
|
26
|
+
f"No row exists in the {table_name} table of length {len(table)} "
|
|
27
|
+
f"for {id_column_name}={max_id}. "
|
|
28
|
+
f"Artificial metadata will be substituted.",
|
|
29
|
+
ImputedMetadataWarning,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
return max_id
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def maybe_impute_field_table(
|
|
36
|
+
field: pa.Table, field_id: npt.NDArray[np.int32]
|
|
37
|
+
) -> pa.Table:
|
|
38
|
+
"""Generates a FIELD subtable if there are no row ids
|
|
39
|
+
associated with the given FIELD_ID values"""
|
|
40
|
+
|
|
41
|
+
import pyarrow as pa
|
|
42
|
+
|
|
43
|
+
result = _maybe_return_table_or_max_id(field, "FIELD", field_id, "FIELD_ID")
|
|
44
|
+
if isinstance(result, pa.Table):
|
|
45
|
+
return result
|
|
46
|
+
|
|
47
|
+
return pa.Table.from_pydict(
|
|
48
|
+
{
|
|
49
|
+
"NAME": np.array([f"UNKNOWN-{i}" for i in range(result + 1)], dtype=object),
|
|
50
|
+
"SOURCE_ID": np.zeros(result + 1, np.int32),
|
|
51
|
+
}
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def maybe_impute_state_table(
|
|
56
|
+
state: pa.Table, state_id: npt.NDArray[np.int32]
|
|
57
|
+
) -> pa.Table:
|
|
58
|
+
"""Generates a STATE subtable if there are no row ids
|
|
59
|
+
associated with the given STATE_ID values"""
|
|
60
|
+
import pyarrow as pa
|
|
61
|
+
|
|
62
|
+
result = _maybe_return_table_or_max_id(state, "STATE", state_id, "STATE_ID")
|
|
63
|
+
if isinstance(result, pa.Table):
|
|
64
|
+
return result
|
|
65
|
+
|
|
66
|
+
return pa.Table.from_pydict(
|
|
67
|
+
{
|
|
68
|
+
"OBS_MODE": np.array(["UNSPECIFIED"] * (result + 1), dtype=object),
|
|
69
|
+
"SUB_SCAN": np.zeros(result + 1, np.int32),
|
|
70
|
+
}
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def maybe_impute_observation_table(
|
|
75
|
+
observation: pa.Table, observation_id: npt.NDArray[np.int32]
|
|
76
|
+
) -> pa.Table:
|
|
77
|
+
"""Generates an OBSERVATION table if there are no row ids
|
|
78
|
+
associated with the given OBSERVATION_ID values"""
|
|
79
|
+
import pyarrow as pa
|
|
80
|
+
|
|
81
|
+
result = _maybe_return_table_or_max_id(
|
|
82
|
+
observation, "OBSERVATION", observation_id, "OBSERVATION_ID"
|
|
83
|
+
)
|
|
84
|
+
if isinstance(result, pa.Table):
|
|
85
|
+
return result
|
|
86
|
+
|
|
87
|
+
unknown = np.array(["unknown"] * (result + 1), dtype=object)
|
|
88
|
+
|
|
89
|
+
return pa.Table.from_pydict(
|
|
90
|
+
{
|
|
91
|
+
"OBSERVER": unknown,
|
|
92
|
+
"PROJECT": unknown,
|
|
93
|
+
"TELESCOPE_NAME": unknown,
|
|
94
|
+
}
|
|
95
|
+
)
|
|
@@ -27,6 +27,10 @@ import pyarrow as pa
|
|
|
27
27
|
from arcae.lib.arrow_tables import Table
|
|
28
28
|
from cacheout import Cache
|
|
29
29
|
|
|
30
|
+
from xarray_ms.backend.msv2.imputation import (
|
|
31
|
+
maybe_impute_field_table,
|
|
32
|
+
maybe_impute_state_table,
|
|
33
|
+
)
|
|
30
34
|
from xarray_ms.backend.msv2.partition import PartitionKeyT, TablePartitioner
|
|
31
35
|
from xarray_ms.errors import (
|
|
32
36
|
InvalidMeasurementSet,
|
|
@@ -170,16 +174,16 @@ class PartitionData:
|
|
|
170
174
|
spw_id: int # unique from DATA_DESC_ID
|
|
171
175
|
pol_id: int # unique from DATA_DESC_ID
|
|
172
176
|
# Multiple values per partition
|
|
173
|
-
antenna_ids:
|
|
174
|
-
feed_ids:
|
|
175
|
-
field_ids:
|
|
176
|
-
state_ids:
|
|
177
|
-
scan_numbers:
|
|
177
|
+
antenna_ids: npt.NDArray[np.int32]
|
|
178
|
+
feed_ids: npt.NDArray[np.int32]
|
|
179
|
+
field_ids: npt.NDArray[np.int32]
|
|
180
|
+
state_ids: npt.NDArray[np.int32]
|
|
181
|
+
scan_numbers: npt.NDArray[np.int32]
|
|
178
182
|
# FIELD subtable
|
|
179
|
-
source_ids:
|
|
183
|
+
source_ids: npt.NDArray[np.int32]
|
|
180
184
|
# STATE subtable
|
|
181
185
|
obs_mode: str # unique from STATE::OBS_MODE
|
|
182
|
-
sub_scan_numbers:
|
|
186
|
+
sub_scan_numbers: npt.NDArray[np.int32]
|
|
183
187
|
|
|
184
188
|
# Row to baseline map
|
|
185
189
|
row_map: npt.NDArray[np.int64]
|
|
@@ -233,7 +237,7 @@ class MSv2StructureFactory:
|
|
|
233
237
|
_epoch: str
|
|
234
238
|
_auto_corrs: bool
|
|
235
239
|
_STRUCTURE_CACHE: ClassVar[Cache] = Cache(
|
|
236
|
-
maxsize=100, ttl=60, on_get=on_get_keep_alive
|
|
240
|
+
maxsize=100, ttl=5 * 60, on_get=on_get_keep_alive
|
|
237
241
|
)
|
|
238
242
|
|
|
239
243
|
def __init__(
|
|
@@ -388,6 +392,7 @@ class MSv2Structure(Mapping):
|
|
|
388
392
|
) -> npt.NDArray[np.int32]:
|
|
389
393
|
"""Constructs a SOURCE_ID array from MAIN.FIELD_ID
|
|
390
394
|
broadcast against FIELD.SOURCE_ID"""
|
|
395
|
+
field = maybe_impute_field_table(field, field_id)
|
|
391
396
|
field_source_id = field["SOURCE_ID"].to_numpy()
|
|
392
397
|
source_id = np.empty_like(field_id)
|
|
393
398
|
chunk = (len(source_id) + ncpus - 1) // ncpus
|
|
@@ -411,6 +416,7 @@ class MSv2Structure(Mapping):
|
|
|
411
416
|
) -> npt.NDArray[np.int32]:
|
|
412
417
|
"""Constructs a SUB_SCAN_NUMBER array from MAIN.STATE_ID
|
|
413
418
|
broadcast against STATE.SUB_SCAN_NUMBER"""
|
|
419
|
+
state = maybe_impute_state_table(state, state_id)
|
|
414
420
|
state_ssn = state["SUB_SCAN"].to_numpy()
|
|
415
421
|
subscan_nr = np.empty_like(state_id)
|
|
416
422
|
chunk = (len(state_id) + ncpus - 1) // ncpus
|
|
@@ -434,6 +440,8 @@ class MSv2Structure(Mapping):
|
|
|
434
440
|
) -> Tuple[npt.NDArray[np.int32], Dict[str, List[int]]]:
|
|
435
441
|
"""Constructs an OBS_MODE_ID array from MAIN.STATE_ID broadcast
|
|
436
442
|
against unique entries in STATE.OBS_MODE"""
|
|
443
|
+
|
|
444
|
+
state = maybe_impute_state_table(state, state_id)
|
|
437
445
|
obs_mode = state["OBS_MODE"].to_numpy()
|
|
438
446
|
|
|
439
447
|
# Map unique observation modes to state_ids
|
|
@@ -637,14 +645,24 @@ class MSv2Structure(Mapping):
|
|
|
637
645
|
"""Return the group that the subtable column should be assigned to"""
|
|
638
646
|
return partition_columns if s in subtable_columns else other_columns
|
|
639
647
|
|
|
640
|
-
def get_uid_column(column, dkey, ids) ->
|
|
648
|
+
def get_uid_column(column, dkey, ids) -> npt.NDArray:
|
|
641
649
|
"""Get the unique values for the given column, preferably from the
|
|
642
650
|
partition key or failing that, from `ids`. Generally should be used with
|
|
643
651
|
ID columns"""
|
|
644
652
|
try:
|
|
645
|
-
return [dkey[column]]
|
|
653
|
+
return np.array([dkey[column]])
|
|
654
|
+
except KeyError:
|
|
655
|
+
return self.par_unique(pool, ncpus, ids)
|
|
656
|
+
|
|
657
|
+
def time_coord(column, dkey, ids, utime, time_ids) -> npt.NDArray:
|
|
658
|
+
try:
|
|
659
|
+
value = dkey[column]
|
|
646
660
|
except KeyError:
|
|
647
|
-
|
|
661
|
+
result = np.empty(utime.shape, dtype=ids.dtype)
|
|
662
|
+
result[time_ids] = ids
|
|
663
|
+
return result
|
|
664
|
+
else:
|
|
665
|
+
return np.full(utime.shape, value)
|
|
648
666
|
|
|
649
667
|
# Broadcast and add FIELD.SOURCE_ID column
|
|
650
668
|
field_id = arrow_table["FIELD_ID"].to_numpy()
|
|
@@ -696,21 +714,25 @@ class MSv2Structure(Mapping):
|
|
|
696
714
|
antenna2 = partition["ANTENNA2"]
|
|
697
715
|
interval = partition["INTERVAL"]
|
|
698
716
|
rows = partition["row"]
|
|
699
|
-
|
|
717
|
+
|
|
718
|
+
# Unique sorting/other column values
|
|
719
|
+
utime, time_ids = self.par_unique(
|
|
720
|
+
pool, ncpus, partition["TIME"], return_inverse=True
|
|
721
|
+
)
|
|
700
722
|
|
|
701
723
|
# Unique partition key values
|
|
702
|
-
ufield_ids =
|
|
703
|
-
|
|
704
|
-
|
|
724
|
+
ufield_ids = time_coord(
|
|
725
|
+
"FIELD_ID", dkey, partition["FIELD_ID"], utime, time_ids
|
|
726
|
+
)
|
|
727
|
+
usubscan_nrs = time_coord(
|
|
728
|
+
"SUB_SCAN_NUMBER", dkey, partition["SUB_SCAN_NUMBER"], utime, time_ids
|
|
729
|
+
)
|
|
730
|
+
uscan_nrs = time_coord(
|
|
731
|
+
"SCAN_NUMBER", dkey, partition["SCAN_NUMBER"], utime, time_ids
|
|
705
732
|
)
|
|
706
|
-
uscan_nrs = get_uid_column("SCAN_NUMBER", dkey, partition["SCAN_NUMBER"])
|
|
707
733
|
ustate_ids = get_uid_column("STATE_ID", dkey, partition["STATE_ID"])
|
|
708
734
|
usource_ids = get_uid_column("SOURCE_ID", dkey, partition["SOURCE_ID"])
|
|
709
735
|
|
|
710
|
-
# Unique sorting/other column values
|
|
711
|
-
utime, time_ids = self.par_unique(
|
|
712
|
-
pool, ncpus, partition["TIME"], return_inverse=True
|
|
713
|
-
)
|
|
714
736
|
uantenna1 = self.par_unique(pool, ncpus, antenna1)
|
|
715
737
|
uantenna2 = self.par_unique(pool, ncpus, antenna2)
|
|
716
738
|
uantennas = np.union1d(uantenna1, uantenna2)
|
|
@@ -730,6 +752,7 @@ class MSv2Structure(Mapping):
|
|
|
730
752
|
|
|
731
753
|
na = len(feed_antennas)
|
|
732
754
|
nbl = nr_of_baselines(na, auto_corrs)
|
|
755
|
+
chunk = (len(rows) + ncpus - 1) // ncpus
|
|
733
756
|
|
|
734
757
|
# Populate row map and interval grids
|
|
735
758
|
row_map = np.full(utime.size * nbl, -1, dtype=np.int64)
|
|
@@ -794,8 +817,8 @@ class MSv2Structure(Mapping):
|
|
|
794
817
|
obs_id=obs_id,
|
|
795
818
|
spw_id=spw_id,
|
|
796
819
|
pol_id=pol_id,
|
|
797
|
-
antenna_ids=feed_antennas
|
|
798
|
-
feed_ids=ufeeds
|
|
820
|
+
antenna_ids=feed_antennas,
|
|
821
|
+
feed_ids=ufeeds,
|
|
799
822
|
field_ids=ufield_ids,
|
|
800
823
|
scan_numbers=uscan_nrs,
|
|
801
824
|
source_ids=usource_ids,
|
|
@@ -3,6 +3,15 @@ class IrregularGridWarning(UserWarning):
|
|
|
3
3
|
with each timestep are not homogenous"""
|
|
4
4
|
|
|
5
5
|
|
|
6
|
+
class MissingMetadataWarning(UserWarning):
|
|
7
|
+
"""Warning raised when metadata is missing"""
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ImputedMetadataWarning(MissingMetadataWarning):
|
|
11
|
+
"""Warning raised when metadata is imputed
|
|
12
|
+
if the original metadata is missing"""
|
|
13
|
+
|
|
14
|
+
|
|
6
15
|
class InvalidMeasurementSet(ValueError):
|
|
7
16
|
"""Raised when the Measurement Set foreign key indexing is invalid"""
|
|
8
17
|
|
|
@@ -24,8 +24,8 @@ FIRST_FEB_2023_MJDS = 2459976.50000 * 86400
|
|
|
24
24
|
# Default simulation parameters
|
|
25
25
|
DEFAULT_SIM_PARAMS = {"ntime": 5, "data_description": [(8, ["XX", "XY", "YX", "YY"])]}
|
|
26
26
|
|
|
27
|
-
#
|
|
28
|
-
|
|
27
|
+
# Standard DATA Columns
|
|
28
|
+
STANDARD_DATA_COLUMNS = {
|
|
29
29
|
"DATA": {
|
|
30
30
|
"_c_order": True,
|
|
31
31
|
"comment": "DATA column",
|
|
@@ -91,6 +91,8 @@ class PartitionDescriptor:
|
|
|
91
91
|
|
|
92
92
|
DDIDArgType = List[Tuple[npt.NDArray[np.float64], List[str]]]
|
|
93
93
|
PartitionDataType = Dict[str, Tuple[Tuple[str, ...], npt.NDArray]]
|
|
94
|
+
ChunkDescriptorTransformerT = Callable[[PartitionDescriptor], PartitionDescriptor]
|
|
95
|
+
DataTransformerT = Callable[[PartitionDescriptor, PartitionDataType], PartitionDataType]
|
|
94
96
|
|
|
95
97
|
|
|
96
98
|
class MSStructureSimulator:
|
|
@@ -113,12 +115,11 @@ class MSStructureSimulator:
|
|
|
113
115
|
partition_names: List[str]
|
|
114
116
|
partition_indices: npt.NDArray[np.int32]
|
|
115
117
|
simulate_data: bool
|
|
118
|
+
table_desc: Dict[str, Any]
|
|
116
119
|
model: Dict[str, Any]
|
|
117
120
|
data_description: DataDescription
|
|
118
|
-
|
|
119
|
-
transform_data:
|
|
120
|
-
Callable[[PartitionDescriptor, PartitionDataType], PartitionDataType] | None
|
|
121
|
-
)
|
|
121
|
+
transform_chunk_desc: ChunkDescriptorTransformerT | None
|
|
122
|
+
transform_data: DataTransformerT | None
|
|
122
123
|
|
|
123
124
|
def __init__(
|
|
124
125
|
self,
|
|
@@ -134,11 +135,9 @@ class MSStructureSimulator:
|
|
|
134
135
|
partition: Tuple[str, ...] = ("OBSERVATION_ID", "FIELD_ID", "DATA_DESC_ID"),
|
|
135
136
|
auto_corrs: bool = True,
|
|
136
137
|
simulate_data: bool = True,
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
]
|
|
141
|
-
| None = None,
|
|
138
|
+
table_desc: Dict[str, Any] | None = None,
|
|
139
|
+
transform_chunk_desc: ChunkDescriptorTransformerT | None = None,
|
|
140
|
+
transform_data: DataTransformerT | None = None,
|
|
142
141
|
):
|
|
143
142
|
assert ntime >= 1
|
|
144
143
|
assert time_chunks > 0
|
|
@@ -194,9 +193,10 @@ class MSStructureSimulator:
|
|
|
194
193
|
self.time_chunks = time_chunks
|
|
195
194
|
self.time_start = time_start
|
|
196
195
|
self.simulate_data = simulate_data
|
|
196
|
+
self.table_desc = STANDARD_DATA_COLUMNS if table_desc is None else table_desc
|
|
197
197
|
self.partition_names = cbp_names
|
|
198
198
|
self.partition_indices = bcbp_indices
|
|
199
|
-
self.
|
|
199
|
+
self.transform_chunk_desc = transform_chunk_desc
|
|
200
200
|
self.transform_data = transform_data
|
|
201
201
|
self.model = {
|
|
202
202
|
"data_description": self.data_description,
|
|
@@ -211,17 +211,16 @@ class MSStructureSimulator:
|
|
|
211
211
|
|
|
212
212
|
def simulate_ms(self, output_ms: str) -> None:
|
|
213
213
|
"""Simulate data into the given measurement set name"""
|
|
214
|
-
table_desc = ADDITIONAL_COLUMNS if self.simulate_data else {}
|
|
215
214
|
|
|
216
215
|
# Generate descriptors, create simulated data from the descriptors
|
|
217
216
|
# and write simulated data to the main Measurement Set
|
|
218
|
-
with Table.ms_from_descriptor(output_ms, "MAIN", table_desc) as T:
|
|
217
|
+
with Table.ms_from_descriptor(output_ms, "MAIN", self.table_desc) as T:
|
|
219
218
|
startrow = 0
|
|
220
219
|
|
|
221
220
|
for chunk_desc in self.generate_descriptors():
|
|
222
221
|
# Apply any chunk descriptor transforms
|
|
223
|
-
if self.
|
|
224
|
-
chunk_desc = self.
|
|
222
|
+
if self.transform_chunk_desc is not None:
|
|
223
|
+
chunk_desc = self.transform_chunk_desc(chunk_desc)
|
|
225
224
|
|
|
226
225
|
# Generate the chunk data
|
|
227
226
|
data_dict = self.data_factory(chunk_desc)
|
|
@@ -1,87 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
from typing import TYPE_CHECKING, Any, Callable, Tuple
|
|
4
|
-
|
|
5
|
-
import numpy as np
|
|
6
|
-
from xarray.backends import BackendArray
|
|
7
|
-
from xarray.core.indexing import IndexingSupport, explicit_indexing_adapter
|
|
8
|
-
|
|
9
|
-
if TYPE_CHECKING:
|
|
10
|
-
import numpy.typing as npt
|
|
11
|
-
|
|
12
|
-
from xarray_ms.backend.msv2.structure import MSv2StructureFactory, PartitionKeyT
|
|
13
|
-
from xarray_ms.multiton import Multiton
|
|
14
|
-
|
|
15
|
-
TransformerT = Callable[[npt.NDArray], npt.NDArray] | None
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def slice_length(s, max_len):
|
|
19
|
-
if isinstance(s, np.ndarray):
|
|
20
|
-
if s.ndim != 1:
|
|
21
|
-
raise NotImplementedError("Slicing with non-1D numpy arrays")
|
|
22
|
-
return len(s)
|
|
23
|
-
|
|
24
|
-
start, stop, step = s.indices(max_len)
|
|
25
|
-
if step != 1:
|
|
26
|
-
raise NotImplementedError(f"Slicing with steps {s} other than 1 not supported")
|
|
27
|
-
return stop - start
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
class MSv2Array(BackendArray):
|
|
31
|
-
"""Backend array containing functionality for reading an MSv2 column"""
|
|
32
|
-
|
|
33
|
-
_table_factory: Multiton
|
|
34
|
-
_structure_factory: MSv2StructureFactory
|
|
35
|
-
_partition: PartitionKeyT
|
|
36
|
-
_column: str
|
|
37
|
-
_shape: Tuple[int, ...]
|
|
38
|
-
_dtype: npt.DTypeLike
|
|
39
|
-
_default: Any | None
|
|
40
|
-
_transform: TransformerT
|
|
41
|
-
|
|
42
|
-
def __init__(
|
|
43
|
-
self,
|
|
44
|
-
table_factory: Multiton,
|
|
45
|
-
structure_factory: MSv2StructureFactory,
|
|
46
|
-
partition: PartitionKeyT,
|
|
47
|
-
column: str,
|
|
48
|
-
shape: Tuple[int, ...],
|
|
49
|
-
dtype: npt.DTypeLike,
|
|
50
|
-
default: Any | None = None,
|
|
51
|
-
transform: TransformerT = None,
|
|
52
|
-
):
|
|
53
|
-
self._table_factory = table_factory
|
|
54
|
-
self._structure_factory = structure_factory
|
|
55
|
-
self._partition = partition
|
|
56
|
-
self._column = column
|
|
57
|
-
self._default = default
|
|
58
|
-
self._transform = transform
|
|
59
|
-
self.shape = shape
|
|
60
|
-
self.dtype = np.dtype(dtype)
|
|
61
|
-
|
|
62
|
-
assert len(shape) >= 2, "(time, baseline_ids) required"
|
|
63
|
-
|
|
64
|
-
def __getitem__(self, key) -> npt.NDArray:
|
|
65
|
-
return explicit_indexing_adapter(
|
|
66
|
-
key, self.shape, IndexingSupport.OUTER, self._getitem
|
|
67
|
-
)
|
|
68
|
-
|
|
69
|
-
def _getitem(self, key) -> npt.NDArray:
|
|
70
|
-
assert len(key) == len(self.shape)
|
|
71
|
-
expected_shape = tuple(slice_length(k, s) for k, s in zip(key, self.shape))
|
|
72
|
-
# Map the (time, baseline_id) coordinates onto row indices
|
|
73
|
-
rows = self._structure_factory.instance[self._partition].row_map[key[:2]]
|
|
74
|
-
xkey = (rows.ravel(),) + key[2:]
|
|
75
|
-
row_shape = (rows.size,) + expected_shape[2:]
|
|
76
|
-
result = np.full(row_shape, self._default, dtype=self.dtype)
|
|
77
|
-
self._table_factory.instance.getcol(self._column, xkey, result)
|
|
78
|
-
result = result.reshape(rows.shape + expected_shape[2:])
|
|
79
|
-
return self._transform(result) if self._transform else result
|
|
80
|
-
|
|
81
|
-
@property
|
|
82
|
-
def transform(self) -> TransformerT:
|
|
83
|
-
return self._transform
|
|
84
|
-
|
|
85
|
-
@transform.setter
|
|
86
|
-
def transform(self, value: TransformerT):
|
|
87
|
-
self._transform = value
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|