roms-tools 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,494 @@
1
+ import numpy as np
2
+ import xarray as xr
3
+ import yaml
4
+ import importlib.metadata
5
+ from dataclasses import dataclass, field, asdict
6
+ from roms_tools.setup.grid import Grid
7
+ from roms_tools.setup.utils import (
8
+ interpolate_from_rho_to_u,
9
+ interpolate_from_rho_to_v,
10
+ )
11
+ from roms_tools.setup.plot import _plot, _section_plot, _profile_plot, _line_plot
12
+ import matplotlib.pyplot as plt
13
+
14
+
15
+ @dataclass(frozen=True, kw_only=True)
16
+ class VerticalCoordinate:
17
+ """
18
+ Represents vertical coordinate for ROMS.
19
+
20
+ Parameters
21
+ ----------
22
+ grid : Grid
23
+ Object representing the grid information.
24
+ N : int
25
+ The number of vertical levels.
26
+ theta_s : float
27
+ The surface control parameter. Must satisfy 0 < theta_s <= 10.
28
+ theta_b : float
29
+ The bottom control parameter. Must satisfy 0 < theta_b <= 4.
30
+ hc : float
31
+ The critical depth.
32
+
33
+ Attributes
34
+ ----------
35
+ ds : xr.Dataset
36
+ Xarray Dataset containing the atmospheric forcing data.
37
+ """
38
+
39
+ grid: Grid
40
+ N: int
41
+ theta_s: float
42
+ theta_b: float
43
+ hc: float
44
+
45
+ ds: xr.Dataset = field(init=False, repr=False)
46
+
47
+ def __post_init__(self):
48
+
49
+ h = self.grid.ds.h
50
+
51
+ cs_r, sigma_r = sigma_stretch(self.theta_s, self.theta_b, self.N, "r")
52
+ zr = compute_depth(h * 0, h, self.hc, cs_r, sigma_r)
53
+ cs_w, sigma_w = sigma_stretch(self.theta_s, self.theta_b, self.N, "w")
54
+ zw = compute_depth(h * 0, h, self.hc, cs_w, sigma_w)
55
+
56
+ ds = xr.Dataset()
57
+
58
+ ds["theta_s"] = np.float32(self.theta_s)
59
+ ds["theta_s"].attrs["long_name"] = "S-coordinate surface control parameter"
60
+ ds["theta_s"].attrs["units"] = "nondimensional"
61
+
62
+ ds["theta_b"] = np.float32(self.theta_b)
63
+ ds["theta_b"].attrs["long_name"] = "S-coordinate bottom control parameter"
64
+ ds["theta_b"].attrs["units"] = "nondimensional"
65
+
66
+ ds["Tcline"] = np.float32(self.hc)
67
+ ds["Tcline"].attrs["long_name"] = "S-coordinate surface/bottom layer width"
68
+ ds["Tcline"].attrs["units"] = "m"
69
+
70
+ ds["hc"] = np.float32(self.hc)
71
+ ds["hc"].attrs["long_name"] = "S-coordinate parameter critical depth"
72
+ ds["hc"].attrs["units"] = "m"
73
+
74
+ ds["sc_r"] = sigma_r.astype(np.float32)
75
+ ds["sc_r"].attrs["long_name"] = "S-coordinate at rho-points"
76
+ ds["sc_r"].attrs["units"] = "nondimensional"
77
+
78
+ ds["Cs_r"] = cs_r.astype(np.float32)
79
+ ds["Cs_r"].attrs["long_name"] = "S-coordinate stretching curves at rho-points"
80
+ ds["Cs_r"].attrs["units"] = "nondimensional"
81
+
82
+ depth = -zr
83
+ depth.attrs["long_name"] = "Layer depth at rho-points"
84
+ depth.attrs["units"] = "m"
85
+ ds = ds.assign_coords({"layer_depth_rho": depth.astype(np.float32)})
86
+
87
+ depth_u = interpolate_from_rho_to_u(depth).astype(np.float32)
88
+ depth_u.attrs["long_name"] = "Layer depth at u-points"
89
+ depth_u.attrs["units"] = "m"
90
+ ds = ds.assign_coords({"layer_depth_u": depth_u})
91
+
92
+ depth_v = interpolate_from_rho_to_v(depth).astype(np.float32)
93
+ depth_v.attrs["long_name"] = "Layer depth at v-points"
94
+ depth_v.attrs["units"] = "m"
95
+ ds = ds.assign_coords({"layer_depth_v": depth_v})
96
+
97
+ depth = -zw
98
+ depth.attrs["long_name"] = "Interface depth at rho-points"
99
+ depth.attrs["units"] = "m"
100
+ ds = ds.assign_coords({"interface_depth_rho": depth.astype(np.float32)})
101
+
102
+ depth_u = interpolate_from_rho_to_u(depth).astype(np.float32)
103
+ depth_u.attrs["long_name"] = "Interface depth at u-points"
104
+ depth_u.attrs["units"] = "m"
105
+ ds = ds.assign_coords({"interface_depth_u": depth_u})
106
+
107
+ depth_v = interpolate_from_rho_to_v(depth).astype(np.float32)
108
+ depth_v.attrs["long_name"] = "Interface depth at v-points"
109
+ depth_v.attrs["units"] = "m"
110
+ ds = ds.assign_coords({"interface_depth_v": depth_v})
111
+
112
+ ds = ds.drop_vars(["eta_rho", "xi_rho"])
113
+
114
+ ds.attrs["title"] = "ROMS vertical coordinate file created by ROMS-Tools"
115
+ # Include the version of roms-tools
116
+ try:
117
+ roms_tools_version = importlib.metadata.version("roms-tools")
118
+ except importlib.metadata.PackageNotFoundError:
119
+ roms_tools_version = "unknown"
120
+ ds.attrs["roms_tools_version"] = roms_tools_version
121
+
122
+ object.__setattr__(self, "ds", ds)
123
+
124
+ def plot(
125
+ self,
126
+ varname="layer_depth_rho",
127
+ s=None,
128
+ eta=None,
129
+ xi=None,
130
+ ) -> None:
131
+ """
132
+ Plot the vertical coordinate system for a given eta-, xi-, or s-slice.
133
+
134
+ Parameters
135
+ ----------
136
+ varname : str, optional
137
+ The field to plot. Options are "depth_rho", "depth_u", "depth_v".
138
+ s: int, optional
139
+ The s-index to plot. Default is None.
140
+ eta : int, optional
141
+ The eta-index to plot. Default is None.
142
+ xi : int, optional
143
+ The xi-index to plot. Default is None.
144
+
145
+ Returns
146
+ -------
147
+ None
148
+ This method does not return any value. It generates and displays a plot.
149
+
150
+ Raises
151
+ ------
152
+ ValueError
153
+ If the specified varname is not one of the valid options.
154
+ If none of s, eta, xi are specified.
155
+ """
156
+
157
+ if not any([s is not None, eta is not None, xi is not None]):
158
+ raise ValueError("At least one of s, eta, or xi must be specified.")
159
+
160
+ self.ds[varname].load()
161
+ field = self.ds[varname].squeeze()
162
+
163
+ if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
164
+ interface_depth = self.ds.interface_depth_rho
165
+ elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
166
+ interface_depth = self.ds.interface_depth_u
167
+ elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
168
+ interface_depth = self.ds.interface_depth_v
169
+
170
+ # slice the field as desired
171
+ title = field.long_name
172
+ if s is not None:
173
+ if "s_rho" in field.dims:
174
+ title = title + f", s_rho = {field.s_rho[s].item()}"
175
+ field = field.isel(s_rho=s)
176
+ elif "s_w" in field.dims:
177
+ title = title + f", s_w = {field.s_w[s].item()}"
178
+ field = field.isel(s_w=s)
179
+ else:
180
+ raise ValueError(
181
+ f"None of the expected dimensions (s_rho, s_w) found in ds[{varname}]."
182
+ )
183
+
184
+ if eta is not None:
185
+ if "eta_rho" in field.dims:
186
+ title = title + f", eta_rho = {field.eta_rho[eta].item()}"
187
+ field = field.isel(eta_rho=eta)
188
+ interface_depth = interface_depth.isel(eta_rho=eta)
189
+ elif "eta_v" in field.dims:
190
+ title = title + f", eta_v = {field.eta_v[eta].item()}"
191
+ field = field.isel(eta_v=eta)
192
+ interface_depth = interface_depth.isel(eta_v=eta)
193
+ else:
194
+ raise ValueError(
195
+ f"None of the expected dimensions (eta_rho, eta_v) found in ds[{varname}]."
196
+ )
197
+ if xi is not None:
198
+ if "xi_rho" in field.dims:
199
+ title = title + f", xi_rho = {field.xi_rho[xi].item()}"
200
+ field = field.isel(xi_rho=xi)
201
+ interface_depth = interface_depth.isel(xi_rho=xi)
202
+ elif "xi_u" in field.dims:
203
+ title = title + f", xi_u = {field.xi_u[xi].item()}"
204
+ field = field.isel(xi_u=xi)
205
+ interface_depth = interface_depth.isel(xi_u=xi)
206
+ else:
207
+ raise ValueError(
208
+ f"None of the expected dimensions (xi_rho, xi_u) found in ds[{varname}]."
209
+ )
210
+
211
+ if eta is None and xi is None:
212
+ vmax = field.max().values
213
+ vmin = field.min().values
214
+ cmap = plt.colormaps.get_cmap("YlGnBu")
215
+ cmap.set_bad(color="gray")
216
+ kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
217
+
218
+ _plot(
219
+ self.grid.ds,
220
+ field=field,
221
+ straddle=self.grid.straddle,
222
+ depth_contours=True,
223
+ title=title,
224
+ kwargs=kwargs,
225
+ c="g",
226
+ )
227
+ else:
228
+ if len(field.dims) == 2:
229
+ cmap = plt.colormaps.get_cmap("YlGnBu")
230
+ cmap.set_bad(color="gray")
231
+ kwargs = {"vmax": 0.0, "vmin": 0.0, "cmap": cmap, "add_colorbar": False}
232
+
233
+ _section_plot(
234
+ xr.zeros_like(field),
235
+ interface_depth=interface_depth,
236
+ title=title,
237
+ kwargs=kwargs,
238
+ )
239
+ else:
240
+ if "s_rho" in field.dims or "s_w" in field.dims:
241
+ _profile_plot(field, title=title)
242
+ else:
243
+ _line_plot(field, title=title)
244
+
245
+ def save(self, filepath: str) -> None:
246
+ """
247
+ Save the vertical coordinate information to a netCDF4 file.
248
+
249
+ Parameters
250
+ ----------
251
+ filepath
252
+ """
253
+ self.ds.to_netcdf(filepath)
254
+
255
+ def to_yaml(self, filepath: str) -> None:
256
+ """
257
+ Export the parameters of the class to a YAML file, including the version of roms-tools.
258
+
259
+ Parameters
260
+ ----------
261
+ filepath : str
262
+ The path to the YAML file where the parameters will be saved.
263
+ """
264
+ # Serialize Grid data
265
+ grid_data = asdict(self.grid)
266
+ grid_data.pop("ds", None) # Exclude non-serializable fields
267
+ grid_data.pop("straddle", None)
268
+
269
+ # Include the version of roms-tools
270
+ try:
271
+ roms_tools_version = importlib.metadata.version("roms-tools")
272
+ except importlib.metadata.PackageNotFoundError:
273
+ roms_tools_version = "unknown"
274
+
275
+ # Create header
276
+ header = f"---\nroms_tools_version: {roms_tools_version}\n---\n"
277
+
278
+ grid_yaml_data = {"Grid": grid_data}
279
+
280
+ # Combine all sections
281
+ vertical_coordinate_data = {
282
+ "VerticalCoordinate": {
283
+ "N": self.N,
284
+ "theta_s": self.theta_s,
285
+ "theta_b": self.theta_b,
286
+ "hc": self.hc,
287
+ }
288
+ }
289
+
290
+ # Merge YAML data while excluding empty sections
291
+ yaml_data = {
292
+ **grid_yaml_data,
293
+ **vertical_coordinate_data,
294
+ }
295
+
296
+ with open(filepath, "w") as file:
297
+ # Write header
298
+ file.write(header)
299
+ # Write YAML data
300
+ yaml.dump(yaml_data, file, default_flow_style=False)
301
+
302
+ @classmethod
303
+ def from_file(cls, filepath: str) -> "VerticalCoordinate":
304
+ """
305
+ Create a VerticalCoordinate instance from an existing file.
306
+
307
+ Parameters
308
+ ----------
309
+ filepath : str
310
+ Path to the file containing the vertical coordinate information.
311
+
312
+ Returns
313
+ -------
314
+ VerticalCoordinate
315
+ A new instance of VerticalCoordinate populated with data from the file.
316
+ """
317
+ # Load the dataset from the file
318
+ ds = xr.open_dataset(filepath)
319
+
320
+ # Create a new VerticalCoordinate instance without calling __init__ and __post_init__
321
+ vertical_coordinate = cls.__new__(cls)
322
+
323
+ # Set the dataset for the vertical_corodinate instance
324
+ object.__setattr__(vertical_coordinate, "ds", ds)
325
+
326
+ # Manually set the remaining attributes by extracting parameters from dataset
327
+ object.__setattr__(vertical_coordinate, "N", ds.sizes["s_rho"])
328
+ object.__setattr__(vertical_coordinate, "theta_s", ds["theta_s"].values.item())
329
+ object.__setattr__(vertical_coordinate, "theta_b", ds["theta_b"].values.item())
330
+ object.__setattr__(vertical_coordinate, "hc", ds["hc"].values.item())
331
+ object.__setattr__(vertical_coordinate, "grid", None)
332
+
333
+ return vertical_coordinate
334
+
335
+ @classmethod
336
+ def from_yaml(cls, filepath: str) -> "VerticalCoordinate":
337
+ """
338
+ Create an instance of the VerticalCoordinate class from a YAML file.
339
+
340
+ Parameters
341
+ ----------
342
+ filepath : str
343
+ The path to the YAML file from which the parameters will be read.
344
+
345
+ Returns
346
+ -------
347
+ VerticalCoordinate
348
+ An instance of the VerticalCoordinate class.
349
+ """
350
+ # Read the entire file content
351
+ with open(filepath, "r") as file:
352
+ file_content = file.read()
353
+
354
+ # Split the content into YAML documents
355
+ documents = list(yaml.safe_load_all(file_content))
356
+
357
+ vertical_coordinate_data = None
358
+
359
+ # Process the YAML documents
360
+ for doc in documents:
361
+ if doc is None:
362
+ continue
363
+ if "VerticalCoordinate" in doc:
364
+ vertical_coordinate_data = doc["VerticalCoordinate"]
365
+ break
366
+
367
+ if vertical_coordinate_data is None:
368
+ raise ValueError(
369
+ "No VerticalCoordinate configuration found in the YAML file."
370
+ )
371
+
372
+ # Create Grid instance from the YAML file
373
+ grid = Grid.from_yaml(filepath)
374
+
375
+ # Create and return an instance of TidalForcing
376
+ return cls(
377
+ grid=grid,
378
+ **vertical_coordinate_data,
379
+ )
380
+
381
+
382
+ def compute_cs(sigma, theta_s, theta_b):
383
+ """
384
+ Compute the S-coordinate stretching curves according to Shchepetkin and McWilliams (2009).
385
+
386
+ Parameters
387
+ ----------
388
+ sigma : np.ndarray or float
389
+ The sigma-coordinate values.
390
+ theta_s : float
391
+ The surface control parameter.
392
+ theta_b : float
393
+ The bottom control parameter.
394
+
395
+ Returns
396
+ -------
397
+ C : np.ndarray or float
398
+ The stretching curve values.
399
+
400
+ Raises
401
+ ------
402
+ ValueError
403
+ If theta_s or theta_b are not within the valid range.
404
+ """
405
+ if not (0 < theta_s <= 10):
406
+ raise ValueError("theta_s must be between 0 and 10.")
407
+ if not (0 < theta_b <= 4):
408
+ raise ValueError("theta_b must be between 0 and 4.")
409
+
410
+ C = (1 - np.cosh(theta_s * sigma)) / (np.cosh(theta_s) - 1)
411
+ C = (np.exp(theta_b * C) - 1) / (1 - np.exp(-theta_b))
412
+
413
+ return C
414
+
415
+
416
+ def sigma_stretch(theta_s, theta_b, N, type):
417
+ """
418
+ Compute sigma and stretching curves based on the type and parameters.
419
+
420
+ Parameters
421
+ ----------
422
+ theta_s : float
423
+ The surface control parameter.
424
+ theta_b : float
425
+ The bottom control parameter.
426
+ N : int
427
+ The number of vertical levels.
428
+ type : str
429
+ The type of sigma ('w' for vertical velocity points, 'r' for rho-points).
430
+
431
+ Returns
432
+ -------
433
+ cs : xr.DataArray
434
+ The stretching curve values.
435
+ sigma : xr.DataArray
436
+ The sigma-coordinate values.
437
+
438
+ Raises
439
+ ------
440
+ ValueError
441
+ If the type is not 'w' or 'r'.
442
+ """
443
+ if type == "w":
444
+ k = xr.DataArray(np.arange(N + 1), dims="s_w")
445
+ sigma = (k - N) / N
446
+ elif type == "r":
447
+ k = xr.DataArray(np.arange(1, N + 1), dims="s_rho")
448
+ sigma = (k - N - 0.5) / N
449
+ else:
450
+ raise ValueError(
451
+ "Type must be either 'w' for vertical velocity points or 'r' for rho-points."
452
+ )
453
+
454
+ cs = compute_cs(sigma, theta_s, theta_b)
455
+
456
+ return cs, sigma
457
+
458
+
459
+ def compute_depth(zeta, h, hc, cs, sigma):
460
+ """
461
+ Compute the depth at different sigma levels.
462
+
463
+ Parameters
464
+ ----------
465
+ zeta : xr.DataArray
466
+ The sea surface height.
467
+ h : xr.DataArray
468
+ The depth of the sea bottom.
469
+ hc : float
470
+ The critical depth.
471
+ cs : xr.DataArray
472
+ The stretching curve values.
473
+ sigma : xr.DataArray
474
+ The sigma-coordinate values.
475
+
476
+ Returns
477
+ -------
478
+ z : xr.DataArray
479
+ The depth at different sigma levels.
480
+
481
+ Raises
482
+ ------
483
+ ValueError
484
+ If theta_s or theta_b are less than or equal to zero.
485
+ """
486
+
487
+ # Expand dimensions
488
+ sigma = sigma.expand_dims(dim={"eta_rho": h.eta_rho, "xi_rho": h.xi_rho})
489
+ cs = cs.expand_dims(dim={"eta_rho": h.eta_rho, "xi_rho": h.xi_rho})
490
+
491
+ s = (hc * sigma + h * cs) / (hc + h)
492
+ z = zeta + (zeta + h) * s
493
+
494
+ return z