roms-tools 3.1.2__py3-none-any.whl → 3.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. roms_tools/__init__.py +3 -0
  2. roms_tools/analysis/cdr_analysis.py +203 -0
  3. roms_tools/analysis/cdr_ensemble.py +198 -0
  4. roms_tools/analysis/roms_output.py +80 -46
  5. roms_tools/data/grids/GLORYS_global_grid.nc +0 -0
  6. roms_tools/download.py +4 -0
  7. roms_tools/plot.py +75 -21
  8. roms_tools/setup/boundary_forcing.py +44 -19
  9. roms_tools/setup/cdr_forcing.py +122 -8
  10. roms_tools/setup/cdr_release.py +161 -8
  11. roms_tools/setup/datasets.py +626 -340
  12. roms_tools/setup/grid.py +138 -137
  13. roms_tools/setup/initial_conditions.py +113 -48
  14. roms_tools/setup/mask.py +63 -7
  15. roms_tools/setup/nesting.py +67 -42
  16. roms_tools/setup/river_forcing.py +45 -19
  17. roms_tools/setup/surface_forcing.py +4 -6
  18. roms_tools/setup/tides.py +1 -2
  19. roms_tools/setup/topography.py +4 -4
  20. roms_tools/setup/utils.py +134 -22
  21. roms_tools/tests/test_analysis/test_cdr_analysis.py +144 -0
  22. roms_tools/tests/test_analysis/test_cdr_ensemble.py +202 -0
  23. roms_tools/tests/test_analysis/test_roms_output.py +61 -3
  24. roms_tools/tests/test_setup/test_boundary_forcing.py +54 -52
  25. roms_tools/tests/test_setup/test_cdr_forcing.py +54 -0
  26. roms_tools/tests/test_setup/test_cdr_release.py +118 -1
  27. roms_tools/tests/test_setup/test_datasets.py +392 -44
  28. roms_tools/tests/test_setup/test_grid.py +222 -115
  29. roms_tools/tests/test_setup/test_initial_conditions.py +94 -41
  30. roms_tools/tests/test_setup/test_surface_forcing.py +2 -1
  31. roms_tools/tests/test_setup/test_utils.py +91 -1
  32. roms_tools/tests/test_setup/utils.py +71 -0
  33. roms_tools/tests/test_tiling/test_join.py +241 -0
  34. roms_tools/tests/test_utils.py +139 -17
  35. roms_tools/tiling/join.py +189 -0
  36. roms_tools/utils.py +131 -99
  37. {roms_tools-3.1.2.dist-info → roms_tools-3.2.0.dist-info}/METADATA +12 -2
  38. {roms_tools-3.1.2.dist-info → roms_tools-3.2.0.dist-info}/RECORD +41 -33
  39. {roms_tools-3.1.2.dist-info → roms_tools-3.2.0.dist-info}/WHEEL +0 -0
  40. {roms_tools-3.1.2.dist-info → roms_tools-3.2.0.dist-info}/licenses/LICENSE +0 -0
  41. {roms_tools-3.1.2.dist-info → roms_tools-3.2.0.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,18 @@
1
+ from __future__ import annotations
2
+
1
3
  import importlib.util
2
4
  import logging
3
- import time
5
+ import typing
4
6
  from collections import Counter, defaultdict
5
- from collections.abc import Callable
7
+ from collections.abc import Callable, Mapping
6
8
  from dataclasses import dataclass, field
7
9
  from datetime import datetime, timedelta
8
10
  from pathlib import Path
9
11
  from types import ModuleType
10
- from typing import ClassVar
12
+ from typing import Any, ClassVar, Literal, TypeAlias, cast
13
+
14
+ if typing.TYPE_CHECKING:
15
+ from roms_tools.setup.grid import Grid
11
16
 
12
17
  import numpy as np
13
18
  import xarray as xr
@@ -21,15 +26,32 @@ from roms_tools.download import (
21
26
  )
22
27
  from roms_tools.setup.fill import LateralFill
23
28
  from roms_tools.setup.utils import (
29
+ Timed,
24
30
  assign_dates_to_climatology,
25
31
  convert_cftime_to_datetime,
26
32
  gc_dist,
33
+ get_target_coords,
27
34
  get_time_type,
28
35
  interpolate_cyclic_time,
29
36
  interpolate_from_climatology,
30
37
  one_dim_fill,
31
38
  )
32
- from roms_tools.utils import _get_pkg_error_msg, _has_gcsfs, _load_data
39
+ from roms_tools.utils import get_dask_chunks, get_pkg_error_msg, has_gcsfs, load_data
40
+
41
+ TConcatEndTypes = Literal["lower", "upper", "both"]
42
+ REPO_ROOT = Path(__file__).resolve().parents[2]
43
+ GLORYS_GLOBAL_GRID_PATH = (
44
+ REPO_ROOT / "roms_tools" / "data" / "grids" / "GLORYS_global_grid.nc"
45
+ )
46
+ DEFAULT_NR_BUFFER_POINTS = (
47
+ 20 # Default number of buffer points for subdomain selection.
48
+ )
49
+ # Balances performance and accuracy:
50
+ # - Too many points → more expensive computations
51
+ # - Too few points → potential boundary artifacts when lateral refill is performed
52
+ # See discussion: https://github.com/CWorthy-ocean/roms-tools/issues/153
53
+ # This default will be applied consistently across all datasets requiring lateral fill.
54
+ RawDataSource: TypeAlias = dict[str, str | Path | list[str | Path] | bool]
33
55
 
34
56
  # lat-lon datasets
35
57
 
@@ -47,7 +69,7 @@ class Dataset:
47
69
  Start time for selecting relevant data. If not provided, no time-based filtering is applied.
48
70
  end_time : Optional[datetime], optional
49
71
  End time for selecting relevant data. If not provided, the dataset selects the time entry
50
- closest to `start_time` within the range `[start_time, start_time + 24 hours]`.
72
+ closest to `start_time` within the range `[start_time, start_time + 24 hours)`.
51
73
  If `start_time` is also not provided, no time-based filtering is applied.
52
74
  dim_names: Dict[str, str], optional
53
75
  Dictionary specifying the names of dimensions in the dataset.
@@ -62,8 +84,19 @@ class Dataset:
62
84
  Indicates whether land values require lateral filling. If `True`, ocean values will be extended into land areas
63
85
  to replace NaNs or non-ocean values (such as atmospheric values in ERA5 data). If `False`, it is assumed that
64
86
  land values are already correctly assigned, and lateral filling will be skipped. Defaults to `True`.
65
- use_dask: bool
87
+ use_dask: bool, optional
66
88
  Indicates whether to use dask for chunking. If True, data is loaded with dask; if False, data is loaded eagerly. Defaults to False.
89
+ read_zarr: bool, optional
90
+ If True, use the zarr engine to read the dataset, and don't use mfdataset.
91
+ Defaults to False.
92
+ allow_flex_time: bool, optional
93
+ Controls how strictly the dataset selects a time entry when `end_time` is not provided (relevant for initial conditions):
94
+
95
+ - If False (default): requires an exact match to `start_time`. Raises a ValueError if no match exists.
96
+ - If True: allows a +24h search window after `start_time` and selects the closest available
97
+ time entry within that window. Raises a ValueError if none are found.
98
+
99
+ Only used when `end_time` is None. Has no effect otherwise.
67
100
  apply_post_processing: bool
68
101
  Indicates whether to post-process the dataset for futher use. Defaults to True.
69
102
 
@@ -94,14 +127,15 @@ class Dataset:
94
127
  }
95
128
  )
96
129
  var_names: dict[str, str]
97
- opt_var_names: dict[str, str] | None = field(default_factory=dict)
98
- climatology: bool | None = False
130
+ opt_var_names: dict[str, str] = field(default_factory=dict)
131
+ climatology: bool = False
99
132
  needs_lateral_fill: bool | None = True
100
- use_dask: bool | None = False
133
+ use_dask: bool = False
134
+ read_zarr: bool = False
135
+ allow_flex_time: bool = False
101
136
  apply_post_processing: bool | None = True
102
- read_zarr: bool | None = False
103
- ds_loader_fn: Callable[[], xr.Dataset] | None = None
104
137
 
138
+ ds_loader_fn: Callable[[], xr.Dataset] | None = None
105
139
  is_global: bool = field(init=False, repr=False)
106
140
  ds: xr.Dataset = field(init=False, repr=False)
107
141
 
@@ -172,17 +206,17 @@ class Dataset:
172
206
  ValueError
173
207
  If a list of files is provided but self.dim_names["time"] is not available or use_dask=False.
174
208
  """
175
- ds = _load_data(
176
- self.filename,
177
- self.dim_names,
178
- self.use_dask or False,
179
- read_zarr=self.read_zarr or False,
209
+ ds = load_data(
210
+ filename=self.filename,
211
+ dim_names=self.dim_names,
212
+ use_dask=self.use_dask,
213
+ read_zarr=self.read_zarr,
180
214
  ds_loader_fn=self.ds_loader_fn,
181
215
  )
182
216
 
183
217
  return ds
184
218
 
185
- def clean_up(self, ds: xr.Dataset, **kwargs) -> xr.Dataset:
219
+ def clean_up(self, ds: xr.Dataset) -> xr.Dataset:
186
220
  """Dummy method to be overridden by child classes to clean up the dataset.
187
221
 
188
222
  This method is intended as a placeholder and should be implemented in subclasses
@@ -215,7 +249,7 @@ class Dataset:
215
249
  """
216
250
  _check_dataset(ds, self.dim_names, self.var_names)
217
251
 
218
- def select_relevant_fields(self, ds) -> xr.Dataset:
252
+ def select_relevant_fields(self, ds: xr.Dataset) -> xr.Dataset:
219
253
  """Selects and returns a subset of the dataset containing only the variables
220
254
  specified in `self.var_names`.
221
255
 
@@ -258,7 +292,7 @@ class Dataset:
258
292
  """
259
293
  return ds
260
294
 
261
- def select_relevant_times(self, ds) -> xr.Dataset:
295
+ def select_relevant_times(self, ds: xr.Dataset) -> xr.Dataset:
262
296
  """Select a subset of the dataset based on the specified time range.
263
297
 
264
298
  This method filters the dataset to include all records between `start_time` and `end_time`.
@@ -266,7 +300,7 @@ class Dataset:
266
300
  after `end_time` are included, even if they fall outside the strict time range.
267
301
 
268
302
  If no `end_time` is specified, the method will select the time range of
269
- [start_time, start_time + 24 hours] and return the closest time entry to `start_time` within that range.
303
+ [start_time, start_time + 24 hours) and return the closest time entry to `start_time` within that range.
270
304
 
271
305
  Parameters
272
306
  ----------
@@ -305,8 +339,17 @@ class Dataset:
305
339
  """
306
340
  time_dim = self.dim_names["time"]
307
341
 
342
+ # Ensure start_time is not None for type safety
343
+ if self.start_time is None:
344
+ raise ValueError("select_relevant_times called but start_time is None.")
345
+
308
346
  ds = _select_relevant_times(
309
- ds, time_dim, self.start_time, self.end_time, self.climatology
347
+ ds,
348
+ time_dim,
349
+ self.start_time,
350
+ self.end_time,
351
+ self.climatology,
352
+ self.allow_flex_time,
310
353
  )
311
354
 
312
355
  return ds
@@ -353,7 +396,7 @@ class Dataset:
353
396
 
354
397
  return ds
355
398
 
356
- def infer_horizontal_resolution(self, ds: xr.Dataset):
399
+ def infer_horizontal_resolution(self, ds: xr.Dataset) -> None:
357
400
  """Estimate and set the average horizontal resolution of a dataset based on
358
401
  latitude and longitude spacing.
359
402
 
@@ -381,7 +424,7 @@ class Dataset:
381
424
  # Set the computed resolution as an attribute
382
425
  self.resolution = resolution
383
426
 
384
- def compute_minimal_grid_spacing(self, ds: xr.Dataset):
427
+ def compute_minimal_grid_spacing(self, ds: xr.Dataset) -> float:
385
428
  """Compute the minimal grid spacing in a dataset based on latitude and longitude
386
429
  spacing, considering Earth's radius.
387
430
 
@@ -443,7 +486,12 @@ class Dataset:
443
486
 
444
487
  return is_global
445
488
 
446
- def concatenate_longitudes(self, ds, end="upper", verbose=False):
489
+ def concatenate_longitudes(
490
+ self,
491
+ ds: xr.Dataset,
492
+ end: TConcatEndTypes = "upper",
493
+ verbose: bool = False,
494
+ ) -> xr.Dataset:
447
495
  """Concatenates fields in dataset twice along the longitude dimension.
448
496
 
449
497
  Parameters
@@ -466,58 +514,12 @@ class Dataset:
466
514
  ds_concatenated : xr.Dataset
467
515
  The concatenated dataset.
468
516
  """
469
- if verbose:
470
- start_time = time.time()
471
-
472
- ds_concatenated = xr.Dataset()
473
-
474
- lon = ds[self.dim_names["longitude"]]
475
- if end == "lower":
476
- lon_minus360 = lon - 360
477
- lon_concatenated = xr.concat(
478
- [lon_minus360, lon], dim=self.dim_names["longitude"]
479
- )
480
-
481
- elif end == "upper":
482
- lon_plus360 = lon + 360
483
- lon_concatenated = xr.concat(
484
- [lon, lon_plus360], dim=self.dim_names["longitude"]
485
- )
486
-
487
- elif end == "both":
488
- lon_minus360 = lon - 360
489
- lon_plus360 = lon + 360
490
- lon_concatenated = xr.concat(
491
- [lon_minus360, lon, lon_plus360], dim=self.dim_names["longitude"]
492
- )
493
-
494
- for var in ds.data_vars:
495
- if self.dim_names["longitude"] in ds[var].dims:
496
- field = ds[var]
497
-
498
- if end == "both":
499
- field_concatenated = xr.concat(
500
- [field, field, field], dim=self.dim_names["longitude"]
501
- )
502
- else:
503
- field_concatenated = xr.concat(
504
- [field, field], dim=self.dim_names["longitude"]
505
- )
506
-
507
- if self.use_dask:
508
- field_concatenated = field_concatenated.chunk(
509
- {self.dim_names["longitude"]: -1}
510
- )
511
- field_concatenated[self.dim_names["longitude"]] = lon_concatenated
512
- ds_concatenated[var] = field_concatenated
513
- else:
514
- ds_concatenated[var] = ds[var]
515
-
516
- ds_concatenated[self.dim_names["longitude"]] = lon_concatenated
517
-
518
- if verbose:
519
- logging.info(
520
- f"Concatenating the data along the longitude dimension: {time.time() - start_time:.3f} seconds"
517
+ with Timed(
518
+ "=== Concatenating the data along the longitude dimension ===",
519
+ verbose=verbose,
520
+ ):
521
+ ds_concatenated = _concatenate_longitudes(
522
+ ds, self.dim_names, end, self.use_dask
521
523
  )
522
524
 
523
525
  return ds_concatenated
@@ -552,14 +554,16 @@ class Dataset:
552
554
  ds = self.ds.astype({var: "float64" for var in self.ds.data_vars})
553
555
  self.ds = ds
554
556
 
557
+ return None
558
+
555
559
  def choose_subdomain(
556
560
  self,
557
- target_coords,
558
- buffer_points=20,
559
- return_copy=False,
560
- return_coords_only=False,
561
- verbose=False,
562
- ):
561
+ target_coords: dict[str, Any],
562
+ buffer_points: int = DEFAULT_NR_BUFFER_POINTS,
563
+ return_copy: bool = False,
564
+ return_coords_only: bool = False,
565
+ verbose: bool = False,
566
+ ) -> xr.Dataset | Dataset | None:
563
567
  """Selects a subdomain from the xarray Dataset based on specified target
564
568
  coordinates, extending the selection by a defined buffer. Adjusts longitude
565
569
  ranges as necessary to accommodate the dataset's expected range and handles
@@ -596,94 +600,15 @@ class Dataset:
596
600
  ValueError
597
601
  If the selected latitude or longitude range does not intersect with the dataset.
598
602
  """
599
- lat_min = target_coords["lat"].min().values
600
- lat_max = target_coords["lat"].max().values
601
- lon_min = target_coords["lon"].min().values
602
- lon_max = target_coords["lon"].max().values
603
-
604
- margin = self.resolution * buffer_points
605
-
606
- # Select the subdomain in latitude direction (so that we have to concatenate fewer latitudes below if concatenation is necessary)
607
- subdomain = self.ds.sel(
608
- **{
609
- self.dim_names["latitude"]: slice(lat_min - margin, lat_max + margin),
610
- }
603
+ subdomain = choose_subdomain(
604
+ ds=self.ds,
605
+ dim_names=self.dim_names,
606
+ resolution=self.resolution,
607
+ is_global=self.is_global,
608
+ target_coords=target_coords,
609
+ buffer_points=buffer_points,
610
+ use_dask=self.use_dask,
611
611
  )
612
- lon = subdomain[self.dim_names["longitude"]]
613
-
614
- if self.is_global:
615
- concats = []
616
- # Concatenate only if necessary
617
- if lon_max + margin > lon.max():
618
- # See if shifting by +360 degrees helps
619
- if (lon_min - margin > (lon + 360).min()) and (
620
- lon_max + margin < (lon + 360).max()
621
- ):
622
- subdomain[self.dim_names["longitude"]] = lon + 360
623
- lon = subdomain[self.dim_names["longitude"]]
624
- else:
625
- concats.append("upper")
626
- if lon_min - margin < lon.min():
627
- # See if shifting by -360 degrees helps
628
- if (lon_min - margin > (lon - 360).min()) and (
629
- lon_max + margin < (lon - 360).max()
630
- ):
631
- subdomain[self.dim_names["longitude"]] = lon - 360
632
- lon = subdomain[self.dim_names["longitude"]]
633
- else:
634
- concats.append("lower")
635
-
636
- if concats:
637
- end = "both" if len(concats) == 2 else concats[0]
638
- subdomain = self.concatenate_longitudes(
639
- subdomain, end=end, verbose=False
640
- )
641
- lon = subdomain[self.dim_names["longitude"]]
642
-
643
- else:
644
- # Adjust longitude range if needed to match the expected range
645
- if not target_coords["straddle"]:
646
- if lon.min() < -180:
647
- if lon_max + margin > 0:
648
- lon_min -= 360
649
- lon_max -= 360
650
- elif lon.min() < 0:
651
- if lon_max + margin > 180:
652
- lon_min -= 360
653
- lon_max -= 360
654
-
655
- if target_coords["straddle"]:
656
- if lon.max() > 360:
657
- if lon_min - margin < 180:
658
- lon_min += 360
659
- lon_max += 360
660
- elif lon.max() > 180:
661
- if lon_min - margin < 0:
662
- lon_min += 360
663
- lon_max += 360
664
- # Select the subdomain in longitude direction
665
-
666
- subdomain = subdomain.sel(
667
- **{
668
- self.dim_names["longitude"]: slice(lon_min - margin, lon_max + margin),
669
- }
670
- )
671
-
672
- # Check if the selected subdomain has zero dimensions in latitude or longitude
673
- if subdomain[self.dim_names["latitude"]].size == 0:
674
- raise ValueError("Selected latitude range does not intersect with dataset.")
675
-
676
- if subdomain[self.dim_names["longitude"]].size == 0:
677
- raise ValueError(
678
- "Selected longitude range does not intersect with dataset."
679
- )
680
-
681
- # Adjust longitudes to expected range if needed
682
- lon = subdomain[self.dim_names["longitude"]]
683
- if target_coords["straddle"]:
684
- subdomain[self.dim_names["longitude"]] = xr.where(lon > 180, lon - 360, lon)
685
- else:
686
- subdomain[self.dim_names["longitude"]] = xr.where(lon < 0, lon + 360, lon)
687
612
 
688
613
  if return_coords_only:
689
614
  # Create and return a dataset with only latitudes and longitudes
@@ -696,6 +621,7 @@ class Dataset:
696
621
  return Dataset.from_ds(self, subdomain)
697
622
  else:
698
623
  self.ds = subdomain
624
+ return None
699
625
 
700
626
  def apply_lateral_fill(self):
701
627
  """Apply lateral fill to variables using the dataset's mask and grid dimensions.
@@ -715,10 +641,6 @@ class Dataset:
715
641
  point to the same variable in the dataset.
716
642
  """
717
643
  if self.needs_lateral_fill:
718
- logging.info(
719
- "Applying 2D horizontal fill to the source data before regridding."
720
- )
721
-
722
644
  lateral_fill = LateralFill(
723
645
  self.ds["mask"],
724
646
  [self.dim_names["latitude"], self.dim_names["longitude"]],
@@ -749,10 +671,6 @@ class Dataset:
749
671
  else:
750
672
  # Apply standard lateral fill for other variables
751
673
  self.ds[var_name] = lateral_fill.apply(self.ds[var_name])
752
- else:
753
- logging.info(
754
- "2D horizontal fill is skipped because source data already contains filled values."
755
- )
756
674
 
757
675
  def extrapolate_deepest_to_bottom(self):
758
676
  """Extrapolate deepest non-NaN values to fill bottom NaNs along the depth
@@ -769,7 +687,7 @@ class Dataset:
769
687
  )
770
688
 
771
689
  @classmethod
772
- def from_ds(cls, original_dataset: "Dataset", ds: xr.Dataset) -> "Dataset":
690
+ def from_ds(cls, original_dataset: Dataset, ds: xr.Dataset) -> Dataset:
773
691
  """Substitute the internal dataset of a Dataset object with a new xarray
774
692
  Dataset.
775
693
 
@@ -871,7 +789,7 @@ class TPXODataset(Dataset):
871
789
  ValueError
872
790
  If longitude or latitude values do not match the grid.
873
791
  """
874
- ds_grid = _load_data(self.grid_filename, self.dim_names, self.use_dask)
792
+ ds_grid = load_data(self.grid_filename, self.dim_names, self.use_dask)
875
793
 
876
794
  # Define mask and coordinate names based on location
877
795
  if self.location == "h":
@@ -902,21 +820,13 @@ class TPXODataset(Dataset):
902
820
 
903
821
  # Drop all dimensions except 'longitude' and 'latitude'
904
822
  dims_to_keep = {"longitude", "latitude"}
905
- dims_to_drop = [dim for dim in ds_grid.dims if dim not in dims_to_keep]
823
+ dims_to_drop: set[str] = set(ds_grid.dims) - dims_to_keep
906
824
  if dims_to_drop:
907
825
  ds_grid = ds_grid.isel({dim: 0 for dim in dims_to_drop})
908
826
 
909
827
  # Ensure correct dimension order
910
828
  ds_grid = ds_grid.transpose("latitude", "longitude")
911
829
 
912
- dims_to_keep = {"longitude", "latitude"}
913
- dims_to_drop = set(ds_grid.dims) - dims_to_keep
914
- ds_grid = (
915
- ds_grid.isel({dim: 0 for dim in dims_to_drop}) if dims_to_drop else ds_grid
916
- )
917
- # Bring dimensions in correct order
918
- ds_grid = ds_grid.transpose("latitude", "longitude")
919
-
920
830
  ds = ds.rename({"con": "nc"})
921
831
  ds = ds.assign_coords(
922
832
  {
@@ -1051,7 +961,7 @@ class GLORYSDataset(Dataset):
1051
961
  }
1052
962
  )
1053
963
 
1054
- climatology: bool | None = False
964
+ climatology: bool = False
1055
965
 
1056
966
  def post_process(self):
1057
967
  """Apply a mask to the dataset based on the 'zeta' variable, with 0 where 'zeta'
@@ -1067,19 +977,29 @@ class GLORYSDataset(Dataset):
1067
977
  None
1068
978
  The dataset is modified in-place by applying the mask to each variable.
1069
979
  """
1070
- mask = xr.where(
1071
- self.ds[self.var_names["zeta"]].isel({self.dim_names["time"]: 0}).isnull(),
1072
- 0,
1073
- 1,
1074
- )
1075
- mask_vel = xr.where(
1076
- self.ds[self.var_names["u"]]
1077
- .isel({self.dim_names["time"]: 0, self.dim_names["depth"]: 0})
1078
- .isnull(),
1079
- 0,
1080
- 1,
1081
- )
980
+ zeta = self.ds[self.var_names["zeta"]]
981
+ u = self.ds[self.var_names["u"]]
1082
982
 
983
+ # Select time=0 if time dimension exists, otherwise use data as-is
984
+ if self.dim_names["time"] in zeta.dims:
985
+ zeta_ref = zeta.isel({self.dim_names["time"]: 0})
986
+ else:
987
+ zeta_ref = zeta
988
+
989
+ if self.dim_names["time"] in u.dims:
990
+ u_ref = u.isel({self.dim_names["time"]: 0})
991
+ else:
992
+ u_ref = u
993
+
994
+ # Also handle depth for velocity
995
+ if self.dim_names["depth"] in u_ref.dims:
996
+ u_ref = u_ref.isel({self.dim_names["depth"]: 0})
997
+
998
+ # Create masks
999
+ mask = xr.where(zeta_ref.isnull(), 0, 1)
1000
+ mask_vel = xr.where(u_ref.isnull(), 0, 1)
1001
+
1002
+ # Save to dataset
1083
1003
  self.ds["mask"] = mask
1084
1004
  self.ds["mask_vel"] = mask_vel
1085
1005
 
@@ -1130,7 +1050,7 @@ class GLORYSDefaultDataset(GLORYSDataset):
1130
1050
 
1131
1051
  spec = importlib.util.find_spec(package_name)
1132
1052
  if not spec:
1133
- msg = _get_pkg_error_msg("cloud-based GLORYS data", package_name, "stream")
1053
+ msg = get_pkg_error_msg("cloud-based GLORYS data", package_name, "stream")
1134
1054
  raise RuntimeError(msg)
1135
1055
 
1136
1056
  try:
@@ -1151,14 +1071,36 @@ class GLORYSDefaultDataset(GLORYSDataset):
1151
1071
  The streaming dataset
1152
1072
  """
1153
1073
  copernicusmarine = self._load_copernicus()
1154
- return copernicusmarine.open_dataset(
1074
+
1075
+ # ds = copernicusmarine.download_functions.download_zarr.open_dataset_from_arco_series(
1076
+ # dataset_url="https://s3.waw3-1.cloudferro.com/mdl-arco-geo-025/arco/GLOBAL_MULTIYEAR_PHY_001_030/cmems_mod_glo_phy_my_0.083deg_P1D-m_202311/geoChunked.zarr",
1077
+ # variables=["thetao", "so", "uo", "vo", "zos"],
1078
+ # geographical_parameters=copernicusmarine.download_functions.subset_parameters.GeographicalParameters(),
1079
+ # temporal_parameters=copernicusmarine.download_functions.subset_parameters.TemporalParameters(
1080
+ # start_datetime=self.start_time, end_datetime=self.end_time
1081
+ # ),
1082
+ # depth_parameters=copernicusmarine.download_functions.subset_parameters.DepthParameters(),
1083
+ # coordinates_selection_method="outside",
1084
+ # optimum_dask_chunking={
1085
+ # "time": 1,
1086
+ # "depth": -1,
1087
+ # "latitude": -1,
1088
+ # "longitude": -1,
1089
+ # },
1090
+ # )
1091
+
1092
+ ds = copernicusmarine.open_dataset(
1155
1093
  self.dataset_name,
1156
1094
  start_datetime=self.start_time,
1157
1095
  end_datetime=self.end_time,
1158
1096
  service="arco-geo-series",
1159
- coordinates_selection_method="inside",
1160
- chunk_size_limit=2,
1097
+ coordinates_selection_method="outside",
1098
+ chunk_size_limit=-1,
1161
1099
  )
1100
+ chunks = get_dask_chunks(self.dim_names)
1101
+ ds = ds.chunk(chunks)
1102
+
1103
+ return ds
1162
1104
 
1163
1105
 
1164
1106
  @dataclass(kw_only=True)
@@ -1285,7 +1227,7 @@ class UnifiedBGCDataset(UnifiedDataset):
1285
1227
  }
1286
1228
  )
1287
1229
 
1288
- climatology: bool | None = True
1230
+ climatology: bool = True
1289
1231
 
1290
1232
 
1291
1233
  @dataclass(kw_only=True)
@@ -1307,7 +1249,7 @@ class UnifiedBGCSurfaceDataset(UnifiedDataset):
1307
1249
  }
1308
1250
  )
1309
1251
 
1310
- climatology: bool | None = True
1252
+ climatology: bool = True
1311
1253
 
1312
1254
 
1313
1255
  @dataclass(kw_only=True)
@@ -1422,9 +1364,9 @@ class CESMBGCDataset(CESMDataset):
1422
1364
  }
1423
1365
  )
1424
1366
 
1425
- climatology: bool | None = False
1367
+ climatology: bool = False
1426
1368
 
1427
- def post_process(self):
1369
+ def post_process(self) -> None:
1428
1370
  """
1429
1371
  Processes and converts CESM data values as follows:
1430
1372
  - Convert depth values from cm to m.
@@ -1493,9 +1435,9 @@ class CESMBGCSurfaceForcingDataset(CESMDataset):
1493
1435
  }
1494
1436
  )
1495
1437
 
1496
- climatology: bool | None = False
1438
+ climatology: bool = False
1497
1439
 
1498
- def post_process(self):
1440
+ def post_process(self) -> None:
1499
1441
  """Perform post-processing on the dataset to remove specific variables.
1500
1442
 
1501
1443
  This method checks if the variable "z_t" exists in the dataset. If it does,
@@ -1542,9 +1484,9 @@ class ERA5Dataset(Dataset):
1542
1484
  }
1543
1485
  )
1544
1486
 
1545
- climatology: bool | None = False
1487
+ climatology: bool = False
1546
1488
 
1547
- def post_process(self):
1489
+ def post_process(self) -> None:
1548
1490
  """
1549
1491
  Processes and converts ERA5 data values as follows:
1550
1492
  - Convert radiation values from J/m^2 to W/m^2.
@@ -1632,10 +1574,10 @@ class ERA5ARCODataset(ERA5Dataset):
1632
1574
  }
1633
1575
  )
1634
1576
 
1635
- def __post_init__(self):
1577
+ def __post_init__(self) -> None:
1636
1578
  self.read_zarr = True
1637
- if not _has_gcsfs():
1638
- msg = _get_pkg_error_msg("cloud-based ERA5 data", "gcsfs", "stream")
1579
+ if not has_gcsfs():
1580
+ msg = get_pkg_error_msg("cloud-based ERA5 data", "gcsfs", "stream")
1639
1581
  raise RuntimeError(msg)
1640
1582
 
1641
1583
  super().__post_init__()
@@ -1664,9 +1606,9 @@ class ERA5Correction(Dataset):
1664
1606
  "time": "time",
1665
1607
  }
1666
1608
  )
1667
- climatology: bool | None = True
1609
+ climatology: bool = True
1668
1610
 
1669
- def __post_init__(self):
1611
+ def __post_init__(self) -> None:
1670
1612
  if not self.climatology:
1671
1613
  raise NotImplementedError(
1672
1614
  "Correction data must be a climatology. Set climatology to True."
@@ -1674,32 +1616,31 @@ class ERA5Correction(Dataset):
1674
1616
 
1675
1617
  super().__post_init__()
1676
1618
 
1677
- def choose_subdomain(self, target_coords, straddle: bool):
1678
- """Converts longitude values in the dataset if necessary and selects a subdomain
1679
- based on the specified coordinates.
1619
+ def match_subdomain(self, target_coords: dict[str, Any]) -> None:
1620
+ """
1621
+ Selects a subdomain from the dataset matching the specified coordinates.
1680
1622
 
1681
- This method converts longitude values between different ranges if required and then extracts a subset of the
1682
- dataset according to the given coordinates. It updates the dataset in place to reflect the selected subdomain.
1623
+ This method extracts a subset of the dataset (`self.ds`) based on given latitude
1624
+ and longitude values. If the dataset spans the globe, it concatenates longitudes
1625
+ to ensure seamless wrapping.
1683
1626
 
1684
1627
  Parameters
1685
1628
  ----------
1686
- target_coords : dict
1687
- A dictionary specifying the target coordinates for selecting the subdomain. Keys should correspond to the
1688
- dimension names of the dataset (e.g., latitude and longitude), and values should be the desired ranges or
1689
- specific coordinate values.
1690
- straddle : bool
1691
- If True, assumes that target longitudes are in the range [-180, 180]. If False, assumes longitudes are in the
1692
- range [0, 360]. This parameter determines how longitude values are converted if necessary.
1629
+ target_coords : dict[str, Any]
1630
+ A dictionary containing the target latitude and longitude values to select.
1631
+ Expected keys: "lat" and "lon", each mapped to a DataArray of coordinates.
1693
1632
 
1694
1633
  Raises
1695
1634
  ------
1696
1635
  ValueError
1697
- If the specified subdomain does not fully contain the specified latitude or longitude values. This can occur
1698
- if the dataset does not cover the full range of provided coordinates.
1636
+ If the selected subdomain does not contain all specified latitude or
1637
+ longitude values.
1699
1638
 
1700
1639
  Notes
1701
1640
  -----
1702
- - The dataset (`self.ds`) is updated in place to reflect the chosen subdomain.
1641
+ - The dataset (`self.ds`) is updated in place.
1642
+ - Assumes latitude values in `target_coords["lat"]` are within dataset bounds.
1643
+ - For global datasets, longitude concatenation is applied unconditionally.
1703
1644
  """
1704
1645
  # Select the subdomain in latitude direction (so that we have to concatenate fewer latitudes below if concatenation is performed)
1705
1646
  subdomain = self.ds.sel({self.dim_names["latitude"]: target_coords["lat"]})
@@ -1813,7 +1754,7 @@ class RiverDataset:
1813
1754
  dim_names: dict[str, str]
1814
1755
  var_names: dict[str, str]
1815
1756
  opt_var_names: dict[str, str] | None = field(default_factory=dict)
1816
- climatology: bool | None = False
1757
+ climatology: bool = False
1817
1758
  ds: xr.Dataset = field(init=False, repr=False)
1818
1759
 
1819
1760
  def __post_init__(self):
@@ -1846,7 +1787,7 @@ class RiverDataset:
1846
1787
  ds : xr.Dataset
1847
1788
  The loaded xarray Dataset containing the forcing data.
1848
1789
  """
1849
- ds = _load_data(
1790
+ ds = load_data(
1850
1791
  self.filename, self.dim_names, use_dask=False, decode_times=False
1851
1792
  )
1852
1793
 
@@ -1998,7 +1939,7 @@ class RiverDataset:
1998
1939
  The dataset with rivers sorted by their volume in descending order.
1999
1940
  If the volume variable is not available, the original dataset is returned.
2000
1941
  """
2001
- if "vol" in self.opt_var_names:
1942
+ if self.opt_var_names is not None and "vol" in self.opt_var_names:
2002
1943
  volume_values = ds[self.opt_var_names["vol"]].values
2003
1944
  if isinstance(volume_values, np.ndarray):
2004
1945
  # Check if all volume values are the same
@@ -2158,7 +2099,7 @@ class DaiRiverDataset(RiverDataset):
2158
2099
  "vol": "vol_stn",
2159
2100
  }
2160
2101
  )
2161
- climatology: bool | None = False
2102
+ climatology: bool = False
2162
2103
 
2163
2104
  def add_time_info(self, ds: xr.Dataset) -> xr.Dataset:
2164
2105
  """Adds time information to the dataset based on the climatology flag and
@@ -2737,139 +2678,212 @@ def _check_dataset(
2737
2678
 
2738
2679
 
2739
2680
  def _select_relevant_times(
2740
- ds, time_dim, start_time, end_time=None, climatology=False
2681
+ ds: xr.Dataset,
2682
+ time_dim: str,
2683
+ start_time: datetime,
2684
+ end_time: datetime | None = None,
2685
+ climatology: bool = False,
2686
+ allow_flex_time: bool = False,
2741
2687
  ) -> xr.Dataset:
2742
- """Select a subset of the dataset based on the specified time range.
2688
+ """
2689
+ Select a subset of the dataset based on time constraints.
2743
2690
 
2744
- This method filters the dataset to include all records between `start_time` and `end_time`.
2745
- Additionally, it ensures that one record at or before `start_time` and one record at or
2746
- after `end_time` are included, even if they fall outside the strict time range.
2691
+ This function supports two main use cases:
2747
2692
 
2748
- If no `end_time` is specified, the method will select the time range of
2749
- [start_time, start_time + 24 hours] and return the closest time entry to `start_time` within that range.
2693
+ 1. **Time range selection (start_time + end_time provided):**
2694
+ - Returns all records strictly between `start_time` and `end_time`.
2695
+ - Ensures at least one record at or before `start_time` and one record at or
2696
+ after `end_time` are included, even if they fall outside the strict range.
2697
+
2698
+ 2. **Initial condition selection (start_time provided, end_time=None):**
2699
+ - Delegates to `_select_initial_time`, which reduces the dataset to exactly one
2700
+ time entry.
2701
+ - If `allow_flex_time=True`, a +24-hour buffer around `start_time` is allowed,
2702
+ and the closest timestamp is chosen.
2703
+ - If `allow_flex_time=False`, requires an exact timestamp match.
2704
+
2705
+ Additional behavior:
2706
+ - If `climatology=True`, the dataset must contain exactly 12 time steps. If valid,
2707
+ the climatology dataset is returned without further filtering.
2708
+ - If the dataset uses `cftime` datetime objects, these are converted to
2709
+ `np.datetime64` before filtering.
2750
2710
 
2751
2711
  Parameters
2752
2712
  ----------
2753
2713
  ds : xr.Dataset
2754
- The input dataset to be filtered. Must contain a time dimension.
2755
- time_dim: str
2756
- Name of time dimension.
2714
+ The dataset to filter. Must contain a valid time dimension.
2715
+ time_dim : str
2716
+ Name of the time dimension in `ds`.
2757
2717
  start_time : datetime
2758
- The start time for selecting relevant data.
2759
- end_time : Optional[datetime], optional
2760
- The end time for selecting relevant data. If not provided, only data at the start_time is selected if start_time is provided.
2761
- climatology : bool
2762
- Indicates whether the dataset is climatological. Defaults to False.
2718
+ Start time for filtering.
2719
+ end_time : datetime or None
2720
+ End time for filtering. If `None`, the function assumes an initial condition
2721
+ use case and selects exactly one timestamp.
2722
+ climatology : bool, optional
2723
+ If True, requires exactly 12 time steps and bypasses normal filtering.
2724
+ Defaults to False.
2725
+ allow_flex_time : bool, optional
2726
+ Whether to allow a +24h search window after `start_time` when `end_time`
2727
+ is None. If False (default), requires an exact match.
2763
2728
 
2764
2729
  Returns
2765
2730
  -------
2766
2731
  xr.Dataset
2767
- A dataset filtered to the specified time range, including the closest entries
2768
- at or before `start_time` and at or after `end_time` if applicable.
2732
+ A filtered dataset containing only the selected time entries.
2769
2733
 
2770
2734
  Raises
2771
2735
  ------
2772
2736
  ValueError
2773
- If no matching times are found between `start_time` and `start_time + 24 hours`.
2737
+ - If `climatology=True` but the dataset does not contain exactly 12 time steps.
2738
+ - If `climatology=False` and the dataset contains integer time values.
2739
+ - If no valid records are found within the requested range or window.
2774
2740
 
2775
2741
  Warns
2776
2742
  -----
2777
2743
  UserWarning
2778
- If the dataset contains exactly 12 time steps but the climatology flag is not set.
2779
- This may indicate that the dataset represents climatology data.
2780
-
2781
- UserWarning
2782
- If no records at or before `start_time` or no records at or after `end_time` are found.
2783
-
2784
- UserWarning
2785
- If the dataset does not contain any time dimension or the time dimension is incorrectly named.
2744
+ - If no records exist at or before `start_time` or at or after `end_time`.
2745
+ - If the specified time dimension does not exist in the dataset.
2786
2746
 
2787
2747
  Notes
2788
2748
  -----
2789
- - If the `climatology` flag is set and `end_time` is not provided, the method will
2790
- interpolate initial conditions from climatology data.
2791
- - If the dataset uses `cftime` datetime objects, these will be converted to standard
2792
- `np.datetime64` objects before filtering.
2749
+ - For initial conditions (end_time=None), see `_select_initial_time` for details
2750
+ on strict vs. flexible selection behavior.
2751
+ - Logs warnings instead of failing hard when boundary records are missing, and
2752
+ defaults to using the earliest or latest available time in such cases.
2793
2753
  """
2794
- if time_dim in ds.variables:
2795
- if climatology:
2796
- if len(ds[time_dim]) != 12:
2797
- raise ValueError(
2798
- f"The dataset contains {len(ds[time_dim])} time steps, but the climatology flag is set to True, which requires exactly 12 time steps."
2799
- )
2800
- if not end_time:
2801
- # Convert from timedelta64[ns] to fractional days
2802
- ds["time"] = ds["time"] / np.timedelta64(1, "D")
2803
- # Interpolate from climatology for initial conditions
2804
- ds = interpolate_from_climatology(ds, time_dim, start_time)
2805
- else:
2806
- time_type = get_time_type(ds[time_dim])
2807
- if time_type == "int":
2808
- raise ValueError(
2809
- "The dataset contains integer time values, which are only supported when the climatology flag is set to True. However, your climatology flag is set to False."
2810
- )
2811
- if time_type == "cftime":
2812
- ds = ds.assign_coords(
2813
- {time_dim: convert_cftime_to_datetime(ds[time_dim])}
2814
- )
2815
- if end_time:
2816
- end_time = end_time
2817
-
2818
- # Identify records before or at start_time
2819
- before_start = ds[time_dim] <= np.datetime64(start_time)
2820
- if before_start.any():
2821
- closest_before_start = (
2822
- ds[time_dim].where(before_start, drop=True).max()
2823
- )
2824
- else:
2825
- logging.warning("No records found at or before the start_time.")
2826
- closest_before_start = ds[time_dim].min()
2754
+ if time_dim not in ds.variables:
2755
+ logging.warning(
2756
+ f"Dataset does not contain time dimension '{time_dim}'. "
2757
+ "Please check variable naming or dataset structure."
2758
+ )
2759
+ return ds
2827
2760
 
2828
- # Identify records after or at end_time
2829
- after_end = ds[time_dim] >= np.datetime64(end_time)
2830
- if after_end.any():
2831
- closest_after_end = ds[time_dim].where(after_end, drop=True).min()
2832
- else:
2833
- logging.warning("No records found at or after the end_time.")
2834
- closest_after_end = ds[time_dim].max()
2761
+ time_type = get_time_type(ds[time_dim])
2835
2762
 
2836
- # Select records within the time range and add the closest before/after
2837
- within_range = (ds[time_dim] > np.datetime64(start_time)) & (
2838
- ds[time_dim] < np.datetime64(end_time)
2839
- )
2840
- selected_times = ds[time_dim].where(
2841
- within_range
2842
- | (ds[time_dim] == closest_before_start)
2843
- | (ds[time_dim] == closest_after_end),
2844
- drop=True,
2845
- )
2846
- ds = ds.sel({time_dim: selected_times})
2847
- else:
2848
- # Look in time range [start_time, start_time + 24h]
2849
- end_time = start_time + timedelta(days=1)
2850
- times = (np.datetime64(start_time) <= ds[time_dim]) & (
2851
- ds[time_dim] < np.datetime64(end_time)
2852
- )
2853
- if np.all(~times):
2854
- raise ValueError(
2855
- f"The dataset does not contain any time entries between the specified start_time: {start_time} "
2856
- f"and {start_time + timedelta(hours=24)}. "
2857
- "Please ensure the dataset includes time entries for that range."
2858
- )
2763
+ if climatology:
2764
+ if len(ds[time_dim]) != 12:
2765
+ raise ValueError(
2766
+ f"The dataset contains {len(ds[time_dim])} time steps, but the climatology flag is set to True, which requires exactly 12 time steps."
2767
+ )
2768
+ else:
2769
+ if time_type == "int":
2770
+ raise ValueError(
2771
+ "The dataset contains integer time values, which are only supported when the climatology flag is set to True. However, your climatology flag is set to False."
2772
+ )
2773
+ if time_type == "cftime":
2774
+ ds = ds.assign_coords({time_dim: convert_cftime_to_datetime(ds[time_dim])})
2859
2775
 
2860
- ds = ds.where(times, drop=True)
2861
- if ds.sizes[time_dim] > 1:
2862
- # Pick the time closest to start_time
2863
- ds = ds.isel({time_dim: 0})
2864
- logging.info(
2865
- f"Selected time entry closest to the specified start_time ({start_time}) within the range [{start_time}, {start_time + timedelta(hours=24)}]: {ds[time_dim].values}"
2866
- )
2776
+ if not end_time:
2777
+ # Assume we are looking for exactly one time record for initial conditions
2778
+ return _select_initial_time(
2779
+ ds, time_dim, start_time, climatology, allow_flex_time
2780
+ )
2781
+
2782
+ if climatology:
2783
+ return ds
2784
+
2785
+ # Identify records before or at start_time
2786
+ before_start = ds[time_dim] <= np.datetime64(start_time)
2787
+ if before_start.any():
2788
+ closest_before_start = ds[time_dim].where(before_start, drop=True)[-1]
2789
+ else:
2790
+ logging.warning(f"No records found at or before the start_time: {start_time}.")
2791
+ closest_before_start = ds[time_dim][0]
2792
+
2793
+ # Identify records after or at end_time
2794
+ after_end = ds[time_dim] >= np.datetime64(end_time)
2795
+ if after_end.any():
2796
+ closest_after_end = ds[time_dim].where(after_end, drop=True).min()
2867
2797
  else:
2798
+ logging.warning(f"No records found at or after the end_time: {end_time}.")
2799
+ closest_after_end = ds[time_dim].max()
2800
+
2801
+ # Select records within the time range and add the closest before/after
2802
+ within_range = (ds[time_dim] > np.datetime64(start_time)) & (
2803
+ ds[time_dim] < np.datetime64(end_time)
2804
+ )
2805
+ selected_times = ds[time_dim].where(
2806
+ within_range
2807
+ | (ds[time_dim] == closest_before_start)
2808
+ | (ds[time_dim] == closest_after_end),
2809
+ drop=True,
2810
+ )
2811
+ ds = ds.sel({time_dim: selected_times})
2812
+
2813
+ return ds
2814
+
2815
+
2816
+ def _select_initial_time(
2817
+ ds: xr.Dataset,
2818
+ time_dim: str,
2819
+ ini_time: datetime,
2820
+ climatology: bool,
2821
+ allow_flex_time: bool = False,
2822
+ ) -> xr.Dataset:
2823
+ """Select exactly one initial time from dataset.
2824
+
2825
+ Parameters
2826
+ ----------
2827
+ ds : xr.Dataset
2828
+ The input dataset with a time dimension.
2829
+ time_dim : str
2830
+ Name of the time dimension.
2831
+ ini_time : datetime
2832
+ The desired initial time.
2833
+ allow_flex_time : bool
2834
+ - If True: allow a +24h window and pick the closest available timestamp.
2835
+ - If False (default): require an exact match, otherwise raise ValueError.
2836
+
2837
+ Returns
2838
+ -------
2839
+ xr.Dataset
2840
+ Dataset reduced to exactly one timestamp.
2841
+
2842
+ Raises
2843
+ ------
2844
+ ValueError
2845
+ If no matching time is found (when `allow_flex_time=False`), or no entries are
2846
+ available within the +24h window (when `allow_flex_time=True`).
2847
+ """
2848
+ if climatology:
2849
+ # Convert from timedelta64[ns] to fractional days
2850
+ ds["time"] = ds["time"] / np.timedelta64(1, "D")
2851
+ # Interpolate from climatology for initial conditions
2852
+ return interpolate_from_climatology(ds, time_dim, ini_time)
2853
+
2854
+ if allow_flex_time:
2855
+ # Look in time range [ini_time, ini_time + 24h)
2856
+ end_time = ini_time + timedelta(days=1)
2857
+ times = (np.datetime64(ini_time) <= ds[time_dim]) & (
2858
+ ds[time_dim] < np.datetime64(end_time)
2859
+ )
2860
+
2861
+ if np.all(~times):
2862
+ raise ValueError(
2863
+ f"No time entries found between {ini_time} and {end_time}."
2864
+ )
2865
+
2866
+ ds = ds.where(times, drop=True)
2867
+ if ds.sizes[time_dim] > 1:
2868
+ # Pick the time closest to start_time
2869
+ ds = ds.isel({time_dim: 0})
2870
+
2868
2871
  logging.warning(
2869
- "Dataset does not contain any time information. Please check if the time dimension "
2870
- "is correctly named or if the dataset includes time data."
2872
+ f"Selected time entry closest to the specified start_time in +24 hour range: {ds[time_dim].values}"
2871
2873
  )
2872
2874
 
2875
+ else:
2876
+ # Strict match required
2877
+ if not (ds[time_dim].values == np.datetime64(ini_time)).any():
2878
+ raise ValueError(
2879
+ f"No exact match found for initial time {ini_time}. Consider setting allow_flex_time to True."
2880
+ )
2881
+
2882
+ ds = ds.sel({time_dim: np.datetime64(ini_time)})
2883
+
2884
+ if time_dim not in ds.dims:
2885
+ ds = ds.expand_dims(time_dim)
2886
+
2873
2887
  return ds
2874
2888
 
2875
2889
 
@@ -2998,7 +3012,7 @@ def _deduplicate_river_names(
2998
3012
 
2999
3013
  # Count all names
3000
3014
  name_counts = Counter(names)
3001
- seen = defaultdict(int)
3015
+ seen: defaultdict[str, int] = defaultdict(int)
3002
3016
 
3003
3017
  unique_names = []
3004
3018
  for name in names:
@@ -3017,3 +3031,275 @@ def _deduplicate_river_names(
3017
3031
  ds[name_var] = updated_array
3018
3032
 
3019
3033
  return ds
3034
+
3035
+
3036
+ def _concatenate_longitudes(
3037
+ ds: xr.Dataset,
3038
+ dim_names: Mapping[str, str],
3039
+ end: TConcatEndTypes,
3040
+ use_dask: bool = False,
3041
+ ) -> xr.Dataset:
3042
+ """
3043
+ Concatenate longitude dimension to handle global grids that cross
3044
+ the 0/360-degree or -180/180-degree boundary.
3045
+
3046
+ Extends the longitude dimension either lower, upper, or both sides
3047
+ by +/- 360 degrees and duplicates the corresponding variables along
3048
+ that dimension.
3049
+
3050
+ Parameters
3051
+ ----------
3052
+ ds : xr.Dataset
3053
+ Input xarray Dataset to be concatenated.
3054
+ dim_names : Mapping[str, str]
3055
+ Dictionary or mapping containing dimension names. Must include "longitude".
3056
+ end : str
3057
+ Specifies which side(s) to extend:
3058
+ - "lower": extend by subtracting 360 degrees.
3059
+ - "upper": extend by adding 360 degrees.
3060
+ - "both": extend on both sides.
3061
+ use_dask : bool, default False
3062
+ If True, chunk the concatenated longitude dimension using Dask.
3063
+
3064
+ Returns
3065
+ -------
3066
+ xr.Dataset
3067
+ Dataset with longitude dimension extended and data variables duplicated.
3068
+
3069
+ Notes
3070
+ -----
3071
+ Only data variables containing the longitude dimension are concatenated;
3072
+ others are left unchanged.
3073
+ """
3074
+ ds_concat = xr.Dataset()
3075
+
3076
+ lon_name = dim_names["longitude"]
3077
+ lon = ds[lon_name]
3078
+
3079
+ match end:
3080
+ case "lower":
3081
+ lon_concat = xr.concat([lon - 360, lon], dim=lon_name)
3082
+ n_copies = 2
3083
+ case "upper":
3084
+ lon_concat = xr.concat([lon, lon + 360], dim=lon_name)
3085
+ n_copies = 2
3086
+ case "both":
3087
+ lon_concat = xr.concat([lon - 360, lon, lon + 360], dim=lon_name)
3088
+ n_copies = 3
3089
+ case _:
3090
+ raise ValueError(f"Invalid `end` value: {end}")
3091
+
3092
+ for var in ds.variables:
3093
+ if lon_name in ds[var].dims:
3094
+ field = ds[var]
3095
+ field_concat = xr.concat([field] * n_copies, dim=lon_name)
3096
+
3097
+ if use_dask:
3098
+ field_concat = field_concat.chunk({lon_name: -1})
3099
+
3100
+ ds_concat[var] = field_concat
3101
+ else:
3102
+ ds_concat[var] = ds[var]
3103
+
3104
+ ds_concat = ds_concat.assign_coords({lon_name: lon_concat.values})
3105
+
3106
+ return ds_concat
3107
+
3108
+
3109
+ def choose_subdomain(
3110
+ ds: xr.Dataset,
3111
+ dim_names: Mapping[str, str],
3112
+ resolution: float,
3113
+ is_global: bool,
3114
+ target_coords: Mapping[str, Any],
3115
+ buffer_points: int = 20,
3116
+ use_dask: bool = False,
3117
+ ) -> xr.Dataset:
3118
+ """
3119
+ Select a subdomain from an xarray Dataset based on target coordinates,
3120
+ with optional buffer points and global longitude handling.
3121
+
3122
+ Parameters
3123
+ ----------
3124
+ ds : xr.Dataset
3125
+ The full xarray Dataset to subset.
3126
+ dim_names : Mapping[str, str]
3127
+ Dictionary mapping logical dimension names to dataset dimension names.
3128
+ Example: {"latitude": "latitude", "longitude": "longitude"}.
3129
+ resolution : float
3130
+ Spatial resolution of the dataset, used to compute buffer margin.
3131
+ is_global : bool
3132
+ Whether the dataset covers global longitude (affects concatenation logic).
3133
+ target_coords : Mapping[str, Any]
3134
+ Dictionary containing target latitude and longitude coordinates.
3135
+ Expected keys: "lat", "lon", and "straddle" (boolean for crossing 180°).
3136
+ buffer_points : int, default 20
3137
+ Number of grid points to extend beyond the target coordinates.
3138
+ use_dask: bool, optional
3139
+ Indicates whether to use dask for chunking. If True, data is loaded with dask; if False, data is processed eagerly. Defaults to False.
3140
+
3141
+ Returns
3142
+ -------
3143
+ xr.Dataset
3144
+ Subset of the input Dataset covering the requested coordinates plus buffer.
3145
+
3146
+ Raises
3147
+ ------
3148
+ ValueError
3149
+ If the selected latitude or longitude range does not intersect the dataset.
3150
+ """
3151
+ lat_min = target_coords["lat"].min().values
3152
+ lat_max = target_coords["lat"].max().values
3153
+ lon_min = target_coords["lon"].min().values
3154
+ lon_max = target_coords["lon"].max().values
3155
+
3156
+ margin = resolution * buffer_points
3157
+
3158
+ # Select the subdomain in latitude direction (so that we have to concatenate fewer latitudes below if concatenation is necessary)
3159
+ subdomain = ds.sel(
3160
+ **{
3161
+ dim_names["latitude"]: slice(lat_min - margin, lat_max + margin),
3162
+ }
3163
+ )
3164
+ lon = subdomain[dim_names["longitude"]]
3165
+
3166
+ if is_global:
3167
+ concats = []
3168
+ # Concatenate only if necessary
3169
+ if lon_max + margin > lon.max():
3170
+ # See if shifting by +360 degrees helps
3171
+ if (lon_min - margin > (lon + 360).min()) and (
3172
+ lon_max + margin < (lon + 360).max()
3173
+ ):
3174
+ subdomain[dim_names["longitude"]] = lon + 360
3175
+ lon = subdomain[dim_names["longitude"]]
3176
+ else:
3177
+ concats.append("upper")
3178
+ if lon_min - margin < lon.min():
3179
+ # See if shifting by -360 degrees helps
3180
+ if (lon_min - margin > (lon - 360).min()) and (
3181
+ lon_max + margin < (lon - 360).max()
3182
+ ):
3183
+ subdomain[dim_names["longitude"]] = lon - 360
3184
+ lon = subdomain[dim_names["longitude"]]
3185
+ else:
3186
+ concats.append("lower")
3187
+
3188
+ if concats:
3189
+ end = "both" if len(concats) == 2 else concats[0]
3190
+ end = cast(TConcatEndTypes, end)
3191
+ subdomain = _concatenate_longitudes(
3192
+ subdomain, dim_names=dim_names, end=end, use_dask=use_dask
3193
+ )
3194
+ lon = subdomain[dim_names["longitude"]]
3195
+
3196
+ else:
3197
+ # Adjust longitude range if needed to match the expected range
3198
+ if not target_coords["straddle"]:
3199
+ if lon.min() < -180:
3200
+ if lon_max + margin > 0:
3201
+ lon_min -= 360
3202
+ lon_max -= 360
3203
+ elif lon.min() < 0:
3204
+ if lon_max + margin > 180:
3205
+ lon_min -= 360
3206
+ lon_max -= 360
3207
+
3208
+ if target_coords["straddle"]:
3209
+ if lon.max() > 360:
3210
+ if lon_min - margin < 180:
3211
+ lon_min += 360
3212
+ lon_max += 360
3213
+ elif lon.max() > 180:
3214
+ if lon_min - margin < 0:
3215
+ lon_min += 360
3216
+ lon_max += 360
3217
+ # Select the subdomain in longitude direction
3218
+ subdomain = subdomain.sel(
3219
+ **{
3220
+ dim_names["longitude"]: slice(lon_min - margin, lon_max + margin),
3221
+ }
3222
+ )
3223
+ # Check if the selected subdomain has zero dimensions in latitude or longitude
3224
+ if (
3225
+ dim_names["latitude"] not in subdomain
3226
+ or subdomain[dim_names["latitude"]].size == 0
3227
+ ):
3228
+ raise ValueError("Selected latitude range does not intersect with dataset.")
3229
+
3230
+ if (
3231
+ dim_names["longitude"] not in subdomain
3232
+ or subdomain[dim_names["longitude"]].size == 0
3233
+ ):
3234
+ raise ValueError("Selected longitude range does not intersect with dataset.")
3235
+
3236
+ # Adjust longitudes to expected range if needed
3237
+ lon = subdomain[dim_names["longitude"]]
3238
+ if target_coords["straddle"]:
3239
+ subdomain[dim_names["longitude"]] = xr.where(lon > 180, lon - 360, lon)
3240
+ else:
3241
+ subdomain[dim_names["longitude"]] = xr.where(lon < 0, lon + 360, lon)
3242
+
3243
+ return subdomain
3244
+
3245
+
3246
+ def get_glorys_bounds(
3247
+ grid: Grid,
3248
+ glorys_grid_path: Path | str | None = None,
3249
+ ) -> dict[str, float]:
3250
+ """
3251
+ Compute the latitude/longitude bounds of a GLORYS spatial subset
3252
+ that fully covers the given ROMS grid (with margin for regridding).
3253
+
3254
+ Parameters
3255
+ ----------
3256
+ grid : Grid
3257
+ The grid object.
3258
+ glorys_grid_path : str, optional
3259
+ Path to the GLORYS global grid file. If None, defaults to
3260
+ "<repo_root>/data/grids/GLORYS_global_grid.nc".
3261
+
3262
+ Returns
3263
+ -------
3264
+ dict[str, float]
3265
+ Dictionary containing the bounding box values:
3266
+
3267
+ - `"minimum_latitude"` : float
3268
+ - `"maximum_latitude"` : float
3269
+ - `"minimum_longitude"` : float
3270
+ - `"maximum_longitude"` : float
3271
+
3272
+ Notes
3273
+ -----
3274
+ - The resolution is estimated as the mean of latitude and longitude spacing.
3275
+ """
3276
+ if glorys_grid_path is None:
3277
+ glorys_grid_path = GLORYS_GLOBAL_GRID_PATH
3278
+
3279
+ ds = xr.open_dataset(glorys_grid_path)
3280
+
3281
+ # Estimate grid resolution (mean spacing in degrees)
3282
+ res_lat = ds.latitude.diff("latitude").mean()
3283
+ res_lon = ds.longitude.diff("longitude").mean()
3284
+ resolution = (res_lat + res_lon) / 2
3285
+
3286
+ # Extract target grid coordinates
3287
+ target_coords = get_target_coords(grid)
3288
+
3289
+ # Select subdomain with margin
3290
+ ds_subset = choose_subdomain(
3291
+ ds=ds,
3292
+ dim_names={"latitude": "latitude", "longitude": "longitude"},
3293
+ resolution=resolution,
3294
+ is_global=True,
3295
+ target_coords=target_coords,
3296
+ buffer_points=DEFAULT_NR_BUFFER_POINTS + 1,
3297
+ )
3298
+
3299
+ # Compute bounds
3300
+ return {
3301
+ "minimum_latitude": float(ds_subset.latitude.min()),
3302
+ "maximum_latitude": float(ds_subset.latitude.max()),
3303
+ "minimum_longitude": float(ds_subset.longitude.min()),
3304
+ "maximum_longitude": float(ds_subset.longitude.max()),
3305
+ }