roms-tools 2.0.0__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. roms_tools/__init__.py +2 -1
  2. roms_tools/setup/boundary_forcing.py +21 -30
  3. roms_tools/setup/datasets.py +13 -21
  4. roms_tools/setup/grid.py +253 -139
  5. roms_tools/setup/initial_conditions.py +21 -3
  6. roms_tools/setup/mask.py +50 -4
  7. roms_tools/setup/nesting.py +575 -0
  8. roms_tools/setup/plot.py +214 -55
  9. roms_tools/setup/river_forcing.py +125 -29
  10. roms_tools/setup/surface_forcing.py +21 -8
  11. roms_tools/setup/tides.py +21 -3
  12. roms_tools/setup/topography.py +168 -35
  13. roms_tools/setup/utils.py +127 -21
  14. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/.zmetadata +2 -3
  15. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/river_tracer/.zattrs +1 -2
  16. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/tracer_name/.zarray +1 -1
  17. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/tracer_name/0 +0 -0
  18. roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/.zmetadata +5 -6
  19. roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_tracer/.zarray +2 -2
  20. roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_tracer/.zattrs +1 -2
  21. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/river_tracer/0.0.0 +0 -0
  22. roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/tracer_name/.zarray +2 -2
  23. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/tracer_name/0 +0 -0
  24. roms_tools/tests/test_setup/test_datasets.py +2 -2
  25. roms_tools/tests/test_setup/test_nesting.py +489 -0
  26. roms_tools/tests/test_setup/test_river_forcing.py +50 -13
  27. roms_tools/tests/test_setup/test_surface_forcing.py +1 -0
  28. roms_tools/tests/test_setup/test_validation.py +2 -2
  29. {roms_tools-2.0.0.dist-info → roms_tools-2.1.0.dist-info}/METADATA +8 -4
  30. {roms_tools-2.0.0.dist-info → roms_tools-2.1.0.dist-info}/RECORD +51 -50
  31. {roms_tools-2.0.0.dist-info → roms_tools-2.1.0.dist-info}/WHEEL +1 -1
  32. roms_tools/_version.py +0 -2
  33. roms_tools/tests/test_setup/test_data/river_forcing.zarr/river_tracer/0.0.0 +0 -0
  34. roms_tools/tests/test_setup/test_data/river_forcing.zarr/tracer_name/0 +0 -0
  35. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/.zattrs +0 -0
  36. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/.zgroup +0 -0
  37. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/abs_time/.zarray +0 -0
  38. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/abs_time/.zattrs +0 -0
  39. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/abs_time/0 +0 -0
  40. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/month/.zarray +0 -0
  41. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/month/.zattrs +0 -0
  42. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/month/0 +0 -0
  43. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_name/.zarray +0 -0
  44. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_name/.zattrs +0 -0
  45. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_name/0 +0 -0
  46. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_time/.zarray +0 -0
  47. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_time/.zattrs +0 -0
  48. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_time/0 +0 -0
  49. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_volume/.zarray +0 -0
  50. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_volume/.zattrs +0 -0
  51. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_volume/0.0 +0 -0
  52. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/tracer_name/.zattrs +0 -0
  53. {roms_tools-2.0.0.dist-info → roms_tools-2.1.0.dist-info}/LICENSE +0 -0
  54. {roms_tools-2.0.0.dist-info → roms_tools-2.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,575 @@
1
+ import numpy as np
2
+ import xarray as xr
3
+ from scipy.interpolate import griddata
4
+ from dataclasses import dataclass, field
5
+ from typing import Dict, Union
6
+ from pathlib import Path
7
+ from roms_tools.setup.grid import Grid
8
+ from roms_tools.setup.utils import (
9
+ interpolate_from_rho_to_u,
10
+ interpolate_from_rho_to_v,
11
+ get_boundary_coords,
12
+ wrap_longitudes,
13
+ save_datasets,
14
+ _to_yaml,
15
+ _from_yaml,
16
+ )
17
+ from roms_tools.setup.plot import _plot_nesting
18
+ import logging
19
+ from scipy.interpolate import interp1d
20
+
21
+
22
+ @dataclass(frozen=True, kw_only=True)
23
+ class Nesting:
24
+ """Represents the relationship between a parent and a child grid in nested ROMS
25
+ simulations. This class facilitates mapping the boundaries of the child grid onto
26
+ the parent grid indices and modifying the child grid topography such that it matches
27
+ the parent topography at the boundaries.
28
+
29
+ Parameters
30
+ ----------
31
+ parent_grid : Grid
32
+ The parent grid object, containing information about the larger domain.
33
+ child_grid : Grid
34
+ The child grid object, containing information about the nested domain.
35
+ boundaries : Dict[str, bool], optional
36
+ Dictionary specifying which boundaries of the child grid are used
37
+ in the nesting process. Keys are "south", "east", "north", and "west",
38
+ with boolean values indicating inclusion. Defaults to all boundaries (True).
39
+ child_prefix : str, optional
40
+ Prefix added to variable names in the generated dataset to distinguish
41
+ child grid information. Defaults to "child".
42
+ period : float, optional
43
+ The temporal resolution or output period for boundary variables in the child grid.
44
+ Defaults to hourly.
45
+
46
+ Attributes
47
+ ----------
48
+ ds : xr.Dataset
49
+ An xarray Dataset containing the index mappings between the child and parent grids
50
+ for each specified boundary. Includes metadata about grid points, angles,
51
+ and boundary variable outputs.
52
+ """
53
+
54
+ parent_grid: Grid
55
+ child_grid: Grid
56
+ boundaries: Dict[str, bool] = field(
57
+ default_factory=lambda: {
58
+ "south": True,
59
+ "east": True,
60
+ "north": True,
61
+ "west": True,
62
+ }
63
+ )
64
+ child_prefix: str = "child"
65
+ period: float = 3600.0
66
+
67
+ def __post_init__(self):
68
+
69
+ parent_grid_ds = self.parent_grid.ds
70
+ child_grid_ds = self.child_grid.ds
71
+
72
+ # Adjust longitude for dateline crossing to prevent interpolation artifacts
73
+ for grid_ds in [parent_grid_ds, child_grid_ds]:
74
+ grid_ds = wrap_longitudes(grid_ds, straddle=self.parent_grid.straddle)
75
+
76
+ # Map child boundaries onto parent grid indices
77
+ ds = map_child_boundaries_onto_parent_grid_indices(
78
+ parent_grid_ds,
79
+ child_grid_ds,
80
+ self.boundaries,
81
+ self.child_prefix,
82
+ self.period,
83
+ )
84
+ object.__setattr__(self, "ds", ds)
85
+
86
+ # Modify child topography and mask to match the parent grid
87
+ child_grid_ds = modify_child_topography_and_mask(
88
+ parent_grid_ds, child_grid_ds, self.boundaries
89
+ )
90
+
91
+ # Convert longitudes back to [0, 360] range
92
+ for grid_ds in [parent_grid_ds, child_grid_ds]:
93
+ grid_ds = wrap_longitudes(grid_ds, straddle=False)
94
+ object.__setattr__(self.parent_grid, "ds", parent_grid_ds)
95
+ object.__setattr__(self.child_grid, "ds", child_grid_ds)
96
+
97
+ def plot(self, with_dim_names=False) -> None:
98
+ """Plot the parent and child grids in a single figure.
99
+
100
+ Returns
101
+ -------
102
+ None
103
+ This method does not return any value. It generates and displays a plot.
104
+ """
105
+
106
+ _plot_nesting(
107
+ self.parent_grid.ds,
108
+ self.child_grid.ds,
109
+ self.parent_grid.straddle,
110
+ with_dim_names,
111
+ )
112
+
113
+ def save(
114
+ self,
115
+ filepath: Union[str, Path],
116
+ filepath_child_grid: Union[str, Path],
117
+ np_eta: int = None,
118
+ np_xi: int = None,
119
+ ) -> None:
120
+ """Save the nesting and child grid file to netCDF4 files. The child grid file is
121
+ required because the topography and mask of the child grid has been modified.
122
+
123
+ This method allows saving the nesting and child grid data either each as a single file or each partitioned into multiple files, based on the provided options. The dataset can be saved in two modes:
124
+
125
+ 1. **Single File Mode (default)**:
126
+ - If both `np_eta` and `np_xi` are `None`, the entire dataset is saved as a single netCDF4 file.
127
+ - The file is named based on the provided `filepath`, with `.nc` automatically appended to the filename.
128
+
129
+ 2. **Partitioned Mode**:
130
+ - If either `np_eta` or `np_xi` is specified, the dataset is partitioned spatially along the `eta` and `xi` axes into tiles.
131
+ - Each tile is saved as a separate netCDF4 file. Filenames will be modified with an index to represent each partition, e.g., `"filepath_YYYYMM.0.nc"`, `"filepath_YYYYMM.1.nc"`, etc.
132
+
133
+ Parameters
134
+ ----------
135
+ filepath : Union[str, Path]
136
+ The base path and filename for the output files. The filenames will include the specified path and the `.nc` extension.
137
+ If partitioning is used, additional indices will be appended to the filenames, e.g., `"filepath.0.nc"`, `"filepath.1.nc"`, etc.
138
+
139
+ filepath_child_grid : Union[str, Path]
140
+ The base path and filename for saving the childe grid file.
141
+
142
+ np_eta : int, optional
143
+ The number of partitions along the `eta` direction. If `None`, no spatial partitioning is performed along the `eta` axis.
144
+
145
+ np_xi : int, optional
146
+ The number of partitions along the `xi` direction. If `None`, no spatial partitioning is performed along the `xi` axis.
147
+
148
+ Returns
149
+ -------
150
+ List[Path]
151
+ A list of `Path` objects for the saved files. Each element in the list corresponds to a file that was saved.
152
+ """
153
+
154
+ # Ensure filepath is a Path object
155
+ filepath = Path(filepath)
156
+ filepath_child_grid = Path(filepath_child_grid)
157
+
158
+ # Remove ".nc" suffix if present
159
+ if filepath.suffix == ".nc":
160
+ filepath = filepath.with_suffix("")
161
+ if filepath_child_grid.suffix == ".nc":
162
+ filepath_child_grid = filepath_child_grid.with_suffix("")
163
+
164
+ dataset_list = [self.ds, self.child_grid.ds]
165
+ output_filenames = [str(filepath), str(filepath_child_grid)]
166
+
167
+ saved_filenames = save_datasets(
168
+ dataset_list, output_filenames, np_eta=np_eta, np_xi=np_xi
169
+ )
170
+
171
+ return saved_filenames
172
+
173
+ def to_yaml(self, filepath: Union[str, Path]) -> None:
174
+ """Export the parameters of the class to a YAML file, including the version of
175
+ roms-tools.
176
+
177
+ Parameters
178
+ ----------
179
+ filepath : Union[str, Path]
180
+ The path to the YAML file where the parameters will be saved.
181
+ """
182
+
183
+ _to_yaml(self, filepath)
184
+
185
+ @classmethod
186
+ def from_yaml(cls, filepath: Union[str, Path]) -> "Nesting":
187
+ """Create an instance of the Nesting class from a YAML file.
188
+
189
+ Parameters
190
+ ----------
191
+ filepath : Union[str, Path]
192
+ The path to the YAML file from which the parameters will be read.
193
+
194
+ Returns
195
+ -------
196
+ Nesting
197
+ An instance of the Nesting class.
198
+ """
199
+ filepath = Path(filepath)
200
+
201
+ parent_grid = Grid.from_yaml(filepath, "ParentGrid")
202
+ child_grid = Grid.from_yaml(filepath, "ChildGrid")
203
+ params = _from_yaml(cls, filepath)
204
+
205
+ return cls(parent_grid=parent_grid, child_grid=child_grid, **params)
206
+
207
+
208
+ def map_child_boundaries_onto_parent_grid_indices(
209
+ parent_grid_ds,
210
+ child_grid_ds,
211
+ boundaries={"south": True, "east": True, "north": True, "west": True},
212
+ child_prefix="child",
213
+ period=3600.0,
214
+ update_land_indices=True,
215
+ ):
216
+
217
+ bdry_coords_dict = get_boundary_coords()
218
+
219
+ # add angles at u- and v-points
220
+ child_grid_ds["angle_u"] = interpolate_from_rho_to_u(child_grid_ds["angle"])
221
+ child_grid_ds["angle_v"] = interpolate_from_rho_to_v(child_grid_ds["angle"])
222
+
223
+ ds = xr.Dataset()
224
+
225
+ for direction in ["south", "east", "north", "west"]:
226
+ if boundaries[direction]:
227
+ for grid_location in ["rho", "u", "v"]:
228
+ names = {
229
+ "latitude": f"lat_{grid_location}",
230
+ "longitude": f"lon_{grid_location}",
231
+ "mask": f"mask_{grid_location}",
232
+ "angle": f"angle_{grid_location}",
233
+ }
234
+ bdry_coords = bdry_coords_dict[grid_location]
235
+ if grid_location == "rho":
236
+ suffix = "r"
237
+ else:
238
+ suffix = grid_location
239
+
240
+ lon_child = child_grid_ds[names["longitude"]].isel(
241
+ **bdry_coords[direction]
242
+ )
243
+ lat_child = child_grid_ds[names["latitude"]].isel(
244
+ **bdry_coords[direction]
245
+ )
246
+
247
+ mask_child = child_grid_ds[names["mask"]].isel(**bdry_coords[direction])
248
+
249
+ i_eta, i_xi = interpolate_indices(
250
+ parent_grid_ds, lon_child, lat_child, mask_child
251
+ )
252
+
253
+ if update_land_indices:
254
+ i_eta, i_xi = update_indices_if_on_parent_land(
255
+ i_eta, i_xi, grid_location, parent_grid_ds
256
+ )
257
+
258
+ var_name = f"{child_prefix}_{direction}_{suffix}"
259
+ if grid_location == "rho":
260
+ ds[var_name] = xr.concat([i_xi, i_eta], dim="two")
261
+ ds[var_name].attrs[
262
+ "long_name"
263
+ ] = f"{grid_location}-points of {direction}ern child boundary mapped onto parent (absolute) grid indices"
264
+ ds[var_name].attrs["units"] = "non-dimensional"
265
+ ds[var_name].attrs["output_vars"] = "zeta, temp, salt"
266
+ else:
267
+ angle_child = child_grid_ds[names["angle"]].isel(
268
+ **bdry_coords[direction]
269
+ )
270
+ ds[var_name] = xr.concat([i_xi, i_eta, angle_child], dim="three")
271
+ ds[var_name].attrs[
272
+ "long_name"
273
+ ] = f"{grid_location}-points of {direction}ern child boundary mapped onto parent grid (absolute) indices and angle"
274
+ ds[var_name].attrs["units"] = "non-dimensional and radian"
275
+
276
+ if grid_location == "u":
277
+ ds[var_name].attrs["output_vars"] = "ubar, u, up"
278
+ elif grid_location == "v":
279
+ ds[var_name].attrs["output_vars"] = "vbar, v, vp"
280
+
281
+ ds[var_name].attrs["output_period"] = period
282
+
283
+ vars_to_drop = ["lat_rho", "lon_rho", "lat_u", "lon_u", "lat_v", "lon_v"]
284
+ vars_to_drop_existing = [var for var in vars_to_drop if var in ds]
285
+ ds = ds.drop_vars(vars_to_drop_existing)
286
+
287
+ # Rename dimensions
288
+ dims_to_rename = {
289
+ dim: f"{child_prefix}_{dim}" for dim in ds.dims if dim not in ["two", "three"]
290
+ }
291
+ ds = ds.rename(dims_to_rename)
292
+
293
+ ds = ds.assign_coords(
294
+ {
295
+ "indices_rho": ("two", ["xi", "eta"]),
296
+ "indices_vel": ("three", ["xi", "eta", "angle"]),
297
+ }
298
+ )
299
+
300
+ return ds
301
+
302
+
303
+ def interpolate_indices(parent_grid_ds, lon, lat, mask):
304
+ """Interpolate the parent indices to the child grid.
305
+
306
+ Parameters
307
+ ----------
308
+ parent_grid_ds : xarray.Dataset
309
+ Grid information of parent grid.
310
+ lon : xarray.DataArray
311
+ Longitudes of the child grid where interpolation is desired.
312
+ lat : xarray.DataArray
313
+ Latitudes of the child grid where interpolation is desired.
314
+ mask: xarray.DataArray
315
+ Mask for the child longitudes and latitudes under consideration.
316
+ Returns
317
+ -------
318
+ i : xarray.DataArray
319
+ Interpolated i-indices for the child grid.
320
+ j : xarray.DataArray
321
+ Interpolated j-indices for the child grid.
322
+ """
323
+ i_eta = np.arange(-0.5, len(parent_grid_ds.eta_rho) + -0.5, 1)
324
+ i_xi = np.arange(-0.5, len(parent_grid_ds.xi_rho) + -0.5, 1)
325
+
326
+ parent_grid_ds = parent_grid_ds.assign_coords(
327
+ i_eta=("eta_rho", i_eta)
328
+ ).assign_coords(i_xi=("xi_rho", i_xi))
329
+
330
+ lon_parent = parent_grid_ds.lon_rho
331
+ lat_parent = parent_grid_ds.lat_rho
332
+ i_parent = parent_grid_ds.i_eta
333
+ j_parent = parent_grid_ds.i_xi
334
+
335
+ # Create meshgrid
336
+ j_parent, i_parent = np.meshgrid(j_parent.values, i_parent.values)
337
+
338
+ # Flatten the input coordinates and indices for griddata
339
+ points = np.column_stack((lon_parent.values.ravel(), lat_parent.values.ravel()))
340
+ i_parent_flat = i_parent.ravel()
341
+ j_parent_flat = j_parent.ravel()
342
+
343
+ # Interpolate the i and j indices
344
+ i = griddata(points, i_parent_flat, (lon.values, lat.values), method="linear")
345
+ j = griddata(points, j_parent_flat, (lon.values, lat.values), method="linear")
346
+
347
+ i = xr.DataArray(i, dims=lon.dims)
348
+ j = xr.DataArray(j, dims=lon.dims)
349
+
350
+ # Check for NaN values
351
+ if np.sum(np.isnan(i)) > 0 or np.sum(np.isnan(j)) > 0:
352
+ raise ValueError(
353
+ "Some points are outside the grid. Please choose either a bigger parent grid or a smaller child grid."
354
+ )
355
+
356
+ # Check whether indices are close to border of parent grid
357
+ nxp, nyp = lon_parent.shape
358
+ if np.min(i) < 0 or np.max(i) > nxp - 2:
359
+ logging.warning(
360
+ "Some boundary points of the child grid are very close to the boundary of the parent grid."
361
+ )
362
+ if np.min(j) < 0 or np.max(j) > nyp - 2:
363
+ logging.warning(
364
+ "Some boundary points of the child grid are very close to the boundary of the parent grid."
365
+ )
366
+
367
+ return i, j
368
+
369
+
370
+ def update_indices_if_on_parent_land(i_eta, i_xi, grid_location, parent_grid_ds):
371
+ """Finds points that are in the parent land mask but not land masked in the child
372
+ and replaces parent indices with nearest neighbor wet points.
373
+
374
+ Parameters
375
+ ----------
376
+ i_eta : xarray.DataArray
377
+ Interpolated i_eta-indices for the child grid.
378
+ i_xi : xarray.DataArray
379
+ Interpolated i_xi-indices for the child grid.
380
+ mask: xarray.DataArray
381
+ Mask for the child longitudes and latitudes under consideration.
382
+ grid_location : str
383
+ Location type ('rho', 'u', 'v').
384
+ parent_grid_ds : xarray.Dataset
385
+ Grid information of parent grid.
386
+
387
+ Returns
388
+ -------
389
+ i_eta : xarray.DataArray
390
+ Updated i_eta-indices for the child grid.
391
+ i_xi : xarray.DataArray
392
+ Updated i_xi-indices for the child grid.
393
+ """
394
+
395
+ if grid_location == "rho":
396
+ i_eta_rho = i_eta + 0.5
397
+ i_xi_rho = i_xi + 0.5
398
+ mask_rho = parent_grid_ds.mask_rho
399
+ summed_mask = np.zeros_like(i_eta_rho)
400
+
401
+ for i in range(len(i_eta_rho)):
402
+ i_eta_lower = int(np.floor(i_eta_rho[i]))
403
+ i_xi_lower = int(np.floor(i_xi_rho[i]))
404
+ mask = mask_rho.isel(
405
+ eta_rho=slice(i_eta_lower, i_eta_lower + 2),
406
+ xi_rho=slice(i_xi_lower, i_xi_lower + 2),
407
+ )
408
+ summed_mask[i] = np.sum(mask)
409
+
410
+ elif grid_location in ["u", "v"]:
411
+ i_eta_u = i_eta + 0.5
412
+ i_xi_u = i_xi
413
+
414
+ mask_u = parent_grid_ds.mask_u
415
+ summed_mask_u = np.zeros_like(i_eta_u)
416
+
417
+ for i in range(len(i_eta_u)):
418
+ i_eta_lower = int(np.floor(i_eta_u[i]))
419
+ i_xi_lower = int(np.floor(i_xi_u[i]))
420
+ mask = mask_u.isel(
421
+ eta_rho=slice(i_eta_lower, i_eta_lower + 2),
422
+ xi_u=slice(i_xi_lower, i_xi_lower + 2),
423
+ )
424
+ summed_mask_u[i] = np.sum(mask)
425
+
426
+ i_eta_v = i_eta
427
+ i_xi_v = i_xi + 0.5
428
+
429
+ mask_v = parent_grid_ds.mask_v
430
+ summed_mask_v = np.zeros_like(i_xi_v)
431
+
432
+ for i in range(len(i_eta_v)):
433
+ i_eta_lower = int(np.floor(i_eta_v[i]))
434
+ i_xi_lower = int(np.floor(i_xi_v[i]))
435
+ mask = mask_v.isel(
436
+ eta_v=slice(i_eta_lower, i_eta_lower + 2),
437
+ xi_rho=slice(i_xi_lower, i_xi_lower + 2),
438
+ )
439
+ summed_mask_v[i] = np.sum(mask)
440
+
441
+ summed_mask = summed_mask_u * summed_mask_v
442
+
443
+ # Filter out points where summed_mask is 0
444
+ valid_points = summed_mask != 0
445
+ x_mod = np.arange(len(summed_mask))[valid_points]
446
+ i_eta_mod = i_eta[valid_points]
447
+ i_xi_mod = i_xi[valid_points]
448
+
449
+ # Handle indices where summed_mask is 0
450
+ indx = np.where(summed_mask == 0)[0]
451
+ if len(indx) > 0:
452
+ i_eta_interp = interp1d(
453
+ x_mod, i_eta_mod, kind="nearest", fill_value="extrapolate"
454
+ )
455
+ i_xi_interp = interp1d(
456
+ x_mod, i_xi_mod, kind="nearest", fill_value="extrapolate"
457
+ )
458
+
459
+ i_eta[indx] = i_eta_interp(indx)
460
+ i_xi[indx] = i_xi_interp(indx)
461
+
462
+ return i_eta, i_xi
463
+
464
+
465
+ def modify_child_topography_and_mask(
466
+ parent_grid_ds,
467
+ child_grid_ds,
468
+ boundaries={"south": True, "east": True, "north": True, "west": True},
469
+ ):
470
+
471
+ # regrid parent topography and mask onto child grid
472
+ points = np.column_stack(
473
+ (parent_grid_ds.lon_rho.values.ravel(), parent_grid_ds.lat_rho.values.ravel())
474
+ )
475
+ xi = (child_grid_ds.lon_rho.values, child_grid_ds.lat_rho.values)
476
+
477
+ values = parent_grid_ds["h"].values.ravel()
478
+ h_parent_interpolated = griddata(points, values, xi, method="linear")
479
+ h_parent_interpolated = xr.DataArray(
480
+ h_parent_interpolated, dims=("eta_rho", "xi_rho")
481
+ )
482
+
483
+ values = parent_grid_ds["mask_rho"].values.ravel()
484
+ mask_parent_interpolated = griddata(points, values, xi, method="linear")
485
+ mask_parent_interpolated = xr.DataArray(
486
+ mask_parent_interpolated, dims=("eta_rho", "xi_rho")
487
+ )
488
+
489
+ # compute weight based on distance
490
+ alpha = compute_boundary_distance(child_grid_ds["mask_rho"], boundaries)
491
+ # update child topography and mask to be weighted sum between original child and interpolated parent
492
+ child_grid_ds["h"] = (
493
+ alpha * child_grid_ds["h"] + (1 - alpha) * h_parent_interpolated
494
+ )
495
+
496
+ child_mask = (
497
+ alpha * child_grid_ds["mask_rho"] + (1 - alpha) * mask_parent_interpolated
498
+ )
499
+ child_grid_ds["mask_rho"] = xr.where(child_mask >= 0.5, 1, 0)
500
+
501
+ return child_grid_ds
502
+
503
+
504
+ def compute_boundary_distance(
505
+ child_mask, boundaries={"south": True, "east": True, "north": True, "west": True}
506
+ ):
507
+ """Computes a normalized distance field from the boundaries of a grid, given a mask
508
+ and boundary conditions. The normalized distance values range from 0 (boundary) to 1
509
+ (inner grid).
510
+
511
+ Parameters
512
+ ----------
513
+ child_mask : xr.DataArray
514
+ A 2D xarray DataArray representing the land/sea mask of the grid (1 for sea, 0 for land),
515
+ with dimensions ("eta_rho", "xi_rho").
516
+ boundaries : dict, optional
517
+ A dictionary specifying which boundaries are open. Keys are "south", "east", "north", "west",
518
+ with boolean values indicating whether the boundary is open.
519
+
520
+ Returns
521
+ -------
522
+ xr.DataArray
523
+ A 2D DataArray with normalized distance values.
524
+ """
525
+ dist = np.full_like(child_mask, 1e6, dtype=float)
526
+ nx, ny = child_mask.shape
527
+ n = max(nx, ny)
528
+
529
+ x = np.arange(nx) / n
530
+ y = np.arange(ny) / n
531
+ x, y = np.meshgrid(x, y, indexing="ij")
532
+
533
+ trans = 0.05
534
+ width = int(np.ceil(n * trans))
535
+
536
+ if boundaries["south"]:
537
+ bx = x[:, 0][child_mask[:, 0] > 0]
538
+ by = y[:, 0][child_mask[:, 0] > 0]
539
+ for i in range(len(bx)):
540
+ dtmp = (x[:, :width] - bx[i]) ** 2 + (y[:, :width] - by[i]) ** 2
541
+ dist[:, :width] = np.minimum(dist[:, :width], dtmp)
542
+
543
+ if boundaries["east"]:
544
+ bx = x[-1, :][child_mask[-1, :] > 0]
545
+ by = y[-1, :][child_mask[-1, :] > 0]
546
+ for i in range(len(bx)):
547
+ dtmp = (x[nx - width : nx, :] - bx[i]) ** 2 + (
548
+ y[nx - width : nx, :] - by[i]
549
+ ) ** 2
550
+ dist[nx - width : nx, :] = np.minimum(dist[nx - width : nx, :], dtmp)
551
+
552
+ if boundaries["north"]:
553
+ bx = x[:, -1][child_mask[:, -1] > 0]
554
+ by = y[:, -1][child_mask[:, -1] > 0]
555
+ for i in range(len(bx)):
556
+ dtmp = (x[:, ny - width : ny] - bx[i]) ** 2 + (
557
+ y[:, ny - width : ny] - by[i]
558
+ ) ** 2
559
+ dist[:, ny - width : ny] = np.minimum(dist[:, ny - width : ny], dtmp)
560
+
561
+ if boundaries["west"]:
562
+ bx = x[0, :][child_mask[0, :] > 0]
563
+ by = y[0, :][child_mask[0, :] > 0]
564
+ for i in range(len(bx)):
565
+ dtmp = (x[:width, :] - bx[i]) ** 2 + (y[:width, :] - by[i]) ** 2
566
+ dist[:width, :] = np.minimum(dist[:width, :], dtmp)
567
+
568
+ dist = np.sqrt(dist)
569
+ dist[dist > trans] = trans
570
+ dist = dist / trans
571
+ alpha = 0.5 - 0.5 * np.cos(np.pi * dist)
572
+
573
+ alpha = xr.DataArray(alpha, dims=("eta_rho", "xi_rho"))
574
+
575
+ return alpha