roms-tools 3.1.1__py3-none-any.whl → 3.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- roms_tools/__init__.py +8 -1
- roms_tools/analysis/cdr_analysis.py +203 -0
- roms_tools/analysis/cdr_ensemble.py +198 -0
- roms_tools/analysis/roms_output.py +80 -46
- roms_tools/data/grids/GLORYS_global_grid.nc +0 -0
- roms_tools/download.py +4 -0
- roms_tools/plot.py +131 -30
- roms_tools/regrid.py +6 -1
- roms_tools/setup/boundary_forcing.py +94 -44
- roms_tools/setup/cdr_forcing.py +123 -15
- roms_tools/setup/cdr_release.py +161 -8
- roms_tools/setup/datasets.py +709 -341
- roms_tools/setup/grid.py +167 -139
- roms_tools/setup/initial_conditions.py +113 -48
- roms_tools/setup/mask.py +63 -7
- roms_tools/setup/nesting.py +67 -42
- roms_tools/setup/river_forcing.py +45 -19
- roms_tools/setup/surface_forcing.py +16 -10
- roms_tools/setup/tides.py +1 -2
- roms_tools/setup/topography.py +4 -4
- roms_tools/setup/utils.py +134 -22
- roms_tools/tests/test_analysis/test_cdr_analysis.py +144 -0
- roms_tools/tests/test_analysis/test_cdr_ensemble.py +202 -0
- roms_tools/tests/test_analysis/test_roms_output.py +61 -3
- roms_tools/tests/test_setup/test_boundary_forcing.py +111 -52
- roms_tools/tests/test_setup/test_cdr_forcing.py +54 -0
- roms_tools/tests/test_setup/test_cdr_release.py +118 -1
- roms_tools/tests/test_setup/test_datasets.py +458 -34
- roms_tools/tests/test_setup/test_grid.py +238 -121
- roms_tools/tests/test_setup/test_initial_conditions.py +94 -41
- roms_tools/tests/test_setup/test_surface_forcing.py +28 -3
- roms_tools/tests/test_setup/test_utils.py +91 -1
- roms_tools/tests/test_setup/test_validation.py +21 -15
- roms_tools/tests/test_setup/utils.py +71 -0
- roms_tools/tests/test_tiling/test_join.py +241 -0
- roms_tools/tests/test_tiling/test_partition.py +45 -0
- roms_tools/tests/test_utils.py +224 -2
- roms_tools/tiling/join.py +189 -0
- roms_tools/tiling/partition.py +44 -30
- roms_tools/utils.py +488 -161
- {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/METADATA +15 -4
- {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/RECORD +45 -37
- {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/WHEEL +0 -0
- {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/licenses/LICENSE +0 -0
- {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/top_level.txt +0 -0
roms_tools/tests/test_utils.py
CHANGED
|
@@ -1,7 +1,20 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from unittest import mock
|
|
4
|
+
|
|
1
5
|
import numpy as np
|
|
2
6
|
import pytest
|
|
7
|
+
import xarray as xr
|
|
3
8
|
|
|
4
|
-
from roms_tools.utils import
|
|
9
|
+
from roms_tools.utils import (
|
|
10
|
+
_path_list_from_input,
|
|
11
|
+
generate_focused_coordinate_range,
|
|
12
|
+
get_dask_chunks,
|
|
13
|
+
has_copernicus,
|
|
14
|
+
has_dask,
|
|
15
|
+
has_gcsfs,
|
|
16
|
+
load_data,
|
|
17
|
+
)
|
|
5
18
|
|
|
6
19
|
|
|
7
20
|
@pytest.mark.parametrize(
|
|
@@ -14,8 +27,217 @@ from roms_tools.utils import _generate_focused_coordinate_range
|
|
|
14
27
|
],
|
|
15
28
|
)
|
|
16
29
|
def test_coordinate_range_monotonicity(min_val, max_val, center, sc, N):
|
|
17
|
-
centers, faces =
|
|
30
|
+
centers, faces = generate_focused_coordinate_range(
|
|
18
31
|
min_val=min_val, max_val=max_val, center=center, sc=sc, N=N
|
|
19
32
|
)
|
|
20
33
|
assert np.all(np.diff(faces) > 0), "faces is not strictly increasing"
|
|
21
34
|
assert np.all(np.diff(centers) > 0), "centers is not strictly increasing"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class TestPathListFromInput:
|
|
38
|
+
"""A collection of tests for the _path_list_from_input function."""
|
|
39
|
+
|
|
40
|
+
# Test cases that don't require I/O
|
|
41
|
+
def test_list_of_strings(self):
|
|
42
|
+
"""Test with a list of file paths as strings."""
|
|
43
|
+
files_list = ["path/to/file1.txt", "path/to/file2.txt"]
|
|
44
|
+
result = _path_list_from_input(files_list)
|
|
45
|
+
assert len(result) == 2
|
|
46
|
+
assert result[0] == Path("path/to/file1.txt")
|
|
47
|
+
assert result[1] == Path("path/to/file2.txt")
|
|
48
|
+
|
|
49
|
+
def test_list_of_path_objects(self):
|
|
50
|
+
"""Test with a list of pathlib.Path objects."""
|
|
51
|
+
files_list = [Path("file_a.txt"), Path("file_b.txt")]
|
|
52
|
+
result = _path_list_from_input(files_list)
|
|
53
|
+
assert len(result) == 2
|
|
54
|
+
assert result[0] == Path("file_a.txt")
|
|
55
|
+
assert result[1] == Path("file_b.txt")
|
|
56
|
+
|
|
57
|
+
def test_single_path_object(self):
|
|
58
|
+
"""Test with a single pathlib.Path object."""
|
|
59
|
+
file_path = Path("a_single_file.csv")
|
|
60
|
+
result = _path_list_from_input(file_path)
|
|
61
|
+
assert len(result) == 1
|
|
62
|
+
assert result[0] == file_path
|
|
63
|
+
|
|
64
|
+
def test_invalid_input_type_raises(self):
|
|
65
|
+
"""Test that an invalid input type raises a TypeError."""
|
|
66
|
+
with pytest.raises(TypeError, match="'files' should be str, Path, or List"):
|
|
67
|
+
_path_list_from_input(123)
|
|
68
|
+
|
|
69
|
+
# Test cases that require I/O and `tmp_path`
|
|
70
|
+
def test_single_file_as_str(self, tmp_path):
|
|
71
|
+
"""Test with a single file given as a string, requiring a file to exist."""
|
|
72
|
+
p = tmp_path / "test_file.txt"
|
|
73
|
+
p.touch()
|
|
74
|
+
result = _path_list_from_input(str(p))
|
|
75
|
+
assert len(result) == 1
|
|
76
|
+
assert result[0] == p
|
|
77
|
+
|
|
78
|
+
def test_wildcard_pattern(self, tmp_path, monkeypatch):
|
|
79
|
+
"""Test with a wildcard pattern, requiring files to exist, using monkeypatch."""
|
|
80
|
+
# Setup
|
|
81
|
+
d = tmp_path / "data"
|
|
82
|
+
d.mkdir()
|
|
83
|
+
(d / "file1.csv").touch()
|
|
84
|
+
(d / "file2.csv").touch()
|
|
85
|
+
(d / "other_file.txt").touch()
|
|
86
|
+
|
|
87
|
+
# Action: Temporarily change the current working directory
|
|
88
|
+
monkeypatch.chdir(tmp_path)
|
|
89
|
+
|
|
90
|
+
result = _path_list_from_input("data/*.csv")
|
|
91
|
+
|
|
92
|
+
# Assertion
|
|
93
|
+
assert len(result) == 2
|
|
94
|
+
assert result[0].name == "file1.csv"
|
|
95
|
+
assert result[1].name == "file2.csv"
|
|
96
|
+
|
|
97
|
+
def test_non_matching_pattern_raises(self, tmp_path):
|
|
98
|
+
"""Test that a non-matching pattern raises a FileNotFoundError."""
|
|
99
|
+
with pytest.raises(FileNotFoundError, match="No files matched"):
|
|
100
|
+
_path_list_from_input(str(tmp_path / "non_existent_file_*.txt"))
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def test_has_dask() -> None:
|
|
104
|
+
"""Verify that dask existence is correctly reported when found."""
|
|
105
|
+
with mock.patch("roms_tools.utils.find_spec", return_value=mock.MagicMock):
|
|
106
|
+
assert has_dask()
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def test_has_dask_error_when_missing() -> None:
|
|
110
|
+
"""Verify that dask existence is correctly reported when not found."""
|
|
111
|
+
with mock.patch("roms_tools.utils.find_spec", return_value=None):
|
|
112
|
+
assert not has_dask()
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def test_has_gcfs() -> None:
|
|
116
|
+
"""Verify that GCFS existence is correctly reported when found."""
|
|
117
|
+
with mock.patch("roms_tools.utils.find_spec", return_value=mock.MagicMock):
|
|
118
|
+
assert has_gcsfs()
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def test_has_gcfs_error_when_missing() -> None:
|
|
122
|
+
"""Verify that GCFS existence is correctly reported when not found."""
|
|
123
|
+
with mock.patch("roms_tools.utils.find_spec", return_value=None):
|
|
124
|
+
assert not has_gcsfs()
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def test_has_copernicus() -> None:
|
|
128
|
+
"""Verify that copernicus existence is correctly reported when found."""
|
|
129
|
+
with mock.patch("roms_tools.utils.find_spec", return_value=mock.MagicMock):
|
|
130
|
+
assert has_copernicus()
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def test_has_copernicus_error_when_missing() -> None:
|
|
134
|
+
"""Verify that copernicus existence is correctly reported when not found."""
|
|
135
|
+
with mock.patch("roms_tools.utils.find_spec", return_value=None):
|
|
136
|
+
assert not has_copernicus()
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def test_load_data_dask_not_found() -> None:
|
|
140
|
+
"""Verify that load data raises an exception when dask is requested and missing."""
|
|
141
|
+
with (
|
|
142
|
+
mock.patch("roms_tools.utils.has_dask", return_value=False),
|
|
143
|
+
pytest.raises(RuntimeError),
|
|
144
|
+
):
|
|
145
|
+
load_data("foo.zarr", {"a": "a"}, use_dask=True)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def test_load_data_open_zarr_without_dask() -> None:
|
|
149
|
+
"""Verify that load data raises an exception when zarr is requested without dask."""
|
|
150
|
+
with (
|
|
151
|
+
mock.patch("roms_tools.utils.has_dask", return_value=False),
|
|
152
|
+
pytest.raises(ValueError),
|
|
153
|
+
):
|
|
154
|
+
# read_zarr should require use_dask to be True
|
|
155
|
+
load_data("foo.zarr", {"a": ""}, use_dask=False, read_zarr=True)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
@pytest.mark.parametrize(
|
|
159
|
+
("dataset_name", "expected_dim"),
|
|
160
|
+
[
|
|
161
|
+
("surface_forcing", "time"),
|
|
162
|
+
("bgc_surface_forcing", "time"),
|
|
163
|
+
("tidal_forcing", "eta_rho"),
|
|
164
|
+
("coarse_surface_forcing", "eta_rho"),
|
|
165
|
+
],
|
|
166
|
+
)
|
|
167
|
+
def test_load_data_open_dataset(
|
|
168
|
+
dataset_name: str,
|
|
169
|
+
expected_dim: str,
|
|
170
|
+
get_test_data_path: Callable[[str], Path],
|
|
171
|
+
) -> None:
|
|
172
|
+
"""Verify that a zarr file is correctly loaded when not using Dask.
|
|
173
|
+
|
|
174
|
+
This must use xr.open_dataset
|
|
175
|
+
"""
|
|
176
|
+
ds_path = get_test_data_path(dataset_name)
|
|
177
|
+
|
|
178
|
+
with mock.patch(
|
|
179
|
+
"roms_tools.utils.xr.open_dataset",
|
|
180
|
+
wraps=xr.open_dataset,
|
|
181
|
+
) as fn_od:
|
|
182
|
+
ds = load_data(
|
|
183
|
+
ds_path,
|
|
184
|
+
{"latitude": "latitude"},
|
|
185
|
+
use_dask=False,
|
|
186
|
+
)
|
|
187
|
+
assert fn_od.called
|
|
188
|
+
|
|
189
|
+
assert expected_dim in ds.dims
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
# test get_dask_chunks
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def test_latlon_default_chunks():
|
|
196
|
+
dim_names = {"latitude": "lat", "longitude": "lon"}
|
|
197
|
+
expected = {"lat": -1, "lon": -1}
|
|
198
|
+
result = get_dask_chunks(dim_names)
|
|
199
|
+
assert result == expected
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def test_latlon_with_depth_and_time():
|
|
203
|
+
dim_names = {"latitude": "lat", "longitude": "lon", "depth": "z", "time": "t"}
|
|
204
|
+
expected = {"lat": -1, "lon": -1, "z": -1, "t": 1}
|
|
205
|
+
result = get_dask_chunks(dim_names)
|
|
206
|
+
assert result == expected
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def test_latlon_with_time_chunking_false():
|
|
210
|
+
dim_names = {"latitude": "lat", "longitude": "lon", "time": "t"}
|
|
211
|
+
expected = {"lat": -1, "lon": -1}
|
|
212
|
+
result = get_dask_chunks(dim_names, time_chunking=False)
|
|
213
|
+
assert result == expected
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def test_roms_default_chunks():
|
|
217
|
+
dim_names = {}
|
|
218
|
+
expected_keys = {"eta_rho", "eta_v", "xi_rho", "xi_u", "s_rho"}
|
|
219
|
+
result = get_dask_chunks(dim_names)
|
|
220
|
+
assert set(result.keys()) == expected_keys
|
|
221
|
+
assert all(v == -1 for v in result.values())
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def test_roms_with_depth_and_time():
|
|
225
|
+
dim_names = {"depth": "s_rho", "time": "ocean_time"}
|
|
226
|
+
result = get_dask_chunks(dim_names)
|
|
227
|
+
# ROMS default keys + depth + time
|
|
228
|
+
expected_keys = {"eta_rho", "eta_v", "xi_rho", "xi_u", "s_rho", "ocean_time"}
|
|
229
|
+
assert set(result.keys()) == expected_keys
|
|
230
|
+
assert result["ocean_time"] == 1
|
|
231
|
+
assert result["s_rho"] == -1
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def test_roms_with_ntides():
|
|
235
|
+
dim_names = {"ntides": "nt"}
|
|
236
|
+
result = get_dask_chunks(dim_names)
|
|
237
|
+
assert result["nt"] == 1
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def test_time_chunking_false_roms():
|
|
241
|
+
dim_names = {"time": "ocean_time"}
|
|
242
|
+
result = get_dask_chunks(dim_names, time_chunking=False)
|
|
243
|
+
assert "ocean_time" not in result
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Literal, cast
|
|
4
|
+
|
|
5
|
+
import xarray as xr
|
|
6
|
+
|
|
7
|
+
from roms_tools.utils import FilePaths, _path_list_from_input
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def open_partitions(files: FilePaths) -> xr.Dataset:
|
|
11
|
+
"""
|
|
12
|
+
Open partitioned ROMS netCDF files as a single dataset.
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
files: str | List[str | Path]
|
|
17
|
+
List or wildcard pattern describing files to join,
|
|
18
|
+
e.g. "roms_rst.20121209133435.*.nc"
|
|
19
|
+
|
|
20
|
+
Returns
|
|
21
|
+
-------
|
|
22
|
+
xarray.Dataset
|
|
23
|
+
Dataset containing unified partitioned datasets
|
|
24
|
+
"""
|
|
25
|
+
filepaths = _path_list_from_input(files)
|
|
26
|
+
datasets = [xr.open_dataset(p, decode_timedelta=True) for p in sorted(filepaths)]
|
|
27
|
+
joined = join_datasets(datasets)
|
|
28
|
+
return joined
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def join_netcdf(files: FilePaths, output_path: Path | None = None) -> Path:
|
|
32
|
+
"""
|
|
33
|
+
Join partitioned NetCDFs into a single dataset.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
files : str | List[str | Path]
|
|
38
|
+
List or wildcard pattern describing files to join,
|
|
39
|
+
e.g. "roms_rst.20121209133435.*.nc"
|
|
40
|
+
|
|
41
|
+
output_path : Path, optional
|
|
42
|
+
If provided, the joined dataset will be saved to this path.
|
|
43
|
+
Otherwise, the common base of pattern (e.g. roms_rst.20121209133435.nc)
|
|
44
|
+
will be used.
|
|
45
|
+
|
|
46
|
+
Returns
|
|
47
|
+
-------
|
|
48
|
+
Path
|
|
49
|
+
The path of the saved file
|
|
50
|
+
"""
|
|
51
|
+
filepaths = _path_list_from_input(files)
|
|
52
|
+
# Determine output path if not provided
|
|
53
|
+
if output_path is None:
|
|
54
|
+
# e.g. roms_rst.20120101120000.023.nc -> roms_rst.20120101120000.nc
|
|
55
|
+
output_path = filepaths[0].with_suffix("").with_suffix(".nc")
|
|
56
|
+
|
|
57
|
+
joined = open_partitions(cast(FilePaths, filepaths))
|
|
58
|
+
joined.to_netcdf(output_path)
|
|
59
|
+
print(f"Saved joined dataset to: {output_path}")
|
|
60
|
+
|
|
61
|
+
return output_path
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _find_transitions(dim_sizes: list[int]) -> list[int]:
|
|
65
|
+
"""Finds the indices of all transitions in a list of dimension sizes.
|
|
66
|
+
|
|
67
|
+
A transition is a point where the dimension size changes from the previous one.
|
|
68
|
+
This function is used to determine the number of partitions (e.g., np_eta or np_xi).
|
|
69
|
+
|
|
70
|
+
Parameters
|
|
71
|
+
----------
|
|
72
|
+
dim_sizes : list[int]
|
|
73
|
+
A list of integer sizes for a given dimension across multiple datasets.
|
|
74
|
+
|
|
75
|
+
Returns
|
|
76
|
+
-------
|
|
77
|
+
List[int]
|
|
78
|
+
A list of indices where a transition was detected.
|
|
79
|
+
"""
|
|
80
|
+
transitions: list[int] = []
|
|
81
|
+
if len(dim_sizes) < 2:
|
|
82
|
+
return transitions
|
|
83
|
+
|
|
84
|
+
for i in range(1, len(dim_sizes)):
|
|
85
|
+
if dim_sizes[i] != dim_sizes[i - 1]:
|
|
86
|
+
transitions.append(i)
|
|
87
|
+
return transitions
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _find_common_dims(
|
|
91
|
+
direction: Literal["xi", "eta"], datasets: Sequence[xr.Dataset]
|
|
92
|
+
) -> list[str]:
|
|
93
|
+
"""Finds all common dimensions along the xi or eta direction amongst a list of Datasets.
|
|
94
|
+
|
|
95
|
+
Parameters
|
|
96
|
+
----------
|
|
97
|
+
direction: str ("xi" or "eta")
|
|
98
|
+
The direction in which to seek a common dimension
|
|
99
|
+
datasets: Sequence[xr.Dataset]:
|
|
100
|
+
The datasets in which to look
|
|
101
|
+
|
|
102
|
+
Returns
|
|
103
|
+
-------
|
|
104
|
+
common_dim: list[str]
|
|
105
|
+
The dimensions common to all specified datasets along 'direction'
|
|
106
|
+
"""
|
|
107
|
+
if direction not in ["xi", "eta"]:
|
|
108
|
+
raise ValueError("'direction' must be 'xi' or 'eta'")
|
|
109
|
+
dims = []
|
|
110
|
+
for point in ["rho", "u", "v"]:
|
|
111
|
+
if all(f"{direction}_{point}" in d.dims for d in datasets):
|
|
112
|
+
dims.append(f"{direction}_{point}")
|
|
113
|
+
if not dims:
|
|
114
|
+
raise ValueError(f"No common point found along direction {direction}")
|
|
115
|
+
return dims
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _infer_partition_layout_from_datasets(
|
|
119
|
+
datasets: Sequence[xr.Dataset],
|
|
120
|
+
) -> tuple[int, int]:
|
|
121
|
+
"""Infer np_eta, np_xi from datasets."""
|
|
122
|
+
nd = len(datasets)
|
|
123
|
+
if nd == 1:
|
|
124
|
+
return 1, 1
|
|
125
|
+
|
|
126
|
+
eta_dims = _find_common_dims("eta", datasets)
|
|
127
|
+
first_eta_transition = nd
|
|
128
|
+
|
|
129
|
+
for eta_dim in eta_dims:
|
|
130
|
+
dim_sizes = [ds.sizes.get(eta_dim, 0) for ds in datasets]
|
|
131
|
+
eta_transitions = _find_transitions(dim_sizes)
|
|
132
|
+
if eta_transitions and (min(eta_transitions) < first_eta_transition):
|
|
133
|
+
first_eta_transition = min(eta_transitions)
|
|
134
|
+
if first_eta_transition < nd:
|
|
135
|
+
np_xi = first_eta_transition
|
|
136
|
+
np_eta = nd // np_xi
|
|
137
|
+
return np_xi, np_eta
|
|
138
|
+
# If we did not successfully find np_xi,np_eta using eta points
|
|
139
|
+
# then we have a single-column grid:
|
|
140
|
+
|
|
141
|
+
return nd, 1
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def join_datasets(datasets: Sequence[xr.Dataset]) -> xr.Dataset:
|
|
145
|
+
"""Take a sequence of partitioned Datasets and return a joined Dataset."""
|
|
146
|
+
np_xi, np_eta = _infer_partition_layout_from_datasets(datasets)
|
|
147
|
+
|
|
148
|
+
# Arrange into grid
|
|
149
|
+
grid = [[datasets[j + i * np_xi] for j in range(np_xi)] for i in range(np_eta)]
|
|
150
|
+
|
|
151
|
+
# Join each row (along xi_*)
|
|
152
|
+
rows_joined = []
|
|
153
|
+
for row in grid:
|
|
154
|
+
all_vars = set().union(*(ds.data_vars for ds in row))
|
|
155
|
+
row_dataset = xr.Dataset()
|
|
156
|
+
|
|
157
|
+
for varname in all_vars:
|
|
158
|
+
var_slices = [ds[varname] for ds in row if varname in ds]
|
|
159
|
+
xi_dims = [dim for dim in var_slices[0].dims if dim.startswith("xi_")]
|
|
160
|
+
|
|
161
|
+
if not xi_dims:
|
|
162
|
+
row_dataset[varname] = var_slices[0]
|
|
163
|
+
else:
|
|
164
|
+
xi_dim = xi_dims[0]
|
|
165
|
+
row_dataset[varname] = xr.concat(
|
|
166
|
+
var_slices, dim=xi_dim, combine_attrs="override"
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
rows_joined.append(row_dataset)
|
|
170
|
+
|
|
171
|
+
# Join all rows (along eta_*)
|
|
172
|
+
final_dataset = xr.Dataset()
|
|
173
|
+
all_vars = set().union(*(ds.data_vars for ds in rows_joined))
|
|
174
|
+
|
|
175
|
+
for varname in all_vars:
|
|
176
|
+
var_slices = [ds[varname] for ds in rows_joined if varname in ds]
|
|
177
|
+
eta_dims = [dim for dim in var_slices[0].dims if dim.startswith("eta_")]
|
|
178
|
+
|
|
179
|
+
if not eta_dims:
|
|
180
|
+
final_dataset[varname] = var_slices[0]
|
|
181
|
+
else:
|
|
182
|
+
eta_dim = eta_dims[0]
|
|
183
|
+
final_dataset[varname] = xr.concat(
|
|
184
|
+
var_slices, dim=eta_dim, combine_attrs="override"
|
|
185
|
+
)
|
|
186
|
+
# Copy attributes from first dataset
|
|
187
|
+
final_dataset.attrs = datasets[0].attrs
|
|
188
|
+
|
|
189
|
+
return final_dataset
|
roms_tools/tiling/partition.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
1
2
|
from numbers import Integral
|
|
2
3
|
from pathlib import Path
|
|
3
4
|
|
|
@@ -296,20 +297,21 @@ def partition(
|
|
|
296
297
|
|
|
297
298
|
|
|
298
299
|
def partition_netcdf(
|
|
299
|
-
filepath: str | Path,
|
|
300
|
+
filepath: str | Path | Sequence[str | Path],
|
|
300
301
|
np_eta: int = 1,
|
|
301
302
|
np_xi: int = 1,
|
|
303
|
+
output_dir: str | Path | None = None,
|
|
302
304
|
include_coarse_dims: bool = True,
|
|
303
|
-
) ->
|
|
304
|
-
"""Partition
|
|
305
|
+
) -> list[Path]:
|
|
306
|
+
"""Partition one or more ROMS NetCDF files into smaller spatial tiles and save them to disk.
|
|
305
307
|
|
|
306
|
-
This function divides
|
|
308
|
+
This function divides each dataset into `np_eta` by `np_xi` tiles.
|
|
307
309
|
Each tile is saved as a separate NetCDF file.
|
|
308
310
|
|
|
309
311
|
Parameters
|
|
310
312
|
----------
|
|
311
|
-
filepath :
|
|
312
|
-
|
|
313
|
+
filepath : str | Path | Sequence[str | Path]
|
|
314
|
+
A path or list of paths to input NetCDF files.
|
|
313
315
|
|
|
314
316
|
np_eta : int, optional
|
|
315
317
|
The number of partitions along the `eta` direction. Must be a positive integer. Default is 1.
|
|
@@ -317,6 +319,10 @@ def partition_netcdf(
|
|
|
317
319
|
np_xi : int, optional
|
|
318
320
|
The number of partitions along the `xi` direction. Must be a positive integer. Default is 1.
|
|
319
321
|
|
|
322
|
+
output_dir : str | Path | None, optional
|
|
323
|
+
Directory or base path to save partitioned files.
|
|
324
|
+
If None, files are saved alongside the input file.
|
|
325
|
+
|
|
320
326
|
include_coarse_dims : bool, optional
|
|
321
327
|
Whether to include coarse grid dimensions (`eta_coarse`, `xi_coarse`) in the partitioning.
|
|
322
328
|
If False, these dimensions will not be split. Relevant if none of the coarse resolution variables are actually used by ROMS.
|
|
@@ -324,31 +330,39 @@ def partition_netcdf(
|
|
|
324
330
|
|
|
325
331
|
Returns
|
|
326
332
|
-------
|
|
327
|
-
|
|
333
|
+
list[Path]
|
|
328
334
|
A list of Path objects for the filenames that were saved.
|
|
329
335
|
"""
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
336
|
+
if isinstance(filepath, str | Path):
|
|
337
|
+
filepaths = [Path(filepath)]
|
|
338
|
+
else:
|
|
339
|
+
filepaths = [Path(fp) for fp in filepath]
|
|
340
|
+
|
|
341
|
+
all_saved_filenames = []
|
|
342
|
+
|
|
343
|
+
for fp in filepaths:
|
|
344
|
+
input_file = fp.with_suffix(".nc")
|
|
345
|
+
ds = xr.open_dataset(input_file, decode_timedelta=False)
|
|
346
|
+
|
|
347
|
+
file_numbers, partitioned_datasets = partition(
|
|
348
|
+
ds, np_eta=np_eta, np_xi=np_xi, include_coarse_dims=include_coarse_dims
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
if output_dir:
|
|
352
|
+
output_dir = Path(output_dir)
|
|
353
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
354
|
+
base_filepath = output_dir / fp.stem
|
|
355
|
+
else:
|
|
356
|
+
base_filepath = fp.with_suffix("")
|
|
348
357
|
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
358
|
+
ndigits = len(str(max(file_numbers)))
|
|
359
|
+
paths_to_partitioned_files = [
|
|
360
|
+
Path(f"{base_filepath}.{num:0{ndigits}d}") for num in file_numbers
|
|
361
|
+
]
|
|
362
|
+
|
|
363
|
+
saved = save_datasets(
|
|
364
|
+
partitioned_datasets, paths_to_partitioned_files, verbose=False
|
|
365
|
+
)
|
|
366
|
+
all_saved_filenames.extend(saved)
|
|
353
367
|
|
|
354
|
-
return
|
|
368
|
+
return all_saved_filenames
|