roms-tools 3.1.0__py3-none-any.whl → 3.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,6 +9,7 @@ import xarray as xr
9
9
 
10
10
  from conftest import calculate_file_hash
11
11
  from roms_tools import Grid, RiverForcing
12
+ from roms_tools.constants import MAX_DISTINCT_COLORS
12
13
 
13
14
 
14
15
  @pytest.fixture
@@ -57,6 +58,29 @@ def river_forcing_for_grid_that_straddles_dateline():
57
58
  )
58
59
 
59
60
 
61
+ @pytest.fixture
62
+ def river_forcing_for_gulf_of_mexico():
63
+ """Fixture for creating a RiverForcing object for the Gulf of Mexico with 45 rivers."""
64
+ grid = Grid(
65
+ nx=20,
66
+ ny=15,
67
+ size_x=2000,
68
+ size_y=1500,
69
+ center_lon=-89,
70
+ center_lat=24,
71
+ rot=0,
72
+ N=3,
73
+ )
74
+ start_time = datetime(2012, 1, 1)
75
+ end_time = datetime(2012, 1, 31)
76
+
77
+ return RiverForcing(
78
+ grid=grid,
79
+ start_time=start_time,
80
+ end_time=end_time,
81
+ )
82
+
83
+
60
84
  @pytest.fixture
61
85
  def single_cell_indices():
62
86
  # These are the indices that the `river_forcing` fixture generates automatically.
@@ -247,13 +271,46 @@ class TestRiverForcingGeneral:
247
271
  )
248
272
 
249
273
  def test_river_forcing_plot(self, river_forcing_with_bgc):
250
- """Test plot method."""
274
+ """Test plot methods with and without specifying river_names."""
275
+ river_names = list(river_forcing_with_bgc.indices.keys())[0:2]
276
+
277
+ # Test plot_locations
251
278
  river_forcing_with_bgc.plot_locations()
252
- river_forcing_with_bgc.plot("river_volume")
253
- river_forcing_with_bgc.plot("river_temp")
254
- river_forcing_with_bgc.plot("river_salt")
255
- river_forcing_with_bgc.plot("river_ALK")
256
- river_forcing_with_bgc.plot("river_PO4")
279
+ river_forcing_with_bgc.plot_locations(river_names=river_names)
280
+
281
+ # Fields to test
282
+ variables = [
283
+ "river_volume",
284
+ "river_temp",
285
+ "river_salt",
286
+ "river_ALK",
287
+ "river_PO4",
288
+ ]
289
+
290
+ for var in variables:
291
+ river_forcing_with_bgc.plot(var)
292
+ river_forcing_with_bgc.plot(var, river_names=river_names)
293
+
294
+ def test_plot_max_releases(self, caplog, river_forcing_for_gulf_of_mexico):
295
+ river_names = list(river_forcing_for_gulf_of_mexico.indices.keys())
296
+
297
+ caplog.clear()
298
+ with caplog.at_level("WARNING"):
299
+ river_forcing_for_gulf_of_mexico.plot_locations()
300
+ assert any(
301
+ f"Only the first {MAX_DISTINCT_COLORS} rivers will be plotted" in message
302
+ for message in caplog.messages
303
+ )
304
+
305
+ with caplog.at_level("WARNING"):
306
+ river_forcing_for_gulf_of_mexico.plot(
307
+ "river_volume", river_names=river_names
308
+ )
309
+
310
+ assert any(
311
+ f"Only the first {MAX_DISTINCT_COLORS} rivers will be plotted" in message
312
+ for message in caplog.messages
313
+ )
257
314
 
258
315
  @pytest.mark.parametrize(
259
316
  "river_forcing_fixture",
@@ -187,12 +187,12 @@ def _test_successful_initialization(
187
187
  if coarse_grid_mode == "always":
188
188
  assert sfc_forcing.use_coarse_grid
189
189
  assert (
190
- "Data will be interpolated onto grid coarsened by factor 2."
190
+ "Data will be interpolated onto the grid coarsened by factor 2."
191
191
  in caplog.text
192
192
  )
193
193
  elif coarse_grid_mode == "never":
194
194
  assert not sfc_forcing.use_coarse_grid
195
- assert "Data will be interpolated onto fine grid." in caplog.text
195
+ assert "Data will be interpolated onto the fine grid." in caplog.text
196
196
 
197
197
  assert isinstance(sfc_forcing.ds, xr.Dataset)
198
198
  assert "uwnd" in sfc_forcing.ds
@@ -902,7 +902,9 @@ def test_from_yaml_missing_surface_forcing(tmp_path, use_dask):
902
902
  yaml_filepath.unlink()
903
903
 
904
904
 
905
+ @pytest.mark.skip("Temporary skip until memory consumption issue is addressed. # TODO")
905
906
  @pytest.mark.stream
907
+ @pytest.mark.use_dask
906
908
  def test_surface_forcing_arco(surface_forcing_arco, tmp_path):
907
909
  """One big integration test for cloud-based ERA5 data because the streaming takes a
908
910
  long time.
@@ -932,3 +934,25 @@ def test_surface_forcing_arco(surface_forcing_arco, tmp_path):
932
934
  yaml_filepath.unlink()
933
935
  Path(expected_filepath1).unlink()
934
936
  Path(expected_filepath2).unlink()
937
+
938
+
939
+ @pytest.mark.skip("Temporary skip until memory consumption issue is addressed. # TODO")
940
+ @pytest.mark.stream
941
+ @pytest.mark.use_dask
942
+ @pytest.mark.use_gcsfs
943
+ def test_default_era5_dataset_loading(small_grid: Grid) -> None:
944
+ """Verify the default ERA5 dataset is loaded when a path is not provided."""
945
+ start_time = datetime(2020, 2, 1)
946
+ end_time = datetime(2020, 2, 2)
947
+
948
+ sf = SurfaceForcing(
949
+ grid=small_grid,
950
+ source={"name": "ERA5"},
951
+ type="physics",
952
+ start_time=start_time,
953
+ end_time=end_time,
954
+ use_dask=True,
955
+ )
956
+
957
+ expected_vars = {"uwnd", "vwnd", "swrad", "lwrad", "Tair", "rain"}
958
+ assert set(sf.ds.data_vars).issuperset(expected_vars)
@@ -1,3 +1,4 @@
1
+ import logging
1
2
  from datetime import datetime
2
3
  from pathlib import Path
3
4
 
@@ -7,9 +8,7 @@ import xarray as xr
7
8
  from roms_tools import BoundaryForcing, Grid
8
9
  from roms_tools.download import download_test_data
9
10
  from roms_tools.setup.datasets import ERA5Correction
10
- from roms_tools.setup.utils import (
11
- interpolate_from_climatology,
12
- )
11
+ from roms_tools.setup.utils import interpolate_from_climatology, validate_names
13
12
 
14
13
 
15
14
  def test_interpolate_from_climatology(use_dask):
@@ -71,3 +70,53 @@ def test_roundtrip_yaml(
71
70
 
72
71
  filepath = Path(filepath)
73
72
  filepath.unlink()
73
+
74
+
75
+ # test validate_names function
76
+
77
+ VALID_NAMES = ["a", "b", "c", "d"]
78
+ SENTINEL = "ALL"
79
+ MAX_TO_PLOT = 3
80
+
81
+
82
+ def test_valid_names_no_truncation():
83
+ names = ["a", "b"]
84
+ result = validate_names(names, VALID_NAMES, SENTINEL, MAX_TO_PLOT, label="test")
85
+ assert result == names
86
+
87
+
88
+ def test_valid_names_with_truncation(caplog):
89
+ names = ["a", "b", "c", "d"]
90
+ with caplog.at_level(logging.WARNING):
91
+ result = validate_names(
92
+ names, VALID_NAMES, SENTINEL, max_to_plot=2, label="test"
93
+ )
94
+ assert result == ["a", "b"]
95
+ assert "Only the first 2 tests will be plotted" in caplog.text
96
+
97
+
98
+ def test_include_all_sentinel():
99
+ result = validate_names(SENTINEL, VALID_NAMES, SENTINEL, MAX_TO_PLOT, label="test")
100
+ assert result == VALID_NAMES[:MAX_TO_PLOT]
101
+
102
+
103
+ def test_invalid_name_raises():
104
+ with pytest.raises(ValueError, match="Invalid tests: z"):
105
+ validate_names(["a", "z"], VALID_NAMES, SENTINEL, MAX_TO_PLOT, label="test")
106
+
107
+
108
+ def test_non_list_input_raises():
109
+ with pytest.raises(ValueError, match="`test_names` should be a list of strings."):
110
+ validate_names("a", VALID_NAMES, SENTINEL, MAX_TO_PLOT, label="test")
111
+
112
+
113
+ def test_non_string_elements_in_list_raises():
114
+ with pytest.raises(
115
+ ValueError, match="All elements in `test_names` must be strings."
116
+ ):
117
+ validate_names(["a", 2], VALID_NAMES, SENTINEL, MAX_TO_PLOT, label="test")
118
+
119
+
120
+ def test_custom_label_in_errors():
121
+ with pytest.raises(ValueError, match="Invalid foozs: z"):
122
+ validate_names(["z"], VALID_NAMES, SENTINEL, MAX_TO_PLOT, label="fooz")
@@ -1,15 +1,11 @@
1
- import os
2
1
  import shutil
2
+ from collections.abc import Callable
3
+ from pathlib import Path
3
4
 
4
5
  import pytest
5
6
  import xarray as xr
6
7
 
7
8
 
8
- def _get_fname(name):
9
- dirname = os.path.dirname(__file__)
10
- return os.path.join(dirname, "test_data", f"{name}.zarr")
11
-
12
-
13
9
  @pytest.mark.parametrize(
14
10
  "forcing_fixture",
15
11
  [
@@ -34,7 +30,11 @@ def _get_fname(name):
34
30
  # this test will not be run by default
35
31
  # to run it and overwrite the test data, invoke pytest as follows
36
32
  # pytest --overwrite=tidal_forcing --overwrite=boundary_forcing
37
- def test_save_results(forcing_fixture, request):
33
+ def test_save_results(
34
+ forcing_fixture,
35
+ request: pytest.FixtureRequest,
36
+ get_test_data_path: Callable[[str], Path],
37
+ ) -> None:
38
38
  overwrite = request.config.getoption("--overwrite")
39
39
 
40
40
  # Skip the test if the fixture isn't marked for overwriting, unless 'all' is specified
@@ -42,10 +42,10 @@ def test_save_results(forcing_fixture, request):
42
42
  pytest.skip(f"Skipping overwrite for {forcing_fixture}")
43
43
 
44
44
  forcing = request.getfixturevalue(forcing_fixture)
45
- fname = _get_fname(forcing_fixture)
45
+ fname = get_test_data_path(forcing_fixture)
46
46
 
47
47
  # Check if the Zarr directory exists and delete it if it does
48
- if os.path.exists(fname):
48
+ if fname.exists():
49
49
  shutil.rmtree(fname)
50
50
 
51
51
  forcing.ds.to_zarr(fname)
@@ -72,8 +72,12 @@ def test_save_results(forcing_fixture, request):
72
72
  "river_forcing_no_climatology",
73
73
  ],
74
74
  )
75
- def test_check_results(forcing_fixture, request):
76
- fname = _get_fname(forcing_fixture)
75
+ def test_check_results(
76
+ forcing_fixture,
77
+ request: pytest.FixtureRequest,
78
+ get_test_data_path: Callable[[str], Path],
79
+ ) -> None:
80
+ fname = get_test_data_path(forcing_fixture)
77
81
  expected_forcing_ds = xr.open_zarr(fname, decode_timedelta=False)
78
82
  forcing = request.getfixturevalue(forcing_fixture)
79
83
 
@@ -83,6 +87,7 @@ def test_check_results(forcing_fixture, request):
83
87
  )
84
88
 
85
89
 
90
+ @pytest.mark.use_dask
86
91
  @pytest.mark.parametrize(
87
92
  "forcing_fixture",
88
93
  [
@@ -97,11 +102,12 @@ def test_check_results(forcing_fixture, request):
97
102
  "bgc_boundary_forcing_from_climatology",
98
103
  ],
99
104
  )
100
- def test_dask_vs_no_dask(forcing_fixture, request, tmp_path, use_dask):
105
+ def test_dask_vs_no_dask(
106
+ forcing_fixture: str,
107
+ request: pytest.FixtureRequest,
108
+ tmp_path: Path,
109
+ ) -> None:
101
110
  """Test comparing the forcing created with and without Dask on same platform."""
102
- if not use_dask:
103
- pytest.skip("Test only runs when --use_dask is specified")
104
-
105
111
  # Get the forcing with Dask
106
112
  forcing_with_dask = request.getfixturevalue(forcing_fixture)
107
113
 
@@ -297,3 +297,48 @@ class TestPartitionNetcdf:
297
297
  for expected_filepath in expected_filepath_list:
298
298
  assert expected_filepath.exists()
299
299
  expected_filepath.unlink()
300
+
301
+ def test_partition_netcdf_with_output_dir(self, grid, tmp_path):
302
+ # Save the input file
303
+ input_file = tmp_path / "input_grid.nc"
304
+ grid.save(input_file)
305
+
306
+ # Create a custom output directory
307
+ output_dir = tmp_path / "custom_output"
308
+ output_dir.mkdir()
309
+
310
+ saved_filenames = partition_netcdf(
311
+ input_file, np_eta=3, np_xi=5, output_dir=output_dir
312
+ )
313
+
314
+ base_name = input_file.stem # "input_grid"
315
+ expected_filenames = [output_dir / f"{base_name}.{i:02d}.nc" for i in range(15)]
316
+
317
+ assert saved_filenames == expected_filenames
318
+
319
+ for f in expected_filenames:
320
+ assert f.exists()
321
+ f.unlink()
322
+
323
+ def test_partition_netcdf_multiple_files(self, grid, tmp_path):
324
+ # Create two test input files
325
+ file1 = tmp_path / "grid1.nc"
326
+ file2 = tmp_path / "grid2.nc"
327
+ grid.save(file1)
328
+ grid.save(file2)
329
+
330
+ # Run partitioning with 2x2 tiles on both files
331
+ saved_filenames = partition_netcdf([file1, file2], np_eta=3, np_xi=5)
332
+
333
+ # Expect 4 tiles per file → 8 total output files
334
+ expected_filepaths = []
335
+ for file in [file1, file2]:
336
+ base = file.with_suffix("")
337
+ expected_filepaths += [Path(f"{base}.{i:02d}.nc") for i in range(15)]
338
+
339
+ assert len(saved_filenames) == 30
340
+ assert saved_filenames == expected_filepaths
341
+
342
+ for path in expected_filepaths:
343
+ assert path.exists()
344
+ path.unlink()
@@ -1,7 +1,18 @@
1
+ from collections.abc import Callable
2
+ from pathlib import Path
3
+ from unittest import mock
4
+
1
5
  import numpy as np
2
6
  import pytest
7
+ import xarray as xr
3
8
 
4
- from roms_tools.utils import _generate_focused_coordinate_range
9
+ from roms_tools.utils import (
10
+ _generate_focused_coordinate_range,
11
+ _has_copernicus,
12
+ _has_dask,
13
+ _has_gcsfs,
14
+ _load_data,
15
+ )
5
16
 
6
17
 
7
18
  @pytest.mark.parametrize(
@@ -19,3 +30,92 @@ def test_coordinate_range_monotonicity(min_val, max_val, center, sc, N):
19
30
  )
20
31
  assert np.all(np.diff(faces) > 0), "faces is not strictly increasing"
21
32
  assert np.all(np.diff(centers) > 0), "centers is not strictly increasing"
33
+
34
+
35
+ def test_has_dask() -> None:
36
+ """Verify that dask existence is correctly reported when found."""
37
+ with mock.patch("roms_tools.utils.find_spec", return_value=mock.MagicMock):
38
+ assert _has_dask()
39
+
40
+
41
+ def test_has_dask_error_when_missing() -> None:
42
+ """Verify that dask existence is correctly reported when not found."""
43
+ with mock.patch("roms_tools.utils.find_spec", return_value=None):
44
+ assert not _has_dask()
45
+
46
+
47
+ def test_has_gcfs() -> None:
48
+ """Verify that GCFS existence is correctly reported when found."""
49
+ with mock.patch("roms_tools.utils.find_spec", return_value=mock.MagicMock):
50
+ assert _has_gcsfs()
51
+
52
+
53
+ def test_has_gcfs_error_when_missing() -> None:
54
+ """Verify that GCFS existence is correctly reported when not found."""
55
+ with mock.patch("roms_tools.utils.find_spec", return_value=None):
56
+ assert not _has_gcsfs()
57
+
58
+
59
+ def test_has_copernicus() -> None:
60
+ """Verify that copernicus existence is correctly reported when found."""
61
+ with mock.patch("roms_tools.utils.find_spec", return_value=mock.MagicMock):
62
+ assert _has_copernicus()
63
+
64
+
65
+ def test_has_copernicus_error_when_missing() -> None:
66
+ """Verify that copernicus existence is correctly reported when not found."""
67
+ with mock.patch("roms_tools.utils.find_spec", return_value=None):
68
+ assert not _has_copernicus()
69
+
70
+
71
+ def test_load_data_dask_not_found() -> None:
72
+ """Verify that load data raises an exception when dask is requested and missing."""
73
+ with (
74
+ mock.patch("roms_tools.utils._has_dask", return_value=False),
75
+ pytest.raises(RuntimeError),
76
+ ):
77
+ _load_data("foo.zarr", {"a": "a"}, use_dask=True)
78
+
79
+
80
+ def test_load_data_open_zarr_without_dask() -> None:
81
+ """Verify that load data raises an exception when zarr is requested without dask."""
82
+ with (
83
+ mock.patch("roms_tools.utils._has_dask", return_value=False),
84
+ pytest.raises(ValueError),
85
+ ):
86
+ # read_zarr should require use_dask to be True
87
+ _load_data("foo.zarr", {"a": ""}, use_dask=False, read_zarr=True)
88
+
89
+
90
+ @pytest.mark.parametrize(
91
+ ("dataset_name", "expected_dim"),
92
+ [
93
+ ("surface_forcing", "time"),
94
+ ("bgc_surface_forcing", "time"),
95
+ ("tidal_forcing", "eta_rho"),
96
+ ("coarse_surface_forcing", "eta_rho"),
97
+ ],
98
+ )
99
+ def test_load_data_open_dataset(
100
+ dataset_name: str,
101
+ expected_dim: str,
102
+ get_test_data_path: Callable[[str], Path],
103
+ ) -> None:
104
+ """Verify that a zarr file is correctly loaded when not using Dask.
105
+
106
+ This must use xr.open_dataset
107
+ """
108
+ ds_path = get_test_data_path(dataset_name)
109
+
110
+ with mock.patch(
111
+ "roms_tools.utils.xr.open_dataset",
112
+ wraps=xr.open_dataset,
113
+ ) as fn_od:
114
+ ds = _load_data(
115
+ ds_path,
116
+ {"latitude": "latitude"},
117
+ use_dask=False,
118
+ )
119
+ assert fn_od.called
120
+
121
+ assert expected_dim in ds.dims
@@ -1,3 +1,4 @@
1
+ from collections.abc import Sequence
1
2
  from numbers import Integral
2
3
  from pathlib import Path
3
4
 
@@ -296,20 +297,21 @@ def partition(
296
297
 
297
298
 
298
299
  def partition_netcdf(
299
- filepath: str | Path,
300
+ filepath: str | Path | Sequence[str | Path],
300
301
  np_eta: int = 1,
301
302
  np_xi: int = 1,
303
+ output_dir: str | Path | None = None,
302
304
  include_coarse_dims: bool = True,
303
- ) -> None:
304
- """Partition a ROMS NetCDF file into smaller spatial tiles and save them to disk.
305
+ ) -> list[Path]:
306
+ """Partition one or more ROMS NetCDF files into smaller spatial tiles and save them to disk.
305
307
 
306
- This function divides the dataset in the specified NetCDF file into `np_eta` by `np_xi` tiles.
308
+ This function divides each dataset into `np_eta` by `np_xi` tiles.
307
309
  Each tile is saved as a separate NetCDF file.
308
310
 
309
311
  Parameters
310
312
  ----------
311
- filepath : Union[str, Path]
312
- The path to the input NetCDF file.
313
+ filepath : str | Path | Sequence[str | Path]
314
+ A path or list of paths to input NetCDF files.
313
315
 
314
316
  np_eta : int, optional
315
317
  The number of partitions along the `eta` direction. Must be a positive integer. Default is 1.
@@ -317,6 +319,10 @@ def partition_netcdf(
317
319
  np_xi : int, optional
318
320
  The number of partitions along the `xi` direction. Must be a positive integer. Default is 1.
319
321
 
322
+ output_dir : str | Path | None, optional
323
+ Directory or base path to save partitioned files.
324
+ If None, files are saved alongside the input file.
325
+
320
326
  include_coarse_dims : bool, optional
321
327
  Whether to include coarse grid dimensions (`eta_coarse`, `xi_coarse`) in the partitioning.
322
328
  If False, these dimensions will not be split. Relevant if none of the coarse resolution variables are actually used by ROMS.
@@ -324,31 +330,39 @@ def partition_netcdf(
324
330
 
325
331
  Returns
326
332
  -------
327
- List[Path]
333
+ list[Path]
328
334
  A list of Path objects for the filenames that were saved.
329
335
  """
330
- # Ensure filepath is a Path object
331
- filepath = Path(filepath)
332
-
333
- # Open the dataset
334
- ds = xr.open_dataset(filepath.with_suffix(".nc"), decode_timedelta=False)
335
-
336
- # Partition the dataset
337
- file_numbers, partitioned_datasets = partition(
338
- ds, np_eta=np_eta, np_xi=np_xi, include_coarse_dims=include_coarse_dims
339
- )
340
-
341
- # Generate paths to the partitioned files
342
- base_filepath = filepath.with_suffix("")
343
- ndigits = len(str(max(np.array(file_numbers))))
344
- paths_to_partitioned_files = [
345
- Path(f"{base_filepath}.{file_number:0{ndigits}d}")
346
- for file_number in file_numbers
347
- ]
336
+ if isinstance(filepath, str | Path):
337
+ filepaths = [Path(filepath)]
338
+ else:
339
+ filepaths = [Path(fp) for fp in filepath]
340
+
341
+ all_saved_filenames = []
342
+
343
+ for fp in filepaths:
344
+ input_file = fp.with_suffix(".nc")
345
+ ds = xr.open_dataset(input_file, decode_timedelta=False)
346
+
347
+ file_numbers, partitioned_datasets = partition(
348
+ ds, np_eta=np_eta, np_xi=np_xi, include_coarse_dims=include_coarse_dims
349
+ )
350
+
351
+ if output_dir:
352
+ output_dir = Path(output_dir)
353
+ output_dir.mkdir(parents=True, exist_ok=True)
354
+ base_filepath = output_dir / fp.stem
355
+ else:
356
+ base_filepath = fp.with_suffix("")
348
357
 
349
- # Save the partitioned datasets to files
350
- saved_filenames = save_datasets(
351
- partitioned_datasets, paths_to_partitioned_files, verbose=False
352
- )
358
+ ndigits = len(str(max(file_numbers)))
359
+ paths_to_partitioned_files = [
360
+ Path(f"{base_filepath}.{num:0{ndigits}d}") for num in file_numbers
361
+ ]
362
+
363
+ saved = save_datasets(
364
+ partitioned_datasets, paths_to_partitioned_files, verbose=False
365
+ )
366
+ all_saved_filenames.extend(saved)
353
367
 
354
- return saved_filenames
368
+ return all_saved_filenames