anemoi-datasets 0.5.24__py3-none-any.whl → 0.5.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/commands/finalise-additions.py +2 -1
  3. anemoi/datasets/commands/finalise.py +2 -1
  4. anemoi/datasets/commands/grib-index.py +1 -1
  5. anemoi/datasets/commands/init-additions.py +2 -1
  6. anemoi/datasets/commands/load-additions.py +2 -1
  7. anemoi/datasets/commands/load.py +2 -1
  8. anemoi/datasets/create/__init__.py +24 -33
  9. anemoi/datasets/create/filter.py +22 -24
  10. anemoi/datasets/create/input/__init__.py +0 -20
  11. anemoi/datasets/create/input/step.py +2 -16
  12. anemoi/datasets/create/sources/accumulations.py +7 -6
  13. anemoi/datasets/create/sources/planetary_computer.py +44 -0
  14. anemoi/datasets/create/sources/xarray_support/__init__.py +6 -22
  15. anemoi/datasets/create/sources/xarray_support/coordinates.py +8 -0
  16. anemoi/datasets/create/sources/xarray_support/field.py +1 -4
  17. anemoi/datasets/create/sources/xarray_support/flavour.py +44 -6
  18. anemoi/datasets/create/sources/xarray_support/patch.py +44 -1
  19. anemoi/datasets/create/sources/xarray_support/variable.py +6 -2
  20. anemoi/datasets/data/complement.py +44 -10
  21. anemoi/datasets/data/dataset.py +29 -0
  22. anemoi/datasets/data/forwards.py +8 -2
  23. anemoi/datasets/data/misc.py +74 -16
  24. anemoi/datasets/data/observations/__init__.py +316 -0
  25. anemoi/datasets/data/observations/legacy_obs_dataset.py +200 -0
  26. anemoi/datasets/data/observations/multi.py +64 -0
  27. anemoi/datasets/data/padded.py +227 -0
  28. anemoi/datasets/data/records/__init__.py +442 -0
  29. anemoi/datasets/data/records/backends/__init__.py +157 -0
  30. anemoi/datasets/data/stores.py +7 -56
  31. anemoi/datasets/data/subset.py +5 -0
  32. anemoi/datasets/grids.py +6 -3
  33. {anemoi_datasets-0.5.24.dist-info → anemoi_datasets-0.5.26.dist-info}/METADATA +3 -2
  34. {anemoi_datasets-0.5.24.dist-info → anemoi_datasets-0.5.26.dist-info}/RECORD +38 -51
  35. {anemoi_datasets-0.5.24.dist-info → anemoi_datasets-0.5.26.dist-info}/WHEEL +1 -1
  36. anemoi/datasets/create/filters/__init__.py +0 -33
  37. anemoi/datasets/create/filters/empty.py +0 -37
  38. anemoi/datasets/create/filters/legacy.py +0 -93
  39. anemoi/datasets/create/filters/noop.py +0 -37
  40. anemoi/datasets/create/filters/orog_to_z.py +0 -58
  41. anemoi/datasets/create/filters/pressure_level_relative_humidity_to_specific_humidity.py +0 -83
  42. anemoi/datasets/create/filters/pressure_level_specific_humidity_to_relative_humidity.py +0 -84
  43. anemoi/datasets/create/filters/rename.py +0 -205
  44. anemoi/datasets/create/filters/rotate_winds.py +0 -105
  45. anemoi/datasets/create/filters/single_level_dewpoint_to_relative_humidity.py +0 -78
  46. anemoi/datasets/create/filters/single_level_relative_humidity_to_dewpoint.py +0 -84
  47. anemoi/datasets/create/filters/single_level_relative_humidity_to_specific_humidity.py +0 -163
  48. anemoi/datasets/create/filters/single_level_specific_humidity_to_relative_humidity.py +0 -451
  49. anemoi/datasets/create/filters/speeddir_to_uv.py +0 -95
  50. anemoi/datasets/create/filters/sum.py +0 -68
  51. anemoi/datasets/create/filters/transform.py +0 -51
  52. anemoi/datasets/create/filters/unrotate_winds.py +0 -105
  53. anemoi/datasets/create/filters/uv_to_speeddir.py +0 -94
  54. anemoi/datasets/create/filters/wz_to_w.py +0 -98
  55. anemoi/datasets/create/testing.py +0 -76
  56. {anemoi_datasets-0.5.24.dist-info → anemoi_datasets-0.5.26.dist-info}/entry_points.txt +0 -0
  57. {anemoi_datasets-0.5.24.dist-info → anemoi_datasets-0.5.26.dist-info}/licenses/LICENSE +0 -0
  58. {anemoi_datasets-0.5.24.dist-info → anemoi_datasets-0.5.26.dist-info}/top_level.txt +0 -0
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.5.24'
21
- __version_tuple__ = version_tuple = (0, 5, 24)
20
+ __version__ = version = '0.5.26'
21
+ __version_tuple__ = version_tuple = (0, 5, 26)
@@ -61,7 +61,8 @@ class FinaliseAdditions(Command):
61
61
 
62
62
  if "debug" in options:
63
63
  options.pop("debug")
64
- task(step, options)
64
+
65
+ task(step, options)
65
66
 
66
67
  LOG.info(f"Create step '{step}' completed in {seconds_to_human(time.time()-now)}")
67
68
 
@@ -55,7 +55,8 @@ class Finalise(Command):
55
55
 
56
56
  if "debug" in options:
57
57
  options.pop("debug")
58
- task(step, options)
58
+
59
+ task(step, options)
59
60
 
60
61
  LOG.info(f"Create step '{step}' completed in {seconds_to_human(time.time()-now)}")
61
62
 
@@ -81,7 +81,7 @@ class GribIndexCmd(Command):
81
81
  bool
82
82
  True if the path matches, False otherwise.
83
83
  """
84
- return fnmatch.fnmatch(path, args.match)
84
+ return fnmatch.fnmatch(os.path.basename(path), args.match)
85
85
 
86
86
  from anemoi.datasets.create.sources.grib_index import GribIndex
87
87
 
@@ -61,7 +61,8 @@ class InitAdditions(Command):
61
61
 
62
62
  if "debug" in options:
63
63
  options.pop("debug")
64
- task(step, options)
64
+
65
+ task(step, options)
65
66
 
66
67
  LOG.info(f"Create step '{step}' completed in {seconds_to_human(time.time()-now)}")
67
68
 
@@ -62,7 +62,8 @@ class LoadAdditions(Command):
62
62
 
63
63
  if "debug" in options:
64
64
  options.pop("debug")
65
- task(step, options)
65
+
66
+ task(step, options)
66
67
 
67
68
  LOG.info(f"Create step '{step}' completed in {seconds_to_human(time.time()-now)}")
68
69
 
@@ -62,7 +62,8 @@ class Load(Command):
62
62
 
63
63
  if "debug" in options:
64
64
  options.pop("debug")
65
- task(step, options)
65
+
66
+ task(step, options)
66
67
 
67
68
  LOG.info(f"Create step '{step}' completed in {seconds_to_human(time.time()-now)}")
68
69
 
@@ -44,7 +44,7 @@ from .check import check_data_values
44
44
  from .chunks import ChunkFilter
45
45
  from .config import build_output
46
46
  from .config import loader_config
47
- from .input import build_input
47
+ from .input import InputBuilder
48
48
  from .statistics import Summary
49
49
  from .statistics import TmpStatistics
50
50
  from .statistics import check_variance
@@ -101,7 +101,9 @@ def json_tidy(o: Any) -> Any:
101
101
 
102
102
 
103
103
  def build_statistics_dates(
104
- dates: list[datetime.datetime], start: Optional[datetime.datetime], end: Optional[datetime.datetime]
104
+ dates: list[datetime.datetime],
105
+ start: Optional[datetime.datetime],
106
+ end: Optional[datetime.datetime],
105
107
  ) -> tuple[str, str]:
106
108
  """Compute the start and end dates for the statistics.
107
109
 
@@ -551,36 +553,16 @@ class HasElementForDataMixin:
551
553
 
552
554
  self.output = build_output(config.output, parent=self)
553
555
 
554
- self.input = build_input_(main_config=config, output_config=self.output)
555
- # LOG.info("%s", self.input)
556
-
557
-
558
- def build_input_(main_config: Any, output_config: Any) -> Any:
559
- """Build the input for the dataset.
560
-
561
- Parameters
562
- ----------
563
- main_config : Any
564
- The main configuration.
565
- output_config : Any
566
- The output configuration.
567
-
568
- Returns
569
- -------
570
- Any
571
- The input builder.
572
- """
573
- builder = build_input(
574
- main_config.input,
575
- data_sources=main_config.get("data_sources", {}),
576
- order_by=output_config.order_by,
577
- flatten_grid=output_config.flatten_grid,
578
- remapping=build_remapping(output_config.remapping),
579
- use_grib_paramid=main_config.build.use_grib_paramid,
580
- )
581
- LOG.debug("✅ INPUT_BUILDER")
582
- LOG.debug(builder)
583
- return builder
556
+ self.input = InputBuilder(
557
+ config.input,
558
+ data_sources=config.get("data_sources", {}),
559
+ order_by=self.output.order_by,
560
+ flatten_grid=self.output.flatten_grid,
561
+ remapping=build_remapping(self.output.remapping),
562
+ use_grib_paramid=config.build.use_grib_paramid,
563
+ )
564
+ LOG.debug("✅ INPUT_BUILDER")
565
+ LOG.debug(self.input)
584
566
 
585
567
 
586
568
  class Init(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin):
@@ -1541,7 +1523,16 @@ class Statistics(Actor, HasStatisticTempMixin, HasRegistryMixin):
1541
1523
  if not all(self.registry.get_flags(sync=False)):
1542
1524
  raise Exception(f"❗Zarr {self.path} is not fully built, not writing statistics into dataset.")
1543
1525
 
1544
- for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count", "has_nans"]:
1526
+ for k in [
1527
+ "mean",
1528
+ "stdev",
1529
+ "minimum",
1530
+ "maximum",
1531
+ "sums",
1532
+ "squares",
1533
+ "count",
1534
+ "has_nans",
1535
+ ]:
1545
1536
  self.dataset.add_dataset(name=k, array=stats[k], dimensions=("variable",))
1546
1537
 
1547
1538
  self.registry.add_to_history("compute_statistics_end")
@@ -7,44 +7,42 @@
7
7
  # granted to it by virtue of its status as an intergovernmental organisation
8
8
  # nor does it submit to any jurisdiction.
9
9
 
10
- from abc import ABC
11
- from abc import abstractmethod
12
10
  from typing import Any
11
+ from typing import Dict
13
12
 
14
13
  import earthkit.data as ekd
15
14
 
16
15
 
17
- class Filter(ABC):
18
- """A base class for filters."""
16
+ class TransformFilter:
17
+ """Calls filters from anemoi.transform.filters
19
18
 
20
- def __init__(self, context: Any, *args: Any, **kwargs: Any) -> None:
21
- """Initialise the filter.
19
+ Parameters
20
+ ----------
21
+ context : Any
22
+ The context in which the filter is created.
23
+ name : str
24
+ The name of the filter.
25
+ config : Dict[str, Any]
26
+ The configuration for the filter.
27
+ """
22
28
 
23
- Parameters
24
- ----------
25
- context : Any
26
- The context in which the filter is created.
27
- *args : tuple
28
- Positional arguments.
29
- **kwargs : dict
30
- Keyword arguments.
31
- """
29
+ def __init__(self, context: Any, name: str, config: Dict[str, Any]) -> None:
30
+ from anemoi.transform.filters import create_filter
32
31
 
33
- self.context = context
32
+ self.name = name
33
+ self.transform_filter = create_filter(context, config)
34
34
 
35
- @abstractmethod
36
- def execute(self, data: ekd.FieldList) -> ekd.FieldList:
37
- """Execute the filter.
35
+ def execute(self, input: ekd.FieldList) -> ekd.FieldList:
36
+ """Execute the transformation filter.
38
37
 
39
38
  Parameters
40
39
  ----------
41
- data : ekd.FieldList
42
- The input data.
40
+ input : ekd.FieldList
41
+ The input data to be transformed.
43
42
 
44
43
  Returns
45
44
  -------
46
45
  ekd.FieldList
47
- The output data.
46
+ The transformed data.
48
47
  """
49
-
50
- pass
48
+ return self.transform_filter.forward(input)
@@ -104,23 +104,3 @@ class InputBuilder:
104
104
  Trace string.
105
105
  """
106
106
  return f"InputBuilder({group_of_dates})"
107
-
108
-
109
- def build_input(config: dict, data_sources: Union[dict, list], **kwargs: Any) -> InputBuilder:
110
- """Build an InputBuilder instance.
111
-
112
- Parameters
113
- ----------
114
- config : dict
115
- Configuration dictionary.
116
- data_sources : Union[dict, list]
117
- Data sources.
118
- **kwargs : Any
119
- Additional keyword arguments.
120
-
121
- Returns
122
- -------
123
- InputBuilder
124
- An instance of InputBuilder.
125
- """
126
- return InputBuilder(config, data_sources, **kwargs)
@@ -8,7 +8,6 @@
8
8
  # nor does it submit to any jurisdiction.
9
9
 
10
10
  import logging
11
- import warnings
12
11
  from copy import deepcopy
13
12
  from typing import Any
14
13
  from typing import Dict
@@ -165,24 +164,11 @@ def step_factory(config: Dict[str, Any], context: ActionContext, action_path: Li
165
164
  if cls is not None:
166
165
  return cls(context, action_path, previous_step, *args, **kwargs)
167
166
 
168
- # Try filters from datasets filter registry
167
+ # Try filters from transform filter registry
169
168
  from anemoi.transform.filters import filter_registry as transform_filter_registry
170
169
 
171
- from ..filters import create_filter as create_datasets_filter
172
- from ..filters import filter_registry as datasets_filter_registry
173
-
174
- if datasets_filter_registry.is_registered(key):
175
-
176
- if transform_filter_registry.is_registered(key):
177
- warnings.warn(f"Filter `{key}` is registered in both datasets and transform filter registries")
178
-
179
- filter = create_datasets_filter(None, config)
180
- return FunctionStepAction(context, action_path + [key], previous_step, key, filter)
181
-
182
- # Use filters from transform registry
183
-
184
170
  if transform_filter_registry.is_registered(key):
185
- from ..filters.transform import TransformFilter
171
+ from ..filter import TransformFilter
186
172
 
187
173
  return FunctionStepAction(
188
174
  context, action_path + [key], previous_step, key, TransformFilter(context, key, config)
@@ -459,12 +459,13 @@ class AccumulationFromStart(Accumulation):
459
459
  A tuple representing the MARS date-time step.
460
460
  """
461
461
  assert user_date is None, user_date
462
- assert not frequency, frequency
463
462
 
464
463
  steps = (step1 + add_step, step2 + add_step)
465
464
  if steps[0] == 0:
466
465
  steps = (steps[1],)
467
466
 
467
+ assert frequency == 0 or frequency == (step2 - step1), frequency
468
+
468
469
  return (
469
470
  base_date.year * 10000 + base_date.month * 100 + base_date.day,
470
471
  base_date.hour * 100 + base_date.minute,
@@ -824,6 +825,11 @@ def _compute_accumulations(
824
825
  step1, step2 = user_accumulation_period
825
826
  assert step1 < step2, user_accumulation_period
826
827
 
828
+ if accumulations_reset_frequency is not None:
829
+ AccumulationClass = AccumulationFromLastReset
830
+ else:
831
+ AccumulationClass = AccumulationFromStart if data_accumulation_period in (0, None) else AccumulationFromLastStep
832
+
827
833
  if data_accumulation_period is None:
828
834
  data_accumulation_period = user_accumulation_period[1] - user_accumulation_period[0]
829
835
 
@@ -838,11 +844,6 @@ def _compute_accumulations(
838
844
 
839
845
  base_times = [t // 100 if t > 100 else t for t in base_times]
840
846
 
841
- if accumulations_reset_frequency is not None:
842
- AccumulationClass = AccumulationFromLastReset
843
- else:
844
- AccumulationClass = AccumulationFromStart if data_accumulation_period in (0, None) else AccumulationFromLastStep
845
-
846
847
  mars_date_time_steps = AccumulationClass.mars_date_time_steps(
847
848
  dates=dates,
848
849
  step1=step1,
@@ -0,0 +1,44 @@
1
+ # (C) Copyright 2024 Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+
11
+ from . import source_registry
12
+ from .xarray import XarraySourceBase
13
+
14
+
15
+ @source_registry.register("planetary_computer")
16
+ class PlanetaryComputerSource(XarraySourceBase):
17
+ """An Xarray data source for the planetary_computer."""
18
+
19
+ emoji = "🪐"
20
+
21
+ def __init__(self, context, data_catalog_id, version="v1", *args, **kwargs: dict):
22
+
23
+ import planetary_computer
24
+ import pystac_client
25
+
26
+ self.data_catalog_id = data_catalog_id
27
+ self.flavour = kwargs.pop("flavour", None)
28
+ self.patch = kwargs.pop("patch", None)
29
+ self.options = kwargs.pop("options", {})
30
+
31
+ catalog = pystac_client.Client.open(
32
+ f"https://planetarycomputer.microsoft.com/api/stac/{version}/",
33
+ modifier=planetary_computer.sign_inplace,
34
+ )
35
+ collection = catalog.get_collection(self.data_catalog_id)
36
+
37
+ asset = collection.assets["zarr-abfs"]
38
+
39
+ if "xarray:storage_options" in asset.extra_fields:
40
+ self.options["storage_options"] = asset.extra_fields["xarray:storage_options"]
41
+
42
+ self.options.update(asset.extra_fields["xarray:open_kwargs"])
43
+
44
+ super().__init__(context, url=asset.href, *args, **kwargs)
@@ -20,7 +20,6 @@ import xarray as xr
20
20
  from earthkit.data.core.fieldlist import MultiFieldList
21
21
 
22
22
  from anemoi.datasets.create.sources.patterns import iterate_patterns
23
- from anemoi.datasets.data.stores import name_to_zarr_store
24
23
 
25
24
  from ..legacy import legacy_source
26
25
  from .fieldlist import XarrayFieldList
@@ -89,37 +88,22 @@ def load_one(
89
88
  The loaded dataset.
90
89
  """
91
90
 
92
- """
93
- We manage the S3 client ourselves, bypassing fsspec and s3fs layers, because sometimes something on the stack
94
- zarr/fsspec/s3fs/boto3 (?) seem to flags files as missing when they actually are not (maybe when S3 reports some sort of
95
- connection error). In that case, Zarr will silently fill the chunks that could not be downloaded with NaNs.
96
- See https://github.com/pydata/xarray/issues/8842
97
-
98
- We have seen this bug triggered when we run many clients in parallel, for example, when we create a new dataset using `xarray-zarr`.
99
- """
100
-
101
91
  if options is None:
102
92
  options = {}
103
93
 
104
94
  context.trace(emoji, dataset, options, kwargs)
105
95
 
106
- if isinstance(dataset, str) and ".zarr" in dataset:
107
- data = xr.open_zarr(name_to_zarr_store(dataset), **options)
108
- elif "planetarycomputer" in dataset:
109
- store = name_to_zarr_store(dataset)
110
- if "store" in store:
111
- data = xr.open_zarr(**store)
112
- if "filename_or_obj" in store:
113
- data = xr.open_dataset(**store)
114
- else:
115
- data = xr.open_dataset(dataset, **options)
96
+ if isinstance(dataset, str) and dataset.endswith(".zarr"):
97
+ # If the dataset is a zarr store, we need to use the zarr engine
98
+ options["engine"] = "zarr"
99
+
100
+ data = xr.open_dataset(dataset, **options)
116
101
 
117
102
  fs = XarrayFieldList.from_xarray(data, flavour=flavour, patch=patch)
118
103
 
119
104
  if len(dates) == 0:
120
105
  result = fs.sel(**kwargs)
121
106
  else:
122
- print("dates", dates, kwargs)
123
107
  result = MultiFieldList([fs.sel(valid_datetime=date, **kwargs) for date in dates])
124
108
 
125
109
  if len(result) == 0:
@@ -130,7 +114,7 @@ def load_one(
130
114
  a = ["valid_datetime", k.metadata("valid_datetime", default=None)]
131
115
  for n in kwargs.keys():
132
116
  a.extend([n, k.metadata(n, default=None)])
133
- print([str(x) for x in a])
117
+ LOG.warning(f"{[str(x) for x in a]}")
134
118
 
135
119
  if i > 16:
136
120
  break
@@ -95,6 +95,7 @@ class Coordinate:
95
95
  is_member = False
96
96
  is_x = False
97
97
  is_y = False
98
+ is_point = False
98
99
 
99
100
  def __init__(self, variable: xr.DataArray) -> None:
100
101
  """Initialize the coordinate.
@@ -390,6 +391,13 @@ class EnsembleCoordinate(Coordinate):
390
391
  return value
391
392
 
392
393
 
394
+ class PointCoordinate(Coordinate):
395
+ """Coordinate class for point data."""
396
+
397
+ is_point = True
398
+ mars_names = ("point",)
399
+
400
+
393
401
  class LongitudeCoordinate(Coordinate):
394
402
  """Coordinate class for longitude."""
395
403
 
@@ -87,13 +87,10 @@ class XArrayField(Field):
87
87
  coordinate = owner.by_name[coord_name]
88
88
  self._md[coord_name] = coordinate.normalise(extract_single_value(coord_value))
89
89
 
90
- # print(values.ndim, values.shape, selection.dims)
91
90
  # By now, the only dimensions should be latitude and longitude
92
91
  self._shape = tuple(list(self.selection.shape)[-2:])
93
92
  if math.prod(self._shape) != math.prod(self.selection.shape):
94
- print(self.selection.ndim, self.selection.shape)
95
- print(self.selection)
96
- raise ValueError("Invalid shape for selection")
93
+ raise ValueError(f"Invalid shape for selection {self._shape=}, {self.selection.shape=} {self.selection=}")
97
94
 
98
95
  @property
99
96
  def shape(self) -> Tuple[int, int]:
@@ -26,6 +26,7 @@ from .coordinates import EnsembleCoordinate
26
26
  from .coordinates import LatitudeCoordinate
27
27
  from .coordinates import LevelCoordinate
28
28
  from .coordinates import LongitudeCoordinate
29
+ from .coordinates import PointCoordinate
29
30
  from .coordinates import ScalarCoordinate
30
31
  from .coordinates import StepCoordinate
31
32
  from .coordinates import TimeCoordinate
@@ -134,6 +135,10 @@ class CoordinateGuesser(ABC):
134
135
 
135
136
  d: Optional[Coordinate] = None
136
137
 
138
+ d = self._is_point(coordinate, attributes)
139
+ if d is not None:
140
+ return d
141
+
137
142
  d = self._is_longitude(coordinate, attributes)
138
143
  if d is not None:
139
144
  return d
@@ -308,9 +313,9 @@ class CoordinateGuesser(ABC):
308
313
  return self._grid_cache[(x.name, y.name, dim_vars)]
309
314
 
310
315
  grid_mapping = variable.attrs.get("grid_mapping", None)
311
- if grid_mapping is not None:
312
- print(f"grid_mapping: {grid_mapping}")
313
- print(self.ds[grid_mapping])
316
+ # if grid_mapping is not None:
317
+ # print(f"grid_mapping: {grid_mapping}")
318
+ # print(self.ds[grid_mapping])
314
319
 
315
320
  if grid_mapping is None:
316
321
  LOG.warning(f"No 'grid_mapping' attribute provided for '{variable.name}'")
@@ -392,6 +397,10 @@ class CoordinateGuesser(ABC):
392
397
  """
393
398
  pass
394
399
 
400
+ @abstractmethod
401
+ def _is_point(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[PointCoordinate]:
402
+ pass
403
+
395
404
  @abstractmethod
396
405
  def _is_latitude(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[LatitudeCoordinate]:
397
406
  """Checks if the coordinate is a latitude.
@@ -550,6 +559,15 @@ class DefaultCoordinateGuesser(CoordinateGuesser):
550
559
  """
551
560
  super().__init__(ds)
552
561
 
562
+ def _is_point(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[PointCoordinate]:
563
+ if attributes.standard_name in ["cell", "station", "poi", "point"]:
564
+ return PointCoordinate(c)
565
+
566
+ if attributes.name in ["cell", "station", "poi", "point"]: # WeatherBench
567
+ return PointCoordinate(c)
568
+
569
+ return None
570
+
553
571
  def _is_longitude(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[LongitudeCoordinate]:
554
572
  """Checks if the coordinate is a longitude.
555
573
 
@@ -750,6 +768,9 @@ class DefaultCoordinateGuesser(CoordinateGuesser):
750
768
  if attributes.standard_name == "air_pressure" and attributes.units == "hPa":
751
769
  return LevelCoordinate(c, "pl")
752
770
 
771
+ if attributes.long_name == "pressure" and attributes.units in ["hPa", "Pa"]:
772
+ return LevelCoordinate(c, "pl")
773
+
753
774
  if attributes.name == "level":
754
775
  return LevelCoordinate(c, "pl")
755
776
 
@@ -759,9 +780,6 @@ class DefaultCoordinateGuesser(CoordinateGuesser):
759
780
  if attributes.standard_name == "depth":
760
781
  return LevelCoordinate(c, "depth")
761
782
 
762
- if attributes.name == "vertical" and attributes.units == "hPa":
763
- return LevelCoordinate(c, "pl")
764
-
765
783
  return None
766
784
 
767
785
  def _is_number(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[EnsembleCoordinate]:
@@ -1040,3 +1058,23 @@ class FlavourCoordinateGuesser(CoordinateGuesser):
1040
1058
  return EnsembleCoordinate(c)
1041
1059
 
1042
1060
  return None
1061
+
1062
+ def _is_point(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[PointCoordinate]:
1063
+ """Checks if the coordinate is a point coordinate using the flavour rules.
1064
+
1065
+ Parameters
1066
+ ----------
1067
+ c : xr.DataArray
1068
+ The coordinate to check.
1069
+ attributes : CoordinateAttributes
1070
+ The attributes of the coordinate.
1071
+
1072
+ Returns
1073
+ -------
1074
+ Optional[PointCoordinate]
1075
+ The StepCoorPointCoordinateinate if matched, else None.
1076
+ """
1077
+ if self._match(c, "point", attributes):
1078
+ return PointCoordinate(c)
1079
+
1080
+ return None
@@ -61,9 +61,50 @@ def patch_coordinates(ds: xr.Dataset, coordinates: List[str]) -> Any:
61
61
  return ds
62
62
 
63
63
 
64
+ def patch_rename(ds: xr.Dataset, renames: dict[str, str]) -> Any:
65
+ """Rename variables in the dataset.
66
+
67
+ Parameters
68
+ ----------
69
+ ds : xr.Dataset
70
+ The dataset to patch.
71
+ renames : dict[str, str]
72
+ Mapping from old variable names to new variable names.
73
+
74
+ Returns
75
+ -------
76
+ Any
77
+ The patched dataset.
78
+ """
79
+ return ds.rename(renames)
80
+
81
+
82
+ def patch_sort_coordinate(ds: xr.Dataset, sort_coordinates: List[str]) -> Any:
83
+ """Sort the coordinates of the dataset.
84
+
85
+ Parameters
86
+ ----------
87
+ ds : xr.Dataset
88
+ The dataset to patch.
89
+ sort_coordinates : List[str]
90
+ The coordinates to sort.
91
+
92
+ Returns
93
+ -------
94
+ Any
95
+ The patched dataset.
96
+ """
97
+
98
+ for name in sort_coordinates:
99
+ ds = ds.sortby(name)
100
+ return ds
101
+
102
+
64
103
  PATCHES = {
65
104
  "attributes": patch_attributes,
66
105
  "coordinates": patch_coordinates,
106
+ "rename": patch_rename,
107
+ "sort_coordinates": patch_sort_coordinate,
67
108
  }
68
109
 
69
110
 
@@ -82,7 +123,9 @@ def patch_dataset(ds: xr.Dataset, patch: Dict[str, Dict[str, Any]]) -> Any:
82
123
  Any
83
124
  The patched dataset.
84
125
  """
85
- for what, values in patch.items():
126
+
127
+ ORDER = ["coordinates", "attributes", "rename", "sort_coordinates"]
128
+ for what, values in sorted(patch.items(), key=lambda x: ORDER.index(x[0])):
86
129
  if what not in PATCHES:
87
130
  raise ValueError(f"Unknown patch type {what!r}")
88
131
 
@@ -82,8 +82,12 @@ class Variable:
82
82
 
83
83
  self.time = time
84
84
 
85
- self.shape = tuple(len(c.variable) for c in coordinates if c.is_dim and not c.scalar and not c.is_grid)
86
- self.names = {c.variable.name: c for c in coordinates if c.is_dim and not c.scalar and not c.is_grid}
85
+ self.shape = tuple(
86
+ len(c.variable) for c in coordinates if c.is_dim and not c.scalar and not c.is_grid and not c.is_point
87
+ )
88
+ self.names = {
89
+ c.variable.name: c for c in coordinates if c.is_dim and not c.scalar and not c.is_grid and not c.is_point
90
+ }
87
91
  self.by_name = {c.variable.name: c for c in coordinates}
88
92
 
89
93
  # We need that alias for the time dimension