roms-tools 2.0.0__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. roms_tools/__init__.py +2 -1
  2. roms_tools/setup/boundary_forcing.py +21 -30
  3. roms_tools/setup/datasets.py +13 -21
  4. roms_tools/setup/grid.py +253 -139
  5. roms_tools/setup/initial_conditions.py +21 -3
  6. roms_tools/setup/mask.py +50 -4
  7. roms_tools/setup/nesting.py +575 -0
  8. roms_tools/setup/plot.py +214 -55
  9. roms_tools/setup/river_forcing.py +125 -29
  10. roms_tools/setup/surface_forcing.py +21 -8
  11. roms_tools/setup/tides.py +21 -3
  12. roms_tools/setup/topography.py +168 -35
  13. roms_tools/setup/utils.py +127 -21
  14. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/.zmetadata +2 -3
  15. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/river_tracer/.zattrs +1 -2
  16. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/tracer_name/.zarray +1 -1
  17. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/tracer_name/0 +0 -0
  18. roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/.zmetadata +5 -6
  19. roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_tracer/.zarray +2 -2
  20. roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_tracer/.zattrs +1 -2
  21. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/river_tracer/0.0.0 +0 -0
  22. roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/tracer_name/.zarray +2 -2
  23. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/tracer_name/0 +0 -0
  24. roms_tools/tests/test_setup/test_datasets.py +2 -2
  25. roms_tools/tests/test_setup/test_nesting.py +489 -0
  26. roms_tools/tests/test_setup/test_river_forcing.py +50 -13
  27. roms_tools/tests/test_setup/test_surface_forcing.py +1 -0
  28. roms_tools/tests/test_setup/test_validation.py +2 -2
  29. {roms_tools-2.0.0.dist-info → roms_tools-2.1.0.dist-info}/METADATA +8 -4
  30. {roms_tools-2.0.0.dist-info → roms_tools-2.1.0.dist-info}/RECORD +51 -50
  31. {roms_tools-2.0.0.dist-info → roms_tools-2.1.0.dist-info}/WHEEL +1 -1
  32. roms_tools/_version.py +0 -2
  33. roms_tools/tests/test_setup/test_data/river_forcing.zarr/river_tracer/0.0.0 +0 -0
  34. roms_tools/tests/test_setup/test_data/river_forcing.zarr/tracer_name/0 +0 -0
  35. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/.zattrs +0 -0
  36. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/.zgroup +0 -0
  37. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/abs_time/.zarray +0 -0
  38. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/abs_time/.zattrs +0 -0
  39. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/abs_time/0 +0 -0
  40. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/month/.zarray +0 -0
  41. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/month/.zattrs +0 -0
  42. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/month/0 +0 -0
  43. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_name/.zarray +0 -0
  44. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_name/.zattrs +0 -0
  45. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_name/0 +0 -0
  46. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_time/.zarray +0 -0
  47. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_time/.zattrs +0 -0
  48. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_time/0 +0 -0
  49. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_volume/.zarray +0 -0
  50. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_volume/.zattrs +0 -0
  51. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_volume/0.0 +0 -0
  52. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/tracer_name/.zattrs +0 -0
  53. {roms_tools-2.0.0.dist-info → roms_tools-2.1.0.dist-info}/LICENSE +0 -0
  54. {roms_tools-2.0.0.dist-info → roms_tools-2.1.0.dist-info}/top_level.txt +0 -0
roms_tools/setup/tides.py CHANGED
@@ -51,6 +51,10 @@ class TidalForcing:
51
51
  The reference date for the ROMS simulation. Default is datetime(2000, 1, 1).
52
52
  use_dask: bool, optional
53
53
  Indicates whether to use dask for processing. If True, data is processed with dask; if False, data is processed eagerly. Defaults to False.
54
+ bypass_validation: bool, optional
55
+ Indicates whether to skip validation checks in the processed data. When set to True,
56
+ the validation process that ensures no NaN values exist at wet points
57
+ in the processed dataset is bypassed. Defaults to False.
54
58
 
55
59
  Examples
56
60
  --------
@@ -65,6 +69,7 @@ class TidalForcing:
65
69
  allan_factor: float = 2.0
66
70
  model_reference_date: datetime = datetime(2000, 1, 1)
67
71
  use_dask: bool = False
72
+ bypass_validation: bool = False
68
73
 
69
74
  ds: xr.Dataset = field(init=False, repr=False)
70
75
 
@@ -129,7 +134,8 @@ class TidalForcing:
129
134
 
130
135
  ds = self._add_global_metadata(ds)
131
136
 
132
- self._validate(ds)
137
+ if not self.bypass_validation:
138
+ self._validate(ds)
133
139
 
134
140
  # substitute NaNs over land by a fill value to avoid blow-up of ROMS
135
141
  for var_name in ds.data_vars:
@@ -420,7 +426,10 @@ class TidalForcing:
420
426
 
421
427
  @classmethod
422
428
  def from_yaml(
423
- cls, filepath: Union[str, Path], use_dask: bool = False
429
+ cls,
430
+ filepath: Union[str, Path],
431
+ use_dask: bool = False,
432
+ bypass_validation: bool = False,
424
433
  ) -> "TidalForcing":
425
434
  """Create an instance of the TidalForcing class from a YAML file.
426
435
 
@@ -430,6 +439,10 @@ class TidalForcing:
430
439
  The path to the YAML file from which the parameters will be read.
431
440
  use_dask: bool, optional
432
441
  Indicates whether to use dask for processing. If True, data is processed with dask; if False, data is processed eagerly. Defaults to False.
442
+ bypass_validation: bool, optional
443
+ Indicates whether to skip validation checks in the processed data. When set to True,
444
+ the validation process that ensures no NaN values exist at wet points
445
+ in the processed dataset is bypassed. Defaults to False.
433
446
 
434
447
  Returns
435
448
  -------
@@ -440,7 +453,12 @@ class TidalForcing:
440
453
 
441
454
  grid = Grid.from_yaml(filepath)
442
455
  tidal_forcing_params = _from_yaml(cls, filepath)
443
- return cls(grid=grid, **tidal_forcing_params, use_dask=use_dask)
456
+ return cls(
457
+ grid=grid,
458
+ **tidal_forcing_params,
459
+ use_dask=use_dask,
460
+ bypass_validation=bypass_validation
461
+ )
444
462
 
445
463
  def _correct_tides(self, data):
446
464
  """Apply tidal corrections to the dataset. This method corrects the dataset for
@@ -94,7 +94,18 @@ def _add_topography(
94
94
 
95
95
 
96
96
  def _get_topography_data(source):
97
+ """Load topography data based on the specified source.
97
98
 
99
+ Parameters
100
+ ----------
101
+ source : dict
102
+ A dictionary containing the source details (e.g., "name" and "path").
103
+
104
+ Returns
105
+ -------
106
+ data : object
107
+ The loaded topography dataset (ETOPO5 or SRTM15).
108
+ """
98
109
  kwargs = {"use_dask": False}
99
110
 
100
111
  if source["name"] == "ETOPO5":
@@ -115,7 +126,24 @@ def _get_topography_data(source):
115
126
  def _make_raw_topography(
116
127
  data, target_coords, method="linear", verbose=False
117
128
  ) -> xr.DataArray:
129
+ """Regrid topography data to match target coordinates.
130
+
131
+ Parameters
132
+ ----------
133
+ data : object
134
+ The dataset object containing the topography data.
135
+ target_coords : object
136
+ The target coordinates to which the data will be regridded.
137
+ method : str, optional
138
+ The regridding method to use, by default "linear".
139
+ verbose : bool, optional
140
+ If True, logs the time taken for regridding, by default False.
118
141
 
142
+ Returns
143
+ -------
144
+ xr.DataArray
145
+ The regridded topography data with the sign flipped (bathymetry positive).
146
+ """
119
147
  data.choose_subdomain(target_coords, buffer_points=3, verbose=verbose)
120
148
 
121
149
  if verbose:
@@ -134,9 +162,22 @@ def _make_raw_topography(
134
162
 
135
163
 
136
164
  def _smooth_topography_globally(hraw, factor) -> xr.DataArray:
165
+ """Apply global smoothing to the topography using a Gaussian filter.
166
+
167
+ Parameters
168
+ ----------
169
+ hraw : xr.DataArray
170
+ The raw topography data to be smoothed.
171
+ factor : float
172
+ The smoothing factor (controls the width of the Gaussian filter).
173
+
174
+ Returns
175
+ -------
176
+ xr.DataArray
177
+ The smoothed topography data.
178
+ """
137
179
  # since GCM-Filters assumes periodic domain, we extend the domain by one grid cell in each dimension
138
180
  # and set that margin to land
139
-
140
181
  mask = xr.ones_like(hraw)
141
182
  margin_mask = xr.concat([mask, 0 * mask.isel(eta_rho=-1)], dim="eta_rho")
142
183
  margin_mask = xr.concat(
@@ -164,7 +205,27 @@ def _smooth_topography_globally(hraw, factor) -> xr.DataArray:
164
205
 
165
206
 
166
207
  def _smooth_topography_locally(h, hmin=5, rmax=0.2):
167
- """Smoothes topography locally to satisfy r < rmax."""
208
+ """Smooths topography locally to ensure the slope (r-factor) is below the specified
209
+ threshold.
210
+
211
+ This function applies a logarithmic transformation to the topography and iteratively smooths
212
+ it in four directions (eta, xi, and two diagonals) until the maximum slope parameter (r) is
213
+ below `rmax`. A threshold `hmin` is applied to prevent values from going below a minimum height.
214
+
215
+ Parameters
216
+ ----------
217
+ h : xarray.DataArray
218
+ The topography data to be smoothed.
219
+ hmin : float, optional
220
+ The minimum height threshold. Default is 5.
221
+ rmax : float, optional
222
+ The maximum allowable slope parameter (r-factor). Default is 0.2.
223
+
224
+ Returns
225
+ -------
226
+ xarray.DataArray
227
+ The smoothed topography data.
228
+ """
168
229
  # Compute rmax_log
169
230
  if rmax > 0.0:
170
231
  rmax_log = np.log((1.0 + rmax * 0.9) / (1.0 - rmax * 0.9))
@@ -174,65 +235,90 @@ def _smooth_topography_locally(h, hmin=5, rmax=0.2):
174
235
  # Apply hmin threshold
175
236
  h = xr.where(h < hmin, hmin, h)
176
237
 
177
- # We will smooth logarithmically
238
+ # Perform logarithmic transformation of the height field
178
239
  h_log = np.log(h / hmin)
179
240
 
180
- cf1 = 1.0 / 6
181
- cf2 = 0.25
241
+ # Constants for smoothing
242
+ smoothing_factor_1 = 1.0 / 6
243
+ smoothing_factor_2 = 0.25
182
244
 
245
+ # Iterate until convergence
183
246
  for iter in count():
184
- # Compute gradients in domain interior
247
+ # Compute gradients and smoothing for eta, xi, and diagonal directions
185
248
 
186
- # in eta-direction
187
- cff = h_log.diff("eta_rho").isel(xi_rho=slice(1, -1))
188
- cr = np.abs(cff)
249
+ # Gradient in eta-direction
250
+ delta_eta = h_log.diff("eta_rho").isel(xi_rho=slice(1, -1))
251
+ abs_eta_gradient = np.abs(delta_eta)
189
252
  with warnings.catch_warnings():
190
253
  warnings.simplefilter("ignore") # Ignore division by zero warning
191
- Op1 = xr.where(cr < rmax_log, 0, 1.0 * cff * (1 - rmax_log / cr))
192
-
193
- # in xi-direction
194
- cff = h_log.diff("xi_rho").isel(eta_rho=slice(1, -1))
195
- cr = np.abs(cff)
254
+ eta_correction = xr.where(
255
+ abs_eta_gradient < rmax_log,
256
+ 0,
257
+ delta_eta * (1 - rmax_log / abs_eta_gradient),
258
+ )
259
+
260
+ # Gradient in xi-direction
261
+ delta_xi = h_log.diff("xi_rho").isel(eta_rho=slice(1, -1))
262
+ abs_xi_gradient = np.abs(delta_xi)
196
263
  with warnings.catch_warnings():
197
264
  warnings.simplefilter("ignore") # Ignore division by zero warning
198
- Op2 = xr.where(cr < rmax_log, 0, 1.0 * cff * (1 - rmax_log / cr))
199
-
200
- # in diagonal direction
201
- cff = (h_log - h_log.shift(eta_rho=1, xi_rho=1)).isel(
265
+ xi_correction = xr.where(
266
+ abs_xi_gradient < rmax_log,
267
+ 0,
268
+ delta_xi * (1 - rmax_log / abs_xi_gradient),
269
+ )
270
+
271
+ # Gradient in first diagonal direction
272
+ delta_diag_1 = (h_log - h_log.shift(eta_rho=1, xi_rho=1)).isel(
202
273
  eta_rho=slice(1, None), xi_rho=slice(1, None)
203
274
  )
204
- cr = np.abs(cff)
275
+ abs_diag_1_gradient = np.abs(delta_diag_1)
205
276
  with warnings.catch_warnings():
206
277
  warnings.simplefilter("ignore") # Ignore division by zero warning
207
- Op3 = xr.where(cr < rmax_log, 0, 1.0 * cff * (1 - rmax_log / cr))
208
-
209
- # in the other diagonal direction
210
- cff = (h_log.shift(eta_rho=1) - h_log.shift(xi_rho=1)).isel(
278
+ diag_1_correction = xr.where(
279
+ abs_diag_1_gradient < rmax_log,
280
+ 0,
281
+ delta_diag_1 * (1 - rmax_log / abs_diag_1_gradient),
282
+ )
283
+
284
+ # Gradient in second diagonal direction
285
+ delta_diag_2 = (h_log.shift(eta_rho=1) - h_log.shift(xi_rho=1)).isel(
211
286
  eta_rho=slice(1, None), xi_rho=slice(1, None)
212
287
  )
213
- cr = np.abs(cff)
288
+ abs_diag_2_gradient = np.abs(delta_diag_2)
214
289
  with warnings.catch_warnings():
215
290
  warnings.simplefilter("ignore") # Ignore division by zero warning
216
- Op4 = xr.where(cr < rmax_log, 0, 1.0 * cff * (1 - rmax_log / cr))
291
+ diag_2_correction = xr.where(
292
+ abs_diag_2_gradient < rmax_log,
293
+ 0,
294
+ delta_diag_2 * (1 - rmax_log / abs_diag_2_gradient),
295
+ )
217
296
 
218
297
  # Update h_log in domain interior
219
- h_log[1:-1, 1:-1] += cf1 * (
220
- Op1[1:, :]
221
- - Op1[:-1, :]
222
- + Op2[:, 1:]
223
- - Op2[:, :-1]
224
- + cf2 * (Op3[1:, 1:] - Op3[:-1, :-1] + Op4[:-1, 1:] - Op4[1:, :-1])
298
+ h_log[1:-1, 1:-1] += smoothing_factor_1 * (
299
+ eta_correction[1:, :]
300
+ - eta_correction[:-1, :]
301
+ + xi_correction[:, 1:]
302
+ - xi_correction[:, :-1]
303
+ + smoothing_factor_2
304
+ * (
305
+ diag_1_correction[1:, 1:]
306
+ - diag_1_correction[:-1, :-1]
307
+ + diag_2_correction[:-1, 1:]
308
+ - diag_2_correction[1:, :-1]
309
+ )
225
310
  )
226
311
 
227
312
  # No gradient at the domain boundaries
228
313
  h_log = handle_boundaries(h_log)
229
314
 
230
- # Update h
315
+ # Recompute the topography after smoothing
231
316
  h = hmin * np.exp(h_log)
317
+
232
318
  # Apply hmin threshold again
233
319
  h = xr.where(h < hmin, hmin, h)
234
320
 
235
- # compute maximum slope parameter r
321
+ # Compute maximum slope parameter r
236
322
  r_eta, r_xi = _compute_rfactor(h)
237
323
  rmax0 = np.max([r_eta.max(), r_xi.max()])
238
324
  if rmax0 < rmax:
@@ -242,8 +328,23 @@ def _smooth_topography_locally(h, hmin=5, rmax=0.2):
242
328
 
243
329
 
244
330
  def _compute_rfactor(h):
245
- """Computes slope parameter (or r-factor) r = |Delta h| / 2h in both horizontal grid
246
- directions."""
331
+ """Computes the slope parameter (r-factor) in both horizontal directions.
332
+
333
+ The r-factor is calculated as |Δh| / (2h) in the eta and xi directions:
334
+ - r_eta = |h_i - h_{i-1}| / (h_i + h_{i+1})
335
+ - r_xi = |h_i - h_{i-1}| / (h_i + h_{i+1})
336
+
337
+ Parameters
338
+ ----------
339
+ h : xarray.DataArray
340
+ The topography data.
341
+
342
+ Returns
343
+ -------
344
+ tuple of xarray.DataArray
345
+ r_eta : r-factor in the eta direction.
346
+ r_xi : r-factor in the xi direction.
347
+ """
247
348
  # compute r_{i-1/2} = |h_i - h_{i-1}| / (h_i + h_{i+1})
248
349
  r_eta = np.abs(h.diff("eta_rho")) / (h + h.shift(eta_rho=1)).isel(
249
350
  eta_rho=slice(1, None)
@@ -256,6 +357,26 @@ def _compute_rfactor(h):
256
357
 
257
358
 
258
359
  def _add_topography_metadata(ds, topography_source, smooth_factor, hmin, rmax):
360
+ """Adds topography metadata to the dataset.
361
+
362
+ Parameters
363
+ ----------
364
+ ds : xarray.Dataset
365
+ Dataset to update.
366
+ topography_source : dict
367
+ Dictionary with topography source information (requires 'name' key).
368
+ smooth_factor : float
369
+ Smoothing factor (unused in this function).
370
+ hmin : float
371
+ Minimum height threshold for smoothing.
372
+ rmax : float
373
+ Maximum slope parameter (unused in this function).
374
+
375
+ Returns
376
+ -------
377
+ xarray.Dataset
378
+ Updated dataset with added metadata.
379
+ """
259
380
  ds.attrs["topography_source"] = topography_source["name"]
260
381
  ds.attrs["hmin"] = hmin
261
382
 
@@ -263,6 +384,18 @@ def _add_topography_metadata(ds, topography_source, smooth_factor, hmin, rmax):
263
384
 
264
385
 
265
386
  def nan_check(hraw):
387
+ """Checks for NaN values in the topography data.
388
+
389
+ Parameters
390
+ ----------
391
+ hraw : xarray.DataArray
392
+ Input topography data to check for NaN values.
393
+
394
+ Raises
395
+ ------
396
+ ValueError
397
+ If NaN values are found in the data, raises an error with a descriptive message.
398
+ """
266
399
  error_message = (
267
400
  "NaN values found in regridded topography. This likely occurs because the ROMS grid, including "
268
401
  "a small safety margin for interpolation, is not fully contained within the topography dataset's longitude/latitude range. Please ensure that the "
roms_tools/setup/utils.py CHANGED
@@ -927,37 +927,40 @@ def get_vector_pairs(variable_info):
927
927
  return vector_pairs
928
928
 
929
929
 
930
- def gc_dist(lon1, lat1, lon2, lat2):
931
- """Calculate the great circle distance between two points on the Earth's surface.
932
- Latitude and longitude must be provided in degrees (they will be converted to
933
- radians).
930
+ def gc_dist(lon1, lat1, lon2, lat2, input_in_degrees=True):
931
+ """Calculate the great circle distance between two points on the Earth's surface
932
+ using the Haversine formula.
934
933
 
935
- The function uses the Haversine formula to compute the shortest distance
936
- along the surface of a sphere (Earth), assuming the Earth is a perfect sphere.
934
+ Latitude and longitude are assumed to be in degrees by default. If `input_in_degrees` is set to `False`,
935
+ the input is assumed to already be in radians.
937
936
 
938
937
  Parameters
939
938
  ----------
940
939
  lon1, lat1 : float
941
- Longitude and latitude of the first point in degrees.
940
+ Longitude and latitude of the first point.
942
941
  lon2, lat2 : float
943
- Longitude and latitude of the second point in degrees.
942
+ Longitude and latitude of the second point.
943
+ input_in_degrees : bool, optional
944
+ If True (default), the input coordinates are assumed to be in degrees and will be converted to radians.
945
+ If False, the input is assumed to be in radians and no conversion is applied.
944
946
 
945
947
  Returns
946
948
  -------
947
949
  dis : float
948
950
  The great circle distance between the two points in meters.
949
- This is the shortest distance along the surface of a sphere (Earth).
950
951
 
951
952
  Notes
952
953
  -----
953
954
  The radius of the Earth is taken to be 6371315 meters.
954
955
  """
956
+
955
957
  # Convert degrees to radians
956
- d2r = np.pi / 180
957
- lon1 = lon1 * d2r
958
- lat1 = lat1 * d2r
959
- lon2 = lon2 * d2r
960
- lat2 = lat2 * d2r
958
+ if input_in_degrees:
959
+ d2r = np.pi / 180
960
+ lon1 = lon1 * d2r
961
+ lat1 = lat1 * d2r
962
+ lon2 = lon2 * d2r
963
+ lat2 = lat2 * d2r
961
964
 
962
965
  # Difference in latitudes and longitudes
963
966
  dlat = lat2 - lat1
@@ -1058,12 +1061,17 @@ def _to_yaml(forcing_object, filepath: Union[str, Path]) -> None:
1058
1061
  filepath = Path(filepath)
1059
1062
 
1060
1063
  # Step 1: Serialize Grid data
1061
- # Convert the grid attribute to a dictionary and remove non-serializable fields
1062
- grid_data = asdict(forcing_object.grid)
1063
- grid_data.pop("ds", None) # Remove 'ds' attribute (non-serializable)
1064
- grid_data.pop("straddle", None)
1065
- grid_data.pop("verbose", None)
1066
- grid_yaml_data = {"Grid": grid_data}
1064
+ # Check if the forcing_object has a grid attribute
1065
+ if hasattr(forcing_object, "grid") and forcing_object.grid is not None:
1066
+ grid_data = asdict(forcing_object.grid)
1067
+ grid_yaml_data = {"Grid": _pop_grid_data(grid_data)}
1068
+ else:
1069
+ parent_grid_data = asdict(forcing_object.parent_grid)
1070
+ parent_grid_yaml_data = {"ParentGrid": _pop_grid_data(parent_grid_data)}
1071
+ child_grid_data = asdict(forcing_object.child_grid)
1072
+ child_grid_yaml_data = {"ChildGrid": _pop_grid_data(child_grid_data)}
1073
+
1074
+ grid_yaml_data = {**parent_grid_yaml_data, **child_grid_yaml_data}
1067
1075
 
1068
1076
  # Step 2: Get ROMS Tools version
1069
1077
  # Fetch the version of the 'roms-tools' package for inclusion in the YAML header
@@ -1082,7 +1090,16 @@ def _to_yaml(forcing_object, filepath: Union[str, Path]) -> None:
1082
1090
  filtered_field_names = [
1083
1091
  param
1084
1092
  for param in field_names
1085
- if param not in ("grid", "ds", "use_dask", "climatology")
1093
+ if param
1094
+ not in (
1095
+ "grid",
1096
+ "parent_grid",
1097
+ "child_grid",
1098
+ "ds",
1099
+ "use_dask",
1100
+ "bypass_validation",
1101
+ "climatology",
1102
+ )
1086
1103
  ]
1087
1104
 
1088
1105
  for field_name in filtered_field_names:
@@ -1111,6 +1128,14 @@ def _to_yaml(forcing_object, filepath: Union[str, Path]) -> None:
1111
1128
  yaml.dump(yaml_data, file, default_flow_style=False, sort_keys=False)
1112
1129
 
1113
1130
 
1131
+ def _pop_grid_data(grid_data):
1132
+ grid_data.pop("ds", None) # Remove 'ds' attribute (non-serializable)
1133
+ grid_data.pop("straddle", None)
1134
+ grid_data.pop("verbose", None)
1135
+
1136
+ return grid_data
1137
+
1138
+
1114
1139
  def _from_yaml(forcing_object: Type, filepath: Union[str, Path]) -> Dict[str, Any]:
1115
1140
  """Extract the configuration data for a given forcing object from a YAML file.
1116
1141
 
@@ -1203,3 +1228,84 @@ def handle_boundaries(field):
1203
1228
  field[:, -1] = field[:, -2]
1204
1229
 
1205
1230
  return field
1231
+
1232
+
1233
+ def get_boundary_coords():
1234
+ """This function determines the boundary points for the grid variables by specifying
1235
+ the indices for the south, east, north, and west boundaries.
1236
+
1237
+ Returns
1238
+ -------
1239
+ dict
1240
+ A dictionary containing the boundary coordinates for different variable types.
1241
+ The dictionary has the following structure:
1242
+ - Keys: Variable types ("rho", "u", "v", "vector").
1243
+ - Values: Nested dictionaries that map each direction ("south", "east", "north", "west")
1244
+ to another dictionary specifying the boundary coordinates, represented by grid indices
1245
+ for the respective variable types. For example:
1246
+ - "rho" variables (e.g., `eta_rho`, `xi_rho`)
1247
+ - "u" variables (e.g., `xi_u`)
1248
+ - "v" variables (e.g., `eta_v`)
1249
+ - "vector" variables with lists of indices for multiple grid points (e.g., `eta_rho`, `xi_rho`).
1250
+ """
1251
+
1252
+ bdry_coords = {
1253
+ "rho": {
1254
+ "south": {"eta_rho": 0},
1255
+ "east": {"xi_rho": -1},
1256
+ "north": {"eta_rho": -1},
1257
+ "west": {"xi_rho": 0},
1258
+ },
1259
+ "u": {
1260
+ "south": {"eta_rho": 0},
1261
+ "east": {"xi_u": -1},
1262
+ "north": {"eta_rho": -1},
1263
+ "west": {"xi_u": 0},
1264
+ },
1265
+ "v": {
1266
+ "south": {"eta_v": 0},
1267
+ "east": {"xi_rho": -1},
1268
+ "north": {"eta_v": -1},
1269
+ "west": {"xi_rho": 0},
1270
+ },
1271
+ "vector": {
1272
+ "south": {"eta_rho": [0, 1]},
1273
+ "east": {"xi_rho": [-2, -1]},
1274
+ "north": {"eta_rho": [-2, -1]},
1275
+ "west": {"xi_rho": [0, 1]},
1276
+ },
1277
+ }
1278
+
1279
+ return bdry_coords
1280
+
1281
+
1282
+ def wrap_longitudes(grid_ds, straddle):
1283
+ """Adjusts longitude values in a dataset to handle dateline crossing.
1284
+
1285
+ Parameters
1286
+ ----------
1287
+ grid_ds : xr.Dataset
1288
+ The dataset containing longitude variables to adjust.
1289
+ straddle : bool
1290
+ If True, adjusts longitudes to the range [-180, 180] for datasets
1291
+ that straddle the dateline. If False, adjusts longitudes to the
1292
+ range [0, 360].
1293
+
1294
+ Returns
1295
+ -------
1296
+ xr.Dataset
1297
+ The dataset with adjusted longitude values.
1298
+ """
1299
+ for lon_dim in ["lon_rho", "lon_u", "lon_v"]:
1300
+ if straddle:
1301
+ grid_ds[lon_dim] = xr.where(
1302
+ grid_ds[lon_dim] > 180,
1303
+ grid_ds[lon_dim] - 360,
1304
+ grid_ds[lon_dim],
1305
+ )
1306
+ else:
1307
+ grid_ds[lon_dim] = xr.where(
1308
+ grid_ds[lon_dim] < 0, grid_ds[lon_dim] + 360, grid_ds[lon_dim]
1309
+ )
1310
+
1311
+ return grid_ds
@@ -120,8 +120,7 @@
120
120
  "nriver"
121
121
  ],
122
122
  "coordinates": "abs_time river_name tracer_name",
123
- "long_name": "River tracer data",
124
- "units": "degrees C [temperature]; psu [salinity]"
123
+ "long_name": "River tracer data"
125
124
  },
126
125
  "river_volume/.zarray": {
127
126
  "chunks": [
@@ -165,7 +164,7 @@
165
164
  "id": "blosc",
166
165
  "shuffle": 1
167
166
  },
168
- "dtype": "<U11",
167
+ "dtype": "<U4",
169
168
  "fill_value": null,
170
169
  "filters": null,
171
170
  "order": "C",
@@ -5,6 +5,5 @@
5
5
  "nriver"
6
6
  ],
7
7
  "coordinates": "abs_time river_name tracer_name",
8
- "long_name": "River tracer data",
9
- "units": "degrees C [temperature]; psu [salinity]"
8
+ "long_name": "River tracer data"
10
9
  }
@@ -9,7 +9,7 @@
9
9
  "id": "blosc",
10
10
  "shuffle": 1
11
11
  },
12
- "dtype": "<U11",
12
+ "dtype": "<U4",
13
13
  "fill_value": null,
14
14
  "filters": null,
15
15
  "order": "C",
@@ -121,7 +121,7 @@
121
121
  "river_tracer/.zarray": {
122
122
  "chunks": [
123
123
  12,
124
- 2,
124
+ 34,
125
125
  6
126
126
  ],
127
127
  "compressor": {
@@ -137,7 +137,7 @@
137
137
  "order": "C",
138
138
  "shape": [
139
139
  12,
140
- 2,
140
+ 34,
141
141
  6
142
142
  ],
143
143
  "zarr_format": 2
@@ -149,8 +149,7 @@
149
149
  "nriver"
150
150
  ],
151
151
  "coordinates": "abs_time month river_name tracer_name",
152
- "long_name": "River tracer data",
153
- "units": "degrees C [temperature]; psu [salinity]"
152
+ "long_name": "River tracer data"
154
153
  },
155
154
  "river_volume/.zarray": {
156
155
  "chunks": [
@@ -185,7 +184,7 @@
185
184
  },
186
185
  "tracer_name/.zarray": {
187
186
  "chunks": [
188
- 2
187
+ 34
189
188
  ],
190
189
  "compressor": {
191
190
  "blocksize": 0,
@@ -199,7 +198,7 @@
199
198
  "filters": null,
200
199
  "order": "C",
201
200
  "shape": [
202
- 2
201
+ 34
203
202
  ],
204
203
  "zarr_format": 2
205
204
  },
@@ -1,7 +1,7 @@
1
1
  {
2
2
  "chunks": [
3
3
  12,
4
- 2,
4
+ 34,
5
5
  6
6
6
  ],
7
7
  "compressor": {
@@ -17,7 +17,7 @@
17
17
  "order": "C",
18
18
  "shape": [
19
19
  12,
20
- 2,
20
+ 34,
21
21
  6
22
22
  ],
23
23
  "zarr_format": 2
@@ -5,6 +5,5 @@
5
5
  "nriver"
6
6
  ],
7
7
  "coordinates": "abs_time month river_name tracer_name",
8
- "long_name": "River tracer data",
9
- "units": "degrees C [temperature]; psu [salinity]"
8
+ "long_name": "River tracer data"
10
9
  }