roms-tools 3.1.1__py3-none-any.whl → 3.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
roms_tools/__init__.py CHANGED
@@ -20,5 +20,9 @@ from roms_tools.setup.surface_forcing import SurfaceForcing # noqa: F401
20
20
  from roms_tools.setup.tides import TidalForcing # noqa: F401
21
21
  from roms_tools.tiling.partition import partition_netcdf # noqa: F401
22
22
 
23
+
23
24
  # Configure logging when the package is imported
24
- logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s")
25
+ LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
26
+ DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
27
+
28
+ logging.basicConfig(level=logging.INFO, format=LOG_FORMAT, datefmt=DATE_FORMAT)
roms_tools/plot.py CHANGED
@@ -212,7 +212,14 @@ def plot_nesting(parent_grid_ds, child_grid_ds, parent_straddle, with_dim_names=
212
212
  return fig
213
213
 
214
214
 
215
- def section_plot(field, interface_depth=None, title="", kwargs={}, ax=None):
215
+ def section_plot(
216
+ field: xr.DataArray,
217
+ interface_depth: xr.DataArray | None = None,
218
+ title: str = "",
219
+ yincrease: bool | None = False,
220
+ kwargs: dict = {},
221
+ ax: Axes | None = None,
222
+ ):
216
223
  """Plots a vertical section of a field with optional interface depths.
217
224
 
218
225
  Parameters
@@ -224,6 +231,11 @@ def section_plot(field, interface_depth=None, title="", kwargs={}, ax=None):
224
231
  Defaults to None.
225
232
  title : str, optional
226
233
  Title of the plot. Defaults to an empty string.
234
+ yincrease : bool or None, optional
235
+ Whether to orient the y-axis with increasing values upward.
236
+ If True, y-values increase upward (standard).
237
+ If False, y-values decrease upward (inverted).
238
+ If None (default), behavior is equivalent to False (inverted axis).
227
239
  kwargs : dict, optional
228
240
  Additional keyword arguments to pass to `xarray.plot`. Defaults to an empty dictionary.
229
241
  ax : matplotlib.axes.Axes, optional
@@ -248,6 +260,8 @@ def section_plot(field, interface_depth=None, title="", kwargs={}, ax=None):
248
260
  """
249
261
  if ax is None:
250
262
  fig, ax = plt.subplots(1, 1, figsize=(9, 5))
263
+ if yincrease is None:
264
+ yincrease = False
251
265
 
252
266
  dims_to_check = ["eta_rho", "eta_v", "xi_rho", "xi_u", "lat", "lon"]
253
267
  try:
@@ -279,7 +293,7 @@ def section_plot(field, interface_depth=None, title="", kwargs={}, ax=None):
279
293
  # Handle NaNs on either horizontal end
280
294
  field = field.where(~field[depth_label].isnull(), drop=True)
281
295
 
282
- more_kwargs = {"x": xdim, "y": depth_label, "yincrease": False}
296
+ more_kwargs = {"x": xdim, "y": depth_label, "yincrease": yincrease}
283
297
 
284
298
  field.plot(**kwargs, **more_kwargs, ax=ax)
285
299
 
@@ -313,7 +327,12 @@ def section_plot(field, interface_depth=None, title="", kwargs={}, ax=None):
313
327
  return fig
314
328
 
315
329
 
316
- def profile_plot(field, title="", ax=None):
330
+ def profile_plot(
331
+ field: xr.DataArray,
332
+ title: str = "",
333
+ yincrease: bool | None = False,
334
+ ax: Axes | None = None,
335
+ ):
317
336
  """Plots a vertical profile of the given field against depth.
318
337
 
319
338
  This function generates a profile plot by plotting the field values against
@@ -326,6 +345,11 @@ def profile_plot(field, title="", ax=None):
326
345
  The field to plot, typically representing vertical profile data.
327
346
  title : str, optional
328
347
  Title of the plot. Defaults to an empty string.
348
+ yincrease : bool or None, optional
349
+ Whether to orient the y-axis with increasing values upward.
350
+ If True, y-values increase upward (standard).
351
+ If False, y-values decrease upward (inverted).
352
+ If None (default), behavior is equivalent to False (inverted axis).
329
353
  ax : matplotlib.axes.Axes, optional
330
354
  Pre-existing axes to draw the plot on. If None, a new figure and axes are created.
331
355
 
@@ -343,6 +367,9 @@ def profile_plot(field, title="", ax=None):
343
367
  -----
344
368
  - The y-axis is inverted to ensure that depth increases downward.
345
369
  """
370
+ if yincrease is None:
371
+ yincrease = False
372
+
346
373
  depths_to_check = [
347
374
  "layer_depth",
348
375
  "interface_depth",
@@ -360,8 +387,8 @@ def profile_plot(field, title="", ax=None):
360
387
 
361
388
  if ax is None:
362
389
  fig, ax = plt.subplots(1, 1, figsize=(4, 7))
363
- kwargs = {"y": depth_label, "yincrease": False}
364
- field.plot(**kwargs, linewidth=2)
390
+ kwargs = {"y": depth_label, "yincrease": yincrease}
391
+ field.plot(ax=ax, linewidth=2, **kwargs)
365
392
  ax.set_title(title)
366
393
  ax.set_ylabel("Depth [m]")
367
394
  ax.grid()
@@ -370,7 +397,12 @@ def profile_plot(field, title="", ax=None):
370
397
  return fig
371
398
 
372
399
 
373
- def line_plot(field, title="", ax=None):
400
+ def line_plot(
401
+ field: xr.DataArray,
402
+ title: str = "",
403
+ ax: Axes | None = None,
404
+ yincrease: bool | None = False,
405
+ ):
374
406
  """Plots a line graph of the given field with grey vertical bars indicating NaN
375
407
  regions.
376
408
 
@@ -382,6 +414,11 @@ def line_plot(field, title="", ax=None):
382
414
  Title of the plot. Defaults to an empty string.
383
415
  ax : matplotlib.axes.Axes, optional
384
416
  Pre-existing axes to draw the plot on. If None, a new figure and axes are created.
417
+ yincrease : bool, optional
418
+ Whether to orient the y-axis with increasing values upward.
419
+ If True, y-values increase upward (standard).
420
+ If False, y-values decrease upward (inverted).
421
+ If None (default), behavior is equivalent to True (standard axis).
385
422
 
386
423
  Returns
387
424
  -------
@@ -399,10 +436,12 @@ def line_plot(field, title="", ax=None):
399
436
  -----
400
437
  - NaN regions are identified and marked using `axvspan` with a grey shade.
401
438
  """
439
+ if yincrease is None:
440
+ yincrease = True
402
441
  if ax is None:
403
442
  fig, ax = plt.subplots(1, 1, figsize=(7, 4))
404
443
 
405
- field.plot(ax=ax, linewidth=2)
444
+ field.plot(ax=ax, linewidth=2, yincrease=yincrease)
406
445
 
407
446
  # Loop through the NaNs in the field and add grey vertical bars
408
447
  dims_to_check = ["eta_rho", "eta_v", "xi_rho", "xi_u", "lat", "lon"]
@@ -775,6 +814,7 @@ def plot(
775
814
  depth_contours: bool = False,
776
815
  layer_contours: bool = False,
777
816
  max_nr_layer_contours: int | None = 10,
817
+ yincrease: bool | None = None,
778
818
  use_coarse_grid: bool = False,
779
819
  with_dim_names: bool = False,
780
820
  ax: Axes | None = None,
@@ -838,6 +878,12 @@ def plot(
838
878
  max_nr_layer_contours : int, optional
839
879
  Maximum number of vertical layer contours to draw. Default is 10.
840
880
 
881
+ yincrease: bool, optional
882
+ If True, the y-axis values increase upward (standard orientation).
883
+ If False, the y-axis values decrease upward (inverted axis).
884
+ If None (default), the orientation is determined by the default behavior
885
+ of the underlying plotting function.
886
+
841
887
  use_coarse_grid : bool, optional
842
888
  Use precomputed coarse-resolution grid. Default is False.
843
889
 
@@ -1086,14 +1132,15 @@ def plot(
1086
1132
  field,
1087
1133
  interface_depth=interface_depth,
1088
1134
  title=title,
1135
+ yincrease=yincrease,
1089
1136
  kwargs={**kwargs, "add_colorbar": add_colorbar},
1090
1137
  ax=ax,
1091
1138
  )
1092
1139
  else:
1093
1140
  if "s_rho" in field.dims:
1094
- fig = profile_plot(field, title=title, ax=ax)
1141
+ fig = profile_plot(field, title=title, yincrease=yincrease, ax=ax)
1095
1142
  else:
1096
- fig = line_plot(field, title=title, ax=ax)
1143
+ fig = line_plot(field, title=title, ax=ax, yincrease=yincrease)
1097
1144
 
1098
1145
  if save_path:
1099
1146
  plt.savefig(save_path, dpi=300, bbox_inches="tight")
roms_tools/regrid.py CHANGED
@@ -251,7 +251,12 @@ class VerticalRegridFromROMS:
251
251
  ds : xarray.Dataset
252
252
  The dataset containing the ROMS output data, which must include the vertical coordinate `s_rho`.
253
253
  """
254
- self.grid = xgcm.Grid(ds, coords={"s_rho": {"center": "s_rho"}}, periodic=False)
254
+ self.grid = xgcm.Grid(
255
+ ds,
256
+ coords={"s_rho": {"center": "s_rho"}},
257
+ periodic=False,
258
+ autoparse_metadata=False,
259
+ )
255
260
 
256
261
  def apply(self, da, depth_coords, target_depth_levels, mask_edges=True):
257
262
  """Applies vertical regridding from ROMS to the specified target depth levels.
@@ -1,5 +1,6 @@
1
1
  import importlib.metadata
2
2
  import logging
3
+ from collections import defaultdict
3
4
  from dataclasses import dataclass, field
4
5
  from datetime import datetime
5
6
  from pathlib import Path
@@ -12,7 +13,13 @@ from scipy.ndimage import label
12
13
  from roms_tools import Grid
13
14
  from roms_tools.plot import line_plot, section_plot
14
15
  from roms_tools.regrid import LateralRegridToROMS, VerticalRegridToROMS
15
- from roms_tools.setup.datasets import CESMBGCDataset, GLORYSDataset, UnifiedBGCDataset
16
+ from roms_tools.setup.datasets import (
17
+ CESMBGCDataset,
18
+ Dataset,
19
+ GLORYSDataset,
20
+ GLORYSDefaultDataset,
21
+ UnifiedBGCDataset,
22
+ )
16
23
  from roms_tools.setup.utils import (
17
24
  add_time_info_to_ds,
18
25
  compute_barotropic_velocity,
@@ -181,8 +188,8 @@ class BoundaryForcing:
181
188
  }
182
189
  )
183
190
 
184
- for direction in ["south", "east", "north", "west"]:
185
- if self.boundaries[direction]:
191
+ for direction, is_enabled in self.boundaries.items():
192
+ if is_enabled:
186
193
  bdry_target_coords = {
187
194
  "lat": target_coords["lat"].isel(
188
195
  **self.bdry_coords["vector"][direction]
@@ -403,7 +410,10 @@ class BoundaryForcing:
403
410
  if "name" not in self.source:
404
411
  raise ValueError("`source` must include a 'name'.")
405
412
  if "path" not in self.source:
406
- raise ValueError("`source` must include a 'path'.")
413
+ if self.source["name"] != "GLORYS":
414
+ raise ValueError("`source` must include a 'path'.")
415
+
416
+ self.source["path"] = GLORYSDefaultDataset.dataset_name
407
417
 
408
418
  # Set 'climatology' to False if not provided in 'source'
409
419
  self.source = {
@@ -425,34 +435,49 @@ class BoundaryForcing:
425
435
  "Sea surface height will NOT be used to adjust depth coordinates."
426
436
  )
427
437
 
428
- def _get_data(self):
429
- data_dict = {
430
- "filename": self.source["path"],
431
- "start_time": self.start_time,
432
- "end_time": self.end_time,
433
- "climatology": self.source["climatology"],
434
- "use_dask": self.use_dask,
438
+ def _get_data(self) -> Dataset:
439
+ """Determine the correct `Dataset` type and return an instance.
440
+
441
+ Returns
442
+ -------
443
+ Dataset
444
+ The `Dataset` instance
445
+
446
+ """
447
+ dataset_map: dict[str, dict[str, dict[str, type[Dataset]]]] = {
448
+ "physics": {
449
+ "GLORYS": {
450
+ "external": GLORYSDataset,
451
+ "default": GLORYSDefaultDataset,
452
+ },
453
+ },
454
+ "bgc": {
455
+ "CESM_REGRIDDED": defaultdict(lambda: CESMBGCDataset),
456
+ "UNIFIED": defaultdict(lambda: UnifiedBGCDataset),
457
+ },
435
458
  }
436
459
 
437
- if self.type == "physics":
438
- if self.source["name"] == "GLORYS":
439
- data = GLORYSDataset(**data_dict)
440
- else:
441
- raise ValueError(
442
- 'Only "GLORYS" is a valid option for source["name"] when type is "physics".'
443
- )
460
+ source_name = str(self.source["name"])
461
+ if source_name not in dataset_map[self.type]:
462
+ tpl = 'Valid options for source["name"] for type {} include: {}'
463
+ msg = tpl.format(self.type, " and ".join(dataset_map[self.type].keys()))
464
+ raise ValueError(msg)
444
465
 
445
- elif self.type == "bgc":
446
- if self.source["name"] == "CESM_REGRIDDED":
447
- data = CESMBGCDataset(**data_dict)
448
- elif self.source["name"] == "UNIFIED":
449
- data = UnifiedBGCDataset(**data_dict)
450
- else:
451
- raise ValueError(
452
- 'Only "CESM_REGRIDDED" and "UNIFIED" are valid options for source["name"] when type is "bgc".'
453
- )
466
+ has_no_path = "path" not in self.source
467
+ has_default_path = self.source.get("path") == GLORYSDefaultDataset.dataset_name
468
+ use_default = has_no_path or has_default_path
469
+
470
+ variant = "default" if use_default else "external"
471
+
472
+ data_type = dataset_map[self.type][source_name][variant]
454
473
 
455
- return data
474
+ return data_type(
475
+ filename=self.source["path"],
476
+ start_time=self.start_time,
477
+ end_time=self.end_time,
478
+ climatology=self.source["climatology"],
479
+ use_dask=self.use_dask,
480
+ ) # type: ignore
456
481
 
457
482
  def _set_variable_info(self, data):
458
483
  """Sets up a dictionary with metadata for variables based on the type of data
@@ -797,8 +822,8 @@ class BoundaryForcing:
797
822
  elif location == "v":
798
823
  mask = self.grid.ds.mask_v
799
824
 
800
- for direction in ["south", "east", "north", "west"]:
801
- if self.boundaries[direction]:
825
+ for direction, is_enabled in self.boundaries.items():
826
+ if is_enabled:
802
827
  bdry_var_name = f"{var_name}_{direction}"
803
828
 
804
829
  # Check for NaN values at the first time step using the nan_check function
@@ -1039,18 +1039,12 @@ def _map_3d_gaussian(
1039
1039
  # Stack 2D distribution at that vertical level
1040
1040
  distribution_3d[{"s_rho": vertical_idx}] = distribution_2d
1041
1041
  else:
1042
- # Compute layer thickness
1043
- depth_interface = compute_depth_coordinates(
1044
- grid.ds, zeta=0, depth_type="interface", location="rho"
1045
- )
1046
- dz = depth_interface.diff("s_w").rename({"s_w": "s_rho"})
1047
-
1048
1042
  # Compute vertical Gaussian shape
1049
1043
  exponent = -(((depth - release.depth) / release.vsc) ** 2)
1050
1044
  vertical_profile = np.exp(exponent)
1051
1045
 
1052
1046
  # Apply vertical Gaussian scaling
1053
- distribution_3d = distribution_2d * vertical_profile * dz
1047
+ distribution_3d = distribution_2d * vertical_profile
1054
1048
 
1055
1049
  # Normalize
1056
1050
  distribution_3d /= release.vsc * np.sqrt(np.pi)
@@ -1,9 +1,13 @@
1
+ import importlib.util
1
2
  import logging
2
3
  import time
3
4
  from collections import Counter, defaultdict
5
+ from collections.abc import Callable
4
6
  from dataclasses import dataclass, field
5
7
  from datetime import datetime, timedelta
6
8
  from pathlib import Path
9
+ from types import ModuleType
10
+ from typing import ClassVar
7
11
 
8
12
  import numpy as np
9
13
  import xarray as xr
@@ -25,7 +29,7 @@ from roms_tools.setup.utils import (
25
29
  interpolate_from_climatology,
26
30
  one_dim_fill,
27
31
  )
28
- from roms_tools.utils import _has_gcsfs, _load_data
32
+ from roms_tools.utils import _get_pkg_error_msg, _has_gcsfs, _load_data
29
33
 
30
34
  # lat-lon datasets
31
35
 
@@ -96,17 +100,18 @@ class Dataset:
96
100
  use_dask: bool | None = False
97
101
  apply_post_processing: bool | None = True
98
102
  read_zarr: bool | None = False
103
+ ds_loader_fn: Callable[[], xr.Dataset] | None = None
99
104
 
100
105
  is_global: bool = field(init=False, repr=False)
101
106
  ds: xr.Dataset = field(init=False, repr=False)
102
107
 
103
- def __post_init__(self):
104
- """
105
- Post-initialization processing:
108
+ def __post_init__(self) -> None:
109
+ """Perform post-initialization processing.
110
+
106
111
  1. Loads the dataset from the specified filename.
107
- 2. Applies time filtering based on start_time and end_time if provided.
108
- 3. Selects relevant fields as specified by var_names.
109
- 4. Ensures latitude values and depth values are in ascending order.
112
+ 2. Applies time filtering based on start_time and end_time (if provided).
113
+ 3. Selects relevant fields as specified by `var_names`.
114
+ 4. Ensures latitude, longitude, and depth values are in ascending order.
110
115
  5. Checks if the dataset covers the entire globe and adjusts if necessary.
111
116
  """
112
117
  # Validate start_time and end_time
@@ -168,7 +173,11 @@ class Dataset:
168
173
  If a list of files is provided but self.dim_names["time"] is not available or use_dask=False.
169
174
  """
170
175
  ds = _load_data(
171
- self.filename, self.dim_names, self.use_dask, read_zarr=self.read_zarr
176
+ self.filename,
177
+ self.dim_names,
178
+ self.use_dask or False,
179
+ read_zarr=self.read_zarr or False,
180
+ ds_loader_fn=self.ds_loader_fn,
172
181
  )
173
182
 
174
183
  return ds
@@ -1075,6 +1084,83 @@ class GLORYSDataset(Dataset):
1075
1084
  self.ds["mask_vel"] = mask_vel
1076
1085
 
1077
1086
 
1087
+ @dataclass(kw_only=True)
1088
+ class GLORYSDefaultDataset(GLORYSDataset):
1089
+ """A GLORYS dataset that is loaded from the Copernicus Marine Data Store."""
1090
+
1091
+ dataset_name: ClassVar[str] = "cmems_mod_glo_phy_my_0.083deg_P1D-m"
1092
+ """The GLORYS dataset-id for requests to the Copernicus Marine Toolkit"""
1093
+ _tk_module: ModuleType | None = None
1094
+ """The dynamically imported Copernicus Marine module."""
1095
+
1096
+ def __post_init__(self) -> None:
1097
+ """Configure attributes to ensure use of the correct upstream data-source."""
1098
+ self.read_zarr = True
1099
+ self.use_dask = True
1100
+ self.filename = self.dataset_name
1101
+ self.ds_loader_fn = self._load_from_copernicus
1102
+
1103
+ super().__post_init__()
1104
+
1105
+ def _check_auth(self, package_name: str) -> None:
1106
+ """Check the local credential hierarchy for auth credentials.
1107
+
1108
+ Raises
1109
+ ------
1110
+ RuntimeError
1111
+ If auth credentials cannot be found.
1112
+ """
1113
+ if self._tk_module and not self._tk_module.login(check_credentials_valid=True):
1114
+ msg = f"Authenticate with `{package_name} login` to retrieve GLORYS data."
1115
+ raise RuntimeError(msg)
1116
+
1117
+ def _load_copernicus(self) -> ModuleType:
1118
+ """Dynamically load the optional Copernicus Marine Toolkit dependency.
1119
+
1120
+ Raises
1121
+ ------
1122
+ RuntimeError
1123
+ - If the toolkit module is not available or cannot be imported.
1124
+ - If auth credentials cannot be found.
1125
+ """
1126
+ package_name = "copernicusmarine"
1127
+ if self._tk_module:
1128
+ self._check_auth(package_name)
1129
+ return self._tk_module
1130
+
1131
+ spec = importlib.util.find_spec(package_name)
1132
+ if not spec:
1133
+ msg = _get_pkg_error_msg("cloud-based GLORYS data", package_name, "stream")
1134
+ raise RuntimeError(msg)
1135
+
1136
+ try:
1137
+ self._tk_module = importlib.import_module(package_name)
1138
+ except ImportError as e:
1139
+ msg = f"Package `{package_name}` was found but could not be loaded."
1140
+ raise RuntimeError(msg) from e
1141
+
1142
+ self._check_auth(package_name)
1143
+ return self._tk_module
1144
+
1145
+ def _load_from_copernicus(self) -> xr.Dataset:
1146
+ """Load a GLORYS dataset supporting streaming.
1147
+
1148
+ Returns
1149
+ -------
1150
+ xr.Dataset
1151
+ The streaming dataset
1152
+ """
1153
+ copernicusmarine = self._load_copernicus()
1154
+ return copernicusmarine.open_dataset(
1155
+ self.dataset_name,
1156
+ start_datetime=self.start_time,
1157
+ end_datetime=self.end_time,
1158
+ service="arco-geo-series",
1159
+ coordinates_selection_method="inside",
1160
+ chunk_size_limit=2,
1161
+ )
1162
+
1163
+
1078
1164
  @dataclass(kw_only=True)
1079
1165
  class UnifiedDataset(Dataset):
1080
1166
  """Represents unified BGC data on original grid.
@@ -1549,12 +1635,8 @@ class ERA5ARCODataset(ERA5Dataset):
1549
1635
  def __post_init__(self):
1550
1636
  self.read_zarr = True
1551
1637
  if not _has_gcsfs():
1552
- raise RuntimeError(
1553
- "To use cloud-based ERA5 data, GCSFS is required but not installed. Install it with:\n"
1554
- " • `pip install roms-tools[stream]` or\n"
1555
- " • `conda install gcsfs`\n"
1556
- "Alternatively, install `roms-tools` with conda to include all dependencies."
1557
- )
1638
+ msg = _get_pkg_error_msg("cloud-based ERA5 data", "gcsfs", "stream")
1639
+ raise RuntimeError(msg)
1558
1640
 
1559
1641
  super().__post_init__()
1560
1642
 
roms_tools/setup/grid.py CHANGED
@@ -415,30 +415,57 @@ class Grid:
415
415
 
416
416
  def plot(
417
417
  self,
418
+ lat: float | None = None,
419
+ lon: float | None = None,
418
420
  with_dim_names: bool = False,
419
421
  save_path: str | None = None,
420
422
  ) -> None:
421
- """Plot the grid.
423
+ """Plot the grid with bathymetry.
424
+
425
+ Depending on the arguments, this will either:
426
+ * Plot the full horizontal grid (if both `lat` and `lon` are None),
427
+ * Plot a zonal (east-west) vertical section at a given latitude (`lat`),
428
+ * Plot a meridional (south-north) vertical section at a given longitude (`lon`).
422
429
 
423
430
  Parameters
424
431
  ----------
432
+ lat : float, optional
433
+ Latitude in degrees at which to plot a vertical (zonal) section. Cannot be
434
+ provided together with `lon`. Default is None.
435
+
436
+ lon : float, optional
437
+ Longitude in degrees at which to plot a vertical (meridional) section. Cannot be
438
+ provided together with `lat`. Default is None.
439
+
425
440
  with_dim_names : bool, optional
426
- Whether or not to plot the dimension names. Default is False.
441
+ If True and no section is requested (i.e., both `lat` and `lon` are None), annotate
442
+ the plot with the underlying dimension names. Default is False.
427
443
 
428
444
  save_path : str, optional
429
445
  Path to save the generated plot. If None, the plot is shown interactively.
430
446
  Default is None.
431
447
 
448
+ Raises
449
+ ------
450
+ ValueError
451
+ If both `lat` and `lon` are specified simultaneously.
452
+
432
453
  Returns
433
454
  -------
434
455
  None
435
456
  This method does not return any value. It generates and displays a plot.
436
457
  """
458
+ if lat is not None and lon is not None:
459
+ raise ValueError("Specify either `lat` or `lon`, not both.")
460
+
437
461
  field = self.ds["h"]
438
462
 
439
463
  plot(
440
464
  field=field,
441
465
  grid_ds=self.ds,
466
+ lat=lat,
467
+ lon=lon,
468
+ yincrease=False,
442
469
  with_dim_names=with_dim_names,
443
470
  save_path=save_path,
444
471
  cmap_name="YlGnBu",
@@ -150,12 +150,20 @@ class SurfaceForcing:
150
150
  use_coarse_grid = False
151
151
  elif self.coarse_grid_mode == "auto":
152
152
  use_coarse_grid = self._determine_coarse_grid_usage(data)
153
- if use_coarse_grid:
154
- logging.info("Data will be interpolated onto grid coarsened by factor 2.")
155
- else:
156
- logging.info("Data will be interpolated onto fine grid.")
157
153
  self.use_coarse_grid = use_coarse_grid
158
154
 
155
+ opt_file = "bulk_frc.opt" if self.type == "physics" else "bgc.opt"
156
+ grid_desc = "grid coarsened by factor 2" if use_coarse_grid else "fine grid"
157
+ interp_flag = 1 if use_coarse_grid else 0
158
+
159
+ logging.info(
160
+ "Data will be interpolated onto the %s. "
161
+ "Remember to set `interp_frc = %d` in your `%s` ROMS option file.",
162
+ grid_desc,
163
+ interp_flag,
164
+ opt_file,
165
+ )
166
+
159
167
  target_coords = get_target_coords(self.grid, self.use_coarse_grid)
160
168
  self.target_coords = target_coords
161
169
 
@@ -1,7 +1,9 @@
1
1
  import logging
2
+ import os
2
3
  import textwrap
3
4
  from datetime import datetime
4
5
  from pathlib import Path
6
+ from unittest import mock
5
7
 
6
8
  import matplotlib.pyplot as plt
7
9
  import numpy as np
@@ -758,3 +760,58 @@ def test_from_yaml_missing_boundary_forcing(tmp_path, use_dask):
758
760
 
759
761
  yaml_filepath = Path(yaml_filepath)
760
762
  yaml_filepath.unlink()
763
+
764
+
765
+ @pytest.mark.stream
766
+ @pytest.mark.use_dask
767
+ @pytest.mark.use_copernicus
768
+ def test_default_glorys_dataset_loading(tiny_grid: Grid) -> None:
769
+ """Verify the default GLORYS dataset is loaded when a path is not provided."""
770
+ start_time = datetime(2010, 2, 1)
771
+ end_time = datetime(2010, 3, 1)
772
+
773
+ with mock.patch.dict(
774
+ os.environ, {"PYDEVD_WARN_EVALUATION_TIMEOUT": "90"}, clear=True
775
+ ):
776
+ bf = BoundaryForcing(
777
+ grid=tiny_grid,
778
+ source={"name": "GLORYS"},
779
+ type="physics",
780
+ start_time=start_time,
781
+ end_time=end_time,
782
+ use_dask=True,
783
+ bypass_validation=True,
784
+ )
785
+
786
+ expected_vars = {"u_south", "v_south", "temp_south", "salt_south"}
787
+ assert set(bf.ds.data_vars).issuperset(expected_vars)
788
+
789
+
790
+ @pytest.mark.parametrize(
791
+ "use_dask",
792
+ [pytest.param(True, marks=pytest.mark.use_dask), False],
793
+ )
794
+ def test_nondefault_glorys_dataset_loading(small_grid: Grid, use_dask: bool) -> None:
795
+ """Verify a non-default GLORYS dataset is loaded when a path is provided."""
796
+ start_time = datetime(2012, 1, 1)
797
+ end_time = datetime(2012, 12, 31)
798
+
799
+ local_path = Path(download_test_data("GLORYS_NA_20120101.nc"))
800
+
801
+ with mock.patch.dict(
802
+ os.environ, {"PYDEVD_WARN_EVALUATION_TIMEOUT": "90"}, clear=True
803
+ ):
804
+ bf = BoundaryForcing(
805
+ grid=small_grid,
806
+ source={
807
+ "name": "GLORYS",
808
+ "path": local_path,
809
+ },
810
+ type="physics",
811
+ start_time=start_time,
812
+ end_time=end_time,
813
+ use_dask=use_dask,
814
+ )
815
+
816
+ expected_vars = {"u_south", "v_south", "temp_south", "salt_south"}
817
+ assert set(bf.ds.data_vars).issuperset(expected_vars)