roms-tools 0.1.0__py3-none-any.whl → 0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ci/environment.yml +1 -0
- roms_tools/__init__.py +3 -0
- roms_tools/_version.py +1 -1
- roms_tools/setup/atmospheric_forcing.py +335 -393
- roms_tools/setup/boundary_forcing.py +711 -0
- roms_tools/setup/datasets.py +434 -25
- roms_tools/setup/fill.py +118 -5
- roms_tools/setup/grid.py +145 -19
- roms_tools/setup/initial_conditions.py +528 -0
- roms_tools/setup/plot.py +149 -4
- roms_tools/setup/tides.py +570 -437
- roms_tools/setup/topography.py +17 -2
- roms_tools/setup/utils.py +162 -0
- roms_tools/setup/vertical_coordinate.py +494 -0
- roms_tools/tests/test_atmospheric_forcing.py +1645 -0
- roms_tools/tests/test_boundary_forcing.py +332 -0
- roms_tools/tests/test_datasets.py +306 -0
- roms_tools/tests/test_grid.py +226 -0
- roms_tools/tests/test_initial_conditions.py +300 -0
- roms_tools/tests/test_tides.py +366 -0
- roms_tools/tests/test_topography.py +78 -0
- roms_tools/tests/test_vertical_coordinate.py +337 -0
- {roms_tools-0.1.0.dist-info → roms_tools-0.20.dist-info}/METADATA +3 -2
- roms_tools-0.20.dist-info/RECORD +28 -0
- {roms_tools-0.1.0.dist-info → roms_tools-0.20.dist-info}/WHEEL +1 -1
- roms_tools/tests/test_setup.py +0 -181
- roms_tools-0.1.0.dist-info/RECORD +0 -17
- {roms_tools-0.1.0.dist-info → roms_tools-0.20.dist-info}/LICENSE +0 -0
- {roms_tools-0.1.0.dist-info → roms_tools-0.20.dist-info}/top_level.txt +0 -0
roms_tools/setup/datasets.py
CHANGED
|
@@ -1,48 +1,457 @@
|
|
|
1
1
|
import pooch
|
|
2
2
|
import xarray as xr
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
import glob
|
|
5
|
+
from datetime import datetime, timedelta
|
|
6
|
+
import numpy as np
|
|
7
|
+
from typing import Dict, Optional, List
|
|
8
|
+
import dask
|
|
3
9
|
|
|
4
|
-
|
|
5
|
-
|
|
10
|
+
# Create a Pooch object to manage the global topography data
|
|
11
|
+
pup_data = pooch.create(
|
|
6
12
|
# Use the default cache folder for the operating system
|
|
7
13
|
path=pooch.os_cache("roms-tools"),
|
|
8
14
|
base_url="https://github.com/CWorthy-ocean/roms-tools-data/raw/main/",
|
|
9
|
-
# If this is a development version, get the data from the "main" branch
|
|
10
15
|
# The registry specifies the files that can be fetched
|
|
11
16
|
registry={
|
|
12
17
|
"etopo5.nc": "sha256:23600e422d59bbf7c3666090166a0d468c8ee16092f4f14e32c4e928fbcd627b",
|
|
13
18
|
},
|
|
14
19
|
)
|
|
15
20
|
|
|
21
|
+
# Create a Pooch object to manage the test data
|
|
22
|
+
pup_test_data = pooch.create(
|
|
23
|
+
# Use the default cache folder for the operating system
|
|
24
|
+
path=pooch.os_cache("roms-tools"),
|
|
25
|
+
base_url="https://github.com/CWorthy-ocean/roms-tools-test-data/raw/main/",
|
|
26
|
+
# The registry specifies the files that can be fetched
|
|
27
|
+
registry={
|
|
28
|
+
"GLORYS_test_data.nc": "648f88ec29c433bcf65f257c1fb9497bd3d5d3880640186336b10ed54f7129d2",
|
|
29
|
+
"ERA5_regional_test_data.nc": "bd12ce3b562fbea2a80a3b79ba74c724294043c28dc98ae092ad816d74eac794",
|
|
30
|
+
"ERA5_global_test_data.nc": "8ed177ab64c02caf509b9fb121cf6713f286cc603b1f302f15f3f4eb0c21dc4f",
|
|
31
|
+
"TPXO_global_test_data.nc": "457bfe87a7b247ec6e04e3c7d3e741ccf223020c41593f8ae33a14f2b5255e60",
|
|
32
|
+
"TPXO_regional_test_data.nc": "11739245e2286d9c9d342dce5221e6435d2072b50028bef2e86a30287b3b4032",
|
|
33
|
+
},
|
|
34
|
+
)
|
|
35
|
+
|
|
16
36
|
|
|
17
|
-
def fetch_topo(topography_source) -> xr.Dataset:
|
|
37
|
+
def fetch_topo(topography_source: str) -> xr.Dataset:
|
|
18
38
|
"""
|
|
19
39
|
Load the global topography data as an xarray Dataset.
|
|
40
|
+
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
topography_source : str
|
|
44
|
+
The source of the topography data to be loaded. Available options:
|
|
45
|
+
- "ETOPO5"
|
|
46
|
+
|
|
47
|
+
Returns
|
|
48
|
+
-------
|
|
49
|
+
xr.Dataset
|
|
50
|
+
The global topography data as an xarray Dataset.
|
|
20
51
|
"""
|
|
21
52
|
# Mapping from user-specified topography options to corresponding filenames in the registry
|
|
22
|
-
topo_dict = {"
|
|
23
|
-
|
|
24
|
-
#
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
# The "fetch" method returns the full path to the downloaded data file.
|
|
29
|
-
# All we need to do now is load it with our standard Python tools.
|
|
53
|
+
topo_dict = {"ETOPO5": "etopo5.nc"}
|
|
54
|
+
|
|
55
|
+
# Fetch the file using Pooch, downloading if necessary
|
|
56
|
+
fname = pup_data.fetch(topo_dict[topography_source])
|
|
57
|
+
|
|
58
|
+
# Load the dataset using xarray and return it
|
|
30
59
|
ds = xr.open_dataset(fname)
|
|
31
60
|
return ds
|
|
32
61
|
|
|
33
62
|
|
|
34
|
-
def
|
|
63
|
+
def download_test_data(filename: str) -> str:
|
|
35
64
|
"""
|
|
36
|
-
|
|
65
|
+
Download the test data file.
|
|
66
|
+
|
|
67
|
+
Parameters
|
|
68
|
+
----------
|
|
69
|
+
filename : str
|
|
70
|
+
The name of the test data file to be downloaded. Available options:
|
|
71
|
+
- "GLORYS_test_data.nc"
|
|
72
|
+
- "ERA5_regional_test_data.nc"
|
|
73
|
+
- "ERA5_global_test_data.nc"
|
|
74
|
+
|
|
75
|
+
Returns
|
|
76
|
+
-------
|
|
77
|
+
str
|
|
78
|
+
The path to the downloaded test data file.
|
|
37
79
|
"""
|
|
38
|
-
#
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
80
|
+
# Fetch the file using Pooch, downloading if necessary
|
|
81
|
+
fname = pup_test_data.fetch(filename)
|
|
82
|
+
|
|
83
|
+
return fname
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@dataclass(frozen=True, kw_only=True)
|
|
87
|
+
class Dataset:
|
|
88
|
+
"""
|
|
89
|
+
Represents forcing data on original grid.
|
|
90
|
+
|
|
91
|
+
Parameters
|
|
92
|
+
----------
|
|
93
|
+
filename : str
|
|
94
|
+
The path to the data files. Can contain wildcards.
|
|
95
|
+
start_time : Optional[datetime], optional
|
|
96
|
+
The start time for selecting relevant data. If not provided, the data is not filtered by start time.
|
|
97
|
+
end_time : Optional[datetime], optional
|
|
98
|
+
The end time for selecting relevant data. If not provided, only data at the start_time is selected if start_time is provided,
|
|
99
|
+
or no filtering is applied if start_time is not provided.
|
|
100
|
+
var_names : List[str]
|
|
101
|
+
List of variable names that are required in the dataset.
|
|
102
|
+
dim_names: Dict[str, str], optional
|
|
103
|
+
Dictionary specifying the names of dimensions in the dataset.
|
|
104
|
+
|
|
105
|
+
Attributes
|
|
106
|
+
----------
|
|
107
|
+
ds : xr.Dataset
|
|
108
|
+
The xarray Dataset containing the forcing data on its original grid.
|
|
109
|
+
|
|
110
|
+
Examples
|
|
111
|
+
--------
|
|
112
|
+
>>> dataset = Dataset(
|
|
113
|
+
... filename="data.nc",
|
|
114
|
+
... start_time=datetime(2022, 1, 1),
|
|
115
|
+
... end_time=datetime(2022, 12, 31),
|
|
116
|
+
... )
|
|
117
|
+
>>> dataset.load_data()
|
|
118
|
+
>>> print(dataset.ds)
|
|
119
|
+
<xarray.Dataset>
|
|
120
|
+
Dimensions: ...
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
filename: str
|
|
124
|
+
start_time: Optional[datetime] = None
|
|
125
|
+
end_time: Optional[datetime] = None
|
|
126
|
+
var_names: List[str]
|
|
127
|
+
dim_names: Dict[str, str] = field(
|
|
128
|
+
default_factory=lambda: {
|
|
129
|
+
"longitude": "longitude",
|
|
130
|
+
"latitude": "latitude",
|
|
131
|
+
"time": "time",
|
|
132
|
+
}
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
ds: xr.Dataset = field(init=False, repr=False)
|
|
136
|
+
|
|
137
|
+
def __post_init__(self):
|
|
138
|
+
ds = self.load_data()
|
|
139
|
+
|
|
140
|
+
# Select relevant times
|
|
141
|
+
if "time" in self.dim_names and self.start_time is not None:
|
|
142
|
+
ds = self.select_relevant_times(ds)
|
|
143
|
+
|
|
144
|
+
# Select relevant fields
|
|
145
|
+
ds = self.select_relevant_fields(ds)
|
|
146
|
+
|
|
147
|
+
# Make sure that latitude is ascending
|
|
148
|
+
diff = np.diff(ds[self.dim_names["latitude"]])
|
|
149
|
+
if np.all(diff < 0):
|
|
150
|
+
ds = ds.isel(**{self.dim_names["latitude"]: slice(None, None, -1)})
|
|
151
|
+
|
|
152
|
+
# Check whether the data covers the entire globe
|
|
153
|
+
is_global = self.check_if_global(ds)
|
|
154
|
+
|
|
155
|
+
if is_global:
|
|
156
|
+
ds = self.concatenate_longitudes(ds)
|
|
157
|
+
|
|
158
|
+
object.__setattr__(self, "ds", ds)
|
|
159
|
+
|
|
160
|
+
def load_data(self) -> xr.Dataset:
|
|
161
|
+
"""
|
|
162
|
+
Load dataset from the specified file.
|
|
163
|
+
|
|
164
|
+
Returns
|
|
165
|
+
-------
|
|
166
|
+
ds : xr.Dataset
|
|
167
|
+
The loaded xarray Dataset containing the forcing data.
|
|
168
|
+
|
|
169
|
+
Raises
|
|
170
|
+
------
|
|
171
|
+
FileNotFoundError
|
|
172
|
+
If the specified file does not exist.
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
# Check if the file exists
|
|
176
|
+
matching_files = glob.glob(self.filename)
|
|
177
|
+
if not matching_files:
|
|
178
|
+
raise FileNotFoundError(
|
|
179
|
+
f"No files found matching the pattern '{self.filename}'."
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Load the dataset
|
|
183
|
+
with dask.config.set(**{"array.slicing.split_large_chunks": False}):
|
|
184
|
+
# Define the chunk sizes
|
|
185
|
+
chunks = {
|
|
186
|
+
self.dim_names["latitude"]: -1,
|
|
187
|
+
self.dim_names["longitude"]: -1,
|
|
188
|
+
}
|
|
189
|
+
if "depth" in self.dim_names.keys():
|
|
190
|
+
chunks[self.dim_names["depth"]] = -1
|
|
191
|
+
if "time" in self.dim_names.keys():
|
|
192
|
+
chunks[self.dim_names["time"]] = 1
|
|
193
|
+
|
|
194
|
+
ds = xr.open_mfdataset(
|
|
195
|
+
self.filename,
|
|
196
|
+
combine="nested",
|
|
197
|
+
concat_dim=self.dim_names["time"],
|
|
198
|
+
coords="minimal",
|
|
199
|
+
compat="override",
|
|
200
|
+
chunks=chunks,
|
|
201
|
+
engine="netcdf4",
|
|
202
|
+
)
|
|
203
|
+
else:
|
|
204
|
+
ds = xr.open_dataset(
|
|
205
|
+
self.filename,
|
|
206
|
+
chunks=chunks,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
return ds
|
|
210
|
+
|
|
211
|
+
def select_relevant_fields(self, ds) -> xr.Dataset:
|
|
212
|
+
"""
|
|
213
|
+
Selects and returns a subset of the dataset containing only the variables specified in `self.var_names`.
|
|
214
|
+
|
|
215
|
+
Parameters
|
|
216
|
+
----------
|
|
217
|
+
ds : xr.Dataset
|
|
218
|
+
The input dataset from which variables will be selected.
|
|
219
|
+
|
|
220
|
+
Returns
|
|
221
|
+
-------
|
|
222
|
+
xr.Dataset
|
|
223
|
+
A dataset containing only the variables specified in `self.var_names`.
|
|
224
|
+
|
|
225
|
+
Raises
|
|
226
|
+
------
|
|
227
|
+
ValueError
|
|
228
|
+
If `ds` does not contain all variables listed in `self.var_names`.
|
|
229
|
+
|
|
230
|
+
"""
|
|
231
|
+
missing_vars = [var for var in self.var_names if var not in ds.data_vars]
|
|
232
|
+
if missing_vars:
|
|
233
|
+
raise ValueError(
|
|
234
|
+
f"Dataset does not contain all required variables. The following variables are missing: {missing_vars}"
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
for var in ds.data_vars:
|
|
238
|
+
if var not in self.var_names:
|
|
239
|
+
ds = ds.drop_vars(var)
|
|
240
|
+
|
|
241
|
+
return ds
|
|
242
|
+
|
|
243
|
+
def select_relevant_times(self, ds) -> xr.Dataset:
|
|
244
|
+
|
|
245
|
+
"""
|
|
246
|
+
Selects and returns the subset of the dataset corresponding to the specified time range.
|
|
247
|
+
|
|
248
|
+
This function filters the dataset to include only the data points within the specified
|
|
249
|
+
time range, defined by `self.start_time` and `self.end_time`. If `self.end_time` is not
|
|
250
|
+
provided, it defaults to one day after `self.start_time`.
|
|
251
|
+
|
|
252
|
+
Parameters
|
|
253
|
+
----------
|
|
254
|
+
ds : xr.Dataset
|
|
255
|
+
The input dataset to be filtered.
|
|
256
|
+
|
|
257
|
+
Returns
|
|
258
|
+
-------
|
|
259
|
+
xr.Dataset
|
|
260
|
+
A dataset containing only the data points within the specified time range.
|
|
261
|
+
|
|
262
|
+
"""
|
|
263
|
+
|
|
264
|
+
time_dim = self.dim_names["time"]
|
|
265
|
+
|
|
266
|
+
if not self.end_time:
|
|
267
|
+
end_time = self.start_time + timedelta(days=1)
|
|
268
|
+
else:
|
|
269
|
+
end_time = self.end_time
|
|
270
|
+
|
|
271
|
+
times = (np.datetime64(self.start_time) <= ds[time_dim]) & (
|
|
272
|
+
ds[time_dim] < np.datetime64(end_time)
|
|
273
|
+
)
|
|
274
|
+
ds = ds.where(times, drop=True)
|
|
275
|
+
|
|
276
|
+
if not ds.sizes[time_dim]:
|
|
277
|
+
raise ValueError("No matching times found.")
|
|
278
|
+
|
|
279
|
+
if not self.end_time:
|
|
280
|
+
if ds.sizes[time_dim] != 1:
|
|
281
|
+
found_times = ds.sizes[time_dim]
|
|
282
|
+
raise ValueError(
|
|
283
|
+
f"There must be exactly one time matching the start_time. Found {found_times} matching times."
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
return ds
|
|
287
|
+
|
|
288
|
+
def check_if_global(self, ds) -> bool:
|
|
289
|
+
"""
|
|
290
|
+
Checks if the dataset covers the entire globe in the longitude dimension.
|
|
291
|
+
|
|
292
|
+
This function calculates the mean difference between consecutive longitude values.
|
|
293
|
+
It then checks if the difference between the first and last longitude values (plus 360 degrees)
|
|
294
|
+
is close to this mean difference, within a specified tolerance. If it is, the dataset is considered
|
|
295
|
+
to cover the entire globe in the longitude dimension.
|
|
296
|
+
|
|
297
|
+
Returns
|
|
298
|
+
-------
|
|
299
|
+
bool
|
|
300
|
+
True if the dataset covers the entire globe in the longitude dimension, False otherwise.
|
|
301
|
+
|
|
302
|
+
"""
|
|
303
|
+
dlon_mean = (
|
|
304
|
+
ds[self.dim_names["longitude"]].diff(dim=self.dim_names["longitude"]).mean()
|
|
305
|
+
)
|
|
306
|
+
dlon = (
|
|
307
|
+
ds[self.dim_names["longitude"]][0] - ds[self.dim_names["longitude"]][-1]
|
|
308
|
+
) % 360.0
|
|
309
|
+
is_global = np.isclose(dlon, dlon_mean, rtol=0.0, atol=1e-3)
|
|
310
|
+
|
|
311
|
+
return is_global
|
|
312
|
+
|
|
313
|
+
def concatenate_longitudes(self, ds):
|
|
314
|
+
"""
|
|
315
|
+
Concatenates the field three times: with longitudes shifted by -360, original longitudes, and shifted by +360.
|
|
316
|
+
|
|
317
|
+
Parameters
|
|
318
|
+
----------
|
|
319
|
+
field : xr.DataArray
|
|
320
|
+
The field to be concatenated.
|
|
321
|
+
|
|
322
|
+
Returns
|
|
323
|
+
-------
|
|
324
|
+
xr.DataArray
|
|
325
|
+
The concatenated field, with the longitude dimension extended.
|
|
326
|
+
|
|
327
|
+
Notes
|
|
328
|
+
-----
|
|
329
|
+
Concatenating three times may be overkill in most situations, but it is safe. Alternatively, we could refactor
|
|
330
|
+
to figure out whether concatenating on the lower end, upper end, or at all is needed.
|
|
331
|
+
|
|
332
|
+
"""
|
|
333
|
+
ds_concatenated = xr.Dataset()
|
|
334
|
+
|
|
335
|
+
lon = ds[self.dim_names["longitude"]]
|
|
336
|
+
lon_minus360 = lon - 360
|
|
337
|
+
lon_plus360 = lon + 360
|
|
338
|
+
lon_concatenated = xr.concat(
|
|
339
|
+
[lon_minus360, lon, lon_plus360], dim=self.dim_names["longitude"]
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
ds_concatenated[self.dim_names["longitude"]] = lon_concatenated
|
|
343
|
+
|
|
344
|
+
for var in self.var_names:
|
|
345
|
+
if self.dim_names["longitude"] in ds[var].dims:
|
|
346
|
+
field = ds[var]
|
|
347
|
+
field_concatenated = xr.concat(
|
|
348
|
+
[field, field, field], dim=self.dim_names["longitude"]
|
|
349
|
+
).chunk({self.dim_names["longitude"]: -1})
|
|
350
|
+
field_concatenated[self.dim_names["longitude"]] = lon_concatenated
|
|
351
|
+
ds_concatenated[var] = field_concatenated
|
|
352
|
+
else:
|
|
353
|
+
ds_concatenated[var] = ds[var]
|
|
354
|
+
|
|
355
|
+
return ds_concatenated
|
|
356
|
+
|
|
357
|
+
def choose_subdomain(
|
|
358
|
+
self, latitude_range, longitude_range, margin, straddle, return_subdomain=False
|
|
359
|
+
):
|
|
360
|
+
"""
|
|
361
|
+
Selects a subdomain from the given xarray Dataset based on latitude and longitude ranges,
|
|
362
|
+
extending the selection by the specified margin. Handles the conversion of longitude values
|
|
363
|
+
in the dataset from one range to another.
|
|
364
|
+
|
|
365
|
+
Parameters
|
|
366
|
+
----------
|
|
367
|
+
latitude_range : tuple
|
|
368
|
+
A tuple (lat_min, lat_max) specifying the minimum and maximum latitude values of the subdomain.
|
|
369
|
+
longitude_range : tuple
|
|
370
|
+
A tuple (lon_min, lon_max) specifying the minimum and maximum longitude values of the subdomain.
|
|
371
|
+
margin : float
|
|
372
|
+
Margin in degrees to extend beyond the specified latitude and longitude ranges when selecting the subdomain.
|
|
373
|
+
straddle : bool
|
|
374
|
+
If True, target longitudes are expected in the range [-180, 180].
|
|
375
|
+
If False, target longitudes are expected in the range [0, 360].
|
|
376
|
+
return_subdomain : bool, optional
|
|
377
|
+
If True, returns the subset of the original dataset. If False, assigns it to self.ds.
|
|
378
|
+
Default is False.
|
|
379
|
+
|
|
380
|
+
Returns
|
|
381
|
+
-------
|
|
382
|
+
xr.Dataset
|
|
383
|
+
The subset of the original dataset representing the chosen subdomain, including an extended area
|
|
384
|
+
to cover one extra grid point beyond the specified ranges if return_subdomain is True.
|
|
385
|
+
Otherwise, returns None.
|
|
386
|
+
|
|
387
|
+
Raises
|
|
388
|
+
------
|
|
389
|
+
ValueError
|
|
390
|
+
If the selected latitude or longitude range does not intersect with the dataset.
|
|
391
|
+
"""
|
|
392
|
+
lat_min, lat_max = latitude_range
|
|
393
|
+
lon_min, lon_max = longitude_range
|
|
394
|
+
|
|
395
|
+
lon = self.ds[self.dim_names["longitude"]]
|
|
396
|
+
# Adjust longitude range if needed to match the expected range
|
|
397
|
+
if not straddle:
|
|
398
|
+
if lon.min() < -180:
|
|
399
|
+
if lon_max + margin > 0:
|
|
400
|
+
lon_min -= 360
|
|
401
|
+
lon_max -= 360
|
|
402
|
+
elif lon.min() < 0:
|
|
403
|
+
if lon_max + margin > 180:
|
|
404
|
+
lon_min -= 360
|
|
405
|
+
lon_max -= 360
|
|
406
|
+
|
|
407
|
+
if straddle:
|
|
408
|
+
if lon.max() > 360:
|
|
409
|
+
if lon_min - margin < 180:
|
|
410
|
+
lon_min += 360
|
|
411
|
+
lon_max += 360
|
|
412
|
+
elif lon.max() > 180:
|
|
413
|
+
if lon_min - margin < 0:
|
|
414
|
+
lon_min += 360
|
|
415
|
+
lon_max += 360
|
|
416
|
+
|
|
417
|
+
# Select the subdomain
|
|
418
|
+
subdomain = self.ds.sel(
|
|
419
|
+
**{
|
|
420
|
+
self.dim_names["latitude"]: slice(lat_min - margin, lat_max + margin),
|
|
421
|
+
self.dim_names["longitude"]: slice(lon_min - margin, lon_max + margin),
|
|
422
|
+
}
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
# Check if the selected subdomain has zero dimensions in latitude or longitude
|
|
426
|
+
if subdomain[self.dim_names["latitude"]].size == 0:
|
|
427
|
+
raise ValueError("Selected latitude range does not intersect with dataset.")
|
|
428
|
+
|
|
429
|
+
if subdomain[self.dim_names["longitude"]].size == 0:
|
|
430
|
+
raise ValueError(
|
|
431
|
+
"Selected longitude range does not intersect with dataset."
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
# Adjust longitudes to expected range if needed
|
|
435
|
+
lon = subdomain[self.dim_names["longitude"]]
|
|
436
|
+
if straddle:
|
|
437
|
+
subdomain[self.dim_names["longitude"]] = xr.where(lon > 180, lon - 360, lon)
|
|
438
|
+
else:
|
|
439
|
+
subdomain[self.dim_names["longitude"]] = xr.where(lon < 0, lon + 360, lon)
|
|
440
|
+
|
|
441
|
+
if return_subdomain:
|
|
442
|
+
return subdomain
|
|
443
|
+
else:
|
|
444
|
+
object.__setattr__(self, "ds", subdomain)
|
|
445
|
+
|
|
446
|
+
def convert_to_negative_depth(self):
|
|
447
|
+
"""
|
|
448
|
+
Converts the depth values in the dataset to negative if they are non-negative.
|
|
449
|
+
|
|
450
|
+
This method checks the values in the depth dimension of the dataset (`self.ds[self.dim_names["depth"]]`).
|
|
451
|
+
If all values are greater than or equal to zero, it negates them and updates the dataset accordingly.
|
|
452
|
+
|
|
453
|
+
"""
|
|
454
|
+
depth = self.ds[self.dim_names["depth"]]
|
|
455
|
+
|
|
456
|
+
if (depth >= 0).all():
|
|
457
|
+
self.ds[self.dim_names["depth"]] = -depth
|
roms_tools/setup/fill.py
CHANGED
|
@@ -3,7 +3,103 @@ import xarray as xr
|
|
|
3
3
|
from numba import jit
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
def
|
|
6
|
+
def fill_and_interpolate(
|
|
7
|
+
field,
|
|
8
|
+
mask,
|
|
9
|
+
fill_dims,
|
|
10
|
+
coords,
|
|
11
|
+
method="linear",
|
|
12
|
+
fillvalue_fill=0.0,
|
|
13
|
+
fillvalue_interp=np.nan,
|
|
14
|
+
):
|
|
15
|
+
"""
|
|
16
|
+
Propagates ocean values into land areas and interpolates the data to specified coordinates using a given method.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
field : xr.DataArray
|
|
21
|
+
The data array to be interpolated, typically containing oceanographic or atmospheric data
|
|
22
|
+
with dimensions such as latitude and longitude.
|
|
23
|
+
|
|
24
|
+
mask : xr.DataArray
|
|
25
|
+
A data array with the same spatial dimensions as `field`, where `1` indicates ocean points
|
|
26
|
+
and `0` indicates land points. This mask is used to identify land and ocean areas in the dataset.
|
|
27
|
+
|
|
28
|
+
fill_dims : list of str
|
|
29
|
+
List specifying the dimensions along which to perform the lateral fill, typically the horizontal
|
|
30
|
+
dimensions such as latitude and longitude, e.g., ["latitude", "longitude"].
|
|
31
|
+
|
|
32
|
+
coords : dict
|
|
33
|
+
Dictionary specifying the target coordinates for interpolation. The keys should match the dimensions
|
|
34
|
+
of `field` (e.g., {"longitude": lon_values, "latitude": lat_values, "depth": depth_values}).
|
|
35
|
+
This dictionary provides the new coordinates onto which the data array will be interpolated.
|
|
36
|
+
|
|
37
|
+
method : str, optional, default='linear'
|
|
38
|
+
The interpolation method to use. Valid options are those supported by `xarray.DataArray.interp`,
|
|
39
|
+
such as 'linear' or 'nearest'.
|
|
40
|
+
|
|
41
|
+
fillvalue_fill : float, optional, default=0.0
|
|
42
|
+
Value to use in the fill step if an entire data slice along the fill dimensions contains only NaNs.
|
|
43
|
+
|
|
44
|
+
fillvalue_interp : float, optional, default=np.nan
|
|
45
|
+
Value to use in the interpolation step. `np.nan` means that no extrapolation is applied.
|
|
46
|
+
`None` means that extrapolation is applied, which often makes sense when interpolating in the
|
|
47
|
+
vertical direction to avoid NaNs at the surface if the lowest depth is greater than zero.
|
|
48
|
+
|
|
49
|
+
Returns
|
|
50
|
+
-------
|
|
51
|
+
xr.DataArray
|
|
52
|
+
The interpolated data array. This array has the same dimensions as the input `field` but with values
|
|
53
|
+
interpolated to the new coordinates specified in `coords`.
|
|
54
|
+
|
|
55
|
+
Notes
|
|
56
|
+
-----
|
|
57
|
+
This method performs the following steps:
|
|
58
|
+
1. Sets land values to NaN based on the provided mask to ensure that interpolation does not cross
|
|
59
|
+
the land-ocean boundary.
|
|
60
|
+
2. Uses the `lateral_fill` function to propagate ocean values into the land interior, helping to fill
|
|
61
|
+
gaps in the dataset.
|
|
62
|
+
3. Interpolates the filled data array over the specified coordinates using the selected interpolation method.
|
|
63
|
+
|
|
64
|
+
Example
|
|
65
|
+
-------
|
|
66
|
+
>>> import xarray as xr
|
|
67
|
+
>>> field = xr.DataArray(...)
|
|
68
|
+
>>> mask = xr.DataArray(...)
|
|
69
|
+
>>> fill_dims = ["latitude", "longitude"]
|
|
70
|
+
>>> coords = {"latitude": new_lat_values, "longitude": new_lon_values}
|
|
71
|
+
>>> interpolated_field = fill_and_interpolate(
|
|
72
|
+
... field, mask, fill_dims, coords, method="linear"
|
|
73
|
+
... )
|
|
74
|
+
>>> print(interpolated_field)
|
|
75
|
+
"""
|
|
76
|
+
if not isinstance(field, xr.DataArray):
|
|
77
|
+
raise TypeError("field must be an xarray.DataArray")
|
|
78
|
+
if not isinstance(mask, xr.DataArray):
|
|
79
|
+
raise TypeError("mask must be an xarray.DataArray")
|
|
80
|
+
if not isinstance(coords, dict):
|
|
81
|
+
raise TypeError("coords must be a dictionary")
|
|
82
|
+
if not all(dim in field.dims for dim in coords.keys()):
|
|
83
|
+
raise ValueError("All keys in coords must match dimensions of field")
|
|
84
|
+
if method not in ["linear", "nearest"]:
|
|
85
|
+
raise ValueError(
|
|
86
|
+
"Unsupported interpolation method. Choose from 'linear', 'nearest'"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# Set land values to NaN
|
|
90
|
+
field = field.where(mask)
|
|
91
|
+
|
|
92
|
+
# Propagate ocean values into land interior before interpolation
|
|
93
|
+
field = lateral_fill(field, 1 - mask, fill_dims, fillvalue_fill)
|
|
94
|
+
|
|
95
|
+
field_interpolated = field.interp(
|
|
96
|
+
coords, method=method, kwargs={"fill_value": fillvalue_interp}
|
|
97
|
+
).drop_vars(list(coords.keys()))
|
|
98
|
+
|
|
99
|
+
return field_interpolated
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def lateral_fill(var, land_mask, dims=["latitude", "longitude"], fillvalue=0.0):
|
|
7
103
|
"""
|
|
8
104
|
Perform lateral fill on an xarray DataArray using a land mask.
|
|
9
105
|
|
|
@@ -20,6 +116,9 @@ def lateral_fill(var, land_mask, dims=["latitude", "longitude"]):
|
|
|
20
116
|
dims : list of str, optional, default=['latitude', 'longitude']
|
|
21
117
|
Dimensions along which to perform the fill. The default is ['latitude', 'longitude'].
|
|
22
118
|
|
|
119
|
+
fillvalue : float, optional, default=0.0
|
|
120
|
+
Value to use if an entire data slice along the dims contains only NaNs.
|
|
121
|
+
|
|
23
122
|
Returns
|
|
24
123
|
-------
|
|
25
124
|
var_filled : xarray.DataArray
|
|
@@ -27,21 +126,25 @@ def lateral_fill(var, land_mask, dims=["latitude", "longitude"]):
|
|
|
27
126
|
specified by `land_mask` where NaNs are preserved.
|
|
28
127
|
|
|
29
128
|
"""
|
|
129
|
+
|
|
30
130
|
var_filled = xr.apply_ufunc(
|
|
31
131
|
_lateral_fill_np_array,
|
|
32
132
|
var,
|
|
33
133
|
land_mask,
|
|
34
134
|
input_core_dims=[dims, dims],
|
|
35
135
|
output_core_dims=[dims],
|
|
36
|
-
dask="parallelized",
|
|
37
136
|
output_dtypes=[var.dtype],
|
|
137
|
+
dask="parallelized",
|
|
38
138
|
vectorize=True,
|
|
139
|
+
kwargs={"fillvalue": fillvalue},
|
|
39
140
|
)
|
|
40
141
|
|
|
41
142
|
return var_filled
|
|
42
143
|
|
|
43
144
|
|
|
44
|
-
def _lateral_fill_np_array(
|
|
145
|
+
def _lateral_fill_np_array(
|
|
146
|
+
var, isvalid_mask, fillvalue=0.0, tol=1.0e-4, rc=1.8, max_iter=10000
|
|
147
|
+
):
|
|
45
148
|
"""
|
|
46
149
|
Perform lateral fill on a numpy array.
|
|
47
150
|
|
|
@@ -55,6 +158,9 @@ def _lateral_fill_np_array(var, isvalid_mask, tol=1.0e-4, rc=1.8, max_iter=10000
|
|
|
55
158
|
Valid values mask: `True` where data should be filled. Must have same shape
|
|
56
159
|
as `var`.
|
|
57
160
|
|
|
161
|
+
fillvalue: float
|
|
162
|
+
Value to use if the full field `var` contains only NaNs. Default is 0.0.
|
|
163
|
+
|
|
58
164
|
tol : float, optional, default=1.0e-4
|
|
59
165
|
Convergence criteria: stop filling when the value change is less than
|
|
60
166
|
or equal to `tol * var`, i.e., `delta <= tol * np.abs(var[j, i])`.
|
|
@@ -90,14 +196,14 @@ def _lateral_fill_np_array(var, isvalid_mask, tol=1.0e-4, rc=1.8, max_iter=10000
|
|
|
90
196
|
|
|
91
197
|
fillmask = np.isnan(var) # Fill all NaNs
|
|
92
198
|
keepNaNs = ~isvalid_mask & np.isnan(var)
|
|
93
|
-
var = _iterative_fill_sor(nlat, nlon, var, fillmask, tol, rc, max_iter)
|
|
199
|
+
var = _iterative_fill_sor(nlat, nlon, var, fillmask, tol, rc, max_iter, fillvalue)
|
|
94
200
|
var[keepNaNs] = np.nan # Replace NaNs in areas not designated for filling
|
|
95
201
|
|
|
96
202
|
return var
|
|
97
203
|
|
|
98
204
|
|
|
99
205
|
@jit(nopython=True, parallel=True)
|
|
100
|
-
def _iterative_fill_sor(nlat, nlon, var, fillmask, tol, rc, max_iter):
|
|
206
|
+
def _iterative_fill_sor(nlat, nlon, var, fillmask, tol, rc, max_iter, fillvalue=0.0):
|
|
101
207
|
"""
|
|
102
208
|
Perform an iterative land fill algorithm using the Successive Over-Relaxation (SOR)
|
|
103
209
|
solution of the Laplace Equation.
|
|
@@ -126,6 +232,9 @@ def _iterative_fill_sor(nlat, nlon, var, fillmask, tol, rc, max_iter):
|
|
|
126
232
|
max_iter : int
|
|
127
233
|
Maximum number of iterations allowed before the process is terminated.
|
|
128
234
|
|
|
235
|
+
fillvalue: float
|
|
236
|
+
Value to use if the full field is NaNs. Default is 0.0.
|
|
237
|
+
|
|
129
238
|
Returns
|
|
130
239
|
-------
|
|
131
240
|
None
|
|
@@ -155,6 +264,10 @@ def _iterative_fill_sor(nlat, nlon, var, fillmask, tol, rc, max_iter):
|
|
|
155
264
|
if np.max(np.fabs(var)) == 0.0:
|
|
156
265
|
var = np.zeros_like(var)
|
|
157
266
|
return var
|
|
267
|
+
# If field consists only of NaNs, fill NaNs with fill value
|
|
268
|
+
if np.isnan(var).all():
|
|
269
|
+
var = fillvalue * np.ones_like(var)
|
|
270
|
+
return var
|
|
158
271
|
|
|
159
272
|
# Compute a zonal mean to use as a first guess
|
|
160
273
|
zoncnt = np.zeros(nlat)
|