roms-tools 2.5.0__py3-none-any.whl → 2.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ci/environment-with-xesmf.yml +16 -0
- roms_tools/analysis/roms_output.py +521 -187
- roms_tools/analysis/utils.py +169 -0
- roms_tools/plot.py +351 -214
- roms_tools/regrid.py +161 -9
- roms_tools/setup/boundary_forcing.py +22 -22
- roms_tools/setup/datasets.py +40 -44
- roms_tools/setup/grid.py +28 -28
- roms_tools/setup/initial_conditions.py +23 -31
- roms_tools/setup/nesting.py +3 -3
- roms_tools/setup/river_forcing.py +22 -23
- roms_tools/setup/surface_forcing.py +14 -13
- roms_tools/setup/tides.py +7 -7
- roms_tools/setup/topography.py +2 -2
- roms_tools/tests/test_analysis/test_roms_output.py +299 -188
- roms_tools/tests/test_regrid.py +85 -2
- roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/.zmetadata +2 -2
- roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/.zmetadata +2 -2
- roms_tools/tests/test_setup/test_river_forcing.py +47 -51
- roms_tools/tests/test_vertical_coordinate.py +73 -0
- roms_tools/utils.py +11 -7
- roms_tools/vertical_coordinate.py +7 -0
- {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info}/METADATA +22 -11
- {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info}/RECORD +33 -30
- {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info}/WHEEL +1 -1
- /roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/{river_location → river_flux}/.zarray +0 -0
- /roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/{river_location → river_flux}/.zattrs +0 -0
- /roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/{river_location → river_flux}/0.0 +0 -0
- /roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/{river_location → river_flux}/.zarray +0 -0
- /roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/{river_location → river_flux}/.zattrs +0 -0
- /roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/{river_location → river_flux}/0.0 +0 -0
- {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info/licenses}/LICENSE +0 -0
- {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info}/top_level.txt +0 -0
|
@@ -1,21 +1,25 @@
|
|
|
1
1
|
import xarray as xr
|
|
2
2
|
import numpy as np
|
|
3
3
|
import matplotlib.pyplot as plt
|
|
4
|
+
from roms_tools.plot import _plot, _section_plot, _profile_plot, _line_plot
|
|
4
5
|
from roms_tools.utils import _load_data
|
|
6
|
+
from roms_tools.regrid import LateralRegridFromROMS, VerticalRegridFromROMS
|
|
5
7
|
from dataclasses import dataclass, field
|
|
6
8
|
from typing import Union, Optional
|
|
7
9
|
from pathlib import Path
|
|
8
10
|
import re
|
|
9
11
|
import logging
|
|
12
|
+
import warnings
|
|
10
13
|
from datetime import datetime, timedelta
|
|
11
14
|
from roms_tools import Grid
|
|
12
|
-
from roms_tools.plot import _plot, _section_plot, _profile_plot, _line_plot
|
|
13
15
|
from roms_tools.vertical_coordinate import (
|
|
14
16
|
compute_depth_coordinates,
|
|
15
17
|
)
|
|
18
|
+
from roms_tools.utils import interpolate_from_rho_to_u, interpolate_from_rho_to_v
|
|
19
|
+
from roms_tools.analysis.utils import _validate_plot_inputs, _generate_coordinate_range
|
|
16
20
|
|
|
17
21
|
|
|
18
|
-
@dataclass(
|
|
22
|
+
@dataclass(kw_only=True)
|
|
19
23
|
class ROMSOutput:
|
|
20
24
|
"""Represents ROMS model output.
|
|
21
25
|
|
|
@@ -49,10 +53,10 @@ class ROMSOutput:
|
|
|
49
53
|
self._check_vertical_coordinate(ds)
|
|
50
54
|
ds = self._add_absolute_time(ds)
|
|
51
55
|
ds = self._add_lat_lon_coords(ds)
|
|
52
|
-
|
|
56
|
+
self.ds = ds
|
|
53
57
|
|
|
54
58
|
# Dataset for depth coordinates
|
|
55
|
-
|
|
59
|
+
self.ds_depth_coords = xr.Dataset()
|
|
56
60
|
|
|
57
61
|
def plot(
|
|
58
62
|
self,
|
|
@@ -61,12 +65,15 @@ class ROMSOutput:
|
|
|
61
65
|
s=None,
|
|
62
66
|
eta=None,
|
|
63
67
|
xi=None,
|
|
68
|
+
depth=None,
|
|
69
|
+
lat=None,
|
|
70
|
+
lon=None,
|
|
64
71
|
include_boundary=False,
|
|
65
72
|
depth_contours=False,
|
|
66
|
-
layer_contours=False,
|
|
67
73
|
ax=None,
|
|
74
|
+
save_path=None,
|
|
68
75
|
) -> None:
|
|
69
|
-
"""
|
|
76
|
+
"""Generate a plot of a ROMS output field for a specified vertical or horizontal
|
|
70
77
|
slice.
|
|
71
78
|
|
|
72
79
|
Parameters
|
|
@@ -79,31 +86,56 @@ class ROMSOutput:
|
|
|
79
86
|
|
|
80
87
|
time : int, optional
|
|
81
88
|
Index of the time dimension to plot. Default is 0.
|
|
89
|
+
|
|
82
90
|
s : int, optional
|
|
83
|
-
The index of the vertical layer (`s_rho`) to plot. If
|
|
84
|
-
will
|
|
91
|
+
The index of the vertical layer (`s_rho`) to plot. If specified, the plot
|
|
92
|
+
will display a horizontal slice at that layer. Cannot be used simultaneously
|
|
93
|
+
with `depth`. Default is None.
|
|
94
|
+
|
|
85
95
|
eta : int, optional
|
|
86
|
-
The eta-index to plot. Used for vertical sections or
|
|
87
|
-
|
|
96
|
+
The eta-index to plot. Used for generating vertical sections or plotting
|
|
97
|
+
horizontal slices along a constant eta-coordinate. Cannot be used simultaneously
|
|
98
|
+
with `lat` or `lon`, but can be combined with `xi`. Default is None.
|
|
99
|
+
|
|
88
100
|
xi : int, optional
|
|
89
|
-
The xi-index to plot. Used for vertical sections or
|
|
90
|
-
|
|
101
|
+
The xi-index to plot. Used for generating vertical sections or plotting
|
|
102
|
+
horizontal slices along a constant xi-coordinate. Cannot be used simultaneously
|
|
103
|
+
with `lat` or `lon`, but can be combined with `eta`. Default is None.
|
|
104
|
+
|
|
105
|
+
depth : float, optional
|
|
106
|
+
Depth (in meters) to plot a horizontal slice at a specific depth level.
|
|
107
|
+
If specified, the plot will interpolate the field to the given depth.
|
|
108
|
+
Cannot be used simultaneously with `s` or for fields that are inherently
|
|
109
|
+
2D (such as "zeta"). Default is None.
|
|
110
|
+
|
|
111
|
+
lat : float, optional
|
|
112
|
+
Latitude (in degrees) to plot a vertical section at a specific
|
|
113
|
+
latitude. This option is useful for generating zonal (west-east)
|
|
114
|
+
sections. Cannot be used simultaneously with `eta` or `xi`, bu can be
|
|
115
|
+
combined with `lon`. Default is None.
|
|
116
|
+
|
|
117
|
+
lon : float, optional
|
|
118
|
+
Longitude (in degrees) to plot a vertical section at a specific
|
|
119
|
+
longitude. This option is useful for generating meridional (south-north) sections.
|
|
120
|
+
Cannot be used simultaneously with `eta` or `xi`, but can be combined
|
|
121
|
+
with `lat`. Default is None.
|
|
122
|
+
|
|
91
123
|
include_boundary : bool, optional
|
|
92
124
|
Whether to include the outermost grid cells along the `eta`- and `xi`-boundaries in the plot.
|
|
93
125
|
In diagnostic ROMS output fields, these boundary cells are set to zero, so excluding them can improve visualization.
|
|
94
|
-
This option is only relevant for 2D horizontal plots (`eta=None`, `xi=None`).
|
|
95
126
|
Default is False.
|
|
127
|
+
|
|
96
128
|
depth_contours : bool, optional
|
|
97
|
-
If True,
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
If True, contour lines representing the boundaries between vertical layers will
|
|
102
|
-
be added to the plot. This is particularly useful in vertical sections to
|
|
103
|
-
visualize the layering of the water column. For clarity, the number of layer
|
|
104
|
-
contours displayed is limited to a maximum of 10. Default is False.
|
|
129
|
+
If True, overlays contours representing lines of constant depth on the plot.
|
|
130
|
+
This option is only relevant when the `s` parameter is provided (i.e., not None).
|
|
131
|
+
By default, depth contours are not shown (False).
|
|
132
|
+
|
|
105
133
|
ax : matplotlib.axes.Axes, optional
|
|
106
|
-
The axes to plot on. If None, a new figure is created. Note that this argument does not work for horizontal plots
|
|
134
|
+
The axes to plot on. If None, a new figure is created. Note that this argument does not work for 2D horizontal plots. Default is None.
|
|
135
|
+
|
|
136
|
+
save_path : str, optional
|
|
137
|
+
Path to save the generated plot. If None, the plot is shown interactively.
|
|
138
|
+
Default is None.
|
|
107
139
|
|
|
108
140
|
Returns
|
|
109
141
|
-------
|
|
@@ -113,44 +145,86 @@ class ROMSOutput:
|
|
|
113
145
|
Raises
|
|
114
146
|
------
|
|
115
147
|
ValueError
|
|
116
|
-
If the specified `var_name` is not one of the valid options.
|
|
117
|
-
If the field specified by `var_name` is 3D and none of `s`, `eta`, or `
|
|
118
|
-
If the field specified by `var_name` is 2D and both `eta` and `xi` are specified.
|
|
148
|
+
- If the specified `var_name` is not one of the valid options.
|
|
149
|
+
- If the field specified by `var_name` is 3D and none of `s`, `eta`, `xi`, `depth`, `lat`, or `lon` are specified.
|
|
150
|
+
- If the field specified by `var_name` is 2D and both `eta` and `xi` or both `lat` and `lon` are specified.
|
|
151
|
+
- If conflicting dimensions are specified (e.g., specifying `eta`/`xi` with `lat`/`lon` or both `s` and `depth`).
|
|
152
|
+
- If more than two dimensions are specified for a 3D field.
|
|
153
|
+
- If `time` exceeds the bounds of the time dimension.
|
|
154
|
+
- If `time` is specified for a field that does not have a time dimension.
|
|
155
|
+
- If `eta` or `xi` indices are out of bounds.
|
|
156
|
+
- If `eta` or `xi` lie on the boundary when `include_boundary=False`.
|
|
119
157
|
"""
|
|
120
|
-
|
|
121
|
-
# Input checks
|
|
158
|
+
# Check if variable exists
|
|
122
159
|
if var_name not in self.ds:
|
|
123
|
-
raise ValueError(f"Variable '{var_name}' is not found in dataset.")
|
|
160
|
+
raise ValueError(f"Variable '{var_name}' is not found in the dataset.")
|
|
124
161
|
|
|
125
|
-
|
|
126
|
-
|
|
162
|
+
# Pick the variable
|
|
163
|
+
field = self.ds[var_name]
|
|
164
|
+
|
|
165
|
+
# Check and pick time
|
|
166
|
+
if "time" in field.dims:
|
|
167
|
+
if time >= len(field.time):
|
|
127
168
|
raise ValueError(
|
|
128
169
|
f"Invalid time index: The specified time index ({time}) exceeds the maximum index "
|
|
129
|
-
f"({len(
|
|
170
|
+
f"({len(field.time) - 1}) for the 'time' dimension."
|
|
130
171
|
)
|
|
131
|
-
field =
|
|
172
|
+
field = field.isel(time=time)
|
|
132
173
|
else:
|
|
133
174
|
if time > 0:
|
|
134
175
|
raise ValueError(
|
|
135
|
-
f"Invalid input: The
|
|
176
|
+
f"Invalid input: The field does not have a 'time' dimension, "
|
|
136
177
|
f"but a time index ({time}) greater than 0 was provided."
|
|
137
178
|
)
|
|
138
|
-
field = self.ds[var_name]
|
|
139
179
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
180
|
+
# Input checks
|
|
181
|
+
_validate_plot_inputs(field, s, eta, xi, depth, lat, lon, include_boundary)
|
|
182
|
+
|
|
183
|
+
# Get horizontal dimensions and grid location
|
|
184
|
+
horizontal_dims_dict = {
|
|
185
|
+
"rho": {"eta": "eta_rho", "xi": "xi_rho"},
|
|
186
|
+
"u": {"eta": "eta_rho", "xi": "xi_u"},
|
|
187
|
+
"v": {"eta": "eta_v", "xi": "xi_rho"},
|
|
188
|
+
}
|
|
189
|
+
for loc, horizontal_dims in horizontal_dims_dict.items():
|
|
190
|
+
if all(dim in field.dims for dim in horizontal_dims.values()):
|
|
191
|
+
break
|
|
192
|
+
|
|
193
|
+
# Convert relative to absolute indices
|
|
194
|
+
def _get_absolute_index(idx, field, dim_name):
|
|
195
|
+
index = field[dim_name].isel(**{dim_name: idx}).item()
|
|
196
|
+
return index
|
|
197
|
+
|
|
198
|
+
if eta is not None and eta < 0:
|
|
199
|
+
eta = _get_absolute_index(eta, field, horizontal_dims["eta"])
|
|
200
|
+
if xi is not None and xi < 0:
|
|
201
|
+
xi = _get_absolute_index(xi, field, horizontal_dims["xi"])
|
|
202
|
+
if s is not None and s < 0:
|
|
203
|
+
s = _get_absolute_index(s, field, "s_rho")
|
|
204
|
+
|
|
205
|
+
# Set spatial coordinates
|
|
206
|
+
lat_deg = self.grid.ds[f"lat_{loc}"]
|
|
207
|
+
lon_deg = self.grid.ds[f"lon_{loc}"]
|
|
208
|
+
if self.grid.straddle:
|
|
209
|
+
lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
|
|
210
|
+
field = field.assign_coords({"lon": lon_deg, "lat": lat_deg})
|
|
149
211
|
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
212
|
+
# Mask the field
|
|
213
|
+
mask = self.grid.ds[f"mask_{loc}"]
|
|
214
|
+
field = field.where(mask)
|
|
215
|
+
|
|
216
|
+
# Assign eta and xi as coordinates
|
|
217
|
+
coords_to_assign = {dim: field[dim] for dim in horizontal_dims.values()}
|
|
218
|
+
field = field.assign_coords(**coords_to_assign)
|
|
219
|
+
|
|
220
|
+
# Remove horizontal boundary if desired
|
|
221
|
+
slice_dict = {
|
|
222
|
+
"rho": {"eta_rho": slice(1, -1), "xi_rho": slice(1, -1)},
|
|
223
|
+
"u": {"eta_rho": slice(1, -1), "xi_u": slice(1, -1)},
|
|
224
|
+
"v": {"eta_v": slice(1, -1), "xi_rho": slice(1, -1)},
|
|
225
|
+
}
|
|
226
|
+
if not include_boundary:
|
|
227
|
+
field = field.isel(**slice_dict[loc])
|
|
154
228
|
|
|
155
229
|
# Load the data
|
|
156
230
|
if self.use_dask:
|
|
@@ -159,140 +233,156 @@ class ROMSOutput:
|
|
|
159
233
|
with ProgressBar():
|
|
160
234
|
field.load()
|
|
161
235
|
|
|
162
|
-
#
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
236
|
+
# Compute layer depth for 3D fields when depth contours are requested or no vertical layer is specified.
|
|
237
|
+
compute_layer_depth = len(field.dims) > 2 and (depth_contours or s is None)
|
|
238
|
+
if compute_layer_depth:
|
|
239
|
+
if eta is not None or xi is not None:
|
|
240
|
+
# Computing depth coordinates directly for the slice in question is more efficient
|
|
241
|
+
# than using .ds_depth_coords, which computes depth coordinates for full field
|
|
242
|
+
if self.adjust_depth_for_sea_surface_height:
|
|
243
|
+
zeta = self.ds.zeta.isel(time=time)
|
|
244
|
+
else:
|
|
245
|
+
zeta = 0
|
|
246
|
+
if compute_layer_depth:
|
|
247
|
+
layer_depth = compute_depth_coordinates(
|
|
248
|
+
self.grid.ds,
|
|
249
|
+
zeta,
|
|
250
|
+
depth_type="layer",
|
|
251
|
+
location=loc,
|
|
252
|
+
eta=eta,
|
|
253
|
+
xi=xi,
|
|
254
|
+
)
|
|
255
|
+
else:
|
|
256
|
+
self._get_depth_coordinates(depth_type="layer", locations=[loc])
|
|
257
|
+
layer_depth = self.ds_depth_coords[f"layer_depth_{loc}"]
|
|
258
|
+
if self.adjust_depth_for_sea_surface_height:
|
|
259
|
+
layer_depth = layer_depth.isel(time=time)
|
|
260
|
+
|
|
261
|
+
if not include_boundary:
|
|
262
|
+
# Apply valid slices only for dimensions that exist in layer_depth.dims
|
|
263
|
+
layer_depth = layer_depth.isel(
|
|
264
|
+
**{
|
|
265
|
+
dim: s
|
|
266
|
+
for dim, s in slice_dict.get(loc, {}).items()
|
|
267
|
+
if dim in layer_depth.dims
|
|
268
|
+
}
|
|
269
|
+
)
|
|
270
|
+
layer_depth.load()
|
|
178
271
|
|
|
179
|
-
|
|
272
|
+
# Prepare figure title
|
|
273
|
+
formatted_time = np.datetime_as_string(field.abs_time.values, unit="m")
|
|
274
|
+
title = f"time: {formatted_time}"
|
|
180
275
|
|
|
181
|
-
#
|
|
182
|
-
|
|
183
|
-
|
|
276
|
+
# Slice the field horizontally as desired
|
|
277
|
+
def _slice_along_dimension(field, title, dim_name, idx):
|
|
278
|
+
field = field.sel(**{dim_name: idx})
|
|
279
|
+
title = title + f", {dim_name} = {idx}"
|
|
280
|
+
return field, title
|
|
184
281
|
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
# This is especially beneficial when using Dask or if .ds_depth_coords is precomputed.
|
|
189
|
-
if self.adjust_depth_for_sea_surface_height:
|
|
190
|
-
zeta = self.ds.zeta.isel(time=time)
|
|
191
|
-
else:
|
|
192
|
-
zeta = 0
|
|
193
|
-
if compute_layer_depth:
|
|
194
|
-
layer_depth = compute_depth_coordinates(
|
|
195
|
-
self.grid.ds,
|
|
196
|
-
zeta,
|
|
197
|
-
depth_type="layer",
|
|
198
|
-
location=loc,
|
|
199
|
-
eta=eta,
|
|
200
|
-
xi=xi,
|
|
282
|
+
if eta is not None:
|
|
283
|
+
field, title = _slice_along_dimension(
|
|
284
|
+
field, title, horizontal_dims["eta"], eta
|
|
201
285
|
)
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
interface_depth = compute_depth_coordinates(
|
|
206
|
-
self.grid.ds,
|
|
207
|
-
zeta,
|
|
208
|
-
depth_type="interface",
|
|
209
|
-
location=loc,
|
|
210
|
-
eta=eta,
|
|
211
|
-
xi=xi,
|
|
286
|
+
if xi is not None:
|
|
287
|
+
field, title = _slice_along_dimension(
|
|
288
|
+
field, title, horizontal_dims["xi"], xi
|
|
212
289
|
)
|
|
213
|
-
if s is not None:
|
|
214
|
-
interface_depth = interface_depth.isel(s_w=s)
|
|
215
|
-
|
|
216
|
-
# Slice the field as desired
|
|
217
|
-
title = field.long_name
|
|
218
290
|
if s is not None:
|
|
219
|
-
title = title
|
|
220
|
-
|
|
291
|
+
field, title = _slice_along_dimension(field, title, "s_rho", s)
|
|
292
|
+
if compute_layer_depth:
|
|
293
|
+
layer_depth = layer_depth.isel(s_rho=s)
|
|
221
294
|
else:
|
|
222
295
|
depth_contours = False
|
|
223
296
|
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
297
|
+
# Regrid laterally
|
|
298
|
+
if lat is not None or lon is not None:
|
|
299
|
+
|
|
300
|
+
if lat is not None:
|
|
301
|
+
lats = [lat]
|
|
302
|
+
title = title + f", lat = {lat}°N"
|
|
229
303
|
else:
|
|
230
|
-
|
|
231
|
-
|
|
304
|
+
resolution = self._infer_nominal_horizontal_resolution()
|
|
305
|
+
lats = _generate_coordinate_range(
|
|
306
|
+
field.lat.min().values, field.lat.max().values, resolution
|
|
232
307
|
)
|
|
233
|
-
|
|
308
|
+
lats = xr.DataArray(lats, dims=["lat"], attrs={"units": "°N"})
|
|
234
309
|
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
)
|
|
244
|
-
|
|
245
|
-
if xi is not None:
|
|
246
|
-
field, mask, title = _process_dimension(
|
|
247
|
-
field,
|
|
248
|
-
mask,
|
|
249
|
-
"xi_rho" if "xi_rho" in field.dims else "xi_u",
|
|
250
|
-
field.xi_rho if "xi_rho" in field.dims else field.xi_u,
|
|
251
|
-
xi,
|
|
252
|
-
title,
|
|
253
|
-
)
|
|
310
|
+
if lon is not None:
|
|
311
|
+
lons = [lon]
|
|
312
|
+
title = title + f", lon = {lon}°E"
|
|
313
|
+
else:
|
|
314
|
+
resolution = self._infer_nominal_horizontal_resolution(lat)
|
|
315
|
+
lons = _generate_coordinate_range(
|
|
316
|
+
field.lon.min().values, field.lon.max().values, resolution
|
|
317
|
+
)
|
|
318
|
+
lons = xr.DataArray(lons, dims=["lon"], attrs={"units": "°E"})
|
|
254
319
|
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
320
|
+
target_coords = {"lat": lats, "lon": lons}
|
|
321
|
+
lateral_regrid = LateralRegridFromROMS(field, target_coords)
|
|
322
|
+
field = lateral_regrid.apply(field).squeeze()
|
|
323
|
+
if compute_layer_depth:
|
|
324
|
+
layer_depth = lateral_regrid.apply(layer_depth).squeeze()
|
|
258
325
|
|
|
326
|
+
# Assign depth as coordinate
|
|
259
327
|
if compute_layer_depth:
|
|
260
328
|
field = field.assign_coords({"layer_depth": layer_depth})
|
|
261
329
|
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
330
|
+
def _remove_edge_nans(field, xdim, layer_depth=None):
|
|
331
|
+
"""Removes NaNs from the edges along the specified dimension."""
|
|
332
|
+
if xdim in field.dims:
|
|
333
|
+
if layer_depth is not None:
|
|
334
|
+
nan_mask = layer_depth.isnull().sum(
|
|
335
|
+
dim=[dim for dim in layer_depth.dims if dim != xdim]
|
|
336
|
+
)
|
|
337
|
+
else:
|
|
338
|
+
nan_mask = field.isnull().sum(
|
|
339
|
+
dim=[dim for dim in field.dims if dim != xdim]
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
# Find the valid indices where the sum of the nans is 0
|
|
343
|
+
valid_indices = np.where(nan_mask.values == 0)[0]
|
|
344
|
+
|
|
345
|
+
if len(valid_indices) > 0:
|
|
346
|
+
first_valid = valid_indices[0]
|
|
347
|
+
last_valid = valid_indices[-1]
|
|
348
|
+
|
|
349
|
+
field = field.isel({xdim: slice(first_valid, last_valid + 1)})
|
|
350
|
+
if layer_depth is not None:
|
|
351
|
+
layer_depth = layer_depth.isel(
|
|
352
|
+
{xdim: slice(first_valid, last_valid + 1)}
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
return field, layer_depth
|
|
356
|
+
|
|
357
|
+
if lat is not None:
|
|
358
|
+
field, layer_depth = _remove_edge_nans(
|
|
359
|
+
field, "lon", layer_depth if "layer_depth" in locals() else None
|
|
360
|
+
)
|
|
361
|
+
if lon is not None:
|
|
362
|
+
field, layer_depth = _remove_edge_nans(
|
|
363
|
+
field, "lat", layer_depth if "layer_depth" in locals() else None
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
# Regrid vertically
|
|
367
|
+
if depth is not None:
|
|
368
|
+
vertical_regrid = VerticalRegridFromROMS(self.ds)
|
|
369
|
+
# Save attributes before vertical regridding
|
|
370
|
+
attrs = field.attrs
|
|
371
|
+
field = vertical_regrid.apply(
|
|
372
|
+
field, layer_depth, np.array([depth])
|
|
373
|
+
).squeeze()
|
|
374
|
+
# Reset attributes
|
|
375
|
+
field.attrs = attrs
|
|
376
|
+
title = title + f", depth = {depth}m"
|
|
287
377
|
|
|
288
378
|
# Choose colorbar
|
|
289
379
|
if var_name in ["u", "v", "w", "ubar", "vbar", "zeta"]:
|
|
290
|
-
vmax = max(field.
|
|
380
|
+
vmax = max(field.max().values, -field.min().values)
|
|
291
381
|
vmin = -vmax
|
|
292
382
|
cmap = plt.colormaps.get_cmap("RdBu_r")
|
|
293
383
|
else:
|
|
294
|
-
vmax = field.
|
|
295
|
-
vmin = field.
|
|
384
|
+
vmax = field.max().values
|
|
385
|
+
vmin = field.min().values
|
|
296
386
|
if var_name in ["temp", "salt"]:
|
|
297
387
|
cmap = plt.colormaps.get_cmap("YlOrRd")
|
|
298
388
|
else:
|
|
@@ -301,37 +391,170 @@ class ROMSOutput:
|
|
|
301
391
|
kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
|
|
302
392
|
|
|
303
393
|
# Plotting
|
|
304
|
-
if eta is None and xi is None:
|
|
305
|
-
_plot(
|
|
306
|
-
field=field
|
|
394
|
+
if (eta is None and xi is None) and (lat is None and lon is None):
|
|
395
|
+
fig = _plot(
|
|
396
|
+
field=field,
|
|
307
397
|
depth_contours=depth_contours,
|
|
308
398
|
title=title,
|
|
309
399
|
kwargs=kwargs,
|
|
310
|
-
c=
|
|
400
|
+
c=None,
|
|
311
401
|
)
|
|
312
402
|
else:
|
|
313
403
|
if len(field.dims) == 2:
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
# restrict number of layer_contours to 10 for the sake of plot clearity
|
|
318
|
-
nr_layers = len(interface_depth["s_w"])
|
|
319
|
-
selected_layers = np.linspace(
|
|
320
|
-
0, nr_layers - 1, min(nr_layers, 10), dtype=int
|
|
321
|
-
)
|
|
322
|
-
interface_depth = interface_depth.isel(s_w=selected_layers)
|
|
323
|
-
_section_plot(
|
|
324
|
-
field.where(mask),
|
|
325
|
-
interface_depth=interface_depth,
|
|
404
|
+
fig = _section_plot(
|
|
405
|
+
field,
|
|
406
|
+
interface_depth=None,
|
|
326
407
|
title=title,
|
|
327
408
|
kwargs=kwargs,
|
|
328
409
|
ax=ax,
|
|
329
410
|
)
|
|
330
411
|
else:
|
|
331
412
|
if "s_rho" in field.dims:
|
|
332
|
-
_profile_plot(field
|
|
413
|
+
fig = _profile_plot(field, title=title, ax=ax)
|
|
333
414
|
else:
|
|
334
|
-
_line_plot(field
|
|
415
|
+
fig = _line_plot(field, title=title, ax=ax)
|
|
416
|
+
|
|
417
|
+
if save_path:
|
|
418
|
+
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
|
419
|
+
|
|
420
|
+
def regrid(self, var_names=None, horizontal_resolution=None, depth_levels=None):
|
|
421
|
+
"""Regrid the dataset both horizontally and vertically.
|
|
422
|
+
|
|
423
|
+
This method selects the specified variables, interpolates them onto a lat-lon-z horizontal grid. The horizontal target resolution and vertical target depth levels are either specified or inferred dynamically.
|
|
424
|
+
|
|
425
|
+
Parameters
|
|
426
|
+
----------
|
|
427
|
+
var_names : list of str, optional
|
|
428
|
+
List of variable names to be regridded. If None, all variables in the dataset
|
|
429
|
+
are used.
|
|
430
|
+
horizontal_resolution : float, optional
|
|
431
|
+
Target horizontal resolution in degrees. If None, the nominal horizontal resolution is inferred from the grid.
|
|
432
|
+
depth_levels : xarray.DataArray, numpy.ndarray, list, optional
|
|
433
|
+
Target depth levels. If None, depth levels are determined dynamically.
|
|
434
|
+
If provided as a list or numpy array, it is safely converted to an `xarray.DataArray`.
|
|
435
|
+
|
|
436
|
+
Returns
|
|
437
|
+
-------
|
|
438
|
+
xarray.Dataset
|
|
439
|
+
The regridded dataset.
|
|
440
|
+
"""
|
|
441
|
+
|
|
442
|
+
if var_names is None:
|
|
443
|
+
var_names = list(self.ds.data_vars)
|
|
444
|
+
|
|
445
|
+
# Check that all var_names exist in self.ds
|
|
446
|
+
missing_vars = [var for var in var_names if var not in self.ds.data_vars]
|
|
447
|
+
if missing_vars:
|
|
448
|
+
raise ValueError(
|
|
449
|
+
f"The following variables are not found in the dataset: {', '.join(missing_vars)}"
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
# Retain only the variables in var_names and drop others
|
|
453
|
+
ds = self.ds[var_names]
|
|
454
|
+
|
|
455
|
+
# Prepare lateral regrid
|
|
456
|
+
lat_deg = self.grid.ds["lat_rho"]
|
|
457
|
+
lon_deg = self.grid.ds["lon_rho"]
|
|
458
|
+
if self.grid.straddle:
|
|
459
|
+
lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
|
|
460
|
+
|
|
461
|
+
if horizontal_resolution is None:
|
|
462
|
+
horizontal_resolution = self._infer_nominal_horizontal_resolution()
|
|
463
|
+
lons = _generate_coordinate_range(
|
|
464
|
+
lon_deg.min().values, lon_deg.max().values, horizontal_resolution
|
|
465
|
+
)
|
|
466
|
+
lons = xr.DataArray(lons, dims=["lon"], attrs={"units": "°E"})
|
|
467
|
+
lats = _generate_coordinate_range(
|
|
468
|
+
lat_deg.min().values, lat_deg.max().values, horizontal_resolution
|
|
469
|
+
)
|
|
470
|
+
lats = xr.DataArray(lats, dims=["lat"], attrs={"units": "°N"})
|
|
471
|
+
target_coords = {"lat": lats, "lon": lons}
|
|
472
|
+
|
|
473
|
+
# Prepare vertical regrid
|
|
474
|
+
if depth_levels is None:
|
|
475
|
+
depth_levels, _ = self._compute_exponential_depth_levels()
|
|
476
|
+
|
|
477
|
+
# Ensure depth_levels is an xarray.DataArray
|
|
478
|
+
if not isinstance(depth_levels, xr.DataArray):
|
|
479
|
+
depth_levels = xr.DataArray(
|
|
480
|
+
np.asarray(depth_levels),
|
|
481
|
+
dims=["depth"],
|
|
482
|
+
attrs={"long_name": "Depth", "units": "m"},
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
depth_levels = depth_levels.astype(np.float32)
|
|
486
|
+
|
|
487
|
+
# Initialize list to hold regridded datasets
|
|
488
|
+
regridded_datasets = []
|
|
489
|
+
|
|
490
|
+
for loc, dims in [
|
|
491
|
+
("rho", ("eta_rho", "xi_rho")),
|
|
492
|
+
("u", ("eta_rho", "xi_u")),
|
|
493
|
+
("v", ("eta_v", "xi_rho")),
|
|
494
|
+
]:
|
|
495
|
+
var_names_loc = [
|
|
496
|
+
var_name
|
|
497
|
+
for var_name in var_names
|
|
498
|
+
if all(dim in ds[var_name].dims for dim in dims)
|
|
499
|
+
]
|
|
500
|
+
if var_names_loc:
|
|
501
|
+
ds_loc = (
|
|
502
|
+
ds[var_names_loc]
|
|
503
|
+
.rename({f"lat_{loc}": "lat", f"lon_{loc}": "lon"})
|
|
504
|
+
.where(self.grid.ds[f"mask_{loc}"])
|
|
505
|
+
)
|
|
506
|
+
self._get_depth_coordinates(depth_type="layer", locations=[loc])
|
|
507
|
+
layer_depth_loc = self.ds_depth_coords[f"layer_depth_{loc}"]
|
|
508
|
+
h_loc = self.grid.ds.h
|
|
509
|
+
if loc == "u":
|
|
510
|
+
h_loc = interpolate_from_rho_to_u(h_loc)
|
|
511
|
+
elif loc == "v":
|
|
512
|
+
h_loc = interpolate_from_rho_to_v(h_loc)
|
|
513
|
+
|
|
514
|
+
# Exclude the horizontal boundary cells since diagnostic variables may contain zeros there
|
|
515
|
+
ds_loc = ds_loc.isel({dims[0]: slice(1, -1), dims[1]: slice(1, -1)})
|
|
516
|
+
layer_depth_loc = layer_depth_loc.isel(
|
|
517
|
+
{dims[0]: slice(1, -1), dims[1]: slice(1, -1)}
|
|
518
|
+
)
|
|
519
|
+
h_loc = h_loc.isel({dims[0]: slice(1, -1), dims[1]: slice(1, -1)})
|
|
520
|
+
|
|
521
|
+
# Lateral regridding
|
|
522
|
+
lateral_regrid = LateralRegridFromROMS(ds_loc, target_coords)
|
|
523
|
+
ds_loc = lateral_regrid.apply(ds_loc)
|
|
524
|
+
layer_depth_loc = lateral_regrid.apply(layer_depth_loc)
|
|
525
|
+
h_loc = lateral_regrid.apply(h_loc)
|
|
526
|
+
# Vertical regridding
|
|
527
|
+
vertical_regrid = VerticalRegridFromROMS(ds_loc)
|
|
528
|
+
for var_name in var_names_loc:
|
|
529
|
+
if "s_rho" in ds_loc[var_name].dims:
|
|
530
|
+
attrs = ds_loc[var_name].attrs
|
|
531
|
+
regridded = vertical_regrid.apply(
|
|
532
|
+
ds_loc[var_name],
|
|
533
|
+
layer_depth_loc,
|
|
534
|
+
depth_levels,
|
|
535
|
+
mask_edges=False,
|
|
536
|
+
)
|
|
537
|
+
regridded = regridded.where(regridded.depth < h_loc)
|
|
538
|
+
ds_loc[var_name] = regridded
|
|
539
|
+
ds_loc[var_name].attrs = attrs
|
|
540
|
+
|
|
541
|
+
ds_loc = ds_loc.assign_coords({"depth": depth_levels})
|
|
542
|
+
|
|
543
|
+
# Collect regridded dataset for merging
|
|
544
|
+
regridded_datasets.append(ds_loc)
|
|
545
|
+
|
|
546
|
+
# Merge all regridded datasets
|
|
547
|
+
if regridded_datasets:
|
|
548
|
+
ds = xr.merge(regridded_datasets)
|
|
549
|
+
|
|
550
|
+
with warnings.catch_warnings():
|
|
551
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
552
|
+
ds = ds.rename({"abs_time": "time"}).set_index(time="time")
|
|
553
|
+
ds["time"].attrs = {"long_name": "Time"}
|
|
554
|
+
ds["lon"].attrs = {"long_name": "Longitude", "units": "Degrees East"}
|
|
555
|
+
ds["lat"].attrs = {"long_name": "Latitude", "units": "Degrees North"}
|
|
556
|
+
|
|
557
|
+
return ds
|
|
335
558
|
|
|
336
559
|
def _get_depth_coordinates(self, depth_type="layer", locations=["rho"]):
|
|
337
560
|
"""Ensure depth coordinates are stored for a given location and depth type.
|
|
@@ -378,9 +601,11 @@ class ROMSOutput:
|
|
|
378
601
|
zeta = 0
|
|
379
602
|
|
|
380
603
|
for location in locations:
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
604
|
+
var_name = f"{depth_type}_depth_{location}"
|
|
605
|
+
if var_name not in self.ds_depth_coords:
|
|
606
|
+
self.ds_depth_coords[var_name] = compute_depth_coordinates(
|
|
607
|
+
self.grid.ds, zeta, depth_type, location
|
|
608
|
+
)
|
|
384
609
|
|
|
385
610
|
def _load_model_output(self) -> xr.Dataset:
|
|
386
611
|
"""Load the model output."""
|
|
@@ -435,7 +660,7 @@ class ROMSOutput:
|
|
|
435
660
|
)
|
|
436
661
|
else:
|
|
437
662
|
# Set the model reference date if not already set
|
|
438
|
-
|
|
663
|
+
self.model_reference_date = inferred_date
|
|
439
664
|
else:
|
|
440
665
|
# Handle case where no match is found
|
|
441
666
|
if hasattr(self, "model_reference_date") and self.model_reference_date:
|
|
@@ -548,22 +773,131 @@ class ROMSOutput:
|
|
|
548
773
|
return ds
|
|
549
774
|
|
|
550
775
|
def _add_lat_lon_coords(self, ds: xr.Dataset) -> xr.Dataset:
|
|
551
|
-
"""Add latitude and longitude coordinates to the dataset.
|
|
776
|
+
"""Add latitude and longitude coordinates to the dataset based on the grid.
|
|
552
777
|
|
|
553
|
-
|
|
778
|
+
This method assigns latitude and longitude coordinates from the grid to the dataset.
|
|
779
|
+
It always adds the "lat_rho" and "lon_rho" coordinates. If the dataset contains the
|
|
780
|
+
"xi_u" or "eta_v" dimensions, it also adds the corresponding "lat_u", "lon_u",
|
|
781
|
+
"lat_v", and "lon_v" coordinates.
|
|
554
782
|
|
|
555
783
|
Parameters
|
|
556
784
|
----------
|
|
557
785
|
ds : xarray.Dataset
|
|
558
|
-
|
|
786
|
+
Input dataset to which latitude and longitude coordinates will be added.
|
|
559
787
|
|
|
560
788
|
Returns
|
|
561
789
|
-------
|
|
562
790
|
xarray.Dataset
|
|
563
|
-
|
|
791
|
+
Updated dataset with the appropriate latitude and longitude coordinates
|
|
792
|
+
assigned to "rho", "u", and "v" points if applicable.
|
|
564
793
|
"""
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
794
|
+
coords_to_add = {
|
|
795
|
+
"lat_rho": self.grid.ds["lat_rho"],
|
|
796
|
+
"lon_rho": self.grid.ds["lon_rho"],
|
|
797
|
+
}
|
|
798
|
+
|
|
799
|
+
if "xi_u" in ds.dims:
|
|
800
|
+
coords_to_add.update(
|
|
801
|
+
{"lat_u": self.grid.ds["lat_u"], "lon_u": self.grid.ds["lon_u"]}
|
|
802
|
+
)
|
|
803
|
+
if "eta_v" in ds.dims:
|
|
804
|
+
coords_to_add.update(
|
|
805
|
+
{"lat_v": self.grid.ds["lat_v"], "lon_v": self.grid.ds["lon_v"]}
|
|
806
|
+
)
|
|
568
807
|
|
|
808
|
+
# Add all necessary coordinates in one go
|
|
809
|
+
ds = ds.assign_coords(coords_to_add)
|
|
569
810
|
return ds
|
|
811
|
+
|
|
812
|
+
def _infer_nominal_horizontal_resolution(self, lat=None):
|
|
813
|
+
"""Estimate the nominal horizontal resolution of the grid in degrees at a
|
|
814
|
+
specified latitude.
|
|
815
|
+
|
|
816
|
+
This method calculates the nominal horizontal resolution of the grid by first
|
|
817
|
+
determining the average grid spacing in meters. The spacing is then converted
|
|
818
|
+
to degrees, accounting for the Earth's curvature, and the latitude where the
|
|
819
|
+
resolution is being computed.
|
|
820
|
+
|
|
821
|
+
Parameters
|
|
822
|
+
----------
|
|
823
|
+
lat : float, optional
|
|
824
|
+
Latitude (in degrees) at which to estimate the horizontal resolution.
|
|
825
|
+
If not provided, the resolution is calculated at the average latitude of
|
|
826
|
+
the grid (`lat_rho`).
|
|
827
|
+
|
|
828
|
+
Returns
|
|
829
|
+
-------
|
|
830
|
+
float
|
|
831
|
+
The estimated horizontal resolution in degrees, adjusted for the Earth's curvature.
|
|
832
|
+
"""
|
|
833
|
+
# Earth radius in meters
|
|
834
|
+
r_earth = 6371315.0
|
|
835
|
+
|
|
836
|
+
if lat is None:
|
|
837
|
+
# Center latitude in degrees
|
|
838
|
+
lat = (self.grid.ds.lat_rho.max() + self.grid.ds.lat_rho.min()) / 2
|
|
839
|
+
|
|
840
|
+
# Convert latitude to radians
|
|
841
|
+
lat_rad = np.deg2rad(lat)
|
|
842
|
+
|
|
843
|
+
# Mean resolution in meters
|
|
844
|
+
resolution_in_m = (
|
|
845
|
+
(1 / self.grid.ds.pm).mean() + (1 / self.grid.ds.pn).mean()
|
|
846
|
+
) / 2
|
|
847
|
+
|
|
848
|
+
# Meters per degree at the equator
|
|
849
|
+
meters_per_degree = 2 * np.pi * r_earth / 360
|
|
850
|
+
|
|
851
|
+
# Correct for latitude by multiplying by cos(latitude) for longitude
|
|
852
|
+
resolution_in_degrees = resolution_in_m / (meters_per_degree * np.cos(lat_rad))
|
|
853
|
+
|
|
854
|
+
return resolution_in_degrees
|
|
855
|
+
|
|
856
|
+
def _compute_exponential_depth_levels(self, Nz=None, depth=None, h=None):
|
|
857
|
+
"""Compute vertical grid center and face depths using an exponential profile.
|
|
858
|
+
|
|
859
|
+
Parameters
|
|
860
|
+
----------
|
|
861
|
+
Nz : int, optional
|
|
862
|
+
Number of vertical levels. Defaults to `len(self.ds.s_rho)`.
|
|
863
|
+
|
|
864
|
+
depth : float, optional
|
|
865
|
+
Total depth of the domain. Defaults to `grid.ds.h.max().values`.
|
|
866
|
+
|
|
867
|
+
h : float, optional
|
|
868
|
+
Scaling parameter for the exponential profile. Defaults to `Nz / 4.5`.
|
|
869
|
+
|
|
870
|
+
Returns
|
|
871
|
+
-------
|
|
872
|
+
tuple of numpy.ndarray
|
|
873
|
+
Depth values at the vertical grid centers (`z_centers`) and grid faces (`z_faces`),
|
|
874
|
+
both rounded to two decimal places.
|
|
875
|
+
"""
|
|
876
|
+
if Nz is None:
|
|
877
|
+
Nz = len(self.ds.s_rho)
|
|
878
|
+
if depth is None:
|
|
879
|
+
depth = self.grid.ds.h.max().values
|
|
880
|
+
if h is None:
|
|
881
|
+
h = Nz / 4.5
|
|
882
|
+
|
|
883
|
+
k = np.arange(1, Nz + 2)
|
|
884
|
+
|
|
885
|
+
# Define the exponential profile function
|
|
886
|
+
def exponential_profile(k, Nz, h):
|
|
887
|
+
return np.exp(k / h)
|
|
888
|
+
|
|
889
|
+
z_faces = np.vectorize(exponential_profile)(k, Nz, h)
|
|
890
|
+
|
|
891
|
+
# Normalize
|
|
892
|
+
z_faces -= z_faces[0]
|
|
893
|
+
z_faces *= depth / z_faces[-1]
|
|
894
|
+
z_faces[0] = 0.0
|
|
895
|
+
|
|
896
|
+
# Calculate center depths (average between adjacent face depths)
|
|
897
|
+
z_centers = (z_faces[:-1] + z_faces[1:]) / 2
|
|
898
|
+
|
|
899
|
+
# Round both z_faces and z_centers to two decimal places
|
|
900
|
+
z_faces = np.round(z_faces, 2)
|
|
901
|
+
z_centers = np.round(z_centers, 2)
|
|
902
|
+
|
|
903
|
+
return z_centers, z_faces
|