roms-tools 3.1.1__py3-none-any.whl → 3.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. roms_tools/__init__.py +8 -1
  2. roms_tools/analysis/cdr_analysis.py +203 -0
  3. roms_tools/analysis/cdr_ensemble.py +198 -0
  4. roms_tools/analysis/roms_output.py +80 -46
  5. roms_tools/data/grids/GLORYS_global_grid.nc +0 -0
  6. roms_tools/download.py +4 -0
  7. roms_tools/plot.py +131 -30
  8. roms_tools/regrid.py +6 -1
  9. roms_tools/setup/boundary_forcing.py +94 -44
  10. roms_tools/setup/cdr_forcing.py +123 -15
  11. roms_tools/setup/cdr_release.py +161 -8
  12. roms_tools/setup/datasets.py +709 -341
  13. roms_tools/setup/grid.py +167 -139
  14. roms_tools/setup/initial_conditions.py +113 -48
  15. roms_tools/setup/mask.py +63 -7
  16. roms_tools/setup/nesting.py +67 -42
  17. roms_tools/setup/river_forcing.py +45 -19
  18. roms_tools/setup/surface_forcing.py +16 -10
  19. roms_tools/setup/tides.py +1 -2
  20. roms_tools/setup/topography.py +4 -4
  21. roms_tools/setup/utils.py +134 -22
  22. roms_tools/tests/test_analysis/test_cdr_analysis.py +144 -0
  23. roms_tools/tests/test_analysis/test_cdr_ensemble.py +202 -0
  24. roms_tools/tests/test_analysis/test_roms_output.py +61 -3
  25. roms_tools/tests/test_setup/test_boundary_forcing.py +111 -52
  26. roms_tools/tests/test_setup/test_cdr_forcing.py +54 -0
  27. roms_tools/tests/test_setup/test_cdr_release.py +118 -1
  28. roms_tools/tests/test_setup/test_datasets.py +458 -34
  29. roms_tools/tests/test_setup/test_grid.py +238 -121
  30. roms_tools/tests/test_setup/test_initial_conditions.py +94 -41
  31. roms_tools/tests/test_setup/test_surface_forcing.py +28 -3
  32. roms_tools/tests/test_setup/test_utils.py +91 -1
  33. roms_tools/tests/test_setup/test_validation.py +21 -15
  34. roms_tools/tests/test_setup/utils.py +71 -0
  35. roms_tools/tests/test_tiling/test_join.py +241 -0
  36. roms_tools/tests/test_tiling/test_partition.py +45 -0
  37. roms_tools/tests/test_utils.py +224 -2
  38. roms_tools/tiling/join.py +189 -0
  39. roms_tools/tiling/partition.py +44 -30
  40. roms_tools/utils.py +488 -161
  41. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/METADATA +15 -4
  42. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/RECORD +45 -37
  43. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/WHEEL +0 -0
  44. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/licenses/LICENSE +0 -0
  45. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/top_level.txt +0 -0
roms_tools/plot.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from typing import Any, Literal
2
2
 
3
3
  import cartopy.crs as ccrs
4
+ import matplotlib.dates as mdates
4
5
  import matplotlib.pyplot as plt
5
6
  import numpy as np
6
7
  import xarray as xr
@@ -9,18 +10,16 @@ from matplotlib.figure import Figure
9
10
 
10
11
  from roms_tools.regrid import LateralRegridFromROMS, VerticalRegridFromROMS
11
12
  from roms_tools.utils import (
12
- _generate_coordinate_range,
13
- _remove_edge_nans,
13
+ generate_coordinate_range,
14
14
  infer_nominal_horizontal_resolution,
15
15
  normalize_longitude,
16
+ remove_edge_nans,
16
17
  )
17
18
  from roms_tools.vertical_coordinate import compute_depth_coordinates
18
19
 
19
20
  LABEL_COLOR = "k"
20
21
  LABEL_SZ = 10
21
22
  FONT_SZ = 10
22
- EDGE_POS_START = "start"
23
- EDGE_POS_END = "end"
24
23
 
25
24
 
26
25
  def _add_gridlines(ax: Axes) -> None:
@@ -212,7 +211,14 @@ def plot_nesting(parent_grid_ds, child_grid_ds, parent_straddle, with_dim_names=
212
211
  return fig
213
212
 
214
213
 
215
- def section_plot(field, interface_depth=None, title="", kwargs={}, ax=None):
214
+ def section_plot(
215
+ field: xr.DataArray,
216
+ interface_depth: xr.DataArray | None = None,
217
+ title: str = "",
218
+ yincrease: bool | None = False,
219
+ kwargs: dict = {},
220
+ ax: Axes | None = None,
221
+ ):
216
222
  """Plots a vertical section of a field with optional interface depths.
217
223
 
218
224
  Parameters
@@ -224,6 +230,11 @@ def section_plot(field, interface_depth=None, title="", kwargs={}, ax=None):
224
230
  Defaults to None.
225
231
  title : str, optional
226
232
  Title of the plot. Defaults to an empty string.
233
+ yincrease : bool or None, optional
234
+ Whether to orient the y-axis with increasing values upward.
235
+ If True, y-values increase upward (standard).
236
+ If False, y-values decrease upward (inverted).
237
+ If None (default), behavior is equivalent to False (inverted axis).
227
238
  kwargs : dict, optional
228
239
  Additional keyword arguments to pass to `xarray.plot`. Defaults to an empty dictionary.
229
240
  ax : matplotlib.axes.Axes, optional
@@ -248,6 +259,8 @@ def section_plot(field, interface_depth=None, title="", kwargs={}, ax=None):
248
259
  """
249
260
  if ax is None:
250
261
  fig, ax = plt.subplots(1, 1, figsize=(9, 5))
262
+ if yincrease is None:
263
+ yincrease = False
251
264
 
252
265
  dims_to_check = ["eta_rho", "eta_v", "xi_rho", "xi_u", "lat", "lon"]
253
266
  try:
@@ -279,7 +292,7 @@ def section_plot(field, interface_depth=None, title="", kwargs={}, ax=None):
279
292
  # Handle NaNs on either horizontal end
280
293
  field = field.where(~field[depth_label].isnull(), drop=True)
281
294
 
282
- more_kwargs = {"x": xdim, "y": depth_label, "yincrease": False}
295
+ more_kwargs = {"x": xdim, "y": depth_label, "yincrease": yincrease}
283
296
 
284
297
  field.plot(**kwargs, **more_kwargs, ax=ax)
285
298
 
@@ -313,7 +326,12 @@ def section_plot(field, interface_depth=None, title="", kwargs={}, ax=None):
313
326
  return fig
314
327
 
315
328
 
316
- def profile_plot(field, title="", ax=None):
329
+ def profile_plot(
330
+ field: xr.DataArray,
331
+ title: str = "",
332
+ yincrease: bool | None = False,
333
+ ax: Axes | None = None,
334
+ ):
317
335
  """Plots a vertical profile of the given field against depth.
318
336
 
319
337
  This function generates a profile plot by plotting the field values against
@@ -326,6 +344,11 @@ def profile_plot(field, title="", ax=None):
326
344
  The field to plot, typically representing vertical profile data.
327
345
  title : str, optional
328
346
  Title of the plot. Defaults to an empty string.
347
+ yincrease : bool or None, optional
348
+ Whether to orient the y-axis with increasing values upward.
349
+ If True, y-values increase upward (standard).
350
+ If False, y-values decrease upward (inverted).
351
+ If None (default), behavior is equivalent to False (inverted axis).
329
352
  ax : matplotlib.axes.Axes, optional
330
353
  Pre-existing axes to draw the plot on. If None, a new figure and axes are created.
331
354
 
@@ -343,6 +366,9 @@ def profile_plot(field, title="", ax=None):
343
366
  -----
344
367
  - The y-axis is inverted to ensure that depth increases downward.
345
368
  """
369
+ if yincrease is None:
370
+ yincrease = False
371
+
346
372
  depths_to_check = [
347
373
  "layer_depth",
348
374
  "interface_depth",
@@ -360,8 +386,8 @@ def profile_plot(field, title="", ax=None):
360
386
 
361
387
  if ax is None:
362
388
  fig, ax = plt.subplots(1, 1, figsize=(4, 7))
363
- kwargs = {"y": depth_label, "yincrease": False}
364
- field.plot(**kwargs, linewidth=2)
389
+ kwargs = {"y": depth_label, "yincrease": yincrease}
390
+ field.plot(ax=ax, linewidth=2, **kwargs)
365
391
  ax.set_title(title)
366
392
  ax.set_ylabel("Depth [m]")
367
393
  ax.grid()
@@ -370,7 +396,12 @@ def profile_plot(field, title="", ax=None):
370
396
  return fig
371
397
 
372
398
 
373
- def line_plot(field, title="", ax=None):
399
+ def line_plot(
400
+ field: xr.DataArray,
401
+ title: str = "",
402
+ ax: Axes | None = None,
403
+ yincrease: bool | None = False,
404
+ ):
374
405
  """Plots a line graph of the given field with grey vertical bars indicating NaN
375
406
  regions.
376
407
 
@@ -382,6 +413,11 @@ def line_plot(field, title="", ax=None):
382
413
  Title of the plot. Defaults to an empty string.
383
414
  ax : matplotlib.axes.Axes, optional
384
415
  Pre-existing axes to draw the plot on. If None, a new figure and axes are created.
416
+ yincrease : bool, optional
417
+ Whether to orient the y-axis with increasing values upward.
418
+ If True, y-values increase upward (standard).
419
+ If False, y-values decrease upward (inverted).
420
+ If None (default), behavior is equivalent to True (standard axis).
385
421
 
386
422
  Returns
387
423
  -------
@@ -399,10 +435,12 @@ def line_plot(field, title="", ax=None):
399
435
  -----
400
436
  - NaN regions are identified and marked using `axvspan` with a grey shade.
401
437
  """
438
+ if yincrease is None:
439
+ yincrease = True
402
440
  if ax is None:
403
441
  fig, ax = plt.subplots(1, 1, figsize=(7, 4))
404
442
 
405
- field.plot(ax=ax, linewidth=2)
443
+ field.plot(ax=ax, linewidth=2, yincrease=yincrease)
406
444
 
407
445
  # Loop through the NaNs in the field and add grey vertical bars
408
446
  dims_to_check = ["eta_rho", "eta_v", "xi_rho", "xi_u", "lat", "lon"]
@@ -458,16 +496,16 @@ def line_plot(field, title="", ax=None):
458
496
 
459
497
 
460
498
  def _get_edge(
461
- arr: xr.DataArray, dim_name: str, pos: Literal[EDGE_POS_START, EDGE_POS_END]
499
+ arr: xr.DataArray, dim_name: str, pos: Literal["start", "end"]
462
500
  ) -> xr.DataArray:
463
501
  """Extract the first ("start") or last ("end") slice along the given dimension."""
464
- if pos == EDGE_POS_START:
502
+ if pos == "start":
465
503
  return arr.isel({dim_name: 0})
466
504
 
467
- if pos == EDGE_POS_END:
505
+ if pos == "end":
468
506
  return arr.isel({dim_name: -1})
469
507
 
470
- raise ValueError(f"pos must be {EDGE_POS_START} or {EDGE_POS_END}")
508
+ raise ValueError("pos must be `start` or `end`")
471
509
 
472
510
 
473
511
  def _add_boundary_to_ax(
@@ -499,23 +537,23 @@ def _add_boundary_to_ax(
499
537
 
500
538
  edges = [
501
539
  (
502
- _get_edge(lon_deg, xi_dim, EDGE_POS_START),
503
- _get_edge(lat_deg, xi_dim, EDGE_POS_START),
540
+ _get_edge(lon_deg, xi_dim, "start"),
541
+ _get_edge(lat_deg, xi_dim, "start"),
504
542
  r"$\eta$",
505
543
  ), # left
506
544
  (
507
- _get_edge(lon_deg, xi_dim, EDGE_POS_END),
508
- _get_edge(lat_deg, xi_dim, EDGE_POS_END),
545
+ _get_edge(lon_deg, xi_dim, "end"),
546
+ _get_edge(lat_deg, xi_dim, "end"),
509
547
  r"$\eta$",
510
548
  ), # right
511
549
  (
512
- _get_edge(lon_deg, eta_dim, EDGE_POS_START),
513
- _get_edge(lat_deg, eta_dim, EDGE_POS_START),
550
+ _get_edge(lon_deg, eta_dim, "start"),
551
+ _get_edge(lat_deg, eta_dim, "start"),
514
552
  r"$\xi$",
515
553
  ), # bottom
516
554
  (
517
- _get_edge(lon_deg, eta_dim, EDGE_POS_END),
518
- _get_edge(lat_deg, eta_dim, EDGE_POS_END),
555
+ _get_edge(lon_deg, eta_dim, "end"),
556
+ _get_edge(lat_deg, eta_dim, "end"),
519
557
  r"$\xi$",
520
558
  ), # top
521
559
  ]
@@ -775,11 +813,12 @@ def plot(
775
813
  depth_contours: bool = False,
776
814
  layer_contours: bool = False,
777
815
  max_nr_layer_contours: int | None = 10,
816
+ yincrease: bool | None = None,
778
817
  use_coarse_grid: bool = False,
779
818
  with_dim_names: bool = False,
780
819
  ax: Axes | None = None,
781
820
  save_path: str | None = None,
782
- cmap_name: str | None = "YlOrRd",
821
+ cmap_name: str = "YlOrRd",
783
822
  add_colorbar: bool = True,
784
823
  ) -> None:
785
824
  """Generate a plot of a 2D or 3D ROMS field for a horizontal or vertical slice.
@@ -838,6 +877,12 @@ def plot(
838
877
  max_nr_layer_contours : int, optional
839
878
  Maximum number of vertical layer contours to draw. Default is 10.
840
879
 
880
+ yincrease: bool, optional
881
+ If True, the y-axis values increase upward (standard orientation).
882
+ If False, the y-axis values decrease upward (inverted axis).
883
+ If None (default), the orientation is determined by the default behavior
884
+ of the underlying plotting function.
885
+
841
886
  use_coarse_grid : bool, optional
842
887
  Use precomputed coarse-resolution grid. Default is False.
843
888
 
@@ -1006,7 +1051,7 @@ def plot(
1006
1051
  title = title + f", lat = {lat}°N"
1007
1052
  else:
1008
1053
  resolution = infer_nominal_horizontal_resolution(grid_ds)
1009
- lats = _generate_coordinate_range(
1054
+ lats = generate_coordinate_range(
1010
1055
  field.lat.min().values, field.lat.max().values, resolution
1011
1056
  )
1012
1057
  lats = xr.DataArray(lats, dims=["lat"], attrs={"units": "°N"})
@@ -1016,7 +1061,7 @@ def plot(
1016
1061
  title = title + f", lon = {lon}°E"
1017
1062
  else:
1018
1063
  resolution = infer_nominal_horizontal_resolution(grid_ds, lat)
1019
- lons = _generate_coordinate_range(
1064
+ lons = generate_coordinate_range(
1020
1065
  field.lon.min().values, field.lon.max().values, resolution
1021
1066
  )
1022
1067
  lons = xr.DataArray(lons, dims=["lon"], attrs={"units": "°E"})
@@ -1033,11 +1078,11 @@ def plot(
1033
1078
  field = field.assign_coords({"layer_depth": layer_depth})
1034
1079
 
1035
1080
  if lat is not None:
1036
- field, layer_depth = _remove_edge_nans(
1081
+ field, layer_depth = remove_edge_nans(
1037
1082
  field, "lon", layer_depth if "layer_depth" in locals() else None
1038
1083
  )
1039
1084
  if lon is not None:
1040
- field, layer_depth = _remove_edge_nans(
1085
+ field, layer_depth = remove_edge_nans(
1041
1086
  field, "lat", layer_depth if "layer_depth" in locals() else None
1042
1087
  )
1043
1088
 
@@ -1086,14 +1131,15 @@ def plot(
1086
1131
  field,
1087
1132
  interface_depth=interface_depth,
1088
1133
  title=title,
1134
+ yincrease=yincrease,
1089
1135
  kwargs={**kwargs, "add_colorbar": add_colorbar},
1090
1136
  ax=ax,
1091
1137
  )
1092
1138
  else:
1093
1139
  if "s_rho" in field.dims:
1094
- fig = profile_plot(field, title=title, ax=ax)
1140
+ fig = profile_plot(field, title=title, yincrease=yincrease, ax=ax)
1095
1141
  else:
1096
- fig = line_plot(field, title=title, ax=ax)
1142
+ fig = line_plot(field, title=title, ax=ax, yincrease=yincrease)
1097
1143
 
1098
1144
  if save_path:
1099
1145
  plt.savefig(save_path, dpi=300, bbox_inches="tight")
@@ -1205,3 +1251,58 @@ def plot_location(
1205
1251
 
1206
1252
  if include_legend:
1207
1253
  ax.legend(loc="center left", bbox_to_anchor=(1.1, 0.5))
1254
+
1255
+
1256
+ def plot_uptake_efficiency(ds: xr.Dataset) -> None:
1257
+ """
1258
+ Plot Carbon Dioxide Removal (CDR) uptake efficiency over time.
1259
+
1260
+ This function plots two estimates of uptake efficiency stored in the dataset:
1261
+ 1. `cdr_efficiency`, computed from CO2 flux differences.
1262
+ 2. `cdr_efficiency_from_delta_diff`, computed from DIC differences.
1263
+
1264
+ The x-axis shows absolute time, formatted as YYYY-MM-DD, and the y-axis shows
1265
+ the uptake efficiency values. The plot includes a legend and grid for clarity.
1266
+
1267
+ Parameters
1268
+ ----------
1269
+ ds : xarray.Dataset
1270
+ Dataset containing the following variables:
1271
+ - "abs_time": array of timestamps (datetime-like)
1272
+ - "cdr_efficiency": uptake efficiency from flux differences
1273
+ - "cdr_efficiency_from_delta_diff": uptake efficiency from DIC differences
1274
+
1275
+ Raises
1276
+ ------
1277
+ ValueError
1278
+ If required variables are missing or empty.
1279
+
1280
+ Returns
1281
+ -------
1282
+ None
1283
+ """
1284
+ required_vars = ["abs_time", "cdr_efficiency", "cdr_efficiency_from_delta_diff"]
1285
+ for var in required_vars:
1286
+ if var not in ds or ds[var].size == 0:
1287
+ raise ValueError(f"Dataset must contain non-empty variable '{var}'.")
1288
+
1289
+ times = ds["abs_time"]
1290
+
1291
+ # Check for monotonically increasing times
1292
+ if not np.all(times[1:] >= times[:-1]):
1293
+ raise ValueError("abs_time must be strictly increasing.")
1294
+
1295
+ fig, ax = plt.subplots(figsize=(10, 4))
1296
+
1297
+ ax.plot(times, ds["cdr_efficiency"], label="from CO2 flux differences", lw=2)
1298
+ ax.plot(
1299
+ times, ds["cdr_efficiency_from_delta_diff"], label="from DIC differences", lw=2
1300
+ )
1301
+ ax.grid()
1302
+ ax.set_title("CDR uptake efficiency")
1303
+ ax.legend()
1304
+
1305
+ # Format x-axis as YYYY-MM-DD
1306
+ ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d"))
1307
+ fig.autofmt_xdate()
1308
+ plt.show()
roms_tools/regrid.py CHANGED
@@ -251,7 +251,12 @@ class VerticalRegridFromROMS:
251
251
  ds : xarray.Dataset
252
252
  The dataset containing the ROMS output data, which must include the vertical coordinate `s_rho`.
253
253
  """
254
- self.grid = xgcm.Grid(ds, coords={"s_rho": {"center": "s_rho"}}, periodic=False)
254
+ self.grid = xgcm.Grid(
255
+ ds,
256
+ coords={"s_rho": {"center": "s_rho"}},
257
+ periodic=False,
258
+ autoparse_metadata=False,
259
+ )
255
260
 
256
261
  def apply(self, da, depth_coords, target_depth_levels, mask_edges=True):
257
262
  """Applies vertical regridding from ROMS to the specified target depth levels.
@@ -1,5 +1,6 @@
1
1
  import importlib.metadata
2
2
  import logging
3
+ from collections import defaultdict
3
4
  from dataclasses import dataclass, field
4
5
  from datetime import datetime
5
6
  from pathlib import Path
@@ -12,7 +13,13 @@ from scipy.ndimage import label
12
13
  from roms_tools import Grid
13
14
  from roms_tools.plot import line_plot, section_plot
14
15
  from roms_tools.regrid import LateralRegridToROMS, VerticalRegridToROMS
15
- from roms_tools.setup.datasets import CESMBGCDataset, GLORYSDataset, UnifiedBGCDataset
16
+ from roms_tools.setup.datasets import (
17
+ CESMBGCDataset,
18
+ GLORYSDataset,
19
+ GLORYSDefaultDataset,
20
+ RawDataSource,
21
+ UnifiedBGCDataset,
22
+ )
16
23
  from roms_tools.setup.utils import (
17
24
  add_time_info_to_ds,
18
25
  compute_barotropic_velocity,
@@ -56,7 +63,7 @@ class BoundaryForcing:
56
63
  If no time filtering is desired, set it to None. Default is None.
57
64
  boundaries : Dict[str, bool], optional
58
65
  Dictionary specifying which boundaries are forced (south, east, north, west). Default is all True.
59
- source : Dict[str, Union[str, Path, List[Union[str, Path]]], bool]
66
+ source : RawDataSource
60
67
  Dictionary specifying the source of the boundary forcing data. Keys include:
61
68
 
62
69
  - "name" (str): Name of the data source (e.g., "GLORYS").
@@ -64,7 +71,9 @@ class BoundaryForcing:
64
71
 
65
72
  - A single string (with or without wildcards).
66
73
  - A single Path object.
67
- - A list of strings or Path objects containing multiple files.
74
+ - A list of strings or Path objects.
75
+ If omitted, the data will be streamed via the Copernicus Marine Toolkit.
76
+ Note: streaming is currently not recommended due to performance limitations.
68
77
  - "climatology" (bool): Indicates if the data is climatology data. Defaults to False.
69
78
 
70
79
  type : str
@@ -117,7 +126,7 @@ class BoundaryForcing:
117
126
  }
118
127
  )
119
128
  """Dictionary specifying which boundaries are forced (south, east, north, west)."""
120
- source: dict[str, str | Path | list[str | Path]]
129
+ source: RawDataSource
121
130
  """Dictionary specifying the source of the boundary forcing data."""
122
131
  type: str = "physics"
123
132
  """Specifies the type of forcing data ("physics", "bgc")."""
@@ -150,7 +159,6 @@ class BoundaryForcing:
150
159
  if self.apply_2d_horizontal_fill:
151
160
  data.choose_subdomain(
152
161
  target_coords,
153
- buffer_points=20, # lateral fill needs good buffer from data margin
154
162
  )
155
163
  # Enforce double precision to ensure reproducibility
156
164
  data.convert_to_float64()
@@ -181,8 +189,8 @@ class BoundaryForcing:
181
189
  }
182
190
  )
183
191
 
184
- for direction in ["south", "east", "north", "west"]:
185
- if self.boundaries[direction]:
192
+ for direction, is_enabled in self.boundaries.items():
193
+ if is_enabled:
186
194
  bdry_target_coords = {
187
195
  "lat": target_coords["lat"].isel(
188
196
  **self.bdry_coords["vector"][direction]
@@ -290,14 +298,12 @@ class BoundaryForcing:
290
298
  zeta_v = zeta_v.isel(**self.bdry_coords["v"][direction])
291
299
 
292
300
  if not self.apply_2d_horizontal_fill and bdry_data.needs_lateral_fill:
293
- logging.info(
294
- f"Applying 1D horizontal fill to {direction}ern boundary."
295
- )
296
- self._validate_1d_fill(
297
- processed_fields,
298
- direction,
299
- bdry_data.dim_names["depth"],
300
- )
301
+ if not self.bypass_validation:
302
+ self._validate_1d_fill(
303
+ processed_fields,
304
+ direction,
305
+ bdry_data.dim_names["depth"],
306
+ )
301
307
  for var_name in processed_fields:
302
308
  processed_fields[var_name] = apply_1d_horizontal_fill(
303
309
  processed_fields[var_name]
@@ -403,7 +409,10 @@ class BoundaryForcing:
403
409
  if "name" not in self.source:
404
410
  raise ValueError("`source` must include a 'name'.")
405
411
  if "path" not in self.source:
406
- raise ValueError("`source` must include a 'path'.")
412
+ if self.source["name"] != "GLORYS":
413
+ raise ValueError("`source` must include a 'path'.")
414
+
415
+ self.source["path"] = GLORYSDefaultDataset.dataset_name
407
416
 
408
417
  # Set 'climatology' to False if not provided in 'source'
409
418
  self.source = {
@@ -425,34 +434,68 @@ class BoundaryForcing:
425
434
  "Sea surface height will NOT be used to adjust depth coordinates."
426
435
  )
427
436
 
428
- def _get_data(self):
429
- data_dict = {
430
- "filename": self.source["path"],
431
- "start_time": self.start_time,
432
- "end_time": self.end_time,
433
- "climatology": self.source["climatology"],
434
- "use_dask": self.use_dask,
437
+ def _get_data(
438
+ self,
439
+ ) -> GLORYSDataset | GLORYSDefaultDataset | CESMBGCDataset | UnifiedBGCDataset:
440
+ """Determine the correct `Dataset` type and return an instance.
441
+
442
+ Returns
443
+ -------
444
+ Dataset
445
+ The `Dataset` instance
446
+
447
+ """
448
+ dataset_map: dict[
449
+ str,
450
+ dict[
451
+ str,
452
+ dict[
453
+ str,
454
+ type[
455
+ GLORYSDataset
456
+ | GLORYSDefaultDataset
457
+ | CESMBGCDataset
458
+ | UnifiedBGCDataset
459
+ ],
460
+ ],
461
+ ],
462
+ ] = {
463
+ "physics": {
464
+ "GLORYS": {
465
+ "external": GLORYSDataset,
466
+ "default": GLORYSDefaultDataset,
467
+ },
468
+ },
469
+ "bgc": {
470
+ "CESM_REGRIDDED": defaultdict(lambda: CESMBGCDataset),
471
+ "UNIFIED": defaultdict(lambda: UnifiedBGCDataset),
472
+ },
435
473
  }
436
474
 
437
- if self.type == "physics":
438
- if self.source["name"] == "GLORYS":
439
- data = GLORYSDataset(**data_dict)
440
- else:
441
- raise ValueError(
442
- 'Only "GLORYS" is a valid option for source["name"] when type is "physics".'
443
- )
475
+ source_name = str(self.source["name"])
476
+ if source_name not in dataset_map[self.type]:
477
+ tpl = 'Valid options for source["name"] for type {} include: {}'
478
+ msg = tpl.format(self.type, " and ".join(dataset_map[self.type].keys()))
479
+ raise ValueError(msg)
444
480
 
445
- elif self.type == "bgc":
446
- if self.source["name"] == "CESM_REGRIDDED":
447
- data = CESMBGCDataset(**data_dict)
448
- elif self.source["name"] == "UNIFIED":
449
- data = UnifiedBGCDataset(**data_dict)
450
- else:
451
- raise ValueError(
452
- 'Only "CESM_REGRIDDED" and "UNIFIED" are valid options for source["name"] when type is "bgc".'
453
- )
481
+ has_no_path = "path" not in self.source
482
+ has_default_path = self.source.get("path") == GLORYSDefaultDataset.dataset_name
483
+ use_default = has_no_path or has_default_path
484
+
485
+ variant = "default" if use_default else "external"
454
486
 
455
- return data
487
+ data_type = dataset_map[self.type][source_name][variant]
488
+
489
+ if isinstance(self.source["path"], bool):
490
+ raise ValueError('source["path"] cannot be a boolean here')
491
+
492
+ return data_type(
493
+ filename=self.source["path"],
494
+ start_time=self.start_time,
495
+ end_time=self.end_time,
496
+ climatology=self.source["climatology"], # type: ignore[arg-type]
497
+ use_dask=self.use_dask,
498
+ )
456
499
 
457
500
  def _set_variable_info(self, data):
458
501
  """Sets up a dictionary with metadata for variables based on the type of data
@@ -731,6 +774,9 @@ class BoundaryForcing:
731
774
  None
732
775
  If a boundary is divided by land, a warning is issued. No return value is provided.
733
776
  """
777
+ if not hasattr(self, "_warned_directions"):
778
+ self._warned_directions = set()
779
+
734
780
  for var_name in processed_fields.keys():
735
781
  if self.variable_info[var_name]["validate"]:
736
782
  location = self.variable_info[var_name]["location"]
@@ -753,16 +799,20 @@ class BoundaryForcing:
753
799
  wet_nans = xr.where(da.where(mask).isnull(), 1, 0)
754
800
  # Apply label to find connected components of wet NaNs
755
801
  labeled_array, num_features = label(wet_nans)
802
+
756
803
  left_margin = labeled_array[0]
757
804
  right_margin = labeled_array[-1]
758
805
  if left_margin != 0:
759
806
  num_features = num_features - 1
760
807
  if right_margin != 0:
761
808
  num_features = num_features - 1
762
- if num_features > 0:
809
+
810
+ if num_features > 0 and direction not in self._warned_directions:
763
811
  logging.warning(
764
- f"For {var_name}, the {direction}ern boundary is divided by land. It would be safer (but slower) to use `apply_2d_horizontal_fill = True`."
812
+ f"The {direction}ern boundary is divided by land. "
813
+ "It would be safer (but slower and more memory-intensive) to use `apply_2d_horizontal_fill = True`."
765
814
  )
815
+ self._warned_directions.add(direction)
766
816
 
767
817
  def _validate(self, ds):
768
818
  """Validate the dataset for NaN values at the first time step (bry_time=0) for
@@ -797,8 +847,8 @@ class BoundaryForcing:
797
847
  elif location == "v":
798
848
  mask = self.grid.ds.mask_v
799
849
 
800
- for direction in ["south", "east", "north", "west"]:
801
- if self.boundaries[direction]:
850
+ for direction, is_enabled in self.boundaries.items():
851
+ if is_enabled:
802
852
  bdry_var_name = f"{var_name}_{direction}"
803
853
 
804
854
  # Check for NaN values at the first time step using the nan_check function