roms-tools 3.1.0__py3-none-any.whl → 3.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,12 +6,10 @@ from datetime import datetime
6
6
  from pathlib import Path
7
7
  from typing import Annotated
8
8
 
9
- import cartopy.crs as ccrs
10
9
  import matplotlib.gridspec as gridspec
11
10
  import matplotlib.pyplot as plt
12
11
  import numpy as np
13
12
  import xarray as xr
14
- from matplotlib.axes import Axes
15
13
  from pydantic import (
16
14
  BaseModel,
17
15
  Field,
@@ -22,7 +20,14 @@ from pydantic import (
22
20
  )
23
21
 
24
22
  from roms_tools import Grid
25
- from roms_tools.plot import get_projection, plot, plot_2d_horizontal_field
23
+ from roms_tools.constants import MAX_DISTINCT_COLORS
24
+ from roms_tools.plot import (
25
+ assign_category_colors,
26
+ get_projection,
27
+ plot,
28
+ plot_2d_horizontal_field,
29
+ plot_location,
30
+ )
26
31
  from roms_tools.setup.cdr_release import (
27
32
  Release,
28
33
  ReleaseType,
@@ -36,6 +41,7 @@ from roms_tools.setup.utils import (
36
41
  gc_dist,
37
42
  get_target_coords,
38
43
  to_dict,
44
+ validate_names,
39
45
  write_to_yaml,
40
46
  )
41
47
  from roms_tools.utils import (
@@ -45,6 +51,7 @@ from roms_tools.utils import (
45
51
  from roms_tools.vertical_coordinate import compute_depth_coordinates
46
52
 
47
53
  INCLUDE_ALL_RELEASE_NAMES = "all"
54
+ MAX_RELEASES_TO_PLOT = 20 # must be <= MAX_DISTINCT_COLORS
48
55
 
49
56
 
50
57
  class ReleaseSimulationManager(BaseModel):
@@ -389,7 +396,10 @@ class CDRForcing(BaseModel):
389
396
  return self._ds
390
397
 
391
398
  def plot_volume_flux(
392
- self, start=None, end=None, release_names=INCLUDE_ALL_RELEASE_NAMES
399
+ self,
400
+ start: datetime | None = None,
401
+ end: datetime | None = None,
402
+ release_names: list[str] | str = INCLUDE_ALL_RELEASE_NAMES,
393
403
  ):
394
404
  """Plot the volume flux for each specified release within the given time range.
395
405
 
@@ -419,12 +429,7 @@ class CDRForcing(BaseModel):
419
429
  start = start or self.start_time
420
430
  end = end or self.end_time
421
431
 
422
- valid_release_names = [r.name for r in self.releases]
423
-
424
- if release_names == INCLUDE_ALL_RELEASE_NAMES:
425
- release_names = valid_release_names
426
-
427
- _validate_release_input(release_names, valid_release_names)
432
+ release_names = _validate_release_names(release_names, self.releases)
428
433
 
429
434
  data = self.ds["cdr_volume"]
430
435
 
@@ -440,9 +445,9 @@ class CDRForcing(BaseModel):
440
445
  def plot_tracer_concentration(
441
446
  self,
442
447
  tracer_name: str,
443
- start=None,
444
- end=None,
445
- release_names=INCLUDE_ALL_RELEASE_NAMES,
448
+ start: datetime | None = None,
449
+ end: datetime | None = None,
450
+ release_names: list[str] | str = INCLUDE_ALL_RELEASE_NAMES,
446
451
  ):
447
452
  """Plot the concentration of a given tracer for each specified release within
448
453
  the given time range.
@@ -476,12 +481,7 @@ class CDRForcing(BaseModel):
476
481
  start = start or self.start_time
477
482
  end = end or self.end_time
478
483
 
479
- valid_release_names = [r.name for r in self.releases]
480
-
481
- if release_names == INCLUDE_ALL_RELEASE_NAMES:
482
- release_names = valid_release_names
483
-
484
- _validate_release_input(release_names, valid_release_names)
484
+ release_names = _validate_release_names(release_names, self.releases)
485
485
 
486
486
  tracer_names = list(self.ds["tracer_name"].values)
487
487
  if tracer_name not in tracer_names:
@@ -511,9 +511,9 @@ class CDRForcing(BaseModel):
511
511
  def plot_tracer_flux(
512
512
  self,
513
513
  tracer_name: str,
514
- start=None,
515
- end=None,
516
- release_names=INCLUDE_ALL_RELEASE_NAMES,
514
+ start: datetime | None = None,
515
+ end: datetime | None = None,
516
+ release_names: list[str] | str = INCLUDE_ALL_RELEASE_NAMES,
517
517
  ):
518
518
  """Plot the flux of a given tracer for each specified release within the given
519
519
  time range.
@@ -547,12 +547,7 @@ class CDRForcing(BaseModel):
547
547
  start = start or self.start_time
548
548
  end = end or self.end_time
549
549
 
550
- valid_release_names = [r.name for r in self.releases]
551
-
552
- if release_names == INCLUDE_ALL_RELEASE_NAMES:
553
- release_names = valid_release_names
554
-
555
- _validate_release_input(release_names, valid_release_names)
550
+ release_names = _validate_release_names(release_names, self.releases)
556
551
 
557
552
  tracer_names = list(self.ds["tracer_name"].values)
558
553
  if tracer_name not in tracer_names:
@@ -577,7 +572,10 @@ class CDRForcing(BaseModel):
577
572
  def _plot_line(self, data, release_names, start, end, title="", ylabel=""):
578
573
  """Plots a line graph for the specified releases and time range."""
579
574
  valid_release_names = [r.name for r in self.releases]
580
- colors = _get_release_colors(valid_release_names)
575
+ if len(valid_release_names) > MAX_DISTINCT_COLORS:
576
+ colors = assign_category_colors(release_names)
577
+ else:
578
+ colors = assign_category_colors(valid_release_names)
581
579
 
582
580
  fig, ax = plt.subplots(1, 1, figsize=(7, 4))
583
581
  for name in release_names:
@@ -596,7 +594,9 @@ class CDRForcing(BaseModel):
596
594
  ax.set(title=title, ylabel=ylabel, xlabel="time")
597
595
  ax.set_xlim([start, end])
598
596
 
599
- def plot_locations(self, release_names="all"):
597
+ def plot_locations(
598
+ self, release_names: list[str] | str = INCLUDE_ALL_RELEASE_NAMES
599
+ ):
600
600
  """Plot centers of release locations in top-down view.
601
601
 
602
602
  Parameters
@@ -619,12 +619,7 @@ class CDRForcing(BaseModel):
619
619
  "A grid must be provided for plotting. Please pass a valid `Grid` object."
620
620
  )
621
621
 
622
- valid_release_names = [r.name for r in self.releases]
623
-
624
- if release_names == "all":
625
- release_names = valid_release_names
626
-
627
- _validate_release_input(release_names, valid_release_names)
622
+ release_names = _validate_release_names(release_names, self.releases)
628
623
 
629
624
  lon_deg = self.grid.ds.lon_rho
630
625
  lat_deg = self.grid.ds.lat_rho
@@ -645,12 +640,22 @@ class CDRForcing(BaseModel):
645
640
  plot_2d_horizontal_field(field, kwargs=kwargs, ax=ax, add_colorbar=False)
646
641
 
647
642
  # Plot release locations
648
- colors = _get_release_colors(valid_release_names)
649
- _plot_location(
650
- grid=self.grid,
651
- releases=[self.releases[name] for name in release_names],
643
+ valid_release_names = [r.name for r in self.releases]
644
+ if len(valid_release_names) > MAX_DISTINCT_COLORS:
645
+ colors = assign_category_colors(release_names)
646
+ else:
647
+ colors = assign_category_colors(valid_release_names)
648
+ plot_location(
649
+ grid_ds=self.grid.ds,
650
+ points={
651
+ name: {
652
+ "lat": self.releases[name].lat,
653
+ "lon": self.releases[name].lon,
654
+ "color": colors.get(name, "k"),
655
+ }
656
+ for name in release_names
657
+ },
652
658
  ax=ax,
653
- colors=colors,
654
659
  )
655
660
 
656
661
  def plot_distribution(self, release_name: str, mark_release_center: bool = True):
@@ -680,8 +685,13 @@ class CDRForcing(BaseModel):
680
685
  "A grid must be provided for plotting. Please pass a valid `Grid` object."
681
686
  )
682
687
 
683
- valid_release_names = [r.name for r in self.releases]
684
- _validate_release_input(release_name, valid_release_names, list_allowed=False)
688
+ if not isinstance(release_name, str):
689
+ raise ValueError(
690
+ f"Only a single release name (string) is allowed. Got: {release_name!r}"
691
+ )
692
+
693
+ release_name = _validate_release_names([release_name], self.releases)[0]
694
+
685
695
  release = self.releases[release_name]
686
696
 
687
697
  # Prepare grid coordinates
@@ -713,8 +723,16 @@ class CDRForcing(BaseModel):
713
723
  title="Depth-integrated distribution",
714
724
  )
715
725
  if mark_release_center:
716
- _plot_location(
717
- grid=self.grid, releases=[release], ax=ax0, include_legend=False
726
+ plot_location(
727
+ grid_ds=self.grid.ds,
728
+ points={
729
+ release.name: {
730
+ "lat": release.lat,
731
+ "lon": release.lon,
732
+ }
733
+ },
734
+ ax=ax0,
735
+ include_legend=False,
718
736
  )
719
737
 
720
738
  # Spread horizontal Gaussian field into the vertical
@@ -828,106 +846,39 @@ class CDRForcing(BaseModel):
828
846
  return cls(grid=grid, **params)
829
847
 
830
848
 
831
- def _validate_release_input(releases, valid_releases, list_allowed=True):
832
- """Validates the input for release names in plotting methods to ensure they are in
833
- an acceptable format and exist within the set of valid releases.
834
-
835
- This method ensures that the `releases` parameter is either a single release name (string) or a list
836
- of release names (strings), and checks that each release exists in the set of valid releases.
837
-
838
- Parameters
839
- ----------
840
- releases : str or list of str
841
- A single release name as a string, or a list of release names (strings) to validate.
842
-
843
- list_allowed : bool, optional
844
- If `True`, a list of release names is allowed. If `False`, only a single release name (string)
845
- is allowed. Default is `True`.
846
-
847
- Raises
848
- ------
849
- ValueError
850
- If `releases` is not a string or list of strings, or if any release name is invalid (not in `self.releases`).
851
-
852
- Notes
853
- -----
854
- This method checks that the `releases` input is in a valid format (either a string or a list of strings),
855
- and ensures each release is present in the set of valid releases defined in `self.releases`. Invalid releases
856
- are reported in the error message.
857
-
858
- If `list_allowed` is set to `False`, only a single release name (string) will be accepted. Otherwise, a
859
- list of release names is also acceptable.
849
+ def _validate_release_names(
850
+ release_names: list[str] | str, releases: ReleaseCollector
851
+ ) -> list[str]:
860
852
  """
861
- # Ensure that a list of releases is only allowed if `list_allowed` is True
862
- if not list_allowed and not isinstance(releases, str):
863
- raise ValueError(
864
- f"Only a single release name (string) is allowed. Got: {releases}"
865
- )
853
+ Validate and filter a list of release names.
866
854
 
867
- if isinstance(releases, str):
868
- releases = [releases] # Convert to list if a single string is provided
869
- elif isinstance(releases, list):
870
- if not all(isinstance(r, str) for r in releases):
871
- raise ValueError("All elements in `releases` list must be strings.")
872
- else:
873
- raise ValueError(
874
- "`releases` should be a string (single release name) or a list of strings (release names)."
875
- )
876
-
877
- # Validate that the specified releases exist in self.releases
878
- invalid_releases = [
879
- release for release in releases if release not in valid_releases
880
- ]
881
- if invalid_releases:
882
- raise ValueError(f"Invalid releases: {', '.join(invalid_releases)}")
883
-
884
-
885
- def _get_release_colors(valid_releases: list[str]) -> dict[str, tuple]:
886
- """Returns a dictionary of colors for the valid releases, based on a consistent
887
- colormap.
855
+ Ensures that each release name exists in `releases` and limits the list
856
+ to `MAX_RELEASES_TO_PLOT` entries with a warning if truncated.
888
857
 
889
858
  Parameters
890
859
  ----------
891
- valid_releases : List[str]
892
- List of release names to assign colors to.
860
+ release_names : list of str or INCLUDE_ALL_RELEASE_NAMES
861
+ Names of releases to plot, or sentinel to include all.
862
+ releases : ReleaseCollector
863
+ Object containing valid release names.
893
864
 
894
865
  Returns
895
866
  -------
896
- Dict[str, tuple]
897
- A dictionary where the keys are release names and the values are their corresponding colors,
898
- assigned based on the order of releases in the valid releases list.
867
+ list of str
868
+ Validated and truncated list of release names.
899
869
 
900
870
  Raises
901
871
  ------
902
872
  ValueError
903
- If the number of valid releases exceeds the available colormap capacity.
904
-
905
- Notes
906
- -----
907
- The colormap is chosen dynamically based on the number of valid releases:
908
-
909
- - If there are 10 or fewer releases, the "tab10" colormap is used.
910
- - If there are more than 10 but fewer than or equal to 20 releases, the "tab20" colormap is used.
911
- - For more than 20 releases, the "tab20b" colormap is used.
873
+ If any names are invalid.
912
874
  """
913
- # Determine the colormap based on the number of releases
914
- if len(valid_releases) <= 10:
915
- color_map = plt.get_cmap("tab10")
916
- elif len(valid_releases) <= 20:
917
- color_map = plt.get_cmap("tab20")
918
- else:
919
- color_map = plt.get_cmap("tab20b")
920
-
921
- # Ensure the number of releases doesn't exceed the available colormap capacity
922
- if len(valid_releases) > color_map.N:
923
- raise ValueError(
924
- f"Too many releases. The selected colormap supports up to {color_map.N} releases."
925
- )
926
-
927
- # Create a dictionary of colors based on the release indices
928
- colors = {name: color_map(i) for i, name in enumerate(valid_releases)}
929
-
930
- return colors
875
+ return validate_names(
876
+ release_names,
877
+ [r.name for r in releases],
878
+ INCLUDE_ALL_RELEASE_NAMES,
879
+ MAX_RELEASES_TO_PLOT,
880
+ label="release",
881
+ )
931
882
 
932
883
 
933
884
  def _validate_release_location(grid, release: Release):
@@ -1088,91 +1039,15 @@ def _map_3d_gaussian(
1088
1039
  # Stack 2D distribution at that vertical level
1089
1040
  distribution_3d[{"s_rho": vertical_idx}] = distribution_2d
1090
1041
  else:
1091
- # Compute layer thickness
1092
- depth_interface = compute_depth_coordinates(
1093
- grid.ds, zeta=0, depth_type="interface", location="rho"
1094
- )
1095
- dz = depth_interface.diff("s_w").rename({"s_w": "s_rho"})
1096
-
1097
1042
  # Compute vertical Gaussian shape
1098
1043
  exponent = -(((depth - release.depth) / release.vsc) ** 2)
1099
1044
  vertical_profile = np.exp(exponent)
1100
1045
 
1101
1046
  # Apply vertical Gaussian scaling
1102
- distribution_3d = distribution_2d * vertical_profile * dz
1047
+ distribution_3d = distribution_2d * vertical_profile
1103
1048
 
1104
1049
  # Normalize
1105
1050
  distribution_3d /= release.vsc * np.sqrt(np.pi)
1106
1051
  distribution_3d /= distribution_3d.sum()
1107
1052
 
1108
1053
  return distribution_3d
1109
-
1110
-
1111
- def _plot_location(
1112
- grid: Grid,
1113
- releases: ReleaseCollector,
1114
- ax: Axes,
1115
- colors: dict[str, tuple] | None = None,
1116
- include_legend: bool = True,
1117
- ) -> None:
1118
- """Plot the center location of each release on a top-down map view.
1119
-
1120
- Each release is represented as a point on the map, with its color
1121
- determined by the `colors` dictionary.
1122
-
1123
- Parameters
1124
- ----------
1125
- grid : Grid
1126
- The grid object defining the spatial extent and coordinate system for the plot.
1127
-
1128
- releases : ReleaseCollector
1129
- Collection of `Release` objects to plot. Each `Release` must have `.lat`, `.lon`,
1130
- and `.name` attributes.
1131
-
1132
- ax : matplotlib.axes.Axes
1133
- The Matplotlib axis object to plot on.
1134
-
1135
- colors : dict of str to tuple, optional
1136
- Optional dictionary mapping release names to RGBA color tuples. If not provided,
1137
- all releases are plotted in a default color (`"#dd1c77"`).
1138
-
1139
- include_legend : bool, default True
1140
- Whether to include a legend showing release names.
1141
-
1142
- Returns
1143
- -------
1144
- None
1145
- """
1146
- lon_deg = grid.ds.lon_rho
1147
- lat_deg = grid.ds.lat_rho
1148
- if grid.straddle:
1149
- lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
1150
- trans = get_projection(lon_deg, lat_deg)
1151
-
1152
- proj = ccrs.PlateCarree()
1153
-
1154
- for release in releases:
1155
- # transform coordinates to projected space
1156
- transformed_lon, transformed_lat = trans.transform_point(
1157
- release.lon,
1158
- release.lat,
1159
- proj,
1160
- )
1161
-
1162
- if colors is not None:
1163
- color = colors[release.name]
1164
- else:
1165
- color = "k"
1166
-
1167
- ax.plot(
1168
- transformed_lon,
1169
- transformed_lat,
1170
- marker="x",
1171
- markersize=8,
1172
- markeredgewidth=2,
1173
- label=release.name,
1174
- color=color,
1175
- )
1176
-
1177
- if include_legend:
1178
- ax.legend(loc="center left", bbox_to_anchor=(1.1, 0.5))
@@ -1,9 +1,13 @@
1
+ import importlib.util
1
2
  import logging
2
3
  import time
3
4
  from collections import Counter, defaultdict
5
+ from collections.abc import Callable
4
6
  from dataclasses import dataclass, field
5
7
  from datetime import datetime, timedelta
6
8
  from pathlib import Path
9
+ from types import ModuleType
10
+ from typing import ClassVar
7
11
 
8
12
  import numpy as np
9
13
  import xarray as xr
@@ -25,7 +29,7 @@ from roms_tools.setup.utils import (
25
29
  interpolate_from_climatology,
26
30
  one_dim_fill,
27
31
  )
28
- from roms_tools.utils import _has_gcsfs, _load_data
32
+ from roms_tools.utils import _get_pkg_error_msg, _has_gcsfs, _load_data
29
33
 
30
34
  # lat-lon datasets
31
35
 
@@ -96,17 +100,18 @@ class Dataset:
96
100
  use_dask: bool | None = False
97
101
  apply_post_processing: bool | None = True
98
102
  read_zarr: bool | None = False
103
+ ds_loader_fn: Callable[[], xr.Dataset] | None = None
99
104
 
100
105
  is_global: bool = field(init=False, repr=False)
101
106
  ds: xr.Dataset = field(init=False, repr=False)
102
107
 
103
- def __post_init__(self):
104
- """
105
- Post-initialization processing:
108
+ def __post_init__(self) -> None:
109
+ """Perform post-initialization processing.
110
+
106
111
  1. Loads the dataset from the specified filename.
107
- 2. Applies time filtering based on start_time and end_time if provided.
108
- 3. Selects relevant fields as specified by var_names.
109
- 4. Ensures latitude values and depth values are in ascending order.
112
+ 2. Applies time filtering based on start_time and end_time (if provided).
113
+ 3. Selects relevant fields as specified by `var_names`.
114
+ 4. Ensures latitude, longitude, and depth values are in ascending order.
110
115
  5. Checks if the dataset covers the entire globe and adjusts if necessary.
111
116
  """
112
117
  # Validate start_time and end_time
@@ -168,7 +173,11 @@ class Dataset:
168
173
  If a list of files is provided but self.dim_names["time"] is not available or use_dask=False.
169
174
  """
170
175
  ds = _load_data(
171
- self.filename, self.dim_names, self.use_dask, read_zarr=self.read_zarr
176
+ self.filename,
177
+ self.dim_names,
178
+ self.use_dask or False,
179
+ read_zarr=self.read_zarr or False,
180
+ ds_loader_fn=self.ds_loader_fn,
172
181
  )
173
182
 
174
183
  return ds
@@ -1075,6 +1084,83 @@ class GLORYSDataset(Dataset):
1075
1084
  self.ds["mask_vel"] = mask_vel
1076
1085
 
1077
1086
 
1087
+ @dataclass(kw_only=True)
1088
+ class GLORYSDefaultDataset(GLORYSDataset):
1089
+ """A GLORYS dataset that is loaded from the Copernicus Marine Data Store."""
1090
+
1091
+ dataset_name: ClassVar[str] = "cmems_mod_glo_phy_my_0.083deg_P1D-m"
1092
+ """The GLORYS dataset-id for requests to the Copernicus Marine Toolkit"""
1093
+ _tk_module: ModuleType | None = None
1094
+ """The dynamically imported Copernicus Marine module."""
1095
+
1096
+ def __post_init__(self) -> None:
1097
+ """Configure attributes to ensure use of the correct upstream data-source."""
1098
+ self.read_zarr = True
1099
+ self.use_dask = True
1100
+ self.filename = self.dataset_name
1101
+ self.ds_loader_fn = self._load_from_copernicus
1102
+
1103
+ super().__post_init__()
1104
+
1105
+ def _check_auth(self, package_name: str) -> None:
1106
+ """Check the local credential hierarchy for auth credentials.
1107
+
1108
+ Raises
1109
+ ------
1110
+ RuntimeError
1111
+ If auth credentials cannot be found.
1112
+ """
1113
+ if self._tk_module and not self._tk_module.login(check_credentials_valid=True):
1114
+ msg = f"Authenticate with `{package_name} login` to retrieve GLORYS data."
1115
+ raise RuntimeError(msg)
1116
+
1117
+ def _load_copernicus(self) -> ModuleType:
1118
+ """Dynamically load the optional Copernicus Marine Toolkit dependency.
1119
+
1120
+ Raises
1121
+ ------
1122
+ RuntimeError
1123
+ - If the toolkit module is not available or cannot be imported.
1124
+ - If auth credentials cannot be found.
1125
+ """
1126
+ package_name = "copernicusmarine"
1127
+ if self._tk_module:
1128
+ self._check_auth(package_name)
1129
+ return self._tk_module
1130
+
1131
+ spec = importlib.util.find_spec(package_name)
1132
+ if not spec:
1133
+ msg = _get_pkg_error_msg("cloud-based GLORYS data", package_name, "stream")
1134
+ raise RuntimeError(msg)
1135
+
1136
+ try:
1137
+ self._tk_module = importlib.import_module(package_name)
1138
+ except ImportError as e:
1139
+ msg = f"Package `{package_name}` was found but could not be loaded."
1140
+ raise RuntimeError(msg) from e
1141
+
1142
+ self._check_auth(package_name)
1143
+ return self._tk_module
1144
+
1145
+ def _load_from_copernicus(self) -> xr.Dataset:
1146
+ """Load a GLORYS dataset supporting streaming.
1147
+
1148
+ Returns
1149
+ -------
1150
+ xr.Dataset
1151
+ The streaming dataset
1152
+ """
1153
+ copernicusmarine = self._load_copernicus()
1154
+ return copernicusmarine.open_dataset(
1155
+ self.dataset_name,
1156
+ start_datetime=self.start_time,
1157
+ end_datetime=self.end_time,
1158
+ service="arco-geo-series",
1159
+ coordinates_selection_method="inside",
1160
+ chunk_size_limit=2,
1161
+ )
1162
+
1163
+
1078
1164
  @dataclass(kw_only=True)
1079
1165
  class UnifiedDataset(Dataset):
1080
1166
  """Represents unified BGC data on original grid.
@@ -1549,12 +1635,8 @@ class ERA5ARCODataset(ERA5Dataset):
1549
1635
  def __post_init__(self):
1550
1636
  self.read_zarr = True
1551
1637
  if not _has_gcsfs():
1552
- raise RuntimeError(
1553
- "To use cloud-based ERA5 data, GCSFS is required but not installed. Install it with:\n"
1554
- " • `pip install roms-tools[stream]` or\n"
1555
- " • `conda install gcsfs`\n"
1556
- "Alternatively, install `roms-tools` with conda to include all dependencies."
1557
- )
1638
+ msg = _get_pkg_error_msg("cloud-based ERA5 data", "gcsfs", "stream")
1639
+ raise RuntimeError(msg)
1558
1640
 
1559
1641
  super().__post_init__()
1560
1642
 
roms_tools/setup/grid.py CHANGED
@@ -415,30 +415,57 @@ class Grid:
415
415
 
416
416
  def plot(
417
417
  self,
418
+ lat: float | None = None,
419
+ lon: float | None = None,
418
420
  with_dim_names: bool = False,
419
421
  save_path: str | None = None,
420
422
  ) -> None:
421
- """Plot the grid.
423
+ """Plot the grid with bathymetry.
424
+
425
+ Depending on the arguments, this will either:
426
+ * Plot the full horizontal grid (if both `lat` and `lon` are None),
427
+ * Plot a zonal (east-west) vertical section at a given latitude (`lat`),
428
+ * Plot a meridional (south-north) vertical section at a given longitude (`lon`).
422
429
 
423
430
  Parameters
424
431
  ----------
432
+ lat : float, optional
433
+ Latitude in degrees at which to plot a vertical (zonal) section. Cannot be
434
+ provided together with `lon`. Default is None.
435
+
436
+ lon : float, optional
437
+ Longitude in degrees at which to plot a vertical (meridional) section. Cannot be
438
+ provided together with `lat`. Default is None.
439
+
425
440
  with_dim_names : bool, optional
426
- Whether or not to plot the dimension names. Default is False.
441
+ If True and no section is requested (i.e., both `lat` and `lon` are None), annotate
442
+ the plot with the underlying dimension names. Default is False.
427
443
 
428
444
  save_path : str, optional
429
445
  Path to save the generated plot. If None, the plot is shown interactively.
430
446
  Default is None.
431
447
 
448
+ Raises
449
+ ------
450
+ ValueError
451
+ If both `lat` and `lon` are specified simultaneously.
452
+
432
453
  Returns
433
454
  -------
434
455
  None
435
456
  This method does not return any value. It generates and displays a plot.
436
457
  """
458
+ if lat is not None and lon is not None:
459
+ raise ValueError("Specify either `lat` or `lon`, not both.")
460
+
437
461
  field = self.ds["h"]
438
462
 
439
463
  plot(
440
464
  field=field,
441
465
  grid_ds=self.ds,
466
+ lat=lat,
467
+ lon=lon,
468
+ yincrease=False,
442
469
  with_dim_names=with_dim_names,
443
470
  save_path=save_path,
444
471
  cmap_name="YlGnBu",