roms-tools 0.20__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,596 @@
1
+ import xarray as xr
2
+ import pandas as pd
3
+ import dask
4
+ import yaml
5
+ from datatree import DataTree
6
+ import importlib.metadata
7
+ from dataclasses import dataclass, field, asdict
8
+ from roms_tools.setup.grid import Grid
9
+ from datetime import datetime
10
+ import numpy as np
11
+ from typing import Dict, Union
12
+ from roms_tools.setup.mixins import ROMSToolsMixins
13
+ from roms_tools.setup.datasets import (
14
+ ERA5Dataset,
15
+ ERA5Correction,
16
+ CESMBGCSurfaceForcingDataset,
17
+ )
18
+ from roms_tools.setup.utils import nan_check, interpolate_from_climatology
19
+ from roms_tools.setup.plot import _plot
20
+ import calendar
21
+ import matplotlib.pyplot as plt
22
+
23
+
24
+ @dataclass(frozen=True, kw_only=True)
25
+ class SurfaceForcing(ROMSToolsMixins):
26
+ """
27
+ Represents surface forcing data for ocean modeling.
28
+
29
+ Parameters
30
+ ----------
31
+ grid : Grid
32
+ Object representing the grid information.
33
+ start_time : datetime
34
+ Start time of the desired forcing data.
35
+ end_time : datetime
36
+ End time of the desired forcing data.
37
+ physics_source : Dict[str, Union[str, None]]
38
+ Dictionary specifying the source of the physical surface forcing data:
39
+ - "name" (str): Name of the data source (e.g., "ERA5").
40
+ - "path" (str): Path to the physical data file. Can contain wildcards.
41
+ - "climatology" (bool): Indicates if the physical data is climatology data. Defaults to False.
42
+ bgc_source : Optional[Dict[str, Union[str, None]]]
43
+ Dictionary specifying the source of the biogeochemical (BGC) initial condition data:
44
+ - "name" (str): Name of the BGC data source (e.g., "CESM_REGRIDDED").
45
+ - "path" (str): Path to the BGC data file. Can contain wildcards.
46
+ - "climatology" (bool): Indicates if the BGC data is climatology data. Defaults to False.
47
+ correct_radiation : bool
48
+ Whether to correct shortwave radiation. Default is False.
49
+ use_coarse_grid: bool
50
+ Whether to interpolate to coarsened grid. Default is False.
51
+ model_reference_date : datetime, optional
52
+ Reference date for the model. Default is January 1, 2000.
53
+
54
+ Attributes
55
+ ----------
56
+ ds : xr.Dataset
57
+ Xarray Dataset containing the surface forcing data.
58
+
59
+
60
+ Examples
61
+ --------
62
+ >>> atm_forcing = SurfaceForcing(
63
+ ... grid=grid,
64
+ ... start_time=datetime(2000, 1, 1),
65
+ ... end_time=datetime(2000, 1, 2),
66
+ ... physics_source={"name": "ERA5", "path": "physics_data.nc"},
67
+ ... correct_radiation=True,
68
+ ... )
69
+ """
70
+
71
+ grid: Grid
72
+ start_time: datetime
73
+ end_time: datetime
74
+ physics_source: Dict[str, Union[str, None]]
75
+ bgc_source: Dict[str, Union[str, None]] = None
76
+ correct_radiation: bool = False
77
+ use_coarse_grid: bool = False
78
+ model_reference_date: datetime = datetime(2000, 1, 1)
79
+ ds: xr.Dataset = field(init=False, repr=False)
80
+
81
+ def __post_init__(self):
82
+
83
+ self._input_checks()
84
+ lon, lat, angle, straddle = super().get_target_lon_lat(self.use_coarse_grid)
85
+ object.__setattr__(self, "target_lon", lon)
86
+ object.__setattr__(self, "target_lat", lat)
87
+
88
+ data = self._get_data()
89
+ data.choose_subdomain(
90
+ latitude_range=[lat.min().values, lat.max().values],
91
+ longitude_range=[lon.min().values, lon.max().values],
92
+ margin=2,
93
+ straddle=straddle,
94
+ )
95
+ vars_2d = ["uwnd", "vwnd", "swrad", "lwrad", "Tair", "qair", "rain"]
96
+ vars_3d = []
97
+ data_vars = super().regrid_data(data, vars_2d, vars_3d, lon, lat)
98
+ data_vars = super().process_velocities(data_vars, angle, interpolate=False)
99
+
100
+ if self.correct_radiation:
101
+ correction_data = self._get_correction_data()
102
+ # choose same subdomain as forcing data so that we can use same mask
103
+ coords_correction = {
104
+ correction_data.dim_names["latitude"]: data.ds[
105
+ data.dim_names["latitude"]
106
+ ],
107
+ correction_data.dim_names["longitude"]: data.ds[
108
+ data.dim_names["longitude"]
109
+ ],
110
+ }
111
+ correction_data.choose_subdomain(coords_correction, straddle=straddle)
112
+ # apply mask from ERA5 data
113
+ if "mask" in data.var_names.keys():
114
+ mask = xr.where(
115
+ data.ds[data.var_names["mask"]].isel(time=0).isnull(), 0, 1
116
+ )
117
+ for var in correction_data.ds.data_vars:
118
+ correction_data.ds[var] = xr.where(
119
+ mask == 1, correction_data.ds[var], np.nan
120
+ )
121
+ vars_2d = ["swr_corr"]
122
+ vars_3d = []
123
+ # spatial interpolation
124
+ data_vars_corr = super().regrid_data(
125
+ correction_data, vars_2d, vars_3d, lon, lat
126
+ )
127
+ # temporal interpolation
128
+ corr_factor = interpolate_from_climatology(
129
+ data_vars_corr["swr_corr"],
130
+ correction_data.dim_names["time"],
131
+ time=data_vars["swrad"].time,
132
+ )
133
+
134
+ data_vars["swrad"] = data_vars["swrad"] * corr_factor
135
+
136
+ object.__setattr__(data, "data_vars", data_vars)
137
+
138
+ if self.bgc_source is not None:
139
+ bgc_data = self._get_bgc_data()
140
+ bgc_data.choose_subdomain(
141
+ latitude_range=[lat.min().values, lat.max().values],
142
+ longitude_range=[lon.min().values, lon.max().values],
143
+ margin=2,
144
+ straddle=straddle,
145
+ )
146
+
147
+ vars_2d = bgc_data.var_names.keys()
148
+ vars_3d = []
149
+ data_vars = super().regrid_data(bgc_data, vars_2d, vars_3d, lon, lat)
150
+ object.__setattr__(bgc_data, "data_vars", data_vars)
151
+ else:
152
+ bgc_data = None
153
+
154
+ d_meta = super().get_variable_metadata()
155
+
156
+ ds = self._write_into_datatree(data, bgc_data, d_meta)
157
+
158
+ if self.use_coarse_grid:
159
+ mask = self.grid.ds["mask_coarse"].rename(
160
+ {"eta_coarse": "eta_rho", "xi_coarse": "xi_rho"}
161
+ )
162
+ else:
163
+ mask = self.grid.ds["mask_rho"]
164
+
165
+ for var in ds["physics"].data_vars:
166
+ nan_check(ds["physics"][var].isel(time=0), mask)
167
+
168
+ object.__setattr__(self, "ds", ds)
169
+
170
+ def _input_checks(self):
171
+
172
+ if "name" not in self.physics_source.keys():
173
+ raise ValueError("`physics_source` must include a 'name'.")
174
+ if "path" not in self.physics_source.keys():
175
+ raise ValueError("`physics_source` must include a 'path'.")
176
+ # set self.physics_source["climatology"] to False if not provided
177
+ object.__setattr__(
178
+ self,
179
+ "physics_source",
180
+ {
181
+ **self.physics_source,
182
+ "climatology": self.physics_source.get("climatology", False),
183
+ },
184
+ )
185
+
186
+ if self.bgc_source is not None:
187
+ if "name" not in self.bgc_source.keys():
188
+ raise ValueError(
189
+ "`bgc_source` must include a 'name' if it is provided."
190
+ )
191
+ if "path" not in self.bgc_source.keys():
192
+ raise ValueError(
193
+ "`bgc_source` must include a 'path' if it is provided."
194
+ )
195
+ # set self.bgc_source["climatology"] to False if not provided
196
+ object.__setattr__(
197
+ self,
198
+ "bgc_source",
199
+ {
200
+ **self.bgc_source,
201
+ "climatology": self.bgc_source.get("climatology", False),
202
+ },
203
+ )
204
+
205
+ def _get_data(self):
206
+
207
+ if self.physics_source["name"] == "ERA5":
208
+ data = ERA5Dataset(
209
+ filename=self.physics_source["path"],
210
+ start_time=self.start_time,
211
+ end_time=self.end_time,
212
+ climatology=self.physics_source["climatology"],
213
+ )
214
+ data.post_process()
215
+ else:
216
+ raise ValueError(
217
+ 'Only "ERA5" is a valid option for physics_source["name"].'
218
+ )
219
+
220
+ return data
221
+
222
+ def _get_correction_data(self):
223
+
224
+ if self.physics_source["name"] == "ERA5":
225
+ correction_data = ERA5Correction()
226
+ else:
227
+ raise ValueError(
228
+ "The 'correct_radiation' feature is currently only supported for 'ERA5' as the physics source. "
229
+ "Please ensure your 'physics_source' is set to 'ERA5' or implement additional handling for other sources."
230
+ )
231
+
232
+ return correction_data
233
+
234
+ def _get_bgc_data(self):
235
+
236
+ if self.bgc_source["name"] == "CESM_REGRIDDED":
237
+
238
+ bgc_data = CESMBGCSurfaceForcingDataset(
239
+ filename=self.bgc_source["path"],
240
+ start_time=self.start_time,
241
+ end_time=self.end_time,
242
+ climatology=self.bgc_source["climatology"],
243
+ )
244
+ else:
245
+ raise ValueError(
246
+ 'Only "CESM_REGRIDDED" is a valid option for bgc_source["name"].'
247
+ )
248
+
249
+ return bgc_data
250
+
251
+ def _write_into_dataset(self, data, d_meta):
252
+
253
+ # save in new dataset
254
+ ds = xr.Dataset()
255
+
256
+ for var in data.data_vars.keys():
257
+ ds[var] = data.data_vars[var].astype(np.float32)
258
+ ds[var].attrs["long_name"] = d_meta[var]["long_name"]
259
+ ds[var].attrs["units"] = d_meta[var]["units"]
260
+
261
+ if self.use_coarse_grid:
262
+ ds = ds.assign_coords({"lon": self.target_lon, "lat": self.target_lat})
263
+ ds = ds.rename({"eta_coarse": "eta_rho", "xi_coarse": "xi_rho"})
264
+
265
+ # Preserve absolute time coordinate for readability
266
+ ds = ds.assign_coords({"abs_time": ds["time"]})
267
+
268
+ # Convert the time coordinate to the format expected by ROMS
269
+ if data.climatology:
270
+ # Convert to pandas TimedeltaIndex
271
+ timedelta_index = pd.to_timedelta(ds["time"].values)
272
+ # Determine the start of the year for the base_datetime
273
+ start_of_year = datetime(self.model_reference_date.year, 1, 1)
274
+ # Calculate the offset from midnight of the new year
275
+ offset = self.model_reference_date - start_of_year
276
+ sfc_time = xr.DataArray(
277
+ timedelta_index - offset,
278
+ dims="time",
279
+ )
280
+ else:
281
+ sfc_time = (
282
+ (ds["time"] - np.datetime64(self.model_reference_date)).astype(
283
+ "float64"
284
+ )
285
+ / 3600
286
+ / 24
287
+ * 1e-9
288
+ )
289
+
290
+ ds = ds.assign_coords({"time": sfc_time})
291
+ ds["time"].attrs[
292
+ "long_name"
293
+ ] = f"days since {np.datetime_as_string(np.datetime64(self.model_reference_date), unit='D')}"
294
+ ds["time"].encoding["units"] = "days"
295
+ if data.climatology:
296
+ ds["time"].attrs["cycle_length"] = 365.25
297
+
298
+ return ds
299
+
300
+ def _write_into_datatree(self, data, bgc_data, d_meta):
301
+
302
+ ds = self._add_global_metadata()
303
+
304
+ ds = DataTree(name="root", data=ds)
305
+
306
+ ds_physics = self._write_into_dataset(data, d_meta)
307
+ ds_physics = self._add_global_metadata(ds_physics)
308
+ ds_physics.attrs["physics_source"] = self.physics_source["name"]
309
+ ds_physics = DataTree(name="physics", parent=ds, data=ds_physics)
310
+
311
+ if bgc_data:
312
+ ds_bgc = self._write_into_dataset(bgc_data, d_meta)
313
+ ds_bgc = self._add_global_metadata(ds_bgc)
314
+ ds_bgc.attrs["bgc_source"] = self.bgc_source["name"]
315
+ ds_bgc = DataTree(name="bgc", parent=ds, data=ds_bgc)
316
+
317
+ return ds
318
+
319
+ def _add_global_metadata(self, ds=None):
320
+
321
+ if ds is None:
322
+ ds = xr.Dataset()
323
+ ds.attrs["title"] = "ROMS surface forcing file created by ROMS-Tools"
324
+ # Include the version of roms-tools
325
+ try:
326
+ roms_tools_version = importlib.metadata.version("roms-tools")
327
+ except importlib.metadata.PackageNotFoundError:
328
+ roms_tools_version = "unknown"
329
+ ds.attrs["roms_tools_version"] = roms_tools_version
330
+ ds.attrs["start_time"] = str(self.start_time)
331
+ ds.attrs["end_time"] = str(self.end_time)
332
+ ds.attrs["physics_source"] = self.physics_source["name"]
333
+ if self.bgc_source is not None:
334
+ ds.attrs["bgc_source"] = self.bgc_source["name"]
335
+ ds.attrs["correct_radiation"] = str(self.correct_radiation)
336
+ ds.attrs["use_coarse_grid"] = str(self.use_coarse_grid)
337
+ ds.attrs["model_reference_date"] = str(self.model_reference_date)
338
+
339
+ return ds
340
+
341
+ def plot(self, varname, time=0) -> None:
342
+ """
343
+ Plot the specified surface forcing field for a given time slice.
344
+
345
+ Parameters
346
+ ----------
347
+ varname : str
348
+ The name of the surface forcing field to plot. Options include:
349
+ - "uwnd": 10 meter wind in x-direction.
350
+ - "vwnd": 10 meter wind in y-direction.
351
+ - "swrad": Downward short-wave (solar) radiation.
352
+ - "lwrad": Downward long-wave (thermal) radiation.
353
+ - "Tair": Air temperature at 2m.
354
+ - "qair": Absolute humidity at 2m.
355
+ - "rain": Total precipitation.
356
+ - "pco2_air": Atmospheric pCO2.
357
+ - "pco2_air_alt": Atmospheric pCO2, alternative CO2.
358
+ - "iron": Iron decomposition.
359
+ - "dust": Dust decomposition.
360
+ - "nox": NOx decomposition.
361
+ - "nhy": NHy decomposition.
362
+ time : int, optional
363
+ The time index to plot. Default is 0, which corresponds to the first
364
+ time slice.
365
+
366
+ Returns
367
+ -------
368
+ None
369
+ This method does not return any value. It generates and displays a plot.
370
+
371
+ Raises
372
+ ------
373
+ ValueError
374
+ If the specified varname is not one of the valid options.
375
+
376
+
377
+ Examples
378
+ --------
379
+ >>> atm_forcing.plot("uwnd", time=0)
380
+ """
381
+
382
+ if varname in self.ds["physics"]:
383
+ ds = self.ds["physics"]
384
+ else:
385
+ if "bgc" in self.ds and varname in self.ds["bgc"]:
386
+ ds = self.ds["bgc"]
387
+ else:
388
+ raise ValueError(
389
+ f"Variable '{varname}' is not found in 'physics' or 'bgc' datasets."
390
+ )
391
+
392
+ field = ds[varname].isel(time=time).load()
393
+ title = field.long_name
394
+
395
+ # choose colorbar
396
+ if varname in ["uwnd", "vwnd"]:
397
+ vmax = max(field.max().values, -field.min().values)
398
+ vmin = -vmax
399
+ cmap = plt.colormaps.get_cmap("RdBu_r")
400
+ else:
401
+ vmax = field.max().values
402
+ vmin = field.min().values
403
+ if varname in ["swrad", "lwrad", "Tair", "qair"]:
404
+ cmap = plt.colormaps.get_cmap("YlOrRd")
405
+ else:
406
+ cmap = plt.colormaps.get_cmap("YlGnBu")
407
+ cmap.set_bad(color="gray")
408
+
409
+ kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
410
+
411
+ _plot(
412
+ self.grid.ds,
413
+ field=field,
414
+ straddle=self.grid.straddle,
415
+ coarse_grid=self.use_coarse_grid,
416
+ title=title,
417
+ kwargs=kwargs,
418
+ c="g",
419
+ )
420
+
421
+ def save(self, filepath: str, time_chunk_size: int = 1) -> None:
422
+ """
423
+ Save the interpolated surface forcing fields to netCDF4 files.
424
+
425
+ This method groups the dataset by year and month, chunks the data by the specified
426
+ time chunk size, and saves each chunked subset to a separate netCDF4 file named
427
+ according to the year, month, and day range if not a complete month of data is included.
428
+
429
+ Parameters
430
+ ----------
431
+ filepath : str
432
+ The base path and filename for the output files. The files will be named with
433
+ the format "filepath.YYYYMM.nc" if a full month of data is included, or
434
+ "filepath.YYYYMMDD-DD.nc" otherwise.
435
+ time_chunk_size : int, optional
436
+ Number of time slices to include in each chunk along the time dimension. Default is 1,
437
+ meaning each chunk contains one time slice.
438
+
439
+ Returns
440
+ -------
441
+ None
442
+ """
443
+
444
+ datasets = []
445
+ filenames = []
446
+ writes = []
447
+
448
+ for node in ["physics", "bgc"]:
449
+ if node in self.ds:
450
+ ds = self.ds[node].to_dataset()
451
+ if hasattr(ds["time"], "cycle_length"):
452
+ filename = f"{filepath}_{node}_clim.nc"
453
+ filenames.append(filename)
454
+ datasets.append(ds)
455
+ else:
456
+ # Group dataset by year
457
+ gb = ds.groupby("abs_time.year")
458
+
459
+ for year, group_ds in gb:
460
+ # Further group each yearly group by month
461
+ sub_gb = group_ds.groupby("abs_time.month")
462
+
463
+ for month, ds in sub_gb:
464
+ # Chunk the dataset by the specified time chunk size
465
+ ds = ds.chunk({"time": time_chunk_size})
466
+ datasets.append(ds)
467
+
468
+ # Determine the number of days in the month
469
+ num_days_in_month = calendar.monthrange(year, month)[1]
470
+ first_day = ds.abs_time.dt.day.values[0]
471
+ last_day = ds.abs_time.dt.day.values[-1]
472
+
473
+ # Create filename based on whether the dataset contains a full month
474
+ if first_day == 1 and last_day == num_days_in_month:
475
+ # Full month format: "filepath_physics_YYYYMM.nc"
476
+ year_month_str = f"{year}{month:02}"
477
+ filename = f"{filepath}_{node}_{year_month_str}.nc"
478
+ else:
479
+ # Partial month format: "filepath_physics_YYYYMMDD-DD.nc"
480
+ year_month_day_str = (
481
+ f"{year}{month:02}{first_day:02}-{last_day:02}"
482
+ )
483
+ filename = f"{filepath}_{node}_{year_month_day_str}.nc"
484
+ filenames.append(filename)
485
+
486
+ print("Saving the following files:")
487
+ for ds, filename in zip(datasets, filenames):
488
+ print(filename)
489
+ # Prepare the dataset for writing to a netCDF file without immediately computing
490
+ write = ds.to_netcdf(filename, compute=False)
491
+ writes.append(write)
492
+
493
+ # Perform the actual write operations in parallel
494
+ dask.compute(*writes)
495
+
496
+ def to_yaml(self, filepath: str) -> None:
497
+ """
498
+ Export the parameters of the class to a YAML file, including the version of roms-tools.
499
+
500
+ Parameters
501
+ ----------
502
+ filepath : str
503
+ The path to the YAML file where the parameters will be saved.
504
+ """
505
+ # Serialize Grid data
506
+ grid_data = asdict(self.grid)
507
+ grid_data.pop("ds", None) # Exclude non-serializable fields
508
+ grid_data.pop("straddle", None)
509
+
510
+ # Include the version of roms-tools
511
+ try:
512
+ roms_tools_version = importlib.metadata.version("roms-tools")
513
+ except importlib.metadata.PackageNotFoundError:
514
+ roms_tools_version = "unknown"
515
+
516
+ # Create header
517
+ header = f"---\nroms_tools_version: {roms_tools_version}\n---\n"
518
+
519
+ # Create YAML data for Grid and optional attributes
520
+ grid_yaml_data = {"Grid": grid_data}
521
+
522
+ # Combine all sections
523
+ surface_forcing_data = {
524
+ "SurfaceForcing": {
525
+ "start_time": self.start_time.isoformat(),
526
+ "end_time": self.end_time.isoformat(),
527
+ "physics_source": self.physics_source,
528
+ "correct_radiation": self.correct_radiation,
529
+ "use_coarse_grid": self.use_coarse_grid,
530
+ "model_reference_date": self.model_reference_date.isoformat(),
531
+ }
532
+ }
533
+ # Include bgc_source if it's not None
534
+ if self.bgc_source is not None:
535
+ surface_forcing_data["SurfaceForcing"]["bgc_source"] = self.bgc_source
536
+
537
+ # Merge YAML data while excluding empty sections
538
+ yaml_data = {
539
+ **grid_yaml_data,
540
+ **surface_forcing_data,
541
+ }
542
+
543
+ with open(filepath, "w") as file:
544
+ # Write header
545
+ file.write(header)
546
+ # Write YAML data
547
+ yaml.dump(yaml_data, file, default_flow_style=False)
548
+
549
+ @classmethod
550
+ def from_yaml(cls, filepath: str) -> "SurfaceForcing":
551
+ """
552
+ Create an instance of the SurfaceForcing class from a YAML file.
553
+
554
+ Parameters
555
+ ----------
556
+ filepath : str
557
+ The path to the YAML file from which the parameters will be read.
558
+
559
+ Returns
560
+ -------
561
+ SurfaceForcing
562
+ An instance of the SurfaceForcing class.
563
+ """
564
+ # Read the entire file content
565
+ with open(filepath, "r") as file:
566
+ file_content = file.read()
567
+
568
+ # Split the content into YAML documents
569
+ documents = list(yaml.safe_load_all(file_content))
570
+
571
+ surface_forcing_data = None
572
+
573
+ # Process the YAML documents
574
+ for doc in documents:
575
+ if doc is None:
576
+ continue
577
+ if "SurfaceForcing" in doc:
578
+ surface_forcing_data = doc["SurfaceForcing"]
579
+
580
+ if surface_forcing_data is None:
581
+ raise ValueError("No SurfaceForcing configuration found in the YAML file.")
582
+
583
+ # Convert from string to datetime
584
+ for date_string in ["model_reference_date", "start_time", "end_time"]:
585
+ surface_forcing_data[date_string] = datetime.fromisoformat(
586
+ surface_forcing_data[date_string]
587
+ )
588
+
589
+ # Create Grid instance from the YAML file
590
+ grid = Grid.from_yaml(filepath)
591
+
592
+ # Create and return an instance of SurfaceForcing
593
+ return cls(
594
+ grid=grid,
595
+ **surface_forcing_data,
596
+ )