roms-tools 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ci/environment.yml +2 -0
- roms_tools/__init__.py +4 -2
- roms_tools/_version.py +1 -1
- roms_tools/setup/boundary_forcing.py +757 -0
- roms_tools/setup/datasets.py +1141 -35
- roms_tools/setup/download.py +118 -0
- roms_tools/setup/fill.py +118 -5
- roms_tools/setup/grid.py +145 -19
- roms_tools/setup/initial_conditions.py +557 -0
- roms_tools/setup/mixins.py +395 -0
- roms_tools/setup/plot.py +149 -4
- roms_tools/setup/surface_forcing.py +596 -0
- roms_tools/setup/tides.py +472 -437
- roms_tools/setup/topography.py +18 -3
- roms_tools/setup/utils.py +352 -0
- roms_tools/setup/vertical_coordinate.py +494 -0
- roms_tools/tests/test_boundary_forcing.py +706 -0
- roms_tools/tests/test_datasets.py +370 -0
- roms_tools/tests/test_grid.py +226 -0
- roms_tools/tests/test_initial_conditions.py +520 -0
- roms_tools/tests/test_surface_forcing.py +2622 -0
- roms_tools/tests/test_tides.py +365 -0
- roms_tools/tests/test_topography.py +78 -0
- roms_tools/tests/test_utils.py +16 -0
- roms_tools/tests/test_vertical_coordinate.py +337 -0
- {roms_tools-0.1.0.dist-info → roms_tools-1.0.0.dist-info}/METADATA +9 -4
- roms_tools-1.0.0.dist-info/RECORD +31 -0
- {roms_tools-0.1.0.dist-info → roms_tools-1.0.0.dist-info}/WHEEL +1 -1
- roms_tools/setup/atmospheric_forcing.py +0 -993
- roms_tools/tests/test_setup.py +0 -181
- roms_tools-0.1.0.dist-info/RECORD +0 -17
- {roms_tools-0.1.0.dist-info → roms_tools-1.0.0.dist-info}/LICENSE +0 -0
- {roms_tools-0.1.0.dist-info → roms_tools-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -1,993 +0,0 @@
|
|
|
1
|
-
import xarray as xr
|
|
2
|
-
import dask
|
|
3
|
-
from dataclasses import dataclass, field
|
|
4
|
-
from roms_tools.setup.grid import Grid
|
|
5
|
-
from datetime import datetime
|
|
6
|
-
import glob
|
|
7
|
-
import numpy as np
|
|
8
|
-
from typing import Optional, Dict
|
|
9
|
-
from roms_tools.setup.fill import lateral_fill
|
|
10
|
-
import calendar
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
@dataclass(frozen=True, kw_only=True)
|
|
14
|
-
class ForcingDataset:
|
|
15
|
-
"""
|
|
16
|
-
Represents forcing data on original grid.
|
|
17
|
-
|
|
18
|
-
Parameters
|
|
19
|
-
----------
|
|
20
|
-
filename : str
|
|
21
|
-
The path to the data files. Can contain wildcards.
|
|
22
|
-
start_time: datetime
|
|
23
|
-
The start time for selecting relevant data.
|
|
24
|
-
end_time: datetime
|
|
25
|
-
The end time for selecting relevant data.
|
|
26
|
-
dim_names: Dict[str, str], optional
|
|
27
|
-
Dictionary specifying the names of dimensions in the dataset.
|
|
28
|
-
|
|
29
|
-
Attributes
|
|
30
|
-
----------
|
|
31
|
-
ds : xr.Dataset
|
|
32
|
-
The xarray Dataset containing the forcing data on its original grid.
|
|
33
|
-
|
|
34
|
-
Examples
|
|
35
|
-
--------
|
|
36
|
-
>>> dataset = ForcingDataset(
|
|
37
|
-
... filename="data.nc",
|
|
38
|
-
... start_time=datetime(2022, 1, 1),
|
|
39
|
-
... end_time=datetime(2022, 12, 31),
|
|
40
|
-
... )
|
|
41
|
-
>>> dataset.load_data()
|
|
42
|
-
>>> print(dataset.ds)
|
|
43
|
-
<xarray.Dataset>
|
|
44
|
-
Dimensions: ...
|
|
45
|
-
"""
|
|
46
|
-
|
|
47
|
-
filename: str
|
|
48
|
-
start_time: datetime
|
|
49
|
-
end_time: datetime
|
|
50
|
-
dim_names: Dict[str, str] = field(
|
|
51
|
-
default_factory=lambda: {"longitude": "lon", "latitude": "lat", "time": "time"}
|
|
52
|
-
)
|
|
53
|
-
|
|
54
|
-
ds: xr.Dataset = field(init=False, repr=False)
|
|
55
|
-
|
|
56
|
-
def __post_init__(self):
|
|
57
|
-
|
|
58
|
-
ds = self.load_data()
|
|
59
|
-
|
|
60
|
-
# Select relevant times
|
|
61
|
-
times = (np.datetime64(self.start_time) < ds[self.dim_names["time"]]) & (
|
|
62
|
-
ds[self.dim_names["time"]] < np.datetime64(self.end_time)
|
|
63
|
-
)
|
|
64
|
-
ds = ds.where(times, drop=True)
|
|
65
|
-
|
|
66
|
-
# Make sure that latitude is ascending
|
|
67
|
-
diff = np.diff(ds[self.dim_names["latitude"]])
|
|
68
|
-
if np.all(diff < 0):
|
|
69
|
-
ds = ds.isel(**{self.dim_names["latitude"]: slice(None, None, -1)})
|
|
70
|
-
|
|
71
|
-
object.__setattr__(self, "ds", ds)
|
|
72
|
-
|
|
73
|
-
def load_data(self) -> xr.Dataset:
|
|
74
|
-
"""
|
|
75
|
-
Load dataset from the specified file.
|
|
76
|
-
|
|
77
|
-
Returns
|
|
78
|
-
-------
|
|
79
|
-
ds : xr.Dataset
|
|
80
|
-
The loaded xarray Dataset containing the forcing data.
|
|
81
|
-
|
|
82
|
-
Raises
|
|
83
|
-
------
|
|
84
|
-
FileNotFoundError
|
|
85
|
-
If the specified file does not exist.
|
|
86
|
-
"""
|
|
87
|
-
|
|
88
|
-
# Check if the file exists
|
|
89
|
-
matching_files = glob.glob(self.filename)
|
|
90
|
-
if not matching_files:
|
|
91
|
-
raise FileNotFoundError(
|
|
92
|
-
f"No files found matching the pattern '{self.filename}'."
|
|
93
|
-
)
|
|
94
|
-
|
|
95
|
-
# Load the dataset
|
|
96
|
-
with dask.config.set(**{"array.slicing.split_large_chunks": False}):
|
|
97
|
-
# initially, we wawnt time chunk size of 1 to enable quick .nan_check() and .plot() methods for AtmosphericForcing
|
|
98
|
-
ds = xr.open_mfdataset(
|
|
99
|
-
self.filename,
|
|
100
|
-
combine="nested",
|
|
101
|
-
concat_dim=self.dim_names["time"],
|
|
102
|
-
coords="minimal",
|
|
103
|
-
compat="override",
|
|
104
|
-
chunks={self.dim_names["time"]: 1},
|
|
105
|
-
)
|
|
106
|
-
|
|
107
|
-
return ds
|
|
108
|
-
|
|
109
|
-
def choose_subdomain(self, latitude_range, longitude_range, margin, straddle):
|
|
110
|
-
"""
|
|
111
|
-
Selects a subdomain from the given xarray Dataset based on latitude and longitude ranges,
|
|
112
|
-
extending the selection by the specified margin. Handles the conversion of longitude values
|
|
113
|
-
in the dataset from one range to another.
|
|
114
|
-
|
|
115
|
-
Parameters
|
|
116
|
-
----------
|
|
117
|
-
latitude_range : tuple
|
|
118
|
-
A tuple (lat_min, lat_max) specifying the minimum and maximum latitude values of the subdomain.
|
|
119
|
-
longitude_range : tuple
|
|
120
|
-
A tuple (lon_min, lon_max) specifying the minimum and maximum longitude values of the subdomain.
|
|
121
|
-
margin : float
|
|
122
|
-
Margin in degrees to extend beyond the specified latitude and longitude ranges when selecting the subdomain.
|
|
123
|
-
straddle : bool
|
|
124
|
-
If True, target longitudes are expected in the range [-180, 180].
|
|
125
|
-
If False, target longitudes are expected in the range [0, 360].
|
|
126
|
-
|
|
127
|
-
Returns
|
|
128
|
-
-------
|
|
129
|
-
xr.Dataset
|
|
130
|
-
The subset of the original dataset representing the chosen subdomain, including an extended area
|
|
131
|
-
to cover one extra grid point beyond the specified ranges.
|
|
132
|
-
|
|
133
|
-
Raises
|
|
134
|
-
------
|
|
135
|
-
ValueError
|
|
136
|
-
If the selected latitude or longitude range does not intersect with the dataset.
|
|
137
|
-
"""
|
|
138
|
-
lat_min, lat_max = latitude_range
|
|
139
|
-
lon_min, lon_max = longitude_range
|
|
140
|
-
|
|
141
|
-
lon = self.ds[self.dim_names["longitude"]]
|
|
142
|
-
# Adjust longitude range if needed to match the expected range
|
|
143
|
-
if not straddle:
|
|
144
|
-
if lon.min() < -180:
|
|
145
|
-
if lon_max + margin > 0:
|
|
146
|
-
lon_min -= 360
|
|
147
|
-
lon_max -= 360
|
|
148
|
-
elif lon.min() < 0:
|
|
149
|
-
if lon_max + margin > 180:
|
|
150
|
-
lon_min -= 360
|
|
151
|
-
lon_max -= 360
|
|
152
|
-
|
|
153
|
-
if straddle:
|
|
154
|
-
if lon.max() > 360:
|
|
155
|
-
if lon_min - margin < 180:
|
|
156
|
-
lon_min += 360
|
|
157
|
-
lon_max += 360
|
|
158
|
-
elif lon.max() > 180:
|
|
159
|
-
if lon_min - margin < 0:
|
|
160
|
-
lon_min += 360
|
|
161
|
-
lon_max += 360
|
|
162
|
-
|
|
163
|
-
# Select the subdomain
|
|
164
|
-
subdomain = self.ds.sel(
|
|
165
|
-
**{
|
|
166
|
-
self.dim_names["latitude"]: slice(lat_min - margin, lat_max + margin),
|
|
167
|
-
self.dim_names["longitude"]: slice(lon_min - margin, lon_max + margin),
|
|
168
|
-
}
|
|
169
|
-
)
|
|
170
|
-
|
|
171
|
-
# Check if the selected subdomain has zero dimensions in latitude or longitude
|
|
172
|
-
if subdomain[self.dim_names["latitude"]].size == 0:
|
|
173
|
-
raise ValueError("Selected latitude range does not intersect with dataset.")
|
|
174
|
-
|
|
175
|
-
if subdomain[self.dim_names["longitude"]].size == 0:
|
|
176
|
-
raise ValueError(
|
|
177
|
-
"Selected longitude range does not intersect with dataset."
|
|
178
|
-
)
|
|
179
|
-
|
|
180
|
-
# Adjust longitudes to expected range if needed
|
|
181
|
-
lon = subdomain[self.dim_names["longitude"]]
|
|
182
|
-
if straddle:
|
|
183
|
-
subdomain[self.dim_names["longitude"]] = xr.where(lon > 180, lon - 360, lon)
|
|
184
|
-
else:
|
|
185
|
-
subdomain[self.dim_names["longitude"]] = xr.where(lon < 0, lon + 360, lon)
|
|
186
|
-
|
|
187
|
-
# Set the modified subdomain to the object attribute
|
|
188
|
-
object.__setattr__(self, "ds", subdomain)
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
@dataclass(frozen=True, kw_only=True)
|
|
192
|
-
class SWRCorrection:
|
|
193
|
-
"""
|
|
194
|
-
Configuration for shortwave radiation correction.
|
|
195
|
-
|
|
196
|
-
Parameters
|
|
197
|
-
----------
|
|
198
|
-
filename : str
|
|
199
|
-
Filename of the correction data.
|
|
200
|
-
varname : str
|
|
201
|
-
Variable identifier for the correction.
|
|
202
|
-
dim_names: Dict[str, str], optional
|
|
203
|
-
Dictionary specifying the names of dimensions in the dataset.
|
|
204
|
-
Default is {"longitude": "lon", "latitude": "lat", "time": "time"}.
|
|
205
|
-
temporal_resolution : str, optional
|
|
206
|
-
Temporal resolution of the correction data. Default is "climatology".
|
|
207
|
-
|
|
208
|
-
Attributes
|
|
209
|
-
----------
|
|
210
|
-
ds : xr.Dataset
|
|
211
|
-
The loaded xarray Dataset containing the correction data.
|
|
212
|
-
|
|
213
|
-
Examples
|
|
214
|
-
--------
|
|
215
|
-
>>> swr_correction = SWRCorrection(
|
|
216
|
-
... filename="correction_data.nc",
|
|
217
|
-
... varname="corr",
|
|
218
|
-
... dim_names={
|
|
219
|
-
... "time": "time",
|
|
220
|
-
... "latitude": "latitude",
|
|
221
|
-
... "longitude": "longitude",
|
|
222
|
-
... },
|
|
223
|
-
... temporal_resolution="climatology",
|
|
224
|
-
... )
|
|
225
|
-
"""
|
|
226
|
-
|
|
227
|
-
filename: str
|
|
228
|
-
varname: str
|
|
229
|
-
dim_names: Dict[str, str] = field(
|
|
230
|
-
default_factory=lambda: {
|
|
231
|
-
"longitude": "longitude",
|
|
232
|
-
"latitude": "latitutde",
|
|
233
|
-
"time": "time",
|
|
234
|
-
}
|
|
235
|
-
)
|
|
236
|
-
temporal_resolution: str = "climatology"
|
|
237
|
-
ds: xr.Dataset = field(init=False, repr=False)
|
|
238
|
-
|
|
239
|
-
def __post_init__(self):
|
|
240
|
-
if self.temporal_resolution != "climatology":
|
|
241
|
-
raise NotImplementedError(
|
|
242
|
-
f"temporal_resolution must be 'climatology', got {self.temporal_resolution}"
|
|
243
|
-
)
|
|
244
|
-
|
|
245
|
-
ds = self._load_data()
|
|
246
|
-
self._check_dataset(ds)
|
|
247
|
-
ds = self._ensure_latitude_ascending(ds)
|
|
248
|
-
|
|
249
|
-
object.__setattr__(self, "ds", ds)
|
|
250
|
-
|
|
251
|
-
def _load_data(self):
|
|
252
|
-
"""
|
|
253
|
-
Load data from the specified file.
|
|
254
|
-
|
|
255
|
-
Returns
|
|
256
|
-
-------
|
|
257
|
-
ds : xr.Dataset
|
|
258
|
-
The loaded xarray Dataset containing the correction data.
|
|
259
|
-
|
|
260
|
-
Raises
|
|
261
|
-
------
|
|
262
|
-
FileNotFoundError
|
|
263
|
-
If the specified file does not exist.
|
|
264
|
-
|
|
265
|
-
"""
|
|
266
|
-
# Check if the file exists
|
|
267
|
-
|
|
268
|
-
# Check if any file matching the wildcard pattern exists
|
|
269
|
-
matching_files = glob.glob(self.filename)
|
|
270
|
-
if not matching_files:
|
|
271
|
-
raise FileNotFoundError(
|
|
272
|
-
f"No files found matching the pattern '{self.filename}'."
|
|
273
|
-
)
|
|
274
|
-
|
|
275
|
-
# Load the dataset
|
|
276
|
-
ds = xr.open_dataset(
|
|
277
|
-
self.filename,
|
|
278
|
-
chunks={
|
|
279
|
-
self.dim_names["time"]: -1,
|
|
280
|
-
self.dim_names["latitude"]: -1,
|
|
281
|
-
self.dim_names["longitude"]: -1,
|
|
282
|
-
},
|
|
283
|
-
)
|
|
284
|
-
|
|
285
|
-
return ds
|
|
286
|
-
|
|
287
|
-
def _check_dataset(self, ds: xr.Dataset) -> None:
|
|
288
|
-
"""
|
|
289
|
-
Check if the dataset contains the specified variable and dimensions.
|
|
290
|
-
|
|
291
|
-
Parameters
|
|
292
|
-
----------
|
|
293
|
-
ds : xr.Dataset
|
|
294
|
-
The xarray Dataset to check.
|
|
295
|
-
|
|
296
|
-
Raises
|
|
297
|
-
------
|
|
298
|
-
ValueError
|
|
299
|
-
If the dataset does not contain the specified variable or dimensions.
|
|
300
|
-
"""
|
|
301
|
-
if self.varname not in ds:
|
|
302
|
-
raise ValueError(
|
|
303
|
-
f"The dataset does not contain the variable '{self.varname}'."
|
|
304
|
-
)
|
|
305
|
-
|
|
306
|
-
for dim in self.dim_names.values():
|
|
307
|
-
if dim not in ds.dims:
|
|
308
|
-
raise ValueError(f"The dataset does not contain the dimension '{dim}'.")
|
|
309
|
-
|
|
310
|
-
def _ensure_latitude_ascending(self, ds: xr.Dataset) -> xr.Dataset:
|
|
311
|
-
"""
|
|
312
|
-
Ensure that the latitude dimension is in ascending order.
|
|
313
|
-
|
|
314
|
-
Parameters
|
|
315
|
-
----------
|
|
316
|
-
ds : xr.Dataset
|
|
317
|
-
The xarray Dataset to check.
|
|
318
|
-
|
|
319
|
-
Returns
|
|
320
|
-
-------
|
|
321
|
-
ds : xr.Dataset
|
|
322
|
-
The xarray Dataset with latitude in ascending order.
|
|
323
|
-
"""
|
|
324
|
-
# Make sure that latitude is ascending
|
|
325
|
-
lat_diff = np.diff(ds[self.dim_names["latitude"]])
|
|
326
|
-
if np.all(lat_diff < 0):
|
|
327
|
-
ds = ds.isel(**{self.dim_names["latitude"]: slice(None, None, -1)})
|
|
328
|
-
|
|
329
|
-
return ds
|
|
330
|
-
|
|
331
|
-
def _handle_longitudes(self, straddle: bool) -> None:
|
|
332
|
-
"""
|
|
333
|
-
Handles the conversion of longitude values in the dataset from one range to another.
|
|
334
|
-
|
|
335
|
-
Parameters
|
|
336
|
-
----------
|
|
337
|
-
straddle : bool
|
|
338
|
-
If True, target longitudes are in range [-180, 180].
|
|
339
|
-
If False, target longitudes are in range [0, 360].
|
|
340
|
-
|
|
341
|
-
Raises
|
|
342
|
-
------
|
|
343
|
-
ValueError: If the conversion results in discontinuous longitudes.
|
|
344
|
-
"""
|
|
345
|
-
lon = self.ds[self.dim_names["longitude"]]
|
|
346
|
-
|
|
347
|
-
if lon.min().values < 0 and not straddle:
|
|
348
|
-
# Convert from [-180, 180] to [0, 360]
|
|
349
|
-
self.ds[self.dim_names["longitude"]] = xr.where(lon < 0, lon + 360, lon)
|
|
350
|
-
|
|
351
|
-
if lon.max().values > 180 and straddle:
|
|
352
|
-
# Convert from [0, 360] to [-180, 180]
|
|
353
|
-
self.ds[self.dim_names["longitude"]] = xr.where(lon > 180, lon - 360, lon)
|
|
354
|
-
|
|
355
|
-
def _choose_subdomain(self, coords) -> xr.Dataset:
|
|
356
|
-
"""
|
|
357
|
-
Selects a subdomain from the dataset based on the specified latitude and longitude ranges.
|
|
358
|
-
|
|
359
|
-
Parameters
|
|
360
|
-
----------
|
|
361
|
-
coords : dict
|
|
362
|
-
A dictionary specifying the target coordinates.
|
|
363
|
-
|
|
364
|
-
Returns
|
|
365
|
-
-------
|
|
366
|
-
xr.Dataset
|
|
367
|
-
The subset of the original dataset representing the chosen subdomain.
|
|
368
|
-
|
|
369
|
-
Raises
|
|
370
|
-
------
|
|
371
|
-
ValueError
|
|
372
|
-
If the specified subdomain is not fully contained within the dataset.
|
|
373
|
-
"""
|
|
374
|
-
|
|
375
|
-
# Select the subdomain based on the specified latitude and longitude ranges
|
|
376
|
-
subdomain = self.ds.sel(**coords)
|
|
377
|
-
|
|
378
|
-
# Check if the selected subdomain contains the specified latitude and longitude values
|
|
379
|
-
if not subdomain[self.dim_names["latitude"]].equals(
|
|
380
|
-
coords[self.dim_names["latitude"]]
|
|
381
|
-
):
|
|
382
|
-
raise ValueError(
|
|
383
|
-
"The correction dataset does not contain all specified latitude values."
|
|
384
|
-
)
|
|
385
|
-
if not subdomain[self.dim_names["longitude"]].equals(
|
|
386
|
-
coords[self.dim_names["longitude"]]
|
|
387
|
-
):
|
|
388
|
-
raise ValueError(
|
|
389
|
-
"The correction dataset does not contain all specified longitude values."
|
|
390
|
-
)
|
|
391
|
-
|
|
392
|
-
return subdomain
|
|
393
|
-
|
|
394
|
-
def _interpolate_temporally(self, field, time):
|
|
395
|
-
"""
|
|
396
|
-
Interpolates the given field temporally based on the specified time points.
|
|
397
|
-
|
|
398
|
-
Parameters
|
|
399
|
-
----------
|
|
400
|
-
field : xarray.DataArray
|
|
401
|
-
The field data to be interpolated. This can be any variable from the dataset that
|
|
402
|
-
requires temporal interpolation, such as correction factors or any other relevant data.
|
|
403
|
-
time : xarray.DataArray or pandas.DatetimeIndex
|
|
404
|
-
The target time points for interpolation.
|
|
405
|
-
|
|
406
|
-
Returns
|
|
407
|
-
-------
|
|
408
|
-
xr.DataArray
|
|
409
|
-
The field values interpolated to the specified time points.
|
|
410
|
-
|
|
411
|
-
Raises
|
|
412
|
-
------
|
|
413
|
-
NotImplementedError
|
|
414
|
-
If the temporal resolution is not set to 'climatology'.
|
|
415
|
-
|
|
416
|
-
"""
|
|
417
|
-
if self.temporal_resolution != "climatology":
|
|
418
|
-
raise NotImplementedError(
|
|
419
|
-
f"temporal_resolution must be 'climatology', got {self.temporal_resolution}"
|
|
420
|
-
)
|
|
421
|
-
else:
|
|
422
|
-
field[self.dim_names["time"]] = field[self.dim_names["time"]].dt.days
|
|
423
|
-
day_of_year = time.dt.dayofyear
|
|
424
|
-
|
|
425
|
-
# Concatenate across the beginning and end of the year
|
|
426
|
-
time_concat = xr.concat(
|
|
427
|
-
[
|
|
428
|
-
field[self.dim_names["time"]][-1] - 365.25,
|
|
429
|
-
field[self.dim_names["time"]],
|
|
430
|
-
365.25 + field[self.dim_names["time"]][0],
|
|
431
|
-
],
|
|
432
|
-
dim=self.dim_names["time"],
|
|
433
|
-
)
|
|
434
|
-
field_concat = xr.concat(
|
|
435
|
-
[
|
|
436
|
-
field.isel({self.dim_names["time"]: -1}),
|
|
437
|
-
field,
|
|
438
|
-
field.isel({self.dim_names["time"]: 0}),
|
|
439
|
-
],
|
|
440
|
-
dim=self.dim_names["time"],
|
|
441
|
-
)
|
|
442
|
-
field_concat["time"] = time_concat
|
|
443
|
-
# Interpolate to specified times
|
|
444
|
-
field_interpolated = field_concat.interp(time=day_of_year, method="linear")
|
|
445
|
-
|
|
446
|
-
return field_interpolated
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
@dataclass(frozen=True, kw_only=True)
|
|
450
|
-
class Rivers:
|
|
451
|
-
"""
|
|
452
|
-
Configuration for river forcing.
|
|
453
|
-
|
|
454
|
-
Parameters
|
|
455
|
-
----------
|
|
456
|
-
filename : str, optional
|
|
457
|
-
Filename of the river forcing data.
|
|
458
|
-
"""
|
|
459
|
-
|
|
460
|
-
filename: str = ""
|
|
461
|
-
|
|
462
|
-
def __post_init__(self):
|
|
463
|
-
if not self.filename:
|
|
464
|
-
raise ValueError("The 'filename' must be provided.")
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
@dataclass(frozen=True, kw_only=True)
|
|
468
|
-
class AtmosphericForcing:
|
|
469
|
-
"""
|
|
470
|
-
Represents atmospheric forcing data for ocean modeling.
|
|
471
|
-
|
|
472
|
-
Parameters
|
|
473
|
-
----------
|
|
474
|
-
grid : Grid
|
|
475
|
-
Object representing the grid information.
|
|
476
|
-
use_coarse_grid: bool
|
|
477
|
-
Whether to interpolate to coarsened grid. Default is False.
|
|
478
|
-
start_time : datetime
|
|
479
|
-
Start time of the desired forcing data.
|
|
480
|
-
end_time : datetime
|
|
481
|
-
End time of the desired forcing data.
|
|
482
|
-
model_reference_date : datetime, optional
|
|
483
|
-
Reference date for the model. Default is January 1, 2000.
|
|
484
|
-
source : str, optional
|
|
485
|
-
Source of the atmospheric forcing data. Default is "era5".
|
|
486
|
-
filename: str
|
|
487
|
-
Path to the atmospheric forcing source data file. Can contain wildcards.
|
|
488
|
-
swr_correction : SWRCorrection
|
|
489
|
-
Shortwave radiation correction configuration.
|
|
490
|
-
rivers : Rivers, optional
|
|
491
|
-
River forcing configuration.
|
|
492
|
-
|
|
493
|
-
Attributes
|
|
494
|
-
----------
|
|
495
|
-
ds : xr.Dataset
|
|
496
|
-
Xarray Dataset containing the atmospheric forcing data.
|
|
497
|
-
|
|
498
|
-
Notes
|
|
499
|
-
-----
|
|
500
|
-
This class represents atmospheric forcing data used in ocean modeling. It provides a convenient
|
|
501
|
-
interface to work with forcing data including shortwave radiation correction and river forcing.
|
|
502
|
-
|
|
503
|
-
Examples
|
|
504
|
-
--------
|
|
505
|
-
>>> grid_info = Grid(...)
|
|
506
|
-
>>> start_time = datetime(2000, 1, 1)
|
|
507
|
-
>>> end_time = datetime(2000, 1, 2)
|
|
508
|
-
>>> atm_forcing = AtmosphericForcing(
|
|
509
|
-
... grid=grid_info,
|
|
510
|
-
... start_time=start_time,
|
|
511
|
-
... end_time=end_time,
|
|
512
|
-
... source="era5",
|
|
513
|
-
... filename="atmospheric_data_*.nc",
|
|
514
|
-
... swr_correction=swr_correction,
|
|
515
|
-
... )
|
|
516
|
-
"""
|
|
517
|
-
|
|
518
|
-
grid: Grid
|
|
519
|
-
use_coarse_grid: bool = False
|
|
520
|
-
start_time: datetime
|
|
521
|
-
end_time: datetime
|
|
522
|
-
model_reference_date: datetime = datetime(2000, 1, 1)
|
|
523
|
-
source: str = "era5"
|
|
524
|
-
filename: str
|
|
525
|
-
swr_correction: Optional["SWRCorrection"] = None
|
|
526
|
-
rivers: Optional["Rivers"] = None
|
|
527
|
-
ds: xr.Dataset = field(init=False, repr=False)
|
|
528
|
-
|
|
529
|
-
def __post_init__(self):
|
|
530
|
-
|
|
531
|
-
if self.use_coarse_grid:
|
|
532
|
-
if "lon_coarse" not in self.grid.ds:
|
|
533
|
-
raise ValueError(
|
|
534
|
-
"Grid has not been coarsened yet. Execute grid.coarsen() first."
|
|
535
|
-
)
|
|
536
|
-
|
|
537
|
-
lon = self.grid.ds.lon_coarse
|
|
538
|
-
lat = self.grid.ds.lat_coarse
|
|
539
|
-
angle = self.grid.ds.angle_coarse
|
|
540
|
-
else:
|
|
541
|
-
lon = self.grid.ds.lon_rho
|
|
542
|
-
lat = self.grid.ds.lat_rho
|
|
543
|
-
angle = self.grid.ds.angle
|
|
544
|
-
|
|
545
|
-
if self.source == "era5":
|
|
546
|
-
dims = {"longitude": "longitude", "latitude": "latitude", "time": "time"}
|
|
547
|
-
|
|
548
|
-
data = ForcingDataset(
|
|
549
|
-
filename=self.filename,
|
|
550
|
-
start_time=self.start_time,
|
|
551
|
-
end_time=self.end_time,
|
|
552
|
-
dim_names=dims,
|
|
553
|
-
)
|
|
554
|
-
|
|
555
|
-
# operate on longitudes between -180 and 180 unless ROMS domain lies at least 5 degrees in lontitude away from Greenwich meridian
|
|
556
|
-
lon = xr.where(lon > 180, lon - 360, lon)
|
|
557
|
-
straddle = True
|
|
558
|
-
if not self.grid.straddle and abs(lon).min() > 5:
|
|
559
|
-
lon = xr.where(lon < 0, lon + 360, lon)
|
|
560
|
-
straddle = False
|
|
561
|
-
|
|
562
|
-
# Step 1: Choose subdomain of forcing data including safety margin for interpolation, and Step 2: Convert to the proper longitude range.
|
|
563
|
-
# Step 1 is necessary to avoid discontinuous longitudes that could be introduced by Step 2.
|
|
564
|
-
# Discontinuous longitudes can lead to artifacts in the interpolation process. Specifically, if there is a data gap,
|
|
565
|
-
# discontinuous longitudes could result in values that appear to come from a distant location instead of producing NaNs.
|
|
566
|
-
# These NaNs are important as they can be identified and handled appropriately by the nan_check function.
|
|
567
|
-
data.choose_subdomain(
|
|
568
|
-
latitude_range=[lat.min().values, lat.max().values],
|
|
569
|
-
longitude_range=[lon.min().values, lon.max().values],
|
|
570
|
-
margin=2,
|
|
571
|
-
straddle=straddle,
|
|
572
|
-
)
|
|
573
|
-
|
|
574
|
-
# interpolate onto desired grid
|
|
575
|
-
if self.source == "era5":
|
|
576
|
-
mask = xr.where(data.ds["sst"].isel(time=0).isnull(), 0, 1)
|
|
577
|
-
varnames = {
|
|
578
|
-
"u10": "u10",
|
|
579
|
-
"v10": "v10",
|
|
580
|
-
"swr": "ssr",
|
|
581
|
-
"lwr": "strd",
|
|
582
|
-
"t2m": "t2m",
|
|
583
|
-
"d2m": "d2m",
|
|
584
|
-
"rain": "tp",
|
|
585
|
-
}
|
|
586
|
-
|
|
587
|
-
coords = {dims["latitude"]: lat, dims["longitude"]: lon}
|
|
588
|
-
u10 = self._interpolate(
|
|
589
|
-
data.ds[varnames["u10"]], mask, coords=coords, method="linear"
|
|
590
|
-
)
|
|
591
|
-
v10 = self._interpolate(
|
|
592
|
-
data.ds[varnames["v10"]], mask, coords=coords, method="linear"
|
|
593
|
-
)
|
|
594
|
-
swr = self._interpolate(
|
|
595
|
-
data.ds[varnames["swr"]], mask, coords=coords, method="linear"
|
|
596
|
-
)
|
|
597
|
-
lwr = self._interpolate(
|
|
598
|
-
data.ds[varnames["lwr"]], mask, coords=coords, method="linear"
|
|
599
|
-
)
|
|
600
|
-
t2m = self._interpolate(
|
|
601
|
-
data.ds[varnames["t2m"]], mask, coords=coords, method="linear"
|
|
602
|
-
)
|
|
603
|
-
d2m = self._interpolate(
|
|
604
|
-
data.ds[varnames["d2m"]], mask, coords=coords, method="linear"
|
|
605
|
-
)
|
|
606
|
-
rain = self._interpolate(
|
|
607
|
-
data.ds[varnames["rain"]], mask, coords=coords, method="linear"
|
|
608
|
-
)
|
|
609
|
-
|
|
610
|
-
if self.source == "era5":
|
|
611
|
-
# translate radiation to fluxes. ERA5 stores values integrated over 1 hour.
|
|
612
|
-
swr = swr / 3600 # from J/m^2 to W/m^2
|
|
613
|
-
lwr = lwr / 3600 # from J/m^2 to W/m^2
|
|
614
|
-
rain = rain * 100 * 24 # from m to cm/day
|
|
615
|
-
# convert from K to C
|
|
616
|
-
t2m = t2m - 273.15
|
|
617
|
-
d2m = d2m - 273.15
|
|
618
|
-
# relative humidity fraction
|
|
619
|
-
qair = np.exp((17.625 * d2m) / (243.04 + d2m)) / np.exp(
|
|
620
|
-
(17.625 * t2m) / (243.04 + t2m)
|
|
621
|
-
)
|
|
622
|
-
# convert relative to absolute humidity assuming constant pressure
|
|
623
|
-
patm = 1010.0
|
|
624
|
-
cff = (
|
|
625
|
-
(1.0007 + 3.46e-6 * patm)
|
|
626
|
-
* 6.1121
|
|
627
|
-
* np.exp(17.502 * t2m / (240.97 + t2m))
|
|
628
|
-
)
|
|
629
|
-
cff = cff * qair
|
|
630
|
-
qair = 0.62197 * (cff / (patm - 0.378 * cff))
|
|
631
|
-
|
|
632
|
-
# correct shortwave radiation
|
|
633
|
-
if self.swr_correction:
|
|
634
|
-
|
|
635
|
-
# choose same subdomain as forcing data so that we can use same mask
|
|
636
|
-
self.swr_correction._handle_longitudes(straddle=straddle)
|
|
637
|
-
coords_correction = {
|
|
638
|
-
self.swr_correction.dim_names["latitude"]: data.ds[
|
|
639
|
-
data.dim_names["latitude"]
|
|
640
|
-
],
|
|
641
|
-
self.swr_correction.dim_names["longitude"]: data.ds[
|
|
642
|
-
data.dim_names["longitude"]
|
|
643
|
-
],
|
|
644
|
-
}
|
|
645
|
-
subdomain = self.swr_correction._choose_subdomain(coords_correction)
|
|
646
|
-
|
|
647
|
-
# spatial interpolation
|
|
648
|
-
corr_factor = subdomain[self.swr_correction.varname]
|
|
649
|
-
coords_correction = {
|
|
650
|
-
self.swr_correction.dim_names["latitude"]: lat,
|
|
651
|
-
self.swr_correction.dim_names["longitude"]: lon,
|
|
652
|
-
}
|
|
653
|
-
corr_factor = self._interpolate(
|
|
654
|
-
corr_factor, mask, coords=coords_correction, method="linear"
|
|
655
|
-
)
|
|
656
|
-
|
|
657
|
-
# temporal interpolation
|
|
658
|
-
corr_factor = self.swr_correction._interpolate_temporally(
|
|
659
|
-
corr_factor, time=swr.time
|
|
660
|
-
)
|
|
661
|
-
|
|
662
|
-
swr = corr_factor * swr
|
|
663
|
-
|
|
664
|
-
if self.rivers:
|
|
665
|
-
NotImplementedError("River forcing is not implemented yet.")
|
|
666
|
-
# rain = rain + rivers
|
|
667
|
-
|
|
668
|
-
# save in new dataset
|
|
669
|
-
ds = xr.Dataset()
|
|
670
|
-
|
|
671
|
-
ds["uwnd"] = (u10 * np.cos(angle) + v10 * np.sin(angle)).astype(
|
|
672
|
-
np.float32
|
|
673
|
-
) # rotate to grid orientation
|
|
674
|
-
ds["uwnd"].attrs["long_name"] = "10 meter wind in x-direction"
|
|
675
|
-
ds["uwnd"].attrs["units"] = "m/s"
|
|
676
|
-
|
|
677
|
-
ds["vwnd"] = (v10 * np.cos(angle) - u10 * np.sin(angle)).astype(
|
|
678
|
-
np.float32
|
|
679
|
-
) # rotate to grid orientation
|
|
680
|
-
ds["vwnd"].attrs["long_name"] = "10 meter wind in y-direction"
|
|
681
|
-
ds["vwnd"].attrs["units"] = "m/s"
|
|
682
|
-
|
|
683
|
-
ds["swrad"] = swr.astype(np.float32)
|
|
684
|
-
ds["swrad"].attrs["long_name"] = "Downward short-wave (solar) radiation"
|
|
685
|
-
ds["swrad"].attrs["units"] = "W/m^2"
|
|
686
|
-
|
|
687
|
-
ds["lwrad"] = lwr.astype(np.float32)
|
|
688
|
-
ds["lwrad"].attrs["long_name"] = "Downward long-wave (thermal) radiation"
|
|
689
|
-
ds["lwrad"].attrs["units"] = "W/m^2"
|
|
690
|
-
|
|
691
|
-
ds["Tair"] = t2m.astype(np.float32)
|
|
692
|
-
ds["Tair"].attrs["long_name"] = "Air temperature at 2m"
|
|
693
|
-
ds["Tair"].attrs["units"] = "degrees C"
|
|
694
|
-
|
|
695
|
-
ds["qair"] = qair.astype(np.float32)
|
|
696
|
-
ds["qair"].attrs["long_name"] = "Absolute humidity at 2m"
|
|
697
|
-
ds["qair"].attrs["units"] = "kg/kg"
|
|
698
|
-
|
|
699
|
-
ds["rain"] = rain.astype(np.float32)
|
|
700
|
-
ds["rain"].attrs["long_name"] = "Total precipitation"
|
|
701
|
-
ds["rain"].attrs["units"] = "cm/day"
|
|
702
|
-
|
|
703
|
-
ds.attrs["Title"] = "ROMS bulk surface forcing file produced by roms-tools"
|
|
704
|
-
|
|
705
|
-
ds = ds.assign_coords({"lon": lon, "lat": lat})
|
|
706
|
-
if dims["time"] != "time":
|
|
707
|
-
ds = ds.rename({dims["time"]: "time"})
|
|
708
|
-
if self.use_coarse_grid:
|
|
709
|
-
ds = ds.rename({"eta_coarse": "eta_rho", "xi_coarse": "xi_rho"})
|
|
710
|
-
mask_roms = self.grid.ds["mask_coarse"].rename(
|
|
711
|
-
{"eta_coarse": "eta_rho", "xi_coarse": "xi_rho"}
|
|
712
|
-
)
|
|
713
|
-
else:
|
|
714
|
-
mask_roms = self.grid.ds["mask_rho"]
|
|
715
|
-
|
|
716
|
-
object.__setattr__(self, "ds", ds)
|
|
717
|
-
|
|
718
|
-
self.nan_check(mask_roms, time=0)
|
|
719
|
-
|
|
720
|
-
@staticmethod
|
|
721
|
-
def _interpolate(field, mask, coords, method="linear"):
|
|
722
|
-
"""
|
|
723
|
-
Interpolate a field using specified coordinates and a given method.
|
|
724
|
-
|
|
725
|
-
Parameters
|
|
726
|
-
----------
|
|
727
|
-
field : xr.DataArray
|
|
728
|
-
The data array to be interpolated.
|
|
729
|
-
|
|
730
|
-
mask : xr.DataArray
|
|
731
|
-
A data array with same spatial dimensions as `field`, where `1` indicates wet (ocean)
|
|
732
|
-
points and `0` indicates land points.
|
|
733
|
-
|
|
734
|
-
coords : dict
|
|
735
|
-
A dictionary specifying the target coordinates for interpolation. The keys
|
|
736
|
-
should match the dimensions of `field` (e.g., {"longitude": lon_values, "latitude": lat_values}).
|
|
737
|
-
|
|
738
|
-
method : str, optional, default='linear'
|
|
739
|
-
The interpolation method to use. Valid options are those supported by
|
|
740
|
-
`xarray.DataArray.interp`.
|
|
741
|
-
|
|
742
|
-
Returns
|
|
743
|
-
-------
|
|
744
|
-
xr.DataArray
|
|
745
|
-
The interpolated data array.
|
|
746
|
-
|
|
747
|
-
Notes
|
|
748
|
-
-----
|
|
749
|
-
This method first sets land values to NaN based on the provided mask. It then uses the
|
|
750
|
-
`lateral_fill` function to propagate ocean values. These two steps serve the purpose to
|
|
751
|
-
avoid interpolation across the land-ocean boundary. Finally, it performs interpolation
|
|
752
|
-
over the specified coordinates.
|
|
753
|
-
|
|
754
|
-
"""
|
|
755
|
-
|
|
756
|
-
dims = list(coords.keys())
|
|
757
|
-
|
|
758
|
-
# set land values to nan
|
|
759
|
-
field = field.where(mask)
|
|
760
|
-
# propagate ocean values into land interior before interpolation
|
|
761
|
-
field = lateral_fill(field, 1 - mask, dims)
|
|
762
|
-
# interpolate
|
|
763
|
-
field_interpolated = field.interp(**coords, method=method).drop_vars(dims)
|
|
764
|
-
|
|
765
|
-
return field_interpolated
|
|
766
|
-
|
|
767
|
-
def nan_check(self, mask, time=0) -> None:
|
|
768
|
-
"""
|
|
769
|
-
Checks for NaN values at wet points in all variables of the dataset at a specified time step.
|
|
770
|
-
|
|
771
|
-
Parameters
|
|
772
|
-
----------
|
|
773
|
-
mask : array-like
|
|
774
|
-
A boolean mask indicating the wet points in the dataset.
|
|
775
|
-
time : int
|
|
776
|
-
The time step at which to check for NaN values. Default is 0.
|
|
777
|
-
|
|
778
|
-
Raises
|
|
779
|
-
------
|
|
780
|
-
ValueError
|
|
781
|
-
If any variable contains NaN values at the specified time step.
|
|
782
|
-
|
|
783
|
-
"""
|
|
784
|
-
|
|
785
|
-
for var in self.ds.data_vars:
|
|
786
|
-
da = xr.where(mask == 1, self.ds[var].isel(time=time), 0)
|
|
787
|
-
if da.isnull().any().values:
|
|
788
|
-
raise ValueError(
|
|
789
|
-
f"NaN values found in the variable '{var}' at time step {time} over the ocean. This likely "
|
|
790
|
-
"occurs because the ROMS grid, including a small safety margin for interpolation, is not "
|
|
791
|
-
"fully contained within the dataset's longitude/latitude range. Please ensure that the "
|
|
792
|
-
"dataset covers the entire area required by the ROMS grid."
|
|
793
|
-
)
|
|
794
|
-
|
|
795
|
-
def plot(self, varname, time=0) -> None:
|
|
796
|
-
"""
|
|
797
|
-
Plot the specified atmospheric forcing field for a given time slice.
|
|
798
|
-
|
|
799
|
-
Parameters
|
|
800
|
-
----------
|
|
801
|
-
varname : str
|
|
802
|
-
The name of the atmospheric forcing field to plot. Options include:
|
|
803
|
-
- "uwnd": 10 meter wind in x-direction.
|
|
804
|
-
- "vwnd": 10 meter wind in y-direction.
|
|
805
|
-
- "swrad": Downward short-wave (solar) radiation.
|
|
806
|
-
- "lwrad": Downward long-wave (thermal) radiation.
|
|
807
|
-
- "Tair": Air temperature at 2m.
|
|
808
|
-
- "qair": Absolute humidity at 2m.
|
|
809
|
-
- "rain": Total precipitation.
|
|
810
|
-
time : int, optional
|
|
811
|
-
The time index to plot. Default is 0, which corresponds to the first
|
|
812
|
-
time slice.
|
|
813
|
-
|
|
814
|
-
Returns
|
|
815
|
-
-------
|
|
816
|
-
None
|
|
817
|
-
This method does not return any value. It generates and displays a plot.
|
|
818
|
-
|
|
819
|
-
Raises
|
|
820
|
-
------
|
|
821
|
-
ValueError
|
|
822
|
-
If the specified varname is not one of the valid options.
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
Examples
|
|
826
|
-
--------
|
|
827
|
-
>>> atm_forcing = AtmosphericForcing(
|
|
828
|
-
... grid=grid_info,
|
|
829
|
-
... start_time=start_time,
|
|
830
|
-
... end_time=end_time,
|
|
831
|
-
... source="era5",
|
|
832
|
-
... filename="atmospheric_data_*.nc",
|
|
833
|
-
... swr_correction=swr_correction,
|
|
834
|
-
... )
|
|
835
|
-
>>> atm_forcing.plot("uwnd", time=0)
|
|
836
|
-
"""
|
|
837
|
-
|
|
838
|
-
import cartopy.crs as ccrs
|
|
839
|
-
import matplotlib.pyplot as plt
|
|
840
|
-
|
|
841
|
-
lon_deg = self.ds.lon
|
|
842
|
-
lat_deg = self.ds.lat
|
|
843
|
-
|
|
844
|
-
# check if North or South pole are in domain
|
|
845
|
-
if lat_deg.max().values > 89 or lat_deg.min().values < -89:
|
|
846
|
-
raise NotImplementedError(
|
|
847
|
-
"Plotting the atmospheric forcing is not implemented for the case that the domain contains the North or South pole."
|
|
848
|
-
)
|
|
849
|
-
|
|
850
|
-
if self.grid.straddle:
|
|
851
|
-
lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
|
|
852
|
-
|
|
853
|
-
# Define projections
|
|
854
|
-
proj = ccrs.PlateCarree()
|
|
855
|
-
|
|
856
|
-
trans = ccrs.NearsidePerspective(
|
|
857
|
-
central_longitude=lon_deg.mean().values,
|
|
858
|
-
central_latitude=lat_deg.mean().values,
|
|
859
|
-
)
|
|
860
|
-
|
|
861
|
-
lon_deg = lon_deg.values
|
|
862
|
-
lat_deg = lat_deg.values
|
|
863
|
-
|
|
864
|
-
# find corners
|
|
865
|
-
(lo1, la1) = (lon_deg[0, 0], lat_deg[0, 0])
|
|
866
|
-
(lo2, la2) = (lon_deg[0, -1], lat_deg[0, -1])
|
|
867
|
-
(lo3, la3) = (lon_deg[-1, -1], lat_deg[-1, -1])
|
|
868
|
-
(lo4, la4) = (lon_deg[-1, 0], lat_deg[-1, 0])
|
|
869
|
-
|
|
870
|
-
# transform coordinates to projected space
|
|
871
|
-
lo1t, la1t = trans.transform_point(lo1, la1, proj)
|
|
872
|
-
lo2t, la2t = trans.transform_point(lo2, la2, proj)
|
|
873
|
-
lo3t, la3t = trans.transform_point(lo3, la3, proj)
|
|
874
|
-
lo4t, la4t = trans.transform_point(lo4, la4, proj)
|
|
875
|
-
|
|
876
|
-
plt.figure(figsize=(10, 10))
|
|
877
|
-
ax = plt.axes(projection=trans)
|
|
878
|
-
|
|
879
|
-
ax.plot(
|
|
880
|
-
[lo1t, lo2t, lo3t, lo4t, lo1t],
|
|
881
|
-
[la1t, la2t, la3t, la4t, la1t],
|
|
882
|
-
"go-",
|
|
883
|
-
)
|
|
884
|
-
|
|
885
|
-
ax.coastlines(
|
|
886
|
-
resolution="50m", linewidth=0.5, color="black"
|
|
887
|
-
) # add map of coastlines
|
|
888
|
-
ax.gridlines()
|
|
889
|
-
|
|
890
|
-
field = self.ds[varname].isel(time=time).compute()
|
|
891
|
-
if varname in ["uwnd", "vwnd"]:
|
|
892
|
-
vmax = max(field.max().values, -field.min().values)
|
|
893
|
-
vmin = -vmax
|
|
894
|
-
cmap = "RdBu_r"
|
|
895
|
-
else:
|
|
896
|
-
vmax = field.max().values
|
|
897
|
-
vmin = field.min().values
|
|
898
|
-
if varname in ["swrad", "lwrad", "Tair", "qair"]:
|
|
899
|
-
cmap = "YlOrRd"
|
|
900
|
-
else:
|
|
901
|
-
cmap = "YlGnBu"
|
|
902
|
-
|
|
903
|
-
p = ax.pcolormesh(
|
|
904
|
-
lon_deg, lat_deg, field, transform=proj, vmax=vmax, vmin=vmin, cmap=cmap
|
|
905
|
-
)
|
|
906
|
-
plt.colorbar(p, label=field.units)
|
|
907
|
-
ax.set_title(
|
|
908
|
-
"%s at time %s"
|
|
909
|
-
% (field.long_name, np.datetime_as_string(field.time, unit="s"))
|
|
910
|
-
)
|
|
911
|
-
plt.show()
|
|
912
|
-
|
|
913
|
-
def save(self, filepath: str, time_chunk_size: int = 1) -> None:
|
|
914
|
-
"""
|
|
915
|
-
Save the interpolated atmospheric forcing fields to netCDF4 files.
|
|
916
|
-
|
|
917
|
-
This method groups the dataset by year and month, chunks the data by the specified
|
|
918
|
-
time chunk size, and saves each chunked subset to a separate netCDF4 file named
|
|
919
|
-
according to the year, month, and day range if not a complete month of data is included.
|
|
920
|
-
|
|
921
|
-
Parameters
|
|
922
|
-
----------
|
|
923
|
-
filepath : str
|
|
924
|
-
The base path and filename for the output files. The files will be named with
|
|
925
|
-
the format "filepath.YYYYMM.nc" if a full month of data is included, or
|
|
926
|
-
"filepath.YYYYMMDD-DD.nc" otherwise.
|
|
927
|
-
time_chunk_size : int, optional
|
|
928
|
-
Number of time slices to include in each chunk along the time dimension. Default is 1,
|
|
929
|
-
meaning each chunk contains one time slice.
|
|
930
|
-
|
|
931
|
-
Returns
|
|
932
|
-
-------
|
|
933
|
-
None
|
|
934
|
-
"""
|
|
935
|
-
|
|
936
|
-
datasets = []
|
|
937
|
-
filenames = []
|
|
938
|
-
writes = []
|
|
939
|
-
|
|
940
|
-
# Group dataset by year
|
|
941
|
-
gb = self.ds.groupby("time.year")
|
|
942
|
-
|
|
943
|
-
for year, group_ds in gb:
|
|
944
|
-
# Further group each yearly group by month
|
|
945
|
-
sub_gb = group_ds.groupby("time.month")
|
|
946
|
-
|
|
947
|
-
for month, ds in sub_gb:
|
|
948
|
-
# Chunk the dataset by the specified time chunk size
|
|
949
|
-
ds = ds.chunk({"time": time_chunk_size})
|
|
950
|
-
datasets.append(ds)
|
|
951
|
-
|
|
952
|
-
# Determine the number of days in the month
|
|
953
|
-
num_days_in_month = calendar.monthrange(year, month)[1]
|
|
954
|
-
first_day = ds.time.dt.day.values[0]
|
|
955
|
-
last_day = ds.time.dt.day.values[-1]
|
|
956
|
-
|
|
957
|
-
# Create filename based on whether the dataset contains a full month
|
|
958
|
-
if first_day == 1 and last_day == num_days_in_month:
|
|
959
|
-
# Full month format: "filepath.YYYYMM.nc"
|
|
960
|
-
year_month_str = f"{year}{month:02}"
|
|
961
|
-
filename = f"{filepath}.{year_month_str}.nc"
|
|
962
|
-
else:
|
|
963
|
-
# Partial month format: "filepath.YYYYMMDD-DD.nc"
|
|
964
|
-
year_month_day_str = f"{year}{month:02}{first_day:02}-{last_day:02}"
|
|
965
|
-
filename = f"{filepath}.{year_month_day_str}.nc"
|
|
966
|
-
filenames.append(filename)
|
|
967
|
-
|
|
968
|
-
print("Saving the following files:")
|
|
969
|
-
for filename in filenames:
|
|
970
|
-
print(filename)
|
|
971
|
-
|
|
972
|
-
for ds, filename in zip(datasets, filenames):
|
|
973
|
-
|
|
974
|
-
# Translate the time coordinate to days since the model reference date
|
|
975
|
-
model_reference_date = np.datetime64(self.model_reference_date)
|
|
976
|
-
|
|
977
|
-
# Preserve the original time coordinate for readability
|
|
978
|
-
ds["Time"] = ds["time"]
|
|
979
|
-
|
|
980
|
-
# Convert the time coordinate to the format expected by ROMS (days since model reference date)
|
|
981
|
-
ds["time"] = (
|
|
982
|
-
(ds["time"] - model_reference_date).astype("float64") / 3600 / 24 * 1e-9
|
|
983
|
-
)
|
|
984
|
-
ds["time"].attrs[
|
|
985
|
-
"long_name"
|
|
986
|
-
] = f"time since {np.datetime_as_string(model_reference_date, unit='D')}"
|
|
987
|
-
|
|
988
|
-
# Prepare the dataset for writing to a netCDF file without immediately computing
|
|
989
|
-
write = ds.to_netcdf(filename, compute=False)
|
|
990
|
-
writes.append(write)
|
|
991
|
-
|
|
992
|
-
# Perform the actual write operations in parallel
|
|
993
|
-
dask.compute(*writes)
|