roms-tools 0.1.0__py3-none-any.whl → 0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,6 +4,7 @@ import gcm_filters
4
4
  from scipy.interpolate import RegularGridInterpolator
5
5
  from scipy.ndimage import label
6
6
  from roms_tools.setup.datasets import fetch_topo
7
+ from roms_tools.setup.utils import interpolate_from_rho_to_u, interpolate_from_rho_to_v
7
8
  import warnings
8
9
  from itertools import count
9
10
 
@@ -19,7 +20,7 @@ def _add_topography_and_mask(
19
20
  hraw = xr.DataArray(data=hraw, dims=["eta_rho", "xi_rho"])
20
21
 
21
22
  # Mask is obtained by finding locations where ocean depth is positive
22
- mask = xr.where(hraw > 0, 1, 0)
23
+ mask = xr.where(hraw > 0, 1.0, 0.0)
23
24
 
24
25
  # smooth topography domain-wide with Gaussian kernel to avoid grid scale instabilities
25
26
  ds["hraw"] = _smooth_topography_globally(hraw, mask, smooth_factor)
@@ -37,6 +38,8 @@ def _add_topography_and_mask(
37
38
  "units": "land/water (0/1)",
38
39
  }
39
40
 
41
+ ds = _add_velocity_masks(ds)
42
+
40
43
  # smooth topography locally to satisfy r < rmax
41
44
  ds["h"] = _smooth_topography_locally(ds["hraw"] * ds["mask_rho"], hmin, rmax)
42
45
  ds["h"].attrs = {
@@ -57,7 +60,7 @@ def _make_raw_topography(lon, lat, topography_source) -> np.ndarray:
57
60
  topo_ds = fetch_topo(topography_source)
58
61
 
59
62
  # the following will depend on the topography source
60
- if topography_source == "etopo5":
63
+ if topography_source == "ETOPO5":
61
64
  topo_lon = topo_ds["topo_lon"].copy()
62
65
  # Modify longitude values where necessary
63
66
  topo_lon = xr.where(topo_lon < 0, topo_lon + 360, topo_lon)
@@ -240,3 +243,15 @@ def _add_topography_metadata(ds, topography_source, smooth_factor, hmin, rmax):
240
243
  ds.attrs["rmax"] = rmax
241
244
 
242
245
  return ds
246
+
247
+
248
+ def _add_velocity_masks(ds):
249
+
250
+ # add u- and v-masks
251
+ ds["mask_u"] = interpolate_from_rho_to_u(ds["mask_rho"], method="multiplicative")
252
+ ds["mask_v"] = interpolate_from_rho_to_v(ds["mask_rho"], method="multiplicative")
253
+
254
+ ds["mask_u"].attrs = {"long_name": "Mask at u-points", "units": "land/water (0/1)"}
255
+ ds["mask_v"].attrs = {"long_name": "Mask at v-points", "units": "land/water (0/1)"}
256
+
257
+ return ds
@@ -0,0 +1,162 @@
1
+ import xarray as xr
2
+
3
+
4
+ def nan_check(field, mask) -> None:
5
+ """
6
+ Checks for NaN values at wet points in the field.
7
+
8
+ This function examines the interpolated input field for NaN values at positions indicated as wet points by the mask.
9
+ If any NaN values are found at these wet points, a ValueError is raised.
10
+
11
+ Parameters
12
+ ----------
13
+ field : array-like
14
+ The data array to be checked for NaN values. This is typically an xarray.DataArray or numpy array.
15
+
16
+ mask : array-like
17
+ A boolean mask or data array with the same shape as `field`. The wet points (usually ocean points)
18
+ are indicated by `1` or `True`, and land points by `0` or `False`.
19
+
20
+ Raises
21
+ ------
22
+ ValueError
23
+ If the field contains NaN values at any of the wet points indicated by the mask.
24
+ The error message will explain the potential cause and suggest ensuring the dataset's coverage.
25
+
26
+ """
27
+
28
+ # Replace values in field with 0 where mask is not 1
29
+ da = xr.where(mask == 1, field, 0)
30
+
31
+ # Check if any NaN values exist in the modified field
32
+ if da.isnull().any().values:
33
+ raise ValueError(
34
+ "NaN values found in interpolated field. This likely occurs because the ROMS grid, including "
35
+ "a small safety margin for interpolation, is not fully contained within the dataset's longitude/latitude range. Please ensure that the "
36
+ "dataset covers the entire area required by the ROMS grid."
37
+ )
38
+
39
+
40
+ def interpolate_from_rho_to_u(field, method="additive"):
41
+
42
+ """
43
+ Interpolates the given field from rho points to u points.
44
+
45
+ This function performs an interpolation from the rho grid (cell centers) to the u grid
46
+ (cell edges in the xi direction). Depending on the chosen method, it either averages
47
+ (additive) or multiplies (multiplicative) the field values between adjacent rho points
48
+ along the xi dimension. It also handles the removal of unnecessary coordinate variables
49
+ and updates the dimensions accordingly.
50
+
51
+ Parameters
52
+ ----------
53
+ field : xr.DataArray
54
+ The input data array on the rho grid to be interpolated. It is assumed to have a dimension
55
+ named "xi_rho".
56
+
57
+ method : str, optional, default='additive'
58
+ The method to use for interpolation. Options are:
59
+ - 'additive': Average the field values between adjacent rho points.
60
+ - 'multiplicative': Multiply the field values between adjacent rho points. Appropriate for
61
+ binary masks.
62
+
63
+ Returns
64
+ -------
65
+ field_interpolated : xr.DataArray
66
+ The interpolated data array on the u grid with the dimension "xi_u".
67
+ """
68
+
69
+ if method == "additive":
70
+ field_interpolated = 0.5 * (field + field.shift(xi_rho=1)).isel(
71
+ xi_rho=slice(1, None)
72
+ )
73
+ elif method == "multiplicative":
74
+ field_interpolated = (field * field.shift(xi_rho=1)).isel(xi_rho=slice(1, None))
75
+ else:
76
+ raise NotImplementedError(f"Unsupported method '{method}' specified.")
77
+
78
+ if "lat_rho" in field_interpolated.coords:
79
+ field_interpolated.drop_vars(["lat_rho"])
80
+ if "lon_rho" in field_interpolated.coords:
81
+ field_interpolated.drop_vars(["lon_rho"])
82
+
83
+ field_interpolated = field_interpolated.swap_dims({"xi_rho": "xi_u"})
84
+
85
+ return field_interpolated
86
+
87
+
88
+ def interpolate_from_rho_to_v(field, method="additive"):
89
+
90
+ """
91
+ Interpolates the given field from rho points to v points.
92
+
93
+ This function performs an interpolation from the rho grid (cell centers) to the v grid
94
+ (cell edges in the eta direction). Depending on the chosen method, it either averages
95
+ (additive) or multiplies (multiplicative) the field values between adjacent rho points
96
+ along the eta dimension. It also handles the removal of unnecessary coordinate variables
97
+ and updates the dimensions accordingly.
98
+
99
+ Parameters
100
+ ----------
101
+ field : xr.DataArray
102
+ The input data array on the rho grid to be interpolated. It is assumed to have a dimension
103
+ named "eta_rho".
104
+
105
+ method : str, optional, default='additive'
106
+ The method to use for interpolation. Options are:
107
+ - 'additive': Average the field values between adjacent rho points.
108
+ - 'multiplicative': Multiply the field values between adjacent rho points. Appropriate for
109
+ binary masks.
110
+
111
+ Returns
112
+ -------
113
+ field_interpolated : xr.DataArray
114
+ The interpolated data array on the v grid with the dimension "eta_v".
115
+ """
116
+
117
+ if method == "additive":
118
+ field_interpolated = 0.5 * (field + field.shift(eta_rho=1)).isel(
119
+ eta_rho=slice(1, None)
120
+ )
121
+ elif method == "multiplicative":
122
+ field_interpolated = (field * field.shift(eta_rho=1)).isel(
123
+ eta_rho=slice(1, None)
124
+ )
125
+ else:
126
+ raise NotImplementedError(f"Unsupported method '{method}' specified.")
127
+
128
+ if "lat_rho" in field_interpolated.coords:
129
+ field_interpolated.drop_vars(["lat_rho"])
130
+ if "lon_rho" in field_interpolated.coords:
131
+ field_interpolated.drop_vars(["lon_rho"])
132
+
133
+ field_interpolated = field_interpolated.swap_dims({"eta_rho": "eta_v"})
134
+
135
+ return field_interpolated
136
+
137
+
138
+ def extrapolate_deepest_to_bottom(field: xr.DataArray, dim: str) -> xr.DataArray:
139
+ """
140
+ Extrapolate the deepest non-NaN values to the bottom along a specified dimension.
141
+
142
+ Parameters
143
+ ----------
144
+ field : xr.DataArray
145
+ The input data array containing NaN values that need to be filled. This array
146
+ should have at least one dimension named by `dim`.
147
+ dim : str
148
+ The name of the dimension along which to perform the interpolation and extrapolation.
149
+ Typically, this would be a vertical dimension such as 'depth' or 's_rho'.
150
+
151
+ Returns
152
+ -------
153
+ field_interpolated : xr.DataArray
154
+ A new data array with NaN values along the specified dimension filled by nearest
155
+ neighbor interpolation and extrapolation to the bottom. The original data array is not modified.
156
+
157
+ """
158
+ field_interpolated = field.interpolate_na(
159
+ dim=dim, method="nearest", fill_value="extrapolate"
160
+ )
161
+
162
+ return field_interpolated
@@ -0,0 +1,494 @@
1
+ import numpy as np
2
+ import xarray as xr
3
+ import yaml
4
+ import importlib.metadata
5
+ from dataclasses import dataclass, field, asdict
6
+ from roms_tools.setup.grid import Grid
7
+ from roms_tools.setup.utils import (
8
+ interpolate_from_rho_to_u,
9
+ interpolate_from_rho_to_v,
10
+ )
11
+ from roms_tools.setup.plot import _plot, _section_plot, _profile_plot, _line_plot
12
+ import matplotlib.pyplot as plt
13
+
14
+
15
+ @dataclass(frozen=True, kw_only=True)
16
+ class VerticalCoordinate:
17
+ """
18
+ Represents vertical coordinate for ROMS.
19
+
20
+ Parameters
21
+ ----------
22
+ grid : Grid
23
+ Object representing the grid information.
24
+ N : int
25
+ The number of vertical levels.
26
+ theta_s : float
27
+ The surface control parameter. Must satisfy 0 < theta_s <= 10.
28
+ theta_b : float
29
+ The bottom control parameter. Must satisfy 0 < theta_b <= 4.
30
+ hc : float
31
+ The critical depth.
32
+
33
+ Attributes
34
+ ----------
35
+ ds : xr.Dataset
36
+ Xarray Dataset containing the atmospheric forcing data.
37
+ """
38
+
39
+ grid: Grid
40
+ N: int
41
+ theta_s: float
42
+ theta_b: float
43
+ hc: float
44
+
45
+ ds: xr.Dataset = field(init=False, repr=False)
46
+
47
+ def __post_init__(self):
48
+
49
+ h = self.grid.ds.h
50
+
51
+ cs_r, sigma_r = sigma_stretch(self.theta_s, self.theta_b, self.N, "r")
52
+ zr = compute_depth(h * 0, h, self.hc, cs_r, sigma_r)
53
+ cs_w, sigma_w = sigma_stretch(self.theta_s, self.theta_b, self.N, "w")
54
+ zw = compute_depth(h * 0, h, self.hc, cs_w, sigma_w)
55
+
56
+ ds = xr.Dataset()
57
+
58
+ ds["theta_s"] = np.float32(self.theta_s)
59
+ ds["theta_s"].attrs["long_name"] = "S-coordinate surface control parameter"
60
+ ds["theta_s"].attrs["units"] = "nondimensional"
61
+
62
+ ds["theta_b"] = np.float32(self.theta_b)
63
+ ds["theta_b"].attrs["long_name"] = "S-coordinate bottom control parameter"
64
+ ds["theta_b"].attrs["units"] = "nondimensional"
65
+
66
+ ds["Tcline"] = np.float32(self.hc)
67
+ ds["Tcline"].attrs["long_name"] = "S-coordinate surface/bottom layer width"
68
+ ds["Tcline"].attrs["units"] = "m"
69
+
70
+ ds["hc"] = np.float32(self.hc)
71
+ ds["hc"].attrs["long_name"] = "S-coordinate parameter critical depth"
72
+ ds["hc"].attrs["units"] = "m"
73
+
74
+ ds["sc_r"] = sigma_r.astype(np.float32)
75
+ ds["sc_r"].attrs["long_name"] = "S-coordinate at rho-points"
76
+ ds["sc_r"].attrs["units"] = "nondimensional"
77
+
78
+ ds["Cs_r"] = cs_r.astype(np.float32)
79
+ ds["Cs_r"].attrs["long_name"] = "S-coordinate stretching curves at rho-points"
80
+ ds["Cs_r"].attrs["units"] = "nondimensional"
81
+
82
+ depth = -zr
83
+ depth.attrs["long_name"] = "Layer depth at rho-points"
84
+ depth.attrs["units"] = "m"
85
+ ds = ds.assign_coords({"layer_depth_rho": depth.astype(np.float32)})
86
+
87
+ depth_u = interpolate_from_rho_to_u(depth).astype(np.float32)
88
+ depth_u.attrs["long_name"] = "Layer depth at u-points"
89
+ depth_u.attrs["units"] = "m"
90
+ ds = ds.assign_coords({"layer_depth_u": depth_u})
91
+
92
+ depth_v = interpolate_from_rho_to_v(depth).astype(np.float32)
93
+ depth_v.attrs["long_name"] = "Layer depth at v-points"
94
+ depth_v.attrs["units"] = "m"
95
+ ds = ds.assign_coords({"layer_depth_v": depth_v})
96
+
97
+ depth = -zw
98
+ depth.attrs["long_name"] = "Interface depth at rho-points"
99
+ depth.attrs["units"] = "m"
100
+ ds = ds.assign_coords({"interface_depth_rho": depth.astype(np.float32)})
101
+
102
+ depth_u = interpolate_from_rho_to_u(depth).astype(np.float32)
103
+ depth_u.attrs["long_name"] = "Interface depth at u-points"
104
+ depth_u.attrs["units"] = "m"
105
+ ds = ds.assign_coords({"interface_depth_u": depth_u})
106
+
107
+ depth_v = interpolate_from_rho_to_v(depth).astype(np.float32)
108
+ depth_v.attrs["long_name"] = "Interface depth at v-points"
109
+ depth_v.attrs["units"] = "m"
110
+ ds = ds.assign_coords({"interface_depth_v": depth_v})
111
+
112
+ ds = ds.drop_vars(["eta_rho", "xi_rho"])
113
+
114
+ ds.attrs["title"] = "ROMS vertical coordinate file created by ROMS-Tools"
115
+ # Include the version of roms-tools
116
+ try:
117
+ roms_tools_version = importlib.metadata.version("roms-tools")
118
+ except importlib.metadata.PackageNotFoundError:
119
+ roms_tools_version = "unknown"
120
+ ds.attrs["roms_tools_version"] = roms_tools_version
121
+
122
+ object.__setattr__(self, "ds", ds)
123
+
124
+ def plot(
125
+ self,
126
+ varname="layer_depth_rho",
127
+ s=None,
128
+ eta=None,
129
+ xi=None,
130
+ ) -> None:
131
+ """
132
+ Plot the vertical coordinate system for a given eta-, xi-, or s-slice.
133
+
134
+ Parameters
135
+ ----------
136
+ varname : str, optional
137
+ The field to plot. Options are "depth_rho", "depth_u", "depth_v".
138
+ s: int, optional
139
+ The s-index to plot. Default is None.
140
+ eta : int, optional
141
+ The eta-index to plot. Default is None.
142
+ xi : int, optional
143
+ The xi-index to plot. Default is None.
144
+
145
+ Returns
146
+ -------
147
+ None
148
+ This method does not return any value. It generates and displays a plot.
149
+
150
+ Raises
151
+ ------
152
+ ValueError
153
+ If the specified varname is not one of the valid options.
154
+ If none of s, eta, xi are specified.
155
+ """
156
+
157
+ if not any([s is not None, eta is not None, xi is not None]):
158
+ raise ValueError("At least one of s, eta, or xi must be specified.")
159
+
160
+ self.ds[varname].load()
161
+ field = self.ds[varname].squeeze()
162
+
163
+ if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
164
+ interface_depth = self.ds.interface_depth_rho
165
+ elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
166
+ interface_depth = self.ds.interface_depth_u
167
+ elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
168
+ interface_depth = self.ds.interface_depth_v
169
+
170
+ # slice the field as desired
171
+ title = field.long_name
172
+ if s is not None:
173
+ if "s_rho" in field.dims:
174
+ title = title + f", s_rho = {field.s_rho[s].item()}"
175
+ field = field.isel(s_rho=s)
176
+ elif "s_w" in field.dims:
177
+ title = title + f", s_w = {field.s_w[s].item()}"
178
+ field = field.isel(s_w=s)
179
+ else:
180
+ raise ValueError(
181
+ f"None of the expected dimensions (s_rho, s_w) found in ds[{varname}]."
182
+ )
183
+
184
+ if eta is not None:
185
+ if "eta_rho" in field.dims:
186
+ title = title + f", eta_rho = {field.eta_rho[eta].item()}"
187
+ field = field.isel(eta_rho=eta)
188
+ interface_depth = interface_depth.isel(eta_rho=eta)
189
+ elif "eta_v" in field.dims:
190
+ title = title + f", eta_v = {field.eta_v[eta].item()}"
191
+ field = field.isel(eta_v=eta)
192
+ interface_depth = interface_depth.isel(eta_v=eta)
193
+ else:
194
+ raise ValueError(
195
+ f"None of the expected dimensions (eta_rho, eta_v) found in ds[{varname}]."
196
+ )
197
+ if xi is not None:
198
+ if "xi_rho" in field.dims:
199
+ title = title + f", xi_rho = {field.xi_rho[xi].item()}"
200
+ field = field.isel(xi_rho=xi)
201
+ interface_depth = interface_depth.isel(xi_rho=xi)
202
+ elif "xi_u" in field.dims:
203
+ title = title + f", xi_u = {field.xi_u[xi].item()}"
204
+ field = field.isel(xi_u=xi)
205
+ interface_depth = interface_depth.isel(xi_u=xi)
206
+ else:
207
+ raise ValueError(
208
+ f"None of the expected dimensions (xi_rho, xi_u) found in ds[{varname}]."
209
+ )
210
+
211
+ if eta is None and xi is None:
212
+ vmax = field.max().values
213
+ vmin = field.min().values
214
+ cmap = plt.colormaps.get_cmap("YlGnBu")
215
+ cmap.set_bad(color="gray")
216
+ kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
217
+
218
+ _plot(
219
+ self.grid.ds,
220
+ field=field,
221
+ straddle=self.grid.straddle,
222
+ depth_contours=True,
223
+ title=title,
224
+ kwargs=kwargs,
225
+ c="g",
226
+ )
227
+ else:
228
+ if len(field.dims) == 2:
229
+ cmap = plt.colormaps.get_cmap("YlGnBu")
230
+ cmap.set_bad(color="gray")
231
+ kwargs = {"vmax": 0.0, "vmin": 0.0, "cmap": cmap, "add_colorbar": False}
232
+
233
+ _section_plot(
234
+ xr.zeros_like(field),
235
+ interface_depth=interface_depth,
236
+ title=title,
237
+ kwargs=kwargs,
238
+ )
239
+ else:
240
+ if "s_rho" in field.dims or "s_w" in field.dims:
241
+ _profile_plot(field, title=title)
242
+ else:
243
+ _line_plot(field, title=title)
244
+
245
+ def save(self, filepath: str) -> None:
246
+ """
247
+ Save the vertical coordinate information to a netCDF4 file.
248
+
249
+ Parameters
250
+ ----------
251
+ filepath
252
+ """
253
+ self.ds.to_netcdf(filepath)
254
+
255
+ def to_yaml(self, filepath: str) -> None:
256
+ """
257
+ Export the parameters of the class to a YAML file, including the version of roms-tools.
258
+
259
+ Parameters
260
+ ----------
261
+ filepath : str
262
+ The path to the YAML file where the parameters will be saved.
263
+ """
264
+ # Serialize Grid data
265
+ grid_data = asdict(self.grid)
266
+ grid_data.pop("ds", None) # Exclude non-serializable fields
267
+ grid_data.pop("straddle", None)
268
+
269
+ # Include the version of roms-tools
270
+ try:
271
+ roms_tools_version = importlib.metadata.version("roms-tools")
272
+ except importlib.metadata.PackageNotFoundError:
273
+ roms_tools_version = "unknown"
274
+
275
+ # Create header
276
+ header = f"---\nroms_tools_version: {roms_tools_version}\n---\n"
277
+
278
+ grid_yaml_data = {"Grid": grid_data}
279
+
280
+ # Combine all sections
281
+ vertical_coordinate_data = {
282
+ "VerticalCoordinate": {
283
+ "N": self.N,
284
+ "theta_s": self.theta_s,
285
+ "theta_b": self.theta_b,
286
+ "hc": self.hc,
287
+ }
288
+ }
289
+
290
+ # Merge YAML data while excluding empty sections
291
+ yaml_data = {
292
+ **grid_yaml_data,
293
+ **vertical_coordinate_data,
294
+ }
295
+
296
+ with open(filepath, "w") as file:
297
+ # Write header
298
+ file.write(header)
299
+ # Write YAML data
300
+ yaml.dump(yaml_data, file, default_flow_style=False)
301
+
302
+ @classmethod
303
+ def from_file(cls, filepath: str) -> "VerticalCoordinate":
304
+ """
305
+ Create a VerticalCoordinate instance from an existing file.
306
+
307
+ Parameters
308
+ ----------
309
+ filepath : str
310
+ Path to the file containing the vertical coordinate information.
311
+
312
+ Returns
313
+ -------
314
+ VerticalCoordinate
315
+ A new instance of VerticalCoordinate populated with data from the file.
316
+ """
317
+ # Load the dataset from the file
318
+ ds = xr.open_dataset(filepath)
319
+
320
+ # Create a new VerticalCoordinate instance without calling __init__ and __post_init__
321
+ vertical_coordinate = cls.__new__(cls)
322
+
323
+ # Set the dataset for the vertical_corodinate instance
324
+ object.__setattr__(vertical_coordinate, "ds", ds)
325
+
326
+ # Manually set the remaining attributes by extracting parameters from dataset
327
+ object.__setattr__(vertical_coordinate, "N", ds.sizes["s_rho"])
328
+ object.__setattr__(vertical_coordinate, "theta_s", ds["theta_s"].values.item())
329
+ object.__setattr__(vertical_coordinate, "theta_b", ds["theta_b"].values.item())
330
+ object.__setattr__(vertical_coordinate, "hc", ds["hc"].values.item())
331
+ object.__setattr__(vertical_coordinate, "grid", None)
332
+
333
+ return vertical_coordinate
334
+
335
+ @classmethod
336
+ def from_yaml(cls, filepath: str) -> "VerticalCoordinate":
337
+ """
338
+ Create an instance of the VerticalCoordinate class from a YAML file.
339
+
340
+ Parameters
341
+ ----------
342
+ filepath : str
343
+ The path to the YAML file from which the parameters will be read.
344
+
345
+ Returns
346
+ -------
347
+ VerticalCoordinate
348
+ An instance of the VerticalCoordinate class.
349
+ """
350
+ # Read the entire file content
351
+ with open(filepath, "r") as file:
352
+ file_content = file.read()
353
+
354
+ # Split the content into YAML documents
355
+ documents = list(yaml.safe_load_all(file_content))
356
+
357
+ vertical_coordinate_data = None
358
+
359
+ # Process the YAML documents
360
+ for doc in documents:
361
+ if doc is None:
362
+ continue
363
+ if "VerticalCoordinate" in doc:
364
+ vertical_coordinate_data = doc["VerticalCoordinate"]
365
+ break
366
+
367
+ if vertical_coordinate_data is None:
368
+ raise ValueError(
369
+ "No VerticalCoordinate configuration found in the YAML file."
370
+ )
371
+
372
+ # Create Grid instance from the YAML file
373
+ grid = Grid.from_yaml(filepath)
374
+
375
+ # Create and return an instance of TidalForcing
376
+ return cls(
377
+ grid=grid,
378
+ **vertical_coordinate_data,
379
+ )
380
+
381
+
382
+ def compute_cs(sigma, theta_s, theta_b):
383
+ """
384
+ Compute the S-coordinate stretching curves according to Shchepetkin and McWilliams (2009).
385
+
386
+ Parameters
387
+ ----------
388
+ sigma : np.ndarray or float
389
+ The sigma-coordinate values.
390
+ theta_s : float
391
+ The surface control parameter.
392
+ theta_b : float
393
+ The bottom control parameter.
394
+
395
+ Returns
396
+ -------
397
+ C : np.ndarray or float
398
+ The stretching curve values.
399
+
400
+ Raises
401
+ ------
402
+ ValueError
403
+ If theta_s or theta_b are not within the valid range.
404
+ """
405
+ if not (0 < theta_s <= 10):
406
+ raise ValueError("theta_s must be between 0 and 10.")
407
+ if not (0 < theta_b <= 4):
408
+ raise ValueError("theta_b must be between 0 and 4.")
409
+
410
+ C = (1 - np.cosh(theta_s * sigma)) / (np.cosh(theta_s) - 1)
411
+ C = (np.exp(theta_b * C) - 1) / (1 - np.exp(-theta_b))
412
+
413
+ return C
414
+
415
+
416
+ def sigma_stretch(theta_s, theta_b, N, type):
417
+ """
418
+ Compute sigma and stretching curves based on the type and parameters.
419
+
420
+ Parameters
421
+ ----------
422
+ theta_s : float
423
+ The surface control parameter.
424
+ theta_b : float
425
+ The bottom control parameter.
426
+ N : int
427
+ The number of vertical levels.
428
+ type : str
429
+ The type of sigma ('w' for vertical velocity points, 'r' for rho-points).
430
+
431
+ Returns
432
+ -------
433
+ cs : xr.DataArray
434
+ The stretching curve values.
435
+ sigma : xr.DataArray
436
+ The sigma-coordinate values.
437
+
438
+ Raises
439
+ ------
440
+ ValueError
441
+ If the type is not 'w' or 'r'.
442
+ """
443
+ if type == "w":
444
+ k = xr.DataArray(np.arange(N + 1), dims="s_w")
445
+ sigma = (k - N) / N
446
+ elif type == "r":
447
+ k = xr.DataArray(np.arange(1, N + 1), dims="s_rho")
448
+ sigma = (k - N - 0.5) / N
449
+ else:
450
+ raise ValueError(
451
+ "Type must be either 'w' for vertical velocity points or 'r' for rho-points."
452
+ )
453
+
454
+ cs = compute_cs(sigma, theta_s, theta_b)
455
+
456
+ return cs, sigma
457
+
458
+
459
+ def compute_depth(zeta, h, hc, cs, sigma):
460
+ """
461
+ Compute the depth at different sigma levels.
462
+
463
+ Parameters
464
+ ----------
465
+ zeta : xr.DataArray
466
+ The sea surface height.
467
+ h : xr.DataArray
468
+ The depth of the sea bottom.
469
+ hc : float
470
+ The critical depth.
471
+ cs : xr.DataArray
472
+ The stretching curve values.
473
+ sigma : xr.DataArray
474
+ The sigma-coordinate values.
475
+
476
+ Returns
477
+ -------
478
+ z : xr.DataArray
479
+ The depth at different sigma levels.
480
+
481
+ Raises
482
+ ------
483
+ ValueError
484
+ If theta_s or theta_b are less than or equal to zero.
485
+ """
486
+
487
+ # Expand dimensions
488
+ sigma = sigma.expand_dims(dim={"eta_rho": h.eta_rho, "xi_rho": h.xi_rho})
489
+ cs = cs.expand_dims(dim={"eta_rho": h.eta_rho, "xi_rho": h.xi_rho})
490
+
491
+ s = (hc * sigma + h * cs) / (hc + h)
492
+ z = zeta + (zeta + h) * s
493
+
494
+ return z