roms-tools 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,48 +1,1154 @@
1
- import pooch
2
1
  import xarray as xr
2
+ from dataclasses import dataclass, field
3
+ import glob
4
+ from datetime import datetime, timedelta
5
+ import numpy as np
6
+ from typing import Dict, Optional
7
+ import dask
8
+ import warnings
9
+ from roms_tools.setup.utils import (
10
+ assign_dates_to_climatology,
11
+ interpolate_from_climatology,
12
+ is_cftime_datetime,
13
+ convert_cftime_to_datetime,
14
+ )
15
+ from roms_tools.setup.download import download_correction_data
3
16
 
4
17
 
5
- FRANK = pooch.create(
6
- # Use the default cache folder for the operating system
7
- path=pooch.os_cache("roms-tools"),
8
- base_url="https://github.com/CWorthy-ocean/roms-tools-data/raw/main/",
9
- # If this is a development version, get the data from the "main" branch
10
- # The registry specifies the files that can be fetched
11
- registry={
12
- "etopo5.nc": "sha256:23600e422d59bbf7c3666090166a0d468c8ee16092f4f14e32c4e928fbcd627b",
13
- },
14
- )
18
+ @dataclass(frozen=True, kw_only=True)
19
+ class Dataset:
20
+ """
21
+ Represents forcing data on original grid.
22
+
23
+ Parameters
24
+ ----------
25
+ filename : str
26
+ The path to the data files. Can contain wildcards.
27
+ start_time : Optional[datetime], optional
28
+ The start time for selecting relevant data. If not provided, the data is not filtered by start time.
29
+ end_time : Optional[datetime], optional
30
+ The end time for selecting relevant data. If not provided, only data at the start_time is selected if start_time is provided,
31
+ or no filtering is applied if start_time is not provided.
32
+ var_names: Dict[str, str]
33
+ Dictionary of variable names that are required in the dataset.
34
+ dim_names: Dict[str, str], optional
35
+ Dictionary specifying the names of dimensions in the dataset.
36
+ climatology : bool
37
+ Indicates whether the dataset is climatological. Defaults to False.
38
+
39
+ Attributes
40
+ ----------
41
+ is_global : bool
42
+ Indicates whether the dataset covers the entire globe.
43
+ ds : xr.Dataset
44
+ The xarray Dataset containing the forcing data on its original grid.
45
+
46
+ Examples
47
+ --------
48
+ >>> dataset = Dataset(
49
+ ... filename="data.nc",
50
+ ... start_time=datetime(2022, 1, 1),
51
+ ... end_time=datetime(2022, 12, 31),
52
+ ... )
53
+ >>> dataset.load_data()
54
+ >>> print(dataset.ds)
55
+ <xarray.Dataset>
56
+ Dimensions: ...
57
+ """
58
+
59
+ filename: str
60
+ start_time: Optional[datetime] = None
61
+ end_time: Optional[datetime] = None
62
+ var_names: Dict[str, str]
63
+ dim_names: Dict[str, str] = field(
64
+ default_factory=lambda: {
65
+ "longitude": "longitude",
66
+ "latitude": "latitude",
67
+ "time": "time",
68
+ }
69
+ )
70
+ climatology: Optional[bool] = False
71
+
72
+ is_global: bool = field(init=False, repr=False)
73
+ ds: xr.Dataset = field(init=False, repr=False)
74
+
75
+ def __post_init__(self):
76
+ """
77
+ Post-initialization processing:
78
+ 1. Loads the dataset from the specified filename.
79
+ 2. Applies time filtering based on start_time and end_time if provided.
80
+ 3. Selects relevant fields as specified by var_names.
81
+ 4. Ensures latitude values are in ascending order.
82
+ 5. Checks if the dataset covers the entire globe and adjusts if necessary.
83
+ """
84
+
85
+ ds = self.load_data()
86
+ self.check_dataset(ds)
87
+
88
+ # Select relevant times
89
+ if "time" in self.dim_names and self.start_time is not None:
90
+ ds = self.add_time_info(ds)
91
+ ds = self.select_relevant_times(ds)
92
+
93
+ # Select relevant fields
94
+ ds = self.select_relevant_fields(ds)
95
+
96
+ # Make sure that latitude is ascending
97
+ ds = self.ensure_latitude_ascending(ds)
98
+
99
+ # Check whether the data covers the entire globe
100
+ object.__setattr__(self, "is_global", self.check_if_global(ds))
101
+
102
+ # If dataset is global concatenate three copies of field along longitude dimension
103
+ if self.is_global:
104
+ ds = self.concatenate_longitudes(ds)
105
+
106
+ object.__setattr__(self, "ds", ds)
107
+
108
+ def load_data(self) -> xr.Dataset:
109
+ """
110
+ Load dataset from the specified file.
111
+
112
+ Returns
113
+ -------
114
+ ds : xr.Dataset
115
+ The loaded xarray Dataset containing the forcing data.
116
+
117
+ Raises
118
+ ------
119
+ FileNotFoundError
120
+ If the specified file does not exist.
121
+ """
122
+
123
+ # Check if the file exists
124
+ matching_files = glob.glob(self.filename)
125
+ if not matching_files:
126
+ raise FileNotFoundError(
127
+ f"No files found matching the pattern '{self.filename}'."
128
+ )
129
+
130
+ # Load the dataset
131
+ with dask.config.set(**{"array.slicing.split_large_chunks": False}):
132
+ # Define the chunk sizes
133
+ chunks = {
134
+ self.dim_names["latitude"]: -1,
135
+ self.dim_names["longitude"]: -1,
136
+ }
137
+ if "depth" in self.dim_names.keys():
138
+ chunks[self.dim_names["depth"]] = -1
139
+ if "time" in self.dim_names.keys():
140
+ chunks[self.dim_names["time"]] = 1
141
+
142
+ ds = xr.open_mfdataset(
143
+ self.filename,
144
+ combine="nested",
145
+ concat_dim=self.dim_names["time"],
146
+ coords="minimal",
147
+ compat="override",
148
+ chunks=chunks,
149
+ engine="netcdf4",
150
+ )
151
+ else:
152
+ ds = xr.open_dataset(
153
+ self.filename,
154
+ chunks=chunks,
155
+ )
156
+
157
+ return ds
158
+
159
+ def check_dataset(self, ds: xr.Dataset) -> None:
160
+ """
161
+ Check if the dataset contains the specified variables and dimensions.
162
+
163
+ Parameters
164
+ ----------
165
+ ds : xr.Dataset
166
+ The xarray Dataset to check.
167
+
168
+ Raises
169
+ ------
170
+ ValueError
171
+ If the dataset does not contain the specified variables or dimensions.
172
+ """
173
+ missing_vars = [
174
+ var for var in self.var_names.values() if var not in ds.data_vars
175
+ ]
176
+ if missing_vars:
177
+ raise ValueError(
178
+ f"Dataset does not contain all required variables. The following variables are missing: {missing_vars}"
179
+ )
180
+
181
+ missing_dims = [dim for dim in self.dim_names.values() if dim not in ds.dims]
182
+ if missing_dims:
183
+ raise ValueError(
184
+ f"Dataset does not contain all required dimensions. The following dimensions are missing: {missing_vars}"
185
+ )
186
+
187
+ def select_relevant_fields(self, ds) -> xr.Dataset:
188
+ """
189
+ Selects and returns a subset of the dataset containing only the variables specified in `self.var_names`.
190
+
191
+ Parameters
192
+ ----------
193
+ ds : xr.Dataset
194
+ The input dataset from which variables will be selected.
195
+
196
+ Returns
197
+ -------
198
+ xr.Dataset
199
+ A dataset containing only the variables specified in `self.var_names`.
200
+
201
+ """
202
+
203
+ for var in ds.data_vars:
204
+ if var not in self.var_names.values():
205
+ ds = ds.drop_vars(var)
206
+
207
+ return ds
208
+
209
+ import xarray as xr
210
+
211
+ def add_time_info(self, ds: xr.Dataset) -> xr.Dataset:
212
+ """
213
+ Dummy method to be overridden by child classes to add time information to the dataset.
214
+
215
+ This method is intended as a placeholder and should be implemented in subclasses
216
+ to provide specific functionality for adding time-related information to the dataset.
217
+
218
+ Parameters
219
+ ----------
220
+ ds : xr.Dataset
221
+ The xarray Dataset to which time information will be added.
222
+
223
+ Returns
224
+ -------
225
+ xr.Dataset
226
+ The xarray Dataset with time information added (as implemented by child classes).
227
+ """
228
+ return ds
229
+
230
+ def select_relevant_times(self, ds) -> xr.Dataset:
231
+ """
232
+ Selects and returns the subset of the dataset corresponding to the specified time range.
233
+
234
+ This function filters the dataset to include only the data points within the specified
235
+ time range, defined by `self.start_time` and `self.end_time`. If `self.end_time` is not
236
+ provided, it defaults to one day after `self.start_time`.
237
+
238
+ Parameters
239
+ ----------
240
+ ds : xr.Dataset
241
+ The input dataset to be filtered.
242
+
243
+ Returns
244
+ -------
245
+ xr.Dataset
246
+ A dataset containing only the data points within the specified time range.
247
+
248
+ Raises
249
+ ------
250
+ ValueError
251
+ If no matching times are found or if the number of matching times does not meet expectations.
15
252
 
253
+ Warns
254
+ -----
255
+ UserWarning
256
+ If the dataset contains only 12 time steps but the climatology flag is not set.
257
+ This may indicate that the dataset represents climatology data.
258
+ """
16
259
 
17
- def fetch_topo(topography_source) -> xr.Dataset:
260
+ time_dim = self.dim_names["time"]
261
+ if time_dim in ds.coords or time_dim in ds.data_vars:
262
+ if self.climatology:
263
+ if not self.end_time:
264
+ # Interpolate from climatology for initial conditions
265
+ ds = interpolate_from_climatology(
266
+ ds, self.dim_names["time"], self.start_time
267
+ )
268
+ else:
269
+ if len(ds[time_dim]) == 12:
270
+ warnings.warn(
271
+ "The dataset contains exactly 12 time steps. This may indicate that it is "
272
+ "climatological data. Please verify if climatology is appropriate for your "
273
+ "analysis and set the climatology flag to True."
274
+ )
275
+ if is_cftime_datetime(ds[time_dim]):
276
+ ds = ds.assign_coords(
277
+ {time_dim: convert_cftime_to_datetime(ds[time_dim])}
278
+ )
279
+ if not self.end_time:
280
+ end_time = self.start_time + timedelta(days=1)
281
+ else:
282
+ end_time = self.end_time
283
+
284
+ times = (np.datetime64(self.start_time) <= ds[time_dim]) & (
285
+ ds[time_dim] < np.datetime64(end_time)
286
+ )
287
+ ds = ds.where(times, drop=True)
288
+ else:
289
+ warnings.warn(
290
+ "Dataset does not contain any time information. Please check if the time dimension "
291
+ "is correctly named or if the dataset includes time data."
292
+ )
293
+ if not ds.sizes[time_dim]:
294
+ raise ValueError("No matching times found in the dataset.")
295
+
296
+ if not self.end_time:
297
+ if ds.sizes[time_dim] != 1:
298
+ found_times = ds.sizes[time_dim]
299
+ raise ValueError(
300
+ f"There must be exactly one time matching the start_time. Found {found_times} matching times."
301
+ )
302
+
303
+ return ds
304
+
305
+ def ensure_latitude_ascending(self, ds: xr.Dataset) -> xr.Dataset:
306
+ """
307
+ Ensure that the latitude dimension is in ascending order.
308
+
309
+ Parameters
310
+ ----------
311
+ ds : xr.Dataset
312
+ The xarray Dataset to check.
313
+
314
+ Returns
315
+ -------
316
+ ds : xr.Dataset
317
+ The xarray Dataset with latitude in ascending order.
318
+ """
319
+ # Make sure that latitude is ascending
320
+ lat_diff = np.diff(ds[self.dim_names["latitude"]])
321
+ if np.all(lat_diff < 0):
322
+ ds = ds.isel(**{self.dim_names["latitude"]: slice(None, None, -1)})
323
+
324
+ return ds
325
+
326
+ def check_if_global(self, ds) -> bool:
327
+ """
328
+ Checks if the dataset covers the entire globe in the longitude dimension.
329
+
330
+ This function calculates the mean difference between consecutive longitude values.
331
+ It then checks if the difference between the first and last longitude values (plus 360 degrees)
332
+ is close to this mean difference, within a specified tolerance. If it is, the dataset is considered
333
+ to cover the entire globe in the longitude dimension.
334
+
335
+ Returns
336
+ -------
337
+ bool
338
+ True if the dataset covers the entire globe in the longitude dimension, False otherwise.
339
+
340
+ """
341
+ dlon_mean = (
342
+ ds[self.dim_names["longitude"]].diff(dim=self.dim_names["longitude"]).mean()
343
+ )
344
+ dlon = (
345
+ ds[self.dim_names["longitude"]][0] - ds[self.dim_names["longitude"]][-1]
346
+ ) % 360.0
347
+ is_global = np.isclose(dlon, dlon_mean, rtol=0.0, atol=1e-3).item()
348
+
349
+ return is_global
350
+
351
+ def concatenate_longitudes(self, ds):
352
+ """
353
+ Concatenates the field three times: with longitudes shifted by -360, original longitudes, and shifted by +360.
354
+
355
+ Parameters
356
+ ----------
357
+ field : xr.DataArray
358
+ The field to be concatenated.
359
+
360
+ Returns
361
+ -------
362
+ xr.DataArray
363
+ The concatenated field, with the longitude dimension extended.
364
+
365
+ Notes
366
+ -----
367
+ Concatenating three times may be overkill in most situations, but it is safe. Alternatively, we could refactor
368
+ to figure out whether concatenating on the lower end, upper end, or at all is needed.
369
+
370
+ """
371
+ ds_concatenated = xr.Dataset()
372
+
373
+ lon = ds[self.dim_names["longitude"]]
374
+ lon_minus360 = lon - 360
375
+ lon_plus360 = lon + 360
376
+ lon_concatenated = xr.concat(
377
+ [lon_minus360, lon, lon_plus360], dim=self.dim_names["longitude"]
378
+ )
379
+
380
+ ds_concatenated[self.dim_names["longitude"]] = lon_concatenated
381
+
382
+ for var in self.var_names.values():
383
+ if self.dim_names["longitude"] in ds[var].dims:
384
+ field = ds[var]
385
+ field_concatenated = xr.concat(
386
+ [field, field, field], dim=self.dim_names["longitude"]
387
+ ).chunk({self.dim_names["longitude"]: -1})
388
+ field_concatenated[self.dim_names["longitude"]] = lon_concatenated
389
+ ds_concatenated[var] = field_concatenated
390
+ else:
391
+ ds_concatenated[var] = ds[var]
392
+
393
+ return ds_concatenated
394
+
395
+ def choose_subdomain(
396
+ self, latitude_range, longitude_range, margin, straddle, return_subdomain=False
397
+ ):
398
+ """
399
+ Selects a subdomain from the xarray Dataset based on specified latitude and longitude ranges,
400
+ extending the selection by a specified margin. Handles longitude conversions to accommodate different
401
+ longitude ranges.
402
+
403
+ Parameters
404
+ ----------
405
+ latitude_range : tuple of float
406
+ A tuple (lat_min, lat_max) specifying the minimum and maximum latitude values of the subdomain.
407
+ longitude_range : tuple of float
408
+ A tuple (lon_min, lon_max) specifying the minimum and maximum longitude values of the subdomain.
409
+ margin : float
410
+ Margin in degrees to extend beyond the specified latitude and longitude ranges when selecting the subdomain.
411
+ straddle : bool
412
+ If True, target longitudes are expected in the range [-180, 180].
413
+ If False, target longitudes are expected in the range [0, 360].
414
+ return_subdomain : bool, optional
415
+ If True, returns the subset of the original dataset as an xarray Dataset. If False, assigns the subset to `self.ds`.
416
+ Defaults to False.
417
+
418
+ Returns
419
+ -------
420
+ xr.Dataset or None
421
+ If `return_subdomain` is True, returns the subset of the original dataset representing the chosen subdomain,
422
+ including an extended area to cover one extra grid point beyond the specified ranges. If `return_subdomain` is False,
423
+ returns None as the subset is assigned to `self.ds`.
424
+
425
+ Notes
426
+ -----
427
+ This method adjusts the longitude range if necessary to ensure it matches the expected range for the dataset.
428
+ It also handles longitude discontinuities that can occur when converting to different longitude ranges.
429
+ This is important for avoiding artifacts in the interpolation process.
430
+
431
+ Raises
432
+ ------
433
+ ValueError
434
+ If the selected latitude or longitude range does not intersect with the dataset.
435
+ """
436
+
437
+ lat_min, lat_max = latitude_range
438
+ lon_min, lon_max = longitude_range
439
+
440
+ if not self.is_global:
441
+ # Adjust longitude range if needed to match the expected range
442
+ lon = self.ds[self.dim_names["longitude"]]
443
+ if not straddle:
444
+ if lon.min() < -180:
445
+ if lon_max + margin > 0:
446
+ lon_min -= 360
447
+ lon_max -= 360
448
+ elif lon.min() < 0:
449
+ if lon_max + margin > 180:
450
+ lon_min -= 360
451
+ lon_max -= 360
452
+
453
+ if straddle:
454
+ if lon.max() > 360:
455
+ if lon_min - margin < 180:
456
+ lon_min += 360
457
+ lon_max += 360
458
+ elif lon.max() > 180:
459
+ if lon_min - margin < 0:
460
+ lon_min += 360
461
+ lon_max += 360
462
+
463
+ # Select the subdomain
464
+ subdomain = self.ds.sel(
465
+ **{
466
+ self.dim_names["latitude"]: slice(lat_min - margin, lat_max + margin),
467
+ self.dim_names["longitude"]: slice(lon_min - margin, lon_max + margin),
468
+ }
469
+ )
470
+
471
+ # Check if the selected subdomain has zero dimensions in latitude or longitude
472
+ if subdomain[self.dim_names["latitude"]].size == 0:
473
+ raise ValueError("Selected latitude range does not intersect with dataset.")
474
+
475
+ if subdomain[self.dim_names["longitude"]].size == 0:
476
+ raise ValueError(
477
+ "Selected longitude range does not intersect with dataset."
478
+ )
479
+
480
+ # Adjust longitudes to expected range if needed
481
+ lon = subdomain[self.dim_names["longitude"]]
482
+ if straddle:
483
+ subdomain[self.dim_names["longitude"]] = xr.where(lon > 180, lon - 360, lon)
484
+ else:
485
+ subdomain[self.dim_names["longitude"]] = xr.where(lon < 0, lon + 360, lon)
486
+
487
+ if return_subdomain:
488
+ return subdomain
489
+ else:
490
+ object.__setattr__(self, "ds", subdomain)
491
+
492
+ def convert_to_negative_depth(self):
493
+ """
494
+ Converts the depth values in the dataset to negative if they are non-negative.
495
+
496
+ This method checks the values in the depth dimension of the dataset (`self.ds[self.dim_names["depth"]]`).
497
+ If all values are greater than or equal to zero, it negates them and updates the dataset accordingly.
498
+
499
+ """
500
+ depth = self.ds[self.dim_names["depth"]]
501
+
502
+ if (depth >= 0).all():
503
+ self.ds[self.dim_names["depth"]] = -depth
504
+
505
+
506
+ @dataclass(frozen=True, kw_only=True)
507
+ class TPXODataset(Dataset):
18
508
  """
19
- Load the global topography data as an xarray Dataset.
509
+ Represents tidal data on the original grid from the TPXO dataset.
510
+
511
+ Parameters
512
+ ----------
513
+ filename : str
514
+ The path to the TPXO dataset file.
515
+ var_names : Dict[str, str], optional
516
+ Dictionary of variable names required in the dataset. Defaults to:
517
+ {
518
+ "h_Re": "h_Re",
519
+ "h_Im": "h_Im",
520
+ "sal_Re": "sal_Re",
521
+ "sal_Im": "sal_Im",
522
+ "u_Re": "u_Re",
523
+ "u_Im": "u_Im",
524
+ "v_Re": "v_Re",
525
+ "v_Im": "v_Im",
526
+ "depth": "depth"
527
+ }
528
+ dim_names : Dict[str, str], optional
529
+ Dictionary specifying the names of dimensions in the dataset. Defaults to:
530
+ {"longitude": "ny", "latitude": "nx", "ntides": "nc"}.
531
+
532
+ Attributes
533
+ ----------
534
+ ds : xr.Dataset
535
+ The xarray Dataset containing the TPXO tidal model data, loaded from the specified file.
536
+ reference_date : datetime
537
+ The reference date for the TPXO data. Default is datetime(1992, 1, 1).
20
538
  """
21
- # Mapping from user-specified topography options to corresponding filenames in the registry
22
- topo_dict = {"etopo5": "etopo5.nc"}
23
539
 
24
- # The file will be downloaded automatically the first time this is run
25
- # returns the file path to the downloaded file. Afterwards, Pooch finds
26
- # it in the local cache and doesn't repeat the download.
27
- fname = FRANK.fetch(topo_dict[topography_source])
28
- # The "fetch" method returns the full path to the downloaded data file.
29
- # All we need to do now is load it with our standard Python tools.
30
- ds = xr.open_dataset(fname)
31
- return ds
540
+ filename: str
541
+ var_names: Dict[str, str] = field(
542
+ default_factory=lambda: {
543
+ "ssh_Re": "h_Re",
544
+ "ssh_Im": "h_Im",
545
+ "sal_Re": "sal_Re",
546
+ "sal_Im": "sal_Im",
547
+ "u_Re": "u_Re",
548
+ "u_Im": "u_Im",
549
+ "v_Re": "v_Re",
550
+ "v_Im": "v_Im",
551
+ "depth": "depth",
552
+ }
553
+ )
554
+ dim_names: Dict[str, str] = field(
555
+ default_factory=lambda: {"longitude": "ny", "latitude": "nx", "ntides": "nc"}
556
+ )
557
+ ds: xr.Dataset = field(init=False, repr=False)
558
+ reference_date: datetime = datetime(1992, 1, 1)
559
+
560
+ def __post_init__(self):
561
+ # Perform any necessary dataset initialization or modifications here
562
+ ds = super().load_data()
563
+
564
+ # Clean up dataset
565
+ ds = ds.assign_coords(
566
+ {
567
+ "omega": ds["omega"],
568
+ "nx": ds["lon_r"].isel(
569
+ ny=0
570
+ ), # lon_r is constant along ny, i.e., is only a function of nx
571
+ "ny": ds["lat_r"].isel(
572
+ nx=0
573
+ ), # lat_r is constant along nx, i.e., is only a function of ny
574
+ }
575
+ )
576
+ ds = ds.rename({"nx": "longitude", "ny": "latitude"})
32
577
 
578
+ object.__setattr__(
579
+ self,
580
+ "dim_names",
581
+ {
582
+ "latitude": "latitude",
583
+ "longitude": "longitude",
584
+ "ntides": self.dim_names["ntides"],
585
+ },
586
+ )
587
+ # Select relevant fields
588
+ ds = super().select_relevant_fields(ds)
33
589
 
34
- def fetch_ssr_correction(correction_source) -> xr.Dataset:
590
+ # Check whether the data covers the entire globe
591
+ object.__setattr__(self, "is_global", super().check_if_global(ds))
592
+
593
+ # If dataset is global concatenate three copies of field along longitude dimension
594
+ if self.is_global:
595
+ ds = super().concatenate_longitudes(ds)
596
+
597
+ object.__setattr__(self, "ds", ds)
598
+
599
+ def check_number_constituents(self, ntides: int):
600
+ """
601
+ Checks if the number of constituents in the dataset is at least `ntides`.
602
+
603
+ Parameters
604
+ ----------
605
+ ntides : int
606
+ The required number of tidal constituents.
607
+
608
+ Raises
609
+ ------
610
+ ValueError
611
+ If the number of constituents in the dataset is less than `ntides`.
612
+ """
613
+ if len(self.ds[self.dim_names["ntides"]]) < ntides:
614
+ raise ValueError(
615
+ f"The dataset contains fewer than {ntides} tidal constituents."
616
+ )
617
+
618
+
619
+ @dataclass(frozen=True, kw_only=True)
620
+ class GLORYSDataset(Dataset):
621
+ """
622
+ Represents GLORYS data on original grid.
623
+
624
+ Parameters
625
+ ----------
626
+ filename : str
627
+ The path to the data files. Can contain wildcards.
628
+ start_time : Optional[datetime], optional
629
+ The start time for selecting relevant data. If not provided, the data is not filtered by start time.
630
+ end_time : Optional[datetime], optional
631
+ The end time for selecting relevant data. If not provided, only data at the start_time is selected if start_time is provided,
632
+ or no filtering is applied if start_time is not provided.
633
+ var_names: Dict[str, str], optional
634
+ Dictionary of variable names that are required in the dataset.
635
+ dim_names: Dict[str, str], optional
636
+ Dictionary specifying the names of dimensions in the dataset.
637
+ climatology : bool
638
+ Indicates whether the dataset is climatological. Defaults to False.
639
+
640
+ Attributes
641
+ ----------
642
+ ds : xr.Dataset
643
+ The xarray Dataset containing the GLORYS data on its original grid.
644
+ """
645
+
646
+ var_names: Dict[str, str] = field(
647
+ default_factory=lambda: {
648
+ "temp": "thetao",
649
+ "salt": "so",
650
+ "u": "uo",
651
+ "v": "vo",
652
+ "zeta": "zos",
653
+ }
654
+ )
655
+
656
+ dim_names: Dict[str, str] = field(
657
+ default_factory=lambda: {
658
+ "longitude": "longitude",
659
+ "latitude": "latitude",
660
+ "depth": "depth",
661
+ "time": "time",
662
+ }
663
+ )
664
+
665
+ climatology: Optional[bool] = False
666
+
667
+
668
+ @dataclass(frozen=True, kw_only=True)
669
+ class CESMDataset(Dataset):
670
+ """
671
+ Represents CESM data on original grid.
672
+
673
+ Parameters
674
+ ----------
675
+ filename : str
676
+ The path to the data files. Can contain wildcards.
677
+ start_time : Optional[datetime], optional
678
+ The start time for selecting relevant data. If not provided, the data is not filtered by start time.
679
+ end_time : Optional[datetime], optional
680
+ The end time for selecting relevant data. If not provided, only data at the start_time is selected if start_time is provided,
681
+ or no filtering is applied if start_time is not provided.
682
+ var_names: Dict[str, str], optional
683
+ Dictionary of variable names that are required in the dataset.
684
+ dim_names: Dict[str, str], optional
685
+ Dictionary specifying the names of dimensions in the dataset.
686
+ climatology : bool
687
+ Indicates whether the dataset is climatological. Defaults to True.
688
+
689
+ Attributes
690
+ ----------
691
+ ds : xr.Dataset
692
+ The xarray Dataset containing the GLORYS data on its original grid.
693
+ """
694
+
695
+ # overwrite load_data method from parent class
696
+ def load_data(self) -> xr.Dataset:
697
+ """
698
+ Load dataset from the specified file.
699
+
700
+ Returns
701
+ -------
702
+ ds : xr.Dataset
703
+ The loaded xarray Dataset containing the forcing data.
704
+
705
+ Raises
706
+ ------
707
+ FileNotFoundError
708
+ If the specified file does not exist.
709
+ """
710
+
711
+ # Check if the file exists
712
+ matching_files = glob.glob(self.filename)
713
+ if not matching_files:
714
+ raise FileNotFoundError(
715
+ f"No files found matching the pattern '{self.filename}'."
716
+ )
717
+
718
+ # Load the dataset
719
+ with dask.config.set(**{"array.slicing.split_large_chunks": False}):
720
+ # Define the chunk sizes
721
+ chunks = {
722
+ self.dim_names["latitude"]: -1,
723
+ self.dim_names["longitude"]: -1,
724
+ }
725
+
726
+ ds = xr.open_mfdataset(
727
+ self.filename,
728
+ combine="nested",
729
+ coords="minimal",
730
+ compat="override",
731
+ chunks=chunks,
732
+ engine="netcdf4",
733
+ )
734
+ if "time" not in self.dim_names:
735
+ if "time" in ds.dims:
736
+ self.dim_names["time"] = "time"
737
+ else:
738
+ if "month" in ds.dims:
739
+ self.dim_names["time"] = "month"
740
+ else:
741
+ ds = ds.expand_dims({"time": 1})
742
+ self.dim_names["time"] = "time"
743
+
744
+ return ds
745
+
746
+ def add_time_info(self, ds: xr.Dataset) -> xr.Dataset:
747
+ """
748
+ Adds time information to the dataset based on the climatology flag and dimension names.
749
+
750
+ This method processes the dataset to include time information according to the climatology
751
+ setting. If the dataset represents climatology data and the time dimension is labeled as
752
+ "month", it assigns dates to the dataset based on a monthly climatology. Additionally, it
753
+ handles dimension name updates if necessary.
754
+
755
+ Parameters
756
+ ----------
757
+ ds : xr.Dataset
758
+ The input dataset to which time information will be added.
759
+
760
+ Returns
761
+ -------
762
+ xr.Dataset
763
+ The dataset with time information added, including adjustments for climatology and
764
+ dimension names.
765
+ """
766
+ time_dim = self.dim_names["time"]
767
+
768
+ if self.climatology and time_dim == "month":
769
+ ds = assign_dates_to_climatology(ds, time_dim)
770
+ # rename dimension
771
+ ds = ds.swap_dims({time_dim: "time"})
772
+ # Update dimension names
773
+ updated_dim_names = self.dim_names.copy()
774
+ updated_dim_names["time"] = "time"
775
+ object.__setattr__(self, "dim_names", updated_dim_names)
776
+
777
+ return ds
778
+
779
+
780
+ @dataclass(frozen=True, kw_only=True)
781
+ class CESMBGCDataset(CESMDataset):
35
782
  """
36
- Load the SSR correction data as an xarray Dataset.
783
+ Represents CESM BGC data on original grid.
784
+
785
+ Parameters
786
+ ----------
787
+ filename : str
788
+ The path to the data files. Can contain wildcards.
789
+ start_time : Optional[datetime], optional
790
+ The start time for selecting relevant data. If not provided, the data is not filtered by start time.
791
+ end_time : Optional[datetime], optional
792
+ The end time for selecting relevant data. If not provided, only data at the start_time is selected if start_time is provided,
793
+ or no filtering is applied if start_time is not provided.
794
+ var_names: Dict[str, str], optional
795
+ Dictionary of variable names that are required in the dataset.
796
+ dim_names: Dict[str, str], optional
797
+ Dictionary specifying the names of dimensions in the dataset.
798
+ climatology : bool
799
+ Indicates whether the dataset is climatological. Defaults to True.
800
+
801
+ Attributes
802
+ ----------
803
+ ds : xr.Dataset
804
+ The xarray Dataset containing the GLORYS data on its original grid.
37
805
  """
38
- # Mapping from user-specified topography options to corresponding filenames in the registry
39
- topo_dict = {"corev2": "SSR_correction.nc"}
40
806
 
41
- # The file will be downloaded automatically the first time this is run
42
- # returns the file path to the downloaded file. Afterwards, Pooch finds
43
- # it in the local cache and doesn't repeat the download.
44
- fname = FRANK.fetch(topo_dict[correction_source])
45
- # The "fetch" method returns the full path to the downloaded data file.
46
- # All we need to do now is load it with our standard Python tools.
47
- ds = xr.open_dataset(fname)
48
- return ds
807
+ var_names: Dict[str, str] = field(
808
+ default_factory=lambda: {
809
+ "PO4": "PO4",
810
+ "NO3": "NO3",
811
+ "SiO3": "SiO3",
812
+ "NH4": "NH4",
813
+ "Fe": "Fe",
814
+ "Lig": "Lig",
815
+ "O2": "O2",
816
+ "DIC": "DIC",
817
+ "DIC_ALT_CO2": "DIC_ALT_CO2",
818
+ "ALK": "ALK",
819
+ "ALK_ALT_CO2": "ALK_ALT_CO2",
820
+ "DOC": "DOC",
821
+ "DON": "DON",
822
+ "DOP": "DOP",
823
+ "DOPr": "DOPr",
824
+ "DONr": "DONr",
825
+ "DOCr": "DOCr",
826
+ "spChl": "spChl",
827
+ "spC": "spC",
828
+ "spP": "spP",
829
+ "spFe": "spFe",
830
+ "diatChl": "diatChl",
831
+ "diatC": "diatC",
832
+ "diatP": "diatP",
833
+ "diatFe": "diatFe",
834
+ "diatSi": "diatSi",
835
+ "diazChl": "diazChl",
836
+ "diazC": "diazC",
837
+ "diazP": "diazP",
838
+ "diazFe": "diazFe",
839
+ "spCaCO3": "spCaCO3",
840
+ "zooC": "zooC",
841
+ }
842
+ )
843
+
844
+ dim_names: Dict[str, str] = field(
845
+ default_factory=lambda: {
846
+ "longitude": "lon",
847
+ "latitude": "lat",
848
+ "depth": "z_t",
849
+ }
850
+ )
851
+
852
+ climatology: Optional[bool] = True
853
+
854
+ def post_process(self):
855
+ """
856
+ Processes and converts CESM data values as follows:
857
+ - Convert depth values from cm to m.
858
+ """
859
+
860
+ if self.dim_names["depth"] == "z_t":
861
+ # Fill variables that only have data in upper 150m with NaNs below
862
+ if (
863
+ "z_t_150m" in self.ds.dims
864
+ and np.equal(self.ds.z_t[:15].values, self.ds.z_t_150m.values).all()
865
+ ):
866
+ for var in self.var_names:
867
+ if "z_t_150m" in self.ds[var].dims:
868
+ self.ds[var] = self.ds[var].rename({"z_t_150m": "z_t"})
869
+ self.ds[var] = self.ds[var].chunk({"z_t": -1})
870
+ # Convert depth from cm to m
871
+ ds = self.ds.assign_coords({"depth": self.ds["z_t"] / 100})
872
+ ds["depth"].attrs["long_name"] = "Depth"
873
+ ds["depth"].attrs["units"] = "m"
874
+ ds = ds.swap_dims({"z_t": "depth"})
875
+ if "z_t" in ds:
876
+ ds = ds.drop_vars("z_t")
877
+ if "z_t_150m" in ds:
878
+ ds = ds.drop_vars("z_t_150m")
879
+ # update dataset
880
+ object.__setattr__(self, "ds", ds)
881
+
882
+ # Update dim_names with "depth": "depth" key-value pair
883
+ updated_dim_names = self.dim_names.copy()
884
+ updated_dim_names["depth"] = "depth"
885
+ object.__setattr__(self, "dim_names", updated_dim_names)
886
+
887
+
888
+ @dataclass(frozen=True, kw_only=True)
889
+ class CESMBGCSurfaceForcingDataset(CESMDataset):
890
+ """
891
+ Represents CESM BGC surface forcing data on original grid.
892
+
893
+ Parameters
894
+ ----------
895
+ filename : str
896
+ The path to the data files. Can contain wildcards.
897
+ start_time : Optional[datetime], optional
898
+ The start time for selecting relevant data. If not provided, the data is not filtered by start time.
899
+ end_time : Optional[datetime], optional
900
+ The end time for selecting relevant data. If not provided, only data at the start_time is selected if start_time is provided,
901
+ or no filtering is applied if start_time is not provided.
902
+ var_names: Dict[str, str], optional
903
+ Dictionary of variable names that are required in the dataset.
904
+ dim_names: Dict[str, str], optional
905
+ Dictionary specifying the names of dimensions in the dataset.
906
+ climatology : bool
907
+ Indicates whether the dataset is climatological. Defaults to False.
908
+
909
+ Attributes
910
+ ----------
911
+ ds : xr.Dataset
912
+ The xarray Dataset containing the GLORYS data on its original grid.
913
+ """
914
+
915
+ var_names: Dict[str, str] = field(
916
+ default_factory=lambda: {
917
+ "pco2_air": "pCO2SURF",
918
+ "pco2_air_alt": "pCO2SURF",
919
+ "iron": "IRON_FLUX",
920
+ "dust": "dust_FLUX_IN",
921
+ "nox": "NOx_FLUX",
922
+ "nhy": "NHy_FLUX",
923
+ }
924
+ )
925
+
926
+ dim_names: Dict[str, str] = field(
927
+ default_factory=lambda: {
928
+ "longitude": "lon",
929
+ "latitude": "lat",
930
+ }
931
+ )
932
+
933
+ climatology: Optional[bool] = False
934
+
935
+
936
+ @dataclass(frozen=True, kw_only=True)
937
+ class ERA5Dataset(Dataset):
938
+ """
939
+ Represents ERA5 data on original grid.
940
+
941
+ Parameters
942
+ ----------
943
+ filename : str
944
+ The path to the data files. Can contain wildcards.
945
+ start_time : Optional[datetime], optional
946
+ The start time for selecting relevant data. If not provided, the data is not filtered by start time.
947
+ end_time : Optional[datetime], optional
948
+ The end time for selecting relevant data. If not provided, only data at the start_time is selected if start_time is provided,
949
+ or no filtering is applied if start_time is not provided.
950
+ var_names: Dict[str, str], optional
951
+ Dictionary of variable names that are required in the dataset.
952
+ dim_names: Dict[str, str], optional
953
+ Dictionary specifying the names of dimensions in the dataset.
954
+ climatology : bool
955
+ Indicates whether the dataset is climatological. Defaults to False.
956
+
957
+ Attributes
958
+ ----------
959
+ ds : xr.Dataset
960
+ The xarray Dataset containing the GLORYS data on its original grid.
961
+ """
962
+
963
+ var_names: Dict[str, str] = field(
964
+ default_factory=lambda: {
965
+ "uwnd": "u10",
966
+ "vwnd": "v10",
967
+ "swrad": "ssr",
968
+ "lwrad": "strd",
969
+ "Tair": "t2m",
970
+ "d2m": "d2m",
971
+ "rain": "tp",
972
+ "mask": "sst",
973
+ }
974
+ )
975
+
976
+ dim_names: Dict[str, str] = field(
977
+ default_factory=lambda: {
978
+ "longitude": "longitude",
979
+ "latitude": "latitude",
980
+ "time": "time",
981
+ }
982
+ )
983
+
984
+ climatology: Optional[bool] = False
985
+
986
+ def post_process(self):
987
+ """
988
+ Processes and converts ERA5 data values as follows:
989
+ - Convert radiation values from J/m^2 to W/m^2.
990
+ - Convert rainfall from meters to cm/day.
991
+ - Convert temperature from Kelvin to Celsius.
992
+ - Compute relative humidity if not present, convert to absolute humidity.
993
+ """
994
+ # Translate radiation to fluxes. ERA5 stores values integrated over 1 hour.
995
+ # Convert radiation from J/m^2 to W/m^2
996
+ self.ds[self.var_names["swrad"]] /= 3600
997
+ self.ds[self.var_names["lwrad"]] /= 3600
998
+ self.ds[self.var_names["swrad"]].attrs["units"] = "W/m^2"
999
+ self.ds[self.var_names["lwrad"]].attrs["units"] = "W/m^2"
1000
+ # Convert rainfall from m to cm/day
1001
+ self.ds[self.var_names["rain"]] *= 100 * 24
1002
+
1003
+ # Convert temperature from Kelvin to Celsius
1004
+ self.ds[self.var_names["Tair"]] -= 273.15
1005
+ self.ds[self.var_names["d2m"]] -= 273.15
1006
+ self.ds[self.var_names["Tair"]].attrs["units"] = "degrees C"
1007
+ self.ds[self.var_names["d2m"]].attrs["units"] = "degrees C"
1008
+
1009
+ # Compute relative humidity if not present
1010
+ if "qair" not in self.ds.data_vars:
1011
+ qair = np.exp(
1012
+ (17.625 * self.ds[self.var_names["d2m"]])
1013
+ / (243.04 + self.ds[self.var_names["d2m"]])
1014
+ ) / np.exp(
1015
+ (17.625 * self.ds[self.var_names["Tair"]])
1016
+ / (243.04 + self.ds[self.var_names["Tair"]])
1017
+ )
1018
+ # Convert relative to absolute humidity
1019
+ patm = 1010.0
1020
+ cff = (
1021
+ (1.0007 + 3.46e-6 * patm)
1022
+ * 6.1121
1023
+ * np.exp(
1024
+ 17.502
1025
+ * self.ds[self.var_names["Tair"]]
1026
+ / (240.97 + self.ds[self.var_names["Tair"]])
1027
+ )
1028
+ )
1029
+ cff = cff * qair
1030
+ self.ds["qair"] = 0.62197 * (cff / (patm - 0.378 * cff))
1031
+ self.ds["qair"].attrs["long_name"] = "Absolute humidity at 2m"
1032
+ self.ds["qair"].attrs["units"] = "kg/kg"
1033
+
1034
+ # Update var_names dictionary
1035
+ var_names = {**self.var_names, "qair": "qair"}
1036
+ object.__setattr__(self, "var_names", var_names)
1037
+
1038
+ if "mask" in self.var_names.keys():
1039
+ mask = xr.where(self.ds[self.var_names["mask"]].isel(time=0).isnull(), 0, 1)
1040
+
1041
+ for var in self.ds.data_vars:
1042
+ self.ds[var] = xr.where(mask == 1, self.ds[var], np.nan)
1043
+
1044
+
1045
+ @dataclass(frozen=True, kw_only=True)
1046
+ class ERA5Correction(Dataset):
1047
+ """
1048
+ Global dataset to correct ERA5 radiation. The dataset contains multiplicative correction factors for the ERA5 shortwave radiation, obtained by comparing the COREv2 climatology to the ERA5 climatology.
1049
+
1050
+ Parameters
1051
+ ----------
1052
+ filename : str, optional
1053
+ The path to the correction files. Defaults to download_correction_data('SSR_correction.nc').
1054
+ var_names: Dict[str, str], optional
1055
+ Dictionary of variable names that are required in the dataset.
1056
+ Defaults to {"swr_corr": "ssr_corr"}.
1057
+ dim_names: Dict[str, str], optional
1058
+ Dictionary specifying the names of dimensions in the dataset.
1059
+ Defaults to {"longitude": "longitude", "latitude": "latitude", "time": "time"}.
1060
+ climatology : bool, optional
1061
+ Indicates if the correction data is a climatology. Defaults to True.
1062
+
1063
+ Attributes
1064
+ ----------
1065
+ ds : xr.Dataset
1066
+ The loaded xarray Dataset containing the correction data.
1067
+ """
1068
+
1069
+ filename: str = field(
1070
+ default_factory=lambda: download_correction_data("SSR_correction.nc")
1071
+ )
1072
+ var_names: Dict[str, str] = field(
1073
+ default_factory=lambda: {
1074
+ "swr_corr": "ssr_corr", # multiplicative correction factor for ERA5 shortwave radiation
1075
+ }
1076
+ )
1077
+ dim_names: Dict[str, str] = field(
1078
+ default_factory=lambda: {
1079
+ "longitude": "longitude",
1080
+ "latitude": "latitude",
1081
+ "time": "time",
1082
+ }
1083
+ )
1084
+ climatology: Optional[bool] = True
1085
+
1086
+ ds: xr.Dataset = field(init=False, repr=False)
1087
+
1088
+ def __post_init__(self):
1089
+
1090
+ if not self.climatology:
1091
+ raise NotImplementedError(
1092
+ "Correction data must be a climatology. Set climatology to True."
1093
+ )
1094
+
1095
+ super().__post_init__()
1096
+
1097
+ def choose_subdomain(self, coords, straddle: bool):
1098
+ """
1099
+ Converts longitude values in the dataset if necessary and selects a subdomain based on the specified coordinates.
1100
+
1101
+ This method converts longitude values between different ranges if required and then extracts a subset of the
1102
+ dataset according to the given coordinates. It updates the dataset in place to reflect the selected subdomain.
1103
+
1104
+ Parameters
1105
+ ----------
1106
+ coords : dict
1107
+ A dictionary specifying the target coordinates for selecting the subdomain. Keys should correspond to the
1108
+ dimension names of the dataset (e.g., latitude and longitude), and values should be the desired ranges or
1109
+ specific coordinate values.
1110
+ straddle : bool
1111
+ If True, assumes that target longitudes are in the range [-180, 180]. If False, assumes longitudes are in the
1112
+ range [0, 360]. This parameter determines how longitude values are converted if necessary.
1113
+
1114
+ Raises
1115
+ ------
1116
+ ValueError
1117
+ If the specified subdomain does not fully contain the specified latitude or longitude values. This can occur
1118
+ if the dataset does not cover the full range of provided coordinates.
1119
+
1120
+ Notes
1121
+ -----
1122
+ - The dataset (`self.ds`) is updated in place to reflect the chosen subdomain.
1123
+ """
1124
+
1125
+ lon = self.ds[self.dim_names["longitude"]]
1126
+
1127
+ if not self.is_global:
1128
+ if lon.min().values < 0 and not straddle:
1129
+ # Convert from [-180, 180] to [0, 360]
1130
+ self.ds[self.dim_names["longitude"]] = xr.where(lon < 0, lon + 360, lon)
1131
+
1132
+ if lon.max().values > 180 and straddle:
1133
+ # Convert from [0, 360] to [-180, 180]
1134
+ self.ds[self.dim_names["longitude"]] = xr.where(
1135
+ lon > 180, lon - 360, lon
1136
+ )
1137
+
1138
+ # Select the subdomain based on the specified latitude and longitude ranges
1139
+ subdomain = self.ds.sel(**coords)
1140
+
1141
+ # Check if the selected subdomain contains the specified latitude and longitude values
1142
+ if not subdomain[self.dim_names["latitude"]].equals(
1143
+ coords[self.dim_names["latitude"]]
1144
+ ):
1145
+ raise ValueError(
1146
+ "The correction dataset does not contain all specified latitude values."
1147
+ )
1148
+ if not subdomain[self.dim_names["longitude"]].equals(
1149
+ coords[self.dim_names["longitude"]]
1150
+ ):
1151
+ raise ValueError(
1152
+ "The correction dataset does not contain all specified longitude values."
1153
+ )
1154
+ object.__setattr__(self, "ds", subdomain)