roms-tools 2.2.1__py3-none-any.whl → 2.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- roms_tools/__init__.py +1 -0
- roms_tools/analysis/roms_output.py +586 -0
- roms_tools/{setup/download.py → download.py} +3 -0
- roms_tools/{setup/plot.py → plot.py} +34 -28
- roms_tools/setup/boundary_forcing.py +23 -12
- roms_tools/setup/datasets.py +2 -135
- roms_tools/setup/grid.py +54 -15
- roms_tools/setup/initial_conditions.py +105 -149
- roms_tools/setup/nesting.py +4 -4
- roms_tools/setup/river_forcing.py +7 -9
- roms_tools/setup/surface_forcing.py +14 -14
- roms_tools/setup/tides.py +24 -21
- roms_tools/setup/topography.py +1 -1
- roms_tools/setup/utils.py +20 -154
- roms_tools/tests/test_analysis/test_roms_output.py +269 -0
- roms_tools/tests/{test_setup/test_regrid.py → test_regrid.py} +1 -1
- roms_tools/tests/test_setup/test_boundary_forcing.py +1 -1
- roms_tools/tests/test_setup/test_datasets.py +1 -1
- roms_tools/tests/test_setup/test_grid.py +1 -1
- roms_tools/tests/test_setup/test_initial_conditions.py +1 -1
- roms_tools/tests/test_setup/test_river_forcing.py +1 -1
- roms_tools/tests/test_setup/test_surface_forcing.py +1 -1
- roms_tools/tests/test_setup/test_tides.py +1 -1
- roms_tools/tests/test_setup/test_topography.py +1 -1
- roms_tools/tests/test_setup/test_utils.py +56 -1
- roms_tools/utils.py +301 -0
- roms_tools/vertical_coordinate.py +306 -0
- {roms_tools-2.2.1.dist-info → roms_tools-2.3.0.dist-info}/METADATA +1 -1
- {roms_tools-2.2.1.dist-info → roms_tools-2.3.0.dist-info}/RECORD +33 -31
- roms_tools/setup/vertical_coordinate.py +0 -109
- /roms_tools/{setup/regrid.py → regrid.py} +0 -0
- {roms_tools-2.2.1.dist-info → roms_tools-2.3.0.dist-info}/LICENSE +0 -0
- {roms_tools-2.2.1.dist-info → roms_tools-2.3.0.dist-info}/WHEEL +0 -0
- {roms_tools-2.2.1.dist-info → roms_tools-2.3.0.dist-info}/top_level.txt +0 -0
|
@@ -1,37 +1,35 @@
|
|
|
1
1
|
import cartopy.crs as ccrs
|
|
2
2
|
import matplotlib.pyplot as plt
|
|
3
3
|
import xarray as xr
|
|
4
|
+
import numpy as np
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
def _plot(
|
|
7
|
-
|
|
8
|
-
field=None,
|
|
8
|
+
field,
|
|
9
9
|
depth_contours=False,
|
|
10
|
-
straddle=False,
|
|
11
10
|
c="red",
|
|
12
11
|
title="",
|
|
13
12
|
with_dim_names=False,
|
|
13
|
+
plot_data=True,
|
|
14
14
|
kwargs={},
|
|
15
15
|
):
|
|
16
16
|
"""Plots a grid or field on a map with optional depth contours.
|
|
17
17
|
|
|
18
18
|
This function plots a map using Cartopy projections. It supports plotting a grid, a field, and adding depth contours if desired.
|
|
19
|
-
The projection can be customized, and the grid can be adjusted for domains straddling the 180° meridian.
|
|
20
19
|
|
|
21
20
|
Parameters
|
|
22
21
|
----------
|
|
23
|
-
grid_ds : xarray.Dataset
|
|
24
|
-
The grid dataset containing coordinates (`lon_rho`, `lat_rho`).
|
|
25
22
|
field : xarray.DataArray, optional
|
|
26
23
|
The field to plot. If None, only the grid is plotted.
|
|
27
24
|
depth_contours : bool, optional
|
|
28
25
|
If True, adds depth contours to the plot.
|
|
29
|
-
straddle : bool, optional
|
|
30
|
-
If True, adjusts longitude values to straddle across the 180° meridian.
|
|
31
26
|
c : str, optional
|
|
32
27
|
Color for the boundary plot (default is 'red').
|
|
33
28
|
title : str, optional
|
|
34
29
|
Title of the plot.
|
|
30
|
+
plot_data : bool, optional
|
|
31
|
+
If True, plots the provided field data on the map. If False, only the grid
|
|
32
|
+
boundaries and optional depth contours are plotted. Default is True.
|
|
35
33
|
kwargs : dict, optional
|
|
36
34
|
Additional keyword arguments to pass to `pcolormesh` (e.g., colormap or color limits).
|
|
37
35
|
|
|
@@ -40,24 +38,15 @@ def _plot(
|
|
|
40
38
|
The function raises a `NotImplementedError` if the domain contains the North or South Pole.
|
|
41
39
|
"""
|
|
42
40
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
else:
|
|
48
|
-
|
|
49
|
-
field = field.squeeze()
|
|
50
|
-
lon_deg = field.lon
|
|
51
|
-
lat_deg = field.lat
|
|
41
|
+
field = field.squeeze()
|
|
42
|
+
lon_deg = field.lon
|
|
43
|
+
lat_deg = field.lat
|
|
52
44
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
if straddle:
|
|
60
|
-
lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
|
|
45
|
+
# check if North or South pole are in domain
|
|
46
|
+
if lat_deg.max().values > 89 or lat_deg.min().values < -89:
|
|
47
|
+
raise NotImplementedError(
|
|
48
|
+
"Plotting is not implemented for the case that the domain contains the North or South pole."
|
|
49
|
+
)
|
|
61
50
|
|
|
62
51
|
trans = _get_projection(lon_deg, lat_deg)
|
|
63
52
|
|
|
@@ -71,7 +60,7 @@ def _plot(
|
|
|
71
60
|
ax, lon_deg, lat_deg, trans, c, with_dim_names=with_dim_names
|
|
72
61
|
)
|
|
73
62
|
|
|
74
|
-
if
|
|
63
|
+
if plot_data:
|
|
75
64
|
_add_field_to_ax(ax, lon_deg, lat_deg, field, depth_contours, kwargs=kwargs)
|
|
76
65
|
|
|
77
66
|
ax.coastlines(
|
|
@@ -287,7 +276,7 @@ def _section_plot(field, interface_depth=None, title="", kwargs={}, ax=None):
|
|
|
287
276
|
field.plot(**kwargs, **more_kwargs, ax=ax)
|
|
288
277
|
|
|
289
278
|
if interface_depth is not None:
|
|
290
|
-
layer_key = "s_rho" if "s_rho" in interface_depth else "s_w"
|
|
279
|
+
layer_key = "s_rho" if "s_rho" in interface_depth.dims else "s_w"
|
|
291
280
|
|
|
292
281
|
for i in range(len(interface_depth[layer_key])):
|
|
293
282
|
ax.plot(
|
|
@@ -339,7 +328,8 @@ def _profile_plot(field, title="", ax=None):
|
|
|
339
328
|
|
|
340
329
|
|
|
341
330
|
def _line_plot(field, title="", ax=None):
|
|
342
|
-
"""Plots a line graph of the given field
|
|
331
|
+
"""Plots a line graph of the given field, with grey vertical bars where NaNs are
|
|
332
|
+
located.
|
|
343
333
|
|
|
344
334
|
Parameters
|
|
345
335
|
----------
|
|
@@ -358,6 +348,22 @@ def _line_plot(field, title="", ax=None):
|
|
|
358
348
|
if ax is None:
|
|
359
349
|
fig, ax = plt.subplots(1, 1, figsize=(7, 4))
|
|
360
350
|
field.plot(ax=ax)
|
|
351
|
+
|
|
352
|
+
# Loop through the NaNs in the field and add grey vertical bars
|
|
353
|
+
nan_mask = np.isnan(field.values)
|
|
354
|
+
nan_indices = np.where(nan_mask)[0]
|
|
355
|
+
|
|
356
|
+
if len(nan_indices) > 0:
|
|
357
|
+
# Add grey vertical bars for each NaN region
|
|
358
|
+
start_idx = nan_indices[0]
|
|
359
|
+
for idx in range(1, len(nan_indices)):
|
|
360
|
+
if nan_indices[idx] != nan_indices[idx - 1] + 1:
|
|
361
|
+
ax.axvspan(start_idx, nan_indices[idx - 1] + 1, color="gray", alpha=0.3)
|
|
362
|
+
start_idx = nan_indices[idx]
|
|
363
|
+
# Add the last region of NaNs
|
|
364
|
+
ax.axvspan(start_idx, nan_indices[-1] + 1, color="gray", alpha=0.3)
|
|
365
|
+
|
|
366
|
+
# Set plot title and grid
|
|
361
367
|
ax.set_title(title)
|
|
362
368
|
ax.grid()
|
|
363
369
|
|
|
@@ -5,11 +5,19 @@ import logging
|
|
|
5
5
|
import importlib.metadata
|
|
6
6
|
from typing import Dict, Union, List
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
|
-
from roms_tools.setup.grid import Grid
|
|
9
|
-
from roms_tools.setup.regrid import LateralRegrid, VerticalRegrid
|
|
10
8
|
from datetime import datetime
|
|
9
|
+
import matplotlib.pyplot as plt
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from roms_tools import Grid
|
|
12
|
+
from roms_tools.regrid import LateralRegrid, VerticalRegrid
|
|
13
|
+
from roms_tools.vertical_coordinate import compute_depth
|
|
14
|
+
from roms_tools.plot import _section_plot, _line_plot
|
|
15
|
+
from roms_tools.utils import (
|
|
16
|
+
interpolate_from_rho_to_u,
|
|
17
|
+
interpolate_from_rho_to_v,
|
|
18
|
+
transpose_dimensions,
|
|
19
|
+
)
|
|
11
20
|
from roms_tools.setup.datasets import GLORYSDataset, CESMBGCDataset
|
|
12
|
-
from roms_tools.setup.vertical_coordinate import compute_depth
|
|
13
21
|
from roms_tools.setup.utils import (
|
|
14
22
|
get_variable_metadata,
|
|
15
23
|
group_dataset,
|
|
@@ -17,20 +25,14 @@ from roms_tools.setup.utils import (
|
|
|
17
25
|
get_target_coords,
|
|
18
26
|
rotate_velocities,
|
|
19
27
|
compute_barotropic_velocity,
|
|
20
|
-
transpose_dimensions,
|
|
21
28
|
one_dim_fill,
|
|
22
29
|
nan_check,
|
|
23
30
|
substitute_nans_by_fillvalue,
|
|
24
|
-
interpolate_from_rho_to_u,
|
|
25
|
-
interpolate_from_rho_to_v,
|
|
26
31
|
convert_to_roms_time,
|
|
27
32
|
get_boundary_coords,
|
|
28
33
|
_to_yaml,
|
|
29
34
|
_from_yaml,
|
|
30
35
|
)
|
|
31
|
-
from roms_tools.setup.plot import _section_plot, _line_plot
|
|
32
|
-
import matplotlib.pyplot as plt
|
|
33
|
-
from pathlib import Path
|
|
34
36
|
|
|
35
37
|
|
|
36
38
|
@dataclass(frozen=True, kw_only=True)
|
|
@@ -827,11 +829,20 @@ class BoundaryForcing:
|
|
|
827
829
|
var_name_wo_direction, direction = var_name.split("_")
|
|
828
830
|
location = self.variable_info[var_name_wo_direction]["location"]
|
|
829
831
|
|
|
832
|
+
# Find correct mask
|
|
833
|
+
if location == "rho":
|
|
834
|
+
mask = self.grid.ds.mask_rho
|
|
835
|
+
elif location == "u":
|
|
836
|
+
mask = self.grid.ds.mask_u
|
|
837
|
+
elif location == "v":
|
|
838
|
+
mask = self.grid.ds.mask_v
|
|
839
|
+
|
|
840
|
+
mask = mask.isel(**self.bdry_coords[location][direction])
|
|
841
|
+
|
|
830
842
|
if "s_rho" in field.dims:
|
|
831
843
|
field = field.assign_coords(
|
|
832
844
|
{"layer_depth": self.grid.ds[f"layer_depth_{location}_{direction}"]}
|
|
833
845
|
)
|
|
834
|
-
# chose colorbar
|
|
835
846
|
if var_name.startswith(("u", "v", "ubar", "vbar", "zeta")):
|
|
836
847
|
vmax = max(field.max().values, -field.min().values)
|
|
837
848
|
vmin = -vmax
|
|
@@ -872,14 +883,14 @@ class BoundaryForcing:
|
|
|
872
883
|
interface_depth = None
|
|
873
884
|
|
|
874
885
|
_section_plot(
|
|
875
|
-
field,
|
|
886
|
+
field.where(mask),
|
|
876
887
|
interface_depth=interface_depth,
|
|
877
888
|
title=title,
|
|
878
889
|
kwargs=kwargs,
|
|
879
890
|
ax=ax,
|
|
880
891
|
)
|
|
881
892
|
else:
|
|
882
|
-
_line_plot(field, title=title, ax=ax)
|
|
893
|
+
_line_plot(field.where(mask), title=title, ax=ax)
|
|
883
894
|
|
|
884
895
|
def save(
|
|
885
896
|
self,
|
roms_tools/setup/datasets.py
CHANGED
|
@@ -1,13 +1,12 @@
|
|
|
1
1
|
import time
|
|
2
|
-
import re
|
|
3
2
|
import xarray as xr
|
|
4
3
|
from dataclasses import dataclass, field
|
|
5
|
-
import glob
|
|
6
4
|
from datetime import datetime, timedelta
|
|
7
5
|
import numpy as np
|
|
8
6
|
from typing import Dict, Optional, Union, List
|
|
9
7
|
from pathlib import Path
|
|
10
8
|
import logging
|
|
9
|
+
from roms_tools.utils import _load_data
|
|
11
10
|
from roms_tools.setup.utils import (
|
|
12
11
|
assign_dates_to_climatology,
|
|
13
12
|
interpolate_from_climatology,
|
|
@@ -16,7 +15,7 @@ from roms_tools.setup.utils import (
|
|
|
16
15
|
one_dim_fill,
|
|
17
16
|
gc_dist,
|
|
18
17
|
)
|
|
19
|
-
from roms_tools.
|
|
18
|
+
from roms_tools.download import (
|
|
20
19
|
download_correction_data,
|
|
21
20
|
download_topo,
|
|
22
21
|
download_river_data,
|
|
@@ -1945,138 +1944,6 @@ class DaiRiverDataset(RiverDataset):
|
|
|
1945
1944
|
# shared functions
|
|
1946
1945
|
|
|
1947
1946
|
|
|
1948
|
-
def _load_data(filename, dim_names, use_dask, decode_times=True):
|
|
1949
|
-
"""Load dataset from the specified file.
|
|
1950
|
-
|
|
1951
|
-
Parameters
|
|
1952
|
-
----------
|
|
1953
|
-
filename : Union[str, Path, List[Union[str, Path]]]
|
|
1954
|
-
The path to the data file(s). Can be a single string (with or without wildcards), a single Path object,
|
|
1955
|
-
or a list of strings or Path objects containing multiple files.
|
|
1956
|
-
dim_names: Dict[str, str], optional
|
|
1957
|
-
Dictionary specifying the names of dimensions in the dataset.
|
|
1958
|
-
use_dask: bool
|
|
1959
|
-
Indicates whether to use dask for chunking. If True, data is loaded with dask; if False, data is loaded eagerly. Defaults to False.
|
|
1960
|
-
decode_times: bool, optional
|
|
1961
|
-
If True, decode times encoded in the standard NetCDF datetime format into datetime objects. Otherwise, leave them encoded as numbers.
|
|
1962
|
-
Defaults to True.
|
|
1963
|
-
|
|
1964
|
-
Returns
|
|
1965
|
-
-------
|
|
1966
|
-
ds : xr.Dataset
|
|
1967
|
-
The loaded xarray Dataset containing the forcing data.
|
|
1968
|
-
|
|
1969
|
-
Raises
|
|
1970
|
-
------
|
|
1971
|
-
FileNotFoundError
|
|
1972
|
-
If the specified file does not exist.
|
|
1973
|
-
ValueError
|
|
1974
|
-
If a list of files is provided but dim_names["time"] is not available or use_dask=False.
|
|
1975
|
-
"""
|
|
1976
|
-
|
|
1977
|
-
# Precompile the regex for matching wildcard characters
|
|
1978
|
-
wildcard_regex = re.compile(r"[\*\?\[\]]")
|
|
1979
|
-
|
|
1980
|
-
# Convert Path objects to strings
|
|
1981
|
-
if isinstance(filename, (str, Path)):
|
|
1982
|
-
filename_str = str(filename)
|
|
1983
|
-
elif isinstance(filename, list):
|
|
1984
|
-
filename_str = [str(f) for f in filename]
|
|
1985
|
-
else:
|
|
1986
|
-
raise ValueError("filename must be a string, Path, or a list of strings/Paths.")
|
|
1987
|
-
# Handle the case when filename is a string
|
|
1988
|
-
contains_wildcard = False
|
|
1989
|
-
if isinstance(filename_str, str):
|
|
1990
|
-
contains_wildcard = bool(wildcard_regex.search(filename_str))
|
|
1991
|
-
if contains_wildcard:
|
|
1992
|
-
matching_files = glob.glob(filename_str)
|
|
1993
|
-
if not matching_files:
|
|
1994
|
-
raise FileNotFoundError(
|
|
1995
|
-
f"No files found matching the pattern '{filename_str}'."
|
|
1996
|
-
)
|
|
1997
|
-
else:
|
|
1998
|
-
matching_files = [filename_str]
|
|
1999
|
-
|
|
2000
|
-
# Handle the case when filename is a list
|
|
2001
|
-
elif isinstance(filename_str, list):
|
|
2002
|
-
contains_wildcard = any(wildcard_regex.search(f) for f in filename_str)
|
|
2003
|
-
if contains_wildcard:
|
|
2004
|
-
matching_files = []
|
|
2005
|
-
for f in filename_str:
|
|
2006
|
-
files = glob.glob(f)
|
|
2007
|
-
if not files:
|
|
2008
|
-
raise FileNotFoundError(
|
|
2009
|
-
f"No files found matching the pattern '{f}'."
|
|
2010
|
-
)
|
|
2011
|
-
matching_files.extend(files)
|
|
2012
|
-
else:
|
|
2013
|
-
matching_files = filename_str
|
|
2014
|
-
|
|
2015
|
-
# Check if time dimension is available when multiple files are provided
|
|
2016
|
-
if isinstance(filename_str, list) and "time" not in dim_names:
|
|
2017
|
-
raise ValueError(
|
|
2018
|
-
"A list of files is provided, but time dimension is not available. "
|
|
2019
|
-
"A time dimension must be available to concatenate the files."
|
|
2020
|
-
)
|
|
2021
|
-
|
|
2022
|
-
# Determine the kwargs for combining datasets
|
|
2023
|
-
if contains_wildcard or len(matching_files) == 1:
|
|
2024
|
-
# If there is a wildcard or just one file, use by_coords
|
|
2025
|
-
kwargs = {"combine": "by_coords"}
|
|
2026
|
-
else:
|
|
2027
|
-
# Otherwise, use nested combine based on time
|
|
2028
|
-
kwargs = {"combine": "nested", "concat_dim": dim_names["time"]}
|
|
2029
|
-
|
|
2030
|
-
# Base kwargs used for dataset combination
|
|
2031
|
-
combine_kwargs = {
|
|
2032
|
-
"coords": "minimal",
|
|
2033
|
-
"compat": "override",
|
|
2034
|
-
"combine_attrs": "override",
|
|
2035
|
-
}
|
|
2036
|
-
|
|
2037
|
-
if use_dask:
|
|
2038
|
-
|
|
2039
|
-
chunks = {
|
|
2040
|
-
dim_names["latitude"]: -1,
|
|
2041
|
-
dim_names["longitude"]: -1,
|
|
2042
|
-
}
|
|
2043
|
-
if "depth" in dim_names:
|
|
2044
|
-
chunks[dim_names["depth"]] = -1
|
|
2045
|
-
if "time" in dim_names:
|
|
2046
|
-
chunks[dim_names["time"]] = 1
|
|
2047
|
-
|
|
2048
|
-
ds = xr.open_mfdataset(
|
|
2049
|
-
matching_files,
|
|
2050
|
-
decode_times=decode_times,
|
|
2051
|
-
chunks=chunks,
|
|
2052
|
-
**combine_kwargs,
|
|
2053
|
-
**kwargs,
|
|
2054
|
-
)
|
|
2055
|
-
|
|
2056
|
-
# Rechunk the dataset along the tidal constituent dimension ("ntides") after loading
|
|
2057
|
-
# because the original dataset does not have a chunk size of 1 along this dimension.
|
|
2058
|
-
if "ntides" in dim_names:
|
|
2059
|
-
ds = ds.chunk({dim_names["ntides"]: 1})
|
|
2060
|
-
|
|
2061
|
-
else:
|
|
2062
|
-
ds_list = []
|
|
2063
|
-
for file in matching_files:
|
|
2064
|
-
ds = xr.open_dataset(file, decode_times=decode_times, chunks=None)
|
|
2065
|
-
ds_list.append(ds)
|
|
2066
|
-
|
|
2067
|
-
if kwargs["combine"] == "by_coords":
|
|
2068
|
-
ds = xr.combine_by_coords(ds_list, **combine_kwargs)
|
|
2069
|
-
elif kwargs["combine"] == "nested":
|
|
2070
|
-
ds = xr.combine_nested(
|
|
2071
|
-
ds_list, concat_dim=kwargs["concat_dim"], **combine_kwargs
|
|
2072
|
-
)
|
|
2073
|
-
|
|
2074
|
-
if "time" in dim_names and dim_names["time"] not in ds.dims:
|
|
2075
|
-
ds = ds.expand_dims(dim_names["time"])
|
|
2076
|
-
|
|
2077
|
-
return ds
|
|
2078
|
-
|
|
2079
|
-
|
|
2080
1947
|
def _check_dataset(
|
|
2081
1948
|
ds: xr.Dataset,
|
|
2082
1949
|
dim_names: Dict[str, str],
|
roms_tools/setup/grid.py
CHANGED
|
@@ -10,14 +10,18 @@ import importlib.metadata
|
|
|
10
10
|
from typing import Dict, Union, List
|
|
11
11
|
from roms_tools.setup.topography import _add_topography
|
|
12
12
|
from roms_tools.setup.mask import _add_mask, _add_velocity_masks
|
|
13
|
-
from roms_tools.
|
|
13
|
+
from roms_tools.vertical_coordinate import (
|
|
14
|
+
sigma_stretch,
|
|
15
|
+
compute_depth,
|
|
16
|
+
add_depth_coordinates_to_dataset,
|
|
17
|
+
)
|
|
18
|
+
from roms_tools.plot import _plot, _section_plot
|
|
14
19
|
from roms_tools.setup.utils import (
|
|
15
20
|
interpolate_from_rho_to_u,
|
|
16
21
|
interpolate_from_rho_to_v,
|
|
17
22
|
get_target_coords,
|
|
18
23
|
gc_dist,
|
|
19
24
|
)
|
|
20
|
-
from roms_tools.setup.vertical_coordinate import sigma_stretch, compute_depth
|
|
21
25
|
from roms_tools.setup.utils import extract_single_value, save_datasets
|
|
22
26
|
from pathlib import Path
|
|
23
27
|
|
|
@@ -412,13 +416,16 @@ class Grid:
|
|
|
412
416
|
This method does not return any value. It generates and displays a plot.
|
|
413
417
|
"""
|
|
414
418
|
|
|
419
|
+
field = self.ds.h.where(self.ds.mask_rho)
|
|
420
|
+
lat_deg = self.ds.lat_rho
|
|
421
|
+
lon_deg = self.ds.lon_rho
|
|
422
|
+
if self.straddle:
|
|
423
|
+
lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
|
|
424
|
+
field = field.assign_coords({"lon": lon_deg, "lat": lat_deg})
|
|
425
|
+
|
|
415
426
|
if bathymetry:
|
|
416
427
|
if title is None:
|
|
417
428
|
title = "ROMS grid and bathymetry"
|
|
418
|
-
field = self.ds.h.where(self.ds.mask_rho)
|
|
419
|
-
field = field.assign_coords(
|
|
420
|
-
{"lon": self.ds.lon_rho, "lat": self.ds.lat_rho}
|
|
421
|
-
)
|
|
422
429
|
|
|
423
430
|
vmax = field.max().values
|
|
424
431
|
vmin = field.min().values
|
|
@@ -427,9 +434,7 @@ class Grid:
|
|
|
427
434
|
kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
|
|
428
435
|
|
|
429
436
|
_plot(
|
|
430
|
-
self.ds,
|
|
431
437
|
field=field,
|
|
432
|
-
straddle=self.straddle,
|
|
433
438
|
title=title,
|
|
434
439
|
with_dim_names=with_dim_names,
|
|
435
440
|
kwargs=kwargs,
|
|
@@ -438,12 +443,44 @@ class Grid:
|
|
|
438
443
|
if title is None:
|
|
439
444
|
title = "ROMS grid"
|
|
440
445
|
_plot(
|
|
441
|
-
|
|
442
|
-
straddle=self.straddle,
|
|
443
|
-
title=title,
|
|
444
|
-
with_dim_names=with_dim_names,
|
|
446
|
+
field=field, title=title, with_dim_names=with_dim_names, plot_data=False
|
|
445
447
|
)
|
|
446
448
|
|
|
449
|
+
def compute_depth_coordinates(
|
|
450
|
+
self, depth_type: str, locations: list[str] = ["rho", "u", "v"]
|
|
451
|
+
):
|
|
452
|
+
"""Compute and update vertical depth coordinates.
|
|
453
|
+
|
|
454
|
+
Calculates vertical depth coordinates (layer or interface) for specified locations (e.g., rho, u, v points)
|
|
455
|
+
and updates them in the dataset (`self.ds`).
|
|
456
|
+
|
|
457
|
+
Parameters
|
|
458
|
+
----------
|
|
459
|
+
depth_type : str
|
|
460
|
+
The type of depth coordinate to compute. Valid options:
|
|
461
|
+
- "layer": Compute layer depth coordinates.
|
|
462
|
+
- "interface": Compute interface depth coordinates.
|
|
463
|
+
locations : list[str], optional
|
|
464
|
+
Locations for which to compute depth coordinates. Default is ["rho", "u", "v"].
|
|
465
|
+
Valid options include:
|
|
466
|
+
- "rho": Depth coordinates at rho points.
|
|
467
|
+
- "u": Depth coordinates at u points.
|
|
468
|
+
- "v": Depth coordinates at v points.
|
|
469
|
+
|
|
470
|
+
Updates
|
|
471
|
+
-------
|
|
472
|
+
self.ds : xarray.Dataset
|
|
473
|
+
The dataset (`self.ds`) is updated with the following depth coordinate variables:
|
|
474
|
+
- f"{depth_type}_depth_rho": Depth coordinates at rho points.
|
|
475
|
+
- f"{depth_type}_depth_u": Depth coordinates at u points (if included in `locations`).
|
|
476
|
+
- f"{depth_type}_depth_v": Depth coordinates at v points (if included in `locations`).
|
|
477
|
+
|
|
478
|
+
Notes
|
|
479
|
+
-----
|
|
480
|
+
This method uses the `compute_and_update_depth_coordinates` function to perform calculations and updates.
|
|
481
|
+
"""
|
|
482
|
+
add_depth_coordinates_to_dataset(self.ds, self.ds, depth_type, locations)
|
|
483
|
+
|
|
447
484
|
def plot_vertical_coordinate(
|
|
448
485
|
self,
|
|
449
486
|
s=None,
|
|
@@ -480,7 +517,11 @@ class Grid:
|
|
|
480
517
|
raise ValueError("Exactly one of s, eta, or xi must be specified.")
|
|
481
518
|
|
|
482
519
|
h = self.ds["h"]
|
|
483
|
-
|
|
520
|
+
lat_deg = self.ds.lat_rho
|
|
521
|
+
lon_deg = self.ds.lon_rho
|
|
522
|
+
if self.straddle:
|
|
523
|
+
lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
|
|
524
|
+
h = h.assign_coords({"lon": lon_deg, "lat": lat_deg})
|
|
484
525
|
|
|
485
526
|
# slice the bathymetry as desired
|
|
486
527
|
if eta is not None:
|
|
@@ -505,9 +546,7 @@ class Grid:
|
|
|
505
546
|
kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
|
|
506
547
|
|
|
507
548
|
_plot(
|
|
508
|
-
self.ds,
|
|
509
549
|
field=layer_depth.where(self.ds.mask_rho),
|
|
510
|
-
straddle=self.straddle,
|
|
511
550
|
depth_contours=False,
|
|
512
551
|
title=title,
|
|
513
552
|
kwargs=kwargs,
|