roms-tools 2.2.1__py3-none-any.whl → 2.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- roms_tools/__init__.py +1 -0
- roms_tools/analysis/roms_output.py +586 -0
- roms_tools/{setup/download.py → download.py} +3 -0
- roms_tools/{setup/plot.py → plot.py} +34 -28
- roms_tools/setup/boundary_forcing.py +23 -12
- roms_tools/setup/datasets.py +2 -135
- roms_tools/setup/grid.py +54 -15
- roms_tools/setup/initial_conditions.py +105 -149
- roms_tools/setup/nesting.py +4 -4
- roms_tools/setup/river_forcing.py +7 -9
- roms_tools/setup/surface_forcing.py +14 -14
- roms_tools/setup/tides.py +24 -21
- roms_tools/setup/topography.py +1 -1
- roms_tools/setup/utils.py +20 -154
- roms_tools/tests/test_analysis/test_roms_output.py +269 -0
- roms_tools/tests/{test_setup/test_regrid.py → test_regrid.py} +1 -1
- roms_tools/tests/test_setup/test_boundary_forcing.py +1 -1
- roms_tools/tests/test_setup/test_datasets.py +1 -1
- roms_tools/tests/test_setup/test_grid.py +1 -1
- roms_tools/tests/test_setup/test_initial_conditions.py +1 -1
- roms_tools/tests/test_setup/test_river_forcing.py +1 -1
- roms_tools/tests/test_setup/test_surface_forcing.py +1 -1
- roms_tools/tests/test_setup/test_tides.py +1 -1
- roms_tools/tests/test_setup/test_topography.py +1 -1
- roms_tools/tests/test_setup/test_utils.py +56 -1
- roms_tools/utils.py +301 -0
- roms_tools/vertical_coordinate.py +306 -0
- {roms_tools-2.2.1.dist-info → roms_tools-2.3.0.dist-info}/METADATA +1 -1
- {roms_tools-2.2.1.dist-info → roms_tools-2.3.0.dist-info}/RECORD +33 -31
- roms_tools/setup/vertical_coordinate.py +0 -109
- /roms_tools/{setup/regrid.py → regrid.py} +0 -0
- {roms_tools-2.2.1.dist-info → roms_tools-2.3.0.dist-info}/LICENSE +0 -0
- {roms_tools-2.2.1.dist-info → roms_tools-2.3.0.dist-info}/WHEEL +0 -0
- {roms_tools-2.2.1.dist-info → roms_tools-2.3.0.dist-info}/top_level.txt +0 -0
roms_tools/__init__.py
CHANGED
|
@@ -15,6 +15,7 @@ from roms_tools.setup.initial_conditions import InitialConditions # noqa: F401
|
|
|
15
15
|
from roms_tools.setup.boundary_forcing import BoundaryForcing # noqa: F401
|
|
16
16
|
from roms_tools.setup.river_forcing import RiverForcing # noqa: F401
|
|
17
17
|
from roms_tools.setup.nesting import Nesting # noqa: F401
|
|
18
|
+
from roms_tools.analysis.roms_output import ROMSOutput # noqa: F401
|
|
18
19
|
|
|
19
20
|
# Configure logging when the package is imported
|
|
20
21
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s")
|
|
@@ -0,0 +1,586 @@
|
|
|
1
|
+
import xarray as xr
|
|
2
|
+
import numpy as np
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
from roms_tools.utils import _load_data
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Union, Optional
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
import os
|
|
9
|
+
import re
|
|
10
|
+
import logging
|
|
11
|
+
from datetime import datetime, timedelta
|
|
12
|
+
from roms_tools import Grid
|
|
13
|
+
from roms_tools.plot import _plot, _section_plot, _profile_plot, _line_plot
|
|
14
|
+
from roms_tools.vertical_coordinate import (
|
|
15
|
+
add_depth_coordinates_to_dataset,
|
|
16
|
+
compute_depth_coordinates,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(frozen=True, kw_only=True)
|
|
21
|
+
class ROMSOutput:
|
|
22
|
+
"""Represents ROMS model output.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
grid : Grid
|
|
27
|
+
Object representing the grid information.
|
|
28
|
+
path : Union[str, Path, List[Union[str, Path]]]
|
|
29
|
+
Directory, filename, or list of filenames with model output.
|
|
30
|
+
type : str
|
|
31
|
+
Specifies the type of model output. Options are:
|
|
32
|
+
|
|
33
|
+
- "restart": for restart files.
|
|
34
|
+
- "average": for time-averaged files.
|
|
35
|
+
- "snapshot": for snapshot files.
|
|
36
|
+
|
|
37
|
+
model_reference_date : datetime, optional
|
|
38
|
+
If not specified, this is inferred from metadata of the model output
|
|
39
|
+
If specified and does not coincide with metadata, a warning is raised.
|
|
40
|
+
use_dask: bool, optional
|
|
41
|
+
Indicates whether to use dask for processing. If True, data is processed with dask; if False, data is processed eagerly. Defaults to False.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
grid: Grid
|
|
45
|
+
path: Union[str, Path]
|
|
46
|
+
type: Union[str, Path]
|
|
47
|
+
use_dask: bool = False
|
|
48
|
+
model_reference_date: Optional[datetime] = None
|
|
49
|
+
ds: xr.Dataset = field(init=False, repr=False)
|
|
50
|
+
|
|
51
|
+
def __post_init__(self):
|
|
52
|
+
# Validate `type`
|
|
53
|
+
if self.type not in {"restart", "average", "snapshot"}:
|
|
54
|
+
raise ValueError(
|
|
55
|
+
f"Invalid type '{self.type}'. Must be one of 'restart', 'average', or 'snapshot'."
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
ds = self._load_model_output()
|
|
59
|
+
self._infer_model_reference_date_from_metadata(ds)
|
|
60
|
+
self._check_vertical_coordinate(ds)
|
|
61
|
+
ds = self._add_absolute_time(ds)
|
|
62
|
+
ds = self._add_lat_lon_coords(ds)
|
|
63
|
+
object.__setattr__(self, "ds", ds)
|
|
64
|
+
|
|
65
|
+
def plot(
|
|
66
|
+
self,
|
|
67
|
+
var_name,
|
|
68
|
+
time=0,
|
|
69
|
+
s=None,
|
|
70
|
+
eta=None,
|
|
71
|
+
xi=None,
|
|
72
|
+
depth_contours=False,
|
|
73
|
+
layer_contours=False,
|
|
74
|
+
ax=None,
|
|
75
|
+
) -> None:
|
|
76
|
+
"""Plot a ROMS output field for a given vertical (s_rho) or horizontal (eta, xi)
|
|
77
|
+
slice.
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
----------
|
|
81
|
+
var_name : str
|
|
82
|
+
Name of the variable to plot. Supported options include:
|
|
83
|
+
|
|
84
|
+
- Oceanographic fields: "temp", "salt", "zeta", "u", "v", "w", etc.
|
|
85
|
+
- Biogeochemical tracers: "PO4", "NO3", "O2", "DIC", "ALK", etc.
|
|
86
|
+
|
|
87
|
+
time : int, optional
|
|
88
|
+
Index of the time dimension to plot. Default is 0.
|
|
89
|
+
s : int, optional
|
|
90
|
+
The index of the vertical layer (`s_rho`) to plot. If not specified, the plot
|
|
91
|
+
will represent a horizontal slice (eta- or xi- plane). Default is None.
|
|
92
|
+
eta : int, optional
|
|
93
|
+
The eta-index to plot. Used for vertical sections or horizontal slices.
|
|
94
|
+
Default is None.
|
|
95
|
+
xi : int, optional
|
|
96
|
+
The xi-index to plot. Used for vertical sections or horizontal slices.
|
|
97
|
+
Default is None.
|
|
98
|
+
depth_contours : bool, optional
|
|
99
|
+
If True, depth contours will be overlaid on the plot, showing lines of constant
|
|
100
|
+
depth. This is typically used for plots that show a single vertical layer.
|
|
101
|
+
Default is False.
|
|
102
|
+
layer_contours : bool, optional
|
|
103
|
+
If True, contour lines representing the boundaries between vertical layers will
|
|
104
|
+
be added to the plot. This is particularly useful in vertical sections to
|
|
105
|
+
visualize the layering of the water column. For clarity, the number of layer
|
|
106
|
+
contours displayed is limited to a maximum of 10. Default is False.
|
|
107
|
+
ax : matplotlib.axes.Axes, optional
|
|
108
|
+
The axes to plot on. If None, a new figure is created. Note that this argument does not work for horizontal plots that display the eta- and xi-dimensions at the same time.
|
|
109
|
+
|
|
110
|
+
Returns
|
|
111
|
+
-------
|
|
112
|
+
None
|
|
113
|
+
This method does not return any value. It generates and displays a plot.
|
|
114
|
+
|
|
115
|
+
Raises
|
|
116
|
+
------
|
|
117
|
+
ValueError
|
|
118
|
+
If the specified `var_name` is not one of the valid options.
|
|
119
|
+
If the field specified by `var_name` is 3D and none of `s`, `eta`, or `xi` are specified.
|
|
120
|
+
If the field specified by `var_name` is 2D and both `eta` and `xi` are specified.
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
# Input checks
|
|
124
|
+
if var_name not in self.ds:
|
|
125
|
+
raise ValueError(f"Variable '{var_name}' is not found in dataset.")
|
|
126
|
+
|
|
127
|
+
if "time" in self.ds[var_name].dims:
|
|
128
|
+
if time >= len(self.ds[var_name].time):
|
|
129
|
+
raise ValueError(
|
|
130
|
+
f"Invalid time index: The specified time index ({time}) exceeds the maximum index "
|
|
131
|
+
f"({len(self.ds[var_name].time) - 1}) for the 'time' dimension in variable '{var_name}'."
|
|
132
|
+
)
|
|
133
|
+
field = self.ds[var_name].isel(time=time)
|
|
134
|
+
else:
|
|
135
|
+
if time > 0:
|
|
136
|
+
raise ValueError(
|
|
137
|
+
f"Invalid input: The variable '{var_name}' does not have a 'time' dimension, "
|
|
138
|
+
f"but a time index ({time}) greater than 0 was provided."
|
|
139
|
+
)
|
|
140
|
+
field = self.ds[var_name]
|
|
141
|
+
|
|
142
|
+
if len(field.dims) == 3:
|
|
143
|
+
if not any([s is not None, eta is not None, xi is not None]):
|
|
144
|
+
raise ValueError(
|
|
145
|
+
"Invalid input: For 3D fields, you must specify at least one of the dimensions 's', 'eta', or 'xi'."
|
|
146
|
+
)
|
|
147
|
+
if all([s is not None, eta is not None, xi is not None]):
|
|
148
|
+
raise ValueError(
|
|
149
|
+
"Ambiguous input: For 3D fields, specify at most two of 's', 'eta', or 'xi'. Specifying all three is not allowed."
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
if len(field.dims) == 2 and all([eta is not None, xi is not None]):
|
|
153
|
+
raise ValueError(
|
|
154
|
+
"Conflicting input: For 2D fields, specify only one dimension, either 'eta' or 'xi', not both."
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# Load the data
|
|
158
|
+
if self.use_dask:
|
|
159
|
+
from dask.diagnostics import ProgressBar
|
|
160
|
+
|
|
161
|
+
with ProgressBar():
|
|
162
|
+
field.load()
|
|
163
|
+
|
|
164
|
+
# Get correct mask and spatial coordinates
|
|
165
|
+
if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
|
|
166
|
+
loc = "rho"
|
|
167
|
+
elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
|
|
168
|
+
loc = "u"
|
|
169
|
+
elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
|
|
170
|
+
loc = "v"
|
|
171
|
+
else:
|
|
172
|
+
ValueError("provided field does not have two horizontal dimension")
|
|
173
|
+
|
|
174
|
+
mask = self.grid.ds[f"mask_{loc}"]
|
|
175
|
+
lat_deg = self.grid.ds[f"lat_{loc}"]
|
|
176
|
+
lon_deg = self.grid.ds[f"lon_{loc}"]
|
|
177
|
+
|
|
178
|
+
if self.grid.straddle:
|
|
179
|
+
lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
|
|
180
|
+
|
|
181
|
+
field = field.assign_coords({"lon": lon_deg, "lat": lat_deg})
|
|
182
|
+
|
|
183
|
+
# Retrieve depth coordinates
|
|
184
|
+
compute_layer_depth = (depth_contours or s is None) and len(field.dims) > 2
|
|
185
|
+
compute_interface_depth = layer_contours and s is None
|
|
186
|
+
|
|
187
|
+
if compute_layer_depth:
|
|
188
|
+
layer_depth = compute_depth_coordinates(
|
|
189
|
+
self.ds.isel(time=time),
|
|
190
|
+
self.grid.ds,
|
|
191
|
+
depth_type="layer",
|
|
192
|
+
location=loc,
|
|
193
|
+
s=s,
|
|
194
|
+
eta=eta,
|
|
195
|
+
xi=xi,
|
|
196
|
+
)
|
|
197
|
+
if compute_interface_depth:
|
|
198
|
+
interface_depth = compute_depth_coordinates(
|
|
199
|
+
self.ds.isel(time=time),
|
|
200
|
+
self.grid.ds,
|
|
201
|
+
depth_type="interface",
|
|
202
|
+
location=loc,
|
|
203
|
+
s=s,
|
|
204
|
+
eta=eta,
|
|
205
|
+
xi=xi,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# Slice the field as desired
|
|
209
|
+
title = field.long_name
|
|
210
|
+
if s is not None:
|
|
211
|
+
title = title + f", s_rho = {field.s_rho[s].item()}"
|
|
212
|
+
field = field.isel(s_rho=s)
|
|
213
|
+
else:
|
|
214
|
+
depth_contours = False
|
|
215
|
+
|
|
216
|
+
def _process_dimension(field, mask, dim_name, dim_values, idx, title):
|
|
217
|
+
if dim_name in field.dims:
|
|
218
|
+
title = title + f", {dim_name} = {dim_values[idx].item()}"
|
|
219
|
+
field = field.isel(**{dim_name: idx})
|
|
220
|
+
mask = mask.isel(**{dim_name: idx})
|
|
221
|
+
else:
|
|
222
|
+
raise ValueError(
|
|
223
|
+
f"None of the expected dimensions ({dim_name}) found in field."
|
|
224
|
+
)
|
|
225
|
+
return field, mask, title
|
|
226
|
+
|
|
227
|
+
if eta is not None:
|
|
228
|
+
field, mask, title = _process_dimension(
|
|
229
|
+
field,
|
|
230
|
+
mask,
|
|
231
|
+
"eta_rho" if "eta_rho" in field.dims else "eta_v",
|
|
232
|
+
field.eta_rho if "eta_rho" in field.dims else field.eta_v,
|
|
233
|
+
eta,
|
|
234
|
+
title,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
if xi is not None:
|
|
238
|
+
field, mask, title = _process_dimension(
|
|
239
|
+
field,
|
|
240
|
+
mask,
|
|
241
|
+
"xi_rho" if "xi_rho" in field.dims else "xi_u",
|
|
242
|
+
field.xi_rho if "xi_rho" in field.dims else field.xi_u,
|
|
243
|
+
xi,
|
|
244
|
+
title,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# Format to exclude seconds
|
|
248
|
+
formatted_time = np.datetime_as_string(field.abs_time.values, unit="m")
|
|
249
|
+
title = title + f", time: {formatted_time}"
|
|
250
|
+
|
|
251
|
+
if compute_layer_depth:
|
|
252
|
+
field = field.assign_coords({"layer_depth": layer_depth})
|
|
253
|
+
|
|
254
|
+
# Choose colorbar
|
|
255
|
+
if var_name in ["u", "v", "w", "ubar", "vbar", "zeta"]:
|
|
256
|
+
vmax = max(field.where(mask).max().values, -field.where(mask).min().values)
|
|
257
|
+
vmin = -vmax
|
|
258
|
+
cmap = plt.colormaps.get_cmap("RdBu_r")
|
|
259
|
+
else:
|
|
260
|
+
vmax = field.where(mask).max().values
|
|
261
|
+
vmin = field.where(mask).min().values
|
|
262
|
+
if var_name in ["temp", "salt"]:
|
|
263
|
+
cmap = plt.colormaps.get_cmap("YlOrRd")
|
|
264
|
+
else:
|
|
265
|
+
cmap = plt.colormaps.get_cmap("YlGn")
|
|
266
|
+
cmap.set_bad(color="gray")
|
|
267
|
+
kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
|
|
268
|
+
|
|
269
|
+
# Plotting
|
|
270
|
+
if eta is None and xi is None:
|
|
271
|
+
_plot(
|
|
272
|
+
field=field.where(mask),
|
|
273
|
+
depth_contours=depth_contours,
|
|
274
|
+
title=title,
|
|
275
|
+
kwargs=kwargs,
|
|
276
|
+
c="g",
|
|
277
|
+
)
|
|
278
|
+
else:
|
|
279
|
+
if len(field.dims) == 2:
|
|
280
|
+
if not layer_contours:
|
|
281
|
+
interface_depth = None
|
|
282
|
+
else:
|
|
283
|
+
# restrict number of layer_contours to 10 for the sake of plot clearity
|
|
284
|
+
nr_layers = len(interface_depth["s_w"])
|
|
285
|
+
selected_layers = np.linspace(
|
|
286
|
+
0, nr_layers - 1, min(nr_layers, 10), dtype=int
|
|
287
|
+
)
|
|
288
|
+
interface_depth = interface_depth.isel(s_w=selected_layers)
|
|
289
|
+
_section_plot(
|
|
290
|
+
field.where(mask),
|
|
291
|
+
interface_depth=interface_depth,
|
|
292
|
+
title=title,
|
|
293
|
+
kwargs=kwargs,
|
|
294
|
+
ax=ax,
|
|
295
|
+
)
|
|
296
|
+
else:
|
|
297
|
+
if "s_rho" in field.dims:
|
|
298
|
+
_profile_plot(field.where(mask), title=title, ax=ax)
|
|
299
|
+
else:
|
|
300
|
+
_line_plot(field.where(mask), title=title, ax=ax)
|
|
301
|
+
|
|
302
|
+
def compute_depth_coordinates(self, depth_type="layer", locations=["rho"]):
|
|
303
|
+
"""Compute and update vertical depth coordinates.
|
|
304
|
+
|
|
305
|
+
Calculates vertical depth coordinates (layer or interface) for specified locations (e.g., rho, u, v points)
|
|
306
|
+
and updates them in the dataset (`self.ds`).
|
|
307
|
+
|
|
308
|
+
Parameters
|
|
309
|
+
----------
|
|
310
|
+
depth_type : str
|
|
311
|
+
The type of depth coordinate to compute. Valid options:
|
|
312
|
+
- "layer": Compute layer depth coordinates.
|
|
313
|
+
- "interface": Compute interface depth coordinates.
|
|
314
|
+
locations : list[str], optional
|
|
315
|
+
Locations for which to compute depth coordinates. Default is ["rho", "u", "v"].
|
|
316
|
+
Valid options include:
|
|
317
|
+
- "rho": Depth coordinates at rho points.
|
|
318
|
+
- "u": Depth coordinates at u points.
|
|
319
|
+
- "v": Depth coordinates at v points.
|
|
320
|
+
|
|
321
|
+
Updates
|
|
322
|
+
-------
|
|
323
|
+
self.ds : xarray.Dataset
|
|
324
|
+
The dataset (`self.ds`) is updated with the following depth coordinate variables:
|
|
325
|
+
- f"{depth_type}_depth_rho": Depth coordinates at rho points.
|
|
326
|
+
- f"{depth_type}_depth_u": Depth coordinates at u points (if included in `locations`).
|
|
327
|
+
- f"{depth_type}_depth_v": Depth coordinates at v points (if included in `locations`).
|
|
328
|
+
|
|
329
|
+
Notes
|
|
330
|
+
-----
|
|
331
|
+
This method uses the `compute_and_update_depth_coordinates` function to perform calculations and updates.
|
|
332
|
+
"""
|
|
333
|
+
|
|
334
|
+
add_depth_coordinates_to_dataset(self.ds, self.grid.ds, depth_type, locations)
|
|
335
|
+
|
|
336
|
+
def _load_model_output(self) -> xr.Dataset:
|
|
337
|
+
"""Load the model output based on the type."""
|
|
338
|
+
if isinstance(self.path, list):
|
|
339
|
+
filetype = "list"
|
|
340
|
+
force_combine_nested = True
|
|
341
|
+
# Check if all items in the list are files
|
|
342
|
+
if not all(Path(item).is_file() for item in self.path):
|
|
343
|
+
raise FileNotFoundError(
|
|
344
|
+
"All items in the provided list must be valid files."
|
|
345
|
+
)
|
|
346
|
+
elif Path(self.path).is_file():
|
|
347
|
+
filetype = "file"
|
|
348
|
+
force_combine_nested = False
|
|
349
|
+
elif Path(self.path).is_dir():
|
|
350
|
+
filetype = "dir"
|
|
351
|
+
force_combine_nested = True
|
|
352
|
+
else:
|
|
353
|
+
raise FileNotFoundError(
|
|
354
|
+
f"The specified path '{self.path}' is neither a file, nor a list of files, nor a directory."
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
time_chunking = True
|
|
358
|
+
if self.type == "restart":
|
|
359
|
+
time_chunking = False
|
|
360
|
+
filename = _validate_and_set_filenames(self.path, filetype, "rst")
|
|
361
|
+
elif self.type == "average":
|
|
362
|
+
filename = _validate_and_set_filenames(self.path, filetype, "avg")
|
|
363
|
+
elif self.type == "snapshot":
|
|
364
|
+
filename = _validate_and_set_filenames(self.path, filetype, "his")
|
|
365
|
+
else:
|
|
366
|
+
raise ValueError(f"Unsupported type '{self.type}'.")
|
|
367
|
+
|
|
368
|
+
# Load the dataset
|
|
369
|
+
ds = _load_data(
|
|
370
|
+
filename,
|
|
371
|
+
dim_names={"time": "time"},
|
|
372
|
+
use_dask=self.use_dask,
|
|
373
|
+
time_chunking=time_chunking,
|
|
374
|
+
force_combine_nested=force_combine_nested,
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
return ds
|
|
378
|
+
|
|
379
|
+
def _infer_model_reference_date_from_metadata(self, ds: xr.Dataset) -> None:
|
|
380
|
+
"""Infer and validate the model reference date from `ocean_time` metadata.
|
|
381
|
+
|
|
382
|
+
Parameters
|
|
383
|
+
----------
|
|
384
|
+
ds : xr.Dataset
|
|
385
|
+
Dataset with an `ocean_time` variable and a `long_name` attribute
|
|
386
|
+
in the format `Time since YYYY/MM/DD`.
|
|
387
|
+
|
|
388
|
+
Raises
|
|
389
|
+
------
|
|
390
|
+
ValueError
|
|
391
|
+
If `self.model_reference_date` is not set and the reference date cannot
|
|
392
|
+
be inferred, or if the inferred date does not match `self.model_reference_date`.
|
|
393
|
+
|
|
394
|
+
Warns
|
|
395
|
+
-----
|
|
396
|
+
UserWarning
|
|
397
|
+
If `self.model_reference_date` is set but the reference date cannot be inferred.
|
|
398
|
+
"""
|
|
399
|
+
# Check if 'long_name' exists in the attributes of 'ocean_time'
|
|
400
|
+
if "long_name" in ds.ocean_time.attrs:
|
|
401
|
+
input_string = ds.ocean_time.attrs["long_name"]
|
|
402
|
+
match = re.search(r"(\d{4})/(\d{2})/(\d{2})", input_string)
|
|
403
|
+
|
|
404
|
+
if match:
|
|
405
|
+
# If a match is found, extract year, month, day and create the inferred date
|
|
406
|
+
year, month, day = map(int, match.groups())
|
|
407
|
+
inferred_date = datetime(year, month, day)
|
|
408
|
+
|
|
409
|
+
if hasattr(self, "model_reference_date") and self.model_reference_date:
|
|
410
|
+
# Check if the inferred date matches the provided model reference date
|
|
411
|
+
if self.model_reference_date != inferred_date:
|
|
412
|
+
raise ValueError(
|
|
413
|
+
f"Mismatch between `self.model_reference_date` ({self.model_reference_date}) "
|
|
414
|
+
f"and inferred reference date ({inferred_date})."
|
|
415
|
+
)
|
|
416
|
+
else:
|
|
417
|
+
# Set the model reference date if not already set
|
|
418
|
+
object.__setattr__(self, "model_reference_date", inferred_date)
|
|
419
|
+
else:
|
|
420
|
+
# Handle case where no match is found
|
|
421
|
+
if hasattr(self, "model_reference_date") and self.model_reference_date:
|
|
422
|
+
logging.warning(
|
|
423
|
+
"Could not infer the model reference date from the metadata. "
|
|
424
|
+
"`self.model_reference_date` will be used.",
|
|
425
|
+
)
|
|
426
|
+
else:
|
|
427
|
+
raise ValueError(
|
|
428
|
+
"Model reference date could not be inferred from the metadata, "
|
|
429
|
+
"and `self.model_reference_date` is not set."
|
|
430
|
+
)
|
|
431
|
+
else:
|
|
432
|
+
# Handle case where 'long_name' attribute doesn't exist
|
|
433
|
+
if hasattr(self, "model_reference_date") and self.model_reference_date:
|
|
434
|
+
logging.warning(
|
|
435
|
+
"`long_name` attribute not found in ocean_time. "
|
|
436
|
+
"`self.model_reference_date` will be used instead.",
|
|
437
|
+
)
|
|
438
|
+
else:
|
|
439
|
+
raise ValueError(
|
|
440
|
+
"Model reference date could not be inferred from the metadata, "
|
|
441
|
+
"and `self.model_reference_date` is not set."
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
def _check_vertical_coordinate(self, ds: xr.Dataset) -> None:
|
|
445
|
+
"""Check that the vertical coordinate parameters in the dataset are consistent
|
|
446
|
+
with the model grid.
|
|
447
|
+
|
|
448
|
+
This method compares the vertical coordinate parameters (`theta_s`, `theta_b`, `hc`, `Cs_r`, `Cs_w`) in
|
|
449
|
+
the provided dataset (`ds`) with those in the model grid (`self.grid`). The first three parameters are
|
|
450
|
+
checked for exact equality, while the last two are checked for numerical closeness.
|
|
451
|
+
|
|
452
|
+
Parameters
|
|
453
|
+
----------
|
|
454
|
+
ds : xarray.Dataset
|
|
455
|
+
The dataset containing vertical coordinate parameters in its attributes, such as `theta_s`, `theta_b`,
|
|
456
|
+
`hc`, `Cs_r`, and `Cs_w`.
|
|
457
|
+
|
|
458
|
+
Raises
|
|
459
|
+
------
|
|
460
|
+
ValueError
|
|
461
|
+
If the vertical coordinate parameters do not match the expected values (based on exact or approximate equality).
|
|
462
|
+
|
|
463
|
+
Notes
|
|
464
|
+
-----
|
|
465
|
+
- `theta_s`, `theta_b`, and `hc` are checked for exact equality using `np.array_equal`.
|
|
466
|
+
- `Cs_r` and `Cs_w` are checked for numerical closeness using `np.allclose`.
|
|
467
|
+
"""
|
|
468
|
+
|
|
469
|
+
# Check exact equality for theta_s, theta_b, and hc
|
|
470
|
+
if not np.array_equal(self.grid.theta_s, ds.attrs["theta_s"]):
|
|
471
|
+
raise ValueError(
|
|
472
|
+
f"theta_s from grid ({self.grid.theta_s}) does not match dataset ({ds.attrs['theta_s']})."
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
if not np.array_equal(self.grid.theta_b, ds.attrs["theta_b"]):
|
|
476
|
+
raise ValueError(
|
|
477
|
+
f"theta_b from grid ({self.grid.theta_b}) does not match dataset ({ds.attrs['theta_b']})."
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
if not np.array_equal(self.grid.hc, ds.attrs["hc"]):
|
|
481
|
+
raise ValueError(
|
|
482
|
+
f"hc from grid ({self.grid.hc}) does not match dataset ({ds.attrs['hc']})."
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
# Check numerical closeness for Cs_r and Cs_w
|
|
486
|
+
if not np.allclose(self.grid.ds.Cs_r, ds.attrs["Cs_r"]):
|
|
487
|
+
raise ValueError(
|
|
488
|
+
f"Cs_r from grid ({self.grid.ds.Cs_r}) is not close to dataset ({ds.attrs['Cs_r']})."
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
if not np.allclose(self.grid.ds.Cs_w, ds.attrs["Cs_w"]):
|
|
492
|
+
raise ValueError(
|
|
493
|
+
f"Cs_w from grid ({self.grid.ds.Cs_w}) is not close to dataset ({ds.attrs['Cs_w']})."
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
def _add_absolute_time(self, ds: xr.Dataset) -> xr.Dataset:
|
|
497
|
+
"""Add absolute time as a coordinate to the dataset.
|
|
498
|
+
|
|
499
|
+
Computes "abs_time" based on "ocean_time" and a reference date,
|
|
500
|
+
and adds it as a coordinate.
|
|
501
|
+
|
|
502
|
+
Parameters
|
|
503
|
+
----------
|
|
504
|
+
ds : xarray.Dataset
|
|
505
|
+
Dataset containing "ocean_time" in seconds since the model reference date.
|
|
506
|
+
|
|
507
|
+
Returns
|
|
508
|
+
-------
|
|
509
|
+
xarray.Dataset
|
|
510
|
+
Dataset with "abs_time" added and "time" removed.
|
|
511
|
+
"""
|
|
512
|
+
ocean_time_seconds = ds["ocean_time"].values
|
|
513
|
+
|
|
514
|
+
abs_time = np.array(
|
|
515
|
+
[
|
|
516
|
+
self.model_reference_date + timedelta(seconds=seconds)
|
|
517
|
+
for seconds in ocean_time_seconds
|
|
518
|
+
]
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
abs_time = xr.DataArray(
|
|
522
|
+
abs_time, dims=["time"], coords={"time": ds["ocean_time"]}
|
|
523
|
+
)
|
|
524
|
+
abs_time.attrs["long_name"] = "absolute time"
|
|
525
|
+
ds = ds.assign_coords({"abs_time": abs_time})
|
|
526
|
+
ds = ds.drop_vars("time")
|
|
527
|
+
|
|
528
|
+
return ds
|
|
529
|
+
|
|
530
|
+
def _add_lat_lon_coords(self, ds: xr.Dataset) -> xr.Dataset:
|
|
531
|
+
"""Add latitude and longitude coordinates to the dataset.
|
|
532
|
+
|
|
533
|
+
Adds "lat_rho" and "lon_rho" from the grid object to the dataset.
|
|
534
|
+
|
|
535
|
+
Parameters
|
|
536
|
+
----------
|
|
537
|
+
ds : xarray.Dataset
|
|
538
|
+
Dataset to update.
|
|
539
|
+
|
|
540
|
+
Returns
|
|
541
|
+
-------
|
|
542
|
+
xarray.Dataset
|
|
543
|
+
Dataset with "lat_rho" and "lon_rho" coordinates added.
|
|
544
|
+
"""
|
|
545
|
+
ds = ds.assign_coords(
|
|
546
|
+
{"lat_rho": self.grid.ds["lat_rho"], "lon_rho": self.grid.ds["lon_rho"]}
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
return ds
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
def _validate_and_set_filenames(
|
|
553
|
+
filenames: Union[str, list], filetype: str, string: str
|
|
554
|
+
) -> Union[str, list]:
|
|
555
|
+
"""Validates and adjusts the filename or list of filenames based on the specified
|
|
556
|
+
type and checks for the presence of a string in the filename.
|
|
557
|
+
|
|
558
|
+
Parameters
|
|
559
|
+
----------
|
|
560
|
+
filenames : Union[str, list]
|
|
561
|
+
A single filename (str), a list of filenames, or a directory path.
|
|
562
|
+
filetype : str
|
|
563
|
+
The type of input: 'file' for a single file, 'list' for a list of files, or 'dir' for a directory.
|
|
564
|
+
string : str
|
|
565
|
+
The string that should be present in each filename.
|
|
566
|
+
|
|
567
|
+
Returns
|
|
568
|
+
-------
|
|
569
|
+
Union[str, list]
|
|
570
|
+
The validated filename(s). If a directory is provided, the function returns the adjusted file pattern.
|
|
571
|
+
"""
|
|
572
|
+
if filetype == "file":
|
|
573
|
+
if string not in os.path.basename(filenames):
|
|
574
|
+
logging.warning(
|
|
575
|
+
f"The file '{filenames}' does not appear to contain '{string}' in the name."
|
|
576
|
+
)
|
|
577
|
+
elif filetype == "list":
|
|
578
|
+
for file in filenames:
|
|
579
|
+
if string not in os.path.basename(file):
|
|
580
|
+
logging.warning(
|
|
581
|
+
f"The file '{file}' does not appear to contain '{string}' in the name."
|
|
582
|
+
)
|
|
583
|
+
elif filetype == "dir":
|
|
584
|
+
filenames = os.path.join(filenames, f"*{string}.*.nc")
|
|
585
|
+
|
|
586
|
+
return filenames
|
|
@@ -59,6 +59,9 @@ pup_test_data = pooch.create(
|
|
|
59
59
|
"grid_created_with_matlab.nc": "fd537ef8159fabb18e38495ec8d44e2fa1b7fb615fcb1417dd4c0e1bb5f4e41d",
|
|
60
60
|
"etopo5_coarsened_and_shifted.nc": "9a5cb4b38c779d22ddb0ad069b298b9722db34ca85a89273eccca691e89e6f96",
|
|
61
61
|
"srtm15_coarsened.nc": "48bc8f4beecfdca9c192b13f4cbeef1455f49d8261a82563aaec5757e100dff9",
|
|
62
|
+
"eastpac25km_rst.19980106000000.nc": "8f56d72bd8daf72eb736cc6705f93f478f4ad0ae4a95e98c4c9393a38e032f4c",
|
|
63
|
+
"eastpac25km_rst.19980126000000.nc": "20ad9007c980d211d1e108c50589183120c42a2d96811264cf570875107269e4",
|
|
64
|
+
"epac25km_grd.nc": "ec26c69cda4c4e96abde5b7756c955a7e1074931ab5a0641f598b099778fb617",
|
|
62
65
|
},
|
|
63
66
|
)
|
|
64
67
|
|