roms-tools 1.1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
roms_tools/_version.py CHANGED
@@ -1,2 +1,2 @@
1
1
  # Do not change! Do not track in version control!
2
- __version__ = "1.1.0"
2
+ __version__ = "1.2.0"
@@ -101,7 +101,7 @@ class BoundaryForcing(ROMSToolsMixins):
101
101
  vars_2d = ["zeta"]
102
102
  vars_3d = ["temp", "salt", "u", "v"]
103
103
  data_vars = super().regrid_data(data, vars_2d, vars_3d, lon, lat)
104
- data_vars = super().process_velocities(data_vars, angle)
104
+ data_vars = super().process_velocities(data_vars, angle, "u", "v")
105
105
  object.__setattr__(data, "data_vars", data_vars)
106
106
 
107
107
  if self.bgc_source is not None:
@@ -121,9 +121,9 @@ class BoundaryForcing(ROMSToolsMixins):
121
121
  bgc_data = None
122
122
 
123
123
  d_meta = super().get_variable_metadata()
124
- bdry_coords, rename = super().get_boundary_info()
124
+ bdry_coords = super().get_boundary_info()
125
125
 
126
- ds = self._write_into_datatree(data, bgc_data, d_meta, bdry_coords, rename)
126
+ ds = self._write_into_datatree(data, bgc_data, d_meta, bdry_coords)
127
127
 
128
128
  for direction in ["south", "east", "north", "west"]:
129
129
  if self.boundaries[direction]:
@@ -202,7 +202,7 @@ class BoundaryForcing(ROMSToolsMixins):
202
202
 
203
203
  return data
204
204
 
205
- def _write_into_dataset(self, data, d_meta, bdry_coords, rename):
205
+ def _write_into_dataset(self, data, d_meta, bdry_coords):
206
206
 
207
207
  # save in new dataset
208
208
  ds = xr.Dataset()
@@ -215,21 +215,18 @@ class BoundaryForcing(ROMSToolsMixins):
215
215
  ds[f"{var}_{direction}"] = (
216
216
  data.data_vars[var]
217
217
  .isel(**bdry_coords["u"][direction])
218
- .rename(**rename["u"][direction])
219
218
  .astype(np.float32)
220
219
  )
221
220
  elif var in ["v", "vbar"]:
222
221
  ds[f"{var}_{direction}"] = (
223
222
  data.data_vars[var]
224
223
  .isel(**bdry_coords["v"][direction])
225
- .rename(**rename["v"][direction])
226
224
  .astype(np.float32)
227
225
  )
228
226
  else:
229
227
  ds[f"{var}_{direction}"] = (
230
228
  data.data_vars[var]
231
229
  .isel(**bdry_coords["rho"][direction])
232
- .rename(**rename["rho"][direction])
233
230
  .astype(np.float32)
234
231
  )
235
232
  ds[f"{var}_{direction}"].attrs[
@@ -288,7 +285,7 @@ class BoundaryForcing(ROMSToolsMixins):
288
285
 
289
286
  return ds
290
287
 
291
- def _write_into_datatree(self, data, bgc_data, d_meta, bdry_coords, rename):
288
+ def _write_into_datatree(self, data, bgc_data, d_meta, bdry_coords):
292
289
 
293
290
  ds = self._add_global_metadata()
294
291
  ds["sc_r"] = self.grid.ds["sc_r"]
@@ -296,119 +293,55 @@ class BoundaryForcing(ROMSToolsMixins):
296
293
 
297
294
  ds = DataTree(name="root", data=ds)
298
295
 
299
- ds_physics = self._write_into_dataset(data, d_meta, bdry_coords, rename)
300
- ds_physics = self._add_coordinates(bdry_coords, rename, ds_physics)
296
+ ds_physics = self._write_into_dataset(data, d_meta, bdry_coords)
301
297
  ds_physics = self._add_global_metadata(ds_physics)
302
298
  ds_physics.attrs["physics_source"] = self.physics_source["name"]
303
299
 
304
300
  ds_physics = DataTree(name="physics", parent=ds, data=ds_physics)
305
301
 
306
302
  if bgc_data:
307
- ds_bgc = self._write_into_dataset(bgc_data, d_meta, bdry_coords, rename)
308
- ds_bgc = self._add_coordinates(bdry_coords, rename, ds_bgc)
303
+ ds_bgc = self._write_into_dataset(bgc_data, d_meta, bdry_coords)
309
304
  ds_bgc = self._add_global_metadata(ds_bgc)
310
305
  ds_bgc.attrs["bgc_source"] = self.bgc_source["name"]
311
306
  ds_bgc = DataTree(name="bgc", parent=ds, data=ds_bgc)
312
307
 
313
308
  return ds
314
309
 
315
- def _add_coordinates(self, bdry_coords, rename, ds=None):
316
-
317
- if ds is None:
318
- ds = xr.Dataset()
319
-
320
- for direction in ["south", "east", "north", "west"]:
321
-
322
- if self.boundaries[direction]:
310
+ def _get_coordinates(self, direction, point):
311
+ """
312
+ Retrieve layer and interface depth coordinates for a specified grid boundary.
323
313
 
324
- lat_rho = self.grid.ds.lat_rho.isel(
325
- **bdry_coords["rho"][direction]
326
- ).rename(**rename["rho"][direction])
327
- lon_rho = self.grid.ds.lon_rho.isel(
328
- **bdry_coords["rho"][direction]
329
- ).rename(**rename["rho"][direction])
330
- layer_depth_rho = (
331
- self.grid.ds["layer_depth_rho"]
332
- .isel(**bdry_coords["rho"][direction])
333
- .rename(**rename["rho"][direction])
334
- )
335
- interface_depth_rho = (
336
- self.grid.ds["interface_depth_rho"]
337
- .isel(**bdry_coords["rho"][direction])
338
- .rename(**rename["rho"][direction])
339
- )
314
+ This method extracts the layer depth and interface depth coordinates along
315
+ a specified boundary (north, south, east, or west) and for a specified point
316
+ type (rho, u, or v) from the grid dataset.
340
317
 
341
- lat_u = self.grid.ds.lat_u.isel(**bdry_coords["u"][direction]).rename(
342
- **rename["u"][direction]
343
- )
344
- lon_u = self.grid.ds.lon_u.isel(**bdry_coords["u"][direction]).rename(
345
- **rename["u"][direction]
346
- )
347
- layer_depth_u = (
348
- self.grid.ds["layer_depth_u"]
349
- .isel(**bdry_coords["u"][direction])
350
- .rename(**rename["u"][direction])
351
- )
352
- interface_depth_u = (
353
- self.grid.ds["interface_depth_u"]
354
- .isel(**bdry_coords["u"][direction])
355
- .rename(**rename["u"][direction])
356
- )
318
+ Parameters
319
+ ----------
320
+ direction : str
321
+ The direction of the boundary to retrieve coordinates for. Valid options
322
+ are "north", "south", "east", and "west".
323
+ point : str
324
+ The type of grid point to retrieve coordinates for. Valid options are
325
+ "rho" for the grid's central points, "u" for the u-flux points, and "v"
326
+ for the v-flux points.
357
327
 
358
- lat_v = self.grid.ds.lat_v.isel(**bdry_coords["v"][direction]).rename(
359
- **rename["v"][direction]
360
- )
361
- lon_v = self.grid.ds.lon_v.isel(**bdry_coords["v"][direction]).rename(
362
- **rename["v"][direction]
363
- )
364
- layer_depth_v = (
365
- self.grid.ds["layer_depth_v"]
366
- .isel(**bdry_coords["v"][direction])
367
- .rename(**rename["v"][direction])
368
- )
369
- interface_depth_v = (
370
- self.grid.ds["interface_depth_v"]
371
- .isel(**bdry_coords["v"][direction])
372
- .rename(**rename["v"][direction])
373
- )
328
+ Returns
329
+ -------
330
+ xarray.DataArray, xarray.DataArray
331
+ The layer depth and interface depth coordinates for the specified grid
332
+ boundary and point type.
333
+ """
374
334
 
375
- ds = ds.assign_coords(
376
- {
377
- f"layer_depth_rho_{direction}": layer_depth_rho,
378
- f"layer_depth_u_{direction}": layer_depth_u,
379
- f"layer_depth_v_{direction}": layer_depth_v,
380
- f"interface_depth_rho_{direction}": interface_depth_rho,
381
- f"interface_depth_u_{direction}": interface_depth_u,
382
- f"interface_depth_v_{direction}": interface_depth_v,
383
- f"lat_rho_{direction}": lat_rho,
384
- f"lat_u_{direction}": lat_u,
385
- f"lat_v_{direction}": lat_v,
386
- f"lon_rho_{direction}": lon_rho,
387
- f"lon_u_{direction}": lon_u,
388
- f"lon_v_{direction}": lon_v,
389
- }
390
- )
335
+ bdry_coords = super().get_boundary_info()
391
336
 
392
- # Gracefully handle dropping variables that might not be present
393
- variables_to_drop = [
394
- "s_rho",
395
- "layer_depth_rho",
396
- "layer_depth_u",
397
- "layer_depth_v",
398
- "interface_depth_rho",
399
- "interface_depth_u",
400
- "interface_depth_v",
401
- "lat_rho",
402
- "lon_rho",
403
- "lat_u",
404
- "lon_u",
405
- "lat_v",
406
- "lon_v",
407
- ]
408
- existing_vars = [var for var in variables_to_drop if var in ds]
409
- ds = ds.drop_vars(existing_vars)
337
+ layer_depth = self.grid.ds[f"layer_depth_{point}"].isel(
338
+ **bdry_coords[point][direction]
339
+ )
340
+ interface_depth = self.grid.ds[f"interface_depth_{point}"].isel(
341
+ **bdry_coords[point][direction]
342
+ )
410
343
 
411
- return ds
344
+ return layer_depth, interface_depth
412
345
 
413
346
  def _add_global_metadata(self, ds=None):
414
347
 
@@ -486,8 +419,9 @@ class BoundaryForcing(ROMSToolsMixins):
486
419
  time : int, optional
487
420
  The time index to plot. Default is 0.
488
421
  layer_contours : bool, optional
489
- Whether to include layer contours in the plot. This can help visualize the depth levels
490
- of the field. Default is False.
422
+ If True, contour lines representing the boundaries between vertical layers will
423
+ be added to the plot. For clarity, the number of layer
424
+ contours displayed is limited to a maximum of 10. Default is False.
491
425
 
492
426
  Returns
493
427
  -------
@@ -513,6 +447,19 @@ class BoundaryForcing(ROMSToolsMixins):
513
447
  field = ds[varname].isel(bry_time=time).load()
514
448
  title = field.long_name
515
449
 
450
+ if "s_rho" in field.dims:
451
+ if varname.startswith(("u_", "ubar_")):
452
+ point = "u"
453
+ elif varname.startswith(("v_", "vbar_")):
454
+ point = "v"
455
+ else:
456
+ point = "rho"
457
+ direction = varname.split("_")[-1]
458
+
459
+ layer_depth, interface_depth = self._get_coordinates(direction, point)
460
+
461
+ field = field.assign_coords({"layer_depth": layer_depth})
462
+
516
463
  # chose colorbar
517
464
  if varname.startswith(("u", "v", "ubar", "vbar", "zeta")):
518
465
  vmax = max(field.max().values, -field.min().values)
@@ -530,27 +477,6 @@ class BoundaryForcing(ROMSToolsMixins):
530
477
 
531
478
  if len(field.dims) == 2:
532
479
  if layer_contours:
533
- depths_to_check = [
534
- "interface_depth_rho",
535
- "interface_depth_u",
536
- "interface_depth_v",
537
- ]
538
- try:
539
- interface_depth = next(
540
- ds[depth_label]
541
- for depth_label in ds.coords
542
- if any(
543
- depth_label.startswith(prefix) for prefix in depths_to_check
544
- )
545
- and (
546
- set(ds[depth_label].dims) - {"s_w"}
547
- == set(field.dims) - {"s_rho"}
548
- )
549
- )
550
- except StopIteration:
551
- raise ValueError(
552
- f"None of the expected depths ({', '.join(depths_to_check)}) have dimensions matching field.dims"
553
- )
554
480
  # restrict number of layer_contours to 10 for the sake of plot clearity
555
481
  nr_layers = len(interface_depth["s_w"])
556
482
  selected_layers = np.linspace(
@@ -584,9 +584,14 @@ class TPXODataset(Dataset):
584
584
  "ntides": self.dim_names["ntides"],
585
585
  },
586
586
  )
587
+ self.check_dataset(ds)
588
+
587
589
  # Select relevant fields
588
590
  ds = super().select_relevant_fields(ds)
589
591
 
592
+ # Make sure that latitude is ascending
593
+ ds = super().ensure_latitude_ascending(ds)
594
+
590
595
  # Check whether the data covers the entire globe
591
596
  object.__setattr__(self, "is_global", super().check_if_global(ds))
592
597
 
@@ -769,6 +774,8 @@ class CESMDataset(Dataset):
769
774
  ds = assign_dates_to_climatology(ds, time_dim)
770
775
  # rename dimension
771
776
  ds = ds.swap_dims({time_dim: "time"})
777
+ if time_dim in ds.variables:
778
+ ds = ds.drop_vars(time_dim)
772
779
  # Update dimension names
773
780
  updated_dim_names = self.dim_names.copy()
774
781
  updated_dim_names["time"] = "time"
@@ -872,9 +879,9 @@ class CESMBGCDataset(CESMDataset):
872
879
  ds["depth"].attrs["long_name"] = "Depth"
873
880
  ds["depth"].attrs["units"] = "m"
874
881
  ds = ds.swap_dims({"z_t": "depth"})
875
- if "z_t" in ds:
882
+ if "z_t" in ds.variables:
876
883
  ds = ds.drop_vars("z_t")
877
- if "z_t_150m" in ds:
884
+ if "z_t_150m" in ds.variables:
878
885
  ds = ds.drop_vars("z_t_150m")
879
886
  # update dataset
880
887
  object.__setattr__(self, "ds", ds)
@@ -932,6 +939,19 @@ class CESMBGCSurfaceForcingDataset(CESMDataset):
932
939
 
933
940
  climatology: Optional[bool] = False
934
941
 
942
+ def post_process(self):
943
+ """
944
+ Perform post-processing on the dataset to remove specific variables.
945
+
946
+ This method checks if the variable "z_t" exists in the dataset. If it does,
947
+ the variable is removed from the dataset. The modified dataset is then
948
+ reassigned to the `ds` attribute of the object.
949
+ """
950
+
951
+ if "z_t" in self.ds.variables:
952
+ ds = self.ds.drop_vars("z_t")
953
+ object.__setattr__(self, "ds", ds)
954
+
935
955
 
936
956
  @dataclass(frozen=True, kw_only=True)
937
957
  class ERA5Dataset(Dataset):
roms_tools/setup/grid.py CHANGED
@@ -361,11 +361,20 @@ class Grid:
361
361
  """
362
362
 
363
363
  if bathymetry:
364
- kwargs = {"cmap": "YlGnBu"}
364
+ field = self.ds.h.where(self.ds.mask_rho)
365
+ field = field.assign_coords(
366
+ {"lon": self.ds.lon_rho, "lat": self.ds.lat_rho}
367
+ )
368
+
369
+ vmax = field.max().values
370
+ vmin = field.min().values
371
+ cmap = plt.colormaps.get_cmap("YlGnBu")
372
+ cmap.set_bad(color="gray")
373
+ kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
365
374
 
366
375
  _plot(
367
376
  self.ds,
368
- field=self.ds.h.where(self.ds.mask_rho),
377
+ field=field,
369
378
  straddle=self.straddle,
370
379
  kwargs=kwargs,
371
380
  )
@@ -419,10 +428,18 @@ class Grid:
419
428
 
420
429
  if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
421
430
  interface_depth = self.ds.interface_depth_rho
431
+ field = field.where(self.ds.mask_rho)
432
+ field = field.assign_coords(
433
+ {"lon": self.ds.lon_rho, "lat": self.ds.lat_rho}
434
+ )
422
435
  elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
423
436
  interface_depth = self.ds.interface_depth_u
437
+ field = field.where(self.ds.mask_u)
438
+ field = field.assign_coords({"lon": self.ds.lon_u, "lat": self.ds.lat_u})
424
439
  elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
425
440
  interface_depth = self.ds.interface_depth_v
441
+ field = field.where(self.ds.mask_v)
442
+ field = field.assign_coords({"lon": self.ds.lon_v, "lat": self.ds.lat_v})
426
443
 
427
444
  # slice the field as desired
428
445
  title = field.long_name
@@ -476,10 +493,9 @@ class Grid:
476
493
  self.ds,
477
494
  field=field,
478
495
  straddle=self.straddle,
479
- depth_contours=True,
496
+ depth_contours=False,
480
497
  title=title,
481
498
  kwargs=kwargs,
482
- c="g",
483
499
  )
484
500
  else:
485
501
  if len(field.dims) == 2:
@@ -82,7 +82,7 @@ class InitialConditions(ROMSToolsMixins):
82
82
  vars_2d = ["zeta"]
83
83
  vars_3d = ["temp", "salt", "u", "v"]
84
84
  data_vars = super().regrid_data(data, vars_2d, vars_3d, lon, lat)
85
- data_vars = super().process_velocities(data_vars, angle)
85
+ data_vars = super().process_velocities(data_vars, angle, "u", "v")
86
86
 
87
87
  if self.bgc_source is not None:
88
88
  bgc_data = self._get_bgc_data()
@@ -172,18 +172,18 @@ class InitialConditions(ROMSToolsMixins):
172
172
 
173
173
  if self.bgc_source["name"] == "CESM_REGRIDDED":
174
174
 
175
- bgc_data = CESMBGCDataset(
175
+ data = CESMBGCDataset(
176
176
  filename=self.bgc_source["path"],
177
177
  start_time=self.ini_time,
178
178
  climatology=self.bgc_source["climatology"],
179
179
  )
180
- bgc_data.post_process()
180
+ data.post_process()
181
181
  else:
182
182
  raise ValueError(
183
183
  'Only "CESM_REGRIDDED" is a valid option for bgc_source["name"].'
184
184
  )
185
185
 
186
- return bgc_data
186
+ return data
187
187
 
188
188
  def _write_into_dataset(self, data_vars, d_meta):
189
189
 
@@ -202,19 +202,43 @@ class InitialConditions(ROMSToolsMixins):
202
202
  ds["w"].attrs["long_name"] = d_meta["w"]["long_name"]
203
203
  ds["w"].attrs["units"] = d_meta["w"]["units"]
204
204
 
205
+ variables_to_drop = [
206
+ "s_rho",
207
+ "lat_rho",
208
+ "lon_rho",
209
+ "layer_depth_rho",
210
+ "interface_depth_rho",
211
+ "lat_u",
212
+ "lon_u",
213
+ "lat_v",
214
+ "lon_v",
215
+ ]
216
+ existing_vars = [var for var in variables_to_drop if var in ds]
217
+ ds = ds.drop_vars(existing_vars)
218
+
219
+ ds["sc_r"] = self.grid.ds["sc_r"]
220
+ ds["Cs_r"] = self.grid.ds["Cs_r"]
221
+
222
+ # Preserve absolute time coordinate for readability
223
+ ds = ds.assign_coords({"abs_time": ds["time"]})
224
+
225
+ # Translate the time coordinate to days since the model reference date
226
+ model_reference_date = np.datetime64(self.model_reference_date)
227
+
228
+ # Convert the time coordinate to the format expected by ROMS (days since model reference date)
229
+ ocean_time = (ds["time"] - model_reference_date).astype("float64") * 1e-9
230
+ ds = ds.assign_coords(ocean_time=("time", np.float32(ocean_time)))
231
+ ds["ocean_time"].attrs[
232
+ "long_name"
233
+ ] = f"seconds since {np.datetime_as_string(model_reference_date, unit='s')}"
234
+ ds["ocean_time"].attrs["units"] = "seconds"
235
+ ds = ds.swap_dims({"time": "ocean_time"})
236
+ ds = ds.drop_vars("time")
237
+
205
238
  return ds
206
239
 
207
240
  def _add_global_metadata(self, ds):
208
241
 
209
- ds = ds.assign_coords(
210
- {
211
- "layer_depth_u": self.grid.ds["layer_depth_u"],
212
- "layer_depth_v": self.grid.ds["layer_depth_v"],
213
- "interface_depth_u": self.grid.ds["interface_depth_u"],
214
- "interface_depth_v": self.grid.ds["interface_depth_v"],
215
- }
216
- )
217
-
218
242
  ds.attrs["title"] = "ROMS initial conditions file created by ROMS-Tools"
219
243
  # Include the version of roms-tools
220
244
  try:
@@ -228,24 +252,9 @@ class InitialConditions(ROMSToolsMixins):
228
252
  if self.bgc_source is not None:
229
253
  ds.attrs["bgc_source"] = self.bgc_source["name"]
230
254
 
231
- # Translate the time coordinate to days since the model reference date
232
- model_reference_date = np.datetime64(self.model_reference_date)
233
-
234
- # Convert the time coordinate to the format expected by ROMS (days since model reference date)
235
- ocean_time = (ds["time"] - model_reference_date).astype("float64") * 1e-9
236
- ds = ds.assign_coords(ocean_time=("time", np.float32(ocean_time)))
237
- ds["ocean_time"].attrs[
238
- "long_name"
239
- ] = f"seconds since {np.datetime_as_string(model_reference_date, unit='s')}"
240
- ds["ocean_time"].attrs["units"] = "seconds"
241
-
242
255
  ds.attrs["theta_s"] = self.grid.ds.attrs["theta_s"]
243
256
  ds.attrs["theta_b"] = self.grid.ds.attrs["theta_b"]
244
257
  ds.attrs["hc"] = self.grid.ds.attrs["hc"]
245
- ds["sc_r"] = self.grid.ds["sc_r"]
246
- ds["Cs_r"] = self.grid.ds["Cs_r"]
247
-
248
- ds = ds.drop_vars(["s_rho"])
249
258
 
250
259
  return ds
251
260
 
@@ -306,13 +315,23 @@ class InitialConditions(ROMSToolsMixins):
306
315
  - "diazP": Diazotroph Phosphorus (mmol/m³).
307
316
  - "diazFe": Diazotroph Iron (mmol/m³).
308
317
  s : int, optional
309
- The index of the vertical layer to plot. Default is None.
318
+ The index of the vertical layer (`s_rho`) to plot. If not specified, the plot
319
+ will represent a horizontal slice (eta- or xi- plane). Default is None.
310
320
  eta : int, optional
311
- The eta-index to plot. Default is None.
321
+ The eta-index to plot. Used for vertical sections or horizontal slices.
322
+ Default is None.
312
323
  xi : int, optional
313
- The xi-index to plot. Default is None.
324
+ The xi-index to plot. Used for vertical sections or horizontal slices.
325
+ Default is None.
314
326
  depth_contours : bool, optional
315
- Whether to include depth contours in the plot. Default is False.
327
+ If True, depth contours will be overlaid on the plot, showing lines of constant
328
+ depth. This is typically used for plots that show a single vertical layer.
329
+ Default is False.
330
+ layer_contours : bool, optional
331
+ If True, contour lines representing the boundaries between vertical layers will
332
+ be added to the plot. This is particularly useful in vertical sections to
333
+ visualize the layering of the water column. For clarity, the number of layer
334
+ contours displayed is limited to a maximum of 10. Default is False.
316
335
 
317
336
  Returns
318
337
  -------
@@ -325,6 +344,7 @@ class InitialConditions(ROMSToolsMixins):
325
344
  If the specified `varname` is not one of the valid options.
326
345
  If the field specified by `varname` is 3D and none of `s`, `eta`, or `xi` are specified.
327
346
  If the field specified by `varname` is 2D and both `eta` and `xi` are specified.
347
+
328
348
  """
329
349
 
330
350
  if len(self.ds[varname].squeeze().dims) == 3 and not any(
@@ -343,17 +363,38 @@ class InitialConditions(ROMSToolsMixins):
343
363
  field = self.ds[varname].squeeze()
344
364
 
345
365
  if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
346
- interface_depth = self.ds.interface_depth_rho
366
+ interface_depth = self.grid.ds.interface_depth_rho
367
+ layer_depth = self.grid.ds.layer_depth_rho
368
+ field = field.where(self.grid.ds.mask_rho)
369
+ field = field.assign_coords(
370
+ {"lon": self.grid.ds.lon_rho, "lat": self.grid.ds.lat_rho}
371
+ )
372
+
347
373
  elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
348
- interface_depth = self.ds.interface_depth_u
374
+ interface_depth = self.grid.ds.interface_depth_u
375
+ layer_depth = self.grid.ds.layer_depth_u
376
+ field = field.where(self.grid.ds.mask_u)
377
+ field = field.assign_coords(
378
+ {"lon": self.grid.ds.lon_u, "lat": self.grid.ds.lat_u}
379
+ )
380
+
349
381
  elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
350
- interface_depth = self.ds.interface_depth_v
382
+ interface_depth = self.grid.ds.interface_depth_v
383
+ layer_depth = self.grid.ds.layer_depth_v
384
+ field = field.where(self.grid.ds.mask_v)
385
+ field = field.assign_coords(
386
+ {"lon": self.grid.ds.lon_v, "lat": self.grid.ds.lat_v}
387
+ )
388
+ else:
389
+ ValueError("provided field does not have two horizontal dimension")
351
390
 
352
391
  # slice the field as desired
353
392
  title = field.long_name
354
393
  if s is not None:
355
394
  title = title + f", s_rho = {field.s_rho[s].item()}"
356
395
  field = field.isel(s_rho=s)
396
+ layer_depth = layer_depth.isel(s_rho=s)
397
+ field = field.assign_coords({"layer_depth": layer_depth})
357
398
  else:
358
399
  depth_contours = False
359
400
 
@@ -361,10 +402,14 @@ class InitialConditions(ROMSToolsMixins):
361
402
  if "eta_rho" in field.dims:
362
403
  title = title + f", eta_rho = {field.eta_rho[eta].item()}"
363
404
  field = field.isel(eta_rho=eta)
405
+ layer_depth = layer_depth.isel(eta_rho=eta)
406
+ field = field.assign_coords({"layer_depth": layer_depth})
364
407
  interface_depth = interface_depth.isel(eta_rho=eta)
365
408
  elif "eta_v" in field.dims:
366
409
  title = title + f", eta_v = {field.eta_v[eta].item()}"
367
410
  field = field.isel(eta_v=eta)
411
+ layer_depth = layer_depth.isel(eta_v=eta)
412
+ field = field.assign_coords({"layer_depth": layer_depth})
368
413
  interface_depth = interface_depth.isel(eta_v=eta)
369
414
  else:
370
415
  raise ValueError(
@@ -374,10 +419,14 @@ class InitialConditions(ROMSToolsMixins):
374
419
  if "xi_rho" in field.dims:
375
420
  title = title + f", xi_rho = {field.xi_rho[xi].item()}"
376
421
  field = field.isel(xi_rho=xi)
422
+ layer_depth = layer_depth.isel(xi_rho=xi)
423
+ field = field.assign_coords({"layer_depth": layer_depth})
377
424
  interface_depth = interface_depth.isel(xi_rho=xi)
378
425
  elif "xi_u" in field.dims:
379
426
  title = title + f", xi_u = {field.xi_u[xi].item()}"
380
427
  field = field.isel(xi_u=xi)
428
+ layer_depth = layer_depth.isel(xi_u=xi)
429
+ field = field.assign_coords({"layer_depth": layer_depth})
381
430
  interface_depth = interface_depth.isel(xi_u=xi)
382
431
  else:
383
432
  raise ValueError(