roms-tools 2.5.0__py3-none-any.whl → 2.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. ci/environment-with-xesmf.yml +16 -0
  2. roms_tools/analysis/roms_output.py +521 -187
  3. roms_tools/analysis/utils.py +169 -0
  4. roms_tools/plot.py +351 -214
  5. roms_tools/regrid.py +161 -9
  6. roms_tools/setup/boundary_forcing.py +22 -22
  7. roms_tools/setup/datasets.py +40 -44
  8. roms_tools/setup/grid.py +28 -28
  9. roms_tools/setup/initial_conditions.py +23 -31
  10. roms_tools/setup/nesting.py +3 -3
  11. roms_tools/setup/river_forcing.py +22 -23
  12. roms_tools/setup/surface_forcing.py +14 -13
  13. roms_tools/setup/tides.py +7 -7
  14. roms_tools/setup/topography.py +2 -2
  15. roms_tools/tests/test_analysis/test_roms_output.py +299 -188
  16. roms_tools/tests/test_regrid.py +85 -2
  17. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/.zmetadata +2 -2
  18. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/.zmetadata +2 -2
  19. roms_tools/tests/test_setup/test_river_forcing.py +47 -51
  20. roms_tools/tests/test_vertical_coordinate.py +73 -0
  21. roms_tools/utils.py +11 -7
  22. roms_tools/vertical_coordinate.py +7 -0
  23. {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info}/METADATA +22 -11
  24. {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info}/RECORD +33 -30
  25. {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info}/WHEEL +1 -1
  26. /roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/{river_location → river_flux}/.zarray +0 -0
  27. /roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/{river_location → river_flux}/.zattrs +0 -0
  28. /roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/{river_location → river_flux}/0.0 +0 -0
  29. /roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/{river_location → river_flux}/.zarray +0 -0
  30. /roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/{river_location → river_flux}/.zattrs +0 -0
  31. /roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/{river_location → river_flux}/0.0 +0 -0
  32. {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info/licenses}/LICENSE +0 -0
  33. {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,25 @@
1
1
  import xarray as xr
2
2
  import numpy as np
3
3
  import matplotlib.pyplot as plt
4
+ from roms_tools.plot import _plot, _section_plot, _profile_plot, _line_plot
4
5
  from roms_tools.utils import _load_data
6
+ from roms_tools.regrid import LateralRegridFromROMS, VerticalRegridFromROMS
5
7
  from dataclasses import dataclass, field
6
8
  from typing import Union, Optional
7
9
  from pathlib import Path
8
10
  import re
9
11
  import logging
12
+ import warnings
10
13
  from datetime import datetime, timedelta
11
14
  from roms_tools import Grid
12
- from roms_tools.plot import _plot, _section_plot, _profile_plot, _line_plot
13
15
  from roms_tools.vertical_coordinate import (
14
16
  compute_depth_coordinates,
15
17
  )
18
+ from roms_tools.utils import interpolate_from_rho_to_u, interpolate_from_rho_to_v
19
+ from roms_tools.analysis.utils import _validate_plot_inputs, _generate_coordinate_range
16
20
 
17
21
 
18
- @dataclass(frozen=True, kw_only=True)
22
+ @dataclass(kw_only=True)
19
23
  class ROMSOutput:
20
24
  """Represents ROMS model output.
21
25
 
@@ -49,10 +53,10 @@ class ROMSOutput:
49
53
  self._check_vertical_coordinate(ds)
50
54
  ds = self._add_absolute_time(ds)
51
55
  ds = self._add_lat_lon_coords(ds)
52
- object.__setattr__(self, "ds", ds)
56
+ self.ds = ds
53
57
 
54
58
  # Dataset for depth coordinates
55
- object.__setattr__(self, "ds_depth_coords", xr.Dataset())
59
+ self.ds_depth_coords = xr.Dataset()
56
60
 
57
61
  def plot(
58
62
  self,
@@ -61,12 +65,15 @@ class ROMSOutput:
61
65
  s=None,
62
66
  eta=None,
63
67
  xi=None,
68
+ depth=None,
69
+ lat=None,
70
+ lon=None,
64
71
  include_boundary=False,
65
72
  depth_contours=False,
66
- layer_contours=False,
67
73
  ax=None,
74
+ save_path=None,
68
75
  ) -> None:
69
- """Plot a ROMS output field for a given vertical (s_rho) or horizontal (eta, xi)
76
+ """Generate a plot of a ROMS output field for a specified vertical or horizontal
70
77
  slice.
71
78
 
72
79
  Parameters
@@ -79,31 +86,56 @@ class ROMSOutput:
79
86
 
80
87
  time : int, optional
81
88
  Index of the time dimension to plot. Default is 0.
89
+
82
90
  s : int, optional
83
- The index of the vertical layer (`s_rho`) to plot. If not specified, the plot
84
- will represent a horizontal slice (eta- or xi- plane). Default is None.
91
+ The index of the vertical layer (`s_rho`) to plot. If specified, the plot
92
+ will display a horizontal slice at that layer. Cannot be used simultaneously
93
+ with `depth`. Default is None.
94
+
85
95
  eta : int, optional
86
- The eta-index to plot. Used for vertical sections or horizontal slices.
87
- Default is None.
96
+ The eta-index to plot. Used for generating vertical sections or plotting
97
+ horizontal slices along a constant eta-coordinate. Cannot be used simultaneously
98
+ with `lat` or `lon`, but can be combined with `xi`. Default is None.
99
+
88
100
  xi : int, optional
89
- The xi-index to plot. Used for vertical sections or horizontal slices.
90
- Default is None.
101
+ The xi-index to plot. Used for generating vertical sections or plotting
102
+ horizontal slices along a constant xi-coordinate. Cannot be used simultaneously
103
+ with `lat` or `lon`, but can be combined with `eta`. Default is None.
104
+
105
+ depth : float, optional
106
+ Depth (in meters) to plot a horizontal slice at a specific depth level.
107
+ If specified, the plot will interpolate the field to the given depth.
108
+ Cannot be used simultaneously with `s` or for fields that are inherently
109
+ 2D (such as "zeta"). Default is None.
110
+
111
+ lat : float, optional
112
+ Latitude (in degrees) to plot a vertical section at a specific
113
+ latitude. This option is useful for generating zonal (west-east)
114
+ sections. Cannot be used simultaneously with `eta` or `xi`, bu can be
115
+ combined with `lon`. Default is None.
116
+
117
+ lon : float, optional
118
+ Longitude (in degrees) to plot a vertical section at a specific
119
+ longitude. This option is useful for generating meridional (south-north) sections.
120
+ Cannot be used simultaneously with `eta` or `xi`, but can be combined
121
+ with `lat`. Default is None.
122
+
91
123
  include_boundary : bool, optional
92
124
  Whether to include the outermost grid cells along the `eta`- and `xi`-boundaries in the plot.
93
125
  In diagnostic ROMS output fields, these boundary cells are set to zero, so excluding them can improve visualization.
94
- This option is only relevant for 2D horizontal plots (`eta=None`, `xi=None`).
95
126
  Default is False.
127
+
96
128
  depth_contours : bool, optional
97
- If True, depth contours will be overlaid on the plot, showing lines of constant
98
- depth. This is typically used for plots that show a single vertical layer.
99
- Default is False.
100
- layer_contours : bool, optional
101
- If True, contour lines representing the boundaries between vertical layers will
102
- be added to the plot. This is particularly useful in vertical sections to
103
- visualize the layering of the water column. For clarity, the number of layer
104
- contours displayed is limited to a maximum of 10. Default is False.
129
+ If True, overlays contours representing lines of constant depth on the plot.
130
+ This option is only relevant when the `s` parameter is provided (i.e., not None).
131
+ By default, depth contours are not shown (False).
132
+
105
133
  ax : matplotlib.axes.Axes, optional
106
- The axes to plot on. If None, a new figure is created. Note that this argument does not work for horizontal plots that display the eta- and xi-dimensions at the same time.
134
+ The axes to plot on. If None, a new figure is created. Note that this argument does not work for 2D horizontal plots. Default is None.
135
+
136
+ save_path : str, optional
137
+ Path to save the generated plot. If None, the plot is shown interactively.
138
+ Default is None.
107
139
 
108
140
  Returns
109
141
  -------
@@ -113,44 +145,86 @@ class ROMSOutput:
113
145
  Raises
114
146
  ------
115
147
  ValueError
116
- If the specified `var_name` is not one of the valid options.
117
- If the field specified by `var_name` is 3D and none of `s`, `eta`, or `xi` are specified.
118
- If the field specified by `var_name` is 2D and both `eta` and `xi` are specified.
148
+ - If the specified `var_name` is not one of the valid options.
149
+ - If the field specified by `var_name` is 3D and none of `s`, `eta`, `xi`, `depth`, `lat`, or `lon` are specified.
150
+ - If the field specified by `var_name` is 2D and both `eta` and `xi` or both `lat` and `lon` are specified.
151
+ - If conflicting dimensions are specified (e.g., specifying `eta`/`xi` with `lat`/`lon` or both `s` and `depth`).
152
+ - If more than two dimensions are specified for a 3D field.
153
+ - If `time` exceeds the bounds of the time dimension.
154
+ - If `time` is specified for a field that does not have a time dimension.
155
+ - If `eta` or `xi` indices are out of bounds.
156
+ - If `eta` or `xi` lie on the boundary when `include_boundary=False`.
119
157
  """
120
-
121
- # Input checks
158
+ # Check if variable exists
122
159
  if var_name not in self.ds:
123
- raise ValueError(f"Variable '{var_name}' is not found in dataset.")
160
+ raise ValueError(f"Variable '{var_name}' is not found in the dataset.")
124
161
 
125
- if "time" in self.ds[var_name].dims:
126
- if time >= len(self.ds[var_name].time):
162
+ # Pick the variable
163
+ field = self.ds[var_name]
164
+
165
+ # Check and pick time
166
+ if "time" in field.dims:
167
+ if time >= len(field.time):
127
168
  raise ValueError(
128
169
  f"Invalid time index: The specified time index ({time}) exceeds the maximum index "
129
- f"({len(self.ds[var_name].time) - 1}) for the 'time' dimension in variable '{var_name}'."
170
+ f"({len(field.time) - 1}) for the 'time' dimension."
130
171
  )
131
- field = self.ds[var_name].isel(time=time)
172
+ field = field.isel(time=time)
132
173
  else:
133
174
  if time > 0:
134
175
  raise ValueError(
135
- f"Invalid input: The variable '{var_name}' does not have a 'time' dimension, "
176
+ f"Invalid input: The field does not have a 'time' dimension, "
136
177
  f"but a time index ({time}) greater than 0 was provided."
137
178
  )
138
- field = self.ds[var_name]
139
179
 
140
- if len(field.dims) == 3:
141
- if not any([s is not None, eta is not None, xi is not None]):
142
- raise ValueError(
143
- "Invalid input: For 3D fields, you must specify at least one of the dimensions 's', 'eta', or 'xi'."
144
- )
145
- if all([s is not None, eta is not None, xi is not None]):
146
- raise ValueError(
147
- "Ambiguous input: For 3D fields, specify at most two of 's', 'eta', or 'xi'. Specifying all three is not allowed."
148
- )
180
+ # Input checks
181
+ _validate_plot_inputs(field, s, eta, xi, depth, lat, lon, include_boundary)
182
+
183
+ # Get horizontal dimensions and grid location
184
+ horizontal_dims_dict = {
185
+ "rho": {"eta": "eta_rho", "xi": "xi_rho"},
186
+ "u": {"eta": "eta_rho", "xi": "xi_u"},
187
+ "v": {"eta": "eta_v", "xi": "xi_rho"},
188
+ }
189
+ for loc, horizontal_dims in horizontal_dims_dict.items():
190
+ if all(dim in field.dims for dim in horizontal_dims.values()):
191
+ break
192
+
193
+ # Convert relative to absolute indices
194
+ def _get_absolute_index(idx, field, dim_name):
195
+ index = field[dim_name].isel(**{dim_name: idx}).item()
196
+ return index
197
+
198
+ if eta is not None and eta < 0:
199
+ eta = _get_absolute_index(eta, field, horizontal_dims["eta"])
200
+ if xi is not None and xi < 0:
201
+ xi = _get_absolute_index(xi, field, horizontal_dims["xi"])
202
+ if s is not None and s < 0:
203
+ s = _get_absolute_index(s, field, "s_rho")
204
+
205
+ # Set spatial coordinates
206
+ lat_deg = self.grid.ds[f"lat_{loc}"]
207
+ lon_deg = self.grid.ds[f"lon_{loc}"]
208
+ if self.grid.straddle:
209
+ lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
210
+ field = field.assign_coords({"lon": lon_deg, "lat": lat_deg})
149
211
 
150
- if len(field.dims) == 2 and all([eta is not None, xi is not None]):
151
- raise ValueError(
152
- "Conflicting input: For 2D fields, specify only one dimension, either 'eta' or 'xi', not both."
153
- )
212
+ # Mask the field
213
+ mask = self.grid.ds[f"mask_{loc}"]
214
+ field = field.where(mask)
215
+
216
+ # Assign eta and xi as coordinates
217
+ coords_to_assign = {dim: field[dim] for dim in horizontal_dims.values()}
218
+ field = field.assign_coords(**coords_to_assign)
219
+
220
+ # Remove horizontal boundary if desired
221
+ slice_dict = {
222
+ "rho": {"eta_rho": slice(1, -1), "xi_rho": slice(1, -1)},
223
+ "u": {"eta_rho": slice(1, -1), "xi_u": slice(1, -1)},
224
+ "v": {"eta_v": slice(1, -1), "xi_rho": slice(1, -1)},
225
+ }
226
+ if not include_boundary:
227
+ field = field.isel(**slice_dict[loc])
154
228
 
155
229
  # Load the data
156
230
  if self.use_dask:
@@ -159,140 +233,156 @@ class ROMSOutput:
159
233
  with ProgressBar():
160
234
  field.load()
161
235
 
162
- # Get correct mask and spatial coordinates
163
- if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
164
- loc = "rho"
165
- elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
166
- loc = "u"
167
- elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
168
- loc = "v"
169
- else:
170
- ValueError("provided field does not have two horizontal dimension")
171
-
172
- mask = self.grid.ds[f"mask_{loc}"]
173
- lat_deg = self.grid.ds[f"lat_{loc}"]
174
- lon_deg = self.grid.ds[f"lon_{loc}"]
175
-
176
- if self.grid.straddle:
177
- lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
236
+ # Compute layer depth for 3D fields when depth contours are requested or no vertical layer is specified.
237
+ compute_layer_depth = len(field.dims) > 2 and (depth_contours or s is None)
238
+ if compute_layer_depth:
239
+ if eta is not None or xi is not None:
240
+ # Computing depth coordinates directly for the slice in question is more efficient
241
+ # than using .ds_depth_coords, which computes depth coordinates for full field
242
+ if self.adjust_depth_for_sea_surface_height:
243
+ zeta = self.ds.zeta.isel(time=time)
244
+ else:
245
+ zeta = 0
246
+ if compute_layer_depth:
247
+ layer_depth = compute_depth_coordinates(
248
+ self.grid.ds,
249
+ zeta,
250
+ depth_type="layer",
251
+ location=loc,
252
+ eta=eta,
253
+ xi=xi,
254
+ )
255
+ else:
256
+ self._get_depth_coordinates(depth_type="layer", locations=[loc])
257
+ layer_depth = self.ds_depth_coords[f"layer_depth_{loc}"]
258
+ if self.adjust_depth_for_sea_surface_height:
259
+ layer_depth = layer_depth.isel(time=time)
260
+
261
+ if not include_boundary:
262
+ # Apply valid slices only for dimensions that exist in layer_depth.dims
263
+ layer_depth = layer_depth.isel(
264
+ **{
265
+ dim: s
266
+ for dim, s in slice_dict.get(loc, {}).items()
267
+ if dim in layer_depth.dims
268
+ }
269
+ )
270
+ layer_depth.load()
178
271
 
179
- field = field.assign_coords({"lon": lon_deg, "lat": lat_deg})
272
+ # Prepare figure title
273
+ formatted_time = np.datetime_as_string(field.abs_time.values, unit="m")
274
+ title = f"time: {formatted_time}"
180
275
 
181
- # Retrieve depth coordinates
182
- compute_layer_depth = (depth_contours or s is None) and len(field.dims) > 2
183
- compute_interface_depth = layer_contours and s is None
276
+ # Slice the field horizontally as desired
277
+ def _slice_along_dimension(field, title, dim_name, idx):
278
+ field = field.sel(**{dim_name: idx})
279
+ title = title + f", {dim_name} = {idx}"
280
+ return field, title
184
281
 
185
- # Compute depth coordinates directly instead of using .ds_depth_coords.
186
- # Many cases below require only a 1D or 2D slice, making direct computation
187
- # more efficient than triggering a full 3D depth computation just to extract a subset.
188
- # This is especially beneficial when using Dask or if .ds_depth_coords is precomputed.
189
- if self.adjust_depth_for_sea_surface_height:
190
- zeta = self.ds.zeta.isel(time=time)
191
- else:
192
- zeta = 0
193
- if compute_layer_depth:
194
- layer_depth = compute_depth_coordinates(
195
- self.grid.ds,
196
- zeta,
197
- depth_type="layer",
198
- location=loc,
199
- eta=eta,
200
- xi=xi,
282
+ if eta is not None:
283
+ field, title = _slice_along_dimension(
284
+ field, title, horizontal_dims["eta"], eta
201
285
  )
202
- if s is not None:
203
- layer_depth = layer_depth.isel(s_rho=s)
204
- if compute_interface_depth:
205
- interface_depth = compute_depth_coordinates(
206
- self.grid.ds,
207
- zeta,
208
- depth_type="interface",
209
- location=loc,
210
- eta=eta,
211
- xi=xi,
286
+ if xi is not None:
287
+ field, title = _slice_along_dimension(
288
+ field, title, horizontal_dims["xi"], xi
212
289
  )
213
- if s is not None:
214
- interface_depth = interface_depth.isel(s_w=s)
215
-
216
- # Slice the field as desired
217
- title = field.long_name
218
290
  if s is not None:
219
- title = title + f", s_rho = {field.s_rho[s].item()}"
220
- field = field.isel(s_rho=s)
291
+ field, title = _slice_along_dimension(field, title, "s_rho", s)
292
+ if compute_layer_depth:
293
+ layer_depth = layer_depth.isel(s_rho=s)
221
294
  else:
222
295
  depth_contours = False
223
296
 
224
- def _process_dimension(field, mask, dim_name, dim_values, idx, title):
225
- if dim_name in field.dims:
226
- title = title + f", {dim_name} = {dim_values[idx].item()}"
227
- field = field.isel(**{dim_name: idx})
228
- mask = mask.isel(**{dim_name: idx})
297
+ # Regrid laterally
298
+ if lat is not None or lon is not None:
299
+
300
+ if lat is not None:
301
+ lats = [lat]
302
+ title = title + f", lat = {lat}°N"
229
303
  else:
230
- raise ValueError(
231
- f"None of the expected dimensions ({dim_name}) found in field."
304
+ resolution = self._infer_nominal_horizontal_resolution()
305
+ lats = _generate_coordinate_range(
306
+ field.lat.min().values, field.lat.max().values, resolution
232
307
  )
233
- return field, mask, title
308
+ lats = xr.DataArray(lats, dims=["lat"], attrs={"units": "°N"})
234
309
 
235
- if eta is not None:
236
- field, mask, title = _process_dimension(
237
- field,
238
- mask,
239
- "eta_rho" if "eta_rho" in field.dims else "eta_v",
240
- field.eta_rho if "eta_rho" in field.dims else field.eta_v,
241
- eta,
242
- title,
243
- )
244
-
245
- if xi is not None:
246
- field, mask, title = _process_dimension(
247
- field,
248
- mask,
249
- "xi_rho" if "xi_rho" in field.dims else "xi_u",
250
- field.xi_rho if "xi_rho" in field.dims else field.xi_u,
251
- xi,
252
- title,
253
- )
310
+ if lon is not None:
311
+ lons = [lon]
312
+ title = title + f", lon = {lon}°E"
313
+ else:
314
+ resolution = self._infer_nominal_horizontal_resolution(lat)
315
+ lons = _generate_coordinate_range(
316
+ field.lon.min().values, field.lon.max().values, resolution
317
+ )
318
+ lons = xr.DataArray(lons, dims=["lon"], attrs={"units": "°E"})
254
319
 
255
- # Format to exclude seconds
256
- formatted_time = np.datetime_as_string(field.abs_time.values, unit="m")
257
- title = title + f", time: {formatted_time}"
320
+ target_coords = {"lat": lats, "lon": lons}
321
+ lateral_regrid = LateralRegridFromROMS(field, target_coords)
322
+ field = lateral_regrid.apply(field).squeeze()
323
+ if compute_layer_depth:
324
+ layer_depth = lateral_regrid.apply(layer_depth).squeeze()
258
325
 
326
+ # Assign depth as coordinate
259
327
  if compute_layer_depth:
260
328
  field = field.assign_coords({"layer_depth": layer_depth})
261
329
 
262
- if not include_boundary:
263
- slice_dict = None
264
-
265
- if eta is None and xi is None:
266
- slice_dict = {
267
- "rho": {"eta_rho": slice(1, -1), "xi_rho": slice(1, -1)},
268
- "u": {"eta_rho": slice(1, -1), "xi_u": slice(1, -1)},
269
- "v": {"eta_v": slice(1, -1), "xi_rho": slice(1, -1)},
270
- }
271
- elif eta is None:
272
- slice_dict = {
273
- "rho": {"eta_rho": slice(1, -1)},
274
- "u": {"eta_rho": slice(1, -1)},
275
- "v": {"eta_v": slice(1, -1)},
276
- }
277
- elif xi is None:
278
- slice_dict = {
279
- "rho": {"xi_rho": slice(1, -1)},
280
- "u": {"xi_u": slice(1, -1)},
281
- "v": {"xi_rho": slice(1, -1)},
282
- }
283
- if slice_dict is not None:
284
- if loc in slice_dict:
285
- field = field.isel(**slice_dict[loc])
286
- mask = mask.isel(**slice_dict[loc])
330
+ def _remove_edge_nans(field, xdim, layer_depth=None):
331
+ """Removes NaNs from the edges along the specified dimension."""
332
+ if xdim in field.dims:
333
+ if layer_depth is not None:
334
+ nan_mask = layer_depth.isnull().sum(
335
+ dim=[dim for dim in layer_depth.dims if dim != xdim]
336
+ )
337
+ else:
338
+ nan_mask = field.isnull().sum(
339
+ dim=[dim for dim in field.dims if dim != xdim]
340
+ )
341
+
342
+ # Find the valid indices where the sum of the nans is 0
343
+ valid_indices = np.where(nan_mask.values == 0)[0]
344
+
345
+ if len(valid_indices) > 0:
346
+ first_valid = valid_indices[0]
347
+ last_valid = valid_indices[-1]
348
+
349
+ field = field.isel({xdim: slice(first_valid, last_valid + 1)})
350
+ if layer_depth is not None:
351
+ layer_depth = layer_depth.isel(
352
+ {xdim: slice(first_valid, last_valid + 1)}
353
+ )
354
+
355
+ return field, layer_depth
356
+
357
+ if lat is not None:
358
+ field, layer_depth = _remove_edge_nans(
359
+ field, "lon", layer_depth if "layer_depth" in locals() else None
360
+ )
361
+ if lon is not None:
362
+ field, layer_depth = _remove_edge_nans(
363
+ field, "lat", layer_depth if "layer_depth" in locals() else None
364
+ )
365
+
366
+ # Regrid vertically
367
+ if depth is not None:
368
+ vertical_regrid = VerticalRegridFromROMS(self.ds)
369
+ # Save attributes before vertical regridding
370
+ attrs = field.attrs
371
+ field = vertical_regrid.apply(
372
+ field, layer_depth, np.array([depth])
373
+ ).squeeze()
374
+ # Reset attributes
375
+ field.attrs = attrs
376
+ title = title + f", depth = {depth}m"
287
377
 
288
378
  # Choose colorbar
289
379
  if var_name in ["u", "v", "w", "ubar", "vbar", "zeta"]:
290
- vmax = max(field.where(mask).max().values, -field.where(mask).min().values)
380
+ vmax = max(field.max().values, -field.min().values)
291
381
  vmin = -vmax
292
382
  cmap = plt.colormaps.get_cmap("RdBu_r")
293
383
  else:
294
- vmax = field.where(mask).max().values
295
- vmin = field.where(mask).min().values
384
+ vmax = field.max().values
385
+ vmin = field.min().values
296
386
  if var_name in ["temp", "salt"]:
297
387
  cmap = plt.colormaps.get_cmap("YlOrRd")
298
388
  else:
@@ -301,37 +391,170 @@ class ROMSOutput:
301
391
  kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
302
392
 
303
393
  # Plotting
304
- if eta is None and xi is None:
305
- _plot(
306
- field=field.where(mask),
394
+ if (eta is None and xi is None) and (lat is None and lon is None):
395
+ fig = _plot(
396
+ field=field,
307
397
  depth_contours=depth_contours,
308
398
  title=title,
309
399
  kwargs=kwargs,
310
- c="g",
400
+ c=None,
311
401
  )
312
402
  else:
313
403
  if len(field.dims) == 2:
314
- if not layer_contours:
315
- interface_depth = None
316
- else:
317
- # restrict number of layer_contours to 10 for the sake of plot clearity
318
- nr_layers = len(interface_depth["s_w"])
319
- selected_layers = np.linspace(
320
- 0, nr_layers - 1, min(nr_layers, 10), dtype=int
321
- )
322
- interface_depth = interface_depth.isel(s_w=selected_layers)
323
- _section_plot(
324
- field.where(mask),
325
- interface_depth=interface_depth,
404
+ fig = _section_plot(
405
+ field,
406
+ interface_depth=None,
326
407
  title=title,
327
408
  kwargs=kwargs,
328
409
  ax=ax,
329
410
  )
330
411
  else:
331
412
  if "s_rho" in field.dims:
332
- _profile_plot(field.where(mask), title=title, ax=ax)
413
+ fig = _profile_plot(field, title=title, ax=ax)
333
414
  else:
334
- _line_plot(field.where(mask), title=title, ax=ax)
415
+ fig = _line_plot(field, title=title, ax=ax)
416
+
417
+ if save_path:
418
+ plt.savefig(save_path, dpi=300, bbox_inches="tight")
419
+
420
+ def regrid(self, var_names=None, horizontal_resolution=None, depth_levels=None):
421
+ """Regrid the dataset both horizontally and vertically.
422
+
423
+ This method selects the specified variables, interpolates them onto a lat-lon-z horizontal grid. The horizontal target resolution and vertical target depth levels are either specified or inferred dynamically.
424
+
425
+ Parameters
426
+ ----------
427
+ var_names : list of str, optional
428
+ List of variable names to be regridded. If None, all variables in the dataset
429
+ are used.
430
+ horizontal_resolution : float, optional
431
+ Target horizontal resolution in degrees. If None, the nominal horizontal resolution is inferred from the grid.
432
+ depth_levels : xarray.DataArray, numpy.ndarray, list, optional
433
+ Target depth levels. If None, depth levels are determined dynamically.
434
+ If provided as a list or numpy array, it is safely converted to an `xarray.DataArray`.
435
+
436
+ Returns
437
+ -------
438
+ xarray.Dataset
439
+ The regridded dataset.
440
+ """
441
+
442
+ if var_names is None:
443
+ var_names = list(self.ds.data_vars)
444
+
445
+ # Check that all var_names exist in self.ds
446
+ missing_vars = [var for var in var_names if var not in self.ds.data_vars]
447
+ if missing_vars:
448
+ raise ValueError(
449
+ f"The following variables are not found in the dataset: {', '.join(missing_vars)}"
450
+ )
451
+
452
+ # Retain only the variables in var_names and drop others
453
+ ds = self.ds[var_names]
454
+
455
+ # Prepare lateral regrid
456
+ lat_deg = self.grid.ds["lat_rho"]
457
+ lon_deg = self.grid.ds["lon_rho"]
458
+ if self.grid.straddle:
459
+ lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
460
+
461
+ if horizontal_resolution is None:
462
+ horizontal_resolution = self._infer_nominal_horizontal_resolution()
463
+ lons = _generate_coordinate_range(
464
+ lon_deg.min().values, lon_deg.max().values, horizontal_resolution
465
+ )
466
+ lons = xr.DataArray(lons, dims=["lon"], attrs={"units": "°E"})
467
+ lats = _generate_coordinate_range(
468
+ lat_deg.min().values, lat_deg.max().values, horizontal_resolution
469
+ )
470
+ lats = xr.DataArray(lats, dims=["lat"], attrs={"units": "°N"})
471
+ target_coords = {"lat": lats, "lon": lons}
472
+
473
+ # Prepare vertical regrid
474
+ if depth_levels is None:
475
+ depth_levels, _ = self._compute_exponential_depth_levels()
476
+
477
+ # Ensure depth_levels is an xarray.DataArray
478
+ if not isinstance(depth_levels, xr.DataArray):
479
+ depth_levels = xr.DataArray(
480
+ np.asarray(depth_levels),
481
+ dims=["depth"],
482
+ attrs={"long_name": "Depth", "units": "m"},
483
+ )
484
+
485
+ depth_levels = depth_levels.astype(np.float32)
486
+
487
+ # Initialize list to hold regridded datasets
488
+ regridded_datasets = []
489
+
490
+ for loc, dims in [
491
+ ("rho", ("eta_rho", "xi_rho")),
492
+ ("u", ("eta_rho", "xi_u")),
493
+ ("v", ("eta_v", "xi_rho")),
494
+ ]:
495
+ var_names_loc = [
496
+ var_name
497
+ for var_name in var_names
498
+ if all(dim in ds[var_name].dims for dim in dims)
499
+ ]
500
+ if var_names_loc:
501
+ ds_loc = (
502
+ ds[var_names_loc]
503
+ .rename({f"lat_{loc}": "lat", f"lon_{loc}": "lon"})
504
+ .where(self.grid.ds[f"mask_{loc}"])
505
+ )
506
+ self._get_depth_coordinates(depth_type="layer", locations=[loc])
507
+ layer_depth_loc = self.ds_depth_coords[f"layer_depth_{loc}"]
508
+ h_loc = self.grid.ds.h
509
+ if loc == "u":
510
+ h_loc = interpolate_from_rho_to_u(h_loc)
511
+ elif loc == "v":
512
+ h_loc = interpolate_from_rho_to_v(h_loc)
513
+
514
+ # Exclude the horizontal boundary cells since diagnostic variables may contain zeros there
515
+ ds_loc = ds_loc.isel({dims[0]: slice(1, -1), dims[1]: slice(1, -1)})
516
+ layer_depth_loc = layer_depth_loc.isel(
517
+ {dims[0]: slice(1, -1), dims[1]: slice(1, -1)}
518
+ )
519
+ h_loc = h_loc.isel({dims[0]: slice(1, -1), dims[1]: slice(1, -1)})
520
+
521
+ # Lateral regridding
522
+ lateral_regrid = LateralRegridFromROMS(ds_loc, target_coords)
523
+ ds_loc = lateral_regrid.apply(ds_loc)
524
+ layer_depth_loc = lateral_regrid.apply(layer_depth_loc)
525
+ h_loc = lateral_regrid.apply(h_loc)
526
+ # Vertical regridding
527
+ vertical_regrid = VerticalRegridFromROMS(ds_loc)
528
+ for var_name in var_names_loc:
529
+ if "s_rho" in ds_loc[var_name].dims:
530
+ attrs = ds_loc[var_name].attrs
531
+ regridded = vertical_regrid.apply(
532
+ ds_loc[var_name],
533
+ layer_depth_loc,
534
+ depth_levels,
535
+ mask_edges=False,
536
+ )
537
+ regridded = regridded.where(regridded.depth < h_loc)
538
+ ds_loc[var_name] = regridded
539
+ ds_loc[var_name].attrs = attrs
540
+
541
+ ds_loc = ds_loc.assign_coords({"depth": depth_levels})
542
+
543
+ # Collect regridded dataset for merging
544
+ regridded_datasets.append(ds_loc)
545
+
546
+ # Merge all regridded datasets
547
+ if regridded_datasets:
548
+ ds = xr.merge(regridded_datasets)
549
+
550
+ with warnings.catch_warnings():
551
+ warnings.filterwarnings("ignore", category=UserWarning)
552
+ ds = ds.rename({"abs_time": "time"}).set_index(time="time")
553
+ ds["time"].attrs = {"long_name": "Time"}
554
+ ds["lon"].attrs = {"long_name": "Longitude", "units": "Degrees East"}
555
+ ds["lat"].attrs = {"long_name": "Latitude", "units": "Degrees North"}
556
+
557
+ return ds
335
558
 
336
559
  def _get_depth_coordinates(self, depth_type="layer", locations=["rho"]):
337
560
  """Ensure depth coordinates are stored for a given location and depth type.
@@ -378,9 +601,11 @@ class ROMSOutput:
378
601
  zeta = 0
379
602
 
380
603
  for location in locations:
381
- self.ds_depth_coords[
382
- f"{depth_type}_depth_{location}"
383
- ] = compute_depth_coordinates(self.grid.ds, zeta, depth_type, location)
604
+ var_name = f"{depth_type}_depth_{location}"
605
+ if var_name not in self.ds_depth_coords:
606
+ self.ds_depth_coords[var_name] = compute_depth_coordinates(
607
+ self.grid.ds, zeta, depth_type, location
608
+ )
384
609
 
385
610
  def _load_model_output(self) -> xr.Dataset:
386
611
  """Load the model output."""
@@ -435,7 +660,7 @@ class ROMSOutput:
435
660
  )
436
661
  else:
437
662
  # Set the model reference date if not already set
438
- object.__setattr__(self, "model_reference_date", inferred_date)
663
+ self.model_reference_date = inferred_date
439
664
  else:
440
665
  # Handle case where no match is found
441
666
  if hasattr(self, "model_reference_date") and self.model_reference_date:
@@ -548,22 +773,131 @@ class ROMSOutput:
548
773
  return ds
549
774
 
550
775
  def _add_lat_lon_coords(self, ds: xr.Dataset) -> xr.Dataset:
551
- """Add latitude and longitude coordinates to the dataset.
776
+ """Add latitude and longitude coordinates to the dataset based on the grid.
552
777
 
553
- Adds "lat_rho" and "lon_rho" from the grid object to the dataset.
778
+ This method assigns latitude and longitude coordinates from the grid to the dataset.
779
+ It always adds the "lat_rho" and "lon_rho" coordinates. If the dataset contains the
780
+ "xi_u" or "eta_v" dimensions, it also adds the corresponding "lat_u", "lon_u",
781
+ "lat_v", and "lon_v" coordinates.
554
782
 
555
783
  Parameters
556
784
  ----------
557
785
  ds : xarray.Dataset
558
- Dataset to update.
786
+ Input dataset to which latitude and longitude coordinates will be added.
559
787
 
560
788
  Returns
561
789
  -------
562
790
  xarray.Dataset
563
- Dataset with "lat_rho" and "lon_rho" coordinates added.
791
+ Updated dataset with the appropriate latitude and longitude coordinates
792
+ assigned to "rho", "u", and "v" points if applicable.
564
793
  """
565
- ds = ds.assign_coords(
566
- {"lat_rho": self.grid.ds["lat_rho"], "lon_rho": self.grid.ds["lon_rho"]}
567
- )
794
+ coords_to_add = {
795
+ "lat_rho": self.grid.ds["lat_rho"],
796
+ "lon_rho": self.grid.ds["lon_rho"],
797
+ }
798
+
799
+ if "xi_u" in ds.dims:
800
+ coords_to_add.update(
801
+ {"lat_u": self.grid.ds["lat_u"], "lon_u": self.grid.ds["lon_u"]}
802
+ )
803
+ if "eta_v" in ds.dims:
804
+ coords_to_add.update(
805
+ {"lat_v": self.grid.ds["lat_v"], "lon_v": self.grid.ds["lon_v"]}
806
+ )
568
807
 
808
+ # Add all necessary coordinates in one go
809
+ ds = ds.assign_coords(coords_to_add)
569
810
  return ds
811
+
812
+ def _infer_nominal_horizontal_resolution(self, lat=None):
813
+ """Estimate the nominal horizontal resolution of the grid in degrees at a
814
+ specified latitude.
815
+
816
+ This method calculates the nominal horizontal resolution of the grid by first
817
+ determining the average grid spacing in meters. The spacing is then converted
818
+ to degrees, accounting for the Earth's curvature, and the latitude where the
819
+ resolution is being computed.
820
+
821
+ Parameters
822
+ ----------
823
+ lat : float, optional
824
+ Latitude (in degrees) at which to estimate the horizontal resolution.
825
+ If not provided, the resolution is calculated at the average latitude of
826
+ the grid (`lat_rho`).
827
+
828
+ Returns
829
+ -------
830
+ float
831
+ The estimated horizontal resolution in degrees, adjusted for the Earth's curvature.
832
+ """
833
+ # Earth radius in meters
834
+ r_earth = 6371315.0
835
+
836
+ if lat is None:
837
+ # Center latitude in degrees
838
+ lat = (self.grid.ds.lat_rho.max() + self.grid.ds.lat_rho.min()) / 2
839
+
840
+ # Convert latitude to radians
841
+ lat_rad = np.deg2rad(lat)
842
+
843
+ # Mean resolution in meters
844
+ resolution_in_m = (
845
+ (1 / self.grid.ds.pm).mean() + (1 / self.grid.ds.pn).mean()
846
+ ) / 2
847
+
848
+ # Meters per degree at the equator
849
+ meters_per_degree = 2 * np.pi * r_earth / 360
850
+
851
+ # Correct for latitude by multiplying by cos(latitude) for longitude
852
+ resolution_in_degrees = resolution_in_m / (meters_per_degree * np.cos(lat_rad))
853
+
854
+ return resolution_in_degrees
855
+
856
+ def _compute_exponential_depth_levels(self, Nz=None, depth=None, h=None):
857
+ """Compute vertical grid center and face depths using an exponential profile.
858
+
859
+ Parameters
860
+ ----------
861
+ Nz : int, optional
862
+ Number of vertical levels. Defaults to `len(self.ds.s_rho)`.
863
+
864
+ depth : float, optional
865
+ Total depth of the domain. Defaults to `grid.ds.h.max().values`.
866
+
867
+ h : float, optional
868
+ Scaling parameter for the exponential profile. Defaults to `Nz / 4.5`.
869
+
870
+ Returns
871
+ -------
872
+ tuple of numpy.ndarray
873
+ Depth values at the vertical grid centers (`z_centers`) and grid faces (`z_faces`),
874
+ both rounded to two decimal places.
875
+ """
876
+ if Nz is None:
877
+ Nz = len(self.ds.s_rho)
878
+ if depth is None:
879
+ depth = self.grid.ds.h.max().values
880
+ if h is None:
881
+ h = Nz / 4.5
882
+
883
+ k = np.arange(1, Nz + 2)
884
+
885
+ # Define the exponential profile function
886
+ def exponential_profile(k, Nz, h):
887
+ return np.exp(k / h)
888
+
889
+ z_faces = np.vectorize(exponential_profile)(k, Nz, h)
890
+
891
+ # Normalize
892
+ z_faces -= z_faces[0]
893
+ z_faces *= depth / z_faces[-1]
894
+ z_faces[0] = 0.0
895
+
896
+ # Calculate center depths (average between adjacent face depths)
897
+ z_centers = (z_faces[:-1] + z_faces[1:]) / 2
898
+
899
+ # Round both z_faces and z_centers to two decimal places
900
+ z_faces = np.round(z_faces, 2)
901
+ z_centers = np.round(z_centers, 2)
902
+
903
+ return z_centers, z_faces