roms-tools 2.5.0__py3-none-any.whl → 2.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ci/environment-with-xesmf.yml +16 -0
- roms_tools/analysis/roms_output.py +521 -187
- roms_tools/analysis/utils.py +169 -0
- roms_tools/plot.py +351 -214
- roms_tools/regrid.py +161 -9
- roms_tools/setup/boundary_forcing.py +22 -22
- roms_tools/setup/datasets.py +40 -44
- roms_tools/setup/grid.py +28 -28
- roms_tools/setup/initial_conditions.py +23 -31
- roms_tools/setup/nesting.py +3 -3
- roms_tools/setup/river_forcing.py +22 -23
- roms_tools/setup/surface_forcing.py +14 -13
- roms_tools/setup/tides.py +7 -7
- roms_tools/setup/topography.py +2 -2
- roms_tools/tests/test_analysis/test_roms_output.py +299 -188
- roms_tools/tests/test_regrid.py +85 -2
- roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/.zmetadata +2 -2
- roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/.zmetadata +2 -2
- roms_tools/tests/test_setup/test_river_forcing.py +47 -51
- roms_tools/tests/test_vertical_coordinate.py +73 -0
- roms_tools/utils.py +11 -7
- roms_tools/vertical_coordinate.py +7 -0
- {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info}/METADATA +22 -11
- {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info}/RECORD +33 -30
- {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info}/WHEEL +1 -1
- /roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/{river_location → river_flux}/.zarray +0 -0
- /roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/{river_location → river_flux}/.zattrs +0 -0
- /roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/{river_location → river_flux}/0.0 +0 -0
- /roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/{river_location → river_flux}/.zarray +0 -0
- /roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/{river_location → river_flux}/.zattrs +0 -0
- /roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/{river_location → river_flux}/0.0 +0 -0
- {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info/licenses}/LICENSE +0 -0
- {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info}/top_level.txt +0 -0
roms_tools/regrid.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
|
1
|
+
import xgcm
|
|
1
2
|
import xarray as xr
|
|
3
|
+
import warnings
|
|
2
4
|
|
|
3
5
|
|
|
4
|
-
class
|
|
6
|
+
class LateralRegridToROMS:
|
|
5
7
|
"""Handles lateral regridding of data onto a new spatial grid."""
|
|
6
8
|
|
|
7
9
|
def __init__(self, target_coords, source_dim_names):
|
|
@@ -49,7 +51,92 @@ class LateralRegrid:
|
|
|
49
51
|
return regridded
|
|
50
52
|
|
|
51
53
|
|
|
52
|
-
class
|
|
54
|
+
class LateralRegridFromROMS:
|
|
55
|
+
"""Regrids data from a curvilinear ROMS grid onto latitude-longitude coordinates
|
|
56
|
+
using xESMF.
|
|
57
|
+
|
|
58
|
+
It requires the `xesmf` library, which can be installed by installing `roms-tools` via conda.
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
source_grid_ds : xarray.Dataset
|
|
63
|
+
The source dataset containing the curvilinear ROMS grid with 'lat_rho' and 'lon_rho'.
|
|
64
|
+
|
|
65
|
+
target_coords : dict
|
|
66
|
+
A dictionary containing 'lat' and 'lon' arrays representing the target
|
|
67
|
+
latitude and longitude coordinates for regridding.
|
|
68
|
+
|
|
69
|
+
method : str, optional
|
|
70
|
+
The regridding method to use. Default is "bilinear". Other options include "nearest_s2d" and "conservative".
|
|
71
|
+
|
|
72
|
+
Raises
|
|
73
|
+
------
|
|
74
|
+
ImportError
|
|
75
|
+
If xESMF is not installed.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
def __init__(self, ds_in, target_coords, method="bilinear"):
|
|
79
|
+
"""Initializes the regridder with the source and target grids.
|
|
80
|
+
|
|
81
|
+
Parameters
|
|
82
|
+
----------
|
|
83
|
+
ds_in : xarray.Dataset or xarray.DataArray
|
|
84
|
+
The source dataset or dataarray containing the curvilinear ROMS grid with coordinates 'lat' and 'lon'.
|
|
85
|
+
|
|
86
|
+
target_coords : dict
|
|
87
|
+
A dictionary containing 'lat' and 'lon' arrays representing the target latitude
|
|
88
|
+
and longitude coordinates for regridding.
|
|
89
|
+
|
|
90
|
+
method : str, optional
|
|
91
|
+
The regridding method to use. Default is "bilinear". Other options include
|
|
92
|
+
"nearest_s2d" and "conservative".
|
|
93
|
+
|
|
94
|
+
Raises
|
|
95
|
+
------
|
|
96
|
+
ImportError
|
|
97
|
+
If xESMF is not installed.
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
try:
|
|
101
|
+
import xesmf as xe
|
|
102
|
+
|
|
103
|
+
except ImportError:
|
|
104
|
+
raise ImportError(
|
|
105
|
+
"xesmf is required for this regridding task. Please install `roms-tools` via conda, which includes xesmf."
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
ds_out = xr.Dataset()
|
|
109
|
+
ds_out["lat"] = target_coords["lat"]
|
|
110
|
+
ds_out["lon"] = target_coords["lon"]
|
|
111
|
+
|
|
112
|
+
with warnings.catch_warnings():
|
|
113
|
+
warnings.filterwarnings("ignore", category=UserWarning, module="xesmf")
|
|
114
|
+
self.regridder = xe.Regridder(
|
|
115
|
+
ds_in, ds_out, method=method, unmapped_to_nan=True
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
def apply(self, da):
|
|
119
|
+
"""Applies the regridding to the provided data array.
|
|
120
|
+
|
|
121
|
+
Parameters
|
|
122
|
+
----------
|
|
123
|
+
da : xarray.DataArray
|
|
124
|
+
The data array to regrid. This should have the same dimension names as the
|
|
125
|
+
source grid (e.g., 'lat' and 'lon').
|
|
126
|
+
|
|
127
|
+
Returns
|
|
128
|
+
-------
|
|
129
|
+
xarray.DataArray
|
|
130
|
+
The regridded data array.
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
with warnings.catch_warnings():
|
|
134
|
+
warnings.filterwarnings("ignore", category=UserWarning, module="xesmf")
|
|
135
|
+
regridded = self.regridder(da, keep_attrs=True)
|
|
136
|
+
return regridded
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class VerticalRegridToROMS:
|
|
53
140
|
"""Interpolates data onto new vertical (depth) coordinates.
|
|
54
141
|
|
|
55
142
|
Parameters
|
|
@@ -106,13 +193,13 @@ class VerticalRegrid:
|
|
|
106
193
|
}
|
|
107
194
|
)
|
|
108
195
|
|
|
109
|
-
def apply(self,
|
|
196
|
+
def apply(self, da, fill_nans=True):
|
|
110
197
|
"""Interpolates the variable onto the new depth grid using precomputed
|
|
111
198
|
coefficients for linear interpolation between layers.
|
|
112
199
|
|
|
113
200
|
Parameters
|
|
114
201
|
----------
|
|
115
|
-
|
|
202
|
+
da : xarray.DataArray
|
|
116
203
|
The input data to be regridded along the depth dimension. This should be
|
|
117
204
|
an array with the same depth coordinates as the original grid.
|
|
118
205
|
fill_nans : bool, optional
|
|
@@ -130,16 +217,81 @@ class VerticalRegrid:
|
|
|
130
217
|
|
|
131
218
|
dims = {"dim": self.depth_dim}
|
|
132
219
|
|
|
133
|
-
|
|
134
|
-
|
|
220
|
+
da_below = da.where(self.coeff["is_below"]).sum(**dims)
|
|
221
|
+
da_above = da.where(self.coeff["is_above"]).sum(**dims)
|
|
135
222
|
|
|
136
|
-
result =
|
|
223
|
+
result = da_below + (da_above - da_below) * self.coeff["factor"]
|
|
137
224
|
if fill_nans:
|
|
138
|
-
result = result.where(self.coeff["upper_mask"],
|
|
139
|
-
result = result.where(self.coeff["lower_mask"],
|
|
225
|
+
result = result.where(self.coeff["upper_mask"], da.isel({dims["dim"]: 0}))
|
|
226
|
+
result = result.where(self.coeff["lower_mask"], da.isel({dims["dim"]: -1}))
|
|
140
227
|
else:
|
|
141
228
|
result = result.where(self.coeff["upper_mask"]).where(
|
|
142
229
|
self.coeff["lower_mask"]
|
|
143
230
|
)
|
|
144
231
|
|
|
145
232
|
return result
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class VerticalRegridFromROMS:
|
|
236
|
+
"""A class for regridding data from the ROMS vertical coordinate system to target
|
|
237
|
+
depth levels.
|
|
238
|
+
|
|
239
|
+
This class uses the `xgcm` package to perform the transformation from the ROMS depth coordinates to
|
|
240
|
+
a user-defined set of target depth levels. It assumes that the input dataset `ds` contains the necessary
|
|
241
|
+
vertical coordinate information (`s_rho`).
|
|
242
|
+
|
|
243
|
+
Attributes
|
|
244
|
+
----------
|
|
245
|
+
grid : xgcm.Grid
|
|
246
|
+
The grid object used for regridding, initialized with the given dataset `ds`.
|
|
247
|
+
"""
|
|
248
|
+
|
|
249
|
+
def __init__(self, ds):
|
|
250
|
+
"""Initializes the `VerticalRegridFromROMS` object by creating an `xgcm.Grid`
|
|
251
|
+
instance.
|
|
252
|
+
|
|
253
|
+
Parameters
|
|
254
|
+
----------
|
|
255
|
+
ds : xarray.Dataset
|
|
256
|
+
The dataset containing the ROMS output data, which must include the vertical coordinate `s_rho`.
|
|
257
|
+
"""
|
|
258
|
+
self.grid = xgcm.Grid(ds, coords={"s_rho": {"center": "s_rho"}}, periodic=False)
|
|
259
|
+
|
|
260
|
+
def apply(self, da, depth_coords, target_depth_levels, mask_edges=True):
|
|
261
|
+
"""Applies vertical regridding from ROMS to the specified target depth levels.
|
|
262
|
+
|
|
263
|
+
This method transforms the input data array `da` from the ROMS vertical coordinate (`s_rho`)
|
|
264
|
+
to a set of target depth levels defined by `target_depth_levels`.
|
|
265
|
+
|
|
266
|
+
Parameters
|
|
267
|
+
----------
|
|
268
|
+
da : xarray.DataArray
|
|
269
|
+
The data array containing the ROMS output field to be regridded. It must have a vertical
|
|
270
|
+
dimension corresponding to `s_rho`.
|
|
271
|
+
|
|
272
|
+
depth_coords : array-like
|
|
273
|
+
The depth coordinates of the input data array `da` (typically the `s_rho` coordinate in ROMS).
|
|
274
|
+
|
|
275
|
+
target_depth_levels : array-like
|
|
276
|
+
The target depth levels to which the input data `da` will be regridded.
|
|
277
|
+
|
|
278
|
+
mask_edges: bool, optional
|
|
279
|
+
If activated, target values outside the range of depth_coords are masked with nan. Defaults to True.
|
|
280
|
+
|
|
281
|
+
Returns
|
|
282
|
+
-------
|
|
283
|
+
xarray.DataArray
|
|
284
|
+
A new `xarray.DataArray` containing the regridded data at the specified target depth levels.
|
|
285
|
+
"""
|
|
286
|
+
|
|
287
|
+
with warnings.catch_warnings():
|
|
288
|
+
warnings.filterwarnings("ignore", category=FutureWarning, module="xgcm")
|
|
289
|
+
transformed = self.grid.transform(
|
|
290
|
+
da,
|
|
291
|
+
"s_rho",
|
|
292
|
+
target_depth_levels,
|
|
293
|
+
target_data=depth_coords,
|
|
294
|
+
mask_edges=mask_edges,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
return transformed
|
|
@@ -9,7 +9,7 @@ from datetime import datetime
|
|
|
9
9
|
import matplotlib.pyplot as plt
|
|
10
10
|
from pathlib import Path
|
|
11
11
|
from roms_tools import Grid
|
|
12
|
-
from roms_tools.regrid import
|
|
12
|
+
from roms_tools.regrid import LateralRegridToROMS, VerticalRegridToROMS
|
|
13
13
|
from roms_tools.utils import save_datasets
|
|
14
14
|
from roms_tools.vertical_coordinate import compute_depth
|
|
15
15
|
from roms_tools.plot import _section_plot, _line_plot
|
|
@@ -35,7 +35,7 @@ from roms_tools.setup.utils import (
|
|
|
35
35
|
)
|
|
36
36
|
|
|
37
37
|
|
|
38
|
-
@dataclass(
|
|
38
|
+
@dataclass(kw_only=True)
|
|
39
39
|
class BoundaryForcing:
|
|
40
40
|
"""Represents boundary forcing input data for ROMS.
|
|
41
41
|
|
|
@@ -124,7 +124,7 @@ class BoundaryForcing:
|
|
|
124
124
|
|
|
125
125
|
self._input_checks()
|
|
126
126
|
# Dataset for depth coordinates
|
|
127
|
-
|
|
127
|
+
self.ds_depth_coords = xr.Dataset()
|
|
128
128
|
|
|
129
129
|
target_coords = get_target_coords(self.grid)
|
|
130
130
|
|
|
@@ -185,7 +185,7 @@ class BoundaryForcing:
|
|
|
185
185
|
lat = target_coords["lat"].isel(
|
|
186
186
|
**self.bdry_coords["vector"][direction]
|
|
187
187
|
)
|
|
188
|
-
lateral_regrid =
|
|
188
|
+
lateral_regrid = LateralRegridToROMS(
|
|
189
189
|
{"lat": lat, "lon": lon}, bdry_data.dim_names
|
|
190
190
|
)
|
|
191
191
|
for var_name in vector_var_names:
|
|
@@ -214,7 +214,7 @@ class BoundaryForcing:
|
|
|
214
214
|
lat = target_coords["lat"].isel(
|
|
215
215
|
**self.bdry_coords["rho"][direction]
|
|
216
216
|
)
|
|
217
|
-
lateral_regrid =
|
|
217
|
+
lateral_regrid = LateralRegridToROMS(
|
|
218
218
|
{"lat": lat, "lon": lon}, bdry_data.dim_names
|
|
219
219
|
)
|
|
220
220
|
for var_name in tracer_var_names:
|
|
@@ -292,7 +292,7 @@ class BoundaryForcing:
|
|
|
292
292
|
# vertical regridding
|
|
293
293
|
for location in ["rho", "u", "v"]:
|
|
294
294
|
if len(var_names_dict[location]) > 0:
|
|
295
|
-
vertical_regrid =
|
|
295
|
+
vertical_regrid = VerticalRegridToROMS(
|
|
296
296
|
self.ds_depth_coords[f"layer_depth_{location}_{direction}"],
|
|
297
297
|
bdry_data.ds[bdry_data.dim_names["depth"]],
|
|
298
298
|
)
|
|
@@ -335,7 +335,7 @@ class BoundaryForcing:
|
|
|
335
335
|
for var_name in ds.data_vars:
|
|
336
336
|
ds[var_name] = substitute_nans_by_fillvalue(ds[var_name])
|
|
337
337
|
|
|
338
|
-
|
|
338
|
+
self.ds = ds
|
|
339
339
|
|
|
340
340
|
def _input_checks(self):
|
|
341
341
|
# Check that start_time and end_time are both None or none of them is
|
|
@@ -361,11 +361,10 @@ class BoundaryForcing:
|
|
|
361
361
|
raise ValueError("`source` must include a 'path'.")
|
|
362
362
|
|
|
363
363
|
# Set 'climatology' to False if not provided in 'source'
|
|
364
|
-
|
|
365
|
-
self,
|
|
366
|
-
"source",
|
|
367
|
-
|
|
368
|
-
)
|
|
364
|
+
self.source = {
|
|
365
|
+
**self.source,
|
|
366
|
+
"climatology": self.source.get("climatology", False),
|
|
367
|
+
}
|
|
369
368
|
|
|
370
369
|
# Ensure adjust_depth_for_sea_surface_height is only used with type="physics"
|
|
371
370
|
if self.type == "bgc" and self.adjust_depth_for_sea_surface_height:
|
|
@@ -373,7 +372,7 @@ class BoundaryForcing:
|
|
|
373
372
|
"adjust_depth_for_sea_surface_height is not applicable for BGC fields. "
|
|
374
373
|
"Setting it to False."
|
|
375
374
|
)
|
|
376
|
-
|
|
375
|
+
self.adjust_depth_for_sea_surface_height = False
|
|
377
376
|
elif self.adjust_depth_for_sea_surface_height:
|
|
378
377
|
logging.info("Sea surface height will be used to adjust depth coordinates.")
|
|
379
378
|
else:
|
|
@@ -495,7 +494,7 @@ class BoundaryForcing:
|
|
|
495
494
|
else:
|
|
496
495
|
variable_info[var_name] = {**default_info, "validate": False}
|
|
497
496
|
|
|
498
|
-
|
|
497
|
+
self.variable_info = variable_info
|
|
499
498
|
|
|
500
499
|
def _write_into_dataset(self, direction, processed_fields, ds=None):
|
|
501
500
|
if ds is None:
|
|
@@ -563,7 +562,7 @@ class BoundaryForcing:
|
|
|
563
562
|
|
|
564
563
|
bdry_coords = get_boundary_coords()
|
|
565
564
|
|
|
566
|
-
|
|
565
|
+
self.bdry_coords = bdry_coords
|
|
567
566
|
|
|
568
567
|
def _get_depth_coordinates(
|
|
569
568
|
self,
|
|
@@ -861,12 +860,6 @@ class BoundaryForcing:
|
|
|
861
860
|
|
|
862
861
|
field = self.ds[var_name].isel(bry_time=time)
|
|
863
862
|
|
|
864
|
-
if self.use_dask:
|
|
865
|
-
from dask.diagnostics import ProgressBar
|
|
866
|
-
|
|
867
|
-
with ProgressBar():
|
|
868
|
-
field = field.load()
|
|
869
|
-
|
|
870
863
|
title = field.long_name
|
|
871
864
|
var_name_wo_direction, direction = var_name.split("_")
|
|
872
865
|
location = self.variable_info[var_name_wo_direction]["location"]
|
|
@@ -881,10 +874,17 @@ class BoundaryForcing:
|
|
|
881
874
|
|
|
882
875
|
mask = mask.isel(**self.bdry_coords[location][direction])
|
|
883
876
|
|
|
877
|
+
# Load the data
|
|
878
|
+
if self.use_dask:
|
|
879
|
+
from dask.diagnostics import ProgressBar
|
|
880
|
+
|
|
881
|
+
with ProgressBar():
|
|
882
|
+
field = field.load()
|
|
883
|
+
|
|
884
884
|
if "s_rho" in field.dims:
|
|
885
885
|
layer_depth = self.ds_depth_coords[f"layer_depth_{location}_{direction}"]
|
|
886
886
|
if self.adjust_depth_for_sea_surface_height:
|
|
887
|
-
layer_depth = layer_depth.isel(time=time)
|
|
887
|
+
layer_depth = layer_depth.isel(time=time).load()
|
|
888
888
|
field = field.assign_coords({"layer_depth": layer_depth})
|
|
889
889
|
if var_name.startswith(("u", "v", "ubar", "vbar", "zeta")):
|
|
890
890
|
vmax = max(field.max().values, -field.min().values)
|
roms_tools/setup/datasets.py
CHANGED
|
@@ -25,7 +25,7 @@ from roms_tools.setup.fill import LateralFill
|
|
|
25
25
|
# lat-lon datasets
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
@dataclass(
|
|
28
|
+
@dataclass(kw_only=True)
|
|
29
29
|
class Dataset:
|
|
30
30
|
"""Represents forcing data on original grid.
|
|
31
31
|
|
|
@@ -132,8 +132,8 @@ class Dataset:
|
|
|
132
132
|
self.infer_horizontal_resolution(ds)
|
|
133
133
|
|
|
134
134
|
# Check whether the data covers the entire globe
|
|
135
|
-
|
|
136
|
-
|
|
135
|
+
self.is_global = self.check_if_global(ds)
|
|
136
|
+
self.ds = ds
|
|
137
137
|
|
|
138
138
|
if self.apply_post_processing:
|
|
139
139
|
self.post_process()
|
|
@@ -354,7 +354,7 @@ class Dataset:
|
|
|
354
354
|
resolution = np.mean([lat_resolution, lon_resolution])
|
|
355
355
|
|
|
356
356
|
# Set the computed resolution as an attribute
|
|
357
|
-
|
|
357
|
+
self.resolution = resolution
|
|
358
358
|
|
|
359
359
|
def compute_minimal_grid_spacing(self, ds: xr.Dataset):
|
|
360
360
|
"""Compute the minimal grid spacing in a dataset based on latitude and longitude
|
|
@@ -528,7 +528,7 @@ class Dataset:
|
|
|
528
528
|
This method modifies the dataset in place and does not return anything.
|
|
529
529
|
"""
|
|
530
530
|
ds = self.ds.astype({var: "float64" for var in self.ds.data_vars})
|
|
531
|
-
|
|
531
|
+
self.ds = ds
|
|
532
532
|
|
|
533
533
|
def choose_subdomain(
|
|
534
534
|
self,
|
|
@@ -671,7 +671,7 @@ class Dataset:
|
|
|
671
671
|
if return_copy:
|
|
672
672
|
return Dataset.from_ds(self, subdomain)
|
|
673
673
|
else:
|
|
674
|
-
|
|
674
|
+
self.ds = subdomain
|
|
675
675
|
|
|
676
676
|
def apply_lateral_fill(self):
|
|
677
677
|
"""Apply lateral fill to variables using the dataset's mask and grid dimensions.
|
|
@@ -757,7 +757,7 @@ class Dataset:
|
|
|
757
757
|
dataset = cls.__new__(cls)
|
|
758
758
|
|
|
759
759
|
# Directly set the provided dataset as the 'ds' attribute
|
|
760
|
-
|
|
760
|
+
dataset.ds = ds
|
|
761
761
|
|
|
762
762
|
# Copy all other attributes from the original data instance
|
|
763
763
|
for attr in vars(original_dataset):
|
|
@@ -767,7 +767,7 @@ class Dataset:
|
|
|
767
767
|
return dataset
|
|
768
768
|
|
|
769
769
|
|
|
770
|
-
@dataclass(
|
|
770
|
+
@dataclass(kw_only=True)
|
|
771
771
|
class TPXODataset(Dataset):
|
|
772
772
|
"""Represents tidal data on the original grid from the TPXO dataset.
|
|
773
773
|
|
|
@@ -858,15 +858,11 @@ class TPXODataset(Dataset):
|
|
|
858
858
|
{"nx": "longitude", "ny": "latitude", self.dim_names["ntides"]: "ntides"}
|
|
859
859
|
)
|
|
860
860
|
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
"
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
"longitude": "longitude",
|
|
867
|
-
"ntides": "ntides",
|
|
868
|
-
},
|
|
869
|
-
)
|
|
861
|
+
self.dim_names = {
|
|
862
|
+
"latitude": "latitude",
|
|
863
|
+
"longitude": "longitude",
|
|
864
|
+
"ntides": "ntides",
|
|
865
|
+
}
|
|
870
866
|
|
|
871
867
|
return ds
|
|
872
868
|
|
|
@@ -909,15 +905,15 @@ class TPXODataset(Dataset):
|
|
|
909
905
|
ds["mask"] = mask
|
|
910
906
|
ds = ds.drop_vars(["depth"])
|
|
911
907
|
|
|
912
|
-
|
|
908
|
+
self.ds = ds
|
|
913
909
|
|
|
914
910
|
# Remove "depth" from var_names
|
|
915
911
|
updated_var_names = {**self.var_names} # Create a copy of the dictionary
|
|
916
912
|
updated_var_names.pop("depth", None) # Remove "depth" if it exists
|
|
917
|
-
|
|
913
|
+
self.var_names = updated_var_names
|
|
918
914
|
|
|
919
915
|
|
|
920
|
-
@dataclass(
|
|
916
|
+
@dataclass(kw_only=True)
|
|
921
917
|
class GLORYSDataset(Dataset):
|
|
922
918
|
"""Represents GLORYS data on original grid.
|
|
923
919
|
|
|
@@ -996,7 +992,7 @@ class GLORYSDataset(Dataset):
|
|
|
996
992
|
self.ds["mask_vel"] = mask_vel
|
|
997
993
|
|
|
998
994
|
|
|
999
|
-
@dataclass(
|
|
995
|
+
@dataclass(kw_only=True)
|
|
1000
996
|
class CESMDataset(Dataset):
|
|
1001
997
|
"""Represents CESM data on original grid.
|
|
1002
998
|
|
|
@@ -1077,12 +1073,12 @@ class CESMDataset(Dataset):
|
|
|
1077
1073
|
# Update dimension names
|
|
1078
1074
|
updated_dim_names = self.dim_names.copy()
|
|
1079
1075
|
updated_dim_names["time"] = "time"
|
|
1080
|
-
|
|
1076
|
+
self.dim_names = updated_dim_names
|
|
1081
1077
|
|
|
1082
1078
|
return ds
|
|
1083
1079
|
|
|
1084
1080
|
|
|
1085
|
-
@dataclass(
|
|
1081
|
+
@dataclass(kw_only=True)
|
|
1086
1082
|
class CESMBGCDataset(CESMDataset):
|
|
1087
1083
|
"""Represents CESM BGC data on original grid.
|
|
1088
1084
|
|
|
@@ -1184,12 +1180,12 @@ class CESMBGCDataset(CESMDataset):
|
|
|
1184
1180
|
if "z_t_150m" in ds.variables:
|
|
1185
1181
|
ds = ds.drop_vars("z_t_150m")
|
|
1186
1182
|
# update dataset
|
|
1187
|
-
|
|
1183
|
+
self.ds = ds
|
|
1188
1184
|
|
|
1189
1185
|
# Update dim_names with "depth": "depth" key-value pair
|
|
1190
1186
|
updated_dim_names = self.dim_names.copy()
|
|
1191
1187
|
updated_dim_names["depth"] = "depth"
|
|
1192
|
-
|
|
1188
|
+
self.dim_names = updated_dim_names
|
|
1193
1189
|
|
|
1194
1190
|
mask = xr.where(
|
|
1195
1191
|
self.ds[self.var_names["PO4"]]
|
|
@@ -1202,7 +1198,7 @@ class CESMBGCDataset(CESMDataset):
|
|
|
1202
1198
|
self.ds["mask"] = mask
|
|
1203
1199
|
|
|
1204
1200
|
|
|
1205
|
-
@dataclass(
|
|
1201
|
+
@dataclass(kw_only=True)
|
|
1206
1202
|
class CESMBGCSurfaceForcingDataset(CESMDataset):
|
|
1207
1203
|
"""Represents CESM BGC surface forcing data on original grid.
|
|
1208
1204
|
|
|
@@ -1258,7 +1254,7 @@ class CESMBGCSurfaceForcingDataset(CESMDataset):
|
|
|
1258
1254
|
|
|
1259
1255
|
if "z_t" in self.ds.variables:
|
|
1260
1256
|
ds = self.ds.drop_vars("z_t")
|
|
1261
|
-
|
|
1257
|
+
self.ds = ds
|
|
1262
1258
|
|
|
1263
1259
|
mask = xr.where(
|
|
1264
1260
|
self.ds[self.var_names["pco2_air"]]
|
|
@@ -1271,7 +1267,7 @@ class CESMBGCSurfaceForcingDataset(CESMDataset):
|
|
|
1271
1267
|
self.ds["mask"] = mask
|
|
1272
1268
|
|
|
1273
1269
|
|
|
1274
|
-
@dataclass(
|
|
1270
|
+
@dataclass(kw_only=True)
|
|
1275
1271
|
class ERA5Dataset(Dataset):
|
|
1276
1272
|
"""Represents ERA5 data on original grid.
|
|
1277
1273
|
|
|
@@ -1371,27 +1367,27 @@ class ERA5Dataset(Dataset):
|
|
|
1371
1367
|
ds["qair"].attrs["long_name"] = "Absolute humidity at 2m"
|
|
1372
1368
|
ds["qair"].attrs["units"] = "kg/kg"
|
|
1373
1369
|
ds = ds.drop_vars([self.var_names["d2m"]])
|
|
1374
|
-
|
|
1370
|
+
self.ds = ds
|
|
1375
1371
|
|
|
1376
1372
|
# Update var_names dictionary
|
|
1377
1373
|
var_names = {**self.var_names, "qair": "qair"}
|
|
1378
1374
|
var_names.pop("d2m")
|
|
1379
|
-
|
|
1375
|
+
self.var_names = var_names
|
|
1380
1376
|
|
|
1381
1377
|
if "mask" in self.var_names.keys():
|
|
1382
1378
|
ds = self.ds
|
|
1383
1379
|
mask = xr.where(self.ds[self.var_names["mask"]].isel(time=0).isnull(), 0, 1)
|
|
1384
1380
|
ds["mask"] = mask
|
|
1385
1381
|
ds = ds.drop_vars([self.var_names["mask"]])
|
|
1386
|
-
|
|
1382
|
+
self.ds = ds
|
|
1387
1383
|
|
|
1388
1384
|
# Remove mask from var_names dictionary
|
|
1389
1385
|
var_names = self.var_names
|
|
1390
1386
|
var_names.pop("mask")
|
|
1391
|
-
|
|
1387
|
+
self.var_names = var_names
|
|
1392
1388
|
|
|
1393
1389
|
|
|
1394
|
-
@dataclass(
|
|
1390
|
+
@dataclass(kw_only=True)
|
|
1395
1391
|
class ERA5Correction(Dataset):
|
|
1396
1392
|
"""Global dataset to correct ERA5 radiation. The dataset contains multiplicative
|
|
1397
1393
|
correction factors for the ERA5 shortwave radiation, obtained by comparing the
|
|
@@ -1493,10 +1489,10 @@ class ERA5Correction(Dataset):
|
|
|
1493
1489
|
raise ValueError(
|
|
1494
1490
|
"The correction dataset does not contain all specified longitude values."
|
|
1495
1491
|
)
|
|
1496
|
-
|
|
1492
|
+
self.ds = subdomain
|
|
1497
1493
|
|
|
1498
1494
|
|
|
1499
|
-
@dataclass(
|
|
1495
|
+
@dataclass(kw_only=True)
|
|
1500
1496
|
class ETOPO5Dataset(Dataset):
|
|
1501
1497
|
"""Represents topography data on the original grid from the ETOPO5 dataset.
|
|
1502
1498
|
|
|
@@ -1553,7 +1549,7 @@ class ETOPO5Dataset(Dataset):
|
|
|
1553
1549
|
return ds
|
|
1554
1550
|
|
|
1555
1551
|
|
|
1556
|
-
@dataclass(
|
|
1552
|
+
@dataclass(kw_only=True)
|
|
1557
1553
|
class SRTM15Dataset(Dataset):
|
|
1558
1554
|
"""Represents topography data on the original grid from the SRTM15 dataset.
|
|
1559
1555
|
|
|
@@ -1589,7 +1585,7 @@ class SRTM15Dataset(Dataset):
|
|
|
1589
1585
|
|
|
1590
1586
|
|
|
1591
1587
|
# river datasets
|
|
1592
|
-
@dataclass(
|
|
1588
|
+
@dataclass(kw_only=True)
|
|
1593
1589
|
class RiverDataset:
|
|
1594
1590
|
"""Represents river data.
|
|
1595
1591
|
|
|
@@ -1647,7 +1643,7 @@ class RiverDataset:
|
|
|
1647
1643
|
|
|
1648
1644
|
# Select relevant times
|
|
1649
1645
|
ds = self.add_time_info(ds)
|
|
1650
|
-
|
|
1646
|
+
self.ds = ds
|
|
1651
1647
|
|
|
1652
1648
|
def load_data(self) -> xr.Dataset:
|
|
1653
1649
|
"""Load dataset from the specified file.
|
|
@@ -1785,13 +1781,13 @@ class RiverDataset:
|
|
|
1785
1781
|
|
|
1786
1782
|
ds = assign_dates_to_climatology(self.ds, "month")
|
|
1787
1783
|
ds = ds.swap_dims({"month": "time"})
|
|
1788
|
-
|
|
1784
|
+
self.ds = ds
|
|
1789
1785
|
|
|
1790
1786
|
updated_dim_names = {**self.dim_names}
|
|
1791
1787
|
updated_dim_names["time"] = "time"
|
|
1792
|
-
|
|
1788
|
+
self.dim_names = updated_dim_names
|
|
1793
1789
|
|
|
1794
|
-
|
|
1790
|
+
self.climatology = True
|
|
1795
1791
|
|
|
1796
1792
|
def sort_by_river_volume(self, ds: xr.Dataset) -> xr.Dataset:
|
|
1797
1793
|
"""Sorts the dataset by river volume in descending order (largest rivers first),
|
|
@@ -1911,7 +1907,7 @@ class RiverDataset:
|
|
|
1911
1907
|
ds = xr.Dataset()
|
|
1912
1908
|
river_indices = {}
|
|
1913
1909
|
|
|
1914
|
-
|
|
1910
|
+
self.ds = ds
|
|
1915
1911
|
|
|
1916
1912
|
return river_indices
|
|
1917
1913
|
|
|
@@ -1961,10 +1957,10 @@ class RiverDataset:
|
|
|
1961
1957
|
)
|
|
1962
1958
|
|
|
1963
1959
|
# Set the filtered dataset as the new `ds`
|
|
1964
|
-
|
|
1960
|
+
self.ds = ds_filtered
|
|
1965
1961
|
|
|
1966
1962
|
|
|
1967
|
-
@dataclass(
|
|
1963
|
+
@dataclass(kw_only=True)
|
|
1968
1964
|
class DaiRiverDataset(RiverDataset):
|
|
1969
1965
|
"""Represents river data from the Dai river dataset.
|
|
1970
1966
|
|