roms-tools 1.0.1__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
roms_tools/setup/tides.py CHANGED
@@ -15,11 +15,12 @@ from roms_tools.setup.utils import (
15
15
  interpolate_from_rho_to_u,
16
16
  interpolate_from_rho_to_v,
17
17
  )
18
+ from roms_tools.setup.mixins import ROMSToolsMixins
18
19
  import matplotlib.pyplot as plt
19
20
 
20
21
 
21
22
  @dataclass(frozen=True, kw_only=True)
22
- class TidalForcing:
23
+ class TidalForcing(ROMSToolsMixins):
23
24
  """
24
25
  Represents tidal forcing data used in ocean modeling.
25
26
 
@@ -59,30 +60,13 @@ class TidalForcing:
59
60
  ds: xr.Dataset = field(init=False, repr=False)
60
61
 
61
62
  def __post_init__(self):
62
- if "name" not in self.source.keys():
63
- raise ValueError("`source` must include a 'name'.")
64
- if "path" not in self.source.keys():
65
- raise ValueError("`source` must include a 'path'.")
66
- if self.source["name"] == "TPXO":
67
- data = TPXODataset(filename=self.source["path"])
68
- else:
69
- raise ValueError('Only "TPXO" is a valid option for source["name"].')
63
+
64
+ self._input_checks()
65
+ lon, lat, angle, straddle = super().get_target_lon_lat()
66
+
67
+ data = self._get_data()
70
68
 
71
69
  data.check_number_constituents(self.ntides)
72
- # operate on longitudes between -180 and 180 unless ROMS domain lies at least 5 degrees in longitude away from Greenwich meridian
73
- lon = self.grid.ds.lon_rho
74
- lat = self.grid.ds.lat_rho
75
- angle = self.grid.ds.angle
76
-
77
- lon = xr.where(lon > 180, lon - 360, lon)
78
- straddle = True
79
- if not self.grid.straddle and abs(lon).min() > 5:
80
- lon = xr.where(lon < 0, lon + 360, lon)
81
- straddle = False
82
-
83
- # Restrict data to relevant subdomain to achieve better performance and to avoid discontinuous longitudes introduced by converting
84
- # to a different longitude range (+- 360 degrees). Discontinues longitudes can lead to artifacts in the interpolation process that
85
- # would not be detected by the nan_check function.
86
70
  data.choose_subdomain(
87
71
  latitude_range=[lat.min().values, lat.max().values],
88
72
  longitude_range=[lon.min().values, lon.max().values],
@@ -121,56 +105,64 @@ class TidalForcing:
121
105
  method="linear",
122
106
  )
123
107
 
124
- # Rotate to grid orientation
125
- u_Re = data_vars["u_Re"] * np.cos(angle) + data_vars["v_Re"] * np.sin(angle)
126
- v_Re = data_vars["v_Re"] * np.cos(angle) - data_vars["u_Re"] * np.sin(angle)
127
- u_Im = data_vars["u_Im"] * np.cos(angle) + data_vars["v_Im"] * np.sin(angle)
128
- v_Im = data_vars["v_Im"] * np.cos(angle) - data_vars["u_Im"] * np.sin(angle)
108
+ data_vars = super().process_velocities(
109
+ data_vars, angle, "u_Re", "v_Re", interpolate=False
110
+ )
111
+ data_vars = super().process_velocities(
112
+ data_vars, angle, "u_Im", "v_Im", interpolate=False
113
+ )
129
114
 
130
115
  # Convert to barotropic velocity
131
- u_Re = u_Re / self.grid.ds.h
132
- v_Re = v_Re / self.grid.ds.h
133
- u_Im = u_Im / self.grid.ds.h
134
- v_Im = v_Im / self.grid.ds.h
116
+ for varname in ["u_Re", "v_Re", "u_Im", "v_Im"]:
117
+ data_vars[varname] = data_vars[varname] / self.grid.ds.h
135
118
 
136
119
  # Interpolate from rho- to velocity points
137
- u_Re = interpolate_from_rho_to_u(u_Re)
138
- v_Re = interpolate_from_rho_to_v(v_Re)
139
- u_Im = interpolate_from_rho_to_u(u_Im)
140
- v_Im = interpolate_from_rho_to_v(v_Im)
120
+ for uname in ["u_Re", "u_Im"]:
121
+ data_vars[uname] = interpolate_from_rho_to_u(data_vars[uname])
122
+ for vname in ["v_Re", "v_Im"]:
123
+ data_vars[vname] = interpolate_from_rho_to_v(data_vars[vname])
124
+
125
+ d_meta = super().get_variable_metadata()
126
+ ds = self._write_into_dataset(data_vars, d_meta)
127
+ ds["omega"] = tides["omega"]
128
+
129
+ ds = self._add_global_metadata(ds)
130
+
131
+ for var in ["ssh_Re", "u_Re", "v_Im"]:
132
+ nan_check(ds[var].isel(ntides=0), self.grid.ds.mask_rho)
133
+
134
+ object.__setattr__(self, "ds", ds)
135
+
136
+ def _input_checks(self):
137
+
138
+ if "name" not in self.source.keys():
139
+ raise ValueError("`source` must include a 'name'.")
140
+ if "path" not in self.source.keys():
141
+ raise ValueError("`source` must include a 'path'.")
142
+
143
+ def _get_data(self):
144
+
145
+ if self.source["name"] == "TPXO":
146
+ data = TPXODataset(filename=self.source["path"])
147
+ else:
148
+ raise ValueError('Only "TPXO" is a valid option for source["name"].')
149
+ return data
150
+
151
+ def _write_into_dataset(self, data_vars, d_meta):
141
152
 
142
153
  # save in new dataset
143
154
  ds = xr.Dataset()
144
155
 
145
- # ds["omega"] = tides["omega"]
146
-
147
- ds["ssh_Re"] = data_vars["ssh_Re"].astype(np.float32)
148
- ds["ssh_Im"] = data_vars["ssh_Im"].astype(np.float32)
149
- ds["ssh_Re"].attrs["long_name"] = "Tidal elevation, real part"
150
- ds["ssh_Im"].attrs["long_name"] = "Tidal elevation, complex part"
151
- ds["ssh_Re"].attrs["units"] = "m"
152
- ds["ssh_Im"].attrs["units"] = "m"
153
-
154
- ds["pot_Re"] = data_vars["pot_Re"].astype(np.float32)
155
- ds["pot_Im"] = data_vars["pot_Im"].astype(np.float32)
156
- ds["pot_Re"].attrs["long_name"] = "Tidal potential, real part"
157
- ds["pot_Im"].attrs["long_name"] = "Tidal potential, complex part"
158
- ds["pot_Re"].attrs["units"] = "m"
159
- ds["pot_Im"].attrs["units"] = "m"
160
-
161
- ds["u_Re"] = u_Re.astype(np.float32)
162
- ds["u_Im"] = u_Im.astype(np.float32)
163
- ds["u_Re"].attrs["long_name"] = "Tidal velocity in x-direction, real part"
164
- ds["u_Im"].attrs["long_name"] = "Tidal velocity in x-direction, complex part"
165
- ds["u_Re"].attrs["units"] = "m/s"
166
- ds["u_Im"].attrs["units"] = "m/s"
167
-
168
- ds["v_Re"] = v_Re.astype(np.float32)
169
- ds["v_Im"] = v_Im.astype(np.float32)
170
- ds["v_Re"].attrs["long_name"] = "Tidal velocity in y-direction, real part"
171
- ds["v_Im"].attrs["long_name"] = "Tidal velocity in y-direction, complex part"
172
- ds["v_Re"].attrs["units"] = "m/s"
173
- ds["v_Im"].attrs["units"] = "m/s"
156
+ for var in data_vars.keys():
157
+ ds[var] = data_vars[var].astype(np.float32)
158
+ ds[var].attrs["long_name"] = d_meta[var]["long_name"]
159
+ ds[var].attrs["units"] = d_meta[var]["units"]
160
+
161
+ ds = ds.drop_vars(["lat_rho", "lon_rho"])
162
+
163
+ return ds
164
+
165
+ def _add_global_metadata(self, ds):
174
166
 
175
167
  ds.attrs["title"] = "ROMS tidal forcing created by ROMS-Tools"
176
168
  # Include the version of roms-tools
@@ -185,10 +177,7 @@ class TidalForcing:
185
177
  ds.attrs["model_reference_date"] = str(self.model_reference_date)
186
178
  ds.attrs["allan_factor"] = self.allan_factor
187
179
 
188
- object.__setattr__(self, "ds", ds)
189
-
190
- for var in ["ssh_Re", "u_Re", "v_Im"]:
191
- nan_check(self.ds[var].isel(ntides=0), self.grid.ds.mask_rho)
180
+ return ds
192
181
 
193
182
  def plot(self, varname, ntides=0) -> None:
194
183
  """
@@ -228,6 +217,25 @@ class TidalForcing:
228
217
  """
229
218
 
230
219
  field = self.ds[varname].isel(ntides=ntides).compute()
220
+ if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
221
+ field = field.where(self.grid.ds.mask_rho)
222
+ field = field.assign_coords(
223
+ {"lon": self.grid.ds.lon_rho, "lat": self.grid.ds.lat_rho}
224
+ )
225
+
226
+ elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
227
+ field = field.where(self.grid.ds.mask_u)
228
+ field = field.assign_coords(
229
+ {"lon": self.grid.ds.lon_u, "lat": self.grid.ds.lat_u}
230
+ )
231
+
232
+ elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
233
+ field = field.where(self.grid.ds.mask_v)
234
+ field = field.assign_coords(
235
+ {"lon": self.grid.ds.lon_v, "lat": self.grid.ds.lat_v}
236
+ )
237
+ else:
238
+ ValueError("provided field does not have two horizontal dimension")
231
239
 
232
240
  title = "%s, ntides = %i" % (field.long_name, self.ds[varname].ntides[ntides])
233
241
 
@@ -10,8 +10,42 @@ from itertools import count
10
10
 
11
11
 
12
12
  def _add_topography_and_mask(
13
- ds, topography_source, smooth_factor, hmin, rmax
13
+ ds, topography_source, hmin, smooth_factor=8.0, rmax=0.2
14
14
  ) -> xr.Dataset:
15
+ """
16
+ Adds topography and a land/water mask to the dataset based on the provided topography source.
17
+
18
+ This function performs the following operations:
19
+ 1. Interpolates topography data onto the desired grid.
20
+ 2. Applies a mask based on ocean depth.
21
+ 3. Smooths the topography globally to reduce grid-scale instabilities.
22
+ 4. Fills enclosed basins with land.
23
+ 5. Smooths the topography locally to ensure the steepness ratio satisfies the rmax criterion.
24
+ 6. Adds topography metadata.
25
+
26
+ Parameters
27
+ ----------
28
+ ds : xr.Dataset
29
+ The dataset to which topography and the land/water mask will be added.
30
+ topography_source : str
31
+ The source of the topography data.
32
+ hmin : float
33
+ The minimum allowable depth for the topography.
34
+ smooth_factor : float, optional
35
+ The smoothing factor used in the domain-wide Gaussian smoothing of the
36
+ topography. Smaller values result in less smoothing, while larger
37
+ values produce more smoothing. The default is 8.0.
38
+ rmax : float, optional
39
+ The maximum allowable steepness ratio for the topography smoothing.
40
+ This parameter controls the local smoothing of the topography. Smaller values result in
41
+ smoother topography, while larger values preserve more detail. The default is 0.2.
42
+
43
+ Returns
44
+ -------
45
+ xr.Dataset
46
+ The dataset with added topography, mask, and metadata.
47
+ """
48
+
15
49
  lon = ds.lon_rho.values
16
50
  lat = ds.lat_rho.values
17
51
 
@@ -23,16 +57,11 @@ def _add_topography_and_mask(
23
57
  mask = xr.where(hraw > 0, 1.0, 0.0)
24
58
 
25
59
  # smooth topography domain-wide with Gaussian kernel to avoid grid scale instabilities
26
- ds["hraw"] = _smooth_topography_globally(hraw, mask, smooth_factor)
27
- ds["hraw"].attrs = {
28
- "long_name": "Working bathymetry at rho-points",
29
- "source": f"Raw bathymetry from {topography_source} (smoothing diameter {smooth_factor})",
30
- "units": "meter",
31
- }
60
+ hraw = _smooth_topography_globally(hraw, mask, smooth_factor)
32
61
 
33
62
  # fill enclosed basins with land
34
63
  mask = _fill_enclosed_basins(mask.values)
35
- ds["mask_rho"] = xr.DataArray(mask, dims=("eta_rho", "xi_rho"))
64
+ ds["mask_rho"] = xr.DataArray(mask.astype(np.int32), dims=("eta_rho", "xi_rho"))
36
65
  ds["mask_rho"].attrs = {
37
66
  "long_name": "Mask at rho-points",
38
67
  "units": "land/water (0/1)",
@@ -41,7 +70,7 @@ def _add_topography_and_mask(
41
70
  ds = _add_velocity_masks(ds)
42
71
 
43
72
  # smooth topography locally to satisfy r < rmax
44
- ds["h"] = _smooth_topography_locally(ds["hraw"] * ds["mask_rho"], hmin, rmax)
73
+ ds["h"] = _smooth_topography_locally(hraw * ds["mask_rho"], hmin, rmax)
45
74
  ds["h"].attrs = {
46
75
  "long_name": "Final bathymetry at rho-points",
47
76
  "units": "meter",
@@ -238,9 +267,7 @@ def _compute_rfactor(h):
238
267
 
239
268
  def _add_topography_metadata(ds, topography_source, smooth_factor, hmin, rmax):
240
269
  ds.attrs["topography_source"] = topography_source
241
- ds.attrs["smooth_factor"] = smooth_factor
242
270
  ds.attrs["hmin"] = hmin
243
- ds.attrs["rmax"] = rmax
244
271
 
245
272
  return ds
246
273
 
@@ -248,8 +275,12 @@ def _add_topography_metadata(ds, topography_source, smooth_factor, hmin, rmax):
248
275
  def _add_velocity_masks(ds):
249
276
 
250
277
  # add u- and v-masks
251
- ds["mask_u"] = interpolate_from_rho_to_u(ds["mask_rho"], method="multiplicative")
252
- ds["mask_v"] = interpolate_from_rho_to_v(ds["mask_rho"], method="multiplicative")
278
+ ds["mask_u"] = interpolate_from_rho_to_u(
279
+ ds["mask_rho"], method="multiplicative"
280
+ ).astype(np.int32)
281
+ ds["mask_v"] = interpolate_from_rho_to_v(
282
+ ds["mask_rho"], method="multiplicative"
283
+ ).astype(np.int32)
253
284
 
254
285
  ds["mask_u"].attrs = {"long_name": "Mask at u-points", "units": "land/water (0/1)"}
255
286
  ds["mask_v"].attrs = {"long_name": "Mask at v-points", "units": "land/water (0/1)"}
@@ -1,382 +1,5 @@
1
1
  import numpy as np
2
2
  import xarray as xr
3
- import yaml
4
- import importlib.metadata
5
- from dataclasses import dataclass, field, asdict
6
- from roms_tools.setup.grid import Grid
7
- from roms_tools.setup.utils import (
8
- interpolate_from_rho_to_u,
9
- interpolate_from_rho_to_v,
10
- )
11
- from roms_tools.setup.plot import _plot, _section_plot, _profile_plot, _line_plot
12
- import matplotlib.pyplot as plt
13
-
14
-
15
- @dataclass(frozen=True, kw_only=True)
16
- class VerticalCoordinate:
17
- """
18
- Represents vertical coordinate for ROMS.
19
-
20
- Parameters
21
- ----------
22
- grid : Grid
23
- Object representing the grid information.
24
- N : int
25
- The number of vertical levels.
26
- theta_s : float
27
- The surface control parameter. Must satisfy 0 < theta_s <= 10.
28
- theta_b : float
29
- The bottom control parameter. Must satisfy 0 < theta_b <= 4.
30
- hc : float
31
- The critical depth.
32
-
33
- Attributes
34
- ----------
35
- ds : xr.Dataset
36
- Xarray Dataset containing the atmospheric forcing data.
37
- """
38
-
39
- grid: Grid
40
- N: int
41
- theta_s: float
42
- theta_b: float
43
- hc: float
44
-
45
- ds: xr.Dataset = field(init=False, repr=False)
46
-
47
- def __post_init__(self):
48
-
49
- h = self.grid.ds.h
50
-
51
- cs_r, sigma_r = sigma_stretch(self.theta_s, self.theta_b, self.N, "r")
52
- zr = compute_depth(h * 0, h, self.hc, cs_r, sigma_r)
53
- cs_w, sigma_w = sigma_stretch(self.theta_s, self.theta_b, self.N, "w")
54
- zw = compute_depth(h * 0, h, self.hc, cs_w, sigma_w)
55
-
56
- ds = xr.Dataset()
57
-
58
- ds["theta_s"] = np.float32(self.theta_s)
59
- ds["theta_s"].attrs["long_name"] = "S-coordinate surface control parameter"
60
- ds["theta_s"].attrs["units"] = "nondimensional"
61
-
62
- ds["theta_b"] = np.float32(self.theta_b)
63
- ds["theta_b"].attrs["long_name"] = "S-coordinate bottom control parameter"
64
- ds["theta_b"].attrs["units"] = "nondimensional"
65
-
66
- ds["Tcline"] = np.float32(self.hc)
67
- ds["Tcline"].attrs["long_name"] = "S-coordinate surface/bottom layer width"
68
- ds["Tcline"].attrs["units"] = "m"
69
-
70
- ds["hc"] = np.float32(self.hc)
71
- ds["hc"].attrs["long_name"] = "S-coordinate parameter critical depth"
72
- ds["hc"].attrs["units"] = "m"
73
-
74
- ds["sc_r"] = sigma_r.astype(np.float32)
75
- ds["sc_r"].attrs["long_name"] = "S-coordinate at rho-points"
76
- ds["sc_r"].attrs["units"] = "nondimensional"
77
-
78
- ds["Cs_r"] = cs_r.astype(np.float32)
79
- ds["Cs_r"].attrs["long_name"] = "S-coordinate stretching curves at rho-points"
80
- ds["Cs_r"].attrs["units"] = "nondimensional"
81
-
82
- depth = -zr
83
- depth.attrs["long_name"] = "Layer depth at rho-points"
84
- depth.attrs["units"] = "m"
85
- ds = ds.assign_coords({"layer_depth_rho": depth.astype(np.float32)})
86
-
87
- depth_u = interpolate_from_rho_to_u(depth).astype(np.float32)
88
- depth_u.attrs["long_name"] = "Layer depth at u-points"
89
- depth_u.attrs["units"] = "m"
90
- ds = ds.assign_coords({"layer_depth_u": depth_u})
91
-
92
- depth_v = interpolate_from_rho_to_v(depth).astype(np.float32)
93
- depth_v.attrs["long_name"] = "Layer depth at v-points"
94
- depth_v.attrs["units"] = "m"
95
- ds = ds.assign_coords({"layer_depth_v": depth_v})
96
-
97
- depth = -zw
98
- depth.attrs["long_name"] = "Interface depth at rho-points"
99
- depth.attrs["units"] = "m"
100
- ds = ds.assign_coords({"interface_depth_rho": depth.astype(np.float32)})
101
-
102
- depth_u = interpolate_from_rho_to_u(depth).astype(np.float32)
103
- depth_u.attrs["long_name"] = "Interface depth at u-points"
104
- depth_u.attrs["units"] = "m"
105
- ds = ds.assign_coords({"interface_depth_u": depth_u})
106
-
107
- depth_v = interpolate_from_rho_to_v(depth).astype(np.float32)
108
- depth_v.attrs["long_name"] = "Interface depth at v-points"
109
- depth_v.attrs["units"] = "m"
110
- ds = ds.assign_coords({"interface_depth_v": depth_v})
111
-
112
- ds = ds.drop_vars(["eta_rho", "xi_rho"])
113
-
114
- ds.attrs["title"] = "ROMS vertical coordinate file created by ROMS-Tools"
115
- # Include the version of roms-tools
116
- try:
117
- roms_tools_version = importlib.metadata.version("roms-tools")
118
- except importlib.metadata.PackageNotFoundError:
119
- roms_tools_version = "unknown"
120
- ds.attrs["roms_tools_version"] = roms_tools_version
121
-
122
- object.__setattr__(self, "ds", ds)
123
-
124
- def plot(
125
- self,
126
- varname="layer_depth_rho",
127
- s=None,
128
- eta=None,
129
- xi=None,
130
- ) -> None:
131
- """
132
- Plot the vertical coordinate system for a given eta-, xi-, or s-slice.
133
-
134
- Parameters
135
- ----------
136
- varname : str, optional
137
- The field to plot. Options are "depth_rho", "depth_u", "depth_v".
138
- s: int, optional
139
- The s-index to plot. Default is None.
140
- eta : int, optional
141
- The eta-index to plot. Default is None.
142
- xi : int, optional
143
- The xi-index to plot. Default is None.
144
-
145
- Returns
146
- -------
147
- None
148
- This method does not return any value. It generates and displays a plot.
149
-
150
- Raises
151
- ------
152
- ValueError
153
- If the specified varname is not one of the valid options.
154
- If none of s, eta, xi are specified.
155
- """
156
-
157
- if not any([s is not None, eta is not None, xi is not None]):
158
- raise ValueError("At least one of s, eta, or xi must be specified.")
159
-
160
- self.ds[varname].load()
161
- field = self.ds[varname].squeeze()
162
-
163
- if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
164
- interface_depth = self.ds.interface_depth_rho
165
- elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
166
- interface_depth = self.ds.interface_depth_u
167
- elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
168
- interface_depth = self.ds.interface_depth_v
169
-
170
- # slice the field as desired
171
- title = field.long_name
172
- if s is not None:
173
- if "s_rho" in field.dims:
174
- title = title + f", s_rho = {field.s_rho[s].item()}"
175
- field = field.isel(s_rho=s)
176
- elif "s_w" in field.dims:
177
- title = title + f", s_w = {field.s_w[s].item()}"
178
- field = field.isel(s_w=s)
179
- else:
180
- raise ValueError(
181
- f"None of the expected dimensions (s_rho, s_w) found in ds[{varname}]."
182
- )
183
-
184
- if eta is not None:
185
- if "eta_rho" in field.dims:
186
- title = title + f", eta_rho = {field.eta_rho[eta].item()}"
187
- field = field.isel(eta_rho=eta)
188
- interface_depth = interface_depth.isel(eta_rho=eta)
189
- elif "eta_v" in field.dims:
190
- title = title + f", eta_v = {field.eta_v[eta].item()}"
191
- field = field.isel(eta_v=eta)
192
- interface_depth = interface_depth.isel(eta_v=eta)
193
- else:
194
- raise ValueError(
195
- f"None of the expected dimensions (eta_rho, eta_v) found in ds[{varname}]."
196
- )
197
- if xi is not None:
198
- if "xi_rho" in field.dims:
199
- title = title + f", xi_rho = {field.xi_rho[xi].item()}"
200
- field = field.isel(xi_rho=xi)
201
- interface_depth = interface_depth.isel(xi_rho=xi)
202
- elif "xi_u" in field.dims:
203
- title = title + f", xi_u = {field.xi_u[xi].item()}"
204
- field = field.isel(xi_u=xi)
205
- interface_depth = interface_depth.isel(xi_u=xi)
206
- else:
207
- raise ValueError(
208
- f"None of the expected dimensions (xi_rho, xi_u) found in ds[{varname}]."
209
- )
210
-
211
- if eta is None and xi is None:
212
- vmax = field.max().values
213
- vmin = field.min().values
214
- cmap = plt.colormaps.get_cmap("YlGnBu")
215
- cmap.set_bad(color="gray")
216
- kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
217
-
218
- _plot(
219
- self.grid.ds,
220
- field=field,
221
- straddle=self.grid.straddle,
222
- depth_contours=True,
223
- title=title,
224
- kwargs=kwargs,
225
- c="g",
226
- )
227
- else:
228
- if len(field.dims) == 2:
229
- cmap = plt.colormaps.get_cmap("YlGnBu")
230
- cmap.set_bad(color="gray")
231
- kwargs = {"vmax": 0.0, "vmin": 0.0, "cmap": cmap, "add_colorbar": False}
232
-
233
- _section_plot(
234
- xr.zeros_like(field),
235
- interface_depth=interface_depth,
236
- title=title,
237
- kwargs=kwargs,
238
- )
239
- else:
240
- if "s_rho" in field.dims or "s_w" in field.dims:
241
- _profile_plot(field, title=title)
242
- else:
243
- _line_plot(field, title=title)
244
-
245
- def save(self, filepath: str) -> None:
246
- """
247
- Save the vertical coordinate information to a netCDF4 file.
248
-
249
- Parameters
250
- ----------
251
- filepath
252
- """
253
- self.ds.to_netcdf(filepath)
254
-
255
- def to_yaml(self, filepath: str) -> None:
256
- """
257
- Export the parameters of the class to a YAML file, including the version of roms-tools.
258
-
259
- Parameters
260
- ----------
261
- filepath : str
262
- The path to the YAML file where the parameters will be saved.
263
- """
264
- # Serialize Grid data
265
- grid_data = asdict(self.grid)
266
- grid_data.pop("ds", None) # Exclude non-serializable fields
267
- grid_data.pop("straddle", None)
268
-
269
- # Include the version of roms-tools
270
- try:
271
- roms_tools_version = importlib.metadata.version("roms-tools")
272
- except importlib.metadata.PackageNotFoundError:
273
- roms_tools_version = "unknown"
274
-
275
- # Create header
276
- header = f"---\nroms_tools_version: {roms_tools_version}\n---\n"
277
-
278
- grid_yaml_data = {"Grid": grid_data}
279
-
280
- # Combine all sections
281
- vertical_coordinate_data = {
282
- "VerticalCoordinate": {
283
- "N": self.N,
284
- "theta_s": self.theta_s,
285
- "theta_b": self.theta_b,
286
- "hc": self.hc,
287
- }
288
- }
289
-
290
- # Merge YAML data while excluding empty sections
291
- yaml_data = {
292
- **grid_yaml_data,
293
- **vertical_coordinate_data,
294
- }
295
-
296
- with open(filepath, "w") as file:
297
- # Write header
298
- file.write(header)
299
- # Write YAML data
300
- yaml.dump(yaml_data, file, default_flow_style=False)
301
-
302
- @classmethod
303
- def from_file(cls, filepath: str) -> "VerticalCoordinate":
304
- """
305
- Create a VerticalCoordinate instance from an existing file.
306
-
307
- Parameters
308
- ----------
309
- filepath : str
310
- Path to the file containing the vertical coordinate information.
311
-
312
- Returns
313
- -------
314
- VerticalCoordinate
315
- A new instance of VerticalCoordinate populated with data from the file.
316
- """
317
- # Load the dataset from the file
318
- ds = xr.open_dataset(filepath)
319
-
320
- # Create a new VerticalCoordinate instance without calling __init__ and __post_init__
321
- vertical_coordinate = cls.__new__(cls)
322
-
323
- # Set the dataset for the vertical_corodinate instance
324
- object.__setattr__(vertical_coordinate, "ds", ds)
325
-
326
- # Manually set the remaining attributes by extracting parameters from dataset
327
- object.__setattr__(vertical_coordinate, "N", ds.sizes["s_rho"])
328
- object.__setattr__(vertical_coordinate, "theta_s", ds["theta_s"].values.item())
329
- object.__setattr__(vertical_coordinate, "theta_b", ds["theta_b"].values.item())
330
- object.__setattr__(vertical_coordinate, "hc", ds["hc"].values.item())
331
- object.__setattr__(vertical_coordinate, "grid", None)
332
-
333
- return vertical_coordinate
334
-
335
- @classmethod
336
- def from_yaml(cls, filepath: str) -> "VerticalCoordinate":
337
- """
338
- Create an instance of the VerticalCoordinate class from a YAML file.
339
-
340
- Parameters
341
- ----------
342
- filepath : str
343
- The path to the YAML file from which the parameters will be read.
344
-
345
- Returns
346
- -------
347
- VerticalCoordinate
348
- An instance of the VerticalCoordinate class.
349
- """
350
- # Read the entire file content
351
- with open(filepath, "r") as file:
352
- file_content = file.read()
353
-
354
- # Split the content into YAML documents
355
- documents = list(yaml.safe_load_all(file_content))
356
-
357
- vertical_coordinate_data = None
358
-
359
- # Process the YAML documents
360
- for doc in documents:
361
- if doc is None:
362
- continue
363
- if "VerticalCoordinate" in doc:
364
- vertical_coordinate_data = doc["VerticalCoordinate"]
365
- break
366
-
367
- if vertical_coordinate_data is None:
368
- raise ValueError(
369
- "No VerticalCoordinate configuration found in the YAML file."
370
- )
371
-
372
- # Create Grid instance from the YAML file
373
- grid = Grid.from_yaml(filepath)
374
-
375
- # Create and return an instance of TidalForcing
376
- return cls(
377
- grid=grid,
378
- **vertical_coordinate_data,
379
- )
380
3
 
381
4
 
382
5
  def compute_cs(sigma, theta_s, theta_b):
@@ -491,4 +114,9 @@ def compute_depth(zeta, h, hc, cs, sigma):
491
114
  s = (hc * sigma + h * cs) / (hc + h)
492
115
  z = zeta + (zeta + h) * s
493
116
 
117
+ if "s_rho" in z.dims:
118
+ z = z.transpose("s_rho", "eta_rho", "xi_rho")
119
+ elif "s_w" in z.dims:
120
+ z = z.transpose("s_w", "eta_rho", "xi_rho")
121
+
494
122
  return z