roms-tools 3.1.1__py3-none-any.whl → 3.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- roms_tools/__init__.py +5 -1
- roms_tools/plot.py +56 -9
- roms_tools/regrid.py +6 -1
- roms_tools/setup/boundary_forcing.py +55 -30
- roms_tools/setup/cdr_forcing.py +1 -7
- roms_tools/setup/datasets.py +96 -14
- roms_tools/setup/grid.py +29 -2
- roms_tools/setup/surface_forcing.py +12 -4
- roms_tools/tests/test_setup/test_boundary_forcing.py +57 -0
- roms_tools/tests/test_setup/test_datasets.py +76 -0
- roms_tools/tests/test_setup/test_grid.py +16 -6
- roms_tools/tests/test_setup/test_surface_forcing.py +26 -2
- roms_tools/tests/test_setup/test_validation.py +21 -15
- roms_tools/tests/test_tiling/test_partition.py +45 -0
- roms_tools/tests/test_utils.py +101 -1
- roms_tools/tiling/partition.py +44 -30
- roms_tools/utils.py +426 -131
- {roms_tools-3.1.1.dist-info → roms_tools-3.1.2.dist-info}/METADATA +4 -3
- {roms_tools-3.1.1.dist-info → roms_tools-3.1.2.dist-info}/RECORD +22 -22
- {roms_tools-3.1.1.dist-info → roms_tools-3.1.2.dist-info}/WHEEL +0 -0
- {roms_tools-3.1.1.dist-info → roms_tools-3.1.2.dist-info}/licenses/LICENSE +0 -0
- {roms_tools-3.1.1.dist-info → roms_tools-3.1.2.dist-info}/top_level.txt +0 -0
|
@@ -2,6 +2,7 @@ import logging
|
|
|
2
2
|
from collections import OrderedDict
|
|
3
3
|
from datetime import datetime
|
|
4
4
|
from pathlib import Path
|
|
5
|
+
from unittest import mock
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
import pytest
|
|
@@ -11,11 +12,14 @@ from roms_tools.download import download_test_data
|
|
|
11
12
|
from roms_tools.setup.datasets import (
|
|
12
13
|
CESMBGCDataset,
|
|
13
14
|
Dataset,
|
|
15
|
+
ERA5ARCODataset,
|
|
14
16
|
ERA5Correction,
|
|
15
17
|
GLORYSDataset,
|
|
18
|
+
GLORYSDefaultDataset,
|
|
16
19
|
RiverDataset,
|
|
17
20
|
TPXODataset,
|
|
18
21
|
)
|
|
22
|
+
from roms_tools.setup.surface_forcing import DEFAULT_ERA5_ARCO_PATH
|
|
19
23
|
|
|
20
24
|
|
|
21
25
|
@pytest.fixture
|
|
@@ -437,6 +441,78 @@ def test_era5_correction_choose_subdomain(use_dask):
|
|
|
437
441
|
assert (data.ds["longitude"] == lons).all()
|
|
438
442
|
|
|
439
443
|
|
|
444
|
+
@pytest.mark.use_gcsfs
|
|
445
|
+
def test_default_era5_dataset_loading_without_dask() -> None:
|
|
446
|
+
"""Verify that loading the default ERA5 dataset fails if use_dask is not True."""
|
|
447
|
+
start_time = datetime(2020, 2, 1)
|
|
448
|
+
end_time = datetime(2020, 2, 2)
|
|
449
|
+
|
|
450
|
+
with pytest.raises(ValueError):
|
|
451
|
+
_ = ERA5ARCODataset(
|
|
452
|
+
filename=DEFAULT_ERA5_ARCO_PATH,
|
|
453
|
+
start_time=start_time,
|
|
454
|
+
end_time=end_time,
|
|
455
|
+
use_dask=False,
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
@pytest.mark.skip("Temporary skip until memory consumption issue is addressed. # TODO")
|
|
460
|
+
@pytest.mark.stream
|
|
461
|
+
@pytest.mark.use_dask
|
|
462
|
+
@pytest.mark.use_gcsfs
|
|
463
|
+
def test_default_era5_dataset_loading() -> None:
|
|
464
|
+
"""Verify the default ERA5 dataset is loaded correctly."""
|
|
465
|
+
start_time = datetime(2020, 2, 1)
|
|
466
|
+
end_time = datetime(2020, 2, 2)
|
|
467
|
+
|
|
468
|
+
ds = ERA5ARCODataset(
|
|
469
|
+
filename=DEFAULT_ERA5_ARCO_PATH,
|
|
470
|
+
start_time=start_time,
|
|
471
|
+
end_time=end_time,
|
|
472
|
+
use_dask=True,
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
expected_vars = {"uwnd", "vwnd", "swrad", "lwrad", "Tair", "rain"}
|
|
476
|
+
assert set(ds.var_names).issuperset(expected_vars)
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
@pytest.mark.use_copernicus
|
|
480
|
+
def test_default_glorys_dataset_loading_dask_not_installed() -> None:
|
|
481
|
+
"""Verify that loading the default GLORYS dataset fails if dask is not available."""
|
|
482
|
+
start_time = datetime(2020, 2, 1)
|
|
483
|
+
end_time = datetime(2020, 2, 2)
|
|
484
|
+
|
|
485
|
+
with (
|
|
486
|
+
pytest.raises(RuntimeError),
|
|
487
|
+
mock.patch("roms_tools.utils._has_dask", return_value=False),
|
|
488
|
+
):
|
|
489
|
+
_ = GLORYSDefaultDataset(
|
|
490
|
+
filename=GLORYSDefaultDataset.dataset_name,
|
|
491
|
+
start_time=start_time,
|
|
492
|
+
end_time=end_time,
|
|
493
|
+
use_dask=True,
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
@pytest.mark.stream
|
|
498
|
+
@pytest.mark.use_copernicus
|
|
499
|
+
@pytest.mark.use_dask
|
|
500
|
+
def test_default_glorys_dataset_loading() -> None:
|
|
501
|
+
"""Verify the default GLORYS dataset is loaded correctly."""
|
|
502
|
+
start_time = datetime(2012, 1, 1)
|
|
503
|
+
end_time = datetime(2013, 1, 1)
|
|
504
|
+
|
|
505
|
+
ds = GLORYSDefaultDataset(
|
|
506
|
+
filename=GLORYSDefaultDataset.dataset_name,
|
|
507
|
+
start_time=start_time,
|
|
508
|
+
end_time=end_time,
|
|
509
|
+
use_dask=True,
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
expected_vars = {"temp", "salt", "u", "v", "zeta"}
|
|
513
|
+
assert set(ds.var_names).issuperset(expected_vars)
|
|
514
|
+
|
|
515
|
+
|
|
440
516
|
def test_data_concatenation(use_dask):
|
|
441
517
|
fname = download_test_data("GLORYS_NA_2012.nc")
|
|
442
518
|
data = GLORYSDataset(
|
|
@@ -21,6 +21,11 @@ from roms_tools.constants import (
|
|
|
21
21
|
from roms_tools.download import download_test_data
|
|
22
22
|
from roms_tools.setup.topography import _compute_rfactor
|
|
23
23
|
|
|
24
|
+
try:
|
|
25
|
+
import xesmf # type: ignore
|
|
26
|
+
except ImportError:
|
|
27
|
+
xesmf = None
|
|
28
|
+
|
|
24
29
|
|
|
25
30
|
@pytest.fixture()
|
|
26
31
|
def counter_clockwise_rotated_grid():
|
|
@@ -177,13 +182,18 @@ def test_successful_initialization_with_topography(grid_fixture, request):
|
|
|
177
182
|
assert grid is not None
|
|
178
183
|
|
|
179
184
|
|
|
180
|
-
def test_plot():
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
185
|
+
def test_plot(grid_that_straddles_180_degree_meridian):
|
|
186
|
+
grid_that_straddles_180_degree_meridian.plot(with_dim_names=False)
|
|
187
|
+
grid_that_straddles_180_degree_meridian.plot(with_dim_names=True)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
@pytest.mark.skipif(xesmf is None, reason="xesmf required")
|
|
191
|
+
def test_plot_along_lat_lon(grid_that_straddles_180_degree_meridian):
|
|
192
|
+
grid_that_straddles_180_degree_meridian.plot(lat=61)
|
|
193
|
+
grid_that_straddles_180_degree_meridian.plot(lon=180)
|
|
184
194
|
|
|
185
|
-
|
|
186
|
-
|
|
195
|
+
with pytest.raises(ValueError, match="Specify either `lat` or `lon`, not both"):
|
|
196
|
+
grid_that_straddles_180_degree_meridian.plot(lat=61, lon=180)
|
|
187
197
|
|
|
188
198
|
|
|
189
199
|
def test_save(tmp_path):
|
|
@@ -187,12 +187,12 @@ def _test_successful_initialization(
|
|
|
187
187
|
if coarse_grid_mode == "always":
|
|
188
188
|
assert sfc_forcing.use_coarse_grid
|
|
189
189
|
assert (
|
|
190
|
-
"Data will be interpolated onto grid coarsened by factor 2."
|
|
190
|
+
"Data will be interpolated onto the grid coarsened by factor 2."
|
|
191
191
|
in caplog.text
|
|
192
192
|
)
|
|
193
193
|
elif coarse_grid_mode == "never":
|
|
194
194
|
assert not sfc_forcing.use_coarse_grid
|
|
195
|
-
assert "Data will be interpolated onto fine grid." in caplog.text
|
|
195
|
+
assert "Data will be interpolated onto the fine grid." in caplog.text
|
|
196
196
|
|
|
197
197
|
assert isinstance(sfc_forcing.ds, xr.Dataset)
|
|
198
198
|
assert "uwnd" in sfc_forcing.ds
|
|
@@ -902,7 +902,9 @@ def test_from_yaml_missing_surface_forcing(tmp_path, use_dask):
|
|
|
902
902
|
yaml_filepath.unlink()
|
|
903
903
|
|
|
904
904
|
|
|
905
|
+
@pytest.mark.skip("Temporary skip until memory consumption issue is addressed. # TODO")
|
|
905
906
|
@pytest.mark.stream
|
|
907
|
+
@pytest.mark.use_dask
|
|
906
908
|
def test_surface_forcing_arco(surface_forcing_arco, tmp_path):
|
|
907
909
|
"""One big integration test for cloud-based ERA5 data because the streaming takes a
|
|
908
910
|
long time.
|
|
@@ -932,3 +934,25 @@ def test_surface_forcing_arco(surface_forcing_arco, tmp_path):
|
|
|
932
934
|
yaml_filepath.unlink()
|
|
933
935
|
Path(expected_filepath1).unlink()
|
|
934
936
|
Path(expected_filepath2).unlink()
|
|
937
|
+
|
|
938
|
+
|
|
939
|
+
@pytest.mark.skip("Temporary skip until memory consumption issue is addressed. # TODO")
|
|
940
|
+
@pytest.mark.stream
|
|
941
|
+
@pytest.mark.use_dask
|
|
942
|
+
@pytest.mark.use_gcsfs
|
|
943
|
+
def test_default_era5_dataset_loading(small_grid: Grid) -> None:
|
|
944
|
+
"""Verify the default ERA5 dataset is loaded when a path is not provided."""
|
|
945
|
+
start_time = datetime(2020, 2, 1)
|
|
946
|
+
end_time = datetime(2020, 2, 2)
|
|
947
|
+
|
|
948
|
+
sf = SurfaceForcing(
|
|
949
|
+
grid=small_grid,
|
|
950
|
+
source={"name": "ERA5"},
|
|
951
|
+
type="physics",
|
|
952
|
+
start_time=start_time,
|
|
953
|
+
end_time=end_time,
|
|
954
|
+
use_dask=True,
|
|
955
|
+
)
|
|
956
|
+
|
|
957
|
+
expected_vars = {"uwnd", "vwnd", "swrad", "lwrad", "Tair", "rain"}
|
|
958
|
+
assert set(sf.ds.data_vars).issuperset(expected_vars)
|
|
@@ -1,15 +1,11 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import shutil
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from pathlib import Path
|
|
3
4
|
|
|
4
5
|
import pytest
|
|
5
6
|
import xarray as xr
|
|
6
7
|
|
|
7
8
|
|
|
8
|
-
def _get_fname(name):
|
|
9
|
-
dirname = os.path.dirname(__file__)
|
|
10
|
-
return os.path.join(dirname, "test_data", f"{name}.zarr")
|
|
11
|
-
|
|
12
|
-
|
|
13
9
|
@pytest.mark.parametrize(
|
|
14
10
|
"forcing_fixture",
|
|
15
11
|
[
|
|
@@ -34,7 +30,11 @@ def _get_fname(name):
|
|
|
34
30
|
# this test will not be run by default
|
|
35
31
|
# to run it and overwrite the test data, invoke pytest as follows
|
|
36
32
|
# pytest --overwrite=tidal_forcing --overwrite=boundary_forcing
|
|
37
|
-
def test_save_results(
|
|
33
|
+
def test_save_results(
|
|
34
|
+
forcing_fixture,
|
|
35
|
+
request: pytest.FixtureRequest,
|
|
36
|
+
get_test_data_path: Callable[[str], Path],
|
|
37
|
+
) -> None:
|
|
38
38
|
overwrite = request.config.getoption("--overwrite")
|
|
39
39
|
|
|
40
40
|
# Skip the test if the fixture isn't marked for overwriting, unless 'all' is specified
|
|
@@ -42,10 +42,10 @@ def test_save_results(forcing_fixture, request):
|
|
|
42
42
|
pytest.skip(f"Skipping overwrite for {forcing_fixture}")
|
|
43
43
|
|
|
44
44
|
forcing = request.getfixturevalue(forcing_fixture)
|
|
45
|
-
fname =
|
|
45
|
+
fname = get_test_data_path(forcing_fixture)
|
|
46
46
|
|
|
47
47
|
# Check if the Zarr directory exists and delete it if it does
|
|
48
|
-
if
|
|
48
|
+
if fname.exists():
|
|
49
49
|
shutil.rmtree(fname)
|
|
50
50
|
|
|
51
51
|
forcing.ds.to_zarr(fname)
|
|
@@ -72,8 +72,12 @@ def test_save_results(forcing_fixture, request):
|
|
|
72
72
|
"river_forcing_no_climatology",
|
|
73
73
|
],
|
|
74
74
|
)
|
|
75
|
-
def test_check_results(
|
|
76
|
-
|
|
75
|
+
def test_check_results(
|
|
76
|
+
forcing_fixture,
|
|
77
|
+
request: pytest.FixtureRequest,
|
|
78
|
+
get_test_data_path: Callable[[str], Path],
|
|
79
|
+
) -> None:
|
|
80
|
+
fname = get_test_data_path(forcing_fixture)
|
|
77
81
|
expected_forcing_ds = xr.open_zarr(fname, decode_timedelta=False)
|
|
78
82
|
forcing = request.getfixturevalue(forcing_fixture)
|
|
79
83
|
|
|
@@ -83,6 +87,7 @@ def test_check_results(forcing_fixture, request):
|
|
|
83
87
|
)
|
|
84
88
|
|
|
85
89
|
|
|
90
|
+
@pytest.mark.use_dask
|
|
86
91
|
@pytest.mark.parametrize(
|
|
87
92
|
"forcing_fixture",
|
|
88
93
|
[
|
|
@@ -97,11 +102,12 @@ def test_check_results(forcing_fixture, request):
|
|
|
97
102
|
"bgc_boundary_forcing_from_climatology",
|
|
98
103
|
],
|
|
99
104
|
)
|
|
100
|
-
def test_dask_vs_no_dask(
|
|
105
|
+
def test_dask_vs_no_dask(
|
|
106
|
+
forcing_fixture: str,
|
|
107
|
+
request: pytest.FixtureRequest,
|
|
108
|
+
tmp_path: Path,
|
|
109
|
+
) -> None:
|
|
101
110
|
"""Test comparing the forcing created with and without Dask on same platform."""
|
|
102
|
-
if not use_dask:
|
|
103
|
-
pytest.skip("Test only runs when --use_dask is specified")
|
|
104
|
-
|
|
105
111
|
# Get the forcing with Dask
|
|
106
112
|
forcing_with_dask = request.getfixturevalue(forcing_fixture)
|
|
107
113
|
|
|
@@ -297,3 +297,48 @@ class TestPartitionNetcdf:
|
|
|
297
297
|
for expected_filepath in expected_filepath_list:
|
|
298
298
|
assert expected_filepath.exists()
|
|
299
299
|
expected_filepath.unlink()
|
|
300
|
+
|
|
301
|
+
def test_partition_netcdf_with_output_dir(self, grid, tmp_path):
|
|
302
|
+
# Save the input file
|
|
303
|
+
input_file = tmp_path / "input_grid.nc"
|
|
304
|
+
grid.save(input_file)
|
|
305
|
+
|
|
306
|
+
# Create a custom output directory
|
|
307
|
+
output_dir = tmp_path / "custom_output"
|
|
308
|
+
output_dir.mkdir()
|
|
309
|
+
|
|
310
|
+
saved_filenames = partition_netcdf(
|
|
311
|
+
input_file, np_eta=3, np_xi=5, output_dir=output_dir
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
base_name = input_file.stem # "input_grid"
|
|
315
|
+
expected_filenames = [output_dir / f"{base_name}.{i:02d}.nc" for i in range(15)]
|
|
316
|
+
|
|
317
|
+
assert saved_filenames == expected_filenames
|
|
318
|
+
|
|
319
|
+
for f in expected_filenames:
|
|
320
|
+
assert f.exists()
|
|
321
|
+
f.unlink()
|
|
322
|
+
|
|
323
|
+
def test_partition_netcdf_multiple_files(self, grid, tmp_path):
|
|
324
|
+
# Create two test input files
|
|
325
|
+
file1 = tmp_path / "grid1.nc"
|
|
326
|
+
file2 = tmp_path / "grid2.nc"
|
|
327
|
+
grid.save(file1)
|
|
328
|
+
grid.save(file2)
|
|
329
|
+
|
|
330
|
+
# Run partitioning with 2x2 tiles on both files
|
|
331
|
+
saved_filenames = partition_netcdf([file1, file2], np_eta=3, np_xi=5)
|
|
332
|
+
|
|
333
|
+
# Expect 4 tiles per file → 8 total output files
|
|
334
|
+
expected_filepaths = []
|
|
335
|
+
for file in [file1, file2]:
|
|
336
|
+
base = file.with_suffix("")
|
|
337
|
+
expected_filepaths += [Path(f"{base}.{i:02d}.nc") for i in range(15)]
|
|
338
|
+
|
|
339
|
+
assert len(saved_filenames) == 30
|
|
340
|
+
assert saved_filenames == expected_filepaths
|
|
341
|
+
|
|
342
|
+
for path in expected_filepaths:
|
|
343
|
+
assert path.exists()
|
|
344
|
+
path.unlink()
|
roms_tools/tests/test_utils.py
CHANGED
|
@@ -1,7 +1,18 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from unittest import mock
|
|
4
|
+
|
|
1
5
|
import numpy as np
|
|
2
6
|
import pytest
|
|
7
|
+
import xarray as xr
|
|
3
8
|
|
|
4
|
-
from roms_tools.utils import
|
|
9
|
+
from roms_tools.utils import (
|
|
10
|
+
_generate_focused_coordinate_range,
|
|
11
|
+
_has_copernicus,
|
|
12
|
+
_has_dask,
|
|
13
|
+
_has_gcsfs,
|
|
14
|
+
_load_data,
|
|
15
|
+
)
|
|
5
16
|
|
|
6
17
|
|
|
7
18
|
@pytest.mark.parametrize(
|
|
@@ -19,3 +30,92 @@ def test_coordinate_range_monotonicity(min_val, max_val, center, sc, N):
|
|
|
19
30
|
)
|
|
20
31
|
assert np.all(np.diff(faces) > 0), "faces is not strictly increasing"
|
|
21
32
|
assert np.all(np.diff(centers) > 0), "centers is not strictly increasing"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def test_has_dask() -> None:
|
|
36
|
+
"""Verify that dask existence is correctly reported when found."""
|
|
37
|
+
with mock.patch("roms_tools.utils.find_spec", return_value=mock.MagicMock):
|
|
38
|
+
assert _has_dask()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def test_has_dask_error_when_missing() -> None:
|
|
42
|
+
"""Verify that dask existence is correctly reported when not found."""
|
|
43
|
+
with mock.patch("roms_tools.utils.find_spec", return_value=None):
|
|
44
|
+
assert not _has_dask()
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def test_has_gcfs() -> None:
|
|
48
|
+
"""Verify that GCFS existence is correctly reported when found."""
|
|
49
|
+
with mock.patch("roms_tools.utils.find_spec", return_value=mock.MagicMock):
|
|
50
|
+
assert _has_gcsfs()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def test_has_gcfs_error_when_missing() -> None:
|
|
54
|
+
"""Verify that GCFS existence is correctly reported when not found."""
|
|
55
|
+
with mock.patch("roms_tools.utils.find_spec", return_value=None):
|
|
56
|
+
assert not _has_gcsfs()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def test_has_copernicus() -> None:
|
|
60
|
+
"""Verify that copernicus existence is correctly reported when found."""
|
|
61
|
+
with mock.patch("roms_tools.utils.find_spec", return_value=mock.MagicMock):
|
|
62
|
+
assert _has_copernicus()
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def test_has_copernicus_error_when_missing() -> None:
|
|
66
|
+
"""Verify that copernicus existence is correctly reported when not found."""
|
|
67
|
+
with mock.patch("roms_tools.utils.find_spec", return_value=None):
|
|
68
|
+
assert not _has_copernicus()
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def test_load_data_dask_not_found() -> None:
|
|
72
|
+
"""Verify that load data raises an exception when dask is requested and missing."""
|
|
73
|
+
with (
|
|
74
|
+
mock.patch("roms_tools.utils._has_dask", return_value=False),
|
|
75
|
+
pytest.raises(RuntimeError),
|
|
76
|
+
):
|
|
77
|
+
_load_data("foo.zarr", {"a": "a"}, use_dask=True)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def test_load_data_open_zarr_without_dask() -> None:
|
|
81
|
+
"""Verify that load data raises an exception when zarr is requested without dask."""
|
|
82
|
+
with (
|
|
83
|
+
mock.patch("roms_tools.utils._has_dask", return_value=False),
|
|
84
|
+
pytest.raises(ValueError),
|
|
85
|
+
):
|
|
86
|
+
# read_zarr should require use_dask to be True
|
|
87
|
+
_load_data("foo.zarr", {"a": ""}, use_dask=False, read_zarr=True)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@pytest.mark.parametrize(
|
|
91
|
+
("dataset_name", "expected_dim"),
|
|
92
|
+
[
|
|
93
|
+
("surface_forcing", "time"),
|
|
94
|
+
("bgc_surface_forcing", "time"),
|
|
95
|
+
("tidal_forcing", "eta_rho"),
|
|
96
|
+
("coarse_surface_forcing", "eta_rho"),
|
|
97
|
+
],
|
|
98
|
+
)
|
|
99
|
+
def test_load_data_open_dataset(
|
|
100
|
+
dataset_name: str,
|
|
101
|
+
expected_dim: str,
|
|
102
|
+
get_test_data_path: Callable[[str], Path],
|
|
103
|
+
) -> None:
|
|
104
|
+
"""Verify that a zarr file is correctly loaded when not using Dask.
|
|
105
|
+
|
|
106
|
+
This must use xr.open_dataset
|
|
107
|
+
"""
|
|
108
|
+
ds_path = get_test_data_path(dataset_name)
|
|
109
|
+
|
|
110
|
+
with mock.patch(
|
|
111
|
+
"roms_tools.utils.xr.open_dataset",
|
|
112
|
+
wraps=xr.open_dataset,
|
|
113
|
+
) as fn_od:
|
|
114
|
+
ds = _load_data(
|
|
115
|
+
ds_path,
|
|
116
|
+
{"latitude": "latitude"},
|
|
117
|
+
use_dask=False,
|
|
118
|
+
)
|
|
119
|
+
assert fn_od.called
|
|
120
|
+
|
|
121
|
+
assert expected_dim in ds.dims
|
roms_tools/tiling/partition.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
1
2
|
from numbers import Integral
|
|
2
3
|
from pathlib import Path
|
|
3
4
|
|
|
@@ -296,20 +297,21 @@ def partition(
|
|
|
296
297
|
|
|
297
298
|
|
|
298
299
|
def partition_netcdf(
|
|
299
|
-
filepath: str | Path,
|
|
300
|
+
filepath: str | Path | Sequence[str | Path],
|
|
300
301
|
np_eta: int = 1,
|
|
301
302
|
np_xi: int = 1,
|
|
303
|
+
output_dir: str | Path | None = None,
|
|
302
304
|
include_coarse_dims: bool = True,
|
|
303
|
-
) ->
|
|
304
|
-
"""Partition
|
|
305
|
+
) -> list[Path]:
|
|
306
|
+
"""Partition one or more ROMS NetCDF files into smaller spatial tiles and save them to disk.
|
|
305
307
|
|
|
306
|
-
This function divides
|
|
308
|
+
This function divides each dataset into `np_eta` by `np_xi` tiles.
|
|
307
309
|
Each tile is saved as a separate NetCDF file.
|
|
308
310
|
|
|
309
311
|
Parameters
|
|
310
312
|
----------
|
|
311
|
-
filepath :
|
|
312
|
-
|
|
313
|
+
filepath : str | Path | Sequence[str | Path]
|
|
314
|
+
A path or list of paths to input NetCDF files.
|
|
313
315
|
|
|
314
316
|
np_eta : int, optional
|
|
315
317
|
The number of partitions along the `eta` direction. Must be a positive integer. Default is 1.
|
|
@@ -317,6 +319,10 @@ def partition_netcdf(
|
|
|
317
319
|
np_xi : int, optional
|
|
318
320
|
The number of partitions along the `xi` direction. Must be a positive integer. Default is 1.
|
|
319
321
|
|
|
322
|
+
output_dir : str | Path | None, optional
|
|
323
|
+
Directory or base path to save partitioned files.
|
|
324
|
+
If None, files are saved alongside the input file.
|
|
325
|
+
|
|
320
326
|
include_coarse_dims : bool, optional
|
|
321
327
|
Whether to include coarse grid dimensions (`eta_coarse`, `xi_coarse`) in the partitioning.
|
|
322
328
|
If False, these dimensions will not be split. Relevant if none of the coarse resolution variables are actually used by ROMS.
|
|
@@ -324,31 +330,39 @@ def partition_netcdf(
|
|
|
324
330
|
|
|
325
331
|
Returns
|
|
326
332
|
-------
|
|
327
|
-
|
|
333
|
+
list[Path]
|
|
328
334
|
A list of Path objects for the filenames that were saved.
|
|
329
335
|
"""
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
336
|
+
if isinstance(filepath, str | Path):
|
|
337
|
+
filepaths = [Path(filepath)]
|
|
338
|
+
else:
|
|
339
|
+
filepaths = [Path(fp) for fp in filepath]
|
|
340
|
+
|
|
341
|
+
all_saved_filenames = []
|
|
342
|
+
|
|
343
|
+
for fp in filepaths:
|
|
344
|
+
input_file = fp.with_suffix(".nc")
|
|
345
|
+
ds = xr.open_dataset(input_file, decode_timedelta=False)
|
|
346
|
+
|
|
347
|
+
file_numbers, partitioned_datasets = partition(
|
|
348
|
+
ds, np_eta=np_eta, np_xi=np_xi, include_coarse_dims=include_coarse_dims
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
if output_dir:
|
|
352
|
+
output_dir = Path(output_dir)
|
|
353
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
354
|
+
base_filepath = output_dir / fp.stem
|
|
355
|
+
else:
|
|
356
|
+
base_filepath = fp.with_suffix("")
|
|
348
357
|
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
358
|
+
ndigits = len(str(max(file_numbers)))
|
|
359
|
+
paths_to_partitioned_files = [
|
|
360
|
+
Path(f"{base_filepath}.{num:0{ndigits}d}") for num in file_numbers
|
|
361
|
+
]
|
|
362
|
+
|
|
363
|
+
saved = save_datasets(
|
|
364
|
+
partitioned_datasets, paths_to_partitioned_files, verbose=False
|
|
365
|
+
)
|
|
366
|
+
all_saved_filenames.extend(saved)
|
|
353
367
|
|
|
354
|
-
return
|
|
368
|
+
return all_saved_filenames
|