roms-tools 2.5.0__py3-none-any.whl → 2.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. ci/environment-with-xesmf.yml +16 -0
  2. roms_tools/analysis/roms_output.py +521 -187
  3. roms_tools/analysis/utils.py +169 -0
  4. roms_tools/plot.py +351 -214
  5. roms_tools/regrid.py +161 -9
  6. roms_tools/setup/boundary_forcing.py +22 -22
  7. roms_tools/setup/datasets.py +40 -44
  8. roms_tools/setup/grid.py +28 -28
  9. roms_tools/setup/initial_conditions.py +23 -31
  10. roms_tools/setup/nesting.py +3 -3
  11. roms_tools/setup/river_forcing.py +22 -23
  12. roms_tools/setup/surface_forcing.py +14 -13
  13. roms_tools/setup/tides.py +7 -7
  14. roms_tools/setup/topography.py +2 -2
  15. roms_tools/tests/test_analysis/test_roms_output.py +299 -188
  16. roms_tools/tests/test_regrid.py +85 -2
  17. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/.zmetadata +2 -2
  18. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/.zmetadata +2 -2
  19. roms_tools/tests/test_setup/test_river_forcing.py +47 -51
  20. roms_tools/tests/test_vertical_coordinate.py +73 -0
  21. roms_tools/utils.py +11 -7
  22. roms_tools/vertical_coordinate.py +7 -0
  23. {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info}/METADATA +22 -11
  24. {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info}/RECORD +33 -30
  25. {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info}/WHEEL +1 -1
  26. /roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/{river_location → river_flux}/.zarray +0 -0
  27. /roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/{river_location → river_flux}/.zattrs +0 -0
  28. /roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/{river_location → river_flux}/0.0 +0 -0
  29. /roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/{river_location → river_flux}/.zarray +0 -0
  30. /roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/{river_location → river_flux}/.zattrs +0 -0
  31. /roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/{river_location → river_flux}/0.0 +0 -0
  32. {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info/licenses}/LICENSE +0 -0
  33. {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,18 @@
1
1
  import pytest
2
2
  from pathlib import Path
3
3
  import xarray as xr
4
+ import numpy as np
4
5
  import os
5
6
  import logging
6
7
  from datetime import datetime
7
8
  from roms_tools import Grid, ROMSOutput
8
9
  from roms_tools.download import download_test_data
9
10
 
11
+ try:
12
+ import xesmf # type: ignore
13
+ except ImportError:
14
+ xesmf = None
15
+
10
16
 
11
17
  @pytest.fixture
12
18
  def roms_output_from_restart_file(use_dask):
@@ -22,9 +28,32 @@ def roms_output_from_restart_file(use_dask):
22
28
  )
23
29
 
24
30
 
25
- def test_load_model_output_file(roms_output_from_restart_file, use_dask):
31
+ @pytest.fixture
32
+ def roms_output_from_restart_file_adjusted_for_zeta(use_dask):
33
+
34
+ fname_grid = Path(download_test_data("epac25km_grd.nc"))
35
+ grid = Grid.from_file(fname_grid)
36
+
37
+ # Single file
38
+ return ROMSOutput(
39
+ grid=grid,
40
+ path=Path(download_test_data("eastpac25km_rst.19980106000000.nc")),
41
+ adjust_depth_for_sea_surface_height=True,
42
+ use_dask=use_dask,
43
+ )
44
+
26
45
 
27
- assert isinstance(roms_output_from_restart_file.ds, xr.Dataset)
46
+ @pytest.mark.parametrize(
47
+ "roms_output_fixture",
48
+ [
49
+ "roms_output_from_restart_file",
50
+ "roms_output_from_restart_file_adjusted_for_zeta",
51
+ ],
52
+ )
53
+ def test_load_model_output_file(roms_output_fixture, use_dask, request):
54
+ roms_output = request.getfixturevalue(roms_output_fixture)
55
+
56
+ assert isinstance(roms_output.ds, xr.Dataset)
28
57
 
29
58
 
30
59
  def test_load_model_output_file_list(use_dask):
@@ -235,198 +264,280 @@ def test_that_coordinates_are_added(use_dask):
235
264
  assert "lon_rho" in output.ds.coords
236
265
 
237
266
 
238
- def test_plot(roms_output_from_restart_file, use_dask):
267
+ @pytest.mark.parametrize(
268
+ "roms_output_fixture",
269
+ [
270
+ "roms_output_from_restart_file",
271
+ "roms_output_from_restart_file_adjusted_for_zeta",
272
+ ],
273
+ )
274
+ def test_plot_on_native_model_grid(roms_output_fixture, use_dask, request):
275
+ roms_output = request.getfixturevalue(roms_output_fixture)
239
276
 
240
- kwargs = {}
241
- for var_name in ["temp", "u", "v"]:
242
- for include_boundary in [False, True]:
243
- roms_output_from_restart_file.plot(
244
- var_name, time=0, s=-1, **kwargs, include_boundary=include_boundary
245
- )
246
- roms_output_from_restart_file.plot(
247
- var_name, time=0, eta=0, **kwargs, include_boundary=include_boundary
248
- )
249
- roms_output_from_restart_file.plot(
250
- var_name, time=0, eta=1, **kwargs, include_boundary=include_boundary
251
- )
252
- roms_output_from_restart_file.plot(
253
- var_name, time=0, xi=0, **kwargs, include_boundary=include_boundary
254
- )
255
- roms_output_from_restart_file.plot(
256
- var_name, time=0, xi=1, **kwargs, include_boundary=include_boundary
257
- )
258
- roms_output_from_restart_file.plot(
259
- var_name,
260
- time=0,
261
- eta=0,
262
- xi=0,
263
- **kwargs,
264
- include_boundary=include_boundary
265
- )
266
- roms_output_from_restart_file.plot(
267
- var_name,
268
- time=0,
269
- eta=0,
270
- xi=1,
271
- **kwargs,
272
- include_boundary=include_boundary
273
- )
274
- roms_output_from_restart_file.plot(
275
- var_name,
276
- time=0,
277
- eta=1,
278
- xi=0,
279
- **kwargs,
280
- include_boundary=include_boundary
281
- )
282
- roms_output_from_restart_file.plot(
283
- var_name,
284
- time=0,
285
- eta=1,
286
- xi=1,
287
- **kwargs,
288
- include_boundary=include_boundary
289
- )
290
- roms_output_from_restart_file.plot(
291
- var_name,
292
- time=0,
293
- s=-1,
294
- eta=0,
295
- **kwargs,
296
- include_boundary=include_boundary
297
- )
298
- roms_output_from_restart_file.plot(
299
- var_name,
300
- time=0,
301
- s=-1,
302
- eta=1,
303
- **kwargs,
304
- include_boundary=include_boundary
305
- )
306
- roms_output_from_restart_file.plot(
307
- var_name,
308
- time=0,
309
- s=-1,
310
- xi=0,
311
- **kwargs,
312
- include_boundary=include_boundary
313
- )
314
- roms_output_from_restart_file.plot(
315
- var_name,
316
- time=0,
317
- s=-1,
318
- xi=1,
319
- **kwargs,
320
- include_boundary=include_boundary
321
- )
322
-
323
- kwargs = {"depth_contours": True, "layer_contours": True}
324
- for var_name in ["temp", "u", "v"]:
325
- for include_boundary in [False, True]:
326
- roms_output_from_restart_file.plot(
327
- var_name, time=0, s=-1, **kwargs, include_boundary=include_boundary
328
- )
329
- roms_output_from_restart_file.plot(
330
- var_name, time=0, eta=0, **kwargs, include_boundary=include_boundary
331
- )
332
- roms_output_from_restart_file.plot(
333
- var_name, time=0, eta=1, **kwargs, include_boundary=include_boundary
334
- )
335
- roms_output_from_restart_file.plot(
336
- var_name, time=0, xi=0, **kwargs, include_boundary=include_boundary
337
- )
338
- roms_output_from_restart_file.plot(
339
- var_name, time=0, xi=1, **kwargs, include_boundary=include_boundary
340
- )
341
- roms_output_from_restart_file.plot(
342
- var_name,
343
- time=0,
344
- eta=0,
345
- xi=0,
346
- **kwargs,
347
- include_boundary=include_boundary
348
- )
349
- roms_output_from_restart_file.plot(
350
- var_name,
351
- time=0,
352
- eta=0,
353
- xi=1,
354
- **kwargs,
355
- include_boundary=include_boundary
356
- )
357
- roms_output_from_restart_file.plot(
358
- var_name,
359
- time=0,
360
- eta=1,
361
- xi=0,
362
- **kwargs,
363
- include_boundary=include_boundary
364
- )
365
- roms_output_from_restart_file.plot(
366
- var_name,
367
- time=0,
368
- eta=1,
369
- xi=1,
370
- **kwargs,
371
- include_boundary=include_boundary
372
- )
373
- roms_output_from_restart_file.plot(
374
- var_name,
375
- time=0,
376
- s=-1,
377
- eta=0,
378
- **kwargs,
379
- include_boundary=include_boundary
380
- )
381
- roms_output_from_restart_file.plot(
382
- var_name,
383
- time=0,
384
- s=-1,
385
- eta=1,
386
- **kwargs,
387
- include_boundary=include_boundary
388
- )
389
- roms_output_from_restart_file.plot(
390
- var_name,
391
- time=0,
392
- s=-1,
393
- xi=0,
394
- **kwargs,
395
- include_boundary=include_boundary
396
- )
397
- roms_output_from_restart_file.plot(
398
- var_name,
399
- time=0,
400
- s=-1,
401
- xi=1,
402
- **kwargs,
403
- include_boundary=include_boundary
404
- )
277
+ for include_boundary in [False, True]:
278
+ for depth_contours in [False, True]:
279
+
280
+ # 3D fields
281
+ for var_name in ["temp", "u", "v"]:
282
+ kwargs = {
283
+ "include_boundary": include_boundary,
284
+ "depth_contours": depth_contours,
285
+ }
286
+
287
+ roms_output.plot(var_name, time=1, s=-1, **kwargs)
288
+ roms_output.plot(var_name, time=1, depth=1000, **kwargs)
289
+
290
+ roms_output.plot(var_name, time=1, eta=1, **kwargs)
291
+ roms_output.plot(var_name, time=1, xi=1, **kwargs)
292
+
293
+ roms_output.plot(
294
+ var_name,
295
+ time=1,
296
+ eta=1,
297
+ xi=1,
298
+ **kwargs,
299
+ )
300
+
301
+ roms_output.plot(
302
+ var_name,
303
+ time=1,
304
+ s=-1,
305
+ eta=1,
306
+ **kwargs,
307
+ )
308
+ roms_output.plot(
309
+ var_name,
310
+ time=1,
311
+ depth=1000,
312
+ eta=1,
313
+ **kwargs,
314
+ )
315
+
316
+ roms_output.plot(
317
+ var_name,
318
+ time=1,
319
+ s=-1,
320
+ xi=1,
321
+ **kwargs,
322
+ )
323
+ roms_output.plot(
324
+ var_name,
325
+ time=1,
326
+ depth=1000,
327
+ xi=1,
328
+ **kwargs,
329
+ )
330
+
331
+ # 2D fields
332
+ roms_output.plot("zeta", time=1, **kwargs)
333
+ roms_output.plot("zeta", time=1, eta=1, **kwargs)
334
+ roms_output.plot("zeta", time=1, xi=1, **kwargs)
335
+
336
+
337
+ @pytest.mark.parametrize(
338
+ "roms_output_fixture",
339
+ [
340
+ "roms_output_from_restart_file",
341
+ "roms_output_from_restart_file_adjusted_for_zeta",
342
+ ],
343
+ )
344
+ @pytest.mark.skipif(xesmf is None, reason="xesmf required")
345
+ def test_plot_on_lat_lon(roms_output_fixture, use_dask, request):
346
+ roms_output = request.getfixturevalue(roms_output_fixture)
405
347
 
406
348
  for include_boundary in [False, True]:
407
- roms_output_from_restart_file.plot(
408
- "zeta", time=0, **kwargs, include_boundary=include_boundary
409
- )
410
- roms_output_from_restart_file.plot(
411
- "zeta", time=0, eta=0, **kwargs, include_boundary=include_boundary
412
- )
413
- roms_output_from_restart_file.plot(
414
- "zeta", time=0, eta=1, **kwargs, include_boundary=include_boundary
415
- )
416
- roms_output_from_restart_file.plot(
417
- "zeta", time=0, xi=0, **kwargs, include_boundary=include_boundary
418
- )
419
- roms_output_from_restart_file.plot(
420
- "zeta", time=0, xi=1, **kwargs, include_boundary=include_boundary
421
- )
349
+ for depth_contours in [False, True]:
350
+
351
+ # 3D fields
352
+ for var_name in ["temp", "u", "v"]:
353
+ kwargs = {
354
+ "include_boundary": include_boundary,
355
+ "depth_contours": depth_contours,
356
+ }
357
+ roms_output.plot(
358
+ var_name,
359
+ time=1,
360
+ lat=9,
361
+ lon=-128,
362
+ **kwargs,
363
+ )
364
+ roms_output.plot(
365
+ var_name,
366
+ time=1,
367
+ lat=9,
368
+ **kwargs,
369
+ )
370
+ roms_output.plot(
371
+ var_name,
372
+ time=1,
373
+ lat=9,
374
+ s=-1,
375
+ **kwargs,
376
+ )
377
+ roms_output.plot(
378
+ var_name,
379
+ time=1,
380
+ lat=9,
381
+ depth=1000,
382
+ **kwargs,
383
+ )
384
+ roms_output.plot(
385
+ var_name,
386
+ time=1,
387
+ lon=-128,
388
+ **kwargs,
389
+ )
390
+ roms_output.plot(
391
+ var_name,
392
+ time=1,
393
+ lon=-128,
394
+ s=-1,
395
+ **kwargs,
396
+ )
397
+ roms_output.plot(
398
+ var_name,
399
+ time=1,
400
+ lon=-128,
401
+ depth=1000,
402
+ **kwargs,
403
+ )
404
+
405
+ # 2D fields
406
+ roms_output.plot("zeta", time=1, lat=9, **kwargs)
407
+ roms_output.plot("zeta", time=1, lon=-128, **kwargs)
422
408
 
423
409
 
424
410
  def test_plot_errors(roms_output_from_restart_file, use_dask):
411
+ """Test error conditions for the ROMSOutput.plot() method."""
412
+
413
+ # Invalid time index
425
414
  with pytest.raises(ValueError, match="Invalid time index"):
426
415
  roms_output_from_restart_file.plot("temp", time=10, s=-1)
427
- with pytest.raises(ValueError, match="Invalid input"):
428
- roms_output_from_restart_file.plot("temp", time=0)
416
+
417
+ with pytest.raises(
418
+ ValueError,
419
+ match="Conflicting input: You cannot specify both 's' and 'depth' at the same time.",
420
+ ):
421
+ roms_output_from_restart_file.plot("temp", time=0, s=-1, depth=10)
422
+
423
+ # Ambiguous input: Too many dimensions specified for 3D fields
429
424
  with pytest.raises(ValueError, match="Ambiguous input"):
430
- roms_output_from_restart_file.plot("temp", time=0, s=-1, eta=0, xi=0)
431
- with pytest.raises(ValueError, match="Conflicting input"):
432
- roms_output_from_restart_file.plot("zeta", time=0, eta=0, xi=0)
425
+ roms_output_from_restart_file.plot("temp", time=1, s=-1, eta=0, xi=0)
426
+
427
+ # Vertical dimension specified for 2D fields
428
+ with pytest.raises(
429
+ ValueError, match="Vertical dimension 's' should be None for 2D fields"
430
+ ):
431
+ roms_output_from_restart_file.plot("zeta", time=1, s=-1)
432
+ with pytest.raises(
433
+ ValueError, match="Vertical dimension 'depth' should be None for 2D fields"
434
+ ):
435
+ roms_output_from_restart_file.plot("zeta", time=1, depth=100)
436
+
437
+ # Conflicting input: Both eta and xi specified for 2D fields
438
+ with pytest.raises(
439
+ ValueError,
440
+ match="Conflicting input: For 2D fields, specify only one dimension, either 'eta' or 'xi', not both.",
441
+ ):
442
+ roms_output_from_restart_file.plot("zeta", time=1, eta=0, xi=0)
443
+ # Conflicting input: Both lat and lon specified for 2D fields
444
+ with pytest.raises(
445
+ ValueError,
446
+ match="Conflicting input: For 2D fields, specify only one dimension, either 'lat' or 'lon', not both.",
447
+ ):
448
+ roms_output_from_restart_file.plot("zeta", time=1, lat=0, lon=0)
449
+
450
+ # Conflicting input: lat or lon provided with eta or xi
451
+ with pytest.raises(
452
+ ValueError,
453
+ match="Conflicting input: You cannot specify 'lat' or 'lon' simultaneously with 'eta' or 'xi'.",
454
+ ):
455
+ roms_output_from_restart_file.plot("temp", time=1, lat=10, lon=20, eta=5)
456
+
457
+ # Invalid eta index out of bounds
458
+ with pytest.raises(ValueError, match="Invalid eta index"):
459
+ roms_output_from_restart_file.plot("temp", time=1, eta=999)
460
+
461
+ # Invalid xi index out of bounds
462
+ with pytest.raises(ValueError, match="Invalid eta index"):
463
+ roms_output_from_restart_file.plot("temp", time=1, xi=999)
464
+
465
+ # Boundary exclusion error for eta
466
+ with pytest.raises(ValueError, match="Invalid eta index.*boundary.*excluded"):
467
+ roms_output_from_restart_file.plot(
468
+ "temp", time=1, eta=0, include_boundary=False
469
+ )
470
+
471
+ # Boundary exclusion error for xi
472
+ with pytest.raises(ValueError, match="Invalid xi index.*boundary.*excluded"):
473
+ roms_output_from_restart_file.plot("temp", time=1, xi=0, include_boundary=False)
474
+
475
+ # No dimension specified for 3D field
476
+ with pytest.raises(
477
+ ValueError,
478
+ match="Invalid input: For 3D fields, you must specify at least one of the dimensions",
479
+ ):
480
+ roms_output_from_restart_file.plot("temp", time=1)
481
+
482
+
483
+ def test_figure_gets_saved(roms_output_from_restart_file, tmp_path):
484
+
485
+ filename = tmp_path / "figure.png"
486
+ roms_output_from_restart_file.plot("temp", time=0, depth=1000, save_path=filename)
487
+
488
+ assert filename.exists()
489
+ filename.unlink()
490
+
491
+
492
+ @pytest.mark.skipif(xesmf is None, reason="xesmf required")
493
+ def test_regrid_all_variables(roms_output_from_restart_file):
494
+ ds_regridded = roms_output_from_restart_file.regrid()
495
+ assert isinstance(ds_regridded, xr.Dataset)
496
+ assert set(ds_regridded.data_vars).issubset(
497
+ set(roms_output_from_restart_file.ds.data_vars)
498
+ )
499
+ assert "lon" in ds_regridded.coords
500
+ assert "lat" in ds_regridded.coords
501
+ assert "depth" in ds_regridded.coords
502
+ assert "time" in ds_regridded.coords
503
+
504
+
505
+ @pytest.mark.skipif(xesmf is None, reason="xesmf required")
506
+ def test_regrid_specific_variables(roms_output_from_restart_file):
507
+ var_names = ["temp", "salt"]
508
+ ds_regridded = roms_output_from_restart_file.regrid(var_names=var_names)
509
+ assert isinstance(ds_regridded, xr.Dataset)
510
+ assert set(ds_regridded.data_vars) == set(var_names)
511
+
512
+ ds = roms_output_from_restart_file.regrid(var_names=[])
513
+ assert ds is None
514
+
515
+
516
+ @pytest.mark.skipif(xesmf is None, reason="xesmf required")
517
+ def test_regrid_missing_variable_raises_error(roms_output_from_restart_file):
518
+ with pytest.raises(
519
+ ValueError, match="The following variables are not found in the dataset"
520
+ ):
521
+ roms_output_from_restart_file.regrid(var_names=["fake_variable"])
522
+
523
+
524
+ @pytest.mark.skipif(xesmf is None, reason="xesmf required")
525
+ def test_regrid_with_custom_horizontal_resolution(roms_output_from_restart_file):
526
+ ds_regridded = roms_output_from_restart_file.regrid(horizontal_resolution=0.1)
527
+ assert isinstance(ds_regridded, xr.Dataset)
528
+ assert "lon" in ds_regridded.coords
529
+ assert "lat" in ds_regridded.coords
530
+
531
+ assert np.allclose(ds_regridded.lon.diff(dim="lon"), 0.1, atol=1e-4)
532
+ assert np.allclose(ds_regridded.lat.diff(dim="lat"), 0.1, atol=1e-4)
533
+
534
+
535
+ @pytest.mark.skipif(xesmf is None, reason="xesmf required")
536
+ def test_regrid_with_custom_depth_levels(roms_output_from_restart_file):
537
+ depth_levels = xr.DataArray(
538
+ np.linspace(0, 500, 51), dims=["depth"], attrs={"units": "m"}
539
+ )
540
+ ds_regridded = roms_output_from_restart_file.regrid(depth_levels=depth_levels)
541
+ assert isinstance(ds_regridded, xr.Dataset)
542
+ assert "depth" in ds_regridded.coords
543
+ np.allclose(ds_regridded.depth, depth_levels, atol=0.0)
@@ -1,9 +1,92 @@
1
1
  import pytest
2
2
  import numpy as np
3
3
  import xarray as xr
4
- from roms_tools.regrid import VerticalRegrid
4
+ from roms_tools.regrid import VerticalRegridToROMS
5
5
 
6
+ try:
7
+ import xesmf # type: ignore
8
+ except ImportError:
9
+ xesmf = None
6
10
 
11
+ from roms_tools.regrid import LateralRegridFromROMS
12
+
13
+ # Lateral regridding
14
+ @pytest.mark.skipif(xesmf is None, reason="xesmf required")
15
+ def test_lateral_regrid_with_curvilinear_grid():
16
+ """Test that LateralRegridFromROMS regrids data correctly from a curvilinear ROMS
17
+ grid."""
18
+
19
+ # Define ROMS curvilinear grid dimensions
20
+ eta_rho, xi_rho = 10, 20
21
+
22
+ # Create a mock ROMS grid with curvilinear coordinates
23
+ lat_rho = np.linspace(-10, 10, eta_rho).reshape(-1, 1) * np.ones((1, xi_rho))
24
+ lon_rho = np.linspace(120, 140, xi_rho).reshape(1, -1) * np.ones((eta_rho, 1))
25
+
26
+ ds_in = xr.Dataset(
27
+ {
28
+ "temp": (("eta_rho", "xi_rho"), np.random.rand(eta_rho, xi_rho)),
29
+ },
30
+ coords={
31
+ "lat": (("eta_rho", "xi_rho"), lat_rho),
32
+ "lon": (("eta_rho", "xi_rho"), lon_rho),
33
+ },
34
+ )
35
+
36
+ # Define target latitude and longitude coordinates
37
+ target_coords = {
38
+ "lat": np.linspace(-5, 5, 5),
39
+ "lon": np.linspace(125, 135, 10),
40
+ }
41
+
42
+ # Instantiate the regridder
43
+ regridder = LateralRegridFromROMS(ds_in, target_coords, method="bilinear")
44
+
45
+ # Apply the regridding to the input data
46
+ regridded_da = regridder.apply(ds_in["temp"])
47
+
48
+ # Assertions to verify that the output is as expected
49
+ assert isinstance(regridded_da, xr.DataArray)
50
+ assert regridded_da.shape == (5, 10)
51
+ assert np.allclose(regridded_da.coords["lat"], target_coords["lat"])
52
+ assert np.allclose(regridded_da.coords["lon"], target_coords["lon"])
53
+
54
+
55
+ @pytest.mark.skipif(xesmf is not None, reason="xesmf has to be missing")
56
+ def test_lateral_regrid_import_error():
57
+ """Test that LateralRegridFromROMS raises ImportError when xesmf is missing."""
58
+
59
+ # Define mock ROMS curvilinear grid dimensions
60
+ eta_rho, xi_rho = 10, 20
61
+
62
+ # Create a mock ROMS grid with curvilinear coordinates
63
+ lat_rho = np.linspace(-10, 10, eta_rho).reshape(-1, 1) * np.ones((1, xi_rho))
64
+ lon_rho = np.linspace(120, 140, xi_rho).reshape(1, -1) * np.ones((eta_rho, 1))
65
+
66
+ ds_in = xr.Dataset(
67
+ {
68
+ "temp": (("eta_rho", "xi_rho"), np.random.rand(eta_rho, xi_rho)),
69
+ },
70
+ coords={
71
+ "lat": (("eta_rho", "xi_rho"), lat_rho),
72
+ "lon": (("eta_rho", "xi_rho"), lon_rho),
73
+ },
74
+ )
75
+
76
+ # Define target latitude and longitude coordinates
77
+ target_coords = {
78
+ "lat": np.linspace(-5, 5, 5),
79
+ "lon": np.linspace(125, 135, 10),
80
+ }
81
+
82
+ # Check that ImportError is raised when xesmf is missing
83
+ with pytest.raises(
84
+ ImportError, match="xesmf is required for this regridding task.*"
85
+ ):
86
+ LateralRegridFromROMS(ds_in, target_coords, method="bilinear")
87
+
88
+
89
+ # Vertical regridding
7
90
  def vertical_regridder(depth_values, layer_depth_rho_values):
8
91
  class DataContainer:
9
92
  """Mock class for holding data and dimension names."""
@@ -21,7 +104,7 @@ def vertical_regridder(depth_values, layer_depth_rho_values):
21
104
  target_depth = xr.DataArray(data=layer_depth_rho_values, dims=["s_rho"])
22
105
  source_depth = xr.DataArray(data=depth_values, dims=["depth"])
23
106
 
24
- return VerticalRegrid(target_depth, source_depth)
107
+ return VerticalRegridToROMS(target_depth, source_depth)
25
108
 
26
109
 
27
110
  @pytest.mark.parametrize(
@@ -58,7 +58,7 @@
58
58
  ],
59
59
  "long_name": "River ID (1-based Fortran indexing)"
60
60
  },
61
- "river_location/.zarray": {
61
+ "river_flux/.zarray": {
62
62
  "chunks": [
63
63
  20,
64
64
  20
@@ -80,7 +80,7 @@
80
80
  ],
81
81
  "zarr_format": 2
82
82
  },
83
- "river_location/.zattrs": {
83
+ "river_flux/.zattrs": {
84
84
  "_ARRAY_DIMENSIONS": [
85
85
  "eta_rho",
86
86
  "xi_rho"
@@ -86,7 +86,7 @@
86
86
  ],
87
87
  "long_name": "River ID (1-based Fortran indexing)"
88
88
  },
89
- "river_location/.zarray": {
89
+ "river_flux/.zarray": {
90
90
  "chunks": [
91
91
  20,
92
92
  20
@@ -108,7 +108,7 @@
108
108
  ],
109
109
  "zarr_format": 2
110
110
  },
111
- "river_location/.zattrs": {
111
+ "river_flux/.zattrs": {
112
112
  "_ARRAY_DIMENSIONS": [
113
113
  "eta_rho",
114
114
  "xi_rho"