roms-tools 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,993 +0,0 @@
1
- import xarray as xr
2
- import dask
3
- from dataclasses import dataclass, field
4
- from roms_tools.setup.grid import Grid
5
- from datetime import datetime
6
- import glob
7
- import numpy as np
8
- from typing import Optional, Dict
9
- from roms_tools.setup.fill import lateral_fill
10
- import calendar
11
-
12
-
13
- @dataclass(frozen=True, kw_only=True)
14
- class ForcingDataset:
15
- """
16
- Represents forcing data on original grid.
17
-
18
- Parameters
19
- ----------
20
- filename : str
21
- The path to the data files. Can contain wildcards.
22
- start_time: datetime
23
- The start time for selecting relevant data.
24
- end_time: datetime
25
- The end time for selecting relevant data.
26
- dim_names: Dict[str, str], optional
27
- Dictionary specifying the names of dimensions in the dataset.
28
-
29
- Attributes
30
- ----------
31
- ds : xr.Dataset
32
- The xarray Dataset containing the forcing data on its original grid.
33
-
34
- Examples
35
- --------
36
- >>> dataset = ForcingDataset(
37
- ... filename="data.nc",
38
- ... start_time=datetime(2022, 1, 1),
39
- ... end_time=datetime(2022, 12, 31),
40
- ... )
41
- >>> dataset.load_data()
42
- >>> print(dataset.ds)
43
- <xarray.Dataset>
44
- Dimensions: ...
45
- """
46
-
47
- filename: str
48
- start_time: datetime
49
- end_time: datetime
50
- dim_names: Dict[str, str] = field(
51
- default_factory=lambda: {"longitude": "lon", "latitude": "lat", "time": "time"}
52
- )
53
-
54
- ds: xr.Dataset = field(init=False, repr=False)
55
-
56
- def __post_init__(self):
57
-
58
- ds = self.load_data()
59
-
60
- # Select relevant times
61
- times = (np.datetime64(self.start_time) < ds[self.dim_names["time"]]) & (
62
- ds[self.dim_names["time"]] < np.datetime64(self.end_time)
63
- )
64
- ds = ds.where(times, drop=True)
65
-
66
- # Make sure that latitude is ascending
67
- diff = np.diff(ds[self.dim_names["latitude"]])
68
- if np.all(diff < 0):
69
- ds = ds.isel(**{self.dim_names["latitude"]: slice(None, None, -1)})
70
-
71
- object.__setattr__(self, "ds", ds)
72
-
73
- def load_data(self) -> xr.Dataset:
74
- """
75
- Load dataset from the specified file.
76
-
77
- Returns
78
- -------
79
- ds : xr.Dataset
80
- The loaded xarray Dataset containing the forcing data.
81
-
82
- Raises
83
- ------
84
- FileNotFoundError
85
- If the specified file does not exist.
86
- """
87
-
88
- # Check if the file exists
89
- matching_files = glob.glob(self.filename)
90
- if not matching_files:
91
- raise FileNotFoundError(
92
- f"No files found matching the pattern '{self.filename}'."
93
- )
94
-
95
- # Load the dataset
96
- with dask.config.set(**{"array.slicing.split_large_chunks": False}):
97
- # initially, we wawnt time chunk size of 1 to enable quick .nan_check() and .plot() methods for AtmosphericForcing
98
- ds = xr.open_mfdataset(
99
- self.filename,
100
- combine="nested",
101
- concat_dim=self.dim_names["time"],
102
- coords="minimal",
103
- compat="override",
104
- chunks={self.dim_names["time"]: 1},
105
- )
106
-
107
- return ds
108
-
109
- def choose_subdomain(self, latitude_range, longitude_range, margin, straddle):
110
- """
111
- Selects a subdomain from the given xarray Dataset based on latitude and longitude ranges,
112
- extending the selection by the specified margin. Handles the conversion of longitude values
113
- in the dataset from one range to another.
114
-
115
- Parameters
116
- ----------
117
- latitude_range : tuple
118
- A tuple (lat_min, lat_max) specifying the minimum and maximum latitude values of the subdomain.
119
- longitude_range : tuple
120
- A tuple (lon_min, lon_max) specifying the minimum and maximum longitude values of the subdomain.
121
- margin : float
122
- Margin in degrees to extend beyond the specified latitude and longitude ranges when selecting the subdomain.
123
- straddle : bool
124
- If True, target longitudes are expected in the range [-180, 180].
125
- If False, target longitudes are expected in the range [0, 360].
126
-
127
- Returns
128
- -------
129
- xr.Dataset
130
- The subset of the original dataset representing the chosen subdomain, including an extended area
131
- to cover one extra grid point beyond the specified ranges.
132
-
133
- Raises
134
- ------
135
- ValueError
136
- If the selected latitude or longitude range does not intersect with the dataset.
137
- """
138
- lat_min, lat_max = latitude_range
139
- lon_min, lon_max = longitude_range
140
-
141
- lon = self.ds[self.dim_names["longitude"]]
142
- # Adjust longitude range if needed to match the expected range
143
- if not straddle:
144
- if lon.min() < -180:
145
- if lon_max + margin > 0:
146
- lon_min -= 360
147
- lon_max -= 360
148
- elif lon.min() < 0:
149
- if lon_max + margin > 180:
150
- lon_min -= 360
151
- lon_max -= 360
152
-
153
- if straddle:
154
- if lon.max() > 360:
155
- if lon_min - margin < 180:
156
- lon_min += 360
157
- lon_max += 360
158
- elif lon.max() > 180:
159
- if lon_min - margin < 0:
160
- lon_min += 360
161
- lon_max += 360
162
-
163
- # Select the subdomain
164
- subdomain = self.ds.sel(
165
- **{
166
- self.dim_names["latitude"]: slice(lat_min - margin, lat_max + margin),
167
- self.dim_names["longitude"]: slice(lon_min - margin, lon_max + margin),
168
- }
169
- )
170
-
171
- # Check if the selected subdomain has zero dimensions in latitude or longitude
172
- if subdomain[self.dim_names["latitude"]].size == 0:
173
- raise ValueError("Selected latitude range does not intersect with dataset.")
174
-
175
- if subdomain[self.dim_names["longitude"]].size == 0:
176
- raise ValueError(
177
- "Selected longitude range does not intersect with dataset."
178
- )
179
-
180
- # Adjust longitudes to expected range if needed
181
- lon = subdomain[self.dim_names["longitude"]]
182
- if straddle:
183
- subdomain[self.dim_names["longitude"]] = xr.where(lon > 180, lon - 360, lon)
184
- else:
185
- subdomain[self.dim_names["longitude"]] = xr.where(lon < 0, lon + 360, lon)
186
-
187
- # Set the modified subdomain to the object attribute
188
- object.__setattr__(self, "ds", subdomain)
189
-
190
-
191
- @dataclass(frozen=True, kw_only=True)
192
- class SWRCorrection:
193
- """
194
- Configuration for shortwave radiation correction.
195
-
196
- Parameters
197
- ----------
198
- filename : str
199
- Filename of the correction data.
200
- varname : str
201
- Variable identifier for the correction.
202
- dim_names: Dict[str, str], optional
203
- Dictionary specifying the names of dimensions in the dataset.
204
- Default is {"longitude": "lon", "latitude": "lat", "time": "time"}.
205
- temporal_resolution : str, optional
206
- Temporal resolution of the correction data. Default is "climatology".
207
-
208
- Attributes
209
- ----------
210
- ds : xr.Dataset
211
- The loaded xarray Dataset containing the correction data.
212
-
213
- Examples
214
- --------
215
- >>> swr_correction = SWRCorrection(
216
- ... filename="correction_data.nc",
217
- ... varname="corr",
218
- ... dim_names={
219
- ... "time": "time",
220
- ... "latitude": "latitude",
221
- ... "longitude": "longitude",
222
- ... },
223
- ... temporal_resolution="climatology",
224
- ... )
225
- """
226
-
227
- filename: str
228
- varname: str
229
- dim_names: Dict[str, str] = field(
230
- default_factory=lambda: {
231
- "longitude": "longitude",
232
- "latitude": "latitutde",
233
- "time": "time",
234
- }
235
- )
236
- temporal_resolution: str = "climatology"
237
- ds: xr.Dataset = field(init=False, repr=False)
238
-
239
- def __post_init__(self):
240
- if self.temporal_resolution != "climatology":
241
- raise NotImplementedError(
242
- f"temporal_resolution must be 'climatology', got {self.temporal_resolution}"
243
- )
244
-
245
- ds = self._load_data()
246
- self._check_dataset(ds)
247
- ds = self._ensure_latitude_ascending(ds)
248
-
249
- object.__setattr__(self, "ds", ds)
250
-
251
- def _load_data(self):
252
- """
253
- Load data from the specified file.
254
-
255
- Returns
256
- -------
257
- ds : xr.Dataset
258
- The loaded xarray Dataset containing the correction data.
259
-
260
- Raises
261
- ------
262
- FileNotFoundError
263
- If the specified file does not exist.
264
-
265
- """
266
- # Check if the file exists
267
-
268
- # Check if any file matching the wildcard pattern exists
269
- matching_files = glob.glob(self.filename)
270
- if not matching_files:
271
- raise FileNotFoundError(
272
- f"No files found matching the pattern '{self.filename}'."
273
- )
274
-
275
- # Load the dataset
276
- ds = xr.open_dataset(
277
- self.filename,
278
- chunks={
279
- self.dim_names["time"]: -1,
280
- self.dim_names["latitude"]: -1,
281
- self.dim_names["longitude"]: -1,
282
- },
283
- )
284
-
285
- return ds
286
-
287
- def _check_dataset(self, ds: xr.Dataset) -> None:
288
- """
289
- Check if the dataset contains the specified variable and dimensions.
290
-
291
- Parameters
292
- ----------
293
- ds : xr.Dataset
294
- The xarray Dataset to check.
295
-
296
- Raises
297
- ------
298
- ValueError
299
- If the dataset does not contain the specified variable or dimensions.
300
- """
301
- if self.varname not in ds:
302
- raise ValueError(
303
- f"The dataset does not contain the variable '{self.varname}'."
304
- )
305
-
306
- for dim in self.dim_names.values():
307
- if dim not in ds.dims:
308
- raise ValueError(f"The dataset does not contain the dimension '{dim}'.")
309
-
310
- def _ensure_latitude_ascending(self, ds: xr.Dataset) -> xr.Dataset:
311
- """
312
- Ensure that the latitude dimension is in ascending order.
313
-
314
- Parameters
315
- ----------
316
- ds : xr.Dataset
317
- The xarray Dataset to check.
318
-
319
- Returns
320
- -------
321
- ds : xr.Dataset
322
- The xarray Dataset with latitude in ascending order.
323
- """
324
- # Make sure that latitude is ascending
325
- lat_diff = np.diff(ds[self.dim_names["latitude"]])
326
- if np.all(lat_diff < 0):
327
- ds = ds.isel(**{self.dim_names["latitude"]: slice(None, None, -1)})
328
-
329
- return ds
330
-
331
- def _handle_longitudes(self, straddle: bool) -> None:
332
- """
333
- Handles the conversion of longitude values in the dataset from one range to another.
334
-
335
- Parameters
336
- ----------
337
- straddle : bool
338
- If True, target longitudes are in range [-180, 180].
339
- If False, target longitudes are in range [0, 360].
340
-
341
- Raises
342
- ------
343
- ValueError: If the conversion results in discontinuous longitudes.
344
- """
345
- lon = self.ds[self.dim_names["longitude"]]
346
-
347
- if lon.min().values < 0 and not straddle:
348
- # Convert from [-180, 180] to [0, 360]
349
- self.ds[self.dim_names["longitude"]] = xr.where(lon < 0, lon + 360, lon)
350
-
351
- if lon.max().values > 180 and straddle:
352
- # Convert from [0, 360] to [-180, 180]
353
- self.ds[self.dim_names["longitude"]] = xr.where(lon > 180, lon - 360, lon)
354
-
355
- def _choose_subdomain(self, coords) -> xr.Dataset:
356
- """
357
- Selects a subdomain from the dataset based on the specified latitude and longitude ranges.
358
-
359
- Parameters
360
- ----------
361
- coords : dict
362
- A dictionary specifying the target coordinates.
363
-
364
- Returns
365
- -------
366
- xr.Dataset
367
- The subset of the original dataset representing the chosen subdomain.
368
-
369
- Raises
370
- ------
371
- ValueError
372
- If the specified subdomain is not fully contained within the dataset.
373
- """
374
-
375
- # Select the subdomain based on the specified latitude and longitude ranges
376
- subdomain = self.ds.sel(**coords)
377
-
378
- # Check if the selected subdomain contains the specified latitude and longitude values
379
- if not subdomain[self.dim_names["latitude"]].equals(
380
- coords[self.dim_names["latitude"]]
381
- ):
382
- raise ValueError(
383
- "The correction dataset does not contain all specified latitude values."
384
- )
385
- if not subdomain[self.dim_names["longitude"]].equals(
386
- coords[self.dim_names["longitude"]]
387
- ):
388
- raise ValueError(
389
- "The correction dataset does not contain all specified longitude values."
390
- )
391
-
392
- return subdomain
393
-
394
- def _interpolate_temporally(self, field, time):
395
- """
396
- Interpolates the given field temporally based on the specified time points.
397
-
398
- Parameters
399
- ----------
400
- field : xarray.DataArray
401
- The field data to be interpolated. This can be any variable from the dataset that
402
- requires temporal interpolation, such as correction factors or any other relevant data.
403
- time : xarray.DataArray or pandas.DatetimeIndex
404
- The target time points for interpolation.
405
-
406
- Returns
407
- -------
408
- xr.DataArray
409
- The field values interpolated to the specified time points.
410
-
411
- Raises
412
- ------
413
- NotImplementedError
414
- If the temporal resolution is not set to 'climatology'.
415
-
416
- """
417
- if self.temporal_resolution != "climatology":
418
- raise NotImplementedError(
419
- f"temporal_resolution must be 'climatology', got {self.temporal_resolution}"
420
- )
421
- else:
422
- field[self.dim_names["time"]] = field[self.dim_names["time"]].dt.days
423
- day_of_year = time.dt.dayofyear
424
-
425
- # Concatenate across the beginning and end of the year
426
- time_concat = xr.concat(
427
- [
428
- field[self.dim_names["time"]][-1] - 365.25,
429
- field[self.dim_names["time"]],
430
- 365.25 + field[self.dim_names["time"]][0],
431
- ],
432
- dim=self.dim_names["time"],
433
- )
434
- field_concat = xr.concat(
435
- [
436
- field.isel({self.dim_names["time"]: -1}),
437
- field,
438
- field.isel({self.dim_names["time"]: 0}),
439
- ],
440
- dim=self.dim_names["time"],
441
- )
442
- field_concat["time"] = time_concat
443
- # Interpolate to specified times
444
- field_interpolated = field_concat.interp(time=day_of_year, method="linear")
445
-
446
- return field_interpolated
447
-
448
-
449
- @dataclass(frozen=True, kw_only=True)
450
- class Rivers:
451
- """
452
- Configuration for river forcing.
453
-
454
- Parameters
455
- ----------
456
- filename : str, optional
457
- Filename of the river forcing data.
458
- """
459
-
460
- filename: str = ""
461
-
462
- def __post_init__(self):
463
- if not self.filename:
464
- raise ValueError("The 'filename' must be provided.")
465
-
466
-
467
- @dataclass(frozen=True, kw_only=True)
468
- class AtmosphericForcing:
469
- """
470
- Represents atmospheric forcing data for ocean modeling.
471
-
472
- Parameters
473
- ----------
474
- grid : Grid
475
- Object representing the grid information.
476
- use_coarse_grid: bool
477
- Whether to interpolate to coarsened grid. Default is False.
478
- start_time : datetime
479
- Start time of the desired forcing data.
480
- end_time : datetime
481
- End time of the desired forcing data.
482
- model_reference_date : datetime, optional
483
- Reference date for the model. Default is January 1, 2000.
484
- source : str, optional
485
- Source of the atmospheric forcing data. Default is "era5".
486
- filename: str
487
- Path to the atmospheric forcing source data file. Can contain wildcards.
488
- swr_correction : SWRCorrection
489
- Shortwave radiation correction configuration.
490
- rivers : Rivers, optional
491
- River forcing configuration.
492
-
493
- Attributes
494
- ----------
495
- ds : xr.Dataset
496
- Xarray Dataset containing the atmospheric forcing data.
497
-
498
- Notes
499
- -----
500
- This class represents atmospheric forcing data used in ocean modeling. It provides a convenient
501
- interface to work with forcing data including shortwave radiation correction and river forcing.
502
-
503
- Examples
504
- --------
505
- >>> grid_info = Grid(...)
506
- >>> start_time = datetime(2000, 1, 1)
507
- >>> end_time = datetime(2000, 1, 2)
508
- >>> atm_forcing = AtmosphericForcing(
509
- ... grid=grid_info,
510
- ... start_time=start_time,
511
- ... end_time=end_time,
512
- ... source="era5",
513
- ... filename="atmospheric_data_*.nc",
514
- ... swr_correction=swr_correction,
515
- ... )
516
- """
517
-
518
- grid: Grid
519
- use_coarse_grid: bool = False
520
- start_time: datetime
521
- end_time: datetime
522
- model_reference_date: datetime = datetime(2000, 1, 1)
523
- source: str = "era5"
524
- filename: str
525
- swr_correction: Optional["SWRCorrection"] = None
526
- rivers: Optional["Rivers"] = None
527
- ds: xr.Dataset = field(init=False, repr=False)
528
-
529
- def __post_init__(self):
530
-
531
- if self.use_coarse_grid:
532
- if "lon_coarse" not in self.grid.ds:
533
- raise ValueError(
534
- "Grid has not been coarsened yet. Execute grid.coarsen() first."
535
- )
536
-
537
- lon = self.grid.ds.lon_coarse
538
- lat = self.grid.ds.lat_coarse
539
- angle = self.grid.ds.angle_coarse
540
- else:
541
- lon = self.grid.ds.lon_rho
542
- lat = self.grid.ds.lat_rho
543
- angle = self.grid.ds.angle
544
-
545
- if self.source == "era5":
546
- dims = {"longitude": "longitude", "latitude": "latitude", "time": "time"}
547
-
548
- data = ForcingDataset(
549
- filename=self.filename,
550
- start_time=self.start_time,
551
- end_time=self.end_time,
552
- dim_names=dims,
553
- )
554
-
555
- # operate on longitudes between -180 and 180 unless ROMS domain lies at least 5 degrees in lontitude away from Greenwich meridian
556
- lon = xr.where(lon > 180, lon - 360, lon)
557
- straddle = True
558
- if not self.grid.straddle and abs(lon).min() > 5:
559
- lon = xr.where(lon < 0, lon + 360, lon)
560
- straddle = False
561
-
562
- # Step 1: Choose subdomain of forcing data including safety margin for interpolation, and Step 2: Convert to the proper longitude range.
563
- # Step 1 is necessary to avoid discontinuous longitudes that could be introduced by Step 2.
564
- # Discontinuous longitudes can lead to artifacts in the interpolation process. Specifically, if there is a data gap,
565
- # discontinuous longitudes could result in values that appear to come from a distant location instead of producing NaNs.
566
- # These NaNs are important as they can be identified and handled appropriately by the nan_check function.
567
- data.choose_subdomain(
568
- latitude_range=[lat.min().values, lat.max().values],
569
- longitude_range=[lon.min().values, lon.max().values],
570
- margin=2,
571
- straddle=straddle,
572
- )
573
-
574
- # interpolate onto desired grid
575
- if self.source == "era5":
576
- mask = xr.where(data.ds["sst"].isel(time=0).isnull(), 0, 1)
577
- varnames = {
578
- "u10": "u10",
579
- "v10": "v10",
580
- "swr": "ssr",
581
- "lwr": "strd",
582
- "t2m": "t2m",
583
- "d2m": "d2m",
584
- "rain": "tp",
585
- }
586
-
587
- coords = {dims["latitude"]: lat, dims["longitude"]: lon}
588
- u10 = self._interpolate(
589
- data.ds[varnames["u10"]], mask, coords=coords, method="linear"
590
- )
591
- v10 = self._interpolate(
592
- data.ds[varnames["v10"]], mask, coords=coords, method="linear"
593
- )
594
- swr = self._interpolate(
595
- data.ds[varnames["swr"]], mask, coords=coords, method="linear"
596
- )
597
- lwr = self._interpolate(
598
- data.ds[varnames["lwr"]], mask, coords=coords, method="linear"
599
- )
600
- t2m = self._interpolate(
601
- data.ds[varnames["t2m"]], mask, coords=coords, method="linear"
602
- )
603
- d2m = self._interpolate(
604
- data.ds[varnames["d2m"]], mask, coords=coords, method="linear"
605
- )
606
- rain = self._interpolate(
607
- data.ds[varnames["rain"]], mask, coords=coords, method="linear"
608
- )
609
-
610
- if self.source == "era5":
611
- # translate radiation to fluxes. ERA5 stores values integrated over 1 hour.
612
- swr = swr / 3600 # from J/m^2 to W/m^2
613
- lwr = lwr / 3600 # from J/m^2 to W/m^2
614
- rain = rain * 100 * 24 # from m to cm/day
615
- # convert from K to C
616
- t2m = t2m - 273.15
617
- d2m = d2m - 273.15
618
- # relative humidity fraction
619
- qair = np.exp((17.625 * d2m) / (243.04 + d2m)) / np.exp(
620
- (17.625 * t2m) / (243.04 + t2m)
621
- )
622
- # convert relative to absolute humidity assuming constant pressure
623
- patm = 1010.0
624
- cff = (
625
- (1.0007 + 3.46e-6 * patm)
626
- * 6.1121
627
- * np.exp(17.502 * t2m / (240.97 + t2m))
628
- )
629
- cff = cff * qair
630
- qair = 0.62197 * (cff / (patm - 0.378 * cff))
631
-
632
- # correct shortwave radiation
633
- if self.swr_correction:
634
-
635
- # choose same subdomain as forcing data so that we can use same mask
636
- self.swr_correction._handle_longitudes(straddle=straddle)
637
- coords_correction = {
638
- self.swr_correction.dim_names["latitude"]: data.ds[
639
- data.dim_names["latitude"]
640
- ],
641
- self.swr_correction.dim_names["longitude"]: data.ds[
642
- data.dim_names["longitude"]
643
- ],
644
- }
645
- subdomain = self.swr_correction._choose_subdomain(coords_correction)
646
-
647
- # spatial interpolation
648
- corr_factor = subdomain[self.swr_correction.varname]
649
- coords_correction = {
650
- self.swr_correction.dim_names["latitude"]: lat,
651
- self.swr_correction.dim_names["longitude"]: lon,
652
- }
653
- corr_factor = self._interpolate(
654
- corr_factor, mask, coords=coords_correction, method="linear"
655
- )
656
-
657
- # temporal interpolation
658
- corr_factor = self.swr_correction._interpolate_temporally(
659
- corr_factor, time=swr.time
660
- )
661
-
662
- swr = corr_factor * swr
663
-
664
- if self.rivers:
665
- NotImplementedError("River forcing is not implemented yet.")
666
- # rain = rain + rivers
667
-
668
- # save in new dataset
669
- ds = xr.Dataset()
670
-
671
- ds["uwnd"] = (u10 * np.cos(angle) + v10 * np.sin(angle)).astype(
672
- np.float32
673
- ) # rotate to grid orientation
674
- ds["uwnd"].attrs["long_name"] = "10 meter wind in x-direction"
675
- ds["uwnd"].attrs["units"] = "m/s"
676
-
677
- ds["vwnd"] = (v10 * np.cos(angle) - u10 * np.sin(angle)).astype(
678
- np.float32
679
- ) # rotate to grid orientation
680
- ds["vwnd"].attrs["long_name"] = "10 meter wind in y-direction"
681
- ds["vwnd"].attrs["units"] = "m/s"
682
-
683
- ds["swrad"] = swr.astype(np.float32)
684
- ds["swrad"].attrs["long_name"] = "Downward short-wave (solar) radiation"
685
- ds["swrad"].attrs["units"] = "W/m^2"
686
-
687
- ds["lwrad"] = lwr.astype(np.float32)
688
- ds["lwrad"].attrs["long_name"] = "Downward long-wave (thermal) radiation"
689
- ds["lwrad"].attrs["units"] = "W/m^2"
690
-
691
- ds["Tair"] = t2m.astype(np.float32)
692
- ds["Tair"].attrs["long_name"] = "Air temperature at 2m"
693
- ds["Tair"].attrs["units"] = "degrees C"
694
-
695
- ds["qair"] = qair.astype(np.float32)
696
- ds["qair"].attrs["long_name"] = "Absolute humidity at 2m"
697
- ds["qair"].attrs["units"] = "kg/kg"
698
-
699
- ds["rain"] = rain.astype(np.float32)
700
- ds["rain"].attrs["long_name"] = "Total precipitation"
701
- ds["rain"].attrs["units"] = "cm/day"
702
-
703
- ds.attrs["Title"] = "ROMS bulk surface forcing file produced by roms-tools"
704
-
705
- ds = ds.assign_coords({"lon": lon, "lat": lat})
706
- if dims["time"] != "time":
707
- ds = ds.rename({dims["time"]: "time"})
708
- if self.use_coarse_grid:
709
- ds = ds.rename({"eta_coarse": "eta_rho", "xi_coarse": "xi_rho"})
710
- mask_roms = self.grid.ds["mask_coarse"].rename(
711
- {"eta_coarse": "eta_rho", "xi_coarse": "xi_rho"}
712
- )
713
- else:
714
- mask_roms = self.grid.ds["mask_rho"]
715
-
716
- object.__setattr__(self, "ds", ds)
717
-
718
- self.nan_check(mask_roms, time=0)
719
-
720
- @staticmethod
721
- def _interpolate(field, mask, coords, method="linear"):
722
- """
723
- Interpolate a field using specified coordinates and a given method.
724
-
725
- Parameters
726
- ----------
727
- field : xr.DataArray
728
- The data array to be interpolated.
729
-
730
- mask : xr.DataArray
731
- A data array with same spatial dimensions as `field`, where `1` indicates wet (ocean)
732
- points and `0` indicates land points.
733
-
734
- coords : dict
735
- A dictionary specifying the target coordinates for interpolation. The keys
736
- should match the dimensions of `field` (e.g., {"longitude": lon_values, "latitude": lat_values}).
737
-
738
- method : str, optional, default='linear'
739
- The interpolation method to use. Valid options are those supported by
740
- `xarray.DataArray.interp`.
741
-
742
- Returns
743
- -------
744
- xr.DataArray
745
- The interpolated data array.
746
-
747
- Notes
748
- -----
749
- This method first sets land values to NaN based on the provided mask. It then uses the
750
- `lateral_fill` function to propagate ocean values. These two steps serve the purpose to
751
- avoid interpolation across the land-ocean boundary. Finally, it performs interpolation
752
- over the specified coordinates.
753
-
754
- """
755
-
756
- dims = list(coords.keys())
757
-
758
- # set land values to nan
759
- field = field.where(mask)
760
- # propagate ocean values into land interior before interpolation
761
- field = lateral_fill(field, 1 - mask, dims)
762
- # interpolate
763
- field_interpolated = field.interp(**coords, method=method).drop_vars(dims)
764
-
765
- return field_interpolated
766
-
767
- def nan_check(self, mask, time=0) -> None:
768
- """
769
- Checks for NaN values at wet points in all variables of the dataset at a specified time step.
770
-
771
- Parameters
772
- ----------
773
- mask : array-like
774
- A boolean mask indicating the wet points in the dataset.
775
- time : int
776
- The time step at which to check for NaN values. Default is 0.
777
-
778
- Raises
779
- ------
780
- ValueError
781
- If any variable contains NaN values at the specified time step.
782
-
783
- """
784
-
785
- for var in self.ds.data_vars:
786
- da = xr.where(mask == 1, self.ds[var].isel(time=time), 0)
787
- if da.isnull().any().values:
788
- raise ValueError(
789
- f"NaN values found in the variable '{var}' at time step {time} over the ocean. This likely "
790
- "occurs because the ROMS grid, including a small safety margin for interpolation, is not "
791
- "fully contained within the dataset's longitude/latitude range. Please ensure that the "
792
- "dataset covers the entire area required by the ROMS grid."
793
- )
794
-
795
- def plot(self, varname, time=0) -> None:
796
- """
797
- Plot the specified atmospheric forcing field for a given time slice.
798
-
799
- Parameters
800
- ----------
801
- varname : str
802
- The name of the atmospheric forcing field to plot. Options include:
803
- - "uwnd": 10 meter wind in x-direction.
804
- - "vwnd": 10 meter wind in y-direction.
805
- - "swrad": Downward short-wave (solar) radiation.
806
- - "lwrad": Downward long-wave (thermal) radiation.
807
- - "Tair": Air temperature at 2m.
808
- - "qair": Absolute humidity at 2m.
809
- - "rain": Total precipitation.
810
- time : int, optional
811
- The time index to plot. Default is 0, which corresponds to the first
812
- time slice.
813
-
814
- Returns
815
- -------
816
- None
817
- This method does not return any value. It generates and displays a plot.
818
-
819
- Raises
820
- ------
821
- ValueError
822
- If the specified varname is not one of the valid options.
823
-
824
-
825
- Examples
826
- --------
827
- >>> atm_forcing = AtmosphericForcing(
828
- ... grid=grid_info,
829
- ... start_time=start_time,
830
- ... end_time=end_time,
831
- ... source="era5",
832
- ... filename="atmospheric_data_*.nc",
833
- ... swr_correction=swr_correction,
834
- ... )
835
- >>> atm_forcing.plot("uwnd", time=0)
836
- """
837
-
838
- import cartopy.crs as ccrs
839
- import matplotlib.pyplot as plt
840
-
841
- lon_deg = self.ds.lon
842
- lat_deg = self.ds.lat
843
-
844
- # check if North or South pole are in domain
845
- if lat_deg.max().values > 89 or lat_deg.min().values < -89:
846
- raise NotImplementedError(
847
- "Plotting the atmospheric forcing is not implemented for the case that the domain contains the North or South pole."
848
- )
849
-
850
- if self.grid.straddle:
851
- lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
852
-
853
- # Define projections
854
- proj = ccrs.PlateCarree()
855
-
856
- trans = ccrs.NearsidePerspective(
857
- central_longitude=lon_deg.mean().values,
858
- central_latitude=lat_deg.mean().values,
859
- )
860
-
861
- lon_deg = lon_deg.values
862
- lat_deg = lat_deg.values
863
-
864
- # find corners
865
- (lo1, la1) = (lon_deg[0, 0], lat_deg[0, 0])
866
- (lo2, la2) = (lon_deg[0, -1], lat_deg[0, -1])
867
- (lo3, la3) = (lon_deg[-1, -1], lat_deg[-1, -1])
868
- (lo4, la4) = (lon_deg[-1, 0], lat_deg[-1, 0])
869
-
870
- # transform coordinates to projected space
871
- lo1t, la1t = trans.transform_point(lo1, la1, proj)
872
- lo2t, la2t = trans.transform_point(lo2, la2, proj)
873
- lo3t, la3t = trans.transform_point(lo3, la3, proj)
874
- lo4t, la4t = trans.transform_point(lo4, la4, proj)
875
-
876
- plt.figure(figsize=(10, 10))
877
- ax = plt.axes(projection=trans)
878
-
879
- ax.plot(
880
- [lo1t, lo2t, lo3t, lo4t, lo1t],
881
- [la1t, la2t, la3t, la4t, la1t],
882
- "go-",
883
- )
884
-
885
- ax.coastlines(
886
- resolution="50m", linewidth=0.5, color="black"
887
- ) # add map of coastlines
888
- ax.gridlines()
889
-
890
- field = self.ds[varname].isel(time=time).compute()
891
- if varname in ["uwnd", "vwnd"]:
892
- vmax = max(field.max().values, -field.min().values)
893
- vmin = -vmax
894
- cmap = "RdBu_r"
895
- else:
896
- vmax = field.max().values
897
- vmin = field.min().values
898
- if varname in ["swrad", "lwrad", "Tair", "qair"]:
899
- cmap = "YlOrRd"
900
- else:
901
- cmap = "YlGnBu"
902
-
903
- p = ax.pcolormesh(
904
- lon_deg, lat_deg, field, transform=proj, vmax=vmax, vmin=vmin, cmap=cmap
905
- )
906
- plt.colorbar(p, label=field.units)
907
- ax.set_title(
908
- "%s at time %s"
909
- % (field.long_name, np.datetime_as_string(field.time, unit="s"))
910
- )
911
- plt.show()
912
-
913
- def save(self, filepath: str, time_chunk_size: int = 1) -> None:
914
- """
915
- Save the interpolated atmospheric forcing fields to netCDF4 files.
916
-
917
- This method groups the dataset by year and month, chunks the data by the specified
918
- time chunk size, and saves each chunked subset to a separate netCDF4 file named
919
- according to the year, month, and day range if not a complete month of data is included.
920
-
921
- Parameters
922
- ----------
923
- filepath : str
924
- The base path and filename for the output files. The files will be named with
925
- the format "filepath.YYYYMM.nc" if a full month of data is included, or
926
- "filepath.YYYYMMDD-DD.nc" otherwise.
927
- time_chunk_size : int, optional
928
- Number of time slices to include in each chunk along the time dimension. Default is 1,
929
- meaning each chunk contains one time slice.
930
-
931
- Returns
932
- -------
933
- None
934
- """
935
-
936
- datasets = []
937
- filenames = []
938
- writes = []
939
-
940
- # Group dataset by year
941
- gb = self.ds.groupby("time.year")
942
-
943
- for year, group_ds in gb:
944
- # Further group each yearly group by month
945
- sub_gb = group_ds.groupby("time.month")
946
-
947
- for month, ds in sub_gb:
948
- # Chunk the dataset by the specified time chunk size
949
- ds = ds.chunk({"time": time_chunk_size})
950
- datasets.append(ds)
951
-
952
- # Determine the number of days in the month
953
- num_days_in_month = calendar.monthrange(year, month)[1]
954
- first_day = ds.time.dt.day.values[0]
955
- last_day = ds.time.dt.day.values[-1]
956
-
957
- # Create filename based on whether the dataset contains a full month
958
- if first_day == 1 and last_day == num_days_in_month:
959
- # Full month format: "filepath.YYYYMM.nc"
960
- year_month_str = f"{year}{month:02}"
961
- filename = f"{filepath}.{year_month_str}.nc"
962
- else:
963
- # Partial month format: "filepath.YYYYMMDD-DD.nc"
964
- year_month_day_str = f"{year}{month:02}{first_day:02}-{last_day:02}"
965
- filename = f"{filepath}.{year_month_day_str}.nc"
966
- filenames.append(filename)
967
-
968
- print("Saving the following files:")
969
- for filename in filenames:
970
- print(filename)
971
-
972
- for ds, filename in zip(datasets, filenames):
973
-
974
- # Translate the time coordinate to days since the model reference date
975
- model_reference_date = np.datetime64(self.model_reference_date)
976
-
977
- # Preserve the original time coordinate for readability
978
- ds["Time"] = ds["time"]
979
-
980
- # Convert the time coordinate to the format expected by ROMS (days since model reference date)
981
- ds["time"] = (
982
- (ds["time"] - model_reference_date).astype("float64") / 3600 / 24 * 1e-9
983
- )
984
- ds["time"].attrs[
985
- "long_name"
986
- ] = f"time since {np.datetime_as_string(model_reference_date, unit='D')}"
987
-
988
- # Prepare the dataset for writing to a netCDF file without immediately computing
989
- write = ds.to_netcdf(filename, compute=False)
990
- writes.append(write)
991
-
992
- # Perform the actual write operations in parallel
993
- dask.compute(*writes)