roms-tools 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,7 +5,6 @@ import importlib.metadata
5
5
  from dataclasses import dataclass, field, asdict
6
6
  from typing import Optional, Dict, Union
7
7
  from roms_tools.setup.grid import Grid
8
- from roms_tools.setup.vertical_coordinate import VerticalCoordinate
9
8
  from datetime import datetime
10
9
  from roms_tools.setup.datasets import GLORYSDataset, CESMBGCDataset
11
10
  from roms_tools.setup.utils import (
@@ -25,8 +24,6 @@ class InitialConditions(ROMSToolsMixins):
25
24
  ----------
26
25
  grid : Grid
27
26
  Object representing the grid information used for the model.
28
- vertical_coordinate : VerticalCoordinate
29
- Object representing the vertical coordinate system.
30
27
  ini_time : datetime
31
28
  The date and time at which the initial conditions are set.
32
29
  physics_source : Dict[str, Union[str, None]]
@@ -51,7 +48,6 @@ class InitialConditions(ROMSToolsMixins):
51
48
  --------
52
49
  >>> initial_conditions = InitialConditions(
53
50
  ... grid=grid,
54
- ... vertical_coordinate=vertical_coordinate,
55
51
  ... ini_time=datetime(2022, 1, 1),
56
52
  ... physics_source={"name": "GLORYS", "path": "physics_data.nc"},
57
53
  ... bgc_source={
@@ -63,7 +59,6 @@ class InitialConditions(ROMSToolsMixins):
63
59
  """
64
60
 
65
61
  grid: Grid
66
- vertical_coordinate: VerticalCoordinate
67
62
  ini_time: datetime
68
63
  physics_source: Dict[str, Union[str, None]]
69
64
  bgc_source: Optional[Dict[str, Union[str, None]]] = None
@@ -202,9 +197,7 @@ class InitialConditions(ROMSToolsMixins):
202
197
 
203
198
  # initialize vertical velocity to zero
204
199
  ds["w"] = xr.zeros_like(
205
- self.vertical_coordinate.ds["interface_depth_rho"].expand_dims(
206
- time=data_vars["u"].time
207
- )
200
+ self.grid.ds["interface_depth_rho"].expand_dims(time=data_vars["u"].time)
208
201
  ).astype(np.float32)
209
202
  ds["w"].attrs["long_name"] = d_meta["w"]["long_name"]
210
203
  ds["w"].attrs["units"] = d_meta["w"]["units"]
@@ -215,10 +208,10 @@ class InitialConditions(ROMSToolsMixins):
215
208
 
216
209
  ds = ds.assign_coords(
217
210
  {
218
- "layer_depth_u": self.vertical_coordinate.ds["layer_depth_u"],
219
- "layer_depth_v": self.vertical_coordinate.ds["layer_depth_v"],
220
- "interface_depth_u": self.vertical_coordinate.ds["interface_depth_u"],
221
- "interface_depth_v": self.vertical_coordinate.ds["interface_depth_v"],
211
+ "layer_depth_u": self.grid.ds["layer_depth_u"],
212
+ "layer_depth_v": self.grid.ds["layer_depth_v"],
213
+ "interface_depth_u": self.grid.ds["interface_depth_u"],
214
+ "interface_depth_v": self.grid.ds["interface_depth_v"],
222
215
  }
223
216
  )
224
217
 
@@ -246,12 +239,11 @@ class InitialConditions(ROMSToolsMixins):
246
239
  ] = f"seconds since {np.datetime_as_string(model_reference_date, unit='s')}"
247
240
  ds["ocean_time"].attrs["units"] = "seconds"
248
241
 
249
- ds.attrs["theta_s"] = self.vertical_coordinate.ds["theta_s"].item()
250
- ds.attrs["theta_b"] = self.vertical_coordinate.ds["theta_b"].item()
251
- ds.attrs["Tcline"] = self.vertical_coordinate.ds["Tcline"].item()
252
- ds.attrs["hc"] = self.vertical_coordinate.ds["hc"].item()
253
- ds["sc_r"] = self.vertical_coordinate.ds["sc_r"]
254
- ds["Cs_r"] = self.vertical_coordinate.ds["Cs_r"]
242
+ ds.attrs["theta_s"] = self.grid.ds.attrs["theta_s"]
243
+ ds.attrs["theta_b"] = self.grid.ds.attrs["theta_b"]
244
+ ds.attrs["hc"] = self.grid.ds.attrs["hc"]
245
+ ds["sc_r"] = self.grid.ds["sc_r"]
246
+ ds["Cs_r"] = self.grid.ds["Cs_r"]
255
247
 
256
248
  ds = ds.drop_vars(["s_rho"])
257
249
 
@@ -462,11 +454,6 @@ class InitialConditions(ROMSToolsMixins):
462
454
  grid_data.pop("ds", None) # Exclude non-serializable fields
463
455
  grid_data.pop("straddle", None)
464
456
 
465
- # Serialize VerticalCoordinate data
466
- vertical_coordinate_data = asdict(self.vertical_coordinate)
467
- vertical_coordinate_data.pop("ds", None) # Exclude non-serializable fields
468
- vertical_coordinate_data.pop("grid", None) # Exclude non-serializable fields
469
-
470
457
  # Include the version of roms-tools
471
458
  try:
472
459
  roms_tools_version = importlib.metadata.version("roms-tools")
@@ -477,7 +464,6 @@ class InitialConditions(ROMSToolsMixins):
477
464
  header = f"---\nroms_tools_version: {roms_tools_version}\n---\n"
478
465
 
479
466
  grid_yaml_data = {"Grid": grid_data}
480
- vertical_coordinate_yaml_data = {"VerticalCoordinate": vertical_coordinate_data}
481
467
 
482
468
  initial_conditions_data = {
483
469
  "InitialConditions": {
@@ -492,7 +478,6 @@ class InitialConditions(ROMSToolsMixins):
492
478
 
493
479
  yaml_data = {
494
480
  **grid_yaml_data,
495
- **vertical_coordinate_yaml_data,
496
481
  **initial_conditions_data,
497
482
  }
498
483
 
@@ -545,13 +530,10 @@ class InitialConditions(ROMSToolsMixins):
545
530
  initial_conditions_data[date_string]
546
531
  )
547
532
 
548
- # Create VerticalCoordinate instance from the YAML file
549
- vertical_coordinate = VerticalCoordinate.from_yaml(filepath)
550
- grid = vertical_coordinate.grid
533
+ grid = Grid.from_yaml(filepath)
551
534
 
552
535
  # Create and return an instance of InitialConditions
553
536
  return cls(
554
537
  grid=grid,
555
- vertical_coordinate=vertical_coordinate,
556
538
  **initial_conditions_data,
557
539
  )
@@ -1,6 +1,5 @@
1
1
  from dataclasses import dataclass
2
2
  from roms_tools.setup.grid import Grid
3
- from roms_tools.setup.vertical_coordinate import VerticalCoordinate
4
3
  from roms_tools.setup.fill import fill_and_interpolate
5
4
  from roms_tools.setup.utils import (
6
5
  extrapolate_deepest_to_bottom,
@@ -21,13 +20,10 @@ class ROMSToolsMixins:
21
20
  ----------
22
21
  grid : Grid
23
22
  Object representing the grid information used for the model.
24
- vertical_coordinate : VerticalCoordinate
25
- Object representing the vertical coordinate system. Defaults to None.
26
23
 
27
24
  """
28
25
 
29
26
  grid: Grid
30
- vertical_coordinate: VerticalCoordinate = None
31
27
 
32
28
  def get_target_lon_lat(self, use_coarse_grid=False):
33
29
  """
@@ -54,11 +50,6 @@ class ROMSToolsMixins:
54
50
  """
55
51
 
56
52
  if use_coarse_grid:
57
- if "lon_coarse" not in self.grid.ds:
58
- raise ValueError(
59
- "Grid has not been coarsened yet. Execute grid.coarsen() first."
60
- )
61
-
62
53
  lon = self.grid.ds.lon_coarse
63
54
  lat = self.grid.ds.lat_coarse
64
55
  angle = self.grid.ds.angle_coarse
@@ -134,7 +125,7 @@ class ROMSToolsMixins:
134
125
  coords = {
135
126
  data.dim_names["latitude"]: lat,
136
127
  data.dim_names["longitude"]: lon,
137
- data.dim_names["depth"]: self.vertical_coordinate.ds["layer_depth_rho"],
128
+ data.dim_names["depth"]: self.grid.ds["layer_depth_rho"],
138
129
  }
139
130
  # extrapolate deepest value all the way to bottom ("flooding")
140
131
  for var in vars_3d:
@@ -167,8 +158,7 @@ class ROMSToolsMixins:
167
158
  This method performs the following steps:
168
159
  1. Rotates the velocity components to align with the grid orientation using the provided angle.
169
160
  2. Optionally interpolates the rotated velocities to the u- and v-points of the grid.
170
- 3. If a vertical coordinate is provided, computes the barotropic velocities by integrating
171
- over the vertical dimension.
161
+ 3. If the velocities are 3D (with vertical coordinates), computes barotropic (depth-averaged) velocities.
172
162
 
173
163
  Parameters
174
164
  ----------
@@ -183,8 +173,8 @@ class ROMSToolsMixins:
183
173
  Returns
184
174
  -------
185
175
  dict of str: xarray.DataArray
186
- Dictionary of processed velocity components. Includes "ubar" and "vbar" if a vertical coordinate
187
- is provided.
176
+ Dictionary of processed velocity components. Includes "ubar" and "vbar" if the velocity components
177
+ have vertical coordinates and are processed for barotropic (depth-averaged) velocities.
188
178
  """
189
179
  # Determine the correct variable names based on the keys in data_vars
190
180
  uname = "u" if "u" in data_vars else "uwnd"
@@ -202,7 +192,7 @@ class ROMSToolsMixins:
202
192
  data_vars[uname] = u_rot
203
193
  data_vars[vname] = v_rot
204
194
 
205
- if self.vertical_coordinate is not None:
195
+ if "s_rho" in data_vars[uname].dims and "s_rho" in data_vars[vname].dims:
206
196
  # 3D masks for ROMS domain
207
197
  umask = self.grid.ds.mask_u.expand_dims({"s_rho": data_vars[uname].s_rho})
208
198
  vmask = self.grid.ds.mask_v.expand_dims({"s_rho": data_vars[vname].s_rho})
@@ -211,7 +201,7 @@ class ROMSToolsMixins:
211
201
  data_vars[vname] = data_vars[vname] * vmask
212
202
 
213
203
  # Compute barotropic velocity
214
- dz = -self.vertical_coordinate.ds["interface_depth_rho"].diff(dim="s_w")
204
+ dz = -self.grid.ds["interface_depth_rho"].diff(dim="s_w")
215
205
  dz = dz.rename({"s_w": "s_rho"})
216
206
  dzu = interpolate_from_rho_to_u(dz)
217
207
  dzv = interpolate_from_rho_to_v(dz)
@@ -10,8 +10,42 @@ from itertools import count
10
10
 
11
11
 
12
12
  def _add_topography_and_mask(
13
- ds, topography_source, smooth_factor, hmin, rmax
13
+ ds, topography_source, hmin, smooth_factor=8.0, rmax=0.2
14
14
  ) -> xr.Dataset:
15
+ """
16
+ Adds topography and a land/water mask to the dataset based on the provided topography source.
17
+
18
+ This function performs the following operations:
19
+ 1. Interpolates topography data onto the desired grid.
20
+ 2. Applies a mask based on ocean depth.
21
+ 3. Smooths the topography globally to reduce grid-scale instabilities.
22
+ 4. Fills enclosed basins with land.
23
+ 5. Smooths the topography locally to ensure the steepness ratio satisfies the rmax criterion.
24
+ 6. Adds topography metadata.
25
+
26
+ Parameters
27
+ ----------
28
+ ds : xr.Dataset
29
+ The dataset to which topography and the land/water mask will be added.
30
+ topography_source : str
31
+ The source of the topography data.
32
+ hmin : float
33
+ The minimum allowable depth for the topography.
34
+ smooth_factor : float, optional
35
+ The smoothing factor used in the domain-wide Gaussian smoothing of the
36
+ topography. Smaller values result in less smoothing, while larger
37
+ values produce more smoothing. The default is 8.0.
38
+ rmax : float, optional
39
+ The maximum allowable steepness ratio for the topography smoothing.
40
+ This parameter controls the local smoothing of the topography. Smaller values result in
41
+ smoother topography, while larger values preserve more detail. The default is 0.2.
42
+
43
+ Returns
44
+ -------
45
+ xr.Dataset
46
+ The dataset with added topography, mask, and metadata.
47
+ """
48
+
15
49
  lon = ds.lon_rho.values
16
50
  lat = ds.lat_rho.values
17
51
 
@@ -23,16 +57,11 @@ def _add_topography_and_mask(
23
57
  mask = xr.where(hraw > 0, 1.0, 0.0)
24
58
 
25
59
  # smooth topography domain-wide with Gaussian kernel to avoid grid scale instabilities
26
- ds["hraw"] = _smooth_topography_globally(hraw, mask, smooth_factor)
27
- ds["hraw"].attrs = {
28
- "long_name": "Working bathymetry at rho-points",
29
- "source": f"Raw bathymetry from {topography_source} (smoothing diameter {smooth_factor})",
30
- "units": "meter",
31
- }
60
+ hraw = _smooth_topography_globally(hraw, mask, smooth_factor)
32
61
 
33
62
  # fill enclosed basins with land
34
63
  mask = _fill_enclosed_basins(mask.values)
35
- ds["mask_rho"] = xr.DataArray(mask, dims=("eta_rho", "xi_rho"))
64
+ ds["mask_rho"] = xr.DataArray(mask.astype(np.int32), dims=("eta_rho", "xi_rho"))
36
65
  ds["mask_rho"].attrs = {
37
66
  "long_name": "Mask at rho-points",
38
67
  "units": "land/water (0/1)",
@@ -41,7 +70,7 @@ def _add_topography_and_mask(
41
70
  ds = _add_velocity_masks(ds)
42
71
 
43
72
  # smooth topography locally to satisfy r < rmax
44
- ds["h"] = _smooth_topography_locally(ds["hraw"] * ds["mask_rho"], hmin, rmax)
73
+ ds["h"] = _smooth_topography_locally(hraw * ds["mask_rho"], hmin, rmax)
45
74
  ds["h"].attrs = {
46
75
  "long_name": "Final bathymetry at rho-points",
47
76
  "units": "meter",
@@ -238,9 +267,7 @@ def _compute_rfactor(h):
238
267
 
239
268
  def _add_topography_metadata(ds, topography_source, smooth_factor, hmin, rmax):
240
269
  ds.attrs["topography_source"] = topography_source
241
- ds.attrs["smooth_factor"] = smooth_factor
242
270
  ds.attrs["hmin"] = hmin
243
- ds.attrs["rmax"] = rmax
244
271
 
245
272
  return ds
246
273
 
@@ -248,8 +275,12 @@ def _add_topography_metadata(ds, topography_source, smooth_factor, hmin, rmax):
248
275
  def _add_velocity_masks(ds):
249
276
 
250
277
  # add u- and v-masks
251
- ds["mask_u"] = interpolate_from_rho_to_u(ds["mask_rho"], method="multiplicative")
252
- ds["mask_v"] = interpolate_from_rho_to_v(ds["mask_rho"], method="multiplicative")
278
+ ds["mask_u"] = interpolate_from_rho_to_u(
279
+ ds["mask_rho"], method="multiplicative"
280
+ ).astype(np.int32)
281
+ ds["mask_v"] = interpolate_from_rho_to_v(
282
+ ds["mask_rho"], method="multiplicative"
283
+ ).astype(np.int32)
253
284
 
254
285
  ds["mask_u"].attrs = {"long_name": "Mask at u-points", "units": "land/water (0/1)"}
255
286
  ds["mask_v"].attrs = {"long_name": "Mask at v-points", "units": "land/water (0/1)"}
@@ -1,382 +1,5 @@
1
1
  import numpy as np
2
2
  import xarray as xr
3
- import yaml
4
- import importlib.metadata
5
- from dataclasses import dataclass, field, asdict
6
- from roms_tools.setup.grid import Grid
7
- from roms_tools.setup.utils import (
8
- interpolate_from_rho_to_u,
9
- interpolate_from_rho_to_v,
10
- )
11
- from roms_tools.setup.plot import _plot, _section_plot, _profile_plot, _line_plot
12
- import matplotlib.pyplot as plt
13
-
14
-
15
- @dataclass(frozen=True, kw_only=True)
16
- class VerticalCoordinate:
17
- """
18
- Represents vertical coordinate for ROMS.
19
-
20
- Parameters
21
- ----------
22
- grid : Grid
23
- Object representing the grid information.
24
- N : int
25
- The number of vertical levels.
26
- theta_s : float
27
- The surface control parameter. Must satisfy 0 < theta_s <= 10.
28
- theta_b : float
29
- The bottom control parameter. Must satisfy 0 < theta_b <= 4.
30
- hc : float
31
- The critical depth.
32
-
33
- Attributes
34
- ----------
35
- ds : xr.Dataset
36
- Xarray Dataset containing the atmospheric forcing data.
37
- """
38
-
39
- grid: Grid
40
- N: int
41
- theta_s: float
42
- theta_b: float
43
- hc: float
44
-
45
- ds: xr.Dataset = field(init=False, repr=False)
46
-
47
- def __post_init__(self):
48
-
49
- h = self.grid.ds.h
50
-
51
- cs_r, sigma_r = sigma_stretch(self.theta_s, self.theta_b, self.N, "r")
52
- zr = compute_depth(h * 0, h, self.hc, cs_r, sigma_r)
53
- cs_w, sigma_w = sigma_stretch(self.theta_s, self.theta_b, self.N, "w")
54
- zw = compute_depth(h * 0, h, self.hc, cs_w, sigma_w)
55
-
56
- ds = xr.Dataset()
57
-
58
- ds["theta_s"] = np.float32(self.theta_s)
59
- ds["theta_s"].attrs["long_name"] = "S-coordinate surface control parameter"
60
- ds["theta_s"].attrs["units"] = "nondimensional"
61
-
62
- ds["theta_b"] = np.float32(self.theta_b)
63
- ds["theta_b"].attrs["long_name"] = "S-coordinate bottom control parameter"
64
- ds["theta_b"].attrs["units"] = "nondimensional"
65
-
66
- ds["Tcline"] = np.float32(self.hc)
67
- ds["Tcline"].attrs["long_name"] = "S-coordinate surface/bottom layer width"
68
- ds["Tcline"].attrs["units"] = "m"
69
-
70
- ds["hc"] = np.float32(self.hc)
71
- ds["hc"].attrs["long_name"] = "S-coordinate parameter critical depth"
72
- ds["hc"].attrs["units"] = "m"
73
-
74
- ds["sc_r"] = sigma_r.astype(np.float32)
75
- ds["sc_r"].attrs["long_name"] = "S-coordinate at rho-points"
76
- ds["sc_r"].attrs["units"] = "nondimensional"
77
-
78
- ds["Cs_r"] = cs_r.astype(np.float32)
79
- ds["Cs_r"].attrs["long_name"] = "S-coordinate stretching curves at rho-points"
80
- ds["Cs_r"].attrs["units"] = "nondimensional"
81
-
82
- depth = -zr
83
- depth.attrs["long_name"] = "Layer depth at rho-points"
84
- depth.attrs["units"] = "m"
85
- ds = ds.assign_coords({"layer_depth_rho": depth.astype(np.float32)})
86
-
87
- depth_u = interpolate_from_rho_to_u(depth).astype(np.float32)
88
- depth_u.attrs["long_name"] = "Layer depth at u-points"
89
- depth_u.attrs["units"] = "m"
90
- ds = ds.assign_coords({"layer_depth_u": depth_u})
91
-
92
- depth_v = interpolate_from_rho_to_v(depth).astype(np.float32)
93
- depth_v.attrs["long_name"] = "Layer depth at v-points"
94
- depth_v.attrs["units"] = "m"
95
- ds = ds.assign_coords({"layer_depth_v": depth_v})
96
-
97
- depth = -zw
98
- depth.attrs["long_name"] = "Interface depth at rho-points"
99
- depth.attrs["units"] = "m"
100
- ds = ds.assign_coords({"interface_depth_rho": depth.astype(np.float32)})
101
-
102
- depth_u = interpolate_from_rho_to_u(depth).astype(np.float32)
103
- depth_u.attrs["long_name"] = "Interface depth at u-points"
104
- depth_u.attrs["units"] = "m"
105
- ds = ds.assign_coords({"interface_depth_u": depth_u})
106
-
107
- depth_v = interpolate_from_rho_to_v(depth).astype(np.float32)
108
- depth_v.attrs["long_name"] = "Interface depth at v-points"
109
- depth_v.attrs["units"] = "m"
110
- ds = ds.assign_coords({"interface_depth_v": depth_v})
111
-
112
- ds = ds.drop_vars(["eta_rho", "xi_rho"])
113
-
114
- ds.attrs["title"] = "ROMS vertical coordinate file created by ROMS-Tools"
115
- # Include the version of roms-tools
116
- try:
117
- roms_tools_version = importlib.metadata.version("roms-tools")
118
- except importlib.metadata.PackageNotFoundError:
119
- roms_tools_version = "unknown"
120
- ds.attrs["roms_tools_version"] = roms_tools_version
121
-
122
- object.__setattr__(self, "ds", ds)
123
-
124
- def plot(
125
- self,
126
- varname="layer_depth_rho",
127
- s=None,
128
- eta=None,
129
- xi=None,
130
- ) -> None:
131
- """
132
- Plot the vertical coordinate system for a given eta-, xi-, or s-slice.
133
-
134
- Parameters
135
- ----------
136
- varname : str, optional
137
- The field to plot. Options are "depth_rho", "depth_u", "depth_v".
138
- s: int, optional
139
- The s-index to plot. Default is None.
140
- eta : int, optional
141
- The eta-index to plot. Default is None.
142
- xi : int, optional
143
- The xi-index to plot. Default is None.
144
-
145
- Returns
146
- -------
147
- None
148
- This method does not return any value. It generates and displays a plot.
149
-
150
- Raises
151
- ------
152
- ValueError
153
- If the specified varname is not one of the valid options.
154
- If none of s, eta, xi are specified.
155
- """
156
-
157
- if not any([s is not None, eta is not None, xi is not None]):
158
- raise ValueError("At least one of s, eta, or xi must be specified.")
159
-
160
- self.ds[varname].load()
161
- field = self.ds[varname].squeeze()
162
-
163
- if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
164
- interface_depth = self.ds.interface_depth_rho
165
- elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
166
- interface_depth = self.ds.interface_depth_u
167
- elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
168
- interface_depth = self.ds.interface_depth_v
169
-
170
- # slice the field as desired
171
- title = field.long_name
172
- if s is not None:
173
- if "s_rho" in field.dims:
174
- title = title + f", s_rho = {field.s_rho[s].item()}"
175
- field = field.isel(s_rho=s)
176
- elif "s_w" in field.dims:
177
- title = title + f", s_w = {field.s_w[s].item()}"
178
- field = field.isel(s_w=s)
179
- else:
180
- raise ValueError(
181
- f"None of the expected dimensions (s_rho, s_w) found in ds[{varname}]."
182
- )
183
-
184
- if eta is not None:
185
- if "eta_rho" in field.dims:
186
- title = title + f", eta_rho = {field.eta_rho[eta].item()}"
187
- field = field.isel(eta_rho=eta)
188
- interface_depth = interface_depth.isel(eta_rho=eta)
189
- elif "eta_v" in field.dims:
190
- title = title + f", eta_v = {field.eta_v[eta].item()}"
191
- field = field.isel(eta_v=eta)
192
- interface_depth = interface_depth.isel(eta_v=eta)
193
- else:
194
- raise ValueError(
195
- f"None of the expected dimensions (eta_rho, eta_v) found in ds[{varname}]."
196
- )
197
- if xi is not None:
198
- if "xi_rho" in field.dims:
199
- title = title + f", xi_rho = {field.xi_rho[xi].item()}"
200
- field = field.isel(xi_rho=xi)
201
- interface_depth = interface_depth.isel(xi_rho=xi)
202
- elif "xi_u" in field.dims:
203
- title = title + f", xi_u = {field.xi_u[xi].item()}"
204
- field = field.isel(xi_u=xi)
205
- interface_depth = interface_depth.isel(xi_u=xi)
206
- else:
207
- raise ValueError(
208
- f"None of the expected dimensions (xi_rho, xi_u) found in ds[{varname}]."
209
- )
210
-
211
- if eta is None and xi is None:
212
- vmax = field.max().values
213
- vmin = field.min().values
214
- cmap = plt.colormaps.get_cmap("YlGnBu")
215
- cmap.set_bad(color="gray")
216
- kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
217
-
218
- _plot(
219
- self.grid.ds,
220
- field=field,
221
- straddle=self.grid.straddle,
222
- depth_contours=True,
223
- title=title,
224
- kwargs=kwargs,
225
- c="g",
226
- )
227
- else:
228
- if len(field.dims) == 2:
229
- cmap = plt.colormaps.get_cmap("YlGnBu")
230
- cmap.set_bad(color="gray")
231
- kwargs = {"vmax": 0.0, "vmin": 0.0, "cmap": cmap, "add_colorbar": False}
232
-
233
- _section_plot(
234
- xr.zeros_like(field),
235
- interface_depth=interface_depth,
236
- title=title,
237
- kwargs=kwargs,
238
- )
239
- else:
240
- if "s_rho" in field.dims or "s_w" in field.dims:
241
- _profile_plot(field, title=title)
242
- else:
243
- _line_plot(field, title=title)
244
-
245
- def save(self, filepath: str) -> None:
246
- """
247
- Save the vertical coordinate information to a netCDF4 file.
248
-
249
- Parameters
250
- ----------
251
- filepath
252
- """
253
- self.ds.to_netcdf(filepath)
254
-
255
- def to_yaml(self, filepath: str) -> None:
256
- """
257
- Export the parameters of the class to a YAML file, including the version of roms-tools.
258
-
259
- Parameters
260
- ----------
261
- filepath : str
262
- The path to the YAML file where the parameters will be saved.
263
- """
264
- # Serialize Grid data
265
- grid_data = asdict(self.grid)
266
- grid_data.pop("ds", None) # Exclude non-serializable fields
267
- grid_data.pop("straddle", None)
268
-
269
- # Include the version of roms-tools
270
- try:
271
- roms_tools_version = importlib.metadata.version("roms-tools")
272
- except importlib.metadata.PackageNotFoundError:
273
- roms_tools_version = "unknown"
274
-
275
- # Create header
276
- header = f"---\nroms_tools_version: {roms_tools_version}\n---\n"
277
-
278
- grid_yaml_data = {"Grid": grid_data}
279
-
280
- # Combine all sections
281
- vertical_coordinate_data = {
282
- "VerticalCoordinate": {
283
- "N": self.N,
284
- "theta_s": self.theta_s,
285
- "theta_b": self.theta_b,
286
- "hc": self.hc,
287
- }
288
- }
289
-
290
- # Merge YAML data while excluding empty sections
291
- yaml_data = {
292
- **grid_yaml_data,
293
- **vertical_coordinate_data,
294
- }
295
-
296
- with open(filepath, "w") as file:
297
- # Write header
298
- file.write(header)
299
- # Write YAML data
300
- yaml.dump(yaml_data, file, default_flow_style=False)
301
-
302
- @classmethod
303
- def from_file(cls, filepath: str) -> "VerticalCoordinate":
304
- """
305
- Create a VerticalCoordinate instance from an existing file.
306
-
307
- Parameters
308
- ----------
309
- filepath : str
310
- Path to the file containing the vertical coordinate information.
311
-
312
- Returns
313
- -------
314
- VerticalCoordinate
315
- A new instance of VerticalCoordinate populated with data from the file.
316
- """
317
- # Load the dataset from the file
318
- ds = xr.open_dataset(filepath)
319
-
320
- # Create a new VerticalCoordinate instance without calling __init__ and __post_init__
321
- vertical_coordinate = cls.__new__(cls)
322
-
323
- # Set the dataset for the vertical_corodinate instance
324
- object.__setattr__(vertical_coordinate, "ds", ds)
325
-
326
- # Manually set the remaining attributes by extracting parameters from dataset
327
- object.__setattr__(vertical_coordinate, "N", ds.sizes["s_rho"])
328
- object.__setattr__(vertical_coordinate, "theta_s", ds["theta_s"].values.item())
329
- object.__setattr__(vertical_coordinate, "theta_b", ds["theta_b"].values.item())
330
- object.__setattr__(vertical_coordinate, "hc", ds["hc"].values.item())
331
- object.__setattr__(vertical_coordinate, "grid", None)
332
-
333
- return vertical_coordinate
334
-
335
- @classmethod
336
- def from_yaml(cls, filepath: str) -> "VerticalCoordinate":
337
- """
338
- Create an instance of the VerticalCoordinate class from a YAML file.
339
-
340
- Parameters
341
- ----------
342
- filepath : str
343
- The path to the YAML file from which the parameters will be read.
344
-
345
- Returns
346
- -------
347
- VerticalCoordinate
348
- An instance of the VerticalCoordinate class.
349
- """
350
- # Read the entire file content
351
- with open(filepath, "r") as file:
352
- file_content = file.read()
353
-
354
- # Split the content into YAML documents
355
- documents = list(yaml.safe_load_all(file_content))
356
-
357
- vertical_coordinate_data = None
358
-
359
- # Process the YAML documents
360
- for doc in documents:
361
- if doc is None:
362
- continue
363
- if "VerticalCoordinate" in doc:
364
- vertical_coordinate_data = doc["VerticalCoordinate"]
365
- break
366
-
367
- if vertical_coordinate_data is None:
368
- raise ValueError(
369
- "No VerticalCoordinate configuration found in the YAML file."
370
- )
371
-
372
- # Create Grid instance from the YAML file
373
- grid = Grid.from_yaml(filepath)
374
-
375
- # Create and return an instance of TidalForcing
376
- return cls(
377
- grid=grid,
378
- **vertical_coordinate_data,
379
- )
380
3
 
381
4
 
382
5
  def compute_cs(sigma, theta_s, theta_b):