roms-tools 2.2.0__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. roms_tools/__init__.py +1 -0
  2. roms_tools/analysis/roms_output.py +586 -0
  3. roms_tools/{setup/download.py → download.py} +3 -0
  4. roms_tools/{setup/plot.py → plot.py} +34 -28
  5. roms_tools/setup/boundary_forcing.py +23 -12
  6. roms_tools/setup/datasets.py +2 -135
  7. roms_tools/setup/grid.py +54 -15
  8. roms_tools/setup/initial_conditions.py +105 -149
  9. roms_tools/setup/nesting.py +4 -4
  10. roms_tools/setup/river_forcing.py +7 -9
  11. roms_tools/setup/surface_forcing.py +14 -14
  12. roms_tools/setup/tides.py +24 -21
  13. roms_tools/setup/topography.py +1 -1
  14. roms_tools/setup/utils.py +19 -143
  15. roms_tools/tests/test_analysis/test_roms_output.py +269 -0
  16. roms_tools/tests/{test_setup/test_regrid.py → test_regrid.py} +1 -1
  17. roms_tools/tests/test_setup/test_boundary_forcing.py +1 -1
  18. roms_tools/tests/test_setup/test_datasets.py +1 -1
  19. roms_tools/tests/test_setup/test_grid.py +1 -1
  20. roms_tools/tests/test_setup/test_initial_conditions.py +8 -4
  21. roms_tools/tests/test_setup/test_river_forcing.py +1 -1
  22. roms_tools/tests/test_setup/test_surface_forcing.py +1 -1
  23. roms_tools/tests/test_setup/test_tides.py +1 -1
  24. roms_tools/tests/test_setup/test_topography.py +1 -1
  25. roms_tools/tests/test_setup/test_utils.py +56 -1
  26. roms_tools/utils.py +301 -0
  27. roms_tools/vertical_coordinate.py +306 -0
  28. {roms_tools-2.2.0.dist-info → roms_tools-2.3.0.dist-info}/METADATA +1 -1
  29. {roms_tools-2.2.0.dist-info → roms_tools-2.3.0.dist-info}/RECORD +33 -31
  30. roms_tools/setup/vertical_coordinate.py +0 -109
  31. /roms_tools/{setup/regrid.py → regrid.py} +0 -0
  32. {roms_tools-2.2.0.dist-info → roms_tools-2.3.0.dist-info}/LICENSE +0 -0
  33. {roms_tools-2.2.0.dist-info → roms_tools-2.3.0.dist-info}/WHEEL +0 -0
  34. {roms_tools-2.2.0.dist-info → roms_tools-2.3.0.dist-info}/top_level.txt +0 -0
roms_tools/__init__.py CHANGED
@@ -15,6 +15,7 @@ from roms_tools.setup.initial_conditions import InitialConditions # noqa: F401
15
15
  from roms_tools.setup.boundary_forcing import BoundaryForcing # noqa: F401
16
16
  from roms_tools.setup.river_forcing import RiverForcing # noqa: F401
17
17
  from roms_tools.setup.nesting import Nesting # noqa: F401
18
+ from roms_tools.analysis.roms_output import ROMSOutput # noqa: F401
18
19
 
19
20
  # Configure logging when the package is imported
20
21
  logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s")
@@ -0,0 +1,586 @@
1
+ import xarray as xr
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from roms_tools.utils import _load_data
5
+ from dataclasses import dataclass, field
6
+ from typing import Union, Optional
7
+ from pathlib import Path
8
+ import os
9
+ import re
10
+ import logging
11
+ from datetime import datetime, timedelta
12
+ from roms_tools import Grid
13
+ from roms_tools.plot import _plot, _section_plot, _profile_plot, _line_plot
14
+ from roms_tools.vertical_coordinate import (
15
+ add_depth_coordinates_to_dataset,
16
+ compute_depth_coordinates,
17
+ )
18
+
19
+
20
+ @dataclass(frozen=True, kw_only=True)
21
+ class ROMSOutput:
22
+ """Represents ROMS model output.
23
+
24
+ Parameters
25
+ ----------
26
+ grid : Grid
27
+ Object representing the grid information.
28
+ path : Union[str, Path, List[Union[str, Path]]]
29
+ Directory, filename, or list of filenames with model output.
30
+ type : str
31
+ Specifies the type of model output. Options are:
32
+
33
+ - "restart": for restart files.
34
+ - "average": for time-averaged files.
35
+ - "snapshot": for snapshot files.
36
+
37
+ model_reference_date : datetime, optional
38
+ If not specified, this is inferred from metadata of the model output
39
+ If specified and does not coincide with metadata, a warning is raised.
40
+ use_dask: bool, optional
41
+ Indicates whether to use dask for processing. If True, data is processed with dask; if False, data is processed eagerly. Defaults to False.
42
+ """
43
+
44
+ grid: Grid
45
+ path: Union[str, Path]
46
+ type: Union[str, Path]
47
+ use_dask: bool = False
48
+ model_reference_date: Optional[datetime] = None
49
+ ds: xr.Dataset = field(init=False, repr=False)
50
+
51
+ def __post_init__(self):
52
+ # Validate `type`
53
+ if self.type not in {"restart", "average", "snapshot"}:
54
+ raise ValueError(
55
+ f"Invalid type '{self.type}'. Must be one of 'restart', 'average', or 'snapshot'."
56
+ )
57
+
58
+ ds = self._load_model_output()
59
+ self._infer_model_reference_date_from_metadata(ds)
60
+ self._check_vertical_coordinate(ds)
61
+ ds = self._add_absolute_time(ds)
62
+ ds = self._add_lat_lon_coords(ds)
63
+ object.__setattr__(self, "ds", ds)
64
+
65
+ def plot(
66
+ self,
67
+ var_name,
68
+ time=0,
69
+ s=None,
70
+ eta=None,
71
+ xi=None,
72
+ depth_contours=False,
73
+ layer_contours=False,
74
+ ax=None,
75
+ ) -> None:
76
+ """Plot a ROMS output field for a given vertical (s_rho) or horizontal (eta, xi)
77
+ slice.
78
+
79
+ Parameters
80
+ ----------
81
+ var_name : str
82
+ Name of the variable to plot. Supported options include:
83
+
84
+ - Oceanographic fields: "temp", "salt", "zeta", "u", "v", "w", etc.
85
+ - Biogeochemical tracers: "PO4", "NO3", "O2", "DIC", "ALK", etc.
86
+
87
+ time : int, optional
88
+ Index of the time dimension to plot. Default is 0.
89
+ s : int, optional
90
+ The index of the vertical layer (`s_rho`) to plot. If not specified, the plot
91
+ will represent a horizontal slice (eta- or xi- plane). Default is None.
92
+ eta : int, optional
93
+ The eta-index to plot. Used for vertical sections or horizontal slices.
94
+ Default is None.
95
+ xi : int, optional
96
+ The xi-index to plot. Used for vertical sections or horizontal slices.
97
+ Default is None.
98
+ depth_contours : bool, optional
99
+ If True, depth contours will be overlaid on the plot, showing lines of constant
100
+ depth. This is typically used for plots that show a single vertical layer.
101
+ Default is False.
102
+ layer_contours : bool, optional
103
+ If True, contour lines representing the boundaries between vertical layers will
104
+ be added to the plot. This is particularly useful in vertical sections to
105
+ visualize the layering of the water column. For clarity, the number of layer
106
+ contours displayed is limited to a maximum of 10. Default is False.
107
+ ax : matplotlib.axes.Axes, optional
108
+ The axes to plot on. If None, a new figure is created. Note that this argument does not work for horizontal plots that display the eta- and xi-dimensions at the same time.
109
+
110
+ Returns
111
+ -------
112
+ None
113
+ This method does not return any value. It generates and displays a plot.
114
+
115
+ Raises
116
+ ------
117
+ ValueError
118
+ If the specified `var_name` is not one of the valid options.
119
+ If the field specified by `var_name` is 3D and none of `s`, `eta`, or `xi` are specified.
120
+ If the field specified by `var_name` is 2D and both `eta` and `xi` are specified.
121
+ """
122
+
123
+ # Input checks
124
+ if var_name not in self.ds:
125
+ raise ValueError(f"Variable '{var_name}' is not found in dataset.")
126
+
127
+ if "time" in self.ds[var_name].dims:
128
+ if time >= len(self.ds[var_name].time):
129
+ raise ValueError(
130
+ f"Invalid time index: The specified time index ({time}) exceeds the maximum index "
131
+ f"({len(self.ds[var_name].time) - 1}) for the 'time' dimension in variable '{var_name}'."
132
+ )
133
+ field = self.ds[var_name].isel(time=time)
134
+ else:
135
+ if time > 0:
136
+ raise ValueError(
137
+ f"Invalid input: The variable '{var_name}' does not have a 'time' dimension, "
138
+ f"but a time index ({time}) greater than 0 was provided."
139
+ )
140
+ field = self.ds[var_name]
141
+
142
+ if len(field.dims) == 3:
143
+ if not any([s is not None, eta is not None, xi is not None]):
144
+ raise ValueError(
145
+ "Invalid input: For 3D fields, you must specify at least one of the dimensions 's', 'eta', or 'xi'."
146
+ )
147
+ if all([s is not None, eta is not None, xi is not None]):
148
+ raise ValueError(
149
+ "Ambiguous input: For 3D fields, specify at most two of 's', 'eta', or 'xi'. Specifying all three is not allowed."
150
+ )
151
+
152
+ if len(field.dims) == 2 and all([eta is not None, xi is not None]):
153
+ raise ValueError(
154
+ "Conflicting input: For 2D fields, specify only one dimension, either 'eta' or 'xi', not both."
155
+ )
156
+
157
+ # Load the data
158
+ if self.use_dask:
159
+ from dask.diagnostics import ProgressBar
160
+
161
+ with ProgressBar():
162
+ field.load()
163
+
164
+ # Get correct mask and spatial coordinates
165
+ if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
166
+ loc = "rho"
167
+ elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
168
+ loc = "u"
169
+ elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
170
+ loc = "v"
171
+ else:
172
+ ValueError("provided field does not have two horizontal dimension")
173
+
174
+ mask = self.grid.ds[f"mask_{loc}"]
175
+ lat_deg = self.grid.ds[f"lat_{loc}"]
176
+ lon_deg = self.grid.ds[f"lon_{loc}"]
177
+
178
+ if self.grid.straddle:
179
+ lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
180
+
181
+ field = field.assign_coords({"lon": lon_deg, "lat": lat_deg})
182
+
183
+ # Retrieve depth coordinates
184
+ compute_layer_depth = (depth_contours or s is None) and len(field.dims) > 2
185
+ compute_interface_depth = layer_contours and s is None
186
+
187
+ if compute_layer_depth:
188
+ layer_depth = compute_depth_coordinates(
189
+ self.ds.isel(time=time),
190
+ self.grid.ds,
191
+ depth_type="layer",
192
+ location=loc,
193
+ s=s,
194
+ eta=eta,
195
+ xi=xi,
196
+ )
197
+ if compute_interface_depth:
198
+ interface_depth = compute_depth_coordinates(
199
+ self.ds.isel(time=time),
200
+ self.grid.ds,
201
+ depth_type="interface",
202
+ location=loc,
203
+ s=s,
204
+ eta=eta,
205
+ xi=xi,
206
+ )
207
+
208
+ # Slice the field as desired
209
+ title = field.long_name
210
+ if s is not None:
211
+ title = title + f", s_rho = {field.s_rho[s].item()}"
212
+ field = field.isel(s_rho=s)
213
+ else:
214
+ depth_contours = False
215
+
216
+ def _process_dimension(field, mask, dim_name, dim_values, idx, title):
217
+ if dim_name in field.dims:
218
+ title = title + f", {dim_name} = {dim_values[idx].item()}"
219
+ field = field.isel(**{dim_name: idx})
220
+ mask = mask.isel(**{dim_name: idx})
221
+ else:
222
+ raise ValueError(
223
+ f"None of the expected dimensions ({dim_name}) found in field."
224
+ )
225
+ return field, mask, title
226
+
227
+ if eta is not None:
228
+ field, mask, title = _process_dimension(
229
+ field,
230
+ mask,
231
+ "eta_rho" if "eta_rho" in field.dims else "eta_v",
232
+ field.eta_rho if "eta_rho" in field.dims else field.eta_v,
233
+ eta,
234
+ title,
235
+ )
236
+
237
+ if xi is not None:
238
+ field, mask, title = _process_dimension(
239
+ field,
240
+ mask,
241
+ "xi_rho" if "xi_rho" in field.dims else "xi_u",
242
+ field.xi_rho if "xi_rho" in field.dims else field.xi_u,
243
+ xi,
244
+ title,
245
+ )
246
+
247
+ # Format to exclude seconds
248
+ formatted_time = np.datetime_as_string(field.abs_time.values, unit="m")
249
+ title = title + f", time: {formatted_time}"
250
+
251
+ if compute_layer_depth:
252
+ field = field.assign_coords({"layer_depth": layer_depth})
253
+
254
+ # Choose colorbar
255
+ if var_name in ["u", "v", "w", "ubar", "vbar", "zeta"]:
256
+ vmax = max(field.where(mask).max().values, -field.where(mask).min().values)
257
+ vmin = -vmax
258
+ cmap = plt.colormaps.get_cmap("RdBu_r")
259
+ else:
260
+ vmax = field.where(mask).max().values
261
+ vmin = field.where(mask).min().values
262
+ if var_name in ["temp", "salt"]:
263
+ cmap = plt.colormaps.get_cmap("YlOrRd")
264
+ else:
265
+ cmap = plt.colormaps.get_cmap("YlGn")
266
+ cmap.set_bad(color="gray")
267
+ kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
268
+
269
+ # Plotting
270
+ if eta is None and xi is None:
271
+ _plot(
272
+ field=field.where(mask),
273
+ depth_contours=depth_contours,
274
+ title=title,
275
+ kwargs=kwargs,
276
+ c="g",
277
+ )
278
+ else:
279
+ if len(field.dims) == 2:
280
+ if not layer_contours:
281
+ interface_depth = None
282
+ else:
283
+ # restrict number of layer_contours to 10 for the sake of plot clearity
284
+ nr_layers = len(interface_depth["s_w"])
285
+ selected_layers = np.linspace(
286
+ 0, nr_layers - 1, min(nr_layers, 10), dtype=int
287
+ )
288
+ interface_depth = interface_depth.isel(s_w=selected_layers)
289
+ _section_plot(
290
+ field.where(mask),
291
+ interface_depth=interface_depth,
292
+ title=title,
293
+ kwargs=kwargs,
294
+ ax=ax,
295
+ )
296
+ else:
297
+ if "s_rho" in field.dims:
298
+ _profile_plot(field.where(mask), title=title, ax=ax)
299
+ else:
300
+ _line_plot(field.where(mask), title=title, ax=ax)
301
+
302
+ def compute_depth_coordinates(self, depth_type="layer", locations=["rho"]):
303
+ """Compute and update vertical depth coordinates.
304
+
305
+ Calculates vertical depth coordinates (layer or interface) for specified locations (e.g., rho, u, v points)
306
+ and updates them in the dataset (`self.ds`).
307
+
308
+ Parameters
309
+ ----------
310
+ depth_type : str
311
+ The type of depth coordinate to compute. Valid options:
312
+ - "layer": Compute layer depth coordinates.
313
+ - "interface": Compute interface depth coordinates.
314
+ locations : list[str], optional
315
+ Locations for which to compute depth coordinates. Default is ["rho", "u", "v"].
316
+ Valid options include:
317
+ - "rho": Depth coordinates at rho points.
318
+ - "u": Depth coordinates at u points.
319
+ - "v": Depth coordinates at v points.
320
+
321
+ Updates
322
+ -------
323
+ self.ds : xarray.Dataset
324
+ The dataset (`self.ds`) is updated with the following depth coordinate variables:
325
+ - f"{depth_type}_depth_rho": Depth coordinates at rho points.
326
+ - f"{depth_type}_depth_u": Depth coordinates at u points (if included in `locations`).
327
+ - f"{depth_type}_depth_v": Depth coordinates at v points (if included in `locations`).
328
+
329
+ Notes
330
+ -----
331
+ This method uses the `compute_and_update_depth_coordinates` function to perform calculations and updates.
332
+ """
333
+
334
+ add_depth_coordinates_to_dataset(self.ds, self.grid.ds, depth_type, locations)
335
+
336
+ def _load_model_output(self) -> xr.Dataset:
337
+ """Load the model output based on the type."""
338
+ if isinstance(self.path, list):
339
+ filetype = "list"
340
+ force_combine_nested = True
341
+ # Check if all items in the list are files
342
+ if not all(Path(item).is_file() for item in self.path):
343
+ raise FileNotFoundError(
344
+ "All items in the provided list must be valid files."
345
+ )
346
+ elif Path(self.path).is_file():
347
+ filetype = "file"
348
+ force_combine_nested = False
349
+ elif Path(self.path).is_dir():
350
+ filetype = "dir"
351
+ force_combine_nested = True
352
+ else:
353
+ raise FileNotFoundError(
354
+ f"The specified path '{self.path}' is neither a file, nor a list of files, nor a directory."
355
+ )
356
+
357
+ time_chunking = True
358
+ if self.type == "restart":
359
+ time_chunking = False
360
+ filename = _validate_and_set_filenames(self.path, filetype, "rst")
361
+ elif self.type == "average":
362
+ filename = _validate_and_set_filenames(self.path, filetype, "avg")
363
+ elif self.type == "snapshot":
364
+ filename = _validate_and_set_filenames(self.path, filetype, "his")
365
+ else:
366
+ raise ValueError(f"Unsupported type '{self.type}'.")
367
+
368
+ # Load the dataset
369
+ ds = _load_data(
370
+ filename,
371
+ dim_names={"time": "time"},
372
+ use_dask=self.use_dask,
373
+ time_chunking=time_chunking,
374
+ force_combine_nested=force_combine_nested,
375
+ )
376
+
377
+ return ds
378
+
379
+ def _infer_model_reference_date_from_metadata(self, ds: xr.Dataset) -> None:
380
+ """Infer and validate the model reference date from `ocean_time` metadata.
381
+
382
+ Parameters
383
+ ----------
384
+ ds : xr.Dataset
385
+ Dataset with an `ocean_time` variable and a `long_name` attribute
386
+ in the format `Time since YYYY/MM/DD`.
387
+
388
+ Raises
389
+ ------
390
+ ValueError
391
+ If `self.model_reference_date` is not set and the reference date cannot
392
+ be inferred, or if the inferred date does not match `self.model_reference_date`.
393
+
394
+ Warns
395
+ -----
396
+ UserWarning
397
+ If `self.model_reference_date` is set but the reference date cannot be inferred.
398
+ """
399
+ # Check if 'long_name' exists in the attributes of 'ocean_time'
400
+ if "long_name" in ds.ocean_time.attrs:
401
+ input_string = ds.ocean_time.attrs["long_name"]
402
+ match = re.search(r"(\d{4})/(\d{2})/(\d{2})", input_string)
403
+
404
+ if match:
405
+ # If a match is found, extract year, month, day and create the inferred date
406
+ year, month, day = map(int, match.groups())
407
+ inferred_date = datetime(year, month, day)
408
+
409
+ if hasattr(self, "model_reference_date") and self.model_reference_date:
410
+ # Check if the inferred date matches the provided model reference date
411
+ if self.model_reference_date != inferred_date:
412
+ raise ValueError(
413
+ f"Mismatch between `self.model_reference_date` ({self.model_reference_date}) "
414
+ f"and inferred reference date ({inferred_date})."
415
+ )
416
+ else:
417
+ # Set the model reference date if not already set
418
+ object.__setattr__(self, "model_reference_date", inferred_date)
419
+ else:
420
+ # Handle case where no match is found
421
+ if hasattr(self, "model_reference_date") and self.model_reference_date:
422
+ logging.warning(
423
+ "Could not infer the model reference date from the metadata. "
424
+ "`self.model_reference_date` will be used.",
425
+ )
426
+ else:
427
+ raise ValueError(
428
+ "Model reference date could not be inferred from the metadata, "
429
+ "and `self.model_reference_date` is not set."
430
+ )
431
+ else:
432
+ # Handle case where 'long_name' attribute doesn't exist
433
+ if hasattr(self, "model_reference_date") and self.model_reference_date:
434
+ logging.warning(
435
+ "`long_name` attribute not found in ocean_time. "
436
+ "`self.model_reference_date` will be used instead.",
437
+ )
438
+ else:
439
+ raise ValueError(
440
+ "Model reference date could not be inferred from the metadata, "
441
+ "and `self.model_reference_date` is not set."
442
+ )
443
+
444
+ def _check_vertical_coordinate(self, ds: xr.Dataset) -> None:
445
+ """Check that the vertical coordinate parameters in the dataset are consistent
446
+ with the model grid.
447
+
448
+ This method compares the vertical coordinate parameters (`theta_s`, `theta_b`, `hc`, `Cs_r`, `Cs_w`) in
449
+ the provided dataset (`ds`) with those in the model grid (`self.grid`). The first three parameters are
450
+ checked for exact equality, while the last two are checked for numerical closeness.
451
+
452
+ Parameters
453
+ ----------
454
+ ds : xarray.Dataset
455
+ The dataset containing vertical coordinate parameters in its attributes, such as `theta_s`, `theta_b`,
456
+ `hc`, `Cs_r`, and `Cs_w`.
457
+
458
+ Raises
459
+ ------
460
+ ValueError
461
+ If the vertical coordinate parameters do not match the expected values (based on exact or approximate equality).
462
+
463
+ Notes
464
+ -----
465
+ - `theta_s`, `theta_b`, and `hc` are checked for exact equality using `np.array_equal`.
466
+ - `Cs_r` and `Cs_w` are checked for numerical closeness using `np.allclose`.
467
+ """
468
+
469
+ # Check exact equality for theta_s, theta_b, and hc
470
+ if not np.array_equal(self.grid.theta_s, ds.attrs["theta_s"]):
471
+ raise ValueError(
472
+ f"theta_s from grid ({self.grid.theta_s}) does not match dataset ({ds.attrs['theta_s']})."
473
+ )
474
+
475
+ if not np.array_equal(self.grid.theta_b, ds.attrs["theta_b"]):
476
+ raise ValueError(
477
+ f"theta_b from grid ({self.grid.theta_b}) does not match dataset ({ds.attrs['theta_b']})."
478
+ )
479
+
480
+ if not np.array_equal(self.grid.hc, ds.attrs["hc"]):
481
+ raise ValueError(
482
+ f"hc from grid ({self.grid.hc}) does not match dataset ({ds.attrs['hc']})."
483
+ )
484
+
485
+ # Check numerical closeness for Cs_r and Cs_w
486
+ if not np.allclose(self.grid.ds.Cs_r, ds.attrs["Cs_r"]):
487
+ raise ValueError(
488
+ f"Cs_r from grid ({self.grid.ds.Cs_r}) is not close to dataset ({ds.attrs['Cs_r']})."
489
+ )
490
+
491
+ if not np.allclose(self.grid.ds.Cs_w, ds.attrs["Cs_w"]):
492
+ raise ValueError(
493
+ f"Cs_w from grid ({self.grid.ds.Cs_w}) is not close to dataset ({ds.attrs['Cs_w']})."
494
+ )
495
+
496
+ def _add_absolute_time(self, ds: xr.Dataset) -> xr.Dataset:
497
+ """Add absolute time as a coordinate to the dataset.
498
+
499
+ Computes "abs_time" based on "ocean_time" and a reference date,
500
+ and adds it as a coordinate.
501
+
502
+ Parameters
503
+ ----------
504
+ ds : xarray.Dataset
505
+ Dataset containing "ocean_time" in seconds since the model reference date.
506
+
507
+ Returns
508
+ -------
509
+ xarray.Dataset
510
+ Dataset with "abs_time" added and "time" removed.
511
+ """
512
+ ocean_time_seconds = ds["ocean_time"].values
513
+
514
+ abs_time = np.array(
515
+ [
516
+ self.model_reference_date + timedelta(seconds=seconds)
517
+ for seconds in ocean_time_seconds
518
+ ]
519
+ )
520
+
521
+ abs_time = xr.DataArray(
522
+ abs_time, dims=["time"], coords={"time": ds["ocean_time"]}
523
+ )
524
+ abs_time.attrs["long_name"] = "absolute time"
525
+ ds = ds.assign_coords({"abs_time": abs_time})
526
+ ds = ds.drop_vars("time")
527
+
528
+ return ds
529
+
530
+ def _add_lat_lon_coords(self, ds: xr.Dataset) -> xr.Dataset:
531
+ """Add latitude and longitude coordinates to the dataset.
532
+
533
+ Adds "lat_rho" and "lon_rho" from the grid object to the dataset.
534
+
535
+ Parameters
536
+ ----------
537
+ ds : xarray.Dataset
538
+ Dataset to update.
539
+
540
+ Returns
541
+ -------
542
+ xarray.Dataset
543
+ Dataset with "lat_rho" and "lon_rho" coordinates added.
544
+ """
545
+ ds = ds.assign_coords(
546
+ {"lat_rho": self.grid.ds["lat_rho"], "lon_rho": self.grid.ds["lon_rho"]}
547
+ )
548
+
549
+ return ds
550
+
551
+
552
+ def _validate_and_set_filenames(
553
+ filenames: Union[str, list], filetype: str, string: str
554
+ ) -> Union[str, list]:
555
+ """Validates and adjusts the filename or list of filenames based on the specified
556
+ type and checks for the presence of a string in the filename.
557
+
558
+ Parameters
559
+ ----------
560
+ filenames : Union[str, list]
561
+ A single filename (str), a list of filenames, or a directory path.
562
+ filetype : str
563
+ The type of input: 'file' for a single file, 'list' for a list of files, or 'dir' for a directory.
564
+ string : str
565
+ The string that should be present in each filename.
566
+
567
+ Returns
568
+ -------
569
+ Union[str, list]
570
+ The validated filename(s). If a directory is provided, the function returns the adjusted file pattern.
571
+ """
572
+ if filetype == "file":
573
+ if string not in os.path.basename(filenames):
574
+ logging.warning(
575
+ f"The file '{filenames}' does not appear to contain '{string}' in the name."
576
+ )
577
+ elif filetype == "list":
578
+ for file in filenames:
579
+ if string not in os.path.basename(file):
580
+ logging.warning(
581
+ f"The file '{file}' does not appear to contain '{string}' in the name."
582
+ )
583
+ elif filetype == "dir":
584
+ filenames = os.path.join(filenames, f"*{string}.*.nc")
585
+
586
+ return filenames
@@ -59,6 +59,9 @@ pup_test_data = pooch.create(
59
59
  "grid_created_with_matlab.nc": "fd537ef8159fabb18e38495ec8d44e2fa1b7fb615fcb1417dd4c0e1bb5f4e41d",
60
60
  "etopo5_coarsened_and_shifted.nc": "9a5cb4b38c779d22ddb0ad069b298b9722db34ca85a89273eccca691e89e6f96",
61
61
  "srtm15_coarsened.nc": "48bc8f4beecfdca9c192b13f4cbeef1455f49d8261a82563aaec5757e100dff9",
62
+ "eastpac25km_rst.19980106000000.nc": "8f56d72bd8daf72eb736cc6705f93f478f4ad0ae4a95e98c4c9393a38e032f4c",
63
+ "eastpac25km_rst.19980126000000.nc": "20ad9007c980d211d1e108c50589183120c42a2d96811264cf570875107269e4",
64
+ "epac25km_grd.nc": "ec26c69cda4c4e96abde5b7756c955a7e1074931ab5a0641f598b099778fb617",
62
65
  },
63
66
  )
64
67