roms-tools 1.1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -123,9 +123,9 @@ class ROMSToolsMixins:
123
123
  if vars_3d:
124
124
  # 3d interpolation
125
125
  coords = {
126
+ data.dim_names["depth"]: self.grid.ds["layer_depth_rho"],
126
127
  data.dim_names["latitude"]: lat,
127
128
  data.dim_names["longitude"]: lon,
128
- data.dim_names["depth"]: self.grid.ds["layer_depth_rho"],
129
129
  }
130
130
  # extrapolate deepest value all the way to bottom ("flooding")
131
131
  for var in vars_3d:
@@ -149,36 +149,50 @@ class ROMSToolsMixins:
149
149
  if data.dim_names["time"] != "time":
150
150
  data_vars[var] = data_vars[var].rename({data.dim_names["time"]: "time"})
151
151
 
152
+ # transpose to correct order (time, s_rho, eta_rho, xi_rho)
153
+ data_vars[var] = data_vars[var].transpose(
154
+ "time", "s_rho", "eta_rho", "xi_rho"
155
+ )
156
+
152
157
  return data_vars
153
158
 
154
- def process_velocities(self, data_vars, angle, interpolate=True):
159
+ def process_velocities(self, data_vars, angle, uname, vname, interpolate=True):
155
160
  """
156
- Processes and rotates velocity components, and interpolates them to the appropriate grid points.
161
+ Process and rotate velocity components to align with the grid orientation and optionally interpolate
162
+ them to the appropriate grid points.
157
163
 
158
164
  This method performs the following steps:
159
- 1. Rotates the velocity components to align with the grid orientation using the provided angle.
160
- 2. Optionally interpolates the rotated velocities to the u- and v-points of the grid.
161
- 3. If the velocities are 3D (with vertical coordinates), computes barotropic (depth-averaged) velocities.
165
+
166
+ 1. **Rotation**: Rotates the velocity components (e.g., `u`, `v`) to align with the grid orientation
167
+ using the provided angle data.
168
+ 2. **Interpolation**: Optionally interpolates the rotated velocities from rho-points to u- and v-points
169
+ of the grid.
170
+ 3. **Barotropic Velocity Calculation**: If the velocity components are 3D (with vertical coordinates),
171
+ computes the barotropic (depth-averaged) velocities.
162
172
 
163
173
  Parameters
164
174
  ----------
165
175
  data_vars : dict of str: xarray.DataArray
166
- Dictionary containing the velocity components to be processed. Must include keys "u" and "v"
167
- or "uwnd" and "vwnd".
176
+ Dictionary containing the velocity components to be processed. The dictionary should include keys
177
+ corresponding to the velocity component names (e.g., `uname`, `vname`).
168
178
  angle : xarray.DataArray
169
- DataArray containing the angle used for rotating the velocity components to the grid orientation.
179
+ DataArray containing the grid angle values used to rotate the velocity components to the correct
180
+ orientation on the grid.
181
+ uname : str
182
+ The key corresponding to the zonal (east-west) velocity component in `data_vars`.
183
+ vname : str
184
+ The key corresponding to the meridional (north-south) velocity component in `data_vars`.
170
185
  interpolate : bool, optional
171
- If True, interpolates the velocities to the u- and v-points. Defaults to True.
186
+ If True, interpolates the rotated velocity components to the u- and v-points of the grid.
187
+ Defaults to True.
172
188
 
173
189
  Returns
174
190
  -------
175
191
  dict of str: xarray.DataArray
176
- Dictionary of processed velocity components. Includes "ubar" and "vbar" if the velocity components
177
- have vertical coordinates and are processed for barotropic (depth-averaged) velocities.
192
+ A dictionary of the processed velocity components. The returned dictionary includes the rotated and,
193
+ if applicable, interpolated velocity components. If the input velocities are 3D (having a vertical
194
+ dimension), the dictionary also includes the barotropic (depth-averaged) velocities (`ubar` and `vbar`).
178
195
  """
179
- # Determine the correct variable names based on the keys in data_vars
180
- uname = "u" if "u" in data_vars else "uwnd"
181
- vname = "v" if "v" in data_vars else "vwnd"
182
196
 
183
197
  # Rotate velocities to grid orientation
184
198
  u_rot = data_vars[uname] * np.cos(angle) + data_vars[vname] * np.sin(angle)
@@ -230,6 +244,26 @@ class ROMSToolsMixins:
230
244
  """
231
245
 
232
246
  d = {
247
+ "ssh_Re": {"long_name": "Tidal elevation, real part", "units": "m"},
248
+ "ssh_Im": {"long_name": "Tidal elevation, complex part", "units": "m"},
249
+ "pot_Re": {"long_name": "Tidal potential, real part", "units": "m"},
250
+ "pot_Im": {"long_name": "Tidal potential, complex part", "units": "m"},
251
+ "u_Re": {
252
+ "long_name": "Tidal velocity in x-direction, real part",
253
+ "units": "m/s",
254
+ },
255
+ "u_Im": {
256
+ "long_name": "Tidal velocity in x-direction, complex part",
257
+ "units": "m/s",
258
+ },
259
+ "v_Re": {
260
+ "long_name": "Tidal velocity in y-direction, real part",
261
+ "units": "m/s",
262
+ },
263
+ "v_Im": {
264
+ "long_name": "Tidal velocity in y-direction, complex part",
265
+ "units": "m/s",
266
+ },
233
267
  "uwnd": {"long_name": "10 meter wind in x-direction", "units": "m/s"},
234
268
  "vwnd": {"long_name": "10 meter wind in y-direction", "units": "m/s"},
235
269
  "swrad": {
@@ -323,18 +357,18 @@ class ROMSToolsMixins:
323
357
  return d
324
358
 
325
359
  def get_boundary_info(self):
326
- """
327
- Provides boundary coordinate information and renaming conventions for grid boundaries.
328
360
 
329
- This method returns two dictionaries: one specifying the boundary coordinates for different types of
330
- grid variables (e.g., "rho", "u", "v"), and another specifying how to rename dimensions for these boundaries.
361
+ """
362
+ This method provides information about the boundary points for the rho, u, and v
363
+ variables on the grid, specifying the indices for the south, east, north, and west
364
+ boundaries.
331
365
 
332
366
  Returns
333
367
  -------
334
- tuple of (dict, dict)
335
- - A dictionary mapping variable types and directions to boundary coordinates.
336
- - A dictionary mapping variable types and directions to new dimension names.
337
-
368
+ dict
369
+ A dictionary where keys are variable types ("rho", "u", "v"), and values
370
+ are nested dictionaries mapping directions ("south", "east", "north", "west")
371
+ to the corresponding boundary coordinates.
338
372
  """
339
373
 
340
374
  # Boundary coordinates
@@ -359,27 +393,4 @@ class ROMSToolsMixins:
359
393
  },
360
394
  }
361
395
 
362
- # How to rename the dimensions
363
-
364
- rename = {
365
- "rho": {
366
- "south": {"xi_rho": "xi_rho_south"},
367
- "east": {"eta_rho": "eta_rho_east"},
368
- "north": {"xi_rho": "xi_rho_north"},
369
- "west": {"eta_rho": "eta_rho_west"},
370
- },
371
- "u": {
372
- "south": {"xi_u": "xi_u_south"},
373
- "east": {"eta_rho": "eta_u_east"},
374
- "north": {"xi_u": "xi_u_north"},
375
- "west": {"eta_rho": "eta_u_west"},
376
- },
377
- "v": {
378
- "south": {"xi_rho": "xi_v_south"},
379
- "east": {"eta_v": "eta_v_east"},
380
- "north": {"xi_rho": "xi_v_north"},
381
- "west": {"eta_v": "eta_v_west"},
382
- },
383
- }
384
-
385
- return bdry_coords, rename
396
+ return bdry_coords
roms_tools/setup/plot.py CHANGED
@@ -8,7 +8,6 @@ def _plot(
8
8
  field=None,
9
9
  depth_contours=False,
10
10
  straddle=False,
11
- coarse_grid=False,
12
11
  c="red",
13
12
  title="",
14
13
  kwargs={},
@@ -21,29 +20,8 @@ def _plot(
21
20
  else:
22
21
 
23
22
  field = field.squeeze()
24
-
25
- if coarse_grid:
26
-
27
- field = field.rename({"eta_rho": "eta_coarse", "xi_rho": "xi_coarse"})
28
- field = field.where(grid_ds.mask_coarse)
29
- lon_deg = field.lon
30
- lat_deg = field.lat
31
-
32
- else:
33
- if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
34
- field = field.where(grid_ds.mask_rho)
35
- lon_deg = grid_ds["lon_rho"]
36
- lat_deg = grid_ds["lat_rho"]
37
- elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
38
- field = field.where(grid_ds.mask_u)
39
- lon_deg = grid_ds["lon_u"]
40
- lat_deg = grid_ds["lat_u"]
41
- elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
42
- field = field.where(grid_ds.mask_v)
43
- lon_deg = grid_ds["lon_v"]
44
- lat_deg = grid_ds["lat_v"]
45
- else:
46
- ValueError("provided field does not have two horizontal dimension")
23
+ lon_deg = field.lon
24
+ lat_deg = field.lat
47
25
 
48
26
  # check if North or South pole are in domain
49
27
  if lat_deg.max().values > 89 or lat_deg.min().values < -89:
@@ -96,23 +74,7 @@ def _plot(
96
74
  plt.colorbar(p, label=f"{field.long_name} [{field.units}]")
97
75
 
98
76
  if depth_contours:
99
- if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
100
- if "layer_depth_rho" in field.coords:
101
- depth = field.layer_depth_rho
102
- else:
103
- depth = field.interface_depth_rho
104
- elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
105
- if "layer_depth_u" in field.coords:
106
- depth = field.layer_depth_u
107
- else:
108
- depth = field.interface_depth_u
109
- elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
110
- if "layer_depth_v" in field.coords:
111
- depth = field.layer_depth_v
112
- else:
113
- depth = field.interface_depth_v
114
-
115
- cs = ax.contour(lon_deg, lat_deg, depth, transform=proj, colors="k")
77
+ cs = ax.contour(lon_deg, lat_deg, field.layer_depth, transform=proj, colors="k")
116
78
  ax.clabel(cs, inline=True, fontsize=10)
117
79
 
118
80
  return fig
@@ -135,12 +97,8 @@ def _section_plot(field, interface_depth=None, title="", kwargs={}):
135
97
  )
136
98
 
137
99
  depths_to_check = [
138
- "layer_depth_rho",
139
- "layer_depth_u",
140
- "layer_depth_v",
141
- "interface_depth_rho",
142
- "interface_depth_u",
143
- "interface_depth_v",
100
+ "layer_depth",
101
+ "interface_depth",
144
102
  ]
145
103
  try:
146
104
  depth_label = next(
@@ -95,7 +95,9 @@ class SurfaceForcing(ROMSToolsMixins):
95
95
  vars_2d = ["uwnd", "vwnd", "swrad", "lwrad", "Tair", "qair", "rain"]
96
96
  vars_3d = []
97
97
  data_vars = super().regrid_data(data, vars_2d, vars_3d, lon, lat)
98
- data_vars = super().process_velocities(data_vars, angle, interpolate=False)
98
+ data_vars = super().process_velocities(
99
+ data_vars, angle, "uwnd", "vwnd", interpolate=False
100
+ )
99
101
 
100
102
  if self.correct_radiation:
101
103
  correction_data = self._get_correction_data()
@@ -235,18 +237,19 @@ class SurfaceForcing(ROMSToolsMixins):
235
237
 
236
238
  if self.bgc_source["name"] == "CESM_REGRIDDED":
237
239
 
238
- bgc_data = CESMBGCSurfaceForcingDataset(
240
+ data = CESMBGCSurfaceForcingDataset(
239
241
  filename=self.bgc_source["path"],
240
242
  start_time=self.start_time,
241
243
  end_time=self.end_time,
242
244
  climatology=self.bgc_source["climatology"],
243
245
  )
246
+ data.post_process()
244
247
  else:
245
248
  raise ValueError(
246
249
  'Only "CESM_REGRIDDED" is a valid option for bgc_source["name"].'
247
250
  )
248
251
 
249
- return bgc_data
252
+ return data
250
253
 
251
254
  def _write_into_dataset(self, data, d_meta):
252
255
 
@@ -259,7 +262,6 @@ class SurfaceForcing(ROMSToolsMixins):
259
262
  ds[var].attrs["units"] = d_meta[var]["units"]
260
263
 
261
264
  if self.use_coarse_grid:
262
- ds = ds.assign_coords({"lon": self.target_lon, "lat": self.target_lat})
263
265
  ds = ds.rename({"eta_coarse": "eta_rho", "xi_coarse": "xi_rho"})
264
266
 
265
267
  # Preserve absolute time coordinate for readability
@@ -295,6 +297,10 @@ class SurfaceForcing(ROMSToolsMixins):
295
297
  if data.climatology:
296
298
  ds["time"].attrs["cycle_length"] = 365.25
297
299
 
300
+ variables_to_drop = ["lat_rho", "lon_rho", "lat_coarse", "lon_coarse"]
301
+ existing_vars = [var for var in variables_to_drop if var in ds]
302
+ ds = ds.drop_vars(existing_vars)
303
+
298
304
  return ds
299
305
 
300
306
  def _write_into_datatree(self, data, bgc_data, d_meta):
@@ -392,6 +398,15 @@ class SurfaceForcing(ROMSToolsMixins):
392
398
  field = ds[varname].isel(time=time).load()
393
399
  title = field.long_name
394
400
 
401
+ # assign lat / lon
402
+ if self.use_coarse_grid:
403
+ field = field.rename({"eta_rho": "eta_coarse", "xi_rho": "xi_coarse"})
404
+ field = field.where(self.grid.ds.mask_coarse)
405
+ else:
406
+ field = field.where(self.grid.ds.mask_rho)
407
+
408
+ field = field.assign_coords({"lon": self.target_lon, "lat": self.target_lat})
409
+
395
410
  # choose colorbar
396
411
  if varname in ["uwnd", "vwnd"]:
397
412
  vmax = max(field.max().values, -field.min().values)
@@ -412,7 +427,6 @@ class SurfaceForcing(ROMSToolsMixins):
412
427
  self.grid.ds,
413
428
  field=field,
414
429
  straddle=self.grid.straddle,
415
- coarse_grid=self.use_coarse_grid,
416
430
  title=title,
417
431
  kwargs=kwargs,
418
432
  c="g",
roms_tools/setup/tides.py CHANGED
@@ -15,11 +15,12 @@ from roms_tools.setup.utils import (
15
15
  interpolate_from_rho_to_u,
16
16
  interpolate_from_rho_to_v,
17
17
  )
18
+ from roms_tools.setup.mixins import ROMSToolsMixins
18
19
  import matplotlib.pyplot as plt
19
20
 
20
21
 
21
22
  @dataclass(frozen=True, kw_only=True)
22
- class TidalForcing:
23
+ class TidalForcing(ROMSToolsMixins):
23
24
  """
24
25
  Represents tidal forcing data used in ocean modeling.
25
26
 
@@ -59,30 +60,13 @@ class TidalForcing:
59
60
  ds: xr.Dataset = field(init=False, repr=False)
60
61
 
61
62
  def __post_init__(self):
62
- if "name" not in self.source.keys():
63
- raise ValueError("`source` must include a 'name'.")
64
- if "path" not in self.source.keys():
65
- raise ValueError("`source` must include a 'path'.")
66
- if self.source["name"] == "TPXO":
67
- data = TPXODataset(filename=self.source["path"])
68
- else:
69
- raise ValueError('Only "TPXO" is a valid option for source["name"].')
63
+
64
+ self._input_checks()
65
+ lon, lat, angle, straddle = super().get_target_lon_lat()
66
+
67
+ data = self._get_data()
70
68
 
71
69
  data.check_number_constituents(self.ntides)
72
- # operate on longitudes between -180 and 180 unless ROMS domain lies at least 5 degrees in longitude away from Greenwich meridian
73
- lon = self.grid.ds.lon_rho
74
- lat = self.grid.ds.lat_rho
75
- angle = self.grid.ds.angle
76
-
77
- lon = xr.where(lon > 180, lon - 360, lon)
78
- straddle = True
79
- if not self.grid.straddle and abs(lon).min() > 5:
80
- lon = xr.where(lon < 0, lon + 360, lon)
81
- straddle = False
82
-
83
- # Restrict data to relevant subdomain to achieve better performance and to avoid discontinuous longitudes introduced by converting
84
- # to a different longitude range (+- 360 degrees). Discontinues longitudes can lead to artifacts in the interpolation process that
85
- # would not be detected by the nan_check function.
86
70
  data.choose_subdomain(
87
71
  latitude_range=[lat.min().values, lat.max().values],
88
72
  longitude_range=[lon.min().values, lon.max().values],
@@ -121,56 +105,64 @@ class TidalForcing:
121
105
  method="linear",
122
106
  )
123
107
 
124
- # Rotate to grid orientation
125
- u_Re = data_vars["u_Re"] * np.cos(angle) + data_vars["v_Re"] * np.sin(angle)
126
- v_Re = data_vars["v_Re"] * np.cos(angle) - data_vars["u_Re"] * np.sin(angle)
127
- u_Im = data_vars["u_Im"] * np.cos(angle) + data_vars["v_Im"] * np.sin(angle)
128
- v_Im = data_vars["v_Im"] * np.cos(angle) - data_vars["u_Im"] * np.sin(angle)
108
+ data_vars = super().process_velocities(
109
+ data_vars, angle, "u_Re", "v_Re", interpolate=False
110
+ )
111
+ data_vars = super().process_velocities(
112
+ data_vars, angle, "u_Im", "v_Im", interpolate=False
113
+ )
129
114
 
130
115
  # Convert to barotropic velocity
131
- u_Re = u_Re / self.grid.ds.h
132
- v_Re = v_Re / self.grid.ds.h
133
- u_Im = u_Im / self.grid.ds.h
134
- v_Im = v_Im / self.grid.ds.h
116
+ for varname in ["u_Re", "v_Re", "u_Im", "v_Im"]:
117
+ data_vars[varname] = data_vars[varname] / self.grid.ds.h
135
118
 
136
119
  # Interpolate from rho- to velocity points
137
- u_Re = interpolate_from_rho_to_u(u_Re)
138
- v_Re = interpolate_from_rho_to_v(v_Re)
139
- u_Im = interpolate_from_rho_to_u(u_Im)
140
- v_Im = interpolate_from_rho_to_v(v_Im)
120
+ for uname in ["u_Re", "u_Im"]:
121
+ data_vars[uname] = interpolate_from_rho_to_u(data_vars[uname])
122
+ for vname in ["v_Re", "v_Im"]:
123
+ data_vars[vname] = interpolate_from_rho_to_v(data_vars[vname])
124
+
125
+ d_meta = super().get_variable_metadata()
126
+ ds = self._write_into_dataset(data_vars, d_meta)
127
+ ds["omega"] = tides["omega"]
128
+
129
+ ds = self._add_global_metadata(ds)
130
+
131
+ for var in ["ssh_Re", "u_Re", "v_Im"]:
132
+ nan_check(ds[var].isel(ntides=0), self.grid.ds.mask_rho)
133
+
134
+ object.__setattr__(self, "ds", ds)
135
+
136
+ def _input_checks(self):
137
+
138
+ if "name" not in self.source.keys():
139
+ raise ValueError("`source` must include a 'name'.")
140
+ if "path" not in self.source.keys():
141
+ raise ValueError("`source` must include a 'path'.")
142
+
143
+ def _get_data(self):
144
+
145
+ if self.source["name"] == "TPXO":
146
+ data = TPXODataset(filename=self.source["path"])
147
+ else:
148
+ raise ValueError('Only "TPXO" is a valid option for source["name"].')
149
+ return data
150
+
151
+ def _write_into_dataset(self, data_vars, d_meta):
141
152
 
142
153
  # save in new dataset
143
154
  ds = xr.Dataset()
144
155
 
145
- # ds["omega"] = tides["omega"]
146
-
147
- ds["ssh_Re"] = data_vars["ssh_Re"].astype(np.float32)
148
- ds["ssh_Im"] = data_vars["ssh_Im"].astype(np.float32)
149
- ds["ssh_Re"].attrs["long_name"] = "Tidal elevation, real part"
150
- ds["ssh_Im"].attrs["long_name"] = "Tidal elevation, complex part"
151
- ds["ssh_Re"].attrs["units"] = "m"
152
- ds["ssh_Im"].attrs["units"] = "m"
153
-
154
- ds["pot_Re"] = data_vars["pot_Re"].astype(np.float32)
155
- ds["pot_Im"] = data_vars["pot_Im"].astype(np.float32)
156
- ds["pot_Re"].attrs["long_name"] = "Tidal potential, real part"
157
- ds["pot_Im"].attrs["long_name"] = "Tidal potential, complex part"
158
- ds["pot_Re"].attrs["units"] = "m"
159
- ds["pot_Im"].attrs["units"] = "m"
160
-
161
- ds["u_Re"] = u_Re.astype(np.float32)
162
- ds["u_Im"] = u_Im.astype(np.float32)
163
- ds["u_Re"].attrs["long_name"] = "Tidal velocity in x-direction, real part"
164
- ds["u_Im"].attrs["long_name"] = "Tidal velocity in x-direction, complex part"
165
- ds["u_Re"].attrs["units"] = "m/s"
166
- ds["u_Im"].attrs["units"] = "m/s"
167
-
168
- ds["v_Re"] = v_Re.astype(np.float32)
169
- ds["v_Im"] = v_Im.astype(np.float32)
170
- ds["v_Re"].attrs["long_name"] = "Tidal velocity in y-direction, real part"
171
- ds["v_Im"].attrs["long_name"] = "Tidal velocity in y-direction, complex part"
172
- ds["v_Re"].attrs["units"] = "m/s"
173
- ds["v_Im"].attrs["units"] = "m/s"
156
+ for var in data_vars.keys():
157
+ ds[var] = data_vars[var].astype(np.float32)
158
+ ds[var].attrs["long_name"] = d_meta[var]["long_name"]
159
+ ds[var].attrs["units"] = d_meta[var]["units"]
160
+
161
+ ds = ds.drop_vars(["lat_rho", "lon_rho"])
162
+
163
+ return ds
164
+
165
+ def _add_global_metadata(self, ds):
174
166
 
175
167
  ds.attrs["title"] = "ROMS tidal forcing created by ROMS-Tools"
176
168
  # Include the version of roms-tools
@@ -185,10 +177,7 @@ class TidalForcing:
185
177
  ds.attrs["model_reference_date"] = str(self.model_reference_date)
186
178
  ds.attrs["allan_factor"] = self.allan_factor
187
179
 
188
- object.__setattr__(self, "ds", ds)
189
-
190
- for var in ["ssh_Re", "u_Re", "v_Im"]:
191
- nan_check(self.ds[var].isel(ntides=0), self.grid.ds.mask_rho)
180
+ return ds
192
181
 
193
182
  def plot(self, varname, ntides=0) -> None:
194
183
  """
@@ -228,6 +217,25 @@ class TidalForcing:
228
217
  """
229
218
 
230
219
  field = self.ds[varname].isel(ntides=ntides).compute()
220
+ if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
221
+ field = field.where(self.grid.ds.mask_rho)
222
+ field = field.assign_coords(
223
+ {"lon": self.grid.ds.lon_rho, "lat": self.grid.ds.lat_rho}
224
+ )
225
+
226
+ elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
227
+ field = field.where(self.grid.ds.mask_u)
228
+ field = field.assign_coords(
229
+ {"lon": self.grid.ds.lon_u, "lat": self.grid.ds.lat_u}
230
+ )
231
+
232
+ elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
233
+ field = field.where(self.grid.ds.mask_v)
234
+ field = field.assign_coords(
235
+ {"lon": self.grid.ds.lon_v, "lat": self.grid.ds.lat_v}
236
+ )
237
+ else:
238
+ ValueError("provided field does not have two horizontal dimension")
231
239
 
232
240
  title = "%s, ntides = %i" % (field.long_name, self.ds[varname].ntides[ntides])
233
241
 
@@ -114,4 +114,9 @@ def compute_depth(zeta, h, hc, cs, sigma):
114
114
  s = (hc * sigma + h * cs) / (hc + h)
115
115
  z = zeta + (zeta + h) * s
116
116
 
117
+ if "s_rho" in z.dims:
118
+ z = z.transpose("s_rho", "eta_rho", "xi_rho")
119
+ elif "s_w" in z.dims:
120
+ z = z.transpose("s_w", "eta_rho", "xi_rho")
121
+
117
122
  return z