roms-tools 0.0.6__py3-none-any.whl → 0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,711 @@
1
+ import xarray as xr
2
+ import numpy as np
3
+ import yaml
4
+ import importlib.metadata
5
+ from typing import Dict
6
+ from dataclasses import dataclass, field, asdict
7
+ from roms_tools.setup.grid import Grid
8
+ from roms_tools.setup.vertical_coordinate import VerticalCoordinate
9
+ from datetime import datetime
10
+ from roms_tools.setup.datasets import Dataset
11
+ from roms_tools.setup.fill import fill_and_interpolate
12
+ from roms_tools.setup.utils import (
13
+ nan_check,
14
+ interpolate_from_rho_to_u,
15
+ interpolate_from_rho_to_v,
16
+ extrapolate_deepest_to_bottom,
17
+ )
18
+ from roms_tools.setup.plot import _section_plot, _line_plot
19
+ import calendar
20
+ import dask
21
+ import matplotlib.pyplot as plt
22
+
23
+
24
+ @dataclass(frozen=True, kw_only=True)
25
+ class BoundaryForcing:
26
+ """
27
+ Represents boundary forcing for ROMS.
28
+
29
+ Parameters
30
+ ----------
31
+ grid : Grid
32
+ Object representing the grid information.
33
+ vertical_coordinate: VerticalCoordinate
34
+ Object representing the vertical coordinate information.
35
+ start_time : datetime
36
+ Start time of the desired boundary forcing data.
37
+ end_time : datetime
38
+ End time of the desired boundary forcing data.
39
+ boundaries : Dict[str, bool], optional
40
+ Dictionary specifying which boundaries are forced (south, east, north, west). Default is all True.
41
+ model_reference_date : datetime, optional
42
+ Reference date for the model. Default is January 1, 2000.
43
+ source : str, optional
44
+ Source of the boundary forcing data. Default is "glorys".
45
+ filename: str
46
+ Path to the source data file. Can contain wildcards.
47
+
48
+ Attributes
49
+ ----------
50
+ ds : xr.Dataset
51
+ Xarray Dataset containing the atmospheric forcing data.
52
+
53
+ Notes
54
+ -----
55
+ This class represents atmospheric forcing data used in ocean modeling. It provides a convenient
56
+ interface to work with forcing data including shortwave radiation correction and river forcing.
57
+ """
58
+
59
+ grid: Grid
60
+ vertical_coordinate: VerticalCoordinate
61
+ start_time: datetime
62
+ end_time: datetime
63
+ boundaries: Dict[str, bool] = field(
64
+ default_factory=lambda: {
65
+ "south": True,
66
+ "east": True,
67
+ "north": True,
68
+ "west": True,
69
+ }
70
+ )
71
+ model_reference_date: datetime = datetime(2000, 1, 1)
72
+ source: str = "glorys"
73
+ filename: str
74
+ ds: xr.Dataset = field(init=False, repr=False)
75
+
76
+ def __post_init__(self):
77
+
78
+ lon = self.grid.ds.lon_rho
79
+ lat = self.grid.ds.lat_rho
80
+ angle = self.grid.ds.angle
81
+
82
+ if self.source == "glorys":
83
+ dims = {
84
+ "longitude": "longitude",
85
+ "latitude": "latitude",
86
+ "depth": "depth",
87
+ "time": "time",
88
+ }
89
+
90
+ varnames = {
91
+ "temp": "thetao",
92
+ "salt": "so",
93
+ "u": "uo",
94
+ "v": "vo",
95
+ "ssh": "zos",
96
+ }
97
+ data = Dataset(
98
+ filename=self.filename,
99
+ start_time=self.start_time,
100
+ end_time=self.end_time,
101
+ var_names=varnames.values(),
102
+ dim_names=dims,
103
+ )
104
+
105
+ # operate on longitudes between -180 and 180 unless ROMS domain lies at least 5 degrees in lontitude away from Greenwich meridian
106
+ lon = xr.where(lon > 180, lon - 360, lon)
107
+ straddle = True
108
+ if not self.grid.straddle and abs(lon).min() > 5:
109
+ lon = xr.where(lon < 0, lon + 360, lon)
110
+ straddle = False
111
+
112
+ # Restrict data to relevant subdomain to achieve better performance and to avoid discontinuous longitudes introduced by converting
113
+ # to a different longitude range (+- 360 degrees). Discontinues longitudes can lead to artifacts in the interpolation process that
114
+ # would not be detected by the nan_check function.
115
+ data.choose_subdomain(
116
+ latitude_range=[lat.min().values, lat.max().values],
117
+ longitude_range=[lon.min().values, lon.max().values],
118
+ margin=2,
119
+ straddle=straddle,
120
+ )
121
+
122
+ # extrapolate deepest value all the way to bottom ("flooding") to prepare for 3d interpolation
123
+ for var in ["temp", "salt", "u", "v"]:
124
+ data.ds[varnames[var]] = extrapolate_deepest_to_bottom(
125
+ data.ds[varnames[var]], dims["depth"]
126
+ )
127
+
128
+ # interpolate onto desired grid
129
+ fill_dims = [dims["latitude"], dims["longitude"]]
130
+
131
+ # 2d interpolation
132
+ coords = {dims["latitude"]: lat, dims["longitude"]: lon}
133
+ mask = xr.where(data.ds[varnames["ssh"]].isel(time=0).isnull(), 0, 1)
134
+
135
+ ssh = fill_and_interpolate(
136
+ data.ds[varnames["ssh"]].astype(np.float64),
137
+ mask,
138
+ fill_dims=fill_dims,
139
+ coords=coords,
140
+ method="linear",
141
+ )
142
+
143
+ # 3d interpolation
144
+ coords = {
145
+ dims["latitude"]: lat,
146
+ dims["longitude"]: lon,
147
+ dims["depth"]: self.vertical_coordinate.ds["layer_depth_rho"],
148
+ }
149
+ mask = xr.where(data.ds[varnames["temp"]].isel(time=0).isnull(), 0, 1)
150
+
151
+ data_vars = {}
152
+ # setting fillvalue_interp to None means that we allow extrapolation in the
153
+ # interpolation step to avoid NaNs at the surface if the lowest depth in original
154
+ # data is greater than zero
155
+
156
+ for var in ["temp", "salt", "u", "v"]:
157
+
158
+ data_vars[var] = fill_and_interpolate(
159
+ data.ds[varnames[var]].astype(np.float64),
160
+ mask,
161
+ fill_dims=fill_dims,
162
+ coords=coords,
163
+ method="linear",
164
+ fillvalue_interp=None,
165
+ )
166
+
167
+ # rotate velocities to grid orientation
168
+ u_rot = data_vars["u"] * np.cos(angle) + data_vars["v"] * np.sin(angle)
169
+ v_rot = data_vars["v"] * np.cos(angle) - data_vars["u"] * np.sin(angle)
170
+
171
+ # interpolate to u- and v-points
172
+ u = interpolate_from_rho_to_u(u_rot)
173
+ v = interpolate_from_rho_to_v(v_rot)
174
+
175
+ # 3d masks for ROMS domain
176
+ umask = self.grid.ds.mask_u.expand_dims({"s_rho": u.s_rho})
177
+ vmask = self.grid.ds.mask_v.expand_dims({"s_rho": v.s_rho})
178
+
179
+ u = u * umask
180
+ v = v * vmask
181
+
182
+ # Compute barotropic velocity
183
+
184
+ # thicknesses
185
+ dz = -self.vertical_coordinate.ds["interface_depth_rho"].diff(dim="s_w")
186
+ dz = dz.rename({"s_w": "s_rho"})
187
+ # thicknesses at u- and v-points
188
+ dzu = interpolate_from_rho_to_u(dz)
189
+ dzv = interpolate_from_rho_to_v(dz)
190
+
191
+ ubar = (dzu * u).sum(dim="s_rho") / dzu.sum(dim="s_rho")
192
+ vbar = (dzv * v).sum(dim="s_rho") / dzv.sum(dim="s_rho")
193
+
194
+ # Boundary coordinates for rho-points
195
+ bdry_coords_rho = {
196
+ "south": {"eta_rho": 0},
197
+ "east": {"xi_rho": -1},
198
+ "north": {"eta_rho": -1},
199
+ "west": {"xi_rho": 0},
200
+ }
201
+ # How to rename the dimensions at rho-points
202
+ rename_rho = {
203
+ "south": {"xi_rho": "xi_rho_south"},
204
+ "east": {"eta_rho": "eta_rho_east"},
205
+ "north": {"xi_rho": "xi_rho_north"},
206
+ "west": {"eta_rho": "eta_rho_west"},
207
+ }
208
+
209
+ # Boundary coordinates for u-points
210
+ bdry_coords_u = {
211
+ "south": {"eta_rho": 0},
212
+ "east": {"xi_u": -1},
213
+ "north": {"eta_rho": -1},
214
+ "west": {"xi_u": 0},
215
+ }
216
+ # How to rename the dimensions at u-points
217
+ rename_u = {
218
+ "south": {"xi_u": "xi_u_south"},
219
+ "east": {"eta_rho": "eta_u_east"},
220
+ "north": {"xi_u": "xi_u_north"},
221
+ "west": {"eta_rho": "eta_u_west"},
222
+ }
223
+
224
+ # Boundary coordinates for v-points
225
+ bdry_coords_v = {
226
+ "south": {"eta_v": 0},
227
+ "east": {"xi_rho": -1},
228
+ "north": {"eta_v": -1},
229
+ "west": {"xi_rho": 0},
230
+ }
231
+ # How to rename the dimensions at v-points
232
+ rename_v = {
233
+ "south": {"xi_rho": "xi_v_south"},
234
+ "east": {"eta_v": "eta_v_east"},
235
+ "north": {"xi_rho": "xi_v_north"},
236
+ "west": {"eta_v": "eta_v_west"},
237
+ }
238
+
239
+ ds = xr.Dataset()
240
+
241
+ for direction in ["south", "east", "north", "west"]:
242
+
243
+ if self.boundaries[direction]:
244
+
245
+ ds[f"zeta_{direction}"] = (
246
+ ssh.isel(**bdry_coords_rho[direction])
247
+ .rename(**rename_rho[direction])
248
+ .astype(np.float32)
249
+ )
250
+ ds[f"zeta_{direction}"].attrs[
251
+ "long_name"
252
+ ] = f"{direction}ern boundary sea surface height"
253
+ ds[f"zeta_{direction}"].attrs["units"] = "m"
254
+
255
+ ds[f"temp_{direction}"] = (
256
+ data_vars["temp"]
257
+ .isel(**bdry_coords_rho[direction])
258
+ .rename(**rename_rho[direction])
259
+ .astype(np.float32)
260
+ )
261
+ ds[f"temp_{direction}"].attrs[
262
+ "long_name"
263
+ ] = f"{direction}ern boundary potential temperature"
264
+ ds[f"temp_{direction}"].attrs["units"] = "Celsius"
265
+
266
+ ds[f"salt_{direction}"] = (
267
+ data_vars["salt"]
268
+ .isel(**bdry_coords_rho[direction])
269
+ .rename(**rename_rho[direction])
270
+ .astype(np.float32)
271
+ )
272
+ ds[f"salt_{direction}"].attrs[
273
+ "long_name"
274
+ ] = f"{direction}ern boundary salinity"
275
+ ds[f"salt_{direction}"].attrs["units"] = "PSU"
276
+
277
+ ds[f"u_{direction}"] = (
278
+ u.isel(**bdry_coords_u[direction])
279
+ .rename(**rename_u[direction])
280
+ .astype(np.float32)
281
+ )
282
+ ds[f"u_{direction}"].attrs[
283
+ "long_name"
284
+ ] = f"{direction}ern boundary u-flux component"
285
+ ds[f"u_{direction}"].attrs["units"] = "m/s"
286
+
287
+ ds[f"v_{direction}"] = (
288
+ v.isel(**bdry_coords_v[direction])
289
+ .rename(**rename_v[direction])
290
+ .astype(np.float32)
291
+ )
292
+ ds[f"v_{direction}"].attrs[
293
+ "long_name"
294
+ ] = f"{direction}ern boundary v-flux component"
295
+ ds[f"v_{direction}"].attrs["units"] = "m/s"
296
+
297
+ ds[f"ubar_{direction}"] = (
298
+ ubar.isel(**bdry_coords_u[direction])
299
+ .rename(**rename_u[direction])
300
+ .astype(np.float32)
301
+ )
302
+ ds[f"ubar_{direction}"].attrs[
303
+ "long_name"
304
+ ] = f"{direction}ern boundary vertically integrated u-flux component"
305
+ ds[f"ubar_{direction}"].attrs["units"] = "m/s"
306
+
307
+ ds[f"vbar_{direction}"] = (
308
+ vbar.isel(**bdry_coords_v[direction])
309
+ .rename(**rename_v[direction])
310
+ .astype(np.float32)
311
+ )
312
+ ds[f"vbar_{direction}"].attrs[
313
+ "long_name"
314
+ ] = f"{direction}ern boundary vertically integrated v-flux component"
315
+ ds[f"vbar_{direction}"].attrs["units"] = "m/s"
316
+
317
+ # assign the correct depth coordinates
318
+
319
+ lat_rho = self.grid.ds.lat_rho.isel(
320
+ **bdry_coords_rho[direction]
321
+ ).rename(**rename_rho[direction])
322
+ lon_rho = self.grid.ds.lon_rho.isel(
323
+ **bdry_coords_rho[direction]
324
+ ).rename(**rename_rho[direction])
325
+ layer_depth_rho = (
326
+ self.vertical_coordinate.ds["layer_depth_rho"]
327
+ .isel(**bdry_coords_rho[direction])
328
+ .rename(**rename_rho[direction])
329
+ )
330
+ interface_depth_rho = (
331
+ self.vertical_coordinate.ds["interface_depth_rho"]
332
+ .isel(**bdry_coords_rho[direction])
333
+ .rename(**rename_rho[direction])
334
+ )
335
+
336
+ lat_u = self.grid.ds.lat_u.isel(**bdry_coords_u[direction]).rename(
337
+ **rename_u[direction]
338
+ )
339
+ lon_u = self.grid.ds.lon_u.isel(**bdry_coords_u[direction]).rename(
340
+ **rename_u[direction]
341
+ )
342
+ layer_depth_u = (
343
+ self.vertical_coordinate.ds["layer_depth_u"]
344
+ .isel(**bdry_coords_u[direction])
345
+ .rename(**rename_u[direction])
346
+ )
347
+ interface_depth_u = (
348
+ self.vertical_coordinate.ds["interface_depth_u"]
349
+ .isel(**bdry_coords_u[direction])
350
+ .rename(**rename_u[direction])
351
+ )
352
+
353
+ lat_v = self.grid.ds.lat_v.isel(**bdry_coords_v[direction]).rename(
354
+ **rename_v[direction]
355
+ )
356
+ lon_v = self.grid.ds.lon_v.isel(**bdry_coords_v[direction]).rename(
357
+ **rename_v[direction]
358
+ )
359
+ layer_depth_v = (
360
+ self.vertical_coordinate.ds["layer_depth_v"]
361
+ .isel(**bdry_coords_v[direction])
362
+ .rename(**rename_v[direction])
363
+ )
364
+ interface_depth_v = (
365
+ self.vertical_coordinate.ds["interface_depth_v"]
366
+ .isel(**bdry_coords_v[direction])
367
+ .rename(**rename_v[direction])
368
+ )
369
+
370
+ ds = ds.assign_coords(
371
+ {
372
+ f"layer_depth_rho_{direction}": layer_depth_rho,
373
+ f"layer_depth_u_{direction}": layer_depth_u,
374
+ f"layer_depth_v_{direction}": layer_depth_v,
375
+ f"interface_depth_rho_{direction}": interface_depth_rho,
376
+ f"interface_depth_u_{direction}": interface_depth_u,
377
+ f"interface_depth_v_{direction}": interface_depth_v,
378
+ f"lat_rho_{direction}": lat_rho,
379
+ f"lat_u_{direction}": lat_u,
380
+ f"lat_v_{direction}": lat_v,
381
+ f"lon_rho_{direction}": lon_rho,
382
+ f"lon_u_{direction}": lon_u,
383
+ f"lon_v_{direction}": lon_v,
384
+ }
385
+ )
386
+
387
+ ds = ds.drop_vars(
388
+ [
389
+ "layer_depth_rho",
390
+ "layer_depth_u",
391
+ "layer_depth_v",
392
+ "interface_depth_rho",
393
+ "interface_depth_u",
394
+ "interface_depth_v",
395
+ "lat_rho",
396
+ "lon_rho",
397
+ "lat_u",
398
+ "lon_u",
399
+ "lat_v",
400
+ "lon_v",
401
+ "s_rho",
402
+ ]
403
+ )
404
+
405
+ # deal with time dimension
406
+ if dims["time"] != "time":
407
+ ds = ds.rename({dims["time"]: "time"})
408
+
409
+ # Translate the time coordinate to days since the model reference date
410
+ # TODO: Check if we need to convert from 12:00:00 to 00:00:00 as in matlab scripts
411
+ model_reference_date = np.datetime64(self.model_reference_date)
412
+
413
+ # Convert the time coordinate to the format expected by ROMS (days since model reference date)
414
+ bry_time = ds["time"] - model_reference_date
415
+ ds = ds.assign_coords(bry_time=("time", bry_time.data))
416
+ ds["bry_time"].attrs[
417
+ "long_name"
418
+ ] = f"time since {np.datetime_as_string(model_reference_date, unit='D')}"
419
+
420
+ ds["theta_s"] = self.vertical_coordinate.ds["theta_s"]
421
+ ds["theta_b"] = self.vertical_coordinate.ds["theta_b"]
422
+ ds["Tcline"] = self.vertical_coordinate.ds["Tcline"]
423
+ ds["hc"] = self.vertical_coordinate.ds["hc"]
424
+ ds["sc_r"] = self.vertical_coordinate.ds["sc_r"]
425
+ ds["Cs_r"] = self.vertical_coordinate.ds["Cs_r"]
426
+
427
+ ds.attrs["title"] = "ROMS boundary forcing file created by ROMS-Tools"
428
+ # Include the version of roms-tools
429
+ try:
430
+ roms_tools_version = importlib.metadata.version("roms-tools")
431
+ except importlib.metadata.PackageNotFoundError:
432
+ roms_tools_version = "unknown"
433
+ ds.attrs["roms_tools_version"] = roms_tools_version
434
+ ds.attrs["start_time"] = str(self.start_time)
435
+ ds.attrs["end_time"] = str(self.end_time)
436
+ ds.attrs["model_reference_date"] = str(self.model_reference_date)
437
+ ds.attrs["source"] = self.source
438
+
439
+ object.__setattr__(self, "ds", ds)
440
+
441
+ for direction in ["south", "east", "north", "west"]:
442
+ nan_check(
443
+ ds[f"zeta_{direction}"].isel(time=0),
444
+ self.grid.ds.mask_rho.isel(**bdry_coords_rho[direction]),
445
+ )
446
+
447
+ def plot(
448
+ self,
449
+ varname,
450
+ time=0,
451
+ layer_contours=False,
452
+ ) -> None:
453
+ """
454
+ Plot the boundary forcing field for a given time-slice.
455
+
456
+ Parameters
457
+ ----------
458
+ varname : str
459
+ The name of the initial conditions field to plot. Options include:
460
+ - "temp_{direction}": Potential temperature.
461
+ - "salt_{direction}": Salinity.
462
+ - "zeta_{direction}": Sea surface height.
463
+ - "u_{direction}": u-flux component.
464
+ - "v_{direction}": v-flux component.
465
+ - "ubar_{direction}": Vertically integrated u-flux component.
466
+ - "vbar_{direction}": Vertically integrated v-flux component.
467
+ where {direction} can be one of ["south", "east", "north", "west"].
468
+ time : int, optional
469
+ The time index to plot. Default is 0.
470
+ layer_contours : bool, optional
471
+ Whether to include layer contours in the plot. This can help visualize the depth levels
472
+ of the field. Default is False.
473
+
474
+ Returns
475
+ -------
476
+ None
477
+ This method does not return any value. It generates and displays a plot.
478
+
479
+ Raises
480
+ ------
481
+ ValueError
482
+ If the specified varname is not one of the valid options.
483
+ """
484
+
485
+ field = self.ds[varname].isel(time=time).load()
486
+
487
+ title = field.long_name
488
+
489
+ # chose colorbar
490
+ if varname.startswith(("u", "v", "ubar", "vbar", "zeta")):
491
+ vmax = max(field.max().values, -field.min().values)
492
+ vmin = -vmax
493
+ cmap = plt.colormaps.get_cmap("RdBu_r")
494
+ else:
495
+ vmax = field.max().values
496
+ vmin = field.min().values
497
+ cmap = plt.colormaps.get_cmap("YlOrRd")
498
+ cmap.set_bad(color="gray")
499
+ kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
500
+
501
+ if len(field.dims) == 2:
502
+ if layer_contours:
503
+ depths_to_check = [
504
+ "interface_depth_rho",
505
+ "interface_depth_u",
506
+ "interface_depth_v",
507
+ ]
508
+ try:
509
+ interface_depth = next(
510
+ self.ds[depth_label]
511
+ for depth_label in self.ds.coords
512
+ if any(
513
+ depth_label.startswith(prefix) for prefix in depths_to_check
514
+ )
515
+ and (
516
+ set(self.ds[depth_label].dims) - {"s_w"}
517
+ == set(field.dims) - {"s_rho"}
518
+ )
519
+ )
520
+ except StopIteration:
521
+ raise ValueError(
522
+ f"None of the expected depths ({', '.join(depths_to_check)}) have dimensions matching field.dims"
523
+ )
524
+ # restrict number of layer_contours to 10 for the sake of plot clearity
525
+ nr_layers = len(interface_depth["s_w"])
526
+ selected_layers = np.linspace(
527
+ 0, nr_layers - 1, min(nr_layers, 10), dtype=int
528
+ )
529
+ interface_depth = interface_depth.isel(s_w=selected_layers)
530
+
531
+ else:
532
+ interface_depth = None
533
+
534
+ _section_plot(
535
+ field, interface_depth=interface_depth, title=title, kwargs=kwargs
536
+ )
537
+ else:
538
+ _line_plot(field, title=title)
539
+
540
+ def save(self, filepath: str, time_chunk_size: int = 1) -> None:
541
+ """
542
+ Save the interpolated boundary forcing fields to netCDF4 files.
543
+
544
+ This method groups the dataset by year and month, chunks the data by the specified
545
+ time chunk size, and saves each chunked subset to a separate netCDF4 file named
546
+ according to the year, month, and day range if not a complete month of data is included.
547
+
548
+ Parameters
549
+ ----------
550
+ filepath : str
551
+ The base path and filename for the output files. The files will be named with
552
+ the format "filepath.YYYYMM.nc" if a full month of data is included, or
553
+ "filepath.YYYYMMDD-DD.nc" otherwise.
554
+ time_chunk_size : int, optional
555
+ Number of time slices to include in each chunk along the time dimension. Default is 1,
556
+ meaning each chunk contains one time slice.
557
+
558
+ Returns
559
+ -------
560
+ None
561
+ """
562
+ datasets = []
563
+ filenames = []
564
+ writes = []
565
+
566
+ # Group dataset by year
567
+ gb = self.ds.groupby("time.year")
568
+
569
+ for year, group_ds in gb:
570
+ # Further group each yearly group by month
571
+ sub_gb = group_ds.groupby("time.month")
572
+
573
+ for month, ds in sub_gb:
574
+ # Chunk the dataset by the specified time chunk size
575
+ ds = ds.chunk({"time": time_chunk_size})
576
+ datasets.append(ds)
577
+
578
+ # Determine the number of days in the month
579
+ num_days_in_month = calendar.monthrange(year, month)[1]
580
+ first_day = ds.time.dt.day.values[0]
581
+ last_day = ds.time.dt.day.values[-1]
582
+
583
+ # Create filename based on whether the dataset contains a full month
584
+ if first_day == 1 and last_day == num_days_in_month:
585
+ # Full month format: "filepath.YYYYMM.nc"
586
+ year_month_str = f"{year}{month:02}"
587
+ filename = f"{filepath}.{year_month_str}.nc"
588
+ else:
589
+ # Partial month format: "filepath.YYYYMMDD-DD.nc"
590
+ year_month_day_str = f"{year}{month:02}{first_day:02}-{last_day:02}"
591
+ filename = f"{filepath}.{year_month_day_str}.nc"
592
+ filenames.append(filename)
593
+
594
+ print("Saving the following files:")
595
+ for filename in filenames:
596
+ print(filename)
597
+
598
+ for ds, filename in zip(datasets, filenames):
599
+
600
+ # Prepare the dataset for writing to a netCDF file without immediately computing
601
+ write = ds.to_netcdf(filename, compute=False)
602
+ writes.append(write)
603
+
604
+ # Perform the actual write operations in parallel
605
+ dask.compute(*writes)
606
+
607
+ def to_yaml(self, filepath: str) -> None:
608
+ """
609
+ Export the parameters of the class to a YAML file, including the version of roms-tools.
610
+
611
+ Parameters
612
+ ----------
613
+ filepath : str
614
+ The path to the YAML file where the parameters will be saved.
615
+ """
616
+ # Serialize Grid data
617
+ grid_data = asdict(self.grid)
618
+ grid_data.pop("ds", None) # Exclude non-serializable fields
619
+ grid_data.pop("straddle", None)
620
+
621
+ # Serialize VerticalCoordinate data
622
+ vertical_coordinate_data = asdict(self.vertical_coordinate)
623
+ vertical_coordinate_data.pop("ds", None) # Exclude non-serializable fields
624
+ vertical_coordinate_data.pop("grid", None) # Exclude non-serializable fields
625
+
626
+ # Include the version of roms-tools
627
+ try:
628
+ roms_tools_version = importlib.metadata.version("roms-tools")
629
+ except importlib.metadata.PackageNotFoundError:
630
+ roms_tools_version = "unknown"
631
+
632
+ # Create header
633
+ header = f"---\nroms_tools_version: {roms_tools_version}\n---\n"
634
+
635
+ grid_yaml_data = {"Grid": grid_data}
636
+ vertical_coordinate_yaml_data = {"VerticalCoordinate": vertical_coordinate_data}
637
+
638
+ boundary_forcing_data = {
639
+ "BoundaryForcing": {
640
+ "filename": self.filename,
641
+ "start_time": self.start_time.isoformat(),
642
+ "end_time": self.end_time.isoformat(),
643
+ "model_reference_date": self.model_reference_date.isoformat(),
644
+ "source": self.source,
645
+ "boundaries": self.boundaries,
646
+ }
647
+ }
648
+
649
+ yaml_data = {
650
+ **grid_yaml_data,
651
+ **vertical_coordinate_yaml_data,
652
+ **boundary_forcing_data,
653
+ }
654
+
655
+ with open(filepath, "w") as file:
656
+ # Write header
657
+ file.write(header)
658
+ # Write YAML data
659
+ yaml.dump(yaml_data, file, default_flow_style=False)
660
+
661
+ @classmethod
662
+ def from_yaml(cls, filepath: str) -> "BoundaryForcing":
663
+ """
664
+ Create an instance of the BoundaryForcing class from a YAML file.
665
+
666
+ Parameters
667
+ ----------
668
+ filepath : str
669
+ The path to the YAML file from which the parameters will be read.
670
+
671
+ Returns
672
+ -------
673
+ BoundaryForcing
674
+ An instance of the BoundaryForcing class.
675
+ """
676
+ # Read the entire file content
677
+ with open(filepath, "r") as file:
678
+ file_content = file.read()
679
+
680
+ # Split the content into YAML documents
681
+ documents = list(yaml.safe_load_all(file_content))
682
+
683
+ boundary_forcing_data = None
684
+
685
+ # Process the YAML documents
686
+ for doc in documents:
687
+ if doc is None:
688
+ continue
689
+ if "BoundaryForcing" in doc:
690
+ boundary_forcing_data = doc["BoundaryForcing"]
691
+ break
692
+
693
+ if boundary_forcing_data is None:
694
+ raise ValueError("No BoundaryForcing configuration found in the YAML file.")
695
+
696
+ # Convert from string to datetime
697
+ for date_string in ["model_reference_date", "start_time", "end_time"]:
698
+ boundary_forcing_data[date_string] = datetime.fromisoformat(
699
+ boundary_forcing_data[date_string]
700
+ )
701
+
702
+ # Create VerticalCoordinate instance from the YAML file
703
+ vertical_coordinate = VerticalCoordinate.from_yaml(filepath)
704
+ grid = vertical_coordinate.grid
705
+
706
+ # Create and return an instance of InitialConditions
707
+ return cls(
708
+ grid=grid,
709
+ vertical_coordinate=vertical_coordinate,
710
+ **boundary_forcing_data,
711
+ )