roms-tools 0.1.0__py3-none-any.whl → 0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,366 @@
1
+ import pytest
2
+ import tempfile
3
+ import os
4
+ from roms_tools import Grid, TidalForcing
5
+ import xarray as xr
6
+ import numpy as np
7
+ from roms_tools.setup.datasets import download_test_data
8
+ import textwrap
9
+
10
+
11
+ @pytest.fixture
12
+ def grid_that_lies_within_bounds_of_regional_tpxo_data():
13
+ grid = Grid(
14
+ nx=3, ny=3, size_x=1500, size_y=1500, center_lon=235, center_lat=25, rot=-20
15
+ )
16
+ return grid
17
+
18
+
19
+ @pytest.fixture
20
+ def grid_that_is_out_of_bounds_of_regional_tpxo_data():
21
+ grid = Grid(
22
+ nx=3, ny=3, size_x=1800, size_y=1500, center_lon=235, center_lat=25, rot=-20
23
+ )
24
+ return grid
25
+
26
+
27
+ @pytest.fixture
28
+ def grid_that_straddles_dateline():
29
+ """
30
+ Fixture for creating a domain that straddles the dateline.
31
+ """
32
+ grid = Grid(
33
+ nx=5,
34
+ ny=5,
35
+ size_x=1800,
36
+ size_y=2400,
37
+ center_lon=-10,
38
+ center_lat=30,
39
+ rot=20,
40
+ )
41
+
42
+ return grid
43
+
44
+
45
+ @pytest.fixture
46
+ def grid_that_straddles_180_degree_meridian():
47
+ """
48
+ Fixture for creating a domain that straddles 180 degree meridian.
49
+ """
50
+
51
+ grid = Grid(
52
+ nx=5,
53
+ ny=5,
54
+ size_x=1800,
55
+ size_y=2400,
56
+ center_lon=180,
57
+ center_lat=30,
58
+ rot=20,
59
+ )
60
+
61
+ return grid
62
+
63
+
64
+ @pytest.mark.parametrize(
65
+ "grid_fixture",
66
+ [
67
+ "grid_that_lies_within_bounds_of_regional_tpxo_data",
68
+ "grid_that_is_out_of_bounds_of_regional_tpxo_data",
69
+ "grid_that_straddles_dateline",
70
+ "grid_that_straddles_180_degree_meridian",
71
+ ],
72
+ )
73
+ def test_successful_initialization_with_global_data(grid_fixture, request):
74
+
75
+ fname = download_test_data("TPXO_global_test_data.nc")
76
+
77
+ grid = request.getfixturevalue(grid_fixture)
78
+
79
+ tidal_forcing = TidalForcing(grid=grid, filename=fname, source="TPXO", ntides=2)
80
+
81
+ assert isinstance(tidal_forcing.ds, xr.Dataset)
82
+ assert "omega" in tidal_forcing.ds
83
+ assert "ssh_Re" in tidal_forcing.ds
84
+ assert "ssh_Im" in tidal_forcing.ds
85
+ assert "pot_Re" in tidal_forcing.ds
86
+ assert "pot_Im" in tidal_forcing.ds
87
+ assert "u_Re" in tidal_forcing.ds
88
+ assert "u_Im" in tidal_forcing.ds
89
+ assert "v_Re" in tidal_forcing.ds
90
+ assert "v_Im" in tidal_forcing.ds
91
+
92
+ assert tidal_forcing.filename == fname
93
+ assert tidal_forcing.source == "TPXO"
94
+ assert tidal_forcing.ntides == 2
95
+
96
+
97
+ def test_successful_initialization_with_regional_data(
98
+ grid_that_lies_within_bounds_of_regional_tpxo_data,
99
+ ):
100
+
101
+ fname = download_test_data("TPXO_regional_test_data.nc")
102
+
103
+ tidal_forcing = TidalForcing(
104
+ grid=grid_that_lies_within_bounds_of_regional_tpxo_data,
105
+ filename=fname,
106
+ source="TPXO",
107
+ ntides=10,
108
+ )
109
+
110
+ assert isinstance(tidal_forcing.ds, xr.Dataset)
111
+ assert "omega" in tidal_forcing.ds
112
+ assert "ssh_Re" in tidal_forcing.ds
113
+ assert "ssh_Im" in tidal_forcing.ds
114
+ assert "pot_Re" in tidal_forcing.ds
115
+ assert "pot_Im" in tidal_forcing.ds
116
+ assert "u_Re" in tidal_forcing.ds
117
+ assert "u_Im" in tidal_forcing.ds
118
+ assert "v_Re" in tidal_forcing.ds
119
+ assert "v_Im" in tidal_forcing.ds
120
+
121
+ assert tidal_forcing.filename == fname
122
+ assert tidal_forcing.source == "TPXO"
123
+ assert tidal_forcing.ntides == 10
124
+
125
+
126
+ def test_unsuccessful_initialization_with_regional_data_due_to_nans(
127
+ grid_that_is_out_of_bounds_of_regional_tpxo_data,
128
+ ):
129
+
130
+ fname = download_test_data("TPXO_regional_test_data.nc")
131
+
132
+ with pytest.raises(ValueError, match="NaN values found"):
133
+ TidalForcing(
134
+ grid=grid_that_is_out_of_bounds_of_regional_tpxo_data,
135
+ filename=fname,
136
+ source="TPXO",
137
+ ntides=10,
138
+ )
139
+
140
+
141
+ @pytest.mark.parametrize(
142
+ "grid_fixture",
143
+ ["grid_that_straddles_dateline", "grid_that_straddles_180_degree_meridian"],
144
+ )
145
+ def test_unsuccessful_initialization_with_regional_data_due_to_no_overlap(
146
+ grid_fixture, request
147
+ ):
148
+
149
+ fname = download_test_data("TPXO_regional_test_data.nc")
150
+
151
+ grid = request.getfixturevalue(grid_fixture)
152
+
153
+ with pytest.raises(
154
+ ValueError, match="Selected longitude range does not intersect with dataset"
155
+ ):
156
+ TidalForcing(grid=grid, filename=fname, source="TPXO", ntides=10)
157
+
158
+
159
+ def test_insufficient_number_of_consituents(grid_that_straddles_dateline):
160
+
161
+ fname = download_test_data("TPXO_global_test_data.nc")
162
+
163
+ with pytest.raises(ValueError, match="The dataset contains fewer"):
164
+ TidalForcing(
165
+ grid=grid_that_straddles_dateline, filename=fname, source="TPXO", ntides=10
166
+ )
167
+
168
+
169
+ @pytest.fixture
170
+ def tidal_forcing(
171
+ grid_that_lies_within_bounds_of_regional_tpxo_data,
172
+ ):
173
+
174
+ fname = download_test_data("TPXO_regional_test_data.nc")
175
+
176
+ return TidalForcing(
177
+ grid=grid_that_lies_within_bounds_of_regional_tpxo_data,
178
+ filename=fname,
179
+ source="TPXO",
180
+ ntides=1,
181
+ )
182
+
183
+
184
+ def test_tidal_forcing_data_consistency_plot_save(tidal_forcing, tmp_path):
185
+ """
186
+ Test that the data within the TidalForcing object remains consistent.
187
+ Also test plot and save methods in the same test since we dask arrays are already computed.
188
+ """
189
+ tidal_forcing.ds.load()
190
+
191
+ expected_ssh_Re = np.array(
192
+ [
193
+ [
194
+ [0.03362583, 0.0972546, 0.15625167, 0.20162642, 0.22505085],
195
+ [0.04829295, 0.13148762, 0.2091077, 0.26777256, 0.28947946],
196
+ [0.0574473, 0.16427538, 0.26692376, 0.335315, 0.35217384],
197
+ [0.0555277, 0.1952368, 0.32960117, 0.41684473, 0.43021917],
198
+ [0.04893931, 0.22524744, 0.39933527, 0.39793402, -0.18146336],
199
+ ]
200
+ ],
201
+ dtype=np.float32,
202
+ )
203
+
204
+ expected_ssh_Im = np.array(
205
+ [
206
+ [
207
+ [0.28492996, 0.33401084, 0.3791059, 0.40458283, 0.39344734],
208
+ [0.14864475, 0.19812492, 0.25232342, 0.29423112, 0.30793712],
209
+ [-0.01214434, 0.04206207, 0.11318521, 0.18785079, 0.24001373],
210
+ [-0.18849652, -0.13063835, -0.02998546, 0.0921034, 0.20685565],
211
+ [-0.36839223, -0.31615746, -0.18911538, -0.08607443, -0.51923835],
212
+ ]
213
+ ],
214
+ dtype=np.float32,
215
+ )
216
+
217
+ expected_pot_Re = np.array(
218
+ [
219
+ [
220
+ [-0.11110803, -0.08998635, -0.06672653, -0.04285957, -0.01980283],
221
+ [-0.10053363, -0.07692371, -0.05161811, -0.02654761, -0.00358691],
222
+ [-0.08996155, -0.06400539, -0.03679418, -0.01115401, 0.01084424],
223
+ [-0.08017255, -0.05190206, -0.0231975, 0.00163647, 0.01880641],
224
+ [-0.07144432, -0.04169955, -0.0143679, -0.00313035, 0.00145161],
225
+ ]
226
+ ],
227
+ dtype=np.float32,
228
+ )
229
+
230
+ expected_pot_Im = np.array(
231
+ [
232
+ [
233
+ [-0.05019786, -0.06314129, -0.07475527, -0.08616351, -0.09869237],
234
+ [-0.06716369, -0.07930522, -0.08920974, -0.09815053, -0.10770469],
235
+ [-0.08508184, -0.09582505, -0.10310414, -0.10833713, -0.1137427],
236
+ [-0.10244609, -0.11144008, -0.1151103, -0.11618311, -0.11833992],
237
+ [-0.11764989, -0.12432244, -0.12302232, -0.12279626, -0.13328244],
238
+ ]
239
+ ],
240
+ dtype=np.float32,
241
+ )
242
+
243
+ expected_u_Re = np.array(
244
+ [
245
+ [
246
+ [-0.01043007, -0.00768077, -0.00370782, 0.00174401],
247
+ [-0.01046313, -0.00833564, -0.00534876, -0.00036892],
248
+ [-0.01149787, -0.0117521, -0.01165313, -0.00668873],
249
+ [-0.01435909, -0.01959155, -0.02610414, -0.02264688],
250
+ [-0.01590802, -0.02578601, -0.01770638, -0.00307389],
251
+ ]
252
+ ],
253
+ dtype=np.float32,
254
+ )
255
+
256
+ expected_u_Im = np.array(
257
+ [
258
+ [
259
+ [0.00068068, 0.00041515, -0.00098873, -0.00315086],
260
+ [0.00103654, 0.0007357, -0.000998, -0.00411228],
261
+ [0.0017258, 0.00158265, -0.0014292, -0.00713451],
262
+ [0.00458748, 0.00451903, 0.00046625, -0.00838845],
263
+ [0.00930313, 0.01001076, 0.00501656, 0.0004481],
264
+ ]
265
+ ],
266
+ dtype=np.float32,
267
+ )
268
+
269
+ expected_v_Re = np.array(
270
+ [
271
+ [
272
+ [0.01867937, 0.0175135, 0.0163139, 0.01373139, 0.0114212],
273
+ [0.02016588, 0.01930715, 0.01812451, 0.01638661, 0.01130948],
274
+ [0.02174281, 0.02100098, 0.02242658, 0.02136369, 0.01128179],
275
+ [0.02275964, 0.0218871, 0.02481382, 0.01351837, 0.00718958],
276
+ ]
277
+ ],
278
+ dtype=np.float32,
279
+ )
280
+
281
+ expected_v_Im = np.array(
282
+ [
283
+ [
284
+ [-0.00304336, 0.00069296, 0.00384371, 0.00627055, 0.00745201],
285
+ [-0.00472402, -0.00109876, 0.0024061, 0.00627166, 0.00790893],
286
+ [-0.00699575, -0.00359212, 0.00066638, 0.00706607, 0.01097147],
287
+ [-0.00954442, -0.00623799, -0.00171383, 0.00425109, 0.00574474],
288
+ ]
289
+ ],
290
+ dtype=np.float32,
291
+ )
292
+
293
+ # Check the values in the dataset
294
+ assert np.allclose(tidal_forcing.ds["ssh_Re"].values, expected_ssh_Re)
295
+ assert np.allclose(tidal_forcing.ds["ssh_Im"].values, expected_ssh_Im)
296
+ assert np.allclose(tidal_forcing.ds["pot_Re"].values, expected_pot_Re)
297
+ assert np.allclose(tidal_forcing.ds["pot_Im"].values, expected_pot_Im)
298
+ assert np.allclose(tidal_forcing.ds["u_Re"].values, expected_u_Re)
299
+ assert np.allclose(tidal_forcing.ds["u_Im"].values, expected_u_Im)
300
+ assert np.allclose(tidal_forcing.ds["v_Re"].values, expected_v_Re)
301
+ assert np.allclose(tidal_forcing.ds["v_Im"].values, expected_v_Im)
302
+
303
+ tidal_forcing.plot(varname="ssh_Re", ntides=0)
304
+
305
+ # Create a temporary file
306
+ with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
307
+ filepath = tmpfile.name
308
+
309
+ tidal_forcing.save(filepath)
310
+
311
+ try:
312
+ assert os.path.exists(filepath)
313
+ finally:
314
+ os.remove(filepath)
315
+
316
+
317
+ def test_roundtrip_yaml(tidal_forcing):
318
+ """Test that creating a TidalForcing object, saving its parameters to yaml file, and re-opening yaml file creates the same object."""
319
+
320
+ # Create a temporary file
321
+ with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
322
+ filepath = tmpfile.name
323
+
324
+ try:
325
+ tidal_forcing.to_yaml(filepath)
326
+
327
+ tidal_forcing_from_file = TidalForcing.from_yaml(filepath)
328
+
329
+ assert tidal_forcing == tidal_forcing_from_file
330
+
331
+ finally:
332
+ os.remove(filepath)
333
+
334
+
335
+ def test_from_yaml_missing_tidal_forcing():
336
+ yaml_content = textwrap.dedent(
337
+ """\
338
+ ---
339
+ roms_tools_version: 0.0.0
340
+ ---
341
+ Grid:
342
+ nx: 100
343
+ ny: 100
344
+ size_x: 1800
345
+ size_y: 2400
346
+ center_lon: -10
347
+ center_lat: 61
348
+ rot: -20
349
+ topography_source: ETOPO5
350
+ smooth_factor: 8
351
+ hmin: 5.0
352
+ rmax: 0.2
353
+ """
354
+ )
355
+
356
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
357
+ yaml_filepath = tmp_file.name
358
+ tmp_file.write(yaml_content.encode())
359
+
360
+ try:
361
+ with pytest.raises(
362
+ ValueError, match="No TidalForcing configuration found in the YAML file."
363
+ ):
364
+ TidalForcing.from_yaml(yaml_filepath)
365
+ finally:
366
+ os.remove(yaml_filepath)
@@ -0,0 +1,78 @@
1
+ from roms_tools import Grid
2
+ from roms_tools.setup.topography import _compute_rfactor
3
+ import numpy as np
4
+ import numpy.testing as npt
5
+ from scipy.ndimage import label
6
+
7
+
8
+ def test_enclosed_regions():
9
+ """Test that there are only two connected regions, one dry and one wet."""
10
+
11
+ grid = Grid(
12
+ nx=100,
13
+ ny=100,
14
+ size_x=1800,
15
+ size_y=2400,
16
+ center_lon=30,
17
+ center_lat=61,
18
+ rot=20,
19
+ )
20
+
21
+ reg, nreg = label(grid.ds.mask_rho)
22
+ npt.assert_equal(nreg, 2)
23
+
24
+
25
+ def test_rmax_criterion():
26
+ grid = Grid(
27
+ nx=100,
28
+ ny=100,
29
+ size_x=1800,
30
+ size_y=2400,
31
+ center_lon=30,
32
+ center_lat=61,
33
+ rot=20,
34
+ smooth_factor=4,
35
+ rmax=0.2,
36
+ )
37
+ r_eta, r_xi = _compute_rfactor(grid.ds.h)
38
+ rmax0 = np.max([r_eta.max(), r_xi.max()])
39
+ npt.assert_array_less(rmax0, grid.rmax)
40
+
41
+
42
+ def test_hmin_criterion():
43
+ grid = Grid(
44
+ nx=100,
45
+ ny=100,
46
+ size_x=1800,
47
+ size_y=2400,
48
+ center_lon=30,
49
+ center_lat=61,
50
+ rot=20,
51
+ smooth_factor=2,
52
+ rmax=0.2,
53
+ hmin=5,
54
+ )
55
+
56
+ assert np.less_equal(grid.hmin, grid.ds.h.min())
57
+
58
+
59
+ def test_data_consistency():
60
+ """
61
+ Test that the topography generation remains consistent.
62
+ """
63
+
64
+ grid = Grid(
65
+ nx=3, ny=3, size_x=1500, size_y=1500, center_lon=235, center_lat=25, rot=-20
66
+ )
67
+
68
+ expected_h = np.array(
69
+ [
70
+ [4505.16995868, 4505.16995868, 4407.37986032, 4306.51226663, 4306.51226663],
71
+ [4505.16995868, 4505.16995868, 4407.37986032, 4306.51226663, 4306.51226663],
72
+ [4400.69482254, 4400.69482254, 3940.84931344, 3060.19573878, 3060.19573878],
73
+ [4234.97356606, 4234.97356606, 2880.90226836, 2067.46801754, 2067.46801754],
74
+ [4234.97356606, 4234.97356606, 2880.90226836, 2067.46801754, 2067.46801754],
75
+ ]
76
+ )
77
+
78
+ assert np.allclose(grid.ds["h"].values, expected_h)