xradio 0.0.56__py3-none-any.whl → 0.0.58__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. xradio/__init__.py +2 -2
  2. xradio/_utils/_casacore/casacore_from_casatools.py +12 -2
  3. xradio/_utils/_casacore/tables.py +1 -0
  4. xradio/_utils/coord_math.py +22 -23
  5. xradio/_utils/dict_helpers.py +76 -11
  6. xradio/_utils/schema.py +5 -2
  7. xradio/_utils/zarr/common.py +1 -73
  8. xradio/image/_util/_casacore/xds_from_casacore.py +49 -33
  9. xradio/image/_util/_casacore/xds_to_casacore.py +41 -14
  10. xradio/image/_util/_fits/xds_from_fits.py +146 -35
  11. xradio/image/_util/casacore.py +4 -3
  12. xradio/image/_util/common.py +4 -4
  13. xradio/image/_util/image_factory.py +8 -8
  14. xradio/image/image.py +45 -5
  15. xradio/measurement_set/__init__.py +19 -9
  16. xradio/measurement_set/_utils/__init__.py +1 -3
  17. xradio/measurement_set/_utils/_msv2/__init__.py +0 -0
  18. xradio/measurement_set/_utils/_msv2/_tables/read.py +17 -76
  19. xradio/measurement_set/_utils/_msv2/_tables/read_main_table.py +2 -685
  20. xradio/measurement_set/_utils/_msv2/conversion.py +123 -145
  21. xradio/measurement_set/_utils/_msv2/create_antenna_xds.py +9 -16
  22. xradio/measurement_set/_utils/_msv2/create_field_and_source_xds.py +125 -221
  23. xradio/measurement_set/_utils/_msv2/msv2_to_msv4_meta.py +1 -2
  24. xradio/measurement_set/_utils/_msv2/msv4_info_dicts.py +8 -7
  25. xradio/measurement_set/_utils/_msv2/msv4_sub_xdss.py +27 -72
  26. xradio/measurement_set/_utils/_msv2/partition_queries.py +1 -261
  27. xradio/measurement_set/_utils/_msv2/subtables.py +0 -107
  28. xradio/measurement_set/_utils/_utils/interpolate.py +60 -0
  29. xradio/measurement_set/_utils/_zarr/encoding.py +2 -7
  30. xradio/measurement_set/convert_msv2_to_processing_set.py +0 -2
  31. xradio/measurement_set/load_processing_set.py +2 -2
  32. xradio/measurement_set/measurement_set_xdt.py +14 -14
  33. xradio/measurement_set/open_processing_set.py +1 -3
  34. xradio/measurement_set/processing_set_xdt.py +41 -835
  35. xradio/measurement_set/schema.py +95 -122
  36. xradio/schema/check.py +91 -97
  37. xradio/schema/dataclass.py +159 -22
  38. xradio/schema/export.py +99 -0
  39. xradio/schema/metamodel.py +51 -16
  40. xradio/schema/typing.py +5 -5
  41. {xradio-0.0.56.dist-info → xradio-0.0.58.dist-info}/METADATA +2 -1
  42. xradio-0.0.58.dist-info/RECORD +65 -0
  43. {xradio-0.0.56.dist-info → xradio-0.0.58.dist-info}/WHEEL +1 -1
  44. xradio/image/_util/fits.py +0 -13
  45. xradio/measurement_set/_utils/_msv2/_tables/load.py +0 -66
  46. xradio/measurement_set/_utils/_msv2/_tables/load_main_table.py +0 -490
  47. xradio/measurement_set/_utils/_msv2/_tables/read_subtables.py +0 -398
  48. xradio/measurement_set/_utils/_msv2/_tables/write.py +0 -323
  49. xradio/measurement_set/_utils/_msv2/_tables/write_exp_api.py +0 -388
  50. xradio/measurement_set/_utils/_msv2/chunks.py +0 -115
  51. xradio/measurement_set/_utils/_msv2/descr.py +0 -165
  52. xradio/measurement_set/_utils/_msv2/msv2_msv3.py +0 -7
  53. xradio/measurement_set/_utils/_msv2/partitions.py +0 -392
  54. xradio/measurement_set/_utils/_utils/cds.py +0 -40
  55. xradio/measurement_set/_utils/_utils/xds_helper.py +0 -404
  56. xradio/measurement_set/_utils/_zarr/read.py +0 -263
  57. xradio/measurement_set/_utils/_zarr/write.py +0 -329
  58. xradio/measurement_set/_utils/msv2.py +0 -106
  59. xradio/measurement_set/_utils/zarr.py +0 -133
  60. xradio-0.0.56.dist-info/RECORD +0 -78
  61. {xradio-0.0.56.dist-info → xradio-0.0.58.dist-info}/licenses/LICENSE.txt +0 -0
  62. {xradio-0.0.56.dist-info → xradio-0.0.58.dist-info}/top_level.txt +0 -0
xradio/schema/check.py CHANGED
@@ -1,3 +1,4 @@
1
+ import builtins
1
2
  import dataclasses
2
3
  import typing
3
4
  import inspect
@@ -15,6 +16,8 @@ from xradio.schema import (
15
16
  xarray_dataclass_to_dataset_schema,
16
17
  xarray_dataclass_to_dict_schema,
17
18
  )
19
+ from xradio.schema.dataclass import value_schema
20
+ from xradio.schema.metamodel import AttrSchemaRef, ValueSchema
18
21
 
19
22
 
20
23
  @dataclasses.dataclass
@@ -295,7 +298,8 @@ def check_dtype(dtype: numpy.dtype, expected: [numpy.dtype]) -> SchemaIssues:
295
298
  :returns: List of :py:class:`SchemaIssue`s found
296
299
  """
297
300
 
298
- for exp_dtype in expected:
301
+ for exp_dtype_str in expected:
302
+ exp_dtype = numpy.dtype(exp_dtype_str)
299
303
  # If the expected dtype has no size (e.g. "U", a.k.a. a string of
300
304
  # arbitrary length), we don't check itemsize, only kind.
301
305
  if (
@@ -312,7 +316,7 @@ def check_dtype(dtype: numpy.dtype, expected: [numpy.dtype]) -> SchemaIssues:
312
316
  SchemaIssue(
313
317
  path=[("dtype", None)],
314
318
  message="Wrong numpy dtype",
315
- found=dtype,
319
+ found=dtype.str,
316
320
  expected=list(expected),
317
321
  )
318
322
  ]
@@ -334,30 +338,23 @@ def check_attributes(
334
338
 
335
339
  issues = SchemaIssues()
336
340
  for attr_schema in attrs_schema:
337
- # Attribute missing? Note that a value of "None" is equivalent for the
338
- # purpose of the check
341
+ # Attribute missing is equivalent to a value of "None" is
342
+ # equivalent for the purpose of the check
339
343
  val = attrs.get(attr_schema.name)
340
344
  if val is None:
341
345
  if not attr_schema.optional:
342
- # Get options
343
- if typing.get_origin(attr_schema.typ) is typing.Union:
344
- options = typing.get_args(attr_schema.typ)
345
- else:
346
- options = [attr_schema.typ]
347
-
348
346
  issues.add(
349
347
  SchemaIssue(
350
348
  path=[(attr_kind, attr_schema.name)],
351
- message=f"Required attribute {attr_schema.name} is missing!",
352
- expected=options,
349
+ message="Non-optional attribute is missing!",
350
+ found=None,
351
+ expected=[attr_schema.type],
353
352
  )
354
353
  )
355
354
  continue
356
355
 
357
- # Check attribute value
358
- issues += _check_value_union(val, attr_schema.typ).at_path(
359
- attr_kind, attr_schema.name
360
- )
356
+ # Check actual value
357
+ issues += _check_value(val, attr_schema).at_path(attr_kind, attr_schema.name)
361
358
 
362
359
  # Extra attributes are always okay
363
360
 
@@ -385,7 +382,6 @@ def check_data_vars(
385
382
 
386
383
  issues = SchemaIssues()
387
384
  for data_var_schema in data_vars_schema:
388
-
389
385
  allow_mutiple_versions = False
390
386
  for attr in data_var_schema.attributes:
391
387
  if hasattr(attr, "name"):
@@ -450,125 +446,103 @@ def check_dict(
450
446
  return check_attributes(dct, schema.attributes, attr_kind="")
451
447
 
452
448
 
453
- def _check_value(val, ann):
449
+ def _check_value(val: typing.Any, schema: metamodel.ValueSchema):
454
450
  """
455
451
  Check whether value satisfies annotation
456
452
 
457
453
  If the annotation is a data array or dataset schema, it will be checked.
458
454
 
459
455
  :param val: Value to check
460
- :param ann: Type annotation of value
456
+ :param schema: Schema of value
461
457
  :returns: Schema issues
462
458
  """
463
459
 
460
+ # Unspecified?
461
+ if schema.type is None:
462
+ return SchemaIssues()
463
+
464
+ # Optional?
465
+ if schema.optional and val is None:
466
+ return SchemaIssues()
467
+
464
468
  # Is supposed to be a data array?
465
- if bases.is_dataarray_schema(ann):
469
+ if schema.type == "dataarray":
466
470
  # Attempt to convert dictionaries automatically
467
471
  if isinstance(val, dict):
468
472
  try:
469
473
  val = xarray.DataArray.from_dict(val)
470
474
  except ValueError as e:
475
+ expected = [DataArray]
476
+ if schema.optional:
477
+ expected.append(type(None))
471
478
  return SchemaIssues(
472
479
  [
473
480
  SchemaIssue(
474
- path=[], message=str(e), expected=[ann], found=type(val)
481
+ path=[], message=str(e), expected=expected, found=type(val)
475
482
  )
476
483
  ]
477
484
  )
478
485
  except TypeError as e:
486
+ expected = [DataArray]
487
+ if schema.optional:
488
+ expected.append(type(None))
479
489
  return SchemaIssues(
480
490
  [
481
491
  SchemaIssue(
482
- path=[], message=str(e), expected=[ann], found=type(val)
492
+ path=[], message=str(e), expected=expected, found=type(val)
483
493
  )
484
494
  ]
485
495
  )
486
496
 
487
497
  if not isinstance(val, xarray.DataArray):
488
498
  # Fall through to plain type check
489
- ann = xarray.DataArray
490
- else:
491
- return check_array(val, ann)
492
-
493
- # Is supposed to be a dataset?
494
- if bases.is_dataset_schema(ann):
495
- # Attempt to convert dictionaries automatically
496
- if isinstance(val, dict):
497
- try:
498
- val = xarray.Dataset.from_dict(val)
499
- except ValueError as e:
500
- return SchemaIssues(
501
- [
502
- SchemaIssue(
503
- path=[], message=str(t), expected=[ann], found=type(val)
504
- )
505
- ]
506
- )
507
- if not isinstance(val, xarray.Dataset):
508
- # Fall through to plain type check
509
- ann = xarray.Dataset
499
+ type_to_check = xarray.DataArray
510
500
  else:
511
- return check_dataset(val, ann)
501
+ return check_array(val, schema.array_schema)
512
502
 
513
503
  # Is supposed to be a dictionary?
514
- if bases.is_dict_schema(ann):
504
+ elif schema.type == "dict":
515
505
  if not isinstance(val, dict):
516
506
  # Fall through to plain type check
517
- ann = dict
507
+ type_to_check = dict
518
508
  else:
519
- return check_dict(val, ann)
509
+ return check_dict(val, schema.dict_schema)
510
+
511
+ elif schema.type == "list[str]":
512
+ type_to_check = typing.List[str]
513
+ elif schema.type in ["bool", "str", "int", "float"]:
514
+ type_to_check = getattr(builtins, schema.type)
515
+ else:
516
+ raise ValueError(f"Invalid typ_name in schema: {schema.type}")
520
517
 
521
518
  # Otherwise straight type check using typeguard
522
519
  try:
523
- check_type(val, ann)
520
+ check_type(val, type_to_check)
524
521
  except TypeCheckError as t:
522
+ expected = [type_to_check]
523
+ if schema.optional:
524
+ expected.append(type(None))
525
525
  return SchemaIssues(
526
- [SchemaIssue(path=[], message=str(t), expected=[ann], found=type(val))]
526
+ [SchemaIssue(path=[], message=str(t), expected=expected, found=type(val))]
527
527
  )
528
528
 
529
- return SchemaIssues()
530
-
531
-
532
- def _check_value_union(val, ann):
533
- """
534
- Check whether value satisfies annotations, including union types
535
-
536
- If the annotation is a data array or dataset schema, it will be checked.
537
-
538
- :param val: Value to check
539
- :param ann: Type annotation of value
540
- :returns: Schema issues
541
- """
542
-
543
- if ann is None or ann is inspect.Signature.empty:
544
- return SchemaIssues()
545
-
546
- # Account for union types (this especially catches "Optional")
547
- if typing.get_origin(ann) is typing.Union:
548
- options = typing.get_args(ann)
549
- else:
550
- options = [ann]
551
-
552
- # Go through options, try to find one without issues
553
- args_issues = None
554
- okay = False
555
- for option in options:
556
- arg_issues = _check_value(val, option)
557
- # We can immediately return if we find no issues with
558
- # some schema check
559
- if not arg_issues:
560
- return SchemaIssues()
561
- if args_issues is None:
562
- args_issues = arg_issues
563
-
564
- # Crude merging of expected options (for "unexpected type")
565
- elif len(args_issues) == 1 and len(arg_issues) == 1:
566
- args_issues[0].expected += arg_issues[0].expected
529
+ # List of literals given?
530
+ if schema.literal is not None:
531
+ for lit in schema.literal:
532
+ if val == lit:
533
+ return SchemaIssues()
534
+ return SchemaIssues(
535
+ [
536
+ SchemaIssue(
537
+ path=[],
538
+ message=f"Disallowed literal value!",
539
+ expected=schema.literal,
540
+ found=val,
541
+ )
542
+ ]
543
+ )
567
544
 
568
- # Return representative issues list
569
- if not args_issues:
570
- raise ValueError("Empty union set?")
571
- return args_issues
545
+ return SchemaIssues()
572
546
 
573
547
 
574
548
  _DATASET_TYPES = {}
@@ -591,7 +565,7 @@ def register_dataset_type(schema: metamodel.DatasetSchema):
591
565
  continue
592
566
 
593
567
  # Type should be a kind of literal
594
- if typing.get_origin(attr.typ) is not typing.Literal:
568
+ if attr.literal is None:
595
569
  warnings.warn(
596
570
  f"In dataset schema {schema.schema_name}:"
597
571
  'Attribute "type" should be a literal!'
@@ -599,7 +573,12 @@ def register_dataset_type(schema: metamodel.DatasetSchema):
599
573
  continue
600
574
 
601
575
  # Register type names
602
- for typ in typing.get_args(attr.typ):
576
+ for typ in attr.literal:
577
+ assert isinstance(typ, str), (
578
+ f"In dataset schema {schema.schema_name}:"
579
+ 'Attribute "type" should be a literal giving '
580
+ "names of schema!"
581
+ )
603
582
  _DATASET_TYPES[typ] = schema
604
583
 
605
584
 
@@ -621,7 +600,6 @@ def check_datatree(
621
600
  # Loop through all groups in datatree
622
601
  issues = SchemaIssues()
623
602
  for xds_name in datatree.groups:
624
-
625
603
  # Ignore any leaf without data
626
604
  node = datatree[xds_name]
627
605
  if not node.has_data:
@@ -679,7 +657,7 @@ def schema_checked(fn, check_parameters: bool = True, check_return: bool = True)
679
657
  @functools.wraps(fn)
680
658
  def _check_fn(*args, **kwargs):
681
659
  # Hide this function in pytest tracebacks
682
- __tracebackhide__ = True
660
+ # __tracebackhide__ = True
683
661
 
684
662
  # Bind parameters, collect (potential) issues
685
663
  bound = signature.bind(*args, **kwargs)
@@ -689,7 +667,15 @@ def schema_checked(fn, check_parameters: bool = True, check_return: bool = True)
689
667
  continue
690
668
 
691
669
  # Get annotation
692
- issues += _check_value_union(val, anns.get(arg)).at_path(arg)
670
+ vschema = value_schema(anns.get(arg), "function", arg)
671
+ pseudo_attr_schema = AttrSchemaRef(
672
+ name=arg,
673
+ **{
674
+ fld.name: getattr(vschema, fld.name)
675
+ for fld in dataclasses.fields(ValueSchema)
676
+ },
677
+ )
678
+ issues += _check_value(val, pseudo_attr_schema).at_path(arg)
693
679
 
694
680
  # Any issues found? raise
695
681
  issues.expect()
@@ -699,7 +685,15 @@ def schema_checked(fn, check_parameters: bool = True, check_return: bool = True)
699
685
 
700
686
  # Check return
701
687
  if check_return:
702
- issues = _check_value_union(val, signature.return_annotation)
688
+ vschema = value_schema(anns.get(arg), "function", "return")
689
+ pseudo_attr_schema = AttrSchemaRef(
690
+ name="return",
691
+ **{
692
+ fld.name: getattr(vschema, fld.name)
693
+ for fld in dataclasses.fields(ValueSchema)
694
+ },
695
+ )
696
+ issues = _check_value(val, pseudo_attr_schema)
703
697
  issues.at_path("return").expect()
704
698
 
705
699
  # Check return value
@@ -76,7 +76,6 @@ def _check_invalid_dims(
76
76
 
77
77
  # Filter out dimension possibilities with undefined coordinates
78
78
  valid_dims = [ds for ds in dims if set(ds).issubset(all_coord_names)]
79
- # print(f"{klass_name}.{field_name}", valid_dims, dims, all_coord_names)
80
79
 
81
80
  # Raise an exception if this makes the dimension set impossible
82
81
  if dims and not valid_dims:
@@ -88,6 +87,132 @@ def _check_invalid_dims(
88
87
  return valid_dims
89
88
 
90
89
 
90
+ def value_schema(ann: typing.Any, klass_name: str, field_name: str) -> "ValueSchema":
91
+ """
92
+ Take attribute type annotation and convert into type name and
93
+ - optionally - a list of literal allowed values
94
+
95
+ :param ann: Annotation
96
+ :param klass_name: Name of class where annotation origins from
97
+ :param field_name: Name of field where annotation origins from
98
+ :returns: ValueSchema
99
+ """
100
+
101
+ # No annotation?
102
+ if ann is None:
103
+ return ValueSchema(None)
104
+
105
+ # Optional?
106
+ if is_optional(ann):
107
+
108
+ # Optional is actually represented as a union... Construct
109
+ # same union type without the "None" type.
110
+ typs = [typ for typ in get_args(ann) if typ is not None.__class__]
111
+ if len(typs) == 1:
112
+ typ = typs[0]
113
+ else:
114
+ raise ValueError(
115
+ f"In '{klass_name}', field '{field_name}' has"
116
+ f" a union type, which is not allowed!"
117
+ )
118
+
119
+ # Convert to schema recursively
120
+ vschema = value_schema(typ, klass_name, field_name)
121
+ vschema.optional = True
122
+ return vschema
123
+
124
+ # Is a type?
125
+ if isinstance(ann, type):
126
+ # Array type?
127
+ if hasattr(ann, "__xradio_array_schema"):
128
+ return ValueSchema("dataarray", array_schema=ann.__xradio_array_schema)
129
+
130
+ # Dictionary type?
131
+ if hasattr(ann, "__xradio_dict_schema"):
132
+ return ValueSchema("dict", dict_schema=ann.__xradio_dict_schema)
133
+
134
+ # Check that it is an allowable type
135
+ if ann not in [bool, str, int, float, bool]:
136
+ raise ValueError(
137
+ f"In '{klass_name}', field '{field_name}' has"
138
+ f" type {ann} - but only str, int, float or list are allowed!"
139
+ )
140
+ return ValueSchema(ann.__name__)
141
+
142
+ # Is a list
143
+ if typing.get_origin(ann) in [typing.List, list]:
144
+ args = typing.get_args(ann)
145
+
146
+ # Must be a string list
147
+ if args != (str,):
148
+ raise ValueError(
149
+ f"In '{klass_name}', field '{field_name}' has"
150
+ f" annotation {ann}, but only str, int, float, list[str] or Literal allowed!"
151
+ )
152
+
153
+ return ValueSchema("list[str]")
154
+
155
+ # Is a literal?
156
+ if typing.get_origin(ann) is typing.Literal:
157
+ args = typing.get_args(ann)
158
+
159
+ # Check that it is an allowable type
160
+ if len(args) == 0:
161
+ raise ValueError(
162
+ f"In '{klass_name}', field '{field_name}' has"
163
+ f" literal annotation, but allows no values!"
164
+ )
165
+
166
+ # String list?
167
+ typ = type(args[0])
168
+ if typ is list:
169
+ elem_type = type(args[0][0])
170
+ if elem_type is not str:
171
+ raise ValueError(
172
+ f"In '{klass_name}', field '{field_name}' has"
173
+ f" literal type list[{elem_type}] - but only list[str] is allowed!"
174
+ )
175
+ for lit in args:
176
+ if not isinstance(lit, typ):
177
+ raise ValueError(
178
+ f"In '{klass_name}', field '{field_name}' literal"
179
+ f" {lit} has inconsistent type ({typ(lit)}) vs ({typ})!"
180
+ )
181
+ for elem in lit:
182
+ if not isinstance(elem, elem_type):
183
+ raise ValueError(
184
+ f"In '{klass_name}', field '{field_name}' literal"
185
+ f" {lit} has inconsistent element type "
186
+ f"({typ(elem)}) vs ({elem_type})!"
187
+ )
188
+ return ValueSchema(
189
+ "list[str]",
190
+ literal=[[str(elem) for elem in arg] for arg in args],
191
+ )
192
+
193
+ # Check that it is an allowable type
194
+ if typ not in [bool, str, int, float]:
195
+ raise ValueError(
196
+ f"In '{klass_name}', field '{field_name}' has"
197
+ f" literal type {typ} - but only str, int, float or list[str] are allowed!"
198
+ )
199
+
200
+ # Check that all literals have the same type
201
+ for lit in args:
202
+ if not isinstance(lit, typ):
203
+ raise ValueError(
204
+ f"In '{klass_name}', field '{field_name}' literal"
205
+ f" {lit} has inconsistent type ({typ(lit)}) vs ({typ})!"
206
+ )
207
+
208
+ return ValueSchema(typ.__name__, literal=[typ(arg) for arg in args])
209
+
210
+ raise ValueError(
211
+ f"In '{klass_name}', field '{field_name}' has"
212
+ f" annotation {ann}, but only type or Literal allowed!"
213
+ )
214
+
215
+
91
216
  def extract_xarray_dataclass(klass, allow_undefined_coords: bool = False):
92
217
  """
93
218
  Go through dataclass fields and interpret them according to xarray-dataclass
@@ -132,13 +257,27 @@ def extract_xarray_dataclass(klass, allow_undefined_coords: bool = False):
132
257
 
133
258
  # Is it an attribute?
134
259
  if role == Role.ATTR:
260
+ try:
261
+ ann = get_annotated(typ)
262
+ except TypeError as e:
263
+ raise ValueError(
264
+ f"Could not get annotation in '{klass.__name__}' field '{field.name}': {e}"
265
+ )
266
+ vschema = value_schema(get_annotated(typ), klass.__name__, field.name)
267
+ if is_optional(typ):
268
+ vschema.optional = True
269
+
135
270
  attributes.append(
136
271
  AttrSchemaRef(
137
272
  name=field.name,
138
- typ=get_annotated(typ),
139
- optional=is_optional(typ),
140
- default=field.default,
273
+ default=(
274
+ None if field.default is dataclasses.MISSING else field.default
275
+ ),
141
276
  docstring=field_docstrings.get(field.name),
277
+ **{
278
+ fld.name: getattr(vschema, fld.name)
279
+ for fld in dataclasses.fields(ValueSchema)
280
+ },
142
281
  )
143
282
  )
144
283
  continue
@@ -151,7 +290,7 @@ def extract_xarray_dataclass(klass, allow_undefined_coords: bool = False):
151
290
  else:
152
291
  raise ValueError(
153
292
  f"Expected field '{field.name}' in '{klass.__name__}' "
154
- "to be annotated with either Coord, Data or Attr!"
293
+ f"to be annotated with either Coord, Data or Attr!"
155
294
  )
156
295
 
157
296
  # Defined using a dataclass, i.e. Coordof/Dataof?
@@ -173,7 +312,7 @@ def extract_xarray_dataclass(klass, allow_undefined_coords: bool = False):
173
312
  schema_ref = ArraySchemaRef(
174
313
  name=field.name,
175
314
  optional=is_optional(typ),
176
- default=field.default,
315
+ default=None if field.default is dataclasses.MISSING else field.default,
177
316
  docstring=field_docstrings.get(field.name),
178
317
  **arr_schema_fields,
179
318
  )
@@ -206,7 +345,9 @@ def extract_xarray_dataclass(klass, allow_undefined_coords: bool = False):
206
345
  schema_ref = ArraySchemaRef(
207
346
  name=field.name,
208
347
  optional=is_optional(typ),
209
- default=field.default,
348
+ default=(
349
+ None if field.default is dataclasses.MISSING else field.default
350
+ ),
210
351
  docstring=field_docstrings.get(field.name),
211
352
  **arr_schema_fields,
212
353
  )
@@ -215,11 +356,13 @@ def extract_xarray_dataclass(klass, allow_undefined_coords: bool = False):
215
356
  schema_ref = ArraySchemaRef(
216
357
  name=field.name,
217
358
  optional=is_optional(typ),
218
- default=field.default,
359
+ default=(
360
+ None if field.default is dataclasses.MISSING else field.default
361
+ ),
219
362
  docstring=field_docstrings.get(field.name),
220
363
  schema_name=None,
221
364
  dimensions=check_invalid_dims(dims, field.name),
222
- dtypes=[numpy.dtype(typ) for typ in types],
365
+ dtypes=[numpy.dtype(typ).str for typ in types],
223
366
  coordinates=[],
224
367
  attributes=[],
225
368
  class_docstring=None,
@@ -281,7 +424,7 @@ def xarray_dataclass_to_array_schema(klass):
281
424
  schema = ArraySchema(
282
425
  schema_name=f"{klass.__module__}.{klass.__qualname__}",
283
426
  dimensions=data_vars[0].dimensions,
284
- dtypes=data_vars[0].dtypes,
427
+ dtypes=[numpy.dtype(dt).str for dt in data_vars[0].dtypes],
285
428
  coordinates=coordinates,
286
429
  attributes=attributes,
287
430
  class_docstring=inspect.cleandoc(klass.__doc__),
@@ -369,22 +512,16 @@ def xarray_dataclass_to_dict_schema(klass):
369
512
  for field in dataclasses.fields(klass):
370
513
  typ = type_hints[field.name]
371
514
 
372
- # Handle optional value: Strip "None" from the types
373
- optional = is_optional(typ)
374
- if optional:
375
- typs = [typ for typ in get_args(typ) if typ is not None.__class__]
376
- if len(typs) == 1:
377
- typ = typs[0]
378
- else:
379
- typ = typing.Union.__getitem__[tuple(typs)]
380
-
515
+ vschema = value_schema(typ, klass.__name__, field.name)
381
516
  attributes.append(
382
517
  AttrSchemaRef(
383
518
  name=field.name,
384
- typ=typ,
385
- optional=optional,
386
- default=field.default,
519
+ default=None if field.default is dataclasses.MISSING else field.default,
387
520
  docstring=field_docstrings.get(field.name),
521
+ **{
522
+ fld.name: getattr(vschema, fld.name)
523
+ for fld in dataclasses.fields(ValueSchema)
524
+ },
388
525
  )
389
526
  )
390
527
 
@@ -0,0 +1,99 @@
1
+ import dataclasses
2
+ import json
3
+
4
+ from xradio.schema import (
5
+ bases,
6
+ metamodel,
7
+ xarray_dataclass_to_array_schema,
8
+ xarray_dataclass_to_dataset_schema,
9
+ xarray_dataclass_to_dict_schema,
10
+ )
11
+
12
+ CLASS_ATTR = "$class"
13
+
14
+
15
+ class DataclassEncoder(json.JSONEncoder):
16
+ """
17
+ General-purpose encoder that represents data classes as
18
+ dictionaries, omitting defaults and annotating the original class
19
+ as a ``'$class'`` attribute.
20
+ """
21
+
22
+ def default(self, o):
23
+ if dataclasses.is_dataclass(o):
24
+ res = {CLASS_ATTR: o.__class__.__name__}
25
+ for fld in dataclasses.fields(type(o)):
26
+ if (
27
+ getattr(o, fld.name) is not fld.default
28
+ and getattr(o, fld.name) is not dataclasses.MISSING
29
+ ):
30
+ res[fld.name] = getattr(o, fld.name)
31
+ return res
32
+ return super().default(o)
33
+
34
+
35
+ DATACLASS_MAP = {
36
+ cls.__name__: cls
37
+ for cls in [
38
+ metamodel.DictSchema,
39
+ metamodel.ValueSchema,
40
+ metamodel.AttrSchemaRef,
41
+ metamodel.ArraySchema,
42
+ metamodel.ArraySchemaRef,
43
+ metamodel.DatasetSchema,
44
+ ]
45
+ }
46
+
47
+
48
+ class DataclassDecoder(json.JSONDecoder):
49
+ """
50
+ General-purpose decoder that reads JSON as generated by
51
+ :py:class:`DataclassEncoder`.
52
+ """
53
+
54
+ def __init__(self, dataclass_map, *args, **kwargs):
55
+ self._dataclass_map = dataclass_map
56
+ super().__init__(*args, object_hook=self.object_hook, **kwargs)
57
+
58
+ def object_hook(self, obj):
59
+
60
+ # Detect dictionaries with '$class' annotation
61
+ if isinstance(obj, dict) and CLASS_ATTR in obj:
62
+
63
+ # Identify the class
64
+ cls_name = obj[CLASS_ATTR]
65
+ cls = self._dataclass_map.get(cls_name)
66
+ if not cls:
67
+ raise ValueError(
68
+ f"Unknown $dataclass encountered while decoding JSON: {cls_name}"
69
+ )
70
+
71
+ # Instantiate
72
+ del obj[CLASS_ATTR]
73
+ obj = cls(**obj)
74
+
75
+ return obj
76
+
77
+
78
+ def export_schema_json_file(schema, fname):
79
+ """
80
+ Exports given schema as a JSON file
81
+ """
82
+
83
+ # Check that this is actually a Dataset
84
+ if bases.is_dataset_schema(schema):
85
+ schema = xarray_dataclass_to_dataset_schema(schema)
86
+ if not isinstance(schema, metamodel.DatasetSchema):
87
+ raise TypeError(
88
+ f"export_schema_json_file: Expected DatasetSchema, but got {type(schema)}!"
89
+ )
90
+
91
+ # Perform export
92
+ with open(fname, "w", encoding="utf8") as f:
93
+ json.dump(schema, f, cls=DataclassEncoder, ensure_ascii=False, indent=" ")
94
+
95
+
96
+ def import_schema_json_file(fname):
97
+
98
+ with open(fname, "r", encoding="utf8") as f:
99
+ return json.load(f, cls=DataclassDecoder, dataclass_map=DATACLASS_MAP)