roms-tools 0.1.0__py3-none-any.whl → 0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ci/environment.yml +1 -0
- roms_tools/__init__.py +3 -0
- roms_tools/_version.py +1 -1
- roms_tools/setup/atmospheric_forcing.py +335 -393
- roms_tools/setup/boundary_forcing.py +711 -0
- roms_tools/setup/datasets.py +434 -25
- roms_tools/setup/fill.py +118 -5
- roms_tools/setup/grid.py +145 -19
- roms_tools/setup/initial_conditions.py +528 -0
- roms_tools/setup/plot.py +149 -4
- roms_tools/setup/tides.py +570 -437
- roms_tools/setup/topography.py +17 -2
- roms_tools/setup/utils.py +162 -0
- roms_tools/setup/vertical_coordinate.py +494 -0
- roms_tools/tests/test_atmospheric_forcing.py +1645 -0
- roms_tools/tests/test_boundary_forcing.py +332 -0
- roms_tools/tests/test_datasets.py +306 -0
- roms_tools/tests/test_grid.py +226 -0
- roms_tools/tests/test_initial_conditions.py +300 -0
- roms_tools/tests/test_tides.py +366 -0
- roms_tools/tests/test_topography.py +78 -0
- roms_tools/tests/test_vertical_coordinate.py +337 -0
- {roms_tools-0.1.0.dist-info → roms_tools-0.20.dist-info}/METADATA +3 -2
- roms_tools-0.20.dist-info/RECORD +28 -0
- {roms_tools-0.1.0.dist-info → roms_tools-0.20.dist-info}/WHEEL +1 -1
- roms_tools/tests/test_setup.py +0 -181
- roms_tools-0.1.0.dist-info/RECORD +0 -17
- {roms_tools-0.1.0.dist-info → roms_tools-0.20.dist-info}/LICENSE +0 -0
- {roms_tools-0.1.0.dist-info → roms_tools-0.20.dist-info}/top_level.txt +0 -0
roms_tools/setup/topography.py
CHANGED
|
@@ -4,6 +4,7 @@ import gcm_filters
|
|
|
4
4
|
from scipy.interpolate import RegularGridInterpolator
|
|
5
5
|
from scipy.ndimage import label
|
|
6
6
|
from roms_tools.setup.datasets import fetch_topo
|
|
7
|
+
from roms_tools.setup.utils import interpolate_from_rho_to_u, interpolate_from_rho_to_v
|
|
7
8
|
import warnings
|
|
8
9
|
from itertools import count
|
|
9
10
|
|
|
@@ -19,7 +20,7 @@ def _add_topography_and_mask(
|
|
|
19
20
|
hraw = xr.DataArray(data=hraw, dims=["eta_rho", "xi_rho"])
|
|
20
21
|
|
|
21
22
|
# Mask is obtained by finding locations where ocean depth is positive
|
|
22
|
-
mask = xr.where(hraw > 0, 1, 0)
|
|
23
|
+
mask = xr.where(hraw > 0, 1.0, 0.0)
|
|
23
24
|
|
|
24
25
|
# smooth topography domain-wide with Gaussian kernel to avoid grid scale instabilities
|
|
25
26
|
ds["hraw"] = _smooth_topography_globally(hraw, mask, smooth_factor)
|
|
@@ -37,6 +38,8 @@ def _add_topography_and_mask(
|
|
|
37
38
|
"units": "land/water (0/1)",
|
|
38
39
|
}
|
|
39
40
|
|
|
41
|
+
ds = _add_velocity_masks(ds)
|
|
42
|
+
|
|
40
43
|
# smooth topography locally to satisfy r < rmax
|
|
41
44
|
ds["h"] = _smooth_topography_locally(ds["hraw"] * ds["mask_rho"], hmin, rmax)
|
|
42
45
|
ds["h"].attrs = {
|
|
@@ -57,7 +60,7 @@ def _make_raw_topography(lon, lat, topography_source) -> np.ndarray:
|
|
|
57
60
|
topo_ds = fetch_topo(topography_source)
|
|
58
61
|
|
|
59
62
|
# the following will depend on the topography source
|
|
60
|
-
if topography_source == "
|
|
63
|
+
if topography_source == "ETOPO5":
|
|
61
64
|
topo_lon = topo_ds["topo_lon"].copy()
|
|
62
65
|
# Modify longitude values where necessary
|
|
63
66
|
topo_lon = xr.where(topo_lon < 0, topo_lon + 360, topo_lon)
|
|
@@ -240,3 +243,15 @@ def _add_topography_metadata(ds, topography_source, smooth_factor, hmin, rmax):
|
|
|
240
243
|
ds.attrs["rmax"] = rmax
|
|
241
244
|
|
|
242
245
|
return ds
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def _add_velocity_masks(ds):
|
|
249
|
+
|
|
250
|
+
# add u- and v-masks
|
|
251
|
+
ds["mask_u"] = interpolate_from_rho_to_u(ds["mask_rho"], method="multiplicative")
|
|
252
|
+
ds["mask_v"] = interpolate_from_rho_to_v(ds["mask_rho"], method="multiplicative")
|
|
253
|
+
|
|
254
|
+
ds["mask_u"].attrs = {"long_name": "Mask at u-points", "units": "land/water (0/1)"}
|
|
255
|
+
ds["mask_v"].attrs = {"long_name": "Mask at v-points", "units": "land/water (0/1)"}
|
|
256
|
+
|
|
257
|
+
return ds
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
import xarray as xr
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def nan_check(field, mask) -> None:
|
|
5
|
+
"""
|
|
6
|
+
Checks for NaN values at wet points in the field.
|
|
7
|
+
|
|
8
|
+
This function examines the interpolated input field for NaN values at positions indicated as wet points by the mask.
|
|
9
|
+
If any NaN values are found at these wet points, a ValueError is raised.
|
|
10
|
+
|
|
11
|
+
Parameters
|
|
12
|
+
----------
|
|
13
|
+
field : array-like
|
|
14
|
+
The data array to be checked for NaN values. This is typically an xarray.DataArray or numpy array.
|
|
15
|
+
|
|
16
|
+
mask : array-like
|
|
17
|
+
A boolean mask or data array with the same shape as `field`. The wet points (usually ocean points)
|
|
18
|
+
are indicated by `1` or `True`, and land points by `0` or `False`.
|
|
19
|
+
|
|
20
|
+
Raises
|
|
21
|
+
------
|
|
22
|
+
ValueError
|
|
23
|
+
If the field contains NaN values at any of the wet points indicated by the mask.
|
|
24
|
+
The error message will explain the potential cause and suggest ensuring the dataset's coverage.
|
|
25
|
+
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
# Replace values in field with 0 where mask is not 1
|
|
29
|
+
da = xr.where(mask == 1, field, 0)
|
|
30
|
+
|
|
31
|
+
# Check if any NaN values exist in the modified field
|
|
32
|
+
if da.isnull().any().values:
|
|
33
|
+
raise ValueError(
|
|
34
|
+
"NaN values found in interpolated field. This likely occurs because the ROMS grid, including "
|
|
35
|
+
"a small safety margin for interpolation, is not fully contained within the dataset's longitude/latitude range. Please ensure that the "
|
|
36
|
+
"dataset covers the entire area required by the ROMS grid."
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def interpolate_from_rho_to_u(field, method="additive"):
|
|
41
|
+
|
|
42
|
+
"""
|
|
43
|
+
Interpolates the given field from rho points to u points.
|
|
44
|
+
|
|
45
|
+
This function performs an interpolation from the rho grid (cell centers) to the u grid
|
|
46
|
+
(cell edges in the xi direction). Depending on the chosen method, it either averages
|
|
47
|
+
(additive) or multiplies (multiplicative) the field values between adjacent rho points
|
|
48
|
+
along the xi dimension. It also handles the removal of unnecessary coordinate variables
|
|
49
|
+
and updates the dimensions accordingly.
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
field : xr.DataArray
|
|
54
|
+
The input data array on the rho grid to be interpolated. It is assumed to have a dimension
|
|
55
|
+
named "xi_rho".
|
|
56
|
+
|
|
57
|
+
method : str, optional, default='additive'
|
|
58
|
+
The method to use for interpolation. Options are:
|
|
59
|
+
- 'additive': Average the field values between adjacent rho points.
|
|
60
|
+
- 'multiplicative': Multiply the field values between adjacent rho points. Appropriate for
|
|
61
|
+
binary masks.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
field_interpolated : xr.DataArray
|
|
66
|
+
The interpolated data array on the u grid with the dimension "xi_u".
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
if method == "additive":
|
|
70
|
+
field_interpolated = 0.5 * (field + field.shift(xi_rho=1)).isel(
|
|
71
|
+
xi_rho=slice(1, None)
|
|
72
|
+
)
|
|
73
|
+
elif method == "multiplicative":
|
|
74
|
+
field_interpolated = (field * field.shift(xi_rho=1)).isel(xi_rho=slice(1, None))
|
|
75
|
+
else:
|
|
76
|
+
raise NotImplementedError(f"Unsupported method '{method}' specified.")
|
|
77
|
+
|
|
78
|
+
if "lat_rho" in field_interpolated.coords:
|
|
79
|
+
field_interpolated.drop_vars(["lat_rho"])
|
|
80
|
+
if "lon_rho" in field_interpolated.coords:
|
|
81
|
+
field_interpolated.drop_vars(["lon_rho"])
|
|
82
|
+
|
|
83
|
+
field_interpolated = field_interpolated.swap_dims({"xi_rho": "xi_u"})
|
|
84
|
+
|
|
85
|
+
return field_interpolated
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def interpolate_from_rho_to_v(field, method="additive"):
|
|
89
|
+
|
|
90
|
+
"""
|
|
91
|
+
Interpolates the given field from rho points to v points.
|
|
92
|
+
|
|
93
|
+
This function performs an interpolation from the rho grid (cell centers) to the v grid
|
|
94
|
+
(cell edges in the eta direction). Depending on the chosen method, it either averages
|
|
95
|
+
(additive) or multiplies (multiplicative) the field values between adjacent rho points
|
|
96
|
+
along the eta dimension. It also handles the removal of unnecessary coordinate variables
|
|
97
|
+
and updates the dimensions accordingly.
|
|
98
|
+
|
|
99
|
+
Parameters
|
|
100
|
+
----------
|
|
101
|
+
field : xr.DataArray
|
|
102
|
+
The input data array on the rho grid to be interpolated. It is assumed to have a dimension
|
|
103
|
+
named "eta_rho".
|
|
104
|
+
|
|
105
|
+
method : str, optional, default='additive'
|
|
106
|
+
The method to use for interpolation. Options are:
|
|
107
|
+
- 'additive': Average the field values between adjacent rho points.
|
|
108
|
+
- 'multiplicative': Multiply the field values between adjacent rho points. Appropriate for
|
|
109
|
+
binary masks.
|
|
110
|
+
|
|
111
|
+
Returns
|
|
112
|
+
-------
|
|
113
|
+
field_interpolated : xr.DataArray
|
|
114
|
+
The interpolated data array on the v grid with the dimension "eta_v".
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
if method == "additive":
|
|
118
|
+
field_interpolated = 0.5 * (field + field.shift(eta_rho=1)).isel(
|
|
119
|
+
eta_rho=slice(1, None)
|
|
120
|
+
)
|
|
121
|
+
elif method == "multiplicative":
|
|
122
|
+
field_interpolated = (field * field.shift(eta_rho=1)).isel(
|
|
123
|
+
eta_rho=slice(1, None)
|
|
124
|
+
)
|
|
125
|
+
else:
|
|
126
|
+
raise NotImplementedError(f"Unsupported method '{method}' specified.")
|
|
127
|
+
|
|
128
|
+
if "lat_rho" in field_interpolated.coords:
|
|
129
|
+
field_interpolated.drop_vars(["lat_rho"])
|
|
130
|
+
if "lon_rho" in field_interpolated.coords:
|
|
131
|
+
field_interpolated.drop_vars(["lon_rho"])
|
|
132
|
+
|
|
133
|
+
field_interpolated = field_interpolated.swap_dims({"eta_rho": "eta_v"})
|
|
134
|
+
|
|
135
|
+
return field_interpolated
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def extrapolate_deepest_to_bottom(field: xr.DataArray, dim: str) -> xr.DataArray:
|
|
139
|
+
"""
|
|
140
|
+
Extrapolate the deepest non-NaN values to the bottom along a specified dimension.
|
|
141
|
+
|
|
142
|
+
Parameters
|
|
143
|
+
----------
|
|
144
|
+
field : xr.DataArray
|
|
145
|
+
The input data array containing NaN values that need to be filled. This array
|
|
146
|
+
should have at least one dimension named by `dim`.
|
|
147
|
+
dim : str
|
|
148
|
+
The name of the dimension along which to perform the interpolation and extrapolation.
|
|
149
|
+
Typically, this would be a vertical dimension such as 'depth' or 's_rho'.
|
|
150
|
+
|
|
151
|
+
Returns
|
|
152
|
+
-------
|
|
153
|
+
field_interpolated : xr.DataArray
|
|
154
|
+
A new data array with NaN values along the specified dimension filled by nearest
|
|
155
|
+
neighbor interpolation and extrapolation to the bottom. The original data array is not modified.
|
|
156
|
+
|
|
157
|
+
"""
|
|
158
|
+
field_interpolated = field.interpolate_na(
|
|
159
|
+
dim=dim, method="nearest", fill_value="extrapolate"
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
return field_interpolated
|
|
@@ -0,0 +1,494 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import xarray as xr
|
|
3
|
+
import yaml
|
|
4
|
+
import importlib.metadata
|
|
5
|
+
from dataclasses import dataclass, field, asdict
|
|
6
|
+
from roms_tools.setup.grid import Grid
|
|
7
|
+
from roms_tools.setup.utils import (
|
|
8
|
+
interpolate_from_rho_to_u,
|
|
9
|
+
interpolate_from_rho_to_v,
|
|
10
|
+
)
|
|
11
|
+
from roms_tools.setup.plot import _plot, _section_plot, _profile_plot, _line_plot
|
|
12
|
+
import matplotlib.pyplot as plt
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(frozen=True, kw_only=True)
|
|
16
|
+
class VerticalCoordinate:
|
|
17
|
+
"""
|
|
18
|
+
Represents vertical coordinate for ROMS.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
grid : Grid
|
|
23
|
+
Object representing the grid information.
|
|
24
|
+
N : int
|
|
25
|
+
The number of vertical levels.
|
|
26
|
+
theta_s : float
|
|
27
|
+
The surface control parameter. Must satisfy 0 < theta_s <= 10.
|
|
28
|
+
theta_b : float
|
|
29
|
+
The bottom control parameter. Must satisfy 0 < theta_b <= 4.
|
|
30
|
+
hc : float
|
|
31
|
+
The critical depth.
|
|
32
|
+
|
|
33
|
+
Attributes
|
|
34
|
+
----------
|
|
35
|
+
ds : xr.Dataset
|
|
36
|
+
Xarray Dataset containing the atmospheric forcing data.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
grid: Grid
|
|
40
|
+
N: int
|
|
41
|
+
theta_s: float
|
|
42
|
+
theta_b: float
|
|
43
|
+
hc: float
|
|
44
|
+
|
|
45
|
+
ds: xr.Dataset = field(init=False, repr=False)
|
|
46
|
+
|
|
47
|
+
def __post_init__(self):
|
|
48
|
+
|
|
49
|
+
h = self.grid.ds.h
|
|
50
|
+
|
|
51
|
+
cs_r, sigma_r = sigma_stretch(self.theta_s, self.theta_b, self.N, "r")
|
|
52
|
+
zr = compute_depth(h * 0, h, self.hc, cs_r, sigma_r)
|
|
53
|
+
cs_w, sigma_w = sigma_stretch(self.theta_s, self.theta_b, self.N, "w")
|
|
54
|
+
zw = compute_depth(h * 0, h, self.hc, cs_w, sigma_w)
|
|
55
|
+
|
|
56
|
+
ds = xr.Dataset()
|
|
57
|
+
|
|
58
|
+
ds["theta_s"] = np.float32(self.theta_s)
|
|
59
|
+
ds["theta_s"].attrs["long_name"] = "S-coordinate surface control parameter"
|
|
60
|
+
ds["theta_s"].attrs["units"] = "nondimensional"
|
|
61
|
+
|
|
62
|
+
ds["theta_b"] = np.float32(self.theta_b)
|
|
63
|
+
ds["theta_b"].attrs["long_name"] = "S-coordinate bottom control parameter"
|
|
64
|
+
ds["theta_b"].attrs["units"] = "nondimensional"
|
|
65
|
+
|
|
66
|
+
ds["Tcline"] = np.float32(self.hc)
|
|
67
|
+
ds["Tcline"].attrs["long_name"] = "S-coordinate surface/bottom layer width"
|
|
68
|
+
ds["Tcline"].attrs["units"] = "m"
|
|
69
|
+
|
|
70
|
+
ds["hc"] = np.float32(self.hc)
|
|
71
|
+
ds["hc"].attrs["long_name"] = "S-coordinate parameter critical depth"
|
|
72
|
+
ds["hc"].attrs["units"] = "m"
|
|
73
|
+
|
|
74
|
+
ds["sc_r"] = sigma_r.astype(np.float32)
|
|
75
|
+
ds["sc_r"].attrs["long_name"] = "S-coordinate at rho-points"
|
|
76
|
+
ds["sc_r"].attrs["units"] = "nondimensional"
|
|
77
|
+
|
|
78
|
+
ds["Cs_r"] = cs_r.astype(np.float32)
|
|
79
|
+
ds["Cs_r"].attrs["long_name"] = "S-coordinate stretching curves at rho-points"
|
|
80
|
+
ds["Cs_r"].attrs["units"] = "nondimensional"
|
|
81
|
+
|
|
82
|
+
depth = -zr
|
|
83
|
+
depth.attrs["long_name"] = "Layer depth at rho-points"
|
|
84
|
+
depth.attrs["units"] = "m"
|
|
85
|
+
ds = ds.assign_coords({"layer_depth_rho": depth.astype(np.float32)})
|
|
86
|
+
|
|
87
|
+
depth_u = interpolate_from_rho_to_u(depth).astype(np.float32)
|
|
88
|
+
depth_u.attrs["long_name"] = "Layer depth at u-points"
|
|
89
|
+
depth_u.attrs["units"] = "m"
|
|
90
|
+
ds = ds.assign_coords({"layer_depth_u": depth_u})
|
|
91
|
+
|
|
92
|
+
depth_v = interpolate_from_rho_to_v(depth).astype(np.float32)
|
|
93
|
+
depth_v.attrs["long_name"] = "Layer depth at v-points"
|
|
94
|
+
depth_v.attrs["units"] = "m"
|
|
95
|
+
ds = ds.assign_coords({"layer_depth_v": depth_v})
|
|
96
|
+
|
|
97
|
+
depth = -zw
|
|
98
|
+
depth.attrs["long_name"] = "Interface depth at rho-points"
|
|
99
|
+
depth.attrs["units"] = "m"
|
|
100
|
+
ds = ds.assign_coords({"interface_depth_rho": depth.astype(np.float32)})
|
|
101
|
+
|
|
102
|
+
depth_u = interpolate_from_rho_to_u(depth).astype(np.float32)
|
|
103
|
+
depth_u.attrs["long_name"] = "Interface depth at u-points"
|
|
104
|
+
depth_u.attrs["units"] = "m"
|
|
105
|
+
ds = ds.assign_coords({"interface_depth_u": depth_u})
|
|
106
|
+
|
|
107
|
+
depth_v = interpolate_from_rho_to_v(depth).astype(np.float32)
|
|
108
|
+
depth_v.attrs["long_name"] = "Interface depth at v-points"
|
|
109
|
+
depth_v.attrs["units"] = "m"
|
|
110
|
+
ds = ds.assign_coords({"interface_depth_v": depth_v})
|
|
111
|
+
|
|
112
|
+
ds = ds.drop_vars(["eta_rho", "xi_rho"])
|
|
113
|
+
|
|
114
|
+
ds.attrs["title"] = "ROMS vertical coordinate file created by ROMS-Tools"
|
|
115
|
+
# Include the version of roms-tools
|
|
116
|
+
try:
|
|
117
|
+
roms_tools_version = importlib.metadata.version("roms-tools")
|
|
118
|
+
except importlib.metadata.PackageNotFoundError:
|
|
119
|
+
roms_tools_version = "unknown"
|
|
120
|
+
ds.attrs["roms_tools_version"] = roms_tools_version
|
|
121
|
+
|
|
122
|
+
object.__setattr__(self, "ds", ds)
|
|
123
|
+
|
|
124
|
+
def plot(
|
|
125
|
+
self,
|
|
126
|
+
varname="layer_depth_rho",
|
|
127
|
+
s=None,
|
|
128
|
+
eta=None,
|
|
129
|
+
xi=None,
|
|
130
|
+
) -> None:
|
|
131
|
+
"""
|
|
132
|
+
Plot the vertical coordinate system for a given eta-, xi-, or s-slice.
|
|
133
|
+
|
|
134
|
+
Parameters
|
|
135
|
+
----------
|
|
136
|
+
varname : str, optional
|
|
137
|
+
The field to plot. Options are "depth_rho", "depth_u", "depth_v".
|
|
138
|
+
s: int, optional
|
|
139
|
+
The s-index to plot. Default is None.
|
|
140
|
+
eta : int, optional
|
|
141
|
+
The eta-index to plot. Default is None.
|
|
142
|
+
xi : int, optional
|
|
143
|
+
The xi-index to plot. Default is None.
|
|
144
|
+
|
|
145
|
+
Returns
|
|
146
|
+
-------
|
|
147
|
+
None
|
|
148
|
+
This method does not return any value. It generates and displays a plot.
|
|
149
|
+
|
|
150
|
+
Raises
|
|
151
|
+
------
|
|
152
|
+
ValueError
|
|
153
|
+
If the specified varname is not one of the valid options.
|
|
154
|
+
If none of s, eta, xi are specified.
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
if not any([s is not None, eta is not None, xi is not None]):
|
|
158
|
+
raise ValueError("At least one of s, eta, or xi must be specified.")
|
|
159
|
+
|
|
160
|
+
self.ds[varname].load()
|
|
161
|
+
field = self.ds[varname].squeeze()
|
|
162
|
+
|
|
163
|
+
if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
|
|
164
|
+
interface_depth = self.ds.interface_depth_rho
|
|
165
|
+
elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
|
|
166
|
+
interface_depth = self.ds.interface_depth_u
|
|
167
|
+
elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
|
|
168
|
+
interface_depth = self.ds.interface_depth_v
|
|
169
|
+
|
|
170
|
+
# slice the field as desired
|
|
171
|
+
title = field.long_name
|
|
172
|
+
if s is not None:
|
|
173
|
+
if "s_rho" in field.dims:
|
|
174
|
+
title = title + f", s_rho = {field.s_rho[s].item()}"
|
|
175
|
+
field = field.isel(s_rho=s)
|
|
176
|
+
elif "s_w" in field.dims:
|
|
177
|
+
title = title + f", s_w = {field.s_w[s].item()}"
|
|
178
|
+
field = field.isel(s_w=s)
|
|
179
|
+
else:
|
|
180
|
+
raise ValueError(
|
|
181
|
+
f"None of the expected dimensions (s_rho, s_w) found in ds[{varname}]."
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
if eta is not None:
|
|
185
|
+
if "eta_rho" in field.dims:
|
|
186
|
+
title = title + f", eta_rho = {field.eta_rho[eta].item()}"
|
|
187
|
+
field = field.isel(eta_rho=eta)
|
|
188
|
+
interface_depth = interface_depth.isel(eta_rho=eta)
|
|
189
|
+
elif "eta_v" in field.dims:
|
|
190
|
+
title = title + f", eta_v = {field.eta_v[eta].item()}"
|
|
191
|
+
field = field.isel(eta_v=eta)
|
|
192
|
+
interface_depth = interface_depth.isel(eta_v=eta)
|
|
193
|
+
else:
|
|
194
|
+
raise ValueError(
|
|
195
|
+
f"None of the expected dimensions (eta_rho, eta_v) found in ds[{varname}]."
|
|
196
|
+
)
|
|
197
|
+
if xi is not None:
|
|
198
|
+
if "xi_rho" in field.dims:
|
|
199
|
+
title = title + f", xi_rho = {field.xi_rho[xi].item()}"
|
|
200
|
+
field = field.isel(xi_rho=xi)
|
|
201
|
+
interface_depth = interface_depth.isel(xi_rho=xi)
|
|
202
|
+
elif "xi_u" in field.dims:
|
|
203
|
+
title = title + f", xi_u = {field.xi_u[xi].item()}"
|
|
204
|
+
field = field.isel(xi_u=xi)
|
|
205
|
+
interface_depth = interface_depth.isel(xi_u=xi)
|
|
206
|
+
else:
|
|
207
|
+
raise ValueError(
|
|
208
|
+
f"None of the expected dimensions (xi_rho, xi_u) found in ds[{varname}]."
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
if eta is None and xi is None:
|
|
212
|
+
vmax = field.max().values
|
|
213
|
+
vmin = field.min().values
|
|
214
|
+
cmap = plt.colormaps.get_cmap("YlGnBu")
|
|
215
|
+
cmap.set_bad(color="gray")
|
|
216
|
+
kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
|
|
217
|
+
|
|
218
|
+
_plot(
|
|
219
|
+
self.grid.ds,
|
|
220
|
+
field=field,
|
|
221
|
+
straddle=self.grid.straddle,
|
|
222
|
+
depth_contours=True,
|
|
223
|
+
title=title,
|
|
224
|
+
kwargs=kwargs,
|
|
225
|
+
c="g",
|
|
226
|
+
)
|
|
227
|
+
else:
|
|
228
|
+
if len(field.dims) == 2:
|
|
229
|
+
cmap = plt.colormaps.get_cmap("YlGnBu")
|
|
230
|
+
cmap.set_bad(color="gray")
|
|
231
|
+
kwargs = {"vmax": 0.0, "vmin": 0.0, "cmap": cmap, "add_colorbar": False}
|
|
232
|
+
|
|
233
|
+
_section_plot(
|
|
234
|
+
xr.zeros_like(field),
|
|
235
|
+
interface_depth=interface_depth,
|
|
236
|
+
title=title,
|
|
237
|
+
kwargs=kwargs,
|
|
238
|
+
)
|
|
239
|
+
else:
|
|
240
|
+
if "s_rho" in field.dims or "s_w" in field.dims:
|
|
241
|
+
_profile_plot(field, title=title)
|
|
242
|
+
else:
|
|
243
|
+
_line_plot(field, title=title)
|
|
244
|
+
|
|
245
|
+
def save(self, filepath: str) -> None:
|
|
246
|
+
"""
|
|
247
|
+
Save the vertical coordinate information to a netCDF4 file.
|
|
248
|
+
|
|
249
|
+
Parameters
|
|
250
|
+
----------
|
|
251
|
+
filepath
|
|
252
|
+
"""
|
|
253
|
+
self.ds.to_netcdf(filepath)
|
|
254
|
+
|
|
255
|
+
def to_yaml(self, filepath: str) -> None:
|
|
256
|
+
"""
|
|
257
|
+
Export the parameters of the class to a YAML file, including the version of roms-tools.
|
|
258
|
+
|
|
259
|
+
Parameters
|
|
260
|
+
----------
|
|
261
|
+
filepath : str
|
|
262
|
+
The path to the YAML file where the parameters will be saved.
|
|
263
|
+
"""
|
|
264
|
+
# Serialize Grid data
|
|
265
|
+
grid_data = asdict(self.grid)
|
|
266
|
+
grid_data.pop("ds", None) # Exclude non-serializable fields
|
|
267
|
+
grid_data.pop("straddle", None)
|
|
268
|
+
|
|
269
|
+
# Include the version of roms-tools
|
|
270
|
+
try:
|
|
271
|
+
roms_tools_version = importlib.metadata.version("roms-tools")
|
|
272
|
+
except importlib.metadata.PackageNotFoundError:
|
|
273
|
+
roms_tools_version = "unknown"
|
|
274
|
+
|
|
275
|
+
# Create header
|
|
276
|
+
header = f"---\nroms_tools_version: {roms_tools_version}\n---\n"
|
|
277
|
+
|
|
278
|
+
grid_yaml_data = {"Grid": grid_data}
|
|
279
|
+
|
|
280
|
+
# Combine all sections
|
|
281
|
+
vertical_coordinate_data = {
|
|
282
|
+
"VerticalCoordinate": {
|
|
283
|
+
"N": self.N,
|
|
284
|
+
"theta_s": self.theta_s,
|
|
285
|
+
"theta_b": self.theta_b,
|
|
286
|
+
"hc": self.hc,
|
|
287
|
+
}
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
# Merge YAML data while excluding empty sections
|
|
291
|
+
yaml_data = {
|
|
292
|
+
**grid_yaml_data,
|
|
293
|
+
**vertical_coordinate_data,
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
with open(filepath, "w") as file:
|
|
297
|
+
# Write header
|
|
298
|
+
file.write(header)
|
|
299
|
+
# Write YAML data
|
|
300
|
+
yaml.dump(yaml_data, file, default_flow_style=False)
|
|
301
|
+
|
|
302
|
+
@classmethod
|
|
303
|
+
def from_file(cls, filepath: str) -> "VerticalCoordinate":
|
|
304
|
+
"""
|
|
305
|
+
Create a VerticalCoordinate instance from an existing file.
|
|
306
|
+
|
|
307
|
+
Parameters
|
|
308
|
+
----------
|
|
309
|
+
filepath : str
|
|
310
|
+
Path to the file containing the vertical coordinate information.
|
|
311
|
+
|
|
312
|
+
Returns
|
|
313
|
+
-------
|
|
314
|
+
VerticalCoordinate
|
|
315
|
+
A new instance of VerticalCoordinate populated with data from the file.
|
|
316
|
+
"""
|
|
317
|
+
# Load the dataset from the file
|
|
318
|
+
ds = xr.open_dataset(filepath)
|
|
319
|
+
|
|
320
|
+
# Create a new VerticalCoordinate instance without calling __init__ and __post_init__
|
|
321
|
+
vertical_coordinate = cls.__new__(cls)
|
|
322
|
+
|
|
323
|
+
# Set the dataset for the vertical_corodinate instance
|
|
324
|
+
object.__setattr__(vertical_coordinate, "ds", ds)
|
|
325
|
+
|
|
326
|
+
# Manually set the remaining attributes by extracting parameters from dataset
|
|
327
|
+
object.__setattr__(vertical_coordinate, "N", ds.sizes["s_rho"])
|
|
328
|
+
object.__setattr__(vertical_coordinate, "theta_s", ds["theta_s"].values.item())
|
|
329
|
+
object.__setattr__(vertical_coordinate, "theta_b", ds["theta_b"].values.item())
|
|
330
|
+
object.__setattr__(vertical_coordinate, "hc", ds["hc"].values.item())
|
|
331
|
+
object.__setattr__(vertical_coordinate, "grid", None)
|
|
332
|
+
|
|
333
|
+
return vertical_coordinate
|
|
334
|
+
|
|
335
|
+
@classmethod
|
|
336
|
+
def from_yaml(cls, filepath: str) -> "VerticalCoordinate":
|
|
337
|
+
"""
|
|
338
|
+
Create an instance of the VerticalCoordinate class from a YAML file.
|
|
339
|
+
|
|
340
|
+
Parameters
|
|
341
|
+
----------
|
|
342
|
+
filepath : str
|
|
343
|
+
The path to the YAML file from which the parameters will be read.
|
|
344
|
+
|
|
345
|
+
Returns
|
|
346
|
+
-------
|
|
347
|
+
VerticalCoordinate
|
|
348
|
+
An instance of the VerticalCoordinate class.
|
|
349
|
+
"""
|
|
350
|
+
# Read the entire file content
|
|
351
|
+
with open(filepath, "r") as file:
|
|
352
|
+
file_content = file.read()
|
|
353
|
+
|
|
354
|
+
# Split the content into YAML documents
|
|
355
|
+
documents = list(yaml.safe_load_all(file_content))
|
|
356
|
+
|
|
357
|
+
vertical_coordinate_data = None
|
|
358
|
+
|
|
359
|
+
# Process the YAML documents
|
|
360
|
+
for doc in documents:
|
|
361
|
+
if doc is None:
|
|
362
|
+
continue
|
|
363
|
+
if "VerticalCoordinate" in doc:
|
|
364
|
+
vertical_coordinate_data = doc["VerticalCoordinate"]
|
|
365
|
+
break
|
|
366
|
+
|
|
367
|
+
if vertical_coordinate_data is None:
|
|
368
|
+
raise ValueError(
|
|
369
|
+
"No VerticalCoordinate configuration found in the YAML file."
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
# Create Grid instance from the YAML file
|
|
373
|
+
grid = Grid.from_yaml(filepath)
|
|
374
|
+
|
|
375
|
+
# Create and return an instance of TidalForcing
|
|
376
|
+
return cls(
|
|
377
|
+
grid=grid,
|
|
378
|
+
**vertical_coordinate_data,
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def compute_cs(sigma, theta_s, theta_b):
|
|
383
|
+
"""
|
|
384
|
+
Compute the S-coordinate stretching curves according to Shchepetkin and McWilliams (2009).
|
|
385
|
+
|
|
386
|
+
Parameters
|
|
387
|
+
----------
|
|
388
|
+
sigma : np.ndarray or float
|
|
389
|
+
The sigma-coordinate values.
|
|
390
|
+
theta_s : float
|
|
391
|
+
The surface control parameter.
|
|
392
|
+
theta_b : float
|
|
393
|
+
The bottom control parameter.
|
|
394
|
+
|
|
395
|
+
Returns
|
|
396
|
+
-------
|
|
397
|
+
C : np.ndarray or float
|
|
398
|
+
The stretching curve values.
|
|
399
|
+
|
|
400
|
+
Raises
|
|
401
|
+
------
|
|
402
|
+
ValueError
|
|
403
|
+
If theta_s or theta_b are not within the valid range.
|
|
404
|
+
"""
|
|
405
|
+
if not (0 < theta_s <= 10):
|
|
406
|
+
raise ValueError("theta_s must be between 0 and 10.")
|
|
407
|
+
if not (0 < theta_b <= 4):
|
|
408
|
+
raise ValueError("theta_b must be between 0 and 4.")
|
|
409
|
+
|
|
410
|
+
C = (1 - np.cosh(theta_s * sigma)) / (np.cosh(theta_s) - 1)
|
|
411
|
+
C = (np.exp(theta_b * C) - 1) / (1 - np.exp(-theta_b))
|
|
412
|
+
|
|
413
|
+
return C
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
def sigma_stretch(theta_s, theta_b, N, type):
|
|
417
|
+
"""
|
|
418
|
+
Compute sigma and stretching curves based on the type and parameters.
|
|
419
|
+
|
|
420
|
+
Parameters
|
|
421
|
+
----------
|
|
422
|
+
theta_s : float
|
|
423
|
+
The surface control parameter.
|
|
424
|
+
theta_b : float
|
|
425
|
+
The bottom control parameter.
|
|
426
|
+
N : int
|
|
427
|
+
The number of vertical levels.
|
|
428
|
+
type : str
|
|
429
|
+
The type of sigma ('w' for vertical velocity points, 'r' for rho-points).
|
|
430
|
+
|
|
431
|
+
Returns
|
|
432
|
+
-------
|
|
433
|
+
cs : xr.DataArray
|
|
434
|
+
The stretching curve values.
|
|
435
|
+
sigma : xr.DataArray
|
|
436
|
+
The sigma-coordinate values.
|
|
437
|
+
|
|
438
|
+
Raises
|
|
439
|
+
------
|
|
440
|
+
ValueError
|
|
441
|
+
If the type is not 'w' or 'r'.
|
|
442
|
+
"""
|
|
443
|
+
if type == "w":
|
|
444
|
+
k = xr.DataArray(np.arange(N + 1), dims="s_w")
|
|
445
|
+
sigma = (k - N) / N
|
|
446
|
+
elif type == "r":
|
|
447
|
+
k = xr.DataArray(np.arange(1, N + 1), dims="s_rho")
|
|
448
|
+
sigma = (k - N - 0.5) / N
|
|
449
|
+
else:
|
|
450
|
+
raise ValueError(
|
|
451
|
+
"Type must be either 'w' for vertical velocity points or 'r' for rho-points."
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
cs = compute_cs(sigma, theta_s, theta_b)
|
|
455
|
+
|
|
456
|
+
return cs, sigma
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def compute_depth(zeta, h, hc, cs, sigma):
|
|
460
|
+
"""
|
|
461
|
+
Compute the depth at different sigma levels.
|
|
462
|
+
|
|
463
|
+
Parameters
|
|
464
|
+
----------
|
|
465
|
+
zeta : xr.DataArray
|
|
466
|
+
The sea surface height.
|
|
467
|
+
h : xr.DataArray
|
|
468
|
+
The depth of the sea bottom.
|
|
469
|
+
hc : float
|
|
470
|
+
The critical depth.
|
|
471
|
+
cs : xr.DataArray
|
|
472
|
+
The stretching curve values.
|
|
473
|
+
sigma : xr.DataArray
|
|
474
|
+
The sigma-coordinate values.
|
|
475
|
+
|
|
476
|
+
Returns
|
|
477
|
+
-------
|
|
478
|
+
z : xr.DataArray
|
|
479
|
+
The depth at different sigma levels.
|
|
480
|
+
|
|
481
|
+
Raises
|
|
482
|
+
------
|
|
483
|
+
ValueError
|
|
484
|
+
If theta_s or theta_b are less than or equal to zero.
|
|
485
|
+
"""
|
|
486
|
+
|
|
487
|
+
# Expand dimensions
|
|
488
|
+
sigma = sigma.expand_dims(dim={"eta_rho": h.eta_rho, "xi_rho": h.xi_rho})
|
|
489
|
+
cs = cs.expand_dims(dim={"eta_rho": h.eta_rho, "xi_rho": h.xi_rho})
|
|
490
|
+
|
|
491
|
+
s = (hc * sigma + h * cs) / (hc + h)
|
|
492
|
+
z = zeta + (zeta + h) * s
|
|
493
|
+
|
|
494
|
+
return z
|