roms-tools 3.1.0__py3-none-any.whl → 3.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- roms_tools/constants.py +1 -0
- roms_tools/plot.py +108 -0
- roms_tools/setup/cdr_forcing.py +83 -202
- roms_tools/setup/river_forcing.py +110 -52
- roms_tools/setup/utils.py +57 -0
- roms_tools/tests/test_setup/test_cdr_forcing.py +53 -3
- roms_tools/tests/test_setup/test_river_forcing.py +63 -6
- roms_tools/tests/test_setup/test_utils.py +52 -3
- {roms_tools-3.1.0.dist-info → roms_tools-3.1.1.dist-info}/METADATA +3 -1
- {roms_tools-3.1.0.dist-info → roms_tools-3.1.1.dist-info}/RECORD +13 -13
- {roms_tools-3.1.0.dist-info → roms_tools-3.1.1.dist-info}/WHEEL +0 -0
- {roms_tools-3.1.0.dist-info → roms_tools-3.1.1.dist-info}/licenses/LICENSE +0 -0
- {roms_tools-3.1.0.dist-info → roms_tools-3.1.1.dist-info}/top_level.txt +0 -0
roms_tools/constants.py
CHANGED
|
@@ -3,3 +3,4 @@ MAXIMUM_GRID_SIZE = 25000 # in km
|
|
|
3
3
|
UPPER_BOUND_THETA_S = 10 # upper bound for surface vertical stretching parameter
|
|
4
4
|
UPPER_BOUND_THETA_B = 10 # upper bound for bottom vertical stretching parameter
|
|
5
5
|
NUM_TRACERS = 34 # Number of tracers (temperature, salinity, BGC tracers)
|
|
6
|
+
MAX_DISTINCT_COLORS = 20 # Based on tab20 colormap
|
roms_tools/plot.py
CHANGED
|
@@ -1097,3 +1097,111 @@ def plot(
|
|
|
1097
1097
|
|
|
1098
1098
|
if save_path:
|
|
1099
1099
|
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
|
1100
|
+
|
|
1101
|
+
|
|
1102
|
+
def assign_category_colors(names: list[str]) -> dict[str, tuple]:
|
|
1103
|
+
"""
|
|
1104
|
+
Assign a distinct color to each name using a Matplotlib categorical colormap.
|
|
1105
|
+
|
|
1106
|
+
Parameters
|
|
1107
|
+
----------
|
|
1108
|
+
names : list[str]
|
|
1109
|
+
List of category names (e.g., releases, rivers, etc.) to assign colors to.
|
|
1110
|
+
|
|
1111
|
+
Returns
|
|
1112
|
+
-------
|
|
1113
|
+
dict[str, tuple]
|
|
1114
|
+
Dictionary mapping each name to a unique RGBA color.
|
|
1115
|
+
|
|
1116
|
+
Raises
|
|
1117
|
+
------
|
|
1118
|
+
ValueError
|
|
1119
|
+
If the number of names exceeds the selected colormap's capacity.
|
|
1120
|
+
|
|
1121
|
+
Notes
|
|
1122
|
+
-----
|
|
1123
|
+
Colormap selection is based on the number of items:
|
|
1124
|
+
- <= 10: 'tab10'
|
|
1125
|
+
- <= 20: 'tab20'
|
|
1126
|
+
- > 20 : 'tab20b'
|
|
1127
|
+
"""
|
|
1128
|
+
n = len(names)
|
|
1129
|
+
|
|
1130
|
+
if n <= 10:
|
|
1131
|
+
cmap = plt.get_cmap("tab10")
|
|
1132
|
+
elif n <= 20:
|
|
1133
|
+
cmap = plt.get_cmap("tab20")
|
|
1134
|
+
|
|
1135
|
+
if n > cmap.N:
|
|
1136
|
+
raise ValueError(
|
|
1137
|
+
f"Too many categories ({n}) for selected colormap ({cmap.name}) "
|
|
1138
|
+
f"which supports only {cmap.N} distinct entries."
|
|
1139
|
+
)
|
|
1140
|
+
|
|
1141
|
+
return {name: cmap(i) for i, name in enumerate(names)}
|
|
1142
|
+
|
|
1143
|
+
|
|
1144
|
+
def plot_location(
|
|
1145
|
+
grid_ds: xr.Dataset,
|
|
1146
|
+
points: dict[str, dict],
|
|
1147
|
+
ax: Axes,
|
|
1148
|
+
include_legend: bool = True,
|
|
1149
|
+
) -> None:
|
|
1150
|
+
"""Plot named geographic points on a top-down map view.
|
|
1151
|
+
|
|
1152
|
+
Each point is represented as a marker on the map, optionally colored.
|
|
1153
|
+
This function is generic and can be used for releases, rivers, etc.
|
|
1154
|
+
|
|
1155
|
+
Parameters
|
|
1156
|
+
----------
|
|
1157
|
+
grid_ds : xr.Dataset
|
|
1158
|
+
The grid dataset containing 'lon_rho' and 'lat_rho', and a 'straddle' attribute.
|
|
1159
|
+
|
|
1160
|
+
points : dict[str, dict]
|
|
1161
|
+
Dictionary of points to plot. Keys are point names. Each value is a dict with:
|
|
1162
|
+
- "lat": float, latitude in degrees
|
|
1163
|
+
- "lon": float, longitude in degrees
|
|
1164
|
+
- Optional "color": tuple or str, matplotlib color
|
|
1165
|
+
|
|
1166
|
+
ax : matplotlib.axes.Axes
|
|
1167
|
+
The axis object to plot on.
|
|
1168
|
+
|
|
1169
|
+
include_legend : bool, default True
|
|
1170
|
+
Whether to include a legend showing point names.
|
|
1171
|
+
|
|
1172
|
+
Returns
|
|
1173
|
+
-------
|
|
1174
|
+
None
|
|
1175
|
+
"""
|
|
1176
|
+
lon_deg = grid_ds.lon_rho
|
|
1177
|
+
lat_deg = grid_ds.lat_rho
|
|
1178
|
+
|
|
1179
|
+
if "straddle" not in grid_ds.attrs:
|
|
1180
|
+
raise AttributeError("Grid dataset must have a 'straddle' attribute.")
|
|
1181
|
+
|
|
1182
|
+
straddle = grid_ds.attrs["straddle"] == "True"
|
|
1183
|
+
if straddle:
|
|
1184
|
+
lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
|
|
1185
|
+
|
|
1186
|
+
trans = get_projection(lon_deg, lat_deg)
|
|
1187
|
+
proj = ccrs.PlateCarree()
|
|
1188
|
+
|
|
1189
|
+
for name, info in points.items():
|
|
1190
|
+
lon = info["lon"]
|
|
1191
|
+
lat = info["lat"]
|
|
1192
|
+
color = info.get("color", "k") # Default to black if no color specified
|
|
1193
|
+
|
|
1194
|
+
x, y = trans.transform_point(lon, lat, proj)
|
|
1195
|
+
|
|
1196
|
+
ax.plot(
|
|
1197
|
+
x,
|
|
1198
|
+
y,
|
|
1199
|
+
marker="x",
|
|
1200
|
+
markersize=8,
|
|
1201
|
+
markeredgewidth=2,
|
|
1202
|
+
label=name,
|
|
1203
|
+
color=color,
|
|
1204
|
+
)
|
|
1205
|
+
|
|
1206
|
+
if include_legend:
|
|
1207
|
+
ax.legend(loc="center left", bbox_to_anchor=(1.1, 0.5))
|
roms_tools/setup/cdr_forcing.py
CHANGED
|
@@ -6,12 +6,10 @@ from datetime import datetime
|
|
|
6
6
|
from pathlib import Path
|
|
7
7
|
from typing import Annotated
|
|
8
8
|
|
|
9
|
-
import cartopy.crs as ccrs
|
|
10
9
|
import matplotlib.gridspec as gridspec
|
|
11
10
|
import matplotlib.pyplot as plt
|
|
12
11
|
import numpy as np
|
|
13
12
|
import xarray as xr
|
|
14
|
-
from matplotlib.axes import Axes
|
|
15
13
|
from pydantic import (
|
|
16
14
|
BaseModel,
|
|
17
15
|
Field,
|
|
@@ -22,7 +20,14 @@ from pydantic import (
|
|
|
22
20
|
)
|
|
23
21
|
|
|
24
22
|
from roms_tools import Grid
|
|
25
|
-
from roms_tools.
|
|
23
|
+
from roms_tools.constants import MAX_DISTINCT_COLORS
|
|
24
|
+
from roms_tools.plot import (
|
|
25
|
+
assign_category_colors,
|
|
26
|
+
get_projection,
|
|
27
|
+
plot,
|
|
28
|
+
plot_2d_horizontal_field,
|
|
29
|
+
plot_location,
|
|
30
|
+
)
|
|
26
31
|
from roms_tools.setup.cdr_release import (
|
|
27
32
|
Release,
|
|
28
33
|
ReleaseType,
|
|
@@ -36,6 +41,7 @@ from roms_tools.setup.utils import (
|
|
|
36
41
|
gc_dist,
|
|
37
42
|
get_target_coords,
|
|
38
43
|
to_dict,
|
|
44
|
+
validate_names,
|
|
39
45
|
write_to_yaml,
|
|
40
46
|
)
|
|
41
47
|
from roms_tools.utils import (
|
|
@@ -45,6 +51,7 @@ from roms_tools.utils import (
|
|
|
45
51
|
from roms_tools.vertical_coordinate import compute_depth_coordinates
|
|
46
52
|
|
|
47
53
|
INCLUDE_ALL_RELEASE_NAMES = "all"
|
|
54
|
+
MAX_RELEASES_TO_PLOT = 20 # must be <= MAX_DISTINCT_COLORS
|
|
48
55
|
|
|
49
56
|
|
|
50
57
|
class ReleaseSimulationManager(BaseModel):
|
|
@@ -389,7 +396,10 @@ class CDRForcing(BaseModel):
|
|
|
389
396
|
return self._ds
|
|
390
397
|
|
|
391
398
|
def plot_volume_flux(
|
|
392
|
-
self,
|
|
399
|
+
self,
|
|
400
|
+
start: datetime | None = None,
|
|
401
|
+
end: datetime | None = None,
|
|
402
|
+
release_names: list[str] | str = INCLUDE_ALL_RELEASE_NAMES,
|
|
393
403
|
):
|
|
394
404
|
"""Plot the volume flux for each specified release within the given time range.
|
|
395
405
|
|
|
@@ -419,12 +429,7 @@ class CDRForcing(BaseModel):
|
|
|
419
429
|
start = start or self.start_time
|
|
420
430
|
end = end or self.end_time
|
|
421
431
|
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
if release_names == INCLUDE_ALL_RELEASE_NAMES:
|
|
425
|
-
release_names = valid_release_names
|
|
426
|
-
|
|
427
|
-
_validate_release_input(release_names, valid_release_names)
|
|
432
|
+
release_names = _validate_release_names(release_names, self.releases)
|
|
428
433
|
|
|
429
434
|
data = self.ds["cdr_volume"]
|
|
430
435
|
|
|
@@ -440,9 +445,9 @@ class CDRForcing(BaseModel):
|
|
|
440
445
|
def plot_tracer_concentration(
|
|
441
446
|
self,
|
|
442
447
|
tracer_name: str,
|
|
443
|
-
start=None,
|
|
444
|
-
end=None,
|
|
445
|
-
release_names=INCLUDE_ALL_RELEASE_NAMES,
|
|
448
|
+
start: datetime | None = None,
|
|
449
|
+
end: datetime | None = None,
|
|
450
|
+
release_names: list[str] | str = INCLUDE_ALL_RELEASE_NAMES,
|
|
446
451
|
):
|
|
447
452
|
"""Plot the concentration of a given tracer for each specified release within
|
|
448
453
|
the given time range.
|
|
@@ -476,12 +481,7 @@ class CDRForcing(BaseModel):
|
|
|
476
481
|
start = start or self.start_time
|
|
477
482
|
end = end or self.end_time
|
|
478
483
|
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
if release_names == INCLUDE_ALL_RELEASE_NAMES:
|
|
482
|
-
release_names = valid_release_names
|
|
483
|
-
|
|
484
|
-
_validate_release_input(release_names, valid_release_names)
|
|
484
|
+
release_names = _validate_release_names(release_names, self.releases)
|
|
485
485
|
|
|
486
486
|
tracer_names = list(self.ds["tracer_name"].values)
|
|
487
487
|
if tracer_name not in tracer_names:
|
|
@@ -511,9 +511,9 @@ class CDRForcing(BaseModel):
|
|
|
511
511
|
def plot_tracer_flux(
|
|
512
512
|
self,
|
|
513
513
|
tracer_name: str,
|
|
514
|
-
start=None,
|
|
515
|
-
end=None,
|
|
516
|
-
release_names=INCLUDE_ALL_RELEASE_NAMES,
|
|
514
|
+
start: datetime | None = None,
|
|
515
|
+
end: datetime | None = None,
|
|
516
|
+
release_names: list[str] | str = INCLUDE_ALL_RELEASE_NAMES,
|
|
517
517
|
):
|
|
518
518
|
"""Plot the flux of a given tracer for each specified release within the given
|
|
519
519
|
time range.
|
|
@@ -547,12 +547,7 @@ class CDRForcing(BaseModel):
|
|
|
547
547
|
start = start or self.start_time
|
|
548
548
|
end = end or self.end_time
|
|
549
549
|
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
if release_names == INCLUDE_ALL_RELEASE_NAMES:
|
|
553
|
-
release_names = valid_release_names
|
|
554
|
-
|
|
555
|
-
_validate_release_input(release_names, valid_release_names)
|
|
550
|
+
release_names = _validate_release_names(release_names, self.releases)
|
|
556
551
|
|
|
557
552
|
tracer_names = list(self.ds["tracer_name"].values)
|
|
558
553
|
if tracer_name not in tracer_names:
|
|
@@ -577,7 +572,10 @@ class CDRForcing(BaseModel):
|
|
|
577
572
|
def _plot_line(self, data, release_names, start, end, title="", ylabel=""):
|
|
578
573
|
"""Plots a line graph for the specified releases and time range."""
|
|
579
574
|
valid_release_names = [r.name for r in self.releases]
|
|
580
|
-
|
|
575
|
+
if len(valid_release_names) > MAX_DISTINCT_COLORS:
|
|
576
|
+
colors = assign_category_colors(release_names)
|
|
577
|
+
else:
|
|
578
|
+
colors = assign_category_colors(valid_release_names)
|
|
581
579
|
|
|
582
580
|
fig, ax = plt.subplots(1, 1, figsize=(7, 4))
|
|
583
581
|
for name in release_names:
|
|
@@ -596,7 +594,9 @@ class CDRForcing(BaseModel):
|
|
|
596
594
|
ax.set(title=title, ylabel=ylabel, xlabel="time")
|
|
597
595
|
ax.set_xlim([start, end])
|
|
598
596
|
|
|
599
|
-
def plot_locations(
|
|
597
|
+
def plot_locations(
|
|
598
|
+
self, release_names: list[str] | str = INCLUDE_ALL_RELEASE_NAMES
|
|
599
|
+
):
|
|
600
600
|
"""Plot centers of release locations in top-down view.
|
|
601
601
|
|
|
602
602
|
Parameters
|
|
@@ -619,12 +619,7 @@ class CDRForcing(BaseModel):
|
|
|
619
619
|
"A grid must be provided for plotting. Please pass a valid `Grid` object."
|
|
620
620
|
)
|
|
621
621
|
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
if release_names == "all":
|
|
625
|
-
release_names = valid_release_names
|
|
626
|
-
|
|
627
|
-
_validate_release_input(release_names, valid_release_names)
|
|
622
|
+
release_names = _validate_release_names(release_names, self.releases)
|
|
628
623
|
|
|
629
624
|
lon_deg = self.grid.ds.lon_rho
|
|
630
625
|
lat_deg = self.grid.ds.lat_rho
|
|
@@ -645,12 +640,22 @@ class CDRForcing(BaseModel):
|
|
|
645
640
|
plot_2d_horizontal_field(field, kwargs=kwargs, ax=ax, add_colorbar=False)
|
|
646
641
|
|
|
647
642
|
# Plot release locations
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
643
|
+
valid_release_names = [r.name for r in self.releases]
|
|
644
|
+
if len(valid_release_names) > MAX_DISTINCT_COLORS:
|
|
645
|
+
colors = assign_category_colors(release_names)
|
|
646
|
+
else:
|
|
647
|
+
colors = assign_category_colors(valid_release_names)
|
|
648
|
+
plot_location(
|
|
649
|
+
grid_ds=self.grid.ds,
|
|
650
|
+
points={
|
|
651
|
+
name: {
|
|
652
|
+
"lat": self.releases[name].lat,
|
|
653
|
+
"lon": self.releases[name].lon,
|
|
654
|
+
"color": colors.get(name, "k"),
|
|
655
|
+
}
|
|
656
|
+
for name in release_names
|
|
657
|
+
},
|
|
652
658
|
ax=ax,
|
|
653
|
-
colors=colors,
|
|
654
659
|
)
|
|
655
660
|
|
|
656
661
|
def plot_distribution(self, release_name: str, mark_release_center: bool = True):
|
|
@@ -680,8 +685,13 @@ class CDRForcing(BaseModel):
|
|
|
680
685
|
"A grid must be provided for plotting. Please pass a valid `Grid` object."
|
|
681
686
|
)
|
|
682
687
|
|
|
683
|
-
|
|
684
|
-
|
|
688
|
+
if not isinstance(release_name, str):
|
|
689
|
+
raise ValueError(
|
|
690
|
+
f"Only a single release name (string) is allowed. Got: {release_name!r}"
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
release_name = _validate_release_names([release_name], self.releases)[0]
|
|
694
|
+
|
|
685
695
|
release = self.releases[release_name]
|
|
686
696
|
|
|
687
697
|
# Prepare grid coordinates
|
|
@@ -713,8 +723,16 @@ class CDRForcing(BaseModel):
|
|
|
713
723
|
title="Depth-integrated distribution",
|
|
714
724
|
)
|
|
715
725
|
if mark_release_center:
|
|
716
|
-
|
|
717
|
-
|
|
726
|
+
plot_location(
|
|
727
|
+
grid_ds=self.grid.ds,
|
|
728
|
+
points={
|
|
729
|
+
release.name: {
|
|
730
|
+
"lat": release.lat,
|
|
731
|
+
"lon": release.lon,
|
|
732
|
+
}
|
|
733
|
+
},
|
|
734
|
+
ax=ax0,
|
|
735
|
+
include_legend=False,
|
|
718
736
|
)
|
|
719
737
|
|
|
720
738
|
# Spread horizontal Gaussian field into the vertical
|
|
@@ -828,106 +846,39 @@ class CDRForcing(BaseModel):
|
|
|
828
846
|
return cls(grid=grid, **params)
|
|
829
847
|
|
|
830
848
|
|
|
831
|
-
def
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
This method ensures that the `releases` parameter is either a single release name (string) or a list
|
|
836
|
-
of release names (strings), and checks that each release exists in the set of valid releases.
|
|
837
|
-
|
|
838
|
-
Parameters
|
|
839
|
-
----------
|
|
840
|
-
releases : str or list of str
|
|
841
|
-
A single release name as a string, or a list of release names (strings) to validate.
|
|
842
|
-
|
|
843
|
-
list_allowed : bool, optional
|
|
844
|
-
If `True`, a list of release names is allowed. If `False`, only a single release name (string)
|
|
845
|
-
is allowed. Default is `True`.
|
|
846
|
-
|
|
847
|
-
Raises
|
|
848
|
-
------
|
|
849
|
-
ValueError
|
|
850
|
-
If `releases` is not a string or list of strings, or if any release name is invalid (not in `self.releases`).
|
|
851
|
-
|
|
852
|
-
Notes
|
|
853
|
-
-----
|
|
854
|
-
This method checks that the `releases` input is in a valid format (either a string or a list of strings),
|
|
855
|
-
and ensures each release is present in the set of valid releases defined in `self.releases`. Invalid releases
|
|
856
|
-
are reported in the error message.
|
|
857
|
-
|
|
858
|
-
If `list_allowed` is set to `False`, only a single release name (string) will be accepted. Otherwise, a
|
|
859
|
-
list of release names is also acceptable.
|
|
849
|
+
def _validate_release_names(
|
|
850
|
+
release_names: list[str] | str, releases: ReleaseCollector
|
|
851
|
+
) -> list[str]:
|
|
860
852
|
"""
|
|
861
|
-
|
|
862
|
-
if not list_allowed and not isinstance(releases, str):
|
|
863
|
-
raise ValueError(
|
|
864
|
-
f"Only a single release name (string) is allowed. Got: {releases}"
|
|
865
|
-
)
|
|
866
|
-
|
|
867
|
-
if isinstance(releases, str):
|
|
868
|
-
releases = [releases] # Convert to list if a single string is provided
|
|
869
|
-
elif isinstance(releases, list):
|
|
870
|
-
if not all(isinstance(r, str) for r in releases):
|
|
871
|
-
raise ValueError("All elements in `releases` list must be strings.")
|
|
872
|
-
else:
|
|
873
|
-
raise ValueError(
|
|
874
|
-
"`releases` should be a string (single release name) or a list of strings (release names)."
|
|
875
|
-
)
|
|
876
|
-
|
|
877
|
-
# Validate that the specified releases exist in self.releases
|
|
878
|
-
invalid_releases = [
|
|
879
|
-
release for release in releases if release not in valid_releases
|
|
880
|
-
]
|
|
881
|
-
if invalid_releases:
|
|
882
|
-
raise ValueError(f"Invalid releases: {', '.join(invalid_releases)}")
|
|
853
|
+
Validate and filter a list of release names.
|
|
883
854
|
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
"""Returns a dictionary of colors for the valid releases, based on a consistent
|
|
887
|
-
colormap.
|
|
855
|
+
Ensures that each release name exists in `releases` and limits the list
|
|
856
|
+
to `MAX_RELEASES_TO_PLOT` entries with a warning if truncated.
|
|
888
857
|
|
|
889
858
|
Parameters
|
|
890
859
|
----------
|
|
891
|
-
|
|
892
|
-
|
|
860
|
+
release_names : list of str or INCLUDE_ALL_RELEASE_NAMES
|
|
861
|
+
Names of releases to plot, or sentinel to include all.
|
|
862
|
+
releases : ReleaseCollector
|
|
863
|
+
Object containing valid release names.
|
|
893
864
|
|
|
894
865
|
Returns
|
|
895
866
|
-------
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
assigned based on the order of releases in the valid releases list.
|
|
867
|
+
list of str
|
|
868
|
+
Validated and truncated list of release names.
|
|
899
869
|
|
|
900
870
|
Raises
|
|
901
871
|
------
|
|
902
872
|
ValueError
|
|
903
|
-
If
|
|
904
|
-
|
|
905
|
-
Notes
|
|
906
|
-
-----
|
|
907
|
-
The colormap is chosen dynamically based on the number of valid releases:
|
|
908
|
-
|
|
909
|
-
- If there are 10 or fewer releases, the "tab10" colormap is used.
|
|
910
|
-
- If there are more than 10 but fewer than or equal to 20 releases, the "tab20" colormap is used.
|
|
911
|
-
- For more than 20 releases, the "tab20b" colormap is used.
|
|
873
|
+
If any names are invalid.
|
|
912
874
|
"""
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
# Ensure the number of releases doesn't exceed the available colormap capacity
|
|
922
|
-
if len(valid_releases) > color_map.N:
|
|
923
|
-
raise ValueError(
|
|
924
|
-
f"Too many releases. The selected colormap supports up to {color_map.N} releases."
|
|
925
|
-
)
|
|
926
|
-
|
|
927
|
-
# Create a dictionary of colors based on the release indices
|
|
928
|
-
colors = {name: color_map(i) for i, name in enumerate(valid_releases)}
|
|
929
|
-
|
|
930
|
-
return colors
|
|
875
|
+
return validate_names(
|
|
876
|
+
release_names,
|
|
877
|
+
[r.name for r in releases],
|
|
878
|
+
INCLUDE_ALL_RELEASE_NAMES,
|
|
879
|
+
MAX_RELEASES_TO_PLOT,
|
|
880
|
+
label="release",
|
|
881
|
+
)
|
|
931
882
|
|
|
932
883
|
|
|
933
884
|
def _validate_release_location(grid, release: Release):
|
|
@@ -1106,73 +1057,3 @@ def _map_3d_gaussian(
|
|
|
1106
1057
|
distribution_3d /= distribution_3d.sum()
|
|
1107
1058
|
|
|
1108
1059
|
return distribution_3d
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
def _plot_location(
|
|
1112
|
-
grid: Grid,
|
|
1113
|
-
releases: ReleaseCollector,
|
|
1114
|
-
ax: Axes,
|
|
1115
|
-
colors: dict[str, tuple] | None = None,
|
|
1116
|
-
include_legend: bool = True,
|
|
1117
|
-
) -> None:
|
|
1118
|
-
"""Plot the center location of each release on a top-down map view.
|
|
1119
|
-
|
|
1120
|
-
Each release is represented as a point on the map, with its color
|
|
1121
|
-
determined by the `colors` dictionary.
|
|
1122
|
-
|
|
1123
|
-
Parameters
|
|
1124
|
-
----------
|
|
1125
|
-
grid : Grid
|
|
1126
|
-
The grid object defining the spatial extent and coordinate system for the plot.
|
|
1127
|
-
|
|
1128
|
-
releases : ReleaseCollector
|
|
1129
|
-
Collection of `Release` objects to plot. Each `Release` must have `.lat`, `.lon`,
|
|
1130
|
-
and `.name` attributes.
|
|
1131
|
-
|
|
1132
|
-
ax : matplotlib.axes.Axes
|
|
1133
|
-
The Matplotlib axis object to plot on.
|
|
1134
|
-
|
|
1135
|
-
colors : dict of str to tuple, optional
|
|
1136
|
-
Optional dictionary mapping release names to RGBA color tuples. If not provided,
|
|
1137
|
-
all releases are plotted in a default color (`"#dd1c77"`).
|
|
1138
|
-
|
|
1139
|
-
include_legend : bool, default True
|
|
1140
|
-
Whether to include a legend showing release names.
|
|
1141
|
-
|
|
1142
|
-
Returns
|
|
1143
|
-
-------
|
|
1144
|
-
None
|
|
1145
|
-
"""
|
|
1146
|
-
lon_deg = grid.ds.lon_rho
|
|
1147
|
-
lat_deg = grid.ds.lat_rho
|
|
1148
|
-
if grid.straddle:
|
|
1149
|
-
lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
|
|
1150
|
-
trans = get_projection(lon_deg, lat_deg)
|
|
1151
|
-
|
|
1152
|
-
proj = ccrs.PlateCarree()
|
|
1153
|
-
|
|
1154
|
-
for release in releases:
|
|
1155
|
-
# transform coordinates to projected space
|
|
1156
|
-
transformed_lon, transformed_lat = trans.transform_point(
|
|
1157
|
-
release.lon,
|
|
1158
|
-
release.lat,
|
|
1159
|
-
proj,
|
|
1160
|
-
)
|
|
1161
|
-
|
|
1162
|
-
if colors is not None:
|
|
1163
|
-
color = colors[release.name]
|
|
1164
|
-
else:
|
|
1165
|
-
color = "k"
|
|
1166
|
-
|
|
1167
|
-
ax.plot(
|
|
1168
|
-
transformed_lon,
|
|
1169
|
-
transformed_lat,
|
|
1170
|
-
marker="x",
|
|
1171
|
-
markersize=8,
|
|
1172
|
-
markeredgewidth=2,
|
|
1173
|
-
label=release.name,
|
|
1174
|
-
color=color,
|
|
1175
|
-
)
|
|
1176
|
-
|
|
1177
|
-
if include_legend:
|
|
1178
|
-
ax.legend(loc="center left", bbox_to_anchor=(1.1, 0.5))
|
|
@@ -4,14 +4,18 @@ from dataclasses import dataclass, field
|
|
|
4
4
|
from datetime import datetime
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
|
|
7
|
-
import cartopy.crs as ccrs
|
|
8
|
-
import matplotlib.cm as cm
|
|
9
7
|
import matplotlib.pyplot as plt
|
|
10
8
|
import numpy as np
|
|
11
9
|
import xarray as xr
|
|
12
10
|
|
|
13
11
|
from roms_tools import Grid
|
|
14
|
-
from roms_tools.
|
|
12
|
+
from roms_tools.constants import MAX_DISTINCT_COLORS
|
|
13
|
+
from roms_tools.plot import (
|
|
14
|
+
assign_category_colors,
|
|
15
|
+
get_projection,
|
|
16
|
+
plot_2d_horizontal_field,
|
|
17
|
+
plot_location,
|
|
18
|
+
)
|
|
15
19
|
from roms_tools.setup.datasets import (
|
|
16
20
|
DaiRiverDataset,
|
|
17
21
|
get_indices_of_nearest_grid_cell_for_rivers,
|
|
@@ -26,10 +30,14 @@ from roms_tools.setup.utils import (
|
|
|
26
30
|
get_variable_metadata,
|
|
27
31
|
substitute_nans_by_fillvalue,
|
|
28
32
|
to_dict,
|
|
33
|
+
validate_names,
|
|
29
34
|
write_to_yaml,
|
|
30
35
|
)
|
|
31
36
|
from roms_tools.utils import save_datasets
|
|
32
37
|
|
|
38
|
+
INCLUDE_ALL_RIVER_NAMES = "all"
|
|
39
|
+
MAX_RIVERS_TO_PLOT = 20 # must be <= MAX_DISTINCT_COLORS
|
|
40
|
+
|
|
33
41
|
|
|
34
42
|
@dataclass(kw_only=True)
|
|
35
43
|
class RiverForcing:
|
|
@@ -672,8 +680,24 @@ class RiverForcing:
|
|
|
672
680
|
"`convert_to_climatology = 'if_any_missing'` to automatically fill missing values with climatological data."
|
|
673
681
|
)
|
|
674
682
|
|
|
675
|
-
def plot_locations(self):
|
|
676
|
-
"""Plots the original and updated river locations on a map projection.
|
|
683
|
+
def plot_locations(self, river_names: list[str] | str = INCLUDE_ALL_RIVER_NAMES):
|
|
684
|
+
"""Plots the original and updated river locations on a map projection.
|
|
685
|
+
|
|
686
|
+
Parameters
|
|
687
|
+
----------
|
|
688
|
+
river_names : list[str], or str, optional
|
|
689
|
+
A list of release names to plot.
|
|
690
|
+
If a string equal to "all", all rivers will be plotted.
|
|
691
|
+
Defaults to "all".
|
|
692
|
+
|
|
693
|
+
"""
|
|
694
|
+
valid_river_names = list(self.indices.keys())
|
|
695
|
+
river_names = _validate_river_names(river_names, valid_river_names)
|
|
696
|
+
if len(valid_river_names) > MAX_DISTINCT_COLORS:
|
|
697
|
+
colors = assign_category_colors(river_names)
|
|
698
|
+
else:
|
|
699
|
+
colors = assign_category_colors(valid_river_names)
|
|
700
|
+
|
|
677
701
|
field = self.grid.ds.mask_rho
|
|
678
702
|
lon_deg = self.grid.ds.lon_rho
|
|
679
703
|
lat_deg = self.grid.ds.lat_rho
|
|
@@ -695,53 +719,37 @@ class RiverForcing:
|
|
|
695
719
|
for ax in axs:
|
|
696
720
|
plot_2d_horizontal_field(field, kwargs=kwargs, ax=ax, add_colorbar=False)
|
|
697
721
|
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
proj,
|
|
721
|
-
)
|
|
722
|
-
|
|
723
|
-
if name not in added_labels:
|
|
724
|
-
added_labels.add(name)
|
|
725
|
-
label = name
|
|
726
|
-
else:
|
|
727
|
-
label = "_None"
|
|
728
|
-
|
|
729
|
-
ax.plot(
|
|
730
|
-
transformed_lon,
|
|
731
|
-
transformed_lat,
|
|
732
|
-
marker="x",
|
|
733
|
-
markersize=8,
|
|
734
|
-
markeredgewidth=2,
|
|
735
|
-
label=label,
|
|
736
|
-
color=colors[name],
|
|
737
|
-
)
|
|
722
|
+
points = {}
|
|
723
|
+
for j, (ax, indices) in enumerate(
|
|
724
|
+
[(ax, ind) for ax, ind in zip(axs, [self.original_indices, self.indices])]
|
|
725
|
+
):
|
|
726
|
+
for name in river_names:
|
|
727
|
+
if name in indices:
|
|
728
|
+
for i, (eta_index, xi_index) in enumerate(indices[name]):
|
|
729
|
+
lon = self.grid.ds.lon_rho[eta_index, xi_index].item()
|
|
730
|
+
lat = self.grid.ds.lat_rho[eta_index, xi_index].item()
|
|
731
|
+
key = name if i == 0 else f"_{name}_{i}"
|
|
732
|
+
points[key] = {
|
|
733
|
+
"lon": lon,
|
|
734
|
+
"lat": lat,
|
|
735
|
+
"color": colors[name],
|
|
736
|
+
}
|
|
737
|
+
|
|
738
|
+
plot_location(
|
|
739
|
+
grid_ds=self.grid.ds,
|
|
740
|
+
points=points,
|
|
741
|
+
ax=ax,
|
|
742
|
+
include_legend=(j == 1),
|
|
743
|
+
)
|
|
738
744
|
|
|
739
745
|
axs[0].set_title("Original river locations")
|
|
740
746
|
axs[1].set_title("Updated river locations")
|
|
741
747
|
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
748
|
+
def plot(
|
|
749
|
+
self,
|
|
750
|
+
var_name: str = "river_volume",
|
|
751
|
+
river_names: list[str] | str = INCLUDE_ALL_RIVER_NAMES,
|
|
752
|
+
):
|
|
745
753
|
"""Plots the river flux (e.g., volume, temperature, or salinity) over time for
|
|
746
754
|
all rivers.
|
|
747
755
|
|
|
@@ -791,8 +799,19 @@ class RiverForcing:
|
|
|
791
799
|
- 'river_diazFe' : river diazFe (from river_tracer).
|
|
792
800
|
|
|
793
801
|
The default is 'river_volume'.
|
|
802
|
+
|
|
803
|
+
river_names : list[str], or str, optional
|
|
804
|
+
A list of release names to plot.
|
|
805
|
+
If a string equal to "all", all rivers will be plotted.
|
|
806
|
+
Defaults to "all".
|
|
807
|
+
|
|
794
808
|
"""
|
|
795
|
-
|
|
809
|
+
valid_river_names = list(self.indices.keys())
|
|
810
|
+
river_names = _validate_river_names(river_names, valid_river_names)
|
|
811
|
+
if len(valid_river_names) > MAX_DISTINCT_COLORS:
|
|
812
|
+
colors = assign_category_colors(river_names)
|
|
813
|
+
else:
|
|
814
|
+
colors = assign_category_colors(valid_river_names)
|
|
796
815
|
|
|
797
816
|
if self.climatology:
|
|
798
817
|
xticks = self.ds.month.values
|
|
@@ -814,15 +833,19 @@ class RiverForcing:
|
|
|
814
833
|
units = d[var_name_wo_river]["units"]
|
|
815
834
|
long_name = f"River {d[var_name_wo_river]['long_name']}"
|
|
816
835
|
|
|
817
|
-
|
|
836
|
+
fig, ax = plt.subplots(1, 1, figsize=(9, 5))
|
|
837
|
+
for name in river_names:
|
|
838
|
+
nriver = np.where(self.ds["river_name"].values == name)[0].item()
|
|
839
|
+
|
|
818
840
|
ax.plot(
|
|
819
841
|
xticks,
|
|
820
|
-
field.isel(nriver=
|
|
842
|
+
field.isel(nriver=nriver),
|
|
821
843
|
marker="x",
|
|
822
844
|
markersize=8,
|
|
823
845
|
markeredgewidth=2,
|
|
824
846
|
lw=2,
|
|
825
|
-
label=
|
|
847
|
+
label=name,
|
|
848
|
+
color=colors[name],
|
|
826
849
|
)
|
|
827
850
|
|
|
828
851
|
ax.set_xticks(xticks)
|
|
@@ -965,3 +988,38 @@ def check_river_locations_are_along_coast(mask, indices):
|
|
|
965
988
|
raise ValueError(
|
|
966
989
|
f"River `{key}` is not located on the coast at grid cell ({eta_rho}, {xi_rho})."
|
|
967
990
|
)
|
|
991
|
+
|
|
992
|
+
|
|
993
|
+
def _validate_river_names(
|
|
994
|
+
river_names: list[str] | str, valid_river_names: list[str]
|
|
995
|
+
) -> list[str]:
|
|
996
|
+
"""
|
|
997
|
+
Validate and filter a list of river names.
|
|
998
|
+
|
|
999
|
+
Ensures that each river name exists in `valid_river_names` and limits the list
|
|
1000
|
+
to `MAX_RIVERS_TO_PLOT` entries with a warning if truncated.
|
|
1001
|
+
|
|
1002
|
+
Parameters
|
|
1003
|
+
----------
|
|
1004
|
+
river_names : list of str or INCLUDE_ALL_RIVER_NAMES
|
|
1005
|
+
Names of rivers to plot, or sentinel to include all.
|
|
1006
|
+
valid_river_names : list of str
|
|
1007
|
+
List of valid river names.
|
|
1008
|
+
|
|
1009
|
+
Returns
|
|
1010
|
+
-------
|
|
1011
|
+
list of str
|
|
1012
|
+
Validated and truncated list of river names.
|
|
1013
|
+
|
|
1014
|
+
Raises
|
|
1015
|
+
------
|
|
1016
|
+
ValueError
|
|
1017
|
+
If any names are invalid.
|
|
1018
|
+
"""
|
|
1019
|
+
return validate_names(
|
|
1020
|
+
river_names,
|
|
1021
|
+
valid_river_names,
|
|
1022
|
+
INCLUDE_ALL_RIVER_NAMES,
|
|
1023
|
+
MAX_RIVERS_TO_PLOT,
|
|
1024
|
+
label="river",
|
|
1025
|
+
)
|
roms_tools/setup/utils.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import importlib.metadata
|
|
2
|
+
import logging
|
|
2
3
|
from collections.abc import Sequence
|
|
3
4
|
from dataclasses import asdict, fields, is_dataclass
|
|
4
5
|
from datetime import datetime
|
|
@@ -1794,3 +1795,59 @@ def to_float(val):
|
|
|
1794
1795
|
if isinstance(val, list):
|
|
1795
1796
|
return [float(v) for v in val]
|
|
1796
1797
|
return float(val)
|
|
1798
|
+
|
|
1799
|
+
|
|
1800
|
+
def validate_names(
|
|
1801
|
+
names: list[str] | str,
|
|
1802
|
+
valid_names: list[str],
|
|
1803
|
+
include_all_sentinel: str,
|
|
1804
|
+
max_to_plot: int,
|
|
1805
|
+
label: str = "item",
|
|
1806
|
+
) -> list[str]:
|
|
1807
|
+
"""
|
|
1808
|
+
Generic validation and filtering for a list of names.
|
|
1809
|
+
|
|
1810
|
+
Parameters
|
|
1811
|
+
----------
|
|
1812
|
+
names : list of str or sentinel
|
|
1813
|
+
Names to validate, or sentinel value to include all valid names.
|
|
1814
|
+
valid_names : list of str
|
|
1815
|
+
List of valid names to check against.
|
|
1816
|
+
include_all_sentinel : str
|
|
1817
|
+
Sentinel value to indicate all names should be included.
|
|
1818
|
+
max_to_plot : int
|
|
1819
|
+
Maximum number of names to return.
|
|
1820
|
+
label : str, default "item"
|
|
1821
|
+
Label to use in error and warning messages.
|
|
1822
|
+
|
|
1823
|
+
Returns
|
|
1824
|
+
-------
|
|
1825
|
+
list of str
|
|
1826
|
+
Validated and possibly truncated list of names.
|
|
1827
|
+
|
|
1828
|
+
Raises
|
|
1829
|
+
------
|
|
1830
|
+
ValueError
|
|
1831
|
+
If any names are invalid or input is not a list of strings.
|
|
1832
|
+
"""
|
|
1833
|
+
if names == include_all_sentinel:
|
|
1834
|
+
names = valid_names
|
|
1835
|
+
|
|
1836
|
+
if isinstance(names, list):
|
|
1837
|
+
if not all(isinstance(n, str) for n in names):
|
|
1838
|
+
raise ValueError(f"All elements in `{label}_names` must be strings.")
|
|
1839
|
+
else:
|
|
1840
|
+
raise ValueError(f"`{label}_names` should be a list of strings.")
|
|
1841
|
+
|
|
1842
|
+
invalid = [n for n in names if n not in valid_names]
|
|
1843
|
+
if invalid:
|
|
1844
|
+
raise ValueError(f"Invalid {label}s: {', '.join(invalid)}")
|
|
1845
|
+
|
|
1846
|
+
if len(names) > max_to_plot:
|
|
1847
|
+
logging.warning(
|
|
1848
|
+
f"Only the first {max_to_plot} {label}s will be plotted "
|
|
1849
|
+
f"(received {len(names)})."
|
|
1850
|
+
)
|
|
1851
|
+
names = names[:max_to_plot]
|
|
1852
|
+
|
|
1853
|
+
return names
|
|
@@ -9,7 +9,7 @@ from pydantic import ValidationError
|
|
|
9
9
|
|
|
10
10
|
from conftest import calculate_file_hash
|
|
11
11
|
from roms_tools import CDRForcing, Grid, TracerPerturbation, VolumeRelease
|
|
12
|
-
from roms_tools.constants import NUM_TRACERS
|
|
12
|
+
from roms_tools.constants import MAX_DISTINCT_COLORS, NUM_TRACERS
|
|
13
13
|
from roms_tools.setup.cdr_forcing import (
|
|
14
14
|
CDRForcingDatasetBuilder,
|
|
15
15
|
ReleaseCollector,
|
|
@@ -725,6 +725,8 @@ class TestCDRForcing:
|
|
|
725
725
|
rot=0,
|
|
726
726
|
N=3,
|
|
727
727
|
)
|
|
728
|
+
self.grid = grid
|
|
729
|
+
|
|
728
730
|
grid_that_straddles = Grid(
|
|
729
731
|
nx=18,
|
|
730
732
|
ny=18,
|
|
@@ -817,8 +819,13 @@ class TestCDRForcing:
|
|
|
817
819
|
self.volume_release_cdr_forcing_with_straddling_grid,
|
|
818
820
|
]:
|
|
819
821
|
cdr.plot_volume_flux()
|
|
822
|
+
cdr.plot_volume_flux(release_names=["first_release"])
|
|
823
|
+
|
|
820
824
|
cdr.plot_tracer_concentration("ALK")
|
|
825
|
+
cdr.plot_tracer_concentration("ALK", release_names=["first_release"])
|
|
826
|
+
|
|
821
827
|
cdr.plot_tracer_concentration("DIC")
|
|
828
|
+
cdr.plot_tracer_concentration("DIC", release_names=["first_release"])
|
|
822
829
|
|
|
823
830
|
self.volume_release_cdr_forcing.plot_locations()
|
|
824
831
|
self.volume_release_cdr_forcing.plot_locations(release_names=["first_release"])
|
|
@@ -830,13 +837,56 @@ class TestCDRForcing:
|
|
|
830
837
|
self.tracer_perturbation_cdr_forcing_with_straddling_grid,
|
|
831
838
|
]:
|
|
832
839
|
cdr.plot_tracer_flux("ALK")
|
|
840
|
+
cdr.plot_tracer_flux("ALK", release_names=["first_release"])
|
|
841
|
+
|
|
833
842
|
cdr.plot_tracer_flux("DIC")
|
|
843
|
+
cdr.plot_tracer_flux("DIC", release_names=["first_release"])
|
|
834
844
|
|
|
835
845
|
self.tracer_perturbation_cdr_forcing.plot_locations()
|
|
836
846
|
self.tracer_perturbation_cdr_forcing.plot_locations(
|
|
837
847
|
release_names=["first_release"]
|
|
838
848
|
)
|
|
839
849
|
|
|
850
|
+
def test_plot_max_releases(self, caplog):
|
|
851
|
+
# Prepare releases with more than MAX_DISTINCT_COLORS unique names
|
|
852
|
+
releases = []
|
|
853
|
+
for i in range(MAX_DISTINCT_COLORS + 1):
|
|
854
|
+
release = self.first_volume_release.__replace__(name=f"release_{i}")
|
|
855
|
+
releases.append(release)
|
|
856
|
+
|
|
857
|
+
# Construct a CDRForcing object with too many releases to plot
|
|
858
|
+
cdr_forcing = CDRForcing(
|
|
859
|
+
grid=self.grid,
|
|
860
|
+
start_time=self.start_time,
|
|
861
|
+
end_time=self.end_time,
|
|
862
|
+
releases=releases,
|
|
863
|
+
)
|
|
864
|
+
|
|
865
|
+
release_names = [r.name for r in releases]
|
|
866
|
+
|
|
867
|
+
plot_methods_with_release_names = [
|
|
868
|
+
cdr_forcing.plot_locations,
|
|
869
|
+
cdr_forcing.plot_volume_flux,
|
|
870
|
+
]
|
|
871
|
+
|
|
872
|
+
for plot_func in plot_methods_with_release_names:
|
|
873
|
+
caplog.clear()
|
|
874
|
+
with caplog.at_level("WARNING"):
|
|
875
|
+
plot_func(release_names=release_names)
|
|
876
|
+
assert any(
|
|
877
|
+
f"Only the first {MAX_DISTINCT_COLORS} releases will be plotted"
|
|
878
|
+
in message
|
|
879
|
+
for message in caplog.messages
|
|
880
|
+
), f"Warning not raised by {plot_func.__name__}"
|
|
881
|
+
|
|
882
|
+
with caplog.at_level("WARNING"):
|
|
883
|
+
cdr_forcing.plot_locations(release_names=release_names)
|
|
884
|
+
|
|
885
|
+
assert any(
|
|
886
|
+
f"Only the first {MAX_DISTINCT_COLORS} releases will be plotted" in message
|
|
887
|
+
for message in caplog.messages
|
|
888
|
+
)
|
|
889
|
+
|
|
840
890
|
@pytest.mark.skipif(xesmf is None, reason="xesmf required")
|
|
841
891
|
def test_plot_distribution(self):
|
|
842
892
|
self.volume_release_cdr_forcing.plot_distribution("first_release")
|
|
@@ -856,10 +906,10 @@ class TestCDRForcing:
|
|
|
856
906
|
with pytest.raises(ValueError, match="Invalid releases"):
|
|
857
907
|
self.volume_release_cdr_forcing.plot_locations(release_names=["fake"])
|
|
858
908
|
|
|
859
|
-
with pytest.raises(ValueError, match="should be a
|
|
909
|
+
with pytest.raises(ValueError, match="should be a list"):
|
|
860
910
|
self.volume_release_cdr_forcing.plot_locations(release_names=4)
|
|
861
911
|
|
|
862
|
-
with pytest.raises(ValueError, match="
|
|
912
|
+
with pytest.raises(ValueError, match="must be strings"):
|
|
863
913
|
self.volume_release_cdr_forcing.plot_locations(release_names=[4])
|
|
864
914
|
|
|
865
915
|
def test_cdr_forcing_save(self, tmp_path):
|
|
@@ -9,6 +9,7 @@ import xarray as xr
|
|
|
9
9
|
|
|
10
10
|
from conftest import calculate_file_hash
|
|
11
11
|
from roms_tools import Grid, RiverForcing
|
|
12
|
+
from roms_tools.constants import MAX_DISTINCT_COLORS
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
@pytest.fixture
|
|
@@ -57,6 +58,29 @@ def river_forcing_for_grid_that_straddles_dateline():
|
|
|
57
58
|
)
|
|
58
59
|
|
|
59
60
|
|
|
61
|
+
@pytest.fixture
|
|
62
|
+
def river_forcing_for_gulf_of_mexico():
|
|
63
|
+
"""Fixture for creating a RiverForcing object for the Gulf of Mexico with 45 rivers."""
|
|
64
|
+
grid = Grid(
|
|
65
|
+
nx=20,
|
|
66
|
+
ny=15,
|
|
67
|
+
size_x=2000,
|
|
68
|
+
size_y=1500,
|
|
69
|
+
center_lon=-89,
|
|
70
|
+
center_lat=24,
|
|
71
|
+
rot=0,
|
|
72
|
+
N=3,
|
|
73
|
+
)
|
|
74
|
+
start_time = datetime(2012, 1, 1)
|
|
75
|
+
end_time = datetime(2012, 1, 31)
|
|
76
|
+
|
|
77
|
+
return RiverForcing(
|
|
78
|
+
grid=grid,
|
|
79
|
+
start_time=start_time,
|
|
80
|
+
end_time=end_time,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
60
84
|
@pytest.fixture
|
|
61
85
|
def single_cell_indices():
|
|
62
86
|
# These are the indices that the `river_forcing` fixture generates automatically.
|
|
@@ -247,13 +271,46 @@ class TestRiverForcingGeneral:
|
|
|
247
271
|
)
|
|
248
272
|
|
|
249
273
|
def test_river_forcing_plot(self, river_forcing_with_bgc):
|
|
250
|
-
"""Test plot
|
|
274
|
+
"""Test plot methods with and without specifying river_names."""
|
|
275
|
+
river_names = list(river_forcing_with_bgc.indices.keys())[0:2]
|
|
276
|
+
|
|
277
|
+
# Test plot_locations
|
|
251
278
|
river_forcing_with_bgc.plot_locations()
|
|
252
|
-
river_forcing_with_bgc.
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
279
|
+
river_forcing_with_bgc.plot_locations(river_names=river_names)
|
|
280
|
+
|
|
281
|
+
# Fields to test
|
|
282
|
+
variables = [
|
|
283
|
+
"river_volume",
|
|
284
|
+
"river_temp",
|
|
285
|
+
"river_salt",
|
|
286
|
+
"river_ALK",
|
|
287
|
+
"river_PO4",
|
|
288
|
+
]
|
|
289
|
+
|
|
290
|
+
for var in variables:
|
|
291
|
+
river_forcing_with_bgc.plot(var)
|
|
292
|
+
river_forcing_with_bgc.plot(var, river_names=river_names)
|
|
293
|
+
|
|
294
|
+
def test_plot_max_releases(self, caplog, river_forcing_for_gulf_of_mexico):
|
|
295
|
+
river_names = list(river_forcing_for_gulf_of_mexico.indices.keys())
|
|
296
|
+
|
|
297
|
+
caplog.clear()
|
|
298
|
+
with caplog.at_level("WARNING"):
|
|
299
|
+
river_forcing_for_gulf_of_mexico.plot_locations()
|
|
300
|
+
assert any(
|
|
301
|
+
f"Only the first {MAX_DISTINCT_COLORS} rivers will be plotted" in message
|
|
302
|
+
for message in caplog.messages
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
with caplog.at_level("WARNING"):
|
|
306
|
+
river_forcing_for_gulf_of_mexico.plot(
|
|
307
|
+
"river_volume", river_names=river_names
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
assert any(
|
|
311
|
+
f"Only the first {MAX_DISTINCT_COLORS} rivers will be plotted" in message
|
|
312
|
+
for message in caplog.messages
|
|
313
|
+
)
|
|
257
314
|
|
|
258
315
|
@pytest.mark.parametrize(
|
|
259
316
|
"river_forcing_fixture",
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
from datetime import datetime
|
|
2
3
|
from pathlib import Path
|
|
3
4
|
|
|
@@ -7,9 +8,7 @@ import xarray as xr
|
|
|
7
8
|
from roms_tools import BoundaryForcing, Grid
|
|
8
9
|
from roms_tools.download import download_test_data
|
|
9
10
|
from roms_tools.setup.datasets import ERA5Correction
|
|
10
|
-
from roms_tools.setup.utils import
|
|
11
|
-
interpolate_from_climatology,
|
|
12
|
-
)
|
|
11
|
+
from roms_tools.setup.utils import interpolate_from_climatology, validate_names
|
|
13
12
|
|
|
14
13
|
|
|
15
14
|
def test_interpolate_from_climatology(use_dask):
|
|
@@ -71,3 +70,53 @@ def test_roundtrip_yaml(
|
|
|
71
70
|
|
|
72
71
|
filepath = Path(filepath)
|
|
73
72
|
filepath.unlink()
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
# test validate_names function
|
|
76
|
+
|
|
77
|
+
VALID_NAMES = ["a", "b", "c", "d"]
|
|
78
|
+
SENTINEL = "ALL"
|
|
79
|
+
MAX_TO_PLOT = 3
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def test_valid_names_no_truncation():
|
|
83
|
+
names = ["a", "b"]
|
|
84
|
+
result = validate_names(names, VALID_NAMES, SENTINEL, MAX_TO_PLOT, label="test")
|
|
85
|
+
assert result == names
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def test_valid_names_with_truncation(caplog):
|
|
89
|
+
names = ["a", "b", "c", "d"]
|
|
90
|
+
with caplog.at_level(logging.WARNING):
|
|
91
|
+
result = validate_names(
|
|
92
|
+
names, VALID_NAMES, SENTINEL, max_to_plot=2, label="test"
|
|
93
|
+
)
|
|
94
|
+
assert result == ["a", "b"]
|
|
95
|
+
assert "Only the first 2 tests will be plotted" in caplog.text
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def test_include_all_sentinel():
|
|
99
|
+
result = validate_names(SENTINEL, VALID_NAMES, SENTINEL, MAX_TO_PLOT, label="test")
|
|
100
|
+
assert result == VALID_NAMES[:MAX_TO_PLOT]
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def test_invalid_name_raises():
|
|
104
|
+
with pytest.raises(ValueError, match="Invalid tests: z"):
|
|
105
|
+
validate_names(["a", "z"], VALID_NAMES, SENTINEL, MAX_TO_PLOT, label="test")
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def test_non_list_input_raises():
|
|
109
|
+
with pytest.raises(ValueError, match="`test_names` should be a list of strings."):
|
|
110
|
+
validate_names("a", VALID_NAMES, SENTINEL, MAX_TO_PLOT, label="test")
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def test_non_string_elements_in_list_raises():
|
|
114
|
+
with pytest.raises(
|
|
115
|
+
ValueError, match="All elements in `test_names` must be strings."
|
|
116
|
+
):
|
|
117
|
+
validate_names(["a", 2], VALID_NAMES, SENTINEL, MAX_TO_PLOT, label="test")
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def test_custom_label_in_errors():
|
|
121
|
+
with pytest.raises(ValueError, match="Invalid foozs: z"):
|
|
122
|
+
validate_names(["z"], VALID_NAMES, SENTINEL, MAX_TO_PLOT, label="fooz")
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: roms-tools
|
|
3
|
-
Version: 3.1.
|
|
3
|
+
Version: 3.1.1
|
|
4
4
|
Summary: Tools for running and analysing UCLA-ROMS simulations
|
|
5
5
|
Author-email: Nora Loose <nora.loose@gmail.com>, Thomas Nicholas <tom@cworthy.org>, Scott Eilerman <scott.eilerman@cworthy.org>
|
|
6
6
|
License: Apache-2
|
|
@@ -36,9 +36,11 @@ Requires-Dist: numba>=0.61.2
|
|
|
36
36
|
Requires-Dist: pydantic<3,>2
|
|
37
37
|
Provides-Extra: dask
|
|
38
38
|
Requires-Dist: dask[diagnostics]; extra == "dask"
|
|
39
|
+
Requires-Dist: zarr; extra == "dask"
|
|
39
40
|
Provides-Extra: stream
|
|
40
41
|
Requires-Dist: dask[diagnostics]; extra == "stream"
|
|
41
42
|
Requires-Dist: gcsfs; extra == "stream"
|
|
43
|
+
Requires-Dist: zarr; extra == "stream"
|
|
42
44
|
Dynamic: license-file
|
|
43
45
|
|
|
44
46
|
# ROMS-Tools
|
|
@@ -1,16 +1,16 @@
|
|
|
1
1
|
ci/environment-with-xesmf.yml,sha256=1QF0gdRsjisydUNCCZTrsrybh3cCuHrwnLAT0Z1bqmk,234
|
|
2
2
|
ci/environment.yml,sha256=jAi1xo_ZoFNrWevxDRkiKIjMGm1FxzPKocTZwqToT9Y,224
|
|
3
3
|
roms_tools/__init__.py,sha256=XXDoj86gV6gP_sFeKCW0Y6amL8wfDz-iC98VFQGSkfs,1164
|
|
4
|
-
roms_tools/constants.py,sha256=
|
|
4
|
+
roms_tools/constants.py,sha256=VxhoT2dE_Urqgp7bBdxIfdFDgR7WR83hY2IE7oqIYDE,371
|
|
5
5
|
roms_tools/download.py,sha256=Yc7bi1vb0VM-099MQoT-JcPAGwhsQ4QeB0K7CzyRQMo,8372
|
|
6
|
-
roms_tools/plot.py,sha256=
|
|
6
|
+
roms_tools/plot.py,sha256=R7qjRJ6throsTtFTbNbUopuONUh93ln3kFa7mFCUUS0,39234
|
|
7
7
|
roms_tools/regrid.py,sha256=LQhhM5JpjzpgIGSPsj7mr7b_TDUgdyGf5XuQ_GQY9tg,10471
|
|
8
8
|
roms_tools/utils.py,sha256=kAVGVCeTXzi-5euhjbSMbK30aFigvLBfqDSvbQEX9Ls,25321
|
|
9
9
|
roms_tools/vertical_coordinate.py,sha256=081LONzUSe8tL9H6XniAmY1tIyehc87rMnRZ-kYQ0FI,7417
|
|
10
10
|
roms_tools/analysis/roms_output.py,sha256=L2yqhbgHgHkPsross4LgINNY0f__g2ffUMh_FY-5dlE,25223
|
|
11
11
|
roms_tools/setup/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
12
|
roms_tools/setup/boundary_forcing.py,sha256=tc8NYOEaeEWAX5zHO2xQoRrtFP03MhmVnkm-Zc2QjOg,45043
|
|
13
|
-
roms_tools/setup/cdr_forcing.py,sha256=
|
|
13
|
+
roms_tools/setup/cdr_forcing.py,sha256=DpLzeclPNvCP9Ad4f8Q2vQmdOEuHLKx7xWTEpOfjxwY,36610
|
|
14
14
|
roms_tools/setup/cdr_release.py,sha256=TEN_DFLCSJ72UBnZJ5X8Vp8crXtzZ-GjtaymsHG51so,19481
|
|
15
15
|
roms_tools/setup/datasets.py,sha256=5ffqSh0dAeaelny5ucVoYk6vkVisfFVlDocHk9baRWI,109255
|
|
16
16
|
roms_tools/setup/fill.py,sha256=eM5bFqwHcKIQCGBTPi7XOhJiSoCPYsjShbr6w10lIMs,11117
|
|
@@ -18,11 +18,11 @@ roms_tools/setup/grid.py,sha256=vtc7sZHEBScxB7GLGmj4amtZXT6jOfddz5Z_TNVsMRU,5447
|
|
|
18
18
|
roms_tools/setup/initial_conditions.py,sha256=dXuKwiPAhDDS6vJqADxJFaoYfnsY5d_8CcbOE3tMB7g,32245
|
|
19
19
|
roms_tools/setup/mask.py,sha256=MaVfTEc0YhVzuZMLFwuQ-uRKJeQT2bMl3QvVz6dq1P0,3414
|
|
20
20
|
roms_tools/setup/nesting.py,sha256=-tLnp9s_hEI7SM60xJ-fK1FKJ2PSCmHtZgD00Z_MwCo,27012
|
|
21
|
-
roms_tools/setup/river_forcing.py,sha256=
|
|
21
|
+
roms_tools/setup/river_forcing.py,sha256=MJKRZRawSK4YNzV9lBCDoEyKT4YhNRYkgWWjtA5SaCc,41900
|
|
22
22
|
roms_tools/setup/surface_forcing.py,sha256=q1UBpOfER15SoDOd1BweN-Lx4vivwtu7Z43p6ntt1NQ,29514
|
|
23
23
|
roms_tools/setup/tides.py,sha256=ofnDS5MqKI_mqV-dCxKvtw7LDXBaJEq27Zyh2uAzJ04,16287
|
|
24
24
|
roms_tools/setup/topography.py,sha256=W17vUZK1t3QE_w43r4ucKVGHHE6d6lYwECdQ_FCH2OE,14498
|
|
25
|
-
roms_tools/setup/utils.py,sha256=
|
|
25
|
+
roms_tools/setup/utils.py,sha256=qCak3jCNSnsBlLGQ3lUl94LhYZLm_0PJd_63_YeM28w,62397
|
|
26
26
|
roms_tools/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
27
27
|
roms_tools/tests/test_regrid.py,sha256=EUeFXcsIUZ3Z1usH5hPEQdkIX4b2XSiBM5zPGG3HXFo,4686
|
|
28
28
|
roms_tools/tests/test_utils.py,sha256=uDddcU-MrCV-7EBYYlThbT_WN3OnpmO-N_cSx6n0GIY,697
|
|
@@ -30,17 +30,17 @@ roms_tools/tests/test_vertical_coordinate.py,sha256=_L4FGDJGhnDbMhV7g3fc3SGoRt_1
|
|
|
30
30
|
roms_tools/tests/test_analysis/test_roms_output.py,sha256=VS2JckETcO_nEDKMpJXkQfcmgABqM8yUYDU1FAUrtEs,21818
|
|
31
31
|
roms_tools/tests/test_setup/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
32
32
|
roms_tools/tests/test_setup/test_boundary_forcing.py,sha256=DnKKvcSZLCU7r8j7iNIRqqyWPA0uRHOHca02CcmFFD0,25137
|
|
33
|
-
roms_tools/tests/test_setup/test_cdr_forcing.py,sha256=
|
|
33
|
+
roms_tools/tests/test_setup/test_cdr_forcing.py,sha256=E0aX3-U-hroxaHcV8H79OZwmI9ZGRulu-FM3UOrSUOM,36487
|
|
34
34
|
roms_tools/tests/test_setup/test_cdr_release.py,sha256=p2gONUYeZ1lbUegR8NMrDQ9ZsiTdGMru7DhERU8hvZ4,14370
|
|
35
35
|
roms_tools/tests/test_setup/test_datasets.py,sha256=A8pxkU22eouQnESqrbaxyAJpzJYeLbZKdOfUKPyY-n0,21709
|
|
36
36
|
roms_tools/tests/test_setup/test_fill.py,sha256=NvMV-k2J0fRMEKI9D7kjlaERwJ9x-XggJfMd9Vst_7U,3734
|
|
37
37
|
roms_tools/tests/test_setup/test_grid.py,sha256=sjBO6XbrvZ3dvaL31UkiGv-rpqFlH_8-m1xIqtQHGRY,22428
|
|
38
38
|
roms_tools/tests/test_setup/test_initial_conditions.py,sha256=ERjIVcRWW49jeMHdZR3bPbv2AMTeFX1y4UZDOw0Rf_4,19777
|
|
39
39
|
roms_tools/tests/test_setup/test_nesting.py,sha256=75UxhfzfsINBolBzBHJsw6zGebwAG8A8NJUuXTryQpc,18791
|
|
40
|
-
roms_tools/tests/test_setup/test_river_forcing.py,sha256=
|
|
40
|
+
roms_tools/tests/test_setup/test_river_forcing.py,sha256=Egj1BzRwk4_o4651p6v2ENNX1kQolZVXgReAZ0Yo7ic,34620
|
|
41
41
|
roms_tools/tests/test_setup/test_surface_forcing.py,sha256=336-qf-A05SxWIXLIXgEpUhAiLedRHyRXau3uVQKT4Q,31534
|
|
42
42
|
roms_tools/tests/test_setup/test_tides.py,sha256=Rh5pkI9-z_TRtgLSTook5GM0OmynAkm6H6dQoqNWwxE,10537
|
|
43
|
-
roms_tools/tests/test_setup/test_utils.py,sha256=
|
|
43
|
+
roms_tools/tests/test_setup/test_utils.py,sha256=SK_Dn8wZjljAWn23-DYGF0f3hqgWX19xCm8MSE0pKqs,3869
|
|
44
44
|
roms_tools/tests/test_setup/test_validation.py,sha256=aqi3g-c7yhAhBAFSv6aPyxlDLbQxk2u76BQBDUYknl8,3861
|
|
45
45
|
roms_tools/tests/test_setup/test_data/bgc_boundary_forcing_from_climatology.zarr/zarr.json,sha256=gERlygacAzytRigJbNCKS9LonhqTzBL73cRBgSYM2aU,160667
|
|
46
46
|
roms_tools/tests/test_setup/test_data/bgc_boundary_forcing_from_climatology.zarr/ALK_ALT_CO2_east/zarr.json,sha256=KSNZ-MO2SPQq_XFwju_mKbmEiVMFHBMf-uw063Ow4gc,888
|
|
@@ -1101,8 +1101,8 @@ roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/v_Re/zarr.json,sha256=i
|
|
|
1101
1101
|
roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/v_Re/c/0/0/0,sha256=auJ0X3zKvmJprw4ucmJHfspezi6QWHLZzSe9zHEbK0c,89
|
|
1102
1102
|
roms_tools/tests/test_tiling/test_partition.py,sha256=zGxXd0LyihKk-puCUq0_rBmJMPE92r1gKleoPXRAV9g,8569
|
|
1103
1103
|
roms_tools/tiling/partition.py,sha256=evpZ7EUTm4jOkBYo3Ub1OywmQ7KHfJEHNMmK4tRamn8,13154
|
|
1104
|
-
roms_tools-3.1.
|
|
1105
|
-
roms_tools-3.1.
|
|
1106
|
-
roms_tools-3.1.
|
|
1107
|
-
roms_tools-3.1.
|
|
1108
|
-
roms_tools-3.1.
|
|
1104
|
+
roms_tools-3.1.1.dist-info/licenses/LICENSE,sha256=yiff76E4xRioW2bHhlPpyYpstmePQBx2bF8HhgQhSsg,11318
|
|
1105
|
+
roms_tools-3.1.1.dist-info/METADATA,sha256=oNudt3YdkRr4s9TVXhK-0vfjGXxYok2TSeYMPAtQ0kA,4951
|
|
1106
|
+
roms_tools-3.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
1107
|
+
roms_tools-3.1.1.dist-info/top_level.txt,sha256=aAf4T4nYQSkay5iKJ9kmTjlDgd4ETdp9OSlB4sJdt8Y,19
|
|
1108
|
+
roms_tools-3.1.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|