roms-tools 0.0.6__py3-none-any.whl → 0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ci/environment.yml +29 -0
- roms_tools/__init__.py +6 -0
- roms_tools/_version.py +1 -1
- roms_tools/setup/atmospheric_forcing.py +935 -0
- roms_tools/setup/boundary_forcing.py +711 -0
- roms_tools/setup/datasets.py +457 -0
- roms_tools/setup/fill.py +376 -0
- roms_tools/setup/grid.py +610 -325
- roms_tools/setup/initial_conditions.py +528 -0
- roms_tools/setup/plot.py +203 -0
- roms_tools/setup/tides.py +809 -0
- roms_tools/setup/topography.py +257 -0
- roms_tools/setup/utils.py +162 -0
- roms_tools/setup/vertical_coordinate.py +494 -0
- roms_tools/tests/test_atmospheric_forcing.py +1645 -0
- roms_tools/tests/test_boundary_forcing.py +332 -0
- roms_tools/tests/test_datasets.py +306 -0
- roms_tools/tests/test_grid.py +226 -0
- roms_tools/tests/test_initial_conditions.py +300 -0
- roms_tools/tests/test_tides.py +366 -0
- roms_tools/tests/test_topography.py +78 -0
- roms_tools/tests/test_vertical_coordinate.py +337 -0
- roms_tools-0.20.dist-info/METADATA +90 -0
- roms_tools-0.20.dist-info/RECORD +28 -0
- {roms_tools-0.0.6.dist-info → roms_tools-0.20.dist-info}/WHEEL +1 -1
- {roms_tools-0.0.6.dist-info → roms_tools-0.20.dist-info}/top_level.txt +1 -0
- roms_tools/tests/test_setup.py +0 -54
- roms_tools-0.0.6.dist-info/METADATA +0 -134
- roms_tools-0.0.6.dist-info/RECORD +0 -10
- {roms_tools-0.0.6.dist-info → roms_tools-0.20.dist-info}/LICENSE +0 -0
|
@@ -0,0 +1,528 @@
|
|
|
1
|
+
import xarray as xr
|
|
2
|
+
import numpy as np
|
|
3
|
+
import yaml
|
|
4
|
+
import importlib.metadata
|
|
5
|
+
from dataclasses import dataclass, field, asdict
|
|
6
|
+
from roms_tools.setup.grid import Grid
|
|
7
|
+
from roms_tools.setup.vertical_coordinate import VerticalCoordinate
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from roms_tools.setup.datasets import Dataset
|
|
10
|
+
from roms_tools.setup.fill import fill_and_interpolate
|
|
11
|
+
from roms_tools.setup.utils import (
|
|
12
|
+
nan_check,
|
|
13
|
+
interpolate_from_rho_to_u,
|
|
14
|
+
interpolate_from_rho_to_v,
|
|
15
|
+
extrapolate_deepest_to_bottom,
|
|
16
|
+
)
|
|
17
|
+
from roms_tools.setup.plot import _plot, _section_plot, _profile_plot, _line_plot
|
|
18
|
+
import matplotlib.pyplot as plt
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass(frozen=True, kw_only=True)
|
|
22
|
+
class InitialConditions:
|
|
23
|
+
"""
|
|
24
|
+
Represents initial conditions for ROMS.
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
grid : Grid
|
|
29
|
+
Object representing the grid information.
|
|
30
|
+
vertical_coordinate: VerticalCoordinate
|
|
31
|
+
Object representing the vertical coordinate information
|
|
32
|
+
ini_time : datetime
|
|
33
|
+
Desired initialization time.
|
|
34
|
+
model_reference_date : datetime, optional
|
|
35
|
+
Reference date for the model. Default is January 1, 2000.
|
|
36
|
+
source : str, optional
|
|
37
|
+
Source of the initial condition data. Default is "GLORYS".
|
|
38
|
+
filename: str
|
|
39
|
+
Path to the source data file. Can contain wildcards.
|
|
40
|
+
|
|
41
|
+
Attributes
|
|
42
|
+
----------
|
|
43
|
+
ds : xr.Dataset
|
|
44
|
+
Xarray Dataset containing the initial condition data.
|
|
45
|
+
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
grid: Grid
|
|
49
|
+
vertical_coordinate: VerticalCoordinate
|
|
50
|
+
ini_time: datetime
|
|
51
|
+
model_reference_date: datetime = datetime(2000, 1, 1)
|
|
52
|
+
source: str = "GLORYS"
|
|
53
|
+
filename: str
|
|
54
|
+
ds: xr.Dataset = field(init=False, repr=False)
|
|
55
|
+
|
|
56
|
+
def __post_init__(self):
|
|
57
|
+
|
|
58
|
+
# Check that the source is "GLORYS"
|
|
59
|
+
if self.source != "GLORYS":
|
|
60
|
+
raise ValueError('Only "GLORYS" is a valid option for source.')
|
|
61
|
+
if self.source == "GLORYS":
|
|
62
|
+
dims = {
|
|
63
|
+
"longitude": "longitude",
|
|
64
|
+
"latitude": "latitude",
|
|
65
|
+
"depth": "depth",
|
|
66
|
+
"time": "time",
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
varnames = {
|
|
70
|
+
"temp": "thetao",
|
|
71
|
+
"salt": "so",
|
|
72
|
+
"u": "uo",
|
|
73
|
+
"v": "vo",
|
|
74
|
+
"ssh": "zos",
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
data = Dataset(
|
|
78
|
+
filename=self.filename,
|
|
79
|
+
start_time=self.ini_time,
|
|
80
|
+
var_names=varnames.values(),
|
|
81
|
+
dim_names=dims,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
lon = self.grid.ds.lon_rho
|
|
85
|
+
lat = self.grid.ds.lat_rho
|
|
86
|
+
angle = self.grid.ds.angle
|
|
87
|
+
|
|
88
|
+
# operate on longitudes between -180 and 180 unless ROMS domain lies at least 5 degrees in lontitude away from Greenwich meridian
|
|
89
|
+
lon = xr.where(lon > 180, lon - 360, lon)
|
|
90
|
+
straddle = True
|
|
91
|
+
if not self.grid.straddle and abs(lon).min() > 5:
|
|
92
|
+
lon = xr.where(lon < 0, lon + 360, lon)
|
|
93
|
+
straddle = False
|
|
94
|
+
|
|
95
|
+
# The following consists of two steps:
|
|
96
|
+
# Step 1: Choose subdomain of forcing data including safety margin for interpolation, and Step 2: Convert to the proper longitude range.
|
|
97
|
+
# We perform these two steps for two reasons:
|
|
98
|
+
# A) Since the horizontal dimensions consist of a single chunk, selecting a subdomain before interpolation is a lot more performant.
|
|
99
|
+
# B) Step 1 is necessary to avoid discontinuous longitudes that could be introduced by Step 2. Specifically, discontinuous longitudes
|
|
100
|
+
# can lead to artifacts in the interpolation process. Specifically, if there is a data gap if data is not global,
|
|
101
|
+
# discontinuous longitudes could result in values that appear to come from a distant location instead of producing NaNs.
|
|
102
|
+
# These NaNs are important as they can be identified and handled appropriately by the nan_check function.
|
|
103
|
+
data.choose_subdomain(
|
|
104
|
+
latitude_range=[lat.min().values, lat.max().values],
|
|
105
|
+
longitude_range=[lon.min().values, lon.max().values],
|
|
106
|
+
margin=2,
|
|
107
|
+
straddle=straddle,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# interpolate onto desired grid
|
|
111
|
+
fill_dims = [dims["latitude"], dims["longitude"]]
|
|
112
|
+
|
|
113
|
+
# 2d interpolation
|
|
114
|
+
mask = xr.where(data.ds[varnames["ssh"]].isel(time=0).isnull(), 0, 1)
|
|
115
|
+
coords = {dims["latitude"]: lat, dims["longitude"]: lon}
|
|
116
|
+
|
|
117
|
+
ssh = fill_and_interpolate(
|
|
118
|
+
data.ds[varnames["ssh"]].astype(np.float64),
|
|
119
|
+
mask,
|
|
120
|
+
fill_dims=fill_dims,
|
|
121
|
+
coords=coords,
|
|
122
|
+
method="linear",
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# 3d interpolation
|
|
126
|
+
|
|
127
|
+
# extrapolate deepest value all the way to bottom ("flooding")
|
|
128
|
+
for var in ["temp", "salt", "u", "v"]:
|
|
129
|
+
data.ds[varnames[var]] = extrapolate_deepest_to_bottom(
|
|
130
|
+
data.ds[varnames[var]], dims["depth"]
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
mask = xr.where(data.ds[varnames["temp"]].isel(time=0).isnull(), 0, 1)
|
|
134
|
+
coords = {
|
|
135
|
+
dims["latitude"]: lat,
|
|
136
|
+
dims["longitude"]: lon,
|
|
137
|
+
dims["depth"]: self.vertical_coordinate.ds["layer_depth_rho"],
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
# setting fillvalue_interp to None means that we allow extrapolation in the
|
|
141
|
+
# interpolation step to avoid NaNs at the surface if the lowest depth in original
|
|
142
|
+
# data is greater than zero
|
|
143
|
+
data_vars = {}
|
|
144
|
+
for var in ["temp", "salt", "u", "v"]:
|
|
145
|
+
data_vars[var] = fill_and_interpolate(
|
|
146
|
+
data.ds[varnames[var]].astype(np.float64),
|
|
147
|
+
mask,
|
|
148
|
+
fill_dims=fill_dims,
|
|
149
|
+
coords=coords,
|
|
150
|
+
method="linear",
|
|
151
|
+
fillvalue_interp=None,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# rotate to grid orientation
|
|
155
|
+
u_rot = data_vars["u"] * np.cos(angle) + data_vars["v"] * np.sin(angle)
|
|
156
|
+
v_rot = data_vars["v"] * np.cos(angle) - data_vars["u"] * np.sin(angle)
|
|
157
|
+
|
|
158
|
+
# interpolate to u- and v-points
|
|
159
|
+
u = interpolate_from_rho_to_u(u_rot)
|
|
160
|
+
v = interpolate_from_rho_to_v(v_rot)
|
|
161
|
+
|
|
162
|
+
# 3d masks for ROMS domain
|
|
163
|
+
umask = self.grid.ds.mask_u.expand_dims({"s_rho": u.s_rho})
|
|
164
|
+
vmask = self.grid.ds.mask_v.expand_dims({"s_rho": v.s_rho})
|
|
165
|
+
|
|
166
|
+
u = u * umask
|
|
167
|
+
v = v * vmask
|
|
168
|
+
|
|
169
|
+
# Compute barotropic velocity
|
|
170
|
+
# thicknesses
|
|
171
|
+
dz = -self.vertical_coordinate.ds["interface_depth_rho"].diff(dim="s_w")
|
|
172
|
+
dz = dz.rename({"s_w": "s_rho"})
|
|
173
|
+
# thicknesses at u- and v-points
|
|
174
|
+
dzu = interpolate_from_rho_to_u(dz)
|
|
175
|
+
dzv = interpolate_from_rho_to_v(dz)
|
|
176
|
+
|
|
177
|
+
ubar = (dzu * u).sum(dim="s_rho") / dzu.sum(dim="s_rho")
|
|
178
|
+
vbar = (dzv * v).sum(dim="s_rho") / dzv.sum(dim="s_rho")
|
|
179
|
+
|
|
180
|
+
# save in new dataset
|
|
181
|
+
ds = xr.Dataset()
|
|
182
|
+
|
|
183
|
+
ds["temp"] = data_vars["temp"].astype(np.float32)
|
|
184
|
+
ds["temp"].attrs["long_name"] = "Potential temperature"
|
|
185
|
+
ds["temp"].attrs["units"] = "Celsius"
|
|
186
|
+
|
|
187
|
+
ds["salt"] = data_vars["salt"].astype(np.float32)
|
|
188
|
+
ds["salt"].attrs["long_name"] = "Salinity"
|
|
189
|
+
ds["salt"].attrs["units"] = "PSU"
|
|
190
|
+
|
|
191
|
+
ds["zeta"] = ssh.astype(np.float32)
|
|
192
|
+
ds["zeta"].attrs["long_name"] = "Free surface"
|
|
193
|
+
ds["zeta"].attrs["units"] = "m"
|
|
194
|
+
|
|
195
|
+
ds["u"] = u.astype(np.float32)
|
|
196
|
+
ds["u"].attrs["long_name"] = "u-flux component"
|
|
197
|
+
ds["u"].attrs["units"] = "m/s"
|
|
198
|
+
|
|
199
|
+
ds["v"] = v.astype(np.float32)
|
|
200
|
+
ds["v"].attrs["long_name"] = "v-flux component"
|
|
201
|
+
ds["v"].attrs["units"] = "m/s"
|
|
202
|
+
|
|
203
|
+
# initialize vertical velocity to zero
|
|
204
|
+
ds["w"] = xr.zeros_like(
|
|
205
|
+
self.vertical_coordinate.ds["interface_depth_rho"].expand_dims(
|
|
206
|
+
time=ds[dims["time"]]
|
|
207
|
+
)
|
|
208
|
+
).astype(np.float32)
|
|
209
|
+
ds["w"].attrs["long_name"] = "w-flux component"
|
|
210
|
+
ds["w"].attrs["units"] = "m/s"
|
|
211
|
+
|
|
212
|
+
ds["ubar"] = ubar.transpose(dims["time"], "eta_rho", "xi_u").astype(np.float32)
|
|
213
|
+
ds["ubar"].attrs["long_name"] = "vertically integrated u-flux component"
|
|
214
|
+
ds["ubar"].attrs["units"] = "m/s"
|
|
215
|
+
|
|
216
|
+
ds["vbar"] = vbar.transpose(dims["time"], "eta_v", "xi_rho").astype(np.float32)
|
|
217
|
+
ds["vbar"].attrs["long_name"] = "vertically integrated v-flux component"
|
|
218
|
+
ds["vbar"].attrs["units"] = "m/s"
|
|
219
|
+
|
|
220
|
+
ds = ds.assign_coords(
|
|
221
|
+
{
|
|
222
|
+
"layer_depth_u": self.vertical_coordinate.ds["layer_depth_u"],
|
|
223
|
+
"layer_depth_v": self.vertical_coordinate.ds["layer_depth_v"],
|
|
224
|
+
"interface_depth_u": self.vertical_coordinate.ds["interface_depth_u"],
|
|
225
|
+
"interface_depth_v": self.vertical_coordinate.ds["interface_depth_v"],
|
|
226
|
+
}
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
ds.attrs["title"] = "ROMS initial conditions file created by ROMS-Tools"
|
|
230
|
+
# Include the version of roms-tools
|
|
231
|
+
try:
|
|
232
|
+
roms_tools_version = importlib.metadata.version("roms-tools")
|
|
233
|
+
except importlib.metadata.PackageNotFoundError:
|
|
234
|
+
roms_tools_version = "unknown"
|
|
235
|
+
ds.attrs["roms_tools_version"] = roms_tools_version
|
|
236
|
+
ds.attrs["ini_time"] = str(self.ini_time)
|
|
237
|
+
ds.attrs["model_reference_date"] = str(self.model_reference_date)
|
|
238
|
+
ds.attrs["source"] = self.source
|
|
239
|
+
|
|
240
|
+
if dims["time"] != "time":
|
|
241
|
+
ds = ds.rename({dims["time"]: "time"})
|
|
242
|
+
|
|
243
|
+
# Translate the time coordinate to days since the model reference date
|
|
244
|
+
model_reference_date = np.datetime64(self.model_reference_date)
|
|
245
|
+
|
|
246
|
+
# Convert the time coordinate to the format expected by ROMS (days since model reference date)
|
|
247
|
+
ocean_time = (ds["time"] - model_reference_date).astype("float64") * 1e-9
|
|
248
|
+
ds = ds.assign_coords(ocean_time=("time", np.float32(ocean_time)))
|
|
249
|
+
ds["ocean_time"].attrs[
|
|
250
|
+
"long_name"
|
|
251
|
+
] = f"time since {np.datetime_as_string(model_reference_date, unit='D')}"
|
|
252
|
+
ds["ocean_time"].attrs["units"] = "seconds"
|
|
253
|
+
|
|
254
|
+
ds["theta_s"] = self.vertical_coordinate.ds["theta_s"]
|
|
255
|
+
ds["theta_b"] = self.vertical_coordinate.ds["theta_b"]
|
|
256
|
+
ds["Tcline"] = self.vertical_coordinate.ds["Tcline"]
|
|
257
|
+
ds["hc"] = self.vertical_coordinate.ds["hc"]
|
|
258
|
+
ds["sc_r"] = self.vertical_coordinate.ds["sc_r"]
|
|
259
|
+
ds["Cs_r"] = self.vertical_coordinate.ds["Cs_r"]
|
|
260
|
+
|
|
261
|
+
ds = ds.drop_vars(["s_rho"])
|
|
262
|
+
|
|
263
|
+
object.__setattr__(self, "ds", ds)
|
|
264
|
+
|
|
265
|
+
ds["zeta"].load()
|
|
266
|
+
nan_check(ds["zeta"].squeeze(), self.grid.ds.mask_rho)
|
|
267
|
+
|
|
268
|
+
def plot(
|
|
269
|
+
self,
|
|
270
|
+
varname,
|
|
271
|
+
s=None,
|
|
272
|
+
eta=None,
|
|
273
|
+
xi=None,
|
|
274
|
+
depth_contours=False,
|
|
275
|
+
layer_contours=False,
|
|
276
|
+
) -> None:
|
|
277
|
+
"""
|
|
278
|
+
Plot the initial conditions field for a given eta-, xi-, or s_rho-slice.
|
|
279
|
+
|
|
280
|
+
Parameters
|
|
281
|
+
----------
|
|
282
|
+
varname : str
|
|
283
|
+
The name of the initial conditions field to plot. Options include:
|
|
284
|
+
- "temp": Potential temperature.
|
|
285
|
+
- "salt": Salinity.
|
|
286
|
+
- "zeta": Free surface.
|
|
287
|
+
- "u": u-flux component.
|
|
288
|
+
- "v": v-flux component.
|
|
289
|
+
- "w": w-flux component.
|
|
290
|
+
- "ubar": Vertically integrated u-flux component.
|
|
291
|
+
- "vbar": Vertically integrated v-flux component.
|
|
292
|
+
s : int, optional
|
|
293
|
+
The index of the vertical layer to plot. Default is None.
|
|
294
|
+
eta : int, optional
|
|
295
|
+
The eta-index to plot. Default is None.
|
|
296
|
+
xi : int, optional
|
|
297
|
+
The xi-index to plot. Default is None.
|
|
298
|
+
depth_contours : bool, optional
|
|
299
|
+
Whether to include depth contours in the plot. Default is False.
|
|
300
|
+
|
|
301
|
+
Returns
|
|
302
|
+
-------
|
|
303
|
+
None
|
|
304
|
+
This method does not return any value. It generates and displays a plot.
|
|
305
|
+
|
|
306
|
+
Raises
|
|
307
|
+
------
|
|
308
|
+
ValueError
|
|
309
|
+
If the specified varname is not one of the valid options.
|
|
310
|
+
If field is 3D and none of s_rho, eta, xi are specified.
|
|
311
|
+
If field is 2D and both eta and xi are specified.
|
|
312
|
+
"""
|
|
313
|
+
|
|
314
|
+
if len(self.ds[varname].squeeze().dims) == 3 and not any(
|
|
315
|
+
[s is not None, eta is not None, xi is not None]
|
|
316
|
+
):
|
|
317
|
+
raise ValueError(
|
|
318
|
+
"For 3D fields, at least one of s, eta, or xi must be specified."
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
if len(self.ds[varname].squeeze().dims) == 2 and all(
|
|
322
|
+
[eta is not None, xi is not None]
|
|
323
|
+
):
|
|
324
|
+
raise ValueError("For 2D fields, specify either eta or xi, not both.")
|
|
325
|
+
|
|
326
|
+
self.ds[varname].load()
|
|
327
|
+
field = self.ds[varname].squeeze()
|
|
328
|
+
|
|
329
|
+
if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
|
|
330
|
+
interface_depth = self.ds.interface_depth_rho
|
|
331
|
+
elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
|
|
332
|
+
interface_depth = self.ds.interface_depth_u
|
|
333
|
+
elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
|
|
334
|
+
interface_depth = self.ds.interface_depth_v
|
|
335
|
+
|
|
336
|
+
# slice the field as desired
|
|
337
|
+
title = field.long_name
|
|
338
|
+
if s is not None:
|
|
339
|
+
title = title + f", s_rho = {field.s_rho[s].item()}"
|
|
340
|
+
field = field.isel(s_rho=s)
|
|
341
|
+
else:
|
|
342
|
+
depth_contours = False
|
|
343
|
+
|
|
344
|
+
if eta is not None:
|
|
345
|
+
if "eta_rho" in field.dims:
|
|
346
|
+
title = title + f", eta_rho = {field.eta_rho[eta].item()}"
|
|
347
|
+
field = field.isel(eta_rho=eta)
|
|
348
|
+
interface_depth = interface_depth.isel(eta_rho=eta)
|
|
349
|
+
elif "eta_v" in field.dims:
|
|
350
|
+
title = title + f", eta_v = {field.eta_v[eta].item()}"
|
|
351
|
+
field = field.isel(eta_v=eta)
|
|
352
|
+
interface_depth = interface_depth.isel(eta_v=eta)
|
|
353
|
+
else:
|
|
354
|
+
raise ValueError(
|
|
355
|
+
f"None of the expected dimensions (eta_rho, eta_v) found in ds[{varname}]."
|
|
356
|
+
)
|
|
357
|
+
if xi is not None:
|
|
358
|
+
if "xi_rho" in field.dims:
|
|
359
|
+
title = title + f", xi_rho = {field.xi_rho[xi].item()}"
|
|
360
|
+
field = field.isel(xi_rho=xi)
|
|
361
|
+
interface_depth = interface_depth.isel(xi_rho=xi)
|
|
362
|
+
elif "xi_u" in field.dims:
|
|
363
|
+
title = title + f", xi_u = {field.xi_u[xi].item()}"
|
|
364
|
+
field = field.isel(xi_u=xi)
|
|
365
|
+
interface_depth = interface_depth.isel(xi_u=xi)
|
|
366
|
+
else:
|
|
367
|
+
raise ValueError(
|
|
368
|
+
f"None of the expected dimensions (xi_rho, xi_u) found in ds[{varname}]."
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
# chose colorbar
|
|
372
|
+
if varname in ["u", "v", "w", "ubar", "vbar", "zeta"]:
|
|
373
|
+
vmax = max(field.max().values, -field.min().values)
|
|
374
|
+
vmin = -vmax
|
|
375
|
+
cmap = plt.colormaps.get_cmap("RdBu_r")
|
|
376
|
+
else:
|
|
377
|
+
vmax = field.max().values
|
|
378
|
+
vmin = field.min().values
|
|
379
|
+
cmap = plt.colormaps.get_cmap("YlOrRd")
|
|
380
|
+
cmap.set_bad(color="gray")
|
|
381
|
+
kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
|
|
382
|
+
|
|
383
|
+
if eta is None and xi is None:
|
|
384
|
+
_plot(
|
|
385
|
+
self.grid.ds,
|
|
386
|
+
field=field,
|
|
387
|
+
straddle=self.grid.straddle,
|
|
388
|
+
depth_contours=depth_contours,
|
|
389
|
+
title=title,
|
|
390
|
+
kwargs=kwargs,
|
|
391
|
+
c="g",
|
|
392
|
+
)
|
|
393
|
+
else:
|
|
394
|
+
if not layer_contours:
|
|
395
|
+
interface_depth = None
|
|
396
|
+
else:
|
|
397
|
+
# restrict number of layer_contours to 10 for the sake of plot clearity
|
|
398
|
+
nr_layers = len(interface_depth["s_w"])
|
|
399
|
+
selected_layers = np.linspace(
|
|
400
|
+
0, nr_layers - 1, min(nr_layers, 10), dtype=int
|
|
401
|
+
)
|
|
402
|
+
interface_depth = interface_depth.isel(s_w=selected_layers)
|
|
403
|
+
|
|
404
|
+
if len(field.dims) == 2:
|
|
405
|
+
_section_plot(
|
|
406
|
+
field, interface_depth=interface_depth, title=title, kwargs=kwargs
|
|
407
|
+
)
|
|
408
|
+
else:
|
|
409
|
+
if "s_rho" in field.dims:
|
|
410
|
+
_profile_plot(field, title=title)
|
|
411
|
+
else:
|
|
412
|
+
_line_plot(field, title=title)
|
|
413
|
+
|
|
414
|
+
def save(self, filepath: str) -> None:
|
|
415
|
+
"""
|
|
416
|
+
Save the initial conditions information to a netCDF4 file.
|
|
417
|
+
|
|
418
|
+
Parameters
|
|
419
|
+
----------
|
|
420
|
+
filepath
|
|
421
|
+
"""
|
|
422
|
+
self.ds.to_netcdf(filepath)
|
|
423
|
+
|
|
424
|
+
def to_yaml(self, filepath: str) -> None:
|
|
425
|
+
"""
|
|
426
|
+
Export the parameters of the class to a YAML file, including the version of roms-tools.
|
|
427
|
+
|
|
428
|
+
Parameters
|
|
429
|
+
----------
|
|
430
|
+
filepath : str
|
|
431
|
+
The path to the YAML file where the parameters will be saved.
|
|
432
|
+
"""
|
|
433
|
+
# Serialize Grid data
|
|
434
|
+
grid_data = asdict(self.grid)
|
|
435
|
+
grid_data.pop("ds", None) # Exclude non-serializable fields
|
|
436
|
+
grid_data.pop("straddle", None)
|
|
437
|
+
|
|
438
|
+
# Serialize VerticalCoordinate data
|
|
439
|
+
vertical_coordinate_data = asdict(self.vertical_coordinate)
|
|
440
|
+
vertical_coordinate_data.pop("ds", None) # Exclude non-serializable fields
|
|
441
|
+
vertical_coordinate_data.pop("grid", None) # Exclude non-serializable fields
|
|
442
|
+
|
|
443
|
+
# Include the version of roms-tools
|
|
444
|
+
try:
|
|
445
|
+
roms_tools_version = importlib.metadata.version("roms-tools")
|
|
446
|
+
except importlib.metadata.PackageNotFoundError:
|
|
447
|
+
roms_tools_version = "unknown"
|
|
448
|
+
|
|
449
|
+
# Create header
|
|
450
|
+
header = f"---\nroms_tools_version: {roms_tools_version}\n---\n"
|
|
451
|
+
|
|
452
|
+
grid_yaml_data = {"Grid": grid_data}
|
|
453
|
+
vertical_coordinate_yaml_data = {"VerticalCoordinate": vertical_coordinate_data}
|
|
454
|
+
|
|
455
|
+
initial_conditions_data = {
|
|
456
|
+
"InitialConditions": {
|
|
457
|
+
"filename": self.filename,
|
|
458
|
+
"ini_time": self.ini_time.isoformat(),
|
|
459
|
+
"model_reference_date": self.model_reference_date.isoformat(),
|
|
460
|
+
"source": self.source,
|
|
461
|
+
}
|
|
462
|
+
}
|
|
463
|
+
|
|
464
|
+
yaml_data = {
|
|
465
|
+
**grid_yaml_data,
|
|
466
|
+
**vertical_coordinate_yaml_data,
|
|
467
|
+
**initial_conditions_data,
|
|
468
|
+
}
|
|
469
|
+
|
|
470
|
+
with open(filepath, "w") as file:
|
|
471
|
+
# Write header
|
|
472
|
+
file.write(header)
|
|
473
|
+
# Write YAML data
|
|
474
|
+
yaml.dump(yaml_data, file, default_flow_style=False)
|
|
475
|
+
|
|
476
|
+
@classmethod
|
|
477
|
+
def from_yaml(cls, filepath: str) -> "InitialConditions":
|
|
478
|
+
"""
|
|
479
|
+
Create an instance of the InitialConditions class from a YAML file.
|
|
480
|
+
|
|
481
|
+
Parameters
|
|
482
|
+
----------
|
|
483
|
+
filepath : str
|
|
484
|
+
The path to the YAML file from which the parameters will be read.
|
|
485
|
+
|
|
486
|
+
Returns
|
|
487
|
+
-------
|
|
488
|
+
InitialConditions
|
|
489
|
+
An instance of the InitialConditions class.
|
|
490
|
+
"""
|
|
491
|
+
# Read the entire file content
|
|
492
|
+
with open(filepath, "r") as file:
|
|
493
|
+
file_content = file.read()
|
|
494
|
+
|
|
495
|
+
# Split the content into YAML documents
|
|
496
|
+
documents = list(yaml.safe_load_all(file_content))
|
|
497
|
+
|
|
498
|
+
initial_conditions_data = None
|
|
499
|
+
|
|
500
|
+
# Process the YAML documents
|
|
501
|
+
for doc in documents:
|
|
502
|
+
if doc is None:
|
|
503
|
+
continue
|
|
504
|
+
if "InitialConditions" in doc:
|
|
505
|
+
initial_conditions_data = doc["InitialConditions"]
|
|
506
|
+
break
|
|
507
|
+
|
|
508
|
+
if initial_conditions_data is None:
|
|
509
|
+
raise ValueError(
|
|
510
|
+
"No InitialConditions configuration found in the YAML file."
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
# Convert from string to datetime
|
|
514
|
+
for date_string in ["model_reference_date", "ini_time"]:
|
|
515
|
+
initial_conditions_data[date_string] = datetime.fromisoformat(
|
|
516
|
+
initial_conditions_data[date_string]
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
# Create VerticalCoordinate instance from the YAML file
|
|
520
|
+
vertical_coordinate = VerticalCoordinate.from_yaml(filepath)
|
|
521
|
+
grid = vertical_coordinate.grid
|
|
522
|
+
|
|
523
|
+
# Create and return an instance of InitialConditions
|
|
524
|
+
return cls(
|
|
525
|
+
grid=grid,
|
|
526
|
+
vertical_coordinate=vertical_coordinate,
|
|
527
|
+
**initial_conditions_data,
|
|
528
|
+
)
|