roms-tools 2.2.0__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. roms_tools/__init__.py +1 -0
  2. roms_tools/analysis/roms_output.py +586 -0
  3. roms_tools/{setup/download.py → download.py} +3 -0
  4. roms_tools/{setup/plot.py → plot.py} +34 -28
  5. roms_tools/setup/boundary_forcing.py +23 -12
  6. roms_tools/setup/datasets.py +2 -135
  7. roms_tools/setup/grid.py +54 -15
  8. roms_tools/setup/initial_conditions.py +105 -149
  9. roms_tools/setup/nesting.py +4 -4
  10. roms_tools/setup/river_forcing.py +7 -9
  11. roms_tools/setup/surface_forcing.py +14 -14
  12. roms_tools/setup/tides.py +24 -21
  13. roms_tools/setup/topography.py +1 -1
  14. roms_tools/setup/utils.py +19 -143
  15. roms_tools/tests/test_analysis/test_roms_output.py +269 -0
  16. roms_tools/tests/{test_setup/test_regrid.py → test_regrid.py} +1 -1
  17. roms_tools/tests/test_setup/test_boundary_forcing.py +1 -1
  18. roms_tools/tests/test_setup/test_datasets.py +1 -1
  19. roms_tools/tests/test_setup/test_grid.py +1 -1
  20. roms_tools/tests/test_setup/test_initial_conditions.py +8 -4
  21. roms_tools/tests/test_setup/test_river_forcing.py +1 -1
  22. roms_tools/tests/test_setup/test_surface_forcing.py +1 -1
  23. roms_tools/tests/test_setup/test_tides.py +1 -1
  24. roms_tools/tests/test_setup/test_topography.py +1 -1
  25. roms_tools/tests/test_setup/test_utils.py +56 -1
  26. roms_tools/utils.py +301 -0
  27. roms_tools/vertical_coordinate.py +306 -0
  28. {roms_tools-2.2.0.dist-info → roms_tools-2.3.0.dist-info}/METADATA +1 -1
  29. {roms_tools-2.2.0.dist-info → roms_tools-2.3.0.dist-info}/RECORD +33 -31
  30. roms_tools/setup/vertical_coordinate.py +0 -109
  31. /roms_tools/{setup/regrid.py → regrid.py} +0 -0
  32. {roms_tools-2.2.0.dist-info → roms_tools-2.3.0.dist-info}/LICENSE +0 -0
  33. {roms_tools-2.2.0.dist-info → roms_tools-2.3.0.dist-info}/WHEEL +0 -0
  34. {roms_tools-2.2.0.dist-info → roms_tools-2.3.0.dist-info}/top_level.txt +0 -0
roms_tools/setup/utils.py CHANGED
@@ -9,6 +9,7 @@ from datetime import datetime
9
9
  from dataclasses import fields, asdict
10
10
  import importlib.metadata
11
11
  import yaml
12
+ from roms_tools.utils import interpolate_from_rho_to_u, interpolate_from_rho_to_v
12
13
 
13
14
 
14
15
  def nan_check(field, mask, error_message=None) -> None:
@@ -71,100 +72,6 @@ def substitute_nans_by_fillvalue(field, fill_value=0.0) -> xr.DataArray:
71
72
  return field.fillna(fill_value)
72
73
 
73
74
 
74
- def interpolate_from_rho_to_u(field, method="additive"):
75
- """Interpolates the given field from rho points to u points.
76
-
77
- This function performs an interpolation from the rho grid (cell centers) to the u grid
78
- (cell edges in the xi direction). Depending on the chosen method, it either averages
79
- (additive) or multiplies (multiplicative) the field values between adjacent rho points
80
- along the xi dimension. It also handles the removal of unnecessary coordinate variables
81
- and updates the dimensions accordingly.
82
-
83
- Parameters
84
- ----------
85
- field : xr.DataArray
86
- The input data array on the rho grid to be interpolated. It is assumed to have a dimension
87
- named "xi_rho".
88
-
89
- method : str, optional, default='additive'
90
- The method to use for interpolation. Options are:
91
- - 'additive': Average the field values between adjacent rho points.
92
- - 'multiplicative': Multiply the field values between adjacent rho points. Appropriate for
93
- binary masks.
94
-
95
- Returns
96
- -------
97
- field_interpolated : xr.DataArray
98
- The interpolated data array on the u grid with the dimension "xi_u".
99
- """
100
-
101
- if method == "additive":
102
- field_interpolated = 0.5 * (field + field.shift(xi_rho=1)).isel(
103
- xi_rho=slice(1, None)
104
- )
105
- elif method == "multiplicative":
106
- field_interpolated = (field * field.shift(xi_rho=1)).isel(xi_rho=slice(1, None))
107
- else:
108
- raise NotImplementedError(f"Unsupported method '{method}' specified.")
109
-
110
- vars_to_drop = ["lat_rho", "lon_rho", "eta_rho", "xi_rho"]
111
- for var in vars_to_drop:
112
- if var in field_interpolated.coords:
113
- field_interpolated = field_interpolated.drop_vars(var)
114
-
115
- field_interpolated = field_interpolated.swap_dims({"xi_rho": "xi_u"})
116
-
117
- return field_interpolated
118
-
119
-
120
- def interpolate_from_rho_to_v(field, method="additive"):
121
- """Interpolates the given field from rho points to v points.
122
-
123
- This function performs an interpolation from the rho grid (cell centers) to the v grid
124
- (cell edges in the eta direction). Depending on the chosen method, it either averages
125
- (additive) or multiplies (multiplicative) the field values between adjacent rho points
126
- along the eta dimension. It also handles the removal of unnecessary coordinate variables
127
- and updates the dimensions accordingly.
128
-
129
- Parameters
130
- ----------
131
- field : xr.DataArray
132
- The input data array on the rho grid to be interpolated. It is assumed to have a dimension
133
- named "eta_rho".
134
-
135
- method : str, optional, default='additive'
136
- The method to use for interpolation. Options are:
137
- - 'additive': Average the field values between adjacent rho points.
138
- - 'multiplicative': Multiply the field values between adjacent rho points. Appropriate for
139
- binary masks.
140
-
141
- Returns
142
- -------
143
- field_interpolated : xr.DataArray
144
- The interpolated data array on the v grid with the dimension "eta_v".
145
- """
146
-
147
- if method == "additive":
148
- field_interpolated = 0.5 * (field + field.shift(eta_rho=1)).isel(
149
- eta_rho=slice(1, None)
150
- )
151
- elif method == "multiplicative":
152
- field_interpolated = (field * field.shift(eta_rho=1)).isel(
153
- eta_rho=slice(1, None)
154
- )
155
- else:
156
- raise NotImplementedError(f"Unsupported method '{method}' specified.")
157
-
158
- vars_to_drop = ["lat_rho", "lon_rho", "eta_rho", "xi_rho"]
159
- for var in vars_to_drop:
160
- if var in field_interpolated.coords:
161
- field_interpolated = field_interpolated.drop_vars(var)
162
-
163
- field_interpolated = field_interpolated.swap_dims({"eta_rho": "eta_v"})
164
-
165
- return field_interpolated
166
-
167
-
168
75
  def one_dim_fill(da: xr.DataArray, dim: str, direction="forward") -> xr.DataArray:
169
76
  """Fill NaN values in a DataArray along a specified dimension.
170
77
 
@@ -863,45 +770,6 @@ def compute_barotropic_velocity(
863
770
  return vel_bar
864
771
 
865
772
 
866
- def transpose_dimensions(da: xr.DataArray) -> xr.DataArray:
867
- """Transpose the dimensions of an xarray.DataArray to ensure that 'time', any
868
- dimension starting with 's_', 'eta_', and 'xi_' are ordered first, followed by the
869
- remaining dimensions in their original order.
870
-
871
- Parameters
872
- ----------
873
- da : xarray.DataArray
874
- The input DataArray whose dimensions are to be reordered.
875
-
876
- Returns
877
- -------
878
- xarray.DataArray
879
- The DataArray with dimensions reordered so that 'time', 's_*', 'eta_*',
880
- and 'xi_*' are first, in that order, if they exist.
881
- """
882
-
883
- # List of preferred dimension patterns
884
- preferred_order = ["time", "s_", "eta_", "xi_"]
885
-
886
- # Get the existing dimensions in the DataArray
887
- dims = list(da.dims)
888
-
889
- # Collect dimensions that match any of the preferred patterns
890
- matched_dims = []
891
- for pattern in preferred_order:
892
- # Find dimensions that start with the pattern
893
- matched_dims += [dim for dim in dims if dim.startswith(pattern)]
894
-
895
- # Create a new order: first the matched dimensions, then the rest
896
- remaining_dims = [dim for dim in dims if dim not in matched_dims]
897
- new_order = matched_dims + remaining_dims
898
-
899
- # Transpose the DataArray to the new order
900
- transposed_da = da.transpose(*new_order)
901
-
902
- return transposed_da
903
-
904
-
905
773
  def get_vector_pairs(variable_info):
906
774
  """Extracts all unique vector pairs from the variable_info dictionary.
907
775
 
@@ -1079,12 +947,20 @@ def _to_yaml(forcing_object, filepath: Union[str, Path]) -> None:
1079
947
 
1080
948
  grid_yaml_data = {**parent_grid_yaml_data, **child_grid_yaml_data}
1081
949
 
1082
- # Ensure forcing_object.source.path is a string (convert if it's a pathlib object)
1083
- if hasattr(forcing_object, "source") and "path" in forcing_object.source:
1084
- forcing_object.source["path"] = str(forcing_object.source["path"])
950
+ # Step 2: Ensure Paths are Strings
951
+ def ensure_paths_are_strings(obj, key):
952
+ attr = getattr(obj, key, None)
953
+ if attr is not None and "path" in attr:
954
+ paths = attr["path"]
955
+ if isinstance(paths, list):
956
+ attr["path"] = [str(p) if isinstance(p, Path) else p for p in paths]
957
+ elif isinstance(paths, Path):
958
+ attr["path"] = str(paths)
959
+
960
+ ensure_paths_are_strings(forcing_object, "source")
961
+ ensure_paths_are_strings(forcing_object, "bgc_source")
1085
962
 
1086
- # Step 2: Get ROMS Tools version
1087
- # Fetch the version of the 'roms-tools' package for inclusion in the YAML header
963
+ # Step 3: Get ROMS Tools version
1088
964
  try:
1089
965
  roms_tools_version = importlib.metadata.version("roms-tools")
1090
966
  except importlib.metadata.PackageNotFoundError:
@@ -1093,8 +969,7 @@ def _to_yaml(forcing_object, filepath: Union[str, Path]) -> None:
1093
969
  # Create YAML header with version information
1094
970
  header = f"---\nroms_tools_version: {roms_tools_version}\n---\n"
1095
971
 
1096
- # Step 3: Prepare Forcing Data
1097
- # Prepare the forcing object fields, excluding 'grid' and 'ds'
972
+ # Step 4: Prepare Forcing Data
1098
973
  forcing_data = {}
1099
974
  field_names = [field.name for field in fields(forcing_object)]
1100
975
  filtered_field_names = [
@@ -1123,14 +998,13 @@ def _to_yaml(forcing_object, filepath: Union[str, Path]) -> None:
1123
998
  # Add the field and its value to the forcing_data dictionary
1124
999
  forcing_data[field_name] = value
1125
1000
 
1126
- # Step 4: Combine Grid and Forcing Data
1127
- # Combine grid and forcing data into a single dictionary for the final YAML content
1001
+ # Step 5: Combine Grid and Forcing Data into a single dictionary for the final YAML content
1128
1002
  yaml_data = {
1129
1003
  **grid_yaml_data, # Add the grid data to the final YAML structure
1130
1004
  forcing_object.__class__.__name__: forcing_data, # Include the serialized forcing object data
1131
1005
  }
1132
1006
 
1133
- # Step 5: Write to YAML file
1007
+ # Step 6: Write to YAML file
1134
1008
  with filepath.open("w") as file:
1135
1009
  # Write the header first
1136
1010
  file.write(header)
@@ -1176,6 +1050,8 @@ def _from_yaml(forcing_object: Type, filepath: Union[str, Path]) -> Dict[str, An
1176
1050
  ValueError
1177
1051
  If no configuration for the specified class name is found in the YAML file.
1178
1052
  """
1053
+ # Ensure filepath is a Path object
1054
+ filepath = Path(filepath)
1179
1055
 
1180
1056
  # Read the entire file content
1181
1057
  with filepath.open("r") as file:
@@ -0,0 +1,269 @@
1
+ import pytest
2
+ from pathlib import Path
3
+ import xarray as xr
4
+ import os
5
+ import logging
6
+ from datetime import datetime
7
+ from roms_tools import Grid, ROMSOutput
8
+ from roms_tools.download import download_test_data
9
+
10
+
11
+ @pytest.fixture
12
+ def roms_output_from_restart_file(use_dask):
13
+
14
+ fname_grid = Path(download_test_data("epac25km_grd.nc"))
15
+ grid = Grid.from_file(fname_grid)
16
+
17
+ # Single file
18
+ return ROMSOutput(
19
+ grid=grid,
20
+ path=Path(download_test_data("eastpac25km_rst.19980106000000.nc")),
21
+ type="restart",
22
+ use_dask=use_dask,
23
+ )
24
+
25
+
26
+ def test_load_model_output_file(roms_output_from_restart_file, use_dask):
27
+
28
+ assert isinstance(roms_output_from_restart_file.ds, xr.Dataset)
29
+
30
+
31
+ def test_load_model_output_directory(use_dask):
32
+ fname_grid = Path(download_test_data("epac25km_grd.nc"))
33
+ grid = Grid.from_file(fname_grid)
34
+
35
+ # Download at least two files, so these will be found within the pooch directory
36
+ _ = Path(download_test_data("eastpac25km_rst.19980106000000.nc"))
37
+ _ = Path(download_test_data("eastpac25km_rst.19980126000000.nc"))
38
+
39
+ # Directory
40
+ directory = os.path.dirname(download_test_data("eastpac25km_rst.19980106000000.nc"))
41
+ output = ROMSOutput(grid=grid, path=directory, type="restart", use_dask=use_dask)
42
+ assert isinstance(output.ds, xr.Dataset)
43
+
44
+
45
+ def test_load_model_output_file_list(use_dask):
46
+ fname_grid = Path(download_test_data("epac25km_grd.nc"))
47
+ grid = Grid.from_file(fname_grid)
48
+
49
+ # List of files
50
+ file1 = Path(download_test_data("eastpac25km_rst.19980106000000.nc"))
51
+ file2 = Path(download_test_data("eastpac25km_rst.19980126000000.nc"))
52
+ output = ROMSOutput(
53
+ grid=grid, path=[file1, file2], type="restart", use_dask=use_dask
54
+ )
55
+ assert isinstance(output.ds, xr.Dataset)
56
+
57
+
58
+ def test_invalid_type(use_dask):
59
+ fname_grid = Path(download_test_data("epac25km_grd.nc"))
60
+ grid = Grid.from_file(fname_grid)
61
+
62
+ # Invalid type
63
+ with pytest.raises(ValueError, match="Invalid type 'invalid_type'"):
64
+ ROMSOutput(
65
+ grid=grid,
66
+ path=Path(download_test_data("eastpac25km_rst.19980106000000.nc")),
67
+ type="invalid_type",
68
+ use_dask=use_dask,
69
+ )
70
+
71
+
72
+ def test_invalid_path(use_dask):
73
+ fname_grid = Path(download_test_data("epac25km_grd.nc"))
74
+ grid = Grid.from_file(fname_grid)
75
+
76
+ # Non-existent file
77
+ with pytest.raises(FileNotFoundError):
78
+ ROMSOutput(
79
+ grid=grid,
80
+ path=Path("/path/to/nonexistent/file.nc"),
81
+ type="restart",
82
+ use_dask=use_dask,
83
+ )
84
+
85
+ # Non-existent directory
86
+ with pytest.raises(FileNotFoundError):
87
+ ROMSOutput(
88
+ grid=grid,
89
+ path=Path("/path/to/nonexistent/directory"),
90
+ type="restart",
91
+ use_dask=use_dask,
92
+ )
93
+
94
+
95
+ def test_set_correct_model_reference_date(use_dask):
96
+ fname_grid = Path(download_test_data("epac25km_grd.nc"))
97
+ grid = Grid.from_file(fname_grid)
98
+
99
+ output = ROMSOutput(
100
+ grid=grid,
101
+ path=Path(download_test_data("eastpac25km_rst.19980106000000.nc")),
102
+ type="restart",
103
+ use_dask=use_dask,
104
+ )
105
+ assert output.model_reference_date == datetime(1995, 1, 1)
106
+
107
+
108
+ def test_model_reference_date_mismatch(use_dask):
109
+ fname_grid = Path(download_test_data("epac25km_grd.nc"))
110
+ grid = Grid.from_file(fname_grid)
111
+
112
+ # Create a ROMSOutput with a specified model_reference_date
113
+ model_ref_date = datetime(2020, 1, 1)
114
+ with pytest.raises(
115
+ ValueError, match="Mismatch between `self.model_reference_date`"
116
+ ):
117
+ ROMSOutput(
118
+ grid=grid,
119
+ path=Path(download_test_data("eastpac25km_rst.19980106000000.nc")),
120
+ type="restart",
121
+ model_reference_date=model_ref_date,
122
+ use_dask=use_dask,
123
+ )
124
+
125
+
126
+ def test_model_reference_date_no_metadata(use_dask, tmp_path, caplog):
127
+ # Helper function to handle the test logic for cases where metadata is missing or invalid
128
+ def test_no_metadata(faulty_ocean_time_attr, expected_exception, log_message=None):
129
+ ds = xr.open_dataset(fname)
130
+ ds["ocean_time"].attrs = faulty_ocean_time_attr
131
+
132
+ # Write modified dataset to a new file
133
+ fname_mod = tmp_path / "eastpac25km_rst.19980106000000_without_metadata.nc"
134
+ ds.to_netcdf(fname_mod)
135
+
136
+ # Test case 1: Expecting a ValueError when metadata is missing or invalid
137
+ with pytest.raises(
138
+ expected_exception,
139
+ match="Model reference date could not be inferred from the metadata",
140
+ ):
141
+ ROMSOutput(grid=grid, path=fname_mod, type="restart", use_dask=use_dask)
142
+
143
+ # Test case 2: When a model reference date is explicitly set, verify the warning
144
+ with caplog.at_level(logging.WARNING):
145
+ ROMSOutput(
146
+ grid=grid,
147
+ path=fname_mod,
148
+ model_reference_date=datetime(1995, 1, 1),
149
+ type="restart",
150
+ use_dask=use_dask,
151
+ )
152
+
153
+ if log_message:
154
+ # Verify the warning message in the log
155
+ assert log_message in caplog.text
156
+
157
+ fname_mod.unlink()
158
+
159
+ # Load grid and test data
160
+ fname_grid = Path(download_test_data("epac25km_grd.nc"))
161
+ grid = Grid.from_file(fname_grid)
162
+ fname = download_test_data("eastpac25km_rst.19980106000000.nc")
163
+
164
+ # Test 1: Ocean time attribute 'long_name' is missing
165
+ test_no_metadata({}, ValueError)
166
+
167
+ # Test 2: Ocean time attribute 'long_name' contains invalid information
168
+ test_no_metadata(
169
+ {"long_name": "some random text"},
170
+ ValueError,
171
+ "Could not infer the model reference date from the metadata.",
172
+ )
173
+
174
+
175
+ def test_compute_depth_coordinates(use_dask):
176
+ fname_grid = Path(download_test_data("epac25km_grd.nc"))
177
+ grid = Grid.from_file(fname_grid)
178
+
179
+ fname_restart1 = Path(download_test_data("eastpac25km_rst.19980106000000.nc"))
180
+ output = ROMSOutput(
181
+ grid=grid, path=fname_restart1, type="restart", use_dask=use_dask
182
+ )
183
+
184
+ # Before calling get_vertical_coordinates, check if the dataset doesn't already have depth coordinates
185
+ assert "layer_depth_rho" not in output.ds.data_vars
186
+
187
+ # Call the method to get vertical coordinates
188
+ output.compute_depth_coordinates(depth_type="layer")
189
+
190
+ # Check if the depth coordinates were added
191
+ assert "layer_depth_rho" in output.ds.data_vars
192
+
193
+
194
+ def test_check_vertical_coordinate_mismatch(use_dask):
195
+ fname_grid = Path(download_test_data("epac25km_grd.nc"))
196
+ grid = Grid.from_file(fname_grid)
197
+
198
+ fname_restart1 = Path(download_test_data("eastpac25km_rst.19980106000000.nc"))
199
+ output = ROMSOutput(
200
+ grid=grid, path=fname_restart1, type="restart", use_dask=use_dask
201
+ )
202
+
203
+ # create a mock dataset with inconsistent vertical coordinate parameters
204
+ ds_mock = output.ds.copy()
205
+
206
+ # Modify one of the vertical coordinate attributes to cause a mismatch
207
+ ds_mock.attrs["theta_s"] = 999
208
+
209
+ # Check if ValueError is raised due to mismatch
210
+ with pytest.raises(ValueError, match="theta_s from grid"):
211
+ output._check_vertical_coordinate(ds_mock)
212
+
213
+ # create a mock dataset with inconsistent vertical coordinate parameters
214
+ ds_mock = output.ds.copy()
215
+
216
+ # Modify one of the vertical coordinate attributes to cause a mismatch
217
+ ds_mock.attrs["Cs_w"] = ds_mock.attrs["Cs_w"] + 0.01
218
+
219
+ # Check if ValueError is raised due to mismatch
220
+ with pytest.raises(ValueError, match="Cs_w from grid"):
221
+ output._check_vertical_coordinate(ds_mock)
222
+
223
+
224
+ def test_that_coordinates_are_added(use_dask):
225
+ fname_grid = Path(download_test_data("epac25km_grd.nc"))
226
+ grid = Grid.from_file(fname_grid)
227
+
228
+ fname_restart1 = Path(download_test_data("eastpac25km_rst.19980106000000.nc"))
229
+ output = ROMSOutput(
230
+ grid=grid, path=fname_restart1, type="restart", use_dask=use_dask
231
+ )
232
+
233
+ assert "abs_time" in output.ds.coords
234
+ assert "lat_rho" in output.ds.coords
235
+ assert "lon_rho" in output.ds.coords
236
+
237
+
238
+ def test_plot(roms_output_from_restart_file, use_dask):
239
+
240
+ kwargs = {}
241
+ for var_name in ["temp", "u", "v"]:
242
+ roms_output_from_restart_file.plot(var_name, time=0, s=-1, **kwargs)
243
+ roms_output_from_restart_file.plot(var_name, time=0, eta=0, **kwargs)
244
+ roms_output_from_restart_file.plot(var_name, time=0, xi=0, **kwargs)
245
+ roms_output_from_restart_file.plot(var_name, time=0, eta=0, xi=0, **kwargs)
246
+ roms_output_from_restart_file.plot(var_name, time=0, s=-1, eta=0, **kwargs)
247
+
248
+ kwargs = {"depth_contours": True, "layer_contours": True}
249
+ for var_name in ["temp", "u", "v"]:
250
+ roms_output_from_restart_file.plot(var_name, time=0, s=-1, **kwargs)
251
+ roms_output_from_restart_file.plot(var_name, time=0, eta=0, **kwargs)
252
+ roms_output_from_restart_file.plot(var_name, time=0, xi=0, **kwargs)
253
+ roms_output_from_restart_file.plot(var_name, time=0, eta=0, xi=0, **kwargs)
254
+ roms_output_from_restart_file.plot(var_name, time=0, s=-1, eta=0, **kwargs)
255
+
256
+ roms_output_from_restart_file.plot("zeta", time=0, **kwargs)
257
+ roms_output_from_restart_file.plot("zeta", time=0, eta=0, **kwargs)
258
+ roms_output_from_restart_file.plot("zeta", time=0, xi=0, **kwargs)
259
+
260
+
261
+ def test_plot_errors(roms_output_from_restart_file, use_dask):
262
+ with pytest.raises(ValueError, match="Invalid time index"):
263
+ roms_output_from_restart_file.plot("temp", time=10, s=-1)
264
+ with pytest.raises(ValueError, match="Invalid input"):
265
+ roms_output_from_restart_file.plot("temp", time=0)
266
+ with pytest.raises(ValueError, match="Ambiguous input"):
267
+ roms_output_from_restart_file.plot("temp", time=0, s=-1, eta=0, xi=0)
268
+ with pytest.raises(ValueError, match="Conflicting input"):
269
+ roms_output_from_restart_file.plot("zeta", time=0, eta=0, xi=0)
@@ -1,7 +1,7 @@
1
1
  import pytest
2
2
  import numpy as np
3
3
  import xarray as xr
4
- from roms_tools.setup.regrid import VerticalRegrid
4
+ from roms_tools.regrid import VerticalRegrid
5
5
 
6
6
 
7
7
  def vertical_regridder(depth_values, layer_depth_rho_values):
@@ -3,7 +3,7 @@ from datetime import datetime
3
3
  import xarray as xr
4
4
  from roms_tools import Grid, BoundaryForcing
5
5
  import textwrap
6
- from roms_tools.setup.download import download_test_data
6
+ from roms_tools.download import download_test_data
7
7
  from conftest import calculate_file_hash
8
8
  from pathlib import Path
9
9
  import logging
@@ -9,7 +9,7 @@ from roms_tools.setup.datasets import (
9
9
  ERA5Correction,
10
10
  CESMBGCDataset,
11
11
  )
12
- from roms_tools.setup.download import download_test_data
12
+ from roms_tools.download import download_test_data
13
13
  from pathlib import Path
14
14
 
15
15
 
@@ -4,7 +4,7 @@ import xarray as xr
4
4
  from roms_tools import Grid
5
5
  import importlib.metadata
6
6
  import textwrap
7
- from roms_tools.setup.download import download_test_data
7
+ from roms_tools.download import download_test_data
8
8
  from conftest import calculate_file_hash
9
9
  from pathlib import Path
10
10
 
@@ -4,7 +4,7 @@ from roms_tools import InitialConditions, Grid
4
4
  import xarray as xr
5
5
  import numpy as np
6
6
  import textwrap
7
- from roms_tools.setup.download import download_test_data
7
+ from roms_tools.download import download_test_data
8
8
  from roms_tools.setup.datasets import CESMBGCDataset
9
9
  from pathlib import Path
10
10
  from conftest import calculate_file_hash
@@ -250,7 +250,9 @@ def test_initial_conditions_save(
250
250
  expected_filepath.unlink()
251
251
 
252
252
 
253
- def test_roundtrip_yaml(initial_conditions, tmp_path, use_dask):
253
+ def test_roundtrip_yaml(
254
+ initial_conditions_with_bgc_from_climatology, tmp_path, use_dask
255
+ ):
254
256
  """Test that creating an InitialConditions object, saving its parameters to yaml
255
257
  file, and re-opening yaml file creates the same object."""
256
258
 
@@ -261,13 +263,15 @@ def test_roundtrip_yaml(initial_conditions, tmp_path, use_dask):
261
263
  str(tmp_path / file_str),
262
264
  ]: # test for Path object and str
263
265
 
264
- initial_conditions.to_yaml(filepath)
266
+ initial_conditions_with_bgc_from_climatology.to_yaml(filepath)
265
267
 
266
268
  initial_conditions_from_file = InitialConditions.from_yaml(
267
269
  filepath, use_dask=use_dask
268
270
  )
269
271
 
270
- assert initial_conditions == initial_conditions_from_file
272
+ assert (
273
+ initial_conditions_with_bgc_from_climatology == initial_conditions_from_file
274
+ )
271
275
 
272
276
  filepath = Path(filepath)
273
277
  filepath.unlink()
@@ -6,7 +6,7 @@ import textwrap
6
6
  from pathlib import Path
7
7
  import pytest
8
8
  from conftest import calculate_file_hash
9
- from roms_tools.setup.download import download_river_data
9
+ from roms_tools.download import download_river_data
10
10
 
11
11
 
12
12
  @pytest.fixture
@@ -2,7 +2,7 @@ import pytest
2
2
  from datetime import datetime
3
3
  import xarray as xr
4
4
  from roms_tools import Grid, SurfaceForcing
5
- from roms_tools.setup.download import download_test_data
5
+ from roms_tools.download import download_test_data
6
6
  import textwrap
7
7
  from pathlib import Path
8
8
  from conftest import calculate_file_hash
@@ -1,7 +1,7 @@
1
1
  import pytest
2
2
  from roms_tools import Grid, TidalForcing
3
3
  import xarray as xr
4
- from roms_tools.setup.download import download_test_data
4
+ from roms_tools.download import download_test_data
5
5
  import textwrap
6
6
  from pathlib import Path
7
7
  from conftest import calculate_file_hash
@@ -1,7 +1,7 @@
1
1
  import pytest
2
2
  from roms_tools import Grid
3
3
  from roms_tools.setup.topography import _compute_rfactor
4
- from roms_tools.setup.download import download_test_data
4
+ from roms_tools.download import download_test_data
5
5
  import numpy as np
6
6
  import numpy.testing as npt
7
7
  from scipy.ndimage import label
@@ -1,7 +1,11 @@
1
+ from roms_tools import Grid, BoundaryForcing
1
2
  from roms_tools.setup.utils import interpolate_from_climatology
2
3
  from roms_tools.setup.datasets import ERA5Correction
3
- from roms_tools.setup.download import download_test_data
4
+ from roms_tools.download import download_test_data
4
5
  import xarray as xr
6
+ import pytest
7
+ from datetime import datetime
8
+ from pathlib import Path
5
9
 
6
10
 
7
11
  def test_interpolate_from_climatology(use_dask):
@@ -14,3 +18,54 @@ def test_interpolate_from_climatology(use_dask):
14
18
 
15
19
  interpolated_field = interpolate_from_climatology(field, "time", era5_times)
16
20
  assert len(interpolated_field.time) == len(era5_times)
21
+
22
+
23
+ # Test yaml roundtrip with multiple source files
24
+ @pytest.fixture()
25
+ def boundary_forcing_from_multiple_source_files(request, use_dask):
26
+ """Fixture for creating a BoundaryForcing object."""
27
+
28
+ grid = Grid(
29
+ nx=5,
30
+ ny=5,
31
+ size_x=100,
32
+ size_y=100,
33
+ center_lon=-8,
34
+ center_lat=60,
35
+ rot=10,
36
+ N=3, # number of vertical levels
37
+ )
38
+
39
+ fname1 = Path(download_test_data("GLORYS_NA_20120101.nc"))
40
+ fname2 = Path(download_test_data("GLORYS_NA_20121231.nc"))
41
+
42
+ return BoundaryForcing(
43
+ grid=grid,
44
+ start_time=datetime(2011, 1, 1),
45
+ end_time=datetime(2013, 1, 1),
46
+ source={"name": "GLORYS", "path": [fname1, fname2]},
47
+ use_dask=use_dask,
48
+ )
49
+
50
+
51
+ def test_roundtrip_yaml(
52
+ boundary_forcing_from_multiple_source_files, request, tmp_path, use_dask
53
+ ):
54
+ """Test that creating a BoundaryForcing object, saving its parameters to yaml file,
55
+ and re-opening yaml file creates the same object."""
56
+
57
+ # Create a temporary filepath using the tmp_path fixture
58
+ file_str = "test_yaml"
59
+ for filepath in [
60
+ tmp_path / file_str,
61
+ str(tmp_path / file_str),
62
+ ]: # test for Path object and str
63
+
64
+ boundary_forcing_from_multiple_source_files.to_yaml(filepath)
65
+
66
+ bdry_forcing_from_file = BoundaryForcing.from_yaml(filepath, use_dask=use_dask)
67
+
68
+ assert boundary_forcing_from_multiple_source_files == bdry_forcing_from_file
69
+
70
+ filepath = Path(filepath)
71
+ filepath.unlink()