roms-tools 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,365 @@
1
+ import pytest
2
+ import tempfile
3
+ import os
4
+ from roms_tools import Grid, TidalForcing
5
+ import xarray as xr
6
+ import numpy as np
7
+ from roms_tools.setup.download import download_test_data
8
+ import textwrap
9
+
10
+
11
+ @pytest.fixture
12
+ def grid_that_lies_within_bounds_of_regional_tpxo_data():
13
+ grid = Grid(
14
+ nx=3, ny=3, size_x=1500, size_y=1500, center_lon=235, center_lat=25, rot=-20
15
+ )
16
+ return grid
17
+
18
+
19
+ @pytest.fixture
20
+ def grid_that_is_out_of_bounds_of_regional_tpxo_data():
21
+ grid = Grid(
22
+ nx=3, ny=3, size_x=1800, size_y=1500, center_lon=235, center_lat=25, rot=-20
23
+ )
24
+ return grid
25
+
26
+
27
+ @pytest.fixture
28
+ def grid_that_straddles_dateline():
29
+ """
30
+ Fixture for creating a domain that straddles the dateline.
31
+ """
32
+ grid = Grid(
33
+ nx=5,
34
+ ny=5,
35
+ size_x=1800,
36
+ size_y=2400,
37
+ center_lon=-10,
38
+ center_lat=30,
39
+ rot=20,
40
+ )
41
+
42
+ return grid
43
+
44
+
45
+ @pytest.fixture
46
+ def grid_that_straddles_180_degree_meridian():
47
+ """
48
+ Fixture for creating a domain that straddles 180 degree meridian.
49
+ """
50
+
51
+ grid = Grid(
52
+ nx=5,
53
+ ny=5,
54
+ size_x=1800,
55
+ size_y=2400,
56
+ center_lon=180,
57
+ center_lat=30,
58
+ rot=20,
59
+ )
60
+
61
+ return grid
62
+
63
+
64
+ @pytest.mark.parametrize(
65
+ "grid_fixture",
66
+ [
67
+ "grid_that_lies_within_bounds_of_regional_tpxo_data",
68
+ "grid_that_is_out_of_bounds_of_regional_tpxo_data",
69
+ "grid_that_straddles_dateline",
70
+ "grid_that_straddles_180_degree_meridian",
71
+ ],
72
+ )
73
+ def test_successful_initialization_with_global_data(grid_fixture, request):
74
+
75
+ fname = download_test_data("TPXO_global_test_data.nc")
76
+
77
+ grid = request.getfixturevalue(grid_fixture)
78
+
79
+ tidal_forcing = TidalForcing(
80
+ grid=grid, source={"name": "TPXO", "path": fname}, ntides=2
81
+ )
82
+
83
+ assert isinstance(tidal_forcing.ds, xr.Dataset)
84
+ assert "omega" in tidal_forcing.ds
85
+ assert "ssh_Re" in tidal_forcing.ds
86
+ assert "ssh_Im" in tidal_forcing.ds
87
+ assert "pot_Re" in tidal_forcing.ds
88
+ assert "pot_Im" in tidal_forcing.ds
89
+ assert "u_Re" in tidal_forcing.ds
90
+ assert "u_Im" in tidal_forcing.ds
91
+ assert "v_Re" in tidal_forcing.ds
92
+ assert "v_Im" in tidal_forcing.ds
93
+
94
+ assert tidal_forcing.source == {"name": "TPXO", "path": fname}
95
+ assert tidal_forcing.ntides == 2
96
+
97
+
98
+ def test_successful_initialization_with_regional_data(
99
+ grid_that_lies_within_bounds_of_regional_tpxo_data,
100
+ ):
101
+
102
+ fname = download_test_data("TPXO_regional_test_data.nc")
103
+
104
+ tidal_forcing = TidalForcing(
105
+ grid=grid_that_lies_within_bounds_of_regional_tpxo_data,
106
+ source={"name": "TPXO", "path": fname},
107
+ ntides=10,
108
+ )
109
+
110
+ assert isinstance(tidal_forcing.ds, xr.Dataset)
111
+ assert "omega" in tidal_forcing.ds
112
+ assert "ssh_Re" in tidal_forcing.ds
113
+ assert "ssh_Im" in tidal_forcing.ds
114
+ assert "pot_Re" in tidal_forcing.ds
115
+ assert "pot_Im" in tidal_forcing.ds
116
+ assert "u_Re" in tidal_forcing.ds
117
+ assert "u_Im" in tidal_forcing.ds
118
+ assert "v_Re" in tidal_forcing.ds
119
+ assert "v_Im" in tidal_forcing.ds
120
+
121
+ assert tidal_forcing.source == {"name": "TPXO", "path": fname}
122
+ assert tidal_forcing.ntides == 10
123
+
124
+
125
+ def test_unsuccessful_initialization_with_regional_data_due_to_nans(
126
+ grid_that_is_out_of_bounds_of_regional_tpxo_data,
127
+ ):
128
+
129
+ fname = download_test_data("TPXO_regional_test_data.nc")
130
+
131
+ with pytest.raises(ValueError, match="NaN values found"):
132
+ TidalForcing(
133
+ grid=grid_that_is_out_of_bounds_of_regional_tpxo_data,
134
+ source={"name": "TPXO", "path": fname},
135
+ ntides=10,
136
+ )
137
+
138
+
139
+ @pytest.mark.parametrize(
140
+ "grid_fixture",
141
+ ["grid_that_straddles_dateline", "grid_that_straddles_180_degree_meridian"],
142
+ )
143
+ def test_unsuccessful_initialization_with_regional_data_due_to_no_overlap(
144
+ grid_fixture, request
145
+ ):
146
+
147
+ fname = download_test_data("TPXO_regional_test_data.nc")
148
+
149
+ grid = request.getfixturevalue(grid_fixture)
150
+
151
+ with pytest.raises(
152
+ ValueError, match="Selected longitude range does not intersect with dataset"
153
+ ):
154
+ TidalForcing(grid=grid, source={"name": "TPXO", "path": fname}, ntides=10)
155
+
156
+
157
+ def test_insufficient_number_of_consituents(grid_that_straddles_dateline):
158
+
159
+ fname = download_test_data("TPXO_global_test_data.nc")
160
+
161
+ with pytest.raises(ValueError, match="The dataset contains fewer"):
162
+ TidalForcing(
163
+ grid=grid_that_straddles_dateline,
164
+ source={"name": "TPXO", "path": fname},
165
+ ntides=10,
166
+ )
167
+
168
+
169
+ @pytest.fixture
170
+ def tidal_forcing(
171
+ grid_that_lies_within_bounds_of_regional_tpxo_data,
172
+ ):
173
+
174
+ fname = download_test_data("TPXO_regional_test_data.nc")
175
+
176
+ return TidalForcing(
177
+ grid=grid_that_lies_within_bounds_of_regional_tpxo_data,
178
+ source={"name": "TPXO", "path": fname},
179
+ ntides=1,
180
+ )
181
+
182
+
183
+ def test_tidal_forcing_data_consistency_plot_save(tidal_forcing, tmp_path):
184
+ """
185
+ Test that the data within the TidalForcing object remains consistent.
186
+ Also test plot and save methods in the same test since we dask arrays are already computed.
187
+ """
188
+ tidal_forcing.ds.load()
189
+
190
+ expected_ssh_Re = np.array(
191
+ [
192
+ [
193
+ [0.03362583, 0.0972546, 0.15625167, 0.20162642, 0.22505085],
194
+ [0.04829295, 0.13148762, 0.2091077, 0.26777256, 0.28947946],
195
+ [0.0574473, 0.16427538, 0.26692376, 0.335315, 0.35217384],
196
+ [0.0555277, 0.1952368, 0.32960117, 0.41684473, 0.43021917],
197
+ [0.04893931, 0.22524744, 0.39933527, 0.39793402, -0.18146336],
198
+ ]
199
+ ],
200
+ dtype=np.float32,
201
+ )
202
+
203
+ expected_ssh_Im = np.array(
204
+ [
205
+ [
206
+ [0.28492996, 0.33401084, 0.3791059, 0.40458283, 0.39344734],
207
+ [0.14864475, 0.19812492, 0.25232342, 0.29423112, 0.30793712],
208
+ [-0.01214434, 0.04206207, 0.11318521, 0.18785079, 0.24001373],
209
+ [-0.18849652, -0.13063835, -0.02998546, 0.0921034, 0.20685565],
210
+ [-0.36839223, -0.31615746, -0.18911538, -0.08607443, -0.51923835],
211
+ ]
212
+ ],
213
+ dtype=np.float32,
214
+ )
215
+
216
+ expected_pot_Re = np.array(
217
+ [
218
+ [
219
+ [-0.11110803, -0.08998635, -0.06672653, -0.04285957, -0.01980283],
220
+ [-0.10053363, -0.07692371, -0.05161811, -0.02654761, -0.00358691],
221
+ [-0.08996155, -0.06400539, -0.03679418, -0.01115401, 0.01084424],
222
+ [-0.08017255, -0.05190206, -0.0231975, 0.00163647, 0.01880641],
223
+ [-0.07144432, -0.04169955, -0.0143679, -0.00313035, 0.00145161],
224
+ ]
225
+ ],
226
+ dtype=np.float32,
227
+ )
228
+
229
+ expected_pot_Im = np.array(
230
+ [
231
+ [
232
+ [-0.05019786, -0.06314129, -0.07475527, -0.08616351, -0.09869237],
233
+ [-0.06716369, -0.07930522, -0.08920974, -0.09815053, -0.10770469],
234
+ [-0.08508184, -0.09582505, -0.10310414, -0.10833713, -0.1137427],
235
+ [-0.10244609, -0.11144008, -0.1151103, -0.11618311, -0.11833992],
236
+ [-0.11764989, -0.12432244, -0.12302232, -0.12279626, -0.13328244],
237
+ ]
238
+ ],
239
+ dtype=np.float32,
240
+ )
241
+
242
+ expected_u_Re = np.array(
243
+ [
244
+ [
245
+ [-0.01043007, -0.00768077, -0.00370782, 0.00174401],
246
+ [-0.01046313, -0.00833564, -0.00534876, -0.00036892],
247
+ [-0.01149787, -0.0117521, -0.01165313, -0.00668873],
248
+ [-0.01435909, -0.01959155, -0.02610414, -0.02264688],
249
+ [-0.01590802, -0.02578601, -0.01770638, -0.00307389],
250
+ ]
251
+ ],
252
+ dtype=np.float32,
253
+ )
254
+
255
+ expected_u_Im = np.array(
256
+ [
257
+ [
258
+ [0.00068068, 0.00041515, -0.00098873, -0.00315086],
259
+ [0.00103654, 0.0007357, -0.000998, -0.00411228],
260
+ [0.0017258, 0.00158265, -0.0014292, -0.00713451],
261
+ [0.00458748, 0.00451903, 0.00046625, -0.00838845],
262
+ [0.00930313, 0.01001076, 0.00501656, 0.0004481],
263
+ ]
264
+ ],
265
+ dtype=np.float32,
266
+ )
267
+
268
+ expected_v_Re = np.array(
269
+ [
270
+ [
271
+ [0.01867937, 0.0175135, 0.0163139, 0.01373139, 0.0114212],
272
+ [0.02016588, 0.01930715, 0.01812451, 0.01638661, 0.01130948],
273
+ [0.02174281, 0.02100098, 0.02242658, 0.02136369, 0.01128179],
274
+ [0.02275964, 0.0218871, 0.02481382, 0.01351837, 0.00718958],
275
+ ]
276
+ ],
277
+ dtype=np.float32,
278
+ )
279
+
280
+ expected_v_Im = np.array(
281
+ [
282
+ [
283
+ [-0.00304336, 0.00069296, 0.00384371, 0.00627055, 0.00745201],
284
+ [-0.00472402, -0.00109876, 0.0024061, 0.00627166, 0.00790893],
285
+ [-0.00699575, -0.00359212, 0.00066638, 0.00706607, 0.01097147],
286
+ [-0.00954442, -0.00623799, -0.00171383, 0.00425109, 0.00574474],
287
+ ]
288
+ ],
289
+ dtype=np.float32,
290
+ )
291
+
292
+ # Check the values in the dataset
293
+ assert np.allclose(tidal_forcing.ds["ssh_Re"].values, expected_ssh_Re)
294
+ assert np.allclose(tidal_forcing.ds["ssh_Im"].values, expected_ssh_Im)
295
+ assert np.allclose(tidal_forcing.ds["pot_Re"].values, expected_pot_Re)
296
+ assert np.allclose(tidal_forcing.ds["pot_Im"].values, expected_pot_Im)
297
+ assert np.allclose(tidal_forcing.ds["u_Re"].values, expected_u_Re)
298
+ assert np.allclose(tidal_forcing.ds["u_Im"].values, expected_u_Im)
299
+ assert np.allclose(tidal_forcing.ds["v_Re"].values, expected_v_Re)
300
+ assert np.allclose(tidal_forcing.ds["v_Im"].values, expected_v_Im)
301
+
302
+ tidal_forcing.plot(varname="ssh_Re", ntides=0)
303
+
304
+ # Create a temporary file
305
+ with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
306
+ filepath = tmpfile.name
307
+
308
+ tidal_forcing.save(filepath)
309
+
310
+ try:
311
+ assert os.path.exists(filepath)
312
+ finally:
313
+ os.remove(filepath)
314
+
315
+
316
+ def test_roundtrip_yaml(tidal_forcing):
317
+ """Test that creating a TidalForcing object, saving its parameters to yaml file, and re-opening yaml file creates the same object."""
318
+
319
+ # Create a temporary file
320
+ with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
321
+ filepath = tmpfile.name
322
+
323
+ try:
324
+ tidal_forcing.to_yaml(filepath)
325
+
326
+ tidal_forcing_from_file = TidalForcing.from_yaml(filepath)
327
+
328
+ assert tidal_forcing == tidal_forcing_from_file
329
+
330
+ finally:
331
+ os.remove(filepath)
332
+
333
+
334
+ def test_from_yaml_missing_tidal_forcing():
335
+ yaml_content = textwrap.dedent(
336
+ """\
337
+ ---
338
+ roms_tools_version: 0.0.0
339
+ ---
340
+ Grid:
341
+ nx: 100
342
+ ny: 100
343
+ size_x: 1800
344
+ size_y: 2400
345
+ center_lon: -10
346
+ center_lat: 61
347
+ rot: -20
348
+ topography_source: ETOPO5
349
+ smooth_factor: 8
350
+ hmin: 5.0
351
+ rmax: 0.2
352
+ """
353
+ )
354
+
355
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
356
+ yaml_filepath = tmp_file.name
357
+ tmp_file.write(yaml_content.encode())
358
+
359
+ try:
360
+ with pytest.raises(
361
+ ValueError, match="No TidalForcing configuration found in the YAML file."
362
+ ):
363
+ TidalForcing.from_yaml(yaml_filepath)
364
+ finally:
365
+ os.remove(yaml_filepath)
@@ -0,0 +1,78 @@
1
+ from roms_tools import Grid
2
+ from roms_tools.setup.topography import _compute_rfactor
3
+ import numpy as np
4
+ import numpy.testing as npt
5
+ from scipy.ndimage import label
6
+
7
+
8
+ def test_enclosed_regions():
9
+ """Test that there are only two connected regions, one dry and one wet."""
10
+
11
+ grid = Grid(
12
+ nx=100,
13
+ ny=100,
14
+ size_x=1800,
15
+ size_y=2400,
16
+ center_lon=30,
17
+ center_lat=61,
18
+ rot=20,
19
+ )
20
+
21
+ reg, nreg = label(grid.ds.mask_rho)
22
+ npt.assert_equal(nreg, 2)
23
+
24
+
25
+ def test_rmax_criterion():
26
+ grid = Grid(
27
+ nx=100,
28
+ ny=100,
29
+ size_x=1800,
30
+ size_y=2400,
31
+ center_lon=30,
32
+ center_lat=61,
33
+ rot=20,
34
+ smooth_factor=4,
35
+ rmax=0.2,
36
+ )
37
+ r_eta, r_xi = _compute_rfactor(grid.ds.h)
38
+ rmax0 = np.max([r_eta.max(), r_xi.max()])
39
+ npt.assert_array_less(rmax0, grid.rmax)
40
+
41
+
42
+ def test_hmin_criterion():
43
+ grid = Grid(
44
+ nx=100,
45
+ ny=100,
46
+ size_x=1800,
47
+ size_y=2400,
48
+ center_lon=30,
49
+ center_lat=61,
50
+ rot=20,
51
+ smooth_factor=2,
52
+ rmax=0.2,
53
+ hmin=5,
54
+ )
55
+
56
+ assert np.less_equal(grid.hmin, grid.ds.h.min())
57
+
58
+
59
+ def test_data_consistency():
60
+ """
61
+ Test that the topography generation remains consistent.
62
+ """
63
+
64
+ grid = Grid(
65
+ nx=3, ny=3, size_x=1500, size_y=1500, center_lon=235, center_lat=25, rot=-20
66
+ )
67
+
68
+ expected_h = np.array(
69
+ [
70
+ [4505.16995868, 4505.16995868, 4407.37986032, 4306.51226663, 4306.51226663],
71
+ [4505.16995868, 4505.16995868, 4407.37986032, 4306.51226663, 4306.51226663],
72
+ [4400.69482254, 4400.69482254, 3940.84931344, 3060.19573878, 3060.19573878],
73
+ [4234.97356606, 4234.97356606, 2880.90226836, 2067.46801754, 2067.46801754],
74
+ [4234.97356606, 4234.97356606, 2880.90226836, 2067.46801754, 2067.46801754],
75
+ ]
76
+ )
77
+
78
+ assert np.allclose(grid.ds["h"].values, expected_h)
@@ -0,0 +1,16 @@
1
+ from roms_tools.setup.utils import interpolate_from_climatology
2
+ from roms_tools.setup.datasets import ERA5Correction
3
+ from roms_tools.setup.download import download_test_data
4
+ import xarray as xr
5
+
6
+
7
+ def test_interpolate_from_climatology():
8
+
9
+ fname = download_test_data("ERA5_regional_test_data.nc")
10
+ era5_times = xr.open_dataset(fname).time
11
+
12
+ climatology = ERA5Correction()
13
+ field = climatology.ds["ssr_corr"]
14
+
15
+ interpolated_field = interpolate_from_climatology(field, "time", era5_times)
16
+ assert len(interpolated_field.time) == len(era5_times)