roms-tools 3.1.2__py3-none-any.whl → 3.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- roms_tools/__init__.py +3 -0
- roms_tools/analysis/cdr_analysis.py +203 -0
- roms_tools/analysis/cdr_ensemble.py +198 -0
- roms_tools/analysis/roms_output.py +80 -46
- roms_tools/data/grids/GLORYS_global_grid.nc +0 -0
- roms_tools/download.py +4 -0
- roms_tools/plot.py +75 -21
- roms_tools/setup/boundary_forcing.py +44 -19
- roms_tools/setup/cdr_forcing.py +122 -8
- roms_tools/setup/cdr_release.py +161 -8
- roms_tools/setup/datasets.py +626 -340
- roms_tools/setup/grid.py +138 -137
- roms_tools/setup/initial_conditions.py +113 -48
- roms_tools/setup/mask.py +63 -7
- roms_tools/setup/nesting.py +67 -42
- roms_tools/setup/river_forcing.py +45 -19
- roms_tools/setup/surface_forcing.py +4 -6
- roms_tools/setup/tides.py +1 -2
- roms_tools/setup/topography.py +4 -4
- roms_tools/setup/utils.py +134 -22
- roms_tools/tests/test_analysis/test_cdr_analysis.py +144 -0
- roms_tools/tests/test_analysis/test_cdr_ensemble.py +202 -0
- roms_tools/tests/test_analysis/test_roms_output.py +61 -3
- roms_tools/tests/test_setup/test_boundary_forcing.py +54 -52
- roms_tools/tests/test_setup/test_cdr_forcing.py +54 -0
- roms_tools/tests/test_setup/test_cdr_release.py +118 -1
- roms_tools/tests/test_setup/test_datasets.py +392 -44
- roms_tools/tests/test_setup/test_grid.py +222 -115
- roms_tools/tests/test_setup/test_initial_conditions.py +94 -41
- roms_tools/tests/test_setup/test_surface_forcing.py +2 -1
- roms_tools/tests/test_setup/test_utils.py +91 -1
- roms_tools/tests/test_setup/utils.py +71 -0
- roms_tools/tests/test_tiling/test_join.py +241 -0
- roms_tools/tests/test_utils.py +139 -17
- roms_tools/tiling/join.py +189 -0
- roms_tools/utils.py +131 -99
- {roms_tools-3.1.2.dist-info → roms_tools-3.2.0.dist-info}/METADATA +12 -2
- {roms_tools-3.1.2.dist-info → roms_tools-3.2.0.dist-info}/RECORD +41 -33
- {roms_tools-3.1.2.dist-info → roms_tools-3.2.0.dist-info}/WHEEL +0 -0
- {roms_tools-3.1.2.dist-info → roms_tools-3.2.0.dist-info}/licenses/LICENSE +0 -0
- {roms_tools-3.1.2.dist-info → roms_tools-3.2.0.dist-info}/top_level.txt +0 -0
|
@@ -2,13 +2,103 @@ import logging
|
|
|
2
2
|
from datetime import datetime
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
|
|
5
|
+
import numpy as np
|
|
5
6
|
import pytest
|
|
6
7
|
import xarray as xr
|
|
7
8
|
|
|
8
9
|
from roms_tools import BoundaryForcing, Grid
|
|
9
10
|
from roms_tools.download import download_test_data
|
|
10
11
|
from roms_tools.setup.datasets import ERA5Correction
|
|
11
|
-
from roms_tools.setup.utils import
|
|
12
|
+
from roms_tools.setup.utils import (
|
|
13
|
+
get_target_coords,
|
|
14
|
+
interpolate_from_climatology,
|
|
15
|
+
validate_names,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DummyGrid:
|
|
20
|
+
"""Lightweight grid wrapper mimicking the real Grid API for testing."""
|
|
21
|
+
|
|
22
|
+
def __init__(self, ds: xr.Dataset, straddle: bool):
|
|
23
|
+
"""Initialize grid wrapper."""
|
|
24
|
+
self.ds = ds
|
|
25
|
+
self.straddle = straddle
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class TestGetTargetCoords:
|
|
29
|
+
def make_rho_grid(self, lons, lats, with_mask=False):
|
|
30
|
+
"""Helper to create a minimal rho grid dataset."""
|
|
31
|
+
eta, xi = len(lats), len(lons)
|
|
32
|
+
lon_rho, lat_rho = np.meshgrid(lons, lats)
|
|
33
|
+
ds = xr.Dataset(
|
|
34
|
+
{
|
|
35
|
+
"lon_rho": (("eta_rho", "xi_rho"), lon_rho),
|
|
36
|
+
"lat_rho": (("eta_rho", "xi_rho"), lat_rho),
|
|
37
|
+
"angle": (("eta_rho", "xi_rho"), np.zeros_like(lon_rho)),
|
|
38
|
+
},
|
|
39
|
+
coords={"eta_rho": np.arange(eta), "xi_rho": np.arange(xi)},
|
|
40
|
+
)
|
|
41
|
+
if with_mask:
|
|
42
|
+
ds["mask_rho"] = (("eta_rho", "xi_rho"), np.ones_like(lon_rho))
|
|
43
|
+
return ds
|
|
44
|
+
|
|
45
|
+
def test_basic_rho_grid(self):
|
|
46
|
+
ds = self.make_rho_grid(lons=[-10, -5, 0, 5, 10], lats=[50, 55])
|
|
47
|
+
grid = DummyGrid(ds, straddle=True)
|
|
48
|
+
result = get_target_coords(grid)
|
|
49
|
+
assert "lat" in result and "lon" in result
|
|
50
|
+
assert np.allclose(result["lon"], ds.lon_rho)
|
|
51
|
+
|
|
52
|
+
def test_wrap_longitudes_to_minus180_180(self):
|
|
53
|
+
ds = self.make_rho_grid(lons=[190, 200], lats=[0, 1])
|
|
54
|
+
grid = DummyGrid(ds, straddle=True)
|
|
55
|
+
result = get_target_coords(grid)
|
|
56
|
+
# longitudes >180 should wrap to -170, -160
|
|
57
|
+
expected = np.array([[-170, -160], [-170, -160]])
|
|
58
|
+
assert np.allclose(result["lon"].values, expected)
|
|
59
|
+
|
|
60
|
+
def test_convert_to_0_360_if_far_from_greenwich(self):
|
|
61
|
+
ds = self.make_rho_grid(lons=[-170, -160], lats=[0, 1])
|
|
62
|
+
grid = DummyGrid(ds, straddle=False)
|
|
63
|
+
result = get_target_coords(grid)
|
|
64
|
+
# Should convert to 190, 200 since domain is far from Greenwich
|
|
65
|
+
expected = np.array([[190, 200], [190, 200]])
|
|
66
|
+
assert np.allclose(result["lon"].values, expected)
|
|
67
|
+
assert result["straddle"] is False
|
|
68
|
+
|
|
69
|
+
def test_close_to_greenwich_stays_minus180_180(self):
|
|
70
|
+
ds = self.make_rho_grid(lons=[-2, -1], lats=[0, 1])
|
|
71
|
+
grid = DummyGrid(ds, straddle=False)
|
|
72
|
+
result = get_target_coords(grid)
|
|
73
|
+
# Should remain unchanged (-2, -1), not converted to 358, 359
|
|
74
|
+
expected = np.array([[-2, -1], [-2, -1]])
|
|
75
|
+
assert np.allclose(result["lon"].values, expected)
|
|
76
|
+
assert result["straddle"] is True
|
|
77
|
+
|
|
78
|
+
def test_includes_optional_fields(self):
|
|
79
|
+
ds = self.make_rho_grid(lons=[-10, -5], lats=[0, 1], with_mask=True)
|
|
80
|
+
grid = DummyGrid(ds, straddle=True)
|
|
81
|
+
result = get_target_coords(grid)
|
|
82
|
+
assert result["mask"] is not None
|
|
83
|
+
|
|
84
|
+
def test_coarse_grid_selection(self):
|
|
85
|
+
lon = np.array([[190, 200]])
|
|
86
|
+
lat = np.array([[10, 10]])
|
|
87
|
+
ds = xr.Dataset(
|
|
88
|
+
{
|
|
89
|
+
"lon_coarse": (("eta_coarse", "xi_coarse"), lon),
|
|
90
|
+
"lat_coarse": (("eta_coarse", "xi_coarse"), lat),
|
|
91
|
+
"angle_coarse": (("eta_coarse", "xi_coarse"), np.zeros_like(lon)),
|
|
92
|
+
"mask_coarse": (("eta_coarse", "xi_coarse"), np.ones_like(lon)),
|
|
93
|
+
},
|
|
94
|
+
coords={"eta_coarse": [0], "xi_coarse": [0, 1]},
|
|
95
|
+
)
|
|
96
|
+
grid = DummyGrid(ds, straddle=True)
|
|
97
|
+
result = get_target_coords(grid, use_coarse_grid=True)
|
|
98
|
+
# Should wrap longitudes to -170, -160
|
|
99
|
+
expected = np.array([[-170, -160]])
|
|
100
|
+
assert np.allclose(result["lon"].values, expected)
|
|
101
|
+
assert "mask" in result
|
|
12
102
|
|
|
13
103
|
|
|
14
104
|
def test_interpolate_from_climatology(use_dask):
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from roms_tools import Grid, get_glorys_bounds
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
import copernicusmarine # type: ignore
|
|
8
|
+
except ImportError:
|
|
9
|
+
copernicusmarine = None
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def download_regional_and_bigger(
|
|
13
|
+
tmp_path: Path,
|
|
14
|
+
grid: Grid,
|
|
15
|
+
start_time: datetime,
|
|
16
|
+
variables: list[str] = ["thetao", "so", "uo", "vo", "zos"],
|
|
17
|
+
) -> tuple[Path, Path]:
|
|
18
|
+
"""
|
|
19
|
+
Helper: download minimal and slightly bigger GLORYS subsets.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
tmp_path : Path
|
|
24
|
+
Directory to store the downloaded NetCDF files.
|
|
25
|
+
grid : Grid
|
|
26
|
+
ROMS-Tools Grid object defining the target domain.
|
|
27
|
+
start_time : datetime
|
|
28
|
+
Start time of the requested subset.
|
|
29
|
+
variables : list[str]
|
|
30
|
+
What variables to download.
|
|
31
|
+
|
|
32
|
+
Returns
|
|
33
|
+
-------
|
|
34
|
+
Tuple[Path, Path]
|
|
35
|
+
Paths to the minimal and slightly bigger GLORYS subset files.
|
|
36
|
+
"""
|
|
37
|
+
bounds = get_glorys_bounds(grid)
|
|
38
|
+
|
|
39
|
+
# minimal dataset
|
|
40
|
+
regional_file = tmp_path / "regional_GLORYS.nc"
|
|
41
|
+
copernicusmarine.subset(
|
|
42
|
+
dataset_id="cmems_mod_glo_phy_my_0.083deg_P1D-m",
|
|
43
|
+
variables=variables,
|
|
44
|
+
**bounds,
|
|
45
|
+
start_datetime=start_time,
|
|
46
|
+
end_datetime=start_time,
|
|
47
|
+
coordinates_selection_method="outside",
|
|
48
|
+
output_filename=str(regional_file),
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# slightly bigger dataset
|
|
52
|
+
for key, delta in {
|
|
53
|
+
"minimum_latitude": -1,
|
|
54
|
+
"minimum_longitude": -1,
|
|
55
|
+
"maximum_latitude": +1,
|
|
56
|
+
"maximum_longitude": +1,
|
|
57
|
+
}.items():
|
|
58
|
+
bounds[key] += delta
|
|
59
|
+
|
|
60
|
+
bigger_regional_file = tmp_path / "bigger_regional_GLORYS.nc"
|
|
61
|
+
copernicusmarine.subset(
|
|
62
|
+
dataset_id="cmems_mod_glo_phy_my_0.083deg_P1D-m",
|
|
63
|
+
variables=variables,
|
|
64
|
+
**bounds,
|
|
65
|
+
start_datetime=start_time,
|
|
66
|
+
end_datetime=start_time,
|
|
67
|
+
coordinates_selection_method="outside",
|
|
68
|
+
output_filename=str(bigger_regional_file),
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
return regional_file, bigger_regional_file
|
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
import xarray as xr
|
|
6
|
+
|
|
7
|
+
from roms_tools import Grid
|
|
8
|
+
from roms_tools.tiling.join import (
|
|
9
|
+
_find_common_dims,
|
|
10
|
+
_find_transitions,
|
|
11
|
+
_infer_partition_layout_from_datasets,
|
|
12
|
+
join_netcdf,
|
|
13
|
+
open_partitions,
|
|
14
|
+
)
|
|
15
|
+
from roms_tools.tiling.partition import partition_netcdf
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@pytest.fixture
|
|
19
|
+
def partitioned_grid_factory(
|
|
20
|
+
tmp_path, large_grid
|
|
21
|
+
) -> Callable[[int, int], tuple[Grid, list[Path]]]:
|
|
22
|
+
"""
|
|
23
|
+
A fixture factory that returns a function to generate partitioned files
|
|
24
|
+
with a configurable layout.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def _partitioned_files(np_xi: int, np_eta: int) -> tuple[Grid, list[Path]]:
|
|
28
|
+
partable_grid = large_grid
|
|
29
|
+
partable_grid.save(tmp_path / "test_grid.nc")
|
|
30
|
+
parted_files = partition_netcdf(
|
|
31
|
+
tmp_path / "test_grid.nc", np_xi=np_xi, np_eta=np_eta
|
|
32
|
+
)
|
|
33
|
+
return partable_grid, parted_files
|
|
34
|
+
|
|
35
|
+
return _partitioned_files
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@pytest.fixture
|
|
39
|
+
def partitioned_ic_factory(
|
|
40
|
+
tmp_path, initial_conditions_on_large_grid
|
|
41
|
+
) -> Callable[[int, int], tuple[Path, list[Path]]]:
|
|
42
|
+
def _partitioned_files(np_xi: int, np_eta: int) -> tuple[Path, list[Path]]:
|
|
43
|
+
whole_ics = initial_conditions_on_large_grid
|
|
44
|
+
whole_path = whole_ics.save(tmp_path / "test_ic.nc")[0]
|
|
45
|
+
parted_paths = partition_netcdf(whole_path, np_xi=np_xi, np_eta=np_eta)
|
|
46
|
+
return whole_path, parted_paths
|
|
47
|
+
|
|
48
|
+
return _partitioned_files
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class TestHelperFunctions:
|
|
52
|
+
def test_find_common_dims(self):
|
|
53
|
+
"""Test _find_common_dims with different datasets and directions.
|
|
54
|
+
|
|
55
|
+
Tests for a valid common dimension, a valid dimension that is not common,
|
|
56
|
+
an invalid direction, and a ValueError when no common dimension is found.
|
|
57
|
+
"""
|
|
58
|
+
# Create mock xarray.Dataset objects for testing
|
|
59
|
+
ds1 = xr.Dataset(coords={"xi_rho": [0], "xi_u": [0]})
|
|
60
|
+
ds2 = xr.Dataset(coords={"xi_rho": [0], "xi_u": [0]})
|
|
61
|
+
ds3 = xr.Dataset(coords={"xi_rho": [0], "xi_v": [0]})
|
|
62
|
+
datasets_common = [ds1, ds2]
|
|
63
|
+
datasets_not_common = [ds1, ds3]
|
|
64
|
+
|
|
65
|
+
# Test case with a common dimension ("xi_rho" and "xi_u")
|
|
66
|
+
# The function should find 'xi_rho' and 'xi_u' and return them in a list.
|
|
67
|
+
assert _find_common_dims("xi", datasets_common) == ["xi_rho", "xi_u"]
|
|
68
|
+
|
|
69
|
+
# Test case where a dimension is not common to all datasets
|
|
70
|
+
# The function should find only "xi_rho" as "xi_u" is not in ds3.
|
|
71
|
+
assert _find_common_dims("xi", datasets_not_common) == ["xi_rho"]
|
|
72
|
+
|
|
73
|
+
# Test case for an invalid direction, should raise a ValueError
|
|
74
|
+
with pytest.raises(ValueError, match="'direction' must be 'xi' or 'eta'"):
|
|
75
|
+
_find_common_dims("zeta", datasets_common)
|
|
76
|
+
|
|
77
|
+
# Test case where no common dimensions exist
|
|
78
|
+
ds_no_common1 = xr.Dataset(coords={"xi_rho": [0]})
|
|
79
|
+
ds_no_common2 = xr.Dataset(coords={"xi_v": [0]})
|
|
80
|
+
with pytest.raises(
|
|
81
|
+
ValueError, match="No common point found along direction xi"
|
|
82
|
+
):
|
|
83
|
+
_find_common_dims("xi", [ds_no_common1, ds_no_common2])
|
|
84
|
+
|
|
85
|
+
def test_find_transitions(self):
|
|
86
|
+
"""Test _find_transitions with various input lists.
|
|
87
|
+
|
|
88
|
+
Test cases include lists with no transitions, a single transition,
|
|
89
|
+
multiple transitions, and edge cases like empty and single-element lists.
|
|
90
|
+
"""
|
|
91
|
+
# Test case with no transitions
|
|
92
|
+
assert _find_transitions([10, 10, 10, 10]) == []
|
|
93
|
+
|
|
94
|
+
# Test case with a single transition
|
|
95
|
+
assert _find_transitions([10, 10, 12, 12]) == [2]
|
|
96
|
+
|
|
97
|
+
# Test case with multiple transitions
|
|
98
|
+
assert _find_transitions([10, 12, 12, 14, 14, 14]) == [1, 3]
|
|
99
|
+
|
|
100
|
+
# Test case with transitions on every element
|
|
101
|
+
assert _find_transitions([10, 12, 14, 16]) == [1, 2, 3]
|
|
102
|
+
|
|
103
|
+
# Edge case: empty list
|
|
104
|
+
assert _find_transitions([]) == []
|
|
105
|
+
|
|
106
|
+
# Edge case: single-element list
|
|
107
|
+
assert _find_transitions([10]) == []
|
|
108
|
+
|
|
109
|
+
def test_infer_partition_layout_from_datasets(self):
|
|
110
|
+
"""Test _infer_partition_layout_from_datasets with various layouts.
|
|
111
|
+
|
|
112
|
+
Tests include a single dataset, a 2x2 grid, a 4x1 grid (single row),
|
|
113
|
+
and a 1x4 grid (single column).
|
|
114
|
+
"""
|
|
115
|
+
# Test case 1: Single dataset (1x1 partition)
|
|
116
|
+
ds1 = xr.Dataset(coords={"eta_rho": [0], "xi_rho": [0]})
|
|
117
|
+
assert _infer_partition_layout_from_datasets([ds1]) == (1, 1)
|
|
118
|
+
|
|
119
|
+
# Test case 2: 2x2 grid partition.
|
|
120
|
+
# The eta dimension will transition after the second dataset (np_xi=2).
|
|
121
|
+
ds_2x2_1 = xr.Dataset(coords={"eta_rho": [0] * 20, "xi_rho": [0] * 10})
|
|
122
|
+
ds_2x2_2 = xr.Dataset(coords={"eta_rho": [0] * 20, "xi_rho": [0] * 10})
|
|
123
|
+
ds_2x2_3 = xr.Dataset(coords={"eta_rho": [0] * 10, "xi_rho": [0] * 10})
|
|
124
|
+
ds_2x2_4 = xr.Dataset(coords={"eta_rho": [0] * 10, "xi_rho": [0] * 10})
|
|
125
|
+
datasets_2x2 = [ds_2x2_1, ds_2x2_2, ds_2x2_3, ds_2x2_4]
|
|
126
|
+
assert _infer_partition_layout_from_datasets(datasets_2x2) == (2, 2)
|
|
127
|
+
|
|
128
|
+
# Test case 3: 4x1 grid partition (single row).
|
|
129
|
+
# The eta dimension sizes are all the same, so no transition is detected.
|
|
130
|
+
# The function falls back to returning nd, 1.
|
|
131
|
+
ds_4x1_1 = xr.Dataset(coords={"eta_rho": [0] * 10, "xi_rho": [0] * 5})
|
|
132
|
+
ds_4x1_2 = xr.Dataset(coords={"eta_rho": [0] * 10, "xi_rho": [0] * 5})
|
|
133
|
+
ds_4x1_3 = xr.Dataset(coords={"eta_rho": [0] * 10, "xi_rho": [0] * 5})
|
|
134
|
+
ds_4x1_4 = xr.Dataset(coords={"eta_rho": [0] * 10, "xi_rho": [0] * 5})
|
|
135
|
+
datasets_4x1 = [ds_4x1_1, ds_4x1_2, ds_4x1_3, ds_4x1_4]
|
|
136
|
+
assert _infer_partition_layout_from_datasets(datasets_4x1) == (4, 1)
|
|
137
|
+
|
|
138
|
+
# Test case 4: 1x4 grid partition (single column).
|
|
139
|
+
# The xi dimension is partitioned, so the eta dimensions must change at every step.
|
|
140
|
+
ds_1x4_1 = xr.Dataset(coords={"eta_rho": [0] * 10, "xi_rho": [0] * 20})
|
|
141
|
+
ds_1x4_2 = xr.Dataset(coords={"eta_rho": [0] * 12, "xi_rho": [0] * 20})
|
|
142
|
+
ds_1x4_3 = xr.Dataset(coords={"eta_rho": [0] * 14, "xi_rho": [0] * 20})
|
|
143
|
+
ds_1x4_4 = xr.Dataset(coords={"eta_rho": [0] * 16, "xi_rho": [0] * 20})
|
|
144
|
+
datasets_1x4 = [ds_1x4_1, ds_1x4_2, ds_1x4_3, ds_1x4_4]
|
|
145
|
+
# In this case, `_find_transitions` for eta will find a transition at index 1, so np_xi=1.
|
|
146
|
+
# This will correctly return (1, 4).
|
|
147
|
+
assert _infer_partition_layout_from_datasets(datasets_1x4) == (1, 4)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class TestJoinROMSData:
|
|
151
|
+
@pytest.mark.parametrize(
|
|
152
|
+
"np_xi, np_eta",
|
|
153
|
+
[
|
|
154
|
+
(1, 1),
|
|
155
|
+
(1, 6),
|
|
156
|
+
(6, 1),
|
|
157
|
+
(2, 2),
|
|
158
|
+
(3, 3),
|
|
159
|
+
(3, 4),
|
|
160
|
+
(4, 3), # (12,24)
|
|
161
|
+
# # All possible:
|
|
162
|
+
# # Single-partition grid
|
|
163
|
+
# (1,1)
|
|
164
|
+
# # Single-row grids
|
|
165
|
+
# (2, 1), (3, 1), (4, 1), (6, 1), (12, 1),
|
|
166
|
+
# # Single-column grids
|
|
167
|
+
# (1, 2), (1, 3), (1, 4), (1, 6), (1, 8), (1, 12), (1, 24),
|
|
168
|
+
# # Multi-row, multi-column grids
|
|
169
|
+
# (2, 2), (2, 3), (2, 4), (2, 6), (2, 8), (2, 12), (2, 24),
|
|
170
|
+
# (3, 2), (3, 3), (3, 4), (3, 6), (3, 8), (3, 12), (3, 24),
|
|
171
|
+
# (4, 2), (4, 3), (4, 4), (4, 6), (4, 8), (4, 12), (4, 24),
|
|
172
|
+
# (6, 2), (6, 3), (6, 4), (6, 6), (6, 8), (6, 12), (6, 24),
|
|
173
|
+
# (12, 2), (12, 3), (12, 4), (12, 6), (12, 8), (12, 12), (12, 24)
|
|
174
|
+
],
|
|
175
|
+
)
|
|
176
|
+
def test_open_grid_partitions(self, partitioned_grid_factory, np_xi, np_eta):
|
|
177
|
+
grid, partitions = partitioned_grid_factory(np_xi=np_xi, np_eta=np_eta)
|
|
178
|
+
joined_grid = open_partitions(partitions)
|
|
179
|
+
|
|
180
|
+
for v in grid.ds.variables:
|
|
181
|
+
assert (grid.ds[v].values == joined_grid[v].values).all(), (
|
|
182
|
+
f"{v} does not match in joined dataset"
|
|
183
|
+
)
|
|
184
|
+
assert grid.ds.attrs == joined_grid.attrs
|
|
185
|
+
|
|
186
|
+
def test_join_grid_netcdf(self, partitioned_grid_factory):
|
|
187
|
+
grid, partitions = partitioned_grid_factory(np_xi=3, np_eta=4)
|
|
188
|
+
joined_netcdf = join_netcdf(
|
|
189
|
+
partitions, output_path=partitions[0].parent / "joined_grid.nc"
|
|
190
|
+
)
|
|
191
|
+
assert joined_netcdf.exists()
|
|
192
|
+
joined_grid = xr.open_dataset(joined_netcdf)
|
|
193
|
+
|
|
194
|
+
for v in grid.ds.variables:
|
|
195
|
+
assert (grid.ds[v].values == joined_grid[v].values).all(), (
|
|
196
|
+
f"{v} does not match in joined dataset"
|
|
197
|
+
)
|
|
198
|
+
assert grid.ds.attrs == joined_grid.attrs
|
|
199
|
+
|
|
200
|
+
@pytest.mark.parametrize(
|
|
201
|
+
"np_xi, np_eta",
|
|
202
|
+
[
|
|
203
|
+
(1, 1),
|
|
204
|
+
(1, 6),
|
|
205
|
+
(6, 1),
|
|
206
|
+
(2, 2),
|
|
207
|
+
(3, 3),
|
|
208
|
+
(3, 4),
|
|
209
|
+
(4, 3), # (12,24)
|
|
210
|
+
],
|
|
211
|
+
)
|
|
212
|
+
def test_open_initial_condition_partitions(
|
|
213
|
+
self, partitioned_ic_factory, np_xi, np_eta
|
|
214
|
+
):
|
|
215
|
+
whole_file, partitioned_files = partitioned_ic_factory(
|
|
216
|
+
np_xi=np_xi, np_eta=np_eta
|
|
217
|
+
)
|
|
218
|
+
joined_ics = open_partitions(partitioned_files)
|
|
219
|
+
whole_ics = xr.open_dataset(whole_file, decode_timedelta=True)
|
|
220
|
+
|
|
221
|
+
for v in whole_ics.variables:
|
|
222
|
+
assert (whole_ics[v].values == joined_ics[v].values).all(), (
|
|
223
|
+
f"{v} does not match in joined dataset: {joined_ics[v].values} vs {whole_ics[v].values}"
|
|
224
|
+
)
|
|
225
|
+
assert whole_ics.attrs == joined_ics.attrs
|
|
226
|
+
|
|
227
|
+
def test_join_initial_condition_netcdf(self, tmp_path, partitioned_ic_factory):
|
|
228
|
+
whole_file, partitioned_files = partitioned_ic_factory(np_xi=3, np_eta=4)
|
|
229
|
+
whole_ics = xr.open_dataset(whole_file, decode_timedelta=True)
|
|
230
|
+
|
|
231
|
+
joined_netcdf = join_netcdf(
|
|
232
|
+
partitioned_files, output_path=partitioned_files[0].parent / "joined_ics.nc"
|
|
233
|
+
)
|
|
234
|
+
assert joined_netcdf.exists()
|
|
235
|
+
joined_ics = xr.open_dataset(joined_netcdf, decode_timedelta=True)
|
|
236
|
+
|
|
237
|
+
for v in whole_ics.variables:
|
|
238
|
+
assert (whole_ics[v].values == joined_ics[v].values).all(), (
|
|
239
|
+
f"{v} does not match in joined dataset: {joined_ics[v].values} vs {whole_ics[v].values}"
|
|
240
|
+
)
|
|
241
|
+
assert whole_ics.attrs == joined_ics.attrs
|
roms_tools/tests/test_utils.py
CHANGED
|
@@ -7,11 +7,13 @@ import pytest
|
|
|
7
7
|
import xarray as xr
|
|
8
8
|
|
|
9
9
|
from roms_tools.utils import (
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
10
|
+
_path_list_from_input,
|
|
11
|
+
generate_focused_coordinate_range,
|
|
12
|
+
get_dask_chunks,
|
|
13
|
+
has_copernicus,
|
|
14
|
+
has_dask,
|
|
15
|
+
has_gcsfs,
|
|
16
|
+
load_data,
|
|
15
17
|
)
|
|
16
18
|
|
|
17
19
|
|
|
@@ -25,66 +27,132 @@ from roms_tools.utils import (
|
|
|
25
27
|
],
|
|
26
28
|
)
|
|
27
29
|
def test_coordinate_range_monotonicity(min_val, max_val, center, sc, N):
|
|
28
|
-
centers, faces =
|
|
30
|
+
centers, faces = generate_focused_coordinate_range(
|
|
29
31
|
min_val=min_val, max_val=max_val, center=center, sc=sc, N=N
|
|
30
32
|
)
|
|
31
33
|
assert np.all(np.diff(faces) > 0), "faces is not strictly increasing"
|
|
32
34
|
assert np.all(np.diff(centers) > 0), "centers is not strictly increasing"
|
|
33
35
|
|
|
34
36
|
|
|
37
|
+
class TestPathListFromInput:
|
|
38
|
+
"""A collection of tests for the _path_list_from_input function."""
|
|
39
|
+
|
|
40
|
+
# Test cases that don't require I/O
|
|
41
|
+
def test_list_of_strings(self):
|
|
42
|
+
"""Test with a list of file paths as strings."""
|
|
43
|
+
files_list = ["path/to/file1.txt", "path/to/file2.txt"]
|
|
44
|
+
result = _path_list_from_input(files_list)
|
|
45
|
+
assert len(result) == 2
|
|
46
|
+
assert result[0] == Path("path/to/file1.txt")
|
|
47
|
+
assert result[1] == Path("path/to/file2.txt")
|
|
48
|
+
|
|
49
|
+
def test_list_of_path_objects(self):
|
|
50
|
+
"""Test with a list of pathlib.Path objects."""
|
|
51
|
+
files_list = [Path("file_a.txt"), Path("file_b.txt")]
|
|
52
|
+
result = _path_list_from_input(files_list)
|
|
53
|
+
assert len(result) == 2
|
|
54
|
+
assert result[0] == Path("file_a.txt")
|
|
55
|
+
assert result[1] == Path("file_b.txt")
|
|
56
|
+
|
|
57
|
+
def test_single_path_object(self):
|
|
58
|
+
"""Test with a single pathlib.Path object."""
|
|
59
|
+
file_path = Path("a_single_file.csv")
|
|
60
|
+
result = _path_list_from_input(file_path)
|
|
61
|
+
assert len(result) == 1
|
|
62
|
+
assert result[0] == file_path
|
|
63
|
+
|
|
64
|
+
def test_invalid_input_type_raises(self):
|
|
65
|
+
"""Test that an invalid input type raises a TypeError."""
|
|
66
|
+
with pytest.raises(TypeError, match="'files' should be str, Path, or List"):
|
|
67
|
+
_path_list_from_input(123)
|
|
68
|
+
|
|
69
|
+
# Test cases that require I/O and `tmp_path`
|
|
70
|
+
def test_single_file_as_str(self, tmp_path):
|
|
71
|
+
"""Test with a single file given as a string, requiring a file to exist."""
|
|
72
|
+
p = tmp_path / "test_file.txt"
|
|
73
|
+
p.touch()
|
|
74
|
+
result = _path_list_from_input(str(p))
|
|
75
|
+
assert len(result) == 1
|
|
76
|
+
assert result[0] == p
|
|
77
|
+
|
|
78
|
+
def test_wildcard_pattern(self, tmp_path, monkeypatch):
|
|
79
|
+
"""Test with a wildcard pattern, requiring files to exist, using monkeypatch."""
|
|
80
|
+
# Setup
|
|
81
|
+
d = tmp_path / "data"
|
|
82
|
+
d.mkdir()
|
|
83
|
+
(d / "file1.csv").touch()
|
|
84
|
+
(d / "file2.csv").touch()
|
|
85
|
+
(d / "other_file.txt").touch()
|
|
86
|
+
|
|
87
|
+
# Action: Temporarily change the current working directory
|
|
88
|
+
monkeypatch.chdir(tmp_path)
|
|
89
|
+
|
|
90
|
+
result = _path_list_from_input("data/*.csv")
|
|
91
|
+
|
|
92
|
+
# Assertion
|
|
93
|
+
assert len(result) == 2
|
|
94
|
+
assert result[0].name == "file1.csv"
|
|
95
|
+
assert result[1].name == "file2.csv"
|
|
96
|
+
|
|
97
|
+
def test_non_matching_pattern_raises(self, tmp_path):
|
|
98
|
+
"""Test that a non-matching pattern raises a FileNotFoundError."""
|
|
99
|
+
with pytest.raises(FileNotFoundError, match="No files matched"):
|
|
100
|
+
_path_list_from_input(str(tmp_path / "non_existent_file_*.txt"))
|
|
101
|
+
|
|
102
|
+
|
|
35
103
|
def test_has_dask() -> None:
|
|
36
104
|
"""Verify that dask existence is correctly reported when found."""
|
|
37
105
|
with mock.patch("roms_tools.utils.find_spec", return_value=mock.MagicMock):
|
|
38
|
-
assert
|
|
106
|
+
assert has_dask()
|
|
39
107
|
|
|
40
108
|
|
|
41
109
|
def test_has_dask_error_when_missing() -> None:
|
|
42
110
|
"""Verify that dask existence is correctly reported when not found."""
|
|
43
111
|
with mock.patch("roms_tools.utils.find_spec", return_value=None):
|
|
44
|
-
assert not
|
|
112
|
+
assert not has_dask()
|
|
45
113
|
|
|
46
114
|
|
|
47
115
|
def test_has_gcfs() -> None:
|
|
48
116
|
"""Verify that GCFS existence is correctly reported when found."""
|
|
49
117
|
with mock.patch("roms_tools.utils.find_spec", return_value=mock.MagicMock):
|
|
50
|
-
assert
|
|
118
|
+
assert has_gcsfs()
|
|
51
119
|
|
|
52
120
|
|
|
53
121
|
def test_has_gcfs_error_when_missing() -> None:
|
|
54
122
|
"""Verify that GCFS existence is correctly reported when not found."""
|
|
55
123
|
with mock.patch("roms_tools.utils.find_spec", return_value=None):
|
|
56
|
-
assert not
|
|
124
|
+
assert not has_gcsfs()
|
|
57
125
|
|
|
58
126
|
|
|
59
127
|
def test_has_copernicus() -> None:
|
|
60
128
|
"""Verify that copernicus existence is correctly reported when found."""
|
|
61
129
|
with mock.patch("roms_tools.utils.find_spec", return_value=mock.MagicMock):
|
|
62
|
-
assert
|
|
130
|
+
assert has_copernicus()
|
|
63
131
|
|
|
64
132
|
|
|
65
133
|
def test_has_copernicus_error_when_missing() -> None:
|
|
66
134
|
"""Verify that copernicus existence is correctly reported when not found."""
|
|
67
135
|
with mock.patch("roms_tools.utils.find_spec", return_value=None):
|
|
68
|
-
assert not
|
|
136
|
+
assert not has_copernicus()
|
|
69
137
|
|
|
70
138
|
|
|
71
139
|
def test_load_data_dask_not_found() -> None:
|
|
72
140
|
"""Verify that load data raises an exception when dask is requested and missing."""
|
|
73
141
|
with (
|
|
74
|
-
mock.patch("roms_tools.utils.
|
|
142
|
+
mock.patch("roms_tools.utils.has_dask", return_value=False),
|
|
75
143
|
pytest.raises(RuntimeError),
|
|
76
144
|
):
|
|
77
|
-
|
|
145
|
+
load_data("foo.zarr", {"a": "a"}, use_dask=True)
|
|
78
146
|
|
|
79
147
|
|
|
80
148
|
def test_load_data_open_zarr_without_dask() -> None:
|
|
81
149
|
"""Verify that load data raises an exception when zarr is requested without dask."""
|
|
82
150
|
with (
|
|
83
|
-
mock.patch("roms_tools.utils.
|
|
151
|
+
mock.patch("roms_tools.utils.has_dask", return_value=False),
|
|
84
152
|
pytest.raises(ValueError),
|
|
85
153
|
):
|
|
86
154
|
# read_zarr should require use_dask to be True
|
|
87
|
-
|
|
155
|
+
load_data("foo.zarr", {"a": ""}, use_dask=False, read_zarr=True)
|
|
88
156
|
|
|
89
157
|
|
|
90
158
|
@pytest.mark.parametrize(
|
|
@@ -111,7 +179,7 @@ def test_load_data_open_dataset(
|
|
|
111
179
|
"roms_tools.utils.xr.open_dataset",
|
|
112
180
|
wraps=xr.open_dataset,
|
|
113
181
|
) as fn_od:
|
|
114
|
-
ds =
|
|
182
|
+
ds = load_data(
|
|
115
183
|
ds_path,
|
|
116
184
|
{"latitude": "latitude"},
|
|
117
185
|
use_dask=False,
|
|
@@ -119,3 +187,57 @@ def test_load_data_open_dataset(
|
|
|
119
187
|
assert fn_od.called
|
|
120
188
|
|
|
121
189
|
assert expected_dim in ds.dims
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
# test get_dask_chunks
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def test_latlon_default_chunks():
|
|
196
|
+
dim_names = {"latitude": "lat", "longitude": "lon"}
|
|
197
|
+
expected = {"lat": -1, "lon": -1}
|
|
198
|
+
result = get_dask_chunks(dim_names)
|
|
199
|
+
assert result == expected
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def test_latlon_with_depth_and_time():
|
|
203
|
+
dim_names = {"latitude": "lat", "longitude": "lon", "depth": "z", "time": "t"}
|
|
204
|
+
expected = {"lat": -1, "lon": -1, "z": -1, "t": 1}
|
|
205
|
+
result = get_dask_chunks(dim_names)
|
|
206
|
+
assert result == expected
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def test_latlon_with_time_chunking_false():
|
|
210
|
+
dim_names = {"latitude": "lat", "longitude": "lon", "time": "t"}
|
|
211
|
+
expected = {"lat": -1, "lon": -1}
|
|
212
|
+
result = get_dask_chunks(dim_names, time_chunking=False)
|
|
213
|
+
assert result == expected
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def test_roms_default_chunks():
|
|
217
|
+
dim_names = {}
|
|
218
|
+
expected_keys = {"eta_rho", "eta_v", "xi_rho", "xi_u", "s_rho"}
|
|
219
|
+
result = get_dask_chunks(dim_names)
|
|
220
|
+
assert set(result.keys()) == expected_keys
|
|
221
|
+
assert all(v == -1 for v in result.values())
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def test_roms_with_depth_and_time():
|
|
225
|
+
dim_names = {"depth": "s_rho", "time": "ocean_time"}
|
|
226
|
+
result = get_dask_chunks(dim_names)
|
|
227
|
+
# ROMS default keys + depth + time
|
|
228
|
+
expected_keys = {"eta_rho", "eta_v", "xi_rho", "xi_u", "s_rho", "ocean_time"}
|
|
229
|
+
assert set(result.keys()) == expected_keys
|
|
230
|
+
assert result["ocean_time"] == 1
|
|
231
|
+
assert result["s_rho"] == -1
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def test_roms_with_ntides():
|
|
235
|
+
dim_names = {"ntides": "nt"}
|
|
236
|
+
result = get_dask_chunks(dim_names)
|
|
237
|
+
assert result["nt"] == 1
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def test_time_chunking_false_roms():
|
|
241
|
+
dim_names = {"time": "ocean_time"}
|
|
242
|
+
result = get_dask_chunks(dim_names, time_chunking=False)
|
|
243
|
+
assert "ocean_time" not in result
|