roms-tools 3.1.0__py3-none-any.whl → 3.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
roms_tools/__init__.py CHANGED
@@ -20,5 +20,9 @@ from roms_tools.setup.surface_forcing import SurfaceForcing # noqa: F401
20
20
  from roms_tools.setup.tides import TidalForcing # noqa: F401
21
21
  from roms_tools.tiling.partition import partition_netcdf # noqa: F401
22
22
 
23
+
23
24
  # Configure logging when the package is imported
24
- logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s")
25
+ LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
26
+ DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
27
+
28
+ logging.basicConfig(level=logging.INFO, format=LOG_FORMAT, datefmt=DATE_FORMAT)
roms_tools/constants.py CHANGED
@@ -3,3 +3,4 @@ MAXIMUM_GRID_SIZE = 25000 # in km
3
3
  UPPER_BOUND_THETA_S = 10 # upper bound for surface vertical stretching parameter
4
4
  UPPER_BOUND_THETA_B = 10 # upper bound for bottom vertical stretching parameter
5
5
  NUM_TRACERS = 34 # Number of tracers (temperature, salinity, BGC tracers)
6
+ MAX_DISTINCT_COLORS = 20 # Based on tab20 colormap
roms_tools/plot.py CHANGED
@@ -212,7 +212,14 @@ def plot_nesting(parent_grid_ds, child_grid_ds, parent_straddle, with_dim_names=
212
212
  return fig
213
213
 
214
214
 
215
- def section_plot(field, interface_depth=None, title="", kwargs={}, ax=None):
215
+ def section_plot(
216
+ field: xr.DataArray,
217
+ interface_depth: xr.DataArray | None = None,
218
+ title: str = "",
219
+ yincrease: bool | None = False,
220
+ kwargs: dict = {},
221
+ ax: Axes | None = None,
222
+ ):
216
223
  """Plots a vertical section of a field with optional interface depths.
217
224
 
218
225
  Parameters
@@ -224,6 +231,11 @@ def section_plot(field, interface_depth=None, title="", kwargs={}, ax=None):
224
231
  Defaults to None.
225
232
  title : str, optional
226
233
  Title of the plot. Defaults to an empty string.
234
+ yincrease : bool or None, optional
235
+ Whether to orient the y-axis with increasing values upward.
236
+ If True, y-values increase upward (standard).
237
+ If False, y-values decrease upward (inverted).
238
+ If None (default), behavior is equivalent to False (inverted axis).
227
239
  kwargs : dict, optional
228
240
  Additional keyword arguments to pass to `xarray.plot`. Defaults to an empty dictionary.
229
241
  ax : matplotlib.axes.Axes, optional
@@ -248,6 +260,8 @@ def section_plot(field, interface_depth=None, title="", kwargs={}, ax=None):
248
260
  """
249
261
  if ax is None:
250
262
  fig, ax = plt.subplots(1, 1, figsize=(9, 5))
263
+ if yincrease is None:
264
+ yincrease = False
251
265
 
252
266
  dims_to_check = ["eta_rho", "eta_v", "xi_rho", "xi_u", "lat", "lon"]
253
267
  try:
@@ -279,7 +293,7 @@ def section_plot(field, interface_depth=None, title="", kwargs={}, ax=None):
279
293
  # Handle NaNs on either horizontal end
280
294
  field = field.where(~field[depth_label].isnull(), drop=True)
281
295
 
282
- more_kwargs = {"x": xdim, "y": depth_label, "yincrease": False}
296
+ more_kwargs = {"x": xdim, "y": depth_label, "yincrease": yincrease}
283
297
 
284
298
  field.plot(**kwargs, **more_kwargs, ax=ax)
285
299
 
@@ -313,7 +327,12 @@ def section_plot(field, interface_depth=None, title="", kwargs={}, ax=None):
313
327
  return fig
314
328
 
315
329
 
316
- def profile_plot(field, title="", ax=None):
330
+ def profile_plot(
331
+ field: xr.DataArray,
332
+ title: str = "",
333
+ yincrease: bool | None = False,
334
+ ax: Axes | None = None,
335
+ ):
317
336
  """Plots a vertical profile of the given field against depth.
318
337
 
319
338
  This function generates a profile plot by plotting the field values against
@@ -326,6 +345,11 @@ def profile_plot(field, title="", ax=None):
326
345
  The field to plot, typically representing vertical profile data.
327
346
  title : str, optional
328
347
  Title of the plot. Defaults to an empty string.
348
+ yincrease : bool or None, optional
349
+ Whether to orient the y-axis with increasing values upward.
350
+ If True, y-values increase upward (standard).
351
+ If False, y-values decrease upward (inverted).
352
+ If None (default), behavior is equivalent to False (inverted axis).
329
353
  ax : matplotlib.axes.Axes, optional
330
354
  Pre-existing axes to draw the plot on. If None, a new figure and axes are created.
331
355
 
@@ -343,6 +367,9 @@ def profile_plot(field, title="", ax=None):
343
367
  -----
344
368
  - The y-axis is inverted to ensure that depth increases downward.
345
369
  """
370
+ if yincrease is None:
371
+ yincrease = False
372
+
346
373
  depths_to_check = [
347
374
  "layer_depth",
348
375
  "interface_depth",
@@ -360,8 +387,8 @@ def profile_plot(field, title="", ax=None):
360
387
 
361
388
  if ax is None:
362
389
  fig, ax = plt.subplots(1, 1, figsize=(4, 7))
363
- kwargs = {"y": depth_label, "yincrease": False}
364
- field.plot(**kwargs, linewidth=2)
390
+ kwargs = {"y": depth_label, "yincrease": yincrease}
391
+ field.plot(ax=ax, linewidth=2, **kwargs)
365
392
  ax.set_title(title)
366
393
  ax.set_ylabel("Depth [m]")
367
394
  ax.grid()
@@ -370,7 +397,12 @@ def profile_plot(field, title="", ax=None):
370
397
  return fig
371
398
 
372
399
 
373
- def line_plot(field, title="", ax=None):
400
+ def line_plot(
401
+ field: xr.DataArray,
402
+ title: str = "",
403
+ ax: Axes | None = None,
404
+ yincrease: bool | None = False,
405
+ ):
374
406
  """Plots a line graph of the given field with grey vertical bars indicating NaN
375
407
  regions.
376
408
 
@@ -382,6 +414,11 @@ def line_plot(field, title="", ax=None):
382
414
  Title of the plot. Defaults to an empty string.
383
415
  ax : matplotlib.axes.Axes, optional
384
416
  Pre-existing axes to draw the plot on. If None, a new figure and axes are created.
417
+ yincrease : bool, optional
418
+ Whether to orient the y-axis with increasing values upward.
419
+ If True, y-values increase upward (standard).
420
+ If False, y-values decrease upward (inverted).
421
+ If None (default), behavior is equivalent to True (standard axis).
385
422
 
386
423
  Returns
387
424
  -------
@@ -399,10 +436,12 @@ def line_plot(field, title="", ax=None):
399
436
  -----
400
437
  - NaN regions are identified and marked using `axvspan` with a grey shade.
401
438
  """
439
+ if yincrease is None:
440
+ yincrease = True
402
441
  if ax is None:
403
442
  fig, ax = plt.subplots(1, 1, figsize=(7, 4))
404
443
 
405
- field.plot(ax=ax, linewidth=2)
444
+ field.plot(ax=ax, linewidth=2, yincrease=yincrease)
406
445
 
407
446
  # Loop through the NaNs in the field and add grey vertical bars
408
447
  dims_to_check = ["eta_rho", "eta_v", "xi_rho", "xi_u", "lat", "lon"]
@@ -775,6 +814,7 @@ def plot(
775
814
  depth_contours: bool = False,
776
815
  layer_contours: bool = False,
777
816
  max_nr_layer_contours: int | None = 10,
817
+ yincrease: bool | None = None,
778
818
  use_coarse_grid: bool = False,
779
819
  with_dim_names: bool = False,
780
820
  ax: Axes | None = None,
@@ -838,6 +878,12 @@ def plot(
838
878
  max_nr_layer_contours : int, optional
839
879
  Maximum number of vertical layer contours to draw. Default is 10.
840
880
 
881
+ yincrease: bool, optional
882
+ If True, the y-axis values increase upward (standard orientation).
883
+ If False, the y-axis values decrease upward (inverted axis).
884
+ If None (default), the orientation is determined by the default behavior
885
+ of the underlying plotting function.
886
+
841
887
  use_coarse_grid : bool, optional
842
888
  Use precomputed coarse-resolution grid. Default is False.
843
889
 
@@ -1086,14 +1132,123 @@ def plot(
1086
1132
  field,
1087
1133
  interface_depth=interface_depth,
1088
1134
  title=title,
1135
+ yincrease=yincrease,
1089
1136
  kwargs={**kwargs, "add_colorbar": add_colorbar},
1090
1137
  ax=ax,
1091
1138
  )
1092
1139
  else:
1093
1140
  if "s_rho" in field.dims:
1094
- fig = profile_plot(field, title=title, ax=ax)
1141
+ fig = profile_plot(field, title=title, yincrease=yincrease, ax=ax)
1095
1142
  else:
1096
- fig = line_plot(field, title=title, ax=ax)
1143
+ fig = line_plot(field, title=title, ax=ax, yincrease=yincrease)
1097
1144
 
1098
1145
  if save_path:
1099
1146
  plt.savefig(save_path, dpi=300, bbox_inches="tight")
1147
+
1148
+
1149
+ def assign_category_colors(names: list[str]) -> dict[str, tuple]:
1150
+ """
1151
+ Assign a distinct color to each name using a Matplotlib categorical colormap.
1152
+
1153
+ Parameters
1154
+ ----------
1155
+ names : list[str]
1156
+ List of category names (e.g., releases, rivers, etc.) to assign colors to.
1157
+
1158
+ Returns
1159
+ -------
1160
+ dict[str, tuple]
1161
+ Dictionary mapping each name to a unique RGBA color.
1162
+
1163
+ Raises
1164
+ ------
1165
+ ValueError
1166
+ If the number of names exceeds the selected colormap's capacity.
1167
+
1168
+ Notes
1169
+ -----
1170
+ Colormap selection is based on the number of items:
1171
+ - <= 10: 'tab10'
1172
+ - <= 20: 'tab20'
1173
+ - > 20 : 'tab20b'
1174
+ """
1175
+ n = len(names)
1176
+
1177
+ if n <= 10:
1178
+ cmap = plt.get_cmap("tab10")
1179
+ elif n <= 20:
1180
+ cmap = plt.get_cmap("tab20")
1181
+
1182
+ if n > cmap.N:
1183
+ raise ValueError(
1184
+ f"Too many categories ({n}) for selected colormap ({cmap.name}) "
1185
+ f"which supports only {cmap.N} distinct entries."
1186
+ )
1187
+
1188
+ return {name: cmap(i) for i, name in enumerate(names)}
1189
+
1190
+
1191
+ def plot_location(
1192
+ grid_ds: xr.Dataset,
1193
+ points: dict[str, dict],
1194
+ ax: Axes,
1195
+ include_legend: bool = True,
1196
+ ) -> None:
1197
+ """Plot named geographic points on a top-down map view.
1198
+
1199
+ Each point is represented as a marker on the map, optionally colored.
1200
+ This function is generic and can be used for releases, rivers, etc.
1201
+
1202
+ Parameters
1203
+ ----------
1204
+ grid_ds : xr.Dataset
1205
+ The grid dataset containing 'lon_rho' and 'lat_rho', and a 'straddle' attribute.
1206
+
1207
+ points : dict[str, dict]
1208
+ Dictionary of points to plot. Keys are point names. Each value is a dict with:
1209
+ - "lat": float, latitude in degrees
1210
+ - "lon": float, longitude in degrees
1211
+ - Optional "color": tuple or str, matplotlib color
1212
+
1213
+ ax : matplotlib.axes.Axes
1214
+ The axis object to plot on.
1215
+
1216
+ include_legend : bool, default True
1217
+ Whether to include a legend showing point names.
1218
+
1219
+ Returns
1220
+ -------
1221
+ None
1222
+ """
1223
+ lon_deg = grid_ds.lon_rho
1224
+ lat_deg = grid_ds.lat_rho
1225
+
1226
+ if "straddle" not in grid_ds.attrs:
1227
+ raise AttributeError("Grid dataset must have a 'straddle' attribute.")
1228
+
1229
+ straddle = grid_ds.attrs["straddle"] == "True"
1230
+ if straddle:
1231
+ lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
1232
+
1233
+ trans = get_projection(lon_deg, lat_deg)
1234
+ proj = ccrs.PlateCarree()
1235
+
1236
+ for name, info in points.items():
1237
+ lon = info["lon"]
1238
+ lat = info["lat"]
1239
+ color = info.get("color", "k") # Default to black if no color specified
1240
+
1241
+ x, y = trans.transform_point(lon, lat, proj)
1242
+
1243
+ ax.plot(
1244
+ x,
1245
+ y,
1246
+ marker="x",
1247
+ markersize=8,
1248
+ markeredgewidth=2,
1249
+ label=name,
1250
+ color=color,
1251
+ )
1252
+
1253
+ if include_legend:
1254
+ ax.legend(loc="center left", bbox_to_anchor=(1.1, 0.5))
roms_tools/regrid.py CHANGED
@@ -251,7 +251,12 @@ class VerticalRegridFromROMS:
251
251
  ds : xarray.Dataset
252
252
  The dataset containing the ROMS output data, which must include the vertical coordinate `s_rho`.
253
253
  """
254
- self.grid = xgcm.Grid(ds, coords={"s_rho": {"center": "s_rho"}}, periodic=False)
254
+ self.grid = xgcm.Grid(
255
+ ds,
256
+ coords={"s_rho": {"center": "s_rho"}},
257
+ periodic=False,
258
+ autoparse_metadata=False,
259
+ )
255
260
 
256
261
  def apply(self, da, depth_coords, target_depth_levels, mask_edges=True):
257
262
  """Applies vertical regridding from ROMS to the specified target depth levels.
@@ -1,5 +1,6 @@
1
1
  import importlib.metadata
2
2
  import logging
3
+ from collections import defaultdict
3
4
  from dataclasses import dataclass, field
4
5
  from datetime import datetime
5
6
  from pathlib import Path
@@ -12,7 +13,13 @@ from scipy.ndimage import label
12
13
  from roms_tools import Grid
13
14
  from roms_tools.plot import line_plot, section_plot
14
15
  from roms_tools.regrid import LateralRegridToROMS, VerticalRegridToROMS
15
- from roms_tools.setup.datasets import CESMBGCDataset, GLORYSDataset, UnifiedBGCDataset
16
+ from roms_tools.setup.datasets import (
17
+ CESMBGCDataset,
18
+ Dataset,
19
+ GLORYSDataset,
20
+ GLORYSDefaultDataset,
21
+ UnifiedBGCDataset,
22
+ )
16
23
  from roms_tools.setup.utils import (
17
24
  add_time_info_to_ds,
18
25
  compute_barotropic_velocity,
@@ -181,8 +188,8 @@ class BoundaryForcing:
181
188
  }
182
189
  )
183
190
 
184
- for direction in ["south", "east", "north", "west"]:
185
- if self.boundaries[direction]:
191
+ for direction, is_enabled in self.boundaries.items():
192
+ if is_enabled:
186
193
  bdry_target_coords = {
187
194
  "lat": target_coords["lat"].isel(
188
195
  **self.bdry_coords["vector"][direction]
@@ -403,7 +410,10 @@ class BoundaryForcing:
403
410
  if "name" not in self.source:
404
411
  raise ValueError("`source` must include a 'name'.")
405
412
  if "path" not in self.source:
406
- raise ValueError("`source` must include a 'path'.")
413
+ if self.source["name"] != "GLORYS":
414
+ raise ValueError("`source` must include a 'path'.")
415
+
416
+ self.source["path"] = GLORYSDefaultDataset.dataset_name
407
417
 
408
418
  # Set 'climatology' to False if not provided in 'source'
409
419
  self.source = {
@@ -425,34 +435,49 @@ class BoundaryForcing:
425
435
  "Sea surface height will NOT be used to adjust depth coordinates."
426
436
  )
427
437
 
428
- def _get_data(self):
429
- data_dict = {
430
- "filename": self.source["path"],
431
- "start_time": self.start_time,
432
- "end_time": self.end_time,
433
- "climatology": self.source["climatology"],
434
- "use_dask": self.use_dask,
438
+ def _get_data(self) -> Dataset:
439
+ """Determine the correct `Dataset` type and return an instance.
440
+
441
+ Returns
442
+ -------
443
+ Dataset
444
+ The `Dataset` instance
445
+
446
+ """
447
+ dataset_map: dict[str, dict[str, dict[str, type[Dataset]]]] = {
448
+ "physics": {
449
+ "GLORYS": {
450
+ "external": GLORYSDataset,
451
+ "default": GLORYSDefaultDataset,
452
+ },
453
+ },
454
+ "bgc": {
455
+ "CESM_REGRIDDED": defaultdict(lambda: CESMBGCDataset),
456
+ "UNIFIED": defaultdict(lambda: UnifiedBGCDataset),
457
+ },
435
458
  }
436
459
 
437
- if self.type == "physics":
438
- if self.source["name"] == "GLORYS":
439
- data = GLORYSDataset(**data_dict)
440
- else:
441
- raise ValueError(
442
- 'Only "GLORYS" is a valid option for source["name"] when type is "physics".'
443
- )
460
+ source_name = str(self.source["name"])
461
+ if source_name not in dataset_map[self.type]:
462
+ tpl = 'Valid options for source["name"] for type {} include: {}'
463
+ msg = tpl.format(self.type, " and ".join(dataset_map[self.type].keys()))
464
+ raise ValueError(msg)
444
465
 
445
- elif self.type == "bgc":
446
- if self.source["name"] == "CESM_REGRIDDED":
447
- data = CESMBGCDataset(**data_dict)
448
- elif self.source["name"] == "UNIFIED":
449
- data = UnifiedBGCDataset(**data_dict)
450
- else:
451
- raise ValueError(
452
- 'Only "CESM_REGRIDDED" and "UNIFIED" are valid options for source["name"] when type is "bgc".'
453
- )
466
+ has_no_path = "path" not in self.source
467
+ has_default_path = self.source.get("path") == GLORYSDefaultDataset.dataset_name
468
+ use_default = has_no_path or has_default_path
469
+
470
+ variant = "default" if use_default else "external"
471
+
472
+ data_type = dataset_map[self.type][source_name][variant]
454
473
 
455
- return data
474
+ return data_type(
475
+ filename=self.source["path"],
476
+ start_time=self.start_time,
477
+ end_time=self.end_time,
478
+ climatology=self.source["climatology"],
479
+ use_dask=self.use_dask,
480
+ ) # type: ignore
456
481
 
457
482
  def _set_variable_info(self, data):
458
483
  """Sets up a dictionary with metadata for variables based on the type of data
@@ -797,8 +822,8 @@ class BoundaryForcing:
797
822
  elif location == "v":
798
823
  mask = self.grid.ds.mask_v
799
824
 
800
- for direction in ["south", "east", "north", "west"]:
801
- if self.boundaries[direction]:
825
+ for direction, is_enabled in self.boundaries.items():
826
+ if is_enabled:
802
827
  bdry_var_name = f"{var_name}_{direction}"
803
828
 
804
829
  # Check for NaN values at the first time step using the nan_check function