roms-tools 0.1.0__py3-none-any.whl → 0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,191 +1,19 @@
1
1
  import xarray as xr
2
2
  import dask
3
- from dataclasses import dataclass, field
3
+ import yaml
4
+ import importlib.metadata
5
+ from dataclasses import dataclass, field, asdict
4
6
  from roms_tools.setup.grid import Grid
5
7
  from datetime import datetime
6
8
  import glob
7
9
  import numpy as np
8
10
  from typing import Optional, Dict
9
- from roms_tools.setup.fill import lateral_fill
11
+ from roms_tools.setup.fill import fill_and_interpolate
12
+ from roms_tools.setup.datasets import Dataset
13
+ from roms_tools.setup.utils import nan_check
14
+ from roms_tools.setup.plot import _plot
10
15
  import calendar
11
-
12
-
13
- @dataclass(frozen=True, kw_only=True)
14
- class ForcingDataset:
15
- """
16
- Represents forcing data on original grid.
17
-
18
- Parameters
19
- ----------
20
- filename : str
21
- The path to the data files. Can contain wildcards.
22
- start_time: datetime
23
- The start time for selecting relevant data.
24
- end_time: datetime
25
- The end time for selecting relevant data.
26
- dim_names: Dict[str, str], optional
27
- Dictionary specifying the names of dimensions in the dataset.
28
-
29
- Attributes
30
- ----------
31
- ds : xr.Dataset
32
- The xarray Dataset containing the forcing data on its original grid.
33
-
34
- Examples
35
- --------
36
- >>> dataset = ForcingDataset(
37
- ... filename="data.nc",
38
- ... start_time=datetime(2022, 1, 1),
39
- ... end_time=datetime(2022, 12, 31),
40
- ... )
41
- >>> dataset.load_data()
42
- >>> print(dataset.ds)
43
- <xarray.Dataset>
44
- Dimensions: ...
45
- """
46
-
47
- filename: str
48
- start_time: datetime
49
- end_time: datetime
50
- dim_names: Dict[str, str] = field(
51
- default_factory=lambda: {"longitude": "lon", "latitude": "lat", "time": "time"}
52
- )
53
-
54
- ds: xr.Dataset = field(init=False, repr=False)
55
-
56
- def __post_init__(self):
57
-
58
- ds = self.load_data()
59
-
60
- # Select relevant times
61
- times = (np.datetime64(self.start_time) < ds[self.dim_names["time"]]) & (
62
- ds[self.dim_names["time"]] < np.datetime64(self.end_time)
63
- )
64
- ds = ds.where(times, drop=True)
65
-
66
- # Make sure that latitude is ascending
67
- diff = np.diff(ds[self.dim_names["latitude"]])
68
- if np.all(diff < 0):
69
- ds = ds.isel(**{self.dim_names["latitude"]: slice(None, None, -1)})
70
-
71
- object.__setattr__(self, "ds", ds)
72
-
73
- def load_data(self) -> xr.Dataset:
74
- """
75
- Load dataset from the specified file.
76
-
77
- Returns
78
- -------
79
- ds : xr.Dataset
80
- The loaded xarray Dataset containing the forcing data.
81
-
82
- Raises
83
- ------
84
- FileNotFoundError
85
- If the specified file does not exist.
86
- """
87
-
88
- # Check if the file exists
89
- matching_files = glob.glob(self.filename)
90
- if not matching_files:
91
- raise FileNotFoundError(
92
- f"No files found matching the pattern '{self.filename}'."
93
- )
94
-
95
- # Load the dataset
96
- with dask.config.set(**{"array.slicing.split_large_chunks": False}):
97
- # initially, we wawnt time chunk size of 1 to enable quick .nan_check() and .plot() methods for AtmosphericForcing
98
- ds = xr.open_mfdataset(
99
- self.filename,
100
- combine="nested",
101
- concat_dim=self.dim_names["time"],
102
- coords="minimal",
103
- compat="override",
104
- chunks={self.dim_names["time"]: 1},
105
- )
106
-
107
- return ds
108
-
109
- def choose_subdomain(self, latitude_range, longitude_range, margin, straddle):
110
- """
111
- Selects a subdomain from the given xarray Dataset based on latitude and longitude ranges,
112
- extending the selection by the specified margin. Handles the conversion of longitude values
113
- in the dataset from one range to another.
114
-
115
- Parameters
116
- ----------
117
- latitude_range : tuple
118
- A tuple (lat_min, lat_max) specifying the minimum and maximum latitude values of the subdomain.
119
- longitude_range : tuple
120
- A tuple (lon_min, lon_max) specifying the minimum and maximum longitude values of the subdomain.
121
- margin : float
122
- Margin in degrees to extend beyond the specified latitude and longitude ranges when selecting the subdomain.
123
- straddle : bool
124
- If True, target longitudes are expected in the range [-180, 180].
125
- If False, target longitudes are expected in the range [0, 360].
126
-
127
- Returns
128
- -------
129
- xr.Dataset
130
- The subset of the original dataset representing the chosen subdomain, including an extended area
131
- to cover one extra grid point beyond the specified ranges.
132
-
133
- Raises
134
- ------
135
- ValueError
136
- If the selected latitude or longitude range does not intersect with the dataset.
137
- """
138
- lat_min, lat_max = latitude_range
139
- lon_min, lon_max = longitude_range
140
-
141
- lon = self.ds[self.dim_names["longitude"]]
142
- # Adjust longitude range if needed to match the expected range
143
- if not straddle:
144
- if lon.min() < -180:
145
- if lon_max + margin > 0:
146
- lon_min -= 360
147
- lon_max -= 360
148
- elif lon.min() < 0:
149
- if lon_max + margin > 180:
150
- lon_min -= 360
151
- lon_max -= 360
152
-
153
- if straddle:
154
- if lon.max() > 360:
155
- if lon_min - margin < 180:
156
- lon_min += 360
157
- lon_max += 360
158
- elif lon.max() > 180:
159
- if lon_min - margin < 0:
160
- lon_min += 360
161
- lon_max += 360
162
-
163
- # Select the subdomain
164
- subdomain = self.ds.sel(
165
- **{
166
- self.dim_names["latitude"]: slice(lat_min - margin, lat_max + margin),
167
- self.dim_names["longitude"]: slice(lon_min - margin, lon_max + margin),
168
- }
169
- )
170
-
171
- # Check if the selected subdomain has zero dimensions in latitude or longitude
172
- if subdomain[self.dim_names["latitude"]].size == 0:
173
- raise ValueError("Selected latitude range does not intersect with dataset.")
174
-
175
- if subdomain[self.dim_names["longitude"]].size == 0:
176
- raise ValueError(
177
- "Selected longitude range does not intersect with dataset."
178
- )
179
-
180
- # Adjust longitudes to expected range if needed
181
- lon = subdomain[self.dim_names["longitude"]]
182
- if straddle:
183
- subdomain[self.dim_names["longitude"]] = xr.where(lon > 180, lon - 360, lon)
184
- else:
185
- subdomain[self.dim_names["longitude"]] = xr.where(lon < 0, lon + 360, lon)
186
-
187
- # Set the modified subdomain to the object attribute
188
- object.__setattr__(self, "ds", subdomain)
16
+ import matplotlib.pyplot as plt
189
17
 
190
18
 
191
19
  @dataclass(frozen=True, kw_only=True)
@@ -445,6 +273,43 @@ class SWRCorrection:
445
273
 
446
274
  return field_interpolated
447
275
 
276
+ @classmethod
277
+ def from_yaml(cls, filepath: str) -> "SWRCorrection":
278
+ """
279
+ Create an instance of the class from a YAML file.
280
+
281
+ Parameters
282
+ ----------
283
+ filepath : str
284
+ The path to the YAML file from which the parameters will be read.
285
+
286
+ Returns
287
+ -------
288
+ Grid
289
+ An instance of the Grid class.
290
+ """
291
+ # Read the entire file content
292
+ with open(filepath, "r") as file:
293
+ file_content = file.read()
294
+
295
+ # Split the content into YAML documents
296
+ documents = list(yaml.safe_load_all(file_content))
297
+
298
+ swr_correction_data = None
299
+
300
+ # Iterate over documents to find the header and grid configuration
301
+ for doc in documents:
302
+ if doc is None:
303
+ continue
304
+ if "SWRCorrection" in doc:
305
+ swr_correction_data = doc["SWRCorrection"]
306
+ break
307
+
308
+ if swr_correction_data is None:
309
+ raise ValueError("No SWRCorrection configuration found in the YAML file.")
310
+
311
+ return cls(**swr_correction_data)
312
+
448
313
 
449
314
  @dataclass(frozen=True, kw_only=True)
450
315
  class Rivers:
@@ -463,6 +328,43 @@ class Rivers:
463
328
  if not self.filename:
464
329
  raise ValueError("The 'filename' must be provided.")
465
330
 
331
+ @classmethod
332
+ def from_yaml(cls, filepath: str) -> "Rivers":
333
+ """
334
+ Create an instance of the class from a YAML file.
335
+
336
+ Parameters
337
+ ----------
338
+ filepath : str
339
+ The path to the YAML file from which the parameters will be read.
340
+
341
+ Returns
342
+ -------
343
+ Grid
344
+ An instance of the Grid class.
345
+ """
346
+ # Read the entire file content
347
+ with open(filepath, "r") as file:
348
+ file_content = file.read()
349
+
350
+ # Split the content into YAML documents
351
+ documents = list(yaml.safe_load_all(file_content))
352
+
353
+ rivers_data = None
354
+
355
+ # Iterate over documents to find the header and grid configuration
356
+ for doc in documents:
357
+ if doc is None:
358
+ continue
359
+ if "Rivers" in doc:
360
+ rivers_data = doc
361
+ break
362
+
363
+ if rivers_data is None:
364
+ raise ValueError("No Rivers configuration found in the YAML file.")
365
+
366
+ return cls(**rivers_data)
367
+
466
368
 
467
369
  @dataclass(frozen=True, kw_only=True)
468
370
  class AtmosphericForcing:
@@ -482,7 +384,7 @@ class AtmosphericForcing:
482
384
  model_reference_date : datetime, optional
483
385
  Reference date for the model. Default is January 1, 2000.
484
386
  source : str, optional
485
- Source of the atmospheric forcing data. Default is "era5".
387
+ Source of the atmospheric forcing data. Default is "ERA5".
486
388
  filename: str
487
389
  Path to the atmospheric forcing source data file. Can contain wildcards.
488
390
  swr_correction : SWRCorrection
@@ -495,10 +397,6 @@ class AtmosphericForcing:
495
397
  ds : xr.Dataset
496
398
  Xarray Dataset containing the atmospheric forcing data.
497
399
 
498
- Notes
499
- -----
500
- This class represents atmospheric forcing data used in ocean modeling. It provides a convenient
501
- interface to work with forcing data including shortwave radiation correction and river forcing.
502
400
 
503
401
  Examples
504
402
  --------
@@ -509,7 +407,7 @@ class AtmosphericForcing:
509
407
  ... grid=grid_info,
510
408
  ... start_time=start_time,
511
409
  ... end_time=end_time,
512
- ... source="era5",
410
+ ... source="ERA5",
513
411
  ... filename="atmospheric_data_*.nc",
514
412
  ... swr_correction=swr_correction,
515
413
  ... )
@@ -520,7 +418,7 @@ class AtmosphericForcing:
520
418
  start_time: datetime
521
419
  end_time: datetime
522
420
  model_reference_date: datetime = datetime(2000, 1, 1)
523
- source: str = "era5"
421
+ source: str = "ERA5"
524
422
  filename: str
525
423
  swr_correction: Optional["SWRCorrection"] = None
526
424
  rivers: Optional["Rivers"] = None
@@ -528,6 +426,30 @@ class AtmosphericForcing:
528
426
 
529
427
  def __post_init__(self):
530
428
 
429
+ # Check that the source is "ERA5"
430
+ if self.source != "ERA5":
431
+ raise ValueError('Only "ERA5" is a valid option for source.')
432
+ if self.source == "ERA5":
433
+ dims = {"longitude": "longitude", "latitude": "latitude", "time": "time"}
434
+ varnames = {
435
+ "u10": "u10",
436
+ "v10": "v10",
437
+ "swr": "ssr",
438
+ "lwr": "strd",
439
+ "t2m": "t2m",
440
+ "d2m": "d2m",
441
+ "rain": "tp",
442
+ "mask": "sst",
443
+ }
444
+
445
+ data = Dataset(
446
+ filename=self.filename,
447
+ start_time=self.start_time,
448
+ end_time=self.end_time,
449
+ var_names=varnames.values(),
450
+ dim_names=dims,
451
+ )
452
+
531
453
  if self.use_coarse_grid:
532
454
  if "lon_coarse" not in self.grid.ds:
533
455
  raise ValueError(
@@ -542,16 +464,6 @@ class AtmosphericForcing:
542
464
  lat = self.grid.ds.lat_rho
543
465
  angle = self.grid.ds.angle
544
466
 
545
- if self.source == "era5":
546
- dims = {"longitude": "longitude", "latitude": "latitude", "time": "time"}
547
-
548
- data = ForcingDataset(
549
- filename=self.filename,
550
- start_time=self.start_time,
551
- end_time=self.end_time,
552
- dim_names=dims,
553
- )
554
-
555
467
  # operate on longitudes between -180 and 180 unless ROMS domain lies at least 5 degrees in lontitude away from Greenwich meridian
556
468
  lon = xr.where(lon > 180, lon - 360, lon)
557
469
  straddle = True
@@ -559,9 +471,12 @@ class AtmosphericForcing:
559
471
  lon = xr.where(lon < 0, lon + 360, lon)
560
472
  straddle = False
561
473
 
474
+ # The following consists of two steps:
562
475
  # Step 1: Choose subdomain of forcing data including safety margin for interpolation, and Step 2: Convert to the proper longitude range.
563
- # Step 1 is necessary to avoid discontinuous longitudes that could be introduced by Step 2.
564
- # Discontinuous longitudes can lead to artifacts in the interpolation process. Specifically, if there is a data gap,
476
+ # We perform these two steps for two reasons:
477
+ # A) Since the horizontal dimensions consist of a single chunk, selecting a subdomain before interpolation is a lot more performant.
478
+ # B) Step 1 is necessary to avoid discontinuous longitudes that could be introduced by Step 2. Specifically, discontinuous longitudes
479
+ # can lead to artifacts in the interpolation process. Specifically, if there is a data gap if data is not global,
565
480
  # discontinuous longitudes could result in values that appear to come from a distant location instead of producing NaNs.
566
481
  # These NaNs are important as they can be identified and handled appropriately by the nan_check function.
567
482
  data.choose_subdomain(
@@ -572,42 +487,33 @@ class AtmosphericForcing:
572
487
  )
573
488
 
574
489
  # interpolate onto desired grid
575
- if self.source == "era5":
576
- mask = xr.where(data.ds["sst"].isel(time=0).isnull(), 0, 1)
577
- varnames = {
578
- "u10": "u10",
579
- "v10": "v10",
580
- "swr": "ssr",
581
- "lwr": "strd",
582
- "t2m": "t2m",
583
- "d2m": "d2m",
584
- "rain": "tp",
585
- }
586
-
587
490
  coords = {dims["latitude"]: lat, dims["longitude"]: lon}
588
- u10 = self._interpolate(
589
- data.ds[varnames["u10"]], mask, coords=coords, method="linear"
590
- )
591
- v10 = self._interpolate(
592
- data.ds[varnames["v10"]], mask, coords=coords, method="linear"
593
- )
594
- swr = self._interpolate(
595
- data.ds[varnames["swr"]], mask, coords=coords, method="linear"
596
- )
597
- lwr = self._interpolate(
598
- data.ds[varnames["lwr"]], mask, coords=coords, method="linear"
599
- )
600
- t2m = self._interpolate(
601
- data.ds[varnames["t2m"]], mask, coords=coords, method="linear"
602
- )
603
- d2m = self._interpolate(
604
- data.ds[varnames["d2m"]], mask, coords=coords, method="linear"
605
- )
606
- rain = self._interpolate(
607
- data.ds[varnames["rain"]], mask, coords=coords, method="linear"
608
- )
609
491
 
610
- if self.source == "era5":
492
+ data_vars = {}
493
+
494
+ mask = xr.where(data.ds[varnames["mask"]].isel(time=0).isnull(), 0, 1)
495
+
496
+ # Fill and interpolate each variable
497
+ for var in varnames.keys():
498
+ if var != "mask":
499
+ data_vars[var] = fill_and_interpolate(
500
+ data.ds[varnames[var]],
501
+ mask,
502
+ list(coords.keys()),
503
+ coords,
504
+ method="linear",
505
+ )
506
+
507
+ # Access the interpolated variables using data_vars dictionary
508
+ u10 = data_vars["u10"]
509
+ v10 = data_vars["v10"]
510
+ swr = data_vars["swr"]
511
+ lwr = data_vars["lwr"]
512
+ t2m = data_vars["t2m"]
513
+ d2m = data_vars["d2m"]
514
+ rain = data_vars["rain"]
515
+
516
+ if self.source == "ERA5":
611
517
  # translate radiation to fluxes. ERA5 stores values integrated over 1 hour.
612
518
  swr = swr / 3600 # from J/m^2 to W/m^2
613
519
  lwr = lwr / 3600 # from J/m^2 to W/m^2
@@ -650,8 +556,12 @@ class AtmosphericForcing:
650
556
  self.swr_correction.dim_names["latitude"]: lat,
651
557
  self.swr_correction.dim_names["longitude"]: lon,
652
558
  }
653
- corr_factor = self._interpolate(
654
- corr_factor, mask, coords=coords_correction, method="linear"
559
+ corr_factor = fill_and_interpolate(
560
+ corr_factor,
561
+ mask,
562
+ list(coords_correction.keys()),
563
+ coords_correction,
564
+ method="linear",
655
565
  )
656
566
 
657
567
  # temporal interpolation
@@ -700,11 +610,22 @@ class AtmosphericForcing:
700
610
  ds["rain"].attrs["long_name"] = "Total precipitation"
701
611
  ds["rain"].attrs["units"] = "cm/day"
702
612
 
703
- ds.attrs["Title"] = "ROMS bulk surface forcing file produced by roms-tools"
613
+ ds.attrs["title"] = "ROMS atmospheric forcing file created by ROMS-Tools"
614
+ # Include the version of roms-tools
615
+ try:
616
+ roms_tools_version = importlib.metadata.version("roms-tools")
617
+ except importlib.metadata.PackageNotFoundError:
618
+ roms_tools_version = "unknown"
619
+ ds.attrs["roms_tools_version"] = roms_tools_version
620
+ ds.attrs["start_time"] = str(self.start_time)
621
+ ds.attrs["end_time"] = str(self.end_time)
622
+ ds.attrs["model_reference_date"] = str(self.model_reference_date)
623
+ ds.attrs["source"] = self.source
624
+ ds.attrs["use_coarse_grid"] = str(self.use_coarse_grid)
625
+ ds.attrs["swr_correction"] = str(self.swr_correction is not None)
626
+ ds.attrs["rivers"] = str(self.rivers is not None)
704
627
 
705
628
  ds = ds.assign_coords({"lon": lon, "lat": lat})
706
- if dims["time"] != "time":
707
- ds = ds.rename({dims["time"]: "time"})
708
629
  if self.use_coarse_grid:
709
630
  ds = ds.rename({"eta_coarse": "eta_rho", "xi_coarse": "xi_rho"})
710
631
  mask_roms = self.grid.ds["mask_coarse"].rename(
@@ -713,84 +634,28 @@ class AtmosphericForcing:
713
634
  else:
714
635
  mask_roms = self.grid.ds["mask_rho"]
715
636
 
716
- object.__setattr__(self, "ds", ds)
717
-
718
- self.nan_check(mask_roms, time=0)
719
-
720
- @staticmethod
721
- def _interpolate(field, mask, coords, method="linear"):
722
- """
723
- Interpolate a field using specified coordinates and a given method.
724
-
725
- Parameters
726
- ----------
727
- field : xr.DataArray
728
- The data array to be interpolated.
729
-
730
- mask : xr.DataArray
731
- A data array with same spatial dimensions as `field`, where `1` indicates wet (ocean)
732
- points and `0` indicates land points.
733
-
734
- coords : dict
735
- A dictionary specifying the target coordinates for interpolation. The keys
736
- should match the dimensions of `field` (e.g., {"longitude": lon_values, "latitude": lat_values}).
737
-
738
- method : str, optional, default='linear'
739
- The interpolation method to use. Valid options are those supported by
740
- `xarray.DataArray.interp`.
741
-
742
- Returns
743
- -------
744
- xr.DataArray
745
- The interpolated data array.
746
-
747
- Notes
748
- -----
749
- This method first sets land values to NaN based on the provided mask. It then uses the
750
- `lateral_fill` function to propagate ocean values. These two steps serve the purpose to
751
- avoid interpolation across the land-ocean boundary. Finally, it performs interpolation
752
- over the specified coordinates.
753
-
754
- """
755
-
756
- dims = list(coords.keys())
757
-
758
- # set land values to nan
759
- field = field.where(mask)
760
- # propagate ocean values into land interior before interpolation
761
- field = lateral_fill(field, 1 - mask, dims)
762
- # interpolate
763
- field_interpolated = field.interp(**coords, method=method).drop_vars(dims)
764
-
765
- return field_interpolated
637
+ if dims["time"] != "time":
638
+ ds = ds.rename({dims["time"]: "time"})
766
639
 
767
- def nan_check(self, mask, time=0) -> None:
768
- """
769
- Checks for NaN values at wet points in all variables of the dataset at a specified time step.
640
+ # Preserve the original time coordinate for readability
641
+ ds = ds.assign_coords({"absolute_time": ds["time"]})
770
642
 
771
- Parameters
772
- ----------
773
- mask : array-like
774
- A boolean mask indicating the wet points in the dataset.
775
- time : int
776
- The time step at which to check for NaN values. Default is 0.
643
+ # Translate the time coordinate to days since the model reference date
644
+ model_reference_date = np.datetime64(self.model_reference_date)
777
645
 
778
- Raises
779
- ------
780
- ValueError
781
- If any variable contains NaN values at the specified time step.
646
+ # Convert the time coordinate to the format expected by ROMS (days since model reference date)
647
+ ds["time"] = (
648
+ (ds["time"] - model_reference_date).astype("float64") / 3600 / 24 * 1e-9
649
+ )
650
+ ds["time"].attrs[
651
+ "long_name"
652
+ ] = f"time since {np.datetime_as_string(model_reference_date, unit='D')}"
653
+ ds["time"].attrs["units"] = "days"
782
654
 
783
- """
655
+ for var in ds.data_vars:
656
+ nan_check(ds[var].isel(time=0), mask_roms)
784
657
 
785
- for var in self.ds.data_vars:
786
- da = xr.where(mask == 1, self.ds[var].isel(time=time), 0)
787
- if da.isnull().any().values:
788
- raise ValueError(
789
- f"NaN values found in the variable '{var}' at time step {time} over the ocean. This likely "
790
- "occurs because the ROMS grid, including a small safety margin for interpolation, is not "
791
- "fully contained within the dataset's longitude/latitude range. Please ensure that the "
792
- "dataset covers the entire area required by the ROMS grid."
793
- )
658
+ object.__setattr__(self, "ds", ds)
794
659
 
795
660
  def plot(self, varname, time=0) -> None:
796
661
  """
@@ -828,87 +693,45 @@ class AtmosphericForcing:
828
693
  ... grid=grid_info,
829
694
  ... start_time=start_time,
830
695
  ... end_time=end_time,
831
- ... source="era5",
696
+ ... source="ERA5",
832
697
  ... filename="atmospheric_data_*.nc",
833
698
  ... swr_correction=swr_correction,
834
699
  ... )
835
700
  >>> atm_forcing.plot("uwnd", time=0)
836
701
  """
837
702
 
838
- import cartopy.crs as ccrs
839
- import matplotlib.pyplot as plt
840
-
841
- lon_deg = self.ds.lon
842
- lat_deg = self.ds.lat
843
-
844
- # check if North or South pole are in domain
845
- if lat_deg.max().values > 89 or lat_deg.min().values < -89:
846
- raise NotImplementedError(
847
- "Plotting the atmospheric forcing is not implemented for the case that the domain contains the North or South pole."
848
- )
849
-
850
- if self.grid.straddle:
851
- lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
852
-
853
- # Define projections
854
- proj = ccrs.PlateCarree()
855
-
856
- trans = ccrs.NearsidePerspective(
857
- central_longitude=lon_deg.mean().values,
858
- central_latitude=lat_deg.mean().values,
703
+ title = "%s at time %s" % (
704
+ self.ds[varname].long_name,
705
+ np.datetime_as_string(self.ds["absolute_time"].isel(time=time), unit="s"),
859
706
  )
860
707
 
861
- lon_deg = lon_deg.values
862
- lat_deg = lat_deg.values
863
-
864
- # find corners
865
- (lo1, la1) = (lon_deg[0, 0], lat_deg[0, 0])
866
- (lo2, la2) = (lon_deg[0, -1], lat_deg[0, -1])
867
- (lo3, la3) = (lon_deg[-1, -1], lat_deg[-1, -1])
868
- (lo4, la4) = (lon_deg[-1, 0], lat_deg[-1, 0])
869
-
870
- # transform coordinates to projected space
871
- lo1t, la1t = trans.transform_point(lo1, la1, proj)
872
- lo2t, la2t = trans.transform_point(lo2, la2, proj)
873
- lo3t, la3t = trans.transform_point(lo3, la3, proj)
874
- lo4t, la4t = trans.transform_point(lo4, la4, proj)
875
-
876
- plt.figure(figsize=(10, 10))
877
- ax = plt.axes(projection=trans)
878
-
879
- ax.plot(
880
- [lo1t, lo2t, lo3t, lo4t, lo1t],
881
- [la1t, la2t, la3t, la4t, la1t],
882
- "go-",
883
- )
884
-
885
- ax.coastlines(
886
- resolution="50m", linewidth=0.5, color="black"
887
- ) # add map of coastlines
888
- ax.gridlines()
889
-
890
708
  field = self.ds[varname].isel(time=time).compute()
709
+
710
+ # choose colorbar
891
711
  if varname in ["uwnd", "vwnd"]:
892
712
  vmax = max(field.max().values, -field.min().values)
893
713
  vmin = -vmax
894
- cmap = "RdBu_r"
714
+ cmap = plt.colormaps.get_cmap("RdBu_r")
895
715
  else:
896
716
  vmax = field.max().values
897
717
  vmin = field.min().values
898
718
  if varname in ["swrad", "lwrad", "Tair", "qair"]:
899
- cmap = "YlOrRd"
719
+ cmap = plt.colormaps.get_cmap("YlOrRd")
900
720
  else:
901
- cmap = "YlGnBu"
902
-
903
- p = ax.pcolormesh(
904
- lon_deg, lat_deg, field, transform=proj, vmax=vmax, vmin=vmin, cmap=cmap
721
+ cmap = plt.colormaps.get_cmap("YlGnBu")
722
+ cmap.set_bad(color="gray")
723
+
724
+ kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
725
+
726
+ _plot(
727
+ self.grid.ds,
728
+ field=field,
729
+ straddle=self.grid.straddle,
730
+ coarse_grid=self.use_coarse_grid,
731
+ title=title,
732
+ kwargs=kwargs,
733
+ c="g",
905
734
  )
906
- plt.colorbar(p, label=field.units)
907
- ax.set_title(
908
- "%s at time %s"
909
- % (field.long_name, np.datetime_as_string(field.time, unit="s"))
910
- )
911
- plt.show()
912
735
 
913
736
  def save(self, filepath: str, time_chunk_size: int = 1) -> None:
914
737
  """
@@ -938,11 +761,11 @@ class AtmosphericForcing:
938
761
  writes = []
939
762
 
940
763
  # Group dataset by year
941
- gb = self.ds.groupby("time.year")
764
+ gb = self.ds.groupby("absolute_time.year")
942
765
 
943
766
  for year, group_ds in gb:
944
767
  # Further group each yearly group by month
945
- sub_gb = group_ds.groupby("time.month")
768
+ sub_gb = group_ds.groupby("absolute_time.month")
946
769
 
947
770
  for month, ds in sub_gb:
948
771
  # Chunk the dataset by the specified time chunk size
@@ -951,8 +774,8 @@ class AtmosphericForcing:
951
774
 
952
775
  # Determine the number of days in the month
953
776
  num_days_in_month = calendar.monthrange(year, month)[1]
954
- first_day = ds.time.dt.day.values[0]
955
- last_day = ds.time.dt.day.values[-1]
777
+ first_day = ds.time.absolute_time.dt.day.values[0]
778
+ last_day = ds.time.absolute_time.dt.day.values[-1]
956
779
 
957
780
  # Create filename based on whether the dataset contains a full month
958
781
  if first_day == 1 and last_day == num_days_in_month:
@@ -971,23 +794,142 @@ class AtmosphericForcing:
971
794
 
972
795
  for ds, filename in zip(datasets, filenames):
973
796
 
974
- # Translate the time coordinate to days since the model reference date
975
- model_reference_date = np.datetime64(self.model_reference_date)
976
-
977
- # Preserve the original time coordinate for readability
978
- ds["Time"] = ds["time"]
979
-
980
- # Convert the time coordinate to the format expected by ROMS (days since model reference date)
981
- ds["time"] = (
982
- (ds["time"] - model_reference_date).astype("float64") / 3600 / 24 * 1e-9
983
- )
984
- ds["time"].attrs[
985
- "long_name"
986
- ] = f"time since {np.datetime_as_string(model_reference_date, unit='D')}"
987
-
988
797
  # Prepare the dataset for writing to a netCDF file without immediately computing
989
798
  write = ds.to_netcdf(filename, compute=False)
990
799
  writes.append(write)
991
800
 
992
801
  # Perform the actual write operations in parallel
993
- dask.compute(*writes)
802
+ dask.persist(*writes)
803
+
804
+ def to_yaml(self, filepath: str) -> None:
805
+ """
806
+ Export the parameters of the class to a YAML file, including the version of roms-tools.
807
+
808
+ Parameters
809
+ ----------
810
+ filepath : str
811
+ The path to the YAML file where the parameters will be saved.
812
+ """
813
+ # Serialize Grid data
814
+ grid_data = asdict(self.grid)
815
+ grid_data.pop("ds", None) # Exclude non-serializable fields
816
+ grid_data.pop("straddle", None)
817
+
818
+ if self.swr_correction:
819
+ swr_correction_data = asdict(self.swr_correction)
820
+ swr_correction_data.pop("ds", None)
821
+ else:
822
+ swr_correction_data = None
823
+
824
+ rivers_data = asdict(self.rivers) if self.rivers else None
825
+
826
+ # Include the version of roms-tools
827
+ try:
828
+ roms_tools_version = importlib.metadata.version("roms-tools")
829
+ except importlib.metadata.PackageNotFoundError:
830
+ roms_tools_version = "unknown"
831
+
832
+ # Create header
833
+ header = f"---\nroms_tools_version: {roms_tools_version}\n---\n"
834
+
835
+ # Create YAML data for Grid and optional attributes
836
+ grid_yaml_data = {"Grid": grid_data}
837
+ swr_correction_yaml_data = (
838
+ {"SWRCorrection": swr_correction_data} if swr_correction_data else {}
839
+ )
840
+ rivers_yaml_data = {"Rivers": rivers_data} if rivers_data else {}
841
+
842
+ # Combine all sections
843
+ atmospheric_forcing_data = {
844
+ "AtmosphericForcing": {
845
+ "filename": self.filename,
846
+ "start_time": self.start_time.isoformat(),
847
+ "end_time": self.end_time.isoformat(),
848
+ "model_reference_date": self.model_reference_date.isoformat(),
849
+ "source": self.source,
850
+ "use_coarse_grid": self.use_coarse_grid,
851
+ }
852
+ }
853
+
854
+ # Merge YAML data while excluding empty sections
855
+ yaml_data = {
856
+ **grid_yaml_data,
857
+ **swr_correction_yaml_data,
858
+ **rivers_yaml_data,
859
+ **atmospheric_forcing_data,
860
+ }
861
+
862
+ with open(filepath, "w") as file:
863
+ # Write header
864
+ file.write(header)
865
+ # Write YAML data
866
+ yaml.dump(yaml_data, file, default_flow_style=False)
867
+
868
+ @classmethod
869
+ def from_yaml(cls, filepath: str) -> "AtmosphericForcing":
870
+ """
871
+ Create an instance of the AtmosphericForcing class from a YAML file.
872
+
873
+ Parameters
874
+ ----------
875
+ filepath : str
876
+ The path to the YAML file from which the parameters will be read.
877
+
878
+ Returns
879
+ -------
880
+ AtmosphericForcing
881
+ An instance of the AtmosphericForcing class.
882
+ """
883
+ # Read the entire file content
884
+ with open(filepath, "r") as file:
885
+ file_content = file.read()
886
+
887
+ # Split the content into YAML documents
888
+ documents = list(yaml.safe_load_all(file_content))
889
+
890
+ swr_correction_data = None
891
+ rivers_data = None
892
+ atmospheric_forcing_data = None
893
+
894
+ # Process the YAML documents
895
+ for doc in documents:
896
+ if doc is None:
897
+ continue
898
+ if "AtmosphericForcing" in doc:
899
+ atmospheric_forcing_data = doc["AtmosphericForcing"]
900
+ if "SWRCorrection" in doc:
901
+ swr_correction_data = doc["SWRCorrection"]
902
+ if "Rivers" in doc:
903
+ rivers_data = doc["Rivers"]
904
+
905
+ if atmospheric_forcing_data is None:
906
+ raise ValueError(
907
+ "No AtmosphericForcing configuration found in the YAML file."
908
+ )
909
+
910
+ # Convert from string to datetime
911
+ for date_string in ["model_reference_date", "start_time", "end_time"]:
912
+ atmospheric_forcing_data[date_string] = datetime.fromisoformat(
913
+ atmospheric_forcing_data[date_string]
914
+ )
915
+
916
+ # Create Grid instance from the YAML file
917
+ grid = Grid.from_yaml(filepath)
918
+
919
+ if swr_correction_data is not None:
920
+ swr_correction = SWRCorrection.from_yaml(filepath)
921
+ else:
922
+ swr_correction = None
923
+
924
+ if rivers_data is not None:
925
+ rivers = Rivers.from_yaml(filepath)
926
+ else:
927
+ rivers = None
928
+
929
+ # Create and return an instance of AtmosphericForcing
930
+ return cls(
931
+ grid=grid,
932
+ swr_correction=swr_correction,
933
+ rivers=rivers,
934
+ **atmospheric_forcing_data,
935
+ )