roms-tools 0.0.6__py3-none-any.whl → 0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,457 @@
1
+ import pooch
2
+ import xarray as xr
3
+ from dataclasses import dataclass, field
4
+ import glob
5
+ from datetime import datetime, timedelta
6
+ import numpy as np
7
+ from typing import Dict, Optional, List
8
+ import dask
9
+
10
+ # Create a Pooch object to manage the global topography data
11
+ pup_data = pooch.create(
12
+ # Use the default cache folder for the operating system
13
+ path=pooch.os_cache("roms-tools"),
14
+ base_url="https://github.com/CWorthy-ocean/roms-tools-data/raw/main/",
15
+ # The registry specifies the files that can be fetched
16
+ registry={
17
+ "etopo5.nc": "sha256:23600e422d59bbf7c3666090166a0d468c8ee16092f4f14e32c4e928fbcd627b",
18
+ },
19
+ )
20
+
21
+ # Create a Pooch object to manage the test data
22
+ pup_test_data = pooch.create(
23
+ # Use the default cache folder for the operating system
24
+ path=pooch.os_cache("roms-tools"),
25
+ base_url="https://github.com/CWorthy-ocean/roms-tools-test-data/raw/main/",
26
+ # The registry specifies the files that can be fetched
27
+ registry={
28
+ "GLORYS_test_data.nc": "648f88ec29c433bcf65f257c1fb9497bd3d5d3880640186336b10ed54f7129d2",
29
+ "ERA5_regional_test_data.nc": "bd12ce3b562fbea2a80a3b79ba74c724294043c28dc98ae092ad816d74eac794",
30
+ "ERA5_global_test_data.nc": "8ed177ab64c02caf509b9fb121cf6713f286cc603b1f302f15f3f4eb0c21dc4f",
31
+ "TPXO_global_test_data.nc": "457bfe87a7b247ec6e04e3c7d3e741ccf223020c41593f8ae33a14f2b5255e60",
32
+ "TPXO_regional_test_data.nc": "11739245e2286d9c9d342dce5221e6435d2072b50028bef2e86a30287b3b4032",
33
+ },
34
+ )
35
+
36
+
37
+ def fetch_topo(topography_source: str) -> xr.Dataset:
38
+ """
39
+ Load the global topography data as an xarray Dataset.
40
+
41
+ Parameters
42
+ ----------
43
+ topography_source : str
44
+ The source of the topography data to be loaded. Available options:
45
+ - "ETOPO5"
46
+
47
+ Returns
48
+ -------
49
+ xr.Dataset
50
+ The global topography data as an xarray Dataset.
51
+ """
52
+ # Mapping from user-specified topography options to corresponding filenames in the registry
53
+ topo_dict = {"ETOPO5": "etopo5.nc"}
54
+
55
+ # Fetch the file using Pooch, downloading if necessary
56
+ fname = pup_data.fetch(topo_dict[topography_source])
57
+
58
+ # Load the dataset using xarray and return it
59
+ ds = xr.open_dataset(fname)
60
+ return ds
61
+
62
+
63
+ def download_test_data(filename: str) -> str:
64
+ """
65
+ Download the test data file.
66
+
67
+ Parameters
68
+ ----------
69
+ filename : str
70
+ The name of the test data file to be downloaded. Available options:
71
+ - "GLORYS_test_data.nc"
72
+ - "ERA5_regional_test_data.nc"
73
+ - "ERA5_global_test_data.nc"
74
+
75
+ Returns
76
+ -------
77
+ str
78
+ The path to the downloaded test data file.
79
+ """
80
+ # Fetch the file using Pooch, downloading if necessary
81
+ fname = pup_test_data.fetch(filename)
82
+
83
+ return fname
84
+
85
+
86
+ @dataclass(frozen=True, kw_only=True)
87
+ class Dataset:
88
+ """
89
+ Represents forcing data on original grid.
90
+
91
+ Parameters
92
+ ----------
93
+ filename : str
94
+ The path to the data files. Can contain wildcards.
95
+ start_time : Optional[datetime], optional
96
+ The start time for selecting relevant data. If not provided, the data is not filtered by start time.
97
+ end_time : Optional[datetime], optional
98
+ The end time for selecting relevant data. If not provided, only data at the start_time is selected if start_time is provided,
99
+ or no filtering is applied if start_time is not provided.
100
+ var_names : List[str]
101
+ List of variable names that are required in the dataset.
102
+ dim_names: Dict[str, str], optional
103
+ Dictionary specifying the names of dimensions in the dataset.
104
+
105
+ Attributes
106
+ ----------
107
+ ds : xr.Dataset
108
+ The xarray Dataset containing the forcing data on its original grid.
109
+
110
+ Examples
111
+ --------
112
+ >>> dataset = Dataset(
113
+ ... filename="data.nc",
114
+ ... start_time=datetime(2022, 1, 1),
115
+ ... end_time=datetime(2022, 12, 31),
116
+ ... )
117
+ >>> dataset.load_data()
118
+ >>> print(dataset.ds)
119
+ <xarray.Dataset>
120
+ Dimensions: ...
121
+ """
122
+
123
+ filename: str
124
+ start_time: Optional[datetime] = None
125
+ end_time: Optional[datetime] = None
126
+ var_names: List[str]
127
+ dim_names: Dict[str, str] = field(
128
+ default_factory=lambda: {
129
+ "longitude": "longitude",
130
+ "latitude": "latitude",
131
+ "time": "time",
132
+ }
133
+ )
134
+
135
+ ds: xr.Dataset = field(init=False, repr=False)
136
+
137
+ def __post_init__(self):
138
+ ds = self.load_data()
139
+
140
+ # Select relevant times
141
+ if "time" in self.dim_names and self.start_time is not None:
142
+ ds = self.select_relevant_times(ds)
143
+
144
+ # Select relevant fields
145
+ ds = self.select_relevant_fields(ds)
146
+
147
+ # Make sure that latitude is ascending
148
+ diff = np.diff(ds[self.dim_names["latitude"]])
149
+ if np.all(diff < 0):
150
+ ds = ds.isel(**{self.dim_names["latitude"]: slice(None, None, -1)})
151
+
152
+ # Check whether the data covers the entire globe
153
+ is_global = self.check_if_global(ds)
154
+
155
+ if is_global:
156
+ ds = self.concatenate_longitudes(ds)
157
+
158
+ object.__setattr__(self, "ds", ds)
159
+
160
+ def load_data(self) -> xr.Dataset:
161
+ """
162
+ Load dataset from the specified file.
163
+
164
+ Returns
165
+ -------
166
+ ds : xr.Dataset
167
+ The loaded xarray Dataset containing the forcing data.
168
+
169
+ Raises
170
+ ------
171
+ FileNotFoundError
172
+ If the specified file does not exist.
173
+ """
174
+
175
+ # Check if the file exists
176
+ matching_files = glob.glob(self.filename)
177
+ if not matching_files:
178
+ raise FileNotFoundError(
179
+ f"No files found matching the pattern '{self.filename}'."
180
+ )
181
+
182
+ # Load the dataset
183
+ with dask.config.set(**{"array.slicing.split_large_chunks": False}):
184
+ # Define the chunk sizes
185
+ chunks = {
186
+ self.dim_names["latitude"]: -1,
187
+ self.dim_names["longitude"]: -1,
188
+ }
189
+ if "depth" in self.dim_names.keys():
190
+ chunks[self.dim_names["depth"]] = -1
191
+ if "time" in self.dim_names.keys():
192
+ chunks[self.dim_names["time"]] = 1
193
+
194
+ ds = xr.open_mfdataset(
195
+ self.filename,
196
+ combine="nested",
197
+ concat_dim=self.dim_names["time"],
198
+ coords="minimal",
199
+ compat="override",
200
+ chunks=chunks,
201
+ engine="netcdf4",
202
+ )
203
+ else:
204
+ ds = xr.open_dataset(
205
+ self.filename,
206
+ chunks=chunks,
207
+ )
208
+
209
+ return ds
210
+
211
+ def select_relevant_fields(self, ds) -> xr.Dataset:
212
+ """
213
+ Selects and returns a subset of the dataset containing only the variables specified in `self.var_names`.
214
+
215
+ Parameters
216
+ ----------
217
+ ds : xr.Dataset
218
+ The input dataset from which variables will be selected.
219
+
220
+ Returns
221
+ -------
222
+ xr.Dataset
223
+ A dataset containing only the variables specified in `self.var_names`.
224
+
225
+ Raises
226
+ ------
227
+ ValueError
228
+ If `ds` does not contain all variables listed in `self.var_names`.
229
+
230
+ """
231
+ missing_vars = [var for var in self.var_names if var not in ds.data_vars]
232
+ if missing_vars:
233
+ raise ValueError(
234
+ f"Dataset does not contain all required variables. The following variables are missing: {missing_vars}"
235
+ )
236
+
237
+ for var in ds.data_vars:
238
+ if var not in self.var_names:
239
+ ds = ds.drop_vars(var)
240
+
241
+ return ds
242
+
243
+ def select_relevant_times(self, ds) -> xr.Dataset:
244
+
245
+ """
246
+ Selects and returns the subset of the dataset corresponding to the specified time range.
247
+
248
+ This function filters the dataset to include only the data points within the specified
249
+ time range, defined by `self.start_time` and `self.end_time`. If `self.end_time` is not
250
+ provided, it defaults to one day after `self.start_time`.
251
+
252
+ Parameters
253
+ ----------
254
+ ds : xr.Dataset
255
+ The input dataset to be filtered.
256
+
257
+ Returns
258
+ -------
259
+ xr.Dataset
260
+ A dataset containing only the data points within the specified time range.
261
+
262
+ """
263
+
264
+ time_dim = self.dim_names["time"]
265
+
266
+ if not self.end_time:
267
+ end_time = self.start_time + timedelta(days=1)
268
+ else:
269
+ end_time = self.end_time
270
+
271
+ times = (np.datetime64(self.start_time) <= ds[time_dim]) & (
272
+ ds[time_dim] < np.datetime64(end_time)
273
+ )
274
+ ds = ds.where(times, drop=True)
275
+
276
+ if not ds.sizes[time_dim]:
277
+ raise ValueError("No matching times found.")
278
+
279
+ if not self.end_time:
280
+ if ds.sizes[time_dim] != 1:
281
+ found_times = ds.sizes[time_dim]
282
+ raise ValueError(
283
+ f"There must be exactly one time matching the start_time. Found {found_times} matching times."
284
+ )
285
+
286
+ return ds
287
+
288
+ def check_if_global(self, ds) -> bool:
289
+ """
290
+ Checks if the dataset covers the entire globe in the longitude dimension.
291
+
292
+ This function calculates the mean difference between consecutive longitude values.
293
+ It then checks if the difference between the first and last longitude values (plus 360 degrees)
294
+ is close to this mean difference, within a specified tolerance. If it is, the dataset is considered
295
+ to cover the entire globe in the longitude dimension.
296
+
297
+ Returns
298
+ -------
299
+ bool
300
+ True if the dataset covers the entire globe in the longitude dimension, False otherwise.
301
+
302
+ """
303
+ dlon_mean = (
304
+ ds[self.dim_names["longitude"]].diff(dim=self.dim_names["longitude"]).mean()
305
+ )
306
+ dlon = (
307
+ ds[self.dim_names["longitude"]][0] - ds[self.dim_names["longitude"]][-1]
308
+ ) % 360.0
309
+ is_global = np.isclose(dlon, dlon_mean, rtol=0.0, atol=1e-3)
310
+
311
+ return is_global
312
+
313
+ def concatenate_longitudes(self, ds):
314
+ """
315
+ Concatenates the field three times: with longitudes shifted by -360, original longitudes, and shifted by +360.
316
+
317
+ Parameters
318
+ ----------
319
+ field : xr.DataArray
320
+ The field to be concatenated.
321
+
322
+ Returns
323
+ -------
324
+ xr.DataArray
325
+ The concatenated field, with the longitude dimension extended.
326
+
327
+ Notes
328
+ -----
329
+ Concatenating three times may be overkill in most situations, but it is safe. Alternatively, we could refactor
330
+ to figure out whether concatenating on the lower end, upper end, or at all is needed.
331
+
332
+ """
333
+ ds_concatenated = xr.Dataset()
334
+
335
+ lon = ds[self.dim_names["longitude"]]
336
+ lon_minus360 = lon - 360
337
+ lon_plus360 = lon + 360
338
+ lon_concatenated = xr.concat(
339
+ [lon_minus360, lon, lon_plus360], dim=self.dim_names["longitude"]
340
+ )
341
+
342
+ ds_concatenated[self.dim_names["longitude"]] = lon_concatenated
343
+
344
+ for var in self.var_names:
345
+ if self.dim_names["longitude"] in ds[var].dims:
346
+ field = ds[var]
347
+ field_concatenated = xr.concat(
348
+ [field, field, field], dim=self.dim_names["longitude"]
349
+ ).chunk({self.dim_names["longitude"]: -1})
350
+ field_concatenated[self.dim_names["longitude"]] = lon_concatenated
351
+ ds_concatenated[var] = field_concatenated
352
+ else:
353
+ ds_concatenated[var] = ds[var]
354
+
355
+ return ds_concatenated
356
+
357
+ def choose_subdomain(
358
+ self, latitude_range, longitude_range, margin, straddle, return_subdomain=False
359
+ ):
360
+ """
361
+ Selects a subdomain from the given xarray Dataset based on latitude and longitude ranges,
362
+ extending the selection by the specified margin. Handles the conversion of longitude values
363
+ in the dataset from one range to another.
364
+
365
+ Parameters
366
+ ----------
367
+ latitude_range : tuple
368
+ A tuple (lat_min, lat_max) specifying the minimum and maximum latitude values of the subdomain.
369
+ longitude_range : tuple
370
+ A tuple (lon_min, lon_max) specifying the minimum and maximum longitude values of the subdomain.
371
+ margin : float
372
+ Margin in degrees to extend beyond the specified latitude and longitude ranges when selecting the subdomain.
373
+ straddle : bool
374
+ If True, target longitudes are expected in the range [-180, 180].
375
+ If False, target longitudes are expected in the range [0, 360].
376
+ return_subdomain : bool, optional
377
+ If True, returns the subset of the original dataset. If False, assigns it to self.ds.
378
+ Default is False.
379
+
380
+ Returns
381
+ -------
382
+ xr.Dataset
383
+ The subset of the original dataset representing the chosen subdomain, including an extended area
384
+ to cover one extra grid point beyond the specified ranges if return_subdomain is True.
385
+ Otherwise, returns None.
386
+
387
+ Raises
388
+ ------
389
+ ValueError
390
+ If the selected latitude or longitude range does not intersect with the dataset.
391
+ """
392
+ lat_min, lat_max = latitude_range
393
+ lon_min, lon_max = longitude_range
394
+
395
+ lon = self.ds[self.dim_names["longitude"]]
396
+ # Adjust longitude range if needed to match the expected range
397
+ if not straddle:
398
+ if lon.min() < -180:
399
+ if lon_max + margin > 0:
400
+ lon_min -= 360
401
+ lon_max -= 360
402
+ elif lon.min() < 0:
403
+ if lon_max + margin > 180:
404
+ lon_min -= 360
405
+ lon_max -= 360
406
+
407
+ if straddle:
408
+ if lon.max() > 360:
409
+ if lon_min - margin < 180:
410
+ lon_min += 360
411
+ lon_max += 360
412
+ elif lon.max() > 180:
413
+ if lon_min - margin < 0:
414
+ lon_min += 360
415
+ lon_max += 360
416
+
417
+ # Select the subdomain
418
+ subdomain = self.ds.sel(
419
+ **{
420
+ self.dim_names["latitude"]: slice(lat_min - margin, lat_max + margin),
421
+ self.dim_names["longitude"]: slice(lon_min - margin, lon_max + margin),
422
+ }
423
+ )
424
+
425
+ # Check if the selected subdomain has zero dimensions in latitude or longitude
426
+ if subdomain[self.dim_names["latitude"]].size == 0:
427
+ raise ValueError("Selected latitude range does not intersect with dataset.")
428
+
429
+ if subdomain[self.dim_names["longitude"]].size == 0:
430
+ raise ValueError(
431
+ "Selected longitude range does not intersect with dataset."
432
+ )
433
+
434
+ # Adjust longitudes to expected range if needed
435
+ lon = subdomain[self.dim_names["longitude"]]
436
+ if straddle:
437
+ subdomain[self.dim_names["longitude"]] = xr.where(lon > 180, lon - 360, lon)
438
+ else:
439
+ subdomain[self.dim_names["longitude"]] = xr.where(lon < 0, lon + 360, lon)
440
+
441
+ if return_subdomain:
442
+ return subdomain
443
+ else:
444
+ object.__setattr__(self, "ds", subdomain)
445
+
446
+ def convert_to_negative_depth(self):
447
+ """
448
+ Converts the depth values in the dataset to negative if they are non-negative.
449
+
450
+ This method checks the values in the depth dimension of the dataset (`self.ds[self.dim_names["depth"]]`).
451
+ If all values are greater than or equal to zero, it negates them and updates the dataset accordingly.
452
+
453
+ """
454
+ depth = self.ds[self.dim_names["depth"]]
455
+
456
+ if (depth >= 0).all():
457
+ self.ds[self.dim_names["depth"]] = -depth