roms-tools 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,370 @@
1
+ import pytest
2
+ from datetime import datetime
3
+ import numpy as np
4
+ import xarray as xr
5
+ from roms_tools.setup.datasets import Dataset, ERA5Correction
6
+ import tempfile
7
+ import os
8
+
9
+
10
+ @pytest.fixture
11
+ def global_dataset():
12
+ lon = np.linspace(0, 359, 360)
13
+ lat = np.linspace(-90, 90, 180)
14
+ depth = np.linspace(0, 2000, 10)
15
+ time = [
16
+ np.datetime64("2022-01-01T00:00:00"),
17
+ np.datetime64("2022-02-01T00:00:00"),
18
+ np.datetime64("2022-03-01T00:00:00"),
19
+ np.datetime64("2022-04-01T00:00:00"),
20
+ ]
21
+ data = np.random.rand(4, 10, 180, 360)
22
+ ds = xr.Dataset(
23
+ {"var": (["time", "depth", "latitude", "longitude"], data)},
24
+ coords={
25
+ "time": (["time"], time),
26
+ "depth": (["depth"], depth),
27
+ "latitude": (["latitude"], lat),
28
+ "longitude": (["longitude"], lon),
29
+ },
30
+ )
31
+ return ds
32
+
33
+
34
+ @pytest.fixture
35
+ def global_dataset_with_noon_times():
36
+ lon = np.linspace(0, 359, 360)
37
+ lat = np.linspace(-90, 90, 180)
38
+ time = [
39
+ np.datetime64("2022-01-01T12:00:00"),
40
+ np.datetime64("2022-02-01T12:00:00"),
41
+ np.datetime64("2022-03-01T12:00:00"),
42
+ np.datetime64("2022-04-01T12:00:00"),
43
+ ]
44
+ data = np.random.rand(4, 180, 360)
45
+ ds = xr.Dataset(
46
+ {"var": (["time", "latitude", "longitude"], data)},
47
+ coords={
48
+ "time": (["time"], time),
49
+ "latitude": (["latitude"], lat),
50
+ "longitude": (["longitude"], lon),
51
+ },
52
+ )
53
+ return ds
54
+
55
+
56
+ @pytest.fixture
57
+ def global_dataset_with_multiple_times_per_day():
58
+ lon = np.linspace(0, 359, 360)
59
+ lat = np.linspace(-90, 90, 180)
60
+ time = [
61
+ np.datetime64("2022-01-01T00:00:00"),
62
+ np.datetime64("2022-01-01T12:00:00"),
63
+ np.datetime64("2022-02-01T00:00:00"),
64
+ np.datetime64("2022-02-01T12:00:00"),
65
+ np.datetime64("2022-03-01T00:00:00"),
66
+ np.datetime64("2022-03-01T12:00:00"),
67
+ np.datetime64("2022-04-01T00:00:00"),
68
+ np.datetime64("2022-04-01T12:00:00"),
69
+ ]
70
+ data = np.random.rand(8, 180, 360)
71
+ ds = xr.Dataset(
72
+ {"var": (["time", "latitude", "longitude"], data)},
73
+ coords={
74
+ "time": (["time"], time),
75
+ "latitude": (["latitude"], lat),
76
+ "longitude": (["longitude"], lon),
77
+ },
78
+ )
79
+ return ds
80
+
81
+
82
+ @pytest.fixture
83
+ def non_global_dataset():
84
+ lon = np.linspace(0, 180, 181)
85
+ lat = np.linspace(-90, 90, 180)
86
+ data = np.random.rand(180, 181)
87
+ ds = xr.Dataset(
88
+ {"var": (["latitude", "longitude"], data)},
89
+ coords={"latitude": (["latitude"], lat), "longitude": (["longitude"], lon)},
90
+ )
91
+ return ds
92
+
93
+
94
+ @pytest.mark.parametrize(
95
+ "data_fixture, expected_time_values",
96
+ [
97
+ ("global_dataset", [np.datetime64("2022-02-01T00:00:00")]),
98
+ ("global_dataset_with_noon_times", [np.datetime64("2022-02-01T12:00:00")]),
99
+ (
100
+ "global_dataset_with_multiple_times_per_day",
101
+ [
102
+ np.datetime64("2022-02-01T00:00:00"),
103
+ np.datetime64("2022-02-01T12:00:00"),
104
+ ],
105
+ ),
106
+ ],
107
+ )
108
+ def test_select_times(data_fixture, expected_time_values, request):
109
+ """
110
+ Test selecting times with different datasets.
111
+ """
112
+ start_time = datetime(2022, 2, 1)
113
+ end_time = datetime(2022, 3, 1)
114
+
115
+ # Get the fixture dynamically based on the parameter
116
+ dataset = request.getfixturevalue(data_fixture)
117
+
118
+ # Create a temporary file
119
+ with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
120
+ filepath = tmpfile.name
121
+ dataset.to_netcdf(filepath)
122
+ try:
123
+ # Instantiate Dataset object using the temporary file
124
+ dataset = Dataset(
125
+ filename=filepath,
126
+ var_names={"var": "var"},
127
+ start_time=start_time,
128
+ end_time=end_time,
129
+ )
130
+
131
+ assert dataset.ds is not None
132
+ assert len(dataset.ds.time) == len(expected_time_values)
133
+ for expected_time in expected_time_values:
134
+ assert expected_time in dataset.ds.time.values
135
+ finally:
136
+ os.remove(filepath)
137
+
138
+
139
+ @pytest.mark.parametrize(
140
+ "data_fixture, expected_time_values",
141
+ [
142
+ ("global_dataset", [np.datetime64("2022-02-01T00:00:00")]),
143
+ ("global_dataset_with_noon_times", [np.datetime64("2022-02-01T12:00:00")]),
144
+ ],
145
+ )
146
+ def test_select_times_no_end_time(data_fixture, expected_time_values, request):
147
+ """
148
+ Test selecting times with only start_time specified.
149
+ """
150
+ start_time = datetime(2022, 2, 1)
151
+
152
+ # Get the fixture dynamically based on the parameter
153
+ dataset = request.getfixturevalue(data_fixture)
154
+
155
+ # Create a temporary file
156
+ with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
157
+ filepath = tmpfile.name
158
+ dataset.to_netcdf(filepath)
159
+ try:
160
+ # Instantiate Dataset object using the temporary file
161
+ dataset = Dataset(
162
+ filename=filepath, var_names={"var": "var"}, start_time=start_time
163
+ )
164
+
165
+ assert dataset.ds is not None
166
+ assert len(dataset.ds.time) == len(expected_time_values)
167
+ for expected_time in expected_time_values:
168
+ assert expected_time in dataset.ds.time.values
169
+ finally:
170
+ os.remove(filepath)
171
+
172
+
173
+ def test_multiple_matching_times(global_dataset_with_multiple_times_per_day):
174
+ """
175
+ Test handling when multiple matching times are found when end_time is not specified.
176
+ """
177
+ start_time = datetime(2022, 1, 1)
178
+
179
+ # Create a temporary file
180
+ with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
181
+ filepath = tmpfile.name
182
+ global_dataset_with_multiple_times_per_day.to_netcdf(filepath)
183
+ try:
184
+ # Instantiate Dataset object using the temporary file
185
+ with pytest.raises(
186
+ ValueError,
187
+ match="There must be exactly one time matching the start_time. Found 2 matching times.",
188
+ ):
189
+ Dataset(filename=filepath, var_names={"var": "var"}, start_time=start_time)
190
+ finally:
191
+ os.remove(filepath)
192
+
193
+
194
+ def test_no_matching_times(global_dataset):
195
+ """
196
+ Test handling when no matching times are found.
197
+ """
198
+ start_time = datetime(2021, 1, 1)
199
+ end_time = datetime(2021, 2, 1)
200
+
201
+ # Create a temporary file
202
+ with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
203
+ filepath = tmpfile.name
204
+ global_dataset.to_netcdf(filepath)
205
+ try:
206
+ # Instantiate Dataset object using the temporary file
207
+ with pytest.raises(ValueError, match="No matching times found."):
208
+ Dataset(
209
+ filename=filepath,
210
+ var_names={"var": "var"},
211
+ start_time=start_time,
212
+ end_time=end_time,
213
+ )
214
+ finally:
215
+ os.remove(filepath)
216
+
217
+
218
+ def test_reverse_latitude_choose_subdomain_negative_depth(global_dataset):
219
+ """
220
+ Test reversing latitude when it is not ascending, the choose_subdomain method, and the convert_to_negative_depth method of the Dataset class.
221
+ """
222
+ start_time = datetime(2022, 1, 1)
223
+
224
+ # Create a temporary file
225
+ with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
226
+ filepath = tmpfile.name
227
+ global_dataset["latitude"] = global_dataset["latitude"][::-1]
228
+ global_dataset.to_netcdf(filepath)
229
+ try:
230
+ # Instantiate Dataset object using the temporary file
231
+ dataset = Dataset(
232
+ filename=filepath,
233
+ var_names={"var": "var"},
234
+ dim_names={
235
+ "latitude": "latitude",
236
+ "longitude": "longitude",
237
+ "time": "time",
238
+ "depth": "depth",
239
+ },
240
+ start_time=start_time,
241
+ )
242
+
243
+ assert np.all(np.diff(dataset.ds["latitude"]) > 0)
244
+
245
+ # test choosing subdomain for domain that straddles the dateline
246
+ dataset.choose_subdomain(
247
+ latitude_range=(-10, 10), longitude_range=(-10, 10), margin=1, straddle=True
248
+ )
249
+
250
+ assert -11 <= dataset.ds["latitude"].min() <= 11
251
+ assert -11 <= dataset.ds["latitude"].max() <= 11
252
+ assert -11 <= dataset.ds["longitude"].min() <= 11
253
+ assert -11 <= dataset.ds["longitude"].max() <= 11
254
+
255
+ # test choosing subdomain for domain that does not straddle the dateline
256
+ dataset = Dataset(
257
+ filename=filepath,
258
+ var_names={"var": "var"},
259
+ dim_names={
260
+ "latitude": "latitude",
261
+ "longitude": "longitude",
262
+ "time": "time",
263
+ "depth": "depth",
264
+ },
265
+ start_time=start_time,
266
+ )
267
+ dataset.choose_subdomain(
268
+ latitude_range=(-10, 10), longitude_range=(10, 20), margin=1, straddle=False
269
+ )
270
+
271
+ assert -11 <= dataset.ds["latitude"].min() <= 11
272
+ assert -11 <= dataset.ds["latitude"].max() <= 11
273
+ assert 9 <= dataset.ds["longitude"].min() <= 21
274
+ assert 9 <= dataset.ds["longitude"].max() <= 21
275
+
276
+ dataset.convert_to_negative_depth()
277
+
278
+ assert (dataset.ds["depth"] <= 0).all()
279
+
280
+ finally:
281
+ os.remove(filepath)
282
+
283
+
284
+ def test_check_if_global_with_global_dataset(global_dataset):
285
+
286
+ with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
287
+ filepath = tmpfile.name
288
+ global_dataset.to_netcdf(filepath)
289
+ try:
290
+ dataset = Dataset(filename=filepath, var_names={"var": "var"})
291
+ is_global = dataset.check_if_global(dataset.ds)
292
+ assert is_global
293
+ finally:
294
+ os.remove(filepath)
295
+
296
+
297
+ def test_check_if_global_with_non_global_dataset(non_global_dataset):
298
+
299
+ with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
300
+ filepath = tmpfile.name
301
+ non_global_dataset.to_netcdf(filepath)
302
+ try:
303
+ dataset = Dataset(filename=filepath, var_names={"var": "var"})
304
+ is_global = dataset.check_if_global(dataset.ds)
305
+
306
+ assert not is_global
307
+ finally:
308
+ os.remove(filepath)
309
+
310
+
311
+ def test_check_dataset(global_dataset):
312
+
313
+ ds = global_dataset.copy()
314
+ ds = ds.drop_vars("var")
315
+
316
+ # Create a temporary file
317
+ with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
318
+ filepath = tmpfile.name
319
+ ds.to_netcdf(filepath)
320
+ try:
321
+ # Instantiate Dataset object using the temporary file
322
+ start_time = datetime(2022, 2, 1)
323
+ end_time = datetime(2022, 3, 1)
324
+ with pytest.raises(
325
+ ValueError, match="Dataset does not contain all required variables."
326
+ ):
327
+
328
+ Dataset(
329
+ filename=filepath,
330
+ var_names={"var": "var"},
331
+ start_time=start_time,
332
+ end_time=end_time,
333
+ )
334
+ finally:
335
+ os.remove(filepath)
336
+
337
+ ds = global_dataset.copy()
338
+ ds = ds.rename({"latitude": "lat", "longitude": "long"})
339
+
340
+ # Create a temporary file
341
+ with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
342
+ filepath = tmpfile.name
343
+ ds.to_netcdf(filepath)
344
+ try:
345
+ # Instantiate Dataset object using the temporary file
346
+ start_time = datetime(2022, 2, 1)
347
+ end_time = datetime(2022, 3, 1)
348
+ with pytest.raises(
349
+ ValueError, match="Dataset does not contain all required dimensions."
350
+ ):
351
+
352
+ Dataset(
353
+ filename=filepath,
354
+ var_names={"var": "var"},
355
+ start_time=start_time,
356
+ end_time=end_time,
357
+ )
358
+ finally:
359
+ os.remove(filepath)
360
+
361
+
362
+ def test_era5_correction_choose_subdomain():
363
+
364
+ data = ERA5Correction()
365
+ lats = data.ds.latitude[10:20]
366
+ lons = data.ds.longitude[10:20]
367
+ coords = {"latitude": lats, "longitude": lons}
368
+ data.choose_subdomain(coords, straddle=False)
369
+ assert (data.ds["latitude"] == lats).all()
370
+ assert (data.ds["longitude"] == lons).all()
@@ -0,0 +1,226 @@
1
+ import pytest
2
+ import numpy as np
3
+ import numpy.testing as npt
4
+ from roms_tools import Grid
5
+ import os
6
+ import tempfile
7
+ import importlib.metadata
8
+ import textwrap
9
+
10
+
11
+ def test_simple_regression():
12
+ grid = Grid(nx=1, ny=1, size_x=100, size_y=100, center_lon=-20, center_lat=0, rot=0)
13
+
14
+ expected_lat = np.array(
15
+ [
16
+ [-8.99249453e-01, -8.99249453e-01, -8.99249453e-01],
17
+ [0.0, 0.0, 0.0],
18
+ [8.99249453e-01, 8.99249453e-01, 8.99249453e-01],
19
+ ]
20
+ )
21
+ expected_lon = np.array(
22
+ [
23
+ [339.10072286, 340.0, 340.89927714],
24
+ [339.10072286, 340.0, 340.89927714],
25
+ [339.10072286, 340.0, 340.89927714],
26
+ ]
27
+ )
28
+
29
+ # TODO: adapt tolerances according to order of magnitude of respective fields
30
+ npt.assert_allclose(grid.ds["lat_rho"], expected_lat, atol=1e-8)
31
+ npt.assert_allclose(grid.ds["lon_rho"], expected_lon, atol=1e-8)
32
+
33
+
34
+ def test_raise_if_domain_too_large():
35
+ with pytest.raises(ValueError, match="Domain size has to be smaller"):
36
+ Grid(nx=3, ny=3, size_x=30000, size_y=30000, center_lon=0, center_lat=51.5)
37
+
38
+ # test grid with reasonable domain size
39
+ grid = Grid(
40
+ nx=3,
41
+ ny=3,
42
+ size_x=1800,
43
+ size_y=2400,
44
+ center_lon=-21,
45
+ center_lat=61,
46
+ rot=20,
47
+ )
48
+ assert isinstance(grid, Grid)
49
+
50
+
51
+ def test_grid_straddle_crosses_meridian():
52
+ grid = Grid(
53
+ nx=3,
54
+ ny=3,
55
+ size_x=100,
56
+ size_y=100,
57
+ center_lon=0,
58
+ center_lat=61,
59
+ rot=20,
60
+ )
61
+ assert grid.straddle
62
+
63
+ grid = Grid(
64
+ nx=3,
65
+ ny=3,
66
+ size_x=100,
67
+ size_y=100,
68
+ center_lon=180,
69
+ center_lat=61,
70
+ rot=20,
71
+ )
72
+ assert not grid.straddle
73
+
74
+
75
+ def test_roundtrip_netcdf():
76
+ """Test that creating a grid, saving it to file, and re-opening it is the same as just creating it."""
77
+
78
+ # Initialize a Grid object using the initializer
79
+ grid_init = Grid(
80
+ nx=10,
81
+ ny=15,
82
+ size_x=100.0,
83
+ size_y=150.0,
84
+ center_lon=0.0,
85
+ center_lat=0.0,
86
+ rot=0.0,
87
+ topography_source="ETOPO5",
88
+ smooth_factor=2,
89
+ hmin=5.0,
90
+ rmax=0.2,
91
+ )
92
+
93
+ # Create a temporary file
94
+ with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
95
+ filepath = tmpfile.name
96
+
97
+ try:
98
+ # Save the grid to a file
99
+ grid_init.save(filepath)
100
+
101
+ # Load the grid from the file
102
+ grid_from_file = Grid.from_file(filepath)
103
+
104
+ # Assert that the initial grid and the loaded grid are equivalent (including the 'ds' attribute)
105
+ assert grid_init == grid_from_file
106
+
107
+ finally:
108
+ os.remove(filepath)
109
+
110
+
111
+ def test_roundtrip_yaml():
112
+ """Test that creating a grid, saving its parameters to yaml file, and re-opening yaml file creates the same grid."""
113
+
114
+ # Initialize a Grid object using the initializer
115
+ grid_init = Grid(
116
+ nx=10,
117
+ ny=15,
118
+ size_x=100.0,
119
+ size_y=150.0,
120
+ center_lon=0.0,
121
+ center_lat=0.0,
122
+ rot=0.0,
123
+ topography_source="ETOPO5",
124
+ smooth_factor=2,
125
+ hmin=5.0,
126
+ rmax=0.2,
127
+ )
128
+
129
+ # Create a temporary file
130
+ with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
131
+ filepath = tmpfile.name
132
+
133
+ try:
134
+ grid_init.to_yaml(filepath)
135
+
136
+ grid_from_file = Grid.from_yaml(filepath)
137
+
138
+ # Assert that the initial grid and the loaded grid are equivalent (including the 'ds' attribute)
139
+ assert grid_init == grid_from_file
140
+
141
+ finally:
142
+ os.remove(filepath)
143
+
144
+
145
+ def test_from_yaml_missing_version():
146
+
147
+ yaml_content = textwrap.dedent(
148
+ """\
149
+ Grid:
150
+ nx: 100
151
+ ny: 100
152
+ size_x: 1800
153
+ size_y: 2400
154
+ center_lon: -10
155
+ center_lat: 61
156
+ rot: -20
157
+ topography_source: ETOPO5
158
+ smooth_factor: 8
159
+ hmin: 5.0
160
+ rmax: 0.2
161
+ """
162
+ )
163
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
164
+ yaml_filepath = tmp_file.name
165
+ tmp_file.write(yaml_content.encode())
166
+
167
+ try:
168
+ with pytest.raises(
169
+ ValueError, match="Version of ROMS-Tools not found in the YAML file."
170
+ ):
171
+ Grid.from_yaml(yaml_filepath)
172
+ finally:
173
+ os.remove(yaml_filepath)
174
+
175
+
176
+ def test_from_yaml_missing_grid():
177
+ roms_tools_version = importlib.metadata.version("roms-tools")
178
+
179
+ yaml_content = f"---\nroms_tools_version: {roms_tools_version}\n---\n"
180
+
181
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
182
+ yaml_filepath = tmp_file.name
183
+ tmp_file.write(yaml_content.encode())
184
+
185
+ try:
186
+ with pytest.raises(
187
+ ValueError, match="No Grid configuration found in the YAML file."
188
+ ):
189
+ Grid.from_yaml(yaml_filepath)
190
+ finally:
191
+ os.remove(yaml_filepath)
192
+
193
+
194
+ def test_from_yaml_version_mismatch():
195
+ yaml_content = textwrap.dedent(
196
+ """\
197
+ ---
198
+ roms_tools_version: 0.0.0
199
+ ---
200
+ Grid:
201
+ nx: 100
202
+ ny: 100
203
+ size_x: 1800
204
+ size_y: 2400
205
+ center_lon: -10
206
+ center_lat: 61
207
+ rot: -20
208
+ topography_source: ETOPO5
209
+ smooth_factor: 8
210
+ hmin: 5.0
211
+ rmax: 0.2
212
+ """
213
+ )
214
+
215
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
216
+ yaml_filepath = tmp_file.name
217
+ tmp_file.write(yaml_content.encode())
218
+
219
+ try:
220
+ with pytest.warns(
221
+ UserWarning,
222
+ match="Current roms-tools version.*does not match the version in the YAML header.*",
223
+ ):
224
+ Grid.from_yaml(yaml_filepath)
225
+ finally:
226
+ os.remove(yaml_filepath)