roms-tools 0.20__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,86 +1,18 @@
1
- import pooch
2
1
  import xarray as xr
3
2
  from dataclasses import dataclass, field
4
3
  import glob
5
4
  from datetime import datetime, timedelta
6
5
  import numpy as np
7
- from typing import Dict, Optional, List
6
+ from typing import Dict, Optional
8
7
  import dask
9
-
10
- # Create a Pooch object to manage the global topography data
11
- pup_data = pooch.create(
12
- # Use the default cache folder for the operating system
13
- path=pooch.os_cache("roms-tools"),
14
- base_url="https://github.com/CWorthy-ocean/roms-tools-data/raw/main/",
15
- # The registry specifies the files that can be fetched
16
- registry={
17
- "etopo5.nc": "sha256:23600e422d59bbf7c3666090166a0d468c8ee16092f4f14e32c4e928fbcd627b",
18
- },
19
- )
20
-
21
- # Create a Pooch object to manage the test data
22
- pup_test_data = pooch.create(
23
- # Use the default cache folder for the operating system
24
- path=pooch.os_cache("roms-tools"),
25
- base_url="https://github.com/CWorthy-ocean/roms-tools-test-data/raw/main/",
26
- # The registry specifies the files that can be fetched
27
- registry={
28
- "GLORYS_test_data.nc": "648f88ec29c433bcf65f257c1fb9497bd3d5d3880640186336b10ed54f7129d2",
29
- "ERA5_regional_test_data.nc": "bd12ce3b562fbea2a80a3b79ba74c724294043c28dc98ae092ad816d74eac794",
30
- "ERA5_global_test_data.nc": "8ed177ab64c02caf509b9fb121cf6713f286cc603b1f302f15f3f4eb0c21dc4f",
31
- "TPXO_global_test_data.nc": "457bfe87a7b247ec6e04e3c7d3e741ccf223020c41593f8ae33a14f2b5255e60",
32
- "TPXO_regional_test_data.nc": "11739245e2286d9c9d342dce5221e6435d2072b50028bef2e86a30287b3b4032",
33
- },
8
+ import warnings
9
+ from roms_tools.setup.utils import (
10
+ assign_dates_to_climatology,
11
+ interpolate_from_climatology,
12
+ is_cftime_datetime,
13
+ convert_cftime_to_datetime,
34
14
  )
35
-
36
-
37
- def fetch_topo(topography_source: str) -> xr.Dataset:
38
- """
39
- Load the global topography data as an xarray Dataset.
40
-
41
- Parameters
42
- ----------
43
- topography_source : str
44
- The source of the topography data to be loaded. Available options:
45
- - "ETOPO5"
46
-
47
- Returns
48
- -------
49
- xr.Dataset
50
- The global topography data as an xarray Dataset.
51
- """
52
- # Mapping from user-specified topography options to corresponding filenames in the registry
53
- topo_dict = {"ETOPO5": "etopo5.nc"}
54
-
55
- # Fetch the file using Pooch, downloading if necessary
56
- fname = pup_data.fetch(topo_dict[topography_source])
57
-
58
- # Load the dataset using xarray and return it
59
- ds = xr.open_dataset(fname)
60
- return ds
61
-
62
-
63
- def download_test_data(filename: str) -> str:
64
- """
65
- Download the test data file.
66
-
67
- Parameters
68
- ----------
69
- filename : str
70
- The name of the test data file to be downloaded. Available options:
71
- - "GLORYS_test_data.nc"
72
- - "ERA5_regional_test_data.nc"
73
- - "ERA5_global_test_data.nc"
74
-
75
- Returns
76
- -------
77
- str
78
- The path to the downloaded test data file.
79
- """
80
- # Fetch the file using Pooch, downloading if necessary
81
- fname = pup_test_data.fetch(filename)
82
-
83
- return fname
15
+ from roms_tools.setup.download import download_correction_data
84
16
 
85
17
 
86
18
  @dataclass(frozen=True, kw_only=True)
@@ -97,13 +29,17 @@ class Dataset:
97
29
  end_time : Optional[datetime], optional
98
30
  The end time for selecting relevant data. If not provided, only data at the start_time is selected if start_time is provided,
99
31
  or no filtering is applied if start_time is not provided.
100
- var_names : List[str]
101
- List of variable names that are required in the dataset.
32
+ var_names: Dict[str, str]
33
+ Dictionary of variable names that are required in the dataset.
102
34
  dim_names: Dict[str, str], optional
103
35
  Dictionary specifying the names of dimensions in the dataset.
36
+ climatology : bool
37
+ Indicates whether the dataset is climatological. Defaults to False.
104
38
 
105
39
  Attributes
106
40
  ----------
41
+ is_global : bool
42
+ Indicates whether the dataset covers the entire globe.
107
43
  ds : xr.Dataset
108
44
  The xarray Dataset containing the forcing data on its original grid.
109
45
 
@@ -123,7 +59,7 @@ class Dataset:
123
59
  filename: str
124
60
  start_time: Optional[datetime] = None
125
61
  end_time: Optional[datetime] = None
126
- var_names: List[str]
62
+ var_names: Dict[str, str]
127
63
  dim_names: Dict[str, str] = field(
128
64
  default_factory=lambda: {
129
65
  "longitude": "longitude",
@@ -131,28 +67,40 @@ class Dataset:
131
67
  "time": "time",
132
68
  }
133
69
  )
70
+ climatology: Optional[bool] = False
134
71
 
72
+ is_global: bool = field(init=False, repr=False)
135
73
  ds: xr.Dataset = field(init=False, repr=False)
136
74
 
137
75
  def __post_init__(self):
76
+ """
77
+ Post-initialization processing:
78
+ 1. Loads the dataset from the specified filename.
79
+ 2. Applies time filtering based on start_time and end_time if provided.
80
+ 3. Selects relevant fields as specified by var_names.
81
+ 4. Ensures latitude values are in ascending order.
82
+ 5. Checks if the dataset covers the entire globe and adjusts if necessary.
83
+ """
84
+
138
85
  ds = self.load_data()
86
+ self.check_dataset(ds)
139
87
 
140
88
  # Select relevant times
141
89
  if "time" in self.dim_names and self.start_time is not None:
90
+ ds = self.add_time_info(ds)
142
91
  ds = self.select_relevant_times(ds)
143
92
 
144
93
  # Select relevant fields
145
94
  ds = self.select_relevant_fields(ds)
146
95
 
147
96
  # Make sure that latitude is ascending
148
- diff = np.diff(ds[self.dim_names["latitude"]])
149
- if np.all(diff < 0):
150
- ds = ds.isel(**{self.dim_names["latitude"]: slice(None, None, -1)})
97
+ ds = self.ensure_latitude_ascending(ds)
151
98
 
152
99
  # Check whether the data covers the entire globe
153
- is_global = self.check_if_global(ds)
100
+ object.__setattr__(self, "is_global", self.check_if_global(ds))
154
101
 
155
- if is_global:
102
+ # If dataset is global concatenate three copies of field along longitude dimension
103
+ if self.is_global:
156
104
  ds = self.concatenate_longitudes(ds)
157
105
 
158
106
  object.__setattr__(self, "ds", ds)
@@ -208,6 +156,34 @@ class Dataset:
208
156
 
209
157
  return ds
210
158
 
159
+ def check_dataset(self, ds: xr.Dataset) -> None:
160
+ """
161
+ Check if the dataset contains the specified variables and dimensions.
162
+
163
+ Parameters
164
+ ----------
165
+ ds : xr.Dataset
166
+ The xarray Dataset to check.
167
+
168
+ Raises
169
+ ------
170
+ ValueError
171
+ If the dataset does not contain the specified variables or dimensions.
172
+ """
173
+ missing_vars = [
174
+ var for var in self.var_names.values() if var not in ds.data_vars
175
+ ]
176
+ if missing_vars:
177
+ raise ValueError(
178
+ f"Dataset does not contain all required variables. The following variables are missing: {missing_vars}"
179
+ )
180
+
181
+ missing_dims = [dim for dim in self.dim_names.values() if dim not in ds.dims]
182
+ if missing_dims:
183
+ raise ValueError(
184
+ f"Dataset does not contain all required dimensions. The following dimensions are missing: {missing_vars}"
185
+ )
186
+
211
187
  def select_relevant_fields(self, ds) -> xr.Dataset:
212
188
  """
213
189
  Selects and returns a subset of the dataset containing only the variables specified in `self.var_names`.
@@ -222,26 +198,36 @@ class Dataset:
222
198
  xr.Dataset
223
199
  A dataset containing only the variables specified in `self.var_names`.
224
200
 
225
- Raises
226
- ------
227
- ValueError
228
- If `ds` does not contain all variables listed in `self.var_names`.
229
-
230
201
  """
231
- missing_vars = [var for var in self.var_names if var not in ds.data_vars]
232
- if missing_vars:
233
- raise ValueError(
234
- f"Dataset does not contain all required variables. The following variables are missing: {missing_vars}"
235
- )
236
202
 
237
203
  for var in ds.data_vars:
238
- if var not in self.var_names:
204
+ if var not in self.var_names.values():
239
205
  ds = ds.drop_vars(var)
240
206
 
241
207
  return ds
242
208
 
243
- def select_relevant_times(self, ds) -> xr.Dataset:
209
+ import xarray as xr
244
210
 
211
+ def add_time_info(self, ds: xr.Dataset) -> xr.Dataset:
212
+ """
213
+ Dummy method to be overridden by child classes to add time information to the dataset.
214
+
215
+ This method is intended as a placeholder and should be implemented in subclasses
216
+ to provide specific functionality for adding time-related information to the dataset.
217
+
218
+ Parameters
219
+ ----------
220
+ ds : xr.Dataset
221
+ The xarray Dataset to which time information will be added.
222
+
223
+ Returns
224
+ -------
225
+ xr.Dataset
226
+ The xarray Dataset with time information added (as implemented by child classes).
227
+ """
228
+ return ds
229
+
230
+ def select_relevant_times(self, ds) -> xr.Dataset:
245
231
  """
246
232
  Selects and returns the subset of the dataset corresponding to the specified time range.
247
233
 
@@ -259,22 +245,53 @@ class Dataset:
259
245
  xr.Dataset
260
246
  A dataset containing only the data points within the specified time range.
261
247
 
248
+ Raises
249
+ ------
250
+ ValueError
251
+ If no matching times are found or if the number of matching times does not meet expectations.
252
+
253
+ Warns
254
+ -----
255
+ UserWarning
256
+ If the dataset contains only 12 time steps but the climatology flag is not set.
257
+ This may indicate that the dataset represents climatology data.
262
258
  """
263
259
 
264
260
  time_dim = self.dim_names["time"]
265
-
266
- if not self.end_time:
267
- end_time = self.start_time + timedelta(days=1)
261
+ if time_dim in ds.coords or time_dim in ds.data_vars:
262
+ if self.climatology:
263
+ if not self.end_time:
264
+ # Interpolate from climatology for initial conditions
265
+ ds = interpolate_from_climatology(
266
+ ds, self.dim_names["time"], self.start_time
267
+ )
268
+ else:
269
+ if len(ds[time_dim]) == 12:
270
+ warnings.warn(
271
+ "The dataset contains exactly 12 time steps. This may indicate that it is "
272
+ "climatological data. Please verify if climatology is appropriate for your "
273
+ "analysis and set the climatology flag to True."
274
+ )
275
+ if is_cftime_datetime(ds[time_dim]):
276
+ ds = ds.assign_coords(
277
+ {time_dim: convert_cftime_to_datetime(ds[time_dim])}
278
+ )
279
+ if not self.end_time:
280
+ end_time = self.start_time + timedelta(days=1)
281
+ else:
282
+ end_time = self.end_time
283
+
284
+ times = (np.datetime64(self.start_time) <= ds[time_dim]) & (
285
+ ds[time_dim] < np.datetime64(end_time)
286
+ )
287
+ ds = ds.where(times, drop=True)
268
288
  else:
269
- end_time = self.end_time
270
-
271
- times = (np.datetime64(self.start_time) <= ds[time_dim]) & (
272
- ds[time_dim] < np.datetime64(end_time)
273
- )
274
- ds = ds.where(times, drop=True)
275
-
289
+ warnings.warn(
290
+ "Dataset does not contain any time information. Please check if the time dimension "
291
+ "is correctly named or if the dataset includes time data."
292
+ )
276
293
  if not ds.sizes[time_dim]:
277
- raise ValueError("No matching times found.")
294
+ raise ValueError("No matching times found in the dataset.")
278
295
 
279
296
  if not self.end_time:
280
297
  if ds.sizes[time_dim] != 1:
@@ -285,6 +302,27 @@ class Dataset:
285
302
 
286
303
  return ds
287
304
 
305
+ def ensure_latitude_ascending(self, ds: xr.Dataset) -> xr.Dataset:
306
+ """
307
+ Ensure that the latitude dimension is in ascending order.
308
+
309
+ Parameters
310
+ ----------
311
+ ds : xr.Dataset
312
+ The xarray Dataset to check.
313
+
314
+ Returns
315
+ -------
316
+ ds : xr.Dataset
317
+ The xarray Dataset with latitude in ascending order.
318
+ """
319
+ # Make sure that latitude is ascending
320
+ lat_diff = np.diff(ds[self.dim_names["latitude"]])
321
+ if np.all(lat_diff < 0):
322
+ ds = ds.isel(**{self.dim_names["latitude"]: slice(None, None, -1)})
323
+
324
+ return ds
325
+
288
326
  def check_if_global(self, ds) -> bool:
289
327
  """
290
328
  Checks if the dataset covers the entire globe in the longitude dimension.
@@ -306,7 +344,7 @@ class Dataset:
306
344
  dlon = (
307
345
  ds[self.dim_names["longitude"]][0] - ds[self.dim_names["longitude"]][-1]
308
346
  ) % 360.0
309
- is_global = np.isclose(dlon, dlon_mean, rtol=0.0, atol=1e-3)
347
+ is_global = np.isclose(dlon, dlon_mean, rtol=0.0, atol=1e-3).item()
310
348
 
311
349
  return is_global
312
350
 
@@ -341,7 +379,7 @@ class Dataset:
341
379
 
342
380
  ds_concatenated[self.dim_names["longitude"]] = lon_concatenated
343
381
 
344
- for var in self.var_names:
382
+ for var in self.var_names.values():
345
383
  if self.dim_names["longitude"] in ds[var].dims:
346
384
  field = ds[var]
347
385
  field_concatenated = xr.concat(
@@ -358,15 +396,15 @@ class Dataset:
358
396
  self, latitude_range, longitude_range, margin, straddle, return_subdomain=False
359
397
  ):
360
398
  """
361
- Selects a subdomain from the given xarray Dataset based on latitude and longitude ranges,
362
- extending the selection by the specified margin. Handles the conversion of longitude values
363
- in the dataset from one range to another.
399
+ Selects a subdomain from the xarray Dataset based on specified latitude and longitude ranges,
400
+ extending the selection by a specified margin. Handles longitude conversions to accommodate different
401
+ longitude ranges.
364
402
 
365
403
  Parameters
366
404
  ----------
367
- latitude_range : tuple
405
+ latitude_range : tuple of float
368
406
  A tuple (lat_min, lat_max) specifying the minimum and maximum latitude values of the subdomain.
369
- longitude_range : tuple
407
+ longitude_range : tuple of float
370
408
  A tuple (lon_min, lon_max) specifying the minimum and maximum longitude values of the subdomain.
371
409
  margin : float
372
410
  Margin in degrees to extend beyond the specified latitude and longitude ranges when selecting the subdomain.
@@ -374,45 +412,53 @@ class Dataset:
374
412
  If True, target longitudes are expected in the range [-180, 180].
375
413
  If False, target longitudes are expected in the range [0, 360].
376
414
  return_subdomain : bool, optional
377
- If True, returns the subset of the original dataset. If False, assigns it to self.ds.
378
- Default is False.
415
+ If True, returns the subset of the original dataset as an xarray Dataset. If False, assigns the subset to `self.ds`.
416
+ Defaults to False.
379
417
 
380
418
  Returns
381
419
  -------
382
- xr.Dataset
383
- The subset of the original dataset representing the chosen subdomain, including an extended area
384
- to cover one extra grid point beyond the specified ranges if return_subdomain is True.
385
- Otherwise, returns None.
420
+ xr.Dataset or None
421
+ If `return_subdomain` is True, returns the subset of the original dataset representing the chosen subdomain,
422
+ including an extended area to cover one extra grid point beyond the specified ranges. If `return_subdomain` is False,
423
+ returns None as the subset is assigned to `self.ds`.
424
+
425
+ Notes
426
+ -----
427
+ This method adjusts the longitude range if necessary to ensure it matches the expected range for the dataset.
428
+ It also handles longitude discontinuities that can occur when converting to different longitude ranges.
429
+ This is important for avoiding artifacts in the interpolation process.
386
430
 
387
431
  Raises
388
432
  ------
389
433
  ValueError
390
434
  If the selected latitude or longitude range does not intersect with the dataset.
391
435
  """
436
+
392
437
  lat_min, lat_max = latitude_range
393
438
  lon_min, lon_max = longitude_range
394
439
 
395
- lon = self.ds[self.dim_names["longitude"]]
396
- # Adjust longitude range if needed to match the expected range
397
- if not straddle:
398
- if lon.min() < -180:
399
- if lon_max + margin > 0:
400
- lon_min -= 360
401
- lon_max -= 360
402
- elif lon.min() < 0:
403
- if lon_max + margin > 180:
404
- lon_min -= 360
405
- lon_max -= 360
406
-
407
- if straddle:
408
- if lon.max() > 360:
409
- if lon_min - margin < 180:
410
- lon_min += 360
411
- lon_max += 360
412
- elif lon.max() > 180:
413
- if lon_min - margin < 0:
414
- lon_min += 360
415
- lon_max += 360
440
+ if not self.is_global:
441
+ # Adjust longitude range if needed to match the expected range
442
+ lon = self.ds[self.dim_names["longitude"]]
443
+ if not straddle:
444
+ if lon.min() < -180:
445
+ if lon_max + margin > 0:
446
+ lon_min -= 360
447
+ lon_max -= 360
448
+ elif lon.min() < 0:
449
+ if lon_max + margin > 180:
450
+ lon_min -= 360
451
+ lon_max -= 360
452
+
453
+ if straddle:
454
+ if lon.max() > 360:
455
+ if lon_min - margin < 180:
456
+ lon_min += 360
457
+ lon_max += 360
458
+ elif lon.max() > 180:
459
+ if lon_min - margin < 0:
460
+ lon_min += 360
461
+ lon_max += 360
416
462
 
417
463
  # Select the subdomain
418
464
  subdomain = self.ds.sel(
@@ -455,3 +501,654 @@ class Dataset:
455
501
 
456
502
  if (depth >= 0).all():
457
503
  self.ds[self.dim_names["depth"]] = -depth
504
+
505
+
506
+ @dataclass(frozen=True, kw_only=True)
507
+ class TPXODataset(Dataset):
508
+ """
509
+ Represents tidal data on the original grid from the TPXO dataset.
510
+
511
+ Parameters
512
+ ----------
513
+ filename : str
514
+ The path to the TPXO dataset file.
515
+ var_names : Dict[str, str], optional
516
+ Dictionary of variable names required in the dataset. Defaults to:
517
+ {
518
+ "h_Re": "h_Re",
519
+ "h_Im": "h_Im",
520
+ "sal_Re": "sal_Re",
521
+ "sal_Im": "sal_Im",
522
+ "u_Re": "u_Re",
523
+ "u_Im": "u_Im",
524
+ "v_Re": "v_Re",
525
+ "v_Im": "v_Im",
526
+ "depth": "depth"
527
+ }
528
+ dim_names : Dict[str, str], optional
529
+ Dictionary specifying the names of dimensions in the dataset. Defaults to:
530
+ {"longitude": "ny", "latitude": "nx", "ntides": "nc"}.
531
+
532
+ Attributes
533
+ ----------
534
+ ds : xr.Dataset
535
+ The xarray Dataset containing the TPXO tidal model data, loaded from the specified file.
536
+ reference_date : datetime
537
+ The reference date for the TPXO data. Default is datetime(1992, 1, 1).
538
+ """
539
+
540
+ filename: str
541
+ var_names: Dict[str, str] = field(
542
+ default_factory=lambda: {
543
+ "ssh_Re": "h_Re",
544
+ "ssh_Im": "h_Im",
545
+ "sal_Re": "sal_Re",
546
+ "sal_Im": "sal_Im",
547
+ "u_Re": "u_Re",
548
+ "u_Im": "u_Im",
549
+ "v_Re": "v_Re",
550
+ "v_Im": "v_Im",
551
+ "depth": "depth",
552
+ }
553
+ )
554
+ dim_names: Dict[str, str] = field(
555
+ default_factory=lambda: {"longitude": "ny", "latitude": "nx", "ntides": "nc"}
556
+ )
557
+ ds: xr.Dataset = field(init=False, repr=False)
558
+ reference_date: datetime = datetime(1992, 1, 1)
559
+
560
+ def __post_init__(self):
561
+ # Perform any necessary dataset initialization or modifications here
562
+ ds = super().load_data()
563
+
564
+ # Clean up dataset
565
+ ds = ds.assign_coords(
566
+ {
567
+ "omega": ds["omega"],
568
+ "nx": ds["lon_r"].isel(
569
+ ny=0
570
+ ), # lon_r is constant along ny, i.e., is only a function of nx
571
+ "ny": ds["lat_r"].isel(
572
+ nx=0
573
+ ), # lat_r is constant along nx, i.e., is only a function of ny
574
+ }
575
+ )
576
+ ds = ds.rename({"nx": "longitude", "ny": "latitude"})
577
+
578
+ object.__setattr__(
579
+ self,
580
+ "dim_names",
581
+ {
582
+ "latitude": "latitude",
583
+ "longitude": "longitude",
584
+ "ntides": self.dim_names["ntides"],
585
+ },
586
+ )
587
+ # Select relevant fields
588
+ ds = super().select_relevant_fields(ds)
589
+
590
+ # Check whether the data covers the entire globe
591
+ object.__setattr__(self, "is_global", super().check_if_global(ds))
592
+
593
+ # If dataset is global concatenate three copies of field along longitude dimension
594
+ if self.is_global:
595
+ ds = super().concatenate_longitudes(ds)
596
+
597
+ object.__setattr__(self, "ds", ds)
598
+
599
+ def check_number_constituents(self, ntides: int):
600
+ """
601
+ Checks if the number of constituents in the dataset is at least `ntides`.
602
+
603
+ Parameters
604
+ ----------
605
+ ntides : int
606
+ The required number of tidal constituents.
607
+
608
+ Raises
609
+ ------
610
+ ValueError
611
+ If the number of constituents in the dataset is less than `ntides`.
612
+ """
613
+ if len(self.ds[self.dim_names["ntides"]]) < ntides:
614
+ raise ValueError(
615
+ f"The dataset contains fewer than {ntides} tidal constituents."
616
+ )
617
+
618
+
619
+ @dataclass(frozen=True, kw_only=True)
620
+ class GLORYSDataset(Dataset):
621
+ """
622
+ Represents GLORYS data on original grid.
623
+
624
+ Parameters
625
+ ----------
626
+ filename : str
627
+ The path to the data files. Can contain wildcards.
628
+ start_time : Optional[datetime], optional
629
+ The start time for selecting relevant data. If not provided, the data is not filtered by start time.
630
+ end_time : Optional[datetime], optional
631
+ The end time for selecting relevant data. If not provided, only data at the start_time is selected if start_time is provided,
632
+ or no filtering is applied if start_time is not provided.
633
+ var_names: Dict[str, str], optional
634
+ Dictionary of variable names that are required in the dataset.
635
+ dim_names: Dict[str, str], optional
636
+ Dictionary specifying the names of dimensions in the dataset.
637
+ climatology : bool
638
+ Indicates whether the dataset is climatological. Defaults to False.
639
+
640
+ Attributes
641
+ ----------
642
+ ds : xr.Dataset
643
+ The xarray Dataset containing the GLORYS data on its original grid.
644
+ """
645
+
646
+ var_names: Dict[str, str] = field(
647
+ default_factory=lambda: {
648
+ "temp": "thetao",
649
+ "salt": "so",
650
+ "u": "uo",
651
+ "v": "vo",
652
+ "zeta": "zos",
653
+ }
654
+ )
655
+
656
+ dim_names: Dict[str, str] = field(
657
+ default_factory=lambda: {
658
+ "longitude": "longitude",
659
+ "latitude": "latitude",
660
+ "depth": "depth",
661
+ "time": "time",
662
+ }
663
+ )
664
+
665
+ climatology: Optional[bool] = False
666
+
667
+
668
+ @dataclass(frozen=True, kw_only=True)
669
+ class CESMDataset(Dataset):
670
+ """
671
+ Represents CESM data on original grid.
672
+
673
+ Parameters
674
+ ----------
675
+ filename : str
676
+ The path to the data files. Can contain wildcards.
677
+ start_time : Optional[datetime], optional
678
+ The start time for selecting relevant data. If not provided, the data is not filtered by start time.
679
+ end_time : Optional[datetime], optional
680
+ The end time for selecting relevant data. If not provided, only data at the start_time is selected if start_time is provided,
681
+ or no filtering is applied if start_time is not provided.
682
+ var_names: Dict[str, str], optional
683
+ Dictionary of variable names that are required in the dataset.
684
+ dim_names: Dict[str, str], optional
685
+ Dictionary specifying the names of dimensions in the dataset.
686
+ climatology : bool
687
+ Indicates whether the dataset is climatological. Defaults to True.
688
+
689
+ Attributes
690
+ ----------
691
+ ds : xr.Dataset
692
+ The xarray Dataset containing the GLORYS data on its original grid.
693
+ """
694
+
695
+ # overwrite load_data method from parent class
696
+ def load_data(self) -> xr.Dataset:
697
+ """
698
+ Load dataset from the specified file.
699
+
700
+ Returns
701
+ -------
702
+ ds : xr.Dataset
703
+ The loaded xarray Dataset containing the forcing data.
704
+
705
+ Raises
706
+ ------
707
+ FileNotFoundError
708
+ If the specified file does not exist.
709
+ """
710
+
711
+ # Check if the file exists
712
+ matching_files = glob.glob(self.filename)
713
+ if not matching_files:
714
+ raise FileNotFoundError(
715
+ f"No files found matching the pattern '{self.filename}'."
716
+ )
717
+
718
+ # Load the dataset
719
+ with dask.config.set(**{"array.slicing.split_large_chunks": False}):
720
+ # Define the chunk sizes
721
+ chunks = {
722
+ self.dim_names["latitude"]: -1,
723
+ self.dim_names["longitude"]: -1,
724
+ }
725
+
726
+ ds = xr.open_mfdataset(
727
+ self.filename,
728
+ combine="nested",
729
+ coords="minimal",
730
+ compat="override",
731
+ chunks=chunks,
732
+ engine="netcdf4",
733
+ )
734
+ if "time" not in self.dim_names:
735
+ if "time" in ds.dims:
736
+ self.dim_names["time"] = "time"
737
+ else:
738
+ if "month" in ds.dims:
739
+ self.dim_names["time"] = "month"
740
+ else:
741
+ ds = ds.expand_dims({"time": 1})
742
+ self.dim_names["time"] = "time"
743
+
744
+ return ds
745
+
746
+ def add_time_info(self, ds: xr.Dataset) -> xr.Dataset:
747
+ """
748
+ Adds time information to the dataset based on the climatology flag and dimension names.
749
+
750
+ This method processes the dataset to include time information according to the climatology
751
+ setting. If the dataset represents climatology data and the time dimension is labeled as
752
+ "month", it assigns dates to the dataset based on a monthly climatology. Additionally, it
753
+ handles dimension name updates if necessary.
754
+
755
+ Parameters
756
+ ----------
757
+ ds : xr.Dataset
758
+ The input dataset to which time information will be added.
759
+
760
+ Returns
761
+ -------
762
+ xr.Dataset
763
+ The dataset with time information added, including adjustments for climatology and
764
+ dimension names.
765
+ """
766
+ time_dim = self.dim_names["time"]
767
+
768
+ if self.climatology and time_dim == "month":
769
+ ds = assign_dates_to_climatology(ds, time_dim)
770
+ # rename dimension
771
+ ds = ds.swap_dims({time_dim: "time"})
772
+ # Update dimension names
773
+ updated_dim_names = self.dim_names.copy()
774
+ updated_dim_names["time"] = "time"
775
+ object.__setattr__(self, "dim_names", updated_dim_names)
776
+
777
+ return ds
778
+
779
+
780
+ @dataclass(frozen=True, kw_only=True)
781
+ class CESMBGCDataset(CESMDataset):
782
+ """
783
+ Represents CESM BGC data on original grid.
784
+
785
+ Parameters
786
+ ----------
787
+ filename : str
788
+ The path to the data files. Can contain wildcards.
789
+ start_time : Optional[datetime], optional
790
+ The start time for selecting relevant data. If not provided, the data is not filtered by start time.
791
+ end_time : Optional[datetime], optional
792
+ The end time for selecting relevant data. If not provided, only data at the start_time is selected if start_time is provided,
793
+ or no filtering is applied if start_time is not provided.
794
+ var_names: Dict[str, str], optional
795
+ Dictionary of variable names that are required in the dataset.
796
+ dim_names: Dict[str, str], optional
797
+ Dictionary specifying the names of dimensions in the dataset.
798
+ climatology : bool
799
+ Indicates whether the dataset is climatological. Defaults to True.
800
+
801
+ Attributes
802
+ ----------
803
+ ds : xr.Dataset
804
+ The xarray Dataset containing the GLORYS data on its original grid.
805
+ """
806
+
807
+ var_names: Dict[str, str] = field(
808
+ default_factory=lambda: {
809
+ "PO4": "PO4",
810
+ "NO3": "NO3",
811
+ "SiO3": "SiO3",
812
+ "NH4": "NH4",
813
+ "Fe": "Fe",
814
+ "Lig": "Lig",
815
+ "O2": "O2",
816
+ "DIC": "DIC",
817
+ "DIC_ALT_CO2": "DIC_ALT_CO2",
818
+ "ALK": "ALK",
819
+ "ALK_ALT_CO2": "ALK_ALT_CO2",
820
+ "DOC": "DOC",
821
+ "DON": "DON",
822
+ "DOP": "DOP",
823
+ "DOPr": "DOPr",
824
+ "DONr": "DONr",
825
+ "DOCr": "DOCr",
826
+ "spChl": "spChl",
827
+ "spC": "spC",
828
+ "spP": "spP",
829
+ "spFe": "spFe",
830
+ "diatChl": "diatChl",
831
+ "diatC": "diatC",
832
+ "diatP": "diatP",
833
+ "diatFe": "diatFe",
834
+ "diatSi": "diatSi",
835
+ "diazChl": "diazChl",
836
+ "diazC": "diazC",
837
+ "diazP": "diazP",
838
+ "diazFe": "diazFe",
839
+ "spCaCO3": "spCaCO3",
840
+ "zooC": "zooC",
841
+ }
842
+ )
843
+
844
+ dim_names: Dict[str, str] = field(
845
+ default_factory=lambda: {
846
+ "longitude": "lon",
847
+ "latitude": "lat",
848
+ "depth": "z_t",
849
+ }
850
+ )
851
+
852
+ climatology: Optional[bool] = True
853
+
854
+ def post_process(self):
855
+ """
856
+ Processes and converts CESM data values as follows:
857
+ - Convert depth values from cm to m.
858
+ """
859
+
860
+ if self.dim_names["depth"] == "z_t":
861
+ # Fill variables that only have data in upper 150m with NaNs below
862
+ if (
863
+ "z_t_150m" in self.ds.dims
864
+ and np.equal(self.ds.z_t[:15].values, self.ds.z_t_150m.values).all()
865
+ ):
866
+ for var in self.var_names:
867
+ if "z_t_150m" in self.ds[var].dims:
868
+ self.ds[var] = self.ds[var].rename({"z_t_150m": "z_t"})
869
+ self.ds[var] = self.ds[var].chunk({"z_t": -1})
870
+ # Convert depth from cm to m
871
+ ds = self.ds.assign_coords({"depth": self.ds["z_t"] / 100})
872
+ ds["depth"].attrs["long_name"] = "Depth"
873
+ ds["depth"].attrs["units"] = "m"
874
+ ds = ds.swap_dims({"z_t": "depth"})
875
+ if "z_t" in ds:
876
+ ds = ds.drop_vars("z_t")
877
+ if "z_t_150m" in ds:
878
+ ds = ds.drop_vars("z_t_150m")
879
+ # update dataset
880
+ object.__setattr__(self, "ds", ds)
881
+
882
+ # Update dim_names with "depth": "depth" key-value pair
883
+ updated_dim_names = self.dim_names.copy()
884
+ updated_dim_names["depth"] = "depth"
885
+ object.__setattr__(self, "dim_names", updated_dim_names)
886
+
887
+
888
+ @dataclass(frozen=True, kw_only=True)
889
+ class CESMBGCSurfaceForcingDataset(CESMDataset):
890
+ """
891
+ Represents CESM BGC surface forcing data on original grid.
892
+
893
+ Parameters
894
+ ----------
895
+ filename : str
896
+ The path to the data files. Can contain wildcards.
897
+ start_time : Optional[datetime], optional
898
+ The start time for selecting relevant data. If not provided, the data is not filtered by start time.
899
+ end_time : Optional[datetime], optional
900
+ The end time for selecting relevant data. If not provided, only data at the start_time is selected if start_time is provided,
901
+ or no filtering is applied if start_time is not provided.
902
+ var_names: Dict[str, str], optional
903
+ Dictionary of variable names that are required in the dataset.
904
+ dim_names: Dict[str, str], optional
905
+ Dictionary specifying the names of dimensions in the dataset.
906
+ climatology : bool
907
+ Indicates whether the dataset is climatological. Defaults to False.
908
+
909
+ Attributes
910
+ ----------
911
+ ds : xr.Dataset
912
+ The xarray Dataset containing the GLORYS data on its original grid.
913
+ """
914
+
915
+ var_names: Dict[str, str] = field(
916
+ default_factory=lambda: {
917
+ "pco2_air": "pCO2SURF",
918
+ "pco2_air_alt": "pCO2SURF",
919
+ "iron": "IRON_FLUX",
920
+ "dust": "dust_FLUX_IN",
921
+ "nox": "NOx_FLUX",
922
+ "nhy": "NHy_FLUX",
923
+ }
924
+ )
925
+
926
+ dim_names: Dict[str, str] = field(
927
+ default_factory=lambda: {
928
+ "longitude": "lon",
929
+ "latitude": "lat",
930
+ }
931
+ )
932
+
933
+ climatology: Optional[bool] = False
934
+
935
+
936
+ @dataclass(frozen=True, kw_only=True)
937
+ class ERA5Dataset(Dataset):
938
+ """
939
+ Represents ERA5 data on original grid.
940
+
941
+ Parameters
942
+ ----------
943
+ filename : str
944
+ The path to the data files. Can contain wildcards.
945
+ start_time : Optional[datetime], optional
946
+ The start time for selecting relevant data. If not provided, the data is not filtered by start time.
947
+ end_time : Optional[datetime], optional
948
+ The end time for selecting relevant data. If not provided, only data at the start_time is selected if start_time is provided,
949
+ or no filtering is applied if start_time is not provided.
950
+ var_names: Dict[str, str], optional
951
+ Dictionary of variable names that are required in the dataset.
952
+ dim_names: Dict[str, str], optional
953
+ Dictionary specifying the names of dimensions in the dataset.
954
+ climatology : bool
955
+ Indicates whether the dataset is climatological. Defaults to False.
956
+
957
+ Attributes
958
+ ----------
959
+ ds : xr.Dataset
960
+ The xarray Dataset containing the GLORYS data on its original grid.
961
+ """
962
+
963
+ var_names: Dict[str, str] = field(
964
+ default_factory=lambda: {
965
+ "uwnd": "u10",
966
+ "vwnd": "v10",
967
+ "swrad": "ssr",
968
+ "lwrad": "strd",
969
+ "Tair": "t2m",
970
+ "d2m": "d2m",
971
+ "rain": "tp",
972
+ "mask": "sst",
973
+ }
974
+ )
975
+
976
+ dim_names: Dict[str, str] = field(
977
+ default_factory=lambda: {
978
+ "longitude": "longitude",
979
+ "latitude": "latitude",
980
+ "time": "time",
981
+ }
982
+ )
983
+
984
+ climatology: Optional[bool] = False
985
+
986
+ def post_process(self):
987
+ """
988
+ Processes and converts ERA5 data values as follows:
989
+ - Convert radiation values from J/m^2 to W/m^2.
990
+ - Convert rainfall from meters to cm/day.
991
+ - Convert temperature from Kelvin to Celsius.
992
+ - Compute relative humidity if not present, convert to absolute humidity.
993
+ """
994
+ # Translate radiation to fluxes. ERA5 stores values integrated over 1 hour.
995
+ # Convert radiation from J/m^2 to W/m^2
996
+ self.ds[self.var_names["swrad"]] /= 3600
997
+ self.ds[self.var_names["lwrad"]] /= 3600
998
+ self.ds[self.var_names["swrad"]].attrs["units"] = "W/m^2"
999
+ self.ds[self.var_names["lwrad"]].attrs["units"] = "W/m^2"
1000
+ # Convert rainfall from m to cm/day
1001
+ self.ds[self.var_names["rain"]] *= 100 * 24
1002
+
1003
+ # Convert temperature from Kelvin to Celsius
1004
+ self.ds[self.var_names["Tair"]] -= 273.15
1005
+ self.ds[self.var_names["d2m"]] -= 273.15
1006
+ self.ds[self.var_names["Tair"]].attrs["units"] = "degrees C"
1007
+ self.ds[self.var_names["d2m"]].attrs["units"] = "degrees C"
1008
+
1009
+ # Compute relative humidity if not present
1010
+ if "qair" not in self.ds.data_vars:
1011
+ qair = np.exp(
1012
+ (17.625 * self.ds[self.var_names["d2m"]])
1013
+ / (243.04 + self.ds[self.var_names["d2m"]])
1014
+ ) / np.exp(
1015
+ (17.625 * self.ds[self.var_names["Tair"]])
1016
+ / (243.04 + self.ds[self.var_names["Tair"]])
1017
+ )
1018
+ # Convert relative to absolute humidity
1019
+ patm = 1010.0
1020
+ cff = (
1021
+ (1.0007 + 3.46e-6 * patm)
1022
+ * 6.1121
1023
+ * np.exp(
1024
+ 17.502
1025
+ * self.ds[self.var_names["Tair"]]
1026
+ / (240.97 + self.ds[self.var_names["Tair"]])
1027
+ )
1028
+ )
1029
+ cff = cff * qair
1030
+ self.ds["qair"] = 0.62197 * (cff / (patm - 0.378 * cff))
1031
+ self.ds["qair"].attrs["long_name"] = "Absolute humidity at 2m"
1032
+ self.ds["qair"].attrs["units"] = "kg/kg"
1033
+
1034
+ # Update var_names dictionary
1035
+ var_names = {**self.var_names, "qair": "qair"}
1036
+ object.__setattr__(self, "var_names", var_names)
1037
+
1038
+ if "mask" in self.var_names.keys():
1039
+ mask = xr.where(self.ds[self.var_names["mask"]].isel(time=0).isnull(), 0, 1)
1040
+
1041
+ for var in self.ds.data_vars:
1042
+ self.ds[var] = xr.where(mask == 1, self.ds[var], np.nan)
1043
+
1044
+
1045
+ @dataclass(frozen=True, kw_only=True)
1046
+ class ERA5Correction(Dataset):
1047
+ """
1048
+ Global dataset to correct ERA5 radiation. The dataset contains multiplicative correction factors for the ERA5 shortwave radiation, obtained by comparing the COREv2 climatology to the ERA5 climatology.
1049
+
1050
+ Parameters
1051
+ ----------
1052
+ filename : str, optional
1053
+ The path to the correction files. Defaults to download_correction_data('SSR_correction.nc').
1054
+ var_names: Dict[str, str], optional
1055
+ Dictionary of variable names that are required in the dataset.
1056
+ Defaults to {"swr_corr": "ssr_corr"}.
1057
+ dim_names: Dict[str, str], optional
1058
+ Dictionary specifying the names of dimensions in the dataset.
1059
+ Defaults to {"longitude": "longitude", "latitude": "latitude", "time": "time"}.
1060
+ climatology : bool, optional
1061
+ Indicates if the correction data is a climatology. Defaults to True.
1062
+
1063
+ Attributes
1064
+ ----------
1065
+ ds : xr.Dataset
1066
+ The loaded xarray Dataset containing the correction data.
1067
+ """
1068
+
1069
+ filename: str = field(
1070
+ default_factory=lambda: download_correction_data("SSR_correction.nc")
1071
+ )
1072
+ var_names: Dict[str, str] = field(
1073
+ default_factory=lambda: {
1074
+ "swr_corr": "ssr_corr", # multiplicative correction factor for ERA5 shortwave radiation
1075
+ }
1076
+ )
1077
+ dim_names: Dict[str, str] = field(
1078
+ default_factory=lambda: {
1079
+ "longitude": "longitude",
1080
+ "latitude": "latitude",
1081
+ "time": "time",
1082
+ }
1083
+ )
1084
+ climatology: Optional[bool] = True
1085
+
1086
+ ds: xr.Dataset = field(init=False, repr=False)
1087
+
1088
+ def __post_init__(self):
1089
+
1090
+ if not self.climatology:
1091
+ raise NotImplementedError(
1092
+ "Correction data must be a climatology. Set climatology to True."
1093
+ )
1094
+
1095
+ super().__post_init__()
1096
+
1097
+ def choose_subdomain(self, coords, straddle: bool):
1098
+ """
1099
+ Converts longitude values in the dataset if necessary and selects a subdomain based on the specified coordinates.
1100
+
1101
+ This method converts longitude values between different ranges if required and then extracts a subset of the
1102
+ dataset according to the given coordinates. It updates the dataset in place to reflect the selected subdomain.
1103
+
1104
+ Parameters
1105
+ ----------
1106
+ coords : dict
1107
+ A dictionary specifying the target coordinates for selecting the subdomain. Keys should correspond to the
1108
+ dimension names of the dataset (e.g., latitude and longitude), and values should be the desired ranges or
1109
+ specific coordinate values.
1110
+ straddle : bool
1111
+ If True, assumes that target longitudes are in the range [-180, 180]. If False, assumes longitudes are in the
1112
+ range [0, 360]. This parameter determines how longitude values are converted if necessary.
1113
+
1114
+ Raises
1115
+ ------
1116
+ ValueError
1117
+ If the specified subdomain does not fully contain the specified latitude or longitude values. This can occur
1118
+ if the dataset does not cover the full range of provided coordinates.
1119
+
1120
+ Notes
1121
+ -----
1122
+ - The dataset (`self.ds`) is updated in place to reflect the chosen subdomain.
1123
+ """
1124
+
1125
+ lon = self.ds[self.dim_names["longitude"]]
1126
+
1127
+ if not self.is_global:
1128
+ if lon.min().values < 0 and not straddle:
1129
+ # Convert from [-180, 180] to [0, 360]
1130
+ self.ds[self.dim_names["longitude"]] = xr.where(lon < 0, lon + 360, lon)
1131
+
1132
+ if lon.max().values > 180 and straddle:
1133
+ # Convert from [0, 360] to [-180, 180]
1134
+ self.ds[self.dim_names["longitude"]] = xr.where(
1135
+ lon > 180, lon - 360, lon
1136
+ )
1137
+
1138
+ # Select the subdomain based on the specified latitude and longitude ranges
1139
+ subdomain = self.ds.sel(**coords)
1140
+
1141
+ # Check if the selected subdomain contains the specified latitude and longitude values
1142
+ if not subdomain[self.dim_names["latitude"]].equals(
1143
+ coords[self.dim_names["latitude"]]
1144
+ ):
1145
+ raise ValueError(
1146
+ "The correction dataset does not contain all specified latitude values."
1147
+ )
1148
+ if not subdomain[self.dim_names["longitude"]].equals(
1149
+ coords[self.dim_names["longitude"]]
1150
+ ):
1151
+ raise ValueError(
1152
+ "The correction dataset does not contain all specified longitude values."
1153
+ )
1154
+ object.__setattr__(self, "ds", subdomain)