roms-tools 3.1.1__py3-none-any.whl → 3.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. roms_tools/__init__.py +8 -1
  2. roms_tools/analysis/cdr_analysis.py +203 -0
  3. roms_tools/analysis/cdr_ensemble.py +198 -0
  4. roms_tools/analysis/roms_output.py +80 -46
  5. roms_tools/data/grids/GLORYS_global_grid.nc +0 -0
  6. roms_tools/download.py +4 -0
  7. roms_tools/plot.py +131 -30
  8. roms_tools/regrid.py +6 -1
  9. roms_tools/setup/boundary_forcing.py +94 -44
  10. roms_tools/setup/cdr_forcing.py +123 -15
  11. roms_tools/setup/cdr_release.py +161 -8
  12. roms_tools/setup/datasets.py +709 -341
  13. roms_tools/setup/grid.py +167 -139
  14. roms_tools/setup/initial_conditions.py +113 -48
  15. roms_tools/setup/mask.py +63 -7
  16. roms_tools/setup/nesting.py +67 -42
  17. roms_tools/setup/river_forcing.py +45 -19
  18. roms_tools/setup/surface_forcing.py +16 -10
  19. roms_tools/setup/tides.py +1 -2
  20. roms_tools/setup/topography.py +4 -4
  21. roms_tools/setup/utils.py +134 -22
  22. roms_tools/tests/test_analysis/test_cdr_analysis.py +144 -0
  23. roms_tools/tests/test_analysis/test_cdr_ensemble.py +202 -0
  24. roms_tools/tests/test_analysis/test_roms_output.py +61 -3
  25. roms_tools/tests/test_setup/test_boundary_forcing.py +111 -52
  26. roms_tools/tests/test_setup/test_cdr_forcing.py +54 -0
  27. roms_tools/tests/test_setup/test_cdr_release.py +118 -1
  28. roms_tools/tests/test_setup/test_datasets.py +458 -34
  29. roms_tools/tests/test_setup/test_grid.py +238 -121
  30. roms_tools/tests/test_setup/test_initial_conditions.py +94 -41
  31. roms_tools/tests/test_setup/test_surface_forcing.py +28 -3
  32. roms_tools/tests/test_setup/test_utils.py +91 -1
  33. roms_tools/tests/test_setup/test_validation.py +21 -15
  34. roms_tools/tests/test_setup/utils.py +71 -0
  35. roms_tools/tests/test_tiling/test_join.py +241 -0
  36. roms_tools/tests/test_tiling/test_partition.py +45 -0
  37. roms_tools/tests/test_utils.py +224 -2
  38. roms_tools/tiling/join.py +189 -0
  39. roms_tools/tiling/partition.py +44 -30
  40. roms_tools/utils.py +488 -161
  41. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/METADATA +15 -4
  42. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/RECORD +45 -37
  43. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/WHEEL +0 -0
  44. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/licenses/LICENSE +0 -0
  45. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,144 @@
1
+ import logging
2
+
3
+ import numpy as np
4
+ import pytest
5
+ import xarray as xr
6
+
7
+ from roms_tools.analysis.cdr_analysis import (
8
+ _validate_source,
9
+ _validate_uptake_efficiency,
10
+ compute_cdr_metrics,
11
+ )
12
+
13
+
14
+ @pytest.fixture
15
+ def minimal_grid_ds():
16
+ """Minimal grid dataset with uniform spacing."""
17
+ return xr.Dataset(
18
+ {
19
+ "pm": (("eta_rho", "xi_rho"), np.ones((2, 2))),
20
+ "pn": (("eta_rho", "xi_rho"), np.ones((2, 2))),
21
+ }
22
+ )
23
+
24
+
25
+ @pytest.fixture
26
+ def minimal_ds():
27
+ """Minimal ROMS dataset with required variables and dimensions."""
28
+ time = np.arange(3)
29
+ s_rho = np.arange(1)
30
+ eta_rho = np.arange(2)
31
+ xi_rho = np.arange(2)
32
+
33
+ return xr.Dataset(
34
+ {
35
+ "avg_begin_time": ("time", time),
36
+ "avg_end_time": ("time", time + 1),
37
+ "ALK_source": (
38
+ ("time", "s_rho", "eta_rho", "xi_rho"),
39
+ np.ones((3, 1, 2, 2)),
40
+ ),
41
+ "DIC_source": (
42
+ ("time", "s_rho", "eta_rho", "xi_rho"),
43
+ -np.ones((3, 1, 2, 2)),
44
+ ),
45
+ "FG_CO2": (("time", "eta_rho", "xi_rho"), np.full((3, 2, 2), 2.0)),
46
+ "FG_ALT_CO2": (("time", "eta_rho", "xi_rho"), np.full((3, 2, 2), 1.0)),
47
+ "hDIC": (
48
+ ("time", "s_rho", "eta_rho", "xi_rho"),
49
+ np.full((3, 1, 2, 2), 10.0),
50
+ ),
51
+ "hDIC_ALT_CO2": (
52
+ ("time", "s_rho", "eta_rho", "xi_rho"),
53
+ np.full((3, 1, 2, 2), 9.0),
54
+ ),
55
+ },
56
+ coords={"time": time, "s_rho": s_rho, "eta_rho": eta_rho, "xi_rho": xi_rho},
57
+ )
58
+
59
+
60
+ def test_compute_cdr_metrics_outputs(
61
+ minimal_ds: xr.Dataset, minimal_grid_ds: xr.Dataset
62
+ ) -> None:
63
+ ds_cdr = compute_cdr_metrics(minimal_ds, minimal_grid_ds)
64
+
65
+ # Required outputs exist
66
+ for var in [
67
+ "area",
68
+ "window_length",
69
+ "FG_CO2",
70
+ "FG_ALT_CO2",
71
+ "hDIC",
72
+ "hDIC_ALT_CO2",
73
+ "cdr_efficiency",
74
+ "cdr_efficiency_from_delta_diff",
75
+ ]:
76
+ assert var in ds_cdr
77
+
78
+ # Area should be 1 (since pm=pn=1)
79
+ assert np.allclose(ds_cdr["area"], 1.0)
80
+
81
+ # Window length should be 1 everywhere
82
+ assert np.all(ds_cdr["window_length"].values == 1)
83
+
84
+
85
+ def test_missing_variable_in_ds(
86
+ minimal_ds: xr.Dataset, minimal_grid_ds: xr.Dataset
87
+ ) -> None:
88
+ bad_ds = minimal_ds.drop_vars("FG_CO2")
89
+ with pytest.raises(KeyError, match="Missing required variables"):
90
+ compute_cdr_metrics(bad_ds, minimal_grid_ds)
91
+
92
+
93
+ def test_missing_variable_in_grid(
94
+ minimal_ds: xr.Dataset, minimal_grid_ds: xr.Dataset
95
+ ) -> None:
96
+ bad_grid_ds = minimal_grid_ds.drop_vars("pm")
97
+ with pytest.raises(KeyError, match="Missing required variables"):
98
+ compute_cdr_metrics(minimal_ds, bad_grid_ds)
99
+
100
+
101
+ def test_validate_source_passes(minimal_ds):
102
+ # Should not raise
103
+ _validate_source(minimal_ds)
104
+
105
+
106
+ def test_validate_source_alk_negative(minimal_ds):
107
+ bad_ds = minimal_ds.copy()
108
+ bad_ds["ALK_source"].loc[dict(time=0)] = -1
109
+ with pytest.raises(ValueError, match="ALK_source"):
110
+ _validate_source(bad_ds)
111
+
112
+
113
+ def test_validate_source_dic_positive(minimal_ds):
114
+ bad_ds = minimal_ds.copy()
115
+ bad_ds["DIC_source"].loc[dict(time=0)] = 1
116
+ with pytest.raises(ValueError, match="DIC_source"):
117
+ _validate_source(bad_ds)
118
+
119
+
120
+ def test_validate_uptake_efficiency_logs(caplog):
121
+ arr1 = xr.DataArray([1.0, 2.0, 3.0], dims="time")
122
+ arr2 = xr.DataArray([1.0, 2.5, 3.0], dims="time")
123
+
124
+ with caplog.at_level(logging.INFO):
125
+ diff = _validate_uptake_efficiency(arr1, arr2)
126
+
127
+ assert np.isclose(diff, 0.5)
128
+ assert "flux-based and DIC-based uptake efficiency" in caplog.text
129
+
130
+
131
+ def test_efficiency_nan_when_zero_source(minimal_ds, minimal_grid_ds):
132
+ # Make ALK_source and DIC_source both zero at t=0
133
+ ds = minimal_ds.copy()
134
+ ds["ALK_source"].loc[dict(time=0)] = 0
135
+ ds["DIC_source"].loc[dict(time=0)] = 0
136
+
137
+ ds_cdr = compute_cdr_metrics(ds, minimal_grid_ds)
138
+
139
+ eff_flux = ds_cdr["cdr_efficiency"].isel(time=0).item()
140
+ eff_diff = ds_cdr["cdr_efficiency_from_delta_diff"].isel(time=0).item()
141
+
142
+ # Should be NaN, not inf
143
+ assert np.isnan(eff_flux)
144
+ assert np.isnan(eff_diff)
@@ -0,0 +1,202 @@
1
+ from pathlib import Path
2
+
3
+ import numpy as np
4
+ import pytest
5
+ import xarray as xr
6
+
7
+ from roms_tools import Ensemble
8
+
9
+ # ----------------------------
10
+ # Fixtures
11
+ # ----------------------------
12
+
13
+
14
+ @pytest.fixture
15
+ def create_member_ds() -> xr.Dataset:
16
+ """Simple Dataset for testing."""
17
+ times = np.array(["2000-01-01", "2000-01-02", "2000-01-03"], dtype="datetime64[ns]")
18
+ ds = xr.Dataset(
19
+ {"cdr_efficiency": ("time", [0.1, 0.2, 0.3]), "abs_time": ("time", times)},
20
+ coords={"time": times},
21
+ )
22
+ return ds
23
+
24
+
25
+ @pytest.fixture
26
+ def identical_members(create_member_ds: xr.Dataset) -> dict[str, xr.Dataset]:
27
+ """Two truly identical members for basic tests."""
28
+ return {
29
+ "member1": create_member_ds.copy(),
30
+ "member2": create_member_ds.copy(),
31
+ }
32
+
33
+
34
+ @pytest.fixture
35
+ def varied_members() -> dict[str, xr.Dataset]:
36
+ """Ensemble members with different lengths, frequencies, start dates, and leading NaNs."""
37
+ # Member 1: daily, 5 days, starts 2000-01-01, first value is NaN
38
+ times1 = np.array(
39
+ ["2000-01-01", "2000-01-02", "2000-01-03", "2000-01-04", "2000-01-05"],
40
+ dtype="datetime64[ns]",
41
+ )
42
+ ds1 = xr.Dataset(
43
+ {
44
+ "cdr_efficiency": ("time", [np.nan, 0.2, 0.3, 0.4, 0.5]),
45
+ "abs_time": ("time", times1),
46
+ },
47
+ coords={"time": times1},
48
+ )
49
+
50
+ # Member 2: every 2 days, 4 entries, starts 2000-01-02, first two values NaN
51
+ times2 = np.array(
52
+ ["2000-01-02", "2000-01-04", "2000-01-06", "2000-01-08"], dtype="datetime64[ns]"
53
+ )
54
+ ds2 = xr.Dataset(
55
+ {
56
+ "cdr_efficiency": ("time", [np.nan, np.nan, 0.6, 0.8]),
57
+ "abs_time": ("time", times2),
58
+ },
59
+ coords={"time": times2},
60
+ )
61
+
62
+ # Member 3: daily, 3 days, starts 1999-12-31, no NaNs
63
+ times3 = np.array(
64
+ ["1999-12-31", "2000-01-01", "2000-01-02"], dtype="datetime64[ns]"
65
+ )
66
+ ds3 = xr.Dataset(
67
+ {"cdr_efficiency": ("time", [0.05, 0.15, 0.25]), "abs_time": ("time", times3)},
68
+ coords={"time": times3},
69
+ )
70
+
71
+ return {"member1": ds1, "member2": ds2, "member3": ds3}
72
+
73
+
74
+ # ----------------------------
75
+ # Tests
76
+ # ----------------------------
77
+
78
+
79
+ def test_extract_efficiency(create_member_ds: xr.Dataset) -> None:
80
+ ens = Ensemble.__new__(Ensemble)
81
+ eff_rel = ens._extract_efficiency(create_member_ds)
82
+
83
+ assert isinstance(eff_rel, xr.DataArray)
84
+ assert np.issubdtype(eff_rel.time.dtype, np.timedelta64)
85
+ assert "abs_time" not in eff_rel.coords
86
+ assert eff_rel.time.attrs.get("long_name") == "time since release start"
87
+
88
+
89
+ def test_extract_efficiency_missing_abs_time() -> None:
90
+ """Test that _extract_efficiency raises an error if 'abs_time' is missing."""
91
+ times = np.array(["2000-01-01", "2000-01-02"], dtype="datetime64[ns]")
92
+ ds = xr.Dataset(
93
+ {"cdr_efficiency": ("time", [0.1, 0.2])},
94
+ coords={"time": times}, # Note: no 'abs_time' coordinate
95
+ )
96
+
97
+ ens = Ensemble.__new__(Ensemble)
98
+ with pytest.raises(
99
+ ValueError, match="Dataset must contain an 'abs_time' coordinate."
100
+ ):
101
+ ens._extract_efficiency(ds)
102
+
103
+
104
+ def test_align_times_identical(identical_members: dict[str, xr.Dataset]) -> None:
105
+ ens = Ensemble.__new__(Ensemble)
106
+ effs = {
107
+ name: Ensemble._extract_efficiency(ens, ds)
108
+ for name, ds in identical_members.items()
109
+ }
110
+ aligned = ens._align_times(effs)
111
+
112
+ assert isinstance(aligned, xr.Dataset)
113
+ for name in identical_members.keys():
114
+ assert name in aligned.data_vars
115
+
116
+ # Time dimension matches union of member times
117
+ all_times = np.unique(np.concatenate([eff.time.values for eff in effs.values()]))
118
+ assert len(aligned.time) == len(all_times)
119
+
120
+
121
+ def test_align_times_varied(varied_members: dict[str, xr.Dataset]) -> None:
122
+ ens = Ensemble.__new__(Ensemble)
123
+ effs = {
124
+ name: Ensemble._extract_efficiency(ens, ds)
125
+ for name, ds in varied_members.items()
126
+ }
127
+ aligned = ens._align_times(effs)
128
+
129
+ # Check all members exist
130
+ for name in varied_members.keys():
131
+ assert name in aligned.data_vars
132
+
133
+ # Time dimension is union of all times
134
+ all_times = np.unique(np.concatenate([eff.time.values for eff in effs.values()]))
135
+ assert len(aligned.time) == len(all_times)
136
+
137
+ # Check that for each member, times before first valid value and after last valid value are NaN
138
+ for name, eff in effs.items():
139
+ # Find first and last valid relative times
140
+ valid_mask = ~np.isnan(eff.values)
141
+ first_valid_time = eff.time.values[valid_mask][0]
142
+ last_valid_time = eff.time.values[valid_mask][-1]
143
+
144
+ # Times before first valid → should be NaN
145
+ missing_before = aligned.time.values < first_valid_time
146
+ if missing_before.any():
147
+ assert np.all(np.isnan(aligned[name].values[missing_before]))
148
+
149
+ # Times after last valid → should be NaN
150
+ missing_after = aligned.time.values > last_valid_time
151
+ if missing_after.any():
152
+ assert np.all(np.isnan(aligned[name].values[missing_after]))
153
+
154
+
155
+ def test_compute_statistics(identical_members: dict[str, xr.Dataset]) -> None:
156
+ ens = Ensemble.__new__(Ensemble)
157
+ effs = {
158
+ name: Ensemble._extract_efficiency(ens, ds)
159
+ for name, ds in identical_members.items()
160
+ }
161
+ aligned = ens._align_times(effs)
162
+ ds_stats = ens._compute_statistics(aligned)
163
+
164
+ assert "ensemble_mean" in ds_stats.data_vars
165
+ assert "ensemble_std" in ds_stats.data_vars
166
+ n_time = len(ds_stats.time)
167
+ assert ds_stats.ensemble_mean.shape[0] == n_time
168
+ assert ds_stats.ensemble_std.shape[0] == n_time
169
+
170
+ # Ensemble mean should equal the member values
171
+ first_member_name = next(iter(identical_members))
172
+ xr.testing.assert_allclose(ds_stats.ensemble_mean, ds_stats[first_member_name])
173
+
174
+ # For identical members, std should be 0
175
+ np.testing.assert_allclose(ds_stats.ensemble_std.values, 0.0)
176
+
177
+
178
+ def test_ensemble_post_init(identical_members: dict[str, xr.Dataset]) -> None:
179
+ ens = Ensemble(identical_members)
180
+ assert isinstance(ens.ds, xr.Dataset)
181
+ assert "ensemble_mean" in ens.ds.data_vars
182
+ assert "ensemble_std" in ens.ds.data_vars
183
+ np.testing.assert_allclose(ens.ds.ensemble_std.values, 0.0)
184
+
185
+
186
+ def test_plot(identical_members: dict[str, xr.Dataset], tmp_path: Path) -> None:
187
+ ens = Ensemble(identical_members)
188
+ save_path = tmp_path / "plot.png"
189
+ ens.plot(save_path=str(save_path))
190
+ assert save_path.exists()
191
+
192
+
193
+ def test_extract_efficiency_empty() -> None:
194
+ # Dataset with all NaN
195
+ times = np.array(["2000-01-01", "2000-01-02"], dtype="datetime64[ns]")
196
+ ds = xr.Dataset(
197
+ {"cdr_efficiency": ("time", [np.nan, np.nan]), "abs_time": ("time", times)},
198
+ coords={"time": times},
199
+ )
200
+ ens = Ensemble.__new__(Ensemble)
201
+ with pytest.raises(ValueError):
202
+ ens._extract_efficiency(ds)
@@ -72,15 +72,19 @@ def test_load_model_output_file(roms_output_fixture, request):
72
72
  assert isinstance(roms_output.ds, xr.Dataset)
73
73
 
74
74
 
75
- def test_load_model_output_file_list(use_dask):
75
+ @pytest.fixture
76
+ def roms_output_from_two_restart_files(use_dask):
76
77
  fname_grid = Path(download_test_data("epac25km_grd.nc"))
77
78
  grid = Grid.from_file(fname_grid)
78
79
 
79
80
  # List of files
80
81
  file1 = Path(download_test_data("eastpac25km_rst.19980106000000.nc"))
81
82
  file2 = Path(download_test_data("eastpac25km_rst.19980126000000.nc"))
82
- output = ROMSOutput(grid=grid, path=[file1, file2], use_dask=use_dask)
83
- assert isinstance(output.ds, xr.Dataset)
83
+ return ROMSOutput(grid=grid, path=[file1, file2], use_dask=use_dask)
84
+
85
+
86
+ def test_load_model_output_file_list(roms_output_from_two_restart_files):
87
+ assert isinstance(roms_output_from_two_restart_files.ds, xr.Dataset)
84
88
 
85
89
 
86
90
  def test_load_model_output_with_wildcard(use_dask):
@@ -618,3 +622,57 @@ def test_regrid_with_custom_depth_levels(roms_output_fixture, request):
618
622
  assert isinstance(ds_regridded, xr.Dataset)
619
623
  assert "depth" in ds_regridded.coords
620
624
  np.allclose(ds_regridded.depth, depth_levels, atol=0.0)
625
+
626
+
627
+ @pytest.fixture
628
+ def roms_output_with_cdr_vars(roms_output_from_two_restart_files):
629
+ """Adds minimal CDR variables to the ROMSOutput dataset."""
630
+ ds = roms_output_from_two_restart_files.ds.copy()
631
+
632
+ # Dimensions
633
+ time = ds.sizes["time"]
634
+ eta_rho = ds.sizes["eta_rho"]
635
+ xi_rho = ds.sizes["xi_rho"]
636
+ s_rho = ds.sizes["s_rho"]
637
+
638
+ # Add required variables for CDR metrics
639
+ ds["ALK_source"] = xr.DataArray(
640
+ np.abs(np.random.randn(time, s_rho, eta_rho, xi_rho)),
641
+ dims=("time", "s_rho", "eta_rho", "xi_rho"),
642
+ )
643
+ ds["DIC_source"] = xr.DataArray(
644
+ -np.abs(np.random.randn(time, s_rho, eta_rho, xi_rho)),
645
+ dims=("time", "s_rho", "eta_rho", "xi_rho"),
646
+ )
647
+ ds["FG_CO2"] = xr.DataArray(
648
+ np.random.randn(time, eta_rho, xi_rho), dims=("time", "eta_rho", "xi_rho")
649
+ )
650
+ ds["FG_ALT_CO2"] = xr.DataArray(
651
+ np.random.randn(time, eta_rho, xi_rho), dims=("time", "eta_rho", "xi_rho")
652
+ )
653
+ ds["hDIC"] = xr.DataArray(
654
+ np.random.randn(time, s_rho, eta_rho, xi_rho),
655
+ dims=("time", "s_rho", "eta_rho", "xi_rho"),
656
+ )
657
+ ds["hDIC_ALT_CO2"] = xr.DataArray(
658
+ np.random.randn(time, s_rho, eta_rho, xi_rho),
659
+ dims=("time", "s_rho", "eta_rho", "xi_rho"),
660
+ )
661
+
662
+ # Add average begin/end times (simulate seconds)
663
+ ds["avg_begin_time"] = xr.DataArray(np.arange(time) * 3600, dims=("time",))
664
+ ds["avg_end_time"] = xr.DataArray((np.arange(time) + 1) * 3600, dims=("time",))
665
+
666
+ roms_output_from_two_restart_files.ds = ds
667
+ return roms_output_from_two_restart_files
668
+
669
+
670
+ def test_cdr_metrics_computes_and_plots(roms_output_with_cdr_vars):
671
+ roms_output_with_cdr_vars.cdr_metrics()
672
+ assert hasattr(roms_output_with_cdr_vars, "ds_cdr")
673
+
674
+ ds_cdr = roms_output_with_cdr_vars.ds_cdr
675
+
676
+ # Check presence of both efficiency variables
677
+ assert "cdr_efficiency" in ds_cdr
678
+ assert "cdr_efficiency_from_delta_diff" in ds_cdr
@@ -1,7 +1,9 @@
1
1
  import logging
2
+ import os
2
3
  import textwrap
3
4
  from datetime import datetime
4
5
  from pathlib import Path
6
+ from unittest import mock
5
7
 
6
8
  import matplotlib.pyplot as plt
7
9
  import numpy as np
@@ -11,6 +13,12 @@ import xarray as xr
11
13
  from conftest import calculate_data_hash
12
14
  from roms_tools import BoundaryForcing, Grid
13
15
  from roms_tools.download import download_test_data
16
+ from roms_tools.tests.test_setup.utils import download_regional_and_bigger
17
+
18
+ try:
19
+ import copernicusmarine # type: ignore
20
+ except ImportError:
21
+ copernicusmarine = None
14
22
 
15
23
 
16
24
  @pytest.mark.parametrize(
@@ -269,7 +277,7 @@ def test_boundary_divided_by_land_warning(caplog, use_dask):
269
277
  use_dask=use_dask,
270
278
  )
271
279
  # Verify the warning message in the log
272
- assert "the western boundary is divided by land" in caplog.text
280
+ assert "divided by land" in caplog.text
273
281
 
274
282
 
275
283
  def test_info_depth(caplog, use_dask):
@@ -322,57 +330,6 @@ def test_info_depth(caplog, use_dask):
322
330
  )
323
331
 
324
332
 
325
- def test_info_fill(caplog, use_dask):
326
- grid = Grid(
327
- nx=3,
328
- ny=3,
329
- size_x=400,
330
- size_y=400,
331
- center_lon=-8,
332
- center_lat=58,
333
- rot=0,
334
- N=3, # number of vertical levels
335
- theta_s=5.0, # surface control parameter
336
- theta_b=2.0, # bottom control parameter
337
- hc=250.0, # critical depth
338
- )
339
-
340
- fname1 = Path(download_test_data("GLORYS_NA_20120101.nc"))
341
- fname2 = Path(download_test_data("GLORYS_NA_20121231.nc"))
342
-
343
- with caplog.at_level(logging.INFO):
344
- BoundaryForcing(
345
- grid=grid,
346
- start_time=datetime(2012, 1, 1),
347
- end_time=datetime(2012, 12, 31),
348
- source={"name": "GLORYS", "path": [fname1, fname2]},
349
- apply_2d_horizontal_fill=True,
350
- use_dask=use_dask,
351
- )
352
-
353
- # Verify the warning message in the log
354
- assert (
355
- "Applying 2D horizontal fill to the source data before regridding."
356
- in caplog.text
357
- )
358
-
359
- # Clear the log before the next test
360
- caplog.clear()
361
-
362
- with caplog.at_level(logging.INFO):
363
- BoundaryForcing(
364
- grid=grid,
365
- start_time=datetime(2012, 1, 1),
366
- end_time=datetime(2012, 12, 31),
367
- source={"name": "GLORYS", "path": [fname1, fname2]},
368
- apply_2d_horizontal_fill=False,
369
- use_dask=use_dask,
370
- )
371
- # Verify the warning message in the log
372
- for direction in ["south", "east", "north", "west"]:
373
- assert f"Applying 1D horizontal fill to {direction}ern boundary." in caplog.text
374
-
375
-
376
333
  def test_1d_and_2d_fill_coincide_if_no_fill(use_dask):
377
334
  grid = Grid(
378
335
  nx=2,
@@ -758,3 +715,105 @@ def test_from_yaml_missing_boundary_forcing(tmp_path, use_dask):
758
715
 
759
716
  yaml_filepath = Path(yaml_filepath)
760
717
  yaml_filepath.unlink()
718
+
719
+
720
+ @pytest.mark.stream
721
+ @pytest.mark.use_dask
722
+ @pytest.mark.use_copernicus
723
+ def test_default_glorys_dataset_loading(tiny_grid: Grid) -> None:
724
+ """Verify the default GLORYS dataset is loaded when a path is not provided."""
725
+ start_time = datetime(2010, 2, 1)
726
+ end_time = datetime(2010, 3, 1)
727
+
728
+ with mock.patch.dict(
729
+ os.environ, {"PYDEVD_WARN_EVALUATION_TIMEOUT": "90"}, clear=True
730
+ ):
731
+ bf = BoundaryForcing(
732
+ grid=tiny_grid,
733
+ source={"name": "GLORYS"},
734
+ type="physics",
735
+ start_time=start_time,
736
+ end_time=end_time,
737
+ use_dask=True,
738
+ bypass_validation=True,
739
+ )
740
+
741
+ expected_vars = {"u_south", "v_south", "temp_south", "salt_south"}
742
+ assert set(bf.ds.data_vars).issuperset(expected_vars)
743
+
744
+
745
+ @pytest.mark.use_copernicus
746
+ @pytest.mark.skipif(copernicusmarine is None, reason="copernicusmarine required")
747
+ @pytest.mark.parametrize(
748
+ "grid_fixture",
749
+ [
750
+ "tiny_grid_that_straddles_dateline",
751
+ "tiny_grid_that_straddles_180_degree_meridian",
752
+ "tiny_rotated_grid",
753
+ ],
754
+ )
755
+ def test_invariance_to_get_glorys_bounds(tmp_path, grid_fixture, use_dask, request):
756
+ start_time = datetime(2012, 1, 1)
757
+ grid = request.getfixturevalue(grid_fixture)
758
+
759
+ regional_file, bigger_regional_file = download_regional_and_bigger(
760
+ tmp_path, grid, start_time
761
+ )
762
+
763
+ bf_from_regional = BoundaryForcing(
764
+ grid=grid,
765
+ source={"name": "GLORYS", "path": str(regional_file)},
766
+ type="physics",
767
+ start_time=start_time,
768
+ end_time=start_time,
769
+ apply_2d_horizontal_fill=True,
770
+ use_dask=use_dask,
771
+ )
772
+ bf_from_bigger_regional = BoundaryForcing(
773
+ grid=grid,
774
+ source={"name": "GLORYS", "path": str(bigger_regional_file)},
775
+ type="physics",
776
+ start_time=start_time,
777
+ end_time=start_time,
778
+ apply_2d_horizontal_fill=True,
779
+ use_dask=use_dask,
780
+ )
781
+
782
+ # Use assert_allclose instead of equals: necessary for grids that straddle the 180° meridian.
783
+ # Copernicus returns data on [-180, 180] by default, but if you request a range
784
+ # like [170, 190], it remaps longitudes. That remapping introduces tiny floating
785
+ # point differences in the longitude coordinate, which will then propagate into further differences once you do regridding.
786
+ # Need to adjust the tolerances for these grids that straddle the 180° meridian.
787
+ xr.testing.assert_allclose(
788
+ bf_from_bigger_regional.ds, bf_from_regional.ds, rtol=1e-4, atol=1e-5
789
+ )
790
+
791
+
792
+ @pytest.mark.parametrize(
793
+ "use_dask",
794
+ [pytest.param(True, marks=pytest.mark.use_dask), False],
795
+ )
796
+ def test_nondefault_glorys_dataset_loading(small_grid: Grid, use_dask: bool) -> None:
797
+ """Verify a non-default GLORYS dataset is loaded when a path is provided."""
798
+ start_time = datetime(2012, 1, 1)
799
+ end_time = datetime(2012, 12, 31)
800
+
801
+ local_path = Path(download_test_data("GLORYS_NA_20120101.nc"))
802
+
803
+ with mock.patch.dict(
804
+ os.environ, {"PYDEVD_WARN_EVALUATION_TIMEOUT": "90"}, clear=True
805
+ ):
806
+ bf = BoundaryForcing(
807
+ grid=small_grid,
808
+ source={
809
+ "name": "GLORYS",
810
+ "path": local_path,
811
+ },
812
+ type="physics",
813
+ start_time=start_time,
814
+ end_time=end_time,
815
+ use_dask=use_dask,
816
+ )
817
+
818
+ expected_vars = {"u_south", "v_south", "temp_south", "salt_south"}
819
+ assert set(bf.ds.data_vars).issuperset(expected_vars)
@@ -3,6 +3,7 @@ from datetime import datetime, timedelta
3
3
  from pathlib import Path
4
4
 
5
5
  import numpy as np
6
+ import pandas as pd
6
7
  import pytest
7
8
  import xarray as xr
8
9
  from pydantic import ValidationError
@@ -16,6 +17,7 @@ from roms_tools.setup.cdr_forcing import (
16
17
  ReleaseSimulationManager,
17
18
  )
18
19
  from roms_tools.setup.cdr_release import ReleaseType
20
+ from roms_tools.setup.utils import get_tracer_metadata_dict
19
21
 
20
22
  try:
21
23
  import xesmf # type: ignore
@@ -977,3 +979,55 @@ class TestCDRForcing:
977
979
  yaml_filepath.unlink()
978
980
  filepath1.unlink()
979
981
  filepath2.unlink()
982
+
983
+ @pytest.mark.parametrize(
984
+ "cdr_forcing, tracer_attr",
985
+ [
986
+ ("volume_release_cdr_forcing_without_grid", "tracer_concentrations"),
987
+ ("tracer_perturbation_cdr_forcing_without_grid", "tracer_fluxes"),
988
+ ],
989
+ )
990
+ def test_compute_total_cdr_source(self, cdr_forcing, tracer_attr, request):
991
+ dt = 30.0
992
+ cdr_instance = getattr(self, cdr_forcing)
993
+
994
+ df = cdr_instance.compute_total_cdr_source(dt)
995
+
996
+ # Check type
997
+ assert isinstance(df, pd.DataFrame)
998
+
999
+ # Check rows = number of releases + 1 for the units row
1000
+ assert df.shape[0] == len(cdr_instance.releases) + 1
1001
+
1002
+ # Columns = tracer names
1003
+ all_tracers = set()
1004
+ for r in cdr_instance.releases:
1005
+ all_tracers.update(getattr(r, tracer_attr).keys())
1006
+
1007
+ # Remove temp and salt since they are excluded from integrated totals
1008
+ all_tracers.discard("temp")
1009
+ all_tracers.discard("salt")
1010
+
1011
+ # Columns are now just tracer names (units row removed)
1012
+ col_tracers = set(df.columns)
1013
+ assert col_tracers == all_tracers
1014
+
1015
+ # Check that units are included in the units row
1016
+ tracer_meta = get_tracer_metadata_dict(include_bgc=True, unit_type="integrated")
1017
+ for tracer in df.columns:
1018
+ unit = tracer_meta.get(tracer, {}).get("units", None)
1019
+ if unit:
1020
+ assert df.loc["units", tracer] == unit, (
1021
+ f"Units row for '{tracer}' is incorrect"
1022
+ )
1023
+
1024
+ # Exclude units row
1025
+ data_only = df.drop("units")
1026
+
1027
+ # Convert all columns to numeric, coerce errors to NaN
1028
+ data_numeric = data_only.apply(pd.to_numeric, errors="coerce")
1029
+
1030
+ # Now check all finite (ignoring any NaNs that were non-numeric)
1031
+ assert np.all(np.isfinite(data_numeric.values)), (
1032
+ "Some values are not finite numbers"
1033
+ )