roms-tools 2.5.0__py3-none-any.whl → 2.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,16 @@
1
+ name: romstools-test
2
+ channels:
3
+ - conda-forge
4
+ - nodefaults
5
+ dependencies:
6
+ - python>=3.10
7
+ - xesmf
8
+ # testing
9
+ - zarr
10
+ - pytest
11
+ - pytest-xdist
12
+ - h5py
13
+ - flake8
14
+ - black
15
+ - pre-commit==3.8.0
16
+ - coverage
@@ -1,7 +1,9 @@
1
1
  import xarray as xr
2
2
  import numpy as np
3
3
  import matplotlib.pyplot as plt
4
+ from roms_tools.plot import _plot, _section_plot, _profile_plot, _line_plot
4
5
  from roms_tools.utils import _load_data
6
+ from roms_tools.regrid import LateralRegridFromROMS, VerticalRegridFromROMS
5
7
  from dataclasses import dataclass, field
6
8
  from typing import Union, Optional
7
9
  from pathlib import Path
@@ -9,13 +11,13 @@ import re
9
11
  import logging
10
12
  from datetime import datetime, timedelta
11
13
  from roms_tools import Grid
12
- from roms_tools.plot import _plot, _section_plot, _profile_plot, _line_plot
13
14
  from roms_tools.vertical_coordinate import (
14
15
  compute_depth_coordinates,
15
16
  )
17
+ from roms_tools.analysis.utils import _validate_plot_inputs, _generate_coordinate_range
16
18
 
17
19
 
18
- @dataclass(frozen=True, kw_only=True)
20
+ @dataclass(kw_only=True)
19
21
  class ROMSOutput:
20
22
  """Represents ROMS model output.
21
23
 
@@ -49,10 +51,10 @@ class ROMSOutput:
49
51
  self._check_vertical_coordinate(ds)
50
52
  ds = self._add_absolute_time(ds)
51
53
  ds = self._add_lat_lon_coords(ds)
52
- object.__setattr__(self, "ds", ds)
54
+ self.ds = ds
53
55
 
54
56
  # Dataset for depth coordinates
55
- object.__setattr__(self, "ds_depth_coords", xr.Dataset())
57
+ self.ds_depth_coords = xr.Dataset()
56
58
 
57
59
  def plot(
58
60
  self,
@@ -61,12 +63,15 @@ class ROMSOutput:
61
63
  s=None,
62
64
  eta=None,
63
65
  xi=None,
66
+ depth=None,
67
+ lat=None,
68
+ lon=None,
64
69
  include_boundary=False,
65
70
  depth_contours=False,
66
- layer_contours=False,
67
71
  ax=None,
72
+ save_path=None,
68
73
  ) -> None:
69
- """Plot a ROMS output field for a given vertical (s_rho) or horizontal (eta, xi)
74
+ """Generate a plot of a ROMS output field for a specified vertical or horizontal
70
75
  slice.
71
76
 
72
77
  Parameters
@@ -79,31 +84,56 @@ class ROMSOutput:
79
84
 
80
85
  time : int, optional
81
86
  Index of the time dimension to plot. Default is 0.
87
+
82
88
  s : int, optional
83
- The index of the vertical layer (`s_rho`) to plot. If not specified, the plot
84
- will represent a horizontal slice (eta- or xi- plane). Default is None.
89
+ The index of the vertical layer (`s_rho`) to plot. If specified, the plot
90
+ will display a horizontal slice at that layer. Cannot be used simultaneously
91
+ with `depth`. Default is None.
92
+
85
93
  eta : int, optional
86
- The eta-index to plot. Used for vertical sections or horizontal slices.
87
- Default is None.
94
+ The eta-index to plot. Used for generating vertical sections or plotting
95
+ horizontal slices along a constant eta-coordinate. Cannot be used simultaneously
96
+ with `lat` or `lon`, but can be combined with `xi`. Default is None.
97
+
88
98
  xi : int, optional
89
- The xi-index to plot. Used for vertical sections or horizontal slices.
90
- Default is None.
99
+ The xi-index to plot. Used for generating vertical sections or plotting
100
+ horizontal slices along a constant xi-coordinate. Cannot be used simultaneously
101
+ with `lat` or `lon`, but can be combined with `eta`. Default is None.
102
+
103
+ depth : float, optional
104
+ Depth (in meters) to plot a horizontal slice at a specific depth level.
105
+ If specified, the plot will interpolate the field to the given depth.
106
+ Cannot be used simultaneously with `s` or for fields that are inherently
107
+ 2D (such as "zeta"). Default is None.
108
+
109
+ lat : float, optional
110
+ Latitude (in degrees) to plot a vertical section at a specific
111
+ latitude. This option is useful for generating zonal (west-east)
112
+ sections. Cannot be used simultaneously with `eta` or `xi`, bu can be
113
+ combined with `lon`. Default is None.
114
+
115
+ lon : float, optional
116
+ Longitude (in degrees) to plot a vertical section at a specific
117
+ longitude. This option is useful for generating meridional (south-north) sections.
118
+ Cannot be used simultaneously with `eta` or `xi`, but can be combined
119
+ with `lat`. Default is None.
120
+
91
121
  include_boundary : bool, optional
92
122
  Whether to include the outermost grid cells along the `eta`- and `xi`-boundaries in the plot.
93
123
  In diagnostic ROMS output fields, these boundary cells are set to zero, so excluding them can improve visualization.
94
- This option is only relevant for 2D horizontal plots (`eta=None`, `xi=None`).
95
124
  Default is False.
125
+
96
126
  depth_contours : bool, optional
97
- If True, depth contours will be overlaid on the plot, showing lines of constant
98
- depth. This is typically used for plots that show a single vertical layer.
99
- Default is False.
100
- layer_contours : bool, optional
101
- If True, contour lines representing the boundaries between vertical layers will
102
- be added to the plot. This is particularly useful in vertical sections to
103
- visualize the layering of the water column. For clarity, the number of layer
104
- contours displayed is limited to a maximum of 10. Default is False.
127
+ If True, overlays contours representing lines of constant depth on the plot.
128
+ This option is only relevant when the `s` parameter is provided (i.e., not None).
129
+ By default, depth contours are not shown (False).
130
+
105
131
  ax : matplotlib.axes.Axes, optional
106
- The axes to plot on. If None, a new figure is created. Note that this argument does not work for horizontal plots that display the eta- and xi-dimensions at the same time.
132
+ The axes to plot on. If None, a new figure is created. Note that this argument does not work for 2D horizontal plots. Default is None.
133
+
134
+ save_path : str, optional
135
+ Path to save the generated plot. If None, the plot is shown interactively.
136
+ Default is None.
107
137
 
108
138
  Returns
109
139
  -------
@@ -113,44 +143,86 @@ class ROMSOutput:
113
143
  Raises
114
144
  ------
115
145
  ValueError
116
- If the specified `var_name` is not one of the valid options.
117
- If the field specified by `var_name` is 3D and none of `s`, `eta`, or `xi` are specified.
118
- If the field specified by `var_name` is 2D and both `eta` and `xi` are specified.
146
+ - If the specified `var_name` is not one of the valid options.
147
+ - If the field specified by `var_name` is 3D and none of `s`, `eta`, `xi`, `depth`, `lat`, or `lon` are specified.
148
+ - If the field specified by `var_name` is 2D and both `eta` and `xi` or both `lat` and `lon` are specified.
149
+ - If conflicting dimensions are specified (e.g., specifying `eta`/`xi` with `lat`/`lon` or both `s` and `depth`).
150
+ - If more than two dimensions are specified for a 3D field.
151
+ - If `time` exceeds the bounds of the time dimension.
152
+ - If `time` is specified for a field that does not have a time dimension.
153
+ - If `eta` or `xi` indices are out of bounds.
154
+ - If `eta` or `xi` lie on the boundary when `include_boundary=False`.
119
155
  """
120
-
121
- # Input checks
156
+ # Check if variable exists
122
157
  if var_name not in self.ds:
123
- raise ValueError(f"Variable '{var_name}' is not found in dataset.")
158
+ raise ValueError(f"Variable '{var_name}' is not found in the dataset.")
124
159
 
125
- if "time" in self.ds[var_name].dims:
126
- if time >= len(self.ds[var_name].time):
160
+ # Pick the variable
161
+ field = self.ds[var_name]
162
+
163
+ # Check and pick time
164
+ if "time" in field.dims:
165
+ if time >= len(field.time):
127
166
  raise ValueError(
128
167
  f"Invalid time index: The specified time index ({time}) exceeds the maximum index "
129
- f"({len(self.ds[var_name].time) - 1}) for the 'time' dimension in variable '{var_name}'."
168
+ f"({len(field.time) - 1}) for the 'time' dimension."
130
169
  )
131
- field = self.ds[var_name].isel(time=time)
170
+ field = field.isel(time=time)
132
171
  else:
133
172
  if time > 0:
134
173
  raise ValueError(
135
- f"Invalid input: The variable '{var_name}' does not have a 'time' dimension, "
174
+ f"Invalid input: The field does not have a 'time' dimension, "
136
175
  f"but a time index ({time}) greater than 0 was provided."
137
176
  )
138
- field = self.ds[var_name]
139
177
 
140
- if len(field.dims) == 3:
141
- if not any([s is not None, eta is not None, xi is not None]):
142
- raise ValueError(
143
- "Invalid input: For 3D fields, you must specify at least one of the dimensions 's', 'eta', or 'xi'."
144
- )
145
- if all([s is not None, eta is not None, xi is not None]):
146
- raise ValueError(
147
- "Ambiguous input: For 3D fields, specify at most two of 's', 'eta', or 'xi'. Specifying all three is not allowed."
148
- )
178
+ # Input checks
179
+ _validate_plot_inputs(field, s, eta, xi, depth, lat, lon, include_boundary)
180
+
181
+ # Get horizontal dimensions and grid location
182
+ horizontal_dims_dict = {
183
+ "rho": {"eta": "eta_rho", "xi": "xi_rho"},
184
+ "u": {"eta": "eta_rho", "xi": "xi_u"},
185
+ "v": {"eta": "eta_v", "xi": "xi_rho"},
186
+ }
187
+ for loc, horizontal_dims in horizontal_dims_dict.items():
188
+ if all(dim in field.dims for dim in horizontal_dims.values()):
189
+ break
190
+
191
+ # Convert relative to absolute indices
192
+ def _get_absolute_index(idx, field, dim_name):
193
+ index = field[dim_name].isel(**{dim_name: idx}).item()
194
+ return index
195
+
196
+ if eta is not None and eta < 0:
197
+ eta = _get_absolute_index(eta, field, horizontal_dims["eta"])
198
+ if xi is not None and xi < 0:
199
+ xi = _get_absolute_index(xi, field, horizontal_dims["xi"])
200
+ if s is not None and s < 0:
201
+ s = _get_absolute_index(s, field, "s_rho")
202
+
203
+ # Set spatial coordinates
204
+ lat_deg = self.grid.ds[f"lat_{loc}"]
205
+ lon_deg = self.grid.ds[f"lon_{loc}"]
206
+ if self.grid.straddle:
207
+ lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
208
+ field = field.assign_coords({"lon": lon_deg, "lat": lat_deg})
149
209
 
150
- if len(field.dims) == 2 and all([eta is not None, xi is not None]):
151
- raise ValueError(
152
- "Conflicting input: For 2D fields, specify only one dimension, either 'eta' or 'xi', not both."
153
- )
210
+ # Mask the field
211
+ mask = self.grid.ds[f"mask_{loc}"]
212
+ field = field.where(mask)
213
+
214
+ # Assign eta and xi as coordinates
215
+ coords_to_assign = {dim: field[dim] for dim in horizontal_dims.values()}
216
+ field = field.assign_coords(**coords_to_assign)
217
+
218
+ # Remove horizontal boundary if desired
219
+ slice_dict = {
220
+ "rho": {"eta_rho": slice(1, -1), "xi_rho": slice(1, -1)},
221
+ "u": {"eta_rho": slice(1, -1), "xi_u": slice(1, -1)},
222
+ "v": {"eta_v": slice(1, -1), "xi_rho": slice(1, -1)},
223
+ }
224
+ if not include_boundary:
225
+ field = field.isel(**slice_dict[loc])
154
226
 
155
227
  # Load the data
156
228
  if self.use_dask:
@@ -159,140 +231,156 @@ class ROMSOutput:
159
231
  with ProgressBar():
160
232
  field.load()
161
233
 
162
- # Get correct mask and spatial coordinates
163
- if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
164
- loc = "rho"
165
- elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
166
- loc = "u"
167
- elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
168
- loc = "v"
169
- else:
170
- ValueError("provided field does not have two horizontal dimension")
171
-
172
- mask = self.grid.ds[f"mask_{loc}"]
173
- lat_deg = self.grid.ds[f"lat_{loc}"]
174
- lon_deg = self.grid.ds[f"lon_{loc}"]
175
-
176
- if self.grid.straddle:
177
- lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
234
+ # Compute layer depth for 3D fields when depth contours are requested or no vertical layer is specified.
235
+ compute_layer_depth = len(field.dims) > 2 and (depth_contours or s is None)
236
+ if compute_layer_depth:
237
+ if eta is not None or xi is not None:
238
+ # Computing depth coordinates directly for the slice in question is more efficient
239
+ # than using .ds_depth_coords, which computes depth coordinates for full field
240
+ if self.adjust_depth_for_sea_surface_height:
241
+ zeta = self.ds.zeta.isel(time=time)
242
+ else:
243
+ zeta = 0
244
+ if compute_layer_depth:
245
+ layer_depth = compute_depth_coordinates(
246
+ self.grid.ds,
247
+ zeta,
248
+ depth_type="layer",
249
+ location=loc,
250
+ eta=eta,
251
+ xi=xi,
252
+ )
253
+ else:
254
+ self._get_depth_coordinates(depth_type="layer", locations=[loc])
255
+ layer_depth = self.ds_depth_coords[f"layer_depth_{loc}"]
256
+ if self.adjust_depth_for_sea_surface_height:
257
+ layer_depth = layer_depth.isel(time=time)
258
+
259
+ if not include_boundary:
260
+ # Apply valid slices only for dimensions that exist in layer_depth.dims
261
+ layer_depth = layer_depth.isel(
262
+ **{
263
+ dim: s
264
+ for dim, s in slice_dict.get(loc, {}).items()
265
+ if dim in layer_depth.dims
266
+ }
267
+ )
268
+ layer_depth.load()
178
269
 
179
- field = field.assign_coords({"lon": lon_deg, "lat": lat_deg})
270
+ # Prepare figure title
271
+ formatted_time = np.datetime_as_string(field.abs_time.values, unit="m")
272
+ title = f"time: {formatted_time}"
180
273
 
181
- # Retrieve depth coordinates
182
- compute_layer_depth = (depth_contours or s is None) and len(field.dims) > 2
183
- compute_interface_depth = layer_contours and s is None
274
+ # Slice the field horizontally as desired
275
+ def _slice_along_dimension(field, title, dim_name, idx):
276
+ field = field.sel(**{dim_name: idx})
277
+ title = title + f", {dim_name} = {idx}"
278
+ return field, title
184
279
 
185
- # Compute depth coordinates directly instead of using .ds_depth_coords.
186
- # Many cases below require only a 1D or 2D slice, making direct computation
187
- # more efficient than triggering a full 3D depth computation just to extract a subset.
188
- # This is especially beneficial when using Dask or if .ds_depth_coords is precomputed.
189
- if self.adjust_depth_for_sea_surface_height:
190
- zeta = self.ds.zeta.isel(time=time)
191
- else:
192
- zeta = 0
193
- if compute_layer_depth:
194
- layer_depth = compute_depth_coordinates(
195
- self.grid.ds,
196
- zeta,
197
- depth_type="layer",
198
- location=loc,
199
- eta=eta,
200
- xi=xi,
280
+ if eta is not None:
281
+ field, title = _slice_along_dimension(
282
+ field, title, horizontal_dims["eta"], eta
201
283
  )
202
- if s is not None:
203
- layer_depth = layer_depth.isel(s_rho=s)
204
- if compute_interface_depth:
205
- interface_depth = compute_depth_coordinates(
206
- self.grid.ds,
207
- zeta,
208
- depth_type="interface",
209
- location=loc,
210
- eta=eta,
211
- xi=xi,
284
+ if xi is not None:
285
+ field, title = _slice_along_dimension(
286
+ field, title, horizontal_dims["xi"], xi
212
287
  )
213
- if s is not None:
214
- interface_depth = interface_depth.isel(s_w=s)
215
-
216
- # Slice the field as desired
217
- title = field.long_name
218
288
  if s is not None:
219
- title = title + f", s_rho = {field.s_rho[s].item()}"
220
- field = field.isel(s_rho=s)
289
+ field, title = _slice_along_dimension(field, title, "s_rho", s)
290
+ if compute_layer_depth:
291
+ layer_depth = layer_depth.isel(s_rho=s)
221
292
  else:
222
293
  depth_contours = False
223
294
 
224
- def _process_dimension(field, mask, dim_name, dim_values, idx, title):
225
- if dim_name in field.dims:
226
- title = title + f", {dim_name} = {dim_values[idx].item()}"
227
- field = field.isel(**{dim_name: idx})
228
- mask = mask.isel(**{dim_name: idx})
295
+ # Regrid laterally
296
+ if lat is not None or lon is not None:
297
+
298
+ if lat is not None:
299
+ lats = [lat]
300
+ title = title + f", lat = {lat}°N"
229
301
  else:
230
- raise ValueError(
231
- f"None of the expected dimensions ({dim_name}) found in field."
302
+ resolution = self._infer_nominal_horizontal_resolution()
303
+ lats = _generate_coordinate_range(
304
+ field.lat.min().values, field.lat.max().values, resolution
232
305
  )
233
- return field, mask, title
234
-
235
- if eta is not None:
236
- field, mask, title = _process_dimension(
237
- field,
238
- mask,
239
- "eta_rho" if "eta_rho" in field.dims else "eta_v",
240
- field.eta_rho if "eta_rho" in field.dims else field.eta_v,
241
- eta,
242
- title,
243
- )
306
+ lats = xr.DataArray(lats, dims=["lat"], attrs={"units": "°N"})
244
307
 
245
- if xi is not None:
246
- field, mask, title = _process_dimension(
247
- field,
248
- mask,
249
- "xi_rho" if "xi_rho" in field.dims else "xi_u",
250
- field.xi_rho if "xi_rho" in field.dims else field.xi_u,
251
- xi,
252
- title,
253
- )
308
+ if lon is not None:
309
+ lons = [lon]
310
+ title = title + f", lon = {lon}°E"
311
+ else:
312
+ resolution = self._infer_nominal_horizontal_resolution(lat)
313
+ lons = _generate_coordinate_range(
314
+ field.lon.min().values, field.lon.max().values, resolution
315
+ )
316
+ lons = xr.DataArray(lons, dims=["lon"], attrs={"units": "°E"})
254
317
 
255
- # Format to exclude seconds
256
- formatted_time = np.datetime_as_string(field.abs_time.values, unit="m")
257
- title = title + f", time: {formatted_time}"
318
+ target_coords = {"lat": lats, "lon": lons}
319
+ lateral_regrid = LateralRegridFromROMS(field, target_coords)
320
+ field = lateral_regrid.apply(field).squeeze()
321
+ if compute_layer_depth:
322
+ layer_depth = lateral_regrid.apply(layer_depth).squeeze()
258
323
 
324
+ # Assign depth as coordinate
259
325
  if compute_layer_depth:
260
326
  field = field.assign_coords({"layer_depth": layer_depth})
261
327
 
262
- if not include_boundary:
263
- slice_dict = None
264
-
265
- if eta is None and xi is None:
266
- slice_dict = {
267
- "rho": {"eta_rho": slice(1, -1), "xi_rho": slice(1, -1)},
268
- "u": {"eta_rho": slice(1, -1), "xi_u": slice(1, -1)},
269
- "v": {"eta_v": slice(1, -1), "xi_rho": slice(1, -1)},
270
- }
271
- elif eta is None:
272
- slice_dict = {
273
- "rho": {"eta_rho": slice(1, -1)},
274
- "u": {"eta_rho": slice(1, -1)},
275
- "v": {"eta_v": slice(1, -1)},
276
- }
277
- elif xi is None:
278
- slice_dict = {
279
- "rho": {"xi_rho": slice(1, -1)},
280
- "u": {"xi_u": slice(1, -1)},
281
- "v": {"xi_rho": slice(1, -1)},
282
- }
283
- if slice_dict is not None:
284
- if loc in slice_dict:
285
- field = field.isel(**slice_dict[loc])
286
- mask = mask.isel(**slice_dict[loc])
328
+ def _remove_edge_nans(field, xdim, layer_depth=None):
329
+ """Removes NaNs from the edges along the specified dimension."""
330
+ if xdim in field.dims:
331
+ if layer_depth is not None:
332
+ nan_mask = layer_depth.isnull().sum(
333
+ dim=[dim for dim in layer_depth.dims if dim != xdim]
334
+ )
335
+ else:
336
+ nan_mask = field.isnull().sum(
337
+ dim=[dim for dim in field.dims if dim != xdim]
338
+ )
339
+
340
+ # Find the valid indices where the sum of the nans is 0
341
+ valid_indices = np.where(nan_mask.values == 0)[0]
342
+
343
+ if len(valid_indices) > 0:
344
+ first_valid = valid_indices[0]
345
+ last_valid = valid_indices[-1]
346
+
347
+ field = field.isel({xdim: slice(first_valid, last_valid + 1)})
348
+ if layer_depth is not None:
349
+ layer_depth = layer_depth.isel(
350
+ {xdim: slice(first_valid, last_valid + 1)}
351
+ )
352
+
353
+ return field, layer_depth
354
+
355
+ if lat is not None:
356
+ field, layer_depth = _remove_edge_nans(
357
+ field, "lon", layer_depth if "layer_depth" in locals() else None
358
+ )
359
+ if lon is not None:
360
+ field, layer_depth = _remove_edge_nans(
361
+ field, "lat", layer_depth if "layer_depth" in locals() else None
362
+ )
363
+
364
+ # Regrid vertically
365
+ if depth is not None:
366
+ vertical_regrid = VerticalRegridFromROMS(self.ds)
367
+ # Save attributes before vertical regridding
368
+ attrs = field.attrs
369
+ field = vertical_regrid.apply(
370
+ field, layer_depth, np.array([depth])
371
+ ).squeeze()
372
+ # Reset attributes
373
+ field.attrs = attrs
374
+ title = title + f", depth = {depth}m"
287
375
 
288
376
  # Choose colorbar
289
377
  if var_name in ["u", "v", "w", "ubar", "vbar", "zeta"]:
290
- vmax = max(field.where(mask).max().values, -field.where(mask).min().values)
378
+ vmax = max(field.max().values, -field.min().values)
291
379
  vmin = -vmax
292
380
  cmap = plt.colormaps.get_cmap("RdBu_r")
293
381
  else:
294
- vmax = field.where(mask).max().values
295
- vmin = field.where(mask).min().values
382
+ vmax = field.max().values
383
+ vmin = field.min().values
296
384
  if var_name in ["temp", "salt"]:
297
385
  cmap = plt.colormaps.get_cmap("YlOrRd")
298
386
  else:
@@ -301,37 +389,31 @@ class ROMSOutput:
301
389
  kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
302
390
 
303
391
  # Plotting
304
- if eta is None and xi is None:
305
- _plot(
306
- field=field.where(mask),
392
+ if (eta is None and xi is None) and (lat is None and lon is None):
393
+ fig = _plot(
394
+ field=field,
307
395
  depth_contours=depth_contours,
308
396
  title=title,
309
397
  kwargs=kwargs,
310
- c="g",
398
+ c=None,
311
399
  )
312
400
  else:
313
401
  if len(field.dims) == 2:
314
- if not layer_contours:
315
- interface_depth = None
316
- else:
317
- # restrict number of layer_contours to 10 for the sake of plot clearity
318
- nr_layers = len(interface_depth["s_w"])
319
- selected_layers = np.linspace(
320
- 0, nr_layers - 1, min(nr_layers, 10), dtype=int
321
- )
322
- interface_depth = interface_depth.isel(s_w=selected_layers)
323
- _section_plot(
324
- field.where(mask),
325
- interface_depth=interface_depth,
402
+ fig = _section_plot(
403
+ field,
404
+ interface_depth=None,
326
405
  title=title,
327
406
  kwargs=kwargs,
328
407
  ax=ax,
329
408
  )
330
409
  else:
331
410
  if "s_rho" in field.dims:
332
- _profile_plot(field.where(mask), title=title, ax=ax)
411
+ fig = _profile_plot(field, title=title, ax=ax)
333
412
  else:
334
- _line_plot(field.where(mask), title=title, ax=ax)
413
+ fig = _line_plot(field, title=title, ax=ax)
414
+
415
+ if save_path:
416
+ plt.savefig(save_path, dpi=300, bbox_inches="tight")
335
417
 
336
418
  def _get_depth_coordinates(self, depth_type="layer", locations=["rho"]):
337
419
  """Ensure depth coordinates are stored for a given location and depth type.
@@ -435,7 +517,7 @@ class ROMSOutput:
435
517
  )
436
518
  else:
437
519
  # Set the model reference date if not already set
438
- object.__setattr__(self, "model_reference_date", inferred_date)
520
+ self.model_reference_date = inferred_date
439
521
  else:
440
522
  # Handle case where no match is found
441
523
  if hasattr(self, "model_reference_date") and self.model_reference_date:
@@ -567,3 +649,47 @@ class ROMSOutput:
567
649
  )
568
650
 
569
651
  return ds
652
+
653
+ def _infer_nominal_horizontal_resolution(self, lat=None):
654
+ """Estimate the nominal horizontal resolution of the grid in degrees at a
655
+ specified latitude.
656
+
657
+ This method calculates the nominal horizontal resolution of the grid by first
658
+ determining the average grid spacing in meters. The spacing is then converted
659
+ to degrees, accounting for the Earth's curvature, and the latitude where the
660
+ resolution is being computed.
661
+
662
+ Parameters
663
+ ----------
664
+ lat : float, optional
665
+ Latitude (in degrees) at which to estimate the horizontal resolution.
666
+ If not provided, the resolution is calculated at the average latitude of
667
+ the grid (`lat_rho`).
668
+
669
+ Returns
670
+ -------
671
+ float
672
+ The estimated horizontal resolution in degrees, adjusted for the Earth's curvature.
673
+ """
674
+ # Earth radius in meters
675
+ r_earth = 6371315.0
676
+
677
+ if lat is None:
678
+ # Center latitude in degrees
679
+ lat = (self.grid.ds.lat_rho.max() + self.grid.ds.lat_rho.min()) / 2
680
+
681
+ # Convert latitude to radians
682
+ lat_rad = np.deg2rad(lat)
683
+
684
+ # Mean resolution in meters
685
+ resolution_in_m = (
686
+ (1 / self.grid.ds.pm).mean() + (1 / self.grid.ds.pn).mean()
687
+ ) / 2
688
+
689
+ # Meters per degree at the equator
690
+ meters_per_degree = 2 * np.pi * r_earth / 360
691
+
692
+ # Correct for latitude by multiplying by cos(latitude) for longitude
693
+ resolution_in_degrees = resolution_in_m / (meters_per_degree * np.cos(lat_rad))
694
+
695
+ return resolution_in_degrees