xarray-ms 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,219 @@
1
+ import dataclasses
2
+ import warnings
3
+ from typing import Any, Mapping, Tuple, Type
4
+
5
+ import numpy as np
6
+ from xarray import Variable
7
+ from xarray.coding.variables import unpack_for_decoding
8
+ from xarray.core.indexing import LazilyIndexedArray
9
+ from xarray.core.utils import FrozenDict
10
+
11
+ from xarray_ms.backend.msv2.array import MSv2Array
12
+ from xarray_ms.backend.msv2.encoders import (
13
+ CasaCoder,
14
+ QuantityCoder,
15
+ TimeCoder,
16
+ )
17
+ from xarray_ms.backend.msv2.structure import MSv2StructureFactory, PartitionKeyT
18
+ from xarray_ms.backend.msv2.table_factory import TableFactory
19
+ from xarray_ms.errors import IrregularGridWarning
20
+
21
+
22
+ @dataclasses.dataclass
23
+ class MSv2ColumnSchema:
24
+ name: str
25
+ dims: Tuple[str, ...]
26
+ default: Any = None
27
+ coder: Type[CasaCoder] | None = None
28
+
29
+
30
+ MSV4_to_MSV2_COLUMN_SCHEMAS = {
31
+ "TIME": MSv2ColumnSchema("TIME", (), np.nan, TimeCoder),
32
+ "INTEGRATION_TIME": MSv2ColumnSchema("INTERVAL", (), np.nan, QuantityCoder),
33
+ "TIME_CENTROID": MSv2ColumnSchema("TIME_CENTROID", (), np.nan, TimeCoder),
34
+ "EFFECTIVE_INTEGRATION_TIME": MSv2ColumnSchema("EXPOSURE", (), np.nan, QuantityCoder),
35
+ "UVW": MSv2ColumnSchema("UVW", ("uvw_label",), np.nan, None),
36
+ "FLAG": MSv2ColumnSchema("FLAG", ("frequency", "polarization"), 1, None),
37
+ "VISIBILITY": MSv2ColumnSchema(
38
+ "DATA", ("frequency", "polarization"), np.nan + np.nan * 1j, None
39
+ ),
40
+ "WEIGHT": MSv2ColumnSchema(
41
+ "WEIGHT_SPECTRUM", ("frequency", "polarization"), np.nan, None
42
+ ),
43
+ }
44
+
45
+ FIXED_DIMENSION_SIZES = {"uvw_label": 3}
46
+
47
+
48
+ class MainDatasetFactory:
49
+ _partition_key: PartitionKeyT
50
+ _table_factory: TableFactory
51
+ _structure_factory: MSv2StructureFactory
52
+
53
+ def __init__(
54
+ self,
55
+ partition_key: PartitionKeyT,
56
+ table_factory: TableFactory,
57
+ structure_factory: MSv2StructureFactory,
58
+ ):
59
+ self._partition_key = partition_key
60
+ self._table_factory = table_factory
61
+ self._structure_factory = structure_factory
62
+
63
+ def _variable_from_column(self, column: str) -> Variable:
64
+ """Derive an xarray Variable from the MSv2 column descriptor and schemas"""
65
+ structure = self._structure_factory()
66
+ partition = structure[self._partition_key]
67
+ main_column_descs = structure.column_descs["MAIN"]
68
+
69
+ try:
70
+ schema = MSV4_to_MSV2_COLUMN_SCHEMAS[column]
71
+ except KeyError:
72
+ raise KeyError(f"No Column Schema exist for {column}")
73
+
74
+ try:
75
+ column_desc = main_column_descs[schema.name]
76
+ except KeyError:
77
+ raise KeyError(f"No Column Descriptor exist for {schema.name}")
78
+
79
+ dim_sizes = {
80
+ "time": len(partition.time),
81
+ "baseline": structure.nbl,
82
+ "frequency": len(partition.chan_freq),
83
+ "polarization": len(partition.corr_type),
84
+ **FIXED_DIMENSION_SIZES,
85
+ }
86
+
87
+ dims = ("time", "baseline") + schema.dims
88
+
89
+ try:
90
+ shape = tuple(dim_sizes[d] for d in dims)
91
+ except KeyError as e:
92
+ raise KeyError(f"No dimension size found for {e.args[0]}")
93
+
94
+ default = column_desc.dtype.type(schema.default)
95
+
96
+ data = MSv2Array(
97
+ self._table_factory,
98
+ self._structure_factory,
99
+ self._partition_key,
100
+ schema.name,
101
+ shape,
102
+ column_desc.dtype,
103
+ default,
104
+ )
105
+
106
+ var = Variable(dims, data)
107
+
108
+ # Apply any measures encoding
109
+ if schema.coder:
110
+ coder = schema.coder(schema.name, structure.column_descs["MAIN"])
111
+ var = coder.decode(var)
112
+
113
+ dims, data, attrs, encoding = unpack_for_decoding(var)
114
+ return Variable(dims, LazilyIndexedArray(data), attrs, encoding, fastpath=True)
115
+
116
+ def get_variables(self) -> Mapping[str, Variable]:
117
+ structure = self._structure_factory()
118
+ partition = structure[self._partition_key]
119
+ ant1, ant2 = structure.antenna_pairs
120
+ nbl = structure.nbl
121
+ assert (nbl,) == ant1.shape
122
+
123
+ ant_names = structure._ant["NAME"].to_numpy()
124
+ ant1_names = ant_names[ant1]
125
+ ant2_names = ant_names[ant2]
126
+
127
+ row_map = partition.row_map
128
+ missing = np.count_nonzero(row_map == -1)
129
+ if missing > 0:
130
+ warnings.warn(
131
+ f"{missing} / {row_map.size} ({100. * missing / row_map.size:.1f}%) "
132
+ f"rows missing from the full (time, baseline) grid "
133
+ f"in partition {self._partition_key}. "
134
+ f"Dataset variables will be padded",
135
+ IrregularGridWarning,
136
+ )
137
+
138
+ data_vars = [
139
+ (n, self._variable_from_column(n))
140
+ for n in (
141
+ "TIME_CENTROID",
142
+ "EFFECTIVE_INTEGRATION_TIME",
143
+ "UVW",
144
+ "VISIBILITY",
145
+ "FLAG",
146
+ "WEIGHT",
147
+ )
148
+ ]
149
+
150
+ # Add coordinates indexing coordinates
151
+ coordinates = [
152
+ (
153
+ "baseline_id",
154
+ (("baseline",), np.arange(len(ant1)), {"coordinates": "baseline_id"}),
155
+ ),
156
+ ("antenna1_name", (("baseline",), ant1_names, {"coordinates": "antenna1_name"})),
157
+ ("antenna2_name", (("baseline",), ant2_names, {"coordinates": "antenna2_name"})),
158
+ ("polarization", (("polarization",), partition.corr_type, None)),
159
+ ]
160
+
161
+ coordinates = [(n, Variable(d, v, a)) for n, (d, v, a) in coordinates]
162
+
163
+ # Add time coordinate
164
+ time_coder = TimeCoder("TIME", structure.column_descs["MAIN"])
165
+
166
+ if partition.interval.size == 1:
167
+ time_attrs = {"integration_time": partition.interval.item()}
168
+ else:
169
+ warnings.warn(
170
+ f"Multiple intervals {partition.interval} "
171
+ f"found in partition {self._partition_key}. "
172
+ f'Setting time.attrs["integration_time"] = nan and '
173
+ f"adding full resolution TIME and INTERVAL columns. ",
174
+ IrregularGridWarning,
175
+ )
176
+ time_attrs = {"integration_time": np.nan}
177
+ data_vars.extend(
178
+ [
179
+ ("TIME", self._variable_from_column("TIME")),
180
+ ("INTEGRATION_TIME", self._variable_from_column("INTEGRATION_TIME")),
181
+ ]
182
+ )
183
+
184
+ coordinates.append(
185
+ ("time", time_coder.decode(Variable("time", partition.time, time_attrs)))
186
+ )
187
+
188
+ # Add frequency coordinate
189
+ freq_attrs = {
190
+ "type": "spectral_coord",
191
+ "frame": partition.spw_frame,
192
+ "units": ["Hz"],
193
+ "spectral_window_name": partition.spw_name or "<Unknown>",
194
+ "reference_frequency": partition.spw_ref_freq,
195
+ "effective_channel_width": "EFFECTIVE_CHANNEL_WIDTH",
196
+ }
197
+
198
+ if partition.spw_freq_group_name:
199
+ freq_attrs["frequency_group_name"] = partition.spw_freq_group_name
200
+
201
+ if partition.chan_width.size == 1:
202
+ freq_attrs["channel_width"] = partition.chan_width.item()
203
+ else:
204
+ freq_attrs["channel_width"] = np.nan
205
+ warnings.warn(
206
+ f"Multiple channel widths {partition.chan_width} "
207
+ f"found in partition {self._partition_key}. "
208
+ f'Setting frequency.attrs["channel_width"] = nan and '
209
+ f"adding full resolution CHANNEL_FREQUENCY column. ",
210
+ )
211
+ raise NotImplementedError(
212
+ "Full resolution CHANNEL_FREQUENCY " " and CHANNEL_WIDTH columns"
213
+ )
214
+
215
+ coordinates.append(
216
+ ("frequency", Variable("frequency", partition.chan_freq, freq_attrs))
217
+ )
218
+
219
+ return FrozenDict(sorted(data_vars + coordinates))