roms-tools 2.2.1__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. roms_tools/__init__.py +1 -0
  2. roms_tools/analysis/roms_output.py +586 -0
  3. roms_tools/{setup/download.py → download.py} +3 -0
  4. roms_tools/{setup/plot.py → plot.py} +34 -28
  5. roms_tools/setup/boundary_forcing.py +23 -12
  6. roms_tools/setup/datasets.py +2 -135
  7. roms_tools/setup/grid.py +54 -15
  8. roms_tools/setup/initial_conditions.py +105 -149
  9. roms_tools/setup/nesting.py +4 -4
  10. roms_tools/setup/river_forcing.py +7 -9
  11. roms_tools/setup/surface_forcing.py +14 -14
  12. roms_tools/setup/tides.py +24 -21
  13. roms_tools/setup/topography.py +1 -1
  14. roms_tools/setup/utils.py +20 -154
  15. roms_tools/tests/test_analysis/test_roms_output.py +269 -0
  16. roms_tools/tests/{test_setup/test_regrid.py → test_regrid.py} +1 -1
  17. roms_tools/tests/test_setup/test_boundary_forcing.py +1 -1
  18. roms_tools/tests/test_setup/test_datasets.py +1 -1
  19. roms_tools/tests/test_setup/test_grid.py +1 -1
  20. roms_tools/tests/test_setup/test_initial_conditions.py +1 -1
  21. roms_tools/tests/test_setup/test_river_forcing.py +1 -1
  22. roms_tools/tests/test_setup/test_surface_forcing.py +1 -1
  23. roms_tools/tests/test_setup/test_tides.py +1 -1
  24. roms_tools/tests/test_setup/test_topography.py +1 -1
  25. roms_tools/tests/test_setup/test_utils.py +56 -1
  26. roms_tools/utils.py +301 -0
  27. roms_tools/vertical_coordinate.py +306 -0
  28. {roms_tools-2.2.1.dist-info → roms_tools-2.3.0.dist-info}/METADATA +1 -1
  29. {roms_tools-2.2.1.dist-info → roms_tools-2.3.0.dist-info}/RECORD +33 -31
  30. roms_tools/setup/vertical_coordinate.py +0 -109
  31. /roms_tools/{setup/regrid.py → regrid.py} +0 -0
  32. {roms_tools-2.2.1.dist-info → roms_tools-2.3.0.dist-info}/LICENSE +0 -0
  33. {roms_tools-2.2.1.dist-info → roms_tools-2.3.0.dist-info}/WHEEL +0 -0
  34. {roms_tools-2.2.1.dist-info → roms_tools-2.3.0.dist-info}/top_level.txt +0 -0
@@ -3,10 +3,14 @@ import numpy as np
3
3
  import importlib.metadata
4
4
  from dataclasses import dataclass, field
5
5
  from typing import Dict, Union, List, Optional
6
- from roms_tools.setup.grid import Grid
6
+ import matplotlib.pyplot as plt
7
+ from pathlib import Path
7
8
  from datetime import datetime
9
+ from roms_tools import Grid
10
+ from roms_tools.regrid import LateralRegrid, VerticalRegrid
11
+ from roms_tools.plot import _plot, _section_plot, _profile_plot, _line_plot
12
+ from roms_tools.utils import transpose_dimensions
8
13
  from roms_tools.setup.datasets import GLORYSDataset, CESMBGCDataset
9
- from roms_tools.setup.vertical_coordinate import compute_depth
10
14
  from roms_tools.setup.utils import (
11
15
  nan_check,
12
16
  substitute_nans_by_fillvalue,
@@ -15,16 +19,9 @@ from roms_tools.setup.utils import (
15
19
  get_target_coords,
16
20
  rotate_velocities,
17
21
  compute_barotropic_velocity,
18
- transpose_dimensions,
19
- interpolate_from_rho_to_u,
20
- interpolate_from_rho_to_v,
21
22
  _to_yaml,
22
23
  _from_yaml,
23
24
  )
24
- from roms_tools.setup.regrid import LateralRegrid, VerticalRegrid
25
- from roms_tools.setup.plot import _plot, _section_plot, _profile_plot, _line_plot
26
- import matplotlib.pyplot as plt
27
- from pathlib import Path
28
25
 
29
26
 
30
27
  @dataclass(frozen=True, kw_only=True)
@@ -396,54 +393,7 @@ class InitialConditions:
396
393
  - f"{type}_depth_v": Depth coordinates at v points (if applicable).
397
394
  """
398
395
 
399
- layer_vars = []
400
- for location in ["rho"] + additional_locations:
401
- layer_vars.append(f"{type}_depth_{location}")
402
-
403
- if all(layer_var in self.grid.ds for layer_var in layer_vars):
404
- # Vertical coordinate data already exists
405
- pass
406
-
407
- elif f"{type}_depth_rho" in self.grid.ds:
408
- depth = self.grid.ds[f"{type}_depth_rho"]
409
-
410
- if "u" in additional_locations or "v" in additional_locations:
411
- # interpolation
412
- if "u" in additional_locations:
413
- depth_u = interpolate_from_rho_to_u(depth)
414
- depth_u.attrs["long_name"] = f"{type} depth at u-points"
415
- depth_u.attrs["units"] = "m"
416
- self.grid.ds[f"{type}_depth_u"] = depth_u
417
- if "v" in additional_locations:
418
- depth_v = interpolate_from_rho_to_v(depth)
419
- depth_v.attrs["long_name"] = f"{type} depth at v-points"
420
- depth_v.attrs["units"] = "m"
421
- self.grid.ds[f"{type}_depth_v"] = depth_v
422
- else:
423
- h = self.grid.ds["h"]
424
- if type == "layer":
425
- depth = compute_depth(
426
- 0, h, self.grid.hc, self.grid.ds.Cs_r, self.grid.ds.sigma_r
427
- )
428
- else:
429
- depth = compute_depth(
430
- 0, h, self.grid.hc, self.grid.ds.Cs_w, self.grid.ds.sigma_w
431
- )
432
-
433
- depth.attrs["long_name"] = f"{type} depth at rho-points"
434
- depth.attrs["units"] = "m"
435
- self.grid.ds[f"{type}_depth_rho"] = depth
436
-
437
- if "u" in additional_locations or "v" in additional_locations:
438
- # interpolation
439
- depth_u = interpolate_from_rho_to_u(depth)
440
- depth_u.attrs["long_name"] = f"{type} depth at u-points"
441
- depth_u.attrs["units"] = "m"
442
- depth_v = interpolate_from_rho_to_v(depth)
443
- depth_v.attrs["long_name"] = f"{type} depth at v-points"
444
- depth_v.attrs["units"] = "m"
445
- self.grid.ds[f"{type}_depth_u"] = depth_u
446
- self.grid.ds[f"{type}_depth_v"] = depth_v
396
+ self.grid.compute_depth_coordinates(type, additional_locations)
447
397
 
448
398
  def _write_into_dataset(self, processed_fields, d_meta):
449
399
 
@@ -656,18 +606,24 @@ class InitialConditions:
656
606
  If the field specified by `var_name` is 2D and both `eta` and `xi` are specified.
657
607
  """
658
608
 
659
- if len(self.ds[var_name].squeeze().dims) == 3 and not any(
660
- [s is not None, eta is not None, xi is not None]
661
- ):
609
+ field = self.ds[var_name].squeeze()
610
+
611
+ if len(field.dims) == 3:
612
+ if not any([s is not None, eta is not None, xi is not None]):
613
+ raise ValueError(
614
+ "Invalid input: For 3D fields, you must specify at least one of the dimensions 's', 'eta', or 'xi'."
615
+ )
616
+ if all([s is not None, eta is not None, xi is not None]):
617
+ raise ValueError(
618
+ "Ambiguous input: For 3D fields, specify at most two of 's', 'eta', or 'xi'. Specifying all three is not allowed."
619
+ )
620
+
621
+ if len(field.dims) == 2 and all([eta is not None, xi is not None]):
662
622
  raise ValueError(
663
- "For 3D fields, at least one of s, eta, or xi must be specified."
623
+ "Conflicting input: For 2D fields, specify only one dimension, either 'eta' or 'xi', not both."
664
624
  )
665
625
 
666
- if len(self.ds[var_name].squeeze().dims) == 2 and all(
667
- [eta is not None, xi is not None]
668
- ):
669
- raise ValueError("For 2D fields, specify either eta or xi, not both.")
670
-
626
+ # Load the data
671
627
  if self.use_dask:
672
628
  from dask.diagnostics import ProgressBar
673
629
 
@@ -675,54 +631,73 @@ class InitialConditions:
675
631
  self.ds[var_name].load()
676
632
 
677
633
  field = self.ds[var_name].squeeze()
678
- if s is not None:
679
- layer_contours = False
680
634
 
635
+ # Get correct mask and horizontal coordinates
681
636
  if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
682
- if layer_contours:
683
- if "interface_depth_rho" in self.grid.ds:
684
- interface_depth = self.grid.ds.interface_depth_rho
685
- else:
686
- self.get_vertical_coordinates(
687
- type="interface", additional_locations=[]
688
- )
689
- layer_depth = self.grid.ds.layer_depth_rho
690
- mask = self.grid.ds.mask_rho
691
- field = field.assign_coords(
692
- {"lon": self.grid.ds.lon_rho, "lat": self.grid.ds.lat_rho}
693
- )
694
-
637
+ loc = "rho"
695
638
  elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
696
- if layer_contours:
697
- if "interface_depth_u" in self.grid.ds:
698
- interface_depth = self.grid.ds.interface_depth_u
699
- else:
700
- self.get_vertical_coordinates(
701
- type="interface", additional_locations=["u", "v"]
702
- )
703
- layer_depth = self.grid.ds.layer_depth_u
704
- mask = self.grid.ds.mask_u
705
- field = field.assign_coords(
706
- {"lon": self.grid.ds.lon_u, "lat": self.grid.ds.lat_u}
707
- )
639
+ loc = "u"
708
640
 
709
641
  elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
710
- if layer_contours:
711
- if "interface_depth_v" in self.grid.ds:
712
- interface_depth = self.grid.ds.interface_depth_v
642
+ loc = "v"
643
+ else:
644
+ ValueError("provided field does not have two horizontal dimension")
645
+
646
+ mask = self.grid.ds[f"mask_{loc}"]
647
+ lat_deg = self.grid.ds[f"lat_{loc}"]
648
+ lon_deg = self.grid.ds[f"lon_{loc}"]
649
+
650
+ if self.grid.straddle:
651
+ lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
652
+
653
+ field = field.assign_coords({"lon": lon_deg, "lat": lat_deg})
654
+
655
+ # Retrieve depth coordinates
656
+ if s is not None:
657
+ layer_contours = False
658
+ # Note that `layer_depth_{loc}` has already been computed during `__post_init__`.
659
+ layer_depth = self.grid.ds[f"layer_depth_{loc}"]
660
+ if layer_contours:
661
+ if f"interface_depth_{loc}" not in self.grid.ds:
662
+ if loc == "rho":
663
+ self.get_vertical_coordinates(
664
+ type="interface", additional_locations=[]
665
+ )
713
666
  else:
714
667
  self.get_vertical_coordinates(
715
668
  type="interface", additional_locations=["u", "v"]
716
669
  )
717
- layer_depth = self.grid.ds.layer_depth_v
718
- mask = self.grid.ds.mask_v
719
- field = field.assign_coords(
720
- {"lon": self.grid.ds.lon_v, "lat": self.grid.ds.lat_v}
721
- )
670
+ interface_depth = self.grid.ds[f"interface_depth_{loc}"]
722
671
  else:
723
- ValueError("provided field does not have two horizontal dimension")
672
+ interface_depth = None
673
+
674
+ # Slice the field as desired
675
+ def _slice_and_assign(
676
+ field,
677
+ mask,
678
+ layer_depth,
679
+ interface_depth,
680
+ title,
681
+ dim_name,
682
+ dim_values,
683
+ idx,
684
+ layer_contours=False,
685
+ ):
686
+ if dim_name in field.dims:
687
+ title = title + f", {dim_name} = {dim_values[idx].item()}"
688
+ field = field.isel(**{dim_name: idx})
689
+ mask = mask.isel(**{dim_name: idx})
690
+ layer_depth = layer_depth.isel(**{dim_name: idx})
691
+ if layer_contours:
692
+ interface_depth = interface_depth.isel(**{dim_name: idx})
693
+ if "s_rho" in field.dims:
694
+ field = field.assign_coords({"layer_depth": layer_depth})
695
+ else:
696
+ raise ValueError(
697
+ f"None of the expected dimensions ({dim_name}) found in field."
698
+ )
699
+ return field, mask, layer_depth, interface_depth, title
724
700
 
725
- # slice the field as desired
726
701
  title = field.long_name
727
702
  if s is not None:
728
703
  title = title + f", s_rho = {field.s_rho[s].item()}"
@@ -733,49 +708,32 @@ class InitialConditions:
733
708
  depth_contours = False
734
709
 
735
710
  if eta is not None:
736
- if "eta_rho" in field.dims:
737
- title = title + f", eta_rho = {field.eta_rho[eta].item()}"
738
- field = field.isel(eta_rho=eta)
739
- layer_depth = layer_depth.isel(eta_rho=eta)
740
- if layer_contours:
741
- interface_depth = interface_depth.isel(eta_rho=eta)
742
- if "s_rho" in field.dims:
743
- field = field.assign_coords({"layer_depth": layer_depth})
744
- elif "eta_v" in field.dims:
745
- title = title + f", eta_v = {field.eta_v[eta].item()}"
746
- field = field.isel(eta_v=eta)
747
- layer_depth = layer_depth.isel(eta_v=eta)
748
- if layer_contours:
749
- interface_depth = interface_depth.isel(eta_v=eta)
750
- if "s_rho" in field.dims:
751
- field = field.assign_coords({"layer_depth": layer_depth})
752
- else:
753
- raise ValueError(
754
- f"None of the expected dimensions (eta_rho, eta_v) found in ds[{var_name}]."
755
- )
711
+ field, mask, layer_depth, interface_depth, title = _slice_and_assign(
712
+ field,
713
+ mask,
714
+ layer_depth,
715
+ interface_depth,
716
+ title,
717
+ "eta_rho" if "eta_rho" in field.dims else "eta_v",
718
+ field.eta_rho if "eta_rho" in field.dims else field.eta_v,
719
+ eta,
720
+ layer_contours,
721
+ )
722
+
756
723
  if xi is not None:
757
- if "xi_rho" in field.dims:
758
- title = title + f", xi_rho = {field.xi_rho[xi].item()}"
759
- field = field.isel(xi_rho=xi)
760
- layer_depth = layer_depth.isel(xi_rho=xi)
761
- if layer_contours:
762
- interface_depth = interface_depth.isel(xi_rho=xi)
763
- if "s_rho" in field.dims:
764
- field = field.assign_coords({"layer_depth": layer_depth})
765
- elif "xi_u" in field.dims:
766
- title = title + f", xi_u = {field.xi_u[xi].item()}"
767
- field = field.isel(xi_u=xi)
768
- layer_depth = layer_depth.isel(xi_u=xi)
769
- if layer_contours:
770
- interface_depth = interface_depth.isel(xi_u=xi)
771
- if "s_rho" in field.dims:
772
- field = field.assign_coords({"layer_depth": layer_depth})
773
- else:
774
- raise ValueError(
775
- f"None of the expected dimensions (xi_rho, xi_u) found in ds[{var_name}]."
776
- )
724
+ field, mask, layer_depth, interface_depth, title = _slice_and_assign(
725
+ field,
726
+ mask,
727
+ layer_depth,
728
+ interface_depth,
729
+ title,
730
+ "xi_rho" if "xi_rho" in field.dims else "xi_u",
731
+ field.xi_rho if "xi_rho" in field.dims else field.xi_u,
732
+ xi,
733
+ layer_contours,
734
+ )
777
735
 
778
- # chose colorbar
736
+ # Choose colorbar
779
737
  if var_name in ["u", "v", "w", "ubar", "vbar", "zeta"]:
780
738
  vmax = max(field.max().values, -field.min().values)
781
739
  vmin = -vmax
@@ -792,9 +750,7 @@ class InitialConditions:
792
750
 
793
751
  if eta is None and xi is None:
794
752
  _plot(
795
- self.grid.ds,
796
753
  field=field.where(mask),
797
- straddle=self.grid.straddle,
798
754
  depth_contours=depth_contours,
799
755
  title=title,
800
756
  kwargs=kwargs,
@@ -813,7 +769,7 @@ class InitialConditions:
813
769
 
814
770
  if len(field.dims) == 2:
815
771
  _section_plot(
816
- field,
772
+ field.where(mask),
817
773
  interface_depth=interface_depth,
818
774
  title=title,
819
775
  kwargs=kwargs,
@@ -821,9 +777,9 @@ class InitialConditions:
821
777
  )
822
778
  else:
823
779
  if "s_rho" in field.dims:
824
- _profile_plot(field, title=title, ax=ax)
780
+ _profile_plot(field.where(mask), title=title, ax=ax)
825
781
  else:
826
- _line_plot(field, title=title, ax=ax)
782
+ _line_plot(field.where(mask), title=title, ax=ax)
827
783
 
828
784
  def save(
829
785
  self, filepath: Union[str, Path], np_eta: int = None, np_xi: int = None
@@ -4,7 +4,10 @@ from scipy.interpolate import griddata
4
4
  from dataclasses import dataclass, field
5
5
  from typing import Dict, Union
6
6
  from pathlib import Path
7
- from roms_tools.setup.grid import Grid
7
+ import logging
8
+ from scipy.interpolate import interp1d
9
+ from roms_tools import Grid
10
+ from roms_tools.plot import _plot_nesting
8
11
  from roms_tools.setup.utils import (
9
12
  interpolate_from_rho_to_u,
10
13
  interpolate_from_rho_to_v,
@@ -14,9 +17,6 @@ from roms_tools.setup.utils import (
14
17
  _to_yaml,
15
18
  _from_yaml,
16
19
  )
17
- from roms_tools.setup.plot import _plot_nesting
18
- import logging
19
- from scipy.interpolate import interp1d
20
20
 
21
21
 
22
22
  @dataclass(frozen=True, kw_only=True)
@@ -2,12 +2,14 @@ import xarray as xr
2
2
  import numpy as np
3
3
  import logging
4
4
  from dataclasses import dataclass, field
5
- from roms_tools.setup.grid import Grid
5
+ import cartopy.crs as ccrs
6
6
  from datetime import datetime
7
7
  from typing import Dict, Union, List
8
- from roms_tools.setup.datasets import DaiRiverDataset
9
8
  from pathlib import Path
10
9
  import matplotlib.pyplot as plt
10
+ from roms_tools import Grid
11
+ from roms_tools.plot import _get_projection, _add_field_to_ax
12
+ from roms_tools.setup.datasets import DaiRiverDataset
11
13
  from roms_tools.setup.utils import (
12
14
  get_target_coords,
13
15
  gc_dist,
@@ -18,8 +20,6 @@ from roms_tools.setup.utils import (
18
20
  _from_yaml,
19
21
  get_variable_metadata,
20
22
  )
21
- from roms_tools.setup.plot import _get_projection, _add_field_to_ax
22
- import cartopy.crs as ccrs
23
23
 
24
24
 
25
25
  @dataclass(frozen=True, kw_only=True)
@@ -416,16 +416,13 @@ class RiverForcing:
416
416
  """Plots the original and updated river locations on a map projection."""
417
417
 
418
418
  field = self.grid.ds.mask_rho
419
- field = field.assign_coords(
420
- {"lon": self.grid.ds.lon_rho, "lat": self.grid.ds.lat_rho}
421
- )
422
419
  vmax = 3
423
420
  vmin = 0
424
421
  cmap = plt.colormaps.get_cmap("Blues")
425
422
  kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
426
423
 
427
- lon_deg = field.lon
428
- lat_deg = field.lat
424
+ lon_deg = self.grid.ds.lon_rho
425
+ lat_deg = self.grid.ds.lat_rho
429
426
 
430
427
  # check if North or South pole are in domain
431
428
  if lat_deg.max().values > 89 or lat_deg.min().values < -89:
@@ -435,6 +432,7 @@ class RiverForcing:
435
432
 
436
433
  if self.grid.straddle:
437
434
  lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
435
+ field = field.assign_coords({"lon": lon_deg, "lat": lat_deg})
438
436
 
439
437
  trans = _get_projection(lon_deg, lat_deg)
440
438
 
@@ -1,11 +1,14 @@
1
1
  import xarray as xr
2
2
  import importlib.metadata
3
3
  from dataclasses import dataclass, field
4
- from roms_tools.setup.grid import Grid
5
4
  from datetime import datetime
6
5
  import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from pathlib import Path
7
8
  from typing import Dict, Union, List
8
- from roms_tools.setup.regrid import LateralRegrid
9
+ from roms_tools import Grid
10
+ from roms_tools.regrid import LateralRegrid
11
+ from roms_tools.plot import _plot
9
12
  from roms_tools.setup.datasets import (
10
13
  ERA5Dataset,
11
14
  ERA5Correction,
@@ -24,9 +27,6 @@ from roms_tools.setup.utils import (
24
27
  _to_yaml,
25
28
  _from_yaml,
26
29
  )
27
- from roms_tools.setup.plot import _plot
28
- import matplotlib.pyplot as plt
29
- from pathlib import Path
30
30
 
31
31
 
32
32
  @dataclass(frozen=True, kw_only=True)
@@ -434,21 +434,23 @@ class SurfaceForcing:
434
434
  raise ValueError(f"Variable '{var_name}' is not found in dataset.")
435
435
 
436
436
  field = self.ds[var_name].isel(time=time)
437
+
437
438
  if self.use_dask:
438
439
  from dask.diagnostics import ProgressBar
439
440
 
440
441
  with ProgressBar():
441
442
  field = field.load()
442
443
 
443
- title = field.long_name
444
-
445
444
  field = field.where(self.target_coords["mask"])
446
445
 
447
- field = field.assign_coords(
448
- {"lon": self.target_coords["lon"], "lat": self.target_coords["lat"]}
449
- )
446
+ lon_deg = self.target_coords["lon"]
447
+ lat_deg = self.target_coords["lat"]
448
+ if self.grid.straddle:
449
+ lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
450
+ field = field.assign_coords({"lon": lon_deg, "lat": lat_deg})
451
+
452
+ title = field.long_name
450
453
 
451
- # choose colorbar
452
454
  if var_name in ["uwnd", "vwnd"]:
453
455
  vmax = max(field.max().values, -field.min().values)
454
456
  vmin = -vmax
@@ -465,12 +467,10 @@ class SurfaceForcing:
465
467
  kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
466
468
 
467
469
  _plot(
468
- self.grid.ds,
469
470
  field=field,
470
- straddle=self.grid.straddle,
471
471
  title=title,
472
- kwargs=kwargs,
473
472
  c="g",
473
+ kwargs=kwargs,
474
474
  )
475
475
 
476
476
  def save(
roms_tools/setup/tides.py CHANGED
@@ -3,9 +3,12 @@ import xarray as xr
3
3
  import numpy as np
4
4
  from typing import Dict, Union, List
5
5
  import importlib.metadata
6
+ import matplotlib.pyplot as plt
7
+ from pathlib import Path
6
8
  from dataclasses import dataclass, field
7
- from roms_tools.setup.grid import Grid
8
- from roms_tools.setup.plot import _plot
9
+ from roms_tools import Grid
10
+ from roms_tools.plot import _plot
11
+ from roms_tools.regrid import LateralRegrid
9
12
  from roms_tools.setup.datasets import TPXODataset
10
13
  from roms_tools.setup.utils import (
11
14
  nan_check,
@@ -20,9 +23,6 @@ from roms_tools.setup.utils import (
20
23
  _to_yaml,
21
24
  _from_yaml,
22
25
  )
23
- from roms_tools.setup.regrid import LateralRegrid
24
- import matplotlib.pyplot as plt
25
- from pathlib import Path
26
26
 
27
27
 
28
28
  @dataclass(frozen=True, kw_only=True)
@@ -319,6 +319,8 @@ class TidalForcing:
319
319
  >>> tidal_forcing.plot("ssh_Re", nc=0)
320
320
  """
321
321
 
322
+ if var_name not in self.ds:
323
+ raise ValueError(f"Variable '{var_name}' is not found in dataset.")
322
324
  field = self.ds[var_name].isel(ntides=ntides)
323
325
 
324
326
  if self.use_dask:
@@ -328,25 +330,28 @@ class TidalForcing:
328
330
  field = field.load()
329
331
 
330
332
  if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
331
- field = field.where(self.grid.ds.mask_rho)
332
- field = field.assign_coords(
333
- {"lon": self.grid.ds.lon_rho, "lat": self.grid.ds.lat_rho}
334
- )
333
+ lon_deg = self.grid.ds["lon_rho"]
334
+ lat_deg = self.grid.ds["lat_rho"]
335
+ mask = self.grid.ds["mask_rho"]
335
336
 
336
337
  elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
337
- field = field.where(self.grid.ds.mask_u)
338
- field = field.assign_coords(
339
- {"lon": self.grid.ds.lon_u, "lat": self.grid.ds.lat_u}
340
- )
338
+ lon_deg = self.grid.ds["lon_u"]
339
+ lat_deg = self.grid.ds["lat_u"]
340
+ mask = self.grid.ds["mask_u"]
341
341
 
342
342
  elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
343
- field = field.where(self.grid.ds.mask_v)
344
- field = field.assign_coords(
345
- {"lon": self.grid.ds.lon_v, "lat": self.grid.ds.lat_v}
346
- )
343
+ lon_deg = self.grid.ds["lon_v"]
344
+ lat_deg = self.grid.ds["lat_v"]
345
+ mask = self.grid.ds["mask_v"]
346
+
347
347
  else:
348
348
  ValueError("provided field does not have two horizontal dimension")
349
349
 
350
+ field = field.where(mask)
351
+ if self.grid.straddle:
352
+ lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
353
+ field = field.assign_coords({"lon": lon_deg, "lat": lat_deg})
354
+
350
355
  title = "%s, ntides = %i" % (field.long_name, self.ds[var_name].ntides[ntides])
351
356
 
352
357
  vmax = max(field.max(), -field.min())
@@ -357,12 +362,10 @@ class TidalForcing:
357
362
  kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
358
363
 
359
364
  _plot(
360
- self.grid.ds,
361
365
  field=field,
362
- straddle=self.grid.straddle,
366
+ title=title,
363
367
  c="g",
364
368
  kwargs=kwargs,
365
- title=title,
366
369
  )
367
370
 
368
371
  def save(
@@ -464,7 +467,7 @@ class TidalForcing:
464
467
  grid=grid,
465
468
  **tidal_forcing_params,
466
469
  use_dask=use_dask,
467
- bypass_validation=bypass_validation
470
+ bypass_validation=bypass_validation,
468
471
  )
469
472
 
470
473
  def _correct_tides(self, data):
@@ -6,8 +6,8 @@ import gcm_filters
6
6
  from roms_tools.setup.utils import handle_boundaries
7
7
  import warnings
8
8
  from itertools import count
9
+ from roms_tools.regrid import LateralRegrid
9
10
  from roms_tools.setup.datasets import ETOPO5Dataset, SRTM15Dataset
10
- from roms_tools.setup.regrid import LateralRegrid
11
11
 
12
12
 
13
13
  def _add_topography(