roms-tools 3.1.1__py3-none-any.whl → 3.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- roms_tools/__init__.py +8 -1
- roms_tools/analysis/cdr_analysis.py +203 -0
- roms_tools/analysis/cdr_ensemble.py +198 -0
- roms_tools/analysis/roms_output.py +80 -46
- roms_tools/data/grids/GLORYS_global_grid.nc +0 -0
- roms_tools/download.py +4 -0
- roms_tools/plot.py +131 -30
- roms_tools/regrid.py +6 -1
- roms_tools/setup/boundary_forcing.py +94 -44
- roms_tools/setup/cdr_forcing.py +123 -15
- roms_tools/setup/cdr_release.py +161 -8
- roms_tools/setup/datasets.py +709 -341
- roms_tools/setup/grid.py +167 -139
- roms_tools/setup/initial_conditions.py +113 -48
- roms_tools/setup/mask.py +63 -7
- roms_tools/setup/nesting.py +67 -42
- roms_tools/setup/river_forcing.py +45 -19
- roms_tools/setup/surface_forcing.py +16 -10
- roms_tools/setup/tides.py +1 -2
- roms_tools/setup/topography.py +4 -4
- roms_tools/setup/utils.py +134 -22
- roms_tools/tests/test_analysis/test_cdr_analysis.py +144 -0
- roms_tools/tests/test_analysis/test_cdr_ensemble.py +202 -0
- roms_tools/tests/test_analysis/test_roms_output.py +61 -3
- roms_tools/tests/test_setup/test_boundary_forcing.py +111 -52
- roms_tools/tests/test_setup/test_cdr_forcing.py +54 -0
- roms_tools/tests/test_setup/test_cdr_release.py +118 -1
- roms_tools/tests/test_setup/test_datasets.py +458 -34
- roms_tools/tests/test_setup/test_grid.py +238 -121
- roms_tools/tests/test_setup/test_initial_conditions.py +94 -41
- roms_tools/tests/test_setup/test_surface_forcing.py +28 -3
- roms_tools/tests/test_setup/test_utils.py +91 -1
- roms_tools/tests/test_setup/test_validation.py +21 -15
- roms_tools/tests/test_setup/utils.py +71 -0
- roms_tools/tests/test_tiling/test_join.py +241 -0
- roms_tools/tests/test_tiling/test_partition.py +45 -0
- roms_tools/tests/test_utils.py +224 -2
- roms_tools/tiling/join.py +189 -0
- roms_tools/tiling/partition.py +44 -30
- roms_tools/utils.py +488 -161
- {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/METADATA +15 -4
- {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/RECORD +45 -37
- {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/WHEEL +0 -0
- {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/licenses/LICENSE +0 -0
- {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/top_level.txt +0 -0
roms_tools/plot.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from typing import Any, Literal
|
|
2
2
|
|
|
3
3
|
import cartopy.crs as ccrs
|
|
4
|
+
import matplotlib.dates as mdates
|
|
4
5
|
import matplotlib.pyplot as plt
|
|
5
6
|
import numpy as np
|
|
6
7
|
import xarray as xr
|
|
@@ -9,18 +10,16 @@ from matplotlib.figure import Figure
|
|
|
9
10
|
|
|
10
11
|
from roms_tools.regrid import LateralRegridFromROMS, VerticalRegridFromROMS
|
|
11
12
|
from roms_tools.utils import (
|
|
12
|
-
|
|
13
|
-
_remove_edge_nans,
|
|
13
|
+
generate_coordinate_range,
|
|
14
14
|
infer_nominal_horizontal_resolution,
|
|
15
15
|
normalize_longitude,
|
|
16
|
+
remove_edge_nans,
|
|
16
17
|
)
|
|
17
18
|
from roms_tools.vertical_coordinate import compute_depth_coordinates
|
|
18
19
|
|
|
19
20
|
LABEL_COLOR = "k"
|
|
20
21
|
LABEL_SZ = 10
|
|
21
22
|
FONT_SZ = 10
|
|
22
|
-
EDGE_POS_START = "start"
|
|
23
|
-
EDGE_POS_END = "end"
|
|
24
23
|
|
|
25
24
|
|
|
26
25
|
def _add_gridlines(ax: Axes) -> None:
|
|
@@ -212,7 +211,14 @@ def plot_nesting(parent_grid_ds, child_grid_ds, parent_straddle, with_dim_names=
|
|
|
212
211
|
return fig
|
|
213
212
|
|
|
214
213
|
|
|
215
|
-
def section_plot(
|
|
214
|
+
def section_plot(
|
|
215
|
+
field: xr.DataArray,
|
|
216
|
+
interface_depth: xr.DataArray | None = None,
|
|
217
|
+
title: str = "",
|
|
218
|
+
yincrease: bool | None = False,
|
|
219
|
+
kwargs: dict = {},
|
|
220
|
+
ax: Axes | None = None,
|
|
221
|
+
):
|
|
216
222
|
"""Plots a vertical section of a field with optional interface depths.
|
|
217
223
|
|
|
218
224
|
Parameters
|
|
@@ -224,6 +230,11 @@ def section_plot(field, interface_depth=None, title="", kwargs={}, ax=None):
|
|
|
224
230
|
Defaults to None.
|
|
225
231
|
title : str, optional
|
|
226
232
|
Title of the plot. Defaults to an empty string.
|
|
233
|
+
yincrease : bool or None, optional
|
|
234
|
+
Whether to orient the y-axis with increasing values upward.
|
|
235
|
+
If True, y-values increase upward (standard).
|
|
236
|
+
If False, y-values decrease upward (inverted).
|
|
237
|
+
If None (default), behavior is equivalent to False (inverted axis).
|
|
227
238
|
kwargs : dict, optional
|
|
228
239
|
Additional keyword arguments to pass to `xarray.plot`. Defaults to an empty dictionary.
|
|
229
240
|
ax : matplotlib.axes.Axes, optional
|
|
@@ -248,6 +259,8 @@ def section_plot(field, interface_depth=None, title="", kwargs={}, ax=None):
|
|
|
248
259
|
"""
|
|
249
260
|
if ax is None:
|
|
250
261
|
fig, ax = plt.subplots(1, 1, figsize=(9, 5))
|
|
262
|
+
if yincrease is None:
|
|
263
|
+
yincrease = False
|
|
251
264
|
|
|
252
265
|
dims_to_check = ["eta_rho", "eta_v", "xi_rho", "xi_u", "lat", "lon"]
|
|
253
266
|
try:
|
|
@@ -279,7 +292,7 @@ def section_plot(field, interface_depth=None, title="", kwargs={}, ax=None):
|
|
|
279
292
|
# Handle NaNs on either horizontal end
|
|
280
293
|
field = field.where(~field[depth_label].isnull(), drop=True)
|
|
281
294
|
|
|
282
|
-
more_kwargs = {"x": xdim, "y": depth_label, "yincrease":
|
|
295
|
+
more_kwargs = {"x": xdim, "y": depth_label, "yincrease": yincrease}
|
|
283
296
|
|
|
284
297
|
field.plot(**kwargs, **more_kwargs, ax=ax)
|
|
285
298
|
|
|
@@ -313,7 +326,12 @@ def section_plot(field, interface_depth=None, title="", kwargs={}, ax=None):
|
|
|
313
326
|
return fig
|
|
314
327
|
|
|
315
328
|
|
|
316
|
-
def profile_plot(
|
|
329
|
+
def profile_plot(
|
|
330
|
+
field: xr.DataArray,
|
|
331
|
+
title: str = "",
|
|
332
|
+
yincrease: bool | None = False,
|
|
333
|
+
ax: Axes | None = None,
|
|
334
|
+
):
|
|
317
335
|
"""Plots a vertical profile of the given field against depth.
|
|
318
336
|
|
|
319
337
|
This function generates a profile plot by plotting the field values against
|
|
@@ -326,6 +344,11 @@ def profile_plot(field, title="", ax=None):
|
|
|
326
344
|
The field to plot, typically representing vertical profile data.
|
|
327
345
|
title : str, optional
|
|
328
346
|
Title of the plot. Defaults to an empty string.
|
|
347
|
+
yincrease : bool or None, optional
|
|
348
|
+
Whether to orient the y-axis with increasing values upward.
|
|
349
|
+
If True, y-values increase upward (standard).
|
|
350
|
+
If False, y-values decrease upward (inverted).
|
|
351
|
+
If None (default), behavior is equivalent to False (inverted axis).
|
|
329
352
|
ax : matplotlib.axes.Axes, optional
|
|
330
353
|
Pre-existing axes to draw the plot on. If None, a new figure and axes are created.
|
|
331
354
|
|
|
@@ -343,6 +366,9 @@ def profile_plot(field, title="", ax=None):
|
|
|
343
366
|
-----
|
|
344
367
|
- The y-axis is inverted to ensure that depth increases downward.
|
|
345
368
|
"""
|
|
369
|
+
if yincrease is None:
|
|
370
|
+
yincrease = False
|
|
371
|
+
|
|
346
372
|
depths_to_check = [
|
|
347
373
|
"layer_depth",
|
|
348
374
|
"interface_depth",
|
|
@@ -360,8 +386,8 @@ def profile_plot(field, title="", ax=None):
|
|
|
360
386
|
|
|
361
387
|
if ax is None:
|
|
362
388
|
fig, ax = plt.subplots(1, 1, figsize=(4, 7))
|
|
363
|
-
kwargs = {"y": depth_label, "yincrease":
|
|
364
|
-
field.plot(
|
|
389
|
+
kwargs = {"y": depth_label, "yincrease": yincrease}
|
|
390
|
+
field.plot(ax=ax, linewidth=2, **kwargs)
|
|
365
391
|
ax.set_title(title)
|
|
366
392
|
ax.set_ylabel("Depth [m]")
|
|
367
393
|
ax.grid()
|
|
@@ -370,7 +396,12 @@ def profile_plot(field, title="", ax=None):
|
|
|
370
396
|
return fig
|
|
371
397
|
|
|
372
398
|
|
|
373
|
-
def line_plot(
|
|
399
|
+
def line_plot(
|
|
400
|
+
field: xr.DataArray,
|
|
401
|
+
title: str = "",
|
|
402
|
+
ax: Axes | None = None,
|
|
403
|
+
yincrease: bool | None = False,
|
|
404
|
+
):
|
|
374
405
|
"""Plots a line graph of the given field with grey vertical bars indicating NaN
|
|
375
406
|
regions.
|
|
376
407
|
|
|
@@ -382,6 +413,11 @@ def line_plot(field, title="", ax=None):
|
|
|
382
413
|
Title of the plot. Defaults to an empty string.
|
|
383
414
|
ax : matplotlib.axes.Axes, optional
|
|
384
415
|
Pre-existing axes to draw the plot on. If None, a new figure and axes are created.
|
|
416
|
+
yincrease : bool, optional
|
|
417
|
+
Whether to orient the y-axis with increasing values upward.
|
|
418
|
+
If True, y-values increase upward (standard).
|
|
419
|
+
If False, y-values decrease upward (inverted).
|
|
420
|
+
If None (default), behavior is equivalent to True (standard axis).
|
|
385
421
|
|
|
386
422
|
Returns
|
|
387
423
|
-------
|
|
@@ -399,10 +435,12 @@ def line_plot(field, title="", ax=None):
|
|
|
399
435
|
-----
|
|
400
436
|
- NaN regions are identified and marked using `axvspan` with a grey shade.
|
|
401
437
|
"""
|
|
438
|
+
if yincrease is None:
|
|
439
|
+
yincrease = True
|
|
402
440
|
if ax is None:
|
|
403
441
|
fig, ax = plt.subplots(1, 1, figsize=(7, 4))
|
|
404
442
|
|
|
405
|
-
field.plot(ax=ax, linewidth=2)
|
|
443
|
+
field.plot(ax=ax, linewidth=2, yincrease=yincrease)
|
|
406
444
|
|
|
407
445
|
# Loop through the NaNs in the field and add grey vertical bars
|
|
408
446
|
dims_to_check = ["eta_rho", "eta_v", "xi_rho", "xi_u", "lat", "lon"]
|
|
@@ -458,16 +496,16 @@ def line_plot(field, title="", ax=None):
|
|
|
458
496
|
|
|
459
497
|
|
|
460
498
|
def _get_edge(
|
|
461
|
-
arr: xr.DataArray, dim_name: str, pos: Literal[
|
|
499
|
+
arr: xr.DataArray, dim_name: str, pos: Literal["start", "end"]
|
|
462
500
|
) -> xr.DataArray:
|
|
463
501
|
"""Extract the first ("start") or last ("end") slice along the given dimension."""
|
|
464
|
-
if pos ==
|
|
502
|
+
if pos == "start":
|
|
465
503
|
return arr.isel({dim_name: 0})
|
|
466
504
|
|
|
467
|
-
if pos ==
|
|
505
|
+
if pos == "end":
|
|
468
506
|
return arr.isel({dim_name: -1})
|
|
469
507
|
|
|
470
|
-
raise ValueError(
|
|
508
|
+
raise ValueError("pos must be `start` or `end`")
|
|
471
509
|
|
|
472
510
|
|
|
473
511
|
def _add_boundary_to_ax(
|
|
@@ -499,23 +537,23 @@ def _add_boundary_to_ax(
|
|
|
499
537
|
|
|
500
538
|
edges = [
|
|
501
539
|
(
|
|
502
|
-
_get_edge(lon_deg, xi_dim,
|
|
503
|
-
_get_edge(lat_deg, xi_dim,
|
|
540
|
+
_get_edge(lon_deg, xi_dim, "start"),
|
|
541
|
+
_get_edge(lat_deg, xi_dim, "start"),
|
|
504
542
|
r"$\eta$",
|
|
505
543
|
), # left
|
|
506
544
|
(
|
|
507
|
-
_get_edge(lon_deg, xi_dim,
|
|
508
|
-
_get_edge(lat_deg, xi_dim,
|
|
545
|
+
_get_edge(lon_deg, xi_dim, "end"),
|
|
546
|
+
_get_edge(lat_deg, xi_dim, "end"),
|
|
509
547
|
r"$\eta$",
|
|
510
548
|
), # right
|
|
511
549
|
(
|
|
512
|
-
_get_edge(lon_deg, eta_dim,
|
|
513
|
-
_get_edge(lat_deg, eta_dim,
|
|
550
|
+
_get_edge(lon_deg, eta_dim, "start"),
|
|
551
|
+
_get_edge(lat_deg, eta_dim, "start"),
|
|
514
552
|
r"$\xi$",
|
|
515
553
|
), # bottom
|
|
516
554
|
(
|
|
517
|
-
_get_edge(lon_deg, eta_dim,
|
|
518
|
-
_get_edge(lat_deg, eta_dim,
|
|
555
|
+
_get_edge(lon_deg, eta_dim, "end"),
|
|
556
|
+
_get_edge(lat_deg, eta_dim, "end"),
|
|
519
557
|
r"$\xi$",
|
|
520
558
|
), # top
|
|
521
559
|
]
|
|
@@ -775,11 +813,12 @@ def plot(
|
|
|
775
813
|
depth_contours: bool = False,
|
|
776
814
|
layer_contours: bool = False,
|
|
777
815
|
max_nr_layer_contours: int | None = 10,
|
|
816
|
+
yincrease: bool | None = None,
|
|
778
817
|
use_coarse_grid: bool = False,
|
|
779
818
|
with_dim_names: bool = False,
|
|
780
819
|
ax: Axes | None = None,
|
|
781
820
|
save_path: str | None = None,
|
|
782
|
-
cmap_name: str
|
|
821
|
+
cmap_name: str = "YlOrRd",
|
|
783
822
|
add_colorbar: bool = True,
|
|
784
823
|
) -> None:
|
|
785
824
|
"""Generate a plot of a 2D or 3D ROMS field for a horizontal or vertical slice.
|
|
@@ -838,6 +877,12 @@ def plot(
|
|
|
838
877
|
max_nr_layer_contours : int, optional
|
|
839
878
|
Maximum number of vertical layer contours to draw. Default is 10.
|
|
840
879
|
|
|
880
|
+
yincrease: bool, optional
|
|
881
|
+
If True, the y-axis values increase upward (standard orientation).
|
|
882
|
+
If False, the y-axis values decrease upward (inverted axis).
|
|
883
|
+
If None (default), the orientation is determined by the default behavior
|
|
884
|
+
of the underlying plotting function.
|
|
885
|
+
|
|
841
886
|
use_coarse_grid : bool, optional
|
|
842
887
|
Use precomputed coarse-resolution grid. Default is False.
|
|
843
888
|
|
|
@@ -1006,7 +1051,7 @@ def plot(
|
|
|
1006
1051
|
title = title + f", lat = {lat}°N"
|
|
1007
1052
|
else:
|
|
1008
1053
|
resolution = infer_nominal_horizontal_resolution(grid_ds)
|
|
1009
|
-
lats =
|
|
1054
|
+
lats = generate_coordinate_range(
|
|
1010
1055
|
field.lat.min().values, field.lat.max().values, resolution
|
|
1011
1056
|
)
|
|
1012
1057
|
lats = xr.DataArray(lats, dims=["lat"], attrs={"units": "°N"})
|
|
@@ -1016,7 +1061,7 @@ def plot(
|
|
|
1016
1061
|
title = title + f", lon = {lon}°E"
|
|
1017
1062
|
else:
|
|
1018
1063
|
resolution = infer_nominal_horizontal_resolution(grid_ds, lat)
|
|
1019
|
-
lons =
|
|
1064
|
+
lons = generate_coordinate_range(
|
|
1020
1065
|
field.lon.min().values, field.lon.max().values, resolution
|
|
1021
1066
|
)
|
|
1022
1067
|
lons = xr.DataArray(lons, dims=["lon"], attrs={"units": "°E"})
|
|
@@ -1033,11 +1078,11 @@ def plot(
|
|
|
1033
1078
|
field = field.assign_coords({"layer_depth": layer_depth})
|
|
1034
1079
|
|
|
1035
1080
|
if lat is not None:
|
|
1036
|
-
field, layer_depth =
|
|
1081
|
+
field, layer_depth = remove_edge_nans(
|
|
1037
1082
|
field, "lon", layer_depth if "layer_depth" in locals() else None
|
|
1038
1083
|
)
|
|
1039
1084
|
if lon is not None:
|
|
1040
|
-
field, layer_depth =
|
|
1085
|
+
field, layer_depth = remove_edge_nans(
|
|
1041
1086
|
field, "lat", layer_depth if "layer_depth" in locals() else None
|
|
1042
1087
|
)
|
|
1043
1088
|
|
|
@@ -1086,14 +1131,15 @@ def plot(
|
|
|
1086
1131
|
field,
|
|
1087
1132
|
interface_depth=interface_depth,
|
|
1088
1133
|
title=title,
|
|
1134
|
+
yincrease=yincrease,
|
|
1089
1135
|
kwargs={**kwargs, "add_colorbar": add_colorbar},
|
|
1090
1136
|
ax=ax,
|
|
1091
1137
|
)
|
|
1092
1138
|
else:
|
|
1093
1139
|
if "s_rho" in field.dims:
|
|
1094
|
-
fig = profile_plot(field, title=title, ax=ax)
|
|
1140
|
+
fig = profile_plot(field, title=title, yincrease=yincrease, ax=ax)
|
|
1095
1141
|
else:
|
|
1096
|
-
fig = line_plot(field, title=title, ax=ax)
|
|
1142
|
+
fig = line_plot(field, title=title, ax=ax, yincrease=yincrease)
|
|
1097
1143
|
|
|
1098
1144
|
if save_path:
|
|
1099
1145
|
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
|
@@ -1205,3 +1251,58 @@ def plot_location(
|
|
|
1205
1251
|
|
|
1206
1252
|
if include_legend:
|
|
1207
1253
|
ax.legend(loc="center left", bbox_to_anchor=(1.1, 0.5))
|
|
1254
|
+
|
|
1255
|
+
|
|
1256
|
+
def plot_uptake_efficiency(ds: xr.Dataset) -> None:
|
|
1257
|
+
"""
|
|
1258
|
+
Plot Carbon Dioxide Removal (CDR) uptake efficiency over time.
|
|
1259
|
+
|
|
1260
|
+
This function plots two estimates of uptake efficiency stored in the dataset:
|
|
1261
|
+
1. `cdr_efficiency`, computed from CO2 flux differences.
|
|
1262
|
+
2. `cdr_efficiency_from_delta_diff`, computed from DIC differences.
|
|
1263
|
+
|
|
1264
|
+
The x-axis shows absolute time, formatted as YYYY-MM-DD, and the y-axis shows
|
|
1265
|
+
the uptake efficiency values. The plot includes a legend and grid for clarity.
|
|
1266
|
+
|
|
1267
|
+
Parameters
|
|
1268
|
+
----------
|
|
1269
|
+
ds : xarray.Dataset
|
|
1270
|
+
Dataset containing the following variables:
|
|
1271
|
+
- "abs_time": array of timestamps (datetime-like)
|
|
1272
|
+
- "cdr_efficiency": uptake efficiency from flux differences
|
|
1273
|
+
- "cdr_efficiency_from_delta_diff": uptake efficiency from DIC differences
|
|
1274
|
+
|
|
1275
|
+
Raises
|
|
1276
|
+
------
|
|
1277
|
+
ValueError
|
|
1278
|
+
If required variables are missing or empty.
|
|
1279
|
+
|
|
1280
|
+
Returns
|
|
1281
|
+
-------
|
|
1282
|
+
None
|
|
1283
|
+
"""
|
|
1284
|
+
required_vars = ["abs_time", "cdr_efficiency", "cdr_efficiency_from_delta_diff"]
|
|
1285
|
+
for var in required_vars:
|
|
1286
|
+
if var not in ds or ds[var].size == 0:
|
|
1287
|
+
raise ValueError(f"Dataset must contain non-empty variable '{var}'.")
|
|
1288
|
+
|
|
1289
|
+
times = ds["abs_time"]
|
|
1290
|
+
|
|
1291
|
+
# Check for monotonically increasing times
|
|
1292
|
+
if not np.all(times[1:] >= times[:-1]):
|
|
1293
|
+
raise ValueError("abs_time must be strictly increasing.")
|
|
1294
|
+
|
|
1295
|
+
fig, ax = plt.subplots(figsize=(10, 4))
|
|
1296
|
+
|
|
1297
|
+
ax.plot(times, ds["cdr_efficiency"], label="from CO2 flux differences", lw=2)
|
|
1298
|
+
ax.plot(
|
|
1299
|
+
times, ds["cdr_efficiency_from_delta_diff"], label="from DIC differences", lw=2
|
|
1300
|
+
)
|
|
1301
|
+
ax.grid()
|
|
1302
|
+
ax.set_title("CDR uptake efficiency")
|
|
1303
|
+
ax.legend()
|
|
1304
|
+
|
|
1305
|
+
# Format x-axis as YYYY-MM-DD
|
|
1306
|
+
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d"))
|
|
1307
|
+
fig.autofmt_xdate()
|
|
1308
|
+
plt.show()
|
roms_tools/regrid.py
CHANGED
|
@@ -251,7 +251,12 @@ class VerticalRegridFromROMS:
|
|
|
251
251
|
ds : xarray.Dataset
|
|
252
252
|
The dataset containing the ROMS output data, which must include the vertical coordinate `s_rho`.
|
|
253
253
|
"""
|
|
254
|
-
self.grid = xgcm.Grid(
|
|
254
|
+
self.grid = xgcm.Grid(
|
|
255
|
+
ds,
|
|
256
|
+
coords={"s_rho": {"center": "s_rho"}},
|
|
257
|
+
periodic=False,
|
|
258
|
+
autoparse_metadata=False,
|
|
259
|
+
)
|
|
255
260
|
|
|
256
261
|
def apply(self, da, depth_coords, target_depth_levels, mask_edges=True):
|
|
257
262
|
"""Applies vertical regridding from ROMS to the specified target depth levels.
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import importlib.metadata
|
|
2
2
|
import logging
|
|
3
|
+
from collections import defaultdict
|
|
3
4
|
from dataclasses import dataclass, field
|
|
4
5
|
from datetime import datetime
|
|
5
6
|
from pathlib import Path
|
|
@@ -12,7 +13,13 @@ from scipy.ndimage import label
|
|
|
12
13
|
from roms_tools import Grid
|
|
13
14
|
from roms_tools.plot import line_plot, section_plot
|
|
14
15
|
from roms_tools.regrid import LateralRegridToROMS, VerticalRegridToROMS
|
|
15
|
-
from roms_tools.setup.datasets import
|
|
16
|
+
from roms_tools.setup.datasets import (
|
|
17
|
+
CESMBGCDataset,
|
|
18
|
+
GLORYSDataset,
|
|
19
|
+
GLORYSDefaultDataset,
|
|
20
|
+
RawDataSource,
|
|
21
|
+
UnifiedBGCDataset,
|
|
22
|
+
)
|
|
16
23
|
from roms_tools.setup.utils import (
|
|
17
24
|
add_time_info_to_ds,
|
|
18
25
|
compute_barotropic_velocity,
|
|
@@ -56,7 +63,7 @@ class BoundaryForcing:
|
|
|
56
63
|
If no time filtering is desired, set it to None. Default is None.
|
|
57
64
|
boundaries : Dict[str, bool], optional
|
|
58
65
|
Dictionary specifying which boundaries are forced (south, east, north, west). Default is all True.
|
|
59
|
-
source :
|
|
66
|
+
source : RawDataSource
|
|
60
67
|
Dictionary specifying the source of the boundary forcing data. Keys include:
|
|
61
68
|
|
|
62
69
|
- "name" (str): Name of the data source (e.g., "GLORYS").
|
|
@@ -64,7 +71,9 @@ class BoundaryForcing:
|
|
|
64
71
|
|
|
65
72
|
- A single string (with or without wildcards).
|
|
66
73
|
- A single Path object.
|
|
67
|
-
- A list of strings or Path objects
|
|
74
|
+
- A list of strings or Path objects.
|
|
75
|
+
If omitted, the data will be streamed via the Copernicus Marine Toolkit.
|
|
76
|
+
Note: streaming is currently not recommended due to performance limitations.
|
|
68
77
|
- "climatology" (bool): Indicates if the data is climatology data. Defaults to False.
|
|
69
78
|
|
|
70
79
|
type : str
|
|
@@ -117,7 +126,7 @@ class BoundaryForcing:
|
|
|
117
126
|
}
|
|
118
127
|
)
|
|
119
128
|
"""Dictionary specifying which boundaries are forced (south, east, north, west)."""
|
|
120
|
-
source:
|
|
129
|
+
source: RawDataSource
|
|
121
130
|
"""Dictionary specifying the source of the boundary forcing data."""
|
|
122
131
|
type: str = "physics"
|
|
123
132
|
"""Specifies the type of forcing data ("physics", "bgc")."""
|
|
@@ -150,7 +159,6 @@ class BoundaryForcing:
|
|
|
150
159
|
if self.apply_2d_horizontal_fill:
|
|
151
160
|
data.choose_subdomain(
|
|
152
161
|
target_coords,
|
|
153
|
-
buffer_points=20, # lateral fill needs good buffer from data margin
|
|
154
162
|
)
|
|
155
163
|
# Enforce double precision to ensure reproducibility
|
|
156
164
|
data.convert_to_float64()
|
|
@@ -181,8 +189,8 @@ class BoundaryForcing:
|
|
|
181
189
|
}
|
|
182
190
|
)
|
|
183
191
|
|
|
184
|
-
for direction
|
|
185
|
-
if
|
|
192
|
+
for direction, is_enabled in self.boundaries.items():
|
|
193
|
+
if is_enabled:
|
|
186
194
|
bdry_target_coords = {
|
|
187
195
|
"lat": target_coords["lat"].isel(
|
|
188
196
|
**self.bdry_coords["vector"][direction]
|
|
@@ -290,14 +298,12 @@ class BoundaryForcing:
|
|
|
290
298
|
zeta_v = zeta_v.isel(**self.bdry_coords["v"][direction])
|
|
291
299
|
|
|
292
300
|
if not self.apply_2d_horizontal_fill and bdry_data.needs_lateral_fill:
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
bdry_data.dim_names["depth"],
|
|
300
|
-
)
|
|
301
|
+
if not self.bypass_validation:
|
|
302
|
+
self._validate_1d_fill(
|
|
303
|
+
processed_fields,
|
|
304
|
+
direction,
|
|
305
|
+
bdry_data.dim_names["depth"],
|
|
306
|
+
)
|
|
301
307
|
for var_name in processed_fields:
|
|
302
308
|
processed_fields[var_name] = apply_1d_horizontal_fill(
|
|
303
309
|
processed_fields[var_name]
|
|
@@ -403,7 +409,10 @@ class BoundaryForcing:
|
|
|
403
409
|
if "name" not in self.source:
|
|
404
410
|
raise ValueError("`source` must include a 'name'.")
|
|
405
411
|
if "path" not in self.source:
|
|
406
|
-
|
|
412
|
+
if self.source["name"] != "GLORYS":
|
|
413
|
+
raise ValueError("`source` must include a 'path'.")
|
|
414
|
+
|
|
415
|
+
self.source["path"] = GLORYSDefaultDataset.dataset_name
|
|
407
416
|
|
|
408
417
|
# Set 'climatology' to False if not provided in 'source'
|
|
409
418
|
self.source = {
|
|
@@ -425,34 +434,68 @@ class BoundaryForcing:
|
|
|
425
434
|
"Sea surface height will NOT be used to adjust depth coordinates."
|
|
426
435
|
)
|
|
427
436
|
|
|
428
|
-
def _get_data(
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
437
|
+
def _get_data(
|
|
438
|
+
self,
|
|
439
|
+
) -> GLORYSDataset | GLORYSDefaultDataset | CESMBGCDataset | UnifiedBGCDataset:
|
|
440
|
+
"""Determine the correct `Dataset` type and return an instance.
|
|
441
|
+
|
|
442
|
+
Returns
|
|
443
|
+
-------
|
|
444
|
+
Dataset
|
|
445
|
+
The `Dataset` instance
|
|
446
|
+
|
|
447
|
+
"""
|
|
448
|
+
dataset_map: dict[
|
|
449
|
+
str,
|
|
450
|
+
dict[
|
|
451
|
+
str,
|
|
452
|
+
dict[
|
|
453
|
+
str,
|
|
454
|
+
type[
|
|
455
|
+
GLORYSDataset
|
|
456
|
+
| GLORYSDefaultDataset
|
|
457
|
+
| CESMBGCDataset
|
|
458
|
+
| UnifiedBGCDataset
|
|
459
|
+
],
|
|
460
|
+
],
|
|
461
|
+
],
|
|
462
|
+
] = {
|
|
463
|
+
"physics": {
|
|
464
|
+
"GLORYS": {
|
|
465
|
+
"external": GLORYSDataset,
|
|
466
|
+
"default": GLORYSDefaultDataset,
|
|
467
|
+
},
|
|
468
|
+
},
|
|
469
|
+
"bgc": {
|
|
470
|
+
"CESM_REGRIDDED": defaultdict(lambda: CESMBGCDataset),
|
|
471
|
+
"UNIFIED": defaultdict(lambda: UnifiedBGCDataset),
|
|
472
|
+
},
|
|
435
473
|
}
|
|
436
474
|
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
'Only "GLORYS" is a valid option for source["name"] when type is "physics".'
|
|
443
|
-
)
|
|
475
|
+
source_name = str(self.source["name"])
|
|
476
|
+
if source_name not in dataset_map[self.type]:
|
|
477
|
+
tpl = 'Valid options for source["name"] for type {} include: {}'
|
|
478
|
+
msg = tpl.format(self.type, " and ".join(dataset_map[self.type].keys()))
|
|
479
|
+
raise ValueError(msg)
|
|
444
480
|
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
else:
|
|
451
|
-
raise ValueError(
|
|
452
|
-
'Only "CESM_REGRIDDED" and "UNIFIED" are valid options for source["name"] when type is "bgc".'
|
|
453
|
-
)
|
|
481
|
+
has_no_path = "path" not in self.source
|
|
482
|
+
has_default_path = self.source.get("path") == GLORYSDefaultDataset.dataset_name
|
|
483
|
+
use_default = has_no_path or has_default_path
|
|
484
|
+
|
|
485
|
+
variant = "default" if use_default else "external"
|
|
454
486
|
|
|
455
|
-
|
|
487
|
+
data_type = dataset_map[self.type][source_name][variant]
|
|
488
|
+
|
|
489
|
+
if isinstance(self.source["path"], bool):
|
|
490
|
+
raise ValueError('source["path"] cannot be a boolean here')
|
|
491
|
+
|
|
492
|
+
return data_type(
|
|
493
|
+
filename=self.source["path"],
|
|
494
|
+
start_time=self.start_time,
|
|
495
|
+
end_time=self.end_time,
|
|
496
|
+
climatology=self.source["climatology"], # type: ignore[arg-type]
|
|
497
|
+
use_dask=self.use_dask,
|
|
498
|
+
)
|
|
456
499
|
|
|
457
500
|
def _set_variable_info(self, data):
|
|
458
501
|
"""Sets up a dictionary with metadata for variables based on the type of data
|
|
@@ -731,6 +774,9 @@ class BoundaryForcing:
|
|
|
731
774
|
None
|
|
732
775
|
If a boundary is divided by land, a warning is issued. No return value is provided.
|
|
733
776
|
"""
|
|
777
|
+
if not hasattr(self, "_warned_directions"):
|
|
778
|
+
self._warned_directions = set()
|
|
779
|
+
|
|
734
780
|
for var_name in processed_fields.keys():
|
|
735
781
|
if self.variable_info[var_name]["validate"]:
|
|
736
782
|
location = self.variable_info[var_name]["location"]
|
|
@@ -753,16 +799,20 @@ class BoundaryForcing:
|
|
|
753
799
|
wet_nans = xr.where(da.where(mask).isnull(), 1, 0)
|
|
754
800
|
# Apply label to find connected components of wet NaNs
|
|
755
801
|
labeled_array, num_features = label(wet_nans)
|
|
802
|
+
|
|
756
803
|
left_margin = labeled_array[0]
|
|
757
804
|
right_margin = labeled_array[-1]
|
|
758
805
|
if left_margin != 0:
|
|
759
806
|
num_features = num_features - 1
|
|
760
807
|
if right_margin != 0:
|
|
761
808
|
num_features = num_features - 1
|
|
762
|
-
|
|
809
|
+
|
|
810
|
+
if num_features > 0 and direction not in self._warned_directions:
|
|
763
811
|
logging.warning(
|
|
764
|
-
f"
|
|
812
|
+
f"The {direction}ern boundary is divided by land. "
|
|
813
|
+
"It would be safer (but slower and more memory-intensive) to use `apply_2d_horizontal_fill = True`."
|
|
765
814
|
)
|
|
815
|
+
self._warned_directions.add(direction)
|
|
766
816
|
|
|
767
817
|
def _validate(self, ds):
|
|
768
818
|
"""Validate the dataset for NaN values at the first time step (bry_time=0) for
|
|
@@ -797,8 +847,8 @@ class BoundaryForcing:
|
|
|
797
847
|
elif location == "v":
|
|
798
848
|
mask = self.grid.ds.mask_v
|
|
799
849
|
|
|
800
|
-
for direction
|
|
801
|
-
if
|
|
850
|
+
for direction, is_enabled in self.boundaries.items():
|
|
851
|
+
if is_enabled:
|
|
802
852
|
bdry_var_name = f"{var_name}_{direction}"
|
|
803
853
|
|
|
804
854
|
# Check for NaN values at the first time step using the nan_check function
|