roms-tools 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,757 @@
1
+ import xarray as xr
2
+ import numpy as np
3
+ import pandas as pd
4
+ import yaml
5
+ from datatree import DataTree
6
+ import importlib.metadata
7
+ from typing import Dict, Union, Optional
8
+ from dataclasses import dataclass, field, asdict
9
+ from roms_tools.setup.grid import Grid
10
+ from roms_tools.setup.vertical_coordinate import VerticalCoordinate
11
+ from roms_tools.setup.mixins import ROMSToolsMixins
12
+ from datetime import datetime
13
+ from roms_tools.setup.datasets import GLORYSDataset, CESMBGCDataset
14
+ from roms_tools.setup.utils import (
15
+ nan_check,
16
+ )
17
+ from roms_tools.setup.plot import _section_plot, _line_plot
18
+ import calendar
19
+ import dask
20
+ import matplotlib.pyplot as plt
21
+
22
+
23
+ @dataclass(frozen=True, kw_only=True)
24
+ class BoundaryForcing(ROMSToolsMixins):
25
+ """
26
+ Represents boundary forcing for ROMS.
27
+
28
+ Parameters
29
+ ----------
30
+ grid : Grid
31
+ Object representing the grid information.
32
+ vertical_coordinate: VerticalCoordinate
33
+ Object representing the vertical coordinate information.
34
+ start_time : datetime
35
+ Start time of the desired boundary forcing data.
36
+ end_time : datetime
37
+ End time of the desired boundary forcing data.
38
+ boundaries : Dict[str, bool], optional
39
+ Dictionary specifying which boundaries are forced (south, east, north, west). Default is all True.
40
+ physics_source : Dict[str, Union[str, None]]
41
+ Dictionary specifying the source of the physical boundary forcing data:
42
+ - "name" (str): Name of the data source (e.g., "GLORYS").
43
+ - "path" (str): Path to the physical data file. Can contain wildcards.
44
+ - "climatology" (bool): Indicates if the physical data is climatology data. Defaults to False.
45
+ bgc_source : Optional[Dict[str, Union[str, None]]]
46
+ Dictionary specifying the source of the biogeochemical (BGC) initial condition data:
47
+ - "name" (str): Name of the BGC data source (e.g., "CESM_REGRIDDED").
48
+ - "path" (str): Path to the BGC data file. Can contain wildcards.
49
+ - "climatology" (bool): Indicates if the BGC data is climatology data. Defaults to True.
50
+ model_reference_date : datetime, optional
51
+ Reference date for the model. Default is January 1, 2000.
52
+
53
+ Attributes
54
+ ----------
55
+ ds : xr.Dataset
56
+ Xarray Dataset containing the atmospheric forcing data.
57
+
58
+ Examples
59
+ --------
60
+ >>> boundary_forcing = BoundaryForcing(
61
+ ... grid=grid,
62
+ ... vertical_coordinate=vertical_coordinate,
63
+ ... boundaries={"south": True, "east": True, "north": False, "west": True},
64
+ ... start_time=datetime(2022, 1, 1),
65
+ ... end_time=datetime(2022, 1, 2),
66
+ ... physics_source={"name": "GLORYS", "path": "physics_data.nc"},
67
+ ... bgc_source={
68
+ ... "name": "CESM_REGRIDDED",
69
+ ... "path": "bgc_data.nc",
70
+ ... "climatology": True,
71
+ ... },
72
+ ... )
73
+ """
74
+
75
+ grid: Grid
76
+ vertical_coordinate: VerticalCoordinate
77
+ start_time: datetime
78
+ end_time: datetime
79
+ boundaries: Dict[str, bool] = field(
80
+ default_factory=lambda: {
81
+ "south": True,
82
+ "east": True,
83
+ "north": True,
84
+ "west": True,
85
+ }
86
+ )
87
+ physics_source: Dict[str, Union[str, None]]
88
+ bgc_source: Optional[Dict[str, Union[str, None]]] = None
89
+ model_reference_date: datetime = datetime(2000, 1, 1)
90
+
91
+ ds: xr.Dataset = field(init=False, repr=False)
92
+
93
+ def __post_init__(self):
94
+
95
+ self._input_checks()
96
+ lon, lat, angle, straddle = super().get_target_lon_lat()
97
+
98
+ data = self._get_data()
99
+ data.choose_subdomain(
100
+ latitude_range=[lat.min().values, lat.max().values],
101
+ longitude_range=[lon.min().values, lon.max().values],
102
+ margin=2,
103
+ straddle=straddle,
104
+ )
105
+
106
+ vars_2d = ["zeta"]
107
+ vars_3d = ["temp", "salt", "u", "v"]
108
+ data_vars = super().regrid_data(data, vars_2d, vars_3d, lon, lat)
109
+ data_vars = super().process_velocities(data_vars, angle)
110
+ object.__setattr__(data, "data_vars", data_vars)
111
+
112
+ if self.bgc_source is not None:
113
+ bgc_data = self._get_bgc_data()
114
+ bgc_data.choose_subdomain(
115
+ latitude_range=[lat.min().values, lat.max().values],
116
+ longitude_range=[lon.min().values, lon.max().values],
117
+ margin=2,
118
+ straddle=straddle,
119
+ )
120
+
121
+ vars_2d = []
122
+ vars_3d = bgc_data.var_names.keys()
123
+ data_vars = super().regrid_data(bgc_data, vars_2d, vars_3d, lon, lat)
124
+ object.__setattr__(bgc_data, "data_vars", data_vars)
125
+ else:
126
+ bgc_data = None
127
+
128
+ d_meta = super().get_variable_metadata()
129
+ bdry_coords, rename = super().get_boundary_info()
130
+
131
+ ds = self._write_into_datatree(data, bgc_data, d_meta, bdry_coords, rename)
132
+
133
+ for direction in ["south", "east", "north", "west"]:
134
+ if self.boundaries[direction]:
135
+ nan_check(
136
+ ds["physics"][f"zeta_{direction}"].isel(bry_time=0),
137
+ self.grid.ds.mask_rho.isel(**bdry_coords["rho"][direction]),
138
+ )
139
+
140
+ object.__setattr__(self, "ds", ds)
141
+
142
+ def _input_checks(self):
143
+
144
+ if "name" not in self.physics_source.keys():
145
+ raise ValueError("`physics_source` must include a 'name'.")
146
+ if "path" not in self.physics_source.keys():
147
+ raise ValueError("`physics_source` must include a 'path'.")
148
+ # set self.physics_source["climatology"] to False if not provided
149
+ object.__setattr__(
150
+ self,
151
+ "physics_source",
152
+ {
153
+ **self.physics_source,
154
+ "climatology": self.physics_source.get("climatology", False),
155
+ },
156
+ )
157
+ if self.bgc_source is not None:
158
+ if "name" not in self.bgc_source.keys():
159
+ raise ValueError(
160
+ "`bgc_source` must include a 'name' if it is provided."
161
+ )
162
+ if "path" not in self.bgc_source.keys():
163
+ raise ValueError(
164
+ "`bgc_source` must include a 'path' if it is provided."
165
+ )
166
+ # set self.bgc_source["climatology"] to True if not provided
167
+ object.__setattr__(
168
+ self,
169
+ "bgc_source",
170
+ {
171
+ **self.bgc_source,
172
+ "climatology": self.bgc_source.get("climatology", True),
173
+ },
174
+ )
175
+
176
+ def _get_data(self):
177
+
178
+ if self.physics_source["name"] == "GLORYS":
179
+ data = GLORYSDataset(
180
+ filename=self.physics_source["path"],
181
+ start_time=self.start_time,
182
+ end_time=self.end_time,
183
+ climatology=self.physics_source["climatology"],
184
+ )
185
+ else:
186
+ raise ValueError(
187
+ 'Only "GLORYS" is a valid option for physics_source["name"].'
188
+ )
189
+
190
+ return data
191
+
192
+ def _get_bgc_data(self):
193
+
194
+ if self.bgc_source["name"] == "CESM_REGRIDDED":
195
+
196
+ data = CESMBGCDataset(
197
+ filename=self.bgc_source["path"],
198
+ start_time=self.start_time,
199
+ end_time=self.end_time,
200
+ climatology=self.bgc_source["climatology"],
201
+ )
202
+ data.post_process()
203
+ else:
204
+ raise ValueError(
205
+ 'Only "CESM_REGRIDDED" is a valid option for bgc_source["name"].'
206
+ )
207
+
208
+ return data
209
+
210
+ def _write_into_dataset(self, data, d_meta, bdry_coords, rename):
211
+
212
+ # save in new dataset
213
+ ds = xr.Dataset()
214
+
215
+ for direction in ["south", "east", "north", "west"]:
216
+ if self.boundaries[direction]:
217
+
218
+ for var in data.data_vars.keys():
219
+ if var in ["u", "ubar"]:
220
+ ds[f"{var}_{direction}"] = (
221
+ data.data_vars[var]
222
+ .isel(**bdry_coords["u"][direction])
223
+ .rename(**rename["u"][direction])
224
+ .astype(np.float32)
225
+ )
226
+ elif var in ["v", "vbar"]:
227
+ ds[f"{var}_{direction}"] = (
228
+ data.data_vars[var]
229
+ .isel(**bdry_coords["v"][direction])
230
+ .rename(**rename["v"][direction])
231
+ .astype(np.float32)
232
+ )
233
+ else:
234
+ ds[f"{var}_{direction}"] = (
235
+ data.data_vars[var]
236
+ .isel(**bdry_coords["rho"][direction])
237
+ .rename(**rename["rho"][direction])
238
+ .astype(np.float32)
239
+ )
240
+ ds[f"{var}_{direction}"].attrs[
241
+ "long_name"
242
+ ] = f"{direction}ern boundary {d_meta[var]['long_name']}"
243
+ ds[f"{var}_{direction}"].attrs["units"] = d_meta[var]["units"]
244
+
245
+ # Gracefully handle dropping variables that might not be present
246
+ variables_to_drop = [
247
+ "s_rho",
248
+ "layer_depth_rho",
249
+ "layer_depth_u",
250
+ "layer_depth_v",
251
+ "interface_depth_rho",
252
+ "interface_depth_u",
253
+ "interface_depth_v",
254
+ "lat_rho",
255
+ "lon_rho",
256
+ "lat_u",
257
+ "lon_u",
258
+ "lat_v",
259
+ "lon_v",
260
+ ]
261
+ existing_vars = [var for var in variables_to_drop if var in ds]
262
+ ds = ds.drop_vars(existing_vars)
263
+
264
+ # Preserve absolute time coordinate for readability
265
+ ds = ds.assign_coords({"abs_time": ds["time"]})
266
+
267
+ # Convert the time coordinate to the format expected by ROMS
268
+ if data.climatology:
269
+ # Convert to pandas TimedeltaIndex
270
+ timedelta_index = pd.to_timedelta(ds["time"].values)
271
+ # Determine the start of the year for the base_datetime
272
+ start_of_year = datetime(self.model_reference_date.year, 1, 1)
273
+ # Calculate the offset from midnight of the new year
274
+ offset = self.model_reference_date - start_of_year
275
+ bry_time = xr.DataArray(
276
+ timedelta_index - offset,
277
+ dims="time",
278
+ )
279
+ else:
280
+ # TODO: Check if we need to convert from 12:00:00 to 00:00:00 as in matlab scripts
281
+ bry_time = ds["time"] - np.datetime64(self.model_reference_date)
282
+
283
+ ds = ds.assign_coords({"bry_time": bry_time})
284
+ ds["bry_time"].attrs[
285
+ "long_name"
286
+ ] = f"nanoseconds since {np.datetime_as_string(np.datetime64(self.model_reference_date), unit='ns')}"
287
+ ds["bry_time"].encoding["units"] = "nanoseconds"
288
+ ds = ds.swap_dims({"time": "bry_time"})
289
+ ds = ds.drop_vars("time")
290
+
291
+ if data.climatology:
292
+ ds["bry_time"].attrs["cycle_length"] = 365.25
293
+
294
+ return ds
295
+
296
+ def _write_into_datatree(self, data, bgc_data, d_meta, bdry_coords, rename):
297
+
298
+ ds = self._add_global_metadata()
299
+ ds["sc_r"] = self.vertical_coordinate.ds["sc_r"]
300
+ ds["Cs_r"] = self.vertical_coordinate.ds["Cs_r"]
301
+
302
+ ds = DataTree(name="root", data=ds)
303
+
304
+ ds_physics = self._write_into_dataset(data, d_meta, bdry_coords, rename)
305
+ ds_physics = self._add_coordinates(bdry_coords, rename, ds_physics)
306
+ ds_physics = self._add_global_metadata(ds_physics)
307
+ ds_physics.attrs["physics_source"] = self.physics_source["name"]
308
+
309
+ ds_physics = DataTree(name="physics", parent=ds, data=ds_physics)
310
+
311
+ if bgc_data:
312
+ ds_bgc = self._write_into_dataset(bgc_data, d_meta, bdry_coords, rename)
313
+ ds_bgc = self._add_coordinates(bdry_coords, rename, ds_bgc)
314
+ ds_bgc = self._add_global_metadata(ds_bgc)
315
+ ds_bgc.attrs["bgc_source"] = self.bgc_source["name"]
316
+ ds_bgc = DataTree(name="bgc", parent=ds, data=ds_bgc)
317
+
318
+ return ds
319
+
320
+ def _add_coordinates(self, bdry_coords, rename, ds=None):
321
+
322
+ if ds is None:
323
+ ds = xr.Dataset()
324
+
325
+ for direction in ["south", "east", "north", "west"]:
326
+
327
+ if self.boundaries[direction]:
328
+
329
+ lat_rho = self.grid.ds.lat_rho.isel(
330
+ **bdry_coords["rho"][direction]
331
+ ).rename(**rename["rho"][direction])
332
+ lon_rho = self.grid.ds.lon_rho.isel(
333
+ **bdry_coords["rho"][direction]
334
+ ).rename(**rename["rho"][direction])
335
+ layer_depth_rho = (
336
+ self.vertical_coordinate.ds["layer_depth_rho"]
337
+ .isel(**bdry_coords["rho"][direction])
338
+ .rename(**rename["rho"][direction])
339
+ )
340
+ interface_depth_rho = (
341
+ self.vertical_coordinate.ds["interface_depth_rho"]
342
+ .isel(**bdry_coords["rho"][direction])
343
+ .rename(**rename["rho"][direction])
344
+ )
345
+
346
+ lat_u = self.grid.ds.lat_u.isel(**bdry_coords["u"][direction]).rename(
347
+ **rename["u"][direction]
348
+ )
349
+ lon_u = self.grid.ds.lon_u.isel(**bdry_coords["u"][direction]).rename(
350
+ **rename["u"][direction]
351
+ )
352
+ layer_depth_u = (
353
+ self.vertical_coordinate.ds["layer_depth_u"]
354
+ .isel(**bdry_coords["u"][direction])
355
+ .rename(**rename["u"][direction])
356
+ )
357
+ interface_depth_u = (
358
+ self.vertical_coordinate.ds["interface_depth_u"]
359
+ .isel(**bdry_coords["u"][direction])
360
+ .rename(**rename["u"][direction])
361
+ )
362
+
363
+ lat_v = self.grid.ds.lat_v.isel(**bdry_coords["v"][direction]).rename(
364
+ **rename["v"][direction]
365
+ )
366
+ lon_v = self.grid.ds.lon_v.isel(**bdry_coords["v"][direction]).rename(
367
+ **rename["v"][direction]
368
+ )
369
+ layer_depth_v = (
370
+ self.vertical_coordinate.ds["layer_depth_v"]
371
+ .isel(**bdry_coords["v"][direction])
372
+ .rename(**rename["v"][direction])
373
+ )
374
+ interface_depth_v = (
375
+ self.vertical_coordinate.ds["interface_depth_v"]
376
+ .isel(**bdry_coords["v"][direction])
377
+ .rename(**rename["v"][direction])
378
+ )
379
+
380
+ ds = ds.assign_coords(
381
+ {
382
+ f"layer_depth_rho_{direction}": layer_depth_rho,
383
+ f"layer_depth_u_{direction}": layer_depth_u,
384
+ f"layer_depth_v_{direction}": layer_depth_v,
385
+ f"interface_depth_rho_{direction}": interface_depth_rho,
386
+ f"interface_depth_u_{direction}": interface_depth_u,
387
+ f"interface_depth_v_{direction}": interface_depth_v,
388
+ f"lat_rho_{direction}": lat_rho,
389
+ f"lat_u_{direction}": lat_u,
390
+ f"lat_v_{direction}": lat_v,
391
+ f"lon_rho_{direction}": lon_rho,
392
+ f"lon_u_{direction}": lon_u,
393
+ f"lon_v_{direction}": lon_v,
394
+ }
395
+ )
396
+
397
+ # Gracefully handle dropping variables that might not be present
398
+ variables_to_drop = [
399
+ "s_rho",
400
+ "layer_depth_rho",
401
+ "layer_depth_u",
402
+ "layer_depth_v",
403
+ "interface_depth_rho",
404
+ "interface_depth_u",
405
+ "interface_depth_v",
406
+ "lat_rho",
407
+ "lon_rho",
408
+ "lat_u",
409
+ "lon_u",
410
+ "lat_v",
411
+ "lon_v",
412
+ ]
413
+ existing_vars = [var for var in variables_to_drop if var in ds]
414
+ ds = ds.drop_vars(existing_vars)
415
+
416
+ return ds
417
+
418
+ def _add_global_metadata(self, ds=None):
419
+
420
+ if ds is None:
421
+ ds = xr.Dataset()
422
+ ds.attrs["title"] = "ROMS boundary forcing file created by ROMS-Tools"
423
+ # Include the version of roms-tools
424
+ try:
425
+ roms_tools_version = importlib.metadata.version("roms-tools")
426
+ except importlib.metadata.PackageNotFoundError:
427
+ roms_tools_version = "unknown"
428
+ ds.attrs["roms_tools_version"] = roms_tools_version
429
+ ds.attrs["start_time"] = str(self.start_time)
430
+ ds.attrs["end_time"] = str(self.end_time)
431
+ ds.attrs["model_reference_date"] = str(self.model_reference_date)
432
+
433
+ ds.attrs["theta_s"] = self.vertical_coordinate.ds["theta_s"].item()
434
+ ds.attrs["theta_b"] = self.vertical_coordinate.ds["theta_b"].item()
435
+ ds.attrs["Tcline"] = self.vertical_coordinate.ds["Tcline"].item()
436
+ ds.attrs["hc"] = self.vertical_coordinate.ds["hc"].item()
437
+
438
+ return ds
439
+
440
+ def plot(
441
+ self,
442
+ varname,
443
+ time=0,
444
+ layer_contours=False,
445
+ ) -> None:
446
+ """
447
+ Plot the boundary forcing field for a given time-slice.
448
+
449
+ Parameters
450
+ ----------
451
+ varname : str
452
+ The name of the initial conditions field to plot. Options include:
453
+ - "temp_{direction}": Potential temperature, where {direction} can be one of ["south", "east", "north", "west"].
454
+ - "salt_{direction}": Salinity, where {direction} can be one of ["south", "east", "north", "west"].
455
+ - "zeta_{direction}": Sea surface height, where {direction} can be one of ["south", "east", "north", "west"].
456
+ - "u_{direction}": u-flux component, where {direction} can be one of ["south", "east", "north", "west"].
457
+ - "v_{direction}": v-flux component, where {direction} can be one of ["south", "east", "north", "west"].
458
+ - "ubar_{direction}": Vertically integrated u-flux component, where {direction} can be one of ["south", "east", "north", "west"].
459
+ - "vbar_{direction}": Vertically integrated v-flux component, where {direction} can be one of ["south", "east", "north", "west"].
460
+ - "PO4_{direction}": Dissolved Inorganic Phosphate (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
461
+ - "NO3_{direction}": Dissolved Inorganic Nitrate (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
462
+ - "SiO3_{direction}": Dissolved Inorganic Silicate (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
463
+ - "NH4_{direction}": Dissolved Ammonia (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
464
+ - "Fe_{direction}": Dissolved Inorganic Iron (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
465
+ - "Lig_{direction}": Iron Binding Ligand (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
466
+ - "O2_{direction}": Dissolved Oxygen (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
467
+ - "DIC_{direction}": Dissolved Inorganic Carbon (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
468
+ - "DIC_ALT_CO2_{direction}": Dissolved Inorganic Carbon, Alternative CO2 (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
469
+ - "ALK_{direction}": Alkalinity (meq/m³), where {direction} can be one of ["south", "east", "north", "west"].
470
+ - "ALK_ALT_CO2_{direction}": Alkalinity, Alternative CO2 (meq/m³), where {direction} can be one of ["south", "east", "north", "west"].
471
+ - "DOC_{direction}": Dissolved Organic Carbon (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
472
+ - "DON_{direction}": Dissolved Organic Nitrogen (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
473
+ - "DOP_{direction}": Dissolved Organic Phosphorus (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
474
+ - "DOPr_{direction}": Refractory Dissolved Organic Phosphorus (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
475
+ - "DONr_{direction}": Refractory Dissolved Organic Nitrogen (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
476
+ - "DOCr_{direction}": Refractory Dissolved Organic Carbon (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
477
+ - "zooC_{direction}": Zooplankton Carbon (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
478
+ - "spChl_{direction}": Small Phytoplankton Chlorophyll (mg/m³), where {direction} can be one of ["south", "east", "north", "west"].
479
+ - "spC_{direction}": Small Phytoplankton Carbon (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
480
+ - "spP_{direction}": Small Phytoplankton Phosphorous (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
481
+ - "spFe_{direction}": Small Phytoplankton Iron (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
482
+ - "spCaCO3_{direction}": Small Phytoplankton CaCO3 (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
483
+ - "diatChl_{direction}": Diatom Chlorophyll (mg/m³), where {direction} can be one of ["south", "east", "north", "west"].
484
+ - "diatC_{direction}": Diatom Carbon (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
485
+ - "diatP_{direction}": Diatom Phosphorus (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
486
+ - "diatFe_{direction}": Diatom Iron (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
487
+ - "diatSi_{direction}": Diatom Silicate (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
488
+ - "diazChl_{direction}": Diazotroph Chlorophyll (mg/m³), where {direction} can be one of ["south", "east", "north", "west"].
489
+ - "diazC_{direction}": Diazotroph Carbon (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
490
+ - "diazP_{direction}": Diazotroph Phosphorus (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
491
+ - "diazFe_{direction}": Diazotroph Iron (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
492
+ time : int, optional
493
+ The time index to plot. Default is 0.
494
+ layer_contours : bool, optional
495
+ Whether to include layer contours in the plot. This can help visualize the depth levels
496
+ of the field. Default is False.
497
+
498
+ Returns
499
+ -------
500
+ None
501
+ This method does not return any value. It generates and displays a plot.
502
+
503
+ Raises
504
+ ------
505
+ ValueError
506
+ If the specified varname is not one of the valid options.
507
+ """
508
+
509
+ if varname in self.ds["physics"]:
510
+ ds = self.ds["physics"]
511
+ else:
512
+ if "bgc" in self.ds and varname in self.ds["bgc"]:
513
+ ds = self.ds["bgc"]
514
+ else:
515
+ raise ValueError(
516
+ f"Variable '{varname}' is not found in 'physics' or 'bgc' datasets."
517
+ )
518
+
519
+ field = ds[varname].isel(bry_time=time).load()
520
+ title = field.long_name
521
+
522
+ # chose colorbar
523
+ if varname.startswith(("u", "v", "ubar", "vbar", "zeta")):
524
+ vmax = max(field.max().values, -field.min().values)
525
+ vmin = -vmax
526
+ cmap = plt.colormaps.get_cmap("RdBu_r")
527
+ else:
528
+ vmax = field.max().values
529
+ vmin = field.min().values
530
+ if varname.startswith(("temp", "salt")):
531
+ cmap = plt.colormaps.get_cmap("YlOrRd")
532
+ else:
533
+ cmap = plt.colormaps.get_cmap("YlGn")
534
+ cmap.set_bad(color="gray")
535
+ kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
536
+
537
+ if len(field.dims) == 2:
538
+ if layer_contours:
539
+ depths_to_check = [
540
+ "interface_depth_rho",
541
+ "interface_depth_u",
542
+ "interface_depth_v",
543
+ ]
544
+ try:
545
+ interface_depth = next(
546
+ ds[depth_label]
547
+ for depth_label in ds.coords
548
+ if any(
549
+ depth_label.startswith(prefix) for prefix in depths_to_check
550
+ )
551
+ and (
552
+ set(ds[depth_label].dims) - {"s_w"}
553
+ == set(field.dims) - {"s_rho"}
554
+ )
555
+ )
556
+ except StopIteration:
557
+ raise ValueError(
558
+ f"None of the expected depths ({', '.join(depths_to_check)}) have dimensions matching field.dims"
559
+ )
560
+ # restrict number of layer_contours to 10 for the sake of plot clearity
561
+ nr_layers = len(interface_depth["s_w"])
562
+ selected_layers = np.linspace(
563
+ 0, nr_layers - 1, min(nr_layers, 10), dtype=int
564
+ )
565
+ interface_depth = interface_depth.isel(s_w=selected_layers)
566
+
567
+ else:
568
+ interface_depth = None
569
+
570
+ _section_plot(
571
+ field, interface_depth=interface_depth, title=title, kwargs=kwargs
572
+ )
573
+ else:
574
+ _line_plot(field, title=title)
575
+
576
+ def save(self, filepath: str, time_chunk_size: int = 1) -> None:
577
+ """
578
+ Save the interpolated boundary forcing fields to netCDF4 files.
579
+
580
+ This method groups the dataset by year and month, chunks the data by the specified
581
+ time chunk size, and saves each chunked subset to a separate netCDF4 file named
582
+ according to the year, month, and day range if not a complete month of data is included.
583
+
584
+ Parameters
585
+ ----------
586
+ filepath : str
587
+ The base path and filename for the output files. The files will be named with
588
+ the format "filepath.YYYYMM.nc" if a full month of data is included, or
589
+ "filepath.YYYYMMDD-DD.nc" otherwise.
590
+ time_chunk_size : int, optional
591
+ Number of time slices to include in each chunk along the time dimension. Default is 1,
592
+ meaning each chunk contains one time slice.
593
+
594
+ Returns
595
+ -------
596
+ None
597
+ """
598
+ datasets = []
599
+ filenames = []
600
+ writes = []
601
+
602
+ for node in ["physics", "bgc"]:
603
+ if node in self.ds:
604
+ ds = self.ds[node].to_dataset()
605
+ # copy vertical coordinate variables from parent to children because I believe this is info that ROMS needs
606
+ for var in self.ds.data_vars:
607
+ ds[var] = self.ds[var]
608
+ if hasattr(ds["bry_time"], "cycle_length"):
609
+ filename = f"{filepath}_{node}_clim.nc"
610
+ filenames.append(filename)
611
+ datasets.append(ds)
612
+ else:
613
+ # Group dataset by year
614
+ gb = ds.groupby("abs_time.year")
615
+
616
+ for year, group_ds in gb:
617
+ # Further group each yearly group by month
618
+ sub_gb = group_ds.groupby("abs_time.month")
619
+
620
+ for month, ds in sub_gb:
621
+ # Chunk the dataset by the specified time chunk size
622
+ ds = ds.chunk({"bry_time": time_chunk_size})
623
+ datasets.append(ds)
624
+
625
+ # Determine the number of days in the month
626
+ num_days_in_month = calendar.monthrange(year, month)[1]
627
+ first_day = ds.abs_time.dt.day.values[0]
628
+ last_day = ds.abs_time.dt.day.values[-1]
629
+
630
+ # Create filename based on whether the dataset contains a full month
631
+ if first_day == 1 and last_day == num_days_in_month:
632
+ # Full month format: "filepath_physics_YYYYMM.nc"
633
+ year_month_str = f"{year}{month:02}"
634
+ filename = f"{filepath}_{node}_{year_month_str}.nc"
635
+ else:
636
+ # Partial month format: "filepath_physics_YYYYMMDD-DD.nc"
637
+ year_month_day_str = (
638
+ f"{year}{month:02}{first_day:02}-{last_day:02}"
639
+ )
640
+ filename = f"{filepath}_{node}_{year_month_day_str}.nc"
641
+ filenames.append(filename)
642
+
643
+ print("Saving the following files:")
644
+ for ds, filename in zip(datasets, filenames):
645
+ print(filename)
646
+ # Prepare the dataset for writing to a netCDF file without immediately computing
647
+ write = ds.to_netcdf(filename, compute=False)
648
+ writes.append(write)
649
+
650
+ # Perform the actual write operations in parallel
651
+ dask.compute(*writes)
652
+
653
+ def to_yaml(self, filepath: str) -> None:
654
+ """
655
+ Export the parameters of the class to a YAML file, including the version of roms-tools.
656
+
657
+ Parameters
658
+ ----------
659
+ filepath : str
660
+ The path to the YAML file where the parameters will be saved.
661
+ """
662
+ # Serialize Grid data
663
+ grid_data = asdict(self.grid)
664
+ grid_data.pop("ds", None) # Exclude non-serializable fields
665
+ grid_data.pop("straddle", None)
666
+
667
+ # Serialize VerticalCoordinate data
668
+ vertical_coordinate_data = asdict(self.vertical_coordinate)
669
+ vertical_coordinate_data.pop("ds", None) # Exclude non-serializable fields
670
+ vertical_coordinate_data.pop("grid", None) # Exclude non-serializable fields
671
+
672
+ # Include the version of roms-tools
673
+ try:
674
+ roms_tools_version = importlib.metadata.version("roms-tools")
675
+ except importlib.metadata.PackageNotFoundError:
676
+ roms_tools_version = "unknown"
677
+
678
+ # Create header
679
+ header = f"---\nroms_tools_version: {roms_tools_version}\n---\n"
680
+
681
+ grid_yaml_data = {"Grid": grid_data}
682
+ vertical_coordinate_yaml_data = {"VerticalCoordinate": vertical_coordinate_data}
683
+
684
+ boundary_forcing_data = {
685
+ "BoundaryForcing": {
686
+ "start_time": self.start_time.isoformat(),
687
+ "end_time": self.end_time.isoformat(),
688
+ "boundaries": self.boundaries,
689
+ "physics_source": self.physics_source,
690
+ "bgc_source": self.bgc_source,
691
+ "model_reference_date": self.model_reference_date.isoformat(),
692
+ }
693
+ }
694
+
695
+ yaml_data = {
696
+ **grid_yaml_data,
697
+ **vertical_coordinate_yaml_data,
698
+ **boundary_forcing_data,
699
+ }
700
+
701
+ with open(filepath, "w") as file:
702
+ # Write header
703
+ file.write(header)
704
+ # Write YAML data
705
+ yaml.dump(yaml_data, file, default_flow_style=False)
706
+
707
+ @classmethod
708
+ def from_yaml(cls, filepath: str) -> "BoundaryForcing":
709
+ """
710
+ Create an instance of the BoundaryForcing class from a YAML file.
711
+
712
+ Parameters
713
+ ----------
714
+ filepath : str
715
+ The path to the YAML file from which the parameters will be read.
716
+
717
+ Returns
718
+ -------
719
+ BoundaryForcing
720
+ An instance of the BoundaryForcing class.
721
+ """
722
+ # Read the entire file content
723
+ with open(filepath, "r") as file:
724
+ file_content = file.read()
725
+
726
+ # Split the content into YAML documents
727
+ documents = list(yaml.safe_load_all(file_content))
728
+
729
+ boundary_forcing_data = None
730
+
731
+ # Process the YAML documents
732
+ for doc in documents:
733
+ if doc is None:
734
+ continue
735
+ if "BoundaryForcing" in doc:
736
+ boundary_forcing_data = doc["BoundaryForcing"]
737
+ break
738
+
739
+ if boundary_forcing_data is None:
740
+ raise ValueError("No BoundaryForcing configuration found in the YAML file.")
741
+
742
+ # Convert from string to datetime
743
+ for date_string in ["model_reference_date", "start_time", "end_time"]:
744
+ boundary_forcing_data[date_string] = datetime.fromisoformat(
745
+ boundary_forcing_data[date_string]
746
+ )
747
+
748
+ # Create VerticalCoordinate instance from the YAML file
749
+ vertical_coordinate = VerticalCoordinate.from_yaml(filepath)
750
+ grid = vertical_coordinate.grid
751
+
752
+ # Create and return an instance of InitialConditions
753
+ return cls(
754
+ grid=grid,
755
+ vertical_coordinate=vertical_coordinate,
756
+ **boundary_forcing_data,
757
+ )