roms-tools 3.1.0__py3-none-any.whl → 3.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,14 +4,18 @@ from dataclasses import dataclass, field
4
4
  from datetime import datetime
5
5
  from pathlib import Path
6
6
 
7
- import cartopy.crs as ccrs
8
- import matplotlib.cm as cm
9
7
  import matplotlib.pyplot as plt
10
8
  import numpy as np
11
9
  import xarray as xr
12
10
 
13
11
  from roms_tools import Grid
14
- from roms_tools.plot import get_projection, plot_2d_horizontal_field
12
+ from roms_tools.constants import MAX_DISTINCT_COLORS
13
+ from roms_tools.plot import (
14
+ assign_category_colors,
15
+ get_projection,
16
+ plot_2d_horizontal_field,
17
+ plot_location,
18
+ )
15
19
  from roms_tools.setup.datasets import (
16
20
  DaiRiverDataset,
17
21
  get_indices_of_nearest_grid_cell_for_rivers,
@@ -26,10 +30,14 @@ from roms_tools.setup.utils import (
26
30
  get_variable_metadata,
27
31
  substitute_nans_by_fillvalue,
28
32
  to_dict,
33
+ validate_names,
29
34
  write_to_yaml,
30
35
  )
31
36
  from roms_tools.utils import save_datasets
32
37
 
38
+ INCLUDE_ALL_RIVER_NAMES = "all"
39
+ MAX_RIVERS_TO_PLOT = 20 # must be <= MAX_DISTINCT_COLORS
40
+
33
41
 
34
42
  @dataclass(kw_only=True)
35
43
  class RiverForcing:
@@ -672,8 +680,24 @@ class RiverForcing:
672
680
  "`convert_to_climatology = 'if_any_missing'` to automatically fill missing values with climatological data."
673
681
  )
674
682
 
675
- def plot_locations(self):
676
- """Plots the original and updated river locations on a map projection."""
683
+ def plot_locations(self, river_names: list[str] | str = INCLUDE_ALL_RIVER_NAMES):
684
+ """Plots the original and updated river locations on a map projection.
685
+
686
+ Parameters
687
+ ----------
688
+ river_names : list[str], or str, optional
689
+ A list of release names to plot.
690
+ If a string equal to "all", all rivers will be plotted.
691
+ Defaults to "all".
692
+
693
+ """
694
+ valid_river_names = list(self.indices.keys())
695
+ river_names = _validate_river_names(river_names, valid_river_names)
696
+ if len(valid_river_names) > MAX_DISTINCT_COLORS:
697
+ colors = assign_category_colors(river_names)
698
+ else:
699
+ colors = assign_category_colors(valid_river_names)
700
+
677
701
  field = self.grid.ds.mask_rho
678
702
  lon_deg = self.grid.ds.lon_rho
679
703
  lat_deg = self.grid.ds.lat_rho
@@ -695,53 +719,37 @@ class RiverForcing:
695
719
  for ax in axs:
696
720
  plot_2d_horizontal_field(field, kwargs=kwargs, ax=ax, add_colorbar=False)
697
721
 
698
- proj = ccrs.PlateCarree()
699
-
700
- if len(self.indices) <= 10:
701
- color_map = cm.get_cmap("tab10")
702
- elif len(self.indices) <= 20:
703
- color_map = cm.get_cmap("tab20")
704
- else:
705
- color_map = cm.get_cmap("tab20b")
706
- # Create a dictionary of colors
707
- colors = {name: color_map(i) for i, name in enumerate(self.indices.keys())}
708
-
709
- for ax, indices in zip(axs, [self.original_indices, self.indices]):
710
- added_labels = set()
711
- for name in indices.keys():
712
- for tuple in indices[name]:
713
- eta_index = tuple[0]
714
- xi_index = tuple[1]
715
-
716
- # transform coordinates to projected space
717
- transformed_lon, transformed_lat = trans.transform_point(
718
- self.grid.ds.lon_rho[eta_index, xi_index],
719
- self.grid.ds.lat_rho[eta_index, xi_index],
720
- proj,
721
- )
722
-
723
- if name not in added_labels:
724
- added_labels.add(name)
725
- label = name
726
- else:
727
- label = "_None"
728
-
729
- ax.plot(
730
- transformed_lon,
731
- transformed_lat,
732
- marker="x",
733
- markersize=8,
734
- markeredgewidth=2,
735
- label=label,
736
- color=colors[name],
737
- )
722
+ points = {}
723
+ for j, (ax, indices) in enumerate(
724
+ [(ax, ind) for ax, ind in zip(axs, [self.original_indices, self.indices])]
725
+ ):
726
+ for name in river_names:
727
+ if name in indices:
728
+ for i, (eta_index, xi_index) in enumerate(indices[name]):
729
+ lon = self.grid.ds.lon_rho[eta_index, xi_index].item()
730
+ lat = self.grid.ds.lat_rho[eta_index, xi_index].item()
731
+ key = name if i == 0 else f"_{name}_{i}"
732
+ points[key] = {
733
+ "lon": lon,
734
+ "lat": lat,
735
+ "color": colors[name],
736
+ }
737
+
738
+ plot_location(
739
+ grid_ds=self.grid.ds,
740
+ points=points,
741
+ ax=ax,
742
+ include_legend=(j == 1),
743
+ )
738
744
 
739
745
  axs[0].set_title("Original river locations")
740
746
  axs[1].set_title("Updated river locations")
741
747
 
742
- axs[1].legend(loc="center left", bbox_to_anchor=(1.1, 0.5))
743
-
744
- def plot(self, var_name="river_volume"):
748
+ def plot(
749
+ self,
750
+ var_name: str = "river_volume",
751
+ river_names: list[str] | str = INCLUDE_ALL_RIVER_NAMES,
752
+ ):
745
753
  """Plots the river flux (e.g., volume, temperature, or salinity) over time for
746
754
  all rivers.
747
755
 
@@ -791,8 +799,19 @@ class RiverForcing:
791
799
  - 'river_diazFe' : river diazFe (from river_tracer).
792
800
 
793
801
  The default is 'river_volume'.
802
+
803
+ river_names : list[str], or str, optional
804
+ A list of release names to plot.
805
+ If a string equal to "all", all rivers will be plotted.
806
+ Defaults to "all".
807
+
794
808
  """
795
- fig, ax = plt.subplots(1, 1, figsize=(9, 5))
809
+ valid_river_names = list(self.indices.keys())
810
+ river_names = _validate_river_names(river_names, valid_river_names)
811
+ if len(valid_river_names) > MAX_DISTINCT_COLORS:
812
+ colors = assign_category_colors(river_names)
813
+ else:
814
+ colors = assign_category_colors(valid_river_names)
796
815
 
797
816
  if self.climatology:
798
817
  xticks = self.ds.month.values
@@ -814,15 +833,19 @@ class RiverForcing:
814
833
  units = d[var_name_wo_river]["units"]
815
834
  long_name = f"River {d[var_name_wo_river]['long_name']}"
816
835
 
817
- for i in range(len(self.ds.nriver)):
836
+ fig, ax = plt.subplots(1, 1, figsize=(9, 5))
837
+ for name in river_names:
838
+ nriver = np.where(self.ds["river_name"].values == name)[0].item()
839
+
818
840
  ax.plot(
819
841
  xticks,
820
- field.isel(nriver=i),
842
+ field.isel(nriver=nriver),
821
843
  marker="x",
822
844
  markersize=8,
823
845
  markeredgewidth=2,
824
846
  lw=2,
825
- label=self.ds.isel(nriver=i).river_name.values,
847
+ label=name,
848
+ color=colors[name],
826
849
  )
827
850
 
828
851
  ax.set_xticks(xticks)
@@ -965,3 +988,38 @@ def check_river_locations_are_along_coast(mask, indices):
965
988
  raise ValueError(
966
989
  f"River `{key}` is not located on the coast at grid cell ({eta_rho}, {xi_rho})."
967
990
  )
991
+
992
+
993
+ def _validate_river_names(
994
+ river_names: list[str] | str, valid_river_names: list[str]
995
+ ) -> list[str]:
996
+ """
997
+ Validate and filter a list of river names.
998
+
999
+ Ensures that each river name exists in `valid_river_names` and limits the list
1000
+ to `MAX_RIVERS_TO_PLOT` entries with a warning if truncated.
1001
+
1002
+ Parameters
1003
+ ----------
1004
+ river_names : list of str or INCLUDE_ALL_RIVER_NAMES
1005
+ Names of rivers to plot, or sentinel to include all.
1006
+ valid_river_names : list of str
1007
+ List of valid river names.
1008
+
1009
+ Returns
1010
+ -------
1011
+ list of str
1012
+ Validated and truncated list of river names.
1013
+
1014
+ Raises
1015
+ ------
1016
+ ValueError
1017
+ If any names are invalid.
1018
+ """
1019
+ return validate_names(
1020
+ river_names,
1021
+ valid_river_names,
1022
+ INCLUDE_ALL_RIVER_NAMES,
1023
+ MAX_RIVERS_TO_PLOT,
1024
+ label="river",
1025
+ )
@@ -150,12 +150,20 @@ class SurfaceForcing:
150
150
  use_coarse_grid = False
151
151
  elif self.coarse_grid_mode == "auto":
152
152
  use_coarse_grid = self._determine_coarse_grid_usage(data)
153
- if use_coarse_grid:
154
- logging.info("Data will be interpolated onto grid coarsened by factor 2.")
155
- else:
156
- logging.info("Data will be interpolated onto fine grid.")
157
153
  self.use_coarse_grid = use_coarse_grid
158
154
 
155
+ opt_file = "bulk_frc.opt" if self.type == "physics" else "bgc.opt"
156
+ grid_desc = "grid coarsened by factor 2" if use_coarse_grid else "fine grid"
157
+ interp_flag = 1 if use_coarse_grid else 0
158
+
159
+ logging.info(
160
+ "Data will be interpolated onto the %s. "
161
+ "Remember to set `interp_frc = %d` in your `%s` ROMS option file.",
162
+ grid_desc,
163
+ interp_flag,
164
+ opt_file,
165
+ )
166
+
159
167
  target_coords = get_target_coords(self.grid, self.use_coarse_grid)
160
168
  self.target_coords = target_coords
161
169
 
roms_tools/setup/utils.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import importlib.metadata
2
+ import logging
2
3
  from collections.abc import Sequence
3
4
  from dataclasses import asdict, fields, is_dataclass
4
5
  from datetime import datetime
@@ -1794,3 +1795,59 @@ def to_float(val):
1794
1795
  if isinstance(val, list):
1795
1796
  return [float(v) for v in val]
1796
1797
  return float(val)
1798
+
1799
+
1800
+ def validate_names(
1801
+ names: list[str] | str,
1802
+ valid_names: list[str],
1803
+ include_all_sentinel: str,
1804
+ max_to_plot: int,
1805
+ label: str = "item",
1806
+ ) -> list[str]:
1807
+ """
1808
+ Generic validation and filtering for a list of names.
1809
+
1810
+ Parameters
1811
+ ----------
1812
+ names : list of str or sentinel
1813
+ Names to validate, or sentinel value to include all valid names.
1814
+ valid_names : list of str
1815
+ List of valid names to check against.
1816
+ include_all_sentinel : str
1817
+ Sentinel value to indicate all names should be included.
1818
+ max_to_plot : int
1819
+ Maximum number of names to return.
1820
+ label : str, default "item"
1821
+ Label to use in error and warning messages.
1822
+
1823
+ Returns
1824
+ -------
1825
+ list of str
1826
+ Validated and possibly truncated list of names.
1827
+
1828
+ Raises
1829
+ ------
1830
+ ValueError
1831
+ If any names are invalid or input is not a list of strings.
1832
+ """
1833
+ if names == include_all_sentinel:
1834
+ names = valid_names
1835
+
1836
+ if isinstance(names, list):
1837
+ if not all(isinstance(n, str) for n in names):
1838
+ raise ValueError(f"All elements in `{label}_names` must be strings.")
1839
+ else:
1840
+ raise ValueError(f"`{label}_names` should be a list of strings.")
1841
+
1842
+ invalid = [n for n in names if n not in valid_names]
1843
+ if invalid:
1844
+ raise ValueError(f"Invalid {label}s: {', '.join(invalid)}")
1845
+
1846
+ if len(names) > max_to_plot:
1847
+ logging.warning(
1848
+ f"Only the first {max_to_plot} {label}s will be plotted "
1849
+ f"(received {len(names)})."
1850
+ )
1851
+ names = names[:max_to_plot]
1852
+
1853
+ return names
@@ -1,7 +1,9 @@
1
1
  import logging
2
+ import os
2
3
  import textwrap
3
4
  from datetime import datetime
4
5
  from pathlib import Path
6
+ from unittest import mock
5
7
 
6
8
  import matplotlib.pyplot as plt
7
9
  import numpy as np
@@ -758,3 +760,58 @@ def test_from_yaml_missing_boundary_forcing(tmp_path, use_dask):
758
760
 
759
761
  yaml_filepath = Path(yaml_filepath)
760
762
  yaml_filepath.unlink()
763
+
764
+
765
+ @pytest.mark.stream
766
+ @pytest.mark.use_dask
767
+ @pytest.mark.use_copernicus
768
+ def test_default_glorys_dataset_loading(tiny_grid: Grid) -> None:
769
+ """Verify the default GLORYS dataset is loaded when a path is not provided."""
770
+ start_time = datetime(2010, 2, 1)
771
+ end_time = datetime(2010, 3, 1)
772
+
773
+ with mock.patch.dict(
774
+ os.environ, {"PYDEVD_WARN_EVALUATION_TIMEOUT": "90"}, clear=True
775
+ ):
776
+ bf = BoundaryForcing(
777
+ grid=tiny_grid,
778
+ source={"name": "GLORYS"},
779
+ type="physics",
780
+ start_time=start_time,
781
+ end_time=end_time,
782
+ use_dask=True,
783
+ bypass_validation=True,
784
+ )
785
+
786
+ expected_vars = {"u_south", "v_south", "temp_south", "salt_south"}
787
+ assert set(bf.ds.data_vars).issuperset(expected_vars)
788
+
789
+
790
+ @pytest.mark.parametrize(
791
+ "use_dask",
792
+ [pytest.param(True, marks=pytest.mark.use_dask), False],
793
+ )
794
+ def test_nondefault_glorys_dataset_loading(small_grid: Grid, use_dask: bool) -> None:
795
+ """Verify a non-default GLORYS dataset is loaded when a path is provided."""
796
+ start_time = datetime(2012, 1, 1)
797
+ end_time = datetime(2012, 12, 31)
798
+
799
+ local_path = Path(download_test_data("GLORYS_NA_20120101.nc"))
800
+
801
+ with mock.patch.dict(
802
+ os.environ, {"PYDEVD_WARN_EVALUATION_TIMEOUT": "90"}, clear=True
803
+ ):
804
+ bf = BoundaryForcing(
805
+ grid=small_grid,
806
+ source={
807
+ "name": "GLORYS",
808
+ "path": local_path,
809
+ },
810
+ type="physics",
811
+ start_time=start_time,
812
+ end_time=end_time,
813
+ use_dask=use_dask,
814
+ )
815
+
816
+ expected_vars = {"u_south", "v_south", "temp_south", "salt_south"}
817
+ assert set(bf.ds.data_vars).issuperset(expected_vars)
@@ -9,7 +9,7 @@ from pydantic import ValidationError
9
9
 
10
10
  from conftest import calculate_file_hash
11
11
  from roms_tools import CDRForcing, Grid, TracerPerturbation, VolumeRelease
12
- from roms_tools.constants import NUM_TRACERS
12
+ from roms_tools.constants import MAX_DISTINCT_COLORS, NUM_TRACERS
13
13
  from roms_tools.setup.cdr_forcing import (
14
14
  CDRForcingDatasetBuilder,
15
15
  ReleaseCollector,
@@ -725,6 +725,8 @@ class TestCDRForcing:
725
725
  rot=0,
726
726
  N=3,
727
727
  )
728
+ self.grid = grid
729
+
728
730
  grid_that_straddles = Grid(
729
731
  nx=18,
730
732
  ny=18,
@@ -817,8 +819,13 @@ class TestCDRForcing:
817
819
  self.volume_release_cdr_forcing_with_straddling_grid,
818
820
  ]:
819
821
  cdr.plot_volume_flux()
822
+ cdr.plot_volume_flux(release_names=["first_release"])
823
+
820
824
  cdr.plot_tracer_concentration("ALK")
825
+ cdr.plot_tracer_concentration("ALK", release_names=["first_release"])
826
+
821
827
  cdr.plot_tracer_concentration("DIC")
828
+ cdr.plot_tracer_concentration("DIC", release_names=["first_release"])
822
829
 
823
830
  self.volume_release_cdr_forcing.plot_locations()
824
831
  self.volume_release_cdr_forcing.plot_locations(release_names=["first_release"])
@@ -830,13 +837,56 @@ class TestCDRForcing:
830
837
  self.tracer_perturbation_cdr_forcing_with_straddling_grid,
831
838
  ]:
832
839
  cdr.plot_tracer_flux("ALK")
840
+ cdr.plot_tracer_flux("ALK", release_names=["first_release"])
841
+
833
842
  cdr.plot_tracer_flux("DIC")
843
+ cdr.plot_tracer_flux("DIC", release_names=["first_release"])
834
844
 
835
845
  self.tracer_perturbation_cdr_forcing.plot_locations()
836
846
  self.tracer_perturbation_cdr_forcing.plot_locations(
837
847
  release_names=["first_release"]
838
848
  )
839
849
 
850
+ def test_plot_max_releases(self, caplog):
851
+ # Prepare releases with more than MAX_DISTINCT_COLORS unique names
852
+ releases = []
853
+ for i in range(MAX_DISTINCT_COLORS + 1):
854
+ release = self.first_volume_release.__replace__(name=f"release_{i}")
855
+ releases.append(release)
856
+
857
+ # Construct a CDRForcing object with too many releases to plot
858
+ cdr_forcing = CDRForcing(
859
+ grid=self.grid,
860
+ start_time=self.start_time,
861
+ end_time=self.end_time,
862
+ releases=releases,
863
+ )
864
+
865
+ release_names = [r.name for r in releases]
866
+
867
+ plot_methods_with_release_names = [
868
+ cdr_forcing.plot_locations,
869
+ cdr_forcing.plot_volume_flux,
870
+ ]
871
+
872
+ for plot_func in plot_methods_with_release_names:
873
+ caplog.clear()
874
+ with caplog.at_level("WARNING"):
875
+ plot_func(release_names=release_names)
876
+ assert any(
877
+ f"Only the first {MAX_DISTINCT_COLORS} releases will be plotted"
878
+ in message
879
+ for message in caplog.messages
880
+ ), f"Warning not raised by {plot_func.__name__}"
881
+
882
+ with caplog.at_level("WARNING"):
883
+ cdr_forcing.plot_locations(release_names=release_names)
884
+
885
+ assert any(
886
+ f"Only the first {MAX_DISTINCT_COLORS} releases will be plotted" in message
887
+ for message in caplog.messages
888
+ )
889
+
840
890
  @pytest.mark.skipif(xesmf is None, reason="xesmf required")
841
891
  def test_plot_distribution(self):
842
892
  self.volume_release_cdr_forcing.plot_distribution("first_release")
@@ -856,10 +906,10 @@ class TestCDRForcing:
856
906
  with pytest.raises(ValueError, match="Invalid releases"):
857
907
  self.volume_release_cdr_forcing.plot_locations(release_names=["fake"])
858
908
 
859
- with pytest.raises(ValueError, match="should be a string"):
909
+ with pytest.raises(ValueError, match="should be a list"):
860
910
  self.volume_release_cdr_forcing.plot_locations(release_names=4)
861
911
 
862
- with pytest.raises(ValueError, match="list must be strings"):
912
+ with pytest.raises(ValueError, match="must be strings"):
863
913
  self.volume_release_cdr_forcing.plot_locations(release_names=[4])
864
914
 
865
915
  def test_cdr_forcing_save(self, tmp_path):
@@ -2,6 +2,7 @@ import logging
2
2
  from collections import OrderedDict
3
3
  from datetime import datetime
4
4
  from pathlib import Path
5
+ from unittest import mock
5
6
 
6
7
  import numpy as np
7
8
  import pytest
@@ -11,11 +12,14 @@ from roms_tools.download import download_test_data
11
12
  from roms_tools.setup.datasets import (
12
13
  CESMBGCDataset,
13
14
  Dataset,
15
+ ERA5ARCODataset,
14
16
  ERA5Correction,
15
17
  GLORYSDataset,
18
+ GLORYSDefaultDataset,
16
19
  RiverDataset,
17
20
  TPXODataset,
18
21
  )
22
+ from roms_tools.setup.surface_forcing import DEFAULT_ERA5_ARCO_PATH
19
23
 
20
24
 
21
25
  @pytest.fixture
@@ -437,6 +441,78 @@ def test_era5_correction_choose_subdomain(use_dask):
437
441
  assert (data.ds["longitude"] == lons).all()
438
442
 
439
443
 
444
+ @pytest.mark.use_gcsfs
445
+ def test_default_era5_dataset_loading_without_dask() -> None:
446
+ """Verify that loading the default ERA5 dataset fails if use_dask is not True."""
447
+ start_time = datetime(2020, 2, 1)
448
+ end_time = datetime(2020, 2, 2)
449
+
450
+ with pytest.raises(ValueError):
451
+ _ = ERA5ARCODataset(
452
+ filename=DEFAULT_ERA5_ARCO_PATH,
453
+ start_time=start_time,
454
+ end_time=end_time,
455
+ use_dask=False,
456
+ )
457
+
458
+
459
+ @pytest.mark.skip("Temporary skip until memory consumption issue is addressed. # TODO")
460
+ @pytest.mark.stream
461
+ @pytest.mark.use_dask
462
+ @pytest.mark.use_gcsfs
463
+ def test_default_era5_dataset_loading() -> None:
464
+ """Verify the default ERA5 dataset is loaded correctly."""
465
+ start_time = datetime(2020, 2, 1)
466
+ end_time = datetime(2020, 2, 2)
467
+
468
+ ds = ERA5ARCODataset(
469
+ filename=DEFAULT_ERA5_ARCO_PATH,
470
+ start_time=start_time,
471
+ end_time=end_time,
472
+ use_dask=True,
473
+ )
474
+
475
+ expected_vars = {"uwnd", "vwnd", "swrad", "lwrad", "Tair", "rain"}
476
+ assert set(ds.var_names).issuperset(expected_vars)
477
+
478
+
479
+ @pytest.mark.use_copernicus
480
+ def test_default_glorys_dataset_loading_dask_not_installed() -> None:
481
+ """Verify that loading the default GLORYS dataset fails if dask is not available."""
482
+ start_time = datetime(2020, 2, 1)
483
+ end_time = datetime(2020, 2, 2)
484
+
485
+ with (
486
+ pytest.raises(RuntimeError),
487
+ mock.patch("roms_tools.utils._has_dask", return_value=False),
488
+ ):
489
+ _ = GLORYSDefaultDataset(
490
+ filename=GLORYSDefaultDataset.dataset_name,
491
+ start_time=start_time,
492
+ end_time=end_time,
493
+ use_dask=True,
494
+ )
495
+
496
+
497
+ @pytest.mark.stream
498
+ @pytest.mark.use_copernicus
499
+ @pytest.mark.use_dask
500
+ def test_default_glorys_dataset_loading() -> None:
501
+ """Verify the default GLORYS dataset is loaded correctly."""
502
+ start_time = datetime(2012, 1, 1)
503
+ end_time = datetime(2013, 1, 1)
504
+
505
+ ds = GLORYSDefaultDataset(
506
+ filename=GLORYSDefaultDataset.dataset_name,
507
+ start_time=start_time,
508
+ end_time=end_time,
509
+ use_dask=True,
510
+ )
511
+
512
+ expected_vars = {"temp", "salt", "u", "v", "zeta"}
513
+ assert set(ds.var_names).issuperset(expected_vars)
514
+
515
+
440
516
  def test_data_concatenation(use_dask):
441
517
  fname = download_test_data("GLORYS_NA_2012.nc")
442
518
  data = GLORYSDataset(
@@ -21,6 +21,11 @@ from roms_tools.constants import (
21
21
  from roms_tools.download import download_test_data
22
22
  from roms_tools.setup.topography import _compute_rfactor
23
23
 
24
+ try:
25
+ import xesmf # type: ignore
26
+ except ImportError:
27
+ xesmf = None
28
+
24
29
 
25
30
  @pytest.fixture()
26
31
  def counter_clockwise_rotated_grid():
@@ -177,13 +182,18 @@ def test_successful_initialization_with_topography(grid_fixture, request):
177
182
  assert grid is not None
178
183
 
179
184
 
180
- def test_plot():
181
- grid = Grid(
182
- nx=20, ny=20, size_x=100, size_y=100, center_lon=-20, center_lat=0, rot=0
183
- )
185
+ def test_plot(grid_that_straddles_180_degree_meridian):
186
+ grid_that_straddles_180_degree_meridian.plot(with_dim_names=False)
187
+ grid_that_straddles_180_degree_meridian.plot(with_dim_names=True)
188
+
189
+
190
+ @pytest.mark.skipif(xesmf is None, reason="xesmf required")
191
+ def test_plot_along_lat_lon(grid_that_straddles_180_degree_meridian):
192
+ grid_that_straddles_180_degree_meridian.plot(lat=61)
193
+ grid_that_straddles_180_degree_meridian.plot(lon=180)
184
194
 
185
- grid.plot(with_dim_names=False)
186
- grid.plot(with_dim_names=True)
195
+ with pytest.raises(ValueError, match="Specify either `lat` or `lon`, not both"):
196
+ grid_that_straddles_180_degree_meridian.plot(lat=61, lon=180)
187
197
 
188
198
 
189
199
  def test_save(tmp_path):