roms-tools 3.1.1__py3-none-any.whl → 3.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. roms_tools/__init__.py +8 -1
  2. roms_tools/analysis/cdr_analysis.py +203 -0
  3. roms_tools/analysis/cdr_ensemble.py +198 -0
  4. roms_tools/analysis/roms_output.py +80 -46
  5. roms_tools/data/grids/GLORYS_global_grid.nc +0 -0
  6. roms_tools/download.py +4 -0
  7. roms_tools/plot.py +131 -30
  8. roms_tools/regrid.py +6 -1
  9. roms_tools/setup/boundary_forcing.py +94 -44
  10. roms_tools/setup/cdr_forcing.py +123 -15
  11. roms_tools/setup/cdr_release.py +161 -8
  12. roms_tools/setup/datasets.py +709 -341
  13. roms_tools/setup/grid.py +167 -139
  14. roms_tools/setup/initial_conditions.py +113 -48
  15. roms_tools/setup/mask.py +63 -7
  16. roms_tools/setup/nesting.py +67 -42
  17. roms_tools/setup/river_forcing.py +45 -19
  18. roms_tools/setup/surface_forcing.py +16 -10
  19. roms_tools/setup/tides.py +1 -2
  20. roms_tools/setup/topography.py +4 -4
  21. roms_tools/setup/utils.py +134 -22
  22. roms_tools/tests/test_analysis/test_cdr_analysis.py +144 -0
  23. roms_tools/tests/test_analysis/test_cdr_ensemble.py +202 -0
  24. roms_tools/tests/test_analysis/test_roms_output.py +61 -3
  25. roms_tools/tests/test_setup/test_boundary_forcing.py +111 -52
  26. roms_tools/tests/test_setup/test_cdr_forcing.py +54 -0
  27. roms_tools/tests/test_setup/test_cdr_release.py +118 -1
  28. roms_tools/tests/test_setup/test_datasets.py +458 -34
  29. roms_tools/tests/test_setup/test_grid.py +238 -121
  30. roms_tools/tests/test_setup/test_initial_conditions.py +94 -41
  31. roms_tools/tests/test_setup/test_surface_forcing.py +28 -3
  32. roms_tools/tests/test_setup/test_utils.py +91 -1
  33. roms_tools/tests/test_setup/test_validation.py +21 -15
  34. roms_tools/tests/test_setup/utils.py +71 -0
  35. roms_tools/tests/test_tiling/test_join.py +241 -0
  36. roms_tools/tests/test_tiling/test_partition.py +45 -0
  37. roms_tools/tests/test_utils.py +224 -2
  38. roms_tools/tiling/join.py +189 -0
  39. roms_tools/tiling/partition.py +44 -30
  40. roms_tools/utils.py +488 -161
  41. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/METADATA +15 -4
  42. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/RECORD +45 -37
  43. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/WHEEL +0 -0
  44. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/licenses/LICENSE +0 -0
  45. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/top_level.txt +0 -0
roms_tools/__init__.py CHANGED
@@ -10,6 +10,7 @@ except ImportError: # pragma: no cover
10
10
  # grid must be imported first
11
11
  from roms_tools.setup.grid import Grid # noqa: I001, F401
12
12
  from roms_tools.analysis.roms_output import ROMSOutput # noqa: F401
13
+ from roms_tools.analysis.cdr_ensemble import Ensemble # noqa: F401
13
14
  from roms_tools.setup.boundary_forcing import BoundaryForcing # noqa: F401
14
15
  from roms_tools.setup.cdr_forcing import CDRForcing # noqa: F401
15
16
  from roms_tools.setup.cdr_release import TracerPerturbation, VolumeRelease # noqa: F401
@@ -19,6 +20,12 @@ from roms_tools.setup.river_forcing import RiverForcing # noqa: F401
19
20
  from roms_tools.setup.surface_forcing import SurfaceForcing # noqa: F401
20
21
  from roms_tools.setup.tides import TidalForcing # noqa: F401
21
22
  from roms_tools.tiling.partition import partition_netcdf # noqa: F401
23
+ from roms_tools.setup.datasets import get_glorys_bounds # noqa: F401
24
+ from roms_tools.tiling.join import open_partitions, join_netcdf # noqa: F401
25
+
22
26
 
23
27
  # Configure logging when the package is imported
24
- logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s")
28
+ LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
29
+ DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
30
+
31
+ logging.basicConfig(level=logging.INFO, format=LOG_FORMAT, datefmt=DATE_FORMAT)
@@ -0,0 +1,203 @@
1
+ import logging
2
+
3
+ import numpy as np
4
+ import xarray as xr
5
+
6
+
7
+ def compute_cdr_metrics(ds: xr.Dataset, grid_ds: xr.Dataset) -> xr.Dataset:
8
+ """
9
+ Compute Carbon Dioxide Removal (CDR) metrics from model output.
10
+
11
+ Calculates CDR uptake efficiency using two methods:
12
+ 1. Flux-based: from area-integrated CO2 flux differences.
13
+ 2. DIC difference-based: from volume-integrated DIC differences.
14
+
15
+ Copies selected tracer and flux variables and computes grid cell areas
16
+ and averaging window durations.
17
+
18
+ Parameters
19
+ ----------
20
+ ds : xr.Dataset
21
+ Model output with required variables:
22
+ 'avg_begin_time', 'avg_end_time', 'ALK_source', 'DIC_source',
23
+ 'FG_CO2', 'FG_ALT_CO2', 'hDIC', 'hDIC_ALT_CO2'.
24
+
25
+ grid_ds : xr.Dataset
26
+ Grid dataset with 'pm', 'pn' (inverse grid spacing).
27
+
28
+ Returns
29
+ -------
30
+ ds_cdr : xr.Dataset
31
+ Dataset containing:
32
+ - 'area', 'window_length'
33
+ - copied flux/tracer variables
34
+ - 'cdr_efficiency' and 'cdr_efficiency_from_delta_diff' (dimensionless)
35
+
36
+ Raises
37
+ ------
38
+ KeyError
39
+ If required variables are missing from `ds` or `grid_ds`.
40
+ """
41
+ # Define required variables
42
+ ds_vars = [
43
+ "avg_begin_time",
44
+ "avg_end_time",
45
+ "ALK_source",
46
+ "FG_CO2",
47
+ "FG_ALT_CO2",
48
+ "hDIC",
49
+ "hDIC_ALT_CO2",
50
+ ]
51
+ grid_vars = ["pm", "pn"]
52
+
53
+ # Check that all required variables exist
54
+ missing_ds = [var for var in ds_vars if var not in ds]
55
+ missing_grid = [var for var in grid_vars if var not in grid_ds]
56
+
57
+ if missing_ds:
58
+ raise KeyError(f"Missing required variables in ds: {missing_ds}")
59
+ if missing_grid:
60
+ raise KeyError(f"Missing required variables in grid_ds: {missing_grid}")
61
+
62
+ ds_cdr = xr.Dataset()
63
+
64
+ # Copy relevant variables
65
+ vars_to_copy = ["FG_CO2", "FG_ALT_CO2", "hDIC", "hDIC_ALT_CO2"]
66
+ for var_name in vars_to_copy:
67
+ ds_cdr[var_name] = ds[var_name]
68
+
69
+ # Grid cell area
70
+ ds_cdr["area"] = 1 / (grid_ds["pm"] * grid_ds["pn"])
71
+ ds_cdr["area"].attrs.update(
72
+ long_name="Grid cell area",
73
+ units="m^2",
74
+ )
75
+
76
+ # Duration of each averaging window
77
+ ds_cdr["window_length"] = ds["avg_end_time"] - ds["avg_begin_time"]
78
+ ds_cdr["window_length"].attrs.update(
79
+ long_name="Duration of each averaging window",
80
+ units="s",
81
+ )
82
+
83
+ _validate_source(ds)
84
+
85
+ # Cumulative alkalinity source
86
+ source = (
87
+ (
88
+ (ds["ALK_source"] - ds["DIC_source"]).sum(
89
+ dim=["s_rho", "eta_rho", "xi_rho"]
90
+ )
91
+ * ds_cdr["window_length"]
92
+ )
93
+ .cumsum(dim="time")
94
+ .compute()
95
+ )
96
+
97
+ # Cumulative flux-based uptake (Method 1)
98
+ flux = (
99
+ (
100
+ ((ds["FG_CO2"] - ds["FG_ALT_CO2"]) * ds_cdr["area"]).sum(
101
+ dim=["eta_rho", "xi_rho"]
102
+ )
103
+ * ds_cdr["window_length"]
104
+ )
105
+ .cumsum(dim="time")
106
+ .compute()
107
+ )
108
+
109
+ # DIC difference-based uptake (Method 2)
110
+ diff_DIC = (
111
+ ((ds["hDIC"] - ds["hDIC_ALT_CO2"]) * ds_cdr["area"])
112
+ .sum(dim=["s_rho", "eta_rho", "xi_rho"])
113
+ .compute()
114
+ )
115
+
116
+ # Normalize by cumulative source with safe division (NaN where source=0)
117
+ with np.errstate(divide="ignore", invalid="ignore"):
118
+ uptake_efficiency_flux = (flux / source).where(np.isfinite(flux / source))
119
+ uptake_efficiency_diff = (diff_DIC / source).where(
120
+ np.isfinite(diff_DIC / source)
121
+ )
122
+
123
+ _validate_uptake_efficiency(uptake_efficiency_flux, uptake_efficiency_diff)
124
+
125
+ # Store results with metadata
126
+ ds_cdr["cdr_efficiency"] = uptake_efficiency_flux
127
+ ds_cdr["cdr_efficiency"].attrs.update(
128
+ long_name="CDR uptake efficiency (from flux differences)",
129
+ units="nondimensional",
130
+ description="Carbon Dioxide Removal efficiency computed using area-integrated CO2 flux differences",
131
+ )
132
+ ds_cdr["cdr_efficiency_from_delta_diff"] = uptake_efficiency_diff
133
+ ds_cdr["cdr_efficiency_from_delta_diff"].attrs.update(
134
+ long_name="CDR uptake efficiency (from DIC differences)",
135
+ units="nondimensional",
136
+ description="Carbon Dioxide Removal efficiency computed using volume-integrated DIC differences",
137
+ )
138
+
139
+ return ds_cdr
140
+
141
+
142
+ def _validate_uptake_efficiency(
143
+ uptake_efficiency_flux: xr.DataArray,
144
+ uptake_efficiency_diff: xr.DataArray,
145
+ ) -> float:
146
+ """
147
+ Compute and log the maximum absolute difference between two uptake efficiency estimates.
148
+
149
+ Parameters
150
+ ----------
151
+ uptake_efficiency_flux : xr.DataArray
152
+ Uptake computed from fluxes.
153
+ uptake_efficiency_diff : xr.DataArray
154
+ Uptake computed from DIC differences.
155
+
156
+ Returns
157
+ -------
158
+ max_abs_diff : float
159
+ Maximum absolute difference between uptake_flux and uptake_diff.
160
+ """
161
+ abs_diff = np.abs(uptake_efficiency_flux - uptake_efficiency_diff)
162
+ max_abs_diff = float(abs_diff.max())
163
+
164
+ logging.info(
165
+ "Max absolute difference between flux-based and DIC-based uptake efficiency: %.3e",
166
+ max_abs_diff,
167
+ )
168
+
169
+ return max_abs_diff
170
+
171
+
172
+ def _validate_source(ds: xr.Dataset):
173
+ """
174
+ Validate that ALK_source and DIC_source in a ROMS dataset respect release constraints.
175
+
176
+ - 'ALK_source' must be non-negative (≥ 0).
177
+ - 'DIC_source' must be non-positive (≤ 0).
178
+
179
+ Parameters
180
+ ----------
181
+ ds : xr.Dataset
182
+ Dataset expected to contain 'ALK_source' and 'DIC_source'.
183
+
184
+ Raises
185
+ ------
186
+ KeyError
187
+ If 'ALK_source' or 'DIC_source' are missing from the dataset.
188
+ ValueError
189
+ If 'ALK_source' or 'DIC_source' violate the release constraints.
190
+ """
191
+ constraints = {
192
+ "ALK_source": lambda x: x >= 0,
193
+ "DIC_source": lambda x: x <= 0,
194
+ }
195
+
196
+ for var, check in constraints.items():
197
+ if var not in ds.data_vars:
198
+ raise KeyError(f"Dataset is missing required variable '{var}'.")
199
+ if not check(ds[var]).all():
200
+ sign = "negative" if var == "ALK_source" else "positive"
201
+ raise ValueError(
202
+ f"'{var}' contains {sign} values, which violates release constraints."
203
+ )
@@ -0,0 +1,198 @@
1
+ from dataclasses import dataclass, field
2
+ from pathlib import Path
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import xarray as xr
7
+
8
+
9
+ @dataclass
10
+ class Ensemble:
11
+ """
12
+ Represents an ensemble of CDR (Carbon Dioxide Removal) experiments.
13
+
14
+ Loads, aligns, and analyzes efficiency metrics across multiple members.
15
+
16
+ Parameters
17
+ ----------
18
+ members : dict[str, str | xr.Dataset]
19
+ Dictionary mapping member names to either file paths (NetCDF) or
20
+ xarray.Dataset objects containing the CDR metrics.
21
+ """
22
+
23
+ members: dict[str, str | xr.Dataset]
24
+ """Dictionary mapping member names to CDR metrics."""
25
+ ds: xr.Dataset = field(init=False)
26
+ """xarray Dataset containing aligned efficiencies for all ensemble members."""
27
+
28
+ def __post_init__(self):
29
+ """
30
+ Loads datasets, extracts efficiencies, aligns times, stores in self.ds, and
31
+ plots ensemble curves.
32
+ """
33
+ datasets = self._load_members()
34
+ effs = {name: self._extract_efficiency(ds) for name, ds in datasets.items()}
35
+ aligned = self._align_times(effs)
36
+ self.ds = self._compute_statistics(aligned)
37
+
38
+ def _load_members(self) -> dict[str, xr.Dataset]:
39
+ """
40
+ Loads ensemble member datasets.
41
+
42
+ Converts any file paths in `self.members` to xarray Datasets.
43
+ Members that are already xarray Datasets are left unchanged.
44
+
45
+ Returns
46
+ -------
47
+ dict[str, xr.Dataset]
48
+ Dictionary mapping member names to xarray.Dataset objects.
49
+ """
50
+ return {
51
+ name: xr.open_dataset(path) if isinstance(path, str | Path) else path
52
+ for name, path in self.members.items()
53
+ }
54
+
55
+ def _extract_efficiency(self, ds: xr.Dataset) -> xr.DataArray:
56
+ """
57
+ Extracts the CDR efficiency metric and reindex to time since release start.
58
+
59
+ Parameters
60
+ ----------
61
+ ds : xr.Dataset
62
+ Dataset containing a "cdr_efficiency" variable and "abs_time" coordinate.
63
+
64
+ Returns
65
+ -------
66
+ xr.DataArray
67
+ Efficiency reindexed to a relative time axis in days since release start.
68
+
69
+ Raises
70
+ ------
71
+ ValueError
72
+ If 'abs_time' coordinate is missing or there are no valid efficiency values.
73
+ """
74
+ eff = ds["cdr_efficiency"]
75
+
76
+ # Check that abs_time exists
77
+ if "abs_time" in eff.coords:
78
+ abs_time = eff.coords["abs_time"]
79
+ elif "abs_time" in ds.data_vars:
80
+ abs_time = ds["abs_time"]
81
+ else:
82
+ raise ValueError(
83
+ "Dataset must contain an 'abs_time' coordinate or data variable."
84
+ )
85
+
86
+ # Drop NaNs to find first valid time
87
+ valid_mask = ~np.isnan(eff.values)
88
+ if not valid_mask.any():
89
+ raise ValueError("No valid efficiency values found in dataset.")
90
+
91
+ release_start = abs_time.values[valid_mask][0]
92
+
93
+ # Compute relative time in days
94
+ time_rel = (abs_time - release_start).astype("timedelta64[D]")
95
+
96
+ # Assign new time coordinate and drop abs_time if it was a data variable
97
+ eff_rel = eff.assign_coords(time=time_rel)
98
+ eff_rel.time.attrs["long_name"] = "time since release start"
99
+
100
+ if hasattr(eff_rel, "coords") and "abs_time" in eff_rel.coords:
101
+ eff_rel = eff_rel.drop_vars("abs_time")
102
+ elif hasattr(eff_rel, "variables") and "abs_time" in eff_rel.variables:
103
+ eff_rel = eff_rel.drop_vars("abs_time")
104
+
105
+ return eff_rel
106
+
107
+ def _align_times(self, effs: dict[str, xr.DataArray]) -> xr.Dataset:
108
+ """
109
+ Align all ensemble members to a common time axis.
110
+
111
+ Each member is reindexed to the union of all time coordinates.
112
+ Times outside the original range of a mamber are filled with NaN.
113
+
114
+ Parameters
115
+ ----------
116
+ effs : dict[str, xr.DataArray]
117
+ Dictionary mapping member names to efficiency DataArrays
118
+ (reindexed relative to their release start).
119
+
120
+ Returns
121
+ -------
122
+ xr.Dataset
123
+ Dataset containing all members aligned to a shared time coordinate, with
124
+ NaNs for missing times.
125
+ """
126
+ all_times = np.unique(
127
+ np.concatenate([eff.time.values for eff in effs.values()])
128
+ )
129
+ aligned = {name: eff.reindex(time=all_times) for name, eff in effs.items()}
130
+ return xr.Dataset(aligned)
131
+
132
+ def _compute_statistics(self, ds: xr.Dataset) -> xr.Dataset:
133
+ """
134
+ Computes ensemble statistics: mean and standard deviation.
135
+
136
+ Parameters
137
+ ----------
138
+ ds : xr.Dataset
139
+ Dataset containing aligned ensemble member efficiencies.
140
+
141
+ Returns
142
+ -------
143
+ xr.Dataset
144
+ Dataset with additional variables "mean" and "std" representing
145
+ the ensemble mean and standard deviation across members.
146
+ """
147
+ da = ds.to_dataarray("member") # stack into (member, time)
148
+ ds["ensemble_mean"] = da.mean(dim="member")
149
+ ds["ensemble_std"] = da.std(dim="member")
150
+ return ds
151
+
152
+ def plot(
153
+ self,
154
+ save_path: str | None = None,
155
+ ) -> None:
156
+ """
157
+ Plots ensemble members with mean ± standard deviation shading.
158
+
159
+ Displays individual member efficiency time series along with the ensemble
160
+ mean and ±1 standard deviation as a shaded region.
161
+
162
+ Parameters
163
+ ----------
164
+ save_path : str, optional
165
+ Path to save the generated plot. If None, the plot is shown interactively.
166
+ Default is None.
167
+
168
+ Returns
169
+ -------
170
+ None
171
+ This method does not return any value. It generates and displays a plot.
172
+ """
173
+ fig, ax = plt.subplots(figsize=(8, 5))
174
+
175
+ time = self.ds.time.values / np.timedelta64(1, "D") # converts to float days
176
+
177
+ # Individual ensemble members
178
+ for name in self.members.keys():
179
+ ax.plot(time, self.ds[name], lw=2, label=name)
180
+
181
+ # Mean ± std
182
+ ax.plot(time, self.ds.ensemble_mean, color="black", lw=2, label="ensemble mean")
183
+ ax.fill_between(
184
+ time,
185
+ self.ds.ensemble_mean - self.ds.ensemble_std,
186
+ self.ds.ensemble_mean + self.ds.ensemble_std,
187
+ color="gray",
188
+ alpha=0.3,
189
+ )
190
+
191
+ ax.set_xlabel("Time since release start [days]")
192
+ ax.set_ylabel("CDR Efficiency")
193
+ ax.set_title("Ensemble of interventions")
194
+ ax.legend()
195
+ ax.grid()
196
+
197
+ if save_path:
198
+ plt.savefig(save_path, dpi=300, bbox_inches="tight")
@@ -10,14 +10,15 @@ import xarray as xr
10
10
  from matplotlib.axes import Axes
11
11
 
12
12
  from roms_tools import Grid
13
- from roms_tools.plot import plot
13
+ from roms_tools.analysis.cdr_analysis import compute_cdr_metrics
14
+ from roms_tools.plot import plot, plot_uptake_efficiency
14
15
  from roms_tools.regrid import LateralRegridFromROMS, VerticalRegridFromROMS
15
16
  from roms_tools.utils import (
16
- _generate_coordinate_range,
17
- _load_data,
17
+ generate_coordinate_range,
18
18
  infer_nominal_horizontal_resolution,
19
19
  interpolate_from_rho_to_u,
20
20
  interpolate_from_rho_to_v,
21
+ load_data,
21
22
  )
22
23
  from roms_tools.vertical_coordinate import (
23
24
  compute_depth_coordinates,
@@ -71,6 +72,27 @@ class ROMSOutput:
71
72
  # Dataset for depth coordinates
72
73
  self.ds_depth_coords = xr.Dataset()
73
74
 
75
+ def cdr_metrics(self) -> None:
76
+ """
77
+ Compute and plot Carbon Dioxide Removal (CDR) metrics.
78
+
79
+ If the CDR metrics dataset (`self.ds_cdr`) does not already exist,
80
+ it computes the metrics using model output and grid information.
81
+ Afterwards, it generates a plot of the computed metrics.
82
+
83
+ Notes
84
+ -----
85
+ Metrics include:
86
+ - Grid cell area
87
+ - Selected tracer and flux variables
88
+ - Uptake efficiency computed from flux differences and DIC differences
89
+ """
90
+ if not hasattr(self, "ds_cdr"):
91
+ # Compute metrics and store
92
+ self.ds_cdr = compute_cdr_metrics(self.ds, self.grid.ds)
93
+
94
+ plot_uptake_efficiency(self.ds_cdr)
95
+
74
96
  def plot(
75
97
  self,
76
98
  var_name: str,
@@ -170,7 +192,7 @@ class ROMSOutput:
170
192
  """
171
193
  # Check if variable exists
172
194
  if var_name not in self.ds:
173
- raise ValueError(f"Variable '{var_name}' is not found in the dataset.")
195
+ raise ValueError(f"Variable '{var_name}' is not found in self.ds.")
174
196
 
175
197
  # Pick the variable
176
198
  field = self.ds[var_name]
@@ -269,11 +291,11 @@ class ROMSOutput:
269
291
 
270
292
  if horizontal_resolution is None:
271
293
  horizontal_resolution = infer_nominal_horizontal_resolution(self.grid.ds)
272
- lons = _generate_coordinate_range(
294
+ lons = generate_coordinate_range(
273
295
  lon_deg.min().values, lon_deg.max().values, horizontal_resolution
274
296
  )
275
297
  lons = xr.DataArray(lons, dims=["lon"], attrs={"units": "°E"})
276
- lats = _generate_coordinate_range(
298
+ lats = generate_coordinate_range(
277
299
  lat_deg.min().values, lat_deg.max().values, horizontal_resolution
278
300
  )
279
301
  lats = xr.DataArray(lats, dims=["lat"], attrs={"units": "°N"})
@@ -333,21 +355,22 @@ class ROMSOutput:
333
355
  layer_depth_loc = lateral_regrid.apply(layer_depth_loc)
334
356
  h_loc = lateral_regrid.apply(h_loc)
335
357
  # Vertical regridding
336
- vertical_regrid = VerticalRegridFromROMS(ds_loc)
337
- for var_name in var_names_loc:
338
- if "s_rho" in ds_loc[var_name].dims:
339
- attrs = ds_loc[var_name].attrs
340
- regridded = vertical_regrid.apply(
341
- ds_loc[var_name],
342
- layer_depth_loc,
343
- depth_levels,
344
- mask_edges=False,
345
- )
346
- regridded = regridded.where(regridded.depth < h_loc)
347
- ds_loc[var_name] = regridded
348
- ds_loc[var_name].attrs = attrs
349
-
350
- ds_loc = ds_loc.assign_coords({"depth": depth_levels})
358
+ if "s_rho" in ds_loc.dims:
359
+ vertical_regrid = VerticalRegridFromROMS(ds_loc)
360
+ for var_name in var_names_loc:
361
+ if "s_rho" in ds_loc[var_name].dims:
362
+ attrs = ds_loc[var_name].attrs
363
+ regridded = vertical_regrid.apply(
364
+ ds_loc[var_name],
365
+ layer_depth_loc,
366
+ depth_levels,
367
+ mask_edges=False,
368
+ )
369
+ regridded = regridded.where(regridded.depth < h_loc)
370
+ ds_loc[var_name] = regridded
371
+ ds_loc[var_name].attrs = attrs
372
+
373
+ ds_loc = ds_loc.assign_coords({"depth": depth_levels})
351
374
 
352
375
  # Collect regridded dataset for merging
353
376
  regridded_datasets.append(ds_loc)
@@ -418,10 +441,12 @@ class ROMSOutput:
418
441
  def _load_model_output(self) -> xr.Dataset:
419
442
  """Load the model output."""
420
443
  # Load the dataset
421
- ds = _load_data(
444
+ ds = load_data(
422
445
  self.path,
423
446
  dim_names={"time": "time"},
424
447
  use_dask=self.use_dask,
448
+ decode_times=False,
449
+ decode_timedelta=False,
425
450
  time_chunking=True,
426
451
  force_combine_nested=True,
427
452
  )
@@ -514,35 +539,39 @@ class ROMSOutput:
514
539
 
515
540
  Notes
516
541
  -----
542
+ - Missing attributes trigger a warning instead of an exception.
517
543
  - `theta_s`, `theta_b`, and `hc` are checked for exact equality using `np.array_equal`.
518
544
  - `Cs_r` and `Cs_w` are checked for numerical closeness using `np.allclose`.
519
545
  """
520
- # Check exact equality for theta_s, theta_b, and hc
521
- if not np.array_equal(self.grid.theta_s, ds.attrs["theta_s"]):
522
- raise ValueError(
523
- f"theta_s from grid ({self.grid.theta_s}) does not match dataset ({ds.attrs['theta_s']})."
524
- )
525
-
526
- if not np.array_equal(self.grid.theta_b, ds.attrs["theta_b"]):
527
- raise ValueError(
528
- f"theta_b from grid ({self.grid.theta_b}) does not match dataset ({ds.attrs['theta_b']})."
529
- )
546
+ required_exact = ["theta_s", "theta_b", "hc"]
547
+ required_close = ["Cs_r", "Cs_w"]
530
548
 
531
- if not np.array_equal(self.grid.hc, ds.attrs["hc"]):
532
- raise ValueError(
533
- f"hc from grid ({self.grid.hc}) does not match dataset ({ds.attrs['hc']})."
534
- )
535
-
536
- # Check numerical closeness for Cs_r and Cs_w
537
- if not np.allclose(self.grid.ds.Cs_r, ds.attrs["Cs_r"]):
538
- raise ValueError(
539
- f"Cs_r from grid ({self.grid.ds.Cs_r}) is not close to dataset ({ds.attrs['Cs_r']})."
540
- )
549
+ # Check exact equality
550
+ for param in required_exact:
551
+ value = ds.attrs.get(param, None)
552
+ if value is None:
553
+ logging.warning(
554
+ f"Dataset is missing attribute '{param}'. Skipping this check."
555
+ )
556
+ continue
557
+ if not np.array_equal(getattr(self.grid, param), value):
558
+ raise ValueError(
559
+ f"{param} from grid ({getattr(self.grid, param)}) does not match dataset ({value})."
560
+ )
541
561
 
542
- if not np.allclose(self.grid.ds.Cs_w, ds.attrs["Cs_w"]):
543
- raise ValueError(
544
- f"Cs_w from grid ({self.grid.ds.Cs_w}) is not close to dataset ({ds.attrs['Cs_w']})."
545
- )
562
+ # Check numerical closeness
563
+ for param in required_close:
564
+ value = ds.attrs.get(param, None)
565
+ if value is None:
566
+ logging.warning(
567
+ f"Dataset is missing attribute '{param}'. Skipping this check."
568
+ )
569
+ continue
570
+ grid_value = getattr(self.grid.ds, param)
571
+ if not np.allclose(grid_value, value):
572
+ raise ValueError(
573
+ f"{param} from grid ({grid_value}) is not close to dataset ({value})."
574
+ )
546
575
 
547
576
  def _add_absolute_time(self, ds: xr.Dataset) -> xr.Dataset:
548
577
  """Add absolute time as a coordinate to the dataset.
@@ -560,6 +589,11 @@ class ROMSOutput:
560
589
  xarray.Dataset
561
590
  Dataset with "abs_time" added and "time" removed.
562
591
  """
592
+ if self.model_reference_date is None:
593
+ raise ValueError(
594
+ "`model_reference_date` must be set before computing absolute time."
595
+ )
596
+
563
597
  ocean_time_seconds = ds["ocean_time"].values
564
598
 
565
599
  abs_time = np.array(
roms_tools/download.py CHANGED
@@ -86,6 +86,10 @@ pup_test_data = pooch.create(
86
86
  "eastpac25km_rst.19980106000000.nc": "8f56d72bd8daf72eb736cc6705f93f478f4ad0ae4a95e98c4c9393a38e032f4c",
87
87
  "eastpac25km_rst.19980126000000.nc": "20ad9007c980d211d1e108c50589183120c42a2d96811264cf570875107269e4",
88
88
  "epac25km_grd.nc": "ec26c69cda4c4e96abde5b7756c955a7e1074931ab5a0641f598b099778fb617",
89
+ "GSHHS_l_L1.dbf": "181236ffbf553a83d2afedc5fe5e1f2fea64190f56ea366fc7a8ff5aa6163663",
90
+ "GSHHS_l_L1.prj": "98aaf3d1c0ecadf1a424a4536de261c3daf4e373697cb86c40c43b989daf52eb",
91
+ "GSHHS_l_L1.shp": "bc76f101f9b8671f90e734b4026da91c20066fc627cc8b5889ba22d90cbf97e9",
92
+ "GSHHS_l_L1.shx": "72879354892d80d6c39c612f645661ec0edc75f3f9f8f74b19d9387ae0327377",
89
93
  },
90
94
  )
91
95