roms-tools 0.0.6__py3-none-any.whl → 0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,935 @@
1
+ import xarray as xr
2
+ import dask
3
+ import yaml
4
+ import importlib.metadata
5
+ from dataclasses import dataclass, field, asdict
6
+ from roms_tools.setup.grid import Grid
7
+ from datetime import datetime
8
+ import glob
9
+ import numpy as np
10
+ from typing import Optional, Dict
11
+ from roms_tools.setup.fill import fill_and_interpolate
12
+ from roms_tools.setup.datasets import Dataset
13
+ from roms_tools.setup.utils import nan_check
14
+ from roms_tools.setup.plot import _plot
15
+ import calendar
16
+ import matplotlib.pyplot as plt
17
+
18
+
19
+ @dataclass(frozen=True, kw_only=True)
20
+ class SWRCorrection:
21
+ """
22
+ Configuration for shortwave radiation correction.
23
+
24
+ Parameters
25
+ ----------
26
+ filename : str
27
+ Filename of the correction data.
28
+ varname : str
29
+ Variable identifier for the correction.
30
+ dim_names: Dict[str, str], optional
31
+ Dictionary specifying the names of dimensions in the dataset.
32
+ Default is {"longitude": "lon", "latitude": "lat", "time": "time"}.
33
+ temporal_resolution : str, optional
34
+ Temporal resolution of the correction data. Default is "climatology".
35
+
36
+ Attributes
37
+ ----------
38
+ ds : xr.Dataset
39
+ The loaded xarray Dataset containing the correction data.
40
+
41
+ Examples
42
+ --------
43
+ >>> swr_correction = SWRCorrection(
44
+ ... filename="correction_data.nc",
45
+ ... varname="corr",
46
+ ... dim_names={
47
+ ... "time": "time",
48
+ ... "latitude": "latitude",
49
+ ... "longitude": "longitude",
50
+ ... },
51
+ ... temporal_resolution="climatology",
52
+ ... )
53
+ """
54
+
55
+ filename: str
56
+ varname: str
57
+ dim_names: Dict[str, str] = field(
58
+ default_factory=lambda: {
59
+ "longitude": "longitude",
60
+ "latitude": "latitutde",
61
+ "time": "time",
62
+ }
63
+ )
64
+ temporal_resolution: str = "climatology"
65
+ ds: xr.Dataset = field(init=False, repr=False)
66
+
67
+ def __post_init__(self):
68
+ if self.temporal_resolution != "climatology":
69
+ raise NotImplementedError(
70
+ f"temporal_resolution must be 'climatology', got {self.temporal_resolution}"
71
+ )
72
+
73
+ ds = self._load_data()
74
+ self._check_dataset(ds)
75
+ ds = self._ensure_latitude_ascending(ds)
76
+
77
+ object.__setattr__(self, "ds", ds)
78
+
79
+ def _load_data(self):
80
+ """
81
+ Load data from the specified file.
82
+
83
+ Returns
84
+ -------
85
+ ds : xr.Dataset
86
+ The loaded xarray Dataset containing the correction data.
87
+
88
+ Raises
89
+ ------
90
+ FileNotFoundError
91
+ If the specified file does not exist.
92
+
93
+ """
94
+ # Check if the file exists
95
+
96
+ # Check if any file matching the wildcard pattern exists
97
+ matching_files = glob.glob(self.filename)
98
+ if not matching_files:
99
+ raise FileNotFoundError(
100
+ f"No files found matching the pattern '{self.filename}'."
101
+ )
102
+
103
+ # Load the dataset
104
+ ds = xr.open_dataset(
105
+ self.filename,
106
+ chunks={
107
+ self.dim_names["time"]: -1,
108
+ self.dim_names["latitude"]: -1,
109
+ self.dim_names["longitude"]: -1,
110
+ },
111
+ )
112
+
113
+ return ds
114
+
115
+ def _check_dataset(self, ds: xr.Dataset) -> None:
116
+ """
117
+ Check if the dataset contains the specified variable and dimensions.
118
+
119
+ Parameters
120
+ ----------
121
+ ds : xr.Dataset
122
+ The xarray Dataset to check.
123
+
124
+ Raises
125
+ ------
126
+ ValueError
127
+ If the dataset does not contain the specified variable or dimensions.
128
+ """
129
+ if self.varname not in ds:
130
+ raise ValueError(
131
+ f"The dataset does not contain the variable '{self.varname}'."
132
+ )
133
+
134
+ for dim in self.dim_names.values():
135
+ if dim not in ds.dims:
136
+ raise ValueError(f"The dataset does not contain the dimension '{dim}'.")
137
+
138
+ def _ensure_latitude_ascending(self, ds: xr.Dataset) -> xr.Dataset:
139
+ """
140
+ Ensure that the latitude dimension is in ascending order.
141
+
142
+ Parameters
143
+ ----------
144
+ ds : xr.Dataset
145
+ The xarray Dataset to check.
146
+
147
+ Returns
148
+ -------
149
+ ds : xr.Dataset
150
+ The xarray Dataset with latitude in ascending order.
151
+ """
152
+ # Make sure that latitude is ascending
153
+ lat_diff = np.diff(ds[self.dim_names["latitude"]])
154
+ if np.all(lat_diff < 0):
155
+ ds = ds.isel(**{self.dim_names["latitude"]: slice(None, None, -1)})
156
+
157
+ return ds
158
+
159
+ def _handle_longitudes(self, straddle: bool) -> None:
160
+ """
161
+ Handles the conversion of longitude values in the dataset from one range to another.
162
+
163
+ Parameters
164
+ ----------
165
+ straddle : bool
166
+ If True, target longitudes are in range [-180, 180].
167
+ If False, target longitudes are in range [0, 360].
168
+
169
+ Raises
170
+ ------
171
+ ValueError: If the conversion results in discontinuous longitudes.
172
+ """
173
+ lon = self.ds[self.dim_names["longitude"]]
174
+
175
+ if lon.min().values < 0 and not straddle:
176
+ # Convert from [-180, 180] to [0, 360]
177
+ self.ds[self.dim_names["longitude"]] = xr.where(lon < 0, lon + 360, lon)
178
+
179
+ if lon.max().values > 180 and straddle:
180
+ # Convert from [0, 360] to [-180, 180]
181
+ self.ds[self.dim_names["longitude"]] = xr.where(lon > 180, lon - 360, lon)
182
+
183
+ def _choose_subdomain(self, coords) -> xr.Dataset:
184
+ """
185
+ Selects a subdomain from the dataset based on the specified latitude and longitude ranges.
186
+
187
+ Parameters
188
+ ----------
189
+ coords : dict
190
+ A dictionary specifying the target coordinates.
191
+
192
+ Returns
193
+ -------
194
+ xr.Dataset
195
+ The subset of the original dataset representing the chosen subdomain.
196
+
197
+ Raises
198
+ ------
199
+ ValueError
200
+ If the specified subdomain is not fully contained within the dataset.
201
+ """
202
+
203
+ # Select the subdomain based on the specified latitude and longitude ranges
204
+ subdomain = self.ds.sel(**coords)
205
+
206
+ # Check if the selected subdomain contains the specified latitude and longitude values
207
+ if not subdomain[self.dim_names["latitude"]].equals(
208
+ coords[self.dim_names["latitude"]]
209
+ ):
210
+ raise ValueError(
211
+ "The correction dataset does not contain all specified latitude values."
212
+ )
213
+ if not subdomain[self.dim_names["longitude"]].equals(
214
+ coords[self.dim_names["longitude"]]
215
+ ):
216
+ raise ValueError(
217
+ "The correction dataset does not contain all specified longitude values."
218
+ )
219
+
220
+ return subdomain
221
+
222
+ def _interpolate_temporally(self, field, time):
223
+ """
224
+ Interpolates the given field temporally based on the specified time points.
225
+
226
+ Parameters
227
+ ----------
228
+ field : xarray.DataArray
229
+ The field data to be interpolated. This can be any variable from the dataset that
230
+ requires temporal interpolation, such as correction factors or any other relevant data.
231
+ time : xarray.DataArray or pandas.DatetimeIndex
232
+ The target time points for interpolation.
233
+
234
+ Returns
235
+ -------
236
+ xr.DataArray
237
+ The field values interpolated to the specified time points.
238
+
239
+ Raises
240
+ ------
241
+ NotImplementedError
242
+ If the temporal resolution is not set to 'climatology'.
243
+
244
+ """
245
+ if self.temporal_resolution != "climatology":
246
+ raise NotImplementedError(
247
+ f"temporal_resolution must be 'climatology', got {self.temporal_resolution}"
248
+ )
249
+ else:
250
+ field[self.dim_names["time"]] = field[self.dim_names["time"]].dt.days
251
+ day_of_year = time.dt.dayofyear
252
+
253
+ # Concatenate across the beginning and end of the year
254
+ time_concat = xr.concat(
255
+ [
256
+ field[self.dim_names["time"]][-1] - 365.25,
257
+ field[self.dim_names["time"]],
258
+ 365.25 + field[self.dim_names["time"]][0],
259
+ ],
260
+ dim=self.dim_names["time"],
261
+ )
262
+ field_concat = xr.concat(
263
+ [
264
+ field.isel({self.dim_names["time"]: -1}),
265
+ field,
266
+ field.isel({self.dim_names["time"]: 0}),
267
+ ],
268
+ dim=self.dim_names["time"],
269
+ )
270
+ field_concat["time"] = time_concat
271
+ # Interpolate to specified times
272
+ field_interpolated = field_concat.interp(time=day_of_year, method="linear")
273
+
274
+ return field_interpolated
275
+
276
+ @classmethod
277
+ def from_yaml(cls, filepath: str) -> "SWRCorrection":
278
+ """
279
+ Create an instance of the class from a YAML file.
280
+
281
+ Parameters
282
+ ----------
283
+ filepath : str
284
+ The path to the YAML file from which the parameters will be read.
285
+
286
+ Returns
287
+ -------
288
+ Grid
289
+ An instance of the Grid class.
290
+ """
291
+ # Read the entire file content
292
+ with open(filepath, "r") as file:
293
+ file_content = file.read()
294
+
295
+ # Split the content into YAML documents
296
+ documents = list(yaml.safe_load_all(file_content))
297
+
298
+ swr_correction_data = None
299
+
300
+ # Iterate over documents to find the header and grid configuration
301
+ for doc in documents:
302
+ if doc is None:
303
+ continue
304
+ if "SWRCorrection" in doc:
305
+ swr_correction_data = doc["SWRCorrection"]
306
+ break
307
+
308
+ if swr_correction_data is None:
309
+ raise ValueError("No SWRCorrection configuration found in the YAML file.")
310
+
311
+ return cls(**swr_correction_data)
312
+
313
+
314
+ @dataclass(frozen=True, kw_only=True)
315
+ class Rivers:
316
+ """
317
+ Configuration for river forcing.
318
+
319
+ Parameters
320
+ ----------
321
+ filename : str, optional
322
+ Filename of the river forcing data.
323
+ """
324
+
325
+ filename: str = ""
326
+
327
+ def __post_init__(self):
328
+ if not self.filename:
329
+ raise ValueError("The 'filename' must be provided.")
330
+
331
+ @classmethod
332
+ def from_yaml(cls, filepath: str) -> "Rivers":
333
+ """
334
+ Create an instance of the class from a YAML file.
335
+
336
+ Parameters
337
+ ----------
338
+ filepath : str
339
+ The path to the YAML file from which the parameters will be read.
340
+
341
+ Returns
342
+ -------
343
+ Grid
344
+ An instance of the Grid class.
345
+ """
346
+ # Read the entire file content
347
+ with open(filepath, "r") as file:
348
+ file_content = file.read()
349
+
350
+ # Split the content into YAML documents
351
+ documents = list(yaml.safe_load_all(file_content))
352
+
353
+ rivers_data = None
354
+
355
+ # Iterate over documents to find the header and grid configuration
356
+ for doc in documents:
357
+ if doc is None:
358
+ continue
359
+ if "Rivers" in doc:
360
+ rivers_data = doc
361
+ break
362
+
363
+ if rivers_data is None:
364
+ raise ValueError("No Rivers configuration found in the YAML file.")
365
+
366
+ return cls(**rivers_data)
367
+
368
+
369
+ @dataclass(frozen=True, kw_only=True)
370
+ class AtmosphericForcing:
371
+ """
372
+ Represents atmospheric forcing data for ocean modeling.
373
+
374
+ Parameters
375
+ ----------
376
+ grid : Grid
377
+ Object representing the grid information.
378
+ use_coarse_grid: bool
379
+ Whether to interpolate to coarsened grid. Default is False.
380
+ start_time : datetime
381
+ Start time of the desired forcing data.
382
+ end_time : datetime
383
+ End time of the desired forcing data.
384
+ model_reference_date : datetime, optional
385
+ Reference date for the model. Default is January 1, 2000.
386
+ source : str, optional
387
+ Source of the atmospheric forcing data. Default is "ERA5".
388
+ filename: str
389
+ Path to the atmospheric forcing source data file. Can contain wildcards.
390
+ swr_correction : SWRCorrection
391
+ Shortwave radiation correction configuration.
392
+ rivers : Rivers, optional
393
+ River forcing configuration.
394
+
395
+ Attributes
396
+ ----------
397
+ ds : xr.Dataset
398
+ Xarray Dataset containing the atmospheric forcing data.
399
+
400
+
401
+ Examples
402
+ --------
403
+ >>> grid_info = Grid(...)
404
+ >>> start_time = datetime(2000, 1, 1)
405
+ >>> end_time = datetime(2000, 1, 2)
406
+ >>> atm_forcing = AtmosphericForcing(
407
+ ... grid=grid_info,
408
+ ... start_time=start_time,
409
+ ... end_time=end_time,
410
+ ... source="ERA5",
411
+ ... filename="atmospheric_data_*.nc",
412
+ ... swr_correction=swr_correction,
413
+ ... )
414
+ """
415
+
416
+ grid: Grid
417
+ use_coarse_grid: bool = False
418
+ start_time: datetime
419
+ end_time: datetime
420
+ model_reference_date: datetime = datetime(2000, 1, 1)
421
+ source: str = "ERA5"
422
+ filename: str
423
+ swr_correction: Optional["SWRCorrection"] = None
424
+ rivers: Optional["Rivers"] = None
425
+ ds: xr.Dataset = field(init=False, repr=False)
426
+
427
+ def __post_init__(self):
428
+
429
+ # Check that the source is "ERA5"
430
+ if self.source != "ERA5":
431
+ raise ValueError('Only "ERA5" is a valid option for source.')
432
+ if self.source == "ERA5":
433
+ dims = {"longitude": "longitude", "latitude": "latitude", "time": "time"}
434
+ varnames = {
435
+ "u10": "u10",
436
+ "v10": "v10",
437
+ "swr": "ssr",
438
+ "lwr": "strd",
439
+ "t2m": "t2m",
440
+ "d2m": "d2m",
441
+ "rain": "tp",
442
+ "mask": "sst",
443
+ }
444
+
445
+ data = Dataset(
446
+ filename=self.filename,
447
+ start_time=self.start_time,
448
+ end_time=self.end_time,
449
+ var_names=varnames.values(),
450
+ dim_names=dims,
451
+ )
452
+
453
+ if self.use_coarse_grid:
454
+ if "lon_coarse" not in self.grid.ds:
455
+ raise ValueError(
456
+ "Grid has not been coarsened yet. Execute grid.coarsen() first."
457
+ )
458
+
459
+ lon = self.grid.ds.lon_coarse
460
+ lat = self.grid.ds.lat_coarse
461
+ angle = self.grid.ds.angle_coarse
462
+ else:
463
+ lon = self.grid.ds.lon_rho
464
+ lat = self.grid.ds.lat_rho
465
+ angle = self.grid.ds.angle
466
+
467
+ # operate on longitudes between -180 and 180 unless ROMS domain lies at least 5 degrees in lontitude away from Greenwich meridian
468
+ lon = xr.where(lon > 180, lon - 360, lon)
469
+ straddle = True
470
+ if not self.grid.straddle and abs(lon).min() > 5:
471
+ lon = xr.where(lon < 0, lon + 360, lon)
472
+ straddle = False
473
+
474
+ # The following consists of two steps:
475
+ # Step 1: Choose subdomain of forcing data including safety margin for interpolation, and Step 2: Convert to the proper longitude range.
476
+ # We perform these two steps for two reasons:
477
+ # A) Since the horizontal dimensions consist of a single chunk, selecting a subdomain before interpolation is a lot more performant.
478
+ # B) Step 1 is necessary to avoid discontinuous longitudes that could be introduced by Step 2. Specifically, discontinuous longitudes
479
+ # can lead to artifacts in the interpolation process. Specifically, if there is a data gap if data is not global,
480
+ # discontinuous longitudes could result in values that appear to come from a distant location instead of producing NaNs.
481
+ # These NaNs are important as they can be identified and handled appropriately by the nan_check function.
482
+ data.choose_subdomain(
483
+ latitude_range=[lat.min().values, lat.max().values],
484
+ longitude_range=[lon.min().values, lon.max().values],
485
+ margin=2,
486
+ straddle=straddle,
487
+ )
488
+
489
+ # interpolate onto desired grid
490
+ coords = {dims["latitude"]: lat, dims["longitude"]: lon}
491
+
492
+ data_vars = {}
493
+
494
+ mask = xr.where(data.ds[varnames["mask"]].isel(time=0).isnull(), 0, 1)
495
+
496
+ # Fill and interpolate each variable
497
+ for var in varnames.keys():
498
+ if var != "mask":
499
+ data_vars[var] = fill_and_interpolate(
500
+ data.ds[varnames[var]],
501
+ mask,
502
+ list(coords.keys()),
503
+ coords,
504
+ method="linear",
505
+ )
506
+
507
+ # Access the interpolated variables using data_vars dictionary
508
+ u10 = data_vars["u10"]
509
+ v10 = data_vars["v10"]
510
+ swr = data_vars["swr"]
511
+ lwr = data_vars["lwr"]
512
+ t2m = data_vars["t2m"]
513
+ d2m = data_vars["d2m"]
514
+ rain = data_vars["rain"]
515
+
516
+ if self.source == "ERA5":
517
+ # translate radiation to fluxes. ERA5 stores values integrated over 1 hour.
518
+ swr = swr / 3600 # from J/m^2 to W/m^2
519
+ lwr = lwr / 3600 # from J/m^2 to W/m^2
520
+ rain = rain * 100 * 24 # from m to cm/day
521
+ # convert from K to C
522
+ t2m = t2m - 273.15
523
+ d2m = d2m - 273.15
524
+ # relative humidity fraction
525
+ qair = np.exp((17.625 * d2m) / (243.04 + d2m)) / np.exp(
526
+ (17.625 * t2m) / (243.04 + t2m)
527
+ )
528
+ # convert relative to absolute humidity assuming constant pressure
529
+ patm = 1010.0
530
+ cff = (
531
+ (1.0007 + 3.46e-6 * patm)
532
+ * 6.1121
533
+ * np.exp(17.502 * t2m / (240.97 + t2m))
534
+ )
535
+ cff = cff * qair
536
+ qair = 0.62197 * (cff / (patm - 0.378 * cff))
537
+
538
+ # correct shortwave radiation
539
+ if self.swr_correction:
540
+
541
+ # choose same subdomain as forcing data so that we can use same mask
542
+ self.swr_correction._handle_longitudes(straddle=straddle)
543
+ coords_correction = {
544
+ self.swr_correction.dim_names["latitude"]: data.ds[
545
+ data.dim_names["latitude"]
546
+ ],
547
+ self.swr_correction.dim_names["longitude"]: data.ds[
548
+ data.dim_names["longitude"]
549
+ ],
550
+ }
551
+ subdomain = self.swr_correction._choose_subdomain(coords_correction)
552
+
553
+ # spatial interpolation
554
+ corr_factor = subdomain[self.swr_correction.varname]
555
+ coords_correction = {
556
+ self.swr_correction.dim_names["latitude"]: lat,
557
+ self.swr_correction.dim_names["longitude"]: lon,
558
+ }
559
+ corr_factor = fill_and_interpolate(
560
+ corr_factor,
561
+ mask,
562
+ list(coords_correction.keys()),
563
+ coords_correction,
564
+ method="linear",
565
+ )
566
+
567
+ # temporal interpolation
568
+ corr_factor = self.swr_correction._interpolate_temporally(
569
+ corr_factor, time=swr.time
570
+ )
571
+
572
+ swr = corr_factor * swr
573
+
574
+ if self.rivers:
575
+ NotImplementedError("River forcing is not implemented yet.")
576
+ # rain = rain + rivers
577
+
578
+ # save in new dataset
579
+ ds = xr.Dataset()
580
+
581
+ ds["uwnd"] = (u10 * np.cos(angle) + v10 * np.sin(angle)).astype(
582
+ np.float32
583
+ ) # rotate to grid orientation
584
+ ds["uwnd"].attrs["long_name"] = "10 meter wind in x-direction"
585
+ ds["uwnd"].attrs["units"] = "m/s"
586
+
587
+ ds["vwnd"] = (v10 * np.cos(angle) - u10 * np.sin(angle)).astype(
588
+ np.float32
589
+ ) # rotate to grid orientation
590
+ ds["vwnd"].attrs["long_name"] = "10 meter wind in y-direction"
591
+ ds["vwnd"].attrs["units"] = "m/s"
592
+
593
+ ds["swrad"] = swr.astype(np.float32)
594
+ ds["swrad"].attrs["long_name"] = "Downward short-wave (solar) radiation"
595
+ ds["swrad"].attrs["units"] = "W/m^2"
596
+
597
+ ds["lwrad"] = lwr.astype(np.float32)
598
+ ds["lwrad"].attrs["long_name"] = "Downward long-wave (thermal) radiation"
599
+ ds["lwrad"].attrs["units"] = "W/m^2"
600
+
601
+ ds["Tair"] = t2m.astype(np.float32)
602
+ ds["Tair"].attrs["long_name"] = "Air temperature at 2m"
603
+ ds["Tair"].attrs["units"] = "degrees C"
604
+
605
+ ds["qair"] = qair.astype(np.float32)
606
+ ds["qair"].attrs["long_name"] = "Absolute humidity at 2m"
607
+ ds["qair"].attrs["units"] = "kg/kg"
608
+
609
+ ds["rain"] = rain.astype(np.float32)
610
+ ds["rain"].attrs["long_name"] = "Total precipitation"
611
+ ds["rain"].attrs["units"] = "cm/day"
612
+
613
+ ds.attrs["title"] = "ROMS atmospheric forcing file created by ROMS-Tools"
614
+ # Include the version of roms-tools
615
+ try:
616
+ roms_tools_version = importlib.metadata.version("roms-tools")
617
+ except importlib.metadata.PackageNotFoundError:
618
+ roms_tools_version = "unknown"
619
+ ds.attrs["roms_tools_version"] = roms_tools_version
620
+ ds.attrs["start_time"] = str(self.start_time)
621
+ ds.attrs["end_time"] = str(self.end_time)
622
+ ds.attrs["model_reference_date"] = str(self.model_reference_date)
623
+ ds.attrs["source"] = self.source
624
+ ds.attrs["use_coarse_grid"] = str(self.use_coarse_grid)
625
+ ds.attrs["swr_correction"] = str(self.swr_correction is not None)
626
+ ds.attrs["rivers"] = str(self.rivers is not None)
627
+
628
+ ds = ds.assign_coords({"lon": lon, "lat": lat})
629
+ if self.use_coarse_grid:
630
+ ds = ds.rename({"eta_coarse": "eta_rho", "xi_coarse": "xi_rho"})
631
+ mask_roms = self.grid.ds["mask_coarse"].rename(
632
+ {"eta_coarse": "eta_rho", "xi_coarse": "xi_rho"}
633
+ )
634
+ else:
635
+ mask_roms = self.grid.ds["mask_rho"]
636
+
637
+ if dims["time"] != "time":
638
+ ds = ds.rename({dims["time"]: "time"})
639
+
640
+ # Preserve the original time coordinate for readability
641
+ ds = ds.assign_coords({"absolute_time": ds["time"]})
642
+
643
+ # Translate the time coordinate to days since the model reference date
644
+ model_reference_date = np.datetime64(self.model_reference_date)
645
+
646
+ # Convert the time coordinate to the format expected by ROMS (days since model reference date)
647
+ ds["time"] = (
648
+ (ds["time"] - model_reference_date).astype("float64") / 3600 / 24 * 1e-9
649
+ )
650
+ ds["time"].attrs[
651
+ "long_name"
652
+ ] = f"time since {np.datetime_as_string(model_reference_date, unit='D')}"
653
+ ds["time"].attrs["units"] = "days"
654
+
655
+ for var in ds.data_vars:
656
+ nan_check(ds[var].isel(time=0), mask_roms)
657
+
658
+ object.__setattr__(self, "ds", ds)
659
+
660
+ def plot(self, varname, time=0) -> None:
661
+ """
662
+ Plot the specified atmospheric forcing field for a given time slice.
663
+
664
+ Parameters
665
+ ----------
666
+ varname : str
667
+ The name of the atmospheric forcing field to plot. Options include:
668
+ - "uwnd": 10 meter wind in x-direction.
669
+ - "vwnd": 10 meter wind in y-direction.
670
+ - "swrad": Downward short-wave (solar) radiation.
671
+ - "lwrad": Downward long-wave (thermal) radiation.
672
+ - "Tair": Air temperature at 2m.
673
+ - "qair": Absolute humidity at 2m.
674
+ - "rain": Total precipitation.
675
+ time : int, optional
676
+ The time index to plot. Default is 0, which corresponds to the first
677
+ time slice.
678
+
679
+ Returns
680
+ -------
681
+ None
682
+ This method does not return any value. It generates and displays a plot.
683
+
684
+ Raises
685
+ ------
686
+ ValueError
687
+ If the specified varname is not one of the valid options.
688
+
689
+
690
+ Examples
691
+ --------
692
+ >>> atm_forcing = AtmosphericForcing(
693
+ ... grid=grid_info,
694
+ ... start_time=start_time,
695
+ ... end_time=end_time,
696
+ ... source="ERA5",
697
+ ... filename="atmospheric_data_*.nc",
698
+ ... swr_correction=swr_correction,
699
+ ... )
700
+ >>> atm_forcing.plot("uwnd", time=0)
701
+ """
702
+
703
+ title = "%s at time %s" % (
704
+ self.ds[varname].long_name,
705
+ np.datetime_as_string(self.ds["absolute_time"].isel(time=time), unit="s"),
706
+ )
707
+
708
+ field = self.ds[varname].isel(time=time).compute()
709
+
710
+ # choose colorbar
711
+ if varname in ["uwnd", "vwnd"]:
712
+ vmax = max(field.max().values, -field.min().values)
713
+ vmin = -vmax
714
+ cmap = plt.colormaps.get_cmap("RdBu_r")
715
+ else:
716
+ vmax = field.max().values
717
+ vmin = field.min().values
718
+ if varname in ["swrad", "lwrad", "Tair", "qair"]:
719
+ cmap = plt.colormaps.get_cmap("YlOrRd")
720
+ else:
721
+ cmap = plt.colormaps.get_cmap("YlGnBu")
722
+ cmap.set_bad(color="gray")
723
+
724
+ kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
725
+
726
+ _plot(
727
+ self.grid.ds,
728
+ field=field,
729
+ straddle=self.grid.straddle,
730
+ coarse_grid=self.use_coarse_grid,
731
+ title=title,
732
+ kwargs=kwargs,
733
+ c="g",
734
+ )
735
+
736
+ def save(self, filepath: str, time_chunk_size: int = 1) -> None:
737
+ """
738
+ Save the interpolated atmospheric forcing fields to netCDF4 files.
739
+
740
+ This method groups the dataset by year and month, chunks the data by the specified
741
+ time chunk size, and saves each chunked subset to a separate netCDF4 file named
742
+ according to the year, month, and day range if not a complete month of data is included.
743
+
744
+ Parameters
745
+ ----------
746
+ filepath : str
747
+ The base path and filename for the output files. The files will be named with
748
+ the format "filepath.YYYYMM.nc" if a full month of data is included, or
749
+ "filepath.YYYYMMDD-DD.nc" otherwise.
750
+ time_chunk_size : int, optional
751
+ Number of time slices to include in each chunk along the time dimension. Default is 1,
752
+ meaning each chunk contains one time slice.
753
+
754
+ Returns
755
+ -------
756
+ None
757
+ """
758
+
759
+ datasets = []
760
+ filenames = []
761
+ writes = []
762
+
763
+ # Group dataset by year
764
+ gb = self.ds.groupby("absolute_time.year")
765
+
766
+ for year, group_ds in gb:
767
+ # Further group each yearly group by month
768
+ sub_gb = group_ds.groupby("absolute_time.month")
769
+
770
+ for month, ds in sub_gb:
771
+ # Chunk the dataset by the specified time chunk size
772
+ ds = ds.chunk({"time": time_chunk_size})
773
+ datasets.append(ds)
774
+
775
+ # Determine the number of days in the month
776
+ num_days_in_month = calendar.monthrange(year, month)[1]
777
+ first_day = ds.time.absolute_time.dt.day.values[0]
778
+ last_day = ds.time.absolute_time.dt.day.values[-1]
779
+
780
+ # Create filename based on whether the dataset contains a full month
781
+ if first_day == 1 and last_day == num_days_in_month:
782
+ # Full month format: "filepath.YYYYMM.nc"
783
+ year_month_str = f"{year}{month:02}"
784
+ filename = f"{filepath}.{year_month_str}.nc"
785
+ else:
786
+ # Partial month format: "filepath.YYYYMMDD-DD.nc"
787
+ year_month_day_str = f"{year}{month:02}{first_day:02}-{last_day:02}"
788
+ filename = f"{filepath}.{year_month_day_str}.nc"
789
+ filenames.append(filename)
790
+
791
+ print("Saving the following files:")
792
+ for filename in filenames:
793
+ print(filename)
794
+
795
+ for ds, filename in zip(datasets, filenames):
796
+
797
+ # Prepare the dataset for writing to a netCDF file without immediately computing
798
+ write = ds.to_netcdf(filename, compute=False)
799
+ writes.append(write)
800
+
801
+ # Perform the actual write operations in parallel
802
+ dask.persist(*writes)
803
+
804
+ def to_yaml(self, filepath: str) -> None:
805
+ """
806
+ Export the parameters of the class to a YAML file, including the version of roms-tools.
807
+
808
+ Parameters
809
+ ----------
810
+ filepath : str
811
+ The path to the YAML file where the parameters will be saved.
812
+ """
813
+ # Serialize Grid data
814
+ grid_data = asdict(self.grid)
815
+ grid_data.pop("ds", None) # Exclude non-serializable fields
816
+ grid_data.pop("straddle", None)
817
+
818
+ if self.swr_correction:
819
+ swr_correction_data = asdict(self.swr_correction)
820
+ swr_correction_data.pop("ds", None)
821
+ else:
822
+ swr_correction_data = None
823
+
824
+ rivers_data = asdict(self.rivers) if self.rivers else None
825
+
826
+ # Include the version of roms-tools
827
+ try:
828
+ roms_tools_version = importlib.metadata.version("roms-tools")
829
+ except importlib.metadata.PackageNotFoundError:
830
+ roms_tools_version = "unknown"
831
+
832
+ # Create header
833
+ header = f"---\nroms_tools_version: {roms_tools_version}\n---\n"
834
+
835
+ # Create YAML data for Grid and optional attributes
836
+ grid_yaml_data = {"Grid": grid_data}
837
+ swr_correction_yaml_data = (
838
+ {"SWRCorrection": swr_correction_data} if swr_correction_data else {}
839
+ )
840
+ rivers_yaml_data = {"Rivers": rivers_data} if rivers_data else {}
841
+
842
+ # Combine all sections
843
+ atmospheric_forcing_data = {
844
+ "AtmosphericForcing": {
845
+ "filename": self.filename,
846
+ "start_time": self.start_time.isoformat(),
847
+ "end_time": self.end_time.isoformat(),
848
+ "model_reference_date": self.model_reference_date.isoformat(),
849
+ "source": self.source,
850
+ "use_coarse_grid": self.use_coarse_grid,
851
+ }
852
+ }
853
+
854
+ # Merge YAML data while excluding empty sections
855
+ yaml_data = {
856
+ **grid_yaml_data,
857
+ **swr_correction_yaml_data,
858
+ **rivers_yaml_data,
859
+ **atmospheric_forcing_data,
860
+ }
861
+
862
+ with open(filepath, "w") as file:
863
+ # Write header
864
+ file.write(header)
865
+ # Write YAML data
866
+ yaml.dump(yaml_data, file, default_flow_style=False)
867
+
868
+ @classmethod
869
+ def from_yaml(cls, filepath: str) -> "AtmosphericForcing":
870
+ """
871
+ Create an instance of the AtmosphericForcing class from a YAML file.
872
+
873
+ Parameters
874
+ ----------
875
+ filepath : str
876
+ The path to the YAML file from which the parameters will be read.
877
+
878
+ Returns
879
+ -------
880
+ AtmosphericForcing
881
+ An instance of the AtmosphericForcing class.
882
+ """
883
+ # Read the entire file content
884
+ with open(filepath, "r") as file:
885
+ file_content = file.read()
886
+
887
+ # Split the content into YAML documents
888
+ documents = list(yaml.safe_load_all(file_content))
889
+
890
+ swr_correction_data = None
891
+ rivers_data = None
892
+ atmospheric_forcing_data = None
893
+
894
+ # Process the YAML documents
895
+ for doc in documents:
896
+ if doc is None:
897
+ continue
898
+ if "AtmosphericForcing" in doc:
899
+ atmospheric_forcing_data = doc["AtmosphericForcing"]
900
+ if "SWRCorrection" in doc:
901
+ swr_correction_data = doc["SWRCorrection"]
902
+ if "Rivers" in doc:
903
+ rivers_data = doc["Rivers"]
904
+
905
+ if atmospheric_forcing_data is None:
906
+ raise ValueError(
907
+ "No AtmosphericForcing configuration found in the YAML file."
908
+ )
909
+
910
+ # Convert from string to datetime
911
+ for date_string in ["model_reference_date", "start_time", "end_time"]:
912
+ atmospheric_forcing_data[date_string] = datetime.fromisoformat(
913
+ atmospheric_forcing_data[date_string]
914
+ )
915
+
916
+ # Create Grid instance from the YAML file
917
+ grid = Grid.from_yaml(filepath)
918
+
919
+ if swr_correction_data is not None:
920
+ swr_correction = SWRCorrection.from_yaml(filepath)
921
+ else:
922
+ swr_correction = None
923
+
924
+ if rivers_data is not None:
925
+ rivers = Rivers.from_yaml(filepath)
926
+ else:
927
+ rivers = None
928
+
929
+ # Create and return an instance of AtmosphericForcing
930
+ return cls(
931
+ grid=grid,
932
+ swr_correction=swr_correction,
933
+ rivers=rivers,
934
+ **atmospheric_forcing_data,
935
+ )