roms-tools 2.0.0__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. roms_tools/__init__.py +2 -1
  2. roms_tools/setup/boundary_forcing.py +21 -30
  3. roms_tools/setup/datasets.py +13 -21
  4. roms_tools/setup/grid.py +253 -139
  5. roms_tools/setup/initial_conditions.py +21 -3
  6. roms_tools/setup/mask.py +50 -4
  7. roms_tools/setup/nesting.py +575 -0
  8. roms_tools/setup/plot.py +214 -55
  9. roms_tools/setup/river_forcing.py +125 -29
  10. roms_tools/setup/surface_forcing.py +21 -8
  11. roms_tools/setup/tides.py +21 -3
  12. roms_tools/setup/topography.py +168 -35
  13. roms_tools/setup/utils.py +127 -21
  14. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/.zmetadata +2 -3
  15. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/river_tracer/.zattrs +1 -2
  16. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/tracer_name/.zarray +1 -1
  17. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/tracer_name/0 +0 -0
  18. roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/.zmetadata +5 -6
  19. roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_tracer/.zarray +2 -2
  20. roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_tracer/.zattrs +1 -2
  21. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/river_tracer/0.0.0 +0 -0
  22. roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/tracer_name/.zarray +2 -2
  23. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/tracer_name/0 +0 -0
  24. roms_tools/tests/test_setup/test_datasets.py +2 -2
  25. roms_tools/tests/test_setup/test_nesting.py +489 -0
  26. roms_tools/tests/test_setup/test_river_forcing.py +50 -13
  27. roms_tools/tests/test_setup/test_surface_forcing.py +1 -0
  28. roms_tools/tests/test_setup/test_validation.py +2 -2
  29. {roms_tools-2.0.0.dist-info → roms_tools-2.1.0.dist-info}/METADATA +8 -4
  30. {roms_tools-2.0.0.dist-info → roms_tools-2.1.0.dist-info}/RECORD +51 -50
  31. {roms_tools-2.0.0.dist-info → roms_tools-2.1.0.dist-info}/WHEEL +1 -1
  32. roms_tools/_version.py +0 -2
  33. roms_tools/tests/test_setup/test_data/river_forcing.zarr/river_tracer/0.0.0 +0 -0
  34. roms_tools/tests/test_setup/test_data/river_forcing.zarr/tracer_name/0 +0 -0
  35. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/.zattrs +0 -0
  36. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/.zgroup +0 -0
  37. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/abs_time/.zarray +0 -0
  38. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/abs_time/.zattrs +0 -0
  39. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/abs_time/0 +0 -0
  40. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/month/.zarray +0 -0
  41. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/month/.zattrs +0 -0
  42. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/month/0 +0 -0
  43. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_name/.zarray +0 -0
  44. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_name/.zattrs +0 -0
  45. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_name/0 +0 -0
  46. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_time/.zarray +0 -0
  47. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_time/.zattrs +0 -0
  48. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_time/0 +0 -0
  49. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_volume/.zarray +0 -0
  50. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_volume/.zattrs +0 -0
  51. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_volume/0.0 +0 -0
  52. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/tracer_name/.zattrs +0 -0
  53. {roms_tools-2.0.0.dist-info → roms_tools-2.1.0.dist-info}/LICENSE +0 -0
  54. {roms_tools-2.0.0.dist-info → roms_tools-2.1.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "chunks": [
3
- 2
3
+ 34
4
4
  ],
5
5
  "compressor": {
6
6
  "blocksize": 0,
@@ -14,7 +14,7 @@
14
14
  "filters": null,
15
15
  "order": "C",
16
16
  "shape": [
17
- 2
17
+ 34
18
18
  ],
19
19
  "zarr_format": 2
20
20
  }
@@ -431,8 +431,8 @@ def test_era5_correction_choose_subdomain(use_dask):
431
431
  data = ERA5Correction(use_dask=use_dask)
432
432
  lats = data.ds.latitude[10:20]
433
433
  lons = data.ds.longitude[10:20]
434
- coords = {"latitude": lats, "longitude": lons}
435
- data.choose_subdomain(coords, straddle=False)
434
+ target_coords = {"lat": lats, "lon": lons}
435
+ data.choose_subdomain(target_coords, straddle=False)
436
436
  assert (data.ds["latitude"] == lats).all()
437
437
  assert (data.ds["longitude"] == lons).all()
438
438
 
@@ -0,0 +1,489 @@
1
+ import pytest
2
+ import xarray as xr
3
+ import numpy as np
4
+ import logging
5
+ from pathlib import Path
6
+ from roms_tools import Grid, Nesting
7
+ from roms_tools.setup.utils import get_boundary_coords
8
+ from conftest import calculate_file_hash
9
+ from roms_tools.setup.nesting import (
10
+ interpolate_indices,
11
+ map_child_boundaries_onto_parent_grid_indices,
12
+ compute_boundary_distance,
13
+ modify_child_topography_and_mask,
14
+ )
15
+
16
+
17
+ @pytest.fixture()
18
+ def parent_grid():
19
+ return Grid(
20
+ nx=5, ny=7, center_lon=-23, center_lat=61, rot=20, size_x=1800, size_y=2400
21
+ )
22
+
23
+
24
+ @pytest.fixture()
25
+ def child_grid():
26
+ return Grid(
27
+ nx=10, ny=10, center_lon=-23, center_lat=61, rot=-20, size_x=500, size_y=500
28
+ )
29
+
30
+
31
+ @pytest.fixture()
32
+ def baby_grid():
33
+ return Grid(
34
+ nx=3, ny=5, center_lon=-23, center_lat=61, rot=0, size_x=200, size_y=200
35
+ )
36
+
37
+
38
+ @pytest.fixture()
39
+ def parent_grid_that_straddles():
40
+ return Grid(
41
+ nx=5, ny=7, center_lon=10, center_lat=61, rot=20, size_x=1800, size_y=2400
42
+ )
43
+
44
+
45
+ @pytest.fixture()
46
+ def child_grid_that_straddles():
47
+ return Grid(
48
+ nx=10, ny=10, center_lon=10, center_lat=61, rot=-20, size_x=500, size_y=500
49
+ )
50
+
51
+
52
+ @pytest.fixture()
53
+ def nesting(parent_grid, child_grid):
54
+ return Nesting(parent_grid=parent_grid, child_grid=child_grid, period=3600.0)
55
+
56
+
57
+ @pytest.fixture()
58
+ def nesting_that_straddles(parent_grid_that_straddles, child_grid_that_straddles):
59
+ return Nesting(
60
+ parent_grid=parent_grid_that_straddles,
61
+ child_grid=child_grid_that_straddles,
62
+ period=3600.0,
63
+ )
64
+
65
+
66
+ class TestInterpolateIndices:
67
+ @pytest.mark.parametrize(
68
+ "grid",
69
+ [
70
+ "parent_grid",
71
+ "parent_grid_that_straddles",
72
+ ],
73
+ )
74
+ def test_correct_indices_of_same_grid(self, grid, caplog, request):
75
+ """Verify boundary indices are correctly interpolated for the same grid."""
76
+
77
+ grid = request.getfixturevalue(grid)
78
+
79
+ bdry_coords_dict = get_boundary_coords()
80
+ location = "rho"
81
+ for direction in ["south", "east", "north", "west"]:
82
+ bdry_coords = bdry_coords_dict[location][direction]
83
+ lon = grid.ds[f"lon_{location}"].isel(**bdry_coords)
84
+ lat = grid.ds[f"lat_{location}"].isel(**bdry_coords)
85
+ mask = grid.ds[f"mask_{location}"].isel(**bdry_coords)
86
+
87
+ with caplog.at_level(logging.WARNING):
88
+ i_eta, i_xi = interpolate_indices(grid.ds, lon, lat, mask)
89
+
90
+ # Verify the warning message in the log
91
+ assert (
92
+ "Some boundary points of the child grid are very close to the boundary of the parent grid."
93
+ in caplog.text
94
+ )
95
+
96
+ if direction == "south":
97
+ expected_i_eta = -0.5 * xr.ones_like(grid.ds.xi_rho)
98
+ expected_i_xi = np.arange(-0.5, grid.ds.xi_rho[-1] + 0.5)
99
+ elif direction == "east":
100
+ expected_i_eta = np.arange(-0.5, grid.ds.eta_rho[-1] + 0.5)
101
+ expected_i_xi = (grid.ds.xi_rho[-1] - 0.5) * xr.ones_like(
102
+ grid.ds.eta_rho
103
+ )
104
+ elif direction == "north":
105
+ expected_i_eta = (grid.ds.eta_rho[-1] - 0.5) * xr.ones_like(
106
+ grid.ds.xi_rho
107
+ )
108
+ expected_i_xi = np.arange(-0.5, grid.ds.xi_rho[-1] + 0.5)
109
+ elif direction == "west":
110
+ expected_i_eta = np.arange(-0.5, grid.ds.eta_rho[-1] + 0.5)
111
+ expected_i_xi = -0.5 * xr.ones_like(grid.ds.eta_rho)
112
+
113
+ np.testing.assert_allclose(i_eta.values, expected_i_eta)
114
+ np.testing.assert_allclose(i_xi.values, expected_i_xi)
115
+
116
+ @pytest.mark.parametrize(
117
+ "parent_grid_fixture, child_grid_fixture",
118
+ [
119
+ ("parent_grid", "child_grid"),
120
+ ("parent_grid_that_straddles", "child_grid_that_straddles"),
121
+ ],
122
+ )
123
+ def test_indices_are_within_range_of_parent_grid(
124
+ self, parent_grid_fixture, child_grid_fixture, request
125
+ ):
126
+ """Ensure interpolated indices fall within the parent grid's bounds."""
127
+
128
+ parent_grid = request.getfixturevalue(parent_grid_fixture)
129
+ child_grid = request.getfixturevalue(child_grid_fixture)
130
+
131
+ bdry_coords_dict = get_boundary_coords()
132
+ for location in ["rho", "u", "v"]:
133
+ for direction in ["south", "east", "north", "west"]:
134
+ bdry_coords = bdry_coords_dict[location][direction]
135
+ lon = child_grid.ds[f"lon_{location}"].isel(**bdry_coords)
136
+ lat = child_grid.ds[f"lat_{location}"].isel(**bdry_coords)
137
+ mask = child_grid.ds[f"mask_{location}"].isel(**bdry_coords)
138
+
139
+ i_eta, i_xi = interpolate_indices(parent_grid.ds, lon, lat, mask)
140
+
141
+ expected_i_eta_min = -0.5
142
+ expected_i_eta_max = parent_grid.ds.eta_rho[-1] - 0.5
143
+ expected_i_xi_min = -0.5
144
+ expected_i_xi_max = parent_grid.ds.xi_rho[-1] - 0.5
145
+
146
+ assert (i_eta >= expected_i_eta_min).all()
147
+ assert (i_eta <= expected_i_eta_max).all()
148
+ assert (i_xi >= expected_i_xi_min).all()
149
+ assert (i_xi <= expected_i_xi_max).all()
150
+
151
+
152
+ class TestMapChildBoundaries:
153
+ def test_update_indices_does_nothing_if_no_parent_land(self, child_grid, baby_grid):
154
+ """Verify no change in indices when parent grid has no land at boundaries."""
155
+
156
+ ds_without_updated_indices = map_child_boundaries_onto_parent_grid_indices(
157
+ child_grid.ds, baby_grid.ds, update_land_indices=False
158
+ )
159
+ ds_with_updated_indices = map_child_boundaries_onto_parent_grid_indices(
160
+ child_grid.ds, baby_grid.ds, update_land_indices=True
161
+ )
162
+
163
+ xr.testing.assert_allclose(ds_without_updated_indices, ds_with_updated_indices)
164
+
165
+ @pytest.mark.parametrize(
166
+ "parent_grid_fixture, child_grid_fixture",
167
+ [
168
+ ("parent_grid", "child_grid"),
169
+ ("parent_grid_that_straddles", "child_grid_that_straddles"),
170
+ ],
171
+ )
172
+ def test_updated_indices_map_to_wet_points(
173
+ self, parent_grid_fixture, child_grid_fixture, request
174
+ ):
175
+ """Check updated indices map to wet points on the parent grid."""
176
+
177
+ parent_grid = request.getfixturevalue(parent_grid_fixture)
178
+ child_grid = request.getfixturevalue(child_grid_fixture)
179
+
180
+ ds = map_child_boundaries_onto_parent_grid_indices(
181
+ parent_grid.ds, child_grid.ds
182
+ )
183
+ for direction in ["south", "east", "north", "west"]:
184
+ for location in ["rho", "u", "v"]:
185
+ if location == "rho":
186
+ dim = "two"
187
+ location = "r"
188
+ # convert from absolute indices [-0.5, ...] to [0, ...]
189
+ i_xi = ds[f"child_{direction}_{location}"].isel({dim: 0}) + 0.5
190
+ i_eta = ds[f"child_{direction}_{location}"].isel({dim: 1}) + 0.5
191
+ for i in range(len(i_xi)):
192
+ i_eta_lower = int(np.floor(i_eta[i]))
193
+ i_xi_lower = int(np.floor(i_xi[i]))
194
+ mask = parent_grid.ds.mask_rho.isel(
195
+ eta_rho=slice(i_eta_lower, i_eta_lower + 2),
196
+ xi_rho=slice(i_xi_lower, i_xi_lower + 2),
197
+ )
198
+ assert np.sum(mask) > 0
199
+ # TODO: check also u and v locations
200
+
201
+ @pytest.mark.parametrize(
202
+ "parent_grid_fixture, child_grid_fixture",
203
+ [
204
+ ("parent_grid", "child_grid"),
205
+ ("parent_grid_that_straddles", "child_grid_that_straddles"),
206
+ ],
207
+ )
208
+ def test_indices_are_monotonically_increasing(
209
+ self, parent_grid_fixture, child_grid_fixture, request
210
+ ):
211
+ """Test that child boundary indices are monotonically increasing or decreasing
212
+ in both the xi and eta directions, for all boundaries and locations."""
213
+
214
+ parent_grid = request.getfixturevalue(parent_grid_fixture)
215
+ child_grid = request.getfixturevalue(child_grid_fixture)
216
+
217
+ for update_land_indices in [False, True]:
218
+ ds = map_child_boundaries_onto_parent_grid_indices(
219
+ parent_grid.ds, child_grid.ds, update_land_indices=update_land_indices
220
+ )
221
+
222
+ for direction in ["south", "east", "north", "west"]:
223
+ for location in ["rho", "u", "v"]:
224
+ if location == "rho":
225
+ dim = "two"
226
+ location = "r"
227
+ else:
228
+ dim = "three"
229
+
230
+ for coord in [0, 1]: # 0 for xi, 1 for eta
231
+ index_values = ds[f"child_{direction}_{location}"].isel(
232
+ {dim: coord}
233
+ )
234
+ assert np.all(np.diff(index_values) >= 0) or np.all(
235
+ np.diff(index_values) <= 0
236
+ )
237
+
238
+
239
+ class TestBoundaryDistance:
240
+ @pytest.mark.parametrize(
241
+ "grid_fixture",
242
+ [
243
+ "child_grid",
244
+ "baby_grid",
245
+ ],
246
+ )
247
+ def test_boundary_distance_for_grid_without_land_along_boundary(
248
+ self, grid_fixture, request
249
+ ):
250
+ """Ensure boundary distance is zero for grids without land along boundaries."""
251
+
252
+ grid = request.getfixturevalue(grid_fixture)
253
+ alpha = compute_boundary_distance(grid.ds.mask_rho)
254
+
255
+ # check that all boundaries are zero
256
+ assert (alpha.isel(eta_rho=0) == 0).all()
257
+ assert (alpha.isel(eta_rho=-1) == 0).all()
258
+ assert (alpha.isel(xi_rho=0) == 0).all()
259
+ assert (alpha.isel(xi_rho=-1) == 0).all()
260
+
261
+ # check that inner values are 1
262
+ assert (
263
+ alpha.isel(
264
+ eta_rho=alpha.sizes["eta_rho"] // 2, xi_rho=alpha.sizes["xi_rho"] // 2
265
+ )
266
+ == 1
267
+ )
268
+
269
+ def test_boundary_distance_for_grid_with_land_along_boundary(self, parent_grid):
270
+ """Test that there are 1s along the boundary of alpha if the grid has land along
271
+ the boundary."""
272
+ alpha = compute_boundary_distance(parent_grid.ds.mask_rho)
273
+ assert (alpha.isel(eta_rho=0) == 1).any()
274
+ assert (alpha.isel(eta_rho=-1) == 1).any()
275
+ assert (alpha.isel(xi_rho=0) == 1).any()
276
+ assert (alpha.isel(xi_rho=-1) == 1).any()
277
+
278
+
279
+ class TestModifyChid:
280
+ def test_mask_is_not_modified_if_no_parent_land_along_boundaries(
281
+ self, child_grid, baby_grid
282
+ ):
283
+ """Confirm child mask remains unchanged if no parent land is at boundaries."""
284
+
285
+ mask_original = baby_grid.ds.mask_rho.copy()
286
+ modified_baby_grid_ds = modify_child_topography_and_mask(
287
+ child_grid.ds, baby_grid.ds
288
+ )
289
+ xr.testing.assert_allclose(modified_baby_grid_ds.mask_rho, mask_original)
290
+
291
+ @pytest.mark.parametrize(
292
+ "grid_fixture",
293
+ [
294
+ "parent_grid",
295
+ "child_grid",
296
+ "baby_grid",
297
+ ],
298
+ )
299
+ def test_no_modification_if_parent_and_child_coincide(self, grid_fixture, request):
300
+ """Ensure no changes occur when parent and child grids coincide."""
301
+
302
+ grid = request.getfixturevalue(grid_fixture)
303
+
304
+ h_original = grid.ds.h.copy()
305
+ mask_original = grid.ds.mask_rho.copy()
306
+ modified_grid_ds = modify_child_topography_and_mask(grid.ds, grid.ds)
307
+
308
+ xr.testing.assert_allclose(modified_grid_ds.h, h_original)
309
+ xr.testing.assert_allclose(modified_grid_ds.mask_rho, mask_original)
310
+
311
+ def test_modification_only_along_boundaries(self, parent_grid, child_grid):
312
+ """Test that modifications to the child grid's topography and mask occur only
313
+ along the boundaries, leaving the interior unchanged."""
314
+
315
+ # Make copies of original data for comparison
316
+ h_original = child_grid.ds.h.copy()
317
+ mask_original = child_grid.ds.mask_rho.copy()
318
+
319
+ # Apply the modification function
320
+ modified_ds = modify_child_topography_and_mask(parent_grid.ds, child_grid.ds)
321
+
322
+ # Calculate the center indices for the grid
323
+ eta_center = h_original.sizes["eta_rho"] // 2
324
+ xi_center = h_original.sizes["xi_rho"] // 2
325
+
326
+ # Assert that the center values remain the same
327
+ assert mask_original.isel(
328
+ eta_rho=eta_center, xi_rho=xi_center
329
+ ) == modified_ds.mask_rho.isel(
330
+ eta_rho=eta_center, xi_rho=xi_center
331
+ ), "Mask at the grid center was modified."
332
+
333
+ assert h_original.isel(
334
+ eta_rho=eta_center, xi_rho=xi_center
335
+ ) == modified_ds.h.isel(
336
+ eta_rho=eta_center, xi_rho=xi_center
337
+ ), "Topography at the grid center was modified."
338
+
339
+
340
+ class TestNesting:
341
+ @pytest.mark.parametrize(
342
+ "nesting_fixture",
343
+ ["nesting", "nesting_that_straddles"],
344
+ )
345
+ def test_successful_initialization(self, nesting_fixture, request):
346
+ nesting = request.getfixturevalue(nesting_fixture)
347
+
348
+ assert nesting.boundaries == {
349
+ "south": True,
350
+ "east": True,
351
+ "north": True,
352
+ "west": True,
353
+ }
354
+ assert nesting.child_prefix == "child"
355
+ assert nesting.period == 3600.0
356
+ assert isinstance(nesting.ds, xr.Dataset)
357
+
358
+ ds = nesting.ds
359
+ for direction in ["south", "east", "north", "west"]:
360
+ for location in ["r", "u", "v"]:
361
+ assert f"child_{direction}_{location}" in ds.data_vars
362
+ assert (
363
+ ds[f"child_{direction}_{location}"].attrs["output_period"] == 3600.0
364
+ )
365
+ if location == "r":
366
+ assert (
367
+ ds[f"child_{direction}_{location}"].attrs["output_vars"]
368
+ == "zeta, temp, salt"
369
+ )
370
+ elif location == "u":
371
+ assert (
372
+ ds[f"child_{direction}_{location}"].attrs["output_vars"]
373
+ == "ubar, u, up"
374
+ )
375
+ elif location == "v":
376
+ assert (
377
+ ds[f"child_{direction}_{location}"].attrs["output_vars"]
378
+ == "vbar, v, vp"
379
+ )
380
+
381
+ @pytest.mark.parametrize(
382
+ "parent_grid_fixture, child_grid_fixture",
383
+ [
384
+ ("parent_grid", "child_grid_that_straddles"),
385
+ ("parent_grid_that_straddles", "child_grid"),
386
+ ],
387
+ )
388
+ def test_error_if_child_grid_beyond_parent_grid(
389
+ self, parent_grid_fixture, child_grid_fixture, request
390
+ ):
391
+ parent_grid = request.getfixturevalue(parent_grid_fixture)
392
+ child_grid = request.getfixturevalue(child_grid_fixture)
393
+
394
+ with pytest.raises(ValueError, match="Some points are outside the grid."):
395
+ Nesting(parent_grid=parent_grid, child_grid=child_grid)
396
+
397
+ @pytest.mark.parametrize(
398
+ "nesting_fixture",
399
+ ["nesting", "nesting_that_straddles"],
400
+ )
401
+ def test_plot(self, nesting_fixture, request):
402
+ """Test plot method."""
403
+ nesting = request.getfixturevalue(nesting_fixture)
404
+
405
+ nesting.plot()
406
+ nesting.plot(with_dim_names=True)
407
+
408
+ def test_save(self, nesting, tmp_path):
409
+ """Test save method."""
410
+
411
+ for file_str, grid_file_str in zip(
412
+ ["test_nesting", "test_nesting.nc"], ["test_grid", "test_grid.nc"]
413
+ ):
414
+ # Create a temporary filepath using the tmp_path fixture
415
+ for filepath, grid_filepath in zip(
416
+ [tmp_path / file_str, str(tmp_path / file_str)],
417
+ [tmp_path / grid_file_str, str(tmp_path / grid_file_str)],
418
+ ): # test for Path object and str
419
+
420
+ # Test saving without partitioning
421
+ saved_filenames = nesting.save(filepath, grid_filepath)
422
+ # Check if the .nc file was created
423
+ filepath = Path(filepath).with_suffix(".nc")
424
+ grid_filepath = Path(grid_filepath).with_suffix(".nc")
425
+ assert saved_filenames == [filepath, grid_filepath]
426
+ assert filepath.exists()
427
+ assert grid_filepath.exists()
428
+ # Clean up the .nc file
429
+ filepath.unlink()
430
+ grid_filepath.unlink()
431
+
432
+ # Test saving with partitioning
433
+ saved_filenames = nesting.save(
434
+ filepath, grid_filepath, np_eta=5, np_xi=5
435
+ )
436
+
437
+ filepath_str = str(filepath.with_suffix(""))
438
+ grid_filepath_str = str(grid_filepath.with_suffix(""))
439
+ expected_filepath_list = [
440
+ Path(filepath_str + f".{index}.nc") for index in range(25)
441
+ ] + [Path(grid_filepath_str + f".{index}.nc") for index in range(25)]
442
+ assert saved_filenames == expected_filepath_list
443
+ for expected_filepath in expected_filepath_list:
444
+ assert expected_filepath.exists()
445
+ expected_filepath.unlink()
446
+
447
+ def test_roundtrip_yaml(self, nesting, tmp_path):
448
+ """Test that creating a Nesting object, saving its parameters to yaml file, and
449
+ re-opening yaml file creates the same object."""
450
+
451
+ # Create a temporary filepath using the tmp_path fixture
452
+ file_str = "test_yaml"
453
+ for filepath in [
454
+ tmp_path / file_str,
455
+ str(tmp_path / file_str),
456
+ ]: # test for Path object and str
457
+
458
+ nesting.to_yaml(filepath)
459
+
460
+ nesting_from_file = Nesting.from_yaml(filepath)
461
+
462
+ assert nesting == nesting_from_file
463
+
464
+ filepath = Path(filepath)
465
+ filepath.unlink()
466
+
467
+ def test_files_have_same_hash(self, nesting, tmp_path):
468
+
469
+ yaml_filepath = tmp_path / "test_yaml.yaml"
470
+ filepath1 = tmp_path / "test1.nc"
471
+ filepath2 = tmp_path / "test2.nc"
472
+ grid_filepath1 = tmp_path / "grid_test1.nc"
473
+ grid_filepath2 = tmp_path / "grid_test2.nc"
474
+
475
+ nesting.to_yaml(yaml_filepath)
476
+ nesting.save(filepath1, grid_filepath1)
477
+ nesting_from_file = Nesting.from_yaml(yaml_filepath)
478
+ nesting_from_file.save(filepath2, grid_filepath2)
479
+
480
+ hash1 = calculate_file_hash(filepath1)
481
+ hash2 = calculate_file_hash(filepath2)
482
+
483
+ assert hash1 == hash2, f"Hashes do not match: {hash1} != {hash2}"
484
+
485
+ yaml_filepath.unlink()
486
+ filepath1.unlink()
487
+ filepath2.unlink()
488
+ grid_filepath1.unlink()
489
+ grid_filepath2.unlink()
@@ -61,7 +61,11 @@ def compare_dictionaries(dict1, dict2):
61
61
 
62
62
  @pytest.mark.parametrize(
63
63
  "river_forcing_fixture",
64
- ["river_forcing", "river_forcing_for_grid_that_straddles_dateline"],
64
+ [
65
+ "river_forcing",
66
+ "river_forcing_for_grid_that_straddles_dateline",
67
+ "river_forcing_with_bgc",
68
+ ],
65
69
  )
66
70
  def test_successful_initialization_with_climatological_dai_data(
67
71
  river_forcing_fixture, request
@@ -121,7 +125,11 @@ def test_reproducibility_indices(river_forcing, river_forcing_no_climatology):
121
125
 
122
126
  @pytest.mark.parametrize(
123
127
  "river_forcing_fixture",
124
- ["river_forcing_climatology", "river_forcing_no_climatology"],
128
+ [
129
+ "river_forcing_climatology",
130
+ "river_forcing_no_climatology",
131
+ "river_forcing_with_bgc",
132
+ ],
125
133
  )
126
134
  def test_constant_tracers(river_forcing_fixture, request):
127
135
  river_forcing = request.getfixturevalue(river_forcing_fixture)
@@ -132,11 +140,18 @@ def test_constant_tracers(river_forcing_fixture, request):
132
140
  np.testing.assert_allclose(
133
141
  river_forcing.ds.river_tracer.isel(ntracers=1).values, 1.0, atol=0
134
142
  )
143
+ np.testing.assert_allclose(
144
+ river_forcing.ds.river_tracer.isel(ntracers=slice(2, None)).values, 0.0, atol=0
145
+ )
135
146
 
136
147
 
137
148
  @pytest.mark.parametrize(
138
149
  "river_forcing_fixture",
139
- ["river_forcing_climatology", "river_forcing_no_climatology"],
150
+ [
151
+ "river_forcing_climatology",
152
+ "river_forcing_no_climatology",
153
+ "river_forcing_with_bgc",
154
+ ],
140
155
  )
141
156
  def test_river_locations_are_along_coast(river_forcing_fixture, request):
142
157
  river_forcing = request.getfixturevalue(river_forcing_fixture)
@@ -228,16 +243,18 @@ def test_update_river_flux_variable_without_conflicts(river_forcing, tmp_path):
228
243
  assert isinstance(another_river_forcing.ds, xr.Dataset)
229
244
 
230
245
 
231
- def test_river_forcing_plot(river_forcing):
246
+ def test_river_forcing_plot(river_forcing_with_bgc):
232
247
  """Test plot method."""
233
248
 
234
- river_forcing.plot_locations()
235
- river_forcing.plot("river_volume")
236
- river_forcing.plot("river_temperature")
237
- river_forcing.plot("river_salinity")
249
+ river_forcing_with_bgc.plot_locations()
250
+ river_forcing_with_bgc.plot("river_volume")
251
+ river_forcing_with_bgc.plot("river_temp")
252
+ river_forcing_with_bgc.plot("river_salt")
253
+ river_forcing_with_bgc.plot("river_ALK")
254
+ river_forcing_with_bgc.plot("river_PO4")
238
255
 
239
256
 
240
- def test_river_forcing_save(river_forcing, tmp_path):
257
+ def test_river_forcing_save(river_forcing_with_bgc, tmp_path):
241
258
  """Test save method."""
242
259
 
243
260
  for file_str, grid_file_str in zip(
@@ -250,7 +267,7 @@ def test_river_forcing_save(river_forcing, tmp_path):
250
267
  ): # test for Path object and str
251
268
 
252
269
  # Test saving without partitioning
253
- saved_filenames = river_forcing.save(filepath, grid_filepath)
270
+ saved_filenames = river_forcing_with_bgc.save(filepath, grid_filepath)
254
271
  # Check if the .nc file was created
255
272
  filepath = Path(filepath).with_suffix(".nc")
256
273
  grid_filepath = Path(grid_filepath).with_suffix(".nc")
@@ -262,7 +279,7 @@ def test_river_forcing_save(river_forcing, tmp_path):
262
279
  grid_filepath.unlink()
263
280
 
264
281
  # Test saving with partitioning
265
- saved_filenames = river_forcing.save(
282
+ saved_filenames = river_forcing_with_bgc.save(
266
283
  filepath, grid_filepath, np_eta=3, np_xi=3
267
284
  )
268
285
 
@@ -277,10 +294,20 @@ def test_river_forcing_save(river_forcing, tmp_path):
277
294
  expected_filepath.unlink()
278
295
 
279
296
 
280
- def test_roundtrip_yaml(river_forcing, tmp_path):
297
+ @pytest.mark.parametrize(
298
+ "river_forcing_fixture",
299
+ [
300
+ "river_forcing_climatology",
301
+ "river_forcing_no_climatology",
302
+ "river_forcing_with_bgc",
303
+ ],
304
+ )
305
+ def test_roundtrip_yaml(river_forcing_fixture, request, tmp_path):
281
306
  """Test that creating an RiverForcing object, saving its parameters to yaml file,
282
307
  and re-opening yaml file creates the same object."""
283
308
 
309
+ river_forcing = request.getfixturevalue(river_forcing_fixture)
310
+
284
311
  # Create a temporary filepath using the tmp_path fixture
285
312
  file_str = "test_yaml"
286
313
  for filepath in [
@@ -298,7 +325,17 @@ def test_roundtrip_yaml(river_forcing, tmp_path):
298
325
  filepath.unlink()
299
326
 
300
327
 
301
- def test_files_have_same_hash(river_forcing, tmp_path):
328
+ @pytest.mark.parametrize(
329
+ "river_forcing_fixture",
330
+ [
331
+ "river_forcing_climatology",
332
+ "river_forcing_no_climatology",
333
+ "river_forcing_with_bgc",
334
+ ],
335
+ )
336
+ def test_files_have_same_hash(river_forcing_fixture, request, tmp_path):
337
+
338
+ river_forcing = request.getfixturevalue(river_forcing_fixture)
302
339
 
303
340
  yaml_filepath = tmp_path / "test_yaml.yaml"
304
341
  filepath1 = tmp_path / "test1.nc"
@@ -185,6 +185,7 @@ def test_successful_initialization_with_regional_data(grid_fixture, request, use
185
185
  start_time=start_time,
186
186
  end_time=end_time,
187
187
  source={"name": "ERA5", "path": fname},
188
+ correct_radiation=True,
188
189
  use_dask=use_dask,
189
190
  )
190
191
 
@@ -23,7 +23,7 @@ def _get_fname(name):
23
23
  "bgc_surface_forcing_from_climatology",
24
24
  "boundary_forcing",
25
25
  "bgc_boundary_forcing_from_climatology",
26
- "river_forcing",
26
+ "river_forcing_with_bgc",
27
27
  "river_forcing_no_climatology",
28
28
  ],
29
29
  )
@@ -62,7 +62,7 @@ def test_save_results(forcing_fixture, request):
62
62
  "bgc_surface_forcing_from_climatology",
63
63
  "boundary_forcing",
64
64
  "bgc_boundary_forcing_from_climatology",
65
- "river_forcing",
65
+ "river_forcing_with_bgc",
66
66
  "river_forcing_no_climatology",
67
67
  ],
68
68
  )