roms-tools 3.1.1__py3-none-any.whl → 3.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. roms_tools/__init__.py +8 -1
  2. roms_tools/analysis/cdr_analysis.py +203 -0
  3. roms_tools/analysis/cdr_ensemble.py +198 -0
  4. roms_tools/analysis/roms_output.py +80 -46
  5. roms_tools/data/grids/GLORYS_global_grid.nc +0 -0
  6. roms_tools/download.py +4 -0
  7. roms_tools/plot.py +131 -30
  8. roms_tools/regrid.py +6 -1
  9. roms_tools/setup/boundary_forcing.py +94 -44
  10. roms_tools/setup/cdr_forcing.py +123 -15
  11. roms_tools/setup/cdr_release.py +161 -8
  12. roms_tools/setup/datasets.py +709 -341
  13. roms_tools/setup/grid.py +167 -139
  14. roms_tools/setup/initial_conditions.py +113 -48
  15. roms_tools/setup/mask.py +63 -7
  16. roms_tools/setup/nesting.py +67 -42
  17. roms_tools/setup/river_forcing.py +45 -19
  18. roms_tools/setup/surface_forcing.py +16 -10
  19. roms_tools/setup/tides.py +1 -2
  20. roms_tools/setup/topography.py +4 -4
  21. roms_tools/setup/utils.py +134 -22
  22. roms_tools/tests/test_analysis/test_cdr_analysis.py +144 -0
  23. roms_tools/tests/test_analysis/test_cdr_ensemble.py +202 -0
  24. roms_tools/tests/test_analysis/test_roms_output.py +61 -3
  25. roms_tools/tests/test_setup/test_boundary_forcing.py +111 -52
  26. roms_tools/tests/test_setup/test_cdr_forcing.py +54 -0
  27. roms_tools/tests/test_setup/test_cdr_release.py +118 -1
  28. roms_tools/tests/test_setup/test_datasets.py +458 -34
  29. roms_tools/tests/test_setup/test_grid.py +238 -121
  30. roms_tools/tests/test_setup/test_initial_conditions.py +94 -41
  31. roms_tools/tests/test_setup/test_surface_forcing.py +28 -3
  32. roms_tools/tests/test_setup/test_utils.py +91 -1
  33. roms_tools/tests/test_setup/test_validation.py +21 -15
  34. roms_tools/tests/test_setup/utils.py +71 -0
  35. roms_tools/tests/test_tiling/test_join.py +241 -0
  36. roms_tools/tests/test_tiling/test_partition.py +45 -0
  37. roms_tools/tests/test_utils.py +224 -2
  38. roms_tools/tiling/join.py +189 -0
  39. roms_tools/tiling/partition.py +44 -30
  40. roms_tools/utils.py +488 -161
  41. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/METADATA +15 -4
  42. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/RECORD +45 -37
  43. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/WHEEL +0 -0
  44. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/licenses/LICENSE +0 -0
  45. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,10 @@
1
1
  import importlib.metadata
2
2
  import logging
3
+ from collections import defaultdict
3
4
  from dataclasses import dataclass, field
4
5
  from datetime import datetime
5
6
  from pathlib import Path
7
+ from typing import Literal
6
8
 
7
9
  import numpy as np
8
10
  import xarray as xr
@@ -11,7 +13,14 @@ from matplotlib.axes import Axes
11
13
  from roms_tools import Grid
12
14
  from roms_tools.plot import plot
13
15
  from roms_tools.regrid import LateralRegridToROMS, VerticalRegridToROMS
14
- from roms_tools.setup.datasets import CESMBGCDataset, GLORYSDataset, UnifiedBGCDataset
16
+ from roms_tools.setup.datasets import (
17
+ CESMBGCDataset,
18
+ Dataset,
19
+ GLORYSDataset,
20
+ GLORYSDefaultDataset,
21
+ RawDataSource,
22
+ UnifiedBGCDataset,
23
+ )
15
24
  from roms_tools.setup.utils import (
16
25
  compute_barotropic_velocity,
17
26
  compute_missing_bgc_variables,
@@ -25,7 +34,6 @@ from roms_tools.setup.utils import (
25
34
  write_to_yaml,
26
35
  )
27
36
  from roms_tools.utils import (
28
- get_dask_chunks,
29
37
  interpolate_from_rho_to_u,
30
38
  interpolate_from_rho_to_v,
31
39
  save_datasets,
@@ -48,7 +56,7 @@ class InitialConditions:
48
56
  ini_time : datetime
49
57
  The date and time at which the initial conditions are set.
50
58
  If no exact match is found, the closest time entry to `ini_time` within the time range [ini_time, ini_time + 24 hours] is selected.
51
- source : Dict[str, Union[str, Path, List[Union[str, Path]]], bool]
59
+ source : RawDataSource
52
60
 
53
61
  Dictionary specifying the source of the physical initial condition data. Keys include:
54
62
 
@@ -57,10 +65,12 @@ class InitialConditions:
57
65
 
58
66
  - A single string (with or without wildcards).
59
67
  - A single Path object.
60
- - A list of strings or Path objects containing multiple files.
68
+ - A list of strings or Path objects.
69
+ If omitted, the data will be streamed via the Copernicus Marine Toolkit.
70
+ Note: streaming is currently not recommended due to performance limitations.
61
71
  - "climatology" (bool): Indicates if the data is climatology data. Defaults to False.
62
72
 
63
- bgc_source : Dict[str, Union[str, Path, List[Union[str, Path]]], bool]
73
+ bgc_source : RawDataSource, optional
64
74
  Dictionary specifying the source of the biogeochemical (BGC) initial condition data. Keys include:
65
75
 
66
76
  - "name" (str): Name of the data source (e.g., "CESM_REGRIDDED").
@@ -78,6 +88,13 @@ class InitialConditions:
78
88
  The reference date for the model. Defaults to January 1, 2000.
79
89
  use_dask: bool, optional
80
90
  Indicates whether to use dask for processing. If True, data is processed with dask; if False, data is processed eagerly. Defaults to False.
91
+ allow_flex_time: bool, optional
92
+ Controls how strictly `ini_time` is handled:
93
+
94
+ - If False (default): requires an exact match to `ini_time`. Raises a ValueError if no match exists.
95
+ - If True: allows a +24h search window after `ini_time` and selects the closest available
96
+ time entry within that window. Raises a ValueError if none are found.
97
+
81
98
  horizontal_chunk_size : int, optional
82
99
  The chunk size used for horizontal partitioning for the vertical regridding when `use_dask = True`. Defaults to 50.
83
100
  A larger number results in a bigger memory footprint but faster computations.
@@ -105,9 +122,9 @@ class InitialConditions:
105
122
  """Object representing the grid information."""
106
123
  ini_time: datetime
107
124
  """The date and time at which the initial conditions are set."""
108
- source: dict[str, str | Path | list[str | Path]]
125
+ source: RawDataSource
109
126
  """Dictionary specifying the source of the physical initial condition data."""
110
- bgc_source: dict[str, str | Path | list[str | Path]] | None = None
127
+ bgc_source: RawDataSource | None = None
111
128
  """Dictionary specifying the source of the biogeochemical (BGC) initial condition
112
129
  data."""
113
130
  model_reference_date: datetime = datetime(2000, 1, 1)
@@ -115,6 +132,8 @@ class InitialConditions:
115
132
  adjust_depth_for_sea_surface_height: bool = False
116
133
  """Whether to account for sea surface height variations when computing depth
117
134
  coordinates."""
135
+ allow_flex_time: bool = False
136
+ """Whether to handle ini_time flexibly."""
118
137
  use_dask: bool = False
119
138
  """Whether to use dask for processing."""
120
139
  horizontal_chunk_size: int = 50
@@ -161,14 +180,10 @@ class InitialConditions:
161
180
  def _process_data(self, processed_fields, type="physics"):
162
181
  target_coords = get_target_coords(self.grid)
163
182
 
164
- if type == "physics":
165
- data = self._get_data()
166
- else:
167
- data = self._get_bgc_data()
183
+ data = self._get_data(forcing_type=type)
168
184
 
169
185
  data.choose_subdomain(
170
186
  target_coords,
171
- buffer_points=20, # lateral fill needs good buffer from data margin
172
187
  )
173
188
  # Enforce double precision to ensure reproducibility
174
189
  data.convert_to_float64()
@@ -255,7 +270,7 @@ class InitialConditions:
255
270
  field = processed_fields[var_name]
256
271
  if self.use_dask:
257
272
  field = field.chunk(
258
- get_dask_chunks(location, self.horizontal_chunk_size)
273
+ _set_dask_chunks(location, self.horizontal_chunk_size)
259
274
  )
260
275
  processed_fields[var_name] = vertical_regrid.apply(field)
261
276
 
@@ -280,7 +295,11 @@ class InitialConditions:
280
295
  if "name" not in self.source.keys():
281
296
  raise ValueError("`source` must include a 'name'.")
282
297
  if "path" not in self.source.keys():
283
- raise ValueError("`source` must include a 'path'.")
298
+ if self.source["name"] != "GLORYS":
299
+ raise ValueError("`source` must include a 'path'.")
300
+
301
+ self.source["path"] = GLORYSDefaultDataset.dataset_name
302
+
284
303
  # set self.source["climatology"] to False if not provided
285
304
  self.source = {
286
305
  **self.source,
@@ -307,40 +326,63 @@ class InitialConditions:
307
326
  "Sea surface height will NOT be used to adjust depth coordinates."
308
327
  )
309
328
 
310
- def _get_data(self):
311
- if self.source["name"] == "GLORYS":
312
- data = GLORYSDataset(
313
- filename=self.source["path"],
314
- start_time=self.ini_time,
315
- climatology=self.source["climatology"],
316
- use_dask=self.use_dask,
317
- )
318
- else:
319
- raise ValueError('Only "GLORYS" is a valid option for source["name"].')
320
- return data
321
-
322
- def _get_bgc_data(self):
323
- if self.bgc_source["name"] == "CESM_REGRIDDED":
324
- data = CESMBGCDataset(
325
- filename=self.bgc_source["path"],
326
- start_time=self.ini_time,
327
- climatology=self.bgc_source["climatology"],
328
- use_dask=self.use_dask,
329
- )
330
- elif self.bgc_source["name"] == "UNIFIED":
331
- data = UnifiedBGCDataset(
332
- filename=self.bgc_source["path"],
333
- start_time=self.ini_time,
334
- climatology=self.bgc_source["climatology"],
335
- use_dask=self.use_dask,
336
- )
329
+ def _get_data(self, forcing_type=Literal["physics", "bgc"]) -> Dataset:
330
+ """Determine the correct `Dataset` type and return an instance.
337
331
 
338
- else:
339
- raise ValueError(
340
- 'Only "CESM_REGRIDDED" and "UNIFIED" are valid options for bgc_source["name"].'
332
+ forcing_type : str
333
+ Specifies the type of forcing data. Options are:
334
+
335
+ - "physics": for physical atmospheric forcing.
336
+ - "bgc": for biogeochemical forcing.
337
+ Returns
338
+ -------
339
+ Dataset
340
+ The `Dataset` instance
341
+ """
342
+ dataset_map: dict[str, dict[str, dict[str, type[Dataset]]]] = {
343
+ "physics": {
344
+ "GLORYS": {
345
+ "external": GLORYSDataset,
346
+ "default": GLORYSDefaultDataset,
347
+ },
348
+ },
349
+ "bgc": {
350
+ "CESM_REGRIDDED": defaultdict(lambda: CESMBGCDataset),
351
+ "UNIFIED": defaultdict(lambda: UnifiedBGCDataset),
352
+ },
353
+ }
354
+
355
+ source_dict = self.source if forcing_type == "physics" else self.bgc_source
356
+
357
+ if source_dict is None:
358
+ raise ValueError(f"{forcing_type} source is not set")
359
+
360
+ source_name = str(source_dict["name"])
361
+ if source_name not in dataset_map[forcing_type]:
362
+ tpl = 'Valid options for source["name"] for type {} include: {}'
363
+ msg = tpl.format(
364
+ forcing_type, " and ".join(dataset_map[forcing_type].keys())
341
365
  )
366
+ raise ValueError(msg)
342
367
 
343
- return data
368
+ has_no_path = "path" not in source_dict
369
+ has_default_path = source_dict.get("path") == GLORYSDefaultDataset.dataset_name
370
+ use_default = has_no_path or has_default_path
371
+
372
+ variant = "default" if use_default else "external"
373
+
374
+ data_type = dataset_map[forcing_type][source_name][variant]
375
+
376
+ if isinstance(source_dict["path"], bool):
377
+ raise ValueError('source["path"] cannot be a boolean here')
378
+
379
+ return data_type(
380
+ filename=source_dict["path"],
381
+ start_time=self.ini_time,
382
+ climatology=source_dict["climatology"], # type: ignore
383
+ allow_flex_time=self.allow_flex_time,
384
+ use_dask=self.use_dask,
385
+ )
344
386
 
345
387
  def _set_variable_info(self, data, type="physics"):
346
388
  """Sets up a dictionary with metadata for variables based on the type.
@@ -483,10 +525,10 @@ class InitialConditions:
483
525
  zeta = interpolate_from_rho_to_v(zeta)
484
526
 
485
527
  if self.use_dask:
486
- h = h.chunk(get_dask_chunks(location, self.horizontal_chunk_size))
487
- if self.adjust_depth_for_sea_surface_height:
528
+ h = h.chunk(_set_dask_chunks(location, self.horizontal_chunk_size))
529
+ if isinstance(zeta, xr.DataArray):
488
530
  zeta = zeta.chunk(
489
- get_dask_chunks(location, self.horizontal_chunk_size)
531
+ _set_dask_chunks(location, self.horizontal_chunk_size)
490
532
  )
491
533
  depth = compute_depth(zeta, h, self.grid.ds.attrs["hc"], Cs, sigma)
492
534
  self.ds_depth_coords[key] = depth
@@ -816,3 +858,26 @@ class InitialConditions:
816
858
  **initial_conditions_params,
817
859
  use_dask=use_dask,
818
860
  )
861
+
862
+
863
+ def _set_dask_chunks(location: str, chunk_size: int):
864
+ """Returns the appropriate Dask chunking dictionary based on grid location.
865
+
866
+ Parameters
867
+ ----------
868
+ location : str
869
+ The grid location, one of "rho", "u", or "v".
870
+ chunk_size : int
871
+ The chunk size to apply.
872
+
873
+ Returns
874
+ -------
875
+ dict
876
+ Dictionary specifying the chunking strategy.
877
+ """
878
+ chunk_mapping = {
879
+ "rho": {"eta_rho": chunk_size, "xi_rho": chunk_size},
880
+ "u": {"eta_rho": chunk_size, "xi_u": chunk_size},
881
+ "v": {"eta_v": chunk_size, "xi_rho": chunk_size},
882
+ }
883
+ return chunk_mapping.get(location, {})
roms_tools/setup/mask.py CHANGED
@@ -1,5 +1,8 @@
1
+ import logging
1
2
  import warnings
3
+ from pathlib import Path
2
4
 
5
+ import geopandas as gpd
3
6
  import numpy as np
4
7
  import regionmask
5
8
  import xarray as xr
@@ -12,7 +15,7 @@ from roms_tools.setup.utils import (
12
15
  )
13
16
 
14
17
 
15
- def _add_mask(ds):
18
+ def add_mask(ds: xr.Dataset, shapefile: str | Path | None = None) -> xr.Dataset:
16
19
  """Adds a land/water mask to the dataset at rho-points.
17
20
 
18
21
  Parameters
@@ -20,20 +23,47 @@ def _add_mask(ds):
20
23
  ds : xarray.Dataset
21
24
  Input dataset containing latitude and longitude coordinates at rho-points.
22
25
 
26
+ shapefile: str or Path | None
27
+ Path to a coastal shapefile to determine the land mask. If None, NaturalEarth 10m is used.
28
+
23
29
  Returns
24
30
  -------
25
31
  xarray.Dataset
26
32
  The original dataset with an added 'mask_rho' variable, representing land/water mask.
27
33
  """
28
- land = regionmask.defined_regions.natural_earth_v5_0_0.land_10
29
-
30
34
  # Suppress specific warning
31
35
  with warnings.catch_warnings():
32
36
  warnings.filterwarnings(
33
37
  "ignore", message="No gridpoint belongs to any region.*"
34
38
  )
35
- land_mask = land.mask(ds["lon_rho"], ds["lat_rho"])
36
- mask = land_mask.isnull()
39
+
40
+ if shapefile:
41
+ coast = gpd.read_file(shapefile)
42
+
43
+ try:
44
+ # 3D method: returns a boolean array for each region, then take max along the region dimension
45
+ # Pros: more memory-efficient for high-res grids if number of regions isn't extreme
46
+ mask = ~regionmask.mask_3D_geopandas(
47
+ coast, ds["lon_rho"], ds["lat_rho"]
48
+ ).max(dim="region")
49
+
50
+ except MemoryError:
51
+ logging.info(
52
+ "MemoryError encountered with 3D mask; falling back to 2D method."
53
+ )
54
+ # 2D method: returns a single array with integer codes for each region, using np.nan for points not in any region
55
+ # Pros: works well for small/medium grids
56
+ # Cons: can use a large float64 array internally for very high-resolution grids
57
+ mask_2d = regionmask.mask_geopandas(coast, ds["lon_rho"], ds["lat_rho"])
58
+ mask = mask_2d.isnull()
59
+
60
+ else:
61
+ # Use Natural Earth 10m land polygons if no shapefile is provided
62
+ land = regionmask.defined_regions.natural_earth_v5_0_0.land_10
63
+ land_mask = land.mask(ds["lon_rho"], ds["lat_rho"])
64
+ mask = land_mask.isnull()
65
+
66
+ ds = _add_coastlines_metadata(ds, shapefile)
37
67
 
38
68
  # fill enclosed basins with land
39
69
  mask = _fill_enclosed_basins(mask.values)
@@ -45,7 +75,33 @@ def _add_mask(ds):
45
75
  "long_name": "Mask at rho-points",
46
76
  "units": "land/water (0/1)",
47
77
  }
48
- ds = _add_velocity_masks(ds)
78
+
79
+ ds = add_velocity_masks(ds)
80
+
81
+ return ds
82
+
83
+
84
+ def _add_coastlines_metadata(
85
+ ds: xr.Dataset,
86
+ shapefile: str | Path | None = None,
87
+ ) -> xr.Dataset:
88
+ """
89
+ Add coastline metadata to a dataset.
90
+
91
+ Parameters
92
+ ----------
93
+ ds : xarray.Dataset
94
+ Dataset to be updated.
95
+ shapefile : str or pathlib.Path or None, optional
96
+ Path to the shapefile used for land/ocean masking.
97
+
98
+ Returns
99
+ -------
100
+ xarray.Dataset
101
+ Dataset with updated coastline-related metadata.
102
+ """
103
+ if shapefile is not None:
104
+ ds.attrs["mask_shapefile"] = str(shapefile)
49
105
 
50
106
  return ds
51
107
 
@@ -85,7 +141,7 @@ def _fill_enclosed_basins(mask) -> np.ndarray:
85
141
  return mask
86
142
 
87
143
 
88
- def _add_velocity_masks(ds):
144
+ def add_velocity_masks(ds):
89
145
  """Adds velocity masks for u- and v-points based on the rho-point mask.
90
146
 
91
147
  This function generates masks for u- and v-points by interpolating the rho-point land/water mask.
@@ -9,8 +9,9 @@ from scipy.interpolate import griddata, interp1d
9
9
 
10
10
  from roms_tools import Grid
11
11
  from roms_tools.plot import plot_nesting
12
- from roms_tools.setup.topography import _clip_depth
12
+ from roms_tools.setup.topography import clip_depth
13
13
  from roms_tools.setup.utils import (
14
+ Timed,
14
15
  from_yaml,
15
16
  get_boundary_coords,
16
17
  interpolate_from_rho_to_u,
@@ -48,6 +49,8 @@ class ChildGrid(Grid):
48
49
 
49
50
  - `"prefix"` (str): Prefix for variable names in `ds_nesting`. Defaults to `"child"`.
50
51
  - `"period"` (float): Temporal resolution for boundary outputs in seconds. Defaults to 3600 (hourly).
52
+ verbose: bool, optional
53
+ Indicates whether to print grid generation steps with timing. Defaults to False.
51
54
  """
52
55
 
53
56
  parent_grid: Grid
@@ -67,6 +70,8 @@ class ChildGrid(Grid):
67
70
  default_factory=lambda: {"prefix": "child", "period": 3600.0}
68
71
  )
69
72
  """Dictionary configuring the boundary nesting process."""
73
+ verbose: bool = False
74
+ """Whether to print grid generation steps with timing."""
70
75
 
71
76
  ds: xr.Dataset = field(init=False, repr=False)
72
77
  """An xarray Dataset containing child grid variables aligned with the
@@ -77,44 +82,48 @@ class ChildGrid(Grid):
77
82
 
78
83
  def __post_init__(self):
79
84
  super().__post_init__()
80
- self._map_child_boundaries_onto_parent_grid_indices()
81
- self._modify_child_topography_and_mask()
85
+ self._map_child_boundaries_onto_parent_grid_indices(verbose=self.verbose)
86
+ self._modify_child_topography_and_mask(verbose=self.verbose)
82
87
 
83
- def _map_child_boundaries_onto_parent_grid_indices(self):
88
+ def _map_child_boundaries_onto_parent_grid_indices(self, verbose: bool = False):
84
89
  """Maps child grid boundary points onto absolute indices of the parent grid."""
85
- # Prepare parent and child grid datasets by adjusting longitudes for dateline crossing
86
- parent_grid_ds, child_grid_ds = self._prepare_grid_datasets()
87
-
88
- # Map child boundaries onto parent grid indices
89
- ds_nesting = map_child_boundaries_onto_parent_grid_indices(
90
- parent_grid_ds,
91
- child_grid_ds,
92
- self.boundaries,
93
- self.metadata["prefix"],
94
- self.metadata["period"],
95
- )
90
+ with Timed(
91
+ "=== Mapping the child grid boundary points onto the indices of the parent grid ===",
92
+ verbose=verbose,
93
+ ):
94
+ # Prepare parent and child grid datasets by adjusting longitudes for dateline crossing
95
+ parent_grid_ds, child_grid_ds = self._prepare_grid_datasets()
96
+
97
+ # Map child boundaries onto parent grid indices
98
+ ds_nesting = map_child_boundaries_onto_parent_grid_indices(
99
+ parent_grid_ds,
100
+ child_grid_ds,
101
+ self.boundaries,
102
+ self.metadata["prefix"],
103
+ self.metadata["period"],
104
+ )
96
105
 
97
- self.ds_nesting = ds_nesting
106
+ self.ds_nesting = ds_nesting
98
107
 
99
- def _modify_child_topography_and_mask(self):
108
+ def _modify_child_topography_and_mask(self, verbose: bool = False):
100
109
  """Adjust the topography and mask of the child grid to align with the parent grid.
101
110
 
102
111
  Uses a weighted sum based on boundary distance and clips depth values to a
103
112
  minimum.
104
113
  """
105
- # Prepare parent and child grid datasets by adjusting longitudes for dateline crossing
106
- parent_grid_ds, child_grid_ds = self._prepare_grid_datasets()
107
-
108
- child_grid_ds = modify_child_topography_and_mask(
109
- parent_grid_ds, child_grid_ds, self.boundaries, self.hmin
110
- )
114
+ with Timed("=== Modifying child topography and mask ===", verbose=verbose):
115
+ # Prepare parent and child grid datasets by adjusting longitudes for dateline crossing
116
+ parent_grid_ds, child_grid_ds = self._prepare_grid_datasets()
117
+ child_grid_ds = modify_child_topography_and_mask(
118
+ parent_grid_ds, child_grid_ds, self.boundaries, self.hmin
119
+ )
111
120
 
112
- # Finalize grid datasets by adjusting longitudes back to [0, 360] range
113
- parent_grid_ds, child_grid_ds = self._finalize_grid_datasets(
114
- parent_grid_ds, child_grid_ds
115
- )
121
+ # Finalize grid datasets by adjusting longitudes back to [0, 360] range
122
+ parent_grid_ds, child_grid_ds = self._finalize_grid_datasets(
123
+ parent_grid_ds, child_grid_ds
124
+ )
116
125
 
117
- self.ds = child_grid_ds
126
+ self.ds = child_grid_ds
118
127
 
119
128
  def update_topography(
120
129
  self, topography_source=None, hmin=None, verbose=False
@@ -155,7 +164,7 @@ class ChildGrid(Grid):
155
164
  )
156
165
 
157
166
  # Modify child topography and mask to match the parent grid
158
- self._modify_child_topography_and_mask()
167
+ self._modify_child_topography_and_mask(verbose=verbose)
159
168
 
160
169
  def plot_nesting(self, with_dim_names=False) -> None:
161
170
  """Plot the parent and child grids in a single figure.
@@ -216,13 +225,19 @@ class ChildGrid(Grid):
216
225
  write_to_yaml(forcing_dict, filepath)
217
226
 
218
227
  @classmethod
219
- def from_yaml(cls, filepath: str | Path) -> "ChildGrid":
228
+ def from_yaml(
229
+ cls, filepath: str | Path, verbose: bool = False, **kwargs: Any
230
+ ) -> "ChildGrid":
220
231
  """Create an instance of the ChildGrid class from a YAML file.
221
232
 
222
233
  Parameters
223
234
  ----------
224
235
  filepath : Union[str, Path]
225
236
  The path to the YAML file from which the parameters will be read.
237
+ verbose : bool, optional
238
+ Indicates whether to print grid generation steps with timing. Defaults to False.
239
+ **kwargs : Any
240
+ Additional keyword arguments passed to Grid.from_yaml.
226
241
 
227
242
  Returns
228
243
  -------
@@ -231,10 +246,12 @@ class ChildGrid(Grid):
231
246
  """
232
247
  filepath = Path(filepath)
233
248
 
234
- parent_grid = Grid.from_yaml(filepath, "ParentGrid")
249
+ parent_grid = Grid.from_yaml(
250
+ filepath, verbose=verbose, section_name="ParentGrid"
251
+ )
235
252
  params = from_yaml(cls, filepath)
236
253
 
237
- return cls(parent_grid=parent_grid, **params)
254
+ return cls(parent_grid=parent_grid, **params, verbose=verbose)
238
255
 
239
256
  def _prepare_grid_datasets(self) -> tuple[xr.Dataset, xr.Dataset]:
240
257
  """Prepare parent and child grid datasets by adjusting longitudes for dateline
@@ -258,7 +275,7 @@ class ChildGrid(Grid):
258
275
 
259
276
  def _finalize_grid_datasets(
260
277
  self, parent_grid_ds: xr.Dataset, child_grid_ds: xr.Dataset
261
- ) -> None:
278
+ ) -> tuple[xr.Dataset, xr.Dataset]:
262
279
  """Finalize the grid datasets by converting longitudes back to the [0, 360]
263
280
  range.
264
281
 
@@ -269,13 +286,28 @@ class ChildGrid(Grid):
269
286
 
270
287
  child_grid_ds : xr.Dataset
271
288
  The child grid dataset after modifications.
289
+
290
+ Returns
291
+ -------
292
+ tuple[xr.Dataset, xr.Dataset]
293
+ The finalized parent and child grid datasets with longitudes wrapped to [0, 360].
294
+
272
295
  """
273
296
  parent_grid_ds = wrap_longitudes(parent_grid_ds, straddle=False)
274
297
  child_grid_ds = wrap_longitudes(child_grid_ds, straddle=False)
298
+
275
299
  return parent_grid_ds, child_grid_ds
276
300
 
277
301
  @classmethod
278
- def from_file(cls, filepath: str | Path, verbose: bool = False) -> "ChildGrid":
302
+ def from_file(
303
+ cls,
304
+ filepath: str | Path,
305
+ theta_s: float | None = None,
306
+ theta_b: float | None = None,
307
+ hc: float | None = None,
308
+ N: int | None = None,
309
+ verbose: bool = False,
310
+ ) -> "ChildGrid":
279
311
  """This method is disabled in this subclass.
280
312
 
281
313
  .. noindex::
@@ -410,13 +442,6 @@ def map_child_boundaries_onto_parent_grid_indices(
410
442
  }
411
443
  ds = ds.rename(dims_to_rename)
412
444
 
413
- ds = ds.assign_coords(
414
- {
415
- "indices_rho": ("two", ["xi", "eta"]),
416
- "indices_vel": ("three", ["xi", "eta", "angle"]),
417
- }
418
- )
419
-
420
445
  return ds
421
446
 
422
447
 
@@ -643,7 +668,7 @@ def modify_child_topography_and_mask(
643
668
  alpha * child_grid_ds["h"] + (1 - alpha) * h_parent_interpolated
644
669
  )
645
670
  # Clip depth on modified child topography
646
- child_grid_ds["h"] = _clip_depth(child_grid_ds["h"], hmin)
671
+ child_grid_ds["h"] = clip_depth(child_grid_ds["h"], hmin)
647
672
 
648
673
  child_mask = (
649
674
  alpha * child_grid_ds["mask_rho"] + (1 - alpha) * mask_parent_interpolated