roms-tools 3.1.2__py3-none-any.whl → 3.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. roms_tools/__init__.py +3 -0
  2. roms_tools/analysis/cdr_analysis.py +203 -0
  3. roms_tools/analysis/cdr_ensemble.py +198 -0
  4. roms_tools/analysis/roms_output.py +80 -46
  5. roms_tools/data/grids/GLORYS_global_grid.nc +0 -0
  6. roms_tools/download.py +4 -0
  7. roms_tools/plot.py +75 -21
  8. roms_tools/setup/boundary_forcing.py +44 -19
  9. roms_tools/setup/cdr_forcing.py +122 -8
  10. roms_tools/setup/cdr_release.py +161 -8
  11. roms_tools/setup/datasets.py +626 -340
  12. roms_tools/setup/grid.py +138 -137
  13. roms_tools/setup/initial_conditions.py +113 -48
  14. roms_tools/setup/mask.py +63 -7
  15. roms_tools/setup/nesting.py +67 -42
  16. roms_tools/setup/river_forcing.py +45 -19
  17. roms_tools/setup/surface_forcing.py +4 -6
  18. roms_tools/setup/tides.py +1 -2
  19. roms_tools/setup/topography.py +4 -4
  20. roms_tools/setup/utils.py +134 -22
  21. roms_tools/tests/test_analysis/test_cdr_analysis.py +144 -0
  22. roms_tools/tests/test_analysis/test_cdr_ensemble.py +202 -0
  23. roms_tools/tests/test_analysis/test_roms_output.py +61 -3
  24. roms_tools/tests/test_setup/test_boundary_forcing.py +54 -52
  25. roms_tools/tests/test_setup/test_cdr_forcing.py +54 -0
  26. roms_tools/tests/test_setup/test_cdr_release.py +118 -1
  27. roms_tools/tests/test_setup/test_datasets.py +392 -44
  28. roms_tools/tests/test_setup/test_grid.py +222 -115
  29. roms_tools/tests/test_setup/test_initial_conditions.py +94 -41
  30. roms_tools/tests/test_setup/test_surface_forcing.py +2 -1
  31. roms_tools/tests/test_setup/test_utils.py +91 -1
  32. roms_tools/tests/test_setup/utils.py +71 -0
  33. roms_tools/tests/test_tiling/test_join.py +241 -0
  34. roms_tools/tests/test_utils.py +139 -17
  35. roms_tools/tiling/join.py +189 -0
  36. roms_tools/utils.py +131 -99
  37. {roms_tools-3.1.2.dist-info → roms_tools-3.2.0.dist-info}/METADATA +12 -2
  38. {roms_tools-3.1.2.dist-info → roms_tools-3.2.0.dist-info}/RECORD +41 -33
  39. {roms_tools-3.1.2.dist-info → roms_tools-3.2.0.dist-info}/WHEEL +0 -0
  40. {roms_tools-3.1.2.dist-info → roms_tools-3.2.0.dist-info}/licenses/LICENSE +0 -0
  41. {roms_tools-3.1.2.dist-info → roms_tools-3.2.0.dist-info}/top_level.txt +0 -0
@@ -2,13 +2,103 @@ import logging
2
2
  from datetime import datetime
3
3
  from pathlib import Path
4
4
 
5
+ import numpy as np
5
6
  import pytest
6
7
  import xarray as xr
7
8
 
8
9
  from roms_tools import BoundaryForcing, Grid
9
10
  from roms_tools.download import download_test_data
10
11
  from roms_tools.setup.datasets import ERA5Correction
11
- from roms_tools.setup.utils import interpolate_from_climatology, validate_names
12
+ from roms_tools.setup.utils import (
13
+ get_target_coords,
14
+ interpolate_from_climatology,
15
+ validate_names,
16
+ )
17
+
18
+
19
+ class DummyGrid:
20
+ """Lightweight grid wrapper mimicking the real Grid API for testing."""
21
+
22
+ def __init__(self, ds: xr.Dataset, straddle: bool):
23
+ """Initialize grid wrapper."""
24
+ self.ds = ds
25
+ self.straddle = straddle
26
+
27
+
28
+ class TestGetTargetCoords:
29
+ def make_rho_grid(self, lons, lats, with_mask=False):
30
+ """Helper to create a minimal rho grid dataset."""
31
+ eta, xi = len(lats), len(lons)
32
+ lon_rho, lat_rho = np.meshgrid(lons, lats)
33
+ ds = xr.Dataset(
34
+ {
35
+ "lon_rho": (("eta_rho", "xi_rho"), lon_rho),
36
+ "lat_rho": (("eta_rho", "xi_rho"), lat_rho),
37
+ "angle": (("eta_rho", "xi_rho"), np.zeros_like(lon_rho)),
38
+ },
39
+ coords={"eta_rho": np.arange(eta), "xi_rho": np.arange(xi)},
40
+ )
41
+ if with_mask:
42
+ ds["mask_rho"] = (("eta_rho", "xi_rho"), np.ones_like(lon_rho))
43
+ return ds
44
+
45
+ def test_basic_rho_grid(self):
46
+ ds = self.make_rho_grid(lons=[-10, -5, 0, 5, 10], lats=[50, 55])
47
+ grid = DummyGrid(ds, straddle=True)
48
+ result = get_target_coords(grid)
49
+ assert "lat" in result and "lon" in result
50
+ assert np.allclose(result["lon"], ds.lon_rho)
51
+
52
+ def test_wrap_longitudes_to_minus180_180(self):
53
+ ds = self.make_rho_grid(lons=[190, 200], lats=[0, 1])
54
+ grid = DummyGrid(ds, straddle=True)
55
+ result = get_target_coords(grid)
56
+ # longitudes >180 should wrap to -170, -160
57
+ expected = np.array([[-170, -160], [-170, -160]])
58
+ assert np.allclose(result["lon"].values, expected)
59
+
60
+ def test_convert_to_0_360_if_far_from_greenwich(self):
61
+ ds = self.make_rho_grid(lons=[-170, -160], lats=[0, 1])
62
+ grid = DummyGrid(ds, straddle=False)
63
+ result = get_target_coords(grid)
64
+ # Should convert to 190, 200 since domain is far from Greenwich
65
+ expected = np.array([[190, 200], [190, 200]])
66
+ assert np.allclose(result["lon"].values, expected)
67
+ assert result["straddle"] is False
68
+
69
+ def test_close_to_greenwich_stays_minus180_180(self):
70
+ ds = self.make_rho_grid(lons=[-2, -1], lats=[0, 1])
71
+ grid = DummyGrid(ds, straddle=False)
72
+ result = get_target_coords(grid)
73
+ # Should remain unchanged (-2, -1), not converted to 358, 359
74
+ expected = np.array([[-2, -1], [-2, -1]])
75
+ assert np.allclose(result["lon"].values, expected)
76
+ assert result["straddle"] is True
77
+
78
+ def test_includes_optional_fields(self):
79
+ ds = self.make_rho_grid(lons=[-10, -5], lats=[0, 1], with_mask=True)
80
+ grid = DummyGrid(ds, straddle=True)
81
+ result = get_target_coords(grid)
82
+ assert result["mask"] is not None
83
+
84
+ def test_coarse_grid_selection(self):
85
+ lon = np.array([[190, 200]])
86
+ lat = np.array([[10, 10]])
87
+ ds = xr.Dataset(
88
+ {
89
+ "lon_coarse": (("eta_coarse", "xi_coarse"), lon),
90
+ "lat_coarse": (("eta_coarse", "xi_coarse"), lat),
91
+ "angle_coarse": (("eta_coarse", "xi_coarse"), np.zeros_like(lon)),
92
+ "mask_coarse": (("eta_coarse", "xi_coarse"), np.ones_like(lon)),
93
+ },
94
+ coords={"eta_coarse": [0], "xi_coarse": [0, 1]},
95
+ )
96
+ grid = DummyGrid(ds, straddle=True)
97
+ result = get_target_coords(grid, use_coarse_grid=True)
98
+ # Should wrap longitudes to -170, -160
99
+ expected = np.array([[-170, -160]])
100
+ assert np.allclose(result["lon"].values, expected)
101
+ assert "mask" in result
12
102
 
13
103
 
14
104
  def test_interpolate_from_climatology(use_dask):
@@ -0,0 +1,71 @@
1
+ from datetime import datetime
2
+ from pathlib import Path
3
+
4
+ from roms_tools import Grid, get_glorys_bounds
5
+
6
+ try:
7
+ import copernicusmarine # type: ignore
8
+ except ImportError:
9
+ copernicusmarine = None
10
+
11
+
12
+ def download_regional_and_bigger(
13
+ tmp_path: Path,
14
+ grid: Grid,
15
+ start_time: datetime,
16
+ variables: list[str] = ["thetao", "so", "uo", "vo", "zos"],
17
+ ) -> tuple[Path, Path]:
18
+ """
19
+ Helper: download minimal and slightly bigger GLORYS subsets.
20
+
21
+ Parameters
22
+ ----------
23
+ tmp_path : Path
24
+ Directory to store the downloaded NetCDF files.
25
+ grid : Grid
26
+ ROMS-Tools Grid object defining the target domain.
27
+ start_time : datetime
28
+ Start time of the requested subset.
29
+ variables : list[str]
30
+ What variables to download.
31
+
32
+ Returns
33
+ -------
34
+ Tuple[Path, Path]
35
+ Paths to the minimal and slightly bigger GLORYS subset files.
36
+ """
37
+ bounds = get_glorys_bounds(grid)
38
+
39
+ # minimal dataset
40
+ regional_file = tmp_path / "regional_GLORYS.nc"
41
+ copernicusmarine.subset(
42
+ dataset_id="cmems_mod_glo_phy_my_0.083deg_P1D-m",
43
+ variables=variables,
44
+ **bounds,
45
+ start_datetime=start_time,
46
+ end_datetime=start_time,
47
+ coordinates_selection_method="outside",
48
+ output_filename=str(regional_file),
49
+ )
50
+
51
+ # slightly bigger dataset
52
+ for key, delta in {
53
+ "minimum_latitude": -1,
54
+ "minimum_longitude": -1,
55
+ "maximum_latitude": +1,
56
+ "maximum_longitude": +1,
57
+ }.items():
58
+ bounds[key] += delta
59
+
60
+ bigger_regional_file = tmp_path / "bigger_regional_GLORYS.nc"
61
+ copernicusmarine.subset(
62
+ dataset_id="cmems_mod_glo_phy_my_0.083deg_P1D-m",
63
+ variables=variables,
64
+ **bounds,
65
+ start_datetime=start_time,
66
+ end_datetime=start_time,
67
+ coordinates_selection_method="outside",
68
+ output_filename=str(bigger_regional_file),
69
+ )
70
+
71
+ return regional_file, bigger_regional_file
@@ -0,0 +1,241 @@
1
+ from collections.abc import Callable
2
+ from pathlib import Path
3
+
4
+ import pytest
5
+ import xarray as xr
6
+
7
+ from roms_tools import Grid
8
+ from roms_tools.tiling.join import (
9
+ _find_common_dims,
10
+ _find_transitions,
11
+ _infer_partition_layout_from_datasets,
12
+ join_netcdf,
13
+ open_partitions,
14
+ )
15
+ from roms_tools.tiling.partition import partition_netcdf
16
+
17
+
18
+ @pytest.fixture
19
+ def partitioned_grid_factory(
20
+ tmp_path, large_grid
21
+ ) -> Callable[[int, int], tuple[Grid, list[Path]]]:
22
+ """
23
+ A fixture factory that returns a function to generate partitioned files
24
+ with a configurable layout.
25
+ """
26
+
27
+ def _partitioned_files(np_xi: int, np_eta: int) -> tuple[Grid, list[Path]]:
28
+ partable_grid = large_grid
29
+ partable_grid.save(tmp_path / "test_grid.nc")
30
+ parted_files = partition_netcdf(
31
+ tmp_path / "test_grid.nc", np_xi=np_xi, np_eta=np_eta
32
+ )
33
+ return partable_grid, parted_files
34
+
35
+ return _partitioned_files
36
+
37
+
38
+ @pytest.fixture
39
+ def partitioned_ic_factory(
40
+ tmp_path, initial_conditions_on_large_grid
41
+ ) -> Callable[[int, int], tuple[Path, list[Path]]]:
42
+ def _partitioned_files(np_xi: int, np_eta: int) -> tuple[Path, list[Path]]:
43
+ whole_ics = initial_conditions_on_large_grid
44
+ whole_path = whole_ics.save(tmp_path / "test_ic.nc")[0]
45
+ parted_paths = partition_netcdf(whole_path, np_xi=np_xi, np_eta=np_eta)
46
+ return whole_path, parted_paths
47
+
48
+ return _partitioned_files
49
+
50
+
51
+ class TestHelperFunctions:
52
+ def test_find_common_dims(self):
53
+ """Test _find_common_dims with different datasets and directions.
54
+
55
+ Tests for a valid common dimension, a valid dimension that is not common,
56
+ an invalid direction, and a ValueError when no common dimension is found.
57
+ """
58
+ # Create mock xarray.Dataset objects for testing
59
+ ds1 = xr.Dataset(coords={"xi_rho": [0], "xi_u": [0]})
60
+ ds2 = xr.Dataset(coords={"xi_rho": [0], "xi_u": [0]})
61
+ ds3 = xr.Dataset(coords={"xi_rho": [0], "xi_v": [0]})
62
+ datasets_common = [ds1, ds2]
63
+ datasets_not_common = [ds1, ds3]
64
+
65
+ # Test case with a common dimension ("xi_rho" and "xi_u")
66
+ # The function should find 'xi_rho' and 'xi_u' and return them in a list.
67
+ assert _find_common_dims("xi", datasets_common) == ["xi_rho", "xi_u"]
68
+
69
+ # Test case where a dimension is not common to all datasets
70
+ # The function should find only "xi_rho" as "xi_u" is not in ds3.
71
+ assert _find_common_dims("xi", datasets_not_common) == ["xi_rho"]
72
+
73
+ # Test case for an invalid direction, should raise a ValueError
74
+ with pytest.raises(ValueError, match="'direction' must be 'xi' or 'eta'"):
75
+ _find_common_dims("zeta", datasets_common)
76
+
77
+ # Test case where no common dimensions exist
78
+ ds_no_common1 = xr.Dataset(coords={"xi_rho": [0]})
79
+ ds_no_common2 = xr.Dataset(coords={"xi_v": [0]})
80
+ with pytest.raises(
81
+ ValueError, match="No common point found along direction xi"
82
+ ):
83
+ _find_common_dims("xi", [ds_no_common1, ds_no_common2])
84
+
85
+ def test_find_transitions(self):
86
+ """Test _find_transitions with various input lists.
87
+
88
+ Test cases include lists with no transitions, a single transition,
89
+ multiple transitions, and edge cases like empty and single-element lists.
90
+ """
91
+ # Test case with no transitions
92
+ assert _find_transitions([10, 10, 10, 10]) == []
93
+
94
+ # Test case with a single transition
95
+ assert _find_transitions([10, 10, 12, 12]) == [2]
96
+
97
+ # Test case with multiple transitions
98
+ assert _find_transitions([10, 12, 12, 14, 14, 14]) == [1, 3]
99
+
100
+ # Test case with transitions on every element
101
+ assert _find_transitions([10, 12, 14, 16]) == [1, 2, 3]
102
+
103
+ # Edge case: empty list
104
+ assert _find_transitions([]) == []
105
+
106
+ # Edge case: single-element list
107
+ assert _find_transitions([10]) == []
108
+
109
+ def test_infer_partition_layout_from_datasets(self):
110
+ """Test _infer_partition_layout_from_datasets with various layouts.
111
+
112
+ Tests include a single dataset, a 2x2 grid, a 4x1 grid (single row),
113
+ and a 1x4 grid (single column).
114
+ """
115
+ # Test case 1: Single dataset (1x1 partition)
116
+ ds1 = xr.Dataset(coords={"eta_rho": [0], "xi_rho": [0]})
117
+ assert _infer_partition_layout_from_datasets([ds1]) == (1, 1)
118
+
119
+ # Test case 2: 2x2 grid partition.
120
+ # The eta dimension will transition after the second dataset (np_xi=2).
121
+ ds_2x2_1 = xr.Dataset(coords={"eta_rho": [0] * 20, "xi_rho": [0] * 10})
122
+ ds_2x2_2 = xr.Dataset(coords={"eta_rho": [0] * 20, "xi_rho": [0] * 10})
123
+ ds_2x2_3 = xr.Dataset(coords={"eta_rho": [0] * 10, "xi_rho": [0] * 10})
124
+ ds_2x2_4 = xr.Dataset(coords={"eta_rho": [0] * 10, "xi_rho": [0] * 10})
125
+ datasets_2x2 = [ds_2x2_1, ds_2x2_2, ds_2x2_3, ds_2x2_4]
126
+ assert _infer_partition_layout_from_datasets(datasets_2x2) == (2, 2)
127
+
128
+ # Test case 3: 4x1 grid partition (single row).
129
+ # The eta dimension sizes are all the same, so no transition is detected.
130
+ # The function falls back to returning nd, 1.
131
+ ds_4x1_1 = xr.Dataset(coords={"eta_rho": [0] * 10, "xi_rho": [0] * 5})
132
+ ds_4x1_2 = xr.Dataset(coords={"eta_rho": [0] * 10, "xi_rho": [0] * 5})
133
+ ds_4x1_3 = xr.Dataset(coords={"eta_rho": [0] * 10, "xi_rho": [0] * 5})
134
+ ds_4x1_4 = xr.Dataset(coords={"eta_rho": [0] * 10, "xi_rho": [0] * 5})
135
+ datasets_4x1 = [ds_4x1_1, ds_4x1_2, ds_4x1_3, ds_4x1_4]
136
+ assert _infer_partition_layout_from_datasets(datasets_4x1) == (4, 1)
137
+
138
+ # Test case 4: 1x4 grid partition (single column).
139
+ # The xi dimension is partitioned, so the eta dimensions must change at every step.
140
+ ds_1x4_1 = xr.Dataset(coords={"eta_rho": [0] * 10, "xi_rho": [0] * 20})
141
+ ds_1x4_2 = xr.Dataset(coords={"eta_rho": [0] * 12, "xi_rho": [0] * 20})
142
+ ds_1x4_3 = xr.Dataset(coords={"eta_rho": [0] * 14, "xi_rho": [0] * 20})
143
+ ds_1x4_4 = xr.Dataset(coords={"eta_rho": [0] * 16, "xi_rho": [0] * 20})
144
+ datasets_1x4 = [ds_1x4_1, ds_1x4_2, ds_1x4_3, ds_1x4_4]
145
+ # In this case, `_find_transitions` for eta will find a transition at index 1, so np_xi=1.
146
+ # This will correctly return (1, 4).
147
+ assert _infer_partition_layout_from_datasets(datasets_1x4) == (1, 4)
148
+
149
+
150
+ class TestJoinROMSData:
151
+ @pytest.mark.parametrize(
152
+ "np_xi, np_eta",
153
+ [
154
+ (1, 1),
155
+ (1, 6),
156
+ (6, 1),
157
+ (2, 2),
158
+ (3, 3),
159
+ (3, 4),
160
+ (4, 3), # (12,24)
161
+ # # All possible:
162
+ # # Single-partition grid
163
+ # (1,1)
164
+ # # Single-row grids
165
+ # (2, 1), (3, 1), (4, 1), (6, 1), (12, 1),
166
+ # # Single-column grids
167
+ # (1, 2), (1, 3), (1, 4), (1, 6), (1, 8), (1, 12), (1, 24),
168
+ # # Multi-row, multi-column grids
169
+ # (2, 2), (2, 3), (2, 4), (2, 6), (2, 8), (2, 12), (2, 24),
170
+ # (3, 2), (3, 3), (3, 4), (3, 6), (3, 8), (3, 12), (3, 24),
171
+ # (4, 2), (4, 3), (4, 4), (4, 6), (4, 8), (4, 12), (4, 24),
172
+ # (6, 2), (6, 3), (6, 4), (6, 6), (6, 8), (6, 12), (6, 24),
173
+ # (12, 2), (12, 3), (12, 4), (12, 6), (12, 8), (12, 12), (12, 24)
174
+ ],
175
+ )
176
+ def test_open_grid_partitions(self, partitioned_grid_factory, np_xi, np_eta):
177
+ grid, partitions = partitioned_grid_factory(np_xi=np_xi, np_eta=np_eta)
178
+ joined_grid = open_partitions(partitions)
179
+
180
+ for v in grid.ds.variables:
181
+ assert (grid.ds[v].values == joined_grid[v].values).all(), (
182
+ f"{v} does not match in joined dataset"
183
+ )
184
+ assert grid.ds.attrs == joined_grid.attrs
185
+
186
+ def test_join_grid_netcdf(self, partitioned_grid_factory):
187
+ grid, partitions = partitioned_grid_factory(np_xi=3, np_eta=4)
188
+ joined_netcdf = join_netcdf(
189
+ partitions, output_path=partitions[0].parent / "joined_grid.nc"
190
+ )
191
+ assert joined_netcdf.exists()
192
+ joined_grid = xr.open_dataset(joined_netcdf)
193
+
194
+ for v in grid.ds.variables:
195
+ assert (grid.ds[v].values == joined_grid[v].values).all(), (
196
+ f"{v} does not match in joined dataset"
197
+ )
198
+ assert grid.ds.attrs == joined_grid.attrs
199
+
200
+ @pytest.mark.parametrize(
201
+ "np_xi, np_eta",
202
+ [
203
+ (1, 1),
204
+ (1, 6),
205
+ (6, 1),
206
+ (2, 2),
207
+ (3, 3),
208
+ (3, 4),
209
+ (4, 3), # (12,24)
210
+ ],
211
+ )
212
+ def test_open_initial_condition_partitions(
213
+ self, partitioned_ic_factory, np_xi, np_eta
214
+ ):
215
+ whole_file, partitioned_files = partitioned_ic_factory(
216
+ np_xi=np_xi, np_eta=np_eta
217
+ )
218
+ joined_ics = open_partitions(partitioned_files)
219
+ whole_ics = xr.open_dataset(whole_file, decode_timedelta=True)
220
+
221
+ for v in whole_ics.variables:
222
+ assert (whole_ics[v].values == joined_ics[v].values).all(), (
223
+ f"{v} does not match in joined dataset: {joined_ics[v].values} vs {whole_ics[v].values}"
224
+ )
225
+ assert whole_ics.attrs == joined_ics.attrs
226
+
227
+ def test_join_initial_condition_netcdf(self, tmp_path, partitioned_ic_factory):
228
+ whole_file, partitioned_files = partitioned_ic_factory(np_xi=3, np_eta=4)
229
+ whole_ics = xr.open_dataset(whole_file, decode_timedelta=True)
230
+
231
+ joined_netcdf = join_netcdf(
232
+ partitioned_files, output_path=partitioned_files[0].parent / "joined_ics.nc"
233
+ )
234
+ assert joined_netcdf.exists()
235
+ joined_ics = xr.open_dataset(joined_netcdf, decode_timedelta=True)
236
+
237
+ for v in whole_ics.variables:
238
+ assert (whole_ics[v].values == joined_ics[v].values).all(), (
239
+ f"{v} does not match in joined dataset: {joined_ics[v].values} vs {whole_ics[v].values}"
240
+ )
241
+ assert whole_ics.attrs == joined_ics.attrs
@@ -7,11 +7,13 @@ import pytest
7
7
  import xarray as xr
8
8
 
9
9
  from roms_tools.utils import (
10
- _generate_focused_coordinate_range,
11
- _has_copernicus,
12
- _has_dask,
13
- _has_gcsfs,
14
- _load_data,
10
+ _path_list_from_input,
11
+ generate_focused_coordinate_range,
12
+ get_dask_chunks,
13
+ has_copernicus,
14
+ has_dask,
15
+ has_gcsfs,
16
+ load_data,
15
17
  )
16
18
 
17
19
 
@@ -25,66 +27,132 @@ from roms_tools.utils import (
25
27
  ],
26
28
  )
27
29
  def test_coordinate_range_monotonicity(min_val, max_val, center, sc, N):
28
- centers, faces = _generate_focused_coordinate_range(
30
+ centers, faces = generate_focused_coordinate_range(
29
31
  min_val=min_val, max_val=max_val, center=center, sc=sc, N=N
30
32
  )
31
33
  assert np.all(np.diff(faces) > 0), "faces is not strictly increasing"
32
34
  assert np.all(np.diff(centers) > 0), "centers is not strictly increasing"
33
35
 
34
36
 
37
+ class TestPathListFromInput:
38
+ """A collection of tests for the _path_list_from_input function."""
39
+
40
+ # Test cases that don't require I/O
41
+ def test_list_of_strings(self):
42
+ """Test with a list of file paths as strings."""
43
+ files_list = ["path/to/file1.txt", "path/to/file2.txt"]
44
+ result = _path_list_from_input(files_list)
45
+ assert len(result) == 2
46
+ assert result[0] == Path("path/to/file1.txt")
47
+ assert result[1] == Path("path/to/file2.txt")
48
+
49
+ def test_list_of_path_objects(self):
50
+ """Test with a list of pathlib.Path objects."""
51
+ files_list = [Path("file_a.txt"), Path("file_b.txt")]
52
+ result = _path_list_from_input(files_list)
53
+ assert len(result) == 2
54
+ assert result[0] == Path("file_a.txt")
55
+ assert result[1] == Path("file_b.txt")
56
+
57
+ def test_single_path_object(self):
58
+ """Test with a single pathlib.Path object."""
59
+ file_path = Path("a_single_file.csv")
60
+ result = _path_list_from_input(file_path)
61
+ assert len(result) == 1
62
+ assert result[0] == file_path
63
+
64
+ def test_invalid_input_type_raises(self):
65
+ """Test that an invalid input type raises a TypeError."""
66
+ with pytest.raises(TypeError, match="'files' should be str, Path, or List"):
67
+ _path_list_from_input(123)
68
+
69
+ # Test cases that require I/O and `tmp_path`
70
+ def test_single_file_as_str(self, tmp_path):
71
+ """Test with a single file given as a string, requiring a file to exist."""
72
+ p = tmp_path / "test_file.txt"
73
+ p.touch()
74
+ result = _path_list_from_input(str(p))
75
+ assert len(result) == 1
76
+ assert result[0] == p
77
+
78
+ def test_wildcard_pattern(self, tmp_path, monkeypatch):
79
+ """Test with a wildcard pattern, requiring files to exist, using monkeypatch."""
80
+ # Setup
81
+ d = tmp_path / "data"
82
+ d.mkdir()
83
+ (d / "file1.csv").touch()
84
+ (d / "file2.csv").touch()
85
+ (d / "other_file.txt").touch()
86
+
87
+ # Action: Temporarily change the current working directory
88
+ monkeypatch.chdir(tmp_path)
89
+
90
+ result = _path_list_from_input("data/*.csv")
91
+
92
+ # Assertion
93
+ assert len(result) == 2
94
+ assert result[0].name == "file1.csv"
95
+ assert result[1].name == "file2.csv"
96
+
97
+ def test_non_matching_pattern_raises(self, tmp_path):
98
+ """Test that a non-matching pattern raises a FileNotFoundError."""
99
+ with pytest.raises(FileNotFoundError, match="No files matched"):
100
+ _path_list_from_input(str(tmp_path / "non_existent_file_*.txt"))
101
+
102
+
35
103
  def test_has_dask() -> None:
36
104
  """Verify that dask existence is correctly reported when found."""
37
105
  with mock.patch("roms_tools.utils.find_spec", return_value=mock.MagicMock):
38
- assert _has_dask()
106
+ assert has_dask()
39
107
 
40
108
 
41
109
  def test_has_dask_error_when_missing() -> None:
42
110
  """Verify that dask existence is correctly reported when not found."""
43
111
  with mock.patch("roms_tools.utils.find_spec", return_value=None):
44
- assert not _has_dask()
112
+ assert not has_dask()
45
113
 
46
114
 
47
115
  def test_has_gcfs() -> None:
48
116
  """Verify that GCFS existence is correctly reported when found."""
49
117
  with mock.patch("roms_tools.utils.find_spec", return_value=mock.MagicMock):
50
- assert _has_gcsfs()
118
+ assert has_gcsfs()
51
119
 
52
120
 
53
121
  def test_has_gcfs_error_when_missing() -> None:
54
122
  """Verify that GCFS existence is correctly reported when not found."""
55
123
  with mock.patch("roms_tools.utils.find_spec", return_value=None):
56
- assert not _has_gcsfs()
124
+ assert not has_gcsfs()
57
125
 
58
126
 
59
127
  def test_has_copernicus() -> None:
60
128
  """Verify that copernicus existence is correctly reported when found."""
61
129
  with mock.patch("roms_tools.utils.find_spec", return_value=mock.MagicMock):
62
- assert _has_copernicus()
130
+ assert has_copernicus()
63
131
 
64
132
 
65
133
  def test_has_copernicus_error_when_missing() -> None:
66
134
  """Verify that copernicus existence is correctly reported when not found."""
67
135
  with mock.patch("roms_tools.utils.find_spec", return_value=None):
68
- assert not _has_copernicus()
136
+ assert not has_copernicus()
69
137
 
70
138
 
71
139
  def test_load_data_dask_not_found() -> None:
72
140
  """Verify that load data raises an exception when dask is requested and missing."""
73
141
  with (
74
- mock.patch("roms_tools.utils._has_dask", return_value=False),
142
+ mock.patch("roms_tools.utils.has_dask", return_value=False),
75
143
  pytest.raises(RuntimeError),
76
144
  ):
77
- _load_data("foo.zarr", {"a": "a"}, use_dask=True)
145
+ load_data("foo.zarr", {"a": "a"}, use_dask=True)
78
146
 
79
147
 
80
148
  def test_load_data_open_zarr_without_dask() -> None:
81
149
  """Verify that load data raises an exception when zarr is requested without dask."""
82
150
  with (
83
- mock.patch("roms_tools.utils._has_dask", return_value=False),
151
+ mock.patch("roms_tools.utils.has_dask", return_value=False),
84
152
  pytest.raises(ValueError),
85
153
  ):
86
154
  # read_zarr should require use_dask to be True
87
- _load_data("foo.zarr", {"a": ""}, use_dask=False, read_zarr=True)
155
+ load_data("foo.zarr", {"a": ""}, use_dask=False, read_zarr=True)
88
156
 
89
157
 
90
158
  @pytest.mark.parametrize(
@@ -111,7 +179,7 @@ def test_load_data_open_dataset(
111
179
  "roms_tools.utils.xr.open_dataset",
112
180
  wraps=xr.open_dataset,
113
181
  ) as fn_od:
114
- ds = _load_data(
182
+ ds = load_data(
115
183
  ds_path,
116
184
  {"latitude": "latitude"},
117
185
  use_dask=False,
@@ -119,3 +187,57 @@ def test_load_data_open_dataset(
119
187
  assert fn_od.called
120
188
 
121
189
  assert expected_dim in ds.dims
190
+
191
+
192
+ # test get_dask_chunks
193
+
194
+
195
+ def test_latlon_default_chunks():
196
+ dim_names = {"latitude": "lat", "longitude": "lon"}
197
+ expected = {"lat": -1, "lon": -1}
198
+ result = get_dask_chunks(dim_names)
199
+ assert result == expected
200
+
201
+
202
+ def test_latlon_with_depth_and_time():
203
+ dim_names = {"latitude": "lat", "longitude": "lon", "depth": "z", "time": "t"}
204
+ expected = {"lat": -1, "lon": -1, "z": -1, "t": 1}
205
+ result = get_dask_chunks(dim_names)
206
+ assert result == expected
207
+
208
+
209
+ def test_latlon_with_time_chunking_false():
210
+ dim_names = {"latitude": "lat", "longitude": "lon", "time": "t"}
211
+ expected = {"lat": -1, "lon": -1}
212
+ result = get_dask_chunks(dim_names, time_chunking=False)
213
+ assert result == expected
214
+
215
+
216
+ def test_roms_default_chunks():
217
+ dim_names = {}
218
+ expected_keys = {"eta_rho", "eta_v", "xi_rho", "xi_u", "s_rho"}
219
+ result = get_dask_chunks(dim_names)
220
+ assert set(result.keys()) == expected_keys
221
+ assert all(v == -1 for v in result.values())
222
+
223
+
224
+ def test_roms_with_depth_and_time():
225
+ dim_names = {"depth": "s_rho", "time": "ocean_time"}
226
+ result = get_dask_chunks(dim_names)
227
+ # ROMS default keys + depth + time
228
+ expected_keys = {"eta_rho", "eta_v", "xi_rho", "xi_u", "s_rho", "ocean_time"}
229
+ assert set(result.keys()) == expected_keys
230
+ assert result["ocean_time"] == 1
231
+ assert result["s_rho"] == -1
232
+
233
+
234
+ def test_roms_with_ntides():
235
+ dim_names = {"ntides": "nt"}
236
+ result = get_dask_chunks(dim_names)
237
+ assert result["nt"] == 1
238
+
239
+
240
+ def test_time_chunking_false_roms():
241
+ dim_names = {"time": "ocean_time"}
242
+ result = get_dask_chunks(dim_names, time_chunking=False)
243
+ assert "ocean_time" not in result