roms-tools 3.1.0__py3-none-any.whl → 3.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- roms_tools/__init__.py +5 -1
- roms_tools/constants.py +1 -0
- roms_tools/plot.py +164 -9
- roms_tools/regrid.py +6 -1
- roms_tools/setup/boundary_forcing.py +55 -30
- roms_tools/setup/cdr_forcing.py +84 -209
- roms_tools/setup/datasets.py +96 -14
- roms_tools/setup/grid.py +29 -2
- roms_tools/setup/river_forcing.py +110 -52
- roms_tools/setup/surface_forcing.py +12 -4
- roms_tools/setup/utils.py +57 -0
- roms_tools/tests/test_setup/test_boundary_forcing.py +57 -0
- roms_tools/tests/test_setup/test_cdr_forcing.py +53 -3
- roms_tools/tests/test_setup/test_datasets.py +76 -0
- roms_tools/tests/test_setup/test_grid.py +16 -6
- roms_tools/tests/test_setup/test_river_forcing.py +63 -6
- roms_tools/tests/test_setup/test_surface_forcing.py +26 -2
- roms_tools/tests/test_setup/test_utils.py +52 -3
- roms_tools/tests/test_setup/test_validation.py +21 -15
- roms_tools/tests/test_tiling/test_partition.py +45 -0
- roms_tools/tests/test_utils.py +101 -1
- roms_tools/tiling/partition.py +44 -30
- roms_tools/utils.py +426 -131
- {roms_tools-3.1.0.dist-info → roms_tools-3.1.2.dist-info}/METADATA +6 -3
- {roms_tools-3.1.0.dist-info → roms_tools-3.1.2.dist-info}/RECORD +28 -28
- {roms_tools-3.1.0.dist-info → roms_tools-3.1.2.dist-info}/WHEEL +0 -0
- {roms_tools-3.1.0.dist-info → roms_tools-3.1.2.dist-info}/licenses/LICENSE +0 -0
- {roms_tools-3.1.0.dist-info → roms_tools-3.1.2.dist-info}/top_level.txt +0 -0
|
@@ -4,14 +4,18 @@ from dataclasses import dataclass, field
|
|
|
4
4
|
from datetime import datetime
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
|
|
7
|
-
import cartopy.crs as ccrs
|
|
8
|
-
import matplotlib.cm as cm
|
|
9
7
|
import matplotlib.pyplot as plt
|
|
10
8
|
import numpy as np
|
|
11
9
|
import xarray as xr
|
|
12
10
|
|
|
13
11
|
from roms_tools import Grid
|
|
14
|
-
from roms_tools.
|
|
12
|
+
from roms_tools.constants import MAX_DISTINCT_COLORS
|
|
13
|
+
from roms_tools.plot import (
|
|
14
|
+
assign_category_colors,
|
|
15
|
+
get_projection,
|
|
16
|
+
plot_2d_horizontal_field,
|
|
17
|
+
plot_location,
|
|
18
|
+
)
|
|
15
19
|
from roms_tools.setup.datasets import (
|
|
16
20
|
DaiRiverDataset,
|
|
17
21
|
get_indices_of_nearest_grid_cell_for_rivers,
|
|
@@ -26,10 +30,14 @@ from roms_tools.setup.utils import (
|
|
|
26
30
|
get_variable_metadata,
|
|
27
31
|
substitute_nans_by_fillvalue,
|
|
28
32
|
to_dict,
|
|
33
|
+
validate_names,
|
|
29
34
|
write_to_yaml,
|
|
30
35
|
)
|
|
31
36
|
from roms_tools.utils import save_datasets
|
|
32
37
|
|
|
38
|
+
INCLUDE_ALL_RIVER_NAMES = "all"
|
|
39
|
+
MAX_RIVERS_TO_PLOT = 20 # must be <= MAX_DISTINCT_COLORS
|
|
40
|
+
|
|
33
41
|
|
|
34
42
|
@dataclass(kw_only=True)
|
|
35
43
|
class RiverForcing:
|
|
@@ -672,8 +680,24 @@ class RiverForcing:
|
|
|
672
680
|
"`convert_to_climatology = 'if_any_missing'` to automatically fill missing values with climatological data."
|
|
673
681
|
)
|
|
674
682
|
|
|
675
|
-
def plot_locations(self):
|
|
676
|
-
"""Plots the original and updated river locations on a map projection.
|
|
683
|
+
def plot_locations(self, river_names: list[str] | str = INCLUDE_ALL_RIVER_NAMES):
|
|
684
|
+
"""Plots the original and updated river locations on a map projection.
|
|
685
|
+
|
|
686
|
+
Parameters
|
|
687
|
+
----------
|
|
688
|
+
river_names : list[str], or str, optional
|
|
689
|
+
A list of release names to plot.
|
|
690
|
+
If a string equal to "all", all rivers will be plotted.
|
|
691
|
+
Defaults to "all".
|
|
692
|
+
|
|
693
|
+
"""
|
|
694
|
+
valid_river_names = list(self.indices.keys())
|
|
695
|
+
river_names = _validate_river_names(river_names, valid_river_names)
|
|
696
|
+
if len(valid_river_names) > MAX_DISTINCT_COLORS:
|
|
697
|
+
colors = assign_category_colors(river_names)
|
|
698
|
+
else:
|
|
699
|
+
colors = assign_category_colors(valid_river_names)
|
|
700
|
+
|
|
677
701
|
field = self.grid.ds.mask_rho
|
|
678
702
|
lon_deg = self.grid.ds.lon_rho
|
|
679
703
|
lat_deg = self.grid.ds.lat_rho
|
|
@@ -695,53 +719,37 @@ class RiverForcing:
|
|
|
695
719
|
for ax in axs:
|
|
696
720
|
plot_2d_horizontal_field(field, kwargs=kwargs, ax=ax, add_colorbar=False)
|
|
697
721
|
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
proj,
|
|
721
|
-
)
|
|
722
|
-
|
|
723
|
-
if name not in added_labels:
|
|
724
|
-
added_labels.add(name)
|
|
725
|
-
label = name
|
|
726
|
-
else:
|
|
727
|
-
label = "_None"
|
|
728
|
-
|
|
729
|
-
ax.plot(
|
|
730
|
-
transformed_lon,
|
|
731
|
-
transformed_lat,
|
|
732
|
-
marker="x",
|
|
733
|
-
markersize=8,
|
|
734
|
-
markeredgewidth=2,
|
|
735
|
-
label=label,
|
|
736
|
-
color=colors[name],
|
|
737
|
-
)
|
|
722
|
+
points = {}
|
|
723
|
+
for j, (ax, indices) in enumerate(
|
|
724
|
+
[(ax, ind) for ax, ind in zip(axs, [self.original_indices, self.indices])]
|
|
725
|
+
):
|
|
726
|
+
for name in river_names:
|
|
727
|
+
if name in indices:
|
|
728
|
+
for i, (eta_index, xi_index) in enumerate(indices[name]):
|
|
729
|
+
lon = self.grid.ds.lon_rho[eta_index, xi_index].item()
|
|
730
|
+
lat = self.grid.ds.lat_rho[eta_index, xi_index].item()
|
|
731
|
+
key = name if i == 0 else f"_{name}_{i}"
|
|
732
|
+
points[key] = {
|
|
733
|
+
"lon": lon,
|
|
734
|
+
"lat": lat,
|
|
735
|
+
"color": colors[name],
|
|
736
|
+
}
|
|
737
|
+
|
|
738
|
+
plot_location(
|
|
739
|
+
grid_ds=self.grid.ds,
|
|
740
|
+
points=points,
|
|
741
|
+
ax=ax,
|
|
742
|
+
include_legend=(j == 1),
|
|
743
|
+
)
|
|
738
744
|
|
|
739
745
|
axs[0].set_title("Original river locations")
|
|
740
746
|
axs[1].set_title("Updated river locations")
|
|
741
747
|
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
748
|
+
def plot(
|
|
749
|
+
self,
|
|
750
|
+
var_name: str = "river_volume",
|
|
751
|
+
river_names: list[str] | str = INCLUDE_ALL_RIVER_NAMES,
|
|
752
|
+
):
|
|
745
753
|
"""Plots the river flux (e.g., volume, temperature, or salinity) over time for
|
|
746
754
|
all rivers.
|
|
747
755
|
|
|
@@ -791,8 +799,19 @@ class RiverForcing:
|
|
|
791
799
|
- 'river_diazFe' : river diazFe (from river_tracer).
|
|
792
800
|
|
|
793
801
|
The default is 'river_volume'.
|
|
802
|
+
|
|
803
|
+
river_names : list[str], or str, optional
|
|
804
|
+
A list of release names to plot.
|
|
805
|
+
If a string equal to "all", all rivers will be plotted.
|
|
806
|
+
Defaults to "all".
|
|
807
|
+
|
|
794
808
|
"""
|
|
795
|
-
|
|
809
|
+
valid_river_names = list(self.indices.keys())
|
|
810
|
+
river_names = _validate_river_names(river_names, valid_river_names)
|
|
811
|
+
if len(valid_river_names) > MAX_DISTINCT_COLORS:
|
|
812
|
+
colors = assign_category_colors(river_names)
|
|
813
|
+
else:
|
|
814
|
+
colors = assign_category_colors(valid_river_names)
|
|
796
815
|
|
|
797
816
|
if self.climatology:
|
|
798
817
|
xticks = self.ds.month.values
|
|
@@ -814,15 +833,19 @@ class RiverForcing:
|
|
|
814
833
|
units = d[var_name_wo_river]["units"]
|
|
815
834
|
long_name = f"River {d[var_name_wo_river]['long_name']}"
|
|
816
835
|
|
|
817
|
-
|
|
836
|
+
fig, ax = plt.subplots(1, 1, figsize=(9, 5))
|
|
837
|
+
for name in river_names:
|
|
838
|
+
nriver = np.where(self.ds["river_name"].values == name)[0].item()
|
|
839
|
+
|
|
818
840
|
ax.plot(
|
|
819
841
|
xticks,
|
|
820
|
-
field.isel(nriver=
|
|
842
|
+
field.isel(nriver=nriver),
|
|
821
843
|
marker="x",
|
|
822
844
|
markersize=8,
|
|
823
845
|
markeredgewidth=2,
|
|
824
846
|
lw=2,
|
|
825
|
-
label=
|
|
847
|
+
label=name,
|
|
848
|
+
color=colors[name],
|
|
826
849
|
)
|
|
827
850
|
|
|
828
851
|
ax.set_xticks(xticks)
|
|
@@ -965,3 +988,38 @@ def check_river_locations_are_along_coast(mask, indices):
|
|
|
965
988
|
raise ValueError(
|
|
966
989
|
f"River `{key}` is not located on the coast at grid cell ({eta_rho}, {xi_rho})."
|
|
967
990
|
)
|
|
991
|
+
|
|
992
|
+
|
|
993
|
+
def _validate_river_names(
|
|
994
|
+
river_names: list[str] | str, valid_river_names: list[str]
|
|
995
|
+
) -> list[str]:
|
|
996
|
+
"""
|
|
997
|
+
Validate and filter a list of river names.
|
|
998
|
+
|
|
999
|
+
Ensures that each river name exists in `valid_river_names` and limits the list
|
|
1000
|
+
to `MAX_RIVERS_TO_PLOT` entries with a warning if truncated.
|
|
1001
|
+
|
|
1002
|
+
Parameters
|
|
1003
|
+
----------
|
|
1004
|
+
river_names : list of str or INCLUDE_ALL_RIVER_NAMES
|
|
1005
|
+
Names of rivers to plot, or sentinel to include all.
|
|
1006
|
+
valid_river_names : list of str
|
|
1007
|
+
List of valid river names.
|
|
1008
|
+
|
|
1009
|
+
Returns
|
|
1010
|
+
-------
|
|
1011
|
+
list of str
|
|
1012
|
+
Validated and truncated list of river names.
|
|
1013
|
+
|
|
1014
|
+
Raises
|
|
1015
|
+
------
|
|
1016
|
+
ValueError
|
|
1017
|
+
If any names are invalid.
|
|
1018
|
+
"""
|
|
1019
|
+
return validate_names(
|
|
1020
|
+
river_names,
|
|
1021
|
+
valid_river_names,
|
|
1022
|
+
INCLUDE_ALL_RIVER_NAMES,
|
|
1023
|
+
MAX_RIVERS_TO_PLOT,
|
|
1024
|
+
label="river",
|
|
1025
|
+
)
|
|
@@ -150,12 +150,20 @@ class SurfaceForcing:
|
|
|
150
150
|
use_coarse_grid = False
|
|
151
151
|
elif self.coarse_grid_mode == "auto":
|
|
152
152
|
use_coarse_grid = self._determine_coarse_grid_usage(data)
|
|
153
|
-
if use_coarse_grid:
|
|
154
|
-
logging.info("Data will be interpolated onto grid coarsened by factor 2.")
|
|
155
|
-
else:
|
|
156
|
-
logging.info("Data will be interpolated onto fine grid.")
|
|
157
153
|
self.use_coarse_grid = use_coarse_grid
|
|
158
154
|
|
|
155
|
+
opt_file = "bulk_frc.opt" if self.type == "physics" else "bgc.opt"
|
|
156
|
+
grid_desc = "grid coarsened by factor 2" if use_coarse_grid else "fine grid"
|
|
157
|
+
interp_flag = 1 if use_coarse_grid else 0
|
|
158
|
+
|
|
159
|
+
logging.info(
|
|
160
|
+
"Data will be interpolated onto the %s. "
|
|
161
|
+
"Remember to set `interp_frc = %d` in your `%s` ROMS option file.",
|
|
162
|
+
grid_desc,
|
|
163
|
+
interp_flag,
|
|
164
|
+
opt_file,
|
|
165
|
+
)
|
|
166
|
+
|
|
159
167
|
target_coords = get_target_coords(self.grid, self.use_coarse_grid)
|
|
160
168
|
self.target_coords = target_coords
|
|
161
169
|
|
roms_tools/setup/utils.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import importlib.metadata
|
|
2
|
+
import logging
|
|
2
3
|
from collections.abc import Sequence
|
|
3
4
|
from dataclasses import asdict, fields, is_dataclass
|
|
4
5
|
from datetime import datetime
|
|
@@ -1794,3 +1795,59 @@ def to_float(val):
|
|
|
1794
1795
|
if isinstance(val, list):
|
|
1795
1796
|
return [float(v) for v in val]
|
|
1796
1797
|
return float(val)
|
|
1798
|
+
|
|
1799
|
+
|
|
1800
|
+
def validate_names(
|
|
1801
|
+
names: list[str] | str,
|
|
1802
|
+
valid_names: list[str],
|
|
1803
|
+
include_all_sentinel: str,
|
|
1804
|
+
max_to_plot: int,
|
|
1805
|
+
label: str = "item",
|
|
1806
|
+
) -> list[str]:
|
|
1807
|
+
"""
|
|
1808
|
+
Generic validation and filtering for a list of names.
|
|
1809
|
+
|
|
1810
|
+
Parameters
|
|
1811
|
+
----------
|
|
1812
|
+
names : list of str or sentinel
|
|
1813
|
+
Names to validate, or sentinel value to include all valid names.
|
|
1814
|
+
valid_names : list of str
|
|
1815
|
+
List of valid names to check against.
|
|
1816
|
+
include_all_sentinel : str
|
|
1817
|
+
Sentinel value to indicate all names should be included.
|
|
1818
|
+
max_to_plot : int
|
|
1819
|
+
Maximum number of names to return.
|
|
1820
|
+
label : str, default "item"
|
|
1821
|
+
Label to use in error and warning messages.
|
|
1822
|
+
|
|
1823
|
+
Returns
|
|
1824
|
+
-------
|
|
1825
|
+
list of str
|
|
1826
|
+
Validated and possibly truncated list of names.
|
|
1827
|
+
|
|
1828
|
+
Raises
|
|
1829
|
+
------
|
|
1830
|
+
ValueError
|
|
1831
|
+
If any names are invalid or input is not a list of strings.
|
|
1832
|
+
"""
|
|
1833
|
+
if names == include_all_sentinel:
|
|
1834
|
+
names = valid_names
|
|
1835
|
+
|
|
1836
|
+
if isinstance(names, list):
|
|
1837
|
+
if not all(isinstance(n, str) for n in names):
|
|
1838
|
+
raise ValueError(f"All elements in `{label}_names` must be strings.")
|
|
1839
|
+
else:
|
|
1840
|
+
raise ValueError(f"`{label}_names` should be a list of strings.")
|
|
1841
|
+
|
|
1842
|
+
invalid = [n for n in names if n not in valid_names]
|
|
1843
|
+
if invalid:
|
|
1844
|
+
raise ValueError(f"Invalid {label}s: {', '.join(invalid)}")
|
|
1845
|
+
|
|
1846
|
+
if len(names) > max_to_plot:
|
|
1847
|
+
logging.warning(
|
|
1848
|
+
f"Only the first {max_to_plot} {label}s will be plotted "
|
|
1849
|
+
f"(received {len(names)})."
|
|
1850
|
+
)
|
|
1851
|
+
names = names[:max_to_plot]
|
|
1852
|
+
|
|
1853
|
+
return names
|
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
import os
|
|
2
3
|
import textwrap
|
|
3
4
|
from datetime import datetime
|
|
4
5
|
from pathlib import Path
|
|
6
|
+
from unittest import mock
|
|
5
7
|
|
|
6
8
|
import matplotlib.pyplot as plt
|
|
7
9
|
import numpy as np
|
|
@@ -758,3 +760,58 @@ def test_from_yaml_missing_boundary_forcing(tmp_path, use_dask):
|
|
|
758
760
|
|
|
759
761
|
yaml_filepath = Path(yaml_filepath)
|
|
760
762
|
yaml_filepath.unlink()
|
|
763
|
+
|
|
764
|
+
|
|
765
|
+
@pytest.mark.stream
|
|
766
|
+
@pytest.mark.use_dask
|
|
767
|
+
@pytest.mark.use_copernicus
|
|
768
|
+
def test_default_glorys_dataset_loading(tiny_grid: Grid) -> None:
|
|
769
|
+
"""Verify the default GLORYS dataset is loaded when a path is not provided."""
|
|
770
|
+
start_time = datetime(2010, 2, 1)
|
|
771
|
+
end_time = datetime(2010, 3, 1)
|
|
772
|
+
|
|
773
|
+
with mock.patch.dict(
|
|
774
|
+
os.environ, {"PYDEVD_WARN_EVALUATION_TIMEOUT": "90"}, clear=True
|
|
775
|
+
):
|
|
776
|
+
bf = BoundaryForcing(
|
|
777
|
+
grid=tiny_grid,
|
|
778
|
+
source={"name": "GLORYS"},
|
|
779
|
+
type="physics",
|
|
780
|
+
start_time=start_time,
|
|
781
|
+
end_time=end_time,
|
|
782
|
+
use_dask=True,
|
|
783
|
+
bypass_validation=True,
|
|
784
|
+
)
|
|
785
|
+
|
|
786
|
+
expected_vars = {"u_south", "v_south", "temp_south", "salt_south"}
|
|
787
|
+
assert set(bf.ds.data_vars).issuperset(expected_vars)
|
|
788
|
+
|
|
789
|
+
|
|
790
|
+
@pytest.mark.parametrize(
|
|
791
|
+
"use_dask",
|
|
792
|
+
[pytest.param(True, marks=pytest.mark.use_dask), False],
|
|
793
|
+
)
|
|
794
|
+
def test_nondefault_glorys_dataset_loading(small_grid: Grid, use_dask: bool) -> None:
|
|
795
|
+
"""Verify a non-default GLORYS dataset is loaded when a path is provided."""
|
|
796
|
+
start_time = datetime(2012, 1, 1)
|
|
797
|
+
end_time = datetime(2012, 12, 31)
|
|
798
|
+
|
|
799
|
+
local_path = Path(download_test_data("GLORYS_NA_20120101.nc"))
|
|
800
|
+
|
|
801
|
+
with mock.patch.dict(
|
|
802
|
+
os.environ, {"PYDEVD_WARN_EVALUATION_TIMEOUT": "90"}, clear=True
|
|
803
|
+
):
|
|
804
|
+
bf = BoundaryForcing(
|
|
805
|
+
grid=small_grid,
|
|
806
|
+
source={
|
|
807
|
+
"name": "GLORYS",
|
|
808
|
+
"path": local_path,
|
|
809
|
+
},
|
|
810
|
+
type="physics",
|
|
811
|
+
start_time=start_time,
|
|
812
|
+
end_time=end_time,
|
|
813
|
+
use_dask=use_dask,
|
|
814
|
+
)
|
|
815
|
+
|
|
816
|
+
expected_vars = {"u_south", "v_south", "temp_south", "salt_south"}
|
|
817
|
+
assert set(bf.ds.data_vars).issuperset(expected_vars)
|
|
@@ -9,7 +9,7 @@ from pydantic import ValidationError
|
|
|
9
9
|
|
|
10
10
|
from conftest import calculate_file_hash
|
|
11
11
|
from roms_tools import CDRForcing, Grid, TracerPerturbation, VolumeRelease
|
|
12
|
-
from roms_tools.constants import NUM_TRACERS
|
|
12
|
+
from roms_tools.constants import MAX_DISTINCT_COLORS, NUM_TRACERS
|
|
13
13
|
from roms_tools.setup.cdr_forcing import (
|
|
14
14
|
CDRForcingDatasetBuilder,
|
|
15
15
|
ReleaseCollector,
|
|
@@ -725,6 +725,8 @@ class TestCDRForcing:
|
|
|
725
725
|
rot=0,
|
|
726
726
|
N=3,
|
|
727
727
|
)
|
|
728
|
+
self.grid = grid
|
|
729
|
+
|
|
728
730
|
grid_that_straddles = Grid(
|
|
729
731
|
nx=18,
|
|
730
732
|
ny=18,
|
|
@@ -817,8 +819,13 @@ class TestCDRForcing:
|
|
|
817
819
|
self.volume_release_cdr_forcing_with_straddling_grid,
|
|
818
820
|
]:
|
|
819
821
|
cdr.plot_volume_flux()
|
|
822
|
+
cdr.plot_volume_flux(release_names=["first_release"])
|
|
823
|
+
|
|
820
824
|
cdr.plot_tracer_concentration("ALK")
|
|
825
|
+
cdr.plot_tracer_concentration("ALK", release_names=["first_release"])
|
|
826
|
+
|
|
821
827
|
cdr.plot_tracer_concentration("DIC")
|
|
828
|
+
cdr.plot_tracer_concentration("DIC", release_names=["first_release"])
|
|
822
829
|
|
|
823
830
|
self.volume_release_cdr_forcing.plot_locations()
|
|
824
831
|
self.volume_release_cdr_forcing.plot_locations(release_names=["first_release"])
|
|
@@ -830,13 +837,56 @@ class TestCDRForcing:
|
|
|
830
837
|
self.tracer_perturbation_cdr_forcing_with_straddling_grid,
|
|
831
838
|
]:
|
|
832
839
|
cdr.plot_tracer_flux("ALK")
|
|
840
|
+
cdr.plot_tracer_flux("ALK", release_names=["first_release"])
|
|
841
|
+
|
|
833
842
|
cdr.plot_tracer_flux("DIC")
|
|
843
|
+
cdr.plot_tracer_flux("DIC", release_names=["first_release"])
|
|
834
844
|
|
|
835
845
|
self.tracer_perturbation_cdr_forcing.plot_locations()
|
|
836
846
|
self.tracer_perturbation_cdr_forcing.plot_locations(
|
|
837
847
|
release_names=["first_release"]
|
|
838
848
|
)
|
|
839
849
|
|
|
850
|
+
def test_plot_max_releases(self, caplog):
|
|
851
|
+
# Prepare releases with more than MAX_DISTINCT_COLORS unique names
|
|
852
|
+
releases = []
|
|
853
|
+
for i in range(MAX_DISTINCT_COLORS + 1):
|
|
854
|
+
release = self.first_volume_release.__replace__(name=f"release_{i}")
|
|
855
|
+
releases.append(release)
|
|
856
|
+
|
|
857
|
+
# Construct a CDRForcing object with too many releases to plot
|
|
858
|
+
cdr_forcing = CDRForcing(
|
|
859
|
+
grid=self.grid,
|
|
860
|
+
start_time=self.start_time,
|
|
861
|
+
end_time=self.end_time,
|
|
862
|
+
releases=releases,
|
|
863
|
+
)
|
|
864
|
+
|
|
865
|
+
release_names = [r.name for r in releases]
|
|
866
|
+
|
|
867
|
+
plot_methods_with_release_names = [
|
|
868
|
+
cdr_forcing.plot_locations,
|
|
869
|
+
cdr_forcing.plot_volume_flux,
|
|
870
|
+
]
|
|
871
|
+
|
|
872
|
+
for plot_func in plot_methods_with_release_names:
|
|
873
|
+
caplog.clear()
|
|
874
|
+
with caplog.at_level("WARNING"):
|
|
875
|
+
plot_func(release_names=release_names)
|
|
876
|
+
assert any(
|
|
877
|
+
f"Only the first {MAX_DISTINCT_COLORS} releases will be plotted"
|
|
878
|
+
in message
|
|
879
|
+
for message in caplog.messages
|
|
880
|
+
), f"Warning not raised by {plot_func.__name__}"
|
|
881
|
+
|
|
882
|
+
with caplog.at_level("WARNING"):
|
|
883
|
+
cdr_forcing.plot_locations(release_names=release_names)
|
|
884
|
+
|
|
885
|
+
assert any(
|
|
886
|
+
f"Only the first {MAX_DISTINCT_COLORS} releases will be plotted" in message
|
|
887
|
+
for message in caplog.messages
|
|
888
|
+
)
|
|
889
|
+
|
|
840
890
|
@pytest.mark.skipif(xesmf is None, reason="xesmf required")
|
|
841
891
|
def test_plot_distribution(self):
|
|
842
892
|
self.volume_release_cdr_forcing.plot_distribution("first_release")
|
|
@@ -856,10 +906,10 @@ class TestCDRForcing:
|
|
|
856
906
|
with pytest.raises(ValueError, match="Invalid releases"):
|
|
857
907
|
self.volume_release_cdr_forcing.plot_locations(release_names=["fake"])
|
|
858
908
|
|
|
859
|
-
with pytest.raises(ValueError, match="should be a
|
|
909
|
+
with pytest.raises(ValueError, match="should be a list"):
|
|
860
910
|
self.volume_release_cdr_forcing.plot_locations(release_names=4)
|
|
861
911
|
|
|
862
|
-
with pytest.raises(ValueError, match="
|
|
912
|
+
with pytest.raises(ValueError, match="must be strings"):
|
|
863
913
|
self.volume_release_cdr_forcing.plot_locations(release_names=[4])
|
|
864
914
|
|
|
865
915
|
def test_cdr_forcing_save(self, tmp_path):
|
|
@@ -2,6 +2,7 @@ import logging
|
|
|
2
2
|
from collections import OrderedDict
|
|
3
3
|
from datetime import datetime
|
|
4
4
|
from pathlib import Path
|
|
5
|
+
from unittest import mock
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
import pytest
|
|
@@ -11,11 +12,14 @@ from roms_tools.download import download_test_data
|
|
|
11
12
|
from roms_tools.setup.datasets import (
|
|
12
13
|
CESMBGCDataset,
|
|
13
14
|
Dataset,
|
|
15
|
+
ERA5ARCODataset,
|
|
14
16
|
ERA5Correction,
|
|
15
17
|
GLORYSDataset,
|
|
18
|
+
GLORYSDefaultDataset,
|
|
16
19
|
RiverDataset,
|
|
17
20
|
TPXODataset,
|
|
18
21
|
)
|
|
22
|
+
from roms_tools.setup.surface_forcing import DEFAULT_ERA5_ARCO_PATH
|
|
19
23
|
|
|
20
24
|
|
|
21
25
|
@pytest.fixture
|
|
@@ -437,6 +441,78 @@ def test_era5_correction_choose_subdomain(use_dask):
|
|
|
437
441
|
assert (data.ds["longitude"] == lons).all()
|
|
438
442
|
|
|
439
443
|
|
|
444
|
+
@pytest.mark.use_gcsfs
|
|
445
|
+
def test_default_era5_dataset_loading_without_dask() -> None:
|
|
446
|
+
"""Verify that loading the default ERA5 dataset fails if use_dask is not True."""
|
|
447
|
+
start_time = datetime(2020, 2, 1)
|
|
448
|
+
end_time = datetime(2020, 2, 2)
|
|
449
|
+
|
|
450
|
+
with pytest.raises(ValueError):
|
|
451
|
+
_ = ERA5ARCODataset(
|
|
452
|
+
filename=DEFAULT_ERA5_ARCO_PATH,
|
|
453
|
+
start_time=start_time,
|
|
454
|
+
end_time=end_time,
|
|
455
|
+
use_dask=False,
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
@pytest.mark.skip("Temporary skip until memory consumption issue is addressed. # TODO")
|
|
460
|
+
@pytest.mark.stream
|
|
461
|
+
@pytest.mark.use_dask
|
|
462
|
+
@pytest.mark.use_gcsfs
|
|
463
|
+
def test_default_era5_dataset_loading() -> None:
|
|
464
|
+
"""Verify the default ERA5 dataset is loaded correctly."""
|
|
465
|
+
start_time = datetime(2020, 2, 1)
|
|
466
|
+
end_time = datetime(2020, 2, 2)
|
|
467
|
+
|
|
468
|
+
ds = ERA5ARCODataset(
|
|
469
|
+
filename=DEFAULT_ERA5_ARCO_PATH,
|
|
470
|
+
start_time=start_time,
|
|
471
|
+
end_time=end_time,
|
|
472
|
+
use_dask=True,
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
expected_vars = {"uwnd", "vwnd", "swrad", "lwrad", "Tair", "rain"}
|
|
476
|
+
assert set(ds.var_names).issuperset(expected_vars)
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
@pytest.mark.use_copernicus
|
|
480
|
+
def test_default_glorys_dataset_loading_dask_not_installed() -> None:
|
|
481
|
+
"""Verify that loading the default GLORYS dataset fails if dask is not available."""
|
|
482
|
+
start_time = datetime(2020, 2, 1)
|
|
483
|
+
end_time = datetime(2020, 2, 2)
|
|
484
|
+
|
|
485
|
+
with (
|
|
486
|
+
pytest.raises(RuntimeError),
|
|
487
|
+
mock.patch("roms_tools.utils._has_dask", return_value=False),
|
|
488
|
+
):
|
|
489
|
+
_ = GLORYSDefaultDataset(
|
|
490
|
+
filename=GLORYSDefaultDataset.dataset_name,
|
|
491
|
+
start_time=start_time,
|
|
492
|
+
end_time=end_time,
|
|
493
|
+
use_dask=True,
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
@pytest.mark.stream
|
|
498
|
+
@pytest.mark.use_copernicus
|
|
499
|
+
@pytest.mark.use_dask
|
|
500
|
+
def test_default_glorys_dataset_loading() -> None:
|
|
501
|
+
"""Verify the default GLORYS dataset is loaded correctly."""
|
|
502
|
+
start_time = datetime(2012, 1, 1)
|
|
503
|
+
end_time = datetime(2013, 1, 1)
|
|
504
|
+
|
|
505
|
+
ds = GLORYSDefaultDataset(
|
|
506
|
+
filename=GLORYSDefaultDataset.dataset_name,
|
|
507
|
+
start_time=start_time,
|
|
508
|
+
end_time=end_time,
|
|
509
|
+
use_dask=True,
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
expected_vars = {"temp", "salt", "u", "v", "zeta"}
|
|
513
|
+
assert set(ds.var_names).issuperset(expected_vars)
|
|
514
|
+
|
|
515
|
+
|
|
440
516
|
def test_data_concatenation(use_dask):
|
|
441
517
|
fname = download_test_data("GLORYS_NA_2012.nc")
|
|
442
518
|
data = GLORYSDataset(
|
|
@@ -21,6 +21,11 @@ from roms_tools.constants import (
|
|
|
21
21
|
from roms_tools.download import download_test_data
|
|
22
22
|
from roms_tools.setup.topography import _compute_rfactor
|
|
23
23
|
|
|
24
|
+
try:
|
|
25
|
+
import xesmf # type: ignore
|
|
26
|
+
except ImportError:
|
|
27
|
+
xesmf = None
|
|
28
|
+
|
|
24
29
|
|
|
25
30
|
@pytest.fixture()
|
|
26
31
|
def counter_clockwise_rotated_grid():
|
|
@@ -177,13 +182,18 @@ def test_successful_initialization_with_topography(grid_fixture, request):
|
|
|
177
182
|
assert grid is not None
|
|
178
183
|
|
|
179
184
|
|
|
180
|
-
def test_plot():
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
185
|
+
def test_plot(grid_that_straddles_180_degree_meridian):
|
|
186
|
+
grid_that_straddles_180_degree_meridian.plot(with_dim_names=False)
|
|
187
|
+
grid_that_straddles_180_degree_meridian.plot(with_dim_names=True)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
@pytest.mark.skipif(xesmf is None, reason="xesmf required")
|
|
191
|
+
def test_plot_along_lat_lon(grid_that_straddles_180_degree_meridian):
|
|
192
|
+
grid_that_straddles_180_degree_meridian.plot(lat=61)
|
|
193
|
+
grid_that_straddles_180_degree_meridian.plot(lon=180)
|
|
184
194
|
|
|
185
|
-
|
|
186
|
-
|
|
195
|
+
with pytest.raises(ValueError, match="Specify either `lat` or `lon`, not both"):
|
|
196
|
+
grid_that_straddles_180_degree_meridian.plot(lat=61, lon=180)
|
|
187
197
|
|
|
188
198
|
|
|
189
199
|
def test_save(tmp_path):
|