roms-tools 0.20__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ci/environment.yml +1 -0
- roms_tools/__init__.py +1 -2
- roms_tools/_version.py +1 -1
- roms_tools/setup/boundary_forcing.py +390 -344
- roms_tools/setup/datasets.py +838 -141
- roms_tools/setup/download.py +118 -0
- roms_tools/setup/initial_conditions.py +195 -166
- roms_tools/setup/mixins.py +395 -0
- roms_tools/setup/surface_forcing.py +596 -0
- roms_tools/setup/tides.py +76 -174
- roms_tools/setup/topography.py +1 -1
- roms_tools/setup/utils.py +190 -0
- roms_tools/tests/test_boundary_forcing.py +445 -71
- roms_tools/tests/test_datasets.py +73 -9
- roms_tools/tests/test_initial_conditions.py +252 -32
- roms_tools/tests/test_surface_forcing.py +2622 -0
- roms_tools/tests/test_tides.py +13 -14
- roms_tools/tests/test_utils.py +16 -0
- {roms_tools-0.20.dist-info → roms_tools-1.0.1.dist-info}/METADATA +7 -3
- roms_tools-1.0.1.dist-info/RECORD +31 -0
- {roms_tools-0.20.dist-info → roms_tools-1.0.1.dist-info}/WHEEL +1 -1
- roms_tools/setup/atmospheric_forcing.py +0 -935
- roms_tools/tests/test_atmospheric_forcing.py +0 -1645
- roms_tools-0.20.dist-info/RECORD +0 -28
- {roms_tools-0.20.dist-info → roms_tools-1.0.1.dist-info}/LICENSE +0 -0
- {roms_tools-0.20.dist-info → roms_tools-1.0.1.dist-info}/top_level.txt +0 -0
roms_tools/setup/tides.py
CHANGED
|
@@ -3,165 +3,21 @@ import xarray as xr
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import yaml
|
|
5
5
|
import importlib.metadata
|
|
6
|
+
from typing import Dict, Union
|
|
6
7
|
|
|
7
8
|
from dataclasses import dataclass, field, asdict
|
|
8
9
|
from roms_tools.setup.grid import Grid
|
|
9
10
|
from roms_tools.setup.plot import _plot
|
|
10
11
|
from roms_tools.setup.fill import fill_and_interpolate
|
|
11
|
-
from roms_tools.setup.datasets import
|
|
12
|
+
from roms_tools.setup.datasets import TPXODataset
|
|
12
13
|
from roms_tools.setup.utils import (
|
|
13
14
|
nan_check,
|
|
14
15
|
interpolate_from_rho_to_u,
|
|
15
16
|
interpolate_from_rho_to_v,
|
|
16
17
|
)
|
|
17
|
-
from typing import Dict, List
|
|
18
18
|
import matplotlib.pyplot as plt
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
@dataclass(frozen=True, kw_only=True)
|
|
22
|
-
class TPXO(Dataset):
|
|
23
|
-
"""
|
|
24
|
-
Represents tidal data on original grid.
|
|
25
|
-
|
|
26
|
-
Parameters
|
|
27
|
-
----------
|
|
28
|
-
filename : str
|
|
29
|
-
The path to the TPXO dataset.
|
|
30
|
-
var_names : List[str], optional
|
|
31
|
-
List of variable names that are required in the dataset. Defaults to
|
|
32
|
-
["h_Re", "h_Im", "sal_Re", "sal_Im", "u_Re", "u_Im", "v_Re", "v_Im"].
|
|
33
|
-
dim_names: Dict[str, str], optional
|
|
34
|
-
Dictionary specifying the names of dimensions in the dataset. Defaults to
|
|
35
|
-
{"longitude": "ny", "latitude": "nx"}.
|
|
36
|
-
|
|
37
|
-
Attributes
|
|
38
|
-
----------
|
|
39
|
-
ds : xr.Dataset
|
|
40
|
-
The xarray Dataset containing TPXO tidal model data.
|
|
41
|
-
"""
|
|
42
|
-
|
|
43
|
-
filename: str
|
|
44
|
-
var_names: List[str] = field(
|
|
45
|
-
default_factory=lambda: [
|
|
46
|
-
"h_Re",
|
|
47
|
-
"h_Im",
|
|
48
|
-
"sal_Re",
|
|
49
|
-
"sal_Im",
|
|
50
|
-
"u_Re",
|
|
51
|
-
"u_Im",
|
|
52
|
-
"v_Re",
|
|
53
|
-
"v_Im",
|
|
54
|
-
"depth",
|
|
55
|
-
]
|
|
56
|
-
)
|
|
57
|
-
dim_names: Dict[str, str] = field(
|
|
58
|
-
default_factory=lambda: {"longitude": "ny", "latitude": "nx", "ntides": "nc"}
|
|
59
|
-
)
|
|
60
|
-
ds: xr.Dataset = field(init=False, repr=False)
|
|
61
|
-
|
|
62
|
-
def __post_init__(self):
|
|
63
|
-
# Perform any necessary dataset initialization or modifications here
|
|
64
|
-
ds = super().load_data()
|
|
65
|
-
|
|
66
|
-
# Clean up dataset
|
|
67
|
-
ds = ds.assign_coords(
|
|
68
|
-
{
|
|
69
|
-
"omega": ds["omega"],
|
|
70
|
-
"nx": ds["lon_r"].isel(
|
|
71
|
-
ny=0
|
|
72
|
-
), # lon_r is constant along ny, i.e., is only a function of nx
|
|
73
|
-
"ny": ds["lat_r"].isel(
|
|
74
|
-
nx=0
|
|
75
|
-
), # lat_r is constant along nx, i.e., is only a function of ny
|
|
76
|
-
}
|
|
77
|
-
)
|
|
78
|
-
ds = ds.rename({"nx": "longitude", "ny": "latitude"})
|
|
79
|
-
|
|
80
|
-
object.__setattr__(
|
|
81
|
-
self,
|
|
82
|
-
"dim_names",
|
|
83
|
-
{
|
|
84
|
-
"latitude": "latitude",
|
|
85
|
-
"longitude": "longitude",
|
|
86
|
-
"ntides": self.dim_names["ntides"],
|
|
87
|
-
},
|
|
88
|
-
)
|
|
89
|
-
# Select relevant fields
|
|
90
|
-
ds = super().select_relevant_fields(ds)
|
|
91
|
-
|
|
92
|
-
# Check whether the data covers the entire globe
|
|
93
|
-
is_global = self.check_if_global(ds)
|
|
94
|
-
|
|
95
|
-
if is_global:
|
|
96
|
-
ds = self.concatenate_longitudes(ds)
|
|
97
|
-
|
|
98
|
-
object.__setattr__(self, "ds", ds)
|
|
99
|
-
|
|
100
|
-
def check_number_constituents(self, ntides: int):
|
|
101
|
-
"""
|
|
102
|
-
Checks if the number of constituents in the dataset is at least `ntides`.
|
|
103
|
-
|
|
104
|
-
Parameters
|
|
105
|
-
----------
|
|
106
|
-
ntides : int
|
|
107
|
-
The required number of tidal constituents.
|
|
108
|
-
|
|
109
|
-
Raises
|
|
110
|
-
------
|
|
111
|
-
ValueError
|
|
112
|
-
If the number of constituents in the dataset is less than `ntides`.
|
|
113
|
-
"""
|
|
114
|
-
if len(self.ds[self.dim_names["ntides"]]) < ntides:
|
|
115
|
-
raise ValueError(
|
|
116
|
-
f"The dataset contains fewer than {ntides} tidal constituents."
|
|
117
|
-
)
|
|
118
|
-
|
|
119
|
-
def get_corrected_tides(self, model_reference_date, allan_factor):
|
|
120
|
-
# Get equilibrium tides
|
|
121
|
-
tpc = compute_equilibrium_tide(self.ds["longitude"], self.ds["latitude"]).isel(
|
|
122
|
-
nc=self.ds.nc
|
|
123
|
-
)
|
|
124
|
-
# Correct for SAL
|
|
125
|
-
tsc = allan_factor * (self.ds["sal_Re"] + 1j * self.ds["sal_Im"])
|
|
126
|
-
tpc = tpc - tsc
|
|
127
|
-
|
|
128
|
-
# Elevations and transports
|
|
129
|
-
thc = self.ds["h_Re"] + 1j * self.ds["h_Im"]
|
|
130
|
-
tuc = self.ds["u_Re"] + 1j * self.ds["u_Im"]
|
|
131
|
-
tvc = self.ds["v_Re"] + 1j * self.ds["v_Im"]
|
|
132
|
-
|
|
133
|
-
# Apply correction for phases and amplitudes
|
|
134
|
-
pf, pu, aa = egbert_correction(model_reference_date)
|
|
135
|
-
pf = pf.isel(nc=self.ds.nc)
|
|
136
|
-
pu = pu.isel(nc=self.ds.nc)
|
|
137
|
-
aa = aa.isel(nc=self.ds.nc)
|
|
138
|
-
|
|
139
|
-
tpxo_reference_date = datetime(1992, 1, 1)
|
|
140
|
-
dt = (model_reference_date - tpxo_reference_date).days * 3600 * 24
|
|
141
|
-
|
|
142
|
-
thc = pf * thc * np.exp(1j * (self.ds["omega"] * dt + pu + aa))
|
|
143
|
-
tuc = pf * tuc * np.exp(1j * (self.ds["omega"] * dt + pu + aa))
|
|
144
|
-
tvc = pf * tvc * np.exp(1j * (self.ds["omega"] * dt + pu + aa))
|
|
145
|
-
tpc = pf * tpc * np.exp(1j * (self.ds["omega"] * dt + pu + aa))
|
|
146
|
-
|
|
147
|
-
tides = {
|
|
148
|
-
"ssh_Re": thc.real,
|
|
149
|
-
"ssh_Im": thc.imag,
|
|
150
|
-
"u_Re": tuc.real,
|
|
151
|
-
"u_Im": tuc.imag,
|
|
152
|
-
"v_Re": tvc.real,
|
|
153
|
-
"v_Im": tvc.imag,
|
|
154
|
-
"pot_Re": tpc.real,
|
|
155
|
-
"pot_Im": tpc.imag,
|
|
156
|
-
"omega": self.ds["omega"],
|
|
157
|
-
}
|
|
158
|
-
|
|
159
|
-
for k in tides.keys():
|
|
160
|
-
tides[k] = tides[k].rename({"nc": "ntides"})
|
|
161
|
-
|
|
162
|
-
return tides
|
|
163
|
-
|
|
164
|
-
|
|
165
21
|
@dataclass(frozen=True, kw_only=True)
|
|
166
22
|
class TidalForcing:
|
|
167
23
|
"""
|
|
@@ -171,16 +27,16 @@ class TidalForcing:
|
|
|
171
27
|
----------
|
|
172
28
|
grid : Grid
|
|
173
29
|
The grid object representing the ROMS grid associated with the tidal forcing data.
|
|
174
|
-
|
|
175
|
-
|
|
30
|
+
source : Dict[str, Union[str, None]]
|
|
31
|
+
Dictionary specifying the source of the tidal data:
|
|
32
|
+
- "name" (str): Name of the data source (e.g., "TPXO").
|
|
33
|
+
- "path" (str): Path to the tidal data file. Can contain wildcards.
|
|
176
34
|
ntides : int, optional
|
|
177
35
|
Number of constituents to consider. Maximum number is 14. Default is 10.
|
|
178
|
-
model_reference_date : datetime, optional
|
|
179
|
-
The reference date for the ROMS simulation. Default is datetime(2000, 1, 1).
|
|
180
|
-
source : str, optional
|
|
181
|
-
The source of the tidal data. Default is "TPXO".
|
|
182
36
|
allan_factor : float, optional
|
|
183
37
|
The Allan factor used in tidal model computation. Default is 2.0.
|
|
38
|
+
model_reference_date : datetime, optional
|
|
39
|
+
The reference date for the ROMS simulation. Default is datetime(2000, 1, 1).
|
|
184
40
|
|
|
185
41
|
Attributes
|
|
186
42
|
----------
|
|
@@ -189,27 +45,31 @@ class TidalForcing:
|
|
|
189
45
|
|
|
190
46
|
Examples
|
|
191
47
|
--------
|
|
192
|
-
>>>
|
|
193
|
-
|
|
194
|
-
|
|
48
|
+
>>> tidal_forcing = TidalForcing(
|
|
49
|
+
... grid=grid, source={"name": "TPXO", "path": "tpxo_data.nc"}
|
|
50
|
+
... )
|
|
195
51
|
"""
|
|
196
52
|
|
|
197
53
|
grid: Grid
|
|
198
|
-
|
|
54
|
+
source: Dict[str, Union[str, None]]
|
|
199
55
|
ntides: int = 10
|
|
200
|
-
model_reference_date: datetime = datetime(2000, 1, 1)
|
|
201
|
-
source: str = "TPXO"
|
|
202
56
|
allan_factor: float = 2.0
|
|
57
|
+
model_reference_date: datetime = datetime(2000, 1, 1)
|
|
58
|
+
|
|
203
59
|
ds: xr.Dataset = field(init=False, repr=False)
|
|
204
60
|
|
|
205
61
|
def __post_init__(self):
|
|
206
|
-
if self.source
|
|
207
|
-
|
|
62
|
+
if "name" not in self.source.keys():
|
|
63
|
+
raise ValueError("`source` must include a 'name'.")
|
|
64
|
+
if "path" not in self.source.keys():
|
|
65
|
+
raise ValueError("`source` must include a 'path'.")
|
|
66
|
+
if self.source["name"] == "TPXO":
|
|
67
|
+
data = TPXODataset(filename=self.source["path"])
|
|
208
68
|
else:
|
|
209
|
-
raise ValueError('Only "TPXO" is a valid option for source.')
|
|
69
|
+
raise ValueError('Only "TPXO" is a valid option for source["name"].')
|
|
210
70
|
|
|
211
71
|
data.check_number_constituents(self.ntides)
|
|
212
|
-
# operate on longitudes between -180 and 180 unless ROMS domain lies at least 5 degrees in
|
|
72
|
+
# operate on longitudes between -180 and 180 unless ROMS domain lies at least 5 degrees in longitude away from Greenwich meridian
|
|
213
73
|
lon = self.grid.ds.lon_rho
|
|
214
74
|
lat = self.grid.ds.lat_rho
|
|
215
75
|
angle = self.grid.ds.angle
|
|
@@ -220,14 +80,9 @@ class TidalForcing:
|
|
|
220
80
|
lon = xr.where(lon < 0, lon + 360, lon)
|
|
221
81
|
straddle = False
|
|
222
82
|
|
|
223
|
-
#
|
|
224
|
-
#
|
|
225
|
-
#
|
|
226
|
-
# A) Since the horizontal dimensions consist of a single chunk, selecting a subdomain before interpolation is a lot more performant.
|
|
227
|
-
# B) Step 1 is necessary to avoid discontinuous longitudes that could be introduced by Step 2. Specifically, discontinuous longitudes
|
|
228
|
-
# can lead to artifacts in the interpolation process. Specifically, if there is a data gap if data is not global,
|
|
229
|
-
# discontinuous longitudes could result in values that appear to come from a distant location instead of producing NaNs.
|
|
230
|
-
# These NaNs are important as they can be identified and handled appropriately by the nan_check function.
|
|
83
|
+
# Restrict data to relevant subdomain to achieve better performance and to avoid discontinuous longitudes introduced by converting
|
|
84
|
+
# to a different longitude range (+- 360 degrees). Discontinues longitudes can lead to artifacts in the interpolation process that
|
|
85
|
+
# would not be detected by the nan_check function.
|
|
231
86
|
data.choose_subdomain(
|
|
232
87
|
latitude_range=[lat.min().values, lat.max().values],
|
|
233
88
|
longitude_range=[lon.min().values, lon.max().values],
|
|
@@ -235,7 +90,7 @@ class TidalForcing:
|
|
|
235
90
|
straddle=straddle,
|
|
236
91
|
)
|
|
237
92
|
|
|
238
|
-
tides =
|
|
93
|
+
tides = self._get_corrected_tides(data)
|
|
239
94
|
|
|
240
95
|
# select desired number of constituents
|
|
241
96
|
for k in tides.keys():
|
|
@@ -326,7 +181,7 @@ class TidalForcing:
|
|
|
326
181
|
|
|
327
182
|
ds.attrs["roms_tools_version"] = roms_tools_version
|
|
328
183
|
|
|
329
|
-
ds.attrs["source"] = self.source
|
|
184
|
+
ds.attrs["source"] = self.source["name"]
|
|
330
185
|
ds.attrs["model_reference_date"] = str(self.model_reference_date)
|
|
331
186
|
ds.attrs["allan_factor"] = self.allan_factor
|
|
332
187
|
|
|
@@ -430,10 +285,9 @@ class TidalForcing:
|
|
|
430
285
|
# Extract tidal forcing data
|
|
431
286
|
tidal_forcing_data = {
|
|
432
287
|
"TidalForcing": {
|
|
433
|
-
"
|
|
288
|
+
"source": self.source,
|
|
434
289
|
"ntides": self.ntides,
|
|
435
290
|
"model_reference_date": self.model_reference_date.isoformat(),
|
|
436
|
-
"source": self.source,
|
|
437
291
|
"allan_factor": self.allan_factor,
|
|
438
292
|
}
|
|
439
293
|
}
|
|
@@ -494,6 +348,54 @@ class TidalForcing:
|
|
|
494
348
|
# Create and return an instance of TidalForcing
|
|
495
349
|
return cls(grid=grid, **tidal_forcing_params)
|
|
496
350
|
|
|
351
|
+
def _get_corrected_tides(self, data):
|
|
352
|
+
|
|
353
|
+
# Get equilibrium tides
|
|
354
|
+
tpc = compute_equilibrium_tide(
|
|
355
|
+
data.ds[data.dim_names["longitude"]], data.ds[data.dim_names["latitude"]]
|
|
356
|
+
)
|
|
357
|
+
tpc = tpc.isel(**{data.dim_names["ntides"]: data.ds[data.dim_names["ntides"]]})
|
|
358
|
+
# Correct for SAL
|
|
359
|
+
tsc = self.allan_factor * (
|
|
360
|
+
data.ds[data.var_names["sal_Re"]] + 1j * data.ds[data.var_names["sal_Im"]]
|
|
361
|
+
)
|
|
362
|
+
tpc = tpc - tsc
|
|
363
|
+
|
|
364
|
+
# Elevations and transports
|
|
365
|
+
thc = data.ds[data.var_names["ssh_Re"]] + 1j * data.ds[data.var_names["ssh_Im"]]
|
|
366
|
+
tuc = data.ds[data.var_names["u_Re"]] + 1j * data.ds[data.var_names["u_Im"]]
|
|
367
|
+
tvc = data.ds[data.var_names["v_Re"]] + 1j * data.ds[data.var_names["v_Im"]]
|
|
368
|
+
|
|
369
|
+
# Apply correction for phases and amplitudes
|
|
370
|
+
pf, pu, aa = egbert_correction(self.model_reference_date)
|
|
371
|
+
pf = pf.isel(**{data.dim_names["ntides"]: data.ds[data.dim_names["ntides"]]})
|
|
372
|
+
pu = pu.isel(**{data.dim_names["ntides"]: data.ds[data.dim_names["ntides"]]})
|
|
373
|
+
aa = aa.isel(**{data.dim_names["ntides"]: data.ds[data.dim_names["ntides"]]})
|
|
374
|
+
|
|
375
|
+
dt = (self.model_reference_date - data.reference_date).days * 3600 * 24
|
|
376
|
+
|
|
377
|
+
thc = pf * thc * np.exp(1j * (data.ds["omega"] * dt + pu + aa))
|
|
378
|
+
tuc = pf * tuc * np.exp(1j * (data.ds["omega"] * dt + pu + aa))
|
|
379
|
+
tvc = pf * tvc * np.exp(1j * (data.ds["omega"] * dt + pu + aa))
|
|
380
|
+
tpc = pf * tpc * np.exp(1j * (data.ds["omega"] * dt + pu + aa))
|
|
381
|
+
|
|
382
|
+
tides = {
|
|
383
|
+
"ssh_Re": thc.real,
|
|
384
|
+
"ssh_Im": thc.imag,
|
|
385
|
+
"u_Re": tuc.real,
|
|
386
|
+
"u_Im": tuc.imag,
|
|
387
|
+
"v_Re": tvc.real,
|
|
388
|
+
"v_Im": tvc.imag,
|
|
389
|
+
"pot_Re": tpc.real,
|
|
390
|
+
"pot_Im": tpc.imag,
|
|
391
|
+
"omega": data.ds["omega"],
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
for k in tides.keys():
|
|
395
|
+
tides[k] = tides[k].rename({data.dim_names["ntides"]: "ntides"})
|
|
396
|
+
|
|
397
|
+
return tides
|
|
398
|
+
|
|
497
399
|
|
|
498
400
|
def modified_julian_days(year, month, day, hour=0):
|
|
499
401
|
"""
|
roms_tools/setup/topography.py
CHANGED
|
@@ -3,7 +3,7 @@ import numpy as np
|
|
|
3
3
|
import gcm_filters
|
|
4
4
|
from scipy.interpolate import RegularGridInterpolator
|
|
5
5
|
from scipy.ndimage import label
|
|
6
|
-
from roms_tools.setup.
|
|
6
|
+
from roms_tools.setup.download import fetch_topo
|
|
7
7
|
from roms_tools.setup.utils import interpolate_from_rho_to_u, interpolate_from_rho_to_v
|
|
8
8
|
import warnings
|
|
9
9
|
from itertools import count
|
roms_tools/setup/utils.py
CHANGED
|
@@ -1,4 +1,8 @@
|
|
|
1
1
|
import xarray as xr
|
|
2
|
+
import numpy as np
|
|
3
|
+
from typing import Union
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import cftime
|
|
2
6
|
|
|
3
7
|
|
|
4
8
|
def nan_check(field, mask) -> None:
|
|
@@ -160,3 +164,189 @@ def extrapolate_deepest_to_bottom(field: xr.DataArray, dim: str) -> xr.DataArray
|
|
|
160
164
|
)
|
|
161
165
|
|
|
162
166
|
return field_interpolated
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def assign_dates_to_climatology(ds: xr.Dataset, time_dim: str) -> xr.Dataset:
|
|
170
|
+
"""
|
|
171
|
+
Assigns climatology dates to the dataset's time dimension.
|
|
172
|
+
|
|
173
|
+
This function updates the dataset's time coordinates to reflect climatological dates.
|
|
174
|
+
It defines fixed day increments for each month and assigns these to the specified time dimension.
|
|
175
|
+
The increments represent the cumulative days at mid-month for each month.
|
|
176
|
+
|
|
177
|
+
Parameters
|
|
178
|
+
----------
|
|
179
|
+
ds : xr.Dataset
|
|
180
|
+
The xarray Dataset to which climatological dates will be assigned.
|
|
181
|
+
time_dim : str
|
|
182
|
+
The name of the time dimension in the dataset that will be updated with climatological dates.
|
|
183
|
+
|
|
184
|
+
Returns
|
|
185
|
+
-------
|
|
186
|
+
xr.Dataset
|
|
187
|
+
The updated xarray Dataset with climatological dates assigned to the specified time dimension.
|
|
188
|
+
|
|
189
|
+
"""
|
|
190
|
+
# Define the days in each month and convert to timedelta
|
|
191
|
+
increments = [15, 30, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30]
|
|
192
|
+
days = np.cumsum(increments)
|
|
193
|
+
timedelta_ns = np.array(days, dtype="timedelta64[D]").astype("timedelta64[ns]")
|
|
194
|
+
time = xr.DataArray(timedelta_ns, dims=[time_dim])
|
|
195
|
+
ds = ds.assign_coords({"time": time})
|
|
196
|
+
return ds
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def interpolate_from_climatology(
|
|
200
|
+
field: Union[xr.DataArray, xr.Dataset],
|
|
201
|
+
time_dim_name: str,
|
|
202
|
+
time: Union[xr.DataArray, pd.DatetimeIndex],
|
|
203
|
+
) -> Union[xr.DataArray, xr.Dataset]:
|
|
204
|
+
"""
|
|
205
|
+
Interpolates the given field temporally based on the specified time points.
|
|
206
|
+
|
|
207
|
+
If `field` is an xarray.Dataset, this function applies the interpolation to all data variables in the dataset.
|
|
208
|
+
|
|
209
|
+
Parameters
|
|
210
|
+
----------
|
|
211
|
+
field : xarray.DataArray or xarray.Dataset
|
|
212
|
+
The field data to be interpolated. Can be a single DataArray or a Dataset.
|
|
213
|
+
time_dim_name : str
|
|
214
|
+
The name of the dimension in `field` that represents time.
|
|
215
|
+
time : xarray.DataArray or pandas.DatetimeIndex
|
|
216
|
+
The target time points for interpolation.
|
|
217
|
+
|
|
218
|
+
Returns
|
|
219
|
+
-------
|
|
220
|
+
xarray.DataArray or xarray.Dataset
|
|
221
|
+
The field values interpolated to the specified time points. The type matches the input type.
|
|
222
|
+
"""
|
|
223
|
+
|
|
224
|
+
def interpolate_single_field(data_array: xr.DataArray) -> xr.DataArray:
|
|
225
|
+
|
|
226
|
+
if isinstance(time, xr.DataArray):
|
|
227
|
+
# Extract day of year from xarray.DataArray
|
|
228
|
+
day_of_year = time.dt.dayofyear
|
|
229
|
+
else:
|
|
230
|
+
if np.size(time) == 1:
|
|
231
|
+
day_of_year = time.timetuple().tm_yday
|
|
232
|
+
else:
|
|
233
|
+
day_of_year = np.array([t.timetuple().tm_yday for t in time])
|
|
234
|
+
|
|
235
|
+
data_array[time_dim_name] = data_array[time_dim_name].dt.days
|
|
236
|
+
|
|
237
|
+
# Concatenate across the beginning and end of the year
|
|
238
|
+
time_concat = xr.concat(
|
|
239
|
+
[
|
|
240
|
+
data_array[time_dim_name][-1] - 365.25,
|
|
241
|
+
data_array[time_dim_name],
|
|
242
|
+
365.25 + data_array[time_dim_name][0],
|
|
243
|
+
],
|
|
244
|
+
dim=time_dim_name,
|
|
245
|
+
)
|
|
246
|
+
data_array_concat = xr.concat(
|
|
247
|
+
[
|
|
248
|
+
data_array.isel(**{time_dim_name: -1}),
|
|
249
|
+
data_array,
|
|
250
|
+
data_array.isel(**{time_dim_name: 0}),
|
|
251
|
+
],
|
|
252
|
+
dim=time_dim_name,
|
|
253
|
+
)
|
|
254
|
+
data_array_concat[time_dim_name] = time_concat
|
|
255
|
+
|
|
256
|
+
# Interpolate to specified times
|
|
257
|
+
data_array_interpolated = data_array_concat.interp(
|
|
258
|
+
**{time_dim_name: day_of_year}, method="linear"
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
if np.size(time) == 1:
|
|
262
|
+
data_array_interpolated = data_array_interpolated.expand_dims(
|
|
263
|
+
{time_dim_name: 1}
|
|
264
|
+
)
|
|
265
|
+
return data_array_interpolated
|
|
266
|
+
|
|
267
|
+
if isinstance(field, xr.DataArray):
|
|
268
|
+
return interpolate_single_field(field)
|
|
269
|
+
elif isinstance(field, xr.Dataset):
|
|
270
|
+
interpolated_data_vars = {
|
|
271
|
+
var: interpolate_single_field(data_array)
|
|
272
|
+
for var, data_array in field.data_vars.items()
|
|
273
|
+
}
|
|
274
|
+
return xr.Dataset(interpolated_data_vars, attrs=field.attrs)
|
|
275
|
+
else:
|
|
276
|
+
raise TypeError("Input 'field' must be an xarray.DataArray or xarray.Dataset.")
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def is_cftime_datetime(data_array: xr.DataArray) -> bool:
|
|
280
|
+
"""
|
|
281
|
+
Checks if the xarray DataArray contains cftime datetime objects.
|
|
282
|
+
|
|
283
|
+
Parameters
|
|
284
|
+
----------
|
|
285
|
+
data_array : xr.DataArray
|
|
286
|
+
The xarray DataArray to be checked for cftime datetime objects.
|
|
287
|
+
|
|
288
|
+
Returns
|
|
289
|
+
-------
|
|
290
|
+
bool
|
|
291
|
+
True if the DataArray contains cftime datetime objects, False otherwise.
|
|
292
|
+
|
|
293
|
+
Raises
|
|
294
|
+
------
|
|
295
|
+
TypeError
|
|
296
|
+
If the values in the DataArray are not of type numpy.ndarray or list.
|
|
297
|
+
"""
|
|
298
|
+
# List of cftime datetime types
|
|
299
|
+
cftime_types = (
|
|
300
|
+
cftime.DatetimeNoLeap,
|
|
301
|
+
cftime.DatetimeJulian,
|
|
302
|
+
cftime.DatetimeGregorian,
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
# Check if any of the coordinate values are of cftime type
|
|
306
|
+
if isinstance(data_array.values, (np.ndarray, list)):
|
|
307
|
+
# Check the dtype of the array; numpy datetime64 indicates it's not cftime
|
|
308
|
+
if data_array.values.dtype == "datetime64[ns]":
|
|
309
|
+
return False
|
|
310
|
+
|
|
311
|
+
# Check if any of the values in the array are instances of cftime types
|
|
312
|
+
return any(isinstance(value, cftime_types) for value in data_array.values)
|
|
313
|
+
|
|
314
|
+
# Handle unexpected types
|
|
315
|
+
raise TypeError("DataArray values must be of type numpy.ndarray or list.")
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def convert_cftime_to_datetime(data_array: np.ndarray) -> np.ndarray:
|
|
319
|
+
"""
|
|
320
|
+
Converts cftime datetime objects to numpy datetime64 objects in a numpy ndarray.
|
|
321
|
+
|
|
322
|
+
Parameters
|
|
323
|
+
----------
|
|
324
|
+
data_array : np.ndarray
|
|
325
|
+
The numpy ndarray containing cftime datetime objects to be converted.
|
|
326
|
+
|
|
327
|
+
Returns
|
|
328
|
+
-------
|
|
329
|
+
np.ndarray
|
|
330
|
+
The ndarray with cftime datetimes converted to numpy datetime64 objects.
|
|
331
|
+
|
|
332
|
+
Notes
|
|
333
|
+
-----
|
|
334
|
+
This function is intended to be used with numpy ndarrays. If you need to convert
|
|
335
|
+
cftime datetime objects in an xarray.DataArray, please use the appropriate function
|
|
336
|
+
to handle xarray.DataArray conversions.
|
|
337
|
+
"""
|
|
338
|
+
# List of cftime datetime types
|
|
339
|
+
cftime_types = (
|
|
340
|
+
cftime.DatetimeNoLeap,
|
|
341
|
+
cftime.DatetimeJulian,
|
|
342
|
+
cftime.DatetimeGregorian,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
# Define a conversion function for cftime to numpy datetime64
|
|
346
|
+
def convert_datetime(dt):
|
|
347
|
+
if isinstance(dt, cftime_types):
|
|
348
|
+
# Convert to ISO format and then to nanosecond precision
|
|
349
|
+
return np.datetime64(dt.isoformat(), "ns")
|
|
350
|
+
return np.datetime64(dt, "ns")
|
|
351
|
+
|
|
352
|
+
return np.vectorize(convert_datetime)(data_array)
|