anemoi-datasets 0.5.25__py3-none-any.whl → 0.5.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/commands/grib-index.py +1 -1
  3. anemoi/datasets/create/filter.py +22 -24
  4. anemoi/datasets/create/input/step.py +2 -16
  5. anemoi/datasets/create/sources/planetary_computer.py +44 -0
  6. anemoi/datasets/create/sources/xarray_support/__init__.py +6 -22
  7. anemoi/datasets/create/sources/xarray_support/coordinates.py +8 -0
  8. anemoi/datasets/create/sources/xarray_support/field.py +1 -4
  9. anemoi/datasets/create/sources/xarray_support/flavour.py +44 -6
  10. anemoi/datasets/create/sources/xarray_support/patch.py +44 -1
  11. anemoi/datasets/create/sources/xarray_support/variable.py +6 -2
  12. anemoi/datasets/data/complement.py +44 -10
  13. anemoi/datasets/data/forwards.py +8 -2
  14. anemoi/datasets/data/stores.py +7 -56
  15. anemoi/datasets/grids.py +6 -3
  16. {anemoi_datasets-0.5.25.dist-info → anemoi_datasets-0.5.26.dist-info}/METADATA +3 -2
  17. {anemoi_datasets-0.5.25.dist-info → anemoi_datasets-0.5.26.dist-info}/RECORD +21 -40
  18. anemoi/datasets/create/filters/__init__.py +0 -33
  19. anemoi/datasets/create/filters/empty.py +0 -37
  20. anemoi/datasets/create/filters/legacy.py +0 -93
  21. anemoi/datasets/create/filters/noop.py +0 -37
  22. anemoi/datasets/create/filters/orog_to_z.py +0 -58
  23. anemoi/datasets/create/filters/pressure_level_relative_humidity_to_specific_humidity.py +0 -83
  24. anemoi/datasets/create/filters/pressure_level_specific_humidity_to_relative_humidity.py +0 -84
  25. anemoi/datasets/create/filters/rename.py +0 -205
  26. anemoi/datasets/create/filters/rotate_winds.py +0 -105
  27. anemoi/datasets/create/filters/single_level_dewpoint_to_relative_humidity.py +0 -78
  28. anemoi/datasets/create/filters/single_level_relative_humidity_to_dewpoint.py +0 -84
  29. anemoi/datasets/create/filters/single_level_relative_humidity_to_specific_humidity.py +0 -163
  30. anemoi/datasets/create/filters/single_level_specific_humidity_to_relative_humidity.py +0 -451
  31. anemoi/datasets/create/filters/speeddir_to_uv.py +0 -95
  32. anemoi/datasets/create/filters/sum.py +0 -68
  33. anemoi/datasets/create/filters/transform.py +0 -51
  34. anemoi/datasets/create/filters/unrotate_winds.py +0 -105
  35. anemoi/datasets/create/filters/uv_to_speeddir.py +0 -94
  36. anemoi/datasets/create/filters/wz_to_w.py +0 -98
  37. anemoi/datasets/create/testing.py +0 -76
  38. {anemoi_datasets-0.5.25.dist-info → anemoi_datasets-0.5.26.dist-info}/WHEEL +0 -0
  39. {anemoi_datasets-0.5.25.dist-info → anemoi_datasets-0.5.26.dist-info}/entry_points.txt +0 -0
  40. {anemoi_datasets-0.5.25.dist-info → anemoi_datasets-0.5.26.dist-info}/licenses/LICENSE +0 -0
  41. {anemoi_datasets-0.5.25.dist-info → anemoi_datasets-0.5.26.dist-info}/top_level.txt +0 -0
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.5.25'
21
- __version_tuple__ = version_tuple = (0, 5, 25)
20
+ __version__ = version = '0.5.26'
21
+ __version_tuple__ = version_tuple = (0, 5, 26)
@@ -81,7 +81,7 @@ class GribIndexCmd(Command):
81
81
  bool
82
82
  True if the path matches, False otherwise.
83
83
  """
84
- return fnmatch.fnmatch(path, args.match)
84
+ return fnmatch.fnmatch(os.path.basename(path), args.match)
85
85
 
86
86
  from anemoi.datasets.create.sources.grib_index import GribIndex
87
87
 
@@ -7,44 +7,42 @@
7
7
  # granted to it by virtue of its status as an intergovernmental organisation
8
8
  # nor does it submit to any jurisdiction.
9
9
 
10
- from abc import ABC
11
- from abc import abstractmethod
12
10
  from typing import Any
11
+ from typing import Dict
13
12
 
14
13
  import earthkit.data as ekd
15
14
 
16
15
 
17
- class Filter(ABC):
18
- """A base class for filters."""
16
+ class TransformFilter:
17
+ """Calls filters from anemoi.transform.filters
19
18
 
20
- def __init__(self, context: Any, *args: Any, **kwargs: Any) -> None:
21
- """Initialise the filter.
19
+ Parameters
20
+ ----------
21
+ context : Any
22
+ The context in which the filter is created.
23
+ name : str
24
+ The name of the filter.
25
+ config : Dict[str, Any]
26
+ The configuration for the filter.
27
+ """
22
28
 
23
- Parameters
24
- ----------
25
- context : Any
26
- The context in which the filter is created.
27
- *args : tuple
28
- Positional arguments.
29
- **kwargs : dict
30
- Keyword arguments.
31
- """
29
+ def __init__(self, context: Any, name: str, config: Dict[str, Any]) -> None:
30
+ from anemoi.transform.filters import create_filter
32
31
 
33
- self.context = context
32
+ self.name = name
33
+ self.transform_filter = create_filter(context, config)
34
34
 
35
- @abstractmethod
36
- def execute(self, data: ekd.FieldList) -> ekd.FieldList:
37
- """Execute the filter.
35
+ def execute(self, input: ekd.FieldList) -> ekd.FieldList:
36
+ """Execute the transformation filter.
38
37
 
39
38
  Parameters
40
39
  ----------
41
- data : ekd.FieldList
42
- The input data.
40
+ input : ekd.FieldList
41
+ The input data to be transformed.
43
42
 
44
43
  Returns
45
44
  -------
46
45
  ekd.FieldList
47
- The output data.
46
+ The transformed data.
48
47
  """
49
-
50
- pass
48
+ return self.transform_filter.forward(input)
@@ -8,7 +8,6 @@
8
8
  # nor does it submit to any jurisdiction.
9
9
 
10
10
  import logging
11
- import warnings
12
11
  from copy import deepcopy
13
12
  from typing import Any
14
13
  from typing import Dict
@@ -165,24 +164,11 @@ def step_factory(config: Dict[str, Any], context: ActionContext, action_path: Li
165
164
  if cls is not None:
166
165
  return cls(context, action_path, previous_step, *args, **kwargs)
167
166
 
168
- # Try filters from datasets filter registry
167
+ # Try filters from transform filter registry
169
168
  from anemoi.transform.filters import filter_registry as transform_filter_registry
170
169
 
171
- from ..filters import create_filter as create_datasets_filter
172
- from ..filters import filter_registry as datasets_filter_registry
173
-
174
- if datasets_filter_registry.is_registered(key):
175
-
176
- if transform_filter_registry.is_registered(key):
177
- warnings.warn(f"Filter `{key}` is registered in both datasets and transform filter registries")
178
-
179
- filter = create_datasets_filter(None, config)
180
- return FunctionStepAction(context, action_path + [key], previous_step, key, filter)
181
-
182
- # Use filters from transform registry
183
-
184
170
  if transform_filter_registry.is_registered(key):
185
- from ..filters.transform import TransformFilter
171
+ from ..filter import TransformFilter
186
172
 
187
173
  return FunctionStepAction(
188
174
  context, action_path + [key], previous_step, key, TransformFilter(context, key, config)
@@ -0,0 +1,44 @@
1
+ # (C) Copyright 2024 Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+
11
+ from . import source_registry
12
+ from .xarray import XarraySourceBase
13
+
14
+
15
+ @source_registry.register("planetary_computer")
16
+ class PlanetaryComputerSource(XarraySourceBase):
17
+ """An Xarray data source for the planetary_computer."""
18
+
19
+ emoji = "🪐"
20
+
21
+ def __init__(self, context, data_catalog_id, version="v1", *args, **kwargs: dict):
22
+
23
+ import planetary_computer
24
+ import pystac_client
25
+
26
+ self.data_catalog_id = data_catalog_id
27
+ self.flavour = kwargs.pop("flavour", None)
28
+ self.patch = kwargs.pop("patch", None)
29
+ self.options = kwargs.pop("options", {})
30
+
31
+ catalog = pystac_client.Client.open(
32
+ f"https://planetarycomputer.microsoft.com/api/stac/{version}/",
33
+ modifier=planetary_computer.sign_inplace,
34
+ )
35
+ collection = catalog.get_collection(self.data_catalog_id)
36
+
37
+ asset = collection.assets["zarr-abfs"]
38
+
39
+ if "xarray:storage_options" in asset.extra_fields:
40
+ self.options["storage_options"] = asset.extra_fields["xarray:storage_options"]
41
+
42
+ self.options.update(asset.extra_fields["xarray:open_kwargs"])
43
+
44
+ super().__init__(context, url=asset.href, *args, **kwargs)
@@ -20,7 +20,6 @@ import xarray as xr
20
20
  from earthkit.data.core.fieldlist import MultiFieldList
21
21
 
22
22
  from anemoi.datasets.create.sources.patterns import iterate_patterns
23
- from anemoi.datasets.data.stores import name_to_zarr_store
24
23
 
25
24
  from ..legacy import legacy_source
26
25
  from .fieldlist import XarrayFieldList
@@ -89,37 +88,22 @@ def load_one(
89
88
  The loaded dataset.
90
89
  """
91
90
 
92
- """
93
- We manage the S3 client ourselves, bypassing fsspec and s3fs layers, because sometimes something on the stack
94
- zarr/fsspec/s3fs/boto3 (?) seem to flags files as missing when they actually are not (maybe when S3 reports some sort of
95
- connection error). In that case, Zarr will silently fill the chunks that could not be downloaded with NaNs.
96
- See https://github.com/pydata/xarray/issues/8842
97
-
98
- We have seen this bug triggered when we run many clients in parallel, for example, when we create a new dataset using `xarray-zarr`.
99
- """
100
-
101
91
  if options is None:
102
92
  options = {}
103
93
 
104
94
  context.trace(emoji, dataset, options, kwargs)
105
95
 
106
- if isinstance(dataset, str) and ".zarr" in dataset:
107
- data = xr.open_zarr(name_to_zarr_store(dataset), **options)
108
- elif "planetarycomputer" in dataset:
109
- store = name_to_zarr_store(dataset)
110
- if "store" in store:
111
- data = xr.open_zarr(**store)
112
- if "filename_or_obj" in store:
113
- data = xr.open_dataset(**store)
114
- else:
115
- data = xr.open_dataset(dataset, **options)
96
+ if isinstance(dataset, str) and dataset.endswith(".zarr"):
97
+ # If the dataset is a zarr store, we need to use the zarr engine
98
+ options["engine"] = "zarr"
99
+
100
+ data = xr.open_dataset(dataset, **options)
116
101
 
117
102
  fs = XarrayFieldList.from_xarray(data, flavour=flavour, patch=patch)
118
103
 
119
104
  if len(dates) == 0:
120
105
  result = fs.sel(**kwargs)
121
106
  else:
122
- print("dates", dates, kwargs)
123
107
  result = MultiFieldList([fs.sel(valid_datetime=date, **kwargs) for date in dates])
124
108
 
125
109
  if len(result) == 0:
@@ -130,7 +114,7 @@ def load_one(
130
114
  a = ["valid_datetime", k.metadata("valid_datetime", default=None)]
131
115
  for n in kwargs.keys():
132
116
  a.extend([n, k.metadata(n, default=None)])
133
- print([str(x) for x in a])
117
+ LOG.warning(f"{[str(x) for x in a]}")
134
118
 
135
119
  if i > 16:
136
120
  break
@@ -95,6 +95,7 @@ class Coordinate:
95
95
  is_member = False
96
96
  is_x = False
97
97
  is_y = False
98
+ is_point = False
98
99
 
99
100
  def __init__(self, variable: xr.DataArray) -> None:
100
101
  """Initialize the coordinate.
@@ -390,6 +391,13 @@ class EnsembleCoordinate(Coordinate):
390
391
  return value
391
392
 
392
393
 
394
+ class PointCoordinate(Coordinate):
395
+ """Coordinate class for point data."""
396
+
397
+ is_point = True
398
+ mars_names = ("point",)
399
+
400
+
393
401
  class LongitudeCoordinate(Coordinate):
394
402
  """Coordinate class for longitude."""
395
403
 
@@ -87,13 +87,10 @@ class XArrayField(Field):
87
87
  coordinate = owner.by_name[coord_name]
88
88
  self._md[coord_name] = coordinate.normalise(extract_single_value(coord_value))
89
89
 
90
- # print(values.ndim, values.shape, selection.dims)
91
90
  # By now, the only dimensions should be latitude and longitude
92
91
  self._shape = tuple(list(self.selection.shape)[-2:])
93
92
  if math.prod(self._shape) != math.prod(self.selection.shape):
94
- print(self.selection.ndim, self.selection.shape)
95
- print(self.selection)
96
- raise ValueError("Invalid shape for selection")
93
+ raise ValueError(f"Invalid shape for selection {self._shape=}, {self.selection.shape=} {self.selection=}")
97
94
 
98
95
  @property
99
96
  def shape(self) -> Tuple[int, int]:
@@ -26,6 +26,7 @@ from .coordinates import EnsembleCoordinate
26
26
  from .coordinates import LatitudeCoordinate
27
27
  from .coordinates import LevelCoordinate
28
28
  from .coordinates import LongitudeCoordinate
29
+ from .coordinates import PointCoordinate
29
30
  from .coordinates import ScalarCoordinate
30
31
  from .coordinates import StepCoordinate
31
32
  from .coordinates import TimeCoordinate
@@ -134,6 +135,10 @@ class CoordinateGuesser(ABC):
134
135
 
135
136
  d: Optional[Coordinate] = None
136
137
 
138
+ d = self._is_point(coordinate, attributes)
139
+ if d is not None:
140
+ return d
141
+
137
142
  d = self._is_longitude(coordinate, attributes)
138
143
  if d is not None:
139
144
  return d
@@ -308,9 +313,9 @@ class CoordinateGuesser(ABC):
308
313
  return self._grid_cache[(x.name, y.name, dim_vars)]
309
314
 
310
315
  grid_mapping = variable.attrs.get("grid_mapping", None)
311
- if grid_mapping is not None:
312
- print(f"grid_mapping: {grid_mapping}")
313
- print(self.ds[grid_mapping])
316
+ # if grid_mapping is not None:
317
+ # print(f"grid_mapping: {grid_mapping}")
318
+ # print(self.ds[grid_mapping])
314
319
 
315
320
  if grid_mapping is None:
316
321
  LOG.warning(f"No 'grid_mapping' attribute provided for '{variable.name}'")
@@ -392,6 +397,10 @@ class CoordinateGuesser(ABC):
392
397
  """
393
398
  pass
394
399
 
400
+ @abstractmethod
401
+ def _is_point(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[PointCoordinate]:
402
+ pass
403
+
395
404
  @abstractmethod
396
405
  def _is_latitude(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[LatitudeCoordinate]:
397
406
  """Checks if the coordinate is a latitude.
@@ -550,6 +559,15 @@ class DefaultCoordinateGuesser(CoordinateGuesser):
550
559
  """
551
560
  super().__init__(ds)
552
561
 
562
+ def _is_point(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[PointCoordinate]:
563
+ if attributes.standard_name in ["cell", "station", "poi", "point"]:
564
+ return PointCoordinate(c)
565
+
566
+ if attributes.name in ["cell", "station", "poi", "point"]: # WeatherBench
567
+ return PointCoordinate(c)
568
+
569
+ return None
570
+
553
571
  def _is_longitude(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[LongitudeCoordinate]:
554
572
  """Checks if the coordinate is a longitude.
555
573
 
@@ -750,6 +768,9 @@ class DefaultCoordinateGuesser(CoordinateGuesser):
750
768
  if attributes.standard_name == "air_pressure" and attributes.units == "hPa":
751
769
  return LevelCoordinate(c, "pl")
752
770
 
771
+ if attributes.long_name == "pressure" and attributes.units in ["hPa", "Pa"]:
772
+ return LevelCoordinate(c, "pl")
773
+
753
774
  if attributes.name == "level":
754
775
  return LevelCoordinate(c, "pl")
755
776
 
@@ -759,9 +780,6 @@ class DefaultCoordinateGuesser(CoordinateGuesser):
759
780
  if attributes.standard_name == "depth":
760
781
  return LevelCoordinate(c, "depth")
761
782
 
762
- if attributes.name == "vertical" and attributes.units == "hPa":
763
- return LevelCoordinate(c, "pl")
764
-
765
783
  return None
766
784
 
767
785
  def _is_number(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[EnsembleCoordinate]:
@@ -1040,3 +1058,23 @@ class FlavourCoordinateGuesser(CoordinateGuesser):
1040
1058
  return EnsembleCoordinate(c)
1041
1059
 
1042
1060
  return None
1061
+
1062
+ def _is_point(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[PointCoordinate]:
1063
+ """Checks if the coordinate is a point coordinate using the flavour rules.
1064
+
1065
+ Parameters
1066
+ ----------
1067
+ c : xr.DataArray
1068
+ The coordinate to check.
1069
+ attributes : CoordinateAttributes
1070
+ The attributes of the coordinate.
1071
+
1072
+ Returns
1073
+ -------
1074
+ Optional[PointCoordinate]
1075
+ The StepCoorPointCoordinateinate if matched, else None.
1076
+ """
1077
+ if self._match(c, "point", attributes):
1078
+ return PointCoordinate(c)
1079
+
1080
+ return None
@@ -61,9 +61,50 @@ def patch_coordinates(ds: xr.Dataset, coordinates: List[str]) -> Any:
61
61
  return ds
62
62
 
63
63
 
64
+ def patch_rename(ds: xr.Dataset, renames: dict[str, str]) -> Any:
65
+ """Rename variables in the dataset.
66
+
67
+ Parameters
68
+ ----------
69
+ ds : xr.Dataset
70
+ The dataset to patch.
71
+ renames : dict[str, str]
72
+ Mapping from old variable names to new variable names.
73
+
74
+ Returns
75
+ -------
76
+ Any
77
+ The patched dataset.
78
+ """
79
+ return ds.rename(renames)
80
+
81
+
82
+ def patch_sort_coordinate(ds: xr.Dataset, sort_coordinates: List[str]) -> Any:
83
+ """Sort the coordinates of the dataset.
84
+
85
+ Parameters
86
+ ----------
87
+ ds : xr.Dataset
88
+ The dataset to patch.
89
+ sort_coordinates : List[str]
90
+ The coordinates to sort.
91
+
92
+ Returns
93
+ -------
94
+ Any
95
+ The patched dataset.
96
+ """
97
+
98
+ for name in sort_coordinates:
99
+ ds = ds.sortby(name)
100
+ return ds
101
+
102
+
64
103
  PATCHES = {
65
104
  "attributes": patch_attributes,
66
105
  "coordinates": patch_coordinates,
106
+ "rename": patch_rename,
107
+ "sort_coordinates": patch_sort_coordinate,
67
108
  }
68
109
 
69
110
 
@@ -82,7 +123,9 @@ def patch_dataset(ds: xr.Dataset, patch: Dict[str, Dict[str, Any]]) -> Any:
82
123
  Any
83
124
  The patched dataset.
84
125
  """
85
- for what, values in patch.items():
126
+
127
+ ORDER = ["coordinates", "attributes", "rename", "sort_coordinates"]
128
+ for what, values in sorted(patch.items(), key=lambda x: ORDER.index(x[0])):
86
129
  if what not in PATCHES:
87
130
  raise ValueError(f"Unknown patch type {what!r}")
88
131
 
@@ -82,8 +82,12 @@ class Variable:
82
82
 
83
83
  self.time = time
84
84
 
85
- self.shape = tuple(len(c.variable) for c in coordinates if c.is_dim and not c.scalar and not c.is_grid)
86
- self.names = {c.variable.name: c for c in coordinates if c.is_dim and not c.scalar and not c.is_grid}
85
+ self.shape = tuple(
86
+ len(c.variable) for c in coordinates if c.is_dim and not c.scalar and not c.is_grid and not c.is_point
87
+ )
88
+ self.names = {
89
+ c.variable.name: c for c in coordinates if c.is_dim and not c.scalar and not c.is_grid and not c.is_point
90
+ }
87
91
  self.by_name = {c.variable.name: c for c in coordinates}
88
92
 
89
93
  # We need that alias for the time dimension
@@ -7,7 +7,6 @@
7
7
  # granted to it by virtue of its status as an intergovernmental organisation
8
8
  # nor does it submit to any jurisdiction.
9
9
 
10
-
11
10
  import datetime
12
11
  import logging
13
12
  from abc import abstractmethod
@@ -19,6 +18,7 @@ from typing import Optional
19
18
  from typing import Set
20
19
  from typing import Tuple
21
20
 
21
+ import numpy as np
22
22
  from numpy.typing import NDArray
23
23
 
24
24
  from ..grids import nearest_grid_points
@@ -85,6 +85,7 @@ class Complement(Combined):
85
85
  for v in self._source.variables:
86
86
  if v not in self._target.variables:
87
87
  self._variables.append(v)
88
+ LOG.info(f"The following variables will be complemented: {self._variables}")
88
89
 
89
90
  if not self._variables:
90
91
  raise ValueError("Augment: no missing variables")
@@ -96,9 +97,11 @@ class Complement(Combined):
96
97
 
97
98
  @property
98
99
  def statistics(self) -> Dict[str, NDArray[Any]]:
99
- """Returns the statistics of the complemented dataset."""
100
- index = [self._source.name_to_index[v] for v in self._variables]
101
- return {k: v[index] for k, v in self._source.statistics.items()}
100
+ datasets = [self._source, self._target]
101
+ return {
102
+ k: [d.statistics[k][d.name_to_index[i]] for d in datasets for i in d.variables if i in self.variables]
103
+ for k in datasets[0].statistics
104
+ }
102
105
 
103
106
  def statistics_tendencies(self, delta: Optional[datetime.timedelta] = None) -> Dict[str, NDArray[Any]]:
104
107
  index = [self._source.name_to_index[v] for v in self._variables]
@@ -120,7 +123,11 @@ class Complement(Combined):
120
123
  @property
121
124
  def variables_metadata(self) -> Dict[str, Any]:
122
125
  """Returns the metadata of the variables to be added to the target dataset."""
123
- return {k: v for k, v in self._source.variables_metadata.items() if k in self._variables}
126
+ # Merge the two dicts first
127
+ all_meta = {**self._source.variables_metadata, **self._target.variables_metadata}
128
+
129
+ # Filter to keep only desired variables
130
+ return {k: v for k, v in all_meta.items() if k in self._variables}
124
131
 
125
132
  def check_same_variables(self, d1: Dataset, d2: Dataset) -> None:
126
133
  """Checks if the variables in two datasets are the same.
@@ -231,7 +238,7 @@ class ComplementNone(Complement):
231
238
  class ComplementNearest(Complement):
232
239
  """A class to complement a target dataset with variables from a source dataset using nearest neighbor interpolation."""
233
240
 
234
- def __init__(self, target: Any, source: Any, max_distance: float = None) -> None:
241
+ def __init__(self, target: Any, source: Any, max_distance: float = None, k: int = 1) -> None:
235
242
  """Initializes the ComplementNearest class.
236
243
 
237
244
  Parameters
@@ -242,17 +249,25 @@ class ComplementNearest(Complement):
242
249
  The source dataset.
243
250
  max_distance : float, optional
244
251
  The maximum distance for nearest neighbor interpolation, default is None.
252
+ k : int, optional
253
+ The number of k closest neighbors to consider for interpolation
245
254
  """
246
255
  super().__init__(target, source)
247
256
 
248
- self._nearest_grid_points = nearest_grid_points(
257
+ self.k = k
258
+ self._distances, self._nearest_grid_points = nearest_grid_points(
249
259
  self._source.latitudes,
250
260
  self._source.longitudes,
251
261
  self._target.latitudes,
252
262
  self._target.longitudes,
253
263
  max_distance=max_distance,
264
+ k=k,
254
265
  )
255
266
 
267
+ if k == 1:
268
+ self._distances = np.expand_dims(self._distances, axis=1)
269
+ self._nearest_grid_points = np.expand_dims(self._nearest_grid_points, axis=1)
270
+
256
271
  def check_compatibility(self, d1: Dataset, d2: Dataset) -> None:
257
272
  """Checks the compatibility of two datasets for nearest neighbor interpolation.
258
273
 
@@ -285,7 +300,19 @@ class ComplementNearest(Complement):
285
300
  source_data = self._source[index[0], source_index, index[2], ...]
286
301
  target_data = source_data[..., self._nearest_grid_points]
287
302
 
288
- result = target_data[..., index[3]]
303
+ epsilon = 1e-8 # prevent division by zero
304
+ weights = 1.0 / (self._distances + epsilon)
305
+ weights = weights.astype(target_data.dtype)
306
+ weights /= weights.sum(axis=1, keepdims=True) # normalize
307
+
308
+ # Reshape weights to broadcast correctly
309
+ # Add leading singleton dimensions so it matches target_data shape
310
+ while weights.ndim < target_data.ndim:
311
+ weights = np.expand_dims(weights, axis=0)
312
+
313
+ # Compute weighted average along the last dimension
314
+ final_point = np.sum(target_data * weights, axis=-1)
315
+ result = final_point[..., index[3]]
289
316
 
290
317
  return apply_index_to_slices_changes(result, changes)
291
318
 
@@ -330,6 +357,13 @@ def complement_factory(args: Tuple, kwargs: dict) -> Dataset:
330
357
  "nearest": ComplementNearest,
331
358
  }[interpolation]
332
359
 
333
- complement = Class(target=target, source=source)._subset(**kwargs)
360
+ if interpolation == "nearest":
361
+ k = kwargs.pop("k", "1")
362
+ complement = Class(target=target, source=source, k=k)._subset(**kwargs)
363
+
364
+ else:
365
+ complement = Class(target=target, source=source)._subset(**kwargs)
366
+
367
+ joined = _open_dataset([target, complement])
334
368
 
335
- return _open_dataset([target, complement], reorder=source.variables)
369
+ return _open_dataset(joined, reorder=sorted(joined.variables))
@@ -330,8 +330,14 @@ class Combined(Forwards):
330
330
  ValueError
331
331
  If the grids are not the same.
332
332
  """
333
- if (d1.latitudes != d2.latitudes).any() or (d1.longitudes != d2.longitudes).any():
334
- raise ValueError(f"Incompatible grid ({d1} {d2})")
333
+
334
+ # note: not a proper implementation, should be handled
335
+ # in a more consolidated way ...
336
+ rtol = 1.0e-7
337
+ if not np.allclose(d1.latitudes, d2.latitudes, rtol=rtol) or not np.allclose(
338
+ d1.longitudes, d2.longitudes, rtol=rtol
339
+ ):
340
+ raise ValueError(f"Incompatible grid ({d1.longitudes} {d2.longitudes})")
335
341
 
336
342
  def check_same_shape(self, d1: Dataset, d2: Dataset) -> None:
337
343
  """Checks if the shapes of two datasets are the same.
@@ -107,51 +107,6 @@ class S3Store(ReadOnlyStore):
107
107
  return response["Body"].read()
108
108
 
109
109
 
110
- class PlanetaryComputerStore(ReadOnlyStore):
111
- """We write our own Store to access catalogs on Planetary Computer,
112
- as it requires some extra arguments to use xr.open_zarr.
113
- """
114
-
115
- def __init__(self, data_catalog_id: str) -> None:
116
- """Initialize the PlanetaryComputerStore with a data catalog ID.
117
-
118
- Parameters
119
- ----------
120
- data_catalog_id : str
121
- The data catalog ID.
122
- """
123
- self.data_catalog_id = data_catalog_id
124
-
125
- import planetary_computer
126
- import pystac_client
127
-
128
- catalog = pystac_client.Client.open(
129
- "https://planetarycomputer.microsoft.com/api/stac/v1/",
130
- modifier=planetary_computer.sign_inplace,
131
- )
132
- collection = catalog.get_collection(self.data_catalog_id)
133
-
134
- asset = collection.assets["zarr-abfs"]
135
-
136
- if "xarray:storage_options" in asset.extra_fields:
137
- store = {
138
- "store": asset.href,
139
- "storage_options": asset.extra_fields["xarray:storage_options"],
140
- **asset.extra_fields["xarray:open_kwargs"],
141
- }
142
- else:
143
- store = {
144
- "filename_or_obj": asset.href,
145
- **asset.extra_fields["xarray:open_kwargs"],
146
- }
147
-
148
- self.store = store
149
-
150
- def __getitem__(self, key: str) -> bytes:
151
- """Retrieve an item from the store."""
152
- raise NotImplementedError()
153
-
154
-
155
110
  class DebugStore(ReadOnlyStore):
156
111
  """A store to debug the zarr loading."""
157
112
 
@@ -190,11 +145,11 @@ def name_to_zarr_store(path_or_url: str) -> ReadOnlyStore:
190
145
 
191
146
  if store.startswith("http://") or store.startswith("https://"):
192
147
 
193
- parsed = urlparse(store)
194
-
195
148
  if store.endswith(".zip"):
196
149
  import multiurl
197
150
 
151
+ parsed = urlparse(store)
152
+
198
153
  # Zarr cannot handle zip files over HTTP
199
154
  tmpdir = tempfile.gettempdir()
200
155
  name = os.path.basename(parsed.path)
@@ -210,15 +165,7 @@ def name_to_zarr_store(path_or_url: str) -> ReadOnlyStore:
210
165
  os.rename(path + ".tmp", path)
211
166
  return name_to_zarr_store(path)
212
167
 
213
- bits = parsed.netloc.split(".")
214
- if len(bits) == 5 and (bits[1], bits[3], bits[4]) == ("s3", "amazonaws", "com"):
215
- s3_url = f"s3://{bits[0]}{parsed.path}"
216
- store = S3Store(s3_url, region=bits[2])
217
- elif store.startswith("https://planetarycomputer.microsoft.com/"):
218
- data_catalog_id = store.rsplit("/", 1)[-1]
219
- store = PlanetaryComputerStore(data_catalog_id).store
220
- else:
221
- store = HTTPStore(store)
168
+ return HTTPStore(store)
222
169
 
223
170
  return store
224
171
 
@@ -565,6 +512,10 @@ def zarr_lookup(name: str, fail: bool = True) -> Optional[str]:
565
512
  config = load_config()["datasets"]
566
513
  use_search_path_not_found = config.get("use_search_path_not_found", False)
567
514
 
515
+ if name.endswith(".zarr/"):
516
+ LOG.warning("Removing trailing slash from path: %s", name)
517
+ name = name[:-1]
518
+
568
519
  if name.endswith(".zarr") or name.endswith(".zip"):
569
520
 
570
521
  if os.path.exists(name):