roms-tools 3.1.0__py3-none-any.whl → 3.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- roms_tools/__init__.py +5 -1
- roms_tools/constants.py +1 -0
- roms_tools/plot.py +164 -9
- roms_tools/regrid.py +6 -1
- roms_tools/setup/boundary_forcing.py +55 -30
- roms_tools/setup/cdr_forcing.py +84 -209
- roms_tools/setup/datasets.py +96 -14
- roms_tools/setup/grid.py +29 -2
- roms_tools/setup/river_forcing.py +110 -52
- roms_tools/setup/surface_forcing.py +12 -4
- roms_tools/setup/utils.py +57 -0
- roms_tools/tests/test_setup/test_boundary_forcing.py +57 -0
- roms_tools/tests/test_setup/test_cdr_forcing.py +53 -3
- roms_tools/tests/test_setup/test_datasets.py +76 -0
- roms_tools/tests/test_setup/test_grid.py +16 -6
- roms_tools/tests/test_setup/test_river_forcing.py +63 -6
- roms_tools/tests/test_setup/test_surface_forcing.py +26 -2
- roms_tools/tests/test_setup/test_utils.py +52 -3
- roms_tools/tests/test_setup/test_validation.py +21 -15
- roms_tools/tests/test_tiling/test_partition.py +45 -0
- roms_tools/tests/test_utils.py +101 -1
- roms_tools/tiling/partition.py +44 -30
- roms_tools/utils.py +426 -131
- {roms_tools-3.1.0.dist-info → roms_tools-3.1.2.dist-info}/METADATA +6 -3
- {roms_tools-3.1.0.dist-info → roms_tools-3.1.2.dist-info}/RECORD +28 -28
- {roms_tools-3.1.0.dist-info → roms_tools-3.1.2.dist-info}/WHEEL +0 -0
- {roms_tools-3.1.0.dist-info → roms_tools-3.1.2.dist-info}/licenses/LICENSE +0 -0
- {roms_tools-3.1.0.dist-info → roms_tools-3.1.2.dist-info}/top_level.txt +0 -0
roms_tools/setup/cdr_forcing.py
CHANGED
|
@@ -6,12 +6,10 @@ from datetime import datetime
|
|
|
6
6
|
from pathlib import Path
|
|
7
7
|
from typing import Annotated
|
|
8
8
|
|
|
9
|
-
import cartopy.crs as ccrs
|
|
10
9
|
import matplotlib.gridspec as gridspec
|
|
11
10
|
import matplotlib.pyplot as plt
|
|
12
11
|
import numpy as np
|
|
13
12
|
import xarray as xr
|
|
14
|
-
from matplotlib.axes import Axes
|
|
15
13
|
from pydantic import (
|
|
16
14
|
BaseModel,
|
|
17
15
|
Field,
|
|
@@ -22,7 +20,14 @@ from pydantic import (
|
|
|
22
20
|
)
|
|
23
21
|
|
|
24
22
|
from roms_tools import Grid
|
|
25
|
-
from roms_tools.
|
|
23
|
+
from roms_tools.constants import MAX_DISTINCT_COLORS
|
|
24
|
+
from roms_tools.plot import (
|
|
25
|
+
assign_category_colors,
|
|
26
|
+
get_projection,
|
|
27
|
+
plot,
|
|
28
|
+
plot_2d_horizontal_field,
|
|
29
|
+
plot_location,
|
|
30
|
+
)
|
|
26
31
|
from roms_tools.setup.cdr_release import (
|
|
27
32
|
Release,
|
|
28
33
|
ReleaseType,
|
|
@@ -36,6 +41,7 @@ from roms_tools.setup.utils import (
|
|
|
36
41
|
gc_dist,
|
|
37
42
|
get_target_coords,
|
|
38
43
|
to_dict,
|
|
44
|
+
validate_names,
|
|
39
45
|
write_to_yaml,
|
|
40
46
|
)
|
|
41
47
|
from roms_tools.utils import (
|
|
@@ -45,6 +51,7 @@ from roms_tools.utils import (
|
|
|
45
51
|
from roms_tools.vertical_coordinate import compute_depth_coordinates
|
|
46
52
|
|
|
47
53
|
INCLUDE_ALL_RELEASE_NAMES = "all"
|
|
54
|
+
MAX_RELEASES_TO_PLOT = 20 # must be <= MAX_DISTINCT_COLORS
|
|
48
55
|
|
|
49
56
|
|
|
50
57
|
class ReleaseSimulationManager(BaseModel):
|
|
@@ -389,7 +396,10 @@ class CDRForcing(BaseModel):
|
|
|
389
396
|
return self._ds
|
|
390
397
|
|
|
391
398
|
def plot_volume_flux(
|
|
392
|
-
self,
|
|
399
|
+
self,
|
|
400
|
+
start: datetime | None = None,
|
|
401
|
+
end: datetime | None = None,
|
|
402
|
+
release_names: list[str] | str = INCLUDE_ALL_RELEASE_NAMES,
|
|
393
403
|
):
|
|
394
404
|
"""Plot the volume flux for each specified release within the given time range.
|
|
395
405
|
|
|
@@ -419,12 +429,7 @@ class CDRForcing(BaseModel):
|
|
|
419
429
|
start = start or self.start_time
|
|
420
430
|
end = end or self.end_time
|
|
421
431
|
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
if release_names == INCLUDE_ALL_RELEASE_NAMES:
|
|
425
|
-
release_names = valid_release_names
|
|
426
|
-
|
|
427
|
-
_validate_release_input(release_names, valid_release_names)
|
|
432
|
+
release_names = _validate_release_names(release_names, self.releases)
|
|
428
433
|
|
|
429
434
|
data = self.ds["cdr_volume"]
|
|
430
435
|
|
|
@@ -440,9 +445,9 @@ class CDRForcing(BaseModel):
|
|
|
440
445
|
def plot_tracer_concentration(
|
|
441
446
|
self,
|
|
442
447
|
tracer_name: str,
|
|
443
|
-
start=None,
|
|
444
|
-
end=None,
|
|
445
|
-
release_names=INCLUDE_ALL_RELEASE_NAMES,
|
|
448
|
+
start: datetime | None = None,
|
|
449
|
+
end: datetime | None = None,
|
|
450
|
+
release_names: list[str] | str = INCLUDE_ALL_RELEASE_NAMES,
|
|
446
451
|
):
|
|
447
452
|
"""Plot the concentration of a given tracer for each specified release within
|
|
448
453
|
the given time range.
|
|
@@ -476,12 +481,7 @@ class CDRForcing(BaseModel):
|
|
|
476
481
|
start = start or self.start_time
|
|
477
482
|
end = end or self.end_time
|
|
478
483
|
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
if release_names == INCLUDE_ALL_RELEASE_NAMES:
|
|
482
|
-
release_names = valid_release_names
|
|
483
|
-
|
|
484
|
-
_validate_release_input(release_names, valid_release_names)
|
|
484
|
+
release_names = _validate_release_names(release_names, self.releases)
|
|
485
485
|
|
|
486
486
|
tracer_names = list(self.ds["tracer_name"].values)
|
|
487
487
|
if tracer_name not in tracer_names:
|
|
@@ -511,9 +511,9 @@ class CDRForcing(BaseModel):
|
|
|
511
511
|
def plot_tracer_flux(
|
|
512
512
|
self,
|
|
513
513
|
tracer_name: str,
|
|
514
|
-
start=None,
|
|
515
|
-
end=None,
|
|
516
|
-
release_names=INCLUDE_ALL_RELEASE_NAMES,
|
|
514
|
+
start: datetime | None = None,
|
|
515
|
+
end: datetime | None = None,
|
|
516
|
+
release_names: list[str] | str = INCLUDE_ALL_RELEASE_NAMES,
|
|
517
517
|
):
|
|
518
518
|
"""Plot the flux of a given tracer for each specified release within the given
|
|
519
519
|
time range.
|
|
@@ -547,12 +547,7 @@ class CDRForcing(BaseModel):
|
|
|
547
547
|
start = start or self.start_time
|
|
548
548
|
end = end or self.end_time
|
|
549
549
|
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
if release_names == INCLUDE_ALL_RELEASE_NAMES:
|
|
553
|
-
release_names = valid_release_names
|
|
554
|
-
|
|
555
|
-
_validate_release_input(release_names, valid_release_names)
|
|
550
|
+
release_names = _validate_release_names(release_names, self.releases)
|
|
556
551
|
|
|
557
552
|
tracer_names = list(self.ds["tracer_name"].values)
|
|
558
553
|
if tracer_name not in tracer_names:
|
|
@@ -577,7 +572,10 @@ class CDRForcing(BaseModel):
|
|
|
577
572
|
def _plot_line(self, data, release_names, start, end, title="", ylabel=""):
|
|
578
573
|
"""Plots a line graph for the specified releases and time range."""
|
|
579
574
|
valid_release_names = [r.name for r in self.releases]
|
|
580
|
-
|
|
575
|
+
if len(valid_release_names) > MAX_DISTINCT_COLORS:
|
|
576
|
+
colors = assign_category_colors(release_names)
|
|
577
|
+
else:
|
|
578
|
+
colors = assign_category_colors(valid_release_names)
|
|
581
579
|
|
|
582
580
|
fig, ax = plt.subplots(1, 1, figsize=(7, 4))
|
|
583
581
|
for name in release_names:
|
|
@@ -596,7 +594,9 @@ class CDRForcing(BaseModel):
|
|
|
596
594
|
ax.set(title=title, ylabel=ylabel, xlabel="time")
|
|
597
595
|
ax.set_xlim([start, end])
|
|
598
596
|
|
|
599
|
-
def plot_locations(
|
|
597
|
+
def plot_locations(
|
|
598
|
+
self, release_names: list[str] | str = INCLUDE_ALL_RELEASE_NAMES
|
|
599
|
+
):
|
|
600
600
|
"""Plot centers of release locations in top-down view.
|
|
601
601
|
|
|
602
602
|
Parameters
|
|
@@ -619,12 +619,7 @@ class CDRForcing(BaseModel):
|
|
|
619
619
|
"A grid must be provided for plotting. Please pass a valid `Grid` object."
|
|
620
620
|
)
|
|
621
621
|
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
if release_names == "all":
|
|
625
|
-
release_names = valid_release_names
|
|
626
|
-
|
|
627
|
-
_validate_release_input(release_names, valid_release_names)
|
|
622
|
+
release_names = _validate_release_names(release_names, self.releases)
|
|
628
623
|
|
|
629
624
|
lon_deg = self.grid.ds.lon_rho
|
|
630
625
|
lat_deg = self.grid.ds.lat_rho
|
|
@@ -645,12 +640,22 @@ class CDRForcing(BaseModel):
|
|
|
645
640
|
plot_2d_horizontal_field(field, kwargs=kwargs, ax=ax, add_colorbar=False)
|
|
646
641
|
|
|
647
642
|
# Plot release locations
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
643
|
+
valid_release_names = [r.name for r in self.releases]
|
|
644
|
+
if len(valid_release_names) > MAX_DISTINCT_COLORS:
|
|
645
|
+
colors = assign_category_colors(release_names)
|
|
646
|
+
else:
|
|
647
|
+
colors = assign_category_colors(valid_release_names)
|
|
648
|
+
plot_location(
|
|
649
|
+
grid_ds=self.grid.ds,
|
|
650
|
+
points={
|
|
651
|
+
name: {
|
|
652
|
+
"lat": self.releases[name].lat,
|
|
653
|
+
"lon": self.releases[name].lon,
|
|
654
|
+
"color": colors.get(name, "k"),
|
|
655
|
+
}
|
|
656
|
+
for name in release_names
|
|
657
|
+
},
|
|
652
658
|
ax=ax,
|
|
653
|
-
colors=colors,
|
|
654
659
|
)
|
|
655
660
|
|
|
656
661
|
def plot_distribution(self, release_name: str, mark_release_center: bool = True):
|
|
@@ -680,8 +685,13 @@ class CDRForcing(BaseModel):
|
|
|
680
685
|
"A grid must be provided for plotting. Please pass a valid `Grid` object."
|
|
681
686
|
)
|
|
682
687
|
|
|
683
|
-
|
|
684
|
-
|
|
688
|
+
if not isinstance(release_name, str):
|
|
689
|
+
raise ValueError(
|
|
690
|
+
f"Only a single release name (string) is allowed. Got: {release_name!r}"
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
release_name = _validate_release_names([release_name], self.releases)[0]
|
|
694
|
+
|
|
685
695
|
release = self.releases[release_name]
|
|
686
696
|
|
|
687
697
|
# Prepare grid coordinates
|
|
@@ -713,8 +723,16 @@ class CDRForcing(BaseModel):
|
|
|
713
723
|
title="Depth-integrated distribution",
|
|
714
724
|
)
|
|
715
725
|
if mark_release_center:
|
|
716
|
-
|
|
717
|
-
|
|
726
|
+
plot_location(
|
|
727
|
+
grid_ds=self.grid.ds,
|
|
728
|
+
points={
|
|
729
|
+
release.name: {
|
|
730
|
+
"lat": release.lat,
|
|
731
|
+
"lon": release.lon,
|
|
732
|
+
}
|
|
733
|
+
},
|
|
734
|
+
ax=ax0,
|
|
735
|
+
include_legend=False,
|
|
718
736
|
)
|
|
719
737
|
|
|
720
738
|
# Spread horizontal Gaussian field into the vertical
|
|
@@ -828,106 +846,39 @@ class CDRForcing(BaseModel):
|
|
|
828
846
|
return cls(grid=grid, **params)
|
|
829
847
|
|
|
830
848
|
|
|
831
|
-
def
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
This method ensures that the `releases` parameter is either a single release name (string) or a list
|
|
836
|
-
of release names (strings), and checks that each release exists in the set of valid releases.
|
|
837
|
-
|
|
838
|
-
Parameters
|
|
839
|
-
----------
|
|
840
|
-
releases : str or list of str
|
|
841
|
-
A single release name as a string, or a list of release names (strings) to validate.
|
|
842
|
-
|
|
843
|
-
list_allowed : bool, optional
|
|
844
|
-
If `True`, a list of release names is allowed. If `False`, only a single release name (string)
|
|
845
|
-
is allowed. Default is `True`.
|
|
846
|
-
|
|
847
|
-
Raises
|
|
848
|
-
------
|
|
849
|
-
ValueError
|
|
850
|
-
If `releases` is not a string or list of strings, or if any release name is invalid (not in `self.releases`).
|
|
851
|
-
|
|
852
|
-
Notes
|
|
853
|
-
-----
|
|
854
|
-
This method checks that the `releases` input is in a valid format (either a string or a list of strings),
|
|
855
|
-
and ensures each release is present in the set of valid releases defined in `self.releases`. Invalid releases
|
|
856
|
-
are reported in the error message.
|
|
857
|
-
|
|
858
|
-
If `list_allowed` is set to `False`, only a single release name (string) will be accepted. Otherwise, a
|
|
859
|
-
list of release names is also acceptable.
|
|
849
|
+
def _validate_release_names(
|
|
850
|
+
release_names: list[str] | str, releases: ReleaseCollector
|
|
851
|
+
) -> list[str]:
|
|
860
852
|
"""
|
|
861
|
-
|
|
862
|
-
if not list_allowed and not isinstance(releases, str):
|
|
863
|
-
raise ValueError(
|
|
864
|
-
f"Only a single release name (string) is allowed. Got: {releases}"
|
|
865
|
-
)
|
|
853
|
+
Validate and filter a list of release names.
|
|
866
854
|
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
elif isinstance(releases, list):
|
|
870
|
-
if not all(isinstance(r, str) for r in releases):
|
|
871
|
-
raise ValueError("All elements in `releases` list must be strings.")
|
|
872
|
-
else:
|
|
873
|
-
raise ValueError(
|
|
874
|
-
"`releases` should be a string (single release name) or a list of strings (release names)."
|
|
875
|
-
)
|
|
876
|
-
|
|
877
|
-
# Validate that the specified releases exist in self.releases
|
|
878
|
-
invalid_releases = [
|
|
879
|
-
release for release in releases if release not in valid_releases
|
|
880
|
-
]
|
|
881
|
-
if invalid_releases:
|
|
882
|
-
raise ValueError(f"Invalid releases: {', '.join(invalid_releases)}")
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
def _get_release_colors(valid_releases: list[str]) -> dict[str, tuple]:
|
|
886
|
-
"""Returns a dictionary of colors for the valid releases, based on a consistent
|
|
887
|
-
colormap.
|
|
855
|
+
Ensures that each release name exists in `releases` and limits the list
|
|
856
|
+
to `MAX_RELEASES_TO_PLOT` entries with a warning if truncated.
|
|
888
857
|
|
|
889
858
|
Parameters
|
|
890
859
|
----------
|
|
891
|
-
|
|
892
|
-
|
|
860
|
+
release_names : list of str or INCLUDE_ALL_RELEASE_NAMES
|
|
861
|
+
Names of releases to plot, or sentinel to include all.
|
|
862
|
+
releases : ReleaseCollector
|
|
863
|
+
Object containing valid release names.
|
|
893
864
|
|
|
894
865
|
Returns
|
|
895
866
|
-------
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
assigned based on the order of releases in the valid releases list.
|
|
867
|
+
list of str
|
|
868
|
+
Validated and truncated list of release names.
|
|
899
869
|
|
|
900
870
|
Raises
|
|
901
871
|
------
|
|
902
872
|
ValueError
|
|
903
|
-
If
|
|
904
|
-
|
|
905
|
-
Notes
|
|
906
|
-
-----
|
|
907
|
-
The colormap is chosen dynamically based on the number of valid releases:
|
|
908
|
-
|
|
909
|
-
- If there are 10 or fewer releases, the "tab10" colormap is used.
|
|
910
|
-
- If there are more than 10 but fewer than or equal to 20 releases, the "tab20" colormap is used.
|
|
911
|
-
- For more than 20 releases, the "tab20b" colormap is used.
|
|
873
|
+
If any names are invalid.
|
|
912
874
|
"""
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
# Ensure the number of releases doesn't exceed the available colormap capacity
|
|
922
|
-
if len(valid_releases) > color_map.N:
|
|
923
|
-
raise ValueError(
|
|
924
|
-
f"Too many releases. The selected colormap supports up to {color_map.N} releases."
|
|
925
|
-
)
|
|
926
|
-
|
|
927
|
-
# Create a dictionary of colors based on the release indices
|
|
928
|
-
colors = {name: color_map(i) for i, name in enumerate(valid_releases)}
|
|
929
|
-
|
|
930
|
-
return colors
|
|
875
|
+
return validate_names(
|
|
876
|
+
release_names,
|
|
877
|
+
[r.name for r in releases],
|
|
878
|
+
INCLUDE_ALL_RELEASE_NAMES,
|
|
879
|
+
MAX_RELEASES_TO_PLOT,
|
|
880
|
+
label="release",
|
|
881
|
+
)
|
|
931
882
|
|
|
932
883
|
|
|
933
884
|
def _validate_release_location(grid, release: Release):
|
|
@@ -1088,91 +1039,15 @@ def _map_3d_gaussian(
|
|
|
1088
1039
|
# Stack 2D distribution at that vertical level
|
|
1089
1040
|
distribution_3d[{"s_rho": vertical_idx}] = distribution_2d
|
|
1090
1041
|
else:
|
|
1091
|
-
# Compute layer thickness
|
|
1092
|
-
depth_interface = compute_depth_coordinates(
|
|
1093
|
-
grid.ds, zeta=0, depth_type="interface", location="rho"
|
|
1094
|
-
)
|
|
1095
|
-
dz = depth_interface.diff("s_w").rename({"s_w": "s_rho"})
|
|
1096
|
-
|
|
1097
1042
|
# Compute vertical Gaussian shape
|
|
1098
1043
|
exponent = -(((depth - release.depth) / release.vsc) ** 2)
|
|
1099
1044
|
vertical_profile = np.exp(exponent)
|
|
1100
1045
|
|
|
1101
1046
|
# Apply vertical Gaussian scaling
|
|
1102
|
-
distribution_3d = distribution_2d * vertical_profile
|
|
1047
|
+
distribution_3d = distribution_2d * vertical_profile
|
|
1103
1048
|
|
|
1104
1049
|
# Normalize
|
|
1105
1050
|
distribution_3d /= release.vsc * np.sqrt(np.pi)
|
|
1106
1051
|
distribution_3d /= distribution_3d.sum()
|
|
1107
1052
|
|
|
1108
1053
|
return distribution_3d
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
def _plot_location(
|
|
1112
|
-
grid: Grid,
|
|
1113
|
-
releases: ReleaseCollector,
|
|
1114
|
-
ax: Axes,
|
|
1115
|
-
colors: dict[str, tuple] | None = None,
|
|
1116
|
-
include_legend: bool = True,
|
|
1117
|
-
) -> None:
|
|
1118
|
-
"""Plot the center location of each release on a top-down map view.
|
|
1119
|
-
|
|
1120
|
-
Each release is represented as a point on the map, with its color
|
|
1121
|
-
determined by the `colors` dictionary.
|
|
1122
|
-
|
|
1123
|
-
Parameters
|
|
1124
|
-
----------
|
|
1125
|
-
grid : Grid
|
|
1126
|
-
The grid object defining the spatial extent and coordinate system for the plot.
|
|
1127
|
-
|
|
1128
|
-
releases : ReleaseCollector
|
|
1129
|
-
Collection of `Release` objects to plot. Each `Release` must have `.lat`, `.lon`,
|
|
1130
|
-
and `.name` attributes.
|
|
1131
|
-
|
|
1132
|
-
ax : matplotlib.axes.Axes
|
|
1133
|
-
The Matplotlib axis object to plot on.
|
|
1134
|
-
|
|
1135
|
-
colors : dict of str to tuple, optional
|
|
1136
|
-
Optional dictionary mapping release names to RGBA color tuples. If not provided,
|
|
1137
|
-
all releases are plotted in a default color (`"#dd1c77"`).
|
|
1138
|
-
|
|
1139
|
-
include_legend : bool, default True
|
|
1140
|
-
Whether to include a legend showing release names.
|
|
1141
|
-
|
|
1142
|
-
Returns
|
|
1143
|
-
-------
|
|
1144
|
-
None
|
|
1145
|
-
"""
|
|
1146
|
-
lon_deg = grid.ds.lon_rho
|
|
1147
|
-
lat_deg = grid.ds.lat_rho
|
|
1148
|
-
if grid.straddle:
|
|
1149
|
-
lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
|
|
1150
|
-
trans = get_projection(lon_deg, lat_deg)
|
|
1151
|
-
|
|
1152
|
-
proj = ccrs.PlateCarree()
|
|
1153
|
-
|
|
1154
|
-
for release in releases:
|
|
1155
|
-
# transform coordinates to projected space
|
|
1156
|
-
transformed_lon, transformed_lat = trans.transform_point(
|
|
1157
|
-
release.lon,
|
|
1158
|
-
release.lat,
|
|
1159
|
-
proj,
|
|
1160
|
-
)
|
|
1161
|
-
|
|
1162
|
-
if colors is not None:
|
|
1163
|
-
color = colors[release.name]
|
|
1164
|
-
else:
|
|
1165
|
-
color = "k"
|
|
1166
|
-
|
|
1167
|
-
ax.plot(
|
|
1168
|
-
transformed_lon,
|
|
1169
|
-
transformed_lat,
|
|
1170
|
-
marker="x",
|
|
1171
|
-
markersize=8,
|
|
1172
|
-
markeredgewidth=2,
|
|
1173
|
-
label=release.name,
|
|
1174
|
-
color=color,
|
|
1175
|
-
)
|
|
1176
|
-
|
|
1177
|
-
if include_legend:
|
|
1178
|
-
ax.legend(loc="center left", bbox_to_anchor=(1.1, 0.5))
|
roms_tools/setup/datasets.py
CHANGED
|
@@ -1,9 +1,13 @@
|
|
|
1
|
+
import importlib.util
|
|
1
2
|
import logging
|
|
2
3
|
import time
|
|
3
4
|
from collections import Counter, defaultdict
|
|
5
|
+
from collections.abc import Callable
|
|
4
6
|
from dataclasses import dataclass, field
|
|
5
7
|
from datetime import datetime, timedelta
|
|
6
8
|
from pathlib import Path
|
|
9
|
+
from types import ModuleType
|
|
10
|
+
from typing import ClassVar
|
|
7
11
|
|
|
8
12
|
import numpy as np
|
|
9
13
|
import xarray as xr
|
|
@@ -25,7 +29,7 @@ from roms_tools.setup.utils import (
|
|
|
25
29
|
interpolate_from_climatology,
|
|
26
30
|
one_dim_fill,
|
|
27
31
|
)
|
|
28
|
-
from roms_tools.utils import _has_gcsfs, _load_data
|
|
32
|
+
from roms_tools.utils import _get_pkg_error_msg, _has_gcsfs, _load_data
|
|
29
33
|
|
|
30
34
|
# lat-lon datasets
|
|
31
35
|
|
|
@@ -96,17 +100,18 @@ class Dataset:
|
|
|
96
100
|
use_dask: bool | None = False
|
|
97
101
|
apply_post_processing: bool | None = True
|
|
98
102
|
read_zarr: bool | None = False
|
|
103
|
+
ds_loader_fn: Callable[[], xr.Dataset] | None = None
|
|
99
104
|
|
|
100
105
|
is_global: bool = field(init=False, repr=False)
|
|
101
106
|
ds: xr.Dataset = field(init=False, repr=False)
|
|
102
107
|
|
|
103
|
-
def __post_init__(self):
|
|
104
|
-
"""
|
|
105
|
-
|
|
108
|
+
def __post_init__(self) -> None:
|
|
109
|
+
"""Perform post-initialization processing.
|
|
110
|
+
|
|
106
111
|
1. Loads the dataset from the specified filename.
|
|
107
|
-
2. Applies time filtering based on start_time and end_time if provided.
|
|
108
|
-
3. Selects relevant fields as specified by var_names
|
|
109
|
-
4. Ensures latitude
|
|
112
|
+
2. Applies time filtering based on start_time and end_time (if provided).
|
|
113
|
+
3. Selects relevant fields as specified by `var_names`.
|
|
114
|
+
4. Ensures latitude, longitude, and depth values are in ascending order.
|
|
110
115
|
5. Checks if the dataset covers the entire globe and adjusts if necessary.
|
|
111
116
|
"""
|
|
112
117
|
# Validate start_time and end_time
|
|
@@ -168,7 +173,11 @@ class Dataset:
|
|
|
168
173
|
If a list of files is provided but self.dim_names["time"] is not available or use_dask=False.
|
|
169
174
|
"""
|
|
170
175
|
ds = _load_data(
|
|
171
|
-
self.filename,
|
|
176
|
+
self.filename,
|
|
177
|
+
self.dim_names,
|
|
178
|
+
self.use_dask or False,
|
|
179
|
+
read_zarr=self.read_zarr or False,
|
|
180
|
+
ds_loader_fn=self.ds_loader_fn,
|
|
172
181
|
)
|
|
173
182
|
|
|
174
183
|
return ds
|
|
@@ -1075,6 +1084,83 @@ class GLORYSDataset(Dataset):
|
|
|
1075
1084
|
self.ds["mask_vel"] = mask_vel
|
|
1076
1085
|
|
|
1077
1086
|
|
|
1087
|
+
@dataclass(kw_only=True)
|
|
1088
|
+
class GLORYSDefaultDataset(GLORYSDataset):
|
|
1089
|
+
"""A GLORYS dataset that is loaded from the Copernicus Marine Data Store."""
|
|
1090
|
+
|
|
1091
|
+
dataset_name: ClassVar[str] = "cmems_mod_glo_phy_my_0.083deg_P1D-m"
|
|
1092
|
+
"""The GLORYS dataset-id for requests to the Copernicus Marine Toolkit"""
|
|
1093
|
+
_tk_module: ModuleType | None = None
|
|
1094
|
+
"""The dynamically imported Copernicus Marine module."""
|
|
1095
|
+
|
|
1096
|
+
def __post_init__(self) -> None:
|
|
1097
|
+
"""Configure attributes to ensure use of the correct upstream data-source."""
|
|
1098
|
+
self.read_zarr = True
|
|
1099
|
+
self.use_dask = True
|
|
1100
|
+
self.filename = self.dataset_name
|
|
1101
|
+
self.ds_loader_fn = self._load_from_copernicus
|
|
1102
|
+
|
|
1103
|
+
super().__post_init__()
|
|
1104
|
+
|
|
1105
|
+
def _check_auth(self, package_name: str) -> None:
|
|
1106
|
+
"""Check the local credential hierarchy for auth credentials.
|
|
1107
|
+
|
|
1108
|
+
Raises
|
|
1109
|
+
------
|
|
1110
|
+
RuntimeError
|
|
1111
|
+
If auth credentials cannot be found.
|
|
1112
|
+
"""
|
|
1113
|
+
if self._tk_module and not self._tk_module.login(check_credentials_valid=True):
|
|
1114
|
+
msg = f"Authenticate with `{package_name} login` to retrieve GLORYS data."
|
|
1115
|
+
raise RuntimeError(msg)
|
|
1116
|
+
|
|
1117
|
+
def _load_copernicus(self) -> ModuleType:
|
|
1118
|
+
"""Dynamically load the optional Copernicus Marine Toolkit dependency.
|
|
1119
|
+
|
|
1120
|
+
Raises
|
|
1121
|
+
------
|
|
1122
|
+
RuntimeError
|
|
1123
|
+
- If the toolkit module is not available or cannot be imported.
|
|
1124
|
+
- If auth credentials cannot be found.
|
|
1125
|
+
"""
|
|
1126
|
+
package_name = "copernicusmarine"
|
|
1127
|
+
if self._tk_module:
|
|
1128
|
+
self._check_auth(package_name)
|
|
1129
|
+
return self._tk_module
|
|
1130
|
+
|
|
1131
|
+
spec = importlib.util.find_spec(package_name)
|
|
1132
|
+
if not spec:
|
|
1133
|
+
msg = _get_pkg_error_msg("cloud-based GLORYS data", package_name, "stream")
|
|
1134
|
+
raise RuntimeError(msg)
|
|
1135
|
+
|
|
1136
|
+
try:
|
|
1137
|
+
self._tk_module = importlib.import_module(package_name)
|
|
1138
|
+
except ImportError as e:
|
|
1139
|
+
msg = f"Package `{package_name}` was found but could not be loaded."
|
|
1140
|
+
raise RuntimeError(msg) from e
|
|
1141
|
+
|
|
1142
|
+
self._check_auth(package_name)
|
|
1143
|
+
return self._tk_module
|
|
1144
|
+
|
|
1145
|
+
def _load_from_copernicus(self) -> xr.Dataset:
|
|
1146
|
+
"""Load a GLORYS dataset supporting streaming.
|
|
1147
|
+
|
|
1148
|
+
Returns
|
|
1149
|
+
-------
|
|
1150
|
+
xr.Dataset
|
|
1151
|
+
The streaming dataset
|
|
1152
|
+
"""
|
|
1153
|
+
copernicusmarine = self._load_copernicus()
|
|
1154
|
+
return copernicusmarine.open_dataset(
|
|
1155
|
+
self.dataset_name,
|
|
1156
|
+
start_datetime=self.start_time,
|
|
1157
|
+
end_datetime=self.end_time,
|
|
1158
|
+
service="arco-geo-series",
|
|
1159
|
+
coordinates_selection_method="inside",
|
|
1160
|
+
chunk_size_limit=2,
|
|
1161
|
+
)
|
|
1162
|
+
|
|
1163
|
+
|
|
1078
1164
|
@dataclass(kw_only=True)
|
|
1079
1165
|
class UnifiedDataset(Dataset):
|
|
1080
1166
|
"""Represents unified BGC data on original grid.
|
|
@@ -1549,12 +1635,8 @@ class ERA5ARCODataset(ERA5Dataset):
|
|
|
1549
1635
|
def __post_init__(self):
|
|
1550
1636
|
self.read_zarr = True
|
|
1551
1637
|
if not _has_gcsfs():
|
|
1552
|
-
|
|
1553
|
-
|
|
1554
|
-
" • `pip install roms-tools[stream]` or\n"
|
|
1555
|
-
" • `conda install gcsfs`\n"
|
|
1556
|
-
"Alternatively, install `roms-tools` with conda to include all dependencies."
|
|
1557
|
-
)
|
|
1638
|
+
msg = _get_pkg_error_msg("cloud-based ERA5 data", "gcsfs", "stream")
|
|
1639
|
+
raise RuntimeError(msg)
|
|
1558
1640
|
|
|
1559
1641
|
super().__post_init__()
|
|
1560
1642
|
|
roms_tools/setup/grid.py
CHANGED
|
@@ -415,30 +415,57 @@ class Grid:
|
|
|
415
415
|
|
|
416
416
|
def plot(
|
|
417
417
|
self,
|
|
418
|
+
lat: float | None = None,
|
|
419
|
+
lon: float | None = None,
|
|
418
420
|
with_dim_names: bool = False,
|
|
419
421
|
save_path: str | None = None,
|
|
420
422
|
) -> None:
|
|
421
|
-
"""Plot the grid.
|
|
423
|
+
"""Plot the grid with bathymetry.
|
|
424
|
+
|
|
425
|
+
Depending on the arguments, this will either:
|
|
426
|
+
* Plot the full horizontal grid (if both `lat` and `lon` are None),
|
|
427
|
+
* Plot a zonal (east-west) vertical section at a given latitude (`lat`),
|
|
428
|
+
* Plot a meridional (south-north) vertical section at a given longitude (`lon`).
|
|
422
429
|
|
|
423
430
|
Parameters
|
|
424
431
|
----------
|
|
432
|
+
lat : float, optional
|
|
433
|
+
Latitude in degrees at which to plot a vertical (zonal) section. Cannot be
|
|
434
|
+
provided together with `lon`. Default is None.
|
|
435
|
+
|
|
436
|
+
lon : float, optional
|
|
437
|
+
Longitude in degrees at which to plot a vertical (meridional) section. Cannot be
|
|
438
|
+
provided together with `lat`. Default is None.
|
|
439
|
+
|
|
425
440
|
with_dim_names : bool, optional
|
|
426
|
-
|
|
441
|
+
If True and no section is requested (i.e., both `lat` and `lon` are None), annotate
|
|
442
|
+
the plot with the underlying dimension names. Default is False.
|
|
427
443
|
|
|
428
444
|
save_path : str, optional
|
|
429
445
|
Path to save the generated plot. If None, the plot is shown interactively.
|
|
430
446
|
Default is None.
|
|
431
447
|
|
|
448
|
+
Raises
|
|
449
|
+
------
|
|
450
|
+
ValueError
|
|
451
|
+
If both `lat` and `lon` are specified simultaneously.
|
|
452
|
+
|
|
432
453
|
Returns
|
|
433
454
|
-------
|
|
434
455
|
None
|
|
435
456
|
This method does not return any value. It generates and displays a plot.
|
|
436
457
|
"""
|
|
458
|
+
if lat is not None and lon is not None:
|
|
459
|
+
raise ValueError("Specify either `lat` or `lon`, not both.")
|
|
460
|
+
|
|
437
461
|
field = self.ds["h"]
|
|
438
462
|
|
|
439
463
|
plot(
|
|
440
464
|
field=field,
|
|
441
465
|
grid_ds=self.ds,
|
|
466
|
+
lat=lat,
|
|
467
|
+
lon=lon,
|
|
468
|
+
yincrease=False,
|
|
442
469
|
with_dim_names=with_dim_names,
|
|
443
470
|
save_path=save_path,
|
|
444
471
|
cmap_name="YlGnBu",
|