roms-tools 0.0.6__py3-none-any.whl → 0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,257 @@
1
+ import xarray as xr
2
+ import numpy as np
3
+ import gcm_filters
4
+ from scipy.interpolate import RegularGridInterpolator
5
+ from scipy.ndimage import label
6
+ from roms_tools.setup.datasets import fetch_topo
7
+ from roms_tools.setup.utils import interpolate_from_rho_to_u, interpolate_from_rho_to_v
8
+ import warnings
9
+ from itertools import count
10
+
11
+
12
+ def _add_topography_and_mask(
13
+ ds, topography_source, smooth_factor, hmin, rmax
14
+ ) -> xr.Dataset:
15
+ lon = ds.lon_rho.values
16
+ lat = ds.lat_rho.values
17
+
18
+ # interpolate topography onto desired grid
19
+ hraw = _make_raw_topography(lon, lat, topography_source)
20
+ hraw = xr.DataArray(data=hraw, dims=["eta_rho", "xi_rho"])
21
+
22
+ # Mask is obtained by finding locations where ocean depth is positive
23
+ mask = xr.where(hraw > 0, 1.0, 0.0)
24
+
25
+ # smooth topography domain-wide with Gaussian kernel to avoid grid scale instabilities
26
+ ds["hraw"] = _smooth_topography_globally(hraw, mask, smooth_factor)
27
+ ds["hraw"].attrs = {
28
+ "long_name": "Working bathymetry at rho-points",
29
+ "source": f"Raw bathymetry from {topography_source} (smoothing diameter {smooth_factor})",
30
+ "units": "meter",
31
+ }
32
+
33
+ # fill enclosed basins with land
34
+ mask = _fill_enclosed_basins(mask.values)
35
+ ds["mask_rho"] = xr.DataArray(mask, dims=("eta_rho", "xi_rho"))
36
+ ds["mask_rho"].attrs = {
37
+ "long_name": "Mask at rho-points",
38
+ "units": "land/water (0/1)",
39
+ }
40
+
41
+ ds = _add_velocity_masks(ds)
42
+
43
+ # smooth topography locally to satisfy r < rmax
44
+ ds["h"] = _smooth_topography_locally(ds["hraw"] * ds["mask_rho"], hmin, rmax)
45
+ ds["h"].attrs = {
46
+ "long_name": "Final bathymetry at rho-points",
47
+ "units": "meter",
48
+ }
49
+
50
+ ds = _add_topography_metadata(ds, topography_source, smooth_factor, hmin, rmax)
51
+
52
+ return ds
53
+
54
+
55
+ def _make_raw_topography(lon, lat, topography_source) -> np.ndarray:
56
+ """
57
+ Given a grid of (lon, lat) points, fetch the topography file and interpolate height values onto the desired grid.
58
+ """
59
+
60
+ topo_ds = fetch_topo(topography_source)
61
+
62
+ # the following will depend on the topography source
63
+ if topography_source == "ETOPO5":
64
+ topo_lon = topo_ds["topo_lon"].copy()
65
+ # Modify longitude values where necessary
66
+ topo_lon = xr.where(topo_lon < 0, topo_lon + 360, topo_lon)
67
+ topo_lon_minus360 = topo_lon - 360
68
+ topo_lon_plus360 = topo_lon + 360
69
+ # Concatenate along the longitude axis
70
+ topo_lon_concatenated = xr.concat(
71
+ [topo_lon_minus360, topo_lon, topo_lon_plus360], dim="lon"
72
+ )
73
+ topo_concatenated = xr.concat(
74
+ [-topo_ds["topo"], -topo_ds["topo"], -topo_ds["topo"]], dim="lon"
75
+ )
76
+
77
+ interp = RegularGridInterpolator(
78
+ (topo_ds["topo_lat"].values, topo_lon_concatenated.values),
79
+ topo_concatenated.values,
80
+ method="linear",
81
+ )
82
+
83
+ # Interpolate onto desired domain grid points
84
+ hraw = interp((lat, lon))
85
+
86
+ return hraw
87
+
88
+
89
+ def _smooth_topography_globally(hraw, wet_mask, factor) -> xr.DataArray:
90
+ # since GCM-Filters assumes periodic domain, we extend the domain by one grid cell in each dimension
91
+ # and set that margin to land
92
+ margin_mask = xr.concat([wet_mask, 0 * wet_mask.isel(eta_rho=-1)], dim="eta_rho")
93
+ margin_mask = xr.concat(
94
+ [margin_mask, 0 * margin_mask.isel(xi_rho=-1)], dim="xi_rho"
95
+ )
96
+
97
+ # we choose a Gaussian filter kernel corresponding to a Gaussian with standard deviation factor/sqrt(12);
98
+ # this standard deviation matches the standard deviation of a boxcar kernel with total width equal to factor.
99
+ filter = gcm_filters.Filter(
100
+ filter_scale=factor,
101
+ dx_min=1,
102
+ filter_shape=gcm_filters.FilterShape.GAUSSIAN,
103
+ grid_type=gcm_filters.GridType.REGULAR_WITH_LAND,
104
+ grid_vars={"wet_mask": margin_mask},
105
+ )
106
+ hraw_extended = xr.concat([hraw, hraw.isel(eta_rho=-1)], dim="eta_rho")
107
+ hraw_extended = xr.concat(
108
+ [hraw_extended, hraw_extended.isel(xi_rho=-1)], dim="xi_rho"
109
+ )
110
+
111
+ hsmooth = filter.apply(hraw_extended, dims=["eta_rho", "xi_rho"])
112
+ hsmooth = hsmooth.isel(eta_rho=slice(None, -1), xi_rho=slice(None, -1))
113
+
114
+ return hsmooth
115
+
116
+
117
+ def _fill_enclosed_basins(mask) -> np.ndarray:
118
+ """
119
+ Fills in enclosed basins with land
120
+ """
121
+
122
+ # Label connected regions in the mask
123
+ reg, nreg = label(mask)
124
+ # Find the largest region
125
+ lint = 0
126
+ lreg = 0
127
+ for ireg in range(nreg):
128
+ int_ = np.sum(reg == ireg)
129
+ if int_ > lint and mask[reg == ireg].sum() > 0:
130
+ lreg = ireg
131
+ lint = int_
132
+
133
+ # Remove regions other than the largest one
134
+ for ireg in range(nreg):
135
+ if ireg != lreg:
136
+ mask[reg == ireg] = 0
137
+
138
+ return mask
139
+
140
+
141
+ def _smooth_topography_locally(h, hmin=5, rmax=0.2):
142
+ """
143
+ Smoothes topography locally to satisfy r < rmax
144
+ """
145
+ # Compute rmax_log
146
+ if rmax > 0.0:
147
+ rmax_log = np.log((1.0 + rmax * 0.9) / (1.0 - rmax * 0.9))
148
+ else:
149
+ rmax_log = 0.0
150
+
151
+ # Apply hmin threshold
152
+ h = xr.where(h < hmin, hmin, h)
153
+
154
+ # We will smooth logarithmically
155
+ h_log = np.log(h / hmin)
156
+
157
+ cf1 = 1.0 / 6
158
+ cf2 = 0.25
159
+
160
+ for iter in count():
161
+ # Compute gradients in domain interior
162
+
163
+ # in eta-direction
164
+ cff = h_log.diff("eta_rho").isel(xi_rho=slice(1, -1))
165
+ cr = np.abs(cff)
166
+ with warnings.catch_warnings():
167
+ warnings.simplefilter("ignore") # Ignore division by zero warning
168
+ Op1 = xr.where(cr < rmax_log, 0, 1.0 * cff * (1 - rmax_log / cr))
169
+
170
+ # in xi-direction
171
+ cff = h_log.diff("xi_rho").isel(eta_rho=slice(1, -1))
172
+ cr = np.abs(cff)
173
+ with warnings.catch_warnings():
174
+ warnings.simplefilter("ignore") # Ignore division by zero warning
175
+ Op2 = xr.where(cr < rmax_log, 0, 1.0 * cff * (1 - rmax_log / cr))
176
+
177
+ # in diagonal direction
178
+ cff = (h_log - h_log.shift(eta_rho=1, xi_rho=1)).isel(
179
+ eta_rho=slice(1, None), xi_rho=slice(1, None)
180
+ )
181
+ cr = np.abs(cff)
182
+ with warnings.catch_warnings():
183
+ warnings.simplefilter("ignore") # Ignore division by zero warning
184
+ Op3 = xr.where(cr < rmax_log, 0, 1.0 * cff * (1 - rmax_log / cr))
185
+
186
+ # in the other diagonal direction
187
+ cff = (h_log.shift(eta_rho=1) - h_log.shift(xi_rho=1)).isel(
188
+ eta_rho=slice(1, None), xi_rho=slice(1, None)
189
+ )
190
+ cr = np.abs(cff)
191
+ with warnings.catch_warnings():
192
+ warnings.simplefilter("ignore") # Ignore division by zero warning
193
+ Op4 = xr.where(cr < rmax_log, 0, 1.0 * cff * (1 - rmax_log / cr))
194
+
195
+ # Update h_log in domain interior
196
+ h_log[1:-1, 1:-1] += cf1 * (
197
+ Op1[1:, :]
198
+ - Op1[:-1, :]
199
+ + Op2[:, 1:]
200
+ - Op2[:, :-1]
201
+ + cf2 * (Op3[1:, 1:] - Op3[:-1, :-1] + Op4[:-1, 1:] - Op4[1:, :-1])
202
+ )
203
+
204
+ # No gradient at the domain boundaries
205
+ h_log[0, :] = h_log[1, :]
206
+ h_log[-1, :] = h_log[-2, :]
207
+ h_log[:, 0] = h_log[:, 1]
208
+ h_log[:, -1] = h_log[:, -2]
209
+
210
+ # Update h
211
+ h = hmin * np.exp(h_log)
212
+ # Apply hmin threshold again
213
+ h = xr.where(h < hmin, hmin, h)
214
+
215
+ # compute maximum slope parameter r
216
+ r_eta, r_xi = _compute_rfactor(h)
217
+ rmax0 = np.max([r_eta.max(), r_xi.max()])
218
+ if rmax0 < rmax:
219
+ break
220
+
221
+ return h
222
+
223
+
224
+ def _compute_rfactor(h):
225
+ """
226
+ Computes slope parameter (or r-factor) r = |Delta h| / 2h in both horizontal grid directions.
227
+ """
228
+ # compute r_{i-1/2} = |h_i - h_{i-1}| / (h_i + h_{i+1})
229
+ r_eta = np.abs(h.diff("eta_rho")) / (h + h.shift(eta_rho=1)).isel(
230
+ eta_rho=slice(1, None)
231
+ )
232
+ r_xi = np.abs(h.diff("xi_rho")) / (h + h.shift(xi_rho=1)).isel(
233
+ xi_rho=slice(1, None)
234
+ )
235
+
236
+ return r_eta, r_xi
237
+
238
+
239
+ def _add_topography_metadata(ds, topography_source, smooth_factor, hmin, rmax):
240
+ ds.attrs["topography_source"] = topography_source
241
+ ds.attrs["smooth_factor"] = smooth_factor
242
+ ds.attrs["hmin"] = hmin
243
+ ds.attrs["rmax"] = rmax
244
+
245
+ return ds
246
+
247
+
248
+ def _add_velocity_masks(ds):
249
+
250
+ # add u- and v-masks
251
+ ds["mask_u"] = interpolate_from_rho_to_u(ds["mask_rho"], method="multiplicative")
252
+ ds["mask_v"] = interpolate_from_rho_to_v(ds["mask_rho"], method="multiplicative")
253
+
254
+ ds["mask_u"].attrs = {"long_name": "Mask at u-points", "units": "land/water (0/1)"}
255
+ ds["mask_v"].attrs = {"long_name": "Mask at v-points", "units": "land/water (0/1)"}
256
+
257
+ return ds
@@ -0,0 +1,162 @@
1
+ import xarray as xr
2
+
3
+
4
+ def nan_check(field, mask) -> None:
5
+ """
6
+ Checks for NaN values at wet points in the field.
7
+
8
+ This function examines the interpolated input field for NaN values at positions indicated as wet points by the mask.
9
+ If any NaN values are found at these wet points, a ValueError is raised.
10
+
11
+ Parameters
12
+ ----------
13
+ field : array-like
14
+ The data array to be checked for NaN values. This is typically an xarray.DataArray or numpy array.
15
+
16
+ mask : array-like
17
+ A boolean mask or data array with the same shape as `field`. The wet points (usually ocean points)
18
+ are indicated by `1` or `True`, and land points by `0` or `False`.
19
+
20
+ Raises
21
+ ------
22
+ ValueError
23
+ If the field contains NaN values at any of the wet points indicated by the mask.
24
+ The error message will explain the potential cause and suggest ensuring the dataset's coverage.
25
+
26
+ """
27
+
28
+ # Replace values in field with 0 where mask is not 1
29
+ da = xr.where(mask == 1, field, 0)
30
+
31
+ # Check if any NaN values exist in the modified field
32
+ if da.isnull().any().values:
33
+ raise ValueError(
34
+ "NaN values found in interpolated field. This likely occurs because the ROMS grid, including "
35
+ "a small safety margin for interpolation, is not fully contained within the dataset's longitude/latitude range. Please ensure that the "
36
+ "dataset covers the entire area required by the ROMS grid."
37
+ )
38
+
39
+
40
+ def interpolate_from_rho_to_u(field, method="additive"):
41
+
42
+ """
43
+ Interpolates the given field from rho points to u points.
44
+
45
+ This function performs an interpolation from the rho grid (cell centers) to the u grid
46
+ (cell edges in the xi direction). Depending on the chosen method, it either averages
47
+ (additive) or multiplies (multiplicative) the field values between adjacent rho points
48
+ along the xi dimension. It also handles the removal of unnecessary coordinate variables
49
+ and updates the dimensions accordingly.
50
+
51
+ Parameters
52
+ ----------
53
+ field : xr.DataArray
54
+ The input data array on the rho grid to be interpolated. It is assumed to have a dimension
55
+ named "xi_rho".
56
+
57
+ method : str, optional, default='additive'
58
+ The method to use for interpolation. Options are:
59
+ - 'additive': Average the field values between adjacent rho points.
60
+ - 'multiplicative': Multiply the field values between adjacent rho points. Appropriate for
61
+ binary masks.
62
+
63
+ Returns
64
+ -------
65
+ field_interpolated : xr.DataArray
66
+ The interpolated data array on the u grid with the dimension "xi_u".
67
+ """
68
+
69
+ if method == "additive":
70
+ field_interpolated = 0.5 * (field + field.shift(xi_rho=1)).isel(
71
+ xi_rho=slice(1, None)
72
+ )
73
+ elif method == "multiplicative":
74
+ field_interpolated = (field * field.shift(xi_rho=1)).isel(xi_rho=slice(1, None))
75
+ else:
76
+ raise NotImplementedError(f"Unsupported method '{method}' specified.")
77
+
78
+ if "lat_rho" in field_interpolated.coords:
79
+ field_interpolated.drop_vars(["lat_rho"])
80
+ if "lon_rho" in field_interpolated.coords:
81
+ field_interpolated.drop_vars(["lon_rho"])
82
+
83
+ field_interpolated = field_interpolated.swap_dims({"xi_rho": "xi_u"})
84
+
85
+ return field_interpolated
86
+
87
+
88
+ def interpolate_from_rho_to_v(field, method="additive"):
89
+
90
+ """
91
+ Interpolates the given field from rho points to v points.
92
+
93
+ This function performs an interpolation from the rho grid (cell centers) to the v grid
94
+ (cell edges in the eta direction). Depending on the chosen method, it either averages
95
+ (additive) or multiplies (multiplicative) the field values between adjacent rho points
96
+ along the eta dimension. It also handles the removal of unnecessary coordinate variables
97
+ and updates the dimensions accordingly.
98
+
99
+ Parameters
100
+ ----------
101
+ field : xr.DataArray
102
+ The input data array on the rho grid to be interpolated. It is assumed to have a dimension
103
+ named "eta_rho".
104
+
105
+ method : str, optional, default='additive'
106
+ The method to use for interpolation. Options are:
107
+ - 'additive': Average the field values between adjacent rho points.
108
+ - 'multiplicative': Multiply the field values between adjacent rho points. Appropriate for
109
+ binary masks.
110
+
111
+ Returns
112
+ -------
113
+ field_interpolated : xr.DataArray
114
+ The interpolated data array on the v grid with the dimension "eta_v".
115
+ """
116
+
117
+ if method == "additive":
118
+ field_interpolated = 0.5 * (field + field.shift(eta_rho=1)).isel(
119
+ eta_rho=slice(1, None)
120
+ )
121
+ elif method == "multiplicative":
122
+ field_interpolated = (field * field.shift(eta_rho=1)).isel(
123
+ eta_rho=slice(1, None)
124
+ )
125
+ else:
126
+ raise NotImplementedError(f"Unsupported method '{method}' specified.")
127
+
128
+ if "lat_rho" in field_interpolated.coords:
129
+ field_interpolated.drop_vars(["lat_rho"])
130
+ if "lon_rho" in field_interpolated.coords:
131
+ field_interpolated.drop_vars(["lon_rho"])
132
+
133
+ field_interpolated = field_interpolated.swap_dims({"eta_rho": "eta_v"})
134
+
135
+ return field_interpolated
136
+
137
+
138
+ def extrapolate_deepest_to_bottom(field: xr.DataArray, dim: str) -> xr.DataArray:
139
+ """
140
+ Extrapolate the deepest non-NaN values to the bottom along a specified dimension.
141
+
142
+ Parameters
143
+ ----------
144
+ field : xr.DataArray
145
+ The input data array containing NaN values that need to be filled. This array
146
+ should have at least one dimension named by `dim`.
147
+ dim : str
148
+ The name of the dimension along which to perform the interpolation and extrapolation.
149
+ Typically, this would be a vertical dimension such as 'depth' or 's_rho'.
150
+
151
+ Returns
152
+ -------
153
+ field_interpolated : xr.DataArray
154
+ A new data array with NaN values along the specified dimension filled by nearest
155
+ neighbor interpolation and extrapolation to the bottom. The original data array is not modified.
156
+
157
+ """
158
+ field_interpolated = field.interpolate_na(
159
+ dim=dim, method="nearest", fill_value="extrapolate"
160
+ )
161
+
162
+ return field_interpolated