roms-tools 2.0.0__py3-none-any.whl → 2.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. roms_tools/__init__.py +2 -1
  2. roms_tools/setup/boundary_forcing.py +22 -32
  3. roms_tools/setup/datasets.py +19 -21
  4. roms_tools/setup/grid.py +253 -139
  5. roms_tools/setup/initial_conditions.py +29 -6
  6. roms_tools/setup/mask.py +50 -4
  7. roms_tools/setup/nesting.py +575 -0
  8. roms_tools/setup/plot.py +214 -55
  9. roms_tools/setup/river_forcing.py +125 -29
  10. roms_tools/setup/surface_forcing.py +33 -12
  11. roms_tools/setup/tides.py +31 -6
  12. roms_tools/setup/topography.py +168 -35
  13. roms_tools/setup/utils.py +137 -21
  14. roms_tools/tests/test_setup/test_boundary_forcing.py +7 -5
  15. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/.zmetadata +2 -3
  16. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/river_tracer/.zattrs +1 -2
  17. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/tracer_name/.zarray +1 -1
  18. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/tracer_name/0 +0 -0
  19. roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/.zmetadata +5 -6
  20. roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_tracer/.zarray +2 -2
  21. roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_tracer/.zattrs +1 -2
  22. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/river_tracer/0.0.0 +0 -0
  23. roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/tracer_name/.zarray +2 -2
  24. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/tracer_name/0 +0 -0
  25. roms_tools/tests/test_setup/test_datasets.py +2 -2
  26. roms_tools/tests/test_setup/test_initial_conditions.py +6 -6
  27. roms_tools/tests/test_setup/test_nesting.py +489 -0
  28. roms_tools/tests/test_setup/test_river_forcing.py +50 -13
  29. roms_tools/tests/test_setup/test_surface_forcing.py +9 -8
  30. roms_tools/tests/test_setup/test_tides.py +5 -5
  31. roms_tools/tests/test_setup/test_validation.py +2 -2
  32. {roms_tools-2.0.0.dist-info → roms_tools-2.2.0.dist-info}/METADATA +9 -5
  33. {roms_tools-2.0.0.dist-info → roms_tools-2.2.0.dist-info}/RECORD +54 -53
  34. {roms_tools-2.0.0.dist-info → roms_tools-2.2.0.dist-info}/WHEEL +1 -1
  35. roms_tools/_version.py +0 -2
  36. roms_tools/tests/test_setup/test_data/river_forcing.zarr/river_tracer/0.0.0 +0 -0
  37. roms_tools/tests/test_setup/test_data/river_forcing.zarr/tracer_name/0 +0 -0
  38. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/.zattrs +0 -0
  39. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/.zgroup +0 -0
  40. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/abs_time/.zarray +0 -0
  41. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/abs_time/.zattrs +0 -0
  42. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/abs_time/0 +0 -0
  43. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/month/.zarray +0 -0
  44. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/month/.zattrs +0 -0
  45. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/month/0 +0 -0
  46. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_name/.zarray +0 -0
  47. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_name/.zattrs +0 -0
  48. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_name/0 +0 -0
  49. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_time/.zarray +0 -0
  50. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_time/.zattrs +0 -0
  51. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_time/0 +0 -0
  52. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_volume/.zarray +0 -0
  53. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_volume/.zattrs +0 -0
  54. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_volume/0.0 +0 -0
  55. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/tracer_name/.zattrs +0 -0
  56. {roms_tools-2.0.0.dist-info → roms_tools-2.2.0.dist-info}/LICENSE +0 -0
  57. {roms_tools-2.0.0.dist-info → roms_tools-2.2.0.dist-info}/top_level.txt +0 -0
@@ -66,6 +66,10 @@ class SurfaceForcing:
66
66
  Reference date for the model. Default is January 1, 2000.
67
67
  use_dask: bool, optional
68
68
  Indicates whether to use dask for processing. If True, data is processed with dask; if False, data is processed eagerly. Defaults to False.
69
+ bypass_validation: bool, optional
70
+ Indicates whether to skip validation checks in the processed data. When set to True,
71
+ the validation process that ensures no NaN values exist at wet points
72
+ in the processed dataset is bypassed. Defaults to False.
69
73
 
70
74
  Examples
71
75
  --------
@@ -88,6 +92,7 @@ class SurfaceForcing:
88
92
  use_coarse_grid: bool = False
89
93
  model_reference_date: datetime = datetime(2000, 1, 1)
90
94
  use_dask: bool = False
95
+ bypass_validation: bool = False
91
96
 
92
97
  ds: xr.Dataset = field(init=False, repr=False)
93
98
 
@@ -117,7 +122,7 @@ class SurfaceForcing:
117
122
  data.ds[data.var_names[var_name]]
118
123
  )
119
124
 
120
- # rotation of velocities and interpolation to u/v points
125
+ # rotation of velocities
121
126
  if "uwnd" in self.variable_info and "vwnd" in self.variable_info:
122
127
  processed_fields["uwnd"], processed_fields["vwnd"] = rotate_velocities(
123
128
  processed_fields["uwnd"],
@@ -134,7 +139,8 @@ class SurfaceForcing:
134
139
 
135
140
  ds = self._write_into_dataset(processed_fields, data, d_meta)
136
141
 
137
- self._validate(ds)
142
+ if not self.bypass_validation:
143
+ self._validate(ds)
138
144
 
139
145
  # substitute NaNs over land by a fill value to avoid blow-up of ROMS
140
146
  for var_name in ds.data_vars:
@@ -262,10 +268,8 @@ class SurfaceForcing:
262
268
  correction_data = self._get_correction_data()
263
269
  # choose same subdomain as forcing data so that we can use same mask
264
270
  coords_correction = {
265
- correction_data.dim_names["latitude"]: data.ds[data.dim_names["latitude"]],
266
- correction_data.dim_names["longitude"]: data.ds[
267
- data.dim_names["longitude"]
268
- ],
271
+ "lat": data.ds[data.dim_names["latitude"]],
272
+ "lon": data.ds[data.dim_names["longitude"]],
269
273
  }
270
274
  correction_data.choose_subdomain(
271
275
  coords_correction, straddle=self.target_coords["straddle"]
@@ -287,6 +291,8 @@ class SurfaceForcing:
287
291
 
288
292
  processed_fields["swrad"] = processed_fields["swrad"] * corr_factor
289
293
 
294
+ del corr_factor
295
+
290
296
  return processed_fields
291
297
 
292
298
  def _write_into_dataset(self, processed_fields, data, d_meta):
@@ -294,8 +300,9 @@ class SurfaceForcing:
294
300
  # save in new dataset
295
301
  ds = xr.Dataset()
296
302
 
297
- for var_name in processed_fields.keys():
303
+ for var_name in list(processed_fields.keys()):
298
304
  ds[var_name] = processed_fields[var_name].astype(np.float32)
305
+ del processed_fields[var_name]
299
306
  ds[var_name].attrs["long_name"] = d_meta[var_name]["long_name"]
300
307
  ds[var_name].attrs["units"] = d_meta[var_name]["units"]
301
308
 
@@ -349,9 +356,14 @@ class SurfaceForcing:
349
356
  """
350
357
 
351
358
  for var_name in ds.data_vars:
352
- # Only validate variables based on "validate" flag if use_dask is False
353
- if not self.use_dask or self.variable_info[var_name]["validate"]:
354
- nan_check(ds[var_name].isel(time=0), self.target_coords["mask"])
359
+ if self.variable_info[var_name]["validate"]:
360
+ if self.variable_info[var_name]["location"] == "rho":
361
+ mask = self.target_coords["mask"]
362
+ elif self.variable_info[var_name]["location"] == "u":
363
+ mask = self.target_coords["mask_u"]
364
+ elif self.variable_info[var_name]["location"] == "v":
365
+ mask = self.target_coords["mask_v"]
366
+ nan_check(ds[var_name].isel(time=0), mask)
355
367
 
356
368
  def _add_global_metadata(self, ds=None):
357
369
 
@@ -541,7 +553,10 @@ class SurfaceForcing:
541
553
 
542
554
  @classmethod
543
555
  def from_yaml(
544
- cls, filepath: Union[str, Path], use_dask: bool = False
556
+ cls,
557
+ filepath: Union[str, Path],
558
+ use_dask: bool = False,
559
+ bypass_validation: bool = False,
545
560
  ) -> "SurfaceForcing":
546
561
  """Create an instance of the SurfaceForcing class from a YAML file.
547
562
 
@@ -551,6 +566,10 @@ class SurfaceForcing:
551
566
  The path to the YAML file from which the parameters will be read.
552
567
  use_dask: bool, optional
553
568
  Indicates whether to use dask for processing. If True, data is processed with dask; if False, data is processed eagerly. Defaults to False.
569
+ bypass_validation: bool, optional
570
+ Indicates whether to skip validation checks in the processed data. When set to True,
571
+ the validation process that ensures no NaN values exist at wet points
572
+ in the processed dataset is bypassed. Defaults to False.
554
573
 
555
574
  Returns
556
575
  -------
@@ -562,4 +581,6 @@ class SurfaceForcing:
562
581
  grid = Grid.from_yaml(filepath)
563
582
  params = _from_yaml(cls, filepath)
564
583
 
565
- return cls(grid=grid, **params, use_dask=use_dask)
584
+ return cls(
585
+ grid=grid, **params, use_dask=use_dask, bypass_validation=bypass_validation
586
+ )
roms_tools/setup/tides.py CHANGED
@@ -51,6 +51,10 @@ class TidalForcing:
51
51
  The reference date for the ROMS simulation. Default is datetime(2000, 1, 1).
52
52
  use_dask: bool, optional
53
53
  Indicates whether to use dask for processing. If True, data is processed with dask; if False, data is processed eagerly. Defaults to False.
54
+ bypass_validation: bool, optional
55
+ Indicates whether to skip validation checks in the processed data. When set to True,
56
+ the validation process that ensures no NaN values exist at wet points
57
+ in the processed dataset is bypassed. Defaults to False.
54
58
 
55
59
  Examples
56
60
  --------
@@ -65,6 +69,7 @@ class TidalForcing:
65
69
  allan_factor: float = 2.0
66
70
  model_reference_date: datetime = datetime(2000, 1, 1)
67
71
  use_dask: bool = False
72
+ bypass_validation: bool = False
68
73
 
69
74
  ds: xr.Dataset = field(init=False, repr=False)
70
75
 
@@ -129,7 +134,8 @@ class TidalForcing:
129
134
 
130
135
  ds = self._add_global_metadata(ds)
131
136
 
132
- self._validate(ds)
137
+ if not self.bypass_validation:
138
+ self._validate(ds)
133
139
 
134
140
  # substitute NaNs over land by a fill value to avoid blow-up of ROMS
135
141
  for var_name in ds.data_vars:
@@ -264,9 +270,16 @@ class TidalForcing:
264
270
  The method utilizes `self.grid.ds.mask_rho` to determine the wet points in the domain.
265
271
  """
266
272
  for var_name in ds.data_vars:
267
- # only validate variables based on "validate" flag if use_dask is false
268
- if not self.use_dask or self.variable_info[var_name]["validate"]:
269
- nan_check(ds[var_name].isel(ntides=0), self.grid.ds.mask_rho)
273
+ if self.variable_info[var_name]["validate"]:
274
+ if self.variable_info[var_name]["location"] == "rho":
275
+ mask = self.grid.ds.mask_rho
276
+ elif self.variable_info[var_name]["location"] == "u":
277
+ mask = self.grid.ds.mask_u
278
+ elif self.variable_info[var_name]["location"] == "v":
279
+ mask = self.grid.ds.mask_v
280
+
281
+ da = ds[var_name].isel(ntides=0)
282
+ nan_check(da, mask)
270
283
 
271
284
  def plot(self, var_name, ntides=0) -> None:
272
285
  """Plot the specified tidal forcing variable for a given tidal constituent.
@@ -420,7 +433,10 @@ class TidalForcing:
420
433
 
421
434
  @classmethod
422
435
  def from_yaml(
423
- cls, filepath: Union[str, Path], use_dask: bool = False
436
+ cls,
437
+ filepath: Union[str, Path],
438
+ use_dask: bool = False,
439
+ bypass_validation: bool = False,
424
440
  ) -> "TidalForcing":
425
441
  """Create an instance of the TidalForcing class from a YAML file.
426
442
 
@@ -430,6 +446,10 @@ class TidalForcing:
430
446
  The path to the YAML file from which the parameters will be read.
431
447
  use_dask: bool, optional
432
448
  Indicates whether to use dask for processing. If True, data is processed with dask; if False, data is processed eagerly. Defaults to False.
449
+ bypass_validation: bool, optional
450
+ Indicates whether to skip validation checks in the processed data. When set to True,
451
+ the validation process that ensures no NaN values exist at wet points
452
+ in the processed dataset is bypassed. Defaults to False.
433
453
 
434
454
  Returns
435
455
  -------
@@ -440,7 +460,12 @@ class TidalForcing:
440
460
 
441
461
  grid = Grid.from_yaml(filepath)
442
462
  tidal_forcing_params = _from_yaml(cls, filepath)
443
- return cls(grid=grid, **tidal_forcing_params, use_dask=use_dask)
463
+ return cls(
464
+ grid=grid,
465
+ **tidal_forcing_params,
466
+ use_dask=use_dask,
467
+ bypass_validation=bypass_validation
468
+ )
444
469
 
445
470
  def _correct_tides(self, data):
446
471
  """Apply tidal corrections to the dataset. This method corrects the dataset for
@@ -94,7 +94,18 @@ def _add_topography(
94
94
 
95
95
 
96
96
  def _get_topography_data(source):
97
+ """Load topography data based on the specified source.
97
98
 
99
+ Parameters
100
+ ----------
101
+ source : dict
102
+ A dictionary containing the source details (e.g., "name" and "path").
103
+
104
+ Returns
105
+ -------
106
+ data : object
107
+ The loaded topography dataset (ETOPO5 or SRTM15).
108
+ """
98
109
  kwargs = {"use_dask": False}
99
110
 
100
111
  if source["name"] == "ETOPO5":
@@ -115,7 +126,24 @@ def _get_topography_data(source):
115
126
  def _make_raw_topography(
116
127
  data, target_coords, method="linear", verbose=False
117
128
  ) -> xr.DataArray:
129
+ """Regrid topography data to match target coordinates.
130
+
131
+ Parameters
132
+ ----------
133
+ data : object
134
+ The dataset object containing the topography data.
135
+ target_coords : object
136
+ The target coordinates to which the data will be regridded.
137
+ method : str, optional
138
+ The regridding method to use, by default "linear".
139
+ verbose : bool, optional
140
+ If True, logs the time taken for regridding, by default False.
118
141
 
142
+ Returns
143
+ -------
144
+ xr.DataArray
145
+ The regridded topography data with the sign flipped (bathymetry positive).
146
+ """
119
147
  data.choose_subdomain(target_coords, buffer_points=3, verbose=verbose)
120
148
 
121
149
  if verbose:
@@ -134,9 +162,22 @@ def _make_raw_topography(
134
162
 
135
163
 
136
164
  def _smooth_topography_globally(hraw, factor) -> xr.DataArray:
165
+ """Apply global smoothing to the topography using a Gaussian filter.
166
+
167
+ Parameters
168
+ ----------
169
+ hraw : xr.DataArray
170
+ The raw topography data to be smoothed.
171
+ factor : float
172
+ The smoothing factor (controls the width of the Gaussian filter).
173
+
174
+ Returns
175
+ -------
176
+ xr.DataArray
177
+ The smoothed topography data.
178
+ """
137
179
  # since GCM-Filters assumes periodic domain, we extend the domain by one grid cell in each dimension
138
180
  # and set that margin to land
139
-
140
181
  mask = xr.ones_like(hraw)
141
182
  margin_mask = xr.concat([mask, 0 * mask.isel(eta_rho=-1)], dim="eta_rho")
142
183
  margin_mask = xr.concat(
@@ -164,7 +205,27 @@ def _smooth_topography_globally(hraw, factor) -> xr.DataArray:
164
205
 
165
206
 
166
207
  def _smooth_topography_locally(h, hmin=5, rmax=0.2):
167
- """Smoothes topography locally to satisfy r < rmax."""
208
+ """Smooths topography locally to ensure the slope (r-factor) is below the specified
209
+ threshold.
210
+
211
+ This function applies a logarithmic transformation to the topography and iteratively smooths
212
+ it in four directions (eta, xi, and two diagonals) until the maximum slope parameter (r) is
213
+ below `rmax`. A threshold `hmin` is applied to prevent values from going below a minimum height.
214
+
215
+ Parameters
216
+ ----------
217
+ h : xarray.DataArray
218
+ The topography data to be smoothed.
219
+ hmin : float, optional
220
+ The minimum height threshold. Default is 5.
221
+ rmax : float, optional
222
+ The maximum allowable slope parameter (r-factor). Default is 0.2.
223
+
224
+ Returns
225
+ -------
226
+ xarray.DataArray
227
+ The smoothed topography data.
228
+ """
168
229
  # Compute rmax_log
169
230
  if rmax > 0.0:
170
231
  rmax_log = np.log((1.0 + rmax * 0.9) / (1.0 - rmax * 0.9))
@@ -174,65 +235,90 @@ def _smooth_topography_locally(h, hmin=5, rmax=0.2):
174
235
  # Apply hmin threshold
175
236
  h = xr.where(h < hmin, hmin, h)
176
237
 
177
- # We will smooth logarithmically
238
+ # Perform logarithmic transformation of the height field
178
239
  h_log = np.log(h / hmin)
179
240
 
180
- cf1 = 1.0 / 6
181
- cf2 = 0.25
241
+ # Constants for smoothing
242
+ smoothing_factor_1 = 1.0 / 6
243
+ smoothing_factor_2 = 0.25
182
244
 
245
+ # Iterate until convergence
183
246
  for iter in count():
184
- # Compute gradients in domain interior
247
+ # Compute gradients and smoothing for eta, xi, and diagonal directions
185
248
 
186
- # in eta-direction
187
- cff = h_log.diff("eta_rho").isel(xi_rho=slice(1, -1))
188
- cr = np.abs(cff)
249
+ # Gradient in eta-direction
250
+ delta_eta = h_log.diff("eta_rho").isel(xi_rho=slice(1, -1))
251
+ abs_eta_gradient = np.abs(delta_eta)
189
252
  with warnings.catch_warnings():
190
253
  warnings.simplefilter("ignore") # Ignore division by zero warning
191
- Op1 = xr.where(cr < rmax_log, 0, 1.0 * cff * (1 - rmax_log / cr))
192
-
193
- # in xi-direction
194
- cff = h_log.diff("xi_rho").isel(eta_rho=slice(1, -1))
195
- cr = np.abs(cff)
254
+ eta_correction = xr.where(
255
+ abs_eta_gradient < rmax_log,
256
+ 0,
257
+ delta_eta * (1 - rmax_log / abs_eta_gradient),
258
+ )
259
+
260
+ # Gradient in xi-direction
261
+ delta_xi = h_log.diff("xi_rho").isel(eta_rho=slice(1, -1))
262
+ abs_xi_gradient = np.abs(delta_xi)
196
263
  with warnings.catch_warnings():
197
264
  warnings.simplefilter("ignore") # Ignore division by zero warning
198
- Op2 = xr.where(cr < rmax_log, 0, 1.0 * cff * (1 - rmax_log / cr))
199
-
200
- # in diagonal direction
201
- cff = (h_log - h_log.shift(eta_rho=1, xi_rho=1)).isel(
265
+ xi_correction = xr.where(
266
+ abs_xi_gradient < rmax_log,
267
+ 0,
268
+ delta_xi * (1 - rmax_log / abs_xi_gradient),
269
+ )
270
+
271
+ # Gradient in first diagonal direction
272
+ delta_diag_1 = (h_log - h_log.shift(eta_rho=1, xi_rho=1)).isel(
202
273
  eta_rho=slice(1, None), xi_rho=slice(1, None)
203
274
  )
204
- cr = np.abs(cff)
275
+ abs_diag_1_gradient = np.abs(delta_diag_1)
205
276
  with warnings.catch_warnings():
206
277
  warnings.simplefilter("ignore") # Ignore division by zero warning
207
- Op3 = xr.where(cr < rmax_log, 0, 1.0 * cff * (1 - rmax_log / cr))
208
-
209
- # in the other diagonal direction
210
- cff = (h_log.shift(eta_rho=1) - h_log.shift(xi_rho=1)).isel(
278
+ diag_1_correction = xr.where(
279
+ abs_diag_1_gradient < rmax_log,
280
+ 0,
281
+ delta_diag_1 * (1 - rmax_log / abs_diag_1_gradient),
282
+ )
283
+
284
+ # Gradient in second diagonal direction
285
+ delta_diag_2 = (h_log.shift(eta_rho=1) - h_log.shift(xi_rho=1)).isel(
211
286
  eta_rho=slice(1, None), xi_rho=slice(1, None)
212
287
  )
213
- cr = np.abs(cff)
288
+ abs_diag_2_gradient = np.abs(delta_diag_2)
214
289
  with warnings.catch_warnings():
215
290
  warnings.simplefilter("ignore") # Ignore division by zero warning
216
- Op4 = xr.where(cr < rmax_log, 0, 1.0 * cff * (1 - rmax_log / cr))
291
+ diag_2_correction = xr.where(
292
+ abs_diag_2_gradient < rmax_log,
293
+ 0,
294
+ delta_diag_2 * (1 - rmax_log / abs_diag_2_gradient),
295
+ )
217
296
 
218
297
  # Update h_log in domain interior
219
- h_log[1:-1, 1:-1] += cf1 * (
220
- Op1[1:, :]
221
- - Op1[:-1, :]
222
- + Op2[:, 1:]
223
- - Op2[:, :-1]
224
- + cf2 * (Op3[1:, 1:] - Op3[:-1, :-1] + Op4[:-1, 1:] - Op4[1:, :-1])
298
+ h_log[1:-1, 1:-1] += smoothing_factor_1 * (
299
+ eta_correction[1:, :]
300
+ - eta_correction[:-1, :]
301
+ + xi_correction[:, 1:]
302
+ - xi_correction[:, :-1]
303
+ + smoothing_factor_2
304
+ * (
305
+ diag_1_correction[1:, 1:]
306
+ - diag_1_correction[:-1, :-1]
307
+ + diag_2_correction[:-1, 1:]
308
+ - diag_2_correction[1:, :-1]
309
+ )
225
310
  )
226
311
 
227
312
  # No gradient at the domain boundaries
228
313
  h_log = handle_boundaries(h_log)
229
314
 
230
- # Update h
315
+ # Recompute the topography after smoothing
231
316
  h = hmin * np.exp(h_log)
317
+
232
318
  # Apply hmin threshold again
233
319
  h = xr.where(h < hmin, hmin, h)
234
320
 
235
- # compute maximum slope parameter r
321
+ # Compute maximum slope parameter r
236
322
  r_eta, r_xi = _compute_rfactor(h)
237
323
  rmax0 = np.max([r_eta.max(), r_xi.max()])
238
324
  if rmax0 < rmax:
@@ -242,8 +328,23 @@ def _smooth_topography_locally(h, hmin=5, rmax=0.2):
242
328
 
243
329
 
244
330
  def _compute_rfactor(h):
245
- """Computes slope parameter (or r-factor) r = |Delta h| / 2h in both horizontal grid
246
- directions."""
331
+ """Computes the slope parameter (r-factor) in both horizontal directions.
332
+
333
+ The r-factor is calculated as |Δh| / (2h) in the eta and xi directions:
334
+ - r_eta = |h_i - h_{i-1}| / (h_i + h_{i+1})
335
+ - r_xi = |h_i - h_{i-1}| / (h_i + h_{i+1})
336
+
337
+ Parameters
338
+ ----------
339
+ h : xarray.DataArray
340
+ The topography data.
341
+
342
+ Returns
343
+ -------
344
+ tuple of xarray.DataArray
345
+ r_eta : r-factor in the eta direction.
346
+ r_xi : r-factor in the xi direction.
347
+ """
247
348
  # compute r_{i-1/2} = |h_i - h_{i-1}| / (h_i + h_{i+1})
248
349
  r_eta = np.abs(h.diff("eta_rho")) / (h + h.shift(eta_rho=1)).isel(
249
350
  eta_rho=slice(1, None)
@@ -256,6 +357,26 @@ def _compute_rfactor(h):
256
357
 
257
358
 
258
359
  def _add_topography_metadata(ds, topography_source, smooth_factor, hmin, rmax):
360
+ """Adds topography metadata to the dataset.
361
+
362
+ Parameters
363
+ ----------
364
+ ds : xarray.Dataset
365
+ Dataset to update.
366
+ topography_source : dict
367
+ Dictionary with topography source information (requires 'name' key).
368
+ smooth_factor : float
369
+ Smoothing factor (unused in this function).
370
+ hmin : float
371
+ Minimum height threshold for smoothing.
372
+ rmax : float
373
+ Maximum slope parameter (unused in this function).
374
+
375
+ Returns
376
+ -------
377
+ xarray.Dataset
378
+ Updated dataset with added metadata.
379
+ """
259
380
  ds.attrs["topography_source"] = topography_source["name"]
260
381
  ds.attrs["hmin"] = hmin
261
382
 
@@ -263,6 +384,18 @@ def _add_topography_metadata(ds, topography_source, smooth_factor, hmin, rmax):
263
384
 
264
385
 
265
386
  def nan_check(hraw):
387
+ """Checks for NaN values in the topography data.
388
+
389
+ Parameters
390
+ ----------
391
+ hraw : xarray.DataArray
392
+ Input topography data to check for NaN values.
393
+
394
+ Raises
395
+ ------
396
+ ValueError
397
+ If NaN values are found in the data, raises an error with a descriptive message.
398
+ """
266
399
  error_message = (
267
400
  "NaN values found in regridded topography. This likely occurs because the ROMS grid, including "
268
401
  "a small safety margin for interpolation, is not fully contained within the topography dataset's longitude/latitude range. Please ensure that the "