roms-tools 3.1.1__py3-none-any.whl → 3.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. roms_tools/__init__.py +8 -1
  2. roms_tools/analysis/cdr_analysis.py +203 -0
  3. roms_tools/analysis/cdr_ensemble.py +198 -0
  4. roms_tools/analysis/roms_output.py +80 -46
  5. roms_tools/data/grids/GLORYS_global_grid.nc +0 -0
  6. roms_tools/download.py +4 -0
  7. roms_tools/plot.py +131 -30
  8. roms_tools/regrid.py +6 -1
  9. roms_tools/setup/boundary_forcing.py +94 -44
  10. roms_tools/setup/cdr_forcing.py +123 -15
  11. roms_tools/setup/cdr_release.py +161 -8
  12. roms_tools/setup/datasets.py +709 -341
  13. roms_tools/setup/grid.py +167 -139
  14. roms_tools/setup/initial_conditions.py +113 -48
  15. roms_tools/setup/mask.py +63 -7
  16. roms_tools/setup/nesting.py +67 -42
  17. roms_tools/setup/river_forcing.py +45 -19
  18. roms_tools/setup/surface_forcing.py +16 -10
  19. roms_tools/setup/tides.py +1 -2
  20. roms_tools/setup/topography.py +4 -4
  21. roms_tools/setup/utils.py +134 -22
  22. roms_tools/tests/test_analysis/test_cdr_analysis.py +144 -0
  23. roms_tools/tests/test_analysis/test_cdr_ensemble.py +202 -0
  24. roms_tools/tests/test_analysis/test_roms_output.py +61 -3
  25. roms_tools/tests/test_setup/test_boundary_forcing.py +111 -52
  26. roms_tools/tests/test_setup/test_cdr_forcing.py +54 -0
  27. roms_tools/tests/test_setup/test_cdr_release.py +118 -1
  28. roms_tools/tests/test_setup/test_datasets.py +458 -34
  29. roms_tools/tests/test_setup/test_grid.py +238 -121
  30. roms_tools/tests/test_setup/test_initial_conditions.py +94 -41
  31. roms_tools/tests/test_setup/test_surface_forcing.py +28 -3
  32. roms_tools/tests/test_setup/test_utils.py +91 -1
  33. roms_tools/tests/test_setup/test_validation.py +21 -15
  34. roms_tools/tests/test_setup/utils.py +71 -0
  35. roms_tools/tests/test_tiling/test_join.py +241 -0
  36. roms_tools/tests/test_tiling/test_partition.py +45 -0
  37. roms_tools/tests/test_utils.py +224 -2
  38. roms_tools/tiling/join.py +189 -0
  39. roms_tools/tiling/partition.py +44 -30
  40. roms_tools/utils.py +488 -161
  41. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/METADATA +15 -4
  42. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/RECORD +45 -37
  43. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/WHEEL +0 -0
  44. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/licenses/LICENSE +0 -0
  45. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,18 @@
1
+ from __future__ import annotations
2
+
3
+ import importlib.util
1
4
  import logging
2
- import time
5
+ import typing
3
6
  from collections import Counter, defaultdict
7
+ from collections.abc import Callable, Mapping
4
8
  from dataclasses import dataclass, field
5
9
  from datetime import datetime, timedelta
6
10
  from pathlib import Path
11
+ from types import ModuleType
12
+ from typing import Any, ClassVar, Literal, TypeAlias, cast
13
+
14
+ if typing.TYPE_CHECKING:
15
+ from roms_tools.setup.grid import Grid
7
16
 
8
17
  import numpy as np
9
18
  import xarray as xr
@@ -17,15 +26,32 @@ from roms_tools.download import (
17
26
  )
18
27
  from roms_tools.setup.fill import LateralFill
19
28
  from roms_tools.setup.utils import (
29
+ Timed,
20
30
  assign_dates_to_climatology,
21
31
  convert_cftime_to_datetime,
22
32
  gc_dist,
33
+ get_target_coords,
23
34
  get_time_type,
24
35
  interpolate_cyclic_time,
25
36
  interpolate_from_climatology,
26
37
  one_dim_fill,
27
38
  )
28
- from roms_tools.utils import _has_gcsfs, _load_data
39
+ from roms_tools.utils import get_dask_chunks, get_pkg_error_msg, has_gcsfs, load_data
40
+
41
+ TConcatEndTypes = Literal["lower", "upper", "both"]
42
+ REPO_ROOT = Path(__file__).resolve().parents[2]
43
+ GLORYS_GLOBAL_GRID_PATH = (
44
+ REPO_ROOT / "roms_tools" / "data" / "grids" / "GLORYS_global_grid.nc"
45
+ )
46
+ DEFAULT_NR_BUFFER_POINTS = (
47
+ 20 # Default number of buffer points for subdomain selection.
48
+ )
49
+ # Balances performance and accuracy:
50
+ # - Too many points → more expensive computations
51
+ # - Too few points → potential boundary artifacts when lateral refill is performed
52
+ # See discussion: https://github.com/CWorthy-ocean/roms-tools/issues/153
53
+ # This default will be applied consistently across all datasets requiring lateral fill.
54
+ RawDataSource: TypeAlias = dict[str, str | Path | list[str | Path] | bool]
29
55
 
30
56
  # lat-lon datasets
31
57
 
@@ -43,7 +69,7 @@ class Dataset:
43
69
  Start time for selecting relevant data. If not provided, no time-based filtering is applied.
44
70
  end_time : Optional[datetime], optional
45
71
  End time for selecting relevant data. If not provided, the dataset selects the time entry
46
- closest to `start_time` within the range `[start_time, start_time + 24 hours]`.
72
+ closest to `start_time` within the range `[start_time, start_time + 24 hours)`.
47
73
  If `start_time` is also not provided, no time-based filtering is applied.
48
74
  dim_names: Dict[str, str], optional
49
75
  Dictionary specifying the names of dimensions in the dataset.
@@ -58,8 +84,19 @@ class Dataset:
58
84
  Indicates whether land values require lateral filling. If `True`, ocean values will be extended into land areas
59
85
  to replace NaNs or non-ocean values (such as atmospheric values in ERA5 data). If `False`, it is assumed that
60
86
  land values are already correctly assigned, and lateral filling will be skipped. Defaults to `True`.
61
- use_dask: bool
87
+ use_dask: bool, optional
62
88
  Indicates whether to use dask for chunking. If True, data is loaded with dask; if False, data is loaded eagerly. Defaults to False.
89
+ read_zarr: bool, optional
90
+ If True, use the zarr engine to read the dataset, and don't use mfdataset.
91
+ Defaults to False.
92
+ allow_flex_time: bool, optional
93
+ Controls how strictly the dataset selects a time entry when `end_time` is not provided (relevant for initial conditions):
94
+
95
+ - If False (default): requires an exact match to `start_time`. Raises a ValueError if no match exists.
96
+ - If True: allows a +24h search window after `start_time` and selects the closest available
97
+ time entry within that window. Raises a ValueError if none are found.
98
+
99
+ Only used when `end_time` is None. Has no effect otherwise.
63
100
  apply_post_processing: bool
64
101
  Indicates whether to post-process the dataset for futher use. Defaults to True.
65
102
 
@@ -90,23 +127,25 @@ class Dataset:
90
127
  }
91
128
  )
92
129
  var_names: dict[str, str]
93
- opt_var_names: dict[str, str] | None = field(default_factory=dict)
94
- climatology: bool | None = False
130
+ opt_var_names: dict[str, str] = field(default_factory=dict)
131
+ climatology: bool = False
95
132
  needs_lateral_fill: bool | None = True
96
- use_dask: bool | None = False
133
+ use_dask: bool = False
134
+ read_zarr: bool = False
135
+ allow_flex_time: bool = False
97
136
  apply_post_processing: bool | None = True
98
- read_zarr: bool | None = False
99
137
 
138
+ ds_loader_fn: Callable[[], xr.Dataset] | None = None
100
139
  is_global: bool = field(init=False, repr=False)
101
140
  ds: xr.Dataset = field(init=False, repr=False)
102
141
 
103
- def __post_init__(self):
104
- """
105
- Post-initialization processing:
142
+ def __post_init__(self) -> None:
143
+ """Perform post-initialization processing.
144
+
106
145
  1. Loads the dataset from the specified filename.
107
- 2. Applies time filtering based on start_time and end_time if provided.
108
- 3. Selects relevant fields as specified by var_names.
109
- 4. Ensures latitude values and depth values are in ascending order.
146
+ 2. Applies time filtering based on start_time and end_time (if provided).
147
+ 3. Selects relevant fields as specified by `var_names`.
148
+ 4. Ensures latitude, longitude, and depth values are in ascending order.
110
149
  5. Checks if the dataset covers the entire globe and adjusts if necessary.
111
150
  """
112
151
  # Validate start_time and end_time
@@ -167,13 +206,17 @@ class Dataset:
167
206
  ValueError
168
207
  If a list of files is provided but self.dim_names["time"] is not available or use_dask=False.
169
208
  """
170
- ds = _load_data(
171
- self.filename, self.dim_names, self.use_dask, read_zarr=self.read_zarr
209
+ ds = load_data(
210
+ filename=self.filename,
211
+ dim_names=self.dim_names,
212
+ use_dask=self.use_dask,
213
+ read_zarr=self.read_zarr,
214
+ ds_loader_fn=self.ds_loader_fn,
172
215
  )
173
216
 
174
217
  return ds
175
218
 
176
- def clean_up(self, ds: xr.Dataset, **kwargs) -> xr.Dataset:
219
+ def clean_up(self, ds: xr.Dataset) -> xr.Dataset:
177
220
  """Dummy method to be overridden by child classes to clean up the dataset.
178
221
 
179
222
  This method is intended as a placeholder and should be implemented in subclasses
@@ -206,7 +249,7 @@ class Dataset:
206
249
  """
207
250
  _check_dataset(ds, self.dim_names, self.var_names)
208
251
 
209
- def select_relevant_fields(self, ds) -> xr.Dataset:
252
+ def select_relevant_fields(self, ds: xr.Dataset) -> xr.Dataset:
210
253
  """Selects and returns a subset of the dataset containing only the variables
211
254
  specified in `self.var_names`.
212
255
 
@@ -249,7 +292,7 @@ class Dataset:
249
292
  """
250
293
  return ds
251
294
 
252
- def select_relevant_times(self, ds) -> xr.Dataset:
295
+ def select_relevant_times(self, ds: xr.Dataset) -> xr.Dataset:
253
296
  """Select a subset of the dataset based on the specified time range.
254
297
 
255
298
  This method filters the dataset to include all records between `start_time` and `end_time`.
@@ -257,7 +300,7 @@ class Dataset:
257
300
  after `end_time` are included, even if they fall outside the strict time range.
258
301
 
259
302
  If no `end_time` is specified, the method will select the time range of
260
- [start_time, start_time + 24 hours] and return the closest time entry to `start_time` within that range.
303
+ [start_time, start_time + 24 hours) and return the closest time entry to `start_time` within that range.
261
304
 
262
305
  Parameters
263
306
  ----------
@@ -296,8 +339,17 @@ class Dataset:
296
339
  """
297
340
  time_dim = self.dim_names["time"]
298
341
 
342
+ # Ensure start_time is not None for type safety
343
+ if self.start_time is None:
344
+ raise ValueError("select_relevant_times called but start_time is None.")
345
+
299
346
  ds = _select_relevant_times(
300
- ds, time_dim, self.start_time, self.end_time, self.climatology
347
+ ds,
348
+ time_dim,
349
+ self.start_time,
350
+ self.end_time,
351
+ self.climatology,
352
+ self.allow_flex_time,
301
353
  )
302
354
 
303
355
  return ds
@@ -344,7 +396,7 @@ class Dataset:
344
396
 
345
397
  return ds
346
398
 
347
- def infer_horizontal_resolution(self, ds: xr.Dataset):
399
+ def infer_horizontal_resolution(self, ds: xr.Dataset) -> None:
348
400
  """Estimate and set the average horizontal resolution of a dataset based on
349
401
  latitude and longitude spacing.
350
402
 
@@ -372,7 +424,7 @@ class Dataset:
372
424
  # Set the computed resolution as an attribute
373
425
  self.resolution = resolution
374
426
 
375
- def compute_minimal_grid_spacing(self, ds: xr.Dataset):
427
+ def compute_minimal_grid_spacing(self, ds: xr.Dataset) -> float:
376
428
  """Compute the minimal grid spacing in a dataset based on latitude and longitude
377
429
  spacing, considering Earth's radius.
378
430
 
@@ -434,7 +486,12 @@ class Dataset:
434
486
 
435
487
  return is_global
436
488
 
437
- def concatenate_longitudes(self, ds, end="upper", verbose=False):
489
+ def concatenate_longitudes(
490
+ self,
491
+ ds: xr.Dataset,
492
+ end: TConcatEndTypes = "upper",
493
+ verbose: bool = False,
494
+ ) -> xr.Dataset:
438
495
  """Concatenates fields in dataset twice along the longitude dimension.
439
496
 
440
497
  Parameters
@@ -457,58 +514,12 @@ class Dataset:
457
514
  ds_concatenated : xr.Dataset
458
515
  The concatenated dataset.
459
516
  """
460
- if verbose:
461
- start_time = time.time()
462
-
463
- ds_concatenated = xr.Dataset()
464
-
465
- lon = ds[self.dim_names["longitude"]]
466
- if end == "lower":
467
- lon_minus360 = lon - 360
468
- lon_concatenated = xr.concat(
469
- [lon_minus360, lon], dim=self.dim_names["longitude"]
470
- )
471
-
472
- elif end == "upper":
473
- lon_plus360 = lon + 360
474
- lon_concatenated = xr.concat(
475
- [lon, lon_plus360], dim=self.dim_names["longitude"]
476
- )
477
-
478
- elif end == "both":
479
- lon_minus360 = lon - 360
480
- lon_plus360 = lon + 360
481
- lon_concatenated = xr.concat(
482
- [lon_minus360, lon, lon_plus360], dim=self.dim_names["longitude"]
483
- )
484
-
485
- for var in ds.data_vars:
486
- if self.dim_names["longitude"] in ds[var].dims:
487
- field = ds[var]
488
-
489
- if end == "both":
490
- field_concatenated = xr.concat(
491
- [field, field, field], dim=self.dim_names["longitude"]
492
- )
493
- else:
494
- field_concatenated = xr.concat(
495
- [field, field], dim=self.dim_names["longitude"]
496
- )
497
-
498
- if self.use_dask:
499
- field_concatenated = field_concatenated.chunk(
500
- {self.dim_names["longitude"]: -1}
501
- )
502
- field_concatenated[self.dim_names["longitude"]] = lon_concatenated
503
- ds_concatenated[var] = field_concatenated
504
- else:
505
- ds_concatenated[var] = ds[var]
506
-
507
- ds_concatenated[self.dim_names["longitude"]] = lon_concatenated
508
-
509
- if verbose:
510
- logging.info(
511
- f"Concatenating the data along the longitude dimension: {time.time() - start_time:.3f} seconds"
517
+ with Timed(
518
+ "=== Concatenating the data along the longitude dimension ===",
519
+ verbose=verbose,
520
+ ):
521
+ ds_concatenated = _concatenate_longitudes(
522
+ ds, self.dim_names, end, self.use_dask
512
523
  )
513
524
 
514
525
  return ds_concatenated
@@ -543,14 +554,16 @@ class Dataset:
543
554
  ds = self.ds.astype({var: "float64" for var in self.ds.data_vars})
544
555
  self.ds = ds
545
556
 
557
+ return None
558
+
546
559
  def choose_subdomain(
547
560
  self,
548
- target_coords,
549
- buffer_points=20,
550
- return_copy=False,
551
- return_coords_only=False,
552
- verbose=False,
553
- ):
561
+ target_coords: dict[str, Any],
562
+ buffer_points: int = DEFAULT_NR_BUFFER_POINTS,
563
+ return_copy: bool = False,
564
+ return_coords_only: bool = False,
565
+ verbose: bool = False,
566
+ ) -> xr.Dataset | Dataset | None:
554
567
  """Selects a subdomain from the xarray Dataset based on specified target
555
568
  coordinates, extending the selection by a defined buffer. Adjusts longitude
556
569
  ranges as necessary to accommodate the dataset's expected range and handles
@@ -587,94 +600,15 @@ class Dataset:
587
600
  ValueError
588
601
  If the selected latitude or longitude range does not intersect with the dataset.
589
602
  """
590
- lat_min = target_coords["lat"].min().values
591
- lat_max = target_coords["lat"].max().values
592
- lon_min = target_coords["lon"].min().values
593
- lon_max = target_coords["lon"].max().values
594
-
595
- margin = self.resolution * buffer_points
596
-
597
- # Select the subdomain in latitude direction (so that we have to concatenate fewer latitudes below if concatenation is necessary)
598
- subdomain = self.ds.sel(
599
- **{
600
- self.dim_names["latitude"]: slice(lat_min - margin, lat_max + margin),
601
- }
603
+ subdomain = choose_subdomain(
604
+ ds=self.ds,
605
+ dim_names=self.dim_names,
606
+ resolution=self.resolution,
607
+ is_global=self.is_global,
608
+ target_coords=target_coords,
609
+ buffer_points=buffer_points,
610
+ use_dask=self.use_dask,
602
611
  )
603
- lon = subdomain[self.dim_names["longitude"]]
604
-
605
- if self.is_global:
606
- concats = []
607
- # Concatenate only if necessary
608
- if lon_max + margin > lon.max():
609
- # See if shifting by +360 degrees helps
610
- if (lon_min - margin > (lon + 360).min()) and (
611
- lon_max + margin < (lon + 360).max()
612
- ):
613
- subdomain[self.dim_names["longitude"]] = lon + 360
614
- lon = subdomain[self.dim_names["longitude"]]
615
- else:
616
- concats.append("upper")
617
- if lon_min - margin < lon.min():
618
- # See if shifting by -360 degrees helps
619
- if (lon_min - margin > (lon - 360).min()) and (
620
- lon_max + margin < (lon - 360).max()
621
- ):
622
- subdomain[self.dim_names["longitude"]] = lon - 360
623
- lon = subdomain[self.dim_names["longitude"]]
624
- else:
625
- concats.append("lower")
626
-
627
- if concats:
628
- end = "both" if len(concats) == 2 else concats[0]
629
- subdomain = self.concatenate_longitudes(
630
- subdomain, end=end, verbose=False
631
- )
632
- lon = subdomain[self.dim_names["longitude"]]
633
-
634
- else:
635
- # Adjust longitude range if needed to match the expected range
636
- if not target_coords["straddle"]:
637
- if lon.min() < -180:
638
- if lon_max + margin > 0:
639
- lon_min -= 360
640
- lon_max -= 360
641
- elif lon.min() < 0:
642
- if lon_max + margin > 180:
643
- lon_min -= 360
644
- lon_max -= 360
645
-
646
- if target_coords["straddle"]:
647
- if lon.max() > 360:
648
- if lon_min - margin < 180:
649
- lon_min += 360
650
- lon_max += 360
651
- elif lon.max() > 180:
652
- if lon_min - margin < 0:
653
- lon_min += 360
654
- lon_max += 360
655
- # Select the subdomain in longitude direction
656
-
657
- subdomain = subdomain.sel(
658
- **{
659
- self.dim_names["longitude"]: slice(lon_min - margin, lon_max + margin),
660
- }
661
- )
662
-
663
- # Check if the selected subdomain has zero dimensions in latitude or longitude
664
- if subdomain[self.dim_names["latitude"]].size == 0:
665
- raise ValueError("Selected latitude range does not intersect with dataset.")
666
-
667
- if subdomain[self.dim_names["longitude"]].size == 0:
668
- raise ValueError(
669
- "Selected longitude range does not intersect with dataset."
670
- )
671
-
672
- # Adjust longitudes to expected range if needed
673
- lon = subdomain[self.dim_names["longitude"]]
674
- if target_coords["straddle"]:
675
- subdomain[self.dim_names["longitude"]] = xr.where(lon > 180, lon - 360, lon)
676
- else:
677
- subdomain[self.dim_names["longitude"]] = xr.where(lon < 0, lon + 360, lon)
678
612
 
679
613
  if return_coords_only:
680
614
  # Create and return a dataset with only latitudes and longitudes
@@ -687,6 +621,7 @@ class Dataset:
687
621
  return Dataset.from_ds(self, subdomain)
688
622
  else:
689
623
  self.ds = subdomain
624
+ return None
690
625
 
691
626
  def apply_lateral_fill(self):
692
627
  """Apply lateral fill to variables using the dataset's mask and grid dimensions.
@@ -706,10 +641,6 @@ class Dataset:
706
641
  point to the same variable in the dataset.
707
642
  """
708
643
  if self.needs_lateral_fill:
709
- logging.info(
710
- "Applying 2D horizontal fill to the source data before regridding."
711
- )
712
-
713
644
  lateral_fill = LateralFill(
714
645
  self.ds["mask"],
715
646
  [self.dim_names["latitude"], self.dim_names["longitude"]],
@@ -740,10 +671,6 @@ class Dataset:
740
671
  else:
741
672
  # Apply standard lateral fill for other variables
742
673
  self.ds[var_name] = lateral_fill.apply(self.ds[var_name])
743
- else:
744
- logging.info(
745
- "2D horizontal fill is skipped because source data already contains filled values."
746
- )
747
674
 
748
675
  def extrapolate_deepest_to_bottom(self):
749
676
  """Extrapolate deepest non-NaN values to fill bottom NaNs along the depth
@@ -760,7 +687,7 @@ class Dataset:
760
687
  )
761
688
 
762
689
  @classmethod
763
- def from_ds(cls, original_dataset: "Dataset", ds: xr.Dataset) -> "Dataset":
690
+ def from_ds(cls, original_dataset: Dataset, ds: xr.Dataset) -> Dataset:
764
691
  """Substitute the internal dataset of a Dataset object with a new xarray
765
692
  Dataset.
766
693
 
@@ -862,7 +789,7 @@ class TPXODataset(Dataset):
862
789
  ValueError
863
790
  If longitude or latitude values do not match the grid.
864
791
  """
865
- ds_grid = _load_data(self.grid_filename, self.dim_names, self.use_dask)
792
+ ds_grid = load_data(self.grid_filename, self.dim_names, self.use_dask)
866
793
 
867
794
  # Define mask and coordinate names based on location
868
795
  if self.location == "h":
@@ -893,21 +820,13 @@ class TPXODataset(Dataset):
893
820
 
894
821
  # Drop all dimensions except 'longitude' and 'latitude'
895
822
  dims_to_keep = {"longitude", "latitude"}
896
- dims_to_drop = [dim for dim in ds_grid.dims if dim not in dims_to_keep]
823
+ dims_to_drop: set[str] = set(ds_grid.dims) - dims_to_keep
897
824
  if dims_to_drop:
898
825
  ds_grid = ds_grid.isel({dim: 0 for dim in dims_to_drop})
899
826
 
900
827
  # Ensure correct dimension order
901
828
  ds_grid = ds_grid.transpose("latitude", "longitude")
902
829
 
903
- dims_to_keep = {"longitude", "latitude"}
904
- dims_to_drop = set(ds_grid.dims) - dims_to_keep
905
- ds_grid = (
906
- ds_grid.isel({dim: 0 for dim in dims_to_drop}) if dims_to_drop else ds_grid
907
- )
908
- # Bring dimensions in correct order
909
- ds_grid = ds_grid.transpose("latitude", "longitude")
910
-
911
830
  ds = ds.rename({"con": "nc"})
912
831
  ds = ds.assign_coords(
913
832
  {
@@ -1042,7 +961,7 @@ class GLORYSDataset(Dataset):
1042
961
  }
1043
962
  )
1044
963
 
1045
- climatology: bool | None = False
964
+ climatology: bool = False
1046
965
 
1047
966
  def post_process(self):
1048
967
  """Apply a mask to the dataset based on the 'zeta' variable, with 0 where 'zeta'
@@ -1058,23 +977,132 @@ class GLORYSDataset(Dataset):
1058
977
  None
1059
978
  The dataset is modified in-place by applying the mask to each variable.
1060
979
  """
1061
- mask = xr.where(
1062
- self.ds[self.var_names["zeta"]].isel({self.dim_names["time"]: 0}).isnull(),
1063
- 0,
1064
- 1,
1065
- )
1066
- mask_vel = xr.where(
1067
- self.ds[self.var_names["u"]]
1068
- .isel({self.dim_names["time"]: 0, self.dim_names["depth"]: 0})
1069
- .isnull(),
1070
- 0,
1071
- 1,
1072
- )
980
+ zeta = self.ds[self.var_names["zeta"]]
981
+ u = self.ds[self.var_names["u"]]
1073
982
 
983
+ # Select time=0 if time dimension exists, otherwise use data as-is
984
+ if self.dim_names["time"] in zeta.dims:
985
+ zeta_ref = zeta.isel({self.dim_names["time"]: 0})
986
+ else:
987
+ zeta_ref = zeta
988
+
989
+ if self.dim_names["time"] in u.dims:
990
+ u_ref = u.isel({self.dim_names["time"]: 0})
991
+ else:
992
+ u_ref = u
993
+
994
+ # Also handle depth for velocity
995
+ if self.dim_names["depth"] in u_ref.dims:
996
+ u_ref = u_ref.isel({self.dim_names["depth"]: 0})
997
+
998
+ # Create masks
999
+ mask = xr.where(zeta_ref.isnull(), 0, 1)
1000
+ mask_vel = xr.where(u_ref.isnull(), 0, 1)
1001
+
1002
+ # Save to dataset
1074
1003
  self.ds["mask"] = mask
1075
1004
  self.ds["mask_vel"] = mask_vel
1076
1005
 
1077
1006
 
1007
+ @dataclass(kw_only=True)
1008
+ class GLORYSDefaultDataset(GLORYSDataset):
1009
+ """A GLORYS dataset that is loaded from the Copernicus Marine Data Store."""
1010
+
1011
+ dataset_name: ClassVar[str] = "cmems_mod_glo_phy_my_0.083deg_P1D-m"
1012
+ """The GLORYS dataset-id for requests to the Copernicus Marine Toolkit"""
1013
+ _tk_module: ModuleType | None = None
1014
+ """The dynamically imported Copernicus Marine module."""
1015
+
1016
+ def __post_init__(self) -> None:
1017
+ """Configure attributes to ensure use of the correct upstream data-source."""
1018
+ self.read_zarr = True
1019
+ self.use_dask = True
1020
+ self.filename = self.dataset_name
1021
+ self.ds_loader_fn = self._load_from_copernicus
1022
+
1023
+ super().__post_init__()
1024
+
1025
+ def _check_auth(self, package_name: str) -> None:
1026
+ """Check the local credential hierarchy for auth credentials.
1027
+
1028
+ Raises
1029
+ ------
1030
+ RuntimeError
1031
+ If auth credentials cannot be found.
1032
+ """
1033
+ if self._tk_module and not self._tk_module.login(check_credentials_valid=True):
1034
+ msg = f"Authenticate with `{package_name} login` to retrieve GLORYS data."
1035
+ raise RuntimeError(msg)
1036
+
1037
+ def _load_copernicus(self) -> ModuleType:
1038
+ """Dynamically load the optional Copernicus Marine Toolkit dependency.
1039
+
1040
+ Raises
1041
+ ------
1042
+ RuntimeError
1043
+ - If the toolkit module is not available or cannot be imported.
1044
+ - If auth credentials cannot be found.
1045
+ """
1046
+ package_name = "copernicusmarine"
1047
+ if self._tk_module:
1048
+ self._check_auth(package_name)
1049
+ return self._tk_module
1050
+
1051
+ spec = importlib.util.find_spec(package_name)
1052
+ if not spec:
1053
+ msg = get_pkg_error_msg("cloud-based GLORYS data", package_name, "stream")
1054
+ raise RuntimeError(msg)
1055
+
1056
+ try:
1057
+ self._tk_module = importlib.import_module(package_name)
1058
+ except ImportError as e:
1059
+ msg = f"Package `{package_name}` was found but could not be loaded."
1060
+ raise RuntimeError(msg) from e
1061
+
1062
+ self._check_auth(package_name)
1063
+ return self._tk_module
1064
+
1065
+ def _load_from_copernicus(self) -> xr.Dataset:
1066
+ """Load a GLORYS dataset supporting streaming.
1067
+
1068
+ Returns
1069
+ -------
1070
+ xr.Dataset
1071
+ The streaming dataset
1072
+ """
1073
+ copernicusmarine = self._load_copernicus()
1074
+
1075
+ # ds = copernicusmarine.download_functions.download_zarr.open_dataset_from_arco_series(
1076
+ # dataset_url="https://s3.waw3-1.cloudferro.com/mdl-arco-geo-025/arco/GLOBAL_MULTIYEAR_PHY_001_030/cmems_mod_glo_phy_my_0.083deg_P1D-m_202311/geoChunked.zarr",
1077
+ # variables=["thetao", "so", "uo", "vo", "zos"],
1078
+ # geographical_parameters=copernicusmarine.download_functions.subset_parameters.GeographicalParameters(),
1079
+ # temporal_parameters=copernicusmarine.download_functions.subset_parameters.TemporalParameters(
1080
+ # start_datetime=self.start_time, end_datetime=self.end_time
1081
+ # ),
1082
+ # depth_parameters=copernicusmarine.download_functions.subset_parameters.DepthParameters(),
1083
+ # coordinates_selection_method="outside",
1084
+ # optimum_dask_chunking={
1085
+ # "time": 1,
1086
+ # "depth": -1,
1087
+ # "latitude": -1,
1088
+ # "longitude": -1,
1089
+ # },
1090
+ # )
1091
+
1092
+ ds = copernicusmarine.open_dataset(
1093
+ self.dataset_name,
1094
+ start_datetime=self.start_time,
1095
+ end_datetime=self.end_time,
1096
+ service="arco-geo-series",
1097
+ coordinates_selection_method="outside",
1098
+ chunk_size_limit=-1,
1099
+ )
1100
+ chunks = get_dask_chunks(self.dim_names)
1101
+ ds = ds.chunk(chunks)
1102
+
1103
+ return ds
1104
+
1105
+
1078
1106
  @dataclass(kw_only=True)
1079
1107
  class UnifiedDataset(Dataset):
1080
1108
  """Represents unified BGC data on original grid.
@@ -1199,7 +1227,7 @@ class UnifiedBGCDataset(UnifiedDataset):
1199
1227
  }
1200
1228
  )
1201
1229
 
1202
- climatology: bool | None = True
1230
+ climatology: bool = True
1203
1231
 
1204
1232
 
1205
1233
  @dataclass(kw_only=True)
@@ -1221,7 +1249,7 @@ class UnifiedBGCSurfaceDataset(UnifiedDataset):
1221
1249
  }
1222
1250
  )
1223
1251
 
1224
- climatology: bool | None = True
1252
+ climatology: bool = True
1225
1253
 
1226
1254
 
1227
1255
  @dataclass(kw_only=True)
@@ -1336,9 +1364,9 @@ class CESMBGCDataset(CESMDataset):
1336
1364
  }
1337
1365
  )
1338
1366
 
1339
- climatology: bool | None = False
1367
+ climatology: bool = False
1340
1368
 
1341
- def post_process(self):
1369
+ def post_process(self) -> None:
1342
1370
  """
1343
1371
  Processes and converts CESM data values as follows:
1344
1372
  - Convert depth values from cm to m.
@@ -1407,9 +1435,9 @@ class CESMBGCSurfaceForcingDataset(CESMDataset):
1407
1435
  }
1408
1436
  )
1409
1437
 
1410
- climatology: bool | None = False
1438
+ climatology: bool = False
1411
1439
 
1412
- def post_process(self):
1440
+ def post_process(self) -> None:
1413
1441
  """Perform post-processing on the dataset to remove specific variables.
1414
1442
 
1415
1443
  This method checks if the variable "z_t" exists in the dataset. If it does,
@@ -1456,9 +1484,9 @@ class ERA5Dataset(Dataset):
1456
1484
  }
1457
1485
  )
1458
1486
 
1459
- climatology: bool | None = False
1487
+ climatology: bool = False
1460
1488
 
1461
- def post_process(self):
1489
+ def post_process(self) -> None:
1462
1490
  """
1463
1491
  Processes and converts ERA5 data values as follows:
1464
1492
  - Convert radiation values from J/m^2 to W/m^2.
@@ -1546,15 +1574,11 @@ class ERA5ARCODataset(ERA5Dataset):
1546
1574
  }
1547
1575
  )
1548
1576
 
1549
- def __post_init__(self):
1577
+ def __post_init__(self) -> None:
1550
1578
  self.read_zarr = True
1551
- if not _has_gcsfs():
1552
- raise RuntimeError(
1553
- "To use cloud-based ERA5 data, GCSFS is required but not installed. Install it with:\n"
1554
- " • `pip install roms-tools[stream]` or\n"
1555
- " • `conda install gcsfs`\n"
1556
- "Alternatively, install `roms-tools` with conda to include all dependencies."
1557
- )
1579
+ if not has_gcsfs():
1580
+ msg = get_pkg_error_msg("cloud-based ERA5 data", "gcsfs", "stream")
1581
+ raise RuntimeError(msg)
1558
1582
 
1559
1583
  super().__post_init__()
1560
1584
 
@@ -1582,9 +1606,9 @@ class ERA5Correction(Dataset):
1582
1606
  "time": "time",
1583
1607
  }
1584
1608
  )
1585
- climatology: bool | None = True
1609
+ climatology: bool = True
1586
1610
 
1587
- def __post_init__(self):
1611
+ def __post_init__(self) -> None:
1588
1612
  if not self.climatology:
1589
1613
  raise NotImplementedError(
1590
1614
  "Correction data must be a climatology. Set climatology to True."
@@ -1592,32 +1616,31 @@ class ERA5Correction(Dataset):
1592
1616
 
1593
1617
  super().__post_init__()
1594
1618
 
1595
- def choose_subdomain(self, target_coords, straddle: bool):
1596
- """Converts longitude values in the dataset if necessary and selects a subdomain
1597
- based on the specified coordinates.
1619
+ def match_subdomain(self, target_coords: dict[str, Any]) -> None:
1620
+ """
1621
+ Selects a subdomain from the dataset matching the specified coordinates.
1598
1622
 
1599
- This method converts longitude values between different ranges if required and then extracts a subset of the
1600
- dataset according to the given coordinates. It updates the dataset in place to reflect the selected subdomain.
1623
+ This method extracts a subset of the dataset (`self.ds`) based on given latitude
1624
+ and longitude values. If the dataset spans the globe, it concatenates longitudes
1625
+ to ensure seamless wrapping.
1601
1626
 
1602
1627
  Parameters
1603
1628
  ----------
1604
- target_coords : dict
1605
- A dictionary specifying the target coordinates for selecting the subdomain. Keys should correspond to the
1606
- dimension names of the dataset (e.g., latitude and longitude), and values should be the desired ranges or
1607
- specific coordinate values.
1608
- straddle : bool
1609
- If True, assumes that target longitudes are in the range [-180, 180]. If False, assumes longitudes are in the
1610
- range [0, 360]. This parameter determines how longitude values are converted if necessary.
1629
+ target_coords : dict[str, Any]
1630
+ A dictionary containing the target latitude and longitude values to select.
1631
+ Expected keys: "lat" and "lon", each mapped to a DataArray of coordinates.
1611
1632
 
1612
1633
  Raises
1613
1634
  ------
1614
1635
  ValueError
1615
- If the specified subdomain does not fully contain the specified latitude or longitude values. This can occur
1616
- if the dataset does not cover the full range of provided coordinates.
1636
+ If the selected subdomain does not contain all specified latitude or
1637
+ longitude values.
1617
1638
 
1618
1639
  Notes
1619
1640
  -----
1620
- - The dataset (`self.ds`) is updated in place to reflect the chosen subdomain.
1641
+ - The dataset (`self.ds`) is updated in place.
1642
+ - Assumes latitude values in `target_coords["lat"]` are within dataset bounds.
1643
+ - For global datasets, longitude concatenation is applied unconditionally.
1621
1644
  """
1622
1645
  # Select the subdomain in latitude direction (so that we have to concatenate fewer latitudes below if concatenation is performed)
1623
1646
  subdomain = self.ds.sel({self.dim_names["latitude"]: target_coords["lat"]})
@@ -1731,7 +1754,7 @@ class RiverDataset:
1731
1754
  dim_names: dict[str, str]
1732
1755
  var_names: dict[str, str]
1733
1756
  opt_var_names: dict[str, str] | None = field(default_factory=dict)
1734
- climatology: bool | None = False
1757
+ climatology: bool = False
1735
1758
  ds: xr.Dataset = field(init=False, repr=False)
1736
1759
 
1737
1760
  def __post_init__(self):
@@ -1764,7 +1787,7 @@ class RiverDataset:
1764
1787
  ds : xr.Dataset
1765
1788
  The loaded xarray Dataset containing the forcing data.
1766
1789
  """
1767
- ds = _load_data(
1790
+ ds = load_data(
1768
1791
  self.filename, self.dim_names, use_dask=False, decode_times=False
1769
1792
  )
1770
1793
 
@@ -1916,7 +1939,7 @@ class RiverDataset:
1916
1939
  The dataset with rivers sorted by their volume in descending order.
1917
1940
  If the volume variable is not available, the original dataset is returned.
1918
1941
  """
1919
- if "vol" in self.opt_var_names:
1942
+ if self.opt_var_names is not None and "vol" in self.opt_var_names:
1920
1943
  volume_values = ds[self.opt_var_names["vol"]].values
1921
1944
  if isinstance(volume_values, np.ndarray):
1922
1945
  # Check if all volume values are the same
@@ -2076,7 +2099,7 @@ class DaiRiverDataset(RiverDataset):
2076
2099
  "vol": "vol_stn",
2077
2100
  }
2078
2101
  )
2079
- climatology: bool | None = False
2102
+ climatology: bool = False
2080
2103
 
2081
2104
  def add_time_info(self, ds: xr.Dataset) -> xr.Dataset:
2082
2105
  """Adds time information to the dataset based on the climatology flag and
@@ -2655,139 +2678,212 @@ def _check_dataset(
2655
2678
 
2656
2679
 
2657
2680
  def _select_relevant_times(
2658
- ds, time_dim, start_time, end_time=None, climatology=False
2681
+ ds: xr.Dataset,
2682
+ time_dim: str,
2683
+ start_time: datetime,
2684
+ end_time: datetime | None = None,
2685
+ climatology: bool = False,
2686
+ allow_flex_time: bool = False,
2659
2687
  ) -> xr.Dataset:
2660
- """Select a subset of the dataset based on the specified time range.
2688
+ """
2689
+ Select a subset of the dataset based on time constraints.
2690
+
2691
+ This function supports two main use cases:
2692
+
2693
+ 1. **Time range selection (start_time + end_time provided):**
2694
+ - Returns all records strictly between `start_time` and `end_time`.
2695
+ - Ensures at least one record at or before `start_time` and one record at or
2696
+ after `end_time` are included, even if they fall outside the strict range.
2661
2697
 
2662
- This method filters the dataset to include all records between `start_time` and `end_time`.
2663
- Additionally, it ensures that one record at or before `start_time` and one record at or
2664
- after `end_time` are included, even if they fall outside the strict time range.
2698
+ 2. **Initial condition selection (start_time provided, end_time=None):**
2699
+ - Delegates to `_select_initial_time`, which reduces the dataset to exactly one
2700
+ time entry.
2701
+ - If `allow_flex_time=True`, a +24-hour buffer around `start_time` is allowed,
2702
+ and the closest timestamp is chosen.
2703
+ - If `allow_flex_time=False`, requires an exact timestamp match.
2665
2704
 
2666
- If no `end_time` is specified, the method will select the time range of
2667
- [start_time, start_time + 24 hours] and return the closest time entry to `start_time` within that range.
2705
+ Additional behavior:
2706
+ - If `climatology=True`, the dataset must contain exactly 12 time steps. If valid,
2707
+ the climatology dataset is returned without further filtering.
2708
+ - If the dataset uses `cftime` datetime objects, these are converted to
2709
+ `np.datetime64` before filtering.
2668
2710
 
2669
2711
  Parameters
2670
2712
  ----------
2671
2713
  ds : xr.Dataset
2672
- The input dataset to be filtered. Must contain a time dimension.
2673
- time_dim: str
2674
- Name of time dimension.
2714
+ The dataset to filter. Must contain a valid time dimension.
2715
+ time_dim : str
2716
+ Name of the time dimension in `ds`.
2675
2717
  start_time : datetime
2676
- The start time for selecting relevant data.
2677
- end_time : Optional[datetime], optional
2678
- The end time for selecting relevant data. If not provided, only data at the start_time is selected if start_time is provided.
2679
- climatology : bool
2680
- Indicates whether the dataset is climatological. Defaults to False.
2718
+ Start time for filtering.
2719
+ end_time : datetime or None
2720
+ End time for filtering. If `None`, the function assumes an initial condition
2721
+ use case and selects exactly one timestamp.
2722
+ climatology : bool, optional
2723
+ If True, requires exactly 12 time steps and bypasses normal filtering.
2724
+ Defaults to False.
2725
+ allow_flex_time : bool, optional
2726
+ Whether to allow a +24h search window after `start_time` when `end_time`
2727
+ is None. If False (default), requires an exact match.
2681
2728
 
2682
2729
  Returns
2683
2730
  -------
2684
2731
  xr.Dataset
2685
- A dataset filtered to the specified time range, including the closest entries
2686
- at or before `start_time` and at or after `end_time` if applicable.
2732
+ A filtered dataset containing only the selected time entries.
2687
2733
 
2688
2734
  Raises
2689
2735
  ------
2690
2736
  ValueError
2691
- If no matching times are found between `start_time` and `start_time + 24 hours`.
2737
+ - If `climatology=True` but the dataset does not contain exactly 12 time steps.
2738
+ - If `climatology=False` and the dataset contains integer time values.
2739
+ - If no valid records are found within the requested range or window.
2692
2740
 
2693
2741
  Warns
2694
2742
  -----
2695
2743
  UserWarning
2696
- If the dataset contains exactly 12 time steps but the climatology flag is not set.
2697
- This may indicate that the dataset represents climatology data.
2698
-
2699
- UserWarning
2700
- If no records at or before `start_time` or no records at or after `end_time` are found.
2701
-
2702
- UserWarning
2703
- If the dataset does not contain any time dimension or the time dimension is incorrectly named.
2744
+ - If no records exist at or before `start_time` or at or after `end_time`.
2745
+ - If the specified time dimension does not exist in the dataset.
2704
2746
 
2705
2747
  Notes
2706
2748
  -----
2707
- - If the `climatology` flag is set and `end_time` is not provided, the method will
2708
- interpolate initial conditions from climatology data.
2709
- - If the dataset uses `cftime` datetime objects, these will be converted to standard
2710
- `np.datetime64` objects before filtering.
2749
+ - For initial conditions (end_time=None), see `_select_initial_time` for details
2750
+ on strict vs. flexible selection behavior.
2751
+ - Logs warnings instead of failing hard when boundary records are missing, and
2752
+ defaults to using the earliest or latest available time in such cases.
2711
2753
  """
2712
- if time_dim in ds.variables:
2713
- if climatology:
2714
- if len(ds[time_dim]) != 12:
2715
- raise ValueError(
2716
- f"The dataset contains {len(ds[time_dim])} time steps, but the climatology flag is set to True, which requires exactly 12 time steps."
2717
- )
2718
- if not end_time:
2719
- # Convert from timedelta64[ns] to fractional days
2720
- ds["time"] = ds["time"] / np.timedelta64(1, "D")
2721
- # Interpolate from climatology for initial conditions
2722
- ds = interpolate_from_climatology(ds, time_dim, start_time)
2723
- else:
2724
- time_type = get_time_type(ds[time_dim])
2725
- if time_type == "int":
2726
- raise ValueError(
2727
- "The dataset contains integer time values, which are only supported when the climatology flag is set to True. However, your climatology flag is set to False."
2728
- )
2729
- if time_type == "cftime":
2730
- ds = ds.assign_coords(
2731
- {time_dim: convert_cftime_to_datetime(ds[time_dim])}
2732
- )
2733
- if end_time:
2734
- end_time = end_time
2735
-
2736
- # Identify records before or at start_time
2737
- before_start = ds[time_dim] <= np.datetime64(start_time)
2738
- if before_start.any():
2739
- closest_before_start = (
2740
- ds[time_dim].where(before_start, drop=True).max()
2741
- )
2742
- else:
2743
- logging.warning("No records found at or before the start_time.")
2744
- closest_before_start = ds[time_dim].min()
2754
+ if time_dim not in ds.variables:
2755
+ logging.warning(
2756
+ f"Dataset does not contain time dimension '{time_dim}'. "
2757
+ "Please check variable naming or dataset structure."
2758
+ )
2759
+ return ds
2745
2760
 
2746
- # Identify records after or at end_time
2747
- after_end = ds[time_dim] >= np.datetime64(end_time)
2748
- if after_end.any():
2749
- closest_after_end = ds[time_dim].where(after_end, drop=True).min()
2750
- else:
2751
- logging.warning("No records found at or after the end_time.")
2752
- closest_after_end = ds[time_dim].max()
2761
+ time_type = get_time_type(ds[time_dim])
2753
2762
 
2754
- # Select records within the time range and add the closest before/after
2755
- within_range = (ds[time_dim] > np.datetime64(start_time)) & (
2756
- ds[time_dim] < np.datetime64(end_time)
2757
- )
2758
- selected_times = ds[time_dim].where(
2759
- within_range
2760
- | (ds[time_dim] == closest_before_start)
2761
- | (ds[time_dim] == closest_after_end),
2762
- drop=True,
2763
- )
2764
- ds = ds.sel({time_dim: selected_times})
2765
- else:
2766
- # Look in time range [start_time, start_time + 24h]
2767
- end_time = start_time + timedelta(days=1)
2768
- times = (np.datetime64(start_time) <= ds[time_dim]) & (
2769
- ds[time_dim] < np.datetime64(end_time)
2770
- )
2771
- if np.all(~times):
2772
- raise ValueError(
2773
- f"The dataset does not contain any time entries between the specified start_time: {start_time} "
2774
- f"and {start_time + timedelta(hours=24)}. "
2775
- "Please ensure the dataset includes time entries for that range."
2776
- )
2763
+ if climatology:
2764
+ if len(ds[time_dim]) != 12:
2765
+ raise ValueError(
2766
+ f"The dataset contains {len(ds[time_dim])} time steps, but the climatology flag is set to True, which requires exactly 12 time steps."
2767
+ )
2768
+ else:
2769
+ if time_type == "int":
2770
+ raise ValueError(
2771
+ "The dataset contains integer time values, which are only supported when the climatology flag is set to True. However, your climatology flag is set to False."
2772
+ )
2773
+ if time_type == "cftime":
2774
+ ds = ds.assign_coords({time_dim: convert_cftime_to_datetime(ds[time_dim])})
2777
2775
 
2778
- ds = ds.where(times, drop=True)
2779
- if ds.sizes[time_dim] > 1:
2780
- # Pick the time closest to start_time
2781
- ds = ds.isel({time_dim: 0})
2782
- logging.info(
2783
- f"Selected time entry closest to the specified start_time ({start_time}) within the range [{start_time}, {start_time + timedelta(hours=24)}]: {ds[time_dim].values}"
2784
- )
2776
+ if not end_time:
2777
+ # Assume we are looking for exactly one time record for initial conditions
2778
+ return _select_initial_time(
2779
+ ds, time_dim, start_time, climatology, allow_flex_time
2780
+ )
2781
+
2782
+ if climatology:
2783
+ return ds
2784
+
2785
+ # Identify records before or at start_time
2786
+ before_start = ds[time_dim] <= np.datetime64(start_time)
2787
+ if before_start.any():
2788
+ closest_before_start = ds[time_dim].where(before_start, drop=True)[-1]
2785
2789
  else:
2790
+ logging.warning(f"No records found at or before the start_time: {start_time}.")
2791
+ closest_before_start = ds[time_dim][0]
2792
+
2793
+ # Identify records after or at end_time
2794
+ after_end = ds[time_dim] >= np.datetime64(end_time)
2795
+ if after_end.any():
2796
+ closest_after_end = ds[time_dim].where(after_end, drop=True).min()
2797
+ else:
2798
+ logging.warning(f"No records found at or after the end_time: {end_time}.")
2799
+ closest_after_end = ds[time_dim].max()
2800
+
2801
+ # Select records within the time range and add the closest before/after
2802
+ within_range = (ds[time_dim] > np.datetime64(start_time)) & (
2803
+ ds[time_dim] < np.datetime64(end_time)
2804
+ )
2805
+ selected_times = ds[time_dim].where(
2806
+ within_range
2807
+ | (ds[time_dim] == closest_before_start)
2808
+ | (ds[time_dim] == closest_after_end),
2809
+ drop=True,
2810
+ )
2811
+ ds = ds.sel({time_dim: selected_times})
2812
+
2813
+ return ds
2814
+
2815
+
2816
+ def _select_initial_time(
2817
+ ds: xr.Dataset,
2818
+ time_dim: str,
2819
+ ini_time: datetime,
2820
+ climatology: bool,
2821
+ allow_flex_time: bool = False,
2822
+ ) -> xr.Dataset:
2823
+ """Select exactly one initial time from dataset.
2824
+
2825
+ Parameters
2826
+ ----------
2827
+ ds : xr.Dataset
2828
+ The input dataset with a time dimension.
2829
+ time_dim : str
2830
+ Name of the time dimension.
2831
+ ini_time : datetime
2832
+ The desired initial time.
2833
+ allow_flex_time : bool
2834
+ - If True: allow a +24h window and pick the closest available timestamp.
2835
+ - If False (default): require an exact match, otherwise raise ValueError.
2836
+
2837
+ Returns
2838
+ -------
2839
+ xr.Dataset
2840
+ Dataset reduced to exactly one timestamp.
2841
+
2842
+ Raises
2843
+ ------
2844
+ ValueError
2845
+ If no matching time is found (when `allow_flex_time=False`), or no entries are
2846
+ available within the +24h window (when `allow_flex_time=True`).
2847
+ """
2848
+ if climatology:
2849
+ # Convert from timedelta64[ns] to fractional days
2850
+ ds["time"] = ds["time"] / np.timedelta64(1, "D")
2851
+ # Interpolate from climatology for initial conditions
2852
+ return interpolate_from_climatology(ds, time_dim, ini_time)
2853
+
2854
+ if allow_flex_time:
2855
+ # Look in time range [ini_time, ini_time + 24h)
2856
+ end_time = ini_time + timedelta(days=1)
2857
+ times = (np.datetime64(ini_time) <= ds[time_dim]) & (
2858
+ ds[time_dim] < np.datetime64(end_time)
2859
+ )
2860
+
2861
+ if np.all(~times):
2862
+ raise ValueError(
2863
+ f"No time entries found between {ini_time} and {end_time}."
2864
+ )
2865
+
2866
+ ds = ds.where(times, drop=True)
2867
+ if ds.sizes[time_dim] > 1:
2868
+ # Pick the time closest to start_time
2869
+ ds = ds.isel({time_dim: 0})
2870
+
2786
2871
  logging.warning(
2787
- "Dataset does not contain any time information. Please check if the time dimension "
2788
- "is correctly named or if the dataset includes time data."
2872
+ f"Selected time entry closest to the specified start_time in +24 hour range: {ds[time_dim].values}"
2789
2873
  )
2790
2874
 
2875
+ else:
2876
+ # Strict match required
2877
+ if not (ds[time_dim].values == np.datetime64(ini_time)).any():
2878
+ raise ValueError(
2879
+ f"No exact match found for initial time {ini_time}. Consider setting allow_flex_time to True."
2880
+ )
2881
+
2882
+ ds = ds.sel({time_dim: np.datetime64(ini_time)})
2883
+
2884
+ if time_dim not in ds.dims:
2885
+ ds = ds.expand_dims(time_dim)
2886
+
2791
2887
  return ds
2792
2888
 
2793
2889
 
@@ -2916,7 +3012,7 @@ def _deduplicate_river_names(
2916
3012
 
2917
3013
  # Count all names
2918
3014
  name_counts = Counter(names)
2919
- seen = defaultdict(int)
3015
+ seen: defaultdict[str, int] = defaultdict(int)
2920
3016
 
2921
3017
  unique_names = []
2922
3018
  for name in names:
@@ -2935,3 +3031,275 @@ def _deduplicate_river_names(
2935
3031
  ds[name_var] = updated_array
2936
3032
 
2937
3033
  return ds
3034
+
3035
+
3036
+ def _concatenate_longitudes(
3037
+ ds: xr.Dataset,
3038
+ dim_names: Mapping[str, str],
3039
+ end: TConcatEndTypes,
3040
+ use_dask: bool = False,
3041
+ ) -> xr.Dataset:
3042
+ """
3043
+ Concatenate longitude dimension to handle global grids that cross
3044
+ the 0/360-degree or -180/180-degree boundary.
3045
+
3046
+ Extends the longitude dimension either lower, upper, or both sides
3047
+ by +/- 360 degrees and duplicates the corresponding variables along
3048
+ that dimension.
3049
+
3050
+ Parameters
3051
+ ----------
3052
+ ds : xr.Dataset
3053
+ Input xarray Dataset to be concatenated.
3054
+ dim_names : Mapping[str, str]
3055
+ Dictionary or mapping containing dimension names. Must include "longitude".
3056
+ end : str
3057
+ Specifies which side(s) to extend:
3058
+ - "lower": extend by subtracting 360 degrees.
3059
+ - "upper": extend by adding 360 degrees.
3060
+ - "both": extend on both sides.
3061
+ use_dask : bool, default False
3062
+ If True, chunk the concatenated longitude dimension using Dask.
3063
+
3064
+ Returns
3065
+ -------
3066
+ xr.Dataset
3067
+ Dataset with longitude dimension extended and data variables duplicated.
3068
+
3069
+ Notes
3070
+ -----
3071
+ Only data variables containing the longitude dimension are concatenated;
3072
+ others are left unchanged.
3073
+ """
3074
+ ds_concat = xr.Dataset()
3075
+
3076
+ lon_name = dim_names["longitude"]
3077
+ lon = ds[lon_name]
3078
+
3079
+ match end:
3080
+ case "lower":
3081
+ lon_concat = xr.concat([lon - 360, lon], dim=lon_name)
3082
+ n_copies = 2
3083
+ case "upper":
3084
+ lon_concat = xr.concat([lon, lon + 360], dim=lon_name)
3085
+ n_copies = 2
3086
+ case "both":
3087
+ lon_concat = xr.concat([lon - 360, lon, lon + 360], dim=lon_name)
3088
+ n_copies = 3
3089
+ case _:
3090
+ raise ValueError(f"Invalid `end` value: {end}")
3091
+
3092
+ for var in ds.variables:
3093
+ if lon_name in ds[var].dims:
3094
+ field = ds[var]
3095
+ field_concat = xr.concat([field] * n_copies, dim=lon_name)
3096
+
3097
+ if use_dask:
3098
+ field_concat = field_concat.chunk({lon_name: -1})
3099
+
3100
+ ds_concat[var] = field_concat
3101
+ else:
3102
+ ds_concat[var] = ds[var]
3103
+
3104
+ ds_concat = ds_concat.assign_coords({lon_name: lon_concat.values})
3105
+
3106
+ return ds_concat
3107
+
3108
+
3109
+ def choose_subdomain(
3110
+ ds: xr.Dataset,
3111
+ dim_names: Mapping[str, str],
3112
+ resolution: float,
3113
+ is_global: bool,
3114
+ target_coords: Mapping[str, Any],
3115
+ buffer_points: int = 20,
3116
+ use_dask: bool = False,
3117
+ ) -> xr.Dataset:
3118
+ """
3119
+ Select a subdomain from an xarray Dataset based on target coordinates,
3120
+ with optional buffer points and global longitude handling.
3121
+
3122
+ Parameters
3123
+ ----------
3124
+ ds : xr.Dataset
3125
+ The full xarray Dataset to subset.
3126
+ dim_names : Mapping[str, str]
3127
+ Dictionary mapping logical dimension names to dataset dimension names.
3128
+ Example: {"latitude": "latitude", "longitude": "longitude"}.
3129
+ resolution : float
3130
+ Spatial resolution of the dataset, used to compute buffer margin.
3131
+ is_global : bool
3132
+ Whether the dataset covers global longitude (affects concatenation logic).
3133
+ target_coords : Mapping[str, Any]
3134
+ Dictionary containing target latitude and longitude coordinates.
3135
+ Expected keys: "lat", "lon", and "straddle" (boolean for crossing 180°).
3136
+ buffer_points : int, default 20
3137
+ Number of grid points to extend beyond the target coordinates.
3138
+ use_dask: bool, optional
3139
+ Indicates whether to use dask for chunking. If True, data is loaded with dask; if False, data is processed eagerly. Defaults to False.
3140
+
3141
+ Returns
3142
+ -------
3143
+ xr.Dataset
3144
+ Subset of the input Dataset covering the requested coordinates plus buffer.
3145
+
3146
+ Raises
3147
+ ------
3148
+ ValueError
3149
+ If the selected latitude or longitude range does not intersect the dataset.
3150
+ """
3151
+ lat_min = target_coords["lat"].min().values
3152
+ lat_max = target_coords["lat"].max().values
3153
+ lon_min = target_coords["lon"].min().values
3154
+ lon_max = target_coords["lon"].max().values
3155
+
3156
+ margin = resolution * buffer_points
3157
+
3158
+ # Select the subdomain in latitude direction (so that we have to concatenate fewer latitudes below if concatenation is necessary)
3159
+ subdomain = ds.sel(
3160
+ **{
3161
+ dim_names["latitude"]: slice(lat_min - margin, lat_max + margin),
3162
+ }
3163
+ )
3164
+ lon = subdomain[dim_names["longitude"]]
3165
+
3166
+ if is_global:
3167
+ concats = []
3168
+ # Concatenate only if necessary
3169
+ if lon_max + margin > lon.max():
3170
+ # See if shifting by +360 degrees helps
3171
+ if (lon_min - margin > (lon + 360).min()) and (
3172
+ lon_max + margin < (lon + 360).max()
3173
+ ):
3174
+ subdomain[dim_names["longitude"]] = lon + 360
3175
+ lon = subdomain[dim_names["longitude"]]
3176
+ else:
3177
+ concats.append("upper")
3178
+ if lon_min - margin < lon.min():
3179
+ # See if shifting by -360 degrees helps
3180
+ if (lon_min - margin > (lon - 360).min()) and (
3181
+ lon_max + margin < (lon - 360).max()
3182
+ ):
3183
+ subdomain[dim_names["longitude"]] = lon - 360
3184
+ lon = subdomain[dim_names["longitude"]]
3185
+ else:
3186
+ concats.append("lower")
3187
+
3188
+ if concats:
3189
+ end = "both" if len(concats) == 2 else concats[0]
3190
+ end = cast(TConcatEndTypes, end)
3191
+ subdomain = _concatenate_longitudes(
3192
+ subdomain, dim_names=dim_names, end=end, use_dask=use_dask
3193
+ )
3194
+ lon = subdomain[dim_names["longitude"]]
3195
+
3196
+ else:
3197
+ # Adjust longitude range if needed to match the expected range
3198
+ if not target_coords["straddle"]:
3199
+ if lon.min() < -180:
3200
+ if lon_max + margin > 0:
3201
+ lon_min -= 360
3202
+ lon_max -= 360
3203
+ elif lon.min() < 0:
3204
+ if lon_max + margin > 180:
3205
+ lon_min -= 360
3206
+ lon_max -= 360
3207
+
3208
+ if target_coords["straddle"]:
3209
+ if lon.max() > 360:
3210
+ if lon_min - margin < 180:
3211
+ lon_min += 360
3212
+ lon_max += 360
3213
+ elif lon.max() > 180:
3214
+ if lon_min - margin < 0:
3215
+ lon_min += 360
3216
+ lon_max += 360
3217
+ # Select the subdomain in longitude direction
3218
+ subdomain = subdomain.sel(
3219
+ **{
3220
+ dim_names["longitude"]: slice(lon_min - margin, lon_max + margin),
3221
+ }
3222
+ )
3223
+ # Check if the selected subdomain has zero dimensions in latitude or longitude
3224
+ if (
3225
+ dim_names["latitude"] not in subdomain
3226
+ or subdomain[dim_names["latitude"]].size == 0
3227
+ ):
3228
+ raise ValueError("Selected latitude range does not intersect with dataset.")
3229
+
3230
+ if (
3231
+ dim_names["longitude"] not in subdomain
3232
+ or subdomain[dim_names["longitude"]].size == 0
3233
+ ):
3234
+ raise ValueError("Selected longitude range does not intersect with dataset.")
3235
+
3236
+ # Adjust longitudes to expected range if needed
3237
+ lon = subdomain[dim_names["longitude"]]
3238
+ if target_coords["straddle"]:
3239
+ subdomain[dim_names["longitude"]] = xr.where(lon > 180, lon - 360, lon)
3240
+ else:
3241
+ subdomain[dim_names["longitude"]] = xr.where(lon < 0, lon + 360, lon)
3242
+
3243
+ return subdomain
3244
+
3245
+
3246
+ def get_glorys_bounds(
3247
+ grid: Grid,
3248
+ glorys_grid_path: Path | str | None = None,
3249
+ ) -> dict[str, float]:
3250
+ """
3251
+ Compute the latitude/longitude bounds of a GLORYS spatial subset
3252
+ that fully covers the given ROMS grid (with margin for regridding).
3253
+
3254
+ Parameters
3255
+ ----------
3256
+ grid : Grid
3257
+ The grid object.
3258
+ glorys_grid_path : str, optional
3259
+ Path to the GLORYS global grid file. If None, defaults to
3260
+ "<repo_root>/data/grids/GLORYS_global_grid.nc".
3261
+
3262
+ Returns
3263
+ -------
3264
+ dict[str, float]
3265
+ Dictionary containing the bounding box values:
3266
+
3267
+ - `"minimum_latitude"` : float
3268
+ - `"maximum_latitude"` : float
3269
+ - `"minimum_longitude"` : float
3270
+ - `"maximum_longitude"` : float
3271
+
3272
+ Notes
3273
+ -----
3274
+ - The resolution is estimated as the mean of latitude and longitude spacing.
3275
+ """
3276
+ if glorys_grid_path is None:
3277
+ glorys_grid_path = GLORYS_GLOBAL_GRID_PATH
3278
+
3279
+ ds = xr.open_dataset(glorys_grid_path)
3280
+
3281
+ # Estimate grid resolution (mean spacing in degrees)
3282
+ res_lat = ds.latitude.diff("latitude").mean()
3283
+ res_lon = ds.longitude.diff("longitude").mean()
3284
+ resolution = (res_lat + res_lon) / 2
3285
+
3286
+ # Extract target grid coordinates
3287
+ target_coords = get_target_coords(grid)
3288
+
3289
+ # Select subdomain with margin
3290
+ ds_subset = choose_subdomain(
3291
+ ds=ds,
3292
+ dim_names={"latitude": "latitude", "longitude": "longitude"},
3293
+ resolution=resolution,
3294
+ is_global=True,
3295
+ target_coords=target_coords,
3296
+ buffer_points=DEFAULT_NR_BUFFER_POINTS + 1,
3297
+ )
3298
+
3299
+ # Compute bounds
3300
+ return {
3301
+ "minimum_latitude": float(ds_subset.latitude.min()),
3302
+ "maximum_latitude": float(ds_subset.latitude.max()),
3303
+ "minimum_longitude": float(ds_subset.longitude.min()),
3304
+ "maximum_longitude": float(ds_subset.longitude.max()),
3305
+ }