xradio 0.0.28__py3-none-any.whl → 0.0.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xradio/__init__.py +5 -4
- xradio/_utils/array.py +90 -0
- xradio/_utils/zarr/common.py +48 -3
- xradio/image/_util/zarr.py +4 -1
- xradio/schema/__init__.py +24 -6
- xradio/schema/bases.py +440 -2
- xradio/schema/check.py +96 -55
- xradio/schema/dataclass.py +123 -27
- xradio/schema/metamodel.py +21 -4
- xradio/schema/typing.py +33 -18
- xradio/vis/__init__.py +5 -2
- xradio/vis/_processing_set.py +71 -32
- xradio/vis/_vis_utils/_ms/_tables/create_field_and_source_xds.py +710 -0
- xradio/vis/_vis_utils/_ms/_tables/load.py +23 -10
- xradio/vis/_vis_utils/_ms/_tables/load_main_table.py +145 -64
- xradio/vis/_vis_utils/_ms/_tables/read.py +747 -172
- xradio/vis/_vis_utils/_ms/_tables/read_main_table.py +173 -44
- xradio/vis/_vis_utils/_ms/_tables/read_subtables.py +79 -28
- xradio/vis/_vis_utils/_ms/_tables/write.py +102 -45
- xradio/vis/_vis_utils/_ms/_tables/write_exp_api.py +127 -65
- xradio/vis/_vis_utils/_ms/chunks.py +58 -21
- xradio/vis/_vis_utils/_ms/conversion.py +582 -102
- xradio/vis/_vis_utils/_ms/descr.py +52 -20
- xradio/vis/_vis_utils/_ms/msv2_to_msv4_meta.py +72 -35
- xradio/vis/_vis_utils/_ms/msv4_infos.py +0 -59
- xradio/vis/_vis_utils/_ms/msv4_sub_xdss.py +76 -9
- xradio/vis/_vis_utils/_ms/optimised_functions.py +0 -46
- xradio/vis/_vis_utils/_ms/partition_queries.py +308 -119
- xradio/vis/_vis_utils/_ms/partitions.py +82 -25
- xradio/vis/_vis_utils/_ms/subtables.py +32 -14
- xradio/vis/_vis_utils/_utils/partition_attrs.py +30 -11
- xradio/vis/_vis_utils/_utils/xds_helper.py +136 -45
- xradio/vis/_vis_utils/_zarr/read.py +60 -22
- xradio/vis/_vis_utils/_zarr/write.py +83 -9
- xradio/vis/_vis_utils/ms.py +48 -29
- xradio/vis/_vis_utils/zarr.py +44 -20
- xradio/vis/convert_msv2_to_processing_set.py +43 -32
- xradio/vis/load_processing_set.py +38 -61
- xradio/vis/read_processing_set.py +64 -96
- xradio/vis/schema.py +687 -0
- xradio/vis/vis_io.py +75 -43
- {xradio-0.0.28.dist-info → xradio-0.0.30.dist-info}/LICENSE.txt +6 -1
- {xradio-0.0.28.dist-info → xradio-0.0.30.dist-info}/METADATA +10 -5
- xradio-0.0.30.dist-info/RECORD +73 -0
- {xradio-0.0.28.dist-info → xradio-0.0.30.dist-info}/WHEEL +1 -1
- xradio/vis/model.py +0 -497
- xradio-0.0.28.dist-info/RECORD +0 -71
- {xradio-0.0.28.dist-info → xradio-0.0.30.dist-info}/top_level.txt +0 -0
xradio/schema/check.py
CHANGED
|
@@ -5,13 +5,14 @@ import functools
|
|
|
5
5
|
|
|
6
6
|
import xarray
|
|
7
7
|
import numpy
|
|
8
|
+
from typeguard import check_type, TypeCheckError
|
|
8
9
|
|
|
9
10
|
from xradio.schema import (
|
|
10
11
|
metamodel,
|
|
12
|
+
bases,
|
|
11
13
|
xarray_dataclass_to_array_schema,
|
|
12
14
|
xarray_dataclass_to_dataset_schema,
|
|
13
|
-
|
|
14
|
-
AsDataArray,
|
|
15
|
+
xarray_dataclass_to_dict_schema,
|
|
15
16
|
)
|
|
16
17
|
|
|
17
18
|
|
|
@@ -110,6 +111,9 @@ class SchemaIssues(Exception):
|
|
|
110
111
|
issues_string = "\n * ".join(repr(issue) for issue in self.issues)
|
|
111
112
|
return f"\n * {issues_string}"
|
|
112
113
|
|
|
114
|
+
def __repr__(self):
|
|
115
|
+
return f"SchemaIssues({str(self)})"
|
|
116
|
+
|
|
113
117
|
def expect(
|
|
114
118
|
self, elem: typing.Optional[str] = None, ix: typing.Optional[str] = None
|
|
115
119
|
):
|
|
@@ -129,12 +133,14 @@ class SchemaIssues(Exception):
|
|
|
129
133
|
raise self
|
|
130
134
|
|
|
131
135
|
|
|
132
|
-
def check_array(
|
|
136
|
+
def check_array(
|
|
137
|
+
array: xarray.DataArray, schema: typing.Union[type, metamodel.ArraySchema]
|
|
138
|
+
) -> SchemaIssues:
|
|
133
139
|
"""
|
|
134
140
|
Check whether an xarray DataArray conforms to a schema
|
|
135
141
|
|
|
136
142
|
:param array: DataArray to check
|
|
137
|
-
:param schema: Schema to check against
|
|
143
|
+
:param schema: Schema to check against (possibly as :py:class:`AsDataset`)
|
|
138
144
|
:returns: List of :py:class:`SchemaIssue`s found
|
|
139
145
|
"""
|
|
140
146
|
|
|
@@ -143,7 +149,7 @@ def check_array(array: xarray.DataArray, schema: metamodel.ArraySchema) -> Schem
|
|
|
143
149
|
raise TypeError(
|
|
144
150
|
f"check_array: Expected xarray.DataArray, but got {type(array)}!"
|
|
145
151
|
)
|
|
146
|
-
if
|
|
152
|
+
if bases.is_dataarray_schema(schema):
|
|
147
153
|
schema = xarray_dataclass_to_array_schema(schema)
|
|
148
154
|
if not isinstance(schema, metamodel.ArraySchema):
|
|
149
155
|
raise TypeError(f"check_array: Expected ArraySchema, but got {type(schema)}!")
|
|
@@ -164,13 +170,13 @@ def check_array(array: xarray.DataArray, schema: metamodel.ArraySchema) -> Schem
|
|
|
164
170
|
|
|
165
171
|
|
|
166
172
|
def check_dataset(
|
|
167
|
-
dataset: xarray.Dataset, schema: metamodel.DatasetSchema
|
|
173
|
+
dataset: xarray.Dataset, schema: typing.Union[type, metamodel.DatasetSchema]
|
|
168
174
|
) -> SchemaIssues:
|
|
169
175
|
"""
|
|
170
176
|
Check whether an xarray DataArray conforms to a schema
|
|
171
177
|
|
|
172
178
|
:param array: DataArray to check
|
|
173
|
-
:param schema: Schema to check against (possibly as dataclass)
|
|
179
|
+
:param schema: Schema to check against (possibly as :py:class:`AsDataArray` dataclass)
|
|
174
180
|
:returns: List of :py:class:`SchemaIssue`s found
|
|
175
181
|
"""
|
|
176
182
|
|
|
@@ -179,7 +185,7 @@ def check_dataset(
|
|
|
179
185
|
raise TypeError(
|
|
180
186
|
f"check_dataset: Expected xarray.Dataset, but got {type(dataset)}!"
|
|
181
187
|
)
|
|
182
|
-
if
|
|
188
|
+
if bases.is_dataset_schema(schema):
|
|
183
189
|
schema = xarray_dataclass_to_dataset_schema(schema)
|
|
184
190
|
if not isinstance(schema, metamodel.DatasetSchema):
|
|
185
191
|
raise TypeError(
|
|
@@ -247,7 +253,7 @@ def check_dimensions(
|
|
|
247
253
|
hint_remove = [f"'{hint}'" for hint in dims_set - best]
|
|
248
254
|
hint_add = [f"'{hint}'" for hint in best - dims_set]
|
|
249
255
|
if hint_remove and hint_add:
|
|
250
|
-
message = f"Unexpected coordinates, replace {','.join(
|
|
256
|
+
message = f"Unexpected coordinates, replace {','.join(hint_remove)} by {','.join(hint_add)}?"
|
|
251
257
|
elif hint_remove:
|
|
252
258
|
message = f"Superflous coordinate {','.join(hint_remove)}?"
|
|
253
259
|
elif hint_add:
|
|
@@ -259,8 +265,8 @@ def check_dimensions(
|
|
|
259
265
|
SchemaIssue(
|
|
260
266
|
path=[("dims", None)],
|
|
261
267
|
message=message,
|
|
262
|
-
found=dims,
|
|
263
|
-
expected=expected,
|
|
268
|
+
found=list(dims),
|
|
269
|
+
expected=list(expected),
|
|
264
270
|
)
|
|
265
271
|
]
|
|
266
272
|
)
|
|
@@ -276,7 +282,13 @@ def check_dtype(dtype: numpy.dtype, expected: [numpy.dtype]) -> SchemaIssues:
|
|
|
276
282
|
"""
|
|
277
283
|
|
|
278
284
|
for exp_dtype in expected:
|
|
279
|
-
|
|
285
|
+
# If the expected dtype has no size (e.g. "U", a.k.a. a string of
|
|
286
|
+
# arbitrary length), we don't check itemsize, only kind.
|
|
287
|
+
if (
|
|
288
|
+
dtype.kind == exp_dtype.kind
|
|
289
|
+
and exp_dtype.itemsize == 0
|
|
290
|
+
or exp_dtype == dtype
|
|
291
|
+
):
|
|
280
292
|
return SchemaIssues()
|
|
281
293
|
|
|
282
294
|
# Not sure there's anything more helpful that we can do here? Any special
|
|
@@ -287,7 +299,7 @@ def check_dtype(dtype: numpy.dtype, expected: [numpy.dtype]) -> SchemaIssues:
|
|
|
287
299
|
path=[("dtype", None)],
|
|
288
300
|
message="Wrong numpy dtype",
|
|
289
301
|
found=dtype,
|
|
290
|
-
expected=expected,
|
|
302
|
+
expected=list(expected),
|
|
291
303
|
)
|
|
292
304
|
]
|
|
293
305
|
)
|
|
@@ -296,6 +308,7 @@ def check_dtype(dtype: numpy.dtype, expected: [numpy.dtype]) -> SchemaIssues:
|
|
|
296
308
|
def check_attributes(
|
|
297
309
|
attrs: typing.Dict[str, typing.Any],
|
|
298
310
|
attrs_schema: typing.List[metamodel.AttrSchemaRef],
|
|
311
|
+
attr_kind: str = "attrs",
|
|
299
312
|
) -> SchemaIssues:
|
|
300
313
|
"""
|
|
301
314
|
Check whether an attribute set conforms to a schema
|
|
@@ -307,23 +320,29 @@ def check_attributes(
|
|
|
307
320
|
|
|
308
321
|
issues = SchemaIssues()
|
|
309
322
|
for attr_schema in attrs_schema:
|
|
310
|
-
|
|
311
323
|
# Attribute missing? Note that a value of "None" is equivalent for the
|
|
312
324
|
# purpose of the check
|
|
313
325
|
val = attrs.get(attr_schema.name)
|
|
314
326
|
if val is None:
|
|
315
327
|
if not attr_schema.optional:
|
|
328
|
+
# Get options
|
|
329
|
+
if typing.get_origin(attr_schema.typ) is typing.Union:
|
|
330
|
+
options = typing.get_args(attr_schema.typ)
|
|
331
|
+
else:
|
|
332
|
+
options = [attr_schema.typ]
|
|
333
|
+
|
|
316
334
|
issues.add(
|
|
317
335
|
SchemaIssue(
|
|
318
|
-
path=[(
|
|
336
|
+
path=[(attr_kind, attr_schema.name)],
|
|
319
337
|
message=f"Required attribute {attr_schema.name} is missing!",
|
|
338
|
+
expected=options,
|
|
320
339
|
)
|
|
321
340
|
)
|
|
322
341
|
continue
|
|
323
342
|
|
|
324
343
|
# Check attribute value
|
|
325
344
|
issues += _check_value_union(val, attr_schema.typ).at_path(
|
|
326
|
-
|
|
345
|
+
attr_kind, attr_schema.name
|
|
327
346
|
)
|
|
328
347
|
|
|
329
348
|
# Extra attributes are always okay
|
|
@@ -352,7 +371,6 @@ def check_data_vars(
|
|
|
352
371
|
|
|
353
372
|
issues = SchemaIssues()
|
|
354
373
|
for data_var_schema in data_vars_schema:
|
|
355
|
-
|
|
356
374
|
# Data_Varinate missing?
|
|
357
375
|
data_var = data_vars.get(data_var_schema.name)
|
|
358
376
|
if data_var is None:
|
|
@@ -382,6 +400,21 @@ def check_data_vars(
|
|
|
382
400
|
return issues
|
|
383
401
|
|
|
384
402
|
|
|
403
|
+
def check_dict(
|
|
404
|
+
dct: dict, schema: typing.Union[type, metamodel.DictSchema]
|
|
405
|
+
) -> SchemaIssues:
|
|
406
|
+
# Check that this is actually a dictionary
|
|
407
|
+
if not isinstance(dct, dict):
|
|
408
|
+
raise TypeError(f"check_dict: Expected dictionary, but got {type(dct)}!")
|
|
409
|
+
if bases.is_dict_schema(schema):
|
|
410
|
+
schema = xarray_dataclass_to_dict_schema(schema)
|
|
411
|
+
if not isinstance(schema, metamodel.DictSchema):
|
|
412
|
+
raise TypeError(f"check_dict: Expected DictSchema, but got {type(schema)}!")
|
|
413
|
+
|
|
414
|
+
# Check attributes
|
|
415
|
+
return check_attributes(dct, schema.attributes, attr_kind="")
|
|
416
|
+
|
|
417
|
+
|
|
385
418
|
def _check_value(val, ann):
|
|
386
419
|
"""
|
|
387
420
|
Check whether value satisfies annotation
|
|
@@ -394,46 +427,60 @@ def _check_value(val, ann):
|
|
|
394
427
|
"""
|
|
395
428
|
|
|
396
429
|
# Is supposed to be a data array?
|
|
397
|
-
if
|
|
430
|
+
if bases.is_dataarray_schema(ann):
|
|
431
|
+
# Attempt to convert dictionaries automatically
|
|
432
|
+
if isinstance(val, dict):
|
|
433
|
+
try:
|
|
434
|
+
val = xarray.DataArray.from_dict(val)
|
|
435
|
+
except ValueError as e:
|
|
436
|
+
return SchemaIssues(
|
|
437
|
+
[
|
|
438
|
+
SchemaIssue(
|
|
439
|
+
path=[], message=str(e), expected=[ann], found=type(val)
|
|
440
|
+
)
|
|
441
|
+
]
|
|
442
|
+
)
|
|
443
|
+
|
|
398
444
|
if not isinstance(val, xarray.DataArray):
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
SchemaIssue(
|
|
402
|
-
path=[],
|
|
403
|
-
message="Unexpected type",
|
|
404
|
-
expected=[xarray.DataArray],
|
|
405
|
-
found=type(val),
|
|
406
|
-
)
|
|
407
|
-
]
|
|
408
|
-
)
|
|
445
|
+
# Fall through to plain type check
|
|
446
|
+
ann = xarray.DataArray
|
|
409
447
|
else:
|
|
410
448
|
return check_array(val, ann)
|
|
411
449
|
|
|
412
450
|
# Is supposed to be a dataset?
|
|
413
|
-
if
|
|
451
|
+
if bases.is_dataset_schema(ann):
|
|
452
|
+
# Attempt to convert dictionaries automatically
|
|
453
|
+
if isinstance(val, dict):
|
|
454
|
+
try:
|
|
455
|
+
val = xarray.Dataset.from_dict(val)
|
|
456
|
+
except ValueError as e:
|
|
457
|
+
return SchemaIssues(
|
|
458
|
+
[
|
|
459
|
+
SchemaIssue(
|
|
460
|
+
path=[], message=str(t), expected=[ann], found=type(val)
|
|
461
|
+
)
|
|
462
|
+
]
|
|
463
|
+
)
|
|
414
464
|
if not isinstance(val, xarray.Dataset):
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
SchemaIssue(
|
|
418
|
-
path=[],
|
|
419
|
-
message="Unexpected type",
|
|
420
|
-
expected=[xarray.DataArray],
|
|
421
|
-
found=type(val),
|
|
422
|
-
)
|
|
423
|
-
]
|
|
424
|
-
)
|
|
465
|
+
# Fall through to plain type check
|
|
466
|
+
ann = xarray.Dataset
|
|
425
467
|
else:
|
|
426
468
|
return check_dataset(val, ann)
|
|
427
469
|
|
|
428
|
-
#
|
|
429
|
-
|
|
430
|
-
|
|
470
|
+
# Is supposed to be a dictionary?
|
|
471
|
+
if bases.is_dict_schema(ann):
|
|
472
|
+
if not isinstance(val, dict):
|
|
473
|
+
# Fall through to plain type check
|
|
474
|
+
ann = dict
|
|
475
|
+
else:
|
|
476
|
+
return check_dict(val, ann)
|
|
477
|
+
|
|
478
|
+
# Otherwise straight type check using typeguard
|
|
479
|
+
try:
|
|
480
|
+
check_type(val, ann)
|
|
481
|
+
except TypeCheckError as t:
|
|
431
482
|
return SchemaIssues(
|
|
432
|
-
[
|
|
433
|
-
SchemaIssue(
|
|
434
|
-
path=[], message="Unexpected type", expected=[ann], found=type(val)
|
|
435
|
-
)
|
|
436
|
-
]
|
|
483
|
+
[SchemaIssue(path=[], message=str(t), expected=[ann], found=type(val))]
|
|
437
484
|
)
|
|
438
485
|
|
|
439
486
|
return SchemaIssues()
|
|
@@ -471,13 +518,8 @@ def _check_value_union(val, ann):
|
|
|
471
518
|
if args_issues is None:
|
|
472
519
|
args_issues = arg_issues
|
|
473
520
|
|
|
474
|
-
#
|
|
475
|
-
elif (
|
|
476
|
-
len(args_issues) == 1
|
|
477
|
-
and len(arg_issues) == 1
|
|
478
|
-
and args_issues[0].message == arg_issues[0].message
|
|
479
|
-
):
|
|
480
|
-
|
|
521
|
+
# Crude merging of expected options (for "unexpected type")
|
|
522
|
+
elif len(args_issues) == 1 and len(arg_issues) == 1:
|
|
481
523
|
args_issues[0].expected += arg_issues[0].expected
|
|
482
524
|
|
|
483
525
|
# Return representative issues list
|
|
@@ -511,7 +553,6 @@ def schema_checked(fn, check_parameters: bool = True, check_return: bool = True)
|
|
|
511
553
|
|
|
512
554
|
@functools.wraps(fn)
|
|
513
555
|
def _check_fn(*args, **kwargs):
|
|
514
|
-
|
|
515
556
|
# Hide this function in pytest tracebacks
|
|
516
557
|
__tracebackhide__ = True
|
|
517
558
|
|
xradio/schema/dataclass.py
CHANGED
|
@@ -1,10 +1,13 @@
|
|
|
1
1
|
from typing import get_type_hints, get_args
|
|
2
|
-
from .typing import get_dims,
|
|
2
|
+
from .typing import get_dims, get_types, get_role, Role, get_annotated, is_optional
|
|
3
3
|
|
|
4
4
|
import typing
|
|
5
5
|
import inspect
|
|
6
6
|
import ast
|
|
7
7
|
import dataclasses
|
|
8
|
+
import numpy
|
|
9
|
+
import itertools
|
|
10
|
+
import textwrap
|
|
8
11
|
|
|
9
12
|
from xradio.schema.metamodel import *
|
|
10
13
|
|
|
@@ -18,8 +21,11 @@ def extract_field_docstrings(klass):
|
|
|
18
21
|
"""
|
|
19
22
|
|
|
20
23
|
# Parse class body
|
|
21
|
-
|
|
22
|
-
|
|
24
|
+
try:
|
|
25
|
+
src = inspect.getsource(klass)
|
|
26
|
+
except OSError:
|
|
27
|
+
return {}
|
|
28
|
+
module = ast.parse(textwrap.dedent(src))
|
|
23
29
|
|
|
24
30
|
# Expect module containing a class definition
|
|
25
31
|
if not isinstance(module, ast.Module) or len(module.body) != 1:
|
|
@@ -31,7 +37,6 @@ def extract_field_docstrings(klass):
|
|
|
31
37
|
# Go through body, collect dostrings
|
|
32
38
|
docstrings = {}
|
|
33
39
|
for i, assign in enumerate(cls.body):
|
|
34
|
-
|
|
35
40
|
# Handle both annotated and unannotated case
|
|
36
41
|
if isinstance(assign, ast.AnnAssign):
|
|
37
42
|
if not isinstance(assign.target, ast.Name):
|
|
@@ -78,7 +83,6 @@ def extract_xarray_dataclass(klass):
|
|
|
78
83
|
data_vars = []
|
|
79
84
|
attributes = []
|
|
80
85
|
for field in dataclasses.fields(klass):
|
|
81
|
-
|
|
82
86
|
# Get field "role" (coordinate / data variable / attribute) from its
|
|
83
87
|
# type hint
|
|
84
88
|
typ = type_hints[field.name]
|
|
@@ -111,7 +115,6 @@ def extract_xarray_dataclass(klass):
|
|
|
111
115
|
# Defined using a dataclass, i.e. Coordof/Dataof?
|
|
112
116
|
dataclass = typing.get_args(get_annotated(typ))[0]
|
|
113
117
|
if dataclasses.is_dataclass(dataclass):
|
|
114
|
-
|
|
115
118
|
# Recursively get array schema for data class
|
|
116
119
|
arr_schema = xarray_dataclass_to_array_schema(dataclass)
|
|
117
120
|
arr_schema_fields = {
|
|
@@ -129,24 +132,51 @@ def extract_xarray_dataclass(klass):
|
|
|
129
132
|
)
|
|
130
133
|
|
|
131
134
|
else:
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
135
|
+
# Get dimensions and dtypes
|
|
136
|
+
dims = get_dims(typ)
|
|
137
|
+
types = get_types(typ)
|
|
138
|
+
|
|
139
|
+
# Is types a (single) dataclass?
|
|
140
|
+
if len(types) == 1 and dataclasses.is_dataclass(types[0]):
|
|
141
|
+
# Recursively get array schema for data class
|
|
142
|
+
arr_schema = xarray_dataclass_to_array_schema(types[0])
|
|
143
|
+
|
|
144
|
+
# Prepend dimensions to array schema
|
|
145
|
+
combined_dimensions = [
|
|
146
|
+
dims1 + dims2
|
|
147
|
+
for dims1, dims2 in itertools.product(dims, arr_schema.dimensions)
|
|
148
|
+
]
|
|
149
|
+
|
|
150
|
+
# Repackage as reference
|
|
151
|
+
arr_schema_fields = {
|
|
152
|
+
f.name: getattr(arr_schema, f.name)
|
|
153
|
+
for f in dataclasses.fields(ArraySchema)
|
|
154
|
+
}
|
|
155
|
+
arr_schema_fields["dimensions"] = combined_dimensions
|
|
156
|
+
schema_ref = ArraySchemaRef(
|
|
157
|
+
name=field.name,
|
|
158
|
+
optional=is_optional(typ),
|
|
159
|
+
default=field.default,
|
|
160
|
+
docstring=field_docstrings.get(field.name),
|
|
161
|
+
**arr_schema_fields,
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
# Assume that it's an "inline" declaration using "Coord"/"Data"
|
|
165
|
+
schema_ref = ArraySchemaRef(
|
|
166
|
+
name=field.name,
|
|
167
|
+
optional=is_optional(typ),
|
|
168
|
+
default=field.default,
|
|
169
|
+
docstring=field_docstrings.get(field.name),
|
|
170
|
+
schema_name=f"{klass.__module__}.{klass.__qualname__}.{field.name}",
|
|
171
|
+
dimensions=dims,
|
|
172
|
+
dtypes=[numpy.dtype(typ) for typ in types],
|
|
173
|
+
coordinates=[],
|
|
174
|
+
attributes=[],
|
|
175
|
+
class_docstring=None,
|
|
176
|
+
data_docstring=None,
|
|
177
|
+
)
|
|
147
178
|
|
|
148
179
|
if is_coord:
|
|
149
|
-
|
|
150
180
|
# Make sure that it is valid to use as a coordinate - i.e. we don't
|
|
151
181
|
# have "recursive" (?!) coordinate definitions
|
|
152
182
|
if not schema_ref.is_coord():
|
|
@@ -171,23 +201,34 @@ def xarray_dataclass_to_array_schema(klass):
|
|
|
171
201
|
refer to using CoordOf or DataOf
|
|
172
202
|
"""
|
|
173
203
|
|
|
204
|
+
# Cached?
|
|
205
|
+
if hasattr(klass, "__xradio_array_schema"):
|
|
206
|
+
return klass.__xradio_array_schema
|
|
207
|
+
|
|
174
208
|
# Extract from data class
|
|
175
209
|
coordinates, data_vars, attributes = extract_xarray_dataclass(klass)
|
|
176
210
|
|
|
177
211
|
# For a dataclass there must be exactly one data variable
|
|
178
|
-
# (typically called "data", but we don't check that)
|
|
179
212
|
if not data_vars:
|
|
180
213
|
raise ValueError(
|
|
181
|
-
f"Found no data declaration in (supposed) data
|
|
214
|
+
f"Found no data declaration in (supposed) data array class {klass.__name__}!"
|
|
182
215
|
)
|
|
183
216
|
if len(data_vars) > 1:
|
|
184
217
|
raise ValueError(
|
|
185
218
|
f"Found multiple data variables ({', '.join(v.name for v in data_vars)})"
|
|
186
|
-
f" in supposed data
|
|
219
|
+
f" in supposed data array class {klass.__name__}!"
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
# Check that data variable is named "data". This is important for this to
|
|
223
|
+
# match up with parameters to xarray.DataArray() later (see bases.AsArray)
|
|
224
|
+
if data_vars[0].name != "data":
|
|
225
|
+
raise ValueError(
|
|
226
|
+
f"Data variable in data array class {klass.__name__} "
|
|
227
|
+
f'must be called "data", not {data_vars[0].name}!'
|
|
187
228
|
)
|
|
188
229
|
|
|
189
230
|
# Make class
|
|
190
|
-
|
|
231
|
+
schema = ArraySchema(
|
|
191
232
|
schema_name=f"{klass.__module__}.{klass.__qualname__}",
|
|
192
233
|
dimensions=data_vars[0].dimensions,
|
|
193
234
|
dtypes=data_vars[0].dtypes,
|
|
@@ -196,6 +237,8 @@ def xarray_dataclass_to_array_schema(klass):
|
|
|
196
237
|
class_docstring=inspect.cleandoc(klass.__doc__),
|
|
197
238
|
data_docstring=data_vars[0].docstring,
|
|
198
239
|
)
|
|
240
|
+
klass.__xradio_array_schema = schema
|
|
241
|
+
return schema
|
|
199
242
|
|
|
200
243
|
|
|
201
244
|
def xarray_dataclass_to_dataset_schema(klass):
|
|
@@ -206,6 +249,10 @@ def xarray_dataclass_to_dataset_schema(klass):
|
|
|
206
249
|
refer to using CoordOf or DataOf
|
|
207
250
|
"""
|
|
208
251
|
|
|
252
|
+
# Cached?
|
|
253
|
+
if hasattr(klass, "__xradio_dataset_schema"):
|
|
254
|
+
return klass.__xradio_dataset_schema
|
|
255
|
+
|
|
209
256
|
# Extract from data class
|
|
210
257
|
coordinates, data_vars, attributes = extract_xarray_dataclass(klass)
|
|
211
258
|
|
|
@@ -242,7 +289,7 @@ def xarray_dataclass_to_dataset_schema(klass):
|
|
|
242
289
|
dimensions = [[dim for dim in all_dimensions if dim in dims] for dims in dimensions]
|
|
243
290
|
|
|
244
291
|
# Make class
|
|
245
|
-
|
|
292
|
+
schema = DatasetSchema(
|
|
246
293
|
schema_name=f"{klass.__module__}.{klass.__qualname__}",
|
|
247
294
|
dimensions=dimensions,
|
|
248
295
|
coordinates=coordinates,
|
|
@@ -250,3 +297,52 @@ def xarray_dataclass_to_dataset_schema(klass):
|
|
|
250
297
|
attributes=attributes,
|
|
251
298
|
class_docstring=inspect.cleandoc(klass.__doc__),
|
|
252
299
|
)
|
|
300
|
+
klass.__xradio_dataset_schema = schema
|
|
301
|
+
return schema
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def xarray_dataclass_to_dict_schema(klass):
|
|
305
|
+
"""
|
|
306
|
+
Convert an xarray-dataclass style schema dataclass to an DictSchema
|
|
307
|
+
|
|
308
|
+
This should work on any class annotated with :py:func:`~xradio.schema.bases.dict_schema`
|
|
309
|
+
"""
|
|
310
|
+
|
|
311
|
+
# Cached?
|
|
312
|
+
if hasattr(klass, "__xradio_dict_schema"):
|
|
313
|
+
return klass.__xradio_dict_schema
|
|
314
|
+
|
|
315
|
+
# Get docstrings and type hints
|
|
316
|
+
field_docstrings = extract_field_docstrings(klass)
|
|
317
|
+
type_hints = get_type_hints(klass, include_extras=True)
|
|
318
|
+
attributes = []
|
|
319
|
+
for field in dataclasses.fields(klass):
|
|
320
|
+
typ = type_hints[field.name]
|
|
321
|
+
|
|
322
|
+
# Handle optional value: Strip "None" from the types
|
|
323
|
+
optional = is_optional(typ)
|
|
324
|
+
if optional:
|
|
325
|
+
typs = [typ for typ in get_args(typ) if typ is not None.__class__]
|
|
326
|
+
if len(typs) == 1:
|
|
327
|
+
typ = typs[0]
|
|
328
|
+
else:
|
|
329
|
+
typ = typing.Union.__getitem__[tuple(typs)]
|
|
330
|
+
|
|
331
|
+
attributes.append(
|
|
332
|
+
AttrSchemaRef(
|
|
333
|
+
name=field.name,
|
|
334
|
+
typ=typ,
|
|
335
|
+
optional=optional,
|
|
336
|
+
default=field.default,
|
|
337
|
+
docstring=field_docstrings.get(field.name),
|
|
338
|
+
)
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
# Return
|
|
342
|
+
schema = DictSchema(
|
|
343
|
+
schema_name=f"{klass.__module__}.{klass.__qualname__}",
|
|
344
|
+
attributes=attributes,
|
|
345
|
+
class_docstring=inspect.cleandoc(klass.__doc__),
|
|
346
|
+
)
|
|
347
|
+
klass.__xradio_dict_schema = schema
|
|
348
|
+
return schema
|
xradio/schema/metamodel.py
CHANGED
|
@@ -6,10 +6,11 @@ __all__ = [
|
|
|
6
6
|
"ArraySchema",
|
|
7
7
|
"ArraySchemaRef",
|
|
8
8
|
"DatasetSchema",
|
|
9
|
+
"DictSchema",
|
|
9
10
|
]
|
|
10
11
|
|
|
11
12
|
|
|
12
|
-
@dataclass
|
|
13
|
+
@dataclass(frozen=True)
|
|
13
14
|
class AttrSchemaRef:
|
|
14
15
|
"""
|
|
15
16
|
Schema information about an attribute as referenced from an array or
|
|
@@ -34,7 +35,7 @@ class AttrSchemaRef:
|
|
|
34
35
|
"""Documentation string of attribute reference"""
|
|
35
36
|
|
|
36
37
|
|
|
37
|
-
@dataclass
|
|
38
|
+
@dataclass(frozen=True)
|
|
38
39
|
class ArraySchema:
|
|
39
40
|
"""
|
|
40
41
|
Schema for xarray data array
|
|
@@ -83,7 +84,7 @@ class ArraySchema:
|
|
|
83
84
|
return req_dims
|
|
84
85
|
|
|
85
86
|
|
|
86
|
-
@dataclass
|
|
87
|
+
@dataclass(frozen=True)
|
|
87
88
|
class ArraySchemaRef(ArraySchema):
|
|
88
89
|
"""
|
|
89
90
|
Schema for xarray data array as referenced from a dataset schema
|
|
@@ -102,7 +103,7 @@ class ArraySchemaRef(ArraySchema):
|
|
|
102
103
|
"""Documentation string of array reference"""
|
|
103
104
|
|
|
104
105
|
|
|
105
|
-
@dataclass
|
|
106
|
+
@dataclass(frozen=True)
|
|
106
107
|
class DatasetSchema:
|
|
107
108
|
"""
|
|
108
109
|
Schema for an xarray dataset
|
|
@@ -122,3 +123,19 @@ class DatasetSchema:
|
|
|
122
123
|
|
|
123
124
|
class_docstring: typing.Optional[str]
|
|
124
125
|
"""Documentation string of class"""
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@dataclass(frozen=True)
|
|
129
|
+
class DictSchema:
|
|
130
|
+
"""
|
|
131
|
+
Schema for a simple dictionary
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
schema_name: str
|
|
135
|
+
"""(Class) name of the schema"""
|
|
136
|
+
|
|
137
|
+
attributes: [AttrSchemaRef]
|
|
138
|
+
"""List of attributes"""
|
|
139
|
+
|
|
140
|
+
class_docstring: typing.Optional[str]
|
|
141
|
+
"""Documentation string of class"""
|
xradio/schema/typing.py
CHANGED
|
@@ -29,16 +29,24 @@ from typing import (
|
|
|
29
29
|
Protocol,
|
|
30
30
|
)
|
|
31
31
|
|
|
32
|
+
from typing import Union
|
|
33
|
+
|
|
32
34
|
try:
|
|
35
|
+
# Python 3.10 forward: TypeAlias, ParamSpec are standard, and there is the
|
|
36
|
+
# "a | b" UnionType alternative to "Union[a,b]"
|
|
33
37
|
from typing import TypeAlias, ParamSpec
|
|
34
38
|
from types import UnionType
|
|
39
|
+
|
|
40
|
+
HAVE_UNIONTYPE = True
|
|
35
41
|
except ImportError:
|
|
36
|
-
#
|
|
42
|
+
# Python 3.9: Get TypeAlias, ParamSpec from typing_extensions, no support
|
|
43
|
+
# for "a | b"
|
|
37
44
|
from typing_extensions import (
|
|
38
45
|
TypeAlias,
|
|
39
46
|
ParamSpec,
|
|
40
47
|
)
|
|
41
|
-
|
|
48
|
+
|
|
49
|
+
HAVE_UNIONTYPE = False
|
|
42
50
|
import numpy as np
|
|
43
51
|
from itertools import chain
|
|
44
52
|
from enum import Enum
|
|
@@ -295,7 +303,9 @@ def get_dims(tp: Any) -> List[Dims]:
|
|
|
295
303
|
raise TypeError(f"Could not find any dims in {tp!r}.")
|
|
296
304
|
|
|
297
305
|
# List of allowed dtypes (might just be one)
|
|
298
|
-
if get_origin(dims) is
|
|
306
|
+
if get_origin(dims) is Union:
|
|
307
|
+
dims_in = get_args(dims)
|
|
308
|
+
elif HAVE_UNIONTYPE and get_origin(dims) is UnionType:
|
|
299
309
|
dims_in = get_args(dims)
|
|
300
310
|
else:
|
|
301
311
|
dims_in = [dims]
|
|
@@ -326,35 +336,40 @@ def get_dims(tp: Any) -> List[Dims]:
|
|
|
326
336
|
return dims_out
|
|
327
337
|
|
|
328
338
|
|
|
329
|
-
def
|
|
330
|
-
"""Extract
|
|
339
|
+
def get_types(tp: Any) -> List[AnyDType]:
|
|
340
|
+
"""Extract data types from type annotation
|
|
341
|
+
|
|
342
|
+
E.g. Coord[..., Type1 | Type2 | ...] or Data[..., Type1 | Type2 | ...]
|
|
343
|
+
|
|
344
|
+
"""
|
|
331
345
|
try:
|
|
332
|
-
|
|
346
|
+
typ = get_args(get_args(get_annotated(tp))[1])[0]
|
|
333
347
|
except TypeError:
|
|
334
348
|
raise TypeError(f"Could not find any dtype in {tp!r}.")
|
|
335
349
|
|
|
336
350
|
# List of allowed dtypes (might just be one)
|
|
337
|
-
if get_origin(
|
|
338
|
-
|
|
351
|
+
if get_origin(typ) is Union:
|
|
352
|
+
types_in = get_args(typ)
|
|
353
|
+
elif HAVE_UNIONTYPE and get_origin(typ) is UnionType:
|
|
354
|
+
types_in = get_args(typ)
|
|
339
355
|
else:
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
dtypes_out = []
|
|
343
|
-
for dt in dtypes_in:
|
|
356
|
+
types_in = [typ]
|
|
344
357
|
|
|
358
|
+
types_out = []
|
|
359
|
+
for dt in types_in:
|
|
345
360
|
# Handle case that we want to allow "Any"
|
|
346
361
|
if dt is Any or dt is type(None):
|
|
347
|
-
|
|
362
|
+
types_out.append(None)
|
|
348
363
|
continue
|
|
349
364
|
|
|
350
|
-
# Allow specifying
|
|
365
|
+
# Allow specifying type as literal (e.g. string)
|
|
351
366
|
elif get_origin(dt) is Literal:
|
|
352
|
-
|
|
367
|
+
dt = get_args(dt)[0]
|
|
353
368
|
|
|
354
|
-
#
|
|
355
|
-
|
|
369
|
+
# Return type
|
|
370
|
+
types_out.append(dt)
|
|
356
371
|
|
|
357
|
-
return
|
|
372
|
+
return types_out
|
|
358
373
|
|
|
359
374
|
|
|
360
375
|
def get_name(tp: Any, default: Hashable = None) -> Hashable:
|