roms-tools 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,7 +3,8 @@ import numpy as np
3
3
  import gcm_filters
4
4
  from scipy.interpolate import RegularGridInterpolator
5
5
  from scipy.ndimage import label
6
- from roms_tools.setup.datasets import fetch_topo
6
+ from roms_tools.setup.download import fetch_topo
7
+ from roms_tools.setup.utils import interpolate_from_rho_to_u, interpolate_from_rho_to_v
7
8
  import warnings
8
9
  from itertools import count
9
10
 
@@ -19,7 +20,7 @@ def _add_topography_and_mask(
19
20
  hraw = xr.DataArray(data=hraw, dims=["eta_rho", "xi_rho"])
20
21
 
21
22
  # Mask is obtained by finding locations where ocean depth is positive
22
- mask = xr.where(hraw > 0, 1, 0)
23
+ mask = xr.where(hraw > 0, 1.0, 0.0)
23
24
 
24
25
  # smooth topography domain-wide with Gaussian kernel to avoid grid scale instabilities
25
26
  ds["hraw"] = _smooth_topography_globally(hraw, mask, smooth_factor)
@@ -37,6 +38,8 @@ def _add_topography_and_mask(
37
38
  "units": "land/water (0/1)",
38
39
  }
39
40
 
41
+ ds = _add_velocity_masks(ds)
42
+
40
43
  # smooth topography locally to satisfy r < rmax
41
44
  ds["h"] = _smooth_topography_locally(ds["hraw"] * ds["mask_rho"], hmin, rmax)
42
45
  ds["h"].attrs = {
@@ -57,7 +60,7 @@ def _make_raw_topography(lon, lat, topography_source) -> np.ndarray:
57
60
  topo_ds = fetch_topo(topography_source)
58
61
 
59
62
  # the following will depend on the topography source
60
- if topography_source == "etopo5":
63
+ if topography_source == "ETOPO5":
61
64
  topo_lon = topo_ds["topo_lon"].copy()
62
65
  # Modify longitude values where necessary
63
66
  topo_lon = xr.where(topo_lon < 0, topo_lon + 360, topo_lon)
@@ -240,3 +243,15 @@ def _add_topography_metadata(ds, topography_source, smooth_factor, hmin, rmax):
240
243
  ds.attrs["rmax"] = rmax
241
244
 
242
245
  return ds
246
+
247
+
248
+ def _add_velocity_masks(ds):
249
+
250
+ # add u- and v-masks
251
+ ds["mask_u"] = interpolate_from_rho_to_u(ds["mask_rho"], method="multiplicative")
252
+ ds["mask_v"] = interpolate_from_rho_to_v(ds["mask_rho"], method="multiplicative")
253
+
254
+ ds["mask_u"].attrs = {"long_name": "Mask at u-points", "units": "land/water (0/1)"}
255
+ ds["mask_v"].attrs = {"long_name": "Mask at v-points", "units": "land/water (0/1)"}
256
+
257
+ return ds
@@ -0,0 +1,352 @@
1
+ import xarray as xr
2
+ import numpy as np
3
+ from typing import Union
4
+ import pandas as pd
5
+ import cftime
6
+
7
+
8
+ def nan_check(field, mask) -> None:
9
+ """
10
+ Checks for NaN values at wet points in the field.
11
+
12
+ This function examines the interpolated input field for NaN values at positions indicated as wet points by the mask.
13
+ If any NaN values are found at these wet points, a ValueError is raised.
14
+
15
+ Parameters
16
+ ----------
17
+ field : array-like
18
+ The data array to be checked for NaN values. This is typically an xarray.DataArray or numpy array.
19
+
20
+ mask : array-like
21
+ A boolean mask or data array with the same shape as `field`. The wet points (usually ocean points)
22
+ are indicated by `1` or `True`, and land points by `0` or `False`.
23
+
24
+ Raises
25
+ ------
26
+ ValueError
27
+ If the field contains NaN values at any of the wet points indicated by the mask.
28
+ The error message will explain the potential cause and suggest ensuring the dataset's coverage.
29
+
30
+ """
31
+
32
+ # Replace values in field with 0 where mask is not 1
33
+ da = xr.where(mask == 1, field, 0)
34
+
35
+ # Check if any NaN values exist in the modified field
36
+ if da.isnull().any().values:
37
+ raise ValueError(
38
+ "NaN values found in interpolated field. This likely occurs because the ROMS grid, including "
39
+ "a small safety margin for interpolation, is not fully contained within the dataset's longitude/latitude range. Please ensure that the "
40
+ "dataset covers the entire area required by the ROMS grid."
41
+ )
42
+
43
+
44
+ def interpolate_from_rho_to_u(field, method="additive"):
45
+
46
+ """
47
+ Interpolates the given field from rho points to u points.
48
+
49
+ This function performs an interpolation from the rho grid (cell centers) to the u grid
50
+ (cell edges in the xi direction). Depending on the chosen method, it either averages
51
+ (additive) or multiplies (multiplicative) the field values between adjacent rho points
52
+ along the xi dimension. It also handles the removal of unnecessary coordinate variables
53
+ and updates the dimensions accordingly.
54
+
55
+ Parameters
56
+ ----------
57
+ field : xr.DataArray
58
+ The input data array on the rho grid to be interpolated. It is assumed to have a dimension
59
+ named "xi_rho".
60
+
61
+ method : str, optional, default='additive'
62
+ The method to use for interpolation. Options are:
63
+ - 'additive': Average the field values between adjacent rho points.
64
+ - 'multiplicative': Multiply the field values between adjacent rho points. Appropriate for
65
+ binary masks.
66
+
67
+ Returns
68
+ -------
69
+ field_interpolated : xr.DataArray
70
+ The interpolated data array on the u grid with the dimension "xi_u".
71
+ """
72
+
73
+ if method == "additive":
74
+ field_interpolated = 0.5 * (field + field.shift(xi_rho=1)).isel(
75
+ xi_rho=slice(1, None)
76
+ )
77
+ elif method == "multiplicative":
78
+ field_interpolated = (field * field.shift(xi_rho=1)).isel(xi_rho=slice(1, None))
79
+ else:
80
+ raise NotImplementedError(f"Unsupported method '{method}' specified.")
81
+
82
+ if "lat_rho" in field_interpolated.coords:
83
+ field_interpolated.drop_vars(["lat_rho"])
84
+ if "lon_rho" in field_interpolated.coords:
85
+ field_interpolated.drop_vars(["lon_rho"])
86
+
87
+ field_interpolated = field_interpolated.swap_dims({"xi_rho": "xi_u"})
88
+
89
+ return field_interpolated
90
+
91
+
92
+ def interpolate_from_rho_to_v(field, method="additive"):
93
+
94
+ """
95
+ Interpolates the given field from rho points to v points.
96
+
97
+ This function performs an interpolation from the rho grid (cell centers) to the v grid
98
+ (cell edges in the eta direction). Depending on the chosen method, it either averages
99
+ (additive) or multiplies (multiplicative) the field values between adjacent rho points
100
+ along the eta dimension. It also handles the removal of unnecessary coordinate variables
101
+ and updates the dimensions accordingly.
102
+
103
+ Parameters
104
+ ----------
105
+ field : xr.DataArray
106
+ The input data array on the rho grid to be interpolated. It is assumed to have a dimension
107
+ named "eta_rho".
108
+
109
+ method : str, optional, default='additive'
110
+ The method to use for interpolation. Options are:
111
+ - 'additive': Average the field values between adjacent rho points.
112
+ - 'multiplicative': Multiply the field values between adjacent rho points. Appropriate for
113
+ binary masks.
114
+
115
+ Returns
116
+ -------
117
+ field_interpolated : xr.DataArray
118
+ The interpolated data array on the v grid with the dimension "eta_v".
119
+ """
120
+
121
+ if method == "additive":
122
+ field_interpolated = 0.5 * (field + field.shift(eta_rho=1)).isel(
123
+ eta_rho=slice(1, None)
124
+ )
125
+ elif method == "multiplicative":
126
+ field_interpolated = (field * field.shift(eta_rho=1)).isel(
127
+ eta_rho=slice(1, None)
128
+ )
129
+ else:
130
+ raise NotImplementedError(f"Unsupported method '{method}' specified.")
131
+
132
+ if "lat_rho" in field_interpolated.coords:
133
+ field_interpolated.drop_vars(["lat_rho"])
134
+ if "lon_rho" in field_interpolated.coords:
135
+ field_interpolated.drop_vars(["lon_rho"])
136
+
137
+ field_interpolated = field_interpolated.swap_dims({"eta_rho": "eta_v"})
138
+
139
+ return field_interpolated
140
+
141
+
142
+ def extrapolate_deepest_to_bottom(field: xr.DataArray, dim: str) -> xr.DataArray:
143
+ """
144
+ Extrapolate the deepest non-NaN values to the bottom along a specified dimension.
145
+
146
+ Parameters
147
+ ----------
148
+ field : xr.DataArray
149
+ The input data array containing NaN values that need to be filled. This array
150
+ should have at least one dimension named by `dim`.
151
+ dim : str
152
+ The name of the dimension along which to perform the interpolation and extrapolation.
153
+ Typically, this would be a vertical dimension such as 'depth' or 's_rho'.
154
+
155
+ Returns
156
+ -------
157
+ field_interpolated : xr.DataArray
158
+ A new data array with NaN values along the specified dimension filled by nearest
159
+ neighbor interpolation and extrapolation to the bottom. The original data array is not modified.
160
+
161
+ """
162
+ field_interpolated = field.interpolate_na(
163
+ dim=dim, method="nearest", fill_value="extrapolate"
164
+ )
165
+
166
+ return field_interpolated
167
+
168
+
169
+ def assign_dates_to_climatology(ds: xr.Dataset, time_dim: str) -> xr.Dataset:
170
+ """
171
+ Assigns climatology dates to the dataset's time dimension.
172
+
173
+ This function updates the dataset's time coordinates to reflect climatological dates.
174
+ It defines fixed day increments for each month and assigns these to the specified time dimension.
175
+ The increments represent the cumulative days at mid-month for each month.
176
+
177
+ Parameters
178
+ ----------
179
+ ds : xr.Dataset
180
+ The xarray Dataset to which climatological dates will be assigned.
181
+ time_dim : str
182
+ The name of the time dimension in the dataset that will be updated with climatological dates.
183
+
184
+ Returns
185
+ -------
186
+ xr.Dataset
187
+ The updated xarray Dataset with climatological dates assigned to the specified time dimension.
188
+
189
+ """
190
+ # Define the days in each month and convert to timedelta
191
+ increments = [15, 30, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30]
192
+ days = np.cumsum(increments)
193
+ timedelta_ns = np.array(days, dtype="timedelta64[D]").astype("timedelta64[ns]")
194
+ time = xr.DataArray(timedelta_ns, dims=[time_dim])
195
+ ds = ds.assign_coords({"time": time})
196
+ return ds
197
+
198
+
199
+ def interpolate_from_climatology(
200
+ field: Union[xr.DataArray, xr.Dataset],
201
+ time_dim_name: str,
202
+ time: Union[xr.DataArray, pd.DatetimeIndex],
203
+ ) -> Union[xr.DataArray, xr.Dataset]:
204
+ """
205
+ Interpolates the given field temporally based on the specified time points.
206
+
207
+ If `field` is an xarray.Dataset, this function applies the interpolation to all data variables in the dataset.
208
+
209
+ Parameters
210
+ ----------
211
+ field : xarray.DataArray or xarray.Dataset
212
+ The field data to be interpolated. Can be a single DataArray or a Dataset.
213
+ time_dim_name : str
214
+ The name of the dimension in `field` that represents time.
215
+ time : xarray.DataArray or pandas.DatetimeIndex
216
+ The target time points for interpolation.
217
+
218
+ Returns
219
+ -------
220
+ xarray.DataArray or xarray.Dataset
221
+ The field values interpolated to the specified time points. The type matches the input type.
222
+ """
223
+
224
+ def interpolate_single_field(data_array: xr.DataArray) -> xr.DataArray:
225
+
226
+ if isinstance(time, xr.DataArray):
227
+ # Extract day of year from xarray.DataArray
228
+ day_of_year = time.dt.dayofyear
229
+ else:
230
+ if np.size(time) == 1:
231
+ day_of_year = time.timetuple().tm_yday
232
+ else:
233
+ day_of_year = np.array([t.timetuple().tm_yday for t in time])
234
+
235
+ data_array[time_dim_name] = data_array[time_dim_name].dt.days
236
+
237
+ # Concatenate across the beginning and end of the year
238
+ time_concat = xr.concat(
239
+ [
240
+ data_array[time_dim_name][-1] - 365.25,
241
+ data_array[time_dim_name],
242
+ 365.25 + data_array[time_dim_name][0],
243
+ ],
244
+ dim=time_dim_name,
245
+ )
246
+ data_array_concat = xr.concat(
247
+ [
248
+ data_array.isel(**{time_dim_name: -1}),
249
+ data_array,
250
+ data_array.isel(**{time_dim_name: 0}),
251
+ ],
252
+ dim=time_dim_name,
253
+ )
254
+ data_array_concat[time_dim_name] = time_concat
255
+
256
+ # Interpolate to specified times
257
+ data_array_interpolated = data_array_concat.interp(
258
+ **{time_dim_name: day_of_year}, method="linear"
259
+ )
260
+
261
+ if np.size(time) == 1:
262
+ data_array_interpolated = data_array_interpolated.expand_dims(
263
+ {time_dim_name: 1}
264
+ )
265
+ return data_array_interpolated
266
+
267
+ if isinstance(field, xr.DataArray):
268
+ return interpolate_single_field(field)
269
+ elif isinstance(field, xr.Dataset):
270
+ interpolated_data_vars = {
271
+ var: interpolate_single_field(data_array)
272
+ for var, data_array in field.data_vars.items()
273
+ }
274
+ return xr.Dataset(interpolated_data_vars, attrs=field.attrs)
275
+ else:
276
+ raise TypeError("Input 'field' must be an xarray.DataArray or xarray.Dataset.")
277
+
278
+
279
+ def is_cftime_datetime(data_array: xr.DataArray) -> bool:
280
+ """
281
+ Checks if the xarray DataArray contains cftime datetime objects.
282
+
283
+ Parameters
284
+ ----------
285
+ data_array : xr.DataArray
286
+ The xarray DataArray to be checked for cftime datetime objects.
287
+
288
+ Returns
289
+ -------
290
+ bool
291
+ True if the DataArray contains cftime datetime objects, False otherwise.
292
+
293
+ Raises
294
+ ------
295
+ TypeError
296
+ If the values in the DataArray are not of type numpy.ndarray or list.
297
+ """
298
+ # List of cftime datetime types
299
+ cftime_types = (
300
+ cftime.DatetimeNoLeap,
301
+ cftime.DatetimeJulian,
302
+ cftime.DatetimeGregorian,
303
+ )
304
+
305
+ # Check if any of the coordinate values are of cftime type
306
+ if isinstance(data_array.values, (np.ndarray, list)):
307
+ # Check the dtype of the array; numpy datetime64 indicates it's not cftime
308
+ if data_array.values.dtype == "datetime64[ns]":
309
+ return False
310
+
311
+ # Check if any of the values in the array are instances of cftime types
312
+ return any(isinstance(value, cftime_types) for value in data_array.values)
313
+
314
+ # Handle unexpected types
315
+ raise TypeError("DataArray values must be of type numpy.ndarray or list.")
316
+
317
+
318
+ def convert_cftime_to_datetime(data_array: np.ndarray) -> np.ndarray:
319
+ """
320
+ Converts cftime datetime objects to numpy datetime64 objects in a numpy ndarray.
321
+
322
+ Parameters
323
+ ----------
324
+ data_array : np.ndarray
325
+ The numpy ndarray containing cftime datetime objects to be converted.
326
+
327
+ Returns
328
+ -------
329
+ np.ndarray
330
+ The ndarray with cftime datetimes converted to numpy datetime64 objects.
331
+
332
+ Notes
333
+ -----
334
+ This function is intended to be used with numpy ndarrays. If you need to convert
335
+ cftime datetime objects in an xarray.DataArray, please use the appropriate function
336
+ to handle xarray.DataArray conversions.
337
+ """
338
+ # List of cftime datetime types
339
+ cftime_types = (
340
+ cftime.DatetimeNoLeap,
341
+ cftime.DatetimeJulian,
342
+ cftime.DatetimeGregorian,
343
+ )
344
+
345
+ # Define a conversion function for cftime to numpy datetime64
346
+ def convert_datetime(dt):
347
+ if isinstance(dt, cftime_types):
348
+ # Convert to ISO format and then to nanosecond precision
349
+ return np.datetime64(dt.isoformat(), "ns")
350
+ return np.datetime64(dt, "ns")
351
+
352
+ return np.vectorize(convert_datetime)(data_array)