roms-tools 0.1.0__py3-none-any.whl → 0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ci/environment.yml +1 -0
- roms_tools/__init__.py +3 -0
- roms_tools/_version.py +1 -1
- roms_tools/setup/atmospheric_forcing.py +335 -393
- roms_tools/setup/boundary_forcing.py +711 -0
- roms_tools/setup/datasets.py +434 -25
- roms_tools/setup/fill.py +118 -5
- roms_tools/setup/grid.py +145 -19
- roms_tools/setup/initial_conditions.py +528 -0
- roms_tools/setup/plot.py +149 -4
- roms_tools/setup/tides.py +570 -437
- roms_tools/setup/topography.py +17 -2
- roms_tools/setup/utils.py +162 -0
- roms_tools/setup/vertical_coordinate.py +494 -0
- roms_tools/tests/test_atmospheric_forcing.py +1645 -0
- roms_tools/tests/test_boundary_forcing.py +332 -0
- roms_tools/tests/test_datasets.py +306 -0
- roms_tools/tests/test_grid.py +226 -0
- roms_tools/tests/test_initial_conditions.py +300 -0
- roms_tools/tests/test_tides.py +366 -0
- roms_tools/tests/test_topography.py +78 -0
- roms_tools/tests/test_vertical_coordinate.py +337 -0
- {roms_tools-0.1.0.dist-info → roms_tools-0.20.dist-info}/METADATA +3 -2
- roms_tools-0.20.dist-info/RECORD +28 -0
- {roms_tools-0.1.0.dist-info → roms_tools-0.20.dist-info}/WHEEL +1 -1
- roms_tools/tests/test_setup.py +0 -181
- roms_tools-0.1.0.dist-info/RECORD +0 -17
- {roms_tools-0.1.0.dist-info → roms_tools-0.20.dist-info}/LICENSE +0 -0
- {roms_tools-0.1.0.dist-info → roms_tools-0.20.dist-info}/top_level.txt +0 -0
|
@@ -1,191 +1,19 @@
|
|
|
1
1
|
import xarray as xr
|
|
2
2
|
import dask
|
|
3
|
-
|
|
3
|
+
import yaml
|
|
4
|
+
import importlib.metadata
|
|
5
|
+
from dataclasses import dataclass, field, asdict
|
|
4
6
|
from roms_tools.setup.grid import Grid
|
|
5
7
|
from datetime import datetime
|
|
6
8
|
import glob
|
|
7
9
|
import numpy as np
|
|
8
10
|
from typing import Optional, Dict
|
|
9
|
-
from roms_tools.setup.fill import
|
|
11
|
+
from roms_tools.setup.fill import fill_and_interpolate
|
|
12
|
+
from roms_tools.setup.datasets import Dataset
|
|
13
|
+
from roms_tools.setup.utils import nan_check
|
|
14
|
+
from roms_tools.setup.plot import _plot
|
|
10
15
|
import calendar
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
@dataclass(frozen=True, kw_only=True)
|
|
14
|
-
class ForcingDataset:
|
|
15
|
-
"""
|
|
16
|
-
Represents forcing data on original grid.
|
|
17
|
-
|
|
18
|
-
Parameters
|
|
19
|
-
----------
|
|
20
|
-
filename : str
|
|
21
|
-
The path to the data files. Can contain wildcards.
|
|
22
|
-
start_time: datetime
|
|
23
|
-
The start time for selecting relevant data.
|
|
24
|
-
end_time: datetime
|
|
25
|
-
The end time for selecting relevant data.
|
|
26
|
-
dim_names: Dict[str, str], optional
|
|
27
|
-
Dictionary specifying the names of dimensions in the dataset.
|
|
28
|
-
|
|
29
|
-
Attributes
|
|
30
|
-
----------
|
|
31
|
-
ds : xr.Dataset
|
|
32
|
-
The xarray Dataset containing the forcing data on its original grid.
|
|
33
|
-
|
|
34
|
-
Examples
|
|
35
|
-
--------
|
|
36
|
-
>>> dataset = ForcingDataset(
|
|
37
|
-
... filename="data.nc",
|
|
38
|
-
... start_time=datetime(2022, 1, 1),
|
|
39
|
-
... end_time=datetime(2022, 12, 31),
|
|
40
|
-
... )
|
|
41
|
-
>>> dataset.load_data()
|
|
42
|
-
>>> print(dataset.ds)
|
|
43
|
-
<xarray.Dataset>
|
|
44
|
-
Dimensions: ...
|
|
45
|
-
"""
|
|
46
|
-
|
|
47
|
-
filename: str
|
|
48
|
-
start_time: datetime
|
|
49
|
-
end_time: datetime
|
|
50
|
-
dim_names: Dict[str, str] = field(
|
|
51
|
-
default_factory=lambda: {"longitude": "lon", "latitude": "lat", "time": "time"}
|
|
52
|
-
)
|
|
53
|
-
|
|
54
|
-
ds: xr.Dataset = field(init=False, repr=False)
|
|
55
|
-
|
|
56
|
-
def __post_init__(self):
|
|
57
|
-
|
|
58
|
-
ds = self.load_data()
|
|
59
|
-
|
|
60
|
-
# Select relevant times
|
|
61
|
-
times = (np.datetime64(self.start_time) < ds[self.dim_names["time"]]) & (
|
|
62
|
-
ds[self.dim_names["time"]] < np.datetime64(self.end_time)
|
|
63
|
-
)
|
|
64
|
-
ds = ds.where(times, drop=True)
|
|
65
|
-
|
|
66
|
-
# Make sure that latitude is ascending
|
|
67
|
-
diff = np.diff(ds[self.dim_names["latitude"]])
|
|
68
|
-
if np.all(diff < 0):
|
|
69
|
-
ds = ds.isel(**{self.dim_names["latitude"]: slice(None, None, -1)})
|
|
70
|
-
|
|
71
|
-
object.__setattr__(self, "ds", ds)
|
|
72
|
-
|
|
73
|
-
def load_data(self) -> xr.Dataset:
|
|
74
|
-
"""
|
|
75
|
-
Load dataset from the specified file.
|
|
76
|
-
|
|
77
|
-
Returns
|
|
78
|
-
-------
|
|
79
|
-
ds : xr.Dataset
|
|
80
|
-
The loaded xarray Dataset containing the forcing data.
|
|
81
|
-
|
|
82
|
-
Raises
|
|
83
|
-
------
|
|
84
|
-
FileNotFoundError
|
|
85
|
-
If the specified file does not exist.
|
|
86
|
-
"""
|
|
87
|
-
|
|
88
|
-
# Check if the file exists
|
|
89
|
-
matching_files = glob.glob(self.filename)
|
|
90
|
-
if not matching_files:
|
|
91
|
-
raise FileNotFoundError(
|
|
92
|
-
f"No files found matching the pattern '{self.filename}'."
|
|
93
|
-
)
|
|
94
|
-
|
|
95
|
-
# Load the dataset
|
|
96
|
-
with dask.config.set(**{"array.slicing.split_large_chunks": False}):
|
|
97
|
-
# initially, we wawnt time chunk size of 1 to enable quick .nan_check() and .plot() methods for AtmosphericForcing
|
|
98
|
-
ds = xr.open_mfdataset(
|
|
99
|
-
self.filename,
|
|
100
|
-
combine="nested",
|
|
101
|
-
concat_dim=self.dim_names["time"],
|
|
102
|
-
coords="minimal",
|
|
103
|
-
compat="override",
|
|
104
|
-
chunks={self.dim_names["time"]: 1},
|
|
105
|
-
)
|
|
106
|
-
|
|
107
|
-
return ds
|
|
108
|
-
|
|
109
|
-
def choose_subdomain(self, latitude_range, longitude_range, margin, straddle):
|
|
110
|
-
"""
|
|
111
|
-
Selects a subdomain from the given xarray Dataset based on latitude and longitude ranges,
|
|
112
|
-
extending the selection by the specified margin. Handles the conversion of longitude values
|
|
113
|
-
in the dataset from one range to another.
|
|
114
|
-
|
|
115
|
-
Parameters
|
|
116
|
-
----------
|
|
117
|
-
latitude_range : tuple
|
|
118
|
-
A tuple (lat_min, lat_max) specifying the minimum and maximum latitude values of the subdomain.
|
|
119
|
-
longitude_range : tuple
|
|
120
|
-
A tuple (lon_min, lon_max) specifying the minimum and maximum longitude values of the subdomain.
|
|
121
|
-
margin : float
|
|
122
|
-
Margin in degrees to extend beyond the specified latitude and longitude ranges when selecting the subdomain.
|
|
123
|
-
straddle : bool
|
|
124
|
-
If True, target longitudes are expected in the range [-180, 180].
|
|
125
|
-
If False, target longitudes are expected in the range [0, 360].
|
|
126
|
-
|
|
127
|
-
Returns
|
|
128
|
-
-------
|
|
129
|
-
xr.Dataset
|
|
130
|
-
The subset of the original dataset representing the chosen subdomain, including an extended area
|
|
131
|
-
to cover one extra grid point beyond the specified ranges.
|
|
132
|
-
|
|
133
|
-
Raises
|
|
134
|
-
------
|
|
135
|
-
ValueError
|
|
136
|
-
If the selected latitude or longitude range does not intersect with the dataset.
|
|
137
|
-
"""
|
|
138
|
-
lat_min, lat_max = latitude_range
|
|
139
|
-
lon_min, lon_max = longitude_range
|
|
140
|
-
|
|
141
|
-
lon = self.ds[self.dim_names["longitude"]]
|
|
142
|
-
# Adjust longitude range if needed to match the expected range
|
|
143
|
-
if not straddle:
|
|
144
|
-
if lon.min() < -180:
|
|
145
|
-
if lon_max + margin > 0:
|
|
146
|
-
lon_min -= 360
|
|
147
|
-
lon_max -= 360
|
|
148
|
-
elif lon.min() < 0:
|
|
149
|
-
if lon_max + margin > 180:
|
|
150
|
-
lon_min -= 360
|
|
151
|
-
lon_max -= 360
|
|
152
|
-
|
|
153
|
-
if straddle:
|
|
154
|
-
if lon.max() > 360:
|
|
155
|
-
if lon_min - margin < 180:
|
|
156
|
-
lon_min += 360
|
|
157
|
-
lon_max += 360
|
|
158
|
-
elif lon.max() > 180:
|
|
159
|
-
if lon_min - margin < 0:
|
|
160
|
-
lon_min += 360
|
|
161
|
-
lon_max += 360
|
|
162
|
-
|
|
163
|
-
# Select the subdomain
|
|
164
|
-
subdomain = self.ds.sel(
|
|
165
|
-
**{
|
|
166
|
-
self.dim_names["latitude"]: slice(lat_min - margin, lat_max + margin),
|
|
167
|
-
self.dim_names["longitude"]: slice(lon_min - margin, lon_max + margin),
|
|
168
|
-
}
|
|
169
|
-
)
|
|
170
|
-
|
|
171
|
-
# Check if the selected subdomain has zero dimensions in latitude or longitude
|
|
172
|
-
if subdomain[self.dim_names["latitude"]].size == 0:
|
|
173
|
-
raise ValueError("Selected latitude range does not intersect with dataset.")
|
|
174
|
-
|
|
175
|
-
if subdomain[self.dim_names["longitude"]].size == 0:
|
|
176
|
-
raise ValueError(
|
|
177
|
-
"Selected longitude range does not intersect with dataset."
|
|
178
|
-
)
|
|
179
|
-
|
|
180
|
-
# Adjust longitudes to expected range if needed
|
|
181
|
-
lon = subdomain[self.dim_names["longitude"]]
|
|
182
|
-
if straddle:
|
|
183
|
-
subdomain[self.dim_names["longitude"]] = xr.where(lon > 180, lon - 360, lon)
|
|
184
|
-
else:
|
|
185
|
-
subdomain[self.dim_names["longitude"]] = xr.where(lon < 0, lon + 360, lon)
|
|
186
|
-
|
|
187
|
-
# Set the modified subdomain to the object attribute
|
|
188
|
-
object.__setattr__(self, "ds", subdomain)
|
|
16
|
+
import matplotlib.pyplot as plt
|
|
189
17
|
|
|
190
18
|
|
|
191
19
|
@dataclass(frozen=True, kw_only=True)
|
|
@@ -445,6 +273,43 @@ class SWRCorrection:
|
|
|
445
273
|
|
|
446
274
|
return field_interpolated
|
|
447
275
|
|
|
276
|
+
@classmethod
|
|
277
|
+
def from_yaml(cls, filepath: str) -> "SWRCorrection":
|
|
278
|
+
"""
|
|
279
|
+
Create an instance of the class from a YAML file.
|
|
280
|
+
|
|
281
|
+
Parameters
|
|
282
|
+
----------
|
|
283
|
+
filepath : str
|
|
284
|
+
The path to the YAML file from which the parameters will be read.
|
|
285
|
+
|
|
286
|
+
Returns
|
|
287
|
+
-------
|
|
288
|
+
Grid
|
|
289
|
+
An instance of the Grid class.
|
|
290
|
+
"""
|
|
291
|
+
# Read the entire file content
|
|
292
|
+
with open(filepath, "r") as file:
|
|
293
|
+
file_content = file.read()
|
|
294
|
+
|
|
295
|
+
# Split the content into YAML documents
|
|
296
|
+
documents = list(yaml.safe_load_all(file_content))
|
|
297
|
+
|
|
298
|
+
swr_correction_data = None
|
|
299
|
+
|
|
300
|
+
# Iterate over documents to find the header and grid configuration
|
|
301
|
+
for doc in documents:
|
|
302
|
+
if doc is None:
|
|
303
|
+
continue
|
|
304
|
+
if "SWRCorrection" in doc:
|
|
305
|
+
swr_correction_data = doc["SWRCorrection"]
|
|
306
|
+
break
|
|
307
|
+
|
|
308
|
+
if swr_correction_data is None:
|
|
309
|
+
raise ValueError("No SWRCorrection configuration found in the YAML file.")
|
|
310
|
+
|
|
311
|
+
return cls(**swr_correction_data)
|
|
312
|
+
|
|
448
313
|
|
|
449
314
|
@dataclass(frozen=True, kw_only=True)
|
|
450
315
|
class Rivers:
|
|
@@ -463,6 +328,43 @@ class Rivers:
|
|
|
463
328
|
if not self.filename:
|
|
464
329
|
raise ValueError("The 'filename' must be provided.")
|
|
465
330
|
|
|
331
|
+
@classmethod
|
|
332
|
+
def from_yaml(cls, filepath: str) -> "Rivers":
|
|
333
|
+
"""
|
|
334
|
+
Create an instance of the class from a YAML file.
|
|
335
|
+
|
|
336
|
+
Parameters
|
|
337
|
+
----------
|
|
338
|
+
filepath : str
|
|
339
|
+
The path to the YAML file from which the parameters will be read.
|
|
340
|
+
|
|
341
|
+
Returns
|
|
342
|
+
-------
|
|
343
|
+
Grid
|
|
344
|
+
An instance of the Grid class.
|
|
345
|
+
"""
|
|
346
|
+
# Read the entire file content
|
|
347
|
+
with open(filepath, "r") as file:
|
|
348
|
+
file_content = file.read()
|
|
349
|
+
|
|
350
|
+
# Split the content into YAML documents
|
|
351
|
+
documents = list(yaml.safe_load_all(file_content))
|
|
352
|
+
|
|
353
|
+
rivers_data = None
|
|
354
|
+
|
|
355
|
+
# Iterate over documents to find the header and grid configuration
|
|
356
|
+
for doc in documents:
|
|
357
|
+
if doc is None:
|
|
358
|
+
continue
|
|
359
|
+
if "Rivers" in doc:
|
|
360
|
+
rivers_data = doc
|
|
361
|
+
break
|
|
362
|
+
|
|
363
|
+
if rivers_data is None:
|
|
364
|
+
raise ValueError("No Rivers configuration found in the YAML file.")
|
|
365
|
+
|
|
366
|
+
return cls(**rivers_data)
|
|
367
|
+
|
|
466
368
|
|
|
467
369
|
@dataclass(frozen=True, kw_only=True)
|
|
468
370
|
class AtmosphericForcing:
|
|
@@ -482,7 +384,7 @@ class AtmosphericForcing:
|
|
|
482
384
|
model_reference_date : datetime, optional
|
|
483
385
|
Reference date for the model. Default is January 1, 2000.
|
|
484
386
|
source : str, optional
|
|
485
|
-
Source of the atmospheric forcing data. Default is "
|
|
387
|
+
Source of the atmospheric forcing data. Default is "ERA5".
|
|
486
388
|
filename: str
|
|
487
389
|
Path to the atmospheric forcing source data file. Can contain wildcards.
|
|
488
390
|
swr_correction : SWRCorrection
|
|
@@ -495,10 +397,6 @@ class AtmosphericForcing:
|
|
|
495
397
|
ds : xr.Dataset
|
|
496
398
|
Xarray Dataset containing the atmospheric forcing data.
|
|
497
399
|
|
|
498
|
-
Notes
|
|
499
|
-
-----
|
|
500
|
-
This class represents atmospheric forcing data used in ocean modeling. It provides a convenient
|
|
501
|
-
interface to work with forcing data including shortwave radiation correction and river forcing.
|
|
502
400
|
|
|
503
401
|
Examples
|
|
504
402
|
--------
|
|
@@ -509,7 +407,7 @@ class AtmosphericForcing:
|
|
|
509
407
|
... grid=grid_info,
|
|
510
408
|
... start_time=start_time,
|
|
511
409
|
... end_time=end_time,
|
|
512
|
-
... source="
|
|
410
|
+
... source="ERA5",
|
|
513
411
|
... filename="atmospheric_data_*.nc",
|
|
514
412
|
... swr_correction=swr_correction,
|
|
515
413
|
... )
|
|
@@ -520,7 +418,7 @@ class AtmosphericForcing:
|
|
|
520
418
|
start_time: datetime
|
|
521
419
|
end_time: datetime
|
|
522
420
|
model_reference_date: datetime = datetime(2000, 1, 1)
|
|
523
|
-
source: str = "
|
|
421
|
+
source: str = "ERA5"
|
|
524
422
|
filename: str
|
|
525
423
|
swr_correction: Optional["SWRCorrection"] = None
|
|
526
424
|
rivers: Optional["Rivers"] = None
|
|
@@ -528,6 +426,30 @@ class AtmosphericForcing:
|
|
|
528
426
|
|
|
529
427
|
def __post_init__(self):
|
|
530
428
|
|
|
429
|
+
# Check that the source is "ERA5"
|
|
430
|
+
if self.source != "ERA5":
|
|
431
|
+
raise ValueError('Only "ERA5" is a valid option for source.')
|
|
432
|
+
if self.source == "ERA5":
|
|
433
|
+
dims = {"longitude": "longitude", "latitude": "latitude", "time": "time"}
|
|
434
|
+
varnames = {
|
|
435
|
+
"u10": "u10",
|
|
436
|
+
"v10": "v10",
|
|
437
|
+
"swr": "ssr",
|
|
438
|
+
"lwr": "strd",
|
|
439
|
+
"t2m": "t2m",
|
|
440
|
+
"d2m": "d2m",
|
|
441
|
+
"rain": "tp",
|
|
442
|
+
"mask": "sst",
|
|
443
|
+
}
|
|
444
|
+
|
|
445
|
+
data = Dataset(
|
|
446
|
+
filename=self.filename,
|
|
447
|
+
start_time=self.start_time,
|
|
448
|
+
end_time=self.end_time,
|
|
449
|
+
var_names=varnames.values(),
|
|
450
|
+
dim_names=dims,
|
|
451
|
+
)
|
|
452
|
+
|
|
531
453
|
if self.use_coarse_grid:
|
|
532
454
|
if "lon_coarse" not in self.grid.ds:
|
|
533
455
|
raise ValueError(
|
|
@@ -542,16 +464,6 @@ class AtmosphericForcing:
|
|
|
542
464
|
lat = self.grid.ds.lat_rho
|
|
543
465
|
angle = self.grid.ds.angle
|
|
544
466
|
|
|
545
|
-
if self.source == "era5":
|
|
546
|
-
dims = {"longitude": "longitude", "latitude": "latitude", "time": "time"}
|
|
547
|
-
|
|
548
|
-
data = ForcingDataset(
|
|
549
|
-
filename=self.filename,
|
|
550
|
-
start_time=self.start_time,
|
|
551
|
-
end_time=self.end_time,
|
|
552
|
-
dim_names=dims,
|
|
553
|
-
)
|
|
554
|
-
|
|
555
467
|
# operate on longitudes between -180 and 180 unless ROMS domain lies at least 5 degrees in lontitude away from Greenwich meridian
|
|
556
468
|
lon = xr.where(lon > 180, lon - 360, lon)
|
|
557
469
|
straddle = True
|
|
@@ -559,9 +471,12 @@ class AtmosphericForcing:
|
|
|
559
471
|
lon = xr.where(lon < 0, lon + 360, lon)
|
|
560
472
|
straddle = False
|
|
561
473
|
|
|
474
|
+
# The following consists of two steps:
|
|
562
475
|
# Step 1: Choose subdomain of forcing data including safety margin for interpolation, and Step 2: Convert to the proper longitude range.
|
|
563
|
-
#
|
|
564
|
-
#
|
|
476
|
+
# We perform these two steps for two reasons:
|
|
477
|
+
# A) Since the horizontal dimensions consist of a single chunk, selecting a subdomain before interpolation is a lot more performant.
|
|
478
|
+
# B) Step 1 is necessary to avoid discontinuous longitudes that could be introduced by Step 2. Specifically, discontinuous longitudes
|
|
479
|
+
# can lead to artifacts in the interpolation process. Specifically, if there is a data gap if data is not global,
|
|
565
480
|
# discontinuous longitudes could result in values that appear to come from a distant location instead of producing NaNs.
|
|
566
481
|
# These NaNs are important as they can be identified and handled appropriately by the nan_check function.
|
|
567
482
|
data.choose_subdomain(
|
|
@@ -572,42 +487,33 @@ class AtmosphericForcing:
|
|
|
572
487
|
)
|
|
573
488
|
|
|
574
489
|
# interpolate onto desired grid
|
|
575
|
-
if self.source == "era5":
|
|
576
|
-
mask = xr.where(data.ds["sst"].isel(time=0).isnull(), 0, 1)
|
|
577
|
-
varnames = {
|
|
578
|
-
"u10": "u10",
|
|
579
|
-
"v10": "v10",
|
|
580
|
-
"swr": "ssr",
|
|
581
|
-
"lwr": "strd",
|
|
582
|
-
"t2m": "t2m",
|
|
583
|
-
"d2m": "d2m",
|
|
584
|
-
"rain": "tp",
|
|
585
|
-
}
|
|
586
|
-
|
|
587
490
|
coords = {dims["latitude"]: lat, dims["longitude"]: lon}
|
|
588
|
-
u10 = self._interpolate(
|
|
589
|
-
data.ds[varnames["u10"]], mask, coords=coords, method="linear"
|
|
590
|
-
)
|
|
591
|
-
v10 = self._interpolate(
|
|
592
|
-
data.ds[varnames["v10"]], mask, coords=coords, method="linear"
|
|
593
|
-
)
|
|
594
|
-
swr = self._interpolate(
|
|
595
|
-
data.ds[varnames["swr"]], mask, coords=coords, method="linear"
|
|
596
|
-
)
|
|
597
|
-
lwr = self._interpolate(
|
|
598
|
-
data.ds[varnames["lwr"]], mask, coords=coords, method="linear"
|
|
599
|
-
)
|
|
600
|
-
t2m = self._interpolate(
|
|
601
|
-
data.ds[varnames["t2m"]], mask, coords=coords, method="linear"
|
|
602
|
-
)
|
|
603
|
-
d2m = self._interpolate(
|
|
604
|
-
data.ds[varnames["d2m"]], mask, coords=coords, method="linear"
|
|
605
|
-
)
|
|
606
|
-
rain = self._interpolate(
|
|
607
|
-
data.ds[varnames["rain"]], mask, coords=coords, method="linear"
|
|
608
|
-
)
|
|
609
491
|
|
|
610
|
-
|
|
492
|
+
data_vars = {}
|
|
493
|
+
|
|
494
|
+
mask = xr.where(data.ds[varnames["mask"]].isel(time=0).isnull(), 0, 1)
|
|
495
|
+
|
|
496
|
+
# Fill and interpolate each variable
|
|
497
|
+
for var in varnames.keys():
|
|
498
|
+
if var != "mask":
|
|
499
|
+
data_vars[var] = fill_and_interpolate(
|
|
500
|
+
data.ds[varnames[var]],
|
|
501
|
+
mask,
|
|
502
|
+
list(coords.keys()),
|
|
503
|
+
coords,
|
|
504
|
+
method="linear",
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
# Access the interpolated variables using data_vars dictionary
|
|
508
|
+
u10 = data_vars["u10"]
|
|
509
|
+
v10 = data_vars["v10"]
|
|
510
|
+
swr = data_vars["swr"]
|
|
511
|
+
lwr = data_vars["lwr"]
|
|
512
|
+
t2m = data_vars["t2m"]
|
|
513
|
+
d2m = data_vars["d2m"]
|
|
514
|
+
rain = data_vars["rain"]
|
|
515
|
+
|
|
516
|
+
if self.source == "ERA5":
|
|
611
517
|
# translate radiation to fluxes. ERA5 stores values integrated over 1 hour.
|
|
612
518
|
swr = swr / 3600 # from J/m^2 to W/m^2
|
|
613
519
|
lwr = lwr / 3600 # from J/m^2 to W/m^2
|
|
@@ -650,8 +556,12 @@ class AtmosphericForcing:
|
|
|
650
556
|
self.swr_correction.dim_names["latitude"]: lat,
|
|
651
557
|
self.swr_correction.dim_names["longitude"]: lon,
|
|
652
558
|
}
|
|
653
|
-
corr_factor =
|
|
654
|
-
corr_factor,
|
|
559
|
+
corr_factor = fill_and_interpolate(
|
|
560
|
+
corr_factor,
|
|
561
|
+
mask,
|
|
562
|
+
list(coords_correction.keys()),
|
|
563
|
+
coords_correction,
|
|
564
|
+
method="linear",
|
|
655
565
|
)
|
|
656
566
|
|
|
657
567
|
# temporal interpolation
|
|
@@ -700,11 +610,22 @@ class AtmosphericForcing:
|
|
|
700
610
|
ds["rain"].attrs["long_name"] = "Total precipitation"
|
|
701
611
|
ds["rain"].attrs["units"] = "cm/day"
|
|
702
612
|
|
|
703
|
-
ds.attrs["
|
|
613
|
+
ds.attrs["title"] = "ROMS atmospheric forcing file created by ROMS-Tools"
|
|
614
|
+
# Include the version of roms-tools
|
|
615
|
+
try:
|
|
616
|
+
roms_tools_version = importlib.metadata.version("roms-tools")
|
|
617
|
+
except importlib.metadata.PackageNotFoundError:
|
|
618
|
+
roms_tools_version = "unknown"
|
|
619
|
+
ds.attrs["roms_tools_version"] = roms_tools_version
|
|
620
|
+
ds.attrs["start_time"] = str(self.start_time)
|
|
621
|
+
ds.attrs["end_time"] = str(self.end_time)
|
|
622
|
+
ds.attrs["model_reference_date"] = str(self.model_reference_date)
|
|
623
|
+
ds.attrs["source"] = self.source
|
|
624
|
+
ds.attrs["use_coarse_grid"] = str(self.use_coarse_grid)
|
|
625
|
+
ds.attrs["swr_correction"] = str(self.swr_correction is not None)
|
|
626
|
+
ds.attrs["rivers"] = str(self.rivers is not None)
|
|
704
627
|
|
|
705
628
|
ds = ds.assign_coords({"lon": lon, "lat": lat})
|
|
706
|
-
if dims["time"] != "time":
|
|
707
|
-
ds = ds.rename({dims["time"]: "time"})
|
|
708
629
|
if self.use_coarse_grid:
|
|
709
630
|
ds = ds.rename({"eta_coarse": "eta_rho", "xi_coarse": "xi_rho"})
|
|
710
631
|
mask_roms = self.grid.ds["mask_coarse"].rename(
|
|
@@ -713,84 +634,28 @@ class AtmosphericForcing:
|
|
|
713
634
|
else:
|
|
714
635
|
mask_roms = self.grid.ds["mask_rho"]
|
|
715
636
|
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
self.nan_check(mask_roms, time=0)
|
|
719
|
-
|
|
720
|
-
@staticmethod
|
|
721
|
-
def _interpolate(field, mask, coords, method="linear"):
|
|
722
|
-
"""
|
|
723
|
-
Interpolate a field using specified coordinates and a given method.
|
|
724
|
-
|
|
725
|
-
Parameters
|
|
726
|
-
----------
|
|
727
|
-
field : xr.DataArray
|
|
728
|
-
The data array to be interpolated.
|
|
729
|
-
|
|
730
|
-
mask : xr.DataArray
|
|
731
|
-
A data array with same spatial dimensions as `field`, where `1` indicates wet (ocean)
|
|
732
|
-
points and `0` indicates land points.
|
|
733
|
-
|
|
734
|
-
coords : dict
|
|
735
|
-
A dictionary specifying the target coordinates for interpolation. The keys
|
|
736
|
-
should match the dimensions of `field` (e.g., {"longitude": lon_values, "latitude": lat_values}).
|
|
737
|
-
|
|
738
|
-
method : str, optional, default='linear'
|
|
739
|
-
The interpolation method to use. Valid options are those supported by
|
|
740
|
-
`xarray.DataArray.interp`.
|
|
741
|
-
|
|
742
|
-
Returns
|
|
743
|
-
-------
|
|
744
|
-
xr.DataArray
|
|
745
|
-
The interpolated data array.
|
|
746
|
-
|
|
747
|
-
Notes
|
|
748
|
-
-----
|
|
749
|
-
This method first sets land values to NaN based on the provided mask. It then uses the
|
|
750
|
-
`lateral_fill` function to propagate ocean values. These two steps serve the purpose to
|
|
751
|
-
avoid interpolation across the land-ocean boundary. Finally, it performs interpolation
|
|
752
|
-
over the specified coordinates.
|
|
753
|
-
|
|
754
|
-
"""
|
|
755
|
-
|
|
756
|
-
dims = list(coords.keys())
|
|
757
|
-
|
|
758
|
-
# set land values to nan
|
|
759
|
-
field = field.where(mask)
|
|
760
|
-
# propagate ocean values into land interior before interpolation
|
|
761
|
-
field = lateral_fill(field, 1 - mask, dims)
|
|
762
|
-
# interpolate
|
|
763
|
-
field_interpolated = field.interp(**coords, method=method).drop_vars(dims)
|
|
764
|
-
|
|
765
|
-
return field_interpolated
|
|
637
|
+
if dims["time"] != "time":
|
|
638
|
+
ds = ds.rename({dims["time"]: "time"})
|
|
766
639
|
|
|
767
|
-
|
|
768
|
-
"""
|
|
769
|
-
Checks for NaN values at wet points in all variables of the dataset at a specified time step.
|
|
640
|
+
# Preserve the original time coordinate for readability
|
|
641
|
+
ds = ds.assign_coords({"absolute_time": ds["time"]})
|
|
770
642
|
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
mask : array-like
|
|
774
|
-
A boolean mask indicating the wet points in the dataset.
|
|
775
|
-
time : int
|
|
776
|
-
The time step at which to check for NaN values. Default is 0.
|
|
643
|
+
# Translate the time coordinate to days since the model reference date
|
|
644
|
+
model_reference_date = np.datetime64(self.model_reference_date)
|
|
777
645
|
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
646
|
+
# Convert the time coordinate to the format expected by ROMS (days since model reference date)
|
|
647
|
+
ds["time"] = (
|
|
648
|
+
(ds["time"] - model_reference_date).astype("float64") / 3600 / 24 * 1e-9
|
|
649
|
+
)
|
|
650
|
+
ds["time"].attrs[
|
|
651
|
+
"long_name"
|
|
652
|
+
] = f"time since {np.datetime_as_string(model_reference_date, unit='D')}"
|
|
653
|
+
ds["time"].attrs["units"] = "days"
|
|
782
654
|
|
|
783
|
-
|
|
655
|
+
for var in ds.data_vars:
|
|
656
|
+
nan_check(ds[var].isel(time=0), mask_roms)
|
|
784
657
|
|
|
785
|
-
|
|
786
|
-
da = xr.where(mask == 1, self.ds[var].isel(time=time), 0)
|
|
787
|
-
if da.isnull().any().values:
|
|
788
|
-
raise ValueError(
|
|
789
|
-
f"NaN values found in the variable '{var}' at time step {time} over the ocean. This likely "
|
|
790
|
-
"occurs because the ROMS grid, including a small safety margin for interpolation, is not "
|
|
791
|
-
"fully contained within the dataset's longitude/latitude range. Please ensure that the "
|
|
792
|
-
"dataset covers the entire area required by the ROMS grid."
|
|
793
|
-
)
|
|
658
|
+
object.__setattr__(self, "ds", ds)
|
|
794
659
|
|
|
795
660
|
def plot(self, varname, time=0) -> None:
|
|
796
661
|
"""
|
|
@@ -828,87 +693,45 @@ class AtmosphericForcing:
|
|
|
828
693
|
... grid=grid_info,
|
|
829
694
|
... start_time=start_time,
|
|
830
695
|
... end_time=end_time,
|
|
831
|
-
... source="
|
|
696
|
+
... source="ERA5",
|
|
832
697
|
... filename="atmospheric_data_*.nc",
|
|
833
698
|
... swr_correction=swr_correction,
|
|
834
699
|
... )
|
|
835
700
|
>>> atm_forcing.plot("uwnd", time=0)
|
|
836
701
|
"""
|
|
837
702
|
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
lon_deg = self.ds.lon
|
|
842
|
-
lat_deg = self.ds.lat
|
|
843
|
-
|
|
844
|
-
# check if North or South pole are in domain
|
|
845
|
-
if lat_deg.max().values > 89 or lat_deg.min().values < -89:
|
|
846
|
-
raise NotImplementedError(
|
|
847
|
-
"Plotting the atmospheric forcing is not implemented for the case that the domain contains the North or South pole."
|
|
848
|
-
)
|
|
849
|
-
|
|
850
|
-
if self.grid.straddle:
|
|
851
|
-
lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
|
|
852
|
-
|
|
853
|
-
# Define projections
|
|
854
|
-
proj = ccrs.PlateCarree()
|
|
855
|
-
|
|
856
|
-
trans = ccrs.NearsidePerspective(
|
|
857
|
-
central_longitude=lon_deg.mean().values,
|
|
858
|
-
central_latitude=lat_deg.mean().values,
|
|
703
|
+
title = "%s at time %s" % (
|
|
704
|
+
self.ds[varname].long_name,
|
|
705
|
+
np.datetime_as_string(self.ds["absolute_time"].isel(time=time), unit="s"),
|
|
859
706
|
)
|
|
860
707
|
|
|
861
|
-
lon_deg = lon_deg.values
|
|
862
|
-
lat_deg = lat_deg.values
|
|
863
|
-
|
|
864
|
-
# find corners
|
|
865
|
-
(lo1, la1) = (lon_deg[0, 0], lat_deg[0, 0])
|
|
866
|
-
(lo2, la2) = (lon_deg[0, -1], lat_deg[0, -1])
|
|
867
|
-
(lo3, la3) = (lon_deg[-1, -1], lat_deg[-1, -1])
|
|
868
|
-
(lo4, la4) = (lon_deg[-1, 0], lat_deg[-1, 0])
|
|
869
|
-
|
|
870
|
-
# transform coordinates to projected space
|
|
871
|
-
lo1t, la1t = trans.transform_point(lo1, la1, proj)
|
|
872
|
-
lo2t, la2t = trans.transform_point(lo2, la2, proj)
|
|
873
|
-
lo3t, la3t = trans.transform_point(lo3, la3, proj)
|
|
874
|
-
lo4t, la4t = trans.transform_point(lo4, la4, proj)
|
|
875
|
-
|
|
876
|
-
plt.figure(figsize=(10, 10))
|
|
877
|
-
ax = plt.axes(projection=trans)
|
|
878
|
-
|
|
879
|
-
ax.plot(
|
|
880
|
-
[lo1t, lo2t, lo3t, lo4t, lo1t],
|
|
881
|
-
[la1t, la2t, la3t, la4t, la1t],
|
|
882
|
-
"go-",
|
|
883
|
-
)
|
|
884
|
-
|
|
885
|
-
ax.coastlines(
|
|
886
|
-
resolution="50m", linewidth=0.5, color="black"
|
|
887
|
-
) # add map of coastlines
|
|
888
|
-
ax.gridlines()
|
|
889
|
-
|
|
890
708
|
field = self.ds[varname].isel(time=time).compute()
|
|
709
|
+
|
|
710
|
+
# choose colorbar
|
|
891
711
|
if varname in ["uwnd", "vwnd"]:
|
|
892
712
|
vmax = max(field.max().values, -field.min().values)
|
|
893
713
|
vmin = -vmax
|
|
894
|
-
cmap = "RdBu_r"
|
|
714
|
+
cmap = plt.colormaps.get_cmap("RdBu_r")
|
|
895
715
|
else:
|
|
896
716
|
vmax = field.max().values
|
|
897
717
|
vmin = field.min().values
|
|
898
718
|
if varname in ["swrad", "lwrad", "Tair", "qair"]:
|
|
899
|
-
cmap = "YlOrRd"
|
|
719
|
+
cmap = plt.colormaps.get_cmap("YlOrRd")
|
|
900
720
|
else:
|
|
901
|
-
cmap = "YlGnBu"
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
721
|
+
cmap = plt.colormaps.get_cmap("YlGnBu")
|
|
722
|
+
cmap.set_bad(color="gray")
|
|
723
|
+
|
|
724
|
+
kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
|
|
725
|
+
|
|
726
|
+
_plot(
|
|
727
|
+
self.grid.ds,
|
|
728
|
+
field=field,
|
|
729
|
+
straddle=self.grid.straddle,
|
|
730
|
+
coarse_grid=self.use_coarse_grid,
|
|
731
|
+
title=title,
|
|
732
|
+
kwargs=kwargs,
|
|
733
|
+
c="g",
|
|
905
734
|
)
|
|
906
|
-
plt.colorbar(p, label=field.units)
|
|
907
|
-
ax.set_title(
|
|
908
|
-
"%s at time %s"
|
|
909
|
-
% (field.long_name, np.datetime_as_string(field.time, unit="s"))
|
|
910
|
-
)
|
|
911
|
-
plt.show()
|
|
912
735
|
|
|
913
736
|
def save(self, filepath: str, time_chunk_size: int = 1) -> None:
|
|
914
737
|
"""
|
|
@@ -938,11 +761,11 @@ class AtmosphericForcing:
|
|
|
938
761
|
writes = []
|
|
939
762
|
|
|
940
763
|
# Group dataset by year
|
|
941
|
-
gb = self.ds.groupby("
|
|
764
|
+
gb = self.ds.groupby("absolute_time.year")
|
|
942
765
|
|
|
943
766
|
for year, group_ds in gb:
|
|
944
767
|
# Further group each yearly group by month
|
|
945
|
-
sub_gb = group_ds.groupby("
|
|
768
|
+
sub_gb = group_ds.groupby("absolute_time.month")
|
|
946
769
|
|
|
947
770
|
for month, ds in sub_gb:
|
|
948
771
|
# Chunk the dataset by the specified time chunk size
|
|
@@ -951,8 +774,8 @@ class AtmosphericForcing:
|
|
|
951
774
|
|
|
952
775
|
# Determine the number of days in the month
|
|
953
776
|
num_days_in_month = calendar.monthrange(year, month)[1]
|
|
954
|
-
first_day = ds.time.dt.day.values[0]
|
|
955
|
-
last_day = ds.time.dt.day.values[-1]
|
|
777
|
+
first_day = ds.time.absolute_time.dt.day.values[0]
|
|
778
|
+
last_day = ds.time.absolute_time.dt.day.values[-1]
|
|
956
779
|
|
|
957
780
|
# Create filename based on whether the dataset contains a full month
|
|
958
781
|
if first_day == 1 and last_day == num_days_in_month:
|
|
@@ -971,23 +794,142 @@ class AtmosphericForcing:
|
|
|
971
794
|
|
|
972
795
|
for ds, filename in zip(datasets, filenames):
|
|
973
796
|
|
|
974
|
-
# Translate the time coordinate to days since the model reference date
|
|
975
|
-
model_reference_date = np.datetime64(self.model_reference_date)
|
|
976
|
-
|
|
977
|
-
# Preserve the original time coordinate for readability
|
|
978
|
-
ds["Time"] = ds["time"]
|
|
979
|
-
|
|
980
|
-
# Convert the time coordinate to the format expected by ROMS (days since model reference date)
|
|
981
|
-
ds["time"] = (
|
|
982
|
-
(ds["time"] - model_reference_date).astype("float64") / 3600 / 24 * 1e-9
|
|
983
|
-
)
|
|
984
|
-
ds["time"].attrs[
|
|
985
|
-
"long_name"
|
|
986
|
-
] = f"time since {np.datetime_as_string(model_reference_date, unit='D')}"
|
|
987
|
-
|
|
988
797
|
# Prepare the dataset for writing to a netCDF file without immediately computing
|
|
989
798
|
write = ds.to_netcdf(filename, compute=False)
|
|
990
799
|
writes.append(write)
|
|
991
800
|
|
|
992
801
|
# Perform the actual write operations in parallel
|
|
993
|
-
dask.
|
|
802
|
+
dask.persist(*writes)
|
|
803
|
+
|
|
804
|
+
def to_yaml(self, filepath: str) -> None:
|
|
805
|
+
"""
|
|
806
|
+
Export the parameters of the class to a YAML file, including the version of roms-tools.
|
|
807
|
+
|
|
808
|
+
Parameters
|
|
809
|
+
----------
|
|
810
|
+
filepath : str
|
|
811
|
+
The path to the YAML file where the parameters will be saved.
|
|
812
|
+
"""
|
|
813
|
+
# Serialize Grid data
|
|
814
|
+
grid_data = asdict(self.grid)
|
|
815
|
+
grid_data.pop("ds", None) # Exclude non-serializable fields
|
|
816
|
+
grid_data.pop("straddle", None)
|
|
817
|
+
|
|
818
|
+
if self.swr_correction:
|
|
819
|
+
swr_correction_data = asdict(self.swr_correction)
|
|
820
|
+
swr_correction_data.pop("ds", None)
|
|
821
|
+
else:
|
|
822
|
+
swr_correction_data = None
|
|
823
|
+
|
|
824
|
+
rivers_data = asdict(self.rivers) if self.rivers else None
|
|
825
|
+
|
|
826
|
+
# Include the version of roms-tools
|
|
827
|
+
try:
|
|
828
|
+
roms_tools_version = importlib.metadata.version("roms-tools")
|
|
829
|
+
except importlib.metadata.PackageNotFoundError:
|
|
830
|
+
roms_tools_version = "unknown"
|
|
831
|
+
|
|
832
|
+
# Create header
|
|
833
|
+
header = f"---\nroms_tools_version: {roms_tools_version}\n---\n"
|
|
834
|
+
|
|
835
|
+
# Create YAML data for Grid and optional attributes
|
|
836
|
+
grid_yaml_data = {"Grid": grid_data}
|
|
837
|
+
swr_correction_yaml_data = (
|
|
838
|
+
{"SWRCorrection": swr_correction_data} if swr_correction_data else {}
|
|
839
|
+
)
|
|
840
|
+
rivers_yaml_data = {"Rivers": rivers_data} if rivers_data else {}
|
|
841
|
+
|
|
842
|
+
# Combine all sections
|
|
843
|
+
atmospheric_forcing_data = {
|
|
844
|
+
"AtmosphericForcing": {
|
|
845
|
+
"filename": self.filename,
|
|
846
|
+
"start_time": self.start_time.isoformat(),
|
|
847
|
+
"end_time": self.end_time.isoformat(),
|
|
848
|
+
"model_reference_date": self.model_reference_date.isoformat(),
|
|
849
|
+
"source": self.source,
|
|
850
|
+
"use_coarse_grid": self.use_coarse_grid,
|
|
851
|
+
}
|
|
852
|
+
}
|
|
853
|
+
|
|
854
|
+
# Merge YAML data while excluding empty sections
|
|
855
|
+
yaml_data = {
|
|
856
|
+
**grid_yaml_data,
|
|
857
|
+
**swr_correction_yaml_data,
|
|
858
|
+
**rivers_yaml_data,
|
|
859
|
+
**atmospheric_forcing_data,
|
|
860
|
+
}
|
|
861
|
+
|
|
862
|
+
with open(filepath, "w") as file:
|
|
863
|
+
# Write header
|
|
864
|
+
file.write(header)
|
|
865
|
+
# Write YAML data
|
|
866
|
+
yaml.dump(yaml_data, file, default_flow_style=False)
|
|
867
|
+
|
|
868
|
+
@classmethod
|
|
869
|
+
def from_yaml(cls, filepath: str) -> "AtmosphericForcing":
|
|
870
|
+
"""
|
|
871
|
+
Create an instance of the AtmosphericForcing class from a YAML file.
|
|
872
|
+
|
|
873
|
+
Parameters
|
|
874
|
+
----------
|
|
875
|
+
filepath : str
|
|
876
|
+
The path to the YAML file from which the parameters will be read.
|
|
877
|
+
|
|
878
|
+
Returns
|
|
879
|
+
-------
|
|
880
|
+
AtmosphericForcing
|
|
881
|
+
An instance of the AtmosphericForcing class.
|
|
882
|
+
"""
|
|
883
|
+
# Read the entire file content
|
|
884
|
+
with open(filepath, "r") as file:
|
|
885
|
+
file_content = file.read()
|
|
886
|
+
|
|
887
|
+
# Split the content into YAML documents
|
|
888
|
+
documents = list(yaml.safe_load_all(file_content))
|
|
889
|
+
|
|
890
|
+
swr_correction_data = None
|
|
891
|
+
rivers_data = None
|
|
892
|
+
atmospheric_forcing_data = None
|
|
893
|
+
|
|
894
|
+
# Process the YAML documents
|
|
895
|
+
for doc in documents:
|
|
896
|
+
if doc is None:
|
|
897
|
+
continue
|
|
898
|
+
if "AtmosphericForcing" in doc:
|
|
899
|
+
atmospheric_forcing_data = doc["AtmosphericForcing"]
|
|
900
|
+
if "SWRCorrection" in doc:
|
|
901
|
+
swr_correction_data = doc["SWRCorrection"]
|
|
902
|
+
if "Rivers" in doc:
|
|
903
|
+
rivers_data = doc["Rivers"]
|
|
904
|
+
|
|
905
|
+
if atmospheric_forcing_data is None:
|
|
906
|
+
raise ValueError(
|
|
907
|
+
"No AtmosphericForcing configuration found in the YAML file."
|
|
908
|
+
)
|
|
909
|
+
|
|
910
|
+
# Convert from string to datetime
|
|
911
|
+
for date_string in ["model_reference_date", "start_time", "end_time"]:
|
|
912
|
+
atmospheric_forcing_data[date_string] = datetime.fromisoformat(
|
|
913
|
+
atmospheric_forcing_data[date_string]
|
|
914
|
+
)
|
|
915
|
+
|
|
916
|
+
# Create Grid instance from the YAML file
|
|
917
|
+
grid = Grid.from_yaml(filepath)
|
|
918
|
+
|
|
919
|
+
if swr_correction_data is not None:
|
|
920
|
+
swr_correction = SWRCorrection.from_yaml(filepath)
|
|
921
|
+
else:
|
|
922
|
+
swr_correction = None
|
|
923
|
+
|
|
924
|
+
if rivers_data is not None:
|
|
925
|
+
rivers = Rivers.from_yaml(filepath)
|
|
926
|
+
else:
|
|
927
|
+
rivers = None
|
|
928
|
+
|
|
929
|
+
# Create and return an instance of AtmosphericForcing
|
|
930
|
+
return cls(
|
|
931
|
+
grid=grid,
|
|
932
|
+
swr_correction=swr_correction,
|
|
933
|
+
rivers=rivers,
|
|
934
|
+
**atmospheric_forcing_data,
|
|
935
|
+
)
|