roms-tools 1.7.0__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. roms_tools/__init__.py +2 -1
  2. roms_tools/setup/boundary_forcing.py +246 -146
  3. roms_tools/setup/datasets.py +229 -69
  4. roms_tools/setup/download.py +13 -17
  5. roms_tools/setup/grid.py +777 -614
  6. roms_tools/setup/initial_conditions.py +168 -32
  7. roms_tools/setup/mask.py +115 -0
  8. roms_tools/setup/nesting.py +575 -0
  9. roms_tools/setup/plot.py +218 -63
  10. roms_tools/setup/regrid.py +4 -2
  11. roms_tools/setup/river_forcing.py +125 -29
  12. roms_tools/setup/surface_forcing.py +31 -25
  13. roms_tools/setup/tides.py +29 -14
  14. roms_tools/setup/topography.py +250 -153
  15. roms_tools/setup/utils.py +174 -44
  16. roms_tools/setup/vertical_coordinate.py +5 -16
  17. roms_tools/tests/test_setup/test_boundary_forcing.py +10 -5
  18. roms_tools/tests/test_setup/test_data/grid.zarr/.zattrs +0 -1
  19. roms_tools/tests/test_setup/test_data/grid.zarr/.zmetadata +56 -201
  20. roms_tools/tests/test_setup/test_data/grid.zarr/Cs_r/.zattrs +1 -1
  21. roms_tools/tests/test_setup/test_data/grid.zarr/Cs_w/.zattrs +1 -1
  22. roms_tools/tests/test_setup/test_data/grid.zarr/{layer_depth_rho → sigma_r}/.zarray +2 -6
  23. roms_tools/tests/test_setup/test_data/grid.zarr/sigma_r/.zattrs +7 -0
  24. roms_tools/tests/test_setup/test_data/grid.zarr/sigma_r/0 +0 -0
  25. roms_tools/tests/test_setup/test_data/grid.zarr/sigma_w/.zarray +20 -0
  26. roms_tools/tests/test_setup/test_data/grid.zarr/sigma_w/.zattrs +7 -0
  27. roms_tools/tests/test_setup/test_data/grid.zarr/sigma_w/0 +0 -0
  28. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/.zattrs +1 -2
  29. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/.zmetadata +58 -203
  30. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/Cs_r/.zattrs +1 -1
  31. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/Cs_w/.zattrs +1 -1
  32. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/h/.zattrs +1 -1
  33. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/h/0.0 +0 -0
  34. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/mask_coarse/0.0 +0 -0
  35. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/mask_rho/0.0 +0 -0
  36. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/mask_u/0.0 +0 -0
  37. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/mask_v/0.0 +0 -0
  38. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/sigma_r/.zarray +20 -0
  39. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/sigma_r/.zattrs +7 -0
  40. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/sigma_r/0 +0 -0
  41. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/sigma_w/.zarray +20 -0
  42. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/sigma_w/.zattrs +7 -0
  43. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/sigma_w/0 +0 -0
  44. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/.zmetadata +2 -3
  45. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/river_tracer/.zattrs +1 -2
  46. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/tracer_name/.zarray +1 -1
  47. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/tracer_name/0 +0 -0
  48. roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/.zmetadata +5 -6
  49. roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_tracer/.zarray +2 -2
  50. roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_tracer/.zattrs +1 -2
  51. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/river_tracer/0.0.0 +0 -0
  52. roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/tracer_name/.zarray +2 -2
  53. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/tracer_name/0 +0 -0
  54. roms_tools/tests/test_setup/test_datasets.py +2 -2
  55. roms_tools/tests/test_setup/test_grid.py +110 -12
  56. roms_tools/tests/test_setup/test_initial_conditions.py +2 -1
  57. roms_tools/tests/test_setup/test_nesting.py +489 -0
  58. roms_tools/tests/test_setup/test_river_forcing.py +53 -15
  59. roms_tools/tests/test_setup/test_surface_forcing.py +3 -22
  60. roms_tools/tests/test_setup/test_tides.py +2 -1
  61. roms_tools/tests/test_setup/test_topography.py +106 -1
  62. roms_tools/tests/test_setup/test_validation.py +2 -2
  63. {roms_tools-1.7.0.dist-info → roms_tools-2.1.0.dist-info}/LICENSE +1 -1
  64. {roms_tools-1.7.0.dist-info → roms_tools-2.1.0.dist-info}/METADATA +9 -4
  65. {roms_tools-1.7.0.dist-info → roms_tools-2.1.0.dist-info}/RECORD +85 -108
  66. {roms_tools-1.7.0.dist-info → roms_tools-2.1.0.dist-info}/WHEEL +1 -1
  67. roms_tools/_version.py +0 -2
  68. roms_tools/tests/test_setup/test_data/grid.zarr/interface_depth_rho/.zarray +0 -24
  69. roms_tools/tests/test_setup/test_data/grid.zarr/interface_depth_rho/.zattrs +0 -9
  70. roms_tools/tests/test_setup/test_data/grid.zarr/interface_depth_rho/0.0.0 +0 -0
  71. roms_tools/tests/test_setup/test_data/grid.zarr/interface_depth_u/.zarray +0 -24
  72. roms_tools/tests/test_setup/test_data/grid.zarr/interface_depth_u/.zattrs +0 -9
  73. roms_tools/tests/test_setup/test_data/grid.zarr/interface_depth_u/0.0.0 +0 -0
  74. roms_tools/tests/test_setup/test_data/grid.zarr/interface_depth_v/.zarray +0 -24
  75. roms_tools/tests/test_setup/test_data/grid.zarr/interface_depth_v/.zattrs +0 -9
  76. roms_tools/tests/test_setup/test_data/grid.zarr/interface_depth_v/0.0.0 +0 -0
  77. roms_tools/tests/test_setup/test_data/grid.zarr/layer_depth_rho/.zattrs +0 -9
  78. roms_tools/tests/test_setup/test_data/grid.zarr/layer_depth_rho/0.0.0 +0 -0
  79. roms_tools/tests/test_setup/test_data/grid.zarr/layer_depth_u/.zarray +0 -24
  80. roms_tools/tests/test_setup/test_data/grid.zarr/layer_depth_u/.zattrs +0 -9
  81. roms_tools/tests/test_setup/test_data/grid.zarr/layer_depth_u/0.0.0 +0 -0
  82. roms_tools/tests/test_setup/test_data/grid.zarr/layer_depth_v/.zarray +0 -24
  83. roms_tools/tests/test_setup/test_data/grid.zarr/layer_depth_v/.zattrs +0 -9
  84. roms_tools/tests/test_setup/test_data/grid.zarr/layer_depth_v/0.0.0 +0 -0
  85. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/interface_depth_rho/.zarray +0 -24
  86. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/interface_depth_rho/.zattrs +0 -9
  87. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/interface_depth_rho/0.0.0 +0 -0
  88. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/interface_depth_u/.zarray +0 -24
  89. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/interface_depth_u/.zattrs +0 -9
  90. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/interface_depth_u/0.0.0 +0 -0
  91. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/interface_depth_v/.zarray +0 -24
  92. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/interface_depth_v/.zattrs +0 -9
  93. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/interface_depth_v/0.0.0 +0 -0
  94. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/layer_depth_rho/.zarray +0 -24
  95. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/layer_depth_rho/.zattrs +0 -9
  96. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/layer_depth_rho/0.0.0 +0 -0
  97. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/layer_depth_u/.zarray +0 -24
  98. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/layer_depth_u/.zattrs +0 -9
  99. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/layer_depth_u/0.0.0 +0 -0
  100. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/layer_depth_v/.zarray +0 -24
  101. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/layer_depth_v/.zattrs +0 -9
  102. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/layer_depth_v/0.0.0 +0 -0
  103. roms_tools/tests/test_setup/test_data/river_forcing.zarr/river_tracer/0.0.0 +0 -0
  104. roms_tools/tests/test_setup/test_data/river_forcing.zarr/tracer_name/0 +0 -0
  105. roms_tools/tests/test_setup/test_vertical_coordinate.py +0 -91
  106. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/.zattrs +0 -0
  107. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/.zgroup +0 -0
  108. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/abs_time/.zarray +0 -0
  109. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/abs_time/.zattrs +0 -0
  110. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/abs_time/0 +0 -0
  111. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/month/.zarray +0 -0
  112. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/month/.zattrs +0 -0
  113. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/month/0 +0 -0
  114. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_name/.zarray +0 -0
  115. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_name/.zattrs +0 -0
  116. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_name/0 +0 -0
  117. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_time/.zarray +0 -0
  118. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_time/.zattrs +0 -0
  119. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_time/0 +0 -0
  120. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_volume/.zarray +0 -0
  121. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_volume/.zattrs +0 -0
  122. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_volume/0.0 +0 -0
  123. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/tracer_name/.zattrs +0 -0
  124. {roms_tools-1.7.0.dist-info → roms_tools-2.1.0.dist-info}/top_level.txt +0 -0
roms_tools/setup/grid.py CHANGED
@@ -1,4 +1,5 @@
1
- import copy
1
+ import time
2
+ import logging
2
3
  from dataclasses import dataclass, field, asdict
3
4
 
4
5
  import numpy as np
@@ -6,24 +7,32 @@ import xarray as xr
6
7
  import matplotlib.pyplot as plt
7
8
  import yaml
8
9
  import importlib.metadata
9
-
10
- from typing import Union
11
- from roms_tools.setup.topography import _add_topography_and_mask, _add_velocity_masks
12
- from roms_tools.setup.plot import _plot, _section_plot, _profile_plot, _line_plot
13
- from roms_tools.setup.utils import interpolate_from_rho_to_u, interpolate_from_rho_to_v
10
+ from typing import Dict, Union, List
11
+ from roms_tools.setup.topography import _add_topography
12
+ from roms_tools.setup.mask import _add_mask, _add_velocity_masks
13
+ from roms_tools.setup.plot import _plot, _section_plot
14
+ from roms_tools.setup.utils import (
15
+ interpolate_from_rho_to_u,
16
+ interpolate_from_rho_to_v,
17
+ get_target_coords,
18
+ gc_dist,
19
+ )
14
20
  from roms_tools.setup.vertical_coordinate import sigma_stretch, compute_depth
15
21
  from roms_tools.setup.utils import extract_single_value, save_datasets
16
- import logging
17
22
  from pathlib import Path
18
23
 
19
- RADIUS_OF_EARTH = 6371315.0 # in m
20
-
21
24
 
22
25
  @dataclass(frozen=True, kw_only=True)
23
26
  class Grid:
24
- """A single ROMS grid.
27
+ """A single ROMS grid, used for creating, plotting, and then saving a new ROMS
28
+ domain grid.
29
+
30
+ The grid generation consists of four steps:
25
31
 
26
- Used for creating, plotting, and then saving a new ROMS domain grid.
32
+ 1. Creating the horizontal grid
33
+ 2. Creating the mask
34
+ 3. Generating the topography
35
+ 4. Preparing the vertical coordinate system
27
36
 
28
37
  Parameters
29
38
  ----------
@@ -43,6 +52,15 @@ class Grid:
43
52
  Rotation of grid x-direction from lines of constant latitude, measured in degrees.
44
53
  Positive values represent a counterclockwise rotation.
45
54
  The default is 0, which means that the x-direction of the grid is aligned with lines of constant latitude.
55
+ topography_source : Dict[str, Union[str, Path]], optional
56
+ Dictionary specifying the source of the topography data:
57
+
58
+ - "name" (str): The name of the topography data source (e.g., "SRTM15").
59
+ - "path" (Union[str, Path, List[Union[str, Path]]]): The path to the raw data file. Can be a string or a Path object.
60
+
61
+ The default is "ETOPO5", which does not require a path.
62
+ hmin : float, optional
63
+ The minimum ocean depth (in meters). The default is 5.0.
46
64
  N : int, optional
47
65
  The number of vertical levels. The default is 100.
48
66
  theta_s : float, optional
@@ -51,11 +69,8 @@ class Grid:
51
69
  The bottom control parameter. Must satisfy 0 < theta_b <= 4. The default is 2.0.
52
70
  hc : float, optional
53
71
  The critical depth (in meters). The default is 300.0.
54
- topography_source : str, optional
55
- Specifies the data source to use for the topography. Options are
56
- "ETOPO5". The default is "ETOPO5".
57
- hmin : float, optional
58
- The minimum ocean depth (in meters). The default is 5.0.
72
+ verbose: bool, optional
73
+ Indicates whether to print grid generation steps with timing. Defaults to False.
59
74
 
60
75
  Raises
61
76
  ------
@@ -74,161 +89,166 @@ class Grid:
74
89
  theta_s: float = 5.0
75
90
  theta_b: float = 2.0
76
91
  hc: float = 300.0
77
- topography_source: str = "ETOPO5"
92
+ topography_source: Dict[str, Union[str, Path, List[Union[str, Path]]]] = None
78
93
  hmin: float = 5.0
94
+ verbose: bool = False
79
95
  ds: xr.Dataset = field(init=False, repr=False)
80
96
  straddle: bool = field(init=False, repr=False)
81
97
 
82
98
  def __post_init__(self):
83
- ds = _make_grid_ds(
84
- nx=self.nx,
85
- ny=self.ny,
86
- size_x=self.size_x,
87
- size_y=self.size_y,
88
- center_lon=self.center_lon,
89
- center_lat=self.center_lat,
90
- rot=self.rot,
91
- )
92
- # Calling object.__setattr__ is ugly but apparently this really is the best (current) way to combine __post_init__ with a frozen dataclass
93
- # see https://stackoverflow.com/questions/53756788/how-to-set-the-value-of-dataclass-field-in-post-init-when-frozen-true
94
- object.__setattr__(self, "ds", ds)
95
99
 
96
- # Update self.ds with topography and mask information
97
- self.update_topography_and_mask(
98
- topography_source=self.topography_source,
99
- hmin=self.hmin,
100
- )
100
+ self._input_checks()
101
+
102
+ # Horizontal grid
103
+ self._create_horizontal_grid()
101
104
 
102
105
  # Check if the Greenwich meridian goes through the domain.
103
106
  self._straddle()
104
107
 
105
- object.__setattr__(self, "ds", ds)
108
+ # Mask
109
+ self._create_mask(verbose=self.verbose)
106
110
 
107
- # Update the grid by adding grid variables that are coarsened versions of the original grid variables
111
+ # Coarsen the dataset if needed
108
112
  self._coarsen()
109
113
 
114
+ # Topography and mask
115
+ self.update_topography(
116
+ topography_source=self.topography_source,
117
+ hmin=self.hmin,
118
+ verbose=self.verbose,
119
+ )
120
+
121
+ # Vertical coordinate system
110
122
  self.update_vertical_coordinate(
111
- N=self.N, theta_s=self.theta_s, theta_b=self.theta_b, hc=self.hc
123
+ N=self.N,
124
+ theta_s=self.theta_s,
125
+ theta_b=self.theta_b,
126
+ hc=self.hc,
127
+ verbose=self.verbose,
112
128
  )
113
129
 
114
- def update_topography_and_mask(self, hmin, topography_source="ETOPO5") -> None:
115
- """Update the grid dataset by adding or overwriting the topography and land/sea
116
- mask.
130
+ def _input_checks(self):
131
+ if self.topography_source is None:
132
+ object.__setattr__(self, "topography_source", {"name": "ETOPO5"})
117
133
 
118
- This method processes the topography data and generates a land/sea mask.
119
- It applies several steps, including interpolating topography, smoothing
120
- the topography over the entire domain and locally, and filling in enclosed basins. The
121
- processed topography and mask are added to the grid's dataset as new variables.
134
+ if "name" not in self.topography_source:
135
+ raise ValueError(
136
+ "`topography_source` must include a 'name' key specifying the data source."
137
+ )
122
138
 
123
- Parameters
124
- ----------
125
- hmin : float
126
- The minimum ocean depth (in meters).
127
- topography_source : str
128
- Specifies the data source to use for the topography. Options are
129
- "ETOPO5". Default is "ETOPO5".
139
+ if self.topography_source["name"] != "ETOPO5":
140
+ if "path" not in self.topography_source:
141
+ raise ValueError(
142
+ "`topography_source` must include a 'path' key when the 'name' is not 'ETOPO5'."
143
+ )
130
144
 
131
- Returns
132
- -------
133
- None
134
- This method modifies the dataset in place and does not return a value.
135
- """
145
+ def _create_mask(self, verbose=False) -> None:
146
+
147
+ if verbose:
148
+ start_time = time.time()
149
+ logging.info("=== Creating the mask ===")
150
+ ds = _add_mask(self.ds)
151
+
152
+ if verbose:
153
+ logging.info(f"Total time: {time.time() - start_time:.3f} seconds")
154
+ logging.info(
155
+ "========================================================================================================"
156
+ )
136
157
 
137
- ds = _add_topography_and_mask(self.ds, topography_source, hmin)
138
- # Assign the updated dataset back to the frozen dataclass
139
158
  object.__setattr__(self, "ds", ds)
140
- object.__setattr__(self, "topography_source", topography_source)
141
- object.__setattr__(self, "hmin", hmin)
142
159
 
143
- def _straddle(self) -> None:
144
- """Check if the Greenwich meridian goes through the domain.
160
+ def update_topography(
161
+ self, topography_source=None, hmin=None, verbose=False
162
+ ) -> None:
163
+ """Update the grid dataset with processed topography.
145
164
 
146
- This method sets the `straddle` attribute to `True` if the Greenwich meridian
147
- (0° longitude) intersects the domain defined by `lon_rho`. Otherwise, it sets
148
- the `straddle` attribute to `False`.
165
+ This method performs several key operations, including regridding the topography, smoothing the
166
+ topography over the entire domain and locally.
167
+ The processed topography is then added to the grid's dataset.
149
168
 
150
- The check is based on whether the longitudinal differences between adjacent
151
- points exceed 300 degrees, indicating a potential wraparound of longitude.
152
- """
169
+ Parameters
170
+ ----------
171
+ topography_source : dict, optional
172
+ A dictionary specifying the source of the topography data. The dictionary should
173
+ contain the following keys:
174
+ - "name" (str): The name of the topography data source (e.g., "SRTM15").
175
+ - "path" (Union[str, Path): The path to the raw data file.
153
176
 
154
- if (
155
- np.abs(self.ds.lon_rho.diff("xi_rho")).max() > 300
156
- or np.abs(self.ds.lon_rho.diff("eta_rho")).max() > 300
157
- ):
158
- object.__setattr__(self, "straddle", True)
159
- else:
160
- object.__setattr__(self, "straddle", False)
177
+ If not provided, `topography_source` will remain unchanged (i.e., the existing value will not be overwritten).
161
178
 
162
- def _coarsen(self):
163
- """Update the grid by adding grid variables that are coarsened versions of the
164
- original fine-resoluion grid variables. The coarsening is by a factor of two.
179
+ hmin : float, optional
180
+ The minimum ocean depth (in meters).
181
+ If not provided, `hmin` will remain unchanged (i.e., the existing value will not be overwritten).
165
182
 
166
- The specific variables being coarsened are:
167
- - `lon_rho` -> `lon_coarse`: Longitude at rho points.
168
- - `lat_rho` -> `lat_coarse`: Latitude at rho points.
169
- - `angle` -> `angle_coarse`: Angle between the xi axis and true east.
170
- - `mask_rho` -> `mask_coarse`: Land/sea mask at rho points.
183
+ verbose : bool, optional
184
+ If True, the method will print detailed information about the grid generation process,
185
+ including the timing of each step. Defaults to False.
171
186
 
172
187
  Returns
173
188
  -------
174
189
  None
175
-
176
- Modifies
177
- --------
178
- self.ds : xr.Dataset
179
- The dataset attribute of the Grid instance is updated with the new coarser variables.
190
+ This method updates the internal dataset (`self.ds`) in place by adding or overwriting the
191
+ topography variable. It does not return any value.
180
192
  """
181
- d = {
182
- "angle": "angle_coarse",
183
- "mask_rho": "mask_coarse",
184
- "lat_rho": "lat_coarse",
185
- "lon_rho": "lon_coarse",
186
- }
187
193
 
188
- for fine_var, coarse_var in d.items():
189
- fine_field = self.ds[fine_var]
190
- if self.straddle and fine_var == "lon_rho":
191
- fine_field = xr.where(fine_field > 180, fine_field - 360, fine_field)
194
+ topography_source = topography_source or self.topography_source
195
+ hmin = hmin or self.hmin
192
196
 
193
- coarse_field = _f2c(fine_field)
194
- if fine_var == "lon_rho":
195
- coarse_field = xr.where(
196
- coarse_field < 0, coarse_field + 360, coarse_field
197
- )
198
- if coarse_var in ["lon_coarse", "lat_coarse"]:
199
- ds = self.ds.assign_coords({coarse_var: coarse_field})
200
- object.__setattr__(self, "ds", ds)
201
- else:
202
- self.ds[coarse_var] = coarse_field
197
+ # Extract target coordinates for processing
198
+ target_coords = get_target_coords(self)
203
199
 
204
- self.ds["mask_coarse"] = xr.where(self.ds["mask_coarse"] > 0.5, 1, 0).astype(
205
- np.int32
200
+ # If verbose is enabled, start the timer and print the start message
201
+ if verbose:
202
+ start_time = time.time()
203
+ logging.info(
204
+ f"=== Generating the topography using {topography_source['name']} data and hmin = {hmin} meters ==="
205
+ )
206
+
207
+ # Add topography and mask to the dataset
208
+ ds = _add_topography(
209
+ ds=self.ds,
210
+ target_coords=target_coords,
211
+ topography_source=topography_source,
212
+ hmin=hmin,
213
+ verbose=verbose,
206
214
  )
207
215
 
208
- for fine_var, coarse_var in d.items():
209
- self.ds[coarse_var].attrs[
210
- "long_name"
211
- ] = f"{self.ds[fine_var].attrs['long_name']} on coarsened grid"
212
- self.ds[coarse_var].attrs["units"] = self.ds[fine_var].attrs["units"]
216
+ # If verbose is enabled, print elapsed time and a separator
217
+ if verbose:
218
+ logging.info(f"Total time: {time.time() - start_time:.3f} seconds")
219
+ logging.info(
220
+ "========================================================================================================"
221
+ )
213
222
 
214
- def update_vertical_coordinate(self, N, theta_s, theta_b, hc) -> None:
223
+ # Update the grid's dataset and related attributes
224
+ object.__setattr__(self, "ds", ds)
225
+ object.__setattr__(self, "topography_source", topography_source)
226
+ object.__setattr__(self, "hmin", hmin)
227
+
228
+ def update_vertical_coordinate(
229
+ self, N=None, theta_s=None, theta_b=None, hc=None, verbose=False
230
+ ) -> None:
215
231
  """Create vertical coordinate variables for the ROMS grid.
216
232
 
217
- This method computes the S-coordinate stretching curves and depths
218
- at various grid points (rho, u, v) using the specified parameters.
219
- The computed depths and stretching curves are added to the dataset
220
- as new coordinates, along with their corresponding attributes.
233
+ This method computes the S-coordinate stretching curves at rho- and
234
+ w-points.
221
235
 
222
236
  Parameters
223
237
  ----------
224
238
  N : int
225
239
  Number of vertical levels.
240
+ If not provided, `N` will remain unchanged (i.e., the existing value will not be overwritten).
226
241
  theta_s : float
227
242
  S-coordinate surface control parameter.
243
+ If not provided, `theta_s` will remain unchanged (i.e., the existing value will not be overwritten).
228
244
  theta_b : float
229
245
  S-coordinate bottom control parameter.
246
+ If not provided, `theta_b` will remain unchanged (i.e., the existing value will not be overwritten).
230
247
  hc : float
231
248
  Critical depth (m) used in ROMS vertical coordinate stretching.
249
+ If not provided, `hc` will remain unchanged (i.e., the existing value will not be overwritten).
250
+ verbose: bool, optional
251
+ Indicates whether to print vertical coordinate generation steps with timing. Defaults to False.
232
252
 
233
253
  Returns
234
254
  -------
@@ -236,6 +256,17 @@ class Grid:
236
256
  This method modifies the dataset in place by adding vertical coordinate variables.
237
257
  """
238
258
 
259
+ N = N or self.N
260
+ theta_s = theta_s or self.theta_s
261
+ theta_b = theta_b or self.theta_b
262
+ hc = hc or self.hc
263
+
264
+ if verbose:
265
+ start_time = time.time()
266
+ logging.info(
267
+ f"=== Preparing the vertical coordinate system using N = {N}, theta_s = {theta_s}, theta_b = {theta_b}, hc = {hc} ==="
268
+ )
269
+
239
270
  ds = self.ds
240
271
  # need to drop vertical coordinates because they could cause conflict if N changed
241
272
  vars_to_drop = [
@@ -245,6 +276,8 @@ class Grid:
245
276
  "interface_depth_rho",
246
277
  "interface_depth_u",
247
278
  "interface_depth_v",
279
+ "sigma_r",
280
+ "sigma_w",
248
281
  "Cs_w",
249
282
  "Cs_r",
250
283
  ]
@@ -253,74 +286,125 @@ class Grid:
253
286
  if var in ds.variables:
254
287
  ds = ds.drop_vars(var)
255
288
 
256
- h = ds.h
257
-
258
289
  cs_r, sigma_r = sigma_stretch(theta_s, theta_b, N, "r")
259
- zr = compute_depth(h * 0, h, hc, cs_r, sigma_r)
260
290
  cs_w, sigma_w = sigma_stretch(theta_s, theta_b, N, "w")
261
- zw = compute_depth(h * 0, h, hc, cs_w, sigma_w)
291
+
292
+ ds["sigma_r"] = sigma_r.astype(np.float32)
293
+ ds["sigma_r"].attrs[
294
+ "long_name"
295
+ ] = "Fractional vertical stretching coordinate at rho-points"
296
+ ds["sigma_r"].attrs["units"] = "nondimensional"
262
297
 
263
298
  ds["Cs_r"] = cs_r.astype(np.float32)
264
- ds["Cs_r"].attrs["long_name"] = "S-coordinate stretching curves at rho-points"
299
+ ds["Cs_r"].attrs["long_name"] = "Vertical stretching function at rho-points"
265
300
  ds["Cs_r"].attrs["units"] = "nondimensional"
266
301
 
302
+ ds["sigma_w"] = sigma_w.astype(np.float32)
303
+ ds["sigma_w"].attrs[
304
+ "long_name"
305
+ ] = "Fractional vertical stretching coordinate at w-points"
306
+ ds["sigma_w"].attrs["units"] = "nondimensional"
307
+
267
308
  ds["Cs_w"] = cs_w.astype(np.float32)
268
- ds["Cs_w"].attrs["long_name"] = "S-coordinate stretching curves at w-points"
309
+ ds["Cs_w"].attrs["long_name"] = "Vertical stretching function at w-points"
269
310
  ds["Cs_w"].attrs["units"] = "nondimensional"
270
311
 
271
312
  ds.attrs["theta_s"] = np.float32(theta_s)
272
313
  ds.attrs["theta_b"] = np.float32(theta_b)
273
314
  ds.attrs["hc"] = np.float32(hc)
274
315
 
275
- depth = -zr
276
- depth.attrs["long_name"] = "Layer depth at rho-points"
277
- depth.attrs["units"] = "m"
316
+ if verbose:
317
+ logging.info(f"Total time: {time.time() - start_time:.3f} seconds")
318
+ logging.info(
319
+ "========================================================================================================"
320
+ )
278
321
 
279
- depth_u = interpolate_from_rho_to_u(depth)
280
- depth_u.attrs["long_name"] = "Layer depth at u-points"
281
- depth_u.attrs["units"] = "m"
322
+ object.__setattr__(self, "ds", ds)
323
+ object.__setattr__(self, "theta_s", theta_s)
324
+ object.__setattr__(self, "theta_b", theta_b)
325
+ object.__setattr__(self, "hc", hc)
326
+ object.__setattr__(self, "N", N)
282
327
 
283
- depth_v = interpolate_from_rho_to_v(depth)
284
- depth_v.attrs["long_name"] = "Layer depth at v-points"
285
- depth_v.attrs["units"] = "m"
328
+ def _straddle(self) -> None:
329
+ """Check if the Greenwich meridian goes through the domain.
286
330
 
287
- interface_depth = -zw
288
- interface_depth.attrs["long_name"] = "Interface depth at rho-points"
289
- interface_depth.attrs["units"] = "m"
331
+ This method sets the `straddle` attribute to `True` if the Greenwich meridian
332
+ (0° longitude) intersects the domain defined by `lon_rho`. Otherwise, it sets
333
+ the `straddle` attribute to `False`.
334
+
335
+ The check is based on whether the longitudinal differences between adjacent
336
+ points exceed 300 degrees, indicating a potential wraparound of longitude.
337
+ """
338
+
339
+ if (
340
+ np.abs(self.ds.lon_rho.diff("xi_rho")).max() > 300
341
+ or np.abs(self.ds.lon_rho.diff("eta_rho")).max() > 300
342
+ ):
343
+ object.__setattr__(self, "straddle", True)
344
+ else:
345
+ object.__setattr__(self, "straddle", False)
290
346
 
291
- interface_depth_u = interpolate_from_rho_to_u(interface_depth)
292
- interface_depth_u.attrs["long_name"] = "Interface depth at u-points"
293
- interface_depth_u.attrs["units"] = "m"
347
+ def _coarsen(self):
348
+ """Update the grid by adding grid variables that are coarsened versions of the
349
+ original fine-resoluion grid variables. The coarsening is by a factor of two.
294
350
 
295
- interface_depth_v = interpolate_from_rho_to_v(interface_depth)
296
- interface_depth_v.attrs["long_name"] = "Interface depth at v-points"
297
- interface_depth_v.attrs["units"] = "m"
351
+ The specific variables being coarsened are:
352
+ - `lon_rho` -> `lon_coarse`: Longitude at rho points.
353
+ - `lat_rho` -> `lat_coarse`: Latitude at rho points.
354
+ - `angle` -> `angle_coarse`: Angle between the xi axis and true east.
355
+ - `mask_rho` -> `mask_coarse`: Land/sea mask at rho points.
356
+ """
357
+ d = {
358
+ "angle": "angle_coarse",
359
+ "mask_rho": "mask_coarse",
360
+ "lat_rho": "lat_coarse",
361
+ "lon_rho": "lon_coarse",
362
+ }
298
363
 
299
- ds = ds.assign_coords(
300
- {
301
- "layer_depth_rho": depth.astype(np.float32),
302
- "layer_depth_u": depth_u.astype(np.float32),
303
- "layer_depth_v": depth_v.astype(np.float32),
304
- "interface_depth_rho": interface_depth.astype(np.float32),
305
- "interface_depth_u": interface_depth_u.astype(np.float32),
306
- "interface_depth_v": interface_depth_v.astype(np.float32),
307
- }
308
- )
309
- ds = ds.drop_vars(["eta_rho", "xi_rho"])
364
+ ds = self.ds
365
+
366
+ for fine_var, coarse_var in d.items():
367
+ fine_field = ds[fine_var]
368
+ if self.straddle and fine_var == "lon_rho":
369
+ fine_field = xr.where(fine_field > 180, fine_field - 360, fine_field)
370
+
371
+ coarse_field = _f2c(fine_field)
372
+ if fine_var == "lon_rho":
373
+ coarse_field = xr.where(
374
+ coarse_field < 0, coarse_field + 360, coarse_field
375
+ )
376
+ if coarse_var in ["lon_coarse", "lat_coarse"]:
377
+ ds = ds.assign_coords({coarse_var: coarse_field})
378
+ else:
379
+ ds[coarse_var] = coarse_field
380
+
381
+ del fine_field, coarse_field
382
+
383
+ ds["mask_coarse"] = xr.where(ds["mask_coarse"] > 0.5, 1, 0).astype(np.int32)
384
+
385
+ for fine_var, coarse_var in d.items():
386
+ long_name = ds[fine_var].attrs.get(
387
+ "long_name", ds[fine_var].attrs.get("Long_name", "")
388
+ )
389
+ ds[coarse_var].attrs["long_name"] = f"{long_name} on coarsened grid"
390
+ ds[coarse_var].attrs["units"] = ds[fine_var].attrs["units"]
310
391
 
311
392
  object.__setattr__(self, "ds", ds)
312
- object.__setattr__(self, "theta_s", theta_s)
313
- object.__setattr__(self, "theta_b", theta_b)
314
- object.__setattr__(self, "hc", hc)
315
- object.__setattr__(self, "N", N)
316
393
 
317
- def plot(self, bathymetry: bool = False) -> None:
394
+ def plot(
395
+ self, bathymetry: bool = False, title: str = None, with_dim_names: bool = False
396
+ ) -> None:
318
397
  """Plot the grid.
319
398
 
320
399
  Parameters
321
400
  ----------
322
- bathymetry : bool
401
+ bathymetry : bool, optional
323
402
  Whether or not to plot the bathymetry. Default is False.
403
+ title : str, optional
404
+ The title of the plot. If not provided, it will be set to a default.
405
+ with_dim_names : bool, optional
406
+ Whether or not to plot the dimension names. Default is False.
407
+
324
408
 
325
409
  Returns
326
410
  -------
@@ -329,6 +413,8 @@ class Grid:
329
413
  """
330
414
 
331
415
  if bathymetry:
416
+ if title is None:
417
+ title = "ROMS grid and bathymetry"
332
418
  field = self.ds.h.where(self.ds.mask_rho)
333
419
  field = field.assign_coords(
334
420
  {"lon": self.ds.lon_rho, "lat": self.ds.lat_rho}
@@ -344,28 +430,30 @@ class Grid:
344
430
  self.ds,
345
431
  field=field,
346
432
  straddle=self.straddle,
433
+ title=title,
434
+ with_dim_names=with_dim_names,
347
435
  kwargs=kwargs,
348
436
  )
349
437
  else:
350
- _plot(self.ds, straddle=self.straddle)
438
+ if title is None:
439
+ title = "ROMS grid"
440
+ _plot(
441
+ self.ds,
442
+ straddle=self.straddle,
443
+ title=title,
444
+ with_dim_names=with_dim_names,
445
+ )
351
446
 
352
447
  def plot_vertical_coordinate(
353
- self, varname="layer_depth_rho", s=None, eta=None, xi=None, ax=None
448
+ self,
449
+ s=None,
450
+ eta=None,
451
+ xi=None,
354
452
  ) -> None:
355
- """Plot the vertical coordinate system for a given eta-, xi-, or s-slice.
453
+ """Plot the layer depth for a given eta-, xi-, or s-slice.
356
454
 
357
455
  Parameters
358
456
  ----------
359
- varname : str, optional
360
- The vertical coordinate field to plot. Options include:
361
-
362
- - "layer_depth_rho": Layer depth at rho-points.
363
- - "layer_depth_u": Layer depth at u-points.
364
- - "layer_depth_v": Layer depth at v-points.
365
- - "interface_depth_rho": Interface depth at rho-points.
366
- - "interface_depth_u": Interface depth at u-points.
367
- - "interface_depth_v": Interface depth at v-points.
368
-
369
457
  s: int, optional
370
458
  The s-index to plot. Default is None.
371
459
  eta : int, optional
@@ -383,105 +471,67 @@ class Grid:
383
471
  Raises
384
472
  ------
385
473
  ValueError
386
- If the specified varname is not one of the valid options.
387
- If none of s, eta, xi are specified.
474
+ If not exactly one of s, eta, xi is specified.
388
475
  """
389
476
 
390
- if not any([s is not None, eta is not None, xi is not None]):
391
- raise ValueError("At least one of s, eta, or xi must be specified.")
477
+ title = "Layer depth at rho-points"
392
478
 
393
- self.ds[varname].load()
394
- field = self.ds[varname].squeeze()
479
+ if sum(s is not None for s in [s, eta, xi]) != 1:
480
+ raise ValueError("Exactly one of s, eta, or xi must be specified.")
395
481
 
396
- if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
397
- interface_depth = self.ds.interface_depth_rho
398
- field = field.where(self.ds.mask_rho)
399
- field = field.assign_coords(
400
- {"lon": self.ds.lon_rho, "lat": self.ds.lat_rho}
401
- )
402
- elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
403
- interface_depth = self.ds.interface_depth_u
404
- field = field.where(self.ds.mask_u)
405
- field = field.assign_coords({"lon": self.ds.lon_u, "lat": self.ds.lat_u})
406
- elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
407
- interface_depth = self.ds.interface_depth_v
408
- field = field.where(self.ds.mask_v)
409
- field = field.assign_coords({"lon": self.ds.lon_v, "lat": self.ds.lat_v})
410
-
411
- # slice the field as desired
412
- title = field.long_name
413
- if s is not None:
414
- if "s_rho" in field.dims:
415
- title = title + f", s_rho = {field.s_rho[s].item()}"
416
- field = field.isel(s_rho=s)
417
- elif "s_w" in field.dims:
418
- title = title + f", s_w = {field.s_w[s].item()}"
419
- field = field.isel(s_w=s)
420
- else:
421
- raise ValueError(
422
- f"None of the expected dimensions (s_rho, s_w) found in ds[{varname}]."
423
- )
482
+ h = self.ds["h"]
483
+ h = h.assign_coords({"lon": self.ds.lon_rho, "lat": self.ds.lat_rho})
424
484
 
485
+ # slice the bathymetry as desired
425
486
  if eta is not None:
426
- if "eta_rho" in field.dims:
427
- title = title + f", eta_rho = {field.eta_rho[eta].item()}"
428
- field = field.isel(eta_rho=eta)
429
- interface_depth = interface_depth.isel(eta_rho=eta)
430
- elif "eta_v" in field.dims:
431
- title = title + f", eta_v = {field.eta_v[eta].item()}"
432
- field = field.isel(eta_v=eta)
433
- interface_depth = interface_depth.isel(eta_v=eta)
434
- else:
435
- raise ValueError(
436
- f"None of the expected dimensions (eta_rho, eta_v) found in ds[{varname}]."
437
- )
487
+ title = title + f", eta_rho = {h.eta_rho[eta].item()}"
488
+ h = h.isel(eta_rho=eta)
438
489
  if xi is not None:
439
- if "xi_rho" in field.dims:
440
- title = title + f", xi_rho = {field.xi_rho[xi].item()}"
441
- field = field.isel(xi_rho=xi)
442
- interface_depth = interface_depth.isel(xi_rho=xi)
443
- elif "xi_u" in field.dims:
444
- title = title + f", xi_u = {field.xi_u[xi].item()}"
445
- field = field.isel(xi_u=xi)
446
- interface_depth = interface_depth.isel(xi_u=xi)
447
- else:
448
- raise ValueError(
449
- f"None of the expected dimensions (xi_rho, xi_u) found in ds[{varname}]."
450
- )
490
+ title = title + f", xi_rho = {h.xi_rho[xi].item()}"
491
+ h = h.isel(xi_rho=xi)
451
492
 
452
493
  if eta is None and xi is None:
453
- vmax = field.max().values
454
- vmin = field.min().values
494
+ layer_depth = compute_depth(0, h, self.hc, self.ds.Cs_r, self.ds.sigma_r)
495
+ title = title + f", s_rho = {layer_depth.s_rho[s].item()}"
496
+ layer_depth = layer_depth.isel(s_rho=s)
497
+
498
+ layer_depth.attrs["long_name"] = "Layer depth"
499
+ layer_depth.attrs["units"] = "m"
500
+
501
+ vmax = layer_depth.max().values
502
+ vmin = layer_depth.min().values
455
503
  cmap = plt.colormaps.get_cmap("YlGnBu")
456
504
  cmap.set_bad(color="gray")
457
505
  kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
458
506
 
459
507
  _plot(
460
508
  self.ds,
461
- field=field,
509
+ field=layer_depth.where(self.ds.mask_rho),
462
510
  straddle=self.straddle,
463
511
  depth_contours=False,
464
512
  title=title,
465
513
  kwargs=kwargs,
466
514
  )
467
515
  else:
468
- if len(field.dims) == 2:
469
- cmap = plt.colormaps.get_cmap("YlGnBu")
470
- cmap.set_bad(color="gray")
471
- kwargs = {"vmax": 0.0, "vmin": 0.0, "cmap": cmap, "add_colorbar": False}
472
-
473
- _section_plot(
474
- xr.zeros_like(field),
475
- interface_depth=interface_depth,
476
- title=title,
477
- kwargs=kwargs,
478
- ax=ax,
479
- )
480
- else:
481
- if "s_rho" in field.dims or "s_w" in field.dims:
482
- _profile_plot(field, title=title, ax=ax)
483
- else:
484
- _line_plot(field, title=title, ax=ax)
516
+ layer_depth = compute_depth(0, h, self.hc, self.ds.Cs_r, self.ds.sigma_r)
517
+ layer_depth.attrs["long_name"] = "Layer depth"
518
+ layer_depth.attrs["units"] = "m"
519
+ field = xr.zeros_like(layer_depth)
520
+ field = field.assign_coords({"layer_depth": layer_depth})
521
+
522
+ interface_depth = compute_depth(
523
+ 0, h, self.hc, self.ds.Cs_w, self.ds.sigma_w
524
+ )
525
+ cmap = plt.colormaps.get_cmap("YlGnBu")
526
+ cmap.set_bad(color="gray")
527
+ kwargs = {"vmax": 0.0, "vmin": 0.0, "cmap": cmap, "add_colorbar": False}
528
+
529
+ _section_plot(
530
+ field=field,
531
+ interface_depth=interface_depth,
532
+ title=title,
533
+ kwargs=kwargs,
534
+ )
485
535
 
486
536
  def save(
487
537
  self, filepath: Union[str, Path], np_eta: int = None, np_xi: int = None
@@ -532,13 +582,15 @@ class Grid:
532
582
  return saved_filenames
533
583
 
534
584
  @classmethod
535
- def from_file(cls, filepath: Union[str, Path]) -> "Grid":
585
+ def from_file(cls, filepath: Union[str, Path], verbose: bool = False) -> "Grid":
536
586
  """Create a Grid instance from an existing file.
537
587
 
538
588
  Parameters
539
589
  ----------
540
590
  filepath : Union[str, Path]
541
591
  Path to the file containing the grid information.
592
+ verbose: bool, optional
593
+ Indicates whether to print grid generation steps with timing. Defaults to False.
542
594
 
543
595
  Returns
544
596
  -------
@@ -584,13 +636,14 @@ class Grid:
584
636
 
585
637
  # Update vertical coordinate if necessary
586
638
  if not all(var in grid.ds for var in ["Cs_r", "Cs_w"]):
639
+ logging.warning("Vertical coordinates (Cs_r, Cs_w) not found in grid file.")
587
640
  N = 100
588
641
  theta_s = 5.0
589
642
  theta_b = 2.0
590
643
  hc = 300.0
591
644
 
592
645
  grid.update_vertical_coordinate(
593
- N=N, theta_s=theta_s, theta_b=theta_b, hc=hc
646
+ N=N, theta_s=theta_s, theta_b=theta_b, hc=hc, verbose=True
594
647
  )
595
648
  else:
596
649
  object.__setattr__(grid, "theta_s", ds.attrs["theta_s"].item())
@@ -639,7 +692,10 @@ class Grid:
639
692
  "hmin",
640
693
  ]:
641
694
  if attr in ds.attrs:
642
- a = ds.attrs[attr]
695
+ if attr == "topography_source":
696
+ a = {"name": ds.attrs[attr]}
697
+ else:
698
+ a = ds.attrs[attr]
643
699
  else:
644
700
  a = None
645
701
  object.__setattr__(grid, attr, a)
@@ -661,6 +717,7 @@ class Grid:
661
717
  data = asdict(self)
662
718
  data.pop("ds", None)
663
719
  data.pop("straddle", None)
720
+ data.pop("verbose", None)
664
721
 
665
722
  # Include the version of roms-tools
666
723
  try:
@@ -681,18 +738,38 @@ class Grid:
681
738
  yaml.dump(yaml_data, file, default_flow_style=False, sort_keys=False)
682
739
 
683
740
  @classmethod
684
- def from_yaml(cls, filepath: Union[str, Path]) -> "Grid":
741
+ def from_yaml(
742
+ cls,
743
+ filepath: Union[str, Path],
744
+ section_name: str = "Grid",
745
+ verbose: bool = False,
746
+ ) -> "Grid":
685
747
  """Create an instance of the class from a YAML file.
686
748
 
687
749
  Parameters
688
750
  ----------
689
751
  filepath : Union[str, Path]
690
752
  The path to the YAML file from which the parameters will be read.
753
+ section_name : str, optional
754
+ The name of the YAML section containing the grid configuration. Defaults to "Grid".
755
+ verbose : bool, optional
756
+ Indicates whether to print grid generation steps with timing. Defaults to False.
691
757
 
692
758
  Returns
693
759
  -------
694
760
  Grid
695
- An instance of the Grid class.
761
+ An instance of the Grid class initialized with the parameters from the YAML file.
762
+
763
+ Raises
764
+ ------
765
+ ValueError
766
+ If the ROMS-Tools version is not found in the YAML file or if the specified section
767
+ does not exist in the file.
768
+
769
+ Warnings
770
+ --------
771
+ Issues a warning if the ROMS-Tools version in the YAML header does not match the
772
+ currently installed version.
696
773
  """
697
774
 
698
775
  filepath = Path(filepath)
@@ -712,8 +789,8 @@ class Grid:
712
789
  continue
713
790
  if "roms_tools_version" in doc:
714
791
  header_data = doc
715
- elif "Grid" in doc:
716
- grid_data = doc["Grid"]
792
+ elif section_name in doc:
793
+ grid_data = doc[section_name]
717
794
 
718
795
  if header_data is None:
719
796
  raise ValueError("Version of ROMS-Tools not found in the YAML file.")
@@ -733,164 +810,341 @@ class Grid:
733
810
 
734
811
  if grid_data is None:
735
812
  raise ValueError("No Grid configuration found in the YAML file.")
813
+ return cls(**grid_data, verbose=verbose)
736
814
 
737
- return cls(**grid_data)
738
-
739
- # override __repr__ method to only print attributes that are actually set
740
815
  def __repr__(self) -> str:
816
+ """Return a string representation of the object with non-None attributes,
817
+ excluding 'ds'."""
741
818
  cls = self.__class__
742
819
  cls_name = cls.__name__
743
- # Create a dictionary of attribute names and values, filtering out those that are not set and 'ds'
820
+ # Filter attributes to exclude 'ds' and those with None values
744
821
  attr_dict = {
745
822
  k: v for k, v in self.__dict__.items() if k != "ds" and v is not None
746
823
  }
747
824
  attr_str = ", ".join(f"{k}={v!r}" for k, v in attr_dict.items())
748
825
  return f"{cls_name}({attr_str})"
749
826
 
827
+ def _create_horizontal_grid(self) -> xr.Dataset():
828
+ """Create the horizontal grid based on a Mercator projection and store it in the
829
+ 'ds' attribute.
750
830
 
751
- def _make_grid_ds(
752
- nx: int,
753
- ny: int,
754
- size_x: float,
755
- size_y: float,
756
- center_lon: float,
757
- center_lat: float,
758
- rot: float,
759
- ) -> xr.Dataset:
760
- _raise_if_domain_size_too_large(size_x, size_y)
761
-
762
- initial_lon_lat_vars = _make_initial_lon_lat_ds(size_x, size_y, nx, ny)
763
-
764
- # rotate coordinate system
765
- rotated_lon_lat_vars = _rotate(*initial_lon_lat_vars, rot)
766
-
767
- # translate coordinate system
768
- translated_lon_lat_vars = _translate(*rotated_lon_lat_vars, center_lat, center_lon)
769
- lon, lat, lonu, latu, lonv, latv, lonq, latq = translated_lon_lat_vars
770
-
771
- # compute 1/dx and 1/dy
772
- pm, pn = _compute_coordinate_metrics(lon, lonu, latu, lonv, latv)
773
-
774
- # compute angle of local grid positive x-axis relative to east
775
- ang = _compute_angle(lon, lonu, latu, lonq)
776
-
777
- # make sure lons are in [0, 360] range
778
- lon[lon < 0] = lon[lon < 0] + 2 * np.pi
779
- lonu[lonu < 0] = lonu[lonu < 0] + 2 * np.pi
780
- lonv[lonv < 0] = lonv[lonv < 0] + 2 * np.pi
781
- lonq[lonq < 0] = lonq[lonq < 0] + 2 * np.pi
782
-
783
- ds = _create_grid_ds(
784
- lon,
785
- lat,
786
- lonu,
787
- latu,
788
- lonv,
789
- latv,
790
- lonq,
791
- latq,
792
- pm,
793
- pn,
794
- ang,
795
- rot,
796
- center_lon,
797
- center_lat,
798
- )
831
+ Parameters
832
+ ----------
833
+ None
799
834
 
800
- ds = _add_global_metadata(ds, size_x, size_y, center_lon, center_lat, rot)
835
+ Returns
836
+ -------
837
+ xr.Dataset
838
+ The created horizontal grid dataset, including coordinates, grid metrics, angles, and metadata.
801
839
 
802
- return ds
840
+ Notes
841
+ -----
842
+ - Longitude values are adjusted to fall within the range [0, 360].
843
+ - Grid rotation and translation are applied based on the specified parameters.
844
+ """
845
+ if self.verbose:
846
+ start_time = time.time()
847
+ logging.info("=== Creating the horizontal grid ===")
803
848
 
849
+ self._raise_if_domain_size_too_large()
804
850
 
805
- def _raise_if_domain_size_too_large(size_x, size_y):
806
- threshold = 20000
807
- if size_x > threshold or size_y > threshold:
808
- raise ValueError("Domain size has to be smaller than %g km" % threshold)
851
+ coords = self._make_initial_lon_lat_ds()
809
852
 
853
+ # rotate coordinate system
854
+ coords = _rotate(coords, self.rot)
810
855
 
811
- def _make_initial_lon_lat_ds(size_x, size_y, nx, ny):
812
- # Mercator projection around the equator
856
+ # translate coordinate system
857
+ coords = _translate(coords, self.center_lat, self.center_lon)
813
858
 
814
- # initially define the domain to be longer in x-direction (dimension "length")
815
- # than in y-direction (dimension "width") to keep grid distortion minimal
816
- if size_y > size_x:
817
- domain_length, domain_width = size_y * 1e3, size_x * 1e3 # in m
818
- nl, nw = ny, nx
819
- else:
820
- domain_length, domain_width = size_x * 1e3, size_y * 1e3 # in m
821
- nl, nw = nx, ny
859
+ # compute 1/dx and 1/dy
860
+ coords["pm"], coords["pn"] = _compute_coordinate_metrics(coords)
822
861
 
823
- domain_length_in_degrees = domain_length / RADIUS_OF_EARTH
824
- domain_width_in_degrees = domain_width / RADIUS_OF_EARTH
862
+ # compute angle of local grid positive x-axis relative to east
863
+ coords["angle"] = _compute_angle(coords)
825
864
 
826
- # 1d array describing the longitudes at cell centers
827
- x = np.arange(-0.5, nl + 1.5, 1)
828
- lon_array_1d_in_degrees = (
829
- domain_length_in_degrees * x / nl - domain_length_in_degrees / 2
830
- )
831
- # 1d array describing the longitudes at cell corners (or vorticity points "q")
832
- xq = np.arange(-1, nl + 2, 1)
833
- lonq_array_1d_in_degrees_q = (
834
- domain_length_in_degrees * xq / nl - domain_length_in_degrees / 2
835
- )
865
+ # make sure lons are in [0, 360] range
866
+ for lon in ["lon", "lonu", "lonv", "lonq"]:
867
+ coords[lon][coords[lon] < 0] = coords[lon][coords[lon] < 0] + 2 * np.pi
868
+
869
+ ds = self._create_grid_ds(coords)
870
+
871
+ ds = self._add_global_metadata(ds)
872
+
873
+ if self.verbose:
874
+ logging.info(f"Total time: {time.time() - start_time:.3f} seconds")
875
+ logging.info(
876
+ "========================================================================================================"
877
+ )
878
+
879
+ object.__setattr__(self, "ds", ds)
880
+
881
+ def _add_global_metadata(self, ds):
882
+ """Add global metadata and attributes to the dataset.
883
+
884
+ Parameters
885
+ ----------
886
+ ds : xr.Dataset
887
+ Dataset to which global metadata and attributes will be added.
888
+
889
+ Returns
890
+ -------
891
+ xr.Dataset
892
+ The dataset with added global metadata, including grid type, tool version,
893
+ grid dimensions, center coordinates, and rotation.
894
+
895
+ Notes
896
+ -----
897
+ - The "spherical" attribute indicates the grid type and is set to "T" (spherical).
898
+ - The ROMS-Tools version is included as "roms_tools_version". If unavailable, it defaults to "unknown".
899
+ """
900
+ ds["spherical"] = xr.DataArray(np.array("T", dtype="S1"))
901
+ ds["spherical"].attrs["Long_name"] = "Grid type logical switch"
902
+ ds["spherical"].attrs["option_T"] = "spherical"
836
903
 
837
- # convert degrees latitude to y-coordinate using Mercator projection
838
- y1 = np.log(np.tan(np.pi / 4 - domain_width_in_degrees / 4))
839
- y2 = np.log(np.tan(np.pi / 4 + domain_width_in_degrees / 4))
904
+ ds.attrs["title"] = "ROMS grid created by ROMS-Tools"
840
905
 
841
- # linearly space points in y-space
842
- y = (y2 - y1) * np.arange(-0.5, nw + 1.5, 1) / nw + y1
843
- yq = (y2 - y1) * np.arange(-1, nw + 2) / nw + y1
906
+ # Include the version of roms-tools
907
+ try:
908
+ roms_tools_version = importlib.metadata.version("roms-tools")
909
+ except importlib.metadata.PackageNotFoundError:
910
+ roms_tools_version = "unknown"
911
+
912
+ ds.attrs["roms_tools_version"] = roms_tools_version
913
+ ds.attrs["size_x"] = self.size_x
914
+ ds.attrs["size_y"] = self.size_y
915
+ ds.attrs["center_lon"] = self.center_lon
916
+ ds.attrs["center_lat"] = self.center_lat
917
+ ds.attrs["rot"] = self.rot
918
+
919
+ return ds
920
+
921
+ def _raise_if_domain_size_too_large(self):
922
+ """Raise a ValueError if the domain size exceeds the allowable threshold.
923
+
924
+ Checks if either the x or y domain size exceeds 20,000 km and raises an error
925
+ with appropriate details if the threshold is surpassed.
926
+ """
927
+ threshold = 20000
928
+ if self.size_x > threshold or self.size_y > threshold:
929
+ raise ValueError(
930
+ f"Domain size exceeds the allowable limit of {threshold} km. "
931
+ f"Received dimensions: size_x = {self.size_x} km, size_y = {self.size_y} km. "
932
+ "Please reduce the domain size to meet the threshold."
933
+ )
934
+
935
+ def _make_initial_lon_lat_ds(self):
936
+ """Generate initial longitude and latitude arrays with Mercator projection
937
+ around the equator.
938
+
939
+ Returns
940
+ -------
941
+ dict
942
+ A dictionary containing the following arrays:
943
+ - lon, lat: 2D arrays of longitudes and latitudes at cell centers.
944
+ - lonu, latu: 2D arrays of longitudes and latitudes at u-points.
945
+ - lonv, latv: 2D arrays of longitudes and latitudes at v-points.
946
+ - lonq, latq: 2D arrays of longitudes and latitudes at cell corners.
947
+ """
948
+
949
+ r_earth = 6371315.0
950
+
951
+ # initially define the domain to be longer in x-direction (dimension "length")
952
+ # than in y-direction (dimension "width") to keep grid distortion minimal
953
+ if self.size_y > self.size_x:
954
+ domain_length, domain_width = self.size_y * 1e3, self.size_x * 1e3 # in m
955
+ nl, nw = self.ny, self.nx
956
+ else:
957
+ domain_length, domain_width = self.size_x * 1e3, self.size_y * 1e3 # in m
958
+ nl, nw = self.nx, self.ny
959
+
960
+ domain_length_in_degrees = domain_length / r_earth
961
+ domain_width_in_degrees = domain_width / r_earth
962
+
963
+ # Generate 1D longitude arrays at cell centers and corners
964
+ lon_array_1d_in_degrees = domain_length_in_degrees * (
965
+ np.arange(-0.5, nl + 1.5) / nl - 0.5
966
+ )
967
+ lonq_array_1d_in_degrees_q = domain_length_in_degrees * (
968
+ np.arange(-1, nl + 2) / nl - 0.5
969
+ )
970
+
971
+ # Mercator projection for latitude
972
+ y1 = np.log(np.tan(np.pi / 4 - domain_width_in_degrees / 4))
973
+ y2 = np.log(np.tan(np.pi / 4 + domain_width_in_degrees / 4))
974
+
975
+ # Generate 1D latitude arrays with inverse Mercator projection
976
+ lat_array_1d_in_degrees = np.arctan(
977
+ np.sinh((y2 - y1) * (np.arange(-0.5, nw + 1.5) / nw) + y1)
978
+ )
979
+ latq_array_1d_in_degrees = np.arctan(
980
+ np.sinh((y2 - y1) * (np.arange(-1, nw + 2) / nw) + y1)
981
+ )
982
+
983
+ # 2D grids for cell centers and corners
984
+ lon, lat = np.meshgrid(lon_array_1d_in_degrees, lat_array_1d_in_degrees)
985
+ lonq, latq = np.meshgrid(lonq_array_1d_in_degrees_q, latq_array_1d_in_degrees)
986
+
987
+ if self.size_y > self.size_x:
988
+ # Rotate grid by 90 degrees because until here the grid has been defined
989
+ # to be longer in x-direction than in y-direction
990
+
991
+ lon, lat = _rot_sphere(lon, lat, 90)
992
+ lonq, latq = _rot_sphere(lonq, latq, 90)
993
+
994
+ lon = np.transpose(np.flip(lon, 0))
995
+ lat = np.transpose(np.flip(lat, 0))
996
+ lonq = np.transpose(np.flip(lonq, 0))
997
+ latq = np.transpose(np.flip(latq, 0))
998
+
999
+ # Inference for u- and v-point coordinates
1000
+ lonu = 0.5 * (lon[:, :-1] + lon[:, 1:])
1001
+ latu = 0.5 * (lat[:, :-1] + lat[:, 1:])
1002
+ lonv = 0.5 * (lon[:-1, :] + lon[1:, :])
1003
+ latv = 0.5 * (lat[:-1, :] + lat[1:, :])
1004
+
1005
+ coords = {
1006
+ "lon": lon,
1007
+ "lat": lat,
1008
+ "lonu": lonu,
1009
+ "latu": latu,
1010
+ "lonv": lonv,
1011
+ "latv": latv,
1012
+ "lonq": lonq,
1013
+ "latq": latq,
1014
+ }
1015
+
1016
+ return coords
1017
+
1018
+ def _create_grid_ds(self, coords):
1019
+ """Create an xarray Dataset with grid coordinates and metrics.
1020
+
1021
+ Parameters
1022
+ ----------
1023
+ coords : dict
1024
+ Dictionary containing:
1025
+ - lon, lat, lonu, latu, lonv, latv : 1d arrays of coordinates (degrees)
1026
+ - angle : 2d array (radians)
1027
+ - pm, pn : 2d arrays (meter^-1)
844
1028
 
845
- # inverse Mercator projections
846
- lat_array_1d_in_degrees = np.arctan(np.sinh(y))
847
- latq_array_1d_in_degrees = np.arctan(np.sinh(yq))
1029
+ Returns
1030
+ -------
1031
+ xarray.Dataset
1032
+ Dataset with variables: lon_rho, lat_rho, lon_u, lat_u, lon_v, lat_v,
1033
+ angle, f (Coriolis parameter), pm, pn.
1034
+ """
1035
+
1036
+ ds = xr.Dataset()
1037
+
1038
+ lon_rho = xr.Variable(
1039
+ data=coords["lon"] * 180 / np.pi,
1040
+ dims=["eta_rho", "xi_rho"],
1041
+ attrs={"long_name": "longitude of rho-points", "units": "degrees East"},
1042
+ )
1043
+ lat_rho = xr.Variable(
1044
+ data=coords["lat"] * 180 / np.pi,
1045
+ dims=["eta_rho", "xi_rho"],
1046
+ attrs={"long_name": "latitude of rho-points", "units": "degrees North"},
1047
+ )
1048
+ lon_u = xr.Variable(
1049
+ data=coords["lonu"] * 180 / np.pi,
1050
+ dims=["eta_rho", "xi_u"],
1051
+ attrs={"long_name": "longitude of u-points", "units": "degrees East"},
1052
+ )
1053
+ lat_u = xr.Variable(
1054
+ data=coords["latu"] * 180 / np.pi,
1055
+ dims=["eta_rho", "xi_u"],
1056
+ attrs={"long_name": "latitude of u-points", "units": "degrees North"},
1057
+ )
1058
+ lon_v = xr.Variable(
1059
+ data=coords["lonv"] * 180 / np.pi,
1060
+ dims=["eta_v", "xi_rho"],
1061
+ attrs={"long_name": "longitude of v-points", "units": "degrees East"},
1062
+ )
1063
+ lat_v = xr.Variable(
1064
+ data=coords["latv"] * 180 / np.pi,
1065
+ dims=["eta_v", "xi_rho"],
1066
+ attrs={"long_name": "latitude of v-points", "units": "degrees North"},
1067
+ )
1068
+ # lon_q = xr.Variable(
1069
+ # data=coords["lonq"] * 180 / np.pi,
1070
+ # dims=["eta_psi", "xi_psi"],
1071
+ # attrs={"long_name": "longitude of psi-points", "units": "degrees East"},
1072
+ # )
1073
+ # lat_q = xr.Variable(
1074
+ # data=coords["latq"] * 180 / np.pi,
1075
+ # dims=["eta_psi", "xi_psi"],
1076
+ # attrs={"long_name": "latitude of psi-points", "units": "degrees North"},
1077
+ # )
848
1078
 
849
- # 2d grid at cell centers
850
- lon, lat = np.meshgrid(lon_array_1d_in_degrees, lat_array_1d_in_degrees)
851
- # 2d grid at cell corners
852
- lonq, latq = np.meshgrid(lonq_array_1d_in_degrees_q, latq_array_1d_in_degrees)
1079
+ ds = ds.assign_coords(
1080
+ {
1081
+ "lat_rho": lat_rho,
1082
+ "lon_rho": lon_rho,
1083
+ "lat_u": lat_u,
1084
+ "lon_u": lon_u,
1085
+ "lat_v": lat_v,
1086
+ "lon_v": lon_v,
1087
+ # "lat_psi": lat_q,
1088
+ # "lon_psi": lon_q,
1089
+ }
1090
+ )
853
1091
 
854
- if size_y > size_x:
855
- # Rotate grid by 90 degrees because until here the grid has been defined
856
- # to be longer in x-direction than in y-direction
1092
+ ds["angle"] = xr.Variable(
1093
+ data=coords["angle"],
1094
+ dims=["eta_rho", "xi_rho"],
1095
+ attrs={"long_name": "Angle between xi axis and east", "units": "radians"},
1096
+ )
857
1097
 
858
- lon, lat = _rot_sphere(lon, lat, 90)
859
- lonq, latq = _rot_sphere(lonq, latq, 90)
1098
+ # Coriolis frequency
1099
+ f0 = 4 * np.pi * np.sin(coords["lat"]) / (24 * 3600)
860
1100
 
861
- lon = np.transpose(np.flip(lon, 0))
862
- lat = np.transpose(np.flip(lat, 0))
863
- lonq = np.transpose(np.flip(lonq, 0))
864
- latq = np.transpose(np.flip(latq, 0))
1101
+ ds["f"] = xr.Variable(
1102
+ data=f0,
1103
+ dims=["eta_rho", "xi_rho"],
1104
+ attrs={
1105
+ "long_name": "Coriolis parameter at rho-points",
1106
+ "units": "second-1",
1107
+ },
1108
+ )
865
1109
 
866
- # infer longitudes and latitudes at u- and v-points
867
- lonu = 0.5 * (lon[:, :-1] + lon[:, 1:])
868
- latu = 0.5 * (lat[:, :-1] + lat[:, 1:])
869
- lonv = 0.5 * (lon[:-1, :] + lon[1:, :])
870
- latv = 0.5 * (lat[:-1, :] + lat[1:, :])
1110
+ ds["pm"] = xr.Variable(
1111
+ data=coords["pm"],
1112
+ dims=["eta_rho", "xi_rho"],
1113
+ attrs={
1114
+ "long_name": "Curvilinear coordinate metric in xi-direction",
1115
+ "units": "meter-1",
1116
+ },
1117
+ )
1118
+ ds["pn"] = xr.Variable(
1119
+ data=coords["pn"],
1120
+ dims=["eta_rho", "xi_rho"],
1121
+ attrs={
1122
+ "long_name": "Curvilinear coordinate metric in eta-direction",
1123
+ "units": "meter-1",
1124
+ },
1125
+ )
871
1126
 
872
- # TODO wrap up into temporary container Dataset object?
873
- return lon, lat, lonu, latu, lonv, latv, lonq, latq
1127
+ return ds
874
1128
 
875
1129
 
876
- def _rotate(lon, lat, lonu, latu, lonv, latv, lonq, latq, rot):
1130
+ def _rotate(coords, rot):
877
1131
  """Rotate grid counterclockwise relative to surface of Earth by rot degrees."""
878
1132
 
879
- (lon, lat) = _rot_sphere(lon, lat, rot)
880
- (lonu, latu) = _rot_sphere(lonu, latu, rot)
881
- (lonv, latv) = _rot_sphere(lonv, latv, rot)
882
- (lonq, latq) = _rot_sphere(lonq, latq, rot)
1133
+ (coords["lon"], coords["lat"]) = _rot_sphere(coords["lon"], coords["lat"], rot)
1134
+ (coords["lonu"], coords["latu"]) = _rot_sphere(coords["lonu"], coords["latu"], rot)
1135
+ (coords["lonv"], coords["latv"]) = _rot_sphere(coords["lonv"], coords["latv"], rot)
1136
+ (coords["lonq"], coords["latq"]) = _rot_sphere(coords["lonq"], coords["latq"], rot)
883
1137
 
884
- return lon, lat, lonu, latu, lonv, latv, lonq, latq
1138
+ return coords
885
1139
 
886
1140
 
887
- def _translate(lon, lat, lonu, latu, lonv, latv, lonq, latq, tra_lat, tra_lon):
1141
+ def _translate(coords, tra_lat, tra_lon):
888
1142
  """Translate grid so that the centre lies at the position (tra_lat, tra_lon)"""
889
1143
 
890
- (lon, lat) = _tra_sphere(lon, lat, tra_lat)
891
- (lonu, latu) = _tra_sphere(lonu, latu, tra_lat)
892
- (lonv, latv) = _tra_sphere(lonv, latv, tra_lat)
893
- (lonq, latq) = _tra_sphere(lonq, latq, tra_lat)
1144
+ (lon, lat) = _tra_sphere(coords["lon"], coords["lat"], tra_lat)
1145
+ (lonu, latu) = _tra_sphere(coords["lonu"], coords["latu"], tra_lat)
1146
+ (lonv, latv) = _tra_sphere(coords["lonv"], coords["latv"], tra_lat)
1147
+ (lonq, latq) = _tra_sphere(coords["lonq"], coords["latq"], tra_lat)
894
1148
 
895
1149
  lon = lon + tra_lon * np.pi / 180
896
1150
  lonu = lonu + tra_lon * np.pi / 180
@@ -902,133 +1156,171 @@ def _translate(lon, lat, lonu, latu, lonv, latv, lonq, latq, tra_lat, tra_lon):
902
1156
  lonv[lonv < -np.pi] = lonv[lonv < -np.pi] + 2 * np.pi
903
1157
  lonq[lonq < -np.pi] = lonq[lonq < -np.pi] + 2 * np.pi
904
1158
 
905
- return lon, lat, lonu, latu, lonv, latv, lonq, latq
1159
+ coords = {
1160
+ "lon": lon,
1161
+ "lat": lat,
1162
+ "lonu": lonu,
1163
+ "latu": latu,
1164
+ "lonv": lonv,
1165
+ "latv": latv,
1166
+ "lonq": lonq,
1167
+ "latq": latq,
1168
+ }
1169
+
1170
+ return coords
906
1171
 
907
1172
 
908
1173
  def _rot_sphere(lon, lat, rot):
909
- (n, m) = np.shape(lon)
910
- # convert rotation angle from degrees to radians
1174
+ """Rotate longitude and latitude coordinates on a sphere.
1175
+
1176
+ Parameters
1177
+ ----------
1178
+ lon : ndarray
1179
+ 2D array of longitudes in radians.
1180
+ lat : ndarray
1181
+ 2D array of latitudes in radians.
1182
+ rot : float
1183
+ Rotation angle in degrees.
1184
+
1185
+ Returns
1186
+ -------
1187
+ tuple
1188
+ Rotated longitude and latitude arrays (lon, lat) in radians.
1189
+ """
1190
+ # Convert rotation angle from degrees to radians
911
1191
  rot = rot * np.pi / 180
912
1192
 
913
- # translate into Cartesian coordinates x,y,z
914
- # conventions: (lon,lat) = (0,0) corresponds to (x,y,z) = ( 0,-r, 0)
915
- # (lon,lat) = (0,90) corresponds to (x,y,z) = ( 0, 0, r)
1193
+ # Convert spherical coordinates to Cartesian coordinates (x, y, z)
916
1194
  x1 = np.sin(lon) * np.cos(lat)
917
1195
  y1 = np.cos(lon) * np.cos(lat)
918
1196
  z1 = np.sin(lat)
919
1197
 
920
- # We will rotate these points around the small circle defined by
921
- # the intersection of the sphere and the plane that
922
- # is orthogonal to the line through (lon,lat) (0,0) and (180,0)
923
-
924
- # The rotation is in that plane around its intersection with
925
- # aforementioned line.
926
-
927
- # Since the plane is orthogonal to the y-axis (in my definition at least),
928
- # Rotations in the plane of the small circle maintain constant y and are around
929
- # (x,y,z) = (0,y1,0)
930
-
1198
+ # Calculate the radial distance in the x-z plane
931
1199
  rp1 = np.sqrt(x1**2 + z1**2)
932
1200
 
933
- ap1 = np.pi / 2 * np.ones((n, m))
934
- ap1[np.abs(x1) > 1e-7] = np.arctan(
935
- np.abs(z1[np.abs(x1) > 1e-7] / x1[np.abs(x1) > 1e-7])
936
- )
1201
+ # Compute azimuthal angle
1202
+ ap1 = np.arctan2(np.abs(z1), np.abs(x1))
937
1203
  ap1[x1 < 0] = np.pi - ap1[x1 < 0]
938
1204
  ap1[z1 < 0] = -ap1[z1 < 0]
939
1205
 
1206
+ # Apply rotation to the azimuthal angle
940
1207
  ap2 = ap1 + rot
941
1208
  x2 = rp1 * np.cos(ap2)
942
1209
  y2 = y1
943
1210
  z2 = rp1 * np.sin(ap2)
944
1211
 
945
- lon = np.pi / 2 * np.ones((n, m))
946
- lon[abs(y2) > 1e-7] = np.arctan(
947
- np.abs(x2[np.abs(y2) > 1e-7] / y2[np.abs(y2) > 1e-7])
948
- )
949
- lon[y2 < 0] = np.pi - lon[y2 < 0]
950
- lon[x2 < 0] = -lon[x2 < 0]
1212
+ # Recompute longitude and latitude
1213
+ lon_rot = np.arctan2(np.abs(x2), np.abs(y2))
1214
+ lon_rot[y2 < 0] = np.pi - lon_rot[y2 < 0]
1215
+ lon_rot[x2 < 0] = -lon_rot[x2 < 0]
951
1216
 
952
1217
  pr2 = np.sqrt(x2**2 + y2**2)
953
- lat = np.pi / 2 * np.ones((n, m))
954
- lat[np.abs(pr2) > 1e-7] = np.arctan(
955
- np.abs(z2[np.abs(pr2) > 1e-7] / pr2[np.abs(pr2) > 1e-7])
956
- )
957
- lat[z2 < 0] = -lat[z2 < 0]
1218
+ lat_rot = np.arctan2(np.abs(z2), pr2)
1219
+ lat_rot[z2 < 0] = -lat_rot[z2 < 0]
958
1220
 
959
- return (lon, lat)
1221
+ return lon_rot, lat_rot
960
1222
 
961
1223
 
962
1224
  def _tra_sphere(lon, lat, tra):
963
- (n, m) = np.shape(lon)
964
- tra = tra * np.pi / 180 # translation in latitude direction
1225
+ """Translate longitude and latitude coordinates on a sphere in the latitude
1226
+ direction.
965
1227
 
966
- # translate into x,y,z
967
- # conventions: (lon,lat) = (0,0) corresponds to (x,y,z) = ( 0,-r, 0)
968
- # (lon,lat) = (0,90) corresponds to (x,y,z) = ( 0, 0, r)
969
- x1 = np.sin(lon) * np.cos(lat)
970
- y1 = np.cos(lon) * np.cos(lat)
971
- z1 = np.sin(lat)
1228
+ Parameters
1229
+ ----------
1230
+ lon : ndarray
1231
+ 2D array of longitudes in radians.
1232
+ lat : ndarray
1233
+ 2D array of latitudes in radians.
1234
+ tra : float
1235
+ Translation angle in degrees.
972
1236
 
973
- # We will rotate these points around the small circle defined by
974
- # the intersection of the sphere and the plane that
975
- # is orthogonal to the line through (lon,lat) (90,0) and (-90,0)
1237
+ Returns
1238
+ -------
1239
+ tuple
1240
+ Translated longitude and latitude arrays (lon, lat) in radians.
1241
+ """
976
1242
 
977
- # The rotation is in that plane around its intersection with
978
- # aforementioned line.
1243
+ # Convert translation angle from degrees to radians
1244
+ tra = tra * np.pi / 180
979
1245
 
980
- # Since the plane is orthogonal to the x-axis (in my definition at least),
981
- # Rotations in the plane of the small circle maintain constant x and are around
982
- # (x,y,z) = (x1,0,0)
1246
+ # Convert spherical coordinates to Cartesian coordinates (x, y, z)
1247
+ x1 = np.sin(lon) * np.cos(lat)
1248
+ y1 = np.cos(lon) * np.cos(lat)
1249
+ z1 = np.sin(lat)
983
1250
 
1251
+ # Radial distance in the y-z plane
984
1252
  rp1 = np.sqrt(y1**2 + z1**2)
985
1253
 
986
- ap1 = np.pi / 2 * np.ones((n, m))
987
- ap1[np.abs(y1) > 1e-7] = np.arctan(
988
- np.abs(z1[np.abs(y1) > 1e-7] / y1[np.abs(y1) > 1e-7])
989
- )
1254
+ # Compute azimuthal angle in the y-z plane
1255
+ ap1 = np.arctan2(np.abs(z1), np.abs(y1))
990
1256
  ap1[y1 < 0] = np.pi - ap1[y1 < 0]
991
1257
  ap1[z1 < 0] = -ap1[z1 < 0]
992
1258
 
1259
+ # Apply translation in the azimuthal angle
993
1260
  ap2 = ap1 + tra
994
- x2 = x1
995
1261
  y2 = rp1 * np.cos(ap2)
996
1262
  z2 = rp1 * np.sin(ap2)
997
1263
 
998
- ## transformation from (x,y,z) to (lat,lon)
999
- lon = np.pi / 2 * np.ones((n, m))
1000
- lon[np.abs(y2) > 1e-7] = np.arctan(
1001
- np.abs(x2[np.abs(y2) > 1e-7] / y2[np.abs(y2) > 1e-7])
1002
- )
1003
- lon[y2 < 0] = np.pi - lon[y2 < 0]
1004
- lon[x2 < 0] = -lon[x2 < 0]
1264
+ # Convert back to spherical coordinates
1265
+ lon_rot = np.arctan2(np.abs(x1), np.abs(y2))
1266
+ lon_rot[y2 < 0] = np.pi - lon_rot[y2 < 0]
1267
+ lon_rot[x1 < 0] = -lon_rot[x1 < 0]
1005
1268
 
1006
- pr2 = np.sqrt(x2**2 + y2**2)
1007
- lat = np.pi / (2 * np.ones((n, m)))
1008
- lat[np.abs(pr2) > 1e-7] = np.arctan(
1009
- np.abs(z2[np.abs(pr2) > 1e-7] / pr2[np.abs(pr2) > 1e-7])
1010
- )
1011
- lat[z2 < 0] = -lat[z2 < 0]
1269
+ pr2 = np.sqrt(x1**2 + y2**2)
1270
+ lat_rot = np.arctan2(np.abs(z2), pr2)
1271
+ lat_rot[z2 < 0] = -lat_rot[z2 < 0]
1272
+
1273
+ return lon_rot, lat_rot
1012
1274
 
1013
- return (lon, lat)
1014
1275
 
1276
+ def _compute_coordinate_metrics(coords):
1277
+ """Compute the reciprocal of grid spacing (`pn` and `pm`) in the latitude and
1278
+ longitude directions.
1279
+
1280
+ Parameters
1281
+ ----------
1282
+ coords : dict
1283
+ A dictionary containing coordinate arrays 'lonu', 'latu', 'lonv', and 'latv' for the u- and v-velocity points.
1015
1284
 
1016
- def _compute_coordinate_metrics(lon, lonu, latu, lonv, latv):
1017
- """Compute the curvilinear coordinate metrics pn and pm, defined as 1/grid
1018
- spacing."""
1285
+ Returns
1286
+ -------
1287
+ pn : ndarray
1288
+ The metric for the latitude direction (1/dy).
1289
+
1290
+ pm : ndarray
1291
+ The metric for the longitude direction (1/dx).
1292
+
1293
+ Notes
1294
+ -----
1295
+ Boundary values of `pn` and `pm` are copied from adjacent interior values.
1296
+ """
1019
1297
 
1020
1298
  # pm = 1/dx
1021
- pmu = gc_dist(lonu[:, :-1], latu[:, :-1], lonu[:, 1:], latu[:, 1:])
1022
- pm = 0 * lon
1299
+ pmu = gc_dist(
1300
+ coords["lonu"][:, :-1],
1301
+ coords["latu"][:, :-1],
1302
+ coords["lonu"][:, 1:],
1303
+ coords["latu"][:, 1:],
1304
+ input_in_degrees=False,
1305
+ )
1306
+ pm = np.zeros_like(coords["lon"])
1023
1307
  pm[:, 1:-1] = pmu
1308
+ # Handle boundary conditions
1024
1309
  pm[:, 0] = pm[:, 1]
1025
1310
  pm[:, -1] = pm[:, -2]
1026
1311
  pm = 1 / pm
1027
1312
 
1028
1313
  # pn = 1/dy
1029
- pnv = gc_dist(lonv[:-1, :], latv[:-1, :], lonv[1:, :], latv[1:, :])
1030
- pn = 0 * lon
1314
+ pnv = gc_dist(
1315
+ coords["lonv"][:-1, :],
1316
+ coords["latv"][:-1, :],
1317
+ coords["lonv"][1:, :],
1318
+ coords["latv"][1:, :],
1319
+ input_in_degrees=False,
1320
+ )
1321
+ pn = np.zeros_like(coords["lon"])
1031
1322
  pn[1:-1, :] = pnv
1323
+ # Handle boundary conditions
1032
1324
  pn[0, :] = pn[1, :]
1033
1325
  pn[-1, :] = pn[-2, :]
1034
1326
  pn = 1 / pn
@@ -1036,179 +1328,50 @@ def _compute_coordinate_metrics(lon, lonu, latu, lonv, latv):
1036
1328
  return pn, pm
1037
1329
 
1038
1330
 
1039
- def gc_dist(lon1, lat1, lon2, lat2):
1040
- # Distance between 2 points along a great circle
1041
- # lat and lon in radians!!
1042
- # 2008, Jeroen Molemaker, UCLA
1043
-
1044
- dlat = lat2 - lat1
1045
- dlon = lon2 - lon1
1331
+ def _compute_angle(coords):
1332
+ """Compute angles of the local grid's positive x-axis relative to east.
1046
1333
 
1047
- dang = 2 * np.arcsin(
1048
- np.sqrt(
1049
- np.sin(dlat / 2) ** 2 + np.cos(lat2) * np.cos(lat1) * np.sin(dlon / 2) ** 2
1050
- )
1051
- ) # haversine function
1334
+ The angle is computed for each grid cell using the latitude and longitude
1335
+ differences between neighboring grid points. The result is wrapped to
1336
+ the range [-π, π] and adjusted based on longitude and latitude conditions.
1052
1337
 
1053
- dis = RADIUS_OF_EARTH * dang
1338
+ Parameters
1339
+ ----------
1340
+ coords : dict
1341
+ A dictionary containing 'latu' (latitudes) and 'lonu' (longitudes) arrays.
1054
1342
 
1055
- return dis
1343
+ Returns
1344
+ -------
1345
+ ang : ndarray
1346
+ An array of angles (in radians) of the local grid's positive x-axis
1347
+ relative to east for each grid point.
1348
+ """
1056
1349
 
1350
+ # Compute differences in latitudes and longitudes
1351
+ dellat = coords["latu"][:, 1:] - coords["latu"][:, :-1]
1352
+ dellon = coords["lonu"][:, 1:] - coords["lonu"][:, :-1]
1057
1353
 
1058
- def _compute_angle(lon, lonu, latu, lonq):
1059
- """Compute angles of local grid positive x-axis relative to east."""
1354
+ # Normalize longitude differences to the range [-π, π]
1355
+ dellon = (dellon + np.pi) % (2 * np.pi) - np.pi
1356
+ dellon *= np.cos(0.5 * (coords["latu"][:, 1:] + coords["latu"][:, :-1]))
1060
1357
 
1061
- dellat = latu[:, 1:] - latu[:, :-1]
1062
- dellon = lonu[:, 1:] - lonu[:, :-1]
1063
- dellon[dellon > np.pi] = dellon[dellon > np.pi] - 2 * np.pi
1064
- dellon[dellon < -np.pi] = dellon[dellon < -np.pi] + 2 * np.pi
1065
- dellon = dellon * np.cos(0.5 * (latu[:, 1:] + latu[:, :-1]))
1358
+ # Compute the angle in radians
1359
+ ang_s = np.arctan2(dellat, dellon)
1066
1360
 
1067
- ang = copy.copy(lon)
1068
- ang_s = np.arctan(dellat / (dellon + 1e-16))
1069
- ang_s[(dellon < 0) & (dellat < 0)] = ang_s[(dellon < 0) & (dellat < 0)] - np.pi
1070
- ang_s[(dellon < 0) & (dellat >= 0)] = ang_s[(dellon < 0) & (dellat >= 0)] + np.pi
1071
- ang_s[ang_s > np.pi] = ang_s[ang_s > np.pi] - np.pi
1072
- ang_s[ang_s < -np.pi] = ang_s[ang_s < -np.pi] + np.pi
1361
+ # Adjust angles based on longitude and latitude conditions
1362
+ ang_s[(dellon < 0) & (dellat < 0)] -= np.pi
1363
+ ang_s[(dellon < 0) & (dellat >= 0)] += np.pi
1364
+ ang_s = np.mod(ang_s + np.pi, 2 * np.pi) - np.pi # Ensure angles are in [-π, π]
1073
1365
 
1366
+ # Create output array and set angles
1367
+ ang = np.zeros_like(coords["lon"])
1074
1368
  ang[:, 1:-1] = ang_s
1075
- ang[:, 0] = ang[:, 1]
1076
- ang[:, -1] = ang[:, -2]
1369
+ ang[:, 0] = ang[:, 1] # Set first column to the second column
1370
+ ang[:, -1] = ang[:, -2] # Set last column to the second-to-last column
1077
1371
 
1078
1372
  return ang
1079
1373
 
1080
1374
 
1081
- def _create_grid_ds(
1082
- lon,
1083
- lat,
1084
- lonu,
1085
- latu,
1086
- lonv,
1087
- latv,
1088
- lonq,
1089
- latq,
1090
- pm,
1091
- pn,
1092
- angle,
1093
- rot,
1094
- center_lon,
1095
- center_lat,
1096
- ):
1097
- ds = xr.Dataset()
1098
-
1099
- lon_rho = xr.Variable(
1100
- data=lon * 180 / np.pi,
1101
- dims=["eta_rho", "xi_rho"],
1102
- attrs={"long_name": "longitude of rho-points", "units": "degrees East"},
1103
- )
1104
- lat_rho = xr.Variable(
1105
- data=lat * 180 / np.pi,
1106
- dims=["eta_rho", "xi_rho"],
1107
- attrs={"long_name": "latitude of rho-points", "units": "degrees North"},
1108
- )
1109
- lon_u = xr.Variable(
1110
- data=lonu * 180 / np.pi,
1111
- dims=["eta_rho", "xi_u"],
1112
- attrs={"long_name": "longitude of u-points", "units": "degrees East"},
1113
- )
1114
- lat_u = xr.Variable(
1115
- data=latu * 180 / np.pi,
1116
- dims=["eta_rho", "xi_u"],
1117
- attrs={"long_name": "latitude of u-points", "units": "degrees North"},
1118
- )
1119
- lon_v = xr.Variable(
1120
- data=lonv * 180 / np.pi,
1121
- dims=["eta_v", "xi_rho"],
1122
- attrs={"long_name": "longitude of v-points", "units": "degrees East"},
1123
- )
1124
- lat_v = xr.Variable(
1125
- data=latv * 180 / np.pi,
1126
- dims=["eta_v", "xi_rho"],
1127
- attrs={"long_name": "latitude of v-points", "units": "degrees North"},
1128
- )
1129
- # lon_q = xr.Variable(
1130
- # data=lonq * 180 / np.pi,
1131
- # dims=["eta_psi", "xi_psi"],
1132
- # attrs={"long_name": "longitude of psi-points", "units": "degrees East"},
1133
- # )
1134
- # lat_q = xr.Variable(
1135
- # data=latq * 180 / np.pi,
1136
- # dims=["eta_psi", "xi_psi"],
1137
- # attrs={"long_name": "latitude of psi-points", "units": "degrees North"},
1138
- # )
1139
-
1140
- ds = ds.assign_coords(
1141
- {
1142
- "lat_rho": lat_rho,
1143
- "lon_rho": lon_rho,
1144
- "lat_u": lat_u,
1145
- "lon_u": lon_u,
1146
- "lat_v": lat_v,
1147
- "lon_v": lon_v,
1148
- # "lat_psi": lat_q,
1149
- # "lon_psi": lon_q,
1150
- }
1151
- )
1152
-
1153
- ds["angle"] = xr.Variable(
1154
- data=angle,
1155
- dims=["eta_rho", "xi_rho"],
1156
- attrs={"long_name": "Angle between xi axis and east", "units": "radians"},
1157
- )
1158
-
1159
- # Coriolis frequency
1160
- f0 = 4 * np.pi * np.sin(lat) / (24 * 3600)
1161
-
1162
- ds["f"] = xr.Variable(
1163
- data=f0,
1164
- dims=["eta_rho", "xi_rho"],
1165
- attrs={"long_name": "Coriolis parameter at rho-points", "units": "second-1"},
1166
- )
1167
-
1168
- ds["pm"] = xr.Variable(
1169
- data=pm,
1170
- dims=["eta_rho", "xi_rho"],
1171
- attrs={
1172
- "long_name": "Curvilinear coordinate metric in xi-direction",
1173
- "units": "meter-1",
1174
- },
1175
- )
1176
- ds["pn"] = xr.Variable(
1177
- data=pn,
1178
- dims=["eta_rho", "xi_rho"],
1179
- attrs={
1180
- "long_name": "Curvilinear coordinate metric in eta-direction",
1181
- "units": "meter-1",
1182
- },
1183
- )
1184
-
1185
- return ds
1186
-
1187
-
1188
- def _add_global_metadata(ds, size_x, size_y, center_lon, center_lat, rot):
1189
-
1190
- ds["spherical"] = xr.DataArray(np.array("T", dtype="S1"))
1191
- ds["spherical"].attrs["Long_name"] = "Grid type logical switch"
1192
- ds["spherical"].attrs["option_T"] = "spherical"
1193
-
1194
- ds.attrs["title"] = "ROMS grid created by ROMS-Tools"
1195
-
1196
- # Include the version of roms-tools
1197
- try:
1198
- roms_tools_version = importlib.metadata.version("roms-tools")
1199
- except importlib.metadata.PackageNotFoundError:
1200
- roms_tools_version = "unknown"
1201
-
1202
- ds.attrs["roms_tools_version"] = roms_tools_version
1203
- ds.attrs["size_x"] = size_x
1204
- ds.attrs["size_y"] = size_y
1205
- ds.attrs["center_lon"] = center_lon
1206
- ds.attrs["center_lat"] = center_lat
1207
- ds.attrs["rot"] = rot
1208
-
1209
- return ds
1210
-
1211
-
1212
1375
  def _f2c(f):
1213
1376
  """Coarsen input xarray DataArray f in both x- and y-direction.
1214
1377