roms-tools 2.2.1__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. roms_tools/__init__.py +1 -0
  2. roms_tools/analysis/roms_output.py +586 -0
  3. roms_tools/{setup/download.py → download.py} +3 -0
  4. roms_tools/{setup/plot.py → plot.py} +34 -28
  5. roms_tools/setup/boundary_forcing.py +23 -12
  6. roms_tools/setup/datasets.py +2 -135
  7. roms_tools/setup/grid.py +54 -15
  8. roms_tools/setup/initial_conditions.py +105 -149
  9. roms_tools/setup/nesting.py +4 -4
  10. roms_tools/setup/river_forcing.py +7 -9
  11. roms_tools/setup/surface_forcing.py +14 -14
  12. roms_tools/setup/tides.py +24 -21
  13. roms_tools/setup/topography.py +1 -1
  14. roms_tools/setup/utils.py +20 -154
  15. roms_tools/tests/test_analysis/test_roms_output.py +269 -0
  16. roms_tools/tests/{test_setup/test_regrid.py → test_regrid.py} +1 -1
  17. roms_tools/tests/test_setup/test_boundary_forcing.py +1 -1
  18. roms_tools/tests/test_setup/test_datasets.py +1 -1
  19. roms_tools/tests/test_setup/test_grid.py +1 -1
  20. roms_tools/tests/test_setup/test_initial_conditions.py +1 -1
  21. roms_tools/tests/test_setup/test_river_forcing.py +1 -1
  22. roms_tools/tests/test_setup/test_surface_forcing.py +1 -1
  23. roms_tools/tests/test_setup/test_tides.py +1 -1
  24. roms_tools/tests/test_setup/test_topography.py +1 -1
  25. roms_tools/tests/test_setup/test_utils.py +56 -1
  26. roms_tools/utils.py +301 -0
  27. roms_tools/vertical_coordinate.py +306 -0
  28. {roms_tools-2.2.1.dist-info → roms_tools-2.3.0.dist-info}/METADATA +1 -1
  29. {roms_tools-2.2.1.dist-info → roms_tools-2.3.0.dist-info}/RECORD +33 -31
  30. roms_tools/setup/vertical_coordinate.py +0 -109
  31. /roms_tools/{setup/regrid.py → regrid.py} +0 -0
  32. {roms_tools-2.2.1.dist-info → roms_tools-2.3.0.dist-info}/LICENSE +0 -0
  33. {roms_tools-2.2.1.dist-info → roms_tools-2.3.0.dist-info}/WHEEL +0 -0
  34. {roms_tools-2.2.1.dist-info → roms_tools-2.3.0.dist-info}/top_level.txt +0 -0
roms_tools/setup/utils.py CHANGED
@@ -9,6 +9,7 @@ from datetime import datetime
9
9
  from dataclasses import fields, asdict
10
10
  import importlib.metadata
11
11
  import yaml
12
+ from roms_tools.utils import interpolate_from_rho_to_u, interpolate_from_rho_to_v
12
13
 
13
14
 
14
15
  def nan_check(field, mask, error_message=None) -> None:
@@ -71,100 +72,6 @@ def substitute_nans_by_fillvalue(field, fill_value=0.0) -> xr.DataArray:
71
72
  return field.fillna(fill_value)
72
73
 
73
74
 
74
- def interpolate_from_rho_to_u(field, method="additive"):
75
- """Interpolates the given field from rho points to u points.
76
-
77
- This function performs an interpolation from the rho grid (cell centers) to the u grid
78
- (cell edges in the xi direction). Depending on the chosen method, it either averages
79
- (additive) or multiplies (multiplicative) the field values between adjacent rho points
80
- along the xi dimension. It also handles the removal of unnecessary coordinate variables
81
- and updates the dimensions accordingly.
82
-
83
- Parameters
84
- ----------
85
- field : xr.DataArray
86
- The input data array on the rho grid to be interpolated. It is assumed to have a dimension
87
- named "xi_rho".
88
-
89
- method : str, optional, default='additive'
90
- The method to use for interpolation. Options are:
91
- - 'additive': Average the field values between adjacent rho points.
92
- - 'multiplicative': Multiply the field values between adjacent rho points. Appropriate for
93
- binary masks.
94
-
95
- Returns
96
- -------
97
- field_interpolated : xr.DataArray
98
- The interpolated data array on the u grid with the dimension "xi_u".
99
- """
100
-
101
- if method == "additive":
102
- field_interpolated = 0.5 * (field + field.shift(xi_rho=1)).isel(
103
- xi_rho=slice(1, None)
104
- )
105
- elif method == "multiplicative":
106
- field_interpolated = (field * field.shift(xi_rho=1)).isel(xi_rho=slice(1, None))
107
- else:
108
- raise NotImplementedError(f"Unsupported method '{method}' specified.")
109
-
110
- vars_to_drop = ["lat_rho", "lon_rho", "eta_rho", "xi_rho"]
111
- for var in vars_to_drop:
112
- if var in field_interpolated.coords:
113
- field_interpolated = field_interpolated.drop_vars(var)
114
-
115
- field_interpolated = field_interpolated.swap_dims({"xi_rho": "xi_u"})
116
-
117
- return field_interpolated
118
-
119
-
120
- def interpolate_from_rho_to_v(field, method="additive"):
121
- """Interpolates the given field from rho points to v points.
122
-
123
- This function performs an interpolation from the rho grid (cell centers) to the v grid
124
- (cell edges in the eta direction). Depending on the chosen method, it either averages
125
- (additive) or multiplies (multiplicative) the field values between adjacent rho points
126
- along the eta dimension. It also handles the removal of unnecessary coordinate variables
127
- and updates the dimensions accordingly.
128
-
129
- Parameters
130
- ----------
131
- field : xr.DataArray
132
- The input data array on the rho grid to be interpolated. It is assumed to have a dimension
133
- named "eta_rho".
134
-
135
- method : str, optional, default='additive'
136
- The method to use for interpolation. Options are:
137
- - 'additive': Average the field values between adjacent rho points.
138
- - 'multiplicative': Multiply the field values between adjacent rho points. Appropriate for
139
- binary masks.
140
-
141
- Returns
142
- -------
143
- field_interpolated : xr.DataArray
144
- The interpolated data array on the v grid with the dimension "eta_v".
145
- """
146
-
147
- if method == "additive":
148
- field_interpolated = 0.5 * (field + field.shift(eta_rho=1)).isel(
149
- eta_rho=slice(1, None)
150
- )
151
- elif method == "multiplicative":
152
- field_interpolated = (field * field.shift(eta_rho=1)).isel(
153
- eta_rho=slice(1, None)
154
- )
155
- else:
156
- raise NotImplementedError(f"Unsupported method '{method}' specified.")
157
-
158
- vars_to_drop = ["lat_rho", "lon_rho", "eta_rho", "xi_rho"]
159
- for var in vars_to_drop:
160
- if var in field_interpolated.coords:
161
- field_interpolated = field_interpolated.drop_vars(var)
162
-
163
- field_interpolated = field_interpolated.swap_dims({"eta_rho": "eta_v"})
164
-
165
- return field_interpolated
166
-
167
-
168
75
  def one_dim_fill(da: xr.DataArray, dim: str, direction="forward") -> xr.DataArray:
169
76
  """Fill NaN values in a DataArray along a specified dimension.
170
77
 
@@ -863,45 +770,6 @@ def compute_barotropic_velocity(
863
770
  return vel_bar
864
771
 
865
772
 
866
- def transpose_dimensions(da: xr.DataArray) -> xr.DataArray:
867
- """Transpose the dimensions of an xarray.DataArray to ensure that 'time', any
868
- dimension starting with 's_', 'eta_', and 'xi_' are ordered first, followed by the
869
- remaining dimensions in their original order.
870
-
871
- Parameters
872
- ----------
873
- da : xarray.DataArray
874
- The input DataArray whose dimensions are to be reordered.
875
-
876
- Returns
877
- -------
878
- xarray.DataArray
879
- The DataArray with dimensions reordered so that 'time', 's_*', 'eta_*',
880
- and 'xi_*' are first, in that order, if they exist.
881
- """
882
-
883
- # List of preferred dimension patterns
884
- preferred_order = ["time", "s_", "eta_", "xi_"]
885
-
886
- # Get the existing dimensions in the DataArray
887
- dims = list(da.dims)
888
-
889
- # Collect dimensions that match any of the preferred patterns
890
- matched_dims = []
891
- for pattern in preferred_order:
892
- # Find dimensions that start with the pattern
893
- matched_dims += [dim for dim in dims if dim.startswith(pattern)]
894
-
895
- # Create a new order: first the matched dimensions, then the rest
896
- remaining_dims = [dim for dim in dims if dim not in matched_dims]
897
- new_order = matched_dims + remaining_dims
898
-
899
- # Transpose the DataArray to the new order
900
- transposed_da = da.transpose(*new_order)
901
-
902
- return transposed_da
903
-
904
-
905
773
  def get_vector_pairs(variable_info):
906
774
  """Extracts all unique vector pairs from the variable_info dictionary.
907
775
 
@@ -1079,22 +947,20 @@ def _to_yaml(forcing_object, filepath: Union[str, Path]) -> None:
1079
947
 
1080
948
  grid_yaml_data = {**parent_grid_yaml_data, **child_grid_yaml_data}
1081
949
 
1082
- # Ensure forcing_object.source.path is a string (convert if it's a pathlib object)
1083
- if (
1084
- hasattr(forcing_object, "source")
1085
- and forcing_object.source is not None
1086
- and "path" in forcing_object.source
1087
- ):
1088
- forcing_object.source["path"] = str(forcing_object.source["path"])
1089
- if (
1090
- hasattr(forcing_object, "bgc_source")
1091
- and forcing_object.bgc_source is not None
1092
- and "path" in forcing_object.bgc_source
1093
- ):
1094
- forcing_object.bgc_source["path"] = str(forcing_object.bgc_source["path"])
1095
-
1096
- # Step 2: Get ROMS Tools version
1097
- # Fetch the version of the 'roms-tools' package for inclusion in the YAML header
950
+ # Step 2: Ensure Paths are Strings
951
+ def ensure_paths_are_strings(obj, key):
952
+ attr = getattr(obj, key, None)
953
+ if attr is not None and "path" in attr:
954
+ paths = attr["path"]
955
+ if isinstance(paths, list):
956
+ attr["path"] = [str(p) if isinstance(p, Path) else p for p in paths]
957
+ elif isinstance(paths, Path):
958
+ attr["path"] = str(paths)
959
+
960
+ ensure_paths_are_strings(forcing_object, "source")
961
+ ensure_paths_are_strings(forcing_object, "bgc_source")
962
+
963
+ # Step 3: Get ROMS Tools version
1098
964
  try:
1099
965
  roms_tools_version = importlib.metadata.version("roms-tools")
1100
966
  except importlib.metadata.PackageNotFoundError:
@@ -1103,8 +969,7 @@ def _to_yaml(forcing_object, filepath: Union[str, Path]) -> None:
1103
969
  # Create YAML header with version information
1104
970
  header = f"---\nroms_tools_version: {roms_tools_version}\n---\n"
1105
971
 
1106
- # Step 3: Prepare Forcing Data
1107
- # Prepare the forcing object fields, excluding 'grid' and 'ds'
972
+ # Step 4: Prepare Forcing Data
1108
973
  forcing_data = {}
1109
974
  field_names = [field.name for field in fields(forcing_object)]
1110
975
  filtered_field_names = [
@@ -1133,14 +998,13 @@ def _to_yaml(forcing_object, filepath: Union[str, Path]) -> None:
1133
998
  # Add the field and its value to the forcing_data dictionary
1134
999
  forcing_data[field_name] = value
1135
1000
 
1136
- # Step 4: Combine Grid and Forcing Data
1137
- # Combine grid and forcing data into a single dictionary for the final YAML content
1001
+ # Step 5: Combine Grid and Forcing Data into a single dictionary for the final YAML content
1138
1002
  yaml_data = {
1139
1003
  **grid_yaml_data, # Add the grid data to the final YAML structure
1140
1004
  forcing_object.__class__.__name__: forcing_data, # Include the serialized forcing object data
1141
1005
  }
1142
1006
 
1143
- # Step 5: Write to YAML file
1007
+ # Step 6: Write to YAML file
1144
1008
  with filepath.open("w") as file:
1145
1009
  # Write the header first
1146
1010
  file.write(header)
@@ -1186,6 +1050,8 @@ def _from_yaml(forcing_object: Type, filepath: Union[str, Path]) -> Dict[str, An
1186
1050
  ValueError
1187
1051
  If no configuration for the specified class name is found in the YAML file.
1188
1052
  """
1053
+ # Ensure filepath is a Path object
1054
+ filepath = Path(filepath)
1189
1055
 
1190
1056
  # Read the entire file content
1191
1057
  with filepath.open("r") as file:
@@ -0,0 +1,269 @@
1
+ import pytest
2
+ from pathlib import Path
3
+ import xarray as xr
4
+ import os
5
+ import logging
6
+ from datetime import datetime
7
+ from roms_tools import Grid, ROMSOutput
8
+ from roms_tools.download import download_test_data
9
+
10
+
11
+ @pytest.fixture
12
+ def roms_output_from_restart_file(use_dask):
13
+
14
+ fname_grid = Path(download_test_data("epac25km_grd.nc"))
15
+ grid = Grid.from_file(fname_grid)
16
+
17
+ # Single file
18
+ return ROMSOutput(
19
+ grid=grid,
20
+ path=Path(download_test_data("eastpac25km_rst.19980106000000.nc")),
21
+ type="restart",
22
+ use_dask=use_dask,
23
+ )
24
+
25
+
26
+ def test_load_model_output_file(roms_output_from_restart_file, use_dask):
27
+
28
+ assert isinstance(roms_output_from_restart_file.ds, xr.Dataset)
29
+
30
+
31
+ def test_load_model_output_directory(use_dask):
32
+ fname_grid = Path(download_test_data("epac25km_grd.nc"))
33
+ grid = Grid.from_file(fname_grid)
34
+
35
+ # Download at least two files, so these will be found within the pooch directory
36
+ _ = Path(download_test_data("eastpac25km_rst.19980106000000.nc"))
37
+ _ = Path(download_test_data("eastpac25km_rst.19980126000000.nc"))
38
+
39
+ # Directory
40
+ directory = os.path.dirname(download_test_data("eastpac25km_rst.19980106000000.nc"))
41
+ output = ROMSOutput(grid=grid, path=directory, type="restart", use_dask=use_dask)
42
+ assert isinstance(output.ds, xr.Dataset)
43
+
44
+
45
+ def test_load_model_output_file_list(use_dask):
46
+ fname_grid = Path(download_test_data("epac25km_grd.nc"))
47
+ grid = Grid.from_file(fname_grid)
48
+
49
+ # List of files
50
+ file1 = Path(download_test_data("eastpac25km_rst.19980106000000.nc"))
51
+ file2 = Path(download_test_data("eastpac25km_rst.19980126000000.nc"))
52
+ output = ROMSOutput(
53
+ grid=grid, path=[file1, file2], type="restart", use_dask=use_dask
54
+ )
55
+ assert isinstance(output.ds, xr.Dataset)
56
+
57
+
58
+ def test_invalid_type(use_dask):
59
+ fname_grid = Path(download_test_data("epac25km_grd.nc"))
60
+ grid = Grid.from_file(fname_grid)
61
+
62
+ # Invalid type
63
+ with pytest.raises(ValueError, match="Invalid type 'invalid_type'"):
64
+ ROMSOutput(
65
+ grid=grid,
66
+ path=Path(download_test_data("eastpac25km_rst.19980106000000.nc")),
67
+ type="invalid_type",
68
+ use_dask=use_dask,
69
+ )
70
+
71
+
72
+ def test_invalid_path(use_dask):
73
+ fname_grid = Path(download_test_data("epac25km_grd.nc"))
74
+ grid = Grid.from_file(fname_grid)
75
+
76
+ # Non-existent file
77
+ with pytest.raises(FileNotFoundError):
78
+ ROMSOutput(
79
+ grid=grid,
80
+ path=Path("/path/to/nonexistent/file.nc"),
81
+ type="restart",
82
+ use_dask=use_dask,
83
+ )
84
+
85
+ # Non-existent directory
86
+ with pytest.raises(FileNotFoundError):
87
+ ROMSOutput(
88
+ grid=grid,
89
+ path=Path("/path/to/nonexistent/directory"),
90
+ type="restart",
91
+ use_dask=use_dask,
92
+ )
93
+
94
+
95
+ def test_set_correct_model_reference_date(use_dask):
96
+ fname_grid = Path(download_test_data("epac25km_grd.nc"))
97
+ grid = Grid.from_file(fname_grid)
98
+
99
+ output = ROMSOutput(
100
+ grid=grid,
101
+ path=Path(download_test_data("eastpac25km_rst.19980106000000.nc")),
102
+ type="restart",
103
+ use_dask=use_dask,
104
+ )
105
+ assert output.model_reference_date == datetime(1995, 1, 1)
106
+
107
+
108
+ def test_model_reference_date_mismatch(use_dask):
109
+ fname_grid = Path(download_test_data("epac25km_grd.nc"))
110
+ grid = Grid.from_file(fname_grid)
111
+
112
+ # Create a ROMSOutput with a specified model_reference_date
113
+ model_ref_date = datetime(2020, 1, 1)
114
+ with pytest.raises(
115
+ ValueError, match="Mismatch between `self.model_reference_date`"
116
+ ):
117
+ ROMSOutput(
118
+ grid=grid,
119
+ path=Path(download_test_data("eastpac25km_rst.19980106000000.nc")),
120
+ type="restart",
121
+ model_reference_date=model_ref_date,
122
+ use_dask=use_dask,
123
+ )
124
+
125
+
126
+ def test_model_reference_date_no_metadata(use_dask, tmp_path, caplog):
127
+ # Helper function to handle the test logic for cases where metadata is missing or invalid
128
+ def test_no_metadata(faulty_ocean_time_attr, expected_exception, log_message=None):
129
+ ds = xr.open_dataset(fname)
130
+ ds["ocean_time"].attrs = faulty_ocean_time_attr
131
+
132
+ # Write modified dataset to a new file
133
+ fname_mod = tmp_path / "eastpac25km_rst.19980106000000_without_metadata.nc"
134
+ ds.to_netcdf(fname_mod)
135
+
136
+ # Test case 1: Expecting a ValueError when metadata is missing or invalid
137
+ with pytest.raises(
138
+ expected_exception,
139
+ match="Model reference date could not be inferred from the metadata",
140
+ ):
141
+ ROMSOutput(grid=grid, path=fname_mod, type="restart", use_dask=use_dask)
142
+
143
+ # Test case 2: When a model reference date is explicitly set, verify the warning
144
+ with caplog.at_level(logging.WARNING):
145
+ ROMSOutput(
146
+ grid=grid,
147
+ path=fname_mod,
148
+ model_reference_date=datetime(1995, 1, 1),
149
+ type="restart",
150
+ use_dask=use_dask,
151
+ )
152
+
153
+ if log_message:
154
+ # Verify the warning message in the log
155
+ assert log_message in caplog.text
156
+
157
+ fname_mod.unlink()
158
+
159
+ # Load grid and test data
160
+ fname_grid = Path(download_test_data("epac25km_grd.nc"))
161
+ grid = Grid.from_file(fname_grid)
162
+ fname = download_test_data("eastpac25km_rst.19980106000000.nc")
163
+
164
+ # Test 1: Ocean time attribute 'long_name' is missing
165
+ test_no_metadata({}, ValueError)
166
+
167
+ # Test 2: Ocean time attribute 'long_name' contains invalid information
168
+ test_no_metadata(
169
+ {"long_name": "some random text"},
170
+ ValueError,
171
+ "Could not infer the model reference date from the metadata.",
172
+ )
173
+
174
+
175
+ def test_compute_depth_coordinates(use_dask):
176
+ fname_grid = Path(download_test_data("epac25km_grd.nc"))
177
+ grid = Grid.from_file(fname_grid)
178
+
179
+ fname_restart1 = Path(download_test_data("eastpac25km_rst.19980106000000.nc"))
180
+ output = ROMSOutput(
181
+ grid=grid, path=fname_restart1, type="restart", use_dask=use_dask
182
+ )
183
+
184
+ # Before calling get_vertical_coordinates, check if the dataset doesn't already have depth coordinates
185
+ assert "layer_depth_rho" not in output.ds.data_vars
186
+
187
+ # Call the method to get vertical coordinates
188
+ output.compute_depth_coordinates(depth_type="layer")
189
+
190
+ # Check if the depth coordinates were added
191
+ assert "layer_depth_rho" in output.ds.data_vars
192
+
193
+
194
+ def test_check_vertical_coordinate_mismatch(use_dask):
195
+ fname_grid = Path(download_test_data("epac25km_grd.nc"))
196
+ grid = Grid.from_file(fname_grid)
197
+
198
+ fname_restart1 = Path(download_test_data("eastpac25km_rst.19980106000000.nc"))
199
+ output = ROMSOutput(
200
+ grid=grid, path=fname_restart1, type="restart", use_dask=use_dask
201
+ )
202
+
203
+ # create a mock dataset with inconsistent vertical coordinate parameters
204
+ ds_mock = output.ds.copy()
205
+
206
+ # Modify one of the vertical coordinate attributes to cause a mismatch
207
+ ds_mock.attrs["theta_s"] = 999
208
+
209
+ # Check if ValueError is raised due to mismatch
210
+ with pytest.raises(ValueError, match="theta_s from grid"):
211
+ output._check_vertical_coordinate(ds_mock)
212
+
213
+ # create a mock dataset with inconsistent vertical coordinate parameters
214
+ ds_mock = output.ds.copy()
215
+
216
+ # Modify one of the vertical coordinate attributes to cause a mismatch
217
+ ds_mock.attrs["Cs_w"] = ds_mock.attrs["Cs_w"] + 0.01
218
+
219
+ # Check if ValueError is raised due to mismatch
220
+ with pytest.raises(ValueError, match="Cs_w from grid"):
221
+ output._check_vertical_coordinate(ds_mock)
222
+
223
+
224
+ def test_that_coordinates_are_added(use_dask):
225
+ fname_grid = Path(download_test_data("epac25km_grd.nc"))
226
+ grid = Grid.from_file(fname_grid)
227
+
228
+ fname_restart1 = Path(download_test_data("eastpac25km_rst.19980106000000.nc"))
229
+ output = ROMSOutput(
230
+ grid=grid, path=fname_restart1, type="restart", use_dask=use_dask
231
+ )
232
+
233
+ assert "abs_time" in output.ds.coords
234
+ assert "lat_rho" in output.ds.coords
235
+ assert "lon_rho" in output.ds.coords
236
+
237
+
238
+ def test_plot(roms_output_from_restart_file, use_dask):
239
+
240
+ kwargs = {}
241
+ for var_name in ["temp", "u", "v"]:
242
+ roms_output_from_restart_file.plot(var_name, time=0, s=-1, **kwargs)
243
+ roms_output_from_restart_file.plot(var_name, time=0, eta=0, **kwargs)
244
+ roms_output_from_restart_file.plot(var_name, time=0, xi=0, **kwargs)
245
+ roms_output_from_restart_file.plot(var_name, time=0, eta=0, xi=0, **kwargs)
246
+ roms_output_from_restart_file.plot(var_name, time=0, s=-1, eta=0, **kwargs)
247
+
248
+ kwargs = {"depth_contours": True, "layer_contours": True}
249
+ for var_name in ["temp", "u", "v"]:
250
+ roms_output_from_restart_file.plot(var_name, time=0, s=-1, **kwargs)
251
+ roms_output_from_restart_file.plot(var_name, time=0, eta=0, **kwargs)
252
+ roms_output_from_restart_file.plot(var_name, time=0, xi=0, **kwargs)
253
+ roms_output_from_restart_file.plot(var_name, time=0, eta=0, xi=0, **kwargs)
254
+ roms_output_from_restart_file.plot(var_name, time=0, s=-1, eta=0, **kwargs)
255
+
256
+ roms_output_from_restart_file.plot("zeta", time=0, **kwargs)
257
+ roms_output_from_restart_file.plot("zeta", time=0, eta=0, **kwargs)
258
+ roms_output_from_restart_file.plot("zeta", time=0, xi=0, **kwargs)
259
+
260
+
261
+ def test_plot_errors(roms_output_from_restart_file, use_dask):
262
+ with pytest.raises(ValueError, match="Invalid time index"):
263
+ roms_output_from_restart_file.plot("temp", time=10, s=-1)
264
+ with pytest.raises(ValueError, match="Invalid input"):
265
+ roms_output_from_restart_file.plot("temp", time=0)
266
+ with pytest.raises(ValueError, match="Ambiguous input"):
267
+ roms_output_from_restart_file.plot("temp", time=0, s=-1, eta=0, xi=0)
268
+ with pytest.raises(ValueError, match="Conflicting input"):
269
+ roms_output_from_restart_file.plot("zeta", time=0, eta=0, xi=0)
@@ -1,7 +1,7 @@
1
1
  import pytest
2
2
  import numpy as np
3
3
  import xarray as xr
4
- from roms_tools.setup.regrid import VerticalRegrid
4
+ from roms_tools.regrid import VerticalRegrid
5
5
 
6
6
 
7
7
  def vertical_regridder(depth_values, layer_depth_rho_values):
@@ -3,7 +3,7 @@ from datetime import datetime
3
3
  import xarray as xr
4
4
  from roms_tools import Grid, BoundaryForcing
5
5
  import textwrap
6
- from roms_tools.setup.download import download_test_data
6
+ from roms_tools.download import download_test_data
7
7
  from conftest import calculate_file_hash
8
8
  from pathlib import Path
9
9
  import logging
@@ -9,7 +9,7 @@ from roms_tools.setup.datasets import (
9
9
  ERA5Correction,
10
10
  CESMBGCDataset,
11
11
  )
12
- from roms_tools.setup.download import download_test_data
12
+ from roms_tools.download import download_test_data
13
13
  from pathlib import Path
14
14
 
15
15
 
@@ -4,7 +4,7 @@ import xarray as xr
4
4
  from roms_tools import Grid
5
5
  import importlib.metadata
6
6
  import textwrap
7
- from roms_tools.setup.download import download_test_data
7
+ from roms_tools.download import download_test_data
8
8
  from conftest import calculate_file_hash
9
9
  from pathlib import Path
10
10
 
@@ -4,7 +4,7 @@ from roms_tools import InitialConditions, Grid
4
4
  import xarray as xr
5
5
  import numpy as np
6
6
  import textwrap
7
- from roms_tools.setup.download import download_test_data
7
+ from roms_tools.download import download_test_data
8
8
  from roms_tools.setup.datasets import CESMBGCDataset
9
9
  from pathlib import Path
10
10
  from conftest import calculate_file_hash
@@ -6,7 +6,7 @@ import textwrap
6
6
  from pathlib import Path
7
7
  import pytest
8
8
  from conftest import calculate_file_hash
9
- from roms_tools.setup.download import download_river_data
9
+ from roms_tools.download import download_river_data
10
10
 
11
11
 
12
12
  @pytest.fixture
@@ -2,7 +2,7 @@ import pytest
2
2
  from datetime import datetime
3
3
  import xarray as xr
4
4
  from roms_tools import Grid, SurfaceForcing
5
- from roms_tools.setup.download import download_test_data
5
+ from roms_tools.download import download_test_data
6
6
  import textwrap
7
7
  from pathlib import Path
8
8
  from conftest import calculate_file_hash
@@ -1,7 +1,7 @@
1
1
  import pytest
2
2
  from roms_tools import Grid, TidalForcing
3
3
  import xarray as xr
4
- from roms_tools.setup.download import download_test_data
4
+ from roms_tools.download import download_test_data
5
5
  import textwrap
6
6
  from pathlib import Path
7
7
  from conftest import calculate_file_hash
@@ -1,7 +1,7 @@
1
1
  import pytest
2
2
  from roms_tools import Grid
3
3
  from roms_tools.setup.topography import _compute_rfactor
4
- from roms_tools.setup.download import download_test_data
4
+ from roms_tools.download import download_test_data
5
5
  import numpy as np
6
6
  import numpy.testing as npt
7
7
  from scipy.ndimage import label
@@ -1,7 +1,11 @@
1
+ from roms_tools import Grid, BoundaryForcing
1
2
  from roms_tools.setup.utils import interpolate_from_climatology
2
3
  from roms_tools.setup.datasets import ERA5Correction
3
- from roms_tools.setup.download import download_test_data
4
+ from roms_tools.download import download_test_data
4
5
  import xarray as xr
6
+ import pytest
7
+ from datetime import datetime
8
+ from pathlib import Path
5
9
 
6
10
 
7
11
  def test_interpolate_from_climatology(use_dask):
@@ -14,3 +18,54 @@ def test_interpolate_from_climatology(use_dask):
14
18
 
15
19
  interpolated_field = interpolate_from_climatology(field, "time", era5_times)
16
20
  assert len(interpolated_field.time) == len(era5_times)
21
+
22
+
23
+ # Test yaml roundtrip with multiple source files
24
+ @pytest.fixture()
25
+ def boundary_forcing_from_multiple_source_files(request, use_dask):
26
+ """Fixture for creating a BoundaryForcing object."""
27
+
28
+ grid = Grid(
29
+ nx=5,
30
+ ny=5,
31
+ size_x=100,
32
+ size_y=100,
33
+ center_lon=-8,
34
+ center_lat=60,
35
+ rot=10,
36
+ N=3, # number of vertical levels
37
+ )
38
+
39
+ fname1 = Path(download_test_data("GLORYS_NA_20120101.nc"))
40
+ fname2 = Path(download_test_data("GLORYS_NA_20121231.nc"))
41
+
42
+ return BoundaryForcing(
43
+ grid=grid,
44
+ start_time=datetime(2011, 1, 1),
45
+ end_time=datetime(2013, 1, 1),
46
+ source={"name": "GLORYS", "path": [fname1, fname2]},
47
+ use_dask=use_dask,
48
+ )
49
+
50
+
51
+ def test_roundtrip_yaml(
52
+ boundary_forcing_from_multiple_source_files, request, tmp_path, use_dask
53
+ ):
54
+ """Test that creating a BoundaryForcing object, saving its parameters to yaml file,
55
+ and re-opening yaml file creates the same object."""
56
+
57
+ # Create a temporary filepath using the tmp_path fixture
58
+ file_str = "test_yaml"
59
+ for filepath in [
60
+ tmp_path / file_str,
61
+ str(tmp_path / file_str),
62
+ ]: # test for Path object and str
63
+
64
+ boundary_forcing_from_multiple_source_files.to_yaml(filepath)
65
+
66
+ bdry_forcing_from_file = BoundaryForcing.from_yaml(filepath, use_dask=use_dask)
67
+
68
+ assert boundary_forcing_from_multiple_source_files == bdry_forcing_from_file
69
+
70
+ filepath = Path(filepath)
71
+ filepath.unlink()