roms-tools 2.2.0__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. roms_tools/__init__.py +1 -0
  2. roms_tools/analysis/roms_output.py +586 -0
  3. roms_tools/{setup/download.py → download.py} +3 -0
  4. roms_tools/{setup/plot.py → plot.py} +34 -28
  5. roms_tools/setup/boundary_forcing.py +23 -12
  6. roms_tools/setup/datasets.py +2 -135
  7. roms_tools/setup/grid.py +54 -15
  8. roms_tools/setup/initial_conditions.py +105 -149
  9. roms_tools/setup/nesting.py +4 -4
  10. roms_tools/setup/river_forcing.py +7 -9
  11. roms_tools/setup/surface_forcing.py +14 -14
  12. roms_tools/setup/tides.py +24 -21
  13. roms_tools/setup/topography.py +1 -1
  14. roms_tools/setup/utils.py +19 -143
  15. roms_tools/tests/test_analysis/test_roms_output.py +269 -0
  16. roms_tools/tests/{test_setup/test_regrid.py → test_regrid.py} +1 -1
  17. roms_tools/tests/test_setup/test_boundary_forcing.py +1 -1
  18. roms_tools/tests/test_setup/test_datasets.py +1 -1
  19. roms_tools/tests/test_setup/test_grid.py +1 -1
  20. roms_tools/tests/test_setup/test_initial_conditions.py +8 -4
  21. roms_tools/tests/test_setup/test_river_forcing.py +1 -1
  22. roms_tools/tests/test_setup/test_surface_forcing.py +1 -1
  23. roms_tools/tests/test_setup/test_tides.py +1 -1
  24. roms_tools/tests/test_setup/test_topography.py +1 -1
  25. roms_tools/tests/test_setup/test_utils.py +56 -1
  26. roms_tools/utils.py +301 -0
  27. roms_tools/vertical_coordinate.py +306 -0
  28. {roms_tools-2.2.0.dist-info → roms_tools-2.3.0.dist-info}/METADATA +1 -1
  29. {roms_tools-2.2.0.dist-info → roms_tools-2.3.0.dist-info}/RECORD +33 -31
  30. roms_tools/setup/vertical_coordinate.py +0 -109
  31. /roms_tools/{setup/regrid.py → regrid.py} +0 -0
  32. {roms_tools-2.2.0.dist-info → roms_tools-2.3.0.dist-info}/LICENSE +0 -0
  33. {roms_tools-2.2.0.dist-info → roms_tools-2.3.0.dist-info}/WHEEL +0 -0
  34. {roms_tools-2.2.0.dist-info → roms_tools-2.3.0.dist-info}/top_level.txt +0 -0
@@ -1,37 +1,35 @@
1
1
  import cartopy.crs as ccrs
2
2
  import matplotlib.pyplot as plt
3
3
  import xarray as xr
4
+ import numpy as np
4
5
 
5
6
 
6
7
  def _plot(
7
- grid_ds,
8
- field=None,
8
+ field,
9
9
  depth_contours=False,
10
- straddle=False,
11
10
  c="red",
12
11
  title="",
13
12
  with_dim_names=False,
13
+ plot_data=True,
14
14
  kwargs={},
15
15
  ):
16
16
  """Plots a grid or field on a map with optional depth contours.
17
17
 
18
18
  This function plots a map using Cartopy projections. It supports plotting a grid, a field, and adding depth contours if desired.
19
- The projection can be customized, and the grid can be adjusted for domains straddling the 180° meridian.
20
19
 
21
20
  Parameters
22
21
  ----------
23
- grid_ds : xarray.Dataset
24
- The grid dataset containing coordinates (`lon_rho`, `lat_rho`).
25
22
  field : xarray.DataArray, optional
26
23
  The field to plot. If None, only the grid is plotted.
27
24
  depth_contours : bool, optional
28
25
  If True, adds depth contours to the plot.
29
- straddle : bool, optional
30
- If True, adjusts longitude values to straddle across the 180° meridian.
31
26
  c : str, optional
32
27
  Color for the boundary plot (default is 'red').
33
28
  title : str, optional
34
29
  Title of the plot.
30
+ plot_data : bool, optional
31
+ If True, plots the provided field data on the map. If False, only the grid
32
+ boundaries and optional depth contours are plotted. Default is True.
35
33
  kwargs : dict, optional
36
34
  Additional keyword arguments to pass to `pcolormesh` (e.g., colormap or color limits).
37
35
 
@@ -40,24 +38,15 @@ def _plot(
40
38
  The function raises a `NotImplementedError` if the domain contains the North or South Pole.
41
39
  """
42
40
 
43
- if field is None:
44
- lon_deg = grid_ds["lon_rho"]
45
- lat_deg = grid_ds["lat_rho"]
46
-
47
- else:
48
-
49
- field = field.squeeze()
50
- lon_deg = field.lon
51
- lat_deg = field.lat
41
+ field = field.squeeze()
42
+ lon_deg = field.lon
43
+ lat_deg = field.lat
52
44
 
53
- # check if North or South pole are in domain
54
- if lat_deg.max().values > 89 or lat_deg.min().values < -89:
55
- raise NotImplementedError(
56
- "Plotting is not implemented for the case that the domain contains the North or South pole."
57
- )
58
-
59
- if straddle:
60
- lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
45
+ # check if North or South pole are in domain
46
+ if lat_deg.max().values > 89 or lat_deg.min().values < -89:
47
+ raise NotImplementedError(
48
+ "Plotting is not implemented for the case that the domain contains the North or South pole."
49
+ )
61
50
 
62
51
  trans = _get_projection(lon_deg, lat_deg)
63
52
 
@@ -71,7 +60,7 @@ def _plot(
71
60
  ax, lon_deg, lat_deg, trans, c, with_dim_names=with_dim_names
72
61
  )
73
62
 
74
- if field is not None:
63
+ if plot_data:
75
64
  _add_field_to_ax(ax, lon_deg, lat_deg, field, depth_contours, kwargs=kwargs)
76
65
 
77
66
  ax.coastlines(
@@ -287,7 +276,7 @@ def _section_plot(field, interface_depth=None, title="", kwargs={}, ax=None):
287
276
  field.plot(**kwargs, **more_kwargs, ax=ax)
288
277
 
289
278
  if interface_depth is not None:
290
- layer_key = "s_rho" if "s_rho" in interface_depth else "s_w"
279
+ layer_key = "s_rho" if "s_rho" in interface_depth.dims else "s_w"
291
280
 
292
281
  for i in range(len(interface_depth[layer_key])):
293
282
  ax.plot(
@@ -339,7 +328,8 @@ def _profile_plot(field, title="", ax=None):
339
328
 
340
329
 
341
330
  def _line_plot(field, title="", ax=None):
342
- """Plots a line graph of the given field.
331
+ """Plots a line graph of the given field, with grey vertical bars where NaNs are
332
+ located.
343
333
 
344
334
  Parameters
345
335
  ----------
@@ -358,6 +348,22 @@ def _line_plot(field, title="", ax=None):
358
348
  if ax is None:
359
349
  fig, ax = plt.subplots(1, 1, figsize=(7, 4))
360
350
  field.plot(ax=ax)
351
+
352
+ # Loop through the NaNs in the field and add grey vertical bars
353
+ nan_mask = np.isnan(field.values)
354
+ nan_indices = np.where(nan_mask)[0]
355
+
356
+ if len(nan_indices) > 0:
357
+ # Add grey vertical bars for each NaN region
358
+ start_idx = nan_indices[0]
359
+ for idx in range(1, len(nan_indices)):
360
+ if nan_indices[idx] != nan_indices[idx - 1] + 1:
361
+ ax.axvspan(start_idx, nan_indices[idx - 1] + 1, color="gray", alpha=0.3)
362
+ start_idx = nan_indices[idx]
363
+ # Add the last region of NaNs
364
+ ax.axvspan(start_idx, nan_indices[-1] + 1, color="gray", alpha=0.3)
365
+
366
+ # Set plot title and grid
361
367
  ax.set_title(title)
362
368
  ax.grid()
363
369
 
@@ -5,11 +5,19 @@ import logging
5
5
  import importlib.metadata
6
6
  from typing import Dict, Union, List
7
7
  from dataclasses import dataclass, field
8
- from roms_tools.setup.grid import Grid
9
- from roms_tools.setup.regrid import LateralRegrid, VerticalRegrid
10
8
  from datetime import datetime
9
+ import matplotlib.pyplot as plt
10
+ from pathlib import Path
11
+ from roms_tools import Grid
12
+ from roms_tools.regrid import LateralRegrid, VerticalRegrid
13
+ from roms_tools.vertical_coordinate import compute_depth
14
+ from roms_tools.plot import _section_plot, _line_plot
15
+ from roms_tools.utils import (
16
+ interpolate_from_rho_to_u,
17
+ interpolate_from_rho_to_v,
18
+ transpose_dimensions,
19
+ )
11
20
  from roms_tools.setup.datasets import GLORYSDataset, CESMBGCDataset
12
- from roms_tools.setup.vertical_coordinate import compute_depth
13
21
  from roms_tools.setup.utils import (
14
22
  get_variable_metadata,
15
23
  group_dataset,
@@ -17,20 +25,14 @@ from roms_tools.setup.utils import (
17
25
  get_target_coords,
18
26
  rotate_velocities,
19
27
  compute_barotropic_velocity,
20
- transpose_dimensions,
21
28
  one_dim_fill,
22
29
  nan_check,
23
30
  substitute_nans_by_fillvalue,
24
- interpolate_from_rho_to_u,
25
- interpolate_from_rho_to_v,
26
31
  convert_to_roms_time,
27
32
  get_boundary_coords,
28
33
  _to_yaml,
29
34
  _from_yaml,
30
35
  )
31
- from roms_tools.setup.plot import _section_plot, _line_plot
32
- import matplotlib.pyplot as plt
33
- from pathlib import Path
34
36
 
35
37
 
36
38
  @dataclass(frozen=True, kw_only=True)
@@ -827,11 +829,20 @@ class BoundaryForcing:
827
829
  var_name_wo_direction, direction = var_name.split("_")
828
830
  location = self.variable_info[var_name_wo_direction]["location"]
829
831
 
832
+ # Find correct mask
833
+ if location == "rho":
834
+ mask = self.grid.ds.mask_rho
835
+ elif location == "u":
836
+ mask = self.grid.ds.mask_u
837
+ elif location == "v":
838
+ mask = self.grid.ds.mask_v
839
+
840
+ mask = mask.isel(**self.bdry_coords[location][direction])
841
+
830
842
  if "s_rho" in field.dims:
831
843
  field = field.assign_coords(
832
844
  {"layer_depth": self.grid.ds[f"layer_depth_{location}_{direction}"]}
833
845
  )
834
- # chose colorbar
835
846
  if var_name.startswith(("u", "v", "ubar", "vbar", "zeta")):
836
847
  vmax = max(field.max().values, -field.min().values)
837
848
  vmin = -vmax
@@ -872,14 +883,14 @@ class BoundaryForcing:
872
883
  interface_depth = None
873
884
 
874
885
  _section_plot(
875
- field,
886
+ field.where(mask),
876
887
  interface_depth=interface_depth,
877
888
  title=title,
878
889
  kwargs=kwargs,
879
890
  ax=ax,
880
891
  )
881
892
  else:
882
- _line_plot(field, title=title, ax=ax)
893
+ _line_plot(field.where(mask), title=title, ax=ax)
883
894
 
884
895
  def save(
885
896
  self,
@@ -1,13 +1,12 @@
1
1
  import time
2
- import re
3
2
  import xarray as xr
4
3
  from dataclasses import dataclass, field
5
- import glob
6
4
  from datetime import datetime, timedelta
7
5
  import numpy as np
8
6
  from typing import Dict, Optional, Union, List
9
7
  from pathlib import Path
10
8
  import logging
9
+ from roms_tools.utils import _load_data
11
10
  from roms_tools.setup.utils import (
12
11
  assign_dates_to_climatology,
13
12
  interpolate_from_climatology,
@@ -16,7 +15,7 @@ from roms_tools.setup.utils import (
16
15
  one_dim_fill,
17
16
  gc_dist,
18
17
  )
19
- from roms_tools.setup.download import (
18
+ from roms_tools.download import (
20
19
  download_correction_data,
21
20
  download_topo,
22
21
  download_river_data,
@@ -1945,138 +1944,6 @@ class DaiRiverDataset(RiverDataset):
1945
1944
  # shared functions
1946
1945
 
1947
1946
 
1948
- def _load_data(filename, dim_names, use_dask, decode_times=True):
1949
- """Load dataset from the specified file.
1950
-
1951
- Parameters
1952
- ----------
1953
- filename : Union[str, Path, List[Union[str, Path]]]
1954
- The path to the data file(s). Can be a single string (with or without wildcards), a single Path object,
1955
- or a list of strings or Path objects containing multiple files.
1956
- dim_names: Dict[str, str], optional
1957
- Dictionary specifying the names of dimensions in the dataset.
1958
- use_dask: bool
1959
- Indicates whether to use dask for chunking. If True, data is loaded with dask; if False, data is loaded eagerly. Defaults to False.
1960
- decode_times: bool, optional
1961
- If True, decode times encoded in the standard NetCDF datetime format into datetime objects. Otherwise, leave them encoded as numbers.
1962
- Defaults to True.
1963
-
1964
- Returns
1965
- -------
1966
- ds : xr.Dataset
1967
- The loaded xarray Dataset containing the forcing data.
1968
-
1969
- Raises
1970
- ------
1971
- FileNotFoundError
1972
- If the specified file does not exist.
1973
- ValueError
1974
- If a list of files is provided but dim_names["time"] is not available or use_dask=False.
1975
- """
1976
-
1977
- # Precompile the regex for matching wildcard characters
1978
- wildcard_regex = re.compile(r"[\*\?\[\]]")
1979
-
1980
- # Convert Path objects to strings
1981
- if isinstance(filename, (str, Path)):
1982
- filename_str = str(filename)
1983
- elif isinstance(filename, list):
1984
- filename_str = [str(f) for f in filename]
1985
- else:
1986
- raise ValueError("filename must be a string, Path, or a list of strings/Paths.")
1987
- # Handle the case when filename is a string
1988
- contains_wildcard = False
1989
- if isinstance(filename_str, str):
1990
- contains_wildcard = bool(wildcard_regex.search(filename_str))
1991
- if contains_wildcard:
1992
- matching_files = glob.glob(filename_str)
1993
- if not matching_files:
1994
- raise FileNotFoundError(
1995
- f"No files found matching the pattern '{filename_str}'."
1996
- )
1997
- else:
1998
- matching_files = [filename_str]
1999
-
2000
- # Handle the case when filename is a list
2001
- elif isinstance(filename_str, list):
2002
- contains_wildcard = any(wildcard_regex.search(f) for f in filename_str)
2003
- if contains_wildcard:
2004
- matching_files = []
2005
- for f in filename_str:
2006
- files = glob.glob(f)
2007
- if not files:
2008
- raise FileNotFoundError(
2009
- f"No files found matching the pattern '{f}'."
2010
- )
2011
- matching_files.extend(files)
2012
- else:
2013
- matching_files = filename_str
2014
-
2015
- # Check if time dimension is available when multiple files are provided
2016
- if isinstance(filename_str, list) and "time" not in dim_names:
2017
- raise ValueError(
2018
- "A list of files is provided, but time dimension is not available. "
2019
- "A time dimension must be available to concatenate the files."
2020
- )
2021
-
2022
- # Determine the kwargs for combining datasets
2023
- if contains_wildcard or len(matching_files) == 1:
2024
- # If there is a wildcard or just one file, use by_coords
2025
- kwargs = {"combine": "by_coords"}
2026
- else:
2027
- # Otherwise, use nested combine based on time
2028
- kwargs = {"combine": "nested", "concat_dim": dim_names["time"]}
2029
-
2030
- # Base kwargs used for dataset combination
2031
- combine_kwargs = {
2032
- "coords": "minimal",
2033
- "compat": "override",
2034
- "combine_attrs": "override",
2035
- }
2036
-
2037
- if use_dask:
2038
-
2039
- chunks = {
2040
- dim_names["latitude"]: -1,
2041
- dim_names["longitude"]: -1,
2042
- }
2043
- if "depth" in dim_names:
2044
- chunks[dim_names["depth"]] = -1
2045
- if "time" in dim_names:
2046
- chunks[dim_names["time"]] = 1
2047
-
2048
- ds = xr.open_mfdataset(
2049
- matching_files,
2050
- decode_times=decode_times,
2051
- chunks=chunks,
2052
- **combine_kwargs,
2053
- **kwargs,
2054
- )
2055
-
2056
- # Rechunk the dataset along the tidal constituent dimension ("ntides") after loading
2057
- # because the original dataset does not have a chunk size of 1 along this dimension.
2058
- if "ntides" in dim_names:
2059
- ds = ds.chunk({dim_names["ntides"]: 1})
2060
-
2061
- else:
2062
- ds_list = []
2063
- for file in matching_files:
2064
- ds = xr.open_dataset(file, decode_times=decode_times, chunks=None)
2065
- ds_list.append(ds)
2066
-
2067
- if kwargs["combine"] == "by_coords":
2068
- ds = xr.combine_by_coords(ds_list, **combine_kwargs)
2069
- elif kwargs["combine"] == "nested":
2070
- ds = xr.combine_nested(
2071
- ds_list, concat_dim=kwargs["concat_dim"], **combine_kwargs
2072
- )
2073
-
2074
- if "time" in dim_names and dim_names["time"] not in ds.dims:
2075
- ds = ds.expand_dims(dim_names["time"])
2076
-
2077
- return ds
2078
-
2079
-
2080
1947
  def _check_dataset(
2081
1948
  ds: xr.Dataset,
2082
1949
  dim_names: Dict[str, str],
roms_tools/setup/grid.py CHANGED
@@ -10,14 +10,18 @@ import importlib.metadata
10
10
  from typing import Dict, Union, List
11
11
  from roms_tools.setup.topography import _add_topography
12
12
  from roms_tools.setup.mask import _add_mask, _add_velocity_masks
13
- from roms_tools.setup.plot import _plot, _section_plot
13
+ from roms_tools.vertical_coordinate import (
14
+ sigma_stretch,
15
+ compute_depth,
16
+ add_depth_coordinates_to_dataset,
17
+ )
18
+ from roms_tools.plot import _plot, _section_plot
14
19
  from roms_tools.setup.utils import (
15
20
  interpolate_from_rho_to_u,
16
21
  interpolate_from_rho_to_v,
17
22
  get_target_coords,
18
23
  gc_dist,
19
24
  )
20
- from roms_tools.setup.vertical_coordinate import sigma_stretch, compute_depth
21
25
  from roms_tools.setup.utils import extract_single_value, save_datasets
22
26
  from pathlib import Path
23
27
 
@@ -412,13 +416,16 @@ class Grid:
412
416
  This method does not return any value. It generates and displays a plot.
413
417
  """
414
418
 
419
+ field = self.ds.h.where(self.ds.mask_rho)
420
+ lat_deg = self.ds.lat_rho
421
+ lon_deg = self.ds.lon_rho
422
+ if self.straddle:
423
+ lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
424
+ field = field.assign_coords({"lon": lon_deg, "lat": lat_deg})
425
+
415
426
  if bathymetry:
416
427
  if title is None:
417
428
  title = "ROMS grid and bathymetry"
418
- field = self.ds.h.where(self.ds.mask_rho)
419
- field = field.assign_coords(
420
- {"lon": self.ds.lon_rho, "lat": self.ds.lat_rho}
421
- )
422
429
 
423
430
  vmax = field.max().values
424
431
  vmin = field.min().values
@@ -427,9 +434,7 @@ class Grid:
427
434
  kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
428
435
 
429
436
  _plot(
430
- self.ds,
431
437
  field=field,
432
- straddle=self.straddle,
433
438
  title=title,
434
439
  with_dim_names=with_dim_names,
435
440
  kwargs=kwargs,
@@ -438,12 +443,44 @@ class Grid:
438
443
  if title is None:
439
444
  title = "ROMS grid"
440
445
  _plot(
441
- self.ds,
442
- straddle=self.straddle,
443
- title=title,
444
- with_dim_names=with_dim_names,
446
+ field=field, title=title, with_dim_names=with_dim_names, plot_data=False
445
447
  )
446
448
 
449
+ def compute_depth_coordinates(
450
+ self, depth_type: str, locations: list[str] = ["rho", "u", "v"]
451
+ ):
452
+ """Compute and update vertical depth coordinates.
453
+
454
+ Calculates vertical depth coordinates (layer or interface) for specified locations (e.g., rho, u, v points)
455
+ and updates them in the dataset (`self.ds`).
456
+
457
+ Parameters
458
+ ----------
459
+ depth_type : str
460
+ The type of depth coordinate to compute. Valid options:
461
+ - "layer": Compute layer depth coordinates.
462
+ - "interface": Compute interface depth coordinates.
463
+ locations : list[str], optional
464
+ Locations for which to compute depth coordinates. Default is ["rho", "u", "v"].
465
+ Valid options include:
466
+ - "rho": Depth coordinates at rho points.
467
+ - "u": Depth coordinates at u points.
468
+ - "v": Depth coordinates at v points.
469
+
470
+ Updates
471
+ -------
472
+ self.ds : xarray.Dataset
473
+ The dataset (`self.ds`) is updated with the following depth coordinate variables:
474
+ - f"{depth_type}_depth_rho": Depth coordinates at rho points.
475
+ - f"{depth_type}_depth_u": Depth coordinates at u points (if included in `locations`).
476
+ - f"{depth_type}_depth_v": Depth coordinates at v points (if included in `locations`).
477
+
478
+ Notes
479
+ -----
480
+ This method uses the `compute_and_update_depth_coordinates` function to perform calculations and updates.
481
+ """
482
+ add_depth_coordinates_to_dataset(self.ds, self.ds, depth_type, locations)
483
+
447
484
  def plot_vertical_coordinate(
448
485
  self,
449
486
  s=None,
@@ -480,7 +517,11 @@ class Grid:
480
517
  raise ValueError("Exactly one of s, eta, or xi must be specified.")
481
518
 
482
519
  h = self.ds["h"]
483
- h = h.assign_coords({"lon": self.ds.lon_rho, "lat": self.ds.lat_rho})
520
+ lat_deg = self.ds.lat_rho
521
+ lon_deg = self.ds.lon_rho
522
+ if self.straddle:
523
+ lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
524
+ h = h.assign_coords({"lon": lon_deg, "lat": lat_deg})
484
525
 
485
526
  # slice the bathymetry as desired
486
527
  if eta is not None:
@@ -505,9 +546,7 @@ class Grid:
505
546
  kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
506
547
 
507
548
  _plot(
508
- self.ds,
509
549
  field=layer_depth.where(self.ds.mask_rho),
510
- straddle=self.straddle,
511
550
  depth_contours=False,
512
551
  title=title,
513
552
  kwargs=kwargs,