roms-tools 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,557 @@
1
+ import xarray as xr
2
+ import numpy as np
3
+ import yaml
4
+ import importlib.metadata
5
+ from dataclasses import dataclass, field, asdict
6
+ from typing import Optional, Dict, Union
7
+ from roms_tools.setup.grid import Grid
8
+ from roms_tools.setup.vertical_coordinate import VerticalCoordinate
9
+ from datetime import datetime
10
+ from roms_tools.setup.datasets import GLORYSDataset, CESMBGCDataset
11
+ from roms_tools.setup.utils import (
12
+ nan_check,
13
+ )
14
+ from roms_tools.setup.mixins import ROMSToolsMixins
15
+ from roms_tools.setup.plot import _plot, _section_plot, _profile_plot, _line_plot
16
+ import matplotlib.pyplot as plt
17
+
18
+
19
+ @dataclass(frozen=True, kw_only=True)
20
+ class InitialConditions(ROMSToolsMixins):
21
+ """
22
+ Represents initial conditions for ROMS, including physical and biogeochemical data.
23
+
24
+ Parameters
25
+ ----------
26
+ grid : Grid
27
+ Object representing the grid information used for the model.
28
+ vertical_coordinate : VerticalCoordinate
29
+ Object representing the vertical coordinate system.
30
+ ini_time : datetime
31
+ The date and time at which the initial conditions are set.
32
+ physics_source : Dict[str, Union[str, None]]
33
+ Dictionary specifying the source of the physical initial condition data:
34
+ - "name" (str): Name of the data source (e.g., "GLORYS").
35
+ - "path" (str): Path to the physical data file. Can contain wildcards.
36
+ - "climatology" (bool): Indicates if the physical data is climatology data. Defaults to False.
37
+ bgc_source : Optional[Dict[str, Union[str, None]]]
38
+ Dictionary specifying the source of the biogeochemical (BGC) initial condition data:
39
+ - "name" (str): Name of the BGC data source (e.g., "CESM_REGRIDDED").
40
+ - "path" (str): Path to the BGC data file. Can contain wildcards.
41
+ - "climatology" (bool): Indicates if the BGC data is climatology data. Defaults to True.
42
+ model_reference_date : datetime, optional
43
+ The reference date for the model. Defaults to January 1, 2000.
44
+
45
+ Attributes
46
+ ----------
47
+ ds : xr.Dataset
48
+ Xarray Dataset containing the initial condition data loaded from the specified files.
49
+
50
+ Examples
51
+ --------
52
+ >>> initial_conditions = InitialConditions(
53
+ ... grid=grid,
54
+ ... vertical_coordinate=vertical_coordinate,
55
+ ... ini_time=datetime(2022, 1, 1),
56
+ ... physics_source={"name": "GLORYS", "path": "physics_data.nc"},
57
+ ... bgc_source={
58
+ ... "name": "CESM_REGRIDDED",
59
+ ... "path": "bgc_data.nc",
60
+ ... "climatology": True,
61
+ ... },
62
+ ... )
63
+ """
64
+
65
+ grid: Grid
66
+ vertical_coordinate: VerticalCoordinate
67
+ ini_time: datetime
68
+ physics_source: Dict[str, Union[str, None]]
69
+ bgc_source: Optional[Dict[str, Union[str, None]]] = None
70
+ model_reference_date: datetime = datetime(2000, 1, 1)
71
+
72
+ ds: xr.Dataset = field(init=False, repr=False)
73
+
74
+ def __post_init__(self):
75
+
76
+ self._input_checks()
77
+ lon, lat, angle, straddle = super().get_target_lon_lat()
78
+
79
+ data = self._get_data()
80
+ data.choose_subdomain(
81
+ latitude_range=[lat.min().values, lat.max().values],
82
+ longitude_range=[lon.min().values, lon.max().values],
83
+ margin=2,
84
+ straddle=straddle,
85
+ )
86
+
87
+ vars_2d = ["zeta"]
88
+ vars_3d = ["temp", "salt", "u", "v"]
89
+ data_vars = super().regrid_data(data, vars_2d, vars_3d, lon, lat)
90
+ data_vars = super().process_velocities(data_vars, angle)
91
+
92
+ if self.bgc_source is not None:
93
+ bgc_data = self._get_bgc_data()
94
+ bgc_data.choose_subdomain(
95
+ latitude_range=[lat.min().values, lat.max().values],
96
+ longitude_range=[lon.min().values, lon.max().values],
97
+ margin=2,
98
+ straddle=straddle,
99
+ )
100
+
101
+ vars_2d = []
102
+ vars_3d = bgc_data.var_names.keys()
103
+ bgc_data_vars = super().regrid_data(bgc_data, vars_2d, vars_3d, lon, lat)
104
+
105
+ # Ensure time coordinate matches if climatology is applied in one case but not the other
106
+ if (
107
+ not self.physics_source["climatology"]
108
+ and self.bgc_source["climatology"]
109
+ ):
110
+ for var in bgc_data_vars.keys():
111
+ bgc_data_vars[var] = bgc_data_vars[var].assign_coords(
112
+ {"time": data_vars["temp"]["time"]}
113
+ )
114
+
115
+ # Combine data variables from physical and biogeochemical sources
116
+ data_vars.update(bgc_data_vars)
117
+
118
+ d_meta = super().get_variable_metadata()
119
+ ds = self._write_into_dataset(data_vars, d_meta)
120
+
121
+ ds = self._add_global_metadata(ds)
122
+
123
+ ds["zeta"].load()
124
+ nan_check(ds["zeta"].squeeze(), self.grid.ds.mask_rho)
125
+
126
+ object.__setattr__(self, "ds", ds)
127
+
128
+ def _input_checks(self):
129
+
130
+ if "name" not in self.physics_source.keys():
131
+ raise ValueError("`physics_source` must include a 'name'.")
132
+ if "path" not in self.physics_source.keys():
133
+ raise ValueError("`physics_source` must include a 'path'.")
134
+ # set self.physics_source["climatology"] to False if not provided
135
+ object.__setattr__(
136
+ self,
137
+ "physics_source",
138
+ {
139
+ **self.physics_source,
140
+ "climatology": self.physics_source.get("climatology", False),
141
+ },
142
+ )
143
+ if self.bgc_source is not None:
144
+ if "name" not in self.bgc_source.keys():
145
+ raise ValueError(
146
+ "`bgc_source` must include a 'name' if it is provided."
147
+ )
148
+ if "path" not in self.bgc_source.keys():
149
+ raise ValueError(
150
+ "`bgc_source` must include a 'path' if it is provided."
151
+ )
152
+ # set self.bgc_source["climatology"] to True if not provided
153
+ object.__setattr__(
154
+ self,
155
+ "bgc_source",
156
+ {
157
+ **self.bgc_source,
158
+ "climatology": self.bgc_source.get("climatology", True),
159
+ },
160
+ )
161
+
162
+ def _get_data(self):
163
+
164
+ if self.physics_source["name"] == "GLORYS":
165
+ data = GLORYSDataset(
166
+ filename=self.physics_source["path"],
167
+ start_time=self.ini_time,
168
+ climatology=self.physics_source["climatology"],
169
+ )
170
+ else:
171
+ raise ValueError(
172
+ 'Only "GLORYS" is a valid option for physics_source["name"].'
173
+ )
174
+ return data
175
+
176
+ def _get_bgc_data(self):
177
+
178
+ if self.bgc_source["name"] == "CESM_REGRIDDED":
179
+
180
+ bgc_data = CESMBGCDataset(
181
+ filename=self.bgc_source["path"],
182
+ start_time=self.ini_time,
183
+ climatology=self.bgc_source["climatology"],
184
+ )
185
+ bgc_data.post_process()
186
+ else:
187
+ raise ValueError(
188
+ 'Only "CESM_REGRIDDED" is a valid option for bgc_source["name"].'
189
+ )
190
+
191
+ return bgc_data
192
+
193
+ def _write_into_dataset(self, data_vars, d_meta):
194
+
195
+ # save in new dataset
196
+ ds = xr.Dataset()
197
+
198
+ for var in data_vars.keys():
199
+ ds[var] = data_vars[var].astype(np.float32)
200
+ ds[var].attrs["long_name"] = d_meta[var]["long_name"]
201
+ ds[var].attrs["units"] = d_meta[var]["units"]
202
+
203
+ # initialize vertical velocity to zero
204
+ ds["w"] = xr.zeros_like(
205
+ self.vertical_coordinate.ds["interface_depth_rho"].expand_dims(
206
+ time=data_vars["u"].time
207
+ )
208
+ ).astype(np.float32)
209
+ ds["w"].attrs["long_name"] = d_meta["w"]["long_name"]
210
+ ds["w"].attrs["units"] = d_meta["w"]["units"]
211
+
212
+ return ds
213
+
214
+ def _add_global_metadata(self, ds):
215
+
216
+ ds = ds.assign_coords(
217
+ {
218
+ "layer_depth_u": self.vertical_coordinate.ds["layer_depth_u"],
219
+ "layer_depth_v": self.vertical_coordinate.ds["layer_depth_v"],
220
+ "interface_depth_u": self.vertical_coordinate.ds["interface_depth_u"],
221
+ "interface_depth_v": self.vertical_coordinate.ds["interface_depth_v"],
222
+ }
223
+ )
224
+
225
+ ds.attrs["title"] = "ROMS initial conditions file created by ROMS-Tools"
226
+ # Include the version of roms-tools
227
+ try:
228
+ roms_tools_version = importlib.metadata.version("roms-tools")
229
+ except importlib.metadata.PackageNotFoundError:
230
+ roms_tools_version = "unknown"
231
+ ds.attrs["roms_tools_version"] = roms_tools_version
232
+ ds.attrs["ini_time"] = str(self.ini_time)
233
+ ds.attrs["model_reference_date"] = str(self.model_reference_date)
234
+ ds.attrs["physical_source"] = self.physics_source["name"]
235
+ if self.bgc_source is not None:
236
+ ds.attrs["bgc_source"] = self.bgc_source["name"]
237
+
238
+ # Translate the time coordinate to days since the model reference date
239
+ model_reference_date = np.datetime64(self.model_reference_date)
240
+
241
+ # Convert the time coordinate to the format expected by ROMS (days since model reference date)
242
+ ocean_time = (ds["time"] - model_reference_date).astype("float64") * 1e-9
243
+ ds = ds.assign_coords(ocean_time=("time", np.float32(ocean_time)))
244
+ ds["ocean_time"].attrs[
245
+ "long_name"
246
+ ] = f"seconds since {np.datetime_as_string(model_reference_date, unit='s')}"
247
+ ds["ocean_time"].attrs["units"] = "seconds"
248
+
249
+ ds.attrs["theta_s"] = self.vertical_coordinate.ds["theta_s"].item()
250
+ ds.attrs["theta_b"] = self.vertical_coordinate.ds["theta_b"].item()
251
+ ds.attrs["Tcline"] = self.vertical_coordinate.ds["Tcline"].item()
252
+ ds.attrs["hc"] = self.vertical_coordinate.ds["hc"].item()
253
+ ds["sc_r"] = self.vertical_coordinate.ds["sc_r"]
254
+ ds["Cs_r"] = self.vertical_coordinate.ds["Cs_r"]
255
+
256
+ ds = ds.drop_vars(["s_rho"])
257
+
258
+ return ds
259
+
260
+ def plot(
261
+ self,
262
+ varname,
263
+ s=None,
264
+ eta=None,
265
+ xi=None,
266
+ depth_contours=False,
267
+ layer_contours=False,
268
+ ) -> None:
269
+ """
270
+ Plot the initial conditions field for a given eta-, xi-, or s_rho-slice.
271
+
272
+ Parameters
273
+ ----------
274
+ varname : str
275
+ The name of the initial conditions field to plot. Options include:
276
+ - "temp": Potential temperature.
277
+ - "salt": Salinity.
278
+ - "zeta": Free surface.
279
+ - "u": u-flux component.
280
+ - "v": v-flux component.
281
+ - "w": w-flux component.
282
+ - "ubar": Vertically integrated u-flux component.
283
+ - "vbar": Vertically integrated v-flux component.
284
+ - "PO4": Dissolved Inorganic Phosphate (mmol/m³).
285
+ - "NO3": Dissolved Inorganic Nitrate (mmol/m³).
286
+ - "SiO3": Dissolved Inorganic Silicate (mmol/m³).
287
+ - "NH4": Dissolved Ammonia (mmol/m³).
288
+ - "Fe": Dissolved Inorganic Iron (mmol/m³).
289
+ - "Lig": Iron Binding Ligand (mmol/m³).
290
+ - "O2": Dissolved Oxygen (mmol/m³).
291
+ - "DIC": Dissolved Inorganic Carbon (mmol/m³).
292
+ - "DIC_ALT_CO2": Dissolved Inorganic Carbon, Alternative CO2 (mmol/m³).
293
+ - "ALK": Alkalinity (meq/m³).
294
+ - "ALK_ALT_CO2": Alkalinity, Alternative CO2 (meq/m³).
295
+ - "DOC": Dissolved Organic Carbon (mmol/m³).
296
+ - "DON": Dissolved Organic Nitrogen (mmol/m³).
297
+ - "DOP": Dissolved Organic Phosphorus (mmol/m³).
298
+ - "DOPr": Refractory Dissolved Organic Phosphorus (mmol/m³).
299
+ - "DONr": Refractory Dissolved Organic Nitrogen (mmol/m³).
300
+ - "DOCr": Refractory Dissolved Organic Carbon (mmol/m³).
301
+ - "zooC": Zooplankton Carbon (mmol/m³).
302
+ - "spChl": Small Phytoplankton Chlorophyll (mg/m³).
303
+ - "spC": Small Phytoplankton Carbon (mmol/m³).
304
+ - "spP": Small Phytoplankton Phosphorous (mmol/m³).
305
+ - "spFe": Small Phytoplankton Iron (mmol/m³).
306
+ - "spCaCO3": Small Phytoplankton CaCO3 (mmol/m³).
307
+ - "diatChl": Diatom Chlorophyll (mg/m³).
308
+ - "diatC": Diatom Carbon (mmol/m³).
309
+ - "diatP": Diatom Phosphorus (mmol/m³).
310
+ - "diatFe": Diatom Iron (mmol/m³).
311
+ - "diatSi": Diatom Silicate (mmol/m³).
312
+ - "diazChl": Diazotroph Chlorophyll (mg/m³).
313
+ - "diazC": Diazotroph Carbon (mmol/m³).
314
+ - "diazP": Diazotroph Phosphorus (mmol/m³).
315
+ - "diazFe": Diazotroph Iron (mmol/m³).
316
+ s : int, optional
317
+ The index of the vertical layer to plot. Default is None.
318
+ eta : int, optional
319
+ The eta-index to plot. Default is None.
320
+ xi : int, optional
321
+ The xi-index to plot. Default is None.
322
+ depth_contours : bool, optional
323
+ Whether to include depth contours in the plot. Default is False.
324
+
325
+ Returns
326
+ -------
327
+ None
328
+ This method does not return any value. It generates and displays a plot.
329
+
330
+ Raises
331
+ ------
332
+ ValueError
333
+ If the specified `varname` is not one of the valid options.
334
+ If the field specified by `varname` is 3D and none of `s`, `eta`, or `xi` are specified.
335
+ If the field specified by `varname` is 2D and both `eta` and `xi` are specified.
336
+ """
337
+
338
+ if len(self.ds[varname].squeeze().dims) == 3 and not any(
339
+ [s is not None, eta is not None, xi is not None]
340
+ ):
341
+ raise ValueError(
342
+ "For 3D fields, at least one of s, eta, or xi must be specified."
343
+ )
344
+
345
+ if len(self.ds[varname].squeeze().dims) == 2 and all(
346
+ [eta is not None, xi is not None]
347
+ ):
348
+ raise ValueError("For 2D fields, specify either eta or xi, not both.")
349
+
350
+ self.ds[varname].load()
351
+ field = self.ds[varname].squeeze()
352
+
353
+ if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
354
+ interface_depth = self.ds.interface_depth_rho
355
+ elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
356
+ interface_depth = self.ds.interface_depth_u
357
+ elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
358
+ interface_depth = self.ds.interface_depth_v
359
+
360
+ # slice the field as desired
361
+ title = field.long_name
362
+ if s is not None:
363
+ title = title + f", s_rho = {field.s_rho[s].item()}"
364
+ field = field.isel(s_rho=s)
365
+ else:
366
+ depth_contours = False
367
+
368
+ if eta is not None:
369
+ if "eta_rho" in field.dims:
370
+ title = title + f", eta_rho = {field.eta_rho[eta].item()}"
371
+ field = field.isel(eta_rho=eta)
372
+ interface_depth = interface_depth.isel(eta_rho=eta)
373
+ elif "eta_v" in field.dims:
374
+ title = title + f", eta_v = {field.eta_v[eta].item()}"
375
+ field = field.isel(eta_v=eta)
376
+ interface_depth = interface_depth.isel(eta_v=eta)
377
+ else:
378
+ raise ValueError(
379
+ f"None of the expected dimensions (eta_rho, eta_v) found in ds[{varname}]."
380
+ )
381
+ if xi is not None:
382
+ if "xi_rho" in field.dims:
383
+ title = title + f", xi_rho = {field.xi_rho[xi].item()}"
384
+ field = field.isel(xi_rho=xi)
385
+ interface_depth = interface_depth.isel(xi_rho=xi)
386
+ elif "xi_u" in field.dims:
387
+ title = title + f", xi_u = {field.xi_u[xi].item()}"
388
+ field = field.isel(xi_u=xi)
389
+ interface_depth = interface_depth.isel(xi_u=xi)
390
+ else:
391
+ raise ValueError(
392
+ f"None of the expected dimensions (xi_rho, xi_u) found in ds[{varname}]."
393
+ )
394
+
395
+ # chose colorbar
396
+ if varname in ["u", "v", "w", "ubar", "vbar", "zeta"]:
397
+ vmax = max(field.max().values, -field.min().values)
398
+ vmin = -vmax
399
+ cmap = plt.colormaps.get_cmap("RdBu_r")
400
+ else:
401
+ vmax = field.max().values
402
+ vmin = field.min().values
403
+ if varname in ["temp", "salt"]:
404
+ cmap = plt.colormaps.get_cmap("YlOrRd")
405
+ else:
406
+ cmap = plt.colormaps.get_cmap("YlGn")
407
+ cmap.set_bad(color="gray")
408
+ kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
409
+
410
+ if eta is None and xi is None:
411
+ _plot(
412
+ self.grid.ds,
413
+ field=field,
414
+ straddle=self.grid.straddle,
415
+ depth_contours=depth_contours,
416
+ title=title,
417
+ kwargs=kwargs,
418
+ c="g",
419
+ )
420
+ else:
421
+ if not layer_contours:
422
+ interface_depth = None
423
+ else:
424
+ # restrict number of layer_contours to 10 for the sake of plot clearity
425
+ nr_layers = len(interface_depth["s_w"])
426
+ selected_layers = np.linspace(
427
+ 0, nr_layers - 1, min(nr_layers, 10), dtype=int
428
+ )
429
+ interface_depth = interface_depth.isel(s_w=selected_layers)
430
+
431
+ if len(field.dims) == 2:
432
+ _section_plot(
433
+ field, interface_depth=interface_depth, title=title, kwargs=kwargs
434
+ )
435
+ else:
436
+ if "s_rho" in field.dims:
437
+ _profile_plot(field, title=title)
438
+ else:
439
+ _line_plot(field, title=title)
440
+
441
+ def save(self, filepath: str) -> None:
442
+ """
443
+ Save the initial conditions information to a netCDF4 file.
444
+
445
+ Parameters
446
+ ----------
447
+ filepath
448
+ """
449
+ self.ds.to_netcdf(filepath)
450
+
451
+ def to_yaml(self, filepath: str) -> None:
452
+ """
453
+ Export the parameters of the class to a YAML file, including the version of roms-tools.
454
+
455
+ Parameters
456
+ ----------
457
+ filepath : str
458
+ The path to the YAML file where the parameters will be saved.
459
+ """
460
+ # Serialize Grid data
461
+ grid_data = asdict(self.grid)
462
+ grid_data.pop("ds", None) # Exclude non-serializable fields
463
+ grid_data.pop("straddle", None)
464
+
465
+ # Serialize VerticalCoordinate data
466
+ vertical_coordinate_data = asdict(self.vertical_coordinate)
467
+ vertical_coordinate_data.pop("ds", None) # Exclude non-serializable fields
468
+ vertical_coordinate_data.pop("grid", None) # Exclude non-serializable fields
469
+
470
+ # Include the version of roms-tools
471
+ try:
472
+ roms_tools_version = importlib.metadata.version("roms-tools")
473
+ except importlib.metadata.PackageNotFoundError:
474
+ roms_tools_version = "unknown"
475
+
476
+ # Create header
477
+ header = f"---\nroms_tools_version: {roms_tools_version}\n---\n"
478
+
479
+ grid_yaml_data = {"Grid": grid_data}
480
+ vertical_coordinate_yaml_data = {"VerticalCoordinate": vertical_coordinate_data}
481
+
482
+ initial_conditions_data = {
483
+ "InitialConditions": {
484
+ "physics_source": self.physics_source,
485
+ "ini_time": self.ini_time.isoformat(),
486
+ "model_reference_date": self.model_reference_date.isoformat(),
487
+ }
488
+ }
489
+ # Include bgc_source if it's not None
490
+ if self.bgc_source is not None:
491
+ initial_conditions_data["InitialConditions"]["bgc_source"] = self.bgc_source
492
+
493
+ yaml_data = {
494
+ **grid_yaml_data,
495
+ **vertical_coordinate_yaml_data,
496
+ **initial_conditions_data,
497
+ }
498
+
499
+ with open(filepath, "w") as file:
500
+ # Write header
501
+ file.write(header)
502
+ # Write YAML data
503
+ yaml.dump(yaml_data, file, default_flow_style=False)
504
+
505
+ @classmethod
506
+ def from_yaml(cls, filepath: str) -> "InitialConditions":
507
+ """
508
+ Create an instance of the InitialConditions class from a YAML file.
509
+
510
+ Parameters
511
+ ----------
512
+ filepath : str
513
+ The path to the YAML file from which the parameters will be read.
514
+
515
+ Returns
516
+ -------
517
+ InitialConditions
518
+ An instance of the InitialConditions class.
519
+ """
520
+ # Read the entire file content
521
+ with open(filepath, "r") as file:
522
+ file_content = file.read()
523
+
524
+ # Split the content into YAML documents
525
+ documents = list(yaml.safe_load_all(file_content))
526
+
527
+ initial_conditions_data = None
528
+
529
+ # Process the YAML documents
530
+ for doc in documents:
531
+ if doc is None:
532
+ continue
533
+ if "InitialConditions" in doc:
534
+ initial_conditions_data = doc["InitialConditions"]
535
+ break
536
+
537
+ if initial_conditions_data is None:
538
+ raise ValueError(
539
+ "No InitialConditions configuration found in the YAML file."
540
+ )
541
+
542
+ # Convert from string to datetime
543
+ for date_string in ["model_reference_date", "ini_time"]:
544
+ initial_conditions_data[date_string] = datetime.fromisoformat(
545
+ initial_conditions_data[date_string]
546
+ )
547
+
548
+ # Create VerticalCoordinate instance from the YAML file
549
+ vertical_coordinate = VerticalCoordinate.from_yaml(filepath)
550
+ grid = vertical_coordinate.grid
551
+
552
+ # Create and return an instance of InitialConditions
553
+ return cls(
554
+ grid=grid,
555
+ vertical_coordinate=vertical_coordinate,
556
+ **initial_conditions_data,
557
+ )