roms-tools 3.1.1__py3-none-any.whl → 3.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. roms_tools/__init__.py +8 -1
  2. roms_tools/analysis/cdr_analysis.py +203 -0
  3. roms_tools/analysis/cdr_ensemble.py +198 -0
  4. roms_tools/analysis/roms_output.py +80 -46
  5. roms_tools/data/grids/GLORYS_global_grid.nc +0 -0
  6. roms_tools/download.py +4 -0
  7. roms_tools/plot.py +131 -30
  8. roms_tools/regrid.py +6 -1
  9. roms_tools/setup/boundary_forcing.py +94 -44
  10. roms_tools/setup/cdr_forcing.py +123 -15
  11. roms_tools/setup/cdr_release.py +161 -8
  12. roms_tools/setup/datasets.py +709 -341
  13. roms_tools/setup/grid.py +167 -139
  14. roms_tools/setup/initial_conditions.py +113 -48
  15. roms_tools/setup/mask.py +63 -7
  16. roms_tools/setup/nesting.py +67 -42
  17. roms_tools/setup/river_forcing.py +45 -19
  18. roms_tools/setup/surface_forcing.py +16 -10
  19. roms_tools/setup/tides.py +1 -2
  20. roms_tools/setup/topography.py +4 -4
  21. roms_tools/setup/utils.py +134 -22
  22. roms_tools/tests/test_analysis/test_cdr_analysis.py +144 -0
  23. roms_tools/tests/test_analysis/test_cdr_ensemble.py +202 -0
  24. roms_tools/tests/test_analysis/test_roms_output.py +61 -3
  25. roms_tools/tests/test_setup/test_boundary_forcing.py +111 -52
  26. roms_tools/tests/test_setup/test_cdr_forcing.py +54 -0
  27. roms_tools/tests/test_setup/test_cdr_release.py +118 -1
  28. roms_tools/tests/test_setup/test_datasets.py +458 -34
  29. roms_tools/tests/test_setup/test_grid.py +238 -121
  30. roms_tools/tests/test_setup/test_initial_conditions.py +94 -41
  31. roms_tools/tests/test_setup/test_surface_forcing.py +28 -3
  32. roms_tools/tests/test_setup/test_utils.py +91 -1
  33. roms_tools/tests/test_setup/test_validation.py +21 -15
  34. roms_tools/tests/test_setup/utils.py +71 -0
  35. roms_tools/tests/test_tiling/test_join.py +241 -0
  36. roms_tools/tests/test_tiling/test_partition.py +45 -0
  37. roms_tools/tests/test_utils.py +224 -2
  38. roms_tools/tiling/join.py +189 -0
  39. roms_tools/tiling/partition.py +44 -30
  40. roms_tools/utils.py +488 -161
  41. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/METADATA +15 -4
  42. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/RECORD +45 -37
  43. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/WHEEL +0 -0
  44. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/licenses/LICENSE +0 -0
  45. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,8 @@
1
1
  import logging
2
2
  from collections import OrderedDict
3
- from datetime import datetime
3
+ from datetime import datetime, timedelta
4
4
  from pathlib import Path
5
+ from unittest import mock
5
6
 
6
7
  import numpy as np
7
8
  import pytest
@@ -9,13 +10,27 @@ import xarray as xr
9
10
 
10
11
  from roms_tools.download import download_test_data
11
12
  from roms_tools.setup.datasets import (
13
+ GLORYS_GLOBAL_GRID_PATH,
12
14
  CESMBGCDataset,
13
15
  Dataset,
16
+ ERA5ARCODataset,
14
17
  ERA5Correction,
15
18
  GLORYSDataset,
19
+ GLORYSDefaultDataset,
16
20
  RiverDataset,
17
21
  TPXODataset,
22
+ _concatenate_longitudes,
23
+ choose_subdomain,
24
+ get_glorys_bounds,
18
25
  )
26
+ from roms_tools.setup.surface_forcing import DEFAULT_ERA5_ARCO_PATH
27
+ from roms_tools.setup.utils import get_target_coords
28
+ from roms_tools.tests.test_setup.utils import download_regional_and_bigger
29
+
30
+ try:
31
+ import copernicusmarine # type: ignore
32
+ except ImportError:
33
+ copernicusmarine = None
19
34
 
20
35
 
21
36
  @pytest.fixture
@@ -25,9 +40,9 @@ def global_dataset():
25
40
  depth = np.linspace(0, 2000, 10)
26
41
  time = [
27
42
  np.datetime64("2022-01-01T00:00:00"),
28
- np.datetime64("2022-02-01T00:00:00"),
29
- np.datetime64("2022-03-01T00:00:00"),
30
- np.datetime64("2022-04-01T00:00:00"),
43
+ np.datetime64("2022-01-02T00:00:00"),
44
+ np.datetime64("2022-01-03T00:00:00"),
45
+ np.datetime64("2022-01-04T00:00:00"),
31
46
  ]
32
47
  data = np.random.rand(4, 10, 180, 360)
33
48
  ds = xr.Dataset(
@@ -48,9 +63,9 @@ def global_dataset_with_noon_times():
48
63
  lat = np.linspace(-90, 90, 180)
49
64
  time = [
50
65
  np.datetime64("2022-01-01T12:00:00"),
51
- np.datetime64("2022-02-01T12:00:00"),
52
- np.datetime64("2022-03-01T12:00:00"),
53
- np.datetime64("2022-04-01T12:00:00"),
66
+ np.datetime64("2022-01-02T12:00:00"),
67
+ np.datetime64("2022-01-03T12:00:00"),
68
+ np.datetime64("2022-01-04T12:00:00"),
54
69
  ]
55
70
  data = np.random.rand(4, 180, 360)
56
71
  ds = xr.Dataset(
@@ -71,12 +86,12 @@ def global_dataset_with_multiple_times_per_day():
71
86
  time = [
72
87
  np.datetime64("2022-01-01T00:00:00"),
73
88
  np.datetime64("2022-01-01T12:00:00"),
74
- np.datetime64("2022-02-01T00:00:00"),
75
- np.datetime64("2022-02-01T12:00:00"),
76
- np.datetime64("2022-03-01T00:00:00"),
77
- np.datetime64("2022-03-01T12:00:00"),
78
- np.datetime64("2022-04-01T00:00:00"),
79
- np.datetime64("2022-04-01T12:00:00"),
89
+ np.datetime64("2022-01-02T00:00:00"),
90
+ np.datetime64("2022-01-02T12:00:00"),
91
+ np.datetime64("2022-01-03T00:00:00"),
92
+ np.datetime64("2022-01-03T12:00:00"),
93
+ np.datetime64("2022-01-04T00:00:00"),
94
+ np.datetime64("2022-01-04T12:00:00"),
80
95
  ]
81
96
  data = np.random.rand(8, 180, 360)
82
97
  ds = xr.Dataset(
@@ -108,34 +123,37 @@ def non_global_dataset():
108
123
  (
109
124
  "global_dataset",
110
125
  [
111
- np.datetime64("2022-02-01T00:00:00"),
112
- np.datetime64("2022-03-01T00:00:00"),
126
+ np.datetime64("2022-01-02T00:00:00"),
127
+ np.datetime64("2022-01-03T00:00:00"),
128
+ np.datetime64("2022-01-04T00:00:00"),
113
129
  ],
114
130
  ),
115
131
  (
116
132
  "global_dataset_with_noon_times",
117
133
  [
118
134
  np.datetime64("2022-01-01T12:00:00"),
119
- np.datetime64("2022-02-01T12:00:00"),
120
- np.datetime64("2022-03-01T12:00:00"),
135
+ np.datetime64("2022-01-02T12:00:00"),
136
+ np.datetime64("2022-01-03T12:00:00"),
137
+ np.datetime64("2022-01-04T12:00:00"),
121
138
  ],
122
139
  ),
123
140
  (
124
141
  "global_dataset_with_multiple_times_per_day",
125
142
  [
126
- np.datetime64("2022-02-01T00:00:00"),
127
- np.datetime64("2022-02-01T12:00:00"),
128
- np.datetime64("2022-03-01T00:00:00"),
143
+ np.datetime64("2022-01-02T00:00:00"),
144
+ np.datetime64("2022-01-02T12:00:00"),
145
+ np.datetime64("2022-01-03T00:00:00"),
146
+ np.datetime64("2022-01-03T12:00:00"),
147
+ np.datetime64("2022-01-04T00:00:00"),
129
148
  ],
130
149
  ),
131
150
  ],
132
151
  )
133
152
  def test_select_times(data_fixture, expected_time_values, request, tmp_path, use_dask):
134
153
  """Test selecting times with different datasets."""
135
- start_time = datetime(2022, 2, 1)
136
- end_time = datetime(2022, 3, 1)
154
+ start_time = datetime(2022, 1, 2)
155
+ end_time = datetime(2022, 1, 4)
137
156
 
138
- # Get the fixture dynamically based on the parameter
139
157
  dataset = request.getfixturevalue(data_fixture)
140
158
 
141
159
  filepath = tmp_path / "test.nc"
@@ -157,16 +175,15 @@ def test_select_times(data_fixture, expected_time_values, request, tmp_path, use
157
175
  @pytest.mark.parametrize(
158
176
  "data_fixture, expected_time_values",
159
177
  [
160
- ("global_dataset", [np.datetime64("2022-02-01T00:00:00")]),
161
- ("global_dataset_with_noon_times", [np.datetime64("2022-02-01T12:00:00")]),
178
+ ("global_dataset", [np.datetime64("2022-01-02T00:00:00")]),
179
+ ("global_dataset_with_noon_times", [np.datetime64("2022-01-02T12:00:00")]),
162
180
  ],
163
181
  )
164
182
  def test_select_times_valid_start_no_end_time(
165
183
  data_fixture, expected_time_values, request, tmp_path, use_dask
166
184
  ):
167
185
  """Test selecting times with only start_time specified."""
168
- start_time = datetime(2022, 2, 1)
169
-
186
+ start_time = datetime(2022, 1, 2)
170
187
  # Get the fixture dynamically based on the parameter
171
188
  dataset = request.getfixturevalue(data_fixture)
172
189
 
@@ -180,9 +197,11 @@ def test_select_times_valid_start_no_end_time(
180
197
  var_names={"var": "var"},
181
198
  start_time=start_time,
182
199
  use_dask=use_dask,
200
+ allow_flex_time=True,
183
201
  )
184
202
 
185
203
  assert dataset.ds is not None
204
+ assert "time" in dataset.ds.dims
186
205
  assert len(dataset.ds.time) == len(expected_time_values)
187
206
  for expected_time in expected_time_values:
188
207
  assert expected_time in dataset.ds.time.values
@@ -208,7 +227,7 @@ def test_select_times_invalid_start_no_end_time(
208
227
 
209
228
  with pytest.raises(
210
229
  ValueError,
211
- match="The dataset does not contain any time entries between the specified start_time",
230
+ match="No exact match found ",
212
231
  ):
213
232
  dataset = Dataset(
214
233
  filename=filepath,
@@ -229,11 +248,12 @@ def test_multiple_matching_times(
229
248
  dataset = Dataset(
230
249
  filename=filepath,
231
250
  var_names={"var": "var"},
232
- start_time=datetime(2022, 1, 31, 22, 0),
251
+ start_time=datetime(2021, 12, 31, 22, 0),
233
252
  use_dask=use_dask,
253
+ allow_flex_time=True,
234
254
  )
235
255
 
236
- assert dataset.ds["time"].values == np.datetime64(datetime(2022, 2, 1, 0, 0))
256
+ assert dataset.ds["time"].values == np.datetime64(datetime(2022, 1, 1, 0, 0))
237
257
 
238
258
 
239
259
  def test_warnings_times(global_dataset, tmp_path, caplog, use_dask):
@@ -253,7 +273,7 @@ def test_warnings_times(global_dataset, tmp_path, caplog, use_dask):
253
273
  use_dask=use_dask,
254
274
  )
255
275
  # Verify the warning message in the log
256
- assert "No records found at or before the start_time." in caplog.text
276
+ assert "No records found at or before the start_time" in caplog.text
257
277
 
258
278
  with caplog.at_level(logging.WARNING):
259
279
  start_time = datetime(2024, 1, 1)
@@ -267,7 +287,7 @@ def test_warnings_times(global_dataset, tmp_path, caplog, use_dask):
267
287
  use_dask=use_dask,
268
288
  )
269
289
  # Verify the warning message in the log
270
- assert "No records found at or after the end_time." in caplog.text
290
+ assert "No records found at or after the end_time" in caplog.text
271
291
 
272
292
 
273
293
  def test_from_ds(global_dataset, global_dataset_with_noon_times, use_dask, tmp_path):
@@ -288,6 +308,7 @@ def test_from_ds(global_dataset, global_dataset_with_noon_times, use_dask, tmp_p
288
308
  },
289
309
  start_time=start_time,
290
310
  use_dask=use_dask,
311
+ allow_flex_time=True,
291
312
  )
292
313
 
293
314
  new_dataset = Dataset.from_ds(dataset, global_dataset_with_noon_times)
@@ -329,6 +350,7 @@ def test_reverse_latitude_reverse_depth_choose_subdomain(
329
350
  },
330
351
  start_time=start_time,
331
352
  use_dask=use_dask,
353
+ allow_flex_time=True,
332
354
  )
333
355
 
334
356
  assert np.all(np.diff(dataset.ds["latitude"]) > 0)
@@ -427,16 +449,122 @@ def test_check_dataset(global_dataset, tmp_path, use_dask):
427
449
  )
428
450
 
429
451
 
430
- def test_era5_correction_choose_subdomain(use_dask):
452
+ def test_era5_correction_match_subdomain(use_dask):
431
453
  data = ERA5Correction(use_dask=use_dask)
432
454
  lats = data.ds.latitude[10:20]
433
455
  lons = data.ds.longitude[10:20]
434
456
  target_coords = {"lat": lats, "lon": lons}
435
- data.choose_subdomain(target_coords, straddle=False)
457
+ data.match_subdomain(target_coords)
436
458
  assert (data.ds["latitude"] == lats).all()
437
459
  assert (data.ds["longitude"] == lons).all()
438
460
 
439
461
 
462
+ @pytest.mark.use_gcsfs
463
+ def test_default_era5_dataset_loading_without_dask() -> None:
464
+ """Verify that loading the default ERA5 dataset fails if use_dask is not True."""
465
+ start_time = datetime(2020, 2, 1)
466
+ end_time = datetime(2020, 2, 2)
467
+
468
+ with pytest.raises(ValueError):
469
+ _ = ERA5ARCODataset(
470
+ filename=DEFAULT_ERA5_ARCO_PATH,
471
+ start_time=start_time,
472
+ end_time=end_time,
473
+ use_dask=False,
474
+ )
475
+
476
+
477
+ @pytest.mark.skip("Temporary skip until memory consumption issue is addressed. # TODO")
478
+ @pytest.mark.stream
479
+ @pytest.mark.use_dask
480
+ @pytest.mark.use_gcsfs
481
+ def test_default_era5_dataset_loading() -> None:
482
+ """Verify the default ERA5 dataset is loaded correctly."""
483
+ start_time = datetime(2020, 2, 1)
484
+ end_time = datetime(2020, 2, 2)
485
+
486
+ ds = ERA5ARCODataset(
487
+ filename=DEFAULT_ERA5_ARCO_PATH,
488
+ start_time=start_time,
489
+ end_time=end_time,
490
+ use_dask=True,
491
+ )
492
+
493
+ expected_vars = {"uwnd", "vwnd", "swrad", "lwrad", "Tair", "rain"}
494
+ assert set(ds.var_names).issuperset(expected_vars)
495
+
496
+
497
+ @pytest.mark.use_copernicus
498
+ def test_default_glorys_dataset_loading_dask_not_installed() -> None:
499
+ """Verify that loading the default GLORYS dataset fails if dask is not available."""
500
+ start_time = datetime(2020, 2, 1)
501
+ end_time = datetime(2020, 2, 2)
502
+
503
+ with (
504
+ pytest.raises(RuntimeError),
505
+ mock.patch("roms_tools.utils.has_dask", return_value=False),
506
+ ):
507
+ _ = GLORYSDefaultDataset(
508
+ filename=GLORYSDefaultDataset.dataset_name,
509
+ start_time=start_time,
510
+ end_time=end_time,
511
+ use_dask=True,
512
+ )
513
+
514
+
515
+ @pytest.mark.stream
516
+ @pytest.mark.use_copernicus
517
+ @pytest.mark.use_dask
518
+ def test_default_glorys_dataset_loading() -> None:
519
+ """Verify the default GLORYS dataset is loaded correctly."""
520
+ start_time = datetime(2012, 1, 1)
521
+
522
+ for end_time in [start_time, start_time + timedelta(days=0.5)]:
523
+ data = GLORYSDefaultDataset(
524
+ filename=GLORYSDefaultDataset.dataset_name,
525
+ start_time=start_time,
526
+ end_time=end_time,
527
+ use_dask=True,
528
+ )
529
+
530
+ expected_vars = {"temp", "salt", "u", "v", "zeta"}
531
+ assert set(data.var_names).issuperset(expected_vars)
532
+
533
+ expected_vars = {"thetao", "so", "uo", "vo", "zos"}
534
+ assert "time" in data.ds.dims
535
+ assert set(data.ds.data_vars).issuperset(expected_vars)
536
+
537
+
538
+ @pytest.mark.parametrize(
539
+ "fname,start_time",
540
+ [
541
+ (download_test_data("GLORYS_NA_2012.nc"), datetime(2012, 1, 1, 12)),
542
+ (download_test_data("GLORYS_NA_20121231.nc"), datetime(2012, 12, 31, 12)),
543
+ (download_test_data("GLORYS_coarse_test_data.nc"), datetime(2021, 6, 29)),
544
+ ],
545
+ )
546
+ @pytest.mark.parametrize("allow_flex_time", [True, False])
547
+ def test_non_default_glorys_dataset_loading(
548
+ fname, start_time, allow_flex_time, use_dask
549
+ ) -> None:
550
+ """Verify the default GLORYS dataset is loaded correctly."""
551
+ for end_time in [None, start_time, start_time]:
552
+ data = GLORYSDataset(
553
+ filename=fname,
554
+ start_time=start_time,
555
+ end_time=end_time,
556
+ use_dask=use_dask,
557
+ allow_flex_time=allow_flex_time,
558
+ )
559
+
560
+ expected_vars = {"temp", "salt", "u", "v", "zeta"}
561
+ assert set(data.var_names).issuperset(expected_vars)
562
+
563
+ expected_vars = {"thetao", "so", "uo", "vo", "zos"}
564
+ assert "time" in data.ds.dims
565
+ assert set(data.ds.data_vars).issuperset(expected_vars)
566
+
567
+
440
568
  def test_data_concatenation(use_dask):
441
569
  fname = download_test_data("GLORYS_NA_2012.nc")
442
570
  data = GLORYSDataset(
@@ -672,3 +800,299 @@ class TestRiverDataset:
672
800
  assert "Amazon_2" in names
673
801
  assert "Nile" in names
674
802
  assert len(set(names)) == len(names) # all names must be unique
803
+
804
+
805
+ # test _concatenate_longitudes
806
+
807
+
808
+ @pytest.fixture
809
+ def sample_ds(use_dask):
810
+ lon = xr.DataArray(np.array([0, 90, 180]), dims="lon", name="lon")
811
+ lat = xr.DataArray(np.array([-30, 0, 30]), dims="lat", name="lat")
812
+
813
+ var_with_lon = xr.DataArray(
814
+ np.arange(9).reshape(3, 3),
815
+ dims=("lat", "lon"),
816
+ coords={"lat": lat, "lon": lon},
817
+ name="var_with_lon",
818
+ )
819
+
820
+ var_no_lon = xr.DataArray(
821
+ np.array([1, 2, 3]),
822
+ dims="lat",
823
+ coords={"lat": lat},
824
+ name="var_no_lon",
825
+ )
826
+
827
+ ds = xr.Dataset({"var_with_lon": var_with_lon, "var_no_lon": var_no_lon})
828
+
829
+ if use_dask:
830
+ ds = ds.chunk({"lat": -1, "lon": -1})
831
+
832
+ return ds
833
+
834
+
835
+ @pytest.mark.parametrize(
836
+ "end,expected_lons",
837
+ [
838
+ ("lower", [-360, -270, -180, 0, 90, 180]),
839
+ ("upper", [0, 90, 180, 360, 450, 540]),
840
+ ("both", [-360, -270, -180, 0, 90, 180, 360, 450, 540]),
841
+ ],
842
+ )
843
+ def test_concatenate_longitudes(sample_ds, end, expected_lons, use_dask):
844
+ dim_names = {"longitude": "lon"}
845
+
846
+ ds_concat = _concatenate_longitudes(
847
+ sample_ds, dim_names, end=end, use_dask=use_dask
848
+ )
849
+
850
+ # longitude should be extended as expected
851
+ np.testing.assert_array_equal(ds_concat.lon.values, expected_lons)
852
+
853
+ # variable with longitude should be extended in size
854
+ assert ds_concat.var_with_lon.shape[-1] == len(expected_lons)
855
+
856
+ # variable without longitude should remain untouched
857
+ np.testing.assert_array_equal(
858
+ ds_concat.var_no_lon.values,
859
+ sample_ds.var_no_lon.values,
860
+ )
861
+
862
+ if use_dask:
863
+ import dask
864
+
865
+ # Ensure dask array backing the data
866
+ assert isinstance(ds_concat.var_with_lon.data, dask.array.Array)
867
+ # Longitude dimension should be chunked (-1 → one chunk spanning the whole dim)
868
+ assert ds_concat.var_with_lon.chunks[-1] == (len(expected_lons),)
869
+ else:
870
+ # With use_dask=False, data should be a numpy array
871
+ assert isinstance(ds_concat.var_with_lon.data, np.ndarray)
872
+
873
+
874
+ def test_invalid_end_raises(sample_ds):
875
+ dim_names = {"longitude": "lon"}
876
+ with pytest.raises(ValueError):
877
+ _concatenate_longitudes(sample_ds, dim_names, end="invalid")
878
+
879
+
880
+ # test choose_subdomain
881
+
882
+
883
+ def test_choose_subdomain_basic(global_dataset, use_dask):
884
+ target_coords = {
885
+ "lat": xr.DataArray([0, 10]),
886
+ "lon": xr.DataArray([30, 40]),
887
+ "straddle": False,
888
+ }
889
+ out = choose_subdomain(
890
+ global_dataset,
891
+ dim_names={"latitude": "latitude", "longitude": "longitude"},
892
+ resolution=1.0,
893
+ is_global=True,
894
+ target_coords=target_coords,
895
+ buffer_points=2,
896
+ use_dask=use_dask,
897
+ )
898
+ assert out.latitude.min() <= 0
899
+ assert out.latitude.max() >= 10
900
+ assert out.longitude.min() <= 30
901
+ assert out.longitude.max() >= 40
902
+
903
+
904
+ def test_choose_subdomain_raises_on_empty_lon(non_global_dataset, use_dask):
905
+ target_coords = {
906
+ "lat": xr.DataArray([-10, 10]),
907
+ "lon": xr.DataArray([210, 215]), # outside 0-180
908
+ "straddle": False,
909
+ }
910
+ with pytest.raises(ValueError, match="longitude range"):
911
+ choose_subdomain(
912
+ non_global_dataset,
913
+ dim_names={"latitude": "latitude", "longitude": "longitude"},
914
+ resolution=1.0,
915
+ is_global=False,
916
+ target_coords=target_coords,
917
+ use_dask=use_dask,
918
+ )
919
+
920
+
921
+ def test_choose_subdomain_raises_on_empty_lat(global_dataset, use_dask):
922
+ target_coords = {
923
+ "lat": xr.DataArray([1000, 1010]), # outside dataset range
924
+ "lon": xr.DataArray([30, 40]),
925
+ "straddle": False,
926
+ }
927
+ with pytest.raises(ValueError, match="latitude range"):
928
+ choose_subdomain(
929
+ global_dataset,
930
+ dim_names={"latitude": "latitude", "longitude": "longitude"},
931
+ resolution=1.0,
932
+ is_global=True,
933
+ target_coords=target_coords,
934
+ use_dask=use_dask,
935
+ )
936
+
937
+
938
+ def test_choose_subdomain_straddle(global_dataset, use_dask):
939
+ target_coords = {
940
+ "lat": xr.DataArray([-10, 10]),
941
+ "lon": xr.DataArray([-170, 170]), # cross the dateline
942
+ "straddle": True,
943
+ }
944
+ out = choose_subdomain(
945
+ global_dataset,
946
+ dim_names={"latitude": "latitude", "longitude": "longitude"},
947
+ resolution=1.0,
948
+ is_global=True,
949
+ target_coords=target_coords,
950
+ buffer_points=5,
951
+ use_dask=use_dask,
952
+ )
953
+ # Ensure output includes both sides of the dateline, mapped into -180 - 180
954
+ assert (out.longitude.min() < 0) and (out.longitude.max() > 0)
955
+
956
+
957
+ def test_choose_subdomain_wraps_negative_lon(global_dataset, use_dask):
958
+ target_coords = {
959
+ "lat": xr.DataArray([0, 20]),
960
+ "lon": xr.DataArray([-20, -10]), # negative longitudes
961
+ "straddle": False,
962
+ }
963
+ out = choose_subdomain(
964
+ global_dataset,
965
+ dim_names={"latitude": "latitude", "longitude": "longitude"},
966
+ resolution=1.0,
967
+ is_global=True,
968
+ target_coords=target_coords,
969
+ buffer_points=2,
970
+ use_dask=use_dask,
971
+ )
972
+ # Output longitudes should be shifted into [0, 360]
973
+ assert (out.longitude >= 0).all()
974
+
975
+
976
+ def test_choose_subdomain_respects_buffer(global_dataset, use_dask):
977
+ target_coords = {
978
+ "lat": xr.DataArray([0, 0]),
979
+ "lon": xr.DataArray([50, 50]),
980
+ "straddle": False,
981
+ }
982
+ out = choose_subdomain(
983
+ global_dataset,
984
+ dim_names={"latitude": "latitude", "longitude": "longitude"},
985
+ resolution=1.0,
986
+ is_global=True,
987
+ target_coords=target_coords,
988
+ buffer_points=10,
989
+ use_dask=use_dask,
990
+ )
991
+ # Buffer should extend at least 10 degrees beyond target lon
992
+ assert out.longitude.min() <= 40
993
+ assert out.longitude.max() >= 60
994
+
995
+
996
+ # test get_glorys_bounds
997
+
998
+
999
+ @pytest.fixture
1000
+ def glorys_grid_0_360(tmp_path):
1001
+ lats = np.linspace(-90, 90, 181)
1002
+ lons = np.linspace(0, 360, 361)
1003
+ ds = xr.Dataset(coords={"latitude": lats, "longitude": lons})
1004
+ path = tmp_path / "GLORYS_0_360.nc"
1005
+ ds.to_netcdf(path)
1006
+ return path
1007
+
1008
+
1009
+ @pytest.fixture
1010
+ def glorys_grid_neg180_180(tmp_path):
1011
+ lats = np.linspace(-90, 90, 181)
1012
+ lons = np.linspace(-180, 180, 361)
1013
+ ds = xr.Dataset(coords={"latitude": lats, "longitude": lons})
1014
+ path = tmp_path / "GLORYS_neg180_180.nc"
1015
+ ds.to_netcdf(path)
1016
+ return path
1017
+
1018
+
1019
+ @pytest.fixture
1020
+ def glorys_grid_real():
1021
+ return GLORYS_GLOBAL_GRID_PATH
1022
+
1023
+
1024
+ @pytest.mark.parametrize(
1025
+ "grid_fixture",
1026
+ [
1027
+ "grid",
1028
+ "grid_that_straddles_dateline",
1029
+ "grid_that_straddles_180_degree_meridian",
1030
+ "small_grid",
1031
+ "tiny_grid",
1032
+ ],
1033
+ )
1034
+ @pytest.mark.parametrize(
1035
+ "glorys_grid_fixture",
1036
+ ["glorys_grid_0_360", "glorys_grid_neg180_180", "glorys_grid_real"],
1037
+ )
1038
+ def test_get_glorys_bounds(tmp_path, grid_fixture, glorys_grid_fixture, request):
1039
+ grid = request.getfixturevalue(grid_fixture)
1040
+ glorys_grid_path = request.getfixturevalue(glorys_grid_fixture)
1041
+
1042
+ bounds = get_glorys_bounds(grid=grid, glorys_grid_path=glorys_grid_path)
1043
+ assert set(bounds) == {
1044
+ "minimum_latitude",
1045
+ "maximum_latitude",
1046
+ "minimum_longitude",
1047
+ "maximum_longitude",
1048
+ }
1049
+
1050
+
1051
+ @pytest.mark.use_copernicus
1052
+ @pytest.mark.skipif(copernicusmarine is None, reason="copernicusmarine required")
1053
+ @pytest.mark.parametrize(
1054
+ "grid_fixture",
1055
+ [
1056
+ "tiny_grid_that_straddles_dateline",
1057
+ "tiny_grid_that_straddles_180_degree_meridian",
1058
+ "tiny_rotated_grid",
1059
+ ],
1060
+ )
1061
+ def test_invariance_to_get_glorys_bounds(tmp_path, grid_fixture, use_dask, request):
1062
+ start_time = datetime(2012, 1, 1)
1063
+ grid = request.getfixturevalue(grid_fixture)
1064
+ target_coords = get_target_coords(grid)
1065
+
1066
+ regional_file, bigger_regional_file = download_regional_and_bigger(
1067
+ tmp_path, grid, start_time, variables=["thetao", "uo", "zos"]
1068
+ )
1069
+
1070
+ # create datasets from regional and bigger regional data
1071
+ regional_data = GLORYSDataset(
1072
+ var_names={"temp": "thetao", "u": "uo", "zeta": "zos"},
1073
+ filename=regional_file,
1074
+ start_time=start_time,
1075
+ climatology=False,
1076
+ allow_flex_time=False,
1077
+ use_dask=use_dask,
1078
+ )
1079
+ bigger_regional_data = GLORYSDataset(
1080
+ var_names={"temp": "thetao", "u": "uo", "zeta": "zos"},
1081
+ filename=bigger_regional_file,
1082
+ start_time=start_time,
1083
+ climatology=False,
1084
+ allow_flex_time=False,
1085
+ use_dask=use_dask,
1086
+ )
1087
+
1088
+ # subset both datasets and check they are the same
1089
+ regional_data.choose_subdomain(target_coords)
1090
+ bigger_regional_data.choose_subdomain(target_coords)
1091
+
1092
+ # Use assert_allclose instead of equals: necessary for grids that straddle the 180° meridian.
1093
+ # Copernicus returns data on [-180, 180] by default, but if you request a range
1094
+ # like [170, 190], it remaps longitudes. That remapping introduces tiny floating
1095
+ # point differences in the longitude coordinate, so strict equality will fail
1096
+ # even though the data are effectively identical.
1097
+ # For grids that do not straddle the 180° meridian, strict equality still holds.
1098
+ xr.testing.assert_allclose(bigger_regional_data.ds, regional_data.ds)