roms-tools 3.1.1__py3-none-any.whl → 3.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. roms_tools/__init__.py +8 -1
  2. roms_tools/analysis/cdr_analysis.py +203 -0
  3. roms_tools/analysis/cdr_ensemble.py +198 -0
  4. roms_tools/analysis/roms_output.py +80 -46
  5. roms_tools/data/grids/GLORYS_global_grid.nc +0 -0
  6. roms_tools/download.py +4 -0
  7. roms_tools/plot.py +131 -30
  8. roms_tools/regrid.py +6 -1
  9. roms_tools/setup/boundary_forcing.py +94 -44
  10. roms_tools/setup/cdr_forcing.py +123 -15
  11. roms_tools/setup/cdr_release.py +161 -8
  12. roms_tools/setup/datasets.py +709 -341
  13. roms_tools/setup/grid.py +167 -139
  14. roms_tools/setup/initial_conditions.py +113 -48
  15. roms_tools/setup/mask.py +63 -7
  16. roms_tools/setup/nesting.py +67 -42
  17. roms_tools/setup/river_forcing.py +45 -19
  18. roms_tools/setup/surface_forcing.py +16 -10
  19. roms_tools/setup/tides.py +1 -2
  20. roms_tools/setup/topography.py +4 -4
  21. roms_tools/setup/utils.py +134 -22
  22. roms_tools/tests/test_analysis/test_cdr_analysis.py +144 -0
  23. roms_tools/tests/test_analysis/test_cdr_ensemble.py +202 -0
  24. roms_tools/tests/test_analysis/test_roms_output.py +61 -3
  25. roms_tools/tests/test_setup/test_boundary_forcing.py +111 -52
  26. roms_tools/tests/test_setup/test_cdr_forcing.py +54 -0
  27. roms_tools/tests/test_setup/test_cdr_release.py +118 -1
  28. roms_tools/tests/test_setup/test_datasets.py +458 -34
  29. roms_tools/tests/test_setup/test_grid.py +238 -121
  30. roms_tools/tests/test_setup/test_initial_conditions.py +94 -41
  31. roms_tools/tests/test_setup/test_surface_forcing.py +28 -3
  32. roms_tools/tests/test_setup/test_utils.py +91 -1
  33. roms_tools/tests/test_setup/test_validation.py +21 -15
  34. roms_tools/tests/test_setup/utils.py +71 -0
  35. roms_tools/tests/test_tiling/test_join.py +241 -0
  36. roms_tools/tests/test_tiling/test_partition.py +45 -0
  37. roms_tools/tests/test_utils.py +224 -2
  38. roms_tools/tiling/join.py +189 -0
  39. roms_tools/tiling/partition.py +44 -30
  40. roms_tools/utils.py +488 -161
  41. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/METADATA +15 -4
  42. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/RECORD +45 -37
  43. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/WHEEL +0 -0
  44. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/licenses/LICENSE +0 -0
  45. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/top_level.txt +0 -0
roms_tools/setup/grid.py CHANGED
@@ -1,9 +1,9 @@
1
1
  import importlib.metadata
2
2
  import logging
3
3
  import re
4
- import time
5
4
  from dataclasses import asdict, dataclass, field
6
5
  from pathlib import Path
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import xarray as xr
@@ -12,9 +12,10 @@ from matplotlib.axes import Axes
12
12
 
13
13
  from roms_tools.constants import MAXIMUM_GRID_SIZE, R_EARTH
14
14
  from roms_tools.plot import plot
15
- from roms_tools.setup.mask import _add_mask, _add_velocity_masks
16
- from roms_tools.setup.topography import _add_topography
15
+ from roms_tools.setup.mask import add_mask, add_velocity_masks
16
+ from roms_tools.setup.topography import add_topography
17
17
  from roms_tools.setup.utils import (
18
+ Timed,
18
19
  extract_single_value,
19
20
  gc_dist,
20
21
  get_target_coords,
@@ -64,6 +65,8 @@ class Grid:
64
65
  - "path" (Union[str, Path, List[Union[str, Path]]]): The path to the raw data file. Can be a string or a Path object.
65
66
 
66
67
  The default is "ETOPO5", which does not require a path.
68
+ mask_shapefile: str | Path | None, optional
69
+ Path to a custom shapefile to use to determine the land mask; if None, use NaturalEarth 10m.
67
70
  hmin : float, optional
68
71
  The minimum ocean depth (in meters). The default is 5.0.
69
72
  N : int, optional
@@ -106,8 +109,10 @@ class Grid:
106
109
  """The bottom control parameter."""
107
110
  hc: float = 300.0
108
111
  """The critical depth (in meters)."""
109
- topography_source: dict[str, str | Path | list[str | Path]] = None
112
+ topography_source: dict[str, str | Path | list[str | Path]] | None = None
110
113
  """Dictionary specifying the source of the topography data."""
114
+ mask_shapefile: str | Path | None = None
115
+ """Path to a custom shapefile to use to determine the landmask; if None, use NaturalEarth 10m."""
111
116
  hmin: float = 5.0
112
117
  """The minimum ocean depth (in meters)."""
113
118
  verbose: bool = False
@@ -129,7 +134,7 @@ class Grid:
129
134
  self._straddle()
130
135
 
131
136
  # Mask
132
- self._create_mask(verbose=self.verbose)
137
+ self.update_mask(mask_shapefile=self.mask_shapefile, verbose=self.verbose)
133
138
 
134
139
  # Coarsen the dataset if needed
135
140
  self._coarsen()
@@ -165,19 +170,35 @@ class Grid:
165
170
  "`topography_source` must include a 'path' key when the 'name' is not 'ETOPO5'."
166
171
  )
167
172
 
168
- def _create_mask(self, verbose=False) -> None:
169
- if verbose:
170
- start_time = time.time()
171
- logging.info("=== Creating the mask ===")
172
- ds = _add_mask(self.ds)
173
+ def update_mask(
174
+ self, mask_shapefile: str | Path | None = None, verbose: bool = False
175
+ ) -> None:
176
+ """
177
+ Update the land mask of the current grid dataset.
173
178
 
174
- if verbose:
175
- logging.info(f"Total time: {time.time() - start_time:.3f} seconds")
176
- logging.info(
177
- "========================================================================================================"
178
- )
179
+ This method generates a land mask based on the provided coastline
180
+ shapefile, fills enclosed basins with lands and updates the dataset
181
+ stored in `self.ds`. If no shapefile is provided, a default dataset (Natural
182
+ Earth 10m) is used. The operation is optionally timed and logged.
179
183
 
180
- self.ds = ds
184
+ Parameters
185
+ ----------
186
+ mask_shapefile : str or Path, optional
187
+ Path to a coastal shapefile to derive the land mask. If `None`,
188
+ the default Natural Earth 10m coastline dataset is used.
189
+ verbose : bool, default False
190
+ If True, prints timing and progress information.
191
+
192
+ Returns
193
+ -------
194
+ None
195
+ Updates the `self.ds` attribute in place with the new mask.
196
+
197
+ """
198
+ with Timed("=== Deriving the mask from coastlines ===", verbose=verbose):
199
+ ds = add_mask(self.ds, shapefile=mask_shapefile)
200
+ self.ds = ds
201
+ self.mask_shapefile = mask_shapefile
181
202
 
182
203
  def update_topography(
183
204
  self, topography_source=None, hmin=None, verbose=False
@@ -218,33 +239,22 @@ class Grid:
218
239
  # Extract target coordinates for processing
219
240
  target_coords = get_target_coords(self)
220
241
 
221
- # If verbose is enabled, start the timer and print the start message
222
- if verbose:
223
- start_time = time.time()
224
- logging.info(
225
- f"=== Generating the topography using {topography_source['name']} data and hmin = {hmin} meters ==="
226
- )
227
-
228
- # Add topography to the dataset
229
- ds = _add_topography(
230
- ds=self.ds,
231
- target_coords=target_coords,
232
- topography_source=topography_source,
233
- hmin=hmin,
242
+ with Timed(
243
+ f"=== Generating the topography using {topography_source['name']} data and hmin = {hmin} meters ===",
234
244
  verbose=verbose,
235
- )
236
-
237
- # If verbose is enabled, print elapsed time and a separator
238
- if verbose:
239
- logging.info(f"Total time: {time.time() - start_time:.3f} seconds")
240
- logging.info(
241
- "========================================================================================================"
245
+ ):
246
+ ds = add_topography(
247
+ ds=self.ds,
248
+ target_coords=target_coords,
249
+ topography_source=topography_source,
250
+ hmin=hmin,
251
+ verbose=verbose,
242
252
  )
243
253
 
244
- # Update the grid's dataset and related attributes
245
- self.ds = ds
246
- self.topography_source = topography_source
247
- self.hmin = hmin
254
+ # Update the grid's dataset and related attributes
255
+ self.ds = ds
256
+ self.topography_source = topography_source
257
+ self.hmin = hmin
248
258
 
249
259
  def update_vertical_coordinate(
250
260
  self, N=None, theta_s=None, theta_b=None, hc=None, verbose=False
@@ -281,69 +291,61 @@ class Grid:
281
291
  theta_b = theta_b or self.theta_b
282
292
  hc = hc or self.hc
283
293
 
284
- if verbose:
285
- start_time = time.time()
286
- logging.info(
287
- f"=== Preparing the vertical coordinate system using N = {N}, theta_s = {theta_s}, theta_b = {theta_b}, hc = {hc} ==="
288
- )
294
+ with Timed(
295
+ f"=== Preparing the vertical coordinate system using N = {N}, theta_s = {theta_s}, theta_b = {theta_b}, hc = {hc} ===",
296
+ verbose=verbose,
297
+ ):
298
+ ds = self.ds
299
+ # need to drop vertical coordinates because they could cause conflict if N changed
300
+ vars_to_drop = [
301
+ "layer_depth_rho",
302
+ "layer_depth_u",
303
+ "layer_depth_v",
304
+ "interface_depth_rho",
305
+ "interface_depth_u",
306
+ "interface_depth_v",
307
+ "sigma_r",
308
+ "sigma_w",
309
+ "Cs_w",
310
+ "Cs_r",
311
+ ]
289
312
 
290
- ds = self.ds
291
- # need to drop vertical coordinates because they could cause conflict if N changed
292
- vars_to_drop = [
293
- "layer_depth_rho",
294
- "layer_depth_u",
295
- "layer_depth_v",
296
- "interface_depth_rho",
297
- "interface_depth_u",
298
- "interface_depth_v",
299
- "sigma_r",
300
- "sigma_w",
301
- "Cs_w",
302
- "Cs_r",
303
- ]
304
-
305
- for var in vars_to_drop:
306
- if var in ds.variables:
307
- ds = ds.drop_vars(var)
308
-
309
- cs_r, sigma_r = sigma_stretch(theta_s, theta_b, N, "r")
310
- cs_w, sigma_w = sigma_stretch(theta_s, theta_b, N, "w")
311
-
312
- ds["sigma_r"] = sigma_r.astype(np.float32)
313
- ds["sigma_r"].attrs["long_name"] = (
314
- "Fractional vertical stretching coordinate at rho-points"
315
- )
316
- ds["sigma_r"].attrs["units"] = "nondimensional"
313
+ for var in vars_to_drop:
314
+ if var in ds.variables:
315
+ ds = ds.drop_vars(var)
317
316
 
318
- ds["Cs_r"] = cs_r.astype(np.float32)
319
- ds["Cs_r"].attrs["long_name"] = "Vertical stretching function at rho-points"
320
- ds["Cs_r"].attrs["units"] = "nondimensional"
317
+ cs_r, sigma_r = sigma_stretch(theta_s, theta_b, N, "r")
318
+ cs_w, sigma_w = sigma_stretch(theta_s, theta_b, N, "w")
321
319
 
322
- ds["sigma_w"] = sigma_w.astype(np.float32)
323
- ds["sigma_w"].attrs["long_name"] = (
324
- "Fractional vertical stretching coordinate at w-points"
325
- )
326
- ds["sigma_w"].attrs["units"] = "nondimensional"
327
-
328
- ds["Cs_w"] = cs_w.astype(np.float32)
329
- ds["Cs_w"].attrs["long_name"] = "Vertical stretching function at w-points"
330
- ds["Cs_w"].attrs["units"] = "nondimensional"
320
+ ds["sigma_r"] = sigma_r.astype(np.float32)
321
+ ds["sigma_r"].attrs["long_name"] = (
322
+ "Fractional vertical stretching coordinate at rho-points"
323
+ )
324
+ ds["sigma_r"].attrs["units"] = "nondimensional"
331
325
 
332
- ds.attrs["theta_s"] = np.float32(theta_s)
333
- ds.attrs["theta_b"] = np.float32(theta_b)
334
- ds.attrs["hc"] = np.float32(hc)
326
+ ds["Cs_r"] = cs_r.astype(np.float32)
327
+ ds["Cs_r"].attrs["long_name"] = "Vertical stretching function at rho-points"
328
+ ds["Cs_r"].attrs["units"] = "nondimensional"
335
329
 
336
- if verbose:
337
- logging.info(f"Total time: {time.time() - start_time:.3f} seconds")
338
- logging.info(
339
- "========================================================================================================"
330
+ ds["sigma_w"] = sigma_w.astype(np.float32)
331
+ ds["sigma_w"].attrs["long_name"] = (
332
+ "Fractional vertical stretching coordinate at w-points"
340
333
  )
334
+ ds["sigma_w"].attrs["units"] = "nondimensional"
341
335
 
342
- self.ds = ds
343
- self.theta_s = theta_s
344
- self.theta_b = theta_b
345
- self.hc = hc
346
- self.N = N
336
+ ds["Cs_w"] = cs_w.astype(np.float32)
337
+ ds["Cs_w"].attrs["long_name"] = "Vertical stretching function at w-points"
338
+ ds["Cs_w"].attrs["units"] = "nondimensional"
339
+
340
+ ds.attrs["theta_s"] = np.float32(theta_s)
341
+ ds.attrs["theta_b"] = np.float32(theta_b)
342
+ ds.attrs["hc"] = np.float32(hc)
343
+
344
+ self.ds = ds
345
+ self.theta_s = theta_s
346
+ self.theta_b = theta_b
347
+ self.hc = hc
348
+ self.N = N
347
349
 
348
350
  def _straddle(self) -> None:
349
351
  """Check if the Greenwich meridian goes through the domain.
@@ -415,30 +417,57 @@ class Grid:
415
417
 
416
418
  def plot(
417
419
  self,
420
+ lat: float | None = None,
421
+ lon: float | None = None,
418
422
  with_dim_names: bool = False,
419
423
  save_path: str | None = None,
420
424
  ) -> None:
421
- """Plot the grid.
425
+ """Plot the grid with bathymetry.
426
+
427
+ Depending on the arguments, this will either:
428
+ * Plot the full horizontal grid (if both `lat` and `lon` are None),
429
+ * Plot a zonal (east-west) vertical section at a given latitude (`lat`),
430
+ * Plot a meridional (south-north) vertical section at a given longitude (`lon`).
422
431
 
423
432
  Parameters
424
433
  ----------
434
+ lat : float, optional
435
+ Latitude in degrees at which to plot a vertical (zonal) section. Cannot be
436
+ provided together with `lon`. Default is None.
437
+
438
+ lon : float, optional
439
+ Longitude in degrees at which to plot a vertical (meridional) section. Cannot be
440
+ provided together with `lat`. Default is None.
441
+
425
442
  with_dim_names : bool, optional
426
- Whether or not to plot the dimension names. Default is False.
443
+ If True and no section is requested (i.e., both `lat` and `lon` are None), annotate
444
+ the plot with the underlying dimension names. Default is False.
427
445
 
428
446
  save_path : str, optional
429
447
  Path to save the generated plot. If None, the plot is shown interactively.
430
448
  Default is None.
431
449
 
450
+ Raises
451
+ ------
452
+ ValueError
453
+ If both `lat` and `lon` are specified simultaneously.
454
+
432
455
  Returns
433
456
  -------
434
457
  None
435
458
  This method does not return any value. It generates and displays a plot.
436
459
  """
460
+ if lat is not None and lon is not None:
461
+ raise ValueError("Specify either `lat` or `lon`, not both.")
462
+
437
463
  field = self.ds["h"]
438
464
 
439
465
  plot(
440
466
  field=field,
441
467
  grid_ds=self.ds,
468
+ lat=lat,
469
+ lon=lon,
470
+ yincrease=False,
442
471
  with_dim_names=with_dim_names,
443
472
  save_path=save_path,
444
473
  cmap_name="YlGnBu",
@@ -576,7 +605,7 @@ class Grid:
576
605
  ds = xr.open_dataset(filepath)
577
606
 
578
607
  if not all(mask in ds for mask in ["mask_u", "mask_v"]):
579
- ds = _add_velocity_masks(ds)
608
+ ds = add_velocity_masks(ds)
580
609
 
581
610
  # Create a new Grid instance without calling __init__ and __post_init__
582
611
  grid = cls.__new__(cls)
@@ -731,24 +760,30 @@ class Grid:
731
760
  "hmin",
732
761
  ]:
733
762
  if attr in ds.attrs:
734
- a = float(ds.attrs[attr])
763
+ value = float(ds.attrs[attr])
735
764
  else:
736
- a = None
765
+ value = None
737
766
 
738
- object.__setattr__(grid, attr, a)
767
+ object.__setattr__(grid, attr, value)
739
768
 
740
769
  if "topography_source_name" in ds.attrs:
741
770
  if "topography_source_path" in ds.attrs:
742
- a = {
771
+ topo_source = {
743
772
  "name": ds.attrs["topography_source_name"],
744
773
  "path": ds.attrs["topography_source_path"],
745
774
  }
746
775
  else:
747
- a = {"name": ds.attrs["topography_source_name"]}
776
+ topo_source = {"name": ds.attrs["topography_source_name"]}
777
+ else:
778
+ topo_source = None
779
+ grid.topography_source = topo_source
780
+
781
+ if "mask_shapefile" in ds.attrs:
782
+ mask_shapefile = ds.attrs["mask_shapefile"]
748
783
  else:
749
- a = None
784
+ mask_shapefile = None
750
785
 
751
- object.__setattr__(grid, "topography_source", a)
786
+ grid.mask_shapefile = mask_shapefile
752
787
 
753
788
  return grid
754
789
 
@@ -769,10 +804,7 @@ class Grid:
769
804
 
770
805
  @classmethod
771
806
  def from_yaml(
772
- cls,
773
- filepath: str | Path,
774
- section_name: str = "Grid",
775
- verbose: bool = False,
807
+ cls, filepath: str | Path, verbose: bool = False, **kwargs: Any
776
808
  ) -> "Grid":
777
809
  """Create an instance of the class from a YAML file.
778
810
 
@@ -780,10 +812,13 @@ class Grid:
780
812
  ----------
781
813
  filepath : Union[str, Path]
782
814
  The path to the YAML file from which the parameters will be read.
783
- section_name : str, optional
784
- The name of the YAML section containing the grid configuration. Defaults to "Grid".
785
815
  verbose : bool, optional
786
816
  Indicates whether to print grid generation steps with timing. Defaults to False.
817
+ **kwargs : Any
818
+ Additional keyword arguments:
819
+
820
+ - section_name : str, optional (default: "Grid")
821
+ The name of the YAML section containing the grid configuration.
787
822
 
788
823
  Returns
789
824
  -------
@@ -801,6 +836,8 @@ class Grid:
801
836
  Issues a warning if the ROMS-Tools version in the YAML header does not match the
802
837
  currently installed version.
803
838
  """
839
+ section_name: str = kwargs.pop("section_name", None) or "Grid"
840
+
804
841
  filepath = Path(filepath)
805
842
  # Read the entire file content
806
843
  with filepath.open("r") as file:
@@ -854,7 +891,7 @@ class Grid:
854
891
  attr_str = ", ".join(f"{k}={v!r}" for k, v in attr_dict.items())
855
892
  return f"{cls_name}({attr_str})"
856
893
 
857
- def _create_horizontal_grid(self) -> xr.Dataset():
894
+ def _create_horizontal_grid(self) -> xr.Dataset:
858
895
  """Create the horizontal grid based on a Mercator projection and store it in the
859
896
  'ds' attribute.
860
897
 
@@ -872,41 +909,32 @@ class Grid:
872
909
  - Longitude values are adjusted to fall within the range [0, 360].
873
910
  - Grid rotation and translation are applied based on the specified parameters.
874
911
  """
875
- if self.verbose:
876
- start_time = time.time()
877
- logging.info("=== Creating the horizontal grid ===")
878
-
879
- self._raise_if_domain_size_too_large()
912
+ with Timed("=== Creating the horizontal grid ===", verbose=self.verbose):
913
+ self._raise_if_domain_size_too_large()
880
914
 
881
- coords = self._make_initial_lon_lat_ds()
915
+ coords = self._make_initial_lon_lat_ds()
882
916
 
883
- # rotate coordinate system
884
- coords = _rotate(coords, self.rot)
917
+ # rotate coordinate system
918
+ coords = _rotate(coords, self.rot)
885
919
 
886
- # translate coordinate system
887
- coords = _translate(coords, self.center_lat, self.center_lon)
920
+ # translate coordinate system
921
+ coords = _translate(coords, self.center_lat, self.center_lon)
888
922
 
889
- # compute 1/dx and 1/dy
890
- coords["pm"], coords["pn"] = _compute_coordinate_metrics(coords)
923
+ # compute 1/dx and 1/dy
924
+ coords["pm"], coords["pn"] = _compute_coordinate_metrics(coords)
891
925
 
892
- # compute angle of local grid positive x-axis relative to east
893
- coords["angle"] = _compute_angle(coords)
926
+ # compute angle of local grid positive x-axis relative to east
927
+ coords["angle"] = _compute_angle(coords)
894
928
 
895
- # make sure lons are in [0, 360] range
896
- for lon in ["lon", "lonu", "lonv", "lonq"]:
897
- coords[lon][coords[lon] < 0] = coords[lon][coords[lon] < 0] + 2 * np.pi
929
+ # make sure lons are in [0, 360] range
930
+ for lon in ["lon", "lonu", "lonv", "lonq"]:
931
+ coords[lon][coords[lon] < 0] = coords[lon][coords[lon] < 0] + 2 * np.pi
898
932
 
899
- ds = self._create_grid_ds(coords)
933
+ ds = self._create_grid_ds(coords)
900
934
 
901
- ds = self._add_global_metadata(ds)
935
+ ds = self._add_global_metadata(ds)
902
936
 
903
- if self.verbose:
904
- logging.info(f"Total time: {time.time() - start_time:.3f} seconds")
905
- logging.info(
906
- "========================================================================================================"
907
- )
908
-
909
- self.ds = ds
937
+ self.ds = ds
910
938
 
911
939
  def _add_global_metadata(self, ds):
912
940
  """Add global metadata and attributes to the dataset.