xradio 0.0.28__py3-none-any.whl → 0.0.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. xradio/__init__.py +5 -4
  2. xradio/_utils/array.py +90 -0
  3. xradio/_utils/zarr/common.py +48 -3
  4. xradio/image/_util/zarr.py +4 -1
  5. xradio/schema/__init__.py +24 -6
  6. xradio/schema/bases.py +440 -2
  7. xradio/schema/check.py +96 -55
  8. xradio/schema/dataclass.py +123 -27
  9. xradio/schema/metamodel.py +21 -4
  10. xradio/schema/typing.py +33 -18
  11. xradio/vis/__init__.py +5 -2
  12. xradio/vis/_processing_set.py +71 -32
  13. xradio/vis/_vis_utils/_ms/_tables/create_field_and_source_xds.py +710 -0
  14. xradio/vis/_vis_utils/_ms/_tables/load.py +23 -10
  15. xradio/vis/_vis_utils/_ms/_tables/load_main_table.py +145 -64
  16. xradio/vis/_vis_utils/_ms/_tables/read.py +747 -172
  17. xradio/vis/_vis_utils/_ms/_tables/read_main_table.py +173 -44
  18. xradio/vis/_vis_utils/_ms/_tables/read_subtables.py +79 -28
  19. xradio/vis/_vis_utils/_ms/_tables/write.py +102 -45
  20. xradio/vis/_vis_utils/_ms/_tables/write_exp_api.py +127 -65
  21. xradio/vis/_vis_utils/_ms/chunks.py +58 -21
  22. xradio/vis/_vis_utils/_ms/conversion.py +582 -102
  23. xradio/vis/_vis_utils/_ms/descr.py +52 -20
  24. xradio/vis/_vis_utils/_ms/msv2_to_msv4_meta.py +72 -35
  25. xradio/vis/_vis_utils/_ms/msv4_infos.py +0 -59
  26. xradio/vis/_vis_utils/_ms/msv4_sub_xdss.py +76 -9
  27. xradio/vis/_vis_utils/_ms/optimised_functions.py +0 -46
  28. xradio/vis/_vis_utils/_ms/partition_queries.py +308 -119
  29. xradio/vis/_vis_utils/_ms/partitions.py +82 -25
  30. xradio/vis/_vis_utils/_ms/subtables.py +32 -14
  31. xradio/vis/_vis_utils/_utils/partition_attrs.py +30 -11
  32. xradio/vis/_vis_utils/_utils/xds_helper.py +136 -45
  33. xradio/vis/_vis_utils/_zarr/read.py +60 -22
  34. xradio/vis/_vis_utils/_zarr/write.py +83 -9
  35. xradio/vis/_vis_utils/ms.py +48 -29
  36. xradio/vis/_vis_utils/zarr.py +44 -20
  37. xradio/vis/convert_msv2_to_processing_set.py +43 -32
  38. xradio/vis/load_processing_set.py +38 -61
  39. xradio/vis/read_processing_set.py +64 -96
  40. xradio/vis/schema.py +687 -0
  41. xradio/vis/vis_io.py +75 -43
  42. {xradio-0.0.28.dist-info → xradio-0.0.30.dist-info}/LICENSE.txt +6 -1
  43. {xradio-0.0.28.dist-info → xradio-0.0.30.dist-info}/METADATA +10 -5
  44. xradio-0.0.30.dist-info/RECORD +73 -0
  45. {xradio-0.0.28.dist-info → xradio-0.0.30.dist-info}/WHEEL +1 -1
  46. xradio/vis/model.py +0 -497
  47. xradio-0.0.28.dist-info/RECORD +0 -71
  48. {xradio-0.0.28.dist-info → xradio-0.0.30.dist-info}/top_level.txt +0 -0
xradio/schema/check.py CHANGED
@@ -5,13 +5,14 @@ import functools
5
5
 
6
6
  import xarray
7
7
  import numpy
8
+ from typeguard import check_type, TypeCheckError
8
9
 
9
10
  from xradio.schema import (
10
11
  metamodel,
12
+ bases,
11
13
  xarray_dataclass_to_array_schema,
12
14
  xarray_dataclass_to_dataset_schema,
13
- AsDataset,
14
- AsDataArray,
15
+ xarray_dataclass_to_dict_schema,
15
16
  )
16
17
 
17
18
 
@@ -110,6 +111,9 @@ class SchemaIssues(Exception):
110
111
  issues_string = "\n * ".join(repr(issue) for issue in self.issues)
111
112
  return f"\n * {issues_string}"
112
113
 
114
+ def __repr__(self):
115
+ return f"SchemaIssues({str(self)})"
116
+
113
117
  def expect(
114
118
  self, elem: typing.Optional[str] = None, ix: typing.Optional[str] = None
115
119
  ):
@@ -129,12 +133,14 @@ class SchemaIssues(Exception):
129
133
  raise self
130
134
 
131
135
 
132
- def check_array(array: xarray.DataArray, schema: metamodel.ArraySchema) -> SchemaIssues:
136
+ def check_array(
137
+ array: xarray.DataArray, schema: typing.Union[type, metamodel.ArraySchema]
138
+ ) -> SchemaIssues:
133
139
  """
134
140
  Check whether an xarray DataArray conforms to a schema
135
141
 
136
142
  :param array: DataArray to check
137
- :param schema: Schema to check against
143
+ :param schema: Schema to check against (possibly as :py:class:`AsDataset`)
138
144
  :returns: List of :py:class:`SchemaIssue`s found
139
145
  """
140
146
 
@@ -143,7 +149,7 @@ def check_array(array: xarray.DataArray, schema: metamodel.ArraySchema) -> Schem
143
149
  raise TypeError(
144
150
  f"check_array: Expected xarray.DataArray, but got {type(array)}!"
145
151
  )
146
- if type(schema) == type and issubclass(schema, AsDataArray):
152
+ if bases.is_dataarray_schema(schema):
147
153
  schema = xarray_dataclass_to_array_schema(schema)
148
154
  if not isinstance(schema, metamodel.ArraySchema):
149
155
  raise TypeError(f"check_array: Expected ArraySchema, but got {type(schema)}!")
@@ -164,13 +170,13 @@ def check_array(array: xarray.DataArray, schema: metamodel.ArraySchema) -> Schem
164
170
 
165
171
 
166
172
  def check_dataset(
167
- dataset: xarray.Dataset, schema: metamodel.DatasetSchema
173
+ dataset: xarray.Dataset, schema: typing.Union[type, metamodel.DatasetSchema]
168
174
  ) -> SchemaIssues:
169
175
  """
170
176
  Check whether an xarray DataArray conforms to a schema
171
177
 
172
178
  :param array: DataArray to check
173
- :param schema: Schema to check against (possibly as dataclass)
179
+ :param schema: Schema to check against (possibly as :py:class:`AsDataArray` dataclass)
174
180
  :returns: List of :py:class:`SchemaIssue`s found
175
181
  """
176
182
 
@@ -179,7 +185,7 @@ def check_dataset(
179
185
  raise TypeError(
180
186
  f"check_dataset: Expected xarray.Dataset, but got {type(dataset)}!"
181
187
  )
182
- if isinstance(schema, AsDataset):
188
+ if bases.is_dataset_schema(schema):
183
189
  schema = xarray_dataclass_to_dataset_schema(schema)
184
190
  if not isinstance(schema, metamodel.DatasetSchema):
185
191
  raise TypeError(
@@ -247,7 +253,7 @@ def check_dimensions(
247
253
  hint_remove = [f"'{hint}'" for hint in dims_set - best]
248
254
  hint_add = [f"'{hint}'" for hint in best - dims_set]
249
255
  if hint_remove and hint_add:
250
- message = f"Unexpected coordinates, replace {','.join(hint_add)} by {','.join(hint_remove)}?"
256
+ message = f"Unexpected coordinates, replace {','.join(hint_remove)} by {','.join(hint_add)}?"
251
257
  elif hint_remove:
252
258
  message = f"Superflous coordinate {','.join(hint_remove)}?"
253
259
  elif hint_add:
@@ -259,8 +265,8 @@ def check_dimensions(
259
265
  SchemaIssue(
260
266
  path=[("dims", None)],
261
267
  message=message,
262
- found=dims,
263
- expected=expected,
268
+ found=list(dims),
269
+ expected=list(expected),
264
270
  )
265
271
  ]
266
272
  )
@@ -276,7 +282,13 @@ def check_dtype(dtype: numpy.dtype, expected: [numpy.dtype]) -> SchemaIssues:
276
282
  """
277
283
 
278
284
  for exp_dtype in expected:
279
- if dtype == exp_dtype:
285
+ # If the expected dtype has no size (e.g. "U", a.k.a. a string of
286
+ # arbitrary length), we don't check itemsize, only kind.
287
+ if (
288
+ dtype.kind == exp_dtype.kind
289
+ and exp_dtype.itemsize == 0
290
+ or exp_dtype == dtype
291
+ ):
280
292
  return SchemaIssues()
281
293
 
282
294
  # Not sure there's anything more helpful that we can do here? Any special
@@ -287,7 +299,7 @@ def check_dtype(dtype: numpy.dtype, expected: [numpy.dtype]) -> SchemaIssues:
287
299
  path=[("dtype", None)],
288
300
  message="Wrong numpy dtype",
289
301
  found=dtype,
290
- expected=expected,
302
+ expected=list(expected),
291
303
  )
292
304
  ]
293
305
  )
@@ -296,6 +308,7 @@ def check_dtype(dtype: numpy.dtype, expected: [numpy.dtype]) -> SchemaIssues:
296
308
  def check_attributes(
297
309
  attrs: typing.Dict[str, typing.Any],
298
310
  attrs_schema: typing.List[metamodel.AttrSchemaRef],
311
+ attr_kind: str = "attrs",
299
312
  ) -> SchemaIssues:
300
313
  """
301
314
  Check whether an attribute set conforms to a schema
@@ -307,23 +320,29 @@ def check_attributes(
307
320
 
308
321
  issues = SchemaIssues()
309
322
  for attr_schema in attrs_schema:
310
-
311
323
  # Attribute missing? Note that a value of "None" is equivalent for the
312
324
  # purpose of the check
313
325
  val = attrs.get(attr_schema.name)
314
326
  if val is None:
315
327
  if not attr_schema.optional:
328
+ # Get options
329
+ if typing.get_origin(attr_schema.typ) is typing.Union:
330
+ options = typing.get_args(attr_schema.typ)
331
+ else:
332
+ options = [attr_schema.typ]
333
+
316
334
  issues.add(
317
335
  SchemaIssue(
318
- path=[("attrs", attr_schema.name)],
336
+ path=[(attr_kind, attr_schema.name)],
319
337
  message=f"Required attribute {attr_schema.name} is missing!",
338
+ expected=options,
320
339
  )
321
340
  )
322
341
  continue
323
342
 
324
343
  # Check attribute value
325
344
  issues += _check_value_union(val, attr_schema.typ).at_path(
326
- "attrs", attr_schema.name
345
+ attr_kind, attr_schema.name
327
346
  )
328
347
 
329
348
  # Extra attributes are always okay
@@ -352,7 +371,6 @@ def check_data_vars(
352
371
 
353
372
  issues = SchemaIssues()
354
373
  for data_var_schema in data_vars_schema:
355
-
356
374
  # Data_Varinate missing?
357
375
  data_var = data_vars.get(data_var_schema.name)
358
376
  if data_var is None:
@@ -382,6 +400,21 @@ def check_data_vars(
382
400
  return issues
383
401
 
384
402
 
403
+ def check_dict(
404
+ dct: dict, schema: typing.Union[type, metamodel.DictSchema]
405
+ ) -> SchemaIssues:
406
+ # Check that this is actually a dictionary
407
+ if not isinstance(dct, dict):
408
+ raise TypeError(f"check_dict: Expected dictionary, but got {type(dct)}!")
409
+ if bases.is_dict_schema(schema):
410
+ schema = xarray_dataclass_to_dict_schema(schema)
411
+ if not isinstance(schema, metamodel.DictSchema):
412
+ raise TypeError(f"check_dict: Expected DictSchema, but got {type(schema)}!")
413
+
414
+ # Check attributes
415
+ return check_attributes(dct, schema.attributes, attr_kind="")
416
+
417
+
385
418
  def _check_value(val, ann):
386
419
  """
387
420
  Check whether value satisfies annotation
@@ -394,46 +427,60 @@ def _check_value(val, ann):
394
427
  """
395
428
 
396
429
  # Is supposed to be a data array?
397
- if type(ann) == type and issubclass(ann, AsDataArray):
430
+ if bases.is_dataarray_schema(ann):
431
+ # Attempt to convert dictionaries automatically
432
+ if isinstance(val, dict):
433
+ try:
434
+ val = xarray.DataArray.from_dict(val)
435
+ except ValueError as e:
436
+ return SchemaIssues(
437
+ [
438
+ SchemaIssue(
439
+ path=[], message=str(e), expected=[ann], found=type(val)
440
+ )
441
+ ]
442
+ )
443
+
398
444
  if not isinstance(val, xarray.DataArray):
399
- return SchemaIssues(
400
- [
401
- SchemaIssue(
402
- path=[],
403
- message="Unexpected type",
404
- expected=[xarray.DataArray],
405
- found=type(val),
406
- )
407
- ]
408
- )
445
+ # Fall through to plain type check
446
+ ann = xarray.DataArray
409
447
  else:
410
448
  return check_array(val, ann)
411
449
 
412
450
  # Is supposed to be a dataset?
413
- if type(ann) == type and issubclass(ann, AsDataset):
451
+ if bases.is_dataset_schema(ann):
452
+ # Attempt to convert dictionaries automatically
453
+ if isinstance(val, dict):
454
+ try:
455
+ val = xarray.Dataset.from_dict(val)
456
+ except ValueError as e:
457
+ return SchemaIssues(
458
+ [
459
+ SchemaIssue(
460
+ path=[], message=str(t), expected=[ann], found=type(val)
461
+ )
462
+ ]
463
+ )
414
464
  if not isinstance(val, xarray.Dataset):
415
- return SchemaIssues(
416
- [
417
- SchemaIssue(
418
- path=[],
419
- message="Unexpected type",
420
- expected=[xarray.DataArray],
421
- found=type(val),
422
- )
423
- ]
424
- )
465
+ # Fall through to plain type check
466
+ ann = xarray.Dataset
425
467
  else:
426
468
  return check_dataset(val, ann)
427
469
 
428
- # Otherwise straight type check (TODO - be more fancy, possibly by
429
- # importing from Typeguard module? Don't want to overdo it...)
430
- if not isinstance(val, ann):
470
+ # Is supposed to be a dictionary?
471
+ if bases.is_dict_schema(ann):
472
+ if not isinstance(val, dict):
473
+ # Fall through to plain type check
474
+ ann = dict
475
+ else:
476
+ return check_dict(val, ann)
477
+
478
+ # Otherwise straight type check using typeguard
479
+ try:
480
+ check_type(val, ann)
481
+ except TypeCheckError as t:
431
482
  return SchemaIssues(
432
- [
433
- SchemaIssue(
434
- path=[], message="Unexpected type", expected=[ann], found=type(val)
435
- )
436
- ]
483
+ [SchemaIssue(path=[], message=str(t), expected=[ann], found=type(val))]
437
484
  )
438
485
 
439
486
  return SchemaIssues()
@@ -471,13 +518,8 @@ def _check_value_union(val, ann):
471
518
  if args_issues is None:
472
519
  args_issues = arg_issues
473
520
 
474
- # Fancy merging of expected options (for "unexpected type")
475
- elif (
476
- len(args_issues) == 1
477
- and len(arg_issues) == 1
478
- and args_issues[0].message == arg_issues[0].message
479
- ):
480
-
521
+ # Crude merging of expected options (for "unexpected type")
522
+ elif len(args_issues) == 1 and len(arg_issues) == 1:
481
523
  args_issues[0].expected += arg_issues[0].expected
482
524
 
483
525
  # Return representative issues list
@@ -511,7 +553,6 @@ def schema_checked(fn, check_parameters: bool = True, check_return: bool = True)
511
553
 
512
554
  @functools.wraps(fn)
513
555
  def _check_fn(*args, **kwargs):
514
-
515
556
  # Hide this function in pytest tracebacks
516
557
  __tracebackhide__ = True
517
558
 
@@ -1,10 +1,13 @@
1
1
  from typing import get_type_hints, get_args
2
- from .typing import get_dims, get_dtypes, get_role, Role, get_annotated, is_optional
2
+ from .typing import get_dims, get_types, get_role, Role, get_annotated, is_optional
3
3
 
4
4
  import typing
5
5
  import inspect
6
6
  import ast
7
7
  import dataclasses
8
+ import numpy
9
+ import itertools
10
+ import textwrap
8
11
 
9
12
  from xradio.schema.metamodel import *
10
13
 
@@ -18,8 +21,11 @@ def extract_field_docstrings(klass):
18
21
  """
19
22
 
20
23
  # Parse class body
21
- src = inspect.getsource(klass)
22
- module = ast.parse(src)
24
+ try:
25
+ src = inspect.getsource(klass)
26
+ except OSError:
27
+ return {}
28
+ module = ast.parse(textwrap.dedent(src))
23
29
 
24
30
  # Expect module containing a class definition
25
31
  if not isinstance(module, ast.Module) or len(module.body) != 1:
@@ -31,7 +37,6 @@ def extract_field_docstrings(klass):
31
37
  # Go through body, collect dostrings
32
38
  docstrings = {}
33
39
  for i, assign in enumerate(cls.body):
34
-
35
40
  # Handle both annotated and unannotated case
36
41
  if isinstance(assign, ast.AnnAssign):
37
42
  if not isinstance(assign.target, ast.Name):
@@ -78,7 +83,6 @@ def extract_xarray_dataclass(klass):
78
83
  data_vars = []
79
84
  attributes = []
80
85
  for field in dataclasses.fields(klass):
81
-
82
86
  # Get field "role" (coordinate / data variable / attribute) from its
83
87
  # type hint
84
88
  typ = type_hints[field.name]
@@ -111,7 +115,6 @@ def extract_xarray_dataclass(klass):
111
115
  # Defined using a dataclass, i.e. Coordof/Dataof?
112
116
  dataclass = typing.get_args(get_annotated(typ))[0]
113
117
  if dataclasses.is_dataclass(dataclass):
114
-
115
118
  # Recursively get array schema for data class
116
119
  arr_schema = xarray_dataclass_to_array_schema(dataclass)
117
120
  arr_schema_fields = {
@@ -129,24 +132,51 @@ def extract_xarray_dataclass(klass):
129
132
  )
130
133
 
131
134
  else:
132
-
133
- # Assume that it's an "inline" declaration using "Coord"/"Data"
134
- schema_ref = ArraySchemaRef(
135
- name=field.name,
136
- optional=is_optional(typ),
137
- default=field.default,
138
- docstring=field_docstrings.get(field.name),
139
- schema_name=f"{klass.__module__}.{klass.__qualname__}.{field.name}",
140
- dimensions=get_dims(typ),
141
- dtypes=get_dtypes(typ),
142
- coordinates=[],
143
- attributes=[],
144
- class_docstring=None,
145
- data_docstring=None,
146
- )
135
+ # Get dimensions and dtypes
136
+ dims = get_dims(typ)
137
+ types = get_types(typ)
138
+
139
+ # Is types a (single) dataclass?
140
+ if len(types) == 1 and dataclasses.is_dataclass(types[0]):
141
+ # Recursively get array schema for data class
142
+ arr_schema = xarray_dataclass_to_array_schema(types[0])
143
+
144
+ # Prepend dimensions to array schema
145
+ combined_dimensions = [
146
+ dims1 + dims2
147
+ for dims1, dims2 in itertools.product(dims, arr_schema.dimensions)
148
+ ]
149
+
150
+ # Repackage as reference
151
+ arr_schema_fields = {
152
+ f.name: getattr(arr_schema, f.name)
153
+ for f in dataclasses.fields(ArraySchema)
154
+ }
155
+ arr_schema_fields["dimensions"] = combined_dimensions
156
+ schema_ref = ArraySchemaRef(
157
+ name=field.name,
158
+ optional=is_optional(typ),
159
+ default=field.default,
160
+ docstring=field_docstrings.get(field.name),
161
+ **arr_schema_fields,
162
+ )
163
+ else:
164
+ # Assume that it's an "inline" declaration using "Coord"/"Data"
165
+ schema_ref = ArraySchemaRef(
166
+ name=field.name,
167
+ optional=is_optional(typ),
168
+ default=field.default,
169
+ docstring=field_docstrings.get(field.name),
170
+ schema_name=f"{klass.__module__}.{klass.__qualname__}.{field.name}",
171
+ dimensions=dims,
172
+ dtypes=[numpy.dtype(typ) for typ in types],
173
+ coordinates=[],
174
+ attributes=[],
175
+ class_docstring=None,
176
+ data_docstring=None,
177
+ )
147
178
 
148
179
  if is_coord:
149
-
150
180
  # Make sure that it is valid to use as a coordinate - i.e. we don't
151
181
  # have "recursive" (?!) coordinate definitions
152
182
  if not schema_ref.is_coord():
@@ -171,23 +201,34 @@ def xarray_dataclass_to_array_schema(klass):
171
201
  refer to using CoordOf or DataOf
172
202
  """
173
203
 
204
+ # Cached?
205
+ if hasattr(klass, "__xradio_array_schema"):
206
+ return klass.__xradio_array_schema
207
+
174
208
  # Extract from data class
175
209
  coordinates, data_vars, attributes = extract_xarray_dataclass(klass)
176
210
 
177
211
  # For a dataclass there must be exactly one data variable
178
- # (typically called "data", but we don't check that)
179
212
  if not data_vars:
180
213
  raise ValueError(
181
- f"Found no data declaration in (supposed) data darray class {klass.__name__}!"
214
+ f"Found no data declaration in (supposed) data array class {klass.__name__}!"
182
215
  )
183
216
  if len(data_vars) > 1:
184
217
  raise ValueError(
185
218
  f"Found multiple data variables ({', '.join(v.name for v in data_vars)})"
186
- f" in supposed data darray class {klass.__name__}!"
219
+ f" in supposed data array class {klass.__name__}!"
220
+ )
221
+
222
+ # Check that data variable is named "data". This is important for this to
223
+ # match up with parameters to xarray.DataArray() later (see bases.AsArray)
224
+ if data_vars[0].name != "data":
225
+ raise ValueError(
226
+ f"Data variable in data array class {klass.__name__} "
227
+ f'must be called "data", not {data_vars[0].name}!'
187
228
  )
188
229
 
189
230
  # Make class
190
- return ArraySchema(
231
+ schema = ArraySchema(
191
232
  schema_name=f"{klass.__module__}.{klass.__qualname__}",
192
233
  dimensions=data_vars[0].dimensions,
193
234
  dtypes=data_vars[0].dtypes,
@@ -196,6 +237,8 @@ def xarray_dataclass_to_array_schema(klass):
196
237
  class_docstring=inspect.cleandoc(klass.__doc__),
197
238
  data_docstring=data_vars[0].docstring,
198
239
  )
240
+ klass.__xradio_array_schema = schema
241
+ return schema
199
242
 
200
243
 
201
244
  def xarray_dataclass_to_dataset_schema(klass):
@@ -206,6 +249,10 @@ def xarray_dataclass_to_dataset_schema(klass):
206
249
  refer to using CoordOf or DataOf
207
250
  """
208
251
 
252
+ # Cached?
253
+ if hasattr(klass, "__xradio_dataset_schema"):
254
+ return klass.__xradio_dataset_schema
255
+
209
256
  # Extract from data class
210
257
  coordinates, data_vars, attributes = extract_xarray_dataclass(klass)
211
258
 
@@ -242,7 +289,7 @@ def xarray_dataclass_to_dataset_schema(klass):
242
289
  dimensions = [[dim for dim in all_dimensions if dim in dims] for dims in dimensions]
243
290
 
244
291
  # Make class
245
- return DatasetSchema(
292
+ schema = DatasetSchema(
246
293
  schema_name=f"{klass.__module__}.{klass.__qualname__}",
247
294
  dimensions=dimensions,
248
295
  coordinates=coordinates,
@@ -250,3 +297,52 @@ def xarray_dataclass_to_dataset_schema(klass):
250
297
  attributes=attributes,
251
298
  class_docstring=inspect.cleandoc(klass.__doc__),
252
299
  )
300
+ klass.__xradio_dataset_schema = schema
301
+ return schema
302
+
303
+
304
+ def xarray_dataclass_to_dict_schema(klass):
305
+ """
306
+ Convert an xarray-dataclass style schema dataclass to an DictSchema
307
+
308
+ This should work on any class annotated with :py:func:`~xradio.schema.bases.dict_schema`
309
+ """
310
+
311
+ # Cached?
312
+ if hasattr(klass, "__xradio_dict_schema"):
313
+ return klass.__xradio_dict_schema
314
+
315
+ # Get docstrings and type hints
316
+ field_docstrings = extract_field_docstrings(klass)
317
+ type_hints = get_type_hints(klass, include_extras=True)
318
+ attributes = []
319
+ for field in dataclasses.fields(klass):
320
+ typ = type_hints[field.name]
321
+
322
+ # Handle optional value: Strip "None" from the types
323
+ optional = is_optional(typ)
324
+ if optional:
325
+ typs = [typ for typ in get_args(typ) if typ is not None.__class__]
326
+ if len(typs) == 1:
327
+ typ = typs[0]
328
+ else:
329
+ typ = typing.Union.__getitem__[tuple(typs)]
330
+
331
+ attributes.append(
332
+ AttrSchemaRef(
333
+ name=field.name,
334
+ typ=typ,
335
+ optional=optional,
336
+ default=field.default,
337
+ docstring=field_docstrings.get(field.name),
338
+ )
339
+ )
340
+
341
+ # Return
342
+ schema = DictSchema(
343
+ schema_name=f"{klass.__module__}.{klass.__qualname__}",
344
+ attributes=attributes,
345
+ class_docstring=inspect.cleandoc(klass.__doc__),
346
+ )
347
+ klass.__xradio_dict_schema = schema
348
+ return schema
@@ -6,10 +6,11 @@ __all__ = [
6
6
  "ArraySchema",
7
7
  "ArraySchemaRef",
8
8
  "DatasetSchema",
9
+ "DictSchema",
9
10
  ]
10
11
 
11
12
 
12
- @dataclass
13
+ @dataclass(frozen=True)
13
14
  class AttrSchemaRef:
14
15
  """
15
16
  Schema information about an attribute as referenced from an array or
@@ -34,7 +35,7 @@ class AttrSchemaRef:
34
35
  """Documentation string of attribute reference"""
35
36
 
36
37
 
37
- @dataclass
38
+ @dataclass(frozen=True)
38
39
  class ArraySchema:
39
40
  """
40
41
  Schema for xarray data array
@@ -83,7 +84,7 @@ class ArraySchema:
83
84
  return req_dims
84
85
 
85
86
 
86
- @dataclass
87
+ @dataclass(frozen=True)
87
88
  class ArraySchemaRef(ArraySchema):
88
89
  """
89
90
  Schema for xarray data array as referenced from a dataset schema
@@ -102,7 +103,7 @@ class ArraySchemaRef(ArraySchema):
102
103
  """Documentation string of array reference"""
103
104
 
104
105
 
105
- @dataclass
106
+ @dataclass(frozen=True)
106
107
  class DatasetSchema:
107
108
  """
108
109
  Schema for an xarray dataset
@@ -122,3 +123,19 @@ class DatasetSchema:
122
123
 
123
124
  class_docstring: typing.Optional[str]
124
125
  """Documentation string of class"""
126
+
127
+
128
+ @dataclass(frozen=True)
129
+ class DictSchema:
130
+ """
131
+ Schema for a simple dictionary
132
+ """
133
+
134
+ schema_name: str
135
+ """(Class) name of the schema"""
136
+
137
+ attributes: [AttrSchemaRef]
138
+ """List of attributes"""
139
+
140
+ class_docstring: typing.Optional[str]
141
+ """Documentation string of class"""
xradio/schema/typing.py CHANGED
@@ -29,16 +29,24 @@ from typing import (
29
29
  Protocol,
30
30
  )
31
31
 
32
+ from typing import Union
33
+
32
34
  try:
35
+ # Python 3.10 forward: TypeAlias, ParamSpec are standard, and there is the
36
+ # "a | b" UnionType alternative to "Union[a,b]"
33
37
  from typing import TypeAlias, ParamSpec
34
38
  from types import UnionType
39
+
40
+ HAVE_UNIONTYPE = True
35
41
  except ImportError:
36
- # Required for Python 3.9
42
+ # Python 3.9: Get TypeAlias, ParamSpec from typing_extensions, no support
43
+ # for "a | b"
37
44
  from typing_extensions import (
38
45
  TypeAlias,
39
46
  ParamSpec,
40
47
  )
41
- from typing import Union as UnionType
48
+
49
+ HAVE_UNIONTYPE = False
42
50
  import numpy as np
43
51
  from itertools import chain
44
52
  from enum import Enum
@@ -295,7 +303,9 @@ def get_dims(tp: Any) -> List[Dims]:
295
303
  raise TypeError(f"Could not find any dims in {tp!r}.")
296
304
 
297
305
  # List of allowed dtypes (might just be one)
298
- if get_origin(dims) is UnionType:
306
+ if get_origin(dims) is Union:
307
+ dims_in = get_args(dims)
308
+ elif HAVE_UNIONTYPE and get_origin(dims) is UnionType:
299
309
  dims_in = get_args(dims)
300
310
  else:
301
311
  dims_in = [dims]
@@ -326,35 +336,40 @@ def get_dims(tp: Any) -> List[Dims]:
326
336
  return dims_out
327
337
 
328
338
 
329
- def get_dtypes(tp: Any) -> List[AnyDType]:
330
- """Extract a NumPy data types (dtype)."""
339
+ def get_types(tp: Any) -> List[AnyDType]:
340
+ """Extract data types from type annotation
341
+
342
+ E.g. Coord[..., Type1 | Type2 | ...] or Data[..., Type1 | Type2 | ...]
343
+
344
+ """
331
345
  try:
332
- dtype = get_args(get_args(get_annotated(tp))[1])[0]
346
+ typ = get_args(get_args(get_annotated(tp))[1])[0]
333
347
  except TypeError:
334
348
  raise TypeError(f"Could not find any dtype in {tp!r}.")
335
349
 
336
350
  # List of allowed dtypes (might just be one)
337
- if get_origin(dtype) is UnionType:
338
- dtypes_in = get_args(dtype)
351
+ if get_origin(typ) is Union:
352
+ types_in = get_args(typ)
353
+ elif HAVE_UNIONTYPE and get_origin(typ) is UnionType:
354
+ types_in = get_args(typ)
339
355
  else:
340
- dtypes_in = [dtype]
341
-
342
- dtypes_out = []
343
- for dt in dtypes_in:
356
+ types_in = [typ]
344
357
 
358
+ types_out = []
359
+ for dt in types_in:
345
360
  # Handle case that we want to allow "Any"
346
361
  if dt is Any or dt is type(None):
347
- dtypes_out.append(None)
362
+ types_out.append(None)
348
363
  continue
349
364
 
350
- # Allow specifying dtype as literal (e.g. string)
365
+ # Allow specifying type as literal (e.g. string)
351
366
  elif get_origin(dt) is Literal:
352
- dtype = get_args(dt)[0]
367
+ dt = get_args(dt)[0]
353
368
 
354
- # Construct numpy dtype
355
- dtypes_out.append(np.dtype(dt))
369
+ # Return type
370
+ types_out.append(dt)
356
371
 
357
- return dtypes_out
372
+ return types_out
358
373
 
359
374
 
360
375
  def get_name(tp: Any, default: Hashable = None) -> Hashable: