roms-tools 2.5.0__py3-none-any.whl → 2.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,6 +7,11 @@ from datetime import datetime
7
7
  from roms_tools import Grid, ROMSOutput
8
8
  from roms_tools.download import download_test_data
9
9
 
10
+ try:
11
+ import xesmf # type: ignore
12
+ except ImportError:
13
+ xesmf = None
14
+
10
15
 
11
16
  @pytest.fixture
12
17
  def roms_output_from_restart_file(use_dask):
@@ -22,9 +27,32 @@ def roms_output_from_restart_file(use_dask):
22
27
  )
23
28
 
24
29
 
25
- def test_load_model_output_file(roms_output_from_restart_file, use_dask):
30
+ @pytest.fixture
31
+ def roms_output_from_restart_file_adjusted_for_zeta(use_dask):
32
+
33
+ fname_grid = Path(download_test_data("epac25km_grd.nc"))
34
+ grid = Grid.from_file(fname_grid)
26
35
 
27
- assert isinstance(roms_output_from_restart_file.ds, xr.Dataset)
36
+ # Single file
37
+ return ROMSOutput(
38
+ grid=grid,
39
+ path=Path(download_test_data("eastpac25km_rst.19980106000000.nc")),
40
+ adjust_depth_for_sea_surface_height=True,
41
+ use_dask=use_dask,
42
+ )
43
+
44
+
45
+ @pytest.mark.parametrize(
46
+ "roms_output_fixture",
47
+ [
48
+ "roms_output_from_restart_file",
49
+ "roms_output_from_restart_file_adjusted_for_zeta",
50
+ ],
51
+ )
52
+ def test_load_model_output_file(roms_output_fixture, use_dask, request):
53
+ roms_output = request.getfixturevalue(roms_output_fixture)
54
+
55
+ assert isinstance(roms_output.ds, xr.Dataset)
28
56
 
29
57
 
30
58
  def test_load_model_output_file_list(use_dask):
@@ -235,198 +263,226 @@ def test_that_coordinates_are_added(use_dask):
235
263
  assert "lon_rho" in output.ds.coords
236
264
 
237
265
 
238
- def test_plot(roms_output_from_restart_file, use_dask):
266
+ @pytest.mark.parametrize(
267
+ "roms_output_fixture",
268
+ [
269
+ "roms_output_from_restart_file",
270
+ "roms_output_from_restart_file_adjusted_for_zeta",
271
+ ],
272
+ )
273
+ def test_plot_on_native_model_grid(roms_output_fixture, use_dask, request):
274
+ roms_output = request.getfixturevalue(roms_output_fixture)
239
275
 
240
- kwargs = {}
241
- for var_name in ["temp", "u", "v"]:
242
- for include_boundary in [False, True]:
243
- roms_output_from_restart_file.plot(
244
- var_name, time=0, s=-1, **kwargs, include_boundary=include_boundary
245
- )
246
- roms_output_from_restart_file.plot(
247
- var_name, time=0, eta=0, **kwargs, include_boundary=include_boundary
248
- )
249
- roms_output_from_restart_file.plot(
250
- var_name, time=0, eta=1, **kwargs, include_boundary=include_boundary
251
- )
252
- roms_output_from_restart_file.plot(
253
- var_name, time=0, xi=0, **kwargs, include_boundary=include_boundary
254
- )
255
- roms_output_from_restart_file.plot(
256
- var_name, time=0, xi=1, **kwargs, include_boundary=include_boundary
257
- )
258
- roms_output_from_restart_file.plot(
259
- var_name,
260
- time=0,
261
- eta=0,
262
- xi=0,
263
- **kwargs,
264
- include_boundary=include_boundary
265
- )
266
- roms_output_from_restart_file.plot(
267
- var_name,
268
- time=0,
269
- eta=0,
270
- xi=1,
271
- **kwargs,
272
- include_boundary=include_boundary
273
- )
274
- roms_output_from_restart_file.plot(
275
- var_name,
276
- time=0,
277
- eta=1,
278
- xi=0,
279
- **kwargs,
280
- include_boundary=include_boundary
281
- )
282
- roms_output_from_restart_file.plot(
283
- var_name,
284
- time=0,
285
- eta=1,
286
- xi=1,
287
- **kwargs,
288
- include_boundary=include_boundary
289
- )
290
- roms_output_from_restart_file.plot(
291
- var_name,
292
- time=0,
293
- s=-1,
294
- eta=0,
295
- **kwargs,
296
- include_boundary=include_boundary
297
- )
298
- roms_output_from_restart_file.plot(
299
- var_name,
300
- time=0,
301
- s=-1,
302
- eta=1,
303
- **kwargs,
304
- include_boundary=include_boundary
305
- )
306
- roms_output_from_restart_file.plot(
307
- var_name,
308
- time=0,
309
- s=-1,
310
- xi=0,
311
- **kwargs,
312
- include_boundary=include_boundary
313
- )
314
- roms_output_from_restart_file.plot(
315
- var_name,
316
- time=0,
317
- s=-1,
318
- xi=1,
319
- **kwargs,
320
- include_boundary=include_boundary
321
- )
322
-
323
- kwargs = {"depth_contours": True, "layer_contours": True}
324
- for var_name in ["temp", "u", "v"]:
325
- for include_boundary in [False, True]:
326
- roms_output_from_restart_file.plot(
327
- var_name, time=0, s=-1, **kwargs, include_boundary=include_boundary
328
- )
329
- roms_output_from_restart_file.plot(
330
- var_name, time=0, eta=0, **kwargs, include_boundary=include_boundary
331
- )
332
- roms_output_from_restart_file.plot(
333
- var_name, time=0, eta=1, **kwargs, include_boundary=include_boundary
334
- )
335
- roms_output_from_restart_file.plot(
336
- var_name, time=0, xi=0, **kwargs, include_boundary=include_boundary
337
- )
338
- roms_output_from_restart_file.plot(
339
- var_name, time=0, xi=1, **kwargs, include_boundary=include_boundary
340
- )
341
- roms_output_from_restart_file.plot(
342
- var_name,
343
- time=0,
344
- eta=0,
345
- xi=0,
346
- **kwargs,
347
- include_boundary=include_boundary
348
- )
349
- roms_output_from_restart_file.plot(
350
- var_name,
351
- time=0,
352
- eta=0,
353
- xi=1,
354
- **kwargs,
355
- include_boundary=include_boundary
356
- )
357
- roms_output_from_restart_file.plot(
358
- var_name,
359
- time=0,
360
- eta=1,
361
- xi=0,
362
- **kwargs,
363
- include_boundary=include_boundary
364
- )
365
- roms_output_from_restart_file.plot(
366
- var_name,
367
- time=0,
368
- eta=1,
369
- xi=1,
370
- **kwargs,
371
- include_boundary=include_boundary
372
- )
373
- roms_output_from_restart_file.plot(
374
- var_name,
375
- time=0,
376
- s=-1,
377
- eta=0,
378
- **kwargs,
379
- include_boundary=include_boundary
380
- )
381
- roms_output_from_restart_file.plot(
382
- var_name,
383
- time=0,
384
- s=-1,
385
- eta=1,
386
- **kwargs,
387
- include_boundary=include_boundary
388
- )
389
- roms_output_from_restart_file.plot(
390
- var_name,
391
- time=0,
392
- s=-1,
393
- xi=0,
394
- **kwargs,
395
- include_boundary=include_boundary
396
- )
397
- roms_output_from_restart_file.plot(
398
- var_name,
399
- time=0,
400
- s=-1,
401
- xi=1,
402
- **kwargs,
403
- include_boundary=include_boundary
404
- )
276
+ for include_boundary in [False, True]:
277
+ for depth_contours in [False, True]:
278
+
279
+ # 3D fields
280
+ for var_name in ["temp", "u", "v"]:
281
+ kwargs = {
282
+ "include_boundary": include_boundary,
283
+ "depth_contours": depth_contours,
284
+ }
285
+
286
+ roms_output.plot(var_name, time=1, s=-1, **kwargs)
287
+ roms_output.plot(var_name, time=1, depth=1000, **kwargs)
288
+
289
+ roms_output.plot(var_name, time=1, eta=1, **kwargs)
290
+ roms_output.plot(var_name, time=1, xi=1, **kwargs)
291
+
292
+ roms_output.plot(
293
+ var_name,
294
+ time=1,
295
+ eta=1,
296
+ xi=1,
297
+ **kwargs,
298
+ )
299
+
300
+ roms_output.plot(
301
+ var_name,
302
+ time=1,
303
+ s=-1,
304
+ eta=1,
305
+ **kwargs,
306
+ )
307
+ roms_output.plot(
308
+ var_name,
309
+ time=1,
310
+ depth=1000,
311
+ eta=1,
312
+ **kwargs,
313
+ )
314
+
315
+ roms_output.plot(
316
+ var_name,
317
+ time=1,
318
+ s=-1,
319
+ xi=1,
320
+ **kwargs,
321
+ )
322
+ roms_output.plot(
323
+ var_name,
324
+ time=1,
325
+ depth=1000,
326
+ xi=1,
327
+ **kwargs,
328
+ )
329
+
330
+ # 2D fields
331
+ roms_output.plot("zeta", time=1, **kwargs)
332
+ roms_output.plot("zeta", time=1, eta=1, **kwargs)
333
+ roms_output.plot("zeta", time=1, xi=1, **kwargs)
334
+
335
+
336
+ @pytest.mark.parametrize(
337
+ "roms_output_fixture",
338
+ [
339
+ "roms_output_from_restart_file",
340
+ "roms_output_from_restart_file_adjusted_for_zeta",
341
+ ],
342
+ )
343
+ @pytest.mark.skipif(xesmf is None, reason="xesmf required")
344
+ def test_plot_on_lat_lon(roms_output_fixture, use_dask, request):
345
+ roms_output = request.getfixturevalue(roms_output_fixture)
405
346
 
406
347
  for include_boundary in [False, True]:
407
- roms_output_from_restart_file.plot(
408
- "zeta", time=0, **kwargs, include_boundary=include_boundary
409
- )
410
- roms_output_from_restart_file.plot(
411
- "zeta", time=0, eta=0, **kwargs, include_boundary=include_boundary
412
- )
413
- roms_output_from_restart_file.plot(
414
- "zeta", time=0, eta=1, **kwargs, include_boundary=include_boundary
415
- )
416
- roms_output_from_restart_file.plot(
417
- "zeta", time=0, xi=0, **kwargs, include_boundary=include_boundary
418
- )
419
- roms_output_from_restart_file.plot(
420
- "zeta", time=0, xi=1, **kwargs, include_boundary=include_boundary
421
- )
348
+ for depth_contours in [False, True]:
349
+
350
+ # 3D fields
351
+ for var_name in ["temp", "u", "v"]:
352
+ kwargs = {
353
+ "include_boundary": include_boundary,
354
+ "depth_contours": depth_contours,
355
+ }
356
+ roms_output.plot(
357
+ var_name,
358
+ time=1,
359
+ lat=9,
360
+ lon=-128,
361
+ **kwargs,
362
+ )
363
+ roms_output.plot(
364
+ var_name,
365
+ time=1,
366
+ lat=9,
367
+ **kwargs,
368
+ )
369
+ roms_output.plot(
370
+ var_name,
371
+ time=1,
372
+ lat=9,
373
+ s=-1,
374
+ **kwargs,
375
+ )
376
+ roms_output.plot(
377
+ var_name,
378
+ time=1,
379
+ lat=9,
380
+ depth=1000,
381
+ **kwargs,
382
+ )
383
+ roms_output.plot(
384
+ var_name,
385
+ time=1,
386
+ lon=-128,
387
+ **kwargs,
388
+ )
389
+ roms_output.plot(
390
+ var_name,
391
+ time=1,
392
+ lon=-128,
393
+ s=-1,
394
+ **kwargs,
395
+ )
396
+ roms_output.plot(
397
+ var_name,
398
+ time=1,
399
+ lon=-128,
400
+ depth=1000,
401
+ **kwargs,
402
+ )
403
+
404
+ # 2D fields
405
+ roms_output.plot("zeta", time=1, lat=9, **kwargs)
406
+ roms_output.plot("zeta", time=1, lon=-128, **kwargs)
422
407
 
423
408
 
424
409
  def test_plot_errors(roms_output_from_restart_file, use_dask):
410
+ """Test error conditions for the ROMSOutput.plot() method."""
411
+
412
+ # Invalid time index
425
413
  with pytest.raises(ValueError, match="Invalid time index"):
426
414
  roms_output_from_restart_file.plot("temp", time=10, s=-1)
427
- with pytest.raises(ValueError, match="Invalid input"):
428
- roms_output_from_restart_file.plot("temp", time=0)
415
+
416
+ with pytest.raises(
417
+ ValueError,
418
+ match="Conflicting input: You cannot specify both 's' and 'depth' at the same time.",
419
+ ):
420
+ roms_output_from_restart_file.plot("temp", time=0, s=-1, depth=10)
421
+
422
+ # Ambiguous input: Too many dimensions specified for 3D fields
429
423
  with pytest.raises(ValueError, match="Ambiguous input"):
430
- roms_output_from_restart_file.plot("temp", time=0, s=-1, eta=0, xi=0)
431
- with pytest.raises(ValueError, match="Conflicting input"):
432
- roms_output_from_restart_file.plot("zeta", time=0, eta=0, xi=0)
424
+ roms_output_from_restart_file.plot("temp", time=1, s=-1, eta=0, xi=0)
425
+
426
+ # Vertical dimension specified for 2D fields
427
+ with pytest.raises(
428
+ ValueError, match="Vertical dimension 's' should be None for 2D fields"
429
+ ):
430
+ roms_output_from_restart_file.plot("zeta", time=1, s=-1)
431
+ with pytest.raises(
432
+ ValueError, match="Vertical dimension 'depth' should be None for 2D fields"
433
+ ):
434
+ roms_output_from_restart_file.plot("zeta", time=1, depth=100)
435
+
436
+ # Conflicting input: Both eta and xi specified for 2D fields
437
+ with pytest.raises(
438
+ ValueError,
439
+ match="Conflicting input: For 2D fields, specify only one dimension, either 'eta' or 'xi', not both.",
440
+ ):
441
+ roms_output_from_restart_file.plot("zeta", time=1, eta=0, xi=0)
442
+ # Conflicting input: Both lat and lon specified for 2D fields
443
+ with pytest.raises(
444
+ ValueError,
445
+ match="Conflicting input: For 2D fields, specify only one dimension, either 'lat' or 'lon', not both.",
446
+ ):
447
+ roms_output_from_restart_file.plot("zeta", time=1, lat=0, lon=0)
448
+
449
+ # Conflicting input: lat or lon provided with eta or xi
450
+ with pytest.raises(
451
+ ValueError,
452
+ match="Conflicting input: You cannot specify 'lat' or 'lon' simultaneously with 'eta' or 'xi'.",
453
+ ):
454
+ roms_output_from_restart_file.plot("temp", time=1, lat=10, lon=20, eta=5)
455
+
456
+ # Invalid eta index out of bounds
457
+ with pytest.raises(ValueError, match="Invalid eta index"):
458
+ roms_output_from_restart_file.plot("temp", time=1, eta=999)
459
+
460
+ # Invalid xi index out of bounds
461
+ with pytest.raises(ValueError, match="Invalid eta index"):
462
+ roms_output_from_restart_file.plot("temp", time=1, xi=999)
463
+
464
+ # Boundary exclusion error for eta
465
+ with pytest.raises(ValueError, match="Invalid eta index.*boundary.*excluded"):
466
+ roms_output_from_restart_file.plot(
467
+ "temp", time=1, eta=0, include_boundary=False
468
+ )
469
+
470
+ # Boundary exclusion error for xi
471
+ with pytest.raises(ValueError, match="Invalid xi index.*boundary.*excluded"):
472
+ roms_output_from_restart_file.plot("temp", time=1, xi=0, include_boundary=False)
473
+
474
+ # No dimension specified for 3D field
475
+ with pytest.raises(
476
+ ValueError,
477
+ match="Invalid input: For 3D fields, you must specify at least one of the dimensions",
478
+ ):
479
+ roms_output_from_restart_file.plot("temp", time=1)
480
+
481
+
482
+ def test_figure_gets_saved(roms_output_from_restart_file, tmp_path):
483
+
484
+ filename = tmp_path / "figure.png"
485
+ roms_output_from_restart_file.plot("temp", time=0, depth=1000, save_path=filename)
486
+
487
+ assert filename.exists()
488
+ filename.unlink()
@@ -1,9 +1,92 @@
1
1
  import pytest
2
2
  import numpy as np
3
3
  import xarray as xr
4
- from roms_tools.regrid import VerticalRegrid
4
+ from roms_tools.regrid import VerticalRegridToROMS
5
5
 
6
+ try:
7
+ import xesmf # type: ignore
8
+ except ImportError:
9
+ xesmf = None
6
10
 
11
+ from roms_tools.regrid import LateralRegridFromROMS
12
+
13
+ # Lateral regridding
14
+ @pytest.mark.skipif(xesmf is None, reason="xesmf required")
15
+ def test_lateral_regrid_with_curvilinear_grid():
16
+ """Test that LateralRegridFromROMS regrids data correctly from a curvilinear ROMS
17
+ grid."""
18
+
19
+ # Define ROMS curvilinear grid dimensions
20
+ eta_rho, xi_rho = 10, 20
21
+
22
+ # Create a mock ROMS grid with curvilinear coordinates
23
+ lat_rho = np.linspace(-10, 10, eta_rho).reshape(-1, 1) * np.ones((1, xi_rho))
24
+ lon_rho = np.linspace(120, 140, xi_rho).reshape(1, -1) * np.ones((eta_rho, 1))
25
+
26
+ ds_in = xr.Dataset(
27
+ {
28
+ "temp": (("eta_rho", "xi_rho"), np.random.rand(eta_rho, xi_rho)),
29
+ },
30
+ coords={
31
+ "lat": (("eta_rho", "xi_rho"), lat_rho),
32
+ "lon": (("eta_rho", "xi_rho"), lon_rho),
33
+ },
34
+ )
35
+
36
+ # Define target latitude and longitude coordinates
37
+ target_coords = {
38
+ "lat": np.linspace(-5, 5, 5),
39
+ "lon": np.linspace(125, 135, 10),
40
+ }
41
+
42
+ # Instantiate the regridder
43
+ regridder = LateralRegridFromROMS(ds_in, target_coords, method="bilinear")
44
+
45
+ # Apply the regridding to the input data
46
+ regridded_da = regridder.apply(ds_in["temp"])
47
+
48
+ # Assertions to verify that the output is as expected
49
+ assert isinstance(regridded_da, xr.DataArray)
50
+ assert regridded_da.shape == (5, 10)
51
+ assert np.allclose(regridded_da.coords["lat"], target_coords["lat"])
52
+ assert np.allclose(regridded_da.coords["lon"], target_coords["lon"])
53
+
54
+
55
+ @pytest.mark.skipif(xesmf is not None, reason="xesmf has to be missing")
56
+ def test_lateral_regrid_import_error():
57
+ """Test that LateralRegridFromROMS raises ImportError when xesmf is missing."""
58
+
59
+ # Define mock ROMS curvilinear grid dimensions
60
+ eta_rho, xi_rho = 10, 20
61
+
62
+ # Create a mock ROMS grid with curvilinear coordinates
63
+ lat_rho = np.linspace(-10, 10, eta_rho).reshape(-1, 1) * np.ones((1, xi_rho))
64
+ lon_rho = np.linspace(120, 140, xi_rho).reshape(1, -1) * np.ones((eta_rho, 1))
65
+
66
+ ds_in = xr.Dataset(
67
+ {
68
+ "temp": (("eta_rho", "xi_rho"), np.random.rand(eta_rho, xi_rho)),
69
+ },
70
+ coords={
71
+ "lat": (("eta_rho", "xi_rho"), lat_rho),
72
+ "lon": (("eta_rho", "xi_rho"), lon_rho),
73
+ },
74
+ )
75
+
76
+ # Define target latitude and longitude coordinates
77
+ target_coords = {
78
+ "lat": np.linspace(-5, 5, 5),
79
+ "lon": np.linspace(125, 135, 10),
80
+ }
81
+
82
+ # Check that ImportError is raised when xesmf is missing
83
+ with pytest.raises(
84
+ ImportError, match="xesmf is required for this regridding task.*"
85
+ ):
86
+ LateralRegridFromROMS(ds_in, target_coords, method="bilinear")
87
+
88
+
89
+ # Vertical regridding
7
90
  def vertical_regridder(depth_values, layer_depth_rho_values):
8
91
  class DataContainer:
9
92
  """Mock class for holding data and dimension names."""
@@ -21,7 +104,7 @@ def vertical_regridder(depth_values, layer_depth_rho_values):
21
104
  target_depth = xr.DataArray(data=layer_depth_rho_values, dims=["s_rho"])
22
105
  source_depth = xr.DataArray(data=depth_values, dims=["depth"])
23
106
 
24
- return VerticalRegrid(target_depth, source_depth)
107
+ return VerticalRegridToROMS(target_depth, source_depth)
25
108
 
26
109
 
27
110
  @pytest.mark.parametrize(