roms-tools 3.1.2__py3-none-any.whl → 3.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. roms_tools/__init__.py +3 -0
  2. roms_tools/analysis/cdr_analysis.py +203 -0
  3. roms_tools/analysis/cdr_ensemble.py +198 -0
  4. roms_tools/analysis/roms_output.py +80 -46
  5. roms_tools/data/grids/GLORYS_global_grid.nc +0 -0
  6. roms_tools/download.py +4 -0
  7. roms_tools/plot.py +75 -21
  8. roms_tools/setup/boundary_forcing.py +44 -19
  9. roms_tools/setup/cdr_forcing.py +122 -8
  10. roms_tools/setup/cdr_release.py +161 -8
  11. roms_tools/setup/datasets.py +626 -340
  12. roms_tools/setup/grid.py +138 -137
  13. roms_tools/setup/initial_conditions.py +113 -48
  14. roms_tools/setup/mask.py +63 -7
  15. roms_tools/setup/nesting.py +67 -42
  16. roms_tools/setup/river_forcing.py +45 -19
  17. roms_tools/setup/surface_forcing.py +4 -6
  18. roms_tools/setup/tides.py +1 -2
  19. roms_tools/setup/topography.py +4 -4
  20. roms_tools/setup/utils.py +134 -22
  21. roms_tools/tests/test_analysis/test_cdr_analysis.py +144 -0
  22. roms_tools/tests/test_analysis/test_cdr_ensemble.py +202 -0
  23. roms_tools/tests/test_analysis/test_roms_output.py +61 -3
  24. roms_tools/tests/test_setup/test_boundary_forcing.py +54 -52
  25. roms_tools/tests/test_setup/test_cdr_forcing.py +54 -0
  26. roms_tools/tests/test_setup/test_cdr_release.py +118 -1
  27. roms_tools/tests/test_setup/test_datasets.py +392 -44
  28. roms_tools/tests/test_setup/test_grid.py +222 -115
  29. roms_tools/tests/test_setup/test_initial_conditions.py +94 -41
  30. roms_tools/tests/test_setup/test_surface_forcing.py +2 -1
  31. roms_tools/tests/test_setup/test_utils.py +91 -1
  32. roms_tools/tests/test_setup/utils.py +71 -0
  33. roms_tools/tests/test_tiling/test_join.py +241 -0
  34. roms_tools/tests/test_utils.py +139 -17
  35. roms_tools/tiling/join.py +189 -0
  36. roms_tools/utils.py +131 -99
  37. {roms_tools-3.1.2.dist-info → roms_tools-3.2.0.dist-info}/METADATA +12 -2
  38. {roms_tools-3.1.2.dist-info → roms_tools-3.2.0.dist-info}/RECORD +41 -33
  39. {roms_tools-3.1.2.dist-info → roms_tools-3.2.0.dist-info}/WHEEL +0 -0
  40. {roms_tools-3.1.2.dist-info → roms_tools-3.2.0.dist-info}/licenses/LICENSE +0 -0
  41. {roms_tools-3.1.2.dist-info → roms_tools-3.2.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  import logging
2
2
  from collections import OrderedDict
3
- from datetime import datetime
3
+ from datetime import datetime, timedelta
4
4
  from pathlib import Path
5
5
  from unittest import mock
6
6
 
@@ -10,6 +10,7 @@ import xarray as xr
10
10
 
11
11
  from roms_tools.download import download_test_data
12
12
  from roms_tools.setup.datasets import (
13
+ GLORYS_GLOBAL_GRID_PATH,
13
14
  CESMBGCDataset,
14
15
  Dataset,
15
16
  ERA5ARCODataset,
@@ -18,8 +19,18 @@ from roms_tools.setup.datasets import (
18
19
  GLORYSDefaultDataset,
19
20
  RiverDataset,
20
21
  TPXODataset,
22
+ _concatenate_longitudes,
23
+ choose_subdomain,
24
+ get_glorys_bounds,
21
25
  )
22
26
  from roms_tools.setup.surface_forcing import DEFAULT_ERA5_ARCO_PATH
27
+ from roms_tools.setup.utils import get_target_coords
28
+ from roms_tools.tests.test_setup.utils import download_regional_and_bigger
29
+
30
+ try:
31
+ import copernicusmarine # type: ignore
32
+ except ImportError:
33
+ copernicusmarine = None
23
34
 
24
35
 
25
36
  @pytest.fixture
@@ -29,9 +40,9 @@ def global_dataset():
29
40
  depth = np.linspace(0, 2000, 10)
30
41
  time = [
31
42
  np.datetime64("2022-01-01T00:00:00"),
32
- np.datetime64("2022-02-01T00:00:00"),
33
- np.datetime64("2022-03-01T00:00:00"),
34
- np.datetime64("2022-04-01T00:00:00"),
43
+ np.datetime64("2022-01-02T00:00:00"),
44
+ np.datetime64("2022-01-03T00:00:00"),
45
+ np.datetime64("2022-01-04T00:00:00"),
35
46
  ]
36
47
  data = np.random.rand(4, 10, 180, 360)
37
48
  ds = xr.Dataset(
@@ -52,9 +63,9 @@ def global_dataset_with_noon_times():
52
63
  lat = np.linspace(-90, 90, 180)
53
64
  time = [
54
65
  np.datetime64("2022-01-01T12:00:00"),
55
- np.datetime64("2022-02-01T12:00:00"),
56
- np.datetime64("2022-03-01T12:00:00"),
57
- np.datetime64("2022-04-01T12:00:00"),
66
+ np.datetime64("2022-01-02T12:00:00"),
67
+ np.datetime64("2022-01-03T12:00:00"),
68
+ np.datetime64("2022-01-04T12:00:00"),
58
69
  ]
59
70
  data = np.random.rand(4, 180, 360)
60
71
  ds = xr.Dataset(
@@ -75,12 +86,12 @@ def global_dataset_with_multiple_times_per_day():
75
86
  time = [
76
87
  np.datetime64("2022-01-01T00:00:00"),
77
88
  np.datetime64("2022-01-01T12:00:00"),
78
- np.datetime64("2022-02-01T00:00:00"),
79
- np.datetime64("2022-02-01T12:00:00"),
80
- np.datetime64("2022-03-01T00:00:00"),
81
- np.datetime64("2022-03-01T12:00:00"),
82
- np.datetime64("2022-04-01T00:00:00"),
83
- np.datetime64("2022-04-01T12:00:00"),
89
+ np.datetime64("2022-01-02T00:00:00"),
90
+ np.datetime64("2022-01-02T12:00:00"),
91
+ np.datetime64("2022-01-03T00:00:00"),
92
+ np.datetime64("2022-01-03T12:00:00"),
93
+ np.datetime64("2022-01-04T00:00:00"),
94
+ np.datetime64("2022-01-04T12:00:00"),
84
95
  ]
85
96
  data = np.random.rand(8, 180, 360)
86
97
  ds = xr.Dataset(
@@ -112,34 +123,37 @@ def non_global_dataset():
112
123
  (
113
124
  "global_dataset",
114
125
  [
115
- np.datetime64("2022-02-01T00:00:00"),
116
- np.datetime64("2022-03-01T00:00:00"),
126
+ np.datetime64("2022-01-02T00:00:00"),
127
+ np.datetime64("2022-01-03T00:00:00"),
128
+ np.datetime64("2022-01-04T00:00:00"),
117
129
  ],
118
130
  ),
119
131
  (
120
132
  "global_dataset_with_noon_times",
121
133
  [
122
134
  np.datetime64("2022-01-01T12:00:00"),
123
- np.datetime64("2022-02-01T12:00:00"),
124
- np.datetime64("2022-03-01T12:00:00"),
135
+ np.datetime64("2022-01-02T12:00:00"),
136
+ np.datetime64("2022-01-03T12:00:00"),
137
+ np.datetime64("2022-01-04T12:00:00"),
125
138
  ],
126
139
  ),
127
140
  (
128
141
  "global_dataset_with_multiple_times_per_day",
129
142
  [
130
- np.datetime64("2022-02-01T00:00:00"),
131
- np.datetime64("2022-02-01T12:00:00"),
132
- np.datetime64("2022-03-01T00:00:00"),
143
+ np.datetime64("2022-01-02T00:00:00"),
144
+ np.datetime64("2022-01-02T12:00:00"),
145
+ np.datetime64("2022-01-03T00:00:00"),
146
+ np.datetime64("2022-01-03T12:00:00"),
147
+ np.datetime64("2022-01-04T00:00:00"),
133
148
  ],
134
149
  ),
135
150
  ],
136
151
  )
137
152
  def test_select_times(data_fixture, expected_time_values, request, tmp_path, use_dask):
138
153
  """Test selecting times with different datasets."""
139
- start_time = datetime(2022, 2, 1)
140
- end_time = datetime(2022, 3, 1)
154
+ start_time = datetime(2022, 1, 2)
155
+ end_time = datetime(2022, 1, 4)
141
156
 
142
- # Get the fixture dynamically based on the parameter
143
157
  dataset = request.getfixturevalue(data_fixture)
144
158
 
145
159
  filepath = tmp_path / "test.nc"
@@ -161,16 +175,15 @@ def test_select_times(data_fixture, expected_time_values, request, tmp_path, use
161
175
  @pytest.mark.parametrize(
162
176
  "data_fixture, expected_time_values",
163
177
  [
164
- ("global_dataset", [np.datetime64("2022-02-01T00:00:00")]),
165
- ("global_dataset_with_noon_times", [np.datetime64("2022-02-01T12:00:00")]),
178
+ ("global_dataset", [np.datetime64("2022-01-02T00:00:00")]),
179
+ ("global_dataset_with_noon_times", [np.datetime64("2022-01-02T12:00:00")]),
166
180
  ],
167
181
  )
168
182
  def test_select_times_valid_start_no_end_time(
169
183
  data_fixture, expected_time_values, request, tmp_path, use_dask
170
184
  ):
171
185
  """Test selecting times with only start_time specified."""
172
- start_time = datetime(2022, 2, 1)
173
-
186
+ start_time = datetime(2022, 1, 2)
174
187
  # Get the fixture dynamically based on the parameter
175
188
  dataset = request.getfixturevalue(data_fixture)
176
189
 
@@ -184,9 +197,11 @@ def test_select_times_valid_start_no_end_time(
184
197
  var_names={"var": "var"},
185
198
  start_time=start_time,
186
199
  use_dask=use_dask,
200
+ allow_flex_time=True,
187
201
  )
188
202
 
189
203
  assert dataset.ds is not None
204
+ assert "time" in dataset.ds.dims
190
205
  assert len(dataset.ds.time) == len(expected_time_values)
191
206
  for expected_time in expected_time_values:
192
207
  assert expected_time in dataset.ds.time.values
@@ -212,7 +227,7 @@ def test_select_times_invalid_start_no_end_time(
212
227
 
213
228
  with pytest.raises(
214
229
  ValueError,
215
- match="The dataset does not contain any time entries between the specified start_time",
230
+ match="No exact match found ",
216
231
  ):
217
232
  dataset = Dataset(
218
233
  filename=filepath,
@@ -233,11 +248,12 @@ def test_multiple_matching_times(
233
248
  dataset = Dataset(
234
249
  filename=filepath,
235
250
  var_names={"var": "var"},
236
- start_time=datetime(2022, 1, 31, 22, 0),
251
+ start_time=datetime(2021, 12, 31, 22, 0),
237
252
  use_dask=use_dask,
253
+ allow_flex_time=True,
238
254
  )
239
255
 
240
- assert dataset.ds["time"].values == np.datetime64(datetime(2022, 2, 1, 0, 0))
256
+ assert dataset.ds["time"].values == np.datetime64(datetime(2022, 1, 1, 0, 0))
241
257
 
242
258
 
243
259
  def test_warnings_times(global_dataset, tmp_path, caplog, use_dask):
@@ -257,7 +273,7 @@ def test_warnings_times(global_dataset, tmp_path, caplog, use_dask):
257
273
  use_dask=use_dask,
258
274
  )
259
275
  # Verify the warning message in the log
260
- assert "No records found at or before the start_time." in caplog.text
276
+ assert "No records found at or before the start_time" in caplog.text
261
277
 
262
278
  with caplog.at_level(logging.WARNING):
263
279
  start_time = datetime(2024, 1, 1)
@@ -271,7 +287,7 @@ def test_warnings_times(global_dataset, tmp_path, caplog, use_dask):
271
287
  use_dask=use_dask,
272
288
  )
273
289
  # Verify the warning message in the log
274
- assert "No records found at or after the end_time." in caplog.text
290
+ assert "No records found at or after the end_time" in caplog.text
275
291
 
276
292
 
277
293
  def test_from_ds(global_dataset, global_dataset_with_noon_times, use_dask, tmp_path):
@@ -292,6 +308,7 @@ def test_from_ds(global_dataset, global_dataset_with_noon_times, use_dask, tmp_p
292
308
  },
293
309
  start_time=start_time,
294
310
  use_dask=use_dask,
311
+ allow_flex_time=True,
295
312
  )
296
313
 
297
314
  new_dataset = Dataset.from_ds(dataset, global_dataset_with_noon_times)
@@ -333,6 +350,7 @@ def test_reverse_latitude_reverse_depth_choose_subdomain(
333
350
  },
334
351
  start_time=start_time,
335
352
  use_dask=use_dask,
353
+ allow_flex_time=True,
336
354
  )
337
355
 
338
356
  assert np.all(np.diff(dataset.ds["latitude"]) > 0)
@@ -431,12 +449,12 @@ def test_check_dataset(global_dataset, tmp_path, use_dask):
431
449
  )
432
450
 
433
451
 
434
- def test_era5_correction_choose_subdomain(use_dask):
452
+ def test_era5_correction_match_subdomain(use_dask):
435
453
  data = ERA5Correction(use_dask=use_dask)
436
454
  lats = data.ds.latitude[10:20]
437
455
  lons = data.ds.longitude[10:20]
438
456
  target_coords = {"lat": lats, "lon": lons}
439
- data.choose_subdomain(target_coords, straddle=False)
457
+ data.match_subdomain(target_coords)
440
458
  assert (data.ds["latitude"] == lats).all()
441
459
  assert (data.ds["longitude"] == lons).all()
442
460
 
@@ -484,7 +502,7 @@ def test_default_glorys_dataset_loading_dask_not_installed() -> None:
484
502
 
485
503
  with (
486
504
  pytest.raises(RuntimeError),
487
- mock.patch("roms_tools.utils._has_dask", return_value=False),
505
+ mock.patch("roms_tools.utils.has_dask", return_value=False),
488
506
  ):
489
507
  _ = GLORYSDefaultDataset(
490
508
  filename=GLORYSDefaultDataset.dataset_name,
@@ -500,17 +518,51 @@ def test_default_glorys_dataset_loading_dask_not_installed() -> None:
500
518
  def test_default_glorys_dataset_loading() -> None:
501
519
  """Verify the default GLORYS dataset is loaded correctly."""
502
520
  start_time = datetime(2012, 1, 1)
503
- end_time = datetime(2013, 1, 1)
504
521
 
505
- ds = GLORYSDefaultDataset(
506
- filename=GLORYSDefaultDataset.dataset_name,
507
- start_time=start_time,
508
- end_time=end_time,
509
- use_dask=True,
510
- )
522
+ for end_time in [start_time, start_time + timedelta(days=0.5)]:
523
+ data = GLORYSDefaultDataset(
524
+ filename=GLORYSDefaultDataset.dataset_name,
525
+ start_time=start_time,
526
+ end_time=end_time,
527
+ use_dask=True,
528
+ )
511
529
 
512
- expected_vars = {"temp", "salt", "u", "v", "zeta"}
513
- assert set(ds.var_names).issuperset(expected_vars)
530
+ expected_vars = {"temp", "salt", "u", "v", "zeta"}
531
+ assert set(data.var_names).issuperset(expected_vars)
532
+
533
+ expected_vars = {"thetao", "so", "uo", "vo", "zos"}
534
+ assert "time" in data.ds.dims
535
+ assert set(data.ds.data_vars).issuperset(expected_vars)
536
+
537
+
538
+ @pytest.mark.parametrize(
539
+ "fname,start_time",
540
+ [
541
+ (download_test_data("GLORYS_NA_2012.nc"), datetime(2012, 1, 1, 12)),
542
+ (download_test_data("GLORYS_NA_20121231.nc"), datetime(2012, 12, 31, 12)),
543
+ (download_test_data("GLORYS_coarse_test_data.nc"), datetime(2021, 6, 29)),
544
+ ],
545
+ )
546
+ @pytest.mark.parametrize("allow_flex_time", [True, False])
547
+ def test_non_default_glorys_dataset_loading(
548
+ fname, start_time, allow_flex_time, use_dask
549
+ ) -> None:
550
+ """Verify the default GLORYS dataset is loaded correctly."""
551
+ for end_time in [None, start_time, start_time]:
552
+ data = GLORYSDataset(
553
+ filename=fname,
554
+ start_time=start_time,
555
+ end_time=end_time,
556
+ use_dask=use_dask,
557
+ allow_flex_time=allow_flex_time,
558
+ )
559
+
560
+ expected_vars = {"temp", "salt", "u", "v", "zeta"}
561
+ assert set(data.var_names).issuperset(expected_vars)
562
+
563
+ expected_vars = {"thetao", "so", "uo", "vo", "zos"}
564
+ assert "time" in data.ds.dims
565
+ assert set(data.ds.data_vars).issuperset(expected_vars)
514
566
 
515
567
 
516
568
  def test_data_concatenation(use_dask):
@@ -748,3 +800,299 @@ class TestRiverDataset:
748
800
  assert "Amazon_2" in names
749
801
  assert "Nile" in names
750
802
  assert len(set(names)) == len(names) # all names must be unique
803
+
804
+
805
+ # test _concatenate_longitudes
806
+
807
+
808
+ @pytest.fixture
809
+ def sample_ds(use_dask):
810
+ lon = xr.DataArray(np.array([0, 90, 180]), dims="lon", name="lon")
811
+ lat = xr.DataArray(np.array([-30, 0, 30]), dims="lat", name="lat")
812
+
813
+ var_with_lon = xr.DataArray(
814
+ np.arange(9).reshape(3, 3),
815
+ dims=("lat", "lon"),
816
+ coords={"lat": lat, "lon": lon},
817
+ name="var_with_lon",
818
+ )
819
+
820
+ var_no_lon = xr.DataArray(
821
+ np.array([1, 2, 3]),
822
+ dims="lat",
823
+ coords={"lat": lat},
824
+ name="var_no_lon",
825
+ )
826
+
827
+ ds = xr.Dataset({"var_with_lon": var_with_lon, "var_no_lon": var_no_lon})
828
+
829
+ if use_dask:
830
+ ds = ds.chunk({"lat": -1, "lon": -1})
831
+
832
+ return ds
833
+
834
+
835
+ @pytest.mark.parametrize(
836
+ "end,expected_lons",
837
+ [
838
+ ("lower", [-360, -270, -180, 0, 90, 180]),
839
+ ("upper", [0, 90, 180, 360, 450, 540]),
840
+ ("both", [-360, -270, -180, 0, 90, 180, 360, 450, 540]),
841
+ ],
842
+ )
843
+ def test_concatenate_longitudes(sample_ds, end, expected_lons, use_dask):
844
+ dim_names = {"longitude": "lon"}
845
+
846
+ ds_concat = _concatenate_longitudes(
847
+ sample_ds, dim_names, end=end, use_dask=use_dask
848
+ )
849
+
850
+ # longitude should be extended as expected
851
+ np.testing.assert_array_equal(ds_concat.lon.values, expected_lons)
852
+
853
+ # variable with longitude should be extended in size
854
+ assert ds_concat.var_with_lon.shape[-1] == len(expected_lons)
855
+
856
+ # variable without longitude should remain untouched
857
+ np.testing.assert_array_equal(
858
+ ds_concat.var_no_lon.values,
859
+ sample_ds.var_no_lon.values,
860
+ )
861
+
862
+ if use_dask:
863
+ import dask
864
+
865
+ # Ensure dask array backing the data
866
+ assert isinstance(ds_concat.var_with_lon.data, dask.array.Array)
867
+ # Longitude dimension should be chunked (-1 → one chunk spanning the whole dim)
868
+ assert ds_concat.var_with_lon.chunks[-1] == (len(expected_lons),)
869
+ else:
870
+ # With use_dask=False, data should be a numpy array
871
+ assert isinstance(ds_concat.var_with_lon.data, np.ndarray)
872
+
873
+
874
+ def test_invalid_end_raises(sample_ds):
875
+ dim_names = {"longitude": "lon"}
876
+ with pytest.raises(ValueError):
877
+ _concatenate_longitudes(sample_ds, dim_names, end="invalid")
878
+
879
+
880
+ # test choose_subdomain
881
+
882
+
883
+ def test_choose_subdomain_basic(global_dataset, use_dask):
884
+ target_coords = {
885
+ "lat": xr.DataArray([0, 10]),
886
+ "lon": xr.DataArray([30, 40]),
887
+ "straddle": False,
888
+ }
889
+ out = choose_subdomain(
890
+ global_dataset,
891
+ dim_names={"latitude": "latitude", "longitude": "longitude"},
892
+ resolution=1.0,
893
+ is_global=True,
894
+ target_coords=target_coords,
895
+ buffer_points=2,
896
+ use_dask=use_dask,
897
+ )
898
+ assert out.latitude.min() <= 0
899
+ assert out.latitude.max() >= 10
900
+ assert out.longitude.min() <= 30
901
+ assert out.longitude.max() >= 40
902
+
903
+
904
+ def test_choose_subdomain_raises_on_empty_lon(non_global_dataset, use_dask):
905
+ target_coords = {
906
+ "lat": xr.DataArray([-10, 10]),
907
+ "lon": xr.DataArray([210, 215]), # outside 0-180
908
+ "straddle": False,
909
+ }
910
+ with pytest.raises(ValueError, match="longitude range"):
911
+ choose_subdomain(
912
+ non_global_dataset,
913
+ dim_names={"latitude": "latitude", "longitude": "longitude"},
914
+ resolution=1.0,
915
+ is_global=False,
916
+ target_coords=target_coords,
917
+ use_dask=use_dask,
918
+ )
919
+
920
+
921
+ def test_choose_subdomain_raises_on_empty_lat(global_dataset, use_dask):
922
+ target_coords = {
923
+ "lat": xr.DataArray([1000, 1010]), # outside dataset range
924
+ "lon": xr.DataArray([30, 40]),
925
+ "straddle": False,
926
+ }
927
+ with pytest.raises(ValueError, match="latitude range"):
928
+ choose_subdomain(
929
+ global_dataset,
930
+ dim_names={"latitude": "latitude", "longitude": "longitude"},
931
+ resolution=1.0,
932
+ is_global=True,
933
+ target_coords=target_coords,
934
+ use_dask=use_dask,
935
+ )
936
+
937
+
938
+ def test_choose_subdomain_straddle(global_dataset, use_dask):
939
+ target_coords = {
940
+ "lat": xr.DataArray([-10, 10]),
941
+ "lon": xr.DataArray([-170, 170]), # cross the dateline
942
+ "straddle": True,
943
+ }
944
+ out = choose_subdomain(
945
+ global_dataset,
946
+ dim_names={"latitude": "latitude", "longitude": "longitude"},
947
+ resolution=1.0,
948
+ is_global=True,
949
+ target_coords=target_coords,
950
+ buffer_points=5,
951
+ use_dask=use_dask,
952
+ )
953
+ # Ensure output includes both sides of the dateline, mapped into -180 - 180
954
+ assert (out.longitude.min() < 0) and (out.longitude.max() > 0)
955
+
956
+
957
+ def test_choose_subdomain_wraps_negative_lon(global_dataset, use_dask):
958
+ target_coords = {
959
+ "lat": xr.DataArray([0, 20]),
960
+ "lon": xr.DataArray([-20, -10]), # negative longitudes
961
+ "straddle": False,
962
+ }
963
+ out = choose_subdomain(
964
+ global_dataset,
965
+ dim_names={"latitude": "latitude", "longitude": "longitude"},
966
+ resolution=1.0,
967
+ is_global=True,
968
+ target_coords=target_coords,
969
+ buffer_points=2,
970
+ use_dask=use_dask,
971
+ )
972
+ # Output longitudes should be shifted into [0, 360]
973
+ assert (out.longitude >= 0).all()
974
+
975
+
976
+ def test_choose_subdomain_respects_buffer(global_dataset, use_dask):
977
+ target_coords = {
978
+ "lat": xr.DataArray([0, 0]),
979
+ "lon": xr.DataArray([50, 50]),
980
+ "straddle": False,
981
+ }
982
+ out = choose_subdomain(
983
+ global_dataset,
984
+ dim_names={"latitude": "latitude", "longitude": "longitude"},
985
+ resolution=1.0,
986
+ is_global=True,
987
+ target_coords=target_coords,
988
+ buffer_points=10,
989
+ use_dask=use_dask,
990
+ )
991
+ # Buffer should extend at least 10 degrees beyond target lon
992
+ assert out.longitude.min() <= 40
993
+ assert out.longitude.max() >= 60
994
+
995
+
996
+ # test get_glorys_bounds
997
+
998
+
999
+ @pytest.fixture
1000
+ def glorys_grid_0_360(tmp_path):
1001
+ lats = np.linspace(-90, 90, 181)
1002
+ lons = np.linspace(0, 360, 361)
1003
+ ds = xr.Dataset(coords={"latitude": lats, "longitude": lons})
1004
+ path = tmp_path / "GLORYS_0_360.nc"
1005
+ ds.to_netcdf(path)
1006
+ return path
1007
+
1008
+
1009
+ @pytest.fixture
1010
+ def glorys_grid_neg180_180(tmp_path):
1011
+ lats = np.linspace(-90, 90, 181)
1012
+ lons = np.linspace(-180, 180, 361)
1013
+ ds = xr.Dataset(coords={"latitude": lats, "longitude": lons})
1014
+ path = tmp_path / "GLORYS_neg180_180.nc"
1015
+ ds.to_netcdf(path)
1016
+ return path
1017
+
1018
+
1019
+ @pytest.fixture
1020
+ def glorys_grid_real():
1021
+ return GLORYS_GLOBAL_GRID_PATH
1022
+
1023
+
1024
+ @pytest.mark.parametrize(
1025
+ "grid_fixture",
1026
+ [
1027
+ "grid",
1028
+ "grid_that_straddles_dateline",
1029
+ "grid_that_straddles_180_degree_meridian",
1030
+ "small_grid",
1031
+ "tiny_grid",
1032
+ ],
1033
+ )
1034
+ @pytest.mark.parametrize(
1035
+ "glorys_grid_fixture",
1036
+ ["glorys_grid_0_360", "glorys_grid_neg180_180", "glorys_grid_real"],
1037
+ )
1038
+ def test_get_glorys_bounds(tmp_path, grid_fixture, glorys_grid_fixture, request):
1039
+ grid = request.getfixturevalue(grid_fixture)
1040
+ glorys_grid_path = request.getfixturevalue(glorys_grid_fixture)
1041
+
1042
+ bounds = get_glorys_bounds(grid=grid, glorys_grid_path=glorys_grid_path)
1043
+ assert set(bounds) == {
1044
+ "minimum_latitude",
1045
+ "maximum_latitude",
1046
+ "minimum_longitude",
1047
+ "maximum_longitude",
1048
+ }
1049
+
1050
+
1051
+ @pytest.mark.use_copernicus
1052
+ @pytest.mark.skipif(copernicusmarine is None, reason="copernicusmarine required")
1053
+ @pytest.mark.parametrize(
1054
+ "grid_fixture",
1055
+ [
1056
+ "tiny_grid_that_straddles_dateline",
1057
+ "tiny_grid_that_straddles_180_degree_meridian",
1058
+ "tiny_rotated_grid",
1059
+ ],
1060
+ )
1061
+ def test_invariance_to_get_glorys_bounds(tmp_path, grid_fixture, use_dask, request):
1062
+ start_time = datetime(2012, 1, 1)
1063
+ grid = request.getfixturevalue(grid_fixture)
1064
+ target_coords = get_target_coords(grid)
1065
+
1066
+ regional_file, bigger_regional_file = download_regional_and_bigger(
1067
+ tmp_path, grid, start_time, variables=["thetao", "uo", "zos"]
1068
+ )
1069
+
1070
+ # create datasets from regional and bigger regional data
1071
+ regional_data = GLORYSDataset(
1072
+ var_names={"temp": "thetao", "u": "uo", "zeta": "zos"},
1073
+ filename=regional_file,
1074
+ start_time=start_time,
1075
+ climatology=False,
1076
+ allow_flex_time=False,
1077
+ use_dask=use_dask,
1078
+ )
1079
+ bigger_regional_data = GLORYSDataset(
1080
+ var_names={"temp": "thetao", "u": "uo", "zeta": "zos"},
1081
+ filename=bigger_regional_file,
1082
+ start_time=start_time,
1083
+ climatology=False,
1084
+ allow_flex_time=False,
1085
+ use_dask=use_dask,
1086
+ )
1087
+
1088
+ # subset both datasets and check they are the same
1089
+ regional_data.choose_subdomain(target_coords)
1090
+ bigger_regional_data.choose_subdomain(target_coords)
1091
+
1092
+ # Use assert_allclose instead of equals: necessary for grids that straddle the 180° meridian.
1093
+ # Copernicus returns data on [-180, 180] by default, but if you request a range
1094
+ # like [170, 190], it remaps longitudes. That remapping introduces tiny floating
1095
+ # point differences in the longitude coordinate, so strict equality will fail
1096
+ # even though the data are effectively identical.
1097
+ # For grids that do not straddle the 180° meridian, strict equality still holds.
1098
+ xr.testing.assert_allclose(bigger_regional_data.ds, regional_data.ds)