roms-tools 3.1.1__py3-none-any.whl → 3.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. roms_tools/__init__.py +8 -1
  2. roms_tools/analysis/cdr_analysis.py +203 -0
  3. roms_tools/analysis/cdr_ensemble.py +198 -0
  4. roms_tools/analysis/roms_output.py +80 -46
  5. roms_tools/data/grids/GLORYS_global_grid.nc +0 -0
  6. roms_tools/download.py +4 -0
  7. roms_tools/plot.py +131 -30
  8. roms_tools/regrid.py +6 -1
  9. roms_tools/setup/boundary_forcing.py +94 -44
  10. roms_tools/setup/cdr_forcing.py +123 -15
  11. roms_tools/setup/cdr_release.py +161 -8
  12. roms_tools/setup/datasets.py +709 -341
  13. roms_tools/setup/grid.py +167 -139
  14. roms_tools/setup/initial_conditions.py +113 -48
  15. roms_tools/setup/mask.py +63 -7
  16. roms_tools/setup/nesting.py +67 -42
  17. roms_tools/setup/river_forcing.py +45 -19
  18. roms_tools/setup/surface_forcing.py +16 -10
  19. roms_tools/setup/tides.py +1 -2
  20. roms_tools/setup/topography.py +4 -4
  21. roms_tools/setup/utils.py +134 -22
  22. roms_tools/tests/test_analysis/test_cdr_analysis.py +144 -0
  23. roms_tools/tests/test_analysis/test_cdr_ensemble.py +202 -0
  24. roms_tools/tests/test_analysis/test_roms_output.py +61 -3
  25. roms_tools/tests/test_setup/test_boundary_forcing.py +111 -52
  26. roms_tools/tests/test_setup/test_cdr_forcing.py +54 -0
  27. roms_tools/tests/test_setup/test_cdr_release.py +118 -1
  28. roms_tools/tests/test_setup/test_datasets.py +458 -34
  29. roms_tools/tests/test_setup/test_grid.py +238 -121
  30. roms_tools/tests/test_setup/test_initial_conditions.py +94 -41
  31. roms_tools/tests/test_setup/test_surface_forcing.py +28 -3
  32. roms_tools/tests/test_setup/test_utils.py +91 -1
  33. roms_tools/tests/test_setup/test_validation.py +21 -15
  34. roms_tools/tests/test_setup/utils.py +71 -0
  35. roms_tools/tests/test_tiling/test_join.py +241 -0
  36. roms_tools/tests/test_tiling/test_partition.py +45 -0
  37. roms_tools/tests/test_utils.py +224 -2
  38. roms_tools/tiling/join.py +189 -0
  39. roms_tools/tiling/partition.py +44 -30
  40. roms_tools/utils.py +488 -161
  41. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/METADATA +15 -4
  42. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/RECORD +45 -37
  43. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/WHEEL +0 -0
  44. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/licenses/LICENSE +0 -0
  45. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/top_level.txt +0 -0
@@ -2,13 +2,103 @@ import logging
2
2
  from datetime import datetime
3
3
  from pathlib import Path
4
4
 
5
+ import numpy as np
5
6
  import pytest
6
7
  import xarray as xr
7
8
 
8
9
  from roms_tools import BoundaryForcing, Grid
9
10
  from roms_tools.download import download_test_data
10
11
  from roms_tools.setup.datasets import ERA5Correction
11
- from roms_tools.setup.utils import interpolate_from_climatology, validate_names
12
+ from roms_tools.setup.utils import (
13
+ get_target_coords,
14
+ interpolate_from_climatology,
15
+ validate_names,
16
+ )
17
+
18
+
19
+ class DummyGrid:
20
+ """Lightweight grid wrapper mimicking the real Grid API for testing."""
21
+
22
+ def __init__(self, ds: xr.Dataset, straddle: bool):
23
+ """Initialize grid wrapper."""
24
+ self.ds = ds
25
+ self.straddle = straddle
26
+
27
+
28
+ class TestGetTargetCoords:
29
+ def make_rho_grid(self, lons, lats, with_mask=False):
30
+ """Helper to create a minimal rho grid dataset."""
31
+ eta, xi = len(lats), len(lons)
32
+ lon_rho, lat_rho = np.meshgrid(lons, lats)
33
+ ds = xr.Dataset(
34
+ {
35
+ "lon_rho": (("eta_rho", "xi_rho"), lon_rho),
36
+ "lat_rho": (("eta_rho", "xi_rho"), lat_rho),
37
+ "angle": (("eta_rho", "xi_rho"), np.zeros_like(lon_rho)),
38
+ },
39
+ coords={"eta_rho": np.arange(eta), "xi_rho": np.arange(xi)},
40
+ )
41
+ if with_mask:
42
+ ds["mask_rho"] = (("eta_rho", "xi_rho"), np.ones_like(lon_rho))
43
+ return ds
44
+
45
+ def test_basic_rho_grid(self):
46
+ ds = self.make_rho_grid(lons=[-10, -5, 0, 5, 10], lats=[50, 55])
47
+ grid = DummyGrid(ds, straddle=True)
48
+ result = get_target_coords(grid)
49
+ assert "lat" in result and "lon" in result
50
+ assert np.allclose(result["lon"], ds.lon_rho)
51
+
52
+ def test_wrap_longitudes_to_minus180_180(self):
53
+ ds = self.make_rho_grid(lons=[190, 200], lats=[0, 1])
54
+ grid = DummyGrid(ds, straddle=True)
55
+ result = get_target_coords(grid)
56
+ # longitudes >180 should wrap to -170, -160
57
+ expected = np.array([[-170, -160], [-170, -160]])
58
+ assert np.allclose(result["lon"].values, expected)
59
+
60
+ def test_convert_to_0_360_if_far_from_greenwich(self):
61
+ ds = self.make_rho_grid(lons=[-170, -160], lats=[0, 1])
62
+ grid = DummyGrid(ds, straddle=False)
63
+ result = get_target_coords(grid)
64
+ # Should convert to 190, 200 since domain is far from Greenwich
65
+ expected = np.array([[190, 200], [190, 200]])
66
+ assert np.allclose(result["lon"].values, expected)
67
+ assert result["straddle"] is False
68
+
69
+ def test_close_to_greenwich_stays_minus180_180(self):
70
+ ds = self.make_rho_grid(lons=[-2, -1], lats=[0, 1])
71
+ grid = DummyGrid(ds, straddle=False)
72
+ result = get_target_coords(grid)
73
+ # Should remain unchanged (-2, -1), not converted to 358, 359
74
+ expected = np.array([[-2, -1], [-2, -1]])
75
+ assert np.allclose(result["lon"].values, expected)
76
+ assert result["straddle"] is True
77
+
78
+ def test_includes_optional_fields(self):
79
+ ds = self.make_rho_grid(lons=[-10, -5], lats=[0, 1], with_mask=True)
80
+ grid = DummyGrid(ds, straddle=True)
81
+ result = get_target_coords(grid)
82
+ assert result["mask"] is not None
83
+
84
+ def test_coarse_grid_selection(self):
85
+ lon = np.array([[190, 200]])
86
+ lat = np.array([[10, 10]])
87
+ ds = xr.Dataset(
88
+ {
89
+ "lon_coarse": (("eta_coarse", "xi_coarse"), lon),
90
+ "lat_coarse": (("eta_coarse", "xi_coarse"), lat),
91
+ "angle_coarse": (("eta_coarse", "xi_coarse"), np.zeros_like(lon)),
92
+ "mask_coarse": (("eta_coarse", "xi_coarse"), np.ones_like(lon)),
93
+ },
94
+ coords={"eta_coarse": [0], "xi_coarse": [0, 1]},
95
+ )
96
+ grid = DummyGrid(ds, straddle=True)
97
+ result = get_target_coords(grid, use_coarse_grid=True)
98
+ # Should wrap longitudes to -170, -160
99
+ expected = np.array([[-170, -160]])
100
+ assert np.allclose(result["lon"].values, expected)
101
+ assert "mask" in result
12
102
 
13
103
 
14
104
  def test_interpolate_from_climatology(use_dask):
@@ -1,15 +1,11 @@
1
- import os
2
1
  import shutil
2
+ from collections.abc import Callable
3
+ from pathlib import Path
3
4
 
4
5
  import pytest
5
6
  import xarray as xr
6
7
 
7
8
 
8
- def _get_fname(name):
9
- dirname = os.path.dirname(__file__)
10
- return os.path.join(dirname, "test_data", f"{name}.zarr")
11
-
12
-
13
9
  @pytest.mark.parametrize(
14
10
  "forcing_fixture",
15
11
  [
@@ -34,7 +30,11 @@ def _get_fname(name):
34
30
  # this test will not be run by default
35
31
  # to run it and overwrite the test data, invoke pytest as follows
36
32
  # pytest --overwrite=tidal_forcing --overwrite=boundary_forcing
37
- def test_save_results(forcing_fixture, request):
33
+ def test_save_results(
34
+ forcing_fixture,
35
+ request: pytest.FixtureRequest,
36
+ get_test_data_path: Callable[[str], Path],
37
+ ) -> None:
38
38
  overwrite = request.config.getoption("--overwrite")
39
39
 
40
40
  # Skip the test if the fixture isn't marked for overwriting, unless 'all' is specified
@@ -42,10 +42,10 @@ def test_save_results(forcing_fixture, request):
42
42
  pytest.skip(f"Skipping overwrite for {forcing_fixture}")
43
43
 
44
44
  forcing = request.getfixturevalue(forcing_fixture)
45
- fname = _get_fname(forcing_fixture)
45
+ fname = get_test_data_path(forcing_fixture)
46
46
 
47
47
  # Check if the Zarr directory exists and delete it if it does
48
- if os.path.exists(fname):
48
+ if fname.exists():
49
49
  shutil.rmtree(fname)
50
50
 
51
51
  forcing.ds.to_zarr(fname)
@@ -72,8 +72,12 @@ def test_save_results(forcing_fixture, request):
72
72
  "river_forcing_no_climatology",
73
73
  ],
74
74
  )
75
- def test_check_results(forcing_fixture, request):
76
- fname = _get_fname(forcing_fixture)
75
+ def test_check_results(
76
+ forcing_fixture,
77
+ request: pytest.FixtureRequest,
78
+ get_test_data_path: Callable[[str], Path],
79
+ ) -> None:
80
+ fname = get_test_data_path(forcing_fixture)
77
81
  expected_forcing_ds = xr.open_zarr(fname, decode_timedelta=False)
78
82
  forcing = request.getfixturevalue(forcing_fixture)
79
83
 
@@ -83,6 +87,7 @@ def test_check_results(forcing_fixture, request):
83
87
  )
84
88
 
85
89
 
90
+ @pytest.mark.use_dask
86
91
  @pytest.mark.parametrize(
87
92
  "forcing_fixture",
88
93
  [
@@ -97,11 +102,12 @@ def test_check_results(forcing_fixture, request):
97
102
  "bgc_boundary_forcing_from_climatology",
98
103
  ],
99
104
  )
100
- def test_dask_vs_no_dask(forcing_fixture, request, tmp_path, use_dask):
105
+ def test_dask_vs_no_dask(
106
+ forcing_fixture: str,
107
+ request: pytest.FixtureRequest,
108
+ tmp_path: Path,
109
+ ) -> None:
101
110
  """Test comparing the forcing created with and without Dask on same platform."""
102
- if not use_dask:
103
- pytest.skip("Test only runs when --use_dask is specified")
104
-
105
111
  # Get the forcing with Dask
106
112
  forcing_with_dask = request.getfixturevalue(forcing_fixture)
107
113
 
@@ -0,0 +1,71 @@
1
+ from datetime import datetime
2
+ from pathlib import Path
3
+
4
+ from roms_tools import Grid, get_glorys_bounds
5
+
6
+ try:
7
+ import copernicusmarine # type: ignore
8
+ except ImportError:
9
+ copernicusmarine = None
10
+
11
+
12
+ def download_regional_and_bigger(
13
+ tmp_path: Path,
14
+ grid: Grid,
15
+ start_time: datetime,
16
+ variables: list[str] = ["thetao", "so", "uo", "vo", "zos"],
17
+ ) -> tuple[Path, Path]:
18
+ """
19
+ Helper: download minimal and slightly bigger GLORYS subsets.
20
+
21
+ Parameters
22
+ ----------
23
+ tmp_path : Path
24
+ Directory to store the downloaded NetCDF files.
25
+ grid : Grid
26
+ ROMS-Tools Grid object defining the target domain.
27
+ start_time : datetime
28
+ Start time of the requested subset.
29
+ variables : list[str]
30
+ What variables to download.
31
+
32
+ Returns
33
+ -------
34
+ Tuple[Path, Path]
35
+ Paths to the minimal and slightly bigger GLORYS subset files.
36
+ """
37
+ bounds = get_glorys_bounds(grid)
38
+
39
+ # minimal dataset
40
+ regional_file = tmp_path / "regional_GLORYS.nc"
41
+ copernicusmarine.subset(
42
+ dataset_id="cmems_mod_glo_phy_my_0.083deg_P1D-m",
43
+ variables=variables,
44
+ **bounds,
45
+ start_datetime=start_time,
46
+ end_datetime=start_time,
47
+ coordinates_selection_method="outside",
48
+ output_filename=str(regional_file),
49
+ )
50
+
51
+ # slightly bigger dataset
52
+ for key, delta in {
53
+ "minimum_latitude": -1,
54
+ "minimum_longitude": -1,
55
+ "maximum_latitude": +1,
56
+ "maximum_longitude": +1,
57
+ }.items():
58
+ bounds[key] += delta
59
+
60
+ bigger_regional_file = tmp_path / "bigger_regional_GLORYS.nc"
61
+ copernicusmarine.subset(
62
+ dataset_id="cmems_mod_glo_phy_my_0.083deg_P1D-m",
63
+ variables=variables,
64
+ **bounds,
65
+ start_datetime=start_time,
66
+ end_datetime=start_time,
67
+ coordinates_selection_method="outside",
68
+ output_filename=str(bigger_regional_file),
69
+ )
70
+
71
+ return regional_file, bigger_regional_file
@@ -0,0 +1,241 @@
1
+ from collections.abc import Callable
2
+ from pathlib import Path
3
+
4
+ import pytest
5
+ import xarray as xr
6
+
7
+ from roms_tools import Grid
8
+ from roms_tools.tiling.join import (
9
+ _find_common_dims,
10
+ _find_transitions,
11
+ _infer_partition_layout_from_datasets,
12
+ join_netcdf,
13
+ open_partitions,
14
+ )
15
+ from roms_tools.tiling.partition import partition_netcdf
16
+
17
+
18
+ @pytest.fixture
19
+ def partitioned_grid_factory(
20
+ tmp_path, large_grid
21
+ ) -> Callable[[int, int], tuple[Grid, list[Path]]]:
22
+ """
23
+ A fixture factory that returns a function to generate partitioned files
24
+ with a configurable layout.
25
+ """
26
+
27
+ def _partitioned_files(np_xi: int, np_eta: int) -> tuple[Grid, list[Path]]:
28
+ partable_grid = large_grid
29
+ partable_grid.save(tmp_path / "test_grid.nc")
30
+ parted_files = partition_netcdf(
31
+ tmp_path / "test_grid.nc", np_xi=np_xi, np_eta=np_eta
32
+ )
33
+ return partable_grid, parted_files
34
+
35
+ return _partitioned_files
36
+
37
+
38
+ @pytest.fixture
39
+ def partitioned_ic_factory(
40
+ tmp_path, initial_conditions_on_large_grid
41
+ ) -> Callable[[int, int], tuple[Path, list[Path]]]:
42
+ def _partitioned_files(np_xi: int, np_eta: int) -> tuple[Path, list[Path]]:
43
+ whole_ics = initial_conditions_on_large_grid
44
+ whole_path = whole_ics.save(tmp_path / "test_ic.nc")[0]
45
+ parted_paths = partition_netcdf(whole_path, np_xi=np_xi, np_eta=np_eta)
46
+ return whole_path, parted_paths
47
+
48
+ return _partitioned_files
49
+
50
+
51
+ class TestHelperFunctions:
52
+ def test_find_common_dims(self):
53
+ """Test _find_common_dims with different datasets and directions.
54
+
55
+ Tests for a valid common dimension, a valid dimension that is not common,
56
+ an invalid direction, and a ValueError when no common dimension is found.
57
+ """
58
+ # Create mock xarray.Dataset objects for testing
59
+ ds1 = xr.Dataset(coords={"xi_rho": [0], "xi_u": [0]})
60
+ ds2 = xr.Dataset(coords={"xi_rho": [0], "xi_u": [0]})
61
+ ds3 = xr.Dataset(coords={"xi_rho": [0], "xi_v": [0]})
62
+ datasets_common = [ds1, ds2]
63
+ datasets_not_common = [ds1, ds3]
64
+
65
+ # Test case with a common dimension ("xi_rho" and "xi_u")
66
+ # The function should find 'xi_rho' and 'xi_u' and return them in a list.
67
+ assert _find_common_dims("xi", datasets_common) == ["xi_rho", "xi_u"]
68
+
69
+ # Test case where a dimension is not common to all datasets
70
+ # The function should find only "xi_rho" as "xi_u" is not in ds3.
71
+ assert _find_common_dims("xi", datasets_not_common) == ["xi_rho"]
72
+
73
+ # Test case for an invalid direction, should raise a ValueError
74
+ with pytest.raises(ValueError, match="'direction' must be 'xi' or 'eta'"):
75
+ _find_common_dims("zeta", datasets_common)
76
+
77
+ # Test case where no common dimensions exist
78
+ ds_no_common1 = xr.Dataset(coords={"xi_rho": [0]})
79
+ ds_no_common2 = xr.Dataset(coords={"xi_v": [0]})
80
+ with pytest.raises(
81
+ ValueError, match="No common point found along direction xi"
82
+ ):
83
+ _find_common_dims("xi", [ds_no_common1, ds_no_common2])
84
+
85
+ def test_find_transitions(self):
86
+ """Test _find_transitions with various input lists.
87
+
88
+ Test cases include lists with no transitions, a single transition,
89
+ multiple transitions, and edge cases like empty and single-element lists.
90
+ """
91
+ # Test case with no transitions
92
+ assert _find_transitions([10, 10, 10, 10]) == []
93
+
94
+ # Test case with a single transition
95
+ assert _find_transitions([10, 10, 12, 12]) == [2]
96
+
97
+ # Test case with multiple transitions
98
+ assert _find_transitions([10, 12, 12, 14, 14, 14]) == [1, 3]
99
+
100
+ # Test case with transitions on every element
101
+ assert _find_transitions([10, 12, 14, 16]) == [1, 2, 3]
102
+
103
+ # Edge case: empty list
104
+ assert _find_transitions([]) == []
105
+
106
+ # Edge case: single-element list
107
+ assert _find_transitions([10]) == []
108
+
109
+ def test_infer_partition_layout_from_datasets(self):
110
+ """Test _infer_partition_layout_from_datasets with various layouts.
111
+
112
+ Tests include a single dataset, a 2x2 grid, a 4x1 grid (single row),
113
+ and a 1x4 grid (single column).
114
+ """
115
+ # Test case 1: Single dataset (1x1 partition)
116
+ ds1 = xr.Dataset(coords={"eta_rho": [0], "xi_rho": [0]})
117
+ assert _infer_partition_layout_from_datasets([ds1]) == (1, 1)
118
+
119
+ # Test case 2: 2x2 grid partition.
120
+ # The eta dimension will transition after the second dataset (np_xi=2).
121
+ ds_2x2_1 = xr.Dataset(coords={"eta_rho": [0] * 20, "xi_rho": [0] * 10})
122
+ ds_2x2_2 = xr.Dataset(coords={"eta_rho": [0] * 20, "xi_rho": [0] * 10})
123
+ ds_2x2_3 = xr.Dataset(coords={"eta_rho": [0] * 10, "xi_rho": [0] * 10})
124
+ ds_2x2_4 = xr.Dataset(coords={"eta_rho": [0] * 10, "xi_rho": [0] * 10})
125
+ datasets_2x2 = [ds_2x2_1, ds_2x2_2, ds_2x2_3, ds_2x2_4]
126
+ assert _infer_partition_layout_from_datasets(datasets_2x2) == (2, 2)
127
+
128
+ # Test case 3: 4x1 grid partition (single row).
129
+ # The eta dimension sizes are all the same, so no transition is detected.
130
+ # The function falls back to returning nd, 1.
131
+ ds_4x1_1 = xr.Dataset(coords={"eta_rho": [0] * 10, "xi_rho": [0] * 5})
132
+ ds_4x1_2 = xr.Dataset(coords={"eta_rho": [0] * 10, "xi_rho": [0] * 5})
133
+ ds_4x1_3 = xr.Dataset(coords={"eta_rho": [0] * 10, "xi_rho": [0] * 5})
134
+ ds_4x1_4 = xr.Dataset(coords={"eta_rho": [0] * 10, "xi_rho": [0] * 5})
135
+ datasets_4x1 = [ds_4x1_1, ds_4x1_2, ds_4x1_3, ds_4x1_4]
136
+ assert _infer_partition_layout_from_datasets(datasets_4x1) == (4, 1)
137
+
138
+ # Test case 4: 1x4 grid partition (single column).
139
+ # The xi dimension is partitioned, so the eta dimensions must change at every step.
140
+ ds_1x4_1 = xr.Dataset(coords={"eta_rho": [0] * 10, "xi_rho": [0] * 20})
141
+ ds_1x4_2 = xr.Dataset(coords={"eta_rho": [0] * 12, "xi_rho": [0] * 20})
142
+ ds_1x4_3 = xr.Dataset(coords={"eta_rho": [0] * 14, "xi_rho": [0] * 20})
143
+ ds_1x4_4 = xr.Dataset(coords={"eta_rho": [0] * 16, "xi_rho": [0] * 20})
144
+ datasets_1x4 = [ds_1x4_1, ds_1x4_2, ds_1x4_3, ds_1x4_4]
145
+ # In this case, `_find_transitions` for eta will find a transition at index 1, so np_xi=1.
146
+ # This will correctly return (1, 4).
147
+ assert _infer_partition_layout_from_datasets(datasets_1x4) == (1, 4)
148
+
149
+
150
+ class TestJoinROMSData:
151
+ @pytest.mark.parametrize(
152
+ "np_xi, np_eta",
153
+ [
154
+ (1, 1),
155
+ (1, 6),
156
+ (6, 1),
157
+ (2, 2),
158
+ (3, 3),
159
+ (3, 4),
160
+ (4, 3), # (12,24)
161
+ # # All possible:
162
+ # # Single-partition grid
163
+ # (1,1)
164
+ # # Single-row grids
165
+ # (2, 1), (3, 1), (4, 1), (6, 1), (12, 1),
166
+ # # Single-column grids
167
+ # (1, 2), (1, 3), (1, 4), (1, 6), (1, 8), (1, 12), (1, 24),
168
+ # # Multi-row, multi-column grids
169
+ # (2, 2), (2, 3), (2, 4), (2, 6), (2, 8), (2, 12), (2, 24),
170
+ # (3, 2), (3, 3), (3, 4), (3, 6), (3, 8), (3, 12), (3, 24),
171
+ # (4, 2), (4, 3), (4, 4), (4, 6), (4, 8), (4, 12), (4, 24),
172
+ # (6, 2), (6, 3), (6, 4), (6, 6), (6, 8), (6, 12), (6, 24),
173
+ # (12, 2), (12, 3), (12, 4), (12, 6), (12, 8), (12, 12), (12, 24)
174
+ ],
175
+ )
176
+ def test_open_grid_partitions(self, partitioned_grid_factory, np_xi, np_eta):
177
+ grid, partitions = partitioned_grid_factory(np_xi=np_xi, np_eta=np_eta)
178
+ joined_grid = open_partitions(partitions)
179
+
180
+ for v in grid.ds.variables:
181
+ assert (grid.ds[v].values == joined_grid[v].values).all(), (
182
+ f"{v} does not match in joined dataset"
183
+ )
184
+ assert grid.ds.attrs == joined_grid.attrs
185
+
186
+ def test_join_grid_netcdf(self, partitioned_grid_factory):
187
+ grid, partitions = partitioned_grid_factory(np_xi=3, np_eta=4)
188
+ joined_netcdf = join_netcdf(
189
+ partitions, output_path=partitions[0].parent / "joined_grid.nc"
190
+ )
191
+ assert joined_netcdf.exists()
192
+ joined_grid = xr.open_dataset(joined_netcdf)
193
+
194
+ for v in grid.ds.variables:
195
+ assert (grid.ds[v].values == joined_grid[v].values).all(), (
196
+ f"{v} does not match in joined dataset"
197
+ )
198
+ assert grid.ds.attrs == joined_grid.attrs
199
+
200
+ @pytest.mark.parametrize(
201
+ "np_xi, np_eta",
202
+ [
203
+ (1, 1),
204
+ (1, 6),
205
+ (6, 1),
206
+ (2, 2),
207
+ (3, 3),
208
+ (3, 4),
209
+ (4, 3), # (12,24)
210
+ ],
211
+ )
212
+ def test_open_initial_condition_partitions(
213
+ self, partitioned_ic_factory, np_xi, np_eta
214
+ ):
215
+ whole_file, partitioned_files = partitioned_ic_factory(
216
+ np_xi=np_xi, np_eta=np_eta
217
+ )
218
+ joined_ics = open_partitions(partitioned_files)
219
+ whole_ics = xr.open_dataset(whole_file, decode_timedelta=True)
220
+
221
+ for v in whole_ics.variables:
222
+ assert (whole_ics[v].values == joined_ics[v].values).all(), (
223
+ f"{v} does not match in joined dataset: {joined_ics[v].values} vs {whole_ics[v].values}"
224
+ )
225
+ assert whole_ics.attrs == joined_ics.attrs
226
+
227
+ def test_join_initial_condition_netcdf(self, tmp_path, partitioned_ic_factory):
228
+ whole_file, partitioned_files = partitioned_ic_factory(np_xi=3, np_eta=4)
229
+ whole_ics = xr.open_dataset(whole_file, decode_timedelta=True)
230
+
231
+ joined_netcdf = join_netcdf(
232
+ partitioned_files, output_path=partitioned_files[0].parent / "joined_ics.nc"
233
+ )
234
+ assert joined_netcdf.exists()
235
+ joined_ics = xr.open_dataset(joined_netcdf, decode_timedelta=True)
236
+
237
+ for v in whole_ics.variables:
238
+ assert (whole_ics[v].values == joined_ics[v].values).all(), (
239
+ f"{v} does not match in joined dataset: {joined_ics[v].values} vs {whole_ics[v].values}"
240
+ )
241
+ assert whole_ics.attrs == joined_ics.attrs
@@ -297,3 +297,48 @@ class TestPartitionNetcdf:
297
297
  for expected_filepath in expected_filepath_list:
298
298
  assert expected_filepath.exists()
299
299
  expected_filepath.unlink()
300
+
301
+ def test_partition_netcdf_with_output_dir(self, grid, tmp_path):
302
+ # Save the input file
303
+ input_file = tmp_path / "input_grid.nc"
304
+ grid.save(input_file)
305
+
306
+ # Create a custom output directory
307
+ output_dir = tmp_path / "custom_output"
308
+ output_dir.mkdir()
309
+
310
+ saved_filenames = partition_netcdf(
311
+ input_file, np_eta=3, np_xi=5, output_dir=output_dir
312
+ )
313
+
314
+ base_name = input_file.stem # "input_grid"
315
+ expected_filenames = [output_dir / f"{base_name}.{i:02d}.nc" for i in range(15)]
316
+
317
+ assert saved_filenames == expected_filenames
318
+
319
+ for f in expected_filenames:
320
+ assert f.exists()
321
+ f.unlink()
322
+
323
+ def test_partition_netcdf_multiple_files(self, grid, tmp_path):
324
+ # Create two test input files
325
+ file1 = tmp_path / "grid1.nc"
326
+ file2 = tmp_path / "grid2.nc"
327
+ grid.save(file1)
328
+ grid.save(file2)
329
+
330
+ # Run partitioning with 2x2 tiles on both files
331
+ saved_filenames = partition_netcdf([file1, file2], np_eta=3, np_xi=5)
332
+
333
+ # Expect 4 tiles per file → 8 total output files
334
+ expected_filepaths = []
335
+ for file in [file1, file2]:
336
+ base = file.with_suffix("")
337
+ expected_filepaths += [Path(f"{base}.{i:02d}.nc") for i in range(15)]
338
+
339
+ assert len(saved_filenames) == 30
340
+ assert saved_filenames == expected_filepaths
341
+
342
+ for path in expected_filepaths:
343
+ assert path.exists()
344
+ path.unlink()