roms-tools 0.0.6__py3-none-any.whl → 0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ci/environment.yml +29 -0
- roms_tools/__init__.py +6 -0
- roms_tools/_version.py +1 -1
- roms_tools/setup/atmospheric_forcing.py +935 -0
- roms_tools/setup/boundary_forcing.py +711 -0
- roms_tools/setup/datasets.py +457 -0
- roms_tools/setup/fill.py +376 -0
- roms_tools/setup/grid.py +610 -325
- roms_tools/setup/initial_conditions.py +528 -0
- roms_tools/setup/plot.py +203 -0
- roms_tools/setup/tides.py +809 -0
- roms_tools/setup/topography.py +257 -0
- roms_tools/setup/utils.py +162 -0
- roms_tools/setup/vertical_coordinate.py +494 -0
- roms_tools/tests/test_atmospheric_forcing.py +1645 -0
- roms_tools/tests/test_boundary_forcing.py +332 -0
- roms_tools/tests/test_datasets.py +306 -0
- roms_tools/tests/test_grid.py +226 -0
- roms_tools/tests/test_initial_conditions.py +300 -0
- roms_tools/tests/test_tides.py +366 -0
- roms_tools/tests/test_topography.py +78 -0
- roms_tools/tests/test_vertical_coordinate.py +337 -0
- roms_tools-0.20.dist-info/METADATA +90 -0
- roms_tools-0.20.dist-info/RECORD +28 -0
- {roms_tools-0.0.6.dist-info → roms_tools-0.20.dist-info}/WHEEL +1 -1
- {roms_tools-0.0.6.dist-info → roms_tools-0.20.dist-info}/top_level.txt +1 -0
- roms_tools/tests/test_setup.py +0 -54
- roms_tools-0.0.6.dist-info/METADATA +0 -134
- roms_tools-0.0.6.dist-info/RECORD +0 -10
- {roms_tools-0.0.6.dist-info → roms_tools-0.20.dist-info}/LICENSE +0 -0
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
import xarray as xr
|
|
2
|
+
import numpy as np
|
|
3
|
+
import gcm_filters
|
|
4
|
+
from scipy.interpolate import RegularGridInterpolator
|
|
5
|
+
from scipy.ndimage import label
|
|
6
|
+
from roms_tools.setup.datasets import fetch_topo
|
|
7
|
+
from roms_tools.setup.utils import interpolate_from_rho_to_u, interpolate_from_rho_to_v
|
|
8
|
+
import warnings
|
|
9
|
+
from itertools import count
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _add_topography_and_mask(
|
|
13
|
+
ds, topography_source, smooth_factor, hmin, rmax
|
|
14
|
+
) -> xr.Dataset:
|
|
15
|
+
lon = ds.lon_rho.values
|
|
16
|
+
lat = ds.lat_rho.values
|
|
17
|
+
|
|
18
|
+
# interpolate topography onto desired grid
|
|
19
|
+
hraw = _make_raw_topography(lon, lat, topography_source)
|
|
20
|
+
hraw = xr.DataArray(data=hraw, dims=["eta_rho", "xi_rho"])
|
|
21
|
+
|
|
22
|
+
# Mask is obtained by finding locations where ocean depth is positive
|
|
23
|
+
mask = xr.where(hraw > 0, 1.0, 0.0)
|
|
24
|
+
|
|
25
|
+
# smooth topography domain-wide with Gaussian kernel to avoid grid scale instabilities
|
|
26
|
+
ds["hraw"] = _smooth_topography_globally(hraw, mask, smooth_factor)
|
|
27
|
+
ds["hraw"].attrs = {
|
|
28
|
+
"long_name": "Working bathymetry at rho-points",
|
|
29
|
+
"source": f"Raw bathymetry from {topography_source} (smoothing diameter {smooth_factor})",
|
|
30
|
+
"units": "meter",
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
# fill enclosed basins with land
|
|
34
|
+
mask = _fill_enclosed_basins(mask.values)
|
|
35
|
+
ds["mask_rho"] = xr.DataArray(mask, dims=("eta_rho", "xi_rho"))
|
|
36
|
+
ds["mask_rho"].attrs = {
|
|
37
|
+
"long_name": "Mask at rho-points",
|
|
38
|
+
"units": "land/water (0/1)",
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
ds = _add_velocity_masks(ds)
|
|
42
|
+
|
|
43
|
+
# smooth topography locally to satisfy r < rmax
|
|
44
|
+
ds["h"] = _smooth_topography_locally(ds["hraw"] * ds["mask_rho"], hmin, rmax)
|
|
45
|
+
ds["h"].attrs = {
|
|
46
|
+
"long_name": "Final bathymetry at rho-points",
|
|
47
|
+
"units": "meter",
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
ds = _add_topography_metadata(ds, topography_source, smooth_factor, hmin, rmax)
|
|
51
|
+
|
|
52
|
+
return ds
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _make_raw_topography(lon, lat, topography_source) -> np.ndarray:
|
|
56
|
+
"""
|
|
57
|
+
Given a grid of (lon, lat) points, fetch the topography file and interpolate height values onto the desired grid.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
topo_ds = fetch_topo(topography_source)
|
|
61
|
+
|
|
62
|
+
# the following will depend on the topography source
|
|
63
|
+
if topography_source == "ETOPO5":
|
|
64
|
+
topo_lon = topo_ds["topo_lon"].copy()
|
|
65
|
+
# Modify longitude values where necessary
|
|
66
|
+
topo_lon = xr.where(topo_lon < 0, topo_lon + 360, topo_lon)
|
|
67
|
+
topo_lon_minus360 = topo_lon - 360
|
|
68
|
+
topo_lon_plus360 = topo_lon + 360
|
|
69
|
+
# Concatenate along the longitude axis
|
|
70
|
+
topo_lon_concatenated = xr.concat(
|
|
71
|
+
[topo_lon_minus360, topo_lon, topo_lon_plus360], dim="lon"
|
|
72
|
+
)
|
|
73
|
+
topo_concatenated = xr.concat(
|
|
74
|
+
[-topo_ds["topo"], -topo_ds["topo"], -topo_ds["topo"]], dim="lon"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
interp = RegularGridInterpolator(
|
|
78
|
+
(topo_ds["topo_lat"].values, topo_lon_concatenated.values),
|
|
79
|
+
topo_concatenated.values,
|
|
80
|
+
method="linear",
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# Interpolate onto desired domain grid points
|
|
84
|
+
hraw = interp((lat, lon))
|
|
85
|
+
|
|
86
|
+
return hraw
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _smooth_topography_globally(hraw, wet_mask, factor) -> xr.DataArray:
|
|
90
|
+
# since GCM-Filters assumes periodic domain, we extend the domain by one grid cell in each dimension
|
|
91
|
+
# and set that margin to land
|
|
92
|
+
margin_mask = xr.concat([wet_mask, 0 * wet_mask.isel(eta_rho=-1)], dim="eta_rho")
|
|
93
|
+
margin_mask = xr.concat(
|
|
94
|
+
[margin_mask, 0 * margin_mask.isel(xi_rho=-1)], dim="xi_rho"
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# we choose a Gaussian filter kernel corresponding to a Gaussian with standard deviation factor/sqrt(12);
|
|
98
|
+
# this standard deviation matches the standard deviation of a boxcar kernel with total width equal to factor.
|
|
99
|
+
filter = gcm_filters.Filter(
|
|
100
|
+
filter_scale=factor,
|
|
101
|
+
dx_min=1,
|
|
102
|
+
filter_shape=gcm_filters.FilterShape.GAUSSIAN,
|
|
103
|
+
grid_type=gcm_filters.GridType.REGULAR_WITH_LAND,
|
|
104
|
+
grid_vars={"wet_mask": margin_mask},
|
|
105
|
+
)
|
|
106
|
+
hraw_extended = xr.concat([hraw, hraw.isel(eta_rho=-1)], dim="eta_rho")
|
|
107
|
+
hraw_extended = xr.concat(
|
|
108
|
+
[hraw_extended, hraw_extended.isel(xi_rho=-1)], dim="xi_rho"
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
hsmooth = filter.apply(hraw_extended, dims=["eta_rho", "xi_rho"])
|
|
112
|
+
hsmooth = hsmooth.isel(eta_rho=slice(None, -1), xi_rho=slice(None, -1))
|
|
113
|
+
|
|
114
|
+
return hsmooth
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _fill_enclosed_basins(mask) -> np.ndarray:
|
|
118
|
+
"""
|
|
119
|
+
Fills in enclosed basins with land
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
# Label connected regions in the mask
|
|
123
|
+
reg, nreg = label(mask)
|
|
124
|
+
# Find the largest region
|
|
125
|
+
lint = 0
|
|
126
|
+
lreg = 0
|
|
127
|
+
for ireg in range(nreg):
|
|
128
|
+
int_ = np.sum(reg == ireg)
|
|
129
|
+
if int_ > lint and mask[reg == ireg].sum() > 0:
|
|
130
|
+
lreg = ireg
|
|
131
|
+
lint = int_
|
|
132
|
+
|
|
133
|
+
# Remove regions other than the largest one
|
|
134
|
+
for ireg in range(nreg):
|
|
135
|
+
if ireg != lreg:
|
|
136
|
+
mask[reg == ireg] = 0
|
|
137
|
+
|
|
138
|
+
return mask
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _smooth_topography_locally(h, hmin=5, rmax=0.2):
|
|
142
|
+
"""
|
|
143
|
+
Smoothes topography locally to satisfy r < rmax
|
|
144
|
+
"""
|
|
145
|
+
# Compute rmax_log
|
|
146
|
+
if rmax > 0.0:
|
|
147
|
+
rmax_log = np.log((1.0 + rmax * 0.9) / (1.0 - rmax * 0.9))
|
|
148
|
+
else:
|
|
149
|
+
rmax_log = 0.0
|
|
150
|
+
|
|
151
|
+
# Apply hmin threshold
|
|
152
|
+
h = xr.where(h < hmin, hmin, h)
|
|
153
|
+
|
|
154
|
+
# We will smooth logarithmically
|
|
155
|
+
h_log = np.log(h / hmin)
|
|
156
|
+
|
|
157
|
+
cf1 = 1.0 / 6
|
|
158
|
+
cf2 = 0.25
|
|
159
|
+
|
|
160
|
+
for iter in count():
|
|
161
|
+
# Compute gradients in domain interior
|
|
162
|
+
|
|
163
|
+
# in eta-direction
|
|
164
|
+
cff = h_log.diff("eta_rho").isel(xi_rho=slice(1, -1))
|
|
165
|
+
cr = np.abs(cff)
|
|
166
|
+
with warnings.catch_warnings():
|
|
167
|
+
warnings.simplefilter("ignore") # Ignore division by zero warning
|
|
168
|
+
Op1 = xr.where(cr < rmax_log, 0, 1.0 * cff * (1 - rmax_log / cr))
|
|
169
|
+
|
|
170
|
+
# in xi-direction
|
|
171
|
+
cff = h_log.diff("xi_rho").isel(eta_rho=slice(1, -1))
|
|
172
|
+
cr = np.abs(cff)
|
|
173
|
+
with warnings.catch_warnings():
|
|
174
|
+
warnings.simplefilter("ignore") # Ignore division by zero warning
|
|
175
|
+
Op2 = xr.where(cr < rmax_log, 0, 1.0 * cff * (1 - rmax_log / cr))
|
|
176
|
+
|
|
177
|
+
# in diagonal direction
|
|
178
|
+
cff = (h_log - h_log.shift(eta_rho=1, xi_rho=1)).isel(
|
|
179
|
+
eta_rho=slice(1, None), xi_rho=slice(1, None)
|
|
180
|
+
)
|
|
181
|
+
cr = np.abs(cff)
|
|
182
|
+
with warnings.catch_warnings():
|
|
183
|
+
warnings.simplefilter("ignore") # Ignore division by zero warning
|
|
184
|
+
Op3 = xr.where(cr < rmax_log, 0, 1.0 * cff * (1 - rmax_log / cr))
|
|
185
|
+
|
|
186
|
+
# in the other diagonal direction
|
|
187
|
+
cff = (h_log.shift(eta_rho=1) - h_log.shift(xi_rho=1)).isel(
|
|
188
|
+
eta_rho=slice(1, None), xi_rho=slice(1, None)
|
|
189
|
+
)
|
|
190
|
+
cr = np.abs(cff)
|
|
191
|
+
with warnings.catch_warnings():
|
|
192
|
+
warnings.simplefilter("ignore") # Ignore division by zero warning
|
|
193
|
+
Op4 = xr.where(cr < rmax_log, 0, 1.0 * cff * (1 - rmax_log / cr))
|
|
194
|
+
|
|
195
|
+
# Update h_log in domain interior
|
|
196
|
+
h_log[1:-1, 1:-1] += cf1 * (
|
|
197
|
+
Op1[1:, :]
|
|
198
|
+
- Op1[:-1, :]
|
|
199
|
+
+ Op2[:, 1:]
|
|
200
|
+
- Op2[:, :-1]
|
|
201
|
+
+ cf2 * (Op3[1:, 1:] - Op3[:-1, :-1] + Op4[:-1, 1:] - Op4[1:, :-1])
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
# No gradient at the domain boundaries
|
|
205
|
+
h_log[0, :] = h_log[1, :]
|
|
206
|
+
h_log[-1, :] = h_log[-2, :]
|
|
207
|
+
h_log[:, 0] = h_log[:, 1]
|
|
208
|
+
h_log[:, -1] = h_log[:, -2]
|
|
209
|
+
|
|
210
|
+
# Update h
|
|
211
|
+
h = hmin * np.exp(h_log)
|
|
212
|
+
# Apply hmin threshold again
|
|
213
|
+
h = xr.where(h < hmin, hmin, h)
|
|
214
|
+
|
|
215
|
+
# compute maximum slope parameter r
|
|
216
|
+
r_eta, r_xi = _compute_rfactor(h)
|
|
217
|
+
rmax0 = np.max([r_eta.max(), r_xi.max()])
|
|
218
|
+
if rmax0 < rmax:
|
|
219
|
+
break
|
|
220
|
+
|
|
221
|
+
return h
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def _compute_rfactor(h):
|
|
225
|
+
"""
|
|
226
|
+
Computes slope parameter (or r-factor) r = |Delta h| / 2h in both horizontal grid directions.
|
|
227
|
+
"""
|
|
228
|
+
# compute r_{i-1/2} = |h_i - h_{i-1}| / (h_i + h_{i+1})
|
|
229
|
+
r_eta = np.abs(h.diff("eta_rho")) / (h + h.shift(eta_rho=1)).isel(
|
|
230
|
+
eta_rho=slice(1, None)
|
|
231
|
+
)
|
|
232
|
+
r_xi = np.abs(h.diff("xi_rho")) / (h + h.shift(xi_rho=1)).isel(
|
|
233
|
+
xi_rho=slice(1, None)
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
return r_eta, r_xi
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def _add_topography_metadata(ds, topography_source, smooth_factor, hmin, rmax):
|
|
240
|
+
ds.attrs["topography_source"] = topography_source
|
|
241
|
+
ds.attrs["smooth_factor"] = smooth_factor
|
|
242
|
+
ds.attrs["hmin"] = hmin
|
|
243
|
+
ds.attrs["rmax"] = rmax
|
|
244
|
+
|
|
245
|
+
return ds
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def _add_velocity_masks(ds):
|
|
249
|
+
|
|
250
|
+
# add u- and v-masks
|
|
251
|
+
ds["mask_u"] = interpolate_from_rho_to_u(ds["mask_rho"], method="multiplicative")
|
|
252
|
+
ds["mask_v"] = interpolate_from_rho_to_v(ds["mask_rho"], method="multiplicative")
|
|
253
|
+
|
|
254
|
+
ds["mask_u"].attrs = {"long_name": "Mask at u-points", "units": "land/water (0/1)"}
|
|
255
|
+
ds["mask_v"].attrs = {"long_name": "Mask at v-points", "units": "land/water (0/1)"}
|
|
256
|
+
|
|
257
|
+
return ds
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
import xarray as xr
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def nan_check(field, mask) -> None:
|
|
5
|
+
"""
|
|
6
|
+
Checks for NaN values at wet points in the field.
|
|
7
|
+
|
|
8
|
+
This function examines the interpolated input field for NaN values at positions indicated as wet points by the mask.
|
|
9
|
+
If any NaN values are found at these wet points, a ValueError is raised.
|
|
10
|
+
|
|
11
|
+
Parameters
|
|
12
|
+
----------
|
|
13
|
+
field : array-like
|
|
14
|
+
The data array to be checked for NaN values. This is typically an xarray.DataArray or numpy array.
|
|
15
|
+
|
|
16
|
+
mask : array-like
|
|
17
|
+
A boolean mask or data array with the same shape as `field`. The wet points (usually ocean points)
|
|
18
|
+
are indicated by `1` or `True`, and land points by `0` or `False`.
|
|
19
|
+
|
|
20
|
+
Raises
|
|
21
|
+
------
|
|
22
|
+
ValueError
|
|
23
|
+
If the field contains NaN values at any of the wet points indicated by the mask.
|
|
24
|
+
The error message will explain the potential cause and suggest ensuring the dataset's coverage.
|
|
25
|
+
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
# Replace values in field with 0 where mask is not 1
|
|
29
|
+
da = xr.where(mask == 1, field, 0)
|
|
30
|
+
|
|
31
|
+
# Check if any NaN values exist in the modified field
|
|
32
|
+
if da.isnull().any().values:
|
|
33
|
+
raise ValueError(
|
|
34
|
+
"NaN values found in interpolated field. This likely occurs because the ROMS grid, including "
|
|
35
|
+
"a small safety margin for interpolation, is not fully contained within the dataset's longitude/latitude range. Please ensure that the "
|
|
36
|
+
"dataset covers the entire area required by the ROMS grid."
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def interpolate_from_rho_to_u(field, method="additive"):
|
|
41
|
+
|
|
42
|
+
"""
|
|
43
|
+
Interpolates the given field from rho points to u points.
|
|
44
|
+
|
|
45
|
+
This function performs an interpolation from the rho grid (cell centers) to the u grid
|
|
46
|
+
(cell edges in the xi direction). Depending on the chosen method, it either averages
|
|
47
|
+
(additive) or multiplies (multiplicative) the field values between adjacent rho points
|
|
48
|
+
along the xi dimension. It also handles the removal of unnecessary coordinate variables
|
|
49
|
+
and updates the dimensions accordingly.
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
field : xr.DataArray
|
|
54
|
+
The input data array on the rho grid to be interpolated. It is assumed to have a dimension
|
|
55
|
+
named "xi_rho".
|
|
56
|
+
|
|
57
|
+
method : str, optional, default='additive'
|
|
58
|
+
The method to use for interpolation. Options are:
|
|
59
|
+
- 'additive': Average the field values between adjacent rho points.
|
|
60
|
+
- 'multiplicative': Multiply the field values between adjacent rho points. Appropriate for
|
|
61
|
+
binary masks.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
field_interpolated : xr.DataArray
|
|
66
|
+
The interpolated data array on the u grid with the dimension "xi_u".
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
if method == "additive":
|
|
70
|
+
field_interpolated = 0.5 * (field + field.shift(xi_rho=1)).isel(
|
|
71
|
+
xi_rho=slice(1, None)
|
|
72
|
+
)
|
|
73
|
+
elif method == "multiplicative":
|
|
74
|
+
field_interpolated = (field * field.shift(xi_rho=1)).isel(xi_rho=slice(1, None))
|
|
75
|
+
else:
|
|
76
|
+
raise NotImplementedError(f"Unsupported method '{method}' specified.")
|
|
77
|
+
|
|
78
|
+
if "lat_rho" in field_interpolated.coords:
|
|
79
|
+
field_interpolated.drop_vars(["lat_rho"])
|
|
80
|
+
if "lon_rho" in field_interpolated.coords:
|
|
81
|
+
field_interpolated.drop_vars(["lon_rho"])
|
|
82
|
+
|
|
83
|
+
field_interpolated = field_interpolated.swap_dims({"xi_rho": "xi_u"})
|
|
84
|
+
|
|
85
|
+
return field_interpolated
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def interpolate_from_rho_to_v(field, method="additive"):
|
|
89
|
+
|
|
90
|
+
"""
|
|
91
|
+
Interpolates the given field from rho points to v points.
|
|
92
|
+
|
|
93
|
+
This function performs an interpolation from the rho grid (cell centers) to the v grid
|
|
94
|
+
(cell edges in the eta direction). Depending on the chosen method, it either averages
|
|
95
|
+
(additive) or multiplies (multiplicative) the field values between adjacent rho points
|
|
96
|
+
along the eta dimension. It also handles the removal of unnecessary coordinate variables
|
|
97
|
+
and updates the dimensions accordingly.
|
|
98
|
+
|
|
99
|
+
Parameters
|
|
100
|
+
----------
|
|
101
|
+
field : xr.DataArray
|
|
102
|
+
The input data array on the rho grid to be interpolated. It is assumed to have a dimension
|
|
103
|
+
named "eta_rho".
|
|
104
|
+
|
|
105
|
+
method : str, optional, default='additive'
|
|
106
|
+
The method to use for interpolation. Options are:
|
|
107
|
+
- 'additive': Average the field values between adjacent rho points.
|
|
108
|
+
- 'multiplicative': Multiply the field values between adjacent rho points. Appropriate for
|
|
109
|
+
binary masks.
|
|
110
|
+
|
|
111
|
+
Returns
|
|
112
|
+
-------
|
|
113
|
+
field_interpolated : xr.DataArray
|
|
114
|
+
The interpolated data array on the v grid with the dimension "eta_v".
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
if method == "additive":
|
|
118
|
+
field_interpolated = 0.5 * (field + field.shift(eta_rho=1)).isel(
|
|
119
|
+
eta_rho=slice(1, None)
|
|
120
|
+
)
|
|
121
|
+
elif method == "multiplicative":
|
|
122
|
+
field_interpolated = (field * field.shift(eta_rho=1)).isel(
|
|
123
|
+
eta_rho=slice(1, None)
|
|
124
|
+
)
|
|
125
|
+
else:
|
|
126
|
+
raise NotImplementedError(f"Unsupported method '{method}' specified.")
|
|
127
|
+
|
|
128
|
+
if "lat_rho" in field_interpolated.coords:
|
|
129
|
+
field_interpolated.drop_vars(["lat_rho"])
|
|
130
|
+
if "lon_rho" in field_interpolated.coords:
|
|
131
|
+
field_interpolated.drop_vars(["lon_rho"])
|
|
132
|
+
|
|
133
|
+
field_interpolated = field_interpolated.swap_dims({"eta_rho": "eta_v"})
|
|
134
|
+
|
|
135
|
+
return field_interpolated
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def extrapolate_deepest_to_bottom(field: xr.DataArray, dim: str) -> xr.DataArray:
|
|
139
|
+
"""
|
|
140
|
+
Extrapolate the deepest non-NaN values to the bottom along a specified dimension.
|
|
141
|
+
|
|
142
|
+
Parameters
|
|
143
|
+
----------
|
|
144
|
+
field : xr.DataArray
|
|
145
|
+
The input data array containing NaN values that need to be filled. This array
|
|
146
|
+
should have at least one dimension named by `dim`.
|
|
147
|
+
dim : str
|
|
148
|
+
The name of the dimension along which to perform the interpolation and extrapolation.
|
|
149
|
+
Typically, this would be a vertical dimension such as 'depth' or 's_rho'.
|
|
150
|
+
|
|
151
|
+
Returns
|
|
152
|
+
-------
|
|
153
|
+
field_interpolated : xr.DataArray
|
|
154
|
+
A new data array with NaN values along the specified dimension filled by nearest
|
|
155
|
+
neighbor interpolation and extrapolation to the bottom. The original data array is not modified.
|
|
156
|
+
|
|
157
|
+
"""
|
|
158
|
+
field_interpolated = field.interpolate_na(
|
|
159
|
+
dim=dim, method="nearest", fill_value="extrapolate"
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
return field_interpolated
|