xradio 0.0.28__py3-none-any.whl → 0.0.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. xradio/__init__.py +5 -4
  2. xradio/_utils/array.py +90 -0
  3. xradio/_utils/zarr/common.py +48 -3
  4. xradio/image/_util/zarr.py +4 -1
  5. xradio/schema/__init__.py +24 -6
  6. xradio/schema/bases.py +440 -2
  7. xradio/schema/check.py +96 -55
  8. xradio/schema/dataclass.py +123 -27
  9. xradio/schema/metamodel.py +21 -4
  10. xradio/schema/typing.py +33 -18
  11. xradio/vis/__init__.py +5 -2
  12. xradio/vis/_processing_set.py +28 -20
  13. xradio/vis/_vis_utils/_ms/_tables/create_field_and_source_xds.py +710 -0
  14. xradio/vis/_vis_utils/_ms/_tables/load.py +23 -10
  15. xradio/vis/_vis_utils/_ms/_tables/load_main_table.py +145 -64
  16. xradio/vis/_vis_utils/_ms/_tables/read.py +747 -172
  17. xradio/vis/_vis_utils/_ms/_tables/read_main_table.py +173 -44
  18. xradio/vis/_vis_utils/_ms/_tables/read_subtables.py +79 -28
  19. xradio/vis/_vis_utils/_ms/_tables/write.py +102 -45
  20. xradio/vis/_vis_utils/_ms/_tables/write_exp_api.py +127 -65
  21. xradio/vis/_vis_utils/_ms/chunks.py +58 -21
  22. xradio/vis/_vis_utils/_ms/conversion.py +536 -67
  23. xradio/vis/_vis_utils/_ms/descr.py +52 -20
  24. xradio/vis/_vis_utils/_ms/msv2_to_msv4_meta.py +70 -35
  25. xradio/vis/_vis_utils/_ms/msv4_infos.py +0 -59
  26. xradio/vis/_vis_utils/_ms/msv4_sub_xdss.py +76 -9
  27. xradio/vis/_vis_utils/_ms/optimised_functions.py +0 -46
  28. xradio/vis/_vis_utils/_ms/partition_queries.py +308 -119
  29. xradio/vis/_vis_utils/_ms/partitions.py +82 -25
  30. xradio/vis/_vis_utils/_ms/subtables.py +32 -14
  31. xradio/vis/_vis_utils/_utils/partition_attrs.py +30 -11
  32. xradio/vis/_vis_utils/_utils/xds_helper.py +136 -45
  33. xradio/vis/_vis_utils/_zarr/read.py +60 -22
  34. xradio/vis/_vis_utils/_zarr/write.py +83 -9
  35. xradio/vis/_vis_utils/ms.py +48 -29
  36. xradio/vis/_vis_utils/zarr.py +44 -20
  37. xradio/vis/convert_msv2_to_processing_set.py +106 -32
  38. xradio/vis/load_processing_set.py +38 -61
  39. xradio/vis/read_processing_set.py +62 -96
  40. xradio/vis/schema.py +687 -0
  41. xradio/vis/vis_io.py +75 -43
  42. {xradio-0.0.28.dist-info → xradio-0.0.29.dist-info}/LICENSE.txt +6 -1
  43. {xradio-0.0.28.dist-info → xradio-0.0.29.dist-info}/METADATA +10 -5
  44. xradio-0.0.29.dist-info/RECORD +73 -0
  45. {xradio-0.0.28.dist-info → xradio-0.0.29.dist-info}/WHEEL +1 -1
  46. xradio/vis/model.py +0 -497
  47. xradio-0.0.28.dist-info/RECORD +0 -71
  48. {xradio-0.0.28.dist-info → xradio-0.0.29.dist-info}/top_level.txt +0 -0
xradio/__init__.py CHANGED
@@ -1,12 +1,13 @@
1
1
  import os
2
2
  from graphviper.utils.logger import setup_logger
3
3
 
4
- if not os.getenv("VIPER_LOGGER_NAME"):
5
- os.environ["VIPER_LOGGER_NAME"] = "xradio"
4
+ _logger_name = "xradio"
5
+ if os.getenv("VIPER_LOGGER_NAME") != _logger_name:
6
+ os.environ["VIPER_LOGGER_NAME"] = _logger_name
6
7
  setup_logger(
7
8
  logger_name="xradio",
8
9
  log_to_term=True,
9
- log_to_file=False,
10
+ log_to_file=False, # True
10
11
  log_file="xradio-logfile",
11
- log_level="INFO",
12
+ log_level="DEBUG",
12
13
  )
xradio/_utils/array.py ADDED
@@ -0,0 +1,90 @@
1
+ """Contains optimised functions to be used within other modules."""
2
+
3
+ import numpy as np
4
+
5
+
6
+ def check_if_consistent(array: np.ndarray, array_name: str) -> np.ndarray:
7
+ """_summary_
8
+
9
+ Parameters
10
+ ----------
11
+ col : _type_
12
+ _description_
13
+ col_name : _type_
14
+ _description_
15
+
16
+ Returns
17
+ -------
18
+ _type_
19
+ _description_
20
+ """
21
+ if array.ndim == 0:
22
+ return array.item()
23
+
24
+ array_unique = unique_1d(array)
25
+ assert len(array_unique) == 1, array_name + " is not consistent."
26
+ return array_unique[0]
27
+
28
+
29
+ import numpy as np
30
+ import pandas as pd
31
+
32
+
33
+ def unique_1d(array: np.ndarray) -> np.ndarray:
34
+ """
35
+ Optimised version of np.unique for 1D arrays.
36
+
37
+ Parameters
38
+ ----------
39
+ array : np.ndarray
40
+ a 1D array of values.
41
+
42
+ Returns
43
+ -------
44
+ np.ndarray
45
+ a sorted array of unique values.
46
+
47
+ """
48
+ return np.sort(pd.unique(array))
49
+
50
+
51
+ def pairing_function(antenna_pairs: np.ndarray) -> np.ndarray:
52
+ """
53
+ Pairing function to convert each array pair to a single value.
54
+
55
+ This custom pairing function will only work if the maximum value is less
56
+ than 2**20 and less than 2,048 if using signed 32-bit integers.
57
+
58
+ Parameters
59
+ ----------
60
+ antenna_pairs : np.ndarray
61
+ a 2D array containing antenna 1 and antenna
62
+ 2 ids, which forms a baseline.
63
+
64
+ Returns
65
+ -------
66
+ np.ndarray
67
+ a 1D array of the paired values.
68
+
69
+ """
70
+ return antenna_pairs[:, 0] * 2**20 + antenna_pairs[:, 1]
71
+
72
+
73
+ def inverse_pairing_function(paired_array: np.ndarray) -> np.ndarray:
74
+ """
75
+ Inverse pairing function to convert each paired value to an antenna pair.
76
+
77
+ This inverse pairing function is the inverse of the custom pairing function.
78
+
79
+ Parameters
80
+ ----------
81
+ paired_array : np.ndarray
82
+ a 1D array of the paired values.
83
+
84
+ Returns
85
+ -------
86
+ np.ndarray
87
+ a 2D array containing antenna 1 and antenna 2 ids.
88
+
89
+ """
90
+ return np.column_stack(np.divmod(paired_array, 2**20))
@@ -2,9 +2,54 @@ import copy
2
2
  import xarray as xr
3
3
  import zarr
4
4
  import s3fs
5
+ import os
6
+ from botocore.exceptions import NoCredentialsError
7
+
8
+
9
+ def _get_ms_stores_and_file_system(ps_store: str):
10
+
11
+ if os.path.isdir(ps_store):
12
+ # default to assuming the data are accessible on local file system
13
+ items = os.listdir(ps_store)
14
+ file_system = os
15
+
16
+ elif ps_store.startswith("s3"):
17
+ # only if not found locally, check if dealing with an S3 bucket URL
18
+ # if not ps_store.endswith("/"):
19
+ # # just for consistency, as there is no os.path equivalent in s3fs
20
+ # ps_store = ps_store + "/"
21
+
22
+ try:
23
+ # initialize the S3 "file system", first attempting to use pre-configured credentials
24
+ file_system = s3fs.S3FileSystem(anon=False, requester_pays=False)
25
+ items = [
26
+ bd.split(sep="/")[-1]
27
+ for bd in file_system.listdir(ps_store, detail=False)
28
+ ]
29
+
30
+ except (NoCredentialsError, PermissionError) as e:
31
+ # only public, read-only buckets will be accessible
32
+ # we will want to add messaging and error handling here
33
+ file_system = s3fs.S3FileSystem(anon=True)
34
+ items = [
35
+ bd.split(sep="/")[-1]
36
+ for bd in file_system.listdir(ps_store, detail=False)
37
+ ]
38
+ else:
39
+ raise (
40
+ FileNotFoundError,
41
+ f"Could not find {ps_store} either locally or in the cloud.",
42
+ )
43
+
44
+ items = [
45
+ item for item in items if not item.startswith(".")
46
+ ] # Mac OS likes to place hidden files in the directory (.DStore).
47
+ return file_system, items
5
48
 
6
49
 
7
- def _open_dataset(store, xds_isel=None, data_variables=None, load=False, **kwargs):
50
+ def _open_dataset(
51
+ store, file_system=os, xds_isel=None, data_variables=None, load=False
52
+ ):
8
53
  """
9
54
 
10
55
  Parameters
@@ -26,8 +71,8 @@ def _open_dataset(store, xds_isel=None, data_variables=None, load=False, **kwarg
26
71
 
27
72
  import dask
28
73
 
29
- if "s3" in kwargs.keys():
30
- mapping = s3fs.S3Map(root=store, s3=kwargs["s3"], check=False)
74
+ if isinstance(file_system, s3fs.core.S3FileSystem):
75
+ mapping = s3fs.S3Map(root=store, s3=file_system, check=False)
31
76
  xds = xr.open_zarr(store=mapping)
32
77
  else:
33
78
  xds = xr.open_zarr(store)
@@ -25,7 +25,10 @@ def _xds_from_zarr(
25
25
 
26
26
 
27
27
  def _load_image_from_zarr_no_dask(zarr_file: str, selection: dict) -> xr.Dataset:
28
- image_xds = _open_dataset(zarr_file, selection, load=True)
28
+ # At the moment image module does not support S3 file system. file_system=os is hardcoded.
29
+ image_xds = _open_dataset(
30
+ store=zarr_file, file_system=os, xds_isel=selection, load=True
31
+ )
29
32
  for h in ["HISTORY", "_attrs_xds_history"]:
30
33
  history = os.sep.join([zarr_file, h])
31
34
  if os.path.isdir(history):
xradio/schema/__init__.py CHANGED
@@ -1,17 +1,35 @@
1
- from .metamodel import AttrSchemaRef, ArraySchema, ArraySchemaRef, DatasetSchema
2
1
  from .dataclass import (
3
2
  xarray_dataclass_to_array_schema,
4
3
  xarray_dataclass_to_dataset_schema,
4
+ xarray_dataclass_to_dict_schema,
5
+ )
6
+ from .bases import (
7
+ xarray_dataarray_schema,
8
+ xarray_dataset_schema,
9
+ dict_schema,
10
+ )
11
+ from .check import (
12
+ SchemaIssue,
13
+ SchemaIssues,
14
+ check_array,
15
+ check_dataset,
16
+ check_dict,
17
+ schema_checked,
5
18
  )
6
- from .bases import AsDataArray, AsDataset
7
19
 
8
20
  __all__ = [
9
- "AttrSchemaRef",
10
- "ArraySchema",
11
- "ArraySchemaRef",
12
- "DatasetSchema",
13
21
  "xarray_dataclass_to_array_schema",
14
22
  "xarray_dataclass_to_dataset_schema",
23
+ "xarray_dataclass_to_dict_schema",
15
24
  "AsDataArray",
16
25
  "AsDataset",
26
+ "AsDict",
27
+ "xarray_dataarray_schema",
28
+ "xarray_dataset_schema",
29
+ "SchemaIssue",
30
+ "SchemaIssues",
31
+ "check_array",
32
+ "check_dataset",
33
+ "check_dict",
34
+ "schema_checked",
17
35
  ]
xradio/schema/bases.py CHANGED
@@ -1,6 +1,444 @@
1
+ import xarray
2
+ import inspect
3
+ from . import dataclass, check, metamodel, typing
4
+ import numpy
5
+ import dataclasses
6
+
7
+
8
+ def _guess_dtype(obj: typing.Any):
9
+ try:
10
+ return _guess_dtype(next(iter(obj)))
11
+ except TypeError:
12
+ return numpy.dtype(type(obj))
13
+
14
+
15
+ def _set_parameter(
16
+ val: typing.Any, args: dict, schema: typing.Union["AttrSchemaRef", "ArraySchemaRef"]
17
+ ):
18
+ """
19
+ Extract given entry from parameters - while taking care that the
20
+ parameter value might have been set either before or after, and that
21
+ defaults might apply.
22
+
23
+ :param val: Value from xarray-constructor style ("data_vars"/"coords")
24
+ :param args: Bound arguments to constructor (positional or named)
25
+ :param schema: Schema of argument (either attribute or array)
26
+ :returns: Updated value
27
+ """
28
+
29
+ # If value appears in named parameters, overwrite
30
+ if args.get(schema.name) is not None:
31
+ if val is not None:
32
+ raise ValueError(
33
+ f"Parameter {schema.name} was passed twice ({val} vs {args[schema.name]})!"
34
+ )
35
+ val = args[schema.name]
36
+
37
+ # Otherwise apply defaults *if* it doesn't exist already or deactivate
38
+ # (typically because we are construting from a dataset/data array)
39
+ if val is None and schema.default is not dataclasses.MISSING:
40
+ default = schema.default
41
+ if default is not None:
42
+ val = default
43
+
44
+ return val
45
+
46
+
47
+ def _np_convert(val: typing.Any, schema: metamodel.ArraySchemaRef):
48
+ """
49
+ Convert value to numpy, if appropriate
50
+
51
+ This attempts to catch "early" conversions that we can do more
52
+ appropriately than xarray because we have more information from the schema.
53
+ Specifically, if it's a type where the dtype to choose is somewhat
54
+ ambiguous, we can use this chance to "bias" it towards an allowed one.
55
+
56
+ :param val: Received value
57
+ :param schema: Execpted array schema
58
+ :returns: Possibly converted value
59
+ """
60
+
61
+ # Array schema refs that are not yet a numpy or xarray data type?
62
+ if isinstance(val, list) or isinstance(val, tuple) and isinstance(val[1], list):
63
+ # Check whether we can "guess" the dtype from the object
64
+ dtype = None
65
+ if len(schema.dtypes) > 1:
66
+ guessed = _guess_dtype(val)
67
+ for dt in schema.dtypes:
68
+ # Actually look for closest in precision etc?
69
+ if dt == guessed:
70
+ dtype = dt
71
+ break
72
+
73
+ # Otherwise just use the first one
74
+ if dtype is None:
75
+ dtype = schema.dtypes[0]
76
+
77
+ # Attempt conversation
78
+ try:
79
+ if isinstance(val, list):
80
+ val = numpy.array(val, dtype=dtype)
81
+ else:
82
+ val = tuple([val[0], numpy.array(val[1], dtype=dtype), *val[2:]])
83
+
84
+ except TypeError:
85
+ pass
86
+
87
+ return val
88
+
89
+
90
+ def _dataarray_new(
91
+ cls,
92
+ data=None,
93
+ *args,
94
+ coords=None,
95
+ dims=None,
96
+ name=None,
97
+ attrs=None,
98
+ indexes=None,
99
+ **kwargs,
100
+ ):
101
+ # Convert parameters
102
+ if coords is not None and isinstance(coords, list):
103
+ coords = dict(coords)
104
+ if coords is None:
105
+ coords = {}
106
+ if attrs is None:
107
+ attrs = {}
108
+
109
+ # Get signature of __init__, map parameters and apply defaults. This
110
+ # will raise an exception if there are any extra parameters.
111
+ sig = inspect.Signature.from_callable(cls.__init__)
112
+ sig = sig.replace(parameters=[v for k, v in sig.parameters.items() if k != "self"])
113
+ mapping = sig.bind_partial(data, *args, **kwargs)
114
+
115
+ # Check whether we have a "data" argument now. This happens if we pass
116
+ # it as a positional argument.
117
+ if mapping.arguments.get("data") is not None:
118
+ data = mapping.arguments["data"]
119
+
120
+ # No dims specified? Select one matching the data dimensionality from
121
+ # the schema
122
+ schema = dataclass.xarray_dataclass_to_array_schema(cls)
123
+ data = _np_convert(data, schema)
124
+ for schema_dims in schema.dimensions:
125
+ if len(schema_dims) == len(data.shape):
126
+ dims = schema_dims
127
+ break
128
+
129
+ # If we are constructing from a data array / variable, take over attributes
130
+ if isinstance(data, (xarray.DataArray, xarray.Variable)):
131
+ for attr, attr_val in data.attrs.items():
132
+ # Explicit parameters take precedence though
133
+ if attr not in attrs:
134
+ attrs[attr] = attr_val
135
+
136
+ # Get any coordinates or attributes and add them to the appropriate lists
137
+ for coord in schema.coordinates:
138
+ val = _np_convert(
139
+ _set_parameter(coords.get(coord.name), mapping.arguments, coord), coord
140
+ )
141
+
142
+ # Default to simple range of specified dtype if part of dimensions
143
+ # (that's roughly the behaviour of the xarray constructor as well)
144
+ if val is None and dims is not None:
145
+ dim_ix = dims.index(coord.name)
146
+ if dim_ix is not None and dim_ix < len(data.shape):
147
+ dtype = coord.dtypes[0]
148
+ val = numpy.arange(data.shape[dim_ix], dtype=dtype)
149
+
150
+ if val is not None:
151
+ coords[coord.name] = val
152
+ for attr in schema.attributes:
153
+ val = _set_parameter(attrs.get(attr.name), mapping.arguments, attr)
154
+ if val is not None:
155
+ attrs[attr.name] = val
156
+
157
+ # Redirect to xradio.DataArray constructor
158
+ instance = xarray.DataArray(data, coords, dims, name, attrs, indexes)
159
+
160
+ # Perform schema check
161
+ check.check_array(instance, schema).expect()
162
+ return instance
163
+
164
+
165
+ def xarray_dataarray_schema(cls):
166
+ """Decorator for classes representing :py:class:`xarray.DataArray` schemas.
167
+ The annotated class should exactly contain:
168
+
169
+ * one field called "``data``" annotated with :py:data:`~typing.Data`
170
+ to indicate the array type
171
+ * fields annotated with :py:data:`~typing.Coord` to indicate mappings of
172
+ dimensions to coordinates (coordinates directly associated with dimensions
173
+ should have the same name as the dimension)
174
+ * fields annotated with :py:data:`~typing.Attr` to declare attributes
175
+
176
+ Decorated schema classes can be used with
177
+ :py:func:`~xradio.schema.check.check_array` for checking
178
+ :py:class:`xarray.DataArray` objects against the schema. Furthermore, the
179
+ class constructor will be overwritten to generate schema-confirming
180
+ :py:class:`xarray.DataArray` objects.
181
+
182
+ For example::
183
+
184
+ from xradio.schema import xarray_dataarray_schema
185
+ from xradio.schema.typing import Data, Coord, Attr
186
+ from typing import Optional, Literal
187
+ import dataclasses
188
+
189
+ Coo = Literal["coo"]
190
+
191
+ @xarray_dataarray_schema
192
+ class TestArray:
193
+ data: Data[Coo, complex]
194
+ coo: Coord[Coo, float]
195
+ attr1: Attr[str]
196
+ attr2: Attr[int] = 123
197
+ attr3: Optional[Attr[int]] = None
198
+
199
+ This data class represents a one-dimensional :py:class:`xarray.DataArray`
200
+ with complex data, a ``float`` coordinate and three attributes. Instances of
201
+ this class cannot actually be constructed, instead you will get an appropriate
202
+ :py:class:`xarray.DataArray` object::
203
+
204
+ >>> TestArray(data=[1,2,3], attr1="foo")
205
+ <xarray.DataArray (coo: 3)>
206
+ array([1.+0.j, 2.+0.j, 3.+0.j])
207
+ Coordinates:
208
+ * coo (coo) float64 0.0 1.0 2.0
209
+ Attributes:
210
+ attr1: foo
211
+ attr2: 123
212
+
213
+ Note that:
214
+
215
+ * The constructor uses the annotations to identify the role of every parameter
216
+ * The data was automatically converted into a :py:class:`numpy.ndarray`
217
+ * As there was no coordinate given, it was automatically filled with an
218
+ enumeration of the type specified in the annotation
219
+ * Default attribute values were assigned. A value of `None` is interpreted
220
+ as the value attribute being missing.
221
+ * For the returned :py:class:`~xarray.DataArray` object ``data``, ``coo``,
222
+ ``attr1`` and ``attr2`` can be accessed as if they were members. This works
223
+ as long as the names don't collide with :py:class:`~xarray.DataArray` members.
224
+
225
+ Positional parameters are also supported, and ``coords`` and ``attrs`` passed as
226
+ keyword arguments can supply additional coordinates and attributes::
227
+
228
+ >>> TestArray([1,2,3], [3,4,5], 'bar', coords={'coo_new': ('coo', [3,2,1])}, attrs={'xattr': 'baz'})
229
+ <xarray.DataArray (coo: 3)>
230
+ array([1.+0.j, 2.+0.j, 3.+0.j])
231
+ Coordinates:
232
+ coo_new (coo) int64 3 2 1
233
+ * coo (coo) float64 3.0 4.0 5.0
234
+ Attributes:
235
+ xattr: baz
236
+ attr1: bar
237
+ attr2: 123
238
+
239
+ """
240
+
241
+ # Make into a dataclass (might not even be needed at some point?)
242
+ cls = dataclasses.dataclass(cls, init=True, repr=False, eq=False, frozen=True)
243
+
244
+ # Make schema
245
+ cls.__xradio_array_schema = dataclass.xarray_dataclass_to_array_schema(cls)
246
+
247
+ # Replace __new__
248
+ cls.__new__ = _dataarray_new
249
+
250
+ return cls
251
+
252
+
253
+ def is_dataarray_schema(val: typing.Any):
254
+ return type(val) == type and hasattr(val, "__xradio_array_schema")
255
+
256
+
1
257
  class AsDataArray:
2
- """Mix-in class that provides shorthand methods."""
258
+ __new__ = _dataarray_new
259
+
260
+
261
+ def _dataset_new(cls, *args, data_vars=None, coords=None, attrs=None, **kwargs):
262
+ # Get standard xarray parameters (data_vars, coords, attrs)
263
+ # Note that we only support these as keyword arguments for now
264
+ if data_vars is None:
265
+ data_vars = {}
266
+ if coords is None:
267
+ coords = {}
268
+ if attrs is None:
269
+ attrs = {}
270
+
271
+ # Get signature of __init__, map parameters and apply defaults. This
272
+ # will raise an exception if there are any extra parameters.
273
+ sig = inspect.Signature.from_callable(cls.__init__)
274
+ sig = sig.replace(parameters=[v for k, v in sig.parameters.items() if k != "self"])
275
+ mapping = sig.bind_partial(*args, **kwargs)
276
+
277
+ # Now get schema for this class and identify which of the parameters
278
+ # where meant to be data variables, coordinates and attributes
279
+ # respectively. Note that we interpret "None" as "missing" here, so
280
+ # setting an attribute to `None` will require passing them as
281
+ # attrs.
282
+ schema = dataclass.xarray_dataclass_to_dataset_schema(cls)
283
+ for coord in schema.coordinates:
284
+ val = _np_convert(
285
+ _set_parameter(coords.get(coord.name), mapping.arguments, coord), coord
286
+ )
287
+ if val is not None:
288
+ coords[coord.name] = val
289
+ for data_var in schema.data_vars:
290
+ val = _set_parameter(data_vars.get(data_var.name), mapping.arguments, data_var)
291
+
292
+ # Determine dimensions / convert to Variable
293
+ dims = None
294
+ if isinstance(val, xarray.Variable):
295
+ dims = val.dims
296
+ elif isinstance(val, xarray.DataArray):
297
+ val = val.variable
298
+ dims = val.dims
299
+ elif isinstance(val, tuple):
300
+ val = xarray.Variable(*val)
301
+ dims = val.dims
302
+ else:
303
+ # We are dealing with a plain value. Try to convert it to numpy first
304
+ val = _np_convert(val, data_var)
305
+
306
+ # Then identify dimensions by matching against dimensionality
307
+ dims = None
308
+ for ds in data_var.dimensions:
309
+ if len(ds) == len(val.shape):
310
+ dims = ds
311
+ break
312
+ if dims is None:
313
+ options = ["[" + dims.join(",") + "]" for dims in data_var.dimensions]
314
+ raise ValueError(
315
+ f"Data variable {data_var.name} shape has {len(dims)} dimensions,"
316
+ f" expected {' or '.join(options)}!"
317
+ )
318
+
319
+ # Replace by variable
320
+ val = xarray.Variable(dims, val)
321
+
322
+ # Default coordinates used by this data variable to numpy arange. We
323
+ # can only do this now because we need an example to determine the
324
+ # intended size of the coordinate
325
+ for coord in schema.coordinates:
326
+ if coord.name in dims and coords.get(coord.name) is None:
327
+ dim_ix = dims.index(coord.name)
328
+ if dim_ix is not None and dim_ix < len(val.shape):
329
+ dtype = coord.dtypes[0]
330
+ coords[coord.name] = numpy.arange(val.shape[dim_ix], dtype=dtype)
331
+
332
+ if val is not None:
333
+ data_vars[data_var.name] = val
334
+
335
+ for attr in schema.attributes:
336
+ val = _set_parameter(attrs.get(attr.name), mapping.arguments, attr)
337
+ if val is not None:
338
+ attrs[attr.name] = val
339
+
340
+ # Redirect to xradio.Dataset constructor
341
+ instance = xarray.Dataset(data_vars, coords, attrs)
342
+
343
+ # Finally check schema
344
+ check.check_dataset(instance, schema).expect()
345
+
346
+ return instance
347
+
348
+
349
+ def xarray_dataset_schema(cls):
350
+ """Decorator for classes representing :py:class:`xarray.Dataset` schemas.
351
+ The annotated class should exactly contain:
352
+
353
+ * fields annotated with :py:data:`~typing.Coord` to indicate mappings of
354
+ dimensions to coordinates (coordinates directly associated with dimensions
355
+ should have the same name as the dimension)
356
+ * fields annotated with :py:data:`~typing.Data`
357
+ to indicate data variables
358
+ * fields annotated with :py:data:`~typing.Attr` to declare attributes
359
+
360
+ Decorated schema classes can be used with
361
+ :py:func:`~xradio.schema.check.check_dataset` for checking
362
+ :py:class:`xarray.Dataset` objects against the schema. Furthermore, the
363
+ class constructor will be overwritten to generate schema-confirming
364
+ :py:class:`xarray.Dataset` objects.
365
+ """
366
+
367
+ # Make into a dataclass (might not even be needed at some point?)
368
+ cls = dataclasses.dataclass(cls, init=True, repr=False, eq=False, frozen=True)
369
+
370
+ # Make schema
371
+ cls.__xradio_dataset_schema = dataclass.xarray_dataclass_to_dataset_schema(cls)
372
+
373
+ # Replace __new__
374
+ cls.__new__ = _dataset_new
375
+
376
+ return cls
377
+
378
+
379
+ def is_dataset_schema(val: typing.Any):
380
+ return type(val) == type and hasattr(val, "__xradio_dataset_schema")
3
381
 
4
382
 
5
383
  class AsDataset:
6
- """Mix-in class that provides shorthand methods."""
384
+ """Mix-in class to indicate dataset data classes
385
+
386
+ Deprecated - use decorator :py:func:`xarray_dataset_schema` instead
387
+ """
388
+
389
+ __new__ = _dataset_new
390
+
391
+
392
+ def _dict_new(cls, *args, **kwargs):
393
+ # Get signature of __init__, map parameters and apply defaults. This
394
+ # will raise an exception if there are any extra parameters.
395
+ sig = inspect.Signature.from_callable(cls.__init__)
396
+ sig = sig.replace(parameters=[v for k, v in sig.parameters.items() if k != "self"])
397
+ mapping = sig.bind_partial(*args, **kwargs)
398
+ mapping.apply_defaults()
399
+
400
+ # The dictionary is now simply the arguments. Note that this means that
401
+ # in contrast to the behaviour of AsDataset/AsDataarray, for
402
+ # dictionaries we actually interpret a default of "None" as setting the
403
+ # value in question to "None".
404
+ instance = mapping.arguments
405
+
406
+ # Check schema
407
+ check.check_dict(instance, cls).expect()
408
+ return instance
409
+
410
+
411
+ def dict_schema(cls):
412
+ """Decorator for classes representing ``dict`` schemas, along the lines
413
+ of :py:func:`xarray_dataarray_schema` and :py:func:`xarray_dataset_schema`.
414
+
415
+ The annotated class can contain fields with arbitrary annotations, similar
416
+ to a dataclass. They can be used with
417
+ :py:func:`~xradio.schema.check.check_dict` for checking dictionieries
418
+ against the schema. Furthermore, the class constructor will be overwritten
419
+ to generate schema-confirming :py:class:`xarray.Dataset` objects.
420
+ """
421
+
422
+ # Make into a dataclass (might not even be needed at some point?)
423
+ cls = dataclasses.dataclass(cls, init=True, repr=False, eq=False, frozen=True)
424
+
425
+ # Make schema
426
+ cls.__xradio_dict_schema = dataclass.xarray_dataclass_to_dict_schema(cls)
427
+
428
+ # Replace __new__
429
+ cls.__new__ = _dict_new
430
+
431
+ return cls
432
+
433
+
434
+ def is_dict_schema(val: typing.Any):
435
+ return type(val) == type and hasattr(val, "__xradio_dict_schema")
436
+
437
+
438
+ class AsDict:
439
+ """Mix-in class to indicate dictionary data classes
440
+
441
+ Deprecated - use decorator :py:func:`dict_schema` instead
442
+ """
443
+
444
+ __new__ = _dict_new