roms-tools 3.1.1__py3-none-any.whl → 3.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,6 +2,7 @@ import logging
2
2
  from collections import OrderedDict
3
3
  from datetime import datetime
4
4
  from pathlib import Path
5
+ from unittest import mock
5
6
 
6
7
  import numpy as np
7
8
  import pytest
@@ -11,11 +12,14 @@ from roms_tools.download import download_test_data
11
12
  from roms_tools.setup.datasets import (
12
13
  CESMBGCDataset,
13
14
  Dataset,
15
+ ERA5ARCODataset,
14
16
  ERA5Correction,
15
17
  GLORYSDataset,
18
+ GLORYSDefaultDataset,
16
19
  RiverDataset,
17
20
  TPXODataset,
18
21
  )
22
+ from roms_tools.setup.surface_forcing import DEFAULT_ERA5_ARCO_PATH
19
23
 
20
24
 
21
25
  @pytest.fixture
@@ -437,6 +441,78 @@ def test_era5_correction_choose_subdomain(use_dask):
437
441
  assert (data.ds["longitude"] == lons).all()
438
442
 
439
443
 
444
+ @pytest.mark.use_gcsfs
445
+ def test_default_era5_dataset_loading_without_dask() -> None:
446
+ """Verify that loading the default ERA5 dataset fails if use_dask is not True."""
447
+ start_time = datetime(2020, 2, 1)
448
+ end_time = datetime(2020, 2, 2)
449
+
450
+ with pytest.raises(ValueError):
451
+ _ = ERA5ARCODataset(
452
+ filename=DEFAULT_ERA5_ARCO_PATH,
453
+ start_time=start_time,
454
+ end_time=end_time,
455
+ use_dask=False,
456
+ )
457
+
458
+
459
+ @pytest.mark.skip("Temporary skip until memory consumption issue is addressed. # TODO")
460
+ @pytest.mark.stream
461
+ @pytest.mark.use_dask
462
+ @pytest.mark.use_gcsfs
463
+ def test_default_era5_dataset_loading() -> None:
464
+ """Verify the default ERA5 dataset is loaded correctly."""
465
+ start_time = datetime(2020, 2, 1)
466
+ end_time = datetime(2020, 2, 2)
467
+
468
+ ds = ERA5ARCODataset(
469
+ filename=DEFAULT_ERA5_ARCO_PATH,
470
+ start_time=start_time,
471
+ end_time=end_time,
472
+ use_dask=True,
473
+ )
474
+
475
+ expected_vars = {"uwnd", "vwnd", "swrad", "lwrad", "Tair", "rain"}
476
+ assert set(ds.var_names).issuperset(expected_vars)
477
+
478
+
479
+ @pytest.mark.use_copernicus
480
+ def test_default_glorys_dataset_loading_dask_not_installed() -> None:
481
+ """Verify that loading the default GLORYS dataset fails if dask is not available."""
482
+ start_time = datetime(2020, 2, 1)
483
+ end_time = datetime(2020, 2, 2)
484
+
485
+ with (
486
+ pytest.raises(RuntimeError),
487
+ mock.patch("roms_tools.utils._has_dask", return_value=False),
488
+ ):
489
+ _ = GLORYSDefaultDataset(
490
+ filename=GLORYSDefaultDataset.dataset_name,
491
+ start_time=start_time,
492
+ end_time=end_time,
493
+ use_dask=True,
494
+ )
495
+
496
+
497
+ @pytest.mark.stream
498
+ @pytest.mark.use_copernicus
499
+ @pytest.mark.use_dask
500
+ def test_default_glorys_dataset_loading() -> None:
501
+ """Verify the default GLORYS dataset is loaded correctly."""
502
+ start_time = datetime(2012, 1, 1)
503
+ end_time = datetime(2013, 1, 1)
504
+
505
+ ds = GLORYSDefaultDataset(
506
+ filename=GLORYSDefaultDataset.dataset_name,
507
+ start_time=start_time,
508
+ end_time=end_time,
509
+ use_dask=True,
510
+ )
511
+
512
+ expected_vars = {"temp", "salt", "u", "v", "zeta"}
513
+ assert set(ds.var_names).issuperset(expected_vars)
514
+
515
+
440
516
  def test_data_concatenation(use_dask):
441
517
  fname = download_test_data("GLORYS_NA_2012.nc")
442
518
  data = GLORYSDataset(
@@ -21,6 +21,11 @@ from roms_tools.constants import (
21
21
  from roms_tools.download import download_test_data
22
22
  from roms_tools.setup.topography import _compute_rfactor
23
23
 
24
+ try:
25
+ import xesmf # type: ignore
26
+ except ImportError:
27
+ xesmf = None
28
+
24
29
 
25
30
  @pytest.fixture()
26
31
  def counter_clockwise_rotated_grid():
@@ -177,13 +182,18 @@ def test_successful_initialization_with_topography(grid_fixture, request):
177
182
  assert grid is not None
178
183
 
179
184
 
180
- def test_plot():
181
- grid = Grid(
182
- nx=20, ny=20, size_x=100, size_y=100, center_lon=-20, center_lat=0, rot=0
183
- )
185
+ def test_plot(grid_that_straddles_180_degree_meridian):
186
+ grid_that_straddles_180_degree_meridian.plot(with_dim_names=False)
187
+ grid_that_straddles_180_degree_meridian.plot(with_dim_names=True)
188
+
189
+
190
+ @pytest.mark.skipif(xesmf is None, reason="xesmf required")
191
+ def test_plot_along_lat_lon(grid_that_straddles_180_degree_meridian):
192
+ grid_that_straddles_180_degree_meridian.plot(lat=61)
193
+ grid_that_straddles_180_degree_meridian.plot(lon=180)
184
194
 
185
- grid.plot(with_dim_names=False)
186
- grid.plot(with_dim_names=True)
195
+ with pytest.raises(ValueError, match="Specify either `lat` or `lon`, not both"):
196
+ grid_that_straddles_180_degree_meridian.plot(lat=61, lon=180)
187
197
 
188
198
 
189
199
  def test_save(tmp_path):
@@ -187,12 +187,12 @@ def _test_successful_initialization(
187
187
  if coarse_grid_mode == "always":
188
188
  assert sfc_forcing.use_coarse_grid
189
189
  assert (
190
- "Data will be interpolated onto grid coarsened by factor 2."
190
+ "Data will be interpolated onto the grid coarsened by factor 2."
191
191
  in caplog.text
192
192
  )
193
193
  elif coarse_grid_mode == "never":
194
194
  assert not sfc_forcing.use_coarse_grid
195
- assert "Data will be interpolated onto fine grid." in caplog.text
195
+ assert "Data will be interpolated onto the fine grid." in caplog.text
196
196
 
197
197
  assert isinstance(sfc_forcing.ds, xr.Dataset)
198
198
  assert "uwnd" in sfc_forcing.ds
@@ -902,7 +902,9 @@ def test_from_yaml_missing_surface_forcing(tmp_path, use_dask):
902
902
  yaml_filepath.unlink()
903
903
 
904
904
 
905
+ @pytest.mark.skip("Temporary skip until memory consumption issue is addressed. # TODO")
905
906
  @pytest.mark.stream
907
+ @pytest.mark.use_dask
906
908
  def test_surface_forcing_arco(surface_forcing_arco, tmp_path):
907
909
  """One big integration test for cloud-based ERA5 data because the streaming takes a
908
910
  long time.
@@ -932,3 +934,25 @@ def test_surface_forcing_arco(surface_forcing_arco, tmp_path):
932
934
  yaml_filepath.unlink()
933
935
  Path(expected_filepath1).unlink()
934
936
  Path(expected_filepath2).unlink()
937
+
938
+
939
+ @pytest.mark.skip("Temporary skip until memory consumption issue is addressed. # TODO")
940
+ @pytest.mark.stream
941
+ @pytest.mark.use_dask
942
+ @pytest.mark.use_gcsfs
943
+ def test_default_era5_dataset_loading(small_grid: Grid) -> None:
944
+ """Verify the default ERA5 dataset is loaded when a path is not provided."""
945
+ start_time = datetime(2020, 2, 1)
946
+ end_time = datetime(2020, 2, 2)
947
+
948
+ sf = SurfaceForcing(
949
+ grid=small_grid,
950
+ source={"name": "ERA5"},
951
+ type="physics",
952
+ start_time=start_time,
953
+ end_time=end_time,
954
+ use_dask=True,
955
+ )
956
+
957
+ expected_vars = {"uwnd", "vwnd", "swrad", "lwrad", "Tair", "rain"}
958
+ assert set(sf.ds.data_vars).issuperset(expected_vars)
@@ -1,15 +1,11 @@
1
- import os
2
1
  import shutil
2
+ from collections.abc import Callable
3
+ from pathlib import Path
3
4
 
4
5
  import pytest
5
6
  import xarray as xr
6
7
 
7
8
 
8
- def _get_fname(name):
9
- dirname = os.path.dirname(__file__)
10
- return os.path.join(dirname, "test_data", f"{name}.zarr")
11
-
12
-
13
9
  @pytest.mark.parametrize(
14
10
  "forcing_fixture",
15
11
  [
@@ -34,7 +30,11 @@ def _get_fname(name):
34
30
  # this test will not be run by default
35
31
  # to run it and overwrite the test data, invoke pytest as follows
36
32
  # pytest --overwrite=tidal_forcing --overwrite=boundary_forcing
37
- def test_save_results(forcing_fixture, request):
33
+ def test_save_results(
34
+ forcing_fixture,
35
+ request: pytest.FixtureRequest,
36
+ get_test_data_path: Callable[[str], Path],
37
+ ) -> None:
38
38
  overwrite = request.config.getoption("--overwrite")
39
39
 
40
40
  # Skip the test if the fixture isn't marked for overwriting, unless 'all' is specified
@@ -42,10 +42,10 @@ def test_save_results(forcing_fixture, request):
42
42
  pytest.skip(f"Skipping overwrite for {forcing_fixture}")
43
43
 
44
44
  forcing = request.getfixturevalue(forcing_fixture)
45
- fname = _get_fname(forcing_fixture)
45
+ fname = get_test_data_path(forcing_fixture)
46
46
 
47
47
  # Check if the Zarr directory exists and delete it if it does
48
- if os.path.exists(fname):
48
+ if fname.exists():
49
49
  shutil.rmtree(fname)
50
50
 
51
51
  forcing.ds.to_zarr(fname)
@@ -72,8 +72,12 @@ def test_save_results(forcing_fixture, request):
72
72
  "river_forcing_no_climatology",
73
73
  ],
74
74
  )
75
- def test_check_results(forcing_fixture, request):
76
- fname = _get_fname(forcing_fixture)
75
+ def test_check_results(
76
+ forcing_fixture,
77
+ request: pytest.FixtureRequest,
78
+ get_test_data_path: Callable[[str], Path],
79
+ ) -> None:
80
+ fname = get_test_data_path(forcing_fixture)
77
81
  expected_forcing_ds = xr.open_zarr(fname, decode_timedelta=False)
78
82
  forcing = request.getfixturevalue(forcing_fixture)
79
83
 
@@ -83,6 +87,7 @@ def test_check_results(forcing_fixture, request):
83
87
  )
84
88
 
85
89
 
90
+ @pytest.mark.use_dask
86
91
  @pytest.mark.parametrize(
87
92
  "forcing_fixture",
88
93
  [
@@ -97,11 +102,12 @@ def test_check_results(forcing_fixture, request):
97
102
  "bgc_boundary_forcing_from_climatology",
98
103
  ],
99
104
  )
100
- def test_dask_vs_no_dask(forcing_fixture, request, tmp_path, use_dask):
105
+ def test_dask_vs_no_dask(
106
+ forcing_fixture: str,
107
+ request: pytest.FixtureRequest,
108
+ tmp_path: Path,
109
+ ) -> None:
101
110
  """Test comparing the forcing created with and without Dask on same platform."""
102
- if not use_dask:
103
- pytest.skip("Test only runs when --use_dask is specified")
104
-
105
111
  # Get the forcing with Dask
106
112
  forcing_with_dask = request.getfixturevalue(forcing_fixture)
107
113
 
@@ -297,3 +297,48 @@ class TestPartitionNetcdf:
297
297
  for expected_filepath in expected_filepath_list:
298
298
  assert expected_filepath.exists()
299
299
  expected_filepath.unlink()
300
+
301
+ def test_partition_netcdf_with_output_dir(self, grid, tmp_path):
302
+ # Save the input file
303
+ input_file = tmp_path / "input_grid.nc"
304
+ grid.save(input_file)
305
+
306
+ # Create a custom output directory
307
+ output_dir = tmp_path / "custom_output"
308
+ output_dir.mkdir()
309
+
310
+ saved_filenames = partition_netcdf(
311
+ input_file, np_eta=3, np_xi=5, output_dir=output_dir
312
+ )
313
+
314
+ base_name = input_file.stem # "input_grid"
315
+ expected_filenames = [output_dir / f"{base_name}.{i:02d}.nc" for i in range(15)]
316
+
317
+ assert saved_filenames == expected_filenames
318
+
319
+ for f in expected_filenames:
320
+ assert f.exists()
321
+ f.unlink()
322
+
323
+ def test_partition_netcdf_multiple_files(self, grid, tmp_path):
324
+ # Create two test input files
325
+ file1 = tmp_path / "grid1.nc"
326
+ file2 = tmp_path / "grid2.nc"
327
+ grid.save(file1)
328
+ grid.save(file2)
329
+
330
+ # Run partitioning with 2x2 tiles on both files
331
+ saved_filenames = partition_netcdf([file1, file2], np_eta=3, np_xi=5)
332
+
333
+ # Expect 4 tiles per file → 8 total output files
334
+ expected_filepaths = []
335
+ for file in [file1, file2]:
336
+ base = file.with_suffix("")
337
+ expected_filepaths += [Path(f"{base}.{i:02d}.nc") for i in range(15)]
338
+
339
+ assert len(saved_filenames) == 30
340
+ assert saved_filenames == expected_filepaths
341
+
342
+ for path in expected_filepaths:
343
+ assert path.exists()
344
+ path.unlink()
@@ -1,7 +1,18 @@
1
+ from collections.abc import Callable
2
+ from pathlib import Path
3
+ from unittest import mock
4
+
1
5
  import numpy as np
2
6
  import pytest
7
+ import xarray as xr
3
8
 
4
- from roms_tools.utils import _generate_focused_coordinate_range
9
+ from roms_tools.utils import (
10
+ _generate_focused_coordinate_range,
11
+ _has_copernicus,
12
+ _has_dask,
13
+ _has_gcsfs,
14
+ _load_data,
15
+ )
5
16
 
6
17
 
7
18
  @pytest.mark.parametrize(
@@ -19,3 +30,92 @@ def test_coordinate_range_monotonicity(min_val, max_val, center, sc, N):
19
30
  )
20
31
  assert np.all(np.diff(faces) > 0), "faces is not strictly increasing"
21
32
  assert np.all(np.diff(centers) > 0), "centers is not strictly increasing"
33
+
34
+
35
+ def test_has_dask() -> None:
36
+ """Verify that dask existence is correctly reported when found."""
37
+ with mock.patch("roms_tools.utils.find_spec", return_value=mock.MagicMock):
38
+ assert _has_dask()
39
+
40
+
41
+ def test_has_dask_error_when_missing() -> None:
42
+ """Verify that dask existence is correctly reported when not found."""
43
+ with mock.patch("roms_tools.utils.find_spec", return_value=None):
44
+ assert not _has_dask()
45
+
46
+
47
+ def test_has_gcfs() -> None:
48
+ """Verify that GCFS existence is correctly reported when found."""
49
+ with mock.patch("roms_tools.utils.find_spec", return_value=mock.MagicMock):
50
+ assert _has_gcsfs()
51
+
52
+
53
+ def test_has_gcfs_error_when_missing() -> None:
54
+ """Verify that GCFS existence is correctly reported when not found."""
55
+ with mock.patch("roms_tools.utils.find_spec", return_value=None):
56
+ assert not _has_gcsfs()
57
+
58
+
59
+ def test_has_copernicus() -> None:
60
+ """Verify that copernicus existence is correctly reported when found."""
61
+ with mock.patch("roms_tools.utils.find_spec", return_value=mock.MagicMock):
62
+ assert _has_copernicus()
63
+
64
+
65
+ def test_has_copernicus_error_when_missing() -> None:
66
+ """Verify that copernicus existence is correctly reported when not found."""
67
+ with mock.patch("roms_tools.utils.find_spec", return_value=None):
68
+ assert not _has_copernicus()
69
+
70
+
71
+ def test_load_data_dask_not_found() -> None:
72
+ """Verify that load data raises an exception when dask is requested and missing."""
73
+ with (
74
+ mock.patch("roms_tools.utils._has_dask", return_value=False),
75
+ pytest.raises(RuntimeError),
76
+ ):
77
+ _load_data("foo.zarr", {"a": "a"}, use_dask=True)
78
+
79
+
80
+ def test_load_data_open_zarr_without_dask() -> None:
81
+ """Verify that load data raises an exception when zarr is requested without dask."""
82
+ with (
83
+ mock.patch("roms_tools.utils._has_dask", return_value=False),
84
+ pytest.raises(ValueError),
85
+ ):
86
+ # read_zarr should require use_dask to be True
87
+ _load_data("foo.zarr", {"a": ""}, use_dask=False, read_zarr=True)
88
+
89
+
90
+ @pytest.mark.parametrize(
91
+ ("dataset_name", "expected_dim"),
92
+ [
93
+ ("surface_forcing", "time"),
94
+ ("bgc_surface_forcing", "time"),
95
+ ("tidal_forcing", "eta_rho"),
96
+ ("coarse_surface_forcing", "eta_rho"),
97
+ ],
98
+ )
99
+ def test_load_data_open_dataset(
100
+ dataset_name: str,
101
+ expected_dim: str,
102
+ get_test_data_path: Callable[[str], Path],
103
+ ) -> None:
104
+ """Verify that a zarr file is correctly loaded when not using Dask.
105
+
106
+ This must use xr.open_dataset
107
+ """
108
+ ds_path = get_test_data_path(dataset_name)
109
+
110
+ with mock.patch(
111
+ "roms_tools.utils.xr.open_dataset",
112
+ wraps=xr.open_dataset,
113
+ ) as fn_od:
114
+ ds = _load_data(
115
+ ds_path,
116
+ {"latitude": "latitude"},
117
+ use_dask=False,
118
+ )
119
+ assert fn_od.called
120
+
121
+ assert expected_dim in ds.dims
@@ -1,3 +1,4 @@
1
+ from collections.abc import Sequence
1
2
  from numbers import Integral
2
3
  from pathlib import Path
3
4
 
@@ -296,20 +297,21 @@ def partition(
296
297
 
297
298
 
298
299
  def partition_netcdf(
299
- filepath: str | Path,
300
+ filepath: str | Path | Sequence[str | Path],
300
301
  np_eta: int = 1,
301
302
  np_xi: int = 1,
303
+ output_dir: str | Path | None = None,
302
304
  include_coarse_dims: bool = True,
303
- ) -> None:
304
- """Partition a ROMS NetCDF file into smaller spatial tiles and save them to disk.
305
+ ) -> list[Path]:
306
+ """Partition one or more ROMS NetCDF files into smaller spatial tiles and save them to disk.
305
307
 
306
- This function divides the dataset in the specified NetCDF file into `np_eta` by `np_xi` tiles.
308
+ This function divides each dataset into `np_eta` by `np_xi` tiles.
307
309
  Each tile is saved as a separate NetCDF file.
308
310
 
309
311
  Parameters
310
312
  ----------
311
- filepath : Union[str, Path]
312
- The path to the input NetCDF file.
313
+ filepath : str | Path | Sequence[str | Path]
314
+ A path or list of paths to input NetCDF files.
313
315
 
314
316
  np_eta : int, optional
315
317
  The number of partitions along the `eta` direction. Must be a positive integer. Default is 1.
@@ -317,6 +319,10 @@ def partition_netcdf(
317
319
  np_xi : int, optional
318
320
  The number of partitions along the `xi` direction. Must be a positive integer. Default is 1.
319
321
 
322
+ output_dir : str | Path | None, optional
323
+ Directory or base path to save partitioned files.
324
+ If None, files are saved alongside the input file.
325
+
320
326
  include_coarse_dims : bool, optional
321
327
  Whether to include coarse grid dimensions (`eta_coarse`, `xi_coarse`) in the partitioning.
322
328
  If False, these dimensions will not be split. Relevant if none of the coarse resolution variables are actually used by ROMS.
@@ -324,31 +330,39 @@ def partition_netcdf(
324
330
 
325
331
  Returns
326
332
  -------
327
- List[Path]
333
+ list[Path]
328
334
  A list of Path objects for the filenames that were saved.
329
335
  """
330
- # Ensure filepath is a Path object
331
- filepath = Path(filepath)
332
-
333
- # Open the dataset
334
- ds = xr.open_dataset(filepath.with_suffix(".nc"), decode_timedelta=False)
335
-
336
- # Partition the dataset
337
- file_numbers, partitioned_datasets = partition(
338
- ds, np_eta=np_eta, np_xi=np_xi, include_coarse_dims=include_coarse_dims
339
- )
340
-
341
- # Generate paths to the partitioned files
342
- base_filepath = filepath.with_suffix("")
343
- ndigits = len(str(max(np.array(file_numbers))))
344
- paths_to_partitioned_files = [
345
- Path(f"{base_filepath}.{file_number:0{ndigits}d}")
346
- for file_number in file_numbers
347
- ]
336
+ if isinstance(filepath, str | Path):
337
+ filepaths = [Path(filepath)]
338
+ else:
339
+ filepaths = [Path(fp) for fp in filepath]
340
+
341
+ all_saved_filenames = []
342
+
343
+ for fp in filepaths:
344
+ input_file = fp.with_suffix(".nc")
345
+ ds = xr.open_dataset(input_file, decode_timedelta=False)
346
+
347
+ file_numbers, partitioned_datasets = partition(
348
+ ds, np_eta=np_eta, np_xi=np_xi, include_coarse_dims=include_coarse_dims
349
+ )
350
+
351
+ if output_dir:
352
+ output_dir = Path(output_dir)
353
+ output_dir.mkdir(parents=True, exist_ok=True)
354
+ base_filepath = output_dir / fp.stem
355
+ else:
356
+ base_filepath = fp.with_suffix("")
348
357
 
349
- # Save the partitioned datasets to files
350
- saved_filenames = save_datasets(
351
- partitioned_datasets, paths_to_partitioned_files, verbose=False
352
- )
358
+ ndigits = len(str(max(file_numbers)))
359
+ paths_to_partitioned_files = [
360
+ Path(f"{base_filepath}.{num:0{ndigits}d}") for num in file_numbers
361
+ ]
362
+
363
+ saved = save_datasets(
364
+ partitioned_datasets, paths_to_partitioned_files, verbose=False
365
+ )
366
+ all_saved_filenames.extend(saved)
353
367
 
354
- return saved_filenames
368
+ return all_saved_filenames