xarray-ms 0.2.4__tar.gz → 0.2.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. {xarray_ms-0.2.4 → xarray_ms-0.2.6}/PKG-INFO +1 -1
  2. {xarray_ms-0.2.4 → xarray_ms-0.2.6}/pyproject.toml +2 -2
  3. xarray_ms-0.2.6/xarray_ms/backend/msv2/array.py +170 -0
  4. {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/backend/msv2/factories/antenna.py +4 -3
  5. {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/backend/msv2/factories/correlated.py +84 -36
  6. xarray_ms-0.2.6/xarray_ms/backend/msv2/imputation.py +95 -0
  7. {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/backend/msv2/structure.py +45 -22
  8. {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/errors.py +9 -0
  9. {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/testing/simulator.py +15 -16
  10. xarray_ms-0.2.4/xarray_ms/backend/msv2/array.py +0 -87
  11. {xarray_ms-0.2.4 → xarray_ms-0.2.6}/LICENSE +0 -0
  12. {xarray_ms-0.2.4 → xarray_ms-0.2.6}/README.rst +0 -0
  13. {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/__init__.py +0 -0
  14. {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/backend/msv2/encoders.py +0 -0
  15. {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/backend/msv2/entrypoint.py +0 -0
  16. {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/backend/msv2/entrypoint_utils.py +0 -0
  17. {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/backend/msv2/factories/__init__.py +0 -0
  18. {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/backend/msv2/partition.py +0 -0
  19. {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/casa_types.py +0 -0
  20. {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/msv4_types.py +0 -0
  21. {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/multiton.py +0 -0
  22. {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/query.py +0 -0
  23. {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/testing/__init__.py +0 -0
  24. {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/testing/utils.py +0 -0
  25. {xarray_ms-0.2.4 → xarray_ms-0.2.6}/xarray_ms/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: xarray-ms
3
- Version: 0.2.4
3
+ Version: 0.2.6
4
4
  Summary: xarray MSv4 views over MSv2 Measurement Sets
5
5
  Author: Simon Perkins
6
6
  Author-email: simon.perkins@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "xarray-ms"
3
- version = "0.2.4"
3
+ version = "0.2.6"
4
4
  description = "xarray MSv4 views over MSv2 Measurement Sets"
5
5
  authors = ["Simon Perkins <simon.perkins@gmail.com>"]
6
6
  readme = "README.rst"
@@ -58,7 +58,7 @@ build-backend = "poetry.core.masonry.api"
58
58
  # github_url = "https://github.com/<user or organization>/<project>/"
59
59
 
60
60
  [tool.tbump.version]
61
- current = "0.2.4"
61
+ current = "0.2.6"
62
62
 
63
63
  # Example of a semver regexp.
64
64
  # Make sure this matches current_version before
@@ -0,0 +1,170 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any, Callable, Tuple
4
+
5
+ import numpy as np
6
+ from xarray.backends import BackendArray
7
+ from xarray.core.indexing import IndexingSupport, explicit_indexing_adapter
8
+
9
+ if TYPE_CHECKING:
10
+ import numpy.typing as npt
11
+
12
+ from xarray_ms.backend.msv2.structure import MSv2StructureFactory, PartitionKeyT
13
+ from xarray_ms.multiton import Multiton
14
+
15
+ TransformerT = Callable[[npt.NDArray], npt.NDArray]
16
+
17
+
18
+ def slice_length(s: npt.NDArray | slice, max_len) -> int:
19
+ if isinstance(s, np.ndarray):
20
+ if s.ndim != 1:
21
+ raise NotImplementedError("Slicing with non-1D numpy arrays")
22
+ return len(s)
23
+
24
+ start, stop, step = s.indices(max_len)
25
+ if step != 1:
26
+ raise NotImplementedError(f"Slicing with steps {s} other than 1 not supported")
27
+ return stop - start
28
+
29
+
30
+ class MSv2Array(BackendArray):
31
+ """Base MSv2Array backend array class,
32
+ containing required shape and data type"""
33
+
34
+ __slots__ = ("shape", "dtype")
35
+
36
+ shape: Tuple[int, ...]
37
+ dtype: npt.DTypeLike
38
+
39
+ def __init__(self, shape: Tuple[int, ...], dtype: npt.DTypeLike):
40
+ self.shape = shape
41
+ self.dtype = dtype
42
+
43
+ def __getitem__(self, key) -> npt.NDArray:
44
+ raise NotImplementedError
45
+
46
+ @property
47
+ def transform(self) -> TransformerT | None:
48
+ raise NotImplementedError
49
+
50
+ @transform.setter
51
+ def transform(self, value: TransformerT) -> None:
52
+ raise NotImplementedError
53
+
54
+
55
+ class MainMSv2Array(MSv2Array):
56
+ """Backend array containing functionality for reading an MSv2 column
57
+ from the MAIN table. Columns are assumed to have ("time", "baseline_id")
58
+ as the first dimensions. These are mapped onto the "row" dimension
59
+ via the partition row map"""
60
+
61
+ __slots__ = (
62
+ "_table_factory",
63
+ "_structure_factory",
64
+ "_partition",
65
+ "_column",
66
+ "_default",
67
+ "_transform",
68
+ )
69
+
70
+ _table_factory: Multiton
71
+ _structure_factory: MSv2StructureFactory
72
+ _partition: PartitionKeyT
73
+ _column: str
74
+ _default: Any | None
75
+ _transform: TransformerT | None
76
+
77
+ def __init__(
78
+ self,
79
+ table_factory: Multiton,
80
+ structure_factory: MSv2StructureFactory,
81
+ partition: PartitionKeyT,
82
+ column: str,
83
+ shape: Tuple[int, ...],
84
+ dtype: npt.DTypeLike,
85
+ default: Any | None = None,
86
+ transform: TransformerT | None = None,
87
+ ):
88
+ super().__init__(shape, dtype)
89
+ self._table_factory = table_factory
90
+ self._structure_factory = structure_factory
91
+ self._partition = partition
92
+ self._column = column
93
+ self._default = default
94
+ self._transform = transform
95
+
96
+ assert len(shape) >= 2, "(time, baseline_ids) required"
97
+
98
+ def __getitem__(self, key) -> npt.NDArray:
99
+ return explicit_indexing_adapter(
100
+ key, self.shape, IndexingSupport.OUTER, self._getitem
101
+ )
102
+
103
+ def _getitem(self, key) -> npt.NDArray:
104
+ assert len(key) == len(self.shape)
105
+ expected_shape = tuple(slice_length(k, s) for k, s in zip(key, self.shape))
106
+ # Map the (time, baseline_id) coordinates onto row indices
107
+ rows = self._structure_factory.instance[self._partition].row_map[key[:2]]
108
+ row_key = (rows.ravel(),) + key[2:]
109
+ row_shape = (rows.size,) + expected_shape[2:]
110
+ result = np.full(row_shape, self._default, dtype=self.dtype)
111
+ self._table_factory.instance.getcol(self._column, row_key, result)
112
+ result = result.reshape(rows.shape + expected_shape[2:])
113
+ return self._transform(result) if self._transform else result
114
+
115
+ @property
116
+ def transform(self) -> TransformerT | None:
117
+ return self._transform
118
+
119
+ @transform.setter
120
+ def transform(self, value: TransformerT) -> None:
121
+ self._transform = value
122
+
123
+
124
+ class BroadcastMSv2Array(MSv2Array):
125
+ """Broadcasts a MAIN table MSv2 Column up to an
126
+ MSv4 column. This can be inefficient for example,
127
+ if multiple frequency chunks are read for the same
128
+ ("time", "baseline_id") range as the same
129
+ low resolution data can be read multiple times.
130
+
131
+ However, this should be no worse than reading the
132
+ data for a full resolution column.
133
+
134
+ This is primarily useful for falling back to the
135
+ WEIGHT column when WEIGHT_SPECTRUM is missing, or
136
+ FLAG_ROW if FLAG is missing.
137
+ """
138
+
139
+ __slots__ = ("_low_res_array", "_low_res_index")
140
+
141
+ _low_res_array: MSv2Array
142
+ _low_res_index: Tuple[slice | None, ...]
143
+ shape: Tuple[int, ...]
144
+
145
+ def __init__(
146
+ self,
147
+ low_res_array: MSv2Array,
148
+ low_res_index: Tuple[slice | None, ...],
149
+ high_res_shape: Tuple[int, ...],
150
+ ):
151
+ self._low_res_array = low_res_array
152
+ self._low_res_index = low_res_index
153
+ self.shape = high_res_shape
154
+
155
+ @property
156
+ def dtype(self):
157
+ return self._low_res_array.dtype
158
+
159
+ @property
160
+ def transform(self) -> TransformerT | None:
161
+ return self._low_res_array.transform
162
+
163
+ @transform.setter
164
+ def transform(self, value: TransformerT) -> None:
165
+ self._low_res_array.transform = value
166
+
167
+ def __getitem__(self, key) -> npt.NDArray:
168
+ low_res_data = self._low_res_array.__getitem__(key)
169
+ low_res_data = low_res_data[self._low_res_index]
170
+ return np.broadcast_to(low_res_data, self.shape)
@@ -3,6 +3,7 @@ from typing import Dict, Mapping
3
3
  import numpy as np
4
4
  from xarray import Dataset, Variable
5
5
 
6
+ from xarray_ms.backend.msv2.imputation import maybe_impute_observation_table
6
7
  from xarray_ms.backend.msv2.structure import MSv2StructureFactory, PartitionKeyT
7
8
  from xarray_ms.errors import InvalidMeasurementSet
8
9
  from xarray_ms.multiton import Multiton
@@ -26,13 +27,13 @@ class AntennaDatasetFactory:
26
27
  self._subtable_factories = subtable_factories
27
28
 
28
29
  def get_dataset(self) -> Mapping[str, Variable]:
29
- structure = self._structure_factory.instance
30
- partition = structure[self._partition_key]
30
+ partition = self._structure_factory.instance[self._partition_key]
31
31
  ants = self._subtable_factories["ANTENNA"].instance
32
32
  feeds = self._subtable_factories["FEED"].instance
33
33
  obs = self._subtable_factories["OBSERVATION"].instance
34
34
 
35
- telescope_name = obs["TELESCOPE_NAME"][partition.obs_id].as_py()
35
+ obs = maybe_impute_observation_table(obs, [partition.obs_id])
36
+ telescope_name = obs["TELESCOPE_NAME"][0].as_py()
36
37
 
37
38
  import pyarrow.compute as pac
38
39
 
@@ -8,12 +8,20 @@ from xarray.coding.variables import unpack_for_decoding
8
8
  from xarray.core.indexing import LazilyIndexedArray
9
9
  from xarray.core.utils import FrozenDict
10
10
 
11
- from xarray_ms.backend.msv2.array import MSv2Array
11
+ from xarray_ms.backend.msv2.array import (
12
+ BroadcastMSv2Array,
13
+ MainMSv2Array,
14
+ MSv2Array,
15
+ )
12
16
  from xarray_ms.backend.msv2.encoders import (
13
17
  CasaCoder,
14
18
  QuantityCoder,
15
19
  TimeCoder,
16
20
  )
21
+ from xarray_ms.backend.msv2.imputation import (
22
+ maybe_impute_field_table,
23
+ maybe_impute_observation_table,
24
+ )
17
25
  from xarray_ms.backend.msv2.structure import MSv2StructureFactory, PartitionKeyT
18
26
  from xarray_ms.casa_types import ColumnDesc, FrequencyMeasures, Polarisations
19
27
  from xarray_ms.errors import IrregularGridWarning
@@ -26,6 +34,7 @@ class MSv2ColumnSchema:
26
34
  dims: Tuple[str, ...]
27
35
  default: Any = None
28
36
  coder: Type[CasaCoder] | None = None
37
+ low_res_dims: Tuple[str, ...] | None = None
29
38
 
30
39
 
31
40
  MSV4_to_MSV2_COLUMN_SCHEMAS = {
@@ -34,10 +43,20 @@ MSV4_to_MSV2_COLUMN_SCHEMAS = {
34
43
  "TIME_CENTROID": MSv2ColumnSchema("TIME_CENTROID", (), np.nan, TimeCoder),
35
44
  "EFFECTIVE_INTEGRATION_TIME": MSv2ColumnSchema("EXPOSURE", (), np.nan, QuantityCoder),
36
45
  "UVW": MSv2ColumnSchema("UVW", ("uvw_label",), np.nan, None),
46
+ "FLAG_ROW": MSv2ColumnSchema(
47
+ "FLAG_ROW", ("frequency", "polarization"), 1, None, low_res_dims=()
48
+ ),
37
49
  "FLAG": MSv2ColumnSchema("FLAG", ("frequency", "polarization"), 1, None),
38
50
  "VISIBILITY": MSv2ColumnSchema(
39
51
  "DATA", ("frequency", "polarization"), np.nan + np.nan * 1j, None
40
52
  ),
53
+ "WEIGHT_ROW": MSv2ColumnSchema(
54
+ "WEIGHT",
55
+ ("frequency", "polarization"),
56
+ np.nan,
57
+ None,
58
+ low_res_dims=("polarization",),
59
+ ),
41
60
  "WEIGHT": MSv2ColumnSchema(
42
61
  "WEIGHT_SPECTRUM", ("frequency", "polarization"), np.nan, None
43
62
  ),
@@ -78,11 +97,8 @@ class CorrelatedDatasetFactory:
78
97
  c: ColumnDesc.from_descriptor(c, ms_table_desc) for c in ms.columns()
79
98
  }
80
99
 
81
- def _variable_from_column(self, column: str) -> Variable:
100
+ def _variable_from_column(self, column: str, dim_sizes: Dict[str, int]) -> Variable:
82
101
  """Derive an xarray Variable from the MSv2 column descriptor and schemas"""
83
- structure = self._structure_factory.instance
84
- partition = structure[self._partition_key]
85
-
86
102
  try:
87
103
  schema = MSV4_to_MSV2_COLUMN_SCHEMAS[column]
88
104
  except KeyError:
@@ -93,20 +109,6 @@ class CorrelatedDatasetFactory:
93
109
  except KeyError:
94
110
  raise KeyError(f"No Column Descriptor exist for {schema.name}")
95
111
 
96
- spw = self._subtable_factories["SPECTRAL_WINDOW"].instance
97
- pol = self._subtable_factories["POLARIZATION"].instance
98
-
99
- chan_freq = spw["CHAN_FREQ"][partition.spw_id].as_py()
100
- corr_type = pol["CORR_TYPE"][partition.pol_id].as_py()
101
-
102
- dim_sizes = {
103
- "time": len(partition.time),
104
- "baseline_id": partition.nbl,
105
- "frequency": len(chan_freq),
106
- "polarization": len(corr_type),
107
- **FIXED_DIMENSION_SIZES,
108
- }
109
-
110
112
  dims = ("time", "baseline_id") + schema.dims
111
113
 
112
114
  try:
@@ -116,7 +118,20 @@ class CorrelatedDatasetFactory:
116
118
 
117
119
  default = column_desc.dtype.type(schema.default)
118
120
 
119
- data = MSv2Array(
121
+ high_res_shape = shape
122
+ low_res_index: Tuple[slice | None, ...] = tuple(slice(None) for _ in shape)
123
+
124
+ if schema.low_res_dims:
125
+ low_res_dims = ("time", "baseline_id") + schema.low_res_dims
126
+ high_res_shape = shape
127
+ try:
128
+ shape_map = {d: dim_sizes[d] for d in low_res_dims}
129
+ except KeyError as e:
130
+ raise KeyError(f"No dimension size found for {e.args[0]}")
131
+ low_res_index = tuple(slice(None) if d in shape_map else None for d in dims)
132
+ shape = tuple(shape_map.values())
133
+
134
+ array: MSv2Array = MainMSv2Array(
120
135
  self._ms_factory,
121
136
  self._structure_factory,
122
137
  self._partition_key,
@@ -126,7 +141,10 @@ class CorrelatedDatasetFactory:
126
141
  default,
127
142
  )
128
143
 
129
- var = Variable(dims, data, fastpath=True)
144
+ if schema.low_res_dims:
145
+ array = BroadcastMSv2Array(array, low_res_index, high_res_shape)
146
+
147
+ var = Variable(dims, array, fastpath=True)
130
148
 
131
149
  # Apply any measures encoding
132
150
  if schema.coder:
@@ -144,8 +162,7 @@ class CorrelatedDatasetFactory:
144
162
  structure = self._structure_factory.instance
145
163
  partition = structure[self._partition_key]
146
164
  ant1, ant2 = partition.antenna_pairs
147
- nbl = partition.nbl
148
- assert (nbl,) == ant1.shape
165
+ assert (partition.nbl,) == ant1.shape
149
166
 
150
167
  antenna = self._subtable_factories["ANTENNA"].instance
151
168
  ant_names = antenna["NAME"].to_numpy()
@@ -156,6 +173,7 @@ class CorrelatedDatasetFactory:
156
173
  pol_id = partition.pol_id
157
174
  spw = self._subtable_factories["SPECTRAL_WINDOW"].instance
158
175
  pol = self._subtable_factories["POLARIZATION"].instance
176
+ field = self._subtable_factories["FIELD"].instance
159
177
 
160
178
  chan_freq = spw["CHAN_FREQ"][spw_id].as_py()
161
179
  uchan_width = np.unique(spw["CHAN_WIDTH"][spw_id].as_py())
@@ -167,6 +185,14 @@ class CorrelatedDatasetFactory:
167
185
 
168
186
  corr_type = Polarisations.from_values(pol["CORR_TYPE"][pol_id].as_py()).to_str()
169
187
 
188
+ dim_sizes = {
189
+ "time": len(partition.time),
190
+ "baseline_id": partition.nbl,
191
+ "frequency": len(chan_freq),
192
+ "polarization": len(corr_type),
193
+ **FIXED_DIMENSION_SIZES,
194
+ }
195
+
170
196
  row_map = partition.row_map
171
197
  missing = np.count_nonzero(row_map == -1)
172
198
  if missing > 0:
@@ -181,17 +207,28 @@ class CorrelatedDatasetFactory:
181
207
  )
182
208
 
183
209
  data_vars = [
184
- (n, self._variable_from_column(n))
210
+ (n, self._variable_from_column(n, dim_sizes))
185
211
  for n in (
186
212
  "TIME_CENTROID",
187
213
  "EFFECTIVE_INTEGRATION_TIME",
188
214
  "UVW",
189
215
  "VISIBILITY",
190
- "FLAG",
191
- "WEIGHT",
192
216
  )
193
217
  ]
194
218
 
219
+ if "FLAG" in self._main_column_descs:
220
+ data_vars.append(("FLAG", self._variable_from_column("FLAG", dim_sizes)))
221
+ else:
222
+ data_vars.append(("FLAG", self._variable_from_column("FLAG_ROW", dim_sizes)))
223
+
224
+ if "WEIGHT_SPECTRUM" in self._main_column_descs:
225
+ data_vars.append(("WEIGHT", self._variable_from_column("WEIGHT", dim_sizes)))
226
+ else:
227
+ data_vars.append(("WEIGHT", self._variable_from_column("WEIGHT_ROW", dim_sizes)))
228
+
229
+ field = maybe_impute_field_table(field, partition.field_ids)
230
+ field_names = field.take(partition.field_ids)["NAME"].to_numpy()
231
+
195
232
  # Add coordinates indexing coordinates
196
233
  coordinates = [
197
234
  (
@@ -208,6 +245,12 @@ class CorrelatedDatasetFactory:
208
245
  ),
209
246
  ("polarization", (("polarization",), corr_type, None)),
210
247
  ("uvw_label", (("uvw_label",), ["u", "v", "w"], None)),
248
+ ("field_name", ("time", field_names, {"coordinates": "field_name"})),
249
+ ("scan_number", ("time", partition.scan_numbers, {"coordinates": "scan_number"})),
250
+ (
251
+ "sub_scan_number",
252
+ ("time", partition.sub_scan_numbers, {"coordinates": "sub_scan_number"}),
253
+ ),
211
254
  ]
212
255
 
213
256
  e = {"preferred_chunks": self._preferred_chunks} if self._preferred_chunks else None
@@ -217,8 +260,14 @@ class CorrelatedDatasetFactory:
217
260
  time_coder = TimeCoder("TIME", self._main_column_descs)
218
261
 
219
262
  if partition.interval.size == 1:
263
+ # Single unique value
220
264
  time_attrs = {"integration_time": partition.interval.item()}
265
+ elif np.allclose(partition.interval[:, None], partition.interval[None, :]):
266
+ # Tolerate some jitter in the unique values
267
+ time_attrs = {"integration_time": np.mean(partition.interval)}
221
268
  else:
269
+ # There are multiple unique interval values,
270
+ # a regular grid isn't possible
222
271
  warnings.warn(
223
272
  f"Missing/Multiple intervals {partition.interval} "
224
273
  f"found in partition {self._partition_key}. "
@@ -230,8 +279,11 @@ class CorrelatedDatasetFactory:
230
279
  time_attrs = {"integration_time": np.nan}
231
280
  data_vars.extend(
232
281
  [
233
- ("TIME", self._variable_from_column("TIME")),
234
- ("INTEGRATION_TIME", self._variable_from_column("INTEGRATION_TIME")),
282
+ ("TIME", self._variable_from_column("TIME", dim_sizes)),
283
+ (
284
+ "INTEGRATION_TIME",
285
+ self._variable_from_column("INTEGRATION_TIME", dim_sizes),
286
+ ),
235
287
  ]
236
288
  )
237
289
 
@@ -272,19 +324,15 @@ class CorrelatedDatasetFactory:
272
324
  return FrozenDict(sorted(data_vars + coordinates))
273
325
 
274
326
  def _observation_info(self) -> Dict[str, Any]:
275
- structure = self._structure_factory.instance
276
- partition = structure[self._partition_key]
327
+ partition = self._structure_factory.instance[self._partition_key]
277
328
  obs = self._subtable_factories["OBSERVATION"].instance
278
- observer = obs["OBSERVER"][partition.obs_id].as_py()
279
- project = obs["PROJECT"][partition.obs_id].as_py()
280
- # TODO: A Measures conversions is needed here
281
- release_date = obs["RELEASE_DATE"][partition.obs_id].as_py() # noqa: F841
329
+ obs = maybe_impute_observation_table(obs, [partition.obs_id])
282
330
 
283
331
  return dict(
284
332
  sorted(
285
333
  {
286
- "observer": observer,
287
- "project": project,
334
+ "observer": obs["OBSERVER"][partition.obs_id].as_py(),
335
+ "project": obs["PROJECT"][partition.obs_id].as_py(),
288
336
  }.items()
289
337
  )
290
338
  )
@@ -0,0 +1,95 @@
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+ from typing import TYPE_CHECKING
5
+
6
+ import numpy as np
7
+ import numpy.typing as npt
8
+
9
+ from xarray_ms.errors import ImputedMetadataWarning
10
+
11
+ if TYPE_CHECKING:
12
+ import pyarrow as pa
13
+
14
+
15
+ def _maybe_return_table_or_max_id(
16
+ table: pa.Table, table_name: str, ids: npt.NDArray[np.int32], id_column_name: str
17
+ ) -> pa.Table | int:
18
+ """Returns the existing table if a row entry exists,
19
+ else returns the maximum id"""
20
+ max_id = np.max(ids)
21
+
22
+ if max_id < len(table):
23
+ return table
24
+
25
+ warnings.warn(
26
+ f"No row exists in the {table_name} table of length {len(table)} "
27
+ f"for {id_column_name}={max_id}. "
28
+ f"Artificial metadata will be substituted.",
29
+ ImputedMetadataWarning,
30
+ )
31
+
32
+ return max_id
33
+
34
+
35
+ def maybe_impute_field_table(
36
+ field: pa.Table, field_id: npt.NDArray[np.int32]
37
+ ) -> pa.Table:
38
+ """Generates a FIELD subtable if there are no row ids
39
+ associated with the given FIELD_ID values"""
40
+
41
+ import pyarrow as pa
42
+
43
+ result = _maybe_return_table_or_max_id(field, "FIELD", field_id, "FIELD_ID")
44
+ if isinstance(result, pa.Table):
45
+ return result
46
+
47
+ return pa.Table.from_pydict(
48
+ {
49
+ "NAME": np.array([f"UNKNOWN-{i}" for i in range(result + 1)], dtype=object),
50
+ "SOURCE_ID": np.zeros(result + 1, np.int32),
51
+ }
52
+ )
53
+
54
+
55
+ def maybe_impute_state_table(
56
+ state: pa.Table, state_id: npt.NDArray[np.int32]
57
+ ) -> pa.Table:
58
+ """Generates a STATE subtable if there are no row ids
59
+ associated with the given STATE_ID values"""
60
+ import pyarrow as pa
61
+
62
+ result = _maybe_return_table_or_max_id(state, "STATE", state_id, "STATE_ID")
63
+ if isinstance(result, pa.Table):
64
+ return result
65
+
66
+ return pa.Table.from_pydict(
67
+ {
68
+ "OBS_MODE": np.array(["UNSPECIFIED"] * (result + 1), dtype=object),
69
+ "SUB_SCAN": np.zeros(result + 1, np.int32),
70
+ }
71
+ )
72
+
73
+
74
+ def maybe_impute_observation_table(
75
+ observation: pa.Table, observation_id: npt.NDArray[np.int32]
76
+ ) -> pa.Table:
77
+ """Generates an OBSERVATION table if there are no row ids
78
+ associated with the given OBSERVATION_ID values"""
79
+ import pyarrow as pa
80
+
81
+ result = _maybe_return_table_or_max_id(
82
+ observation, "OBSERVATION", observation_id, "OBSERVATION_ID"
83
+ )
84
+ if isinstance(result, pa.Table):
85
+ return result
86
+
87
+ unknown = np.array(["unknown"] * (result + 1), dtype=object)
88
+
89
+ return pa.Table.from_pydict(
90
+ {
91
+ "OBSERVER": unknown,
92
+ "PROJECT": unknown,
93
+ "TELESCOPE_NAME": unknown,
94
+ }
95
+ )
@@ -27,6 +27,10 @@ import pyarrow as pa
27
27
  from arcae.lib.arrow_tables import Table
28
28
  from cacheout import Cache
29
29
 
30
+ from xarray_ms.backend.msv2.imputation import (
31
+ maybe_impute_field_table,
32
+ maybe_impute_state_table,
33
+ )
30
34
  from xarray_ms.backend.msv2.partition import PartitionKeyT, TablePartitioner
31
35
  from xarray_ms.errors import (
32
36
  InvalidMeasurementSet,
@@ -170,16 +174,16 @@ class PartitionData:
170
174
  spw_id: int # unique from DATA_DESC_ID
171
175
  pol_id: int # unique from DATA_DESC_ID
172
176
  # Multiple values per partition
173
- antenna_ids: List[int]
174
- feed_ids: List[int]
175
- field_ids: List[int]
176
- state_ids: List[int]
177
- scan_numbers: List[int]
177
+ antenna_ids: npt.NDArray[np.int32]
178
+ feed_ids: npt.NDArray[np.int32]
179
+ field_ids: npt.NDArray[np.int32]
180
+ state_ids: npt.NDArray[np.int32]
181
+ scan_numbers: npt.NDArray[np.int32]
178
182
  # FIELD subtable
179
- source_ids: List[int]
183
+ source_ids: npt.NDArray[np.int32]
180
184
  # STATE subtable
181
185
  obs_mode: str # unique from STATE::OBS_MODE
182
- sub_scan_numbers: List[int]
186
+ sub_scan_numbers: npt.NDArray[np.int32]
183
187
 
184
188
  # Row to baseline map
185
189
  row_map: npt.NDArray[np.int64]
@@ -233,7 +237,7 @@ class MSv2StructureFactory:
233
237
  _epoch: str
234
238
  _auto_corrs: bool
235
239
  _STRUCTURE_CACHE: ClassVar[Cache] = Cache(
236
- maxsize=100, ttl=60, on_get=on_get_keep_alive
240
+ maxsize=100, ttl=5 * 60, on_get=on_get_keep_alive
237
241
  )
238
242
 
239
243
  def __init__(
@@ -388,6 +392,7 @@ class MSv2Structure(Mapping):
388
392
  ) -> npt.NDArray[np.int32]:
389
393
  """Constructs a SOURCE_ID array from MAIN.FIELD_ID
390
394
  broadcast against FIELD.SOURCE_ID"""
395
+ field = maybe_impute_field_table(field, field_id)
391
396
  field_source_id = field["SOURCE_ID"].to_numpy()
392
397
  source_id = np.empty_like(field_id)
393
398
  chunk = (len(source_id) + ncpus - 1) // ncpus
@@ -411,6 +416,7 @@ class MSv2Structure(Mapping):
411
416
  ) -> npt.NDArray[np.int32]:
412
417
  """Constructs a SUB_SCAN_NUMBER array from MAIN.STATE_ID
413
418
  broadcast against STATE.SUB_SCAN_NUMBER"""
419
+ state = maybe_impute_state_table(state, state_id)
414
420
  state_ssn = state["SUB_SCAN"].to_numpy()
415
421
  subscan_nr = np.empty_like(state_id)
416
422
  chunk = (len(state_id) + ncpus - 1) // ncpus
@@ -434,6 +440,8 @@ class MSv2Structure(Mapping):
434
440
  ) -> Tuple[npt.NDArray[np.int32], Dict[str, List[int]]]:
435
441
  """Constructs an OBS_MODE_ID array from MAIN.STATE_ID broadcast
436
442
  against unique entries in STATE.OBS_MODE"""
443
+
444
+ state = maybe_impute_state_table(state, state_id)
437
445
  obs_mode = state["OBS_MODE"].to_numpy()
438
446
 
439
447
  # Map unique observation modes to state_ids
@@ -637,14 +645,24 @@ class MSv2Structure(Mapping):
637
645
  """Return the group that the subtable column should be assigned to"""
638
646
  return partition_columns if s in subtable_columns else other_columns
639
647
 
640
- def get_uid_column(column, dkey, ids) -> List[Any]:
648
+ def get_uid_column(column, dkey, ids) -> npt.NDArray:
641
649
  """Get the unique values for the given column, preferably from the
642
650
  partition key or failing that, from `ids`. Generally should be used with
643
651
  ID columns"""
644
652
  try:
645
- return [dkey[column]]
653
+ return np.array([dkey[column]])
654
+ except KeyError:
655
+ return self.par_unique(pool, ncpus, ids)
656
+
657
+ def time_coord(column, dkey, ids, utime, time_ids) -> npt.NDArray:
658
+ try:
659
+ value = dkey[column]
646
660
  except KeyError:
647
- return self.par_unique(pool, ncpus, ids).tolist()
661
+ result = np.empty(utime.shape, dtype=ids.dtype)
662
+ result[time_ids] = ids
663
+ return result
664
+ else:
665
+ return np.full(utime.shape, value)
648
666
 
649
667
  # Broadcast and add FIELD.SOURCE_ID column
650
668
  field_id = arrow_table["FIELD_ID"].to_numpy()
@@ -696,21 +714,25 @@ class MSv2Structure(Mapping):
696
714
  antenna2 = partition["ANTENNA2"]
697
715
  interval = partition["INTERVAL"]
698
716
  rows = partition["row"]
699
- chunk = (len(rows) + ncpus - 1) // ncpus
717
+
718
+ # Unique sorting/other column values
719
+ utime, time_ids = self.par_unique(
720
+ pool, ncpus, partition["TIME"], return_inverse=True
721
+ )
700
722
 
701
723
  # Unique partition key values
702
- ufield_ids = get_uid_column("FIELD_ID", dkey, partition["FIELD_ID"])
703
- usubscan_nrs = get_uid_column(
704
- "SUB_SCAN_NUMBER", dkey, partition["SUB_SCAN_NUMBER"]
724
+ ufield_ids = time_coord(
725
+ "FIELD_ID", dkey, partition["FIELD_ID"], utime, time_ids
726
+ )
727
+ usubscan_nrs = time_coord(
728
+ "SUB_SCAN_NUMBER", dkey, partition["SUB_SCAN_NUMBER"], utime, time_ids
729
+ )
730
+ uscan_nrs = time_coord(
731
+ "SCAN_NUMBER", dkey, partition["SCAN_NUMBER"], utime, time_ids
705
732
  )
706
- uscan_nrs = get_uid_column("SCAN_NUMBER", dkey, partition["SCAN_NUMBER"])
707
733
  ustate_ids = get_uid_column("STATE_ID", dkey, partition["STATE_ID"])
708
734
  usource_ids = get_uid_column("SOURCE_ID", dkey, partition["SOURCE_ID"])
709
735
 
710
- # Unique sorting/other column values
711
- utime, time_ids = self.par_unique(
712
- pool, ncpus, partition["TIME"], return_inverse=True
713
- )
714
736
  uantenna1 = self.par_unique(pool, ncpus, antenna1)
715
737
  uantenna2 = self.par_unique(pool, ncpus, antenna2)
716
738
  uantennas = np.union1d(uantenna1, uantenna2)
@@ -730,6 +752,7 @@ class MSv2Structure(Mapping):
730
752
 
731
753
  na = len(feed_antennas)
732
754
  nbl = nr_of_baselines(na, auto_corrs)
755
+ chunk = (len(rows) + ncpus - 1) // ncpus
733
756
 
734
757
  # Populate row map and interval grids
735
758
  row_map = np.full(utime.size * nbl, -1, dtype=np.int64)
@@ -794,8 +817,8 @@ class MSv2Structure(Mapping):
794
817
  obs_id=obs_id,
795
818
  spw_id=spw_id,
796
819
  pol_id=pol_id,
797
- antenna_ids=feed_antennas.tolist(),
798
- feed_ids=ufeeds.tolist(),
820
+ antenna_ids=feed_antennas,
821
+ feed_ids=ufeeds,
799
822
  field_ids=ufield_ids,
800
823
  scan_numbers=uscan_nrs,
801
824
  source_ids=usource_ids,
@@ -3,6 +3,15 @@ class IrregularGridWarning(UserWarning):
3
3
  with each timestep are not homogenous"""
4
4
 
5
5
 
6
+ class MissingMetadataWarning(UserWarning):
7
+ """Warning raised when metadata is missing"""
8
+
9
+
10
+ class ImputedMetadataWarning(MissingMetadataWarning):
11
+ """Warning raised when metadata is imputed
12
+ if the original metadata is missing"""
13
+
14
+
6
15
  class InvalidMeasurementSet(ValueError):
7
16
  """Raised when the Measurement Set foreign key indexing is invalid"""
8
17
 
@@ -24,8 +24,8 @@ FIRST_FEB_2023_MJDS = 2459976.50000 * 86400
24
24
  # Default simulation parameters
25
25
  DEFAULT_SIM_PARAMS = {"ntime": 5, "data_description": [(8, ["XX", "XY", "YX", "YY"])]}
26
26
 
27
- # Additional Columns to add
28
- ADDITIONAL_COLUMNS = {
27
+ # Standard DATA Columns
28
+ STANDARD_DATA_COLUMNS = {
29
29
  "DATA": {
30
30
  "_c_order": True,
31
31
  "comment": "DATA column",
@@ -91,6 +91,8 @@ class PartitionDescriptor:
91
91
 
92
92
  DDIDArgType = List[Tuple[npt.NDArray[np.float64], List[str]]]
93
93
  PartitionDataType = Dict[str, Tuple[Tuple[str, ...], npt.NDArray]]
94
+ ChunkDescriptorTransformerT = Callable[[PartitionDescriptor], PartitionDescriptor]
95
+ DataTransformerT = Callable[[PartitionDescriptor, PartitionDataType], PartitionDataType]
94
96
 
95
97
 
96
98
  class MSStructureSimulator:
@@ -113,12 +115,11 @@ class MSStructureSimulator:
113
115
  partition_names: List[str]
114
116
  partition_indices: npt.NDArray[np.int32]
115
117
  simulate_data: bool
118
+ table_desc: Dict[str, Any]
116
119
  model: Dict[str, Any]
117
120
  data_description: DataDescription
118
- transform_desc: Callable[[PartitionDescriptor], PartitionDescriptor] | None
119
- transform_data: (
120
- Callable[[PartitionDescriptor, PartitionDataType], PartitionDataType] | None
121
- )
121
+ transform_chunk_desc: ChunkDescriptorTransformerT | None
122
+ transform_data: DataTransformerT | None
122
123
 
123
124
  def __init__(
124
125
  self,
@@ -134,11 +135,9 @@ class MSStructureSimulator:
134
135
  partition: Tuple[str, ...] = ("OBSERVATION_ID", "FIELD_ID", "DATA_DESC_ID"),
135
136
  auto_corrs: bool = True,
136
137
  simulate_data: bool = True,
137
- transform_desc: Callable[[PartitionDescriptor], PartitionDescriptor] | None = None,
138
- transform_data: Callable[
139
- [PartitionDescriptor, PartitionDataType], PartitionDataType
140
- ]
141
- | None = None,
138
+ table_desc: Dict[str, Any] | None = None,
139
+ transform_chunk_desc: ChunkDescriptorTransformerT | None = None,
140
+ transform_data: DataTransformerT | None = None,
142
141
  ):
143
142
  assert ntime >= 1
144
143
  assert time_chunks > 0
@@ -194,9 +193,10 @@ class MSStructureSimulator:
194
193
  self.time_chunks = time_chunks
195
194
  self.time_start = time_start
196
195
  self.simulate_data = simulate_data
196
+ self.table_desc = STANDARD_DATA_COLUMNS if table_desc is None else table_desc
197
197
  self.partition_names = cbp_names
198
198
  self.partition_indices = bcbp_indices
199
- self.transform_desc = transform_desc
199
+ self.transform_chunk_desc = transform_chunk_desc
200
200
  self.transform_data = transform_data
201
201
  self.model = {
202
202
  "data_description": self.data_description,
@@ -211,17 +211,16 @@ class MSStructureSimulator:
211
211
 
212
212
  def simulate_ms(self, output_ms: str) -> None:
213
213
  """Simulate data into the given measurement set name"""
214
- table_desc = ADDITIONAL_COLUMNS if self.simulate_data else {}
215
214
 
216
215
  # Generate descriptors, create simulated data from the descriptors
217
216
  # and write simulated data to the main Measurement Set
218
- with Table.ms_from_descriptor(output_ms, "MAIN", table_desc) as T:
217
+ with Table.ms_from_descriptor(output_ms, "MAIN", self.table_desc) as T:
219
218
  startrow = 0
220
219
 
221
220
  for chunk_desc in self.generate_descriptors():
222
221
  # Apply any chunk descriptor transforms
223
- if self.transform_desc is not None:
224
- chunk_desc = self.transform_desc(chunk_desc)
222
+ if self.transform_chunk_desc is not None:
223
+ chunk_desc = self.transform_chunk_desc(chunk_desc)
225
224
 
226
225
  # Generate the chunk data
227
226
  data_dict = self.data_factory(chunk_desc)
@@ -1,87 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import TYPE_CHECKING, Any, Callable, Tuple
4
-
5
- import numpy as np
6
- from xarray.backends import BackendArray
7
- from xarray.core.indexing import IndexingSupport, explicit_indexing_adapter
8
-
9
- if TYPE_CHECKING:
10
- import numpy.typing as npt
11
-
12
- from xarray_ms.backend.msv2.structure import MSv2StructureFactory, PartitionKeyT
13
- from xarray_ms.multiton import Multiton
14
-
15
- TransformerT = Callable[[npt.NDArray], npt.NDArray] | None
16
-
17
-
18
- def slice_length(s, max_len):
19
- if isinstance(s, np.ndarray):
20
- if s.ndim != 1:
21
- raise NotImplementedError("Slicing with non-1D numpy arrays")
22
- return len(s)
23
-
24
- start, stop, step = s.indices(max_len)
25
- if step != 1:
26
- raise NotImplementedError(f"Slicing with steps {s} other than 1 not supported")
27
- return stop - start
28
-
29
-
30
- class MSv2Array(BackendArray):
31
- """Backend array containing functionality for reading an MSv2 column"""
32
-
33
- _table_factory: Multiton
34
- _structure_factory: MSv2StructureFactory
35
- _partition: PartitionKeyT
36
- _column: str
37
- _shape: Tuple[int, ...]
38
- _dtype: npt.DTypeLike
39
- _default: Any | None
40
- _transform: TransformerT
41
-
42
- def __init__(
43
- self,
44
- table_factory: Multiton,
45
- structure_factory: MSv2StructureFactory,
46
- partition: PartitionKeyT,
47
- column: str,
48
- shape: Tuple[int, ...],
49
- dtype: npt.DTypeLike,
50
- default: Any | None = None,
51
- transform: TransformerT = None,
52
- ):
53
- self._table_factory = table_factory
54
- self._structure_factory = structure_factory
55
- self._partition = partition
56
- self._column = column
57
- self._default = default
58
- self._transform = transform
59
- self.shape = shape
60
- self.dtype = np.dtype(dtype)
61
-
62
- assert len(shape) >= 2, "(time, baseline_ids) required"
63
-
64
- def __getitem__(self, key) -> npt.NDArray:
65
- return explicit_indexing_adapter(
66
- key, self.shape, IndexingSupport.OUTER, self._getitem
67
- )
68
-
69
- def _getitem(self, key) -> npt.NDArray:
70
- assert len(key) == len(self.shape)
71
- expected_shape = tuple(slice_length(k, s) for k, s in zip(key, self.shape))
72
- # Map the (time, baseline_id) coordinates onto row indices
73
- rows = self._structure_factory.instance[self._partition].row_map[key[:2]]
74
- xkey = (rows.ravel(),) + key[2:]
75
- row_shape = (rows.size,) + expected_shape[2:]
76
- result = np.full(row_shape, self._default, dtype=self.dtype)
77
- self._table_factory.instance.getcol(self._column, xkey, result)
78
- result = result.reshape(rows.shape + expected_shape[2:])
79
- return self._transform(result) if self._transform else result
80
-
81
- @property
82
- def transform(self) -> TransformerT:
83
- return self._transform
84
-
85
- @transform.setter
86
- def transform(self, value: TransformerT):
87
- self._transform = value
File without changes
File without changes
File without changes
File without changes