roms-tools 0.20__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
roms_tools/setup/tides.py CHANGED
@@ -3,165 +3,21 @@ import xarray as xr
3
3
  import numpy as np
4
4
  import yaml
5
5
  import importlib.metadata
6
+ from typing import Dict, Union
6
7
 
7
8
  from dataclasses import dataclass, field, asdict
8
9
  from roms_tools.setup.grid import Grid
9
10
  from roms_tools.setup.plot import _plot
10
11
  from roms_tools.setup.fill import fill_and_interpolate
11
- from roms_tools.setup.datasets import Dataset
12
+ from roms_tools.setup.datasets import TPXODataset
12
13
  from roms_tools.setup.utils import (
13
14
  nan_check,
14
15
  interpolate_from_rho_to_u,
15
16
  interpolate_from_rho_to_v,
16
17
  )
17
- from typing import Dict, List
18
18
  import matplotlib.pyplot as plt
19
19
 
20
20
 
21
- @dataclass(frozen=True, kw_only=True)
22
- class TPXO(Dataset):
23
- """
24
- Represents tidal data on original grid.
25
-
26
- Parameters
27
- ----------
28
- filename : str
29
- The path to the TPXO dataset.
30
- var_names : List[str], optional
31
- List of variable names that are required in the dataset. Defaults to
32
- ["h_Re", "h_Im", "sal_Re", "sal_Im", "u_Re", "u_Im", "v_Re", "v_Im"].
33
- dim_names: Dict[str, str], optional
34
- Dictionary specifying the names of dimensions in the dataset. Defaults to
35
- {"longitude": "ny", "latitude": "nx"}.
36
-
37
- Attributes
38
- ----------
39
- ds : xr.Dataset
40
- The xarray Dataset containing TPXO tidal model data.
41
- """
42
-
43
- filename: str
44
- var_names: List[str] = field(
45
- default_factory=lambda: [
46
- "h_Re",
47
- "h_Im",
48
- "sal_Re",
49
- "sal_Im",
50
- "u_Re",
51
- "u_Im",
52
- "v_Re",
53
- "v_Im",
54
- "depth",
55
- ]
56
- )
57
- dim_names: Dict[str, str] = field(
58
- default_factory=lambda: {"longitude": "ny", "latitude": "nx", "ntides": "nc"}
59
- )
60
- ds: xr.Dataset = field(init=False, repr=False)
61
-
62
- def __post_init__(self):
63
- # Perform any necessary dataset initialization or modifications here
64
- ds = super().load_data()
65
-
66
- # Clean up dataset
67
- ds = ds.assign_coords(
68
- {
69
- "omega": ds["omega"],
70
- "nx": ds["lon_r"].isel(
71
- ny=0
72
- ), # lon_r is constant along ny, i.e., is only a function of nx
73
- "ny": ds["lat_r"].isel(
74
- nx=0
75
- ), # lat_r is constant along nx, i.e., is only a function of ny
76
- }
77
- )
78
- ds = ds.rename({"nx": "longitude", "ny": "latitude"})
79
-
80
- object.__setattr__(
81
- self,
82
- "dim_names",
83
- {
84
- "latitude": "latitude",
85
- "longitude": "longitude",
86
- "ntides": self.dim_names["ntides"],
87
- },
88
- )
89
- # Select relevant fields
90
- ds = super().select_relevant_fields(ds)
91
-
92
- # Check whether the data covers the entire globe
93
- is_global = self.check_if_global(ds)
94
-
95
- if is_global:
96
- ds = self.concatenate_longitudes(ds)
97
-
98
- object.__setattr__(self, "ds", ds)
99
-
100
- def check_number_constituents(self, ntides: int):
101
- """
102
- Checks if the number of constituents in the dataset is at least `ntides`.
103
-
104
- Parameters
105
- ----------
106
- ntides : int
107
- The required number of tidal constituents.
108
-
109
- Raises
110
- ------
111
- ValueError
112
- If the number of constituents in the dataset is less than `ntides`.
113
- """
114
- if len(self.ds[self.dim_names["ntides"]]) < ntides:
115
- raise ValueError(
116
- f"The dataset contains fewer than {ntides} tidal constituents."
117
- )
118
-
119
- def get_corrected_tides(self, model_reference_date, allan_factor):
120
- # Get equilibrium tides
121
- tpc = compute_equilibrium_tide(self.ds["longitude"], self.ds["latitude"]).isel(
122
- nc=self.ds.nc
123
- )
124
- # Correct for SAL
125
- tsc = allan_factor * (self.ds["sal_Re"] + 1j * self.ds["sal_Im"])
126
- tpc = tpc - tsc
127
-
128
- # Elevations and transports
129
- thc = self.ds["h_Re"] + 1j * self.ds["h_Im"]
130
- tuc = self.ds["u_Re"] + 1j * self.ds["u_Im"]
131
- tvc = self.ds["v_Re"] + 1j * self.ds["v_Im"]
132
-
133
- # Apply correction for phases and amplitudes
134
- pf, pu, aa = egbert_correction(model_reference_date)
135
- pf = pf.isel(nc=self.ds.nc)
136
- pu = pu.isel(nc=self.ds.nc)
137
- aa = aa.isel(nc=self.ds.nc)
138
-
139
- tpxo_reference_date = datetime(1992, 1, 1)
140
- dt = (model_reference_date - tpxo_reference_date).days * 3600 * 24
141
-
142
- thc = pf * thc * np.exp(1j * (self.ds["omega"] * dt + pu + aa))
143
- tuc = pf * tuc * np.exp(1j * (self.ds["omega"] * dt + pu + aa))
144
- tvc = pf * tvc * np.exp(1j * (self.ds["omega"] * dt + pu + aa))
145
- tpc = pf * tpc * np.exp(1j * (self.ds["omega"] * dt + pu + aa))
146
-
147
- tides = {
148
- "ssh_Re": thc.real,
149
- "ssh_Im": thc.imag,
150
- "u_Re": tuc.real,
151
- "u_Im": tuc.imag,
152
- "v_Re": tvc.real,
153
- "v_Im": tvc.imag,
154
- "pot_Re": tpc.real,
155
- "pot_Im": tpc.imag,
156
- "omega": self.ds["omega"],
157
- }
158
-
159
- for k in tides.keys():
160
- tides[k] = tides[k].rename({"nc": "ntides"})
161
-
162
- return tides
163
-
164
-
165
21
  @dataclass(frozen=True, kw_only=True)
166
22
  class TidalForcing:
167
23
  """
@@ -171,16 +27,16 @@ class TidalForcing:
171
27
  ----------
172
28
  grid : Grid
173
29
  The grid object representing the ROMS grid associated with the tidal forcing data.
174
- filename: str
175
- The path to the native tidal dataset.
30
+ source : Dict[str, Union[str, None]]
31
+ Dictionary specifying the source of the tidal data:
32
+ - "name" (str): Name of the data source (e.g., "TPXO").
33
+ - "path" (str): Path to the tidal data file. Can contain wildcards.
176
34
  ntides : int, optional
177
35
  Number of constituents to consider. Maximum number is 14. Default is 10.
178
- model_reference_date : datetime, optional
179
- The reference date for the ROMS simulation. Default is datetime(2000, 1, 1).
180
- source : str, optional
181
- The source of the tidal data. Default is "TPXO".
182
36
  allan_factor : float, optional
183
37
  The Allan factor used in tidal model computation. Default is 2.0.
38
+ model_reference_date : datetime, optional
39
+ The reference date for the ROMS simulation. Default is datetime(2000, 1, 1).
184
40
 
185
41
  Attributes
186
42
  ----------
@@ -189,27 +45,31 @@ class TidalForcing:
189
45
 
190
46
  Examples
191
47
  --------
192
- >>> grid = Grid(...)
193
- >>> tidal_forcing = TidalForcing(grid)
194
- >>> print(tidal_forcing.ds)
48
+ >>> tidal_forcing = TidalForcing(
49
+ ... grid=grid, source={"name": "TPXO", "path": "tpxo_data.nc"}
50
+ ... )
195
51
  """
196
52
 
197
53
  grid: Grid
198
- filename: str
54
+ source: Dict[str, Union[str, None]]
199
55
  ntides: int = 10
200
- model_reference_date: datetime = datetime(2000, 1, 1)
201
- source: str = "TPXO"
202
56
  allan_factor: float = 2.0
57
+ model_reference_date: datetime = datetime(2000, 1, 1)
58
+
203
59
  ds: xr.Dataset = field(init=False, repr=False)
204
60
 
205
61
  def __post_init__(self):
206
- if self.source == "TPXO":
207
- data = TPXO(filename=self.filename)
62
+ if "name" not in self.source.keys():
63
+ raise ValueError("`source` must include a 'name'.")
64
+ if "path" not in self.source.keys():
65
+ raise ValueError("`source` must include a 'path'.")
66
+ if self.source["name"] == "TPXO":
67
+ data = TPXODataset(filename=self.source["path"])
208
68
  else:
209
- raise ValueError('Only "TPXO" is a valid option for source.')
69
+ raise ValueError('Only "TPXO" is a valid option for source["name"].')
210
70
 
211
71
  data.check_number_constituents(self.ntides)
212
- # operate on longitudes between -180 and 180 unless ROMS domain lies at least 5 degrees in lontitude away from Greenwich meridian
72
+ # operate on longitudes between -180 and 180 unless ROMS domain lies at least 5 degrees in longitude away from Greenwich meridian
213
73
  lon = self.grid.ds.lon_rho
214
74
  lat = self.grid.ds.lat_rho
215
75
  angle = self.grid.ds.angle
@@ -220,14 +80,9 @@ class TidalForcing:
220
80
  lon = xr.where(lon < 0, lon + 360, lon)
221
81
  straddle = False
222
82
 
223
- # The following consists of two steps:
224
- # Step 1: Choose subdomain of forcing data including safety margin for interpolation, and Step 2: Convert to the proper longitude range.
225
- # We perform these two steps for two reasons:
226
- # A) Since the horizontal dimensions consist of a single chunk, selecting a subdomain before interpolation is a lot more performant.
227
- # B) Step 1 is necessary to avoid discontinuous longitudes that could be introduced by Step 2. Specifically, discontinuous longitudes
228
- # can lead to artifacts in the interpolation process. Specifically, if there is a data gap if data is not global,
229
- # discontinuous longitudes could result in values that appear to come from a distant location instead of producing NaNs.
230
- # These NaNs are important as they can be identified and handled appropriately by the nan_check function.
83
+ # Restrict data to relevant subdomain to achieve better performance and to avoid discontinuous longitudes introduced by converting
84
+ # to a different longitude range (+- 360 degrees). Discontinues longitudes can lead to artifacts in the interpolation process that
85
+ # would not be detected by the nan_check function.
231
86
  data.choose_subdomain(
232
87
  latitude_range=[lat.min().values, lat.max().values],
233
88
  longitude_range=[lon.min().values, lon.max().values],
@@ -235,7 +90,7 @@ class TidalForcing:
235
90
  straddle=straddle,
236
91
  )
237
92
 
238
- tides = data.get_corrected_tides(self.model_reference_date, self.allan_factor)
93
+ tides = self._get_corrected_tides(data)
239
94
 
240
95
  # select desired number of constituents
241
96
  for k in tides.keys():
@@ -326,7 +181,7 @@ class TidalForcing:
326
181
 
327
182
  ds.attrs["roms_tools_version"] = roms_tools_version
328
183
 
329
- ds.attrs["source"] = self.source
184
+ ds.attrs["source"] = self.source["name"]
330
185
  ds.attrs["model_reference_date"] = str(self.model_reference_date)
331
186
  ds.attrs["allan_factor"] = self.allan_factor
332
187
 
@@ -430,10 +285,9 @@ class TidalForcing:
430
285
  # Extract tidal forcing data
431
286
  tidal_forcing_data = {
432
287
  "TidalForcing": {
433
- "filename": self.filename,
288
+ "source": self.source,
434
289
  "ntides": self.ntides,
435
290
  "model_reference_date": self.model_reference_date.isoformat(),
436
- "source": self.source,
437
291
  "allan_factor": self.allan_factor,
438
292
  }
439
293
  }
@@ -494,6 +348,54 @@ class TidalForcing:
494
348
  # Create and return an instance of TidalForcing
495
349
  return cls(grid=grid, **tidal_forcing_params)
496
350
 
351
+ def _get_corrected_tides(self, data):
352
+
353
+ # Get equilibrium tides
354
+ tpc = compute_equilibrium_tide(
355
+ data.ds[data.dim_names["longitude"]], data.ds[data.dim_names["latitude"]]
356
+ )
357
+ tpc = tpc.isel(**{data.dim_names["ntides"]: data.ds[data.dim_names["ntides"]]})
358
+ # Correct for SAL
359
+ tsc = self.allan_factor * (
360
+ data.ds[data.var_names["sal_Re"]] + 1j * data.ds[data.var_names["sal_Im"]]
361
+ )
362
+ tpc = tpc - tsc
363
+
364
+ # Elevations and transports
365
+ thc = data.ds[data.var_names["ssh_Re"]] + 1j * data.ds[data.var_names["ssh_Im"]]
366
+ tuc = data.ds[data.var_names["u_Re"]] + 1j * data.ds[data.var_names["u_Im"]]
367
+ tvc = data.ds[data.var_names["v_Re"]] + 1j * data.ds[data.var_names["v_Im"]]
368
+
369
+ # Apply correction for phases and amplitudes
370
+ pf, pu, aa = egbert_correction(self.model_reference_date)
371
+ pf = pf.isel(**{data.dim_names["ntides"]: data.ds[data.dim_names["ntides"]]})
372
+ pu = pu.isel(**{data.dim_names["ntides"]: data.ds[data.dim_names["ntides"]]})
373
+ aa = aa.isel(**{data.dim_names["ntides"]: data.ds[data.dim_names["ntides"]]})
374
+
375
+ dt = (self.model_reference_date - data.reference_date).days * 3600 * 24
376
+
377
+ thc = pf * thc * np.exp(1j * (data.ds["omega"] * dt + pu + aa))
378
+ tuc = pf * tuc * np.exp(1j * (data.ds["omega"] * dt + pu + aa))
379
+ tvc = pf * tvc * np.exp(1j * (data.ds["omega"] * dt + pu + aa))
380
+ tpc = pf * tpc * np.exp(1j * (data.ds["omega"] * dt + pu + aa))
381
+
382
+ tides = {
383
+ "ssh_Re": thc.real,
384
+ "ssh_Im": thc.imag,
385
+ "u_Re": tuc.real,
386
+ "u_Im": tuc.imag,
387
+ "v_Re": tvc.real,
388
+ "v_Im": tvc.imag,
389
+ "pot_Re": tpc.real,
390
+ "pot_Im": tpc.imag,
391
+ "omega": data.ds["omega"],
392
+ }
393
+
394
+ for k in tides.keys():
395
+ tides[k] = tides[k].rename({data.dim_names["ntides"]: "ntides"})
396
+
397
+ return tides
398
+
497
399
 
498
400
  def modified_julian_days(year, month, day, hour=0):
499
401
  """
@@ -3,7 +3,7 @@ import numpy as np
3
3
  import gcm_filters
4
4
  from scipy.interpolate import RegularGridInterpolator
5
5
  from scipy.ndimage import label
6
- from roms_tools.setup.datasets import fetch_topo
6
+ from roms_tools.setup.download import fetch_topo
7
7
  from roms_tools.setup.utils import interpolate_from_rho_to_u, interpolate_from_rho_to_v
8
8
  import warnings
9
9
  from itertools import count
roms_tools/setup/utils.py CHANGED
@@ -1,4 +1,8 @@
1
1
  import xarray as xr
2
+ import numpy as np
3
+ from typing import Union
4
+ import pandas as pd
5
+ import cftime
2
6
 
3
7
 
4
8
  def nan_check(field, mask) -> None:
@@ -160,3 +164,189 @@ def extrapolate_deepest_to_bottom(field: xr.DataArray, dim: str) -> xr.DataArray
160
164
  )
161
165
 
162
166
  return field_interpolated
167
+
168
+
169
+ def assign_dates_to_climatology(ds: xr.Dataset, time_dim: str) -> xr.Dataset:
170
+ """
171
+ Assigns climatology dates to the dataset's time dimension.
172
+
173
+ This function updates the dataset's time coordinates to reflect climatological dates.
174
+ It defines fixed day increments for each month and assigns these to the specified time dimension.
175
+ The increments represent the cumulative days at mid-month for each month.
176
+
177
+ Parameters
178
+ ----------
179
+ ds : xr.Dataset
180
+ The xarray Dataset to which climatological dates will be assigned.
181
+ time_dim : str
182
+ The name of the time dimension in the dataset that will be updated with climatological dates.
183
+
184
+ Returns
185
+ -------
186
+ xr.Dataset
187
+ The updated xarray Dataset with climatological dates assigned to the specified time dimension.
188
+
189
+ """
190
+ # Define the days in each month and convert to timedelta
191
+ increments = [15, 30, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30]
192
+ days = np.cumsum(increments)
193
+ timedelta_ns = np.array(days, dtype="timedelta64[D]").astype("timedelta64[ns]")
194
+ time = xr.DataArray(timedelta_ns, dims=[time_dim])
195
+ ds = ds.assign_coords({"time": time})
196
+ return ds
197
+
198
+
199
+ def interpolate_from_climatology(
200
+ field: Union[xr.DataArray, xr.Dataset],
201
+ time_dim_name: str,
202
+ time: Union[xr.DataArray, pd.DatetimeIndex],
203
+ ) -> Union[xr.DataArray, xr.Dataset]:
204
+ """
205
+ Interpolates the given field temporally based on the specified time points.
206
+
207
+ If `field` is an xarray.Dataset, this function applies the interpolation to all data variables in the dataset.
208
+
209
+ Parameters
210
+ ----------
211
+ field : xarray.DataArray or xarray.Dataset
212
+ The field data to be interpolated. Can be a single DataArray or a Dataset.
213
+ time_dim_name : str
214
+ The name of the dimension in `field` that represents time.
215
+ time : xarray.DataArray or pandas.DatetimeIndex
216
+ The target time points for interpolation.
217
+
218
+ Returns
219
+ -------
220
+ xarray.DataArray or xarray.Dataset
221
+ The field values interpolated to the specified time points. The type matches the input type.
222
+ """
223
+
224
+ def interpolate_single_field(data_array: xr.DataArray) -> xr.DataArray:
225
+
226
+ if isinstance(time, xr.DataArray):
227
+ # Extract day of year from xarray.DataArray
228
+ day_of_year = time.dt.dayofyear
229
+ else:
230
+ if np.size(time) == 1:
231
+ day_of_year = time.timetuple().tm_yday
232
+ else:
233
+ day_of_year = np.array([t.timetuple().tm_yday for t in time])
234
+
235
+ data_array[time_dim_name] = data_array[time_dim_name].dt.days
236
+
237
+ # Concatenate across the beginning and end of the year
238
+ time_concat = xr.concat(
239
+ [
240
+ data_array[time_dim_name][-1] - 365.25,
241
+ data_array[time_dim_name],
242
+ 365.25 + data_array[time_dim_name][0],
243
+ ],
244
+ dim=time_dim_name,
245
+ )
246
+ data_array_concat = xr.concat(
247
+ [
248
+ data_array.isel(**{time_dim_name: -1}),
249
+ data_array,
250
+ data_array.isel(**{time_dim_name: 0}),
251
+ ],
252
+ dim=time_dim_name,
253
+ )
254
+ data_array_concat[time_dim_name] = time_concat
255
+
256
+ # Interpolate to specified times
257
+ data_array_interpolated = data_array_concat.interp(
258
+ **{time_dim_name: day_of_year}, method="linear"
259
+ )
260
+
261
+ if np.size(time) == 1:
262
+ data_array_interpolated = data_array_interpolated.expand_dims(
263
+ {time_dim_name: 1}
264
+ )
265
+ return data_array_interpolated
266
+
267
+ if isinstance(field, xr.DataArray):
268
+ return interpolate_single_field(field)
269
+ elif isinstance(field, xr.Dataset):
270
+ interpolated_data_vars = {
271
+ var: interpolate_single_field(data_array)
272
+ for var, data_array in field.data_vars.items()
273
+ }
274
+ return xr.Dataset(interpolated_data_vars, attrs=field.attrs)
275
+ else:
276
+ raise TypeError("Input 'field' must be an xarray.DataArray or xarray.Dataset.")
277
+
278
+
279
+ def is_cftime_datetime(data_array: xr.DataArray) -> bool:
280
+ """
281
+ Checks if the xarray DataArray contains cftime datetime objects.
282
+
283
+ Parameters
284
+ ----------
285
+ data_array : xr.DataArray
286
+ The xarray DataArray to be checked for cftime datetime objects.
287
+
288
+ Returns
289
+ -------
290
+ bool
291
+ True if the DataArray contains cftime datetime objects, False otherwise.
292
+
293
+ Raises
294
+ ------
295
+ TypeError
296
+ If the values in the DataArray are not of type numpy.ndarray or list.
297
+ """
298
+ # List of cftime datetime types
299
+ cftime_types = (
300
+ cftime.DatetimeNoLeap,
301
+ cftime.DatetimeJulian,
302
+ cftime.DatetimeGregorian,
303
+ )
304
+
305
+ # Check if any of the coordinate values are of cftime type
306
+ if isinstance(data_array.values, (np.ndarray, list)):
307
+ # Check the dtype of the array; numpy datetime64 indicates it's not cftime
308
+ if data_array.values.dtype == "datetime64[ns]":
309
+ return False
310
+
311
+ # Check if any of the values in the array are instances of cftime types
312
+ return any(isinstance(value, cftime_types) for value in data_array.values)
313
+
314
+ # Handle unexpected types
315
+ raise TypeError("DataArray values must be of type numpy.ndarray or list.")
316
+
317
+
318
+ def convert_cftime_to_datetime(data_array: np.ndarray) -> np.ndarray:
319
+ """
320
+ Converts cftime datetime objects to numpy datetime64 objects in a numpy ndarray.
321
+
322
+ Parameters
323
+ ----------
324
+ data_array : np.ndarray
325
+ The numpy ndarray containing cftime datetime objects to be converted.
326
+
327
+ Returns
328
+ -------
329
+ np.ndarray
330
+ The ndarray with cftime datetimes converted to numpy datetime64 objects.
331
+
332
+ Notes
333
+ -----
334
+ This function is intended to be used with numpy ndarrays. If you need to convert
335
+ cftime datetime objects in an xarray.DataArray, please use the appropriate function
336
+ to handle xarray.DataArray conversions.
337
+ """
338
+ # List of cftime datetime types
339
+ cftime_types = (
340
+ cftime.DatetimeNoLeap,
341
+ cftime.DatetimeJulian,
342
+ cftime.DatetimeGregorian,
343
+ )
344
+
345
+ # Define a conversion function for cftime to numpy datetime64
346
+ def convert_datetime(dt):
347
+ if isinstance(dt, cftime_types):
348
+ # Convert to ISO format and then to nanosecond precision
349
+ return np.datetime64(dt.isoformat(), "ns")
350
+ return np.datetime64(dt, "ns")
351
+
352
+ return np.vectorize(convert_datetime)(data_array)