ngiab-data-preprocess 4.2.1__py3-none-any.whl → 4.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_processing/create_realization.py +42 -50
- data_processing/dask_utils.py +92 -0
- data_processing/dataset_utils.py +127 -44
- data_processing/datasets.py +18 -29
- data_processing/file_paths.py +7 -7
- data_processing/forcings.py +102 -102
- data_processing/gpkg_utils.py +18 -18
- data_processing/graph_utils.py +4 -4
- data_processing/s3fs_utils.py +1 -1
- data_processing/subset.py +1 -2
- data_sources/source_validation.py +57 -32
- map_app/__main__.py +3 -2
- map_app/static/css/main.css +33 -10
- map_app/static/css/toggle.css +8 -5
- map_app/static/js/main.js +232 -90
- map_app/templates/index.html +31 -9
- map_app/views.py +8 -8
- ngiab_data_cli/__main__.py +31 -28
- ngiab_data_cli/arguments.py +0 -1
- ngiab_data_cli/forcing_cli.py +10 -19
- {ngiab_data_preprocess-4.2.1.dist-info → ngiab_data_preprocess-4.3.0.dist-info}/METADATA +15 -13
- ngiab_data_preprocess-4.3.0.dist-info/RECORD +43 -0
- {ngiab_data_preprocess-4.2.1.dist-info → ngiab_data_preprocess-4.3.0.dist-info}/WHEEL +1 -1
- map_app/static/resources/dark-style.json +0 -11068
- map_app/static/resources/light-style.json +0 -11068
- ngiab_data_preprocess-4.2.1.dist-info/RECORD +0 -44
- {ngiab_data_preprocess-4.2.1.dist-info → ngiab_data_preprocess-4.3.0.dist-info}/entry_points.txt +0 -0
- {ngiab_data_preprocess-4.2.1.dist-info → ngiab_data_preprocess-4.3.0.dist-info}/licenses/LICENSE +0 -0
- {ngiab_data_preprocess-4.2.1.dist-info → ngiab_data_preprocess-4.3.0.dist-info}/top_level.txt +0 -0
|
@@ -1,30 +1,32 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import logging
|
|
2
3
|
import multiprocessing
|
|
4
|
+
import shutil
|
|
3
5
|
import sqlite3
|
|
4
6
|
from datetime import datetime
|
|
5
7
|
from pathlib import Path
|
|
6
|
-
import
|
|
7
|
-
import requests
|
|
8
|
+
from typing import Dict, Optional
|
|
8
9
|
|
|
9
10
|
import pandas
|
|
11
|
+
import requests
|
|
10
12
|
import s3fs
|
|
11
13
|
import xarray as xr
|
|
12
|
-
import
|
|
13
|
-
from collections import defaultdict
|
|
14
|
-
from dask.distributed import Client, LocalCluster
|
|
14
|
+
from data_processing.dask_utils import temp_cluster
|
|
15
15
|
from data_processing.file_paths import file_paths
|
|
16
16
|
from data_processing.gpkg_utils import (
|
|
17
17
|
GeoPackage,
|
|
18
|
+
get_cat_to_nex_flowpairs,
|
|
18
19
|
get_cat_to_nhd_feature_id,
|
|
19
20
|
get_table_crs_short,
|
|
20
|
-
get_cat_to_nex_flowpairs,
|
|
21
21
|
)
|
|
22
|
-
from tqdm.rich import tqdm
|
|
23
22
|
from pyproj import Transformer
|
|
23
|
+
from tqdm.rich import tqdm
|
|
24
24
|
|
|
25
25
|
logger = logging.getLogger(__name__)
|
|
26
26
|
|
|
27
|
-
|
|
27
|
+
|
|
28
|
+
@temp_cluster
|
|
29
|
+
def get_approximate_gw_storage(paths: file_paths, start_date: datetime) -> Dict[str, int]:
|
|
28
30
|
# get the gw levels from the NWM output on a given start date
|
|
29
31
|
# this kind of works in place of warmstates for now
|
|
30
32
|
year = start_date.strftime("%Y")
|
|
@@ -34,17 +36,10 @@ def get_approximate_gw_storage(paths: file_paths, start_date: datetime):
|
|
|
34
36
|
fs = s3fs.S3FileSystem(anon=True)
|
|
35
37
|
nc_url = f"s3://noaa-nwm-retrospective-3-0-pds/CONUS/netcdf/GWOUT/{year}/{formatted_dt}.GWOUT_DOMAIN1"
|
|
36
38
|
|
|
37
|
-
# make sure there's a dask cluster running
|
|
38
|
-
try:
|
|
39
|
-
client = Client.current()
|
|
40
|
-
except ValueError:
|
|
41
|
-
cluster = LocalCluster()
|
|
42
|
-
client = Client(cluster)
|
|
43
|
-
|
|
44
39
|
with fs.open(nc_url) as file_obj:
|
|
45
|
-
ds = xr.open_dataset(file_obj)
|
|
40
|
+
ds = xr.open_dataset(file_obj) # type: ignore
|
|
46
41
|
|
|
47
|
-
water_levels = dict()
|
|
42
|
+
water_levels: Dict[str, int] = dict()
|
|
48
43
|
for cat, feature in tqdm(cat_to_feature.items()):
|
|
49
44
|
# this value is in CM, we need meters to match max_gw_depth
|
|
50
45
|
# xarray says it's in mm, with 0.1 scale factor. calling .values doesn't apply the scale
|
|
@@ -78,7 +73,9 @@ def make_cfe_config(
|
|
|
78
73
|
slope=row["mean.slope_1km"],
|
|
79
74
|
smcmax=row["mean.smcmax_soil_layers_stag=2"],
|
|
80
75
|
smcwlt=row["mean.smcwlt_soil_layers_stag=2"],
|
|
81
|
-
max_gw_storage=row["mean.Zmax"]/1000
|
|
76
|
+
max_gw_storage=row["mean.Zmax"] / 1000
|
|
77
|
+
if row["mean.Zmax"] is not None
|
|
78
|
+
else "0.011[m]", # mean.Zmax is in mm!
|
|
82
79
|
gw_Coeff=row["mean.Coeff"] if row["mean.Coeff"] is not None else "0.0018[m h-1]",
|
|
83
80
|
gw_Expon=row["mode.Expon"],
|
|
84
81
|
gw_storage="{:.5}".format(gw_storage_ratio),
|
|
@@ -92,7 +89,6 @@ def make_cfe_config(
|
|
|
92
89
|
def make_noahowp_config(
|
|
93
90
|
base_dir: Path, divide_conf_df: pandas.DataFrame, start_time: datetime, end_time: datetime
|
|
94
91
|
) -> None:
|
|
95
|
-
|
|
96
92
|
divide_conf_df.set_index("divide_id", inplace=True)
|
|
97
93
|
start_datetime = start_time.strftime("%Y%m%d%H%M")
|
|
98
94
|
end_datetime = end_time.strftime("%Y%m%d%H%M")
|
|
@@ -110,15 +106,15 @@ def make_noahowp_config(
|
|
|
110
106
|
end_datetime=end_datetime,
|
|
111
107
|
lat=divide_conf_df.loc[divide, "latitude"],
|
|
112
108
|
lon=divide_conf_df.loc[divide, "longitude"],
|
|
113
|
-
terrain_slope=
|
|
114
|
-
azimuth=
|
|
115
|
-
ISLTYP=int(divide_conf_df.loc[divide, "mode.ISLTYP"]),
|
|
116
|
-
IVGTYP=int(divide_conf_df.loc[divide, "mode.IVGTYP"]),
|
|
109
|
+
terrain_slope=divide_conf_df.loc[divide, "mean.slope_1km"],
|
|
110
|
+
azimuth=divide_conf_df.loc[divide, "circ_mean.aspect"],
|
|
111
|
+
ISLTYP=int(divide_conf_df.loc[divide, "mode.ISLTYP"]), # type: ignore
|
|
112
|
+
IVGTYP=int(divide_conf_df.loc[divide, "mode.IVGTYP"]), # type: ignore
|
|
117
113
|
)
|
|
118
114
|
)
|
|
119
115
|
|
|
120
116
|
|
|
121
|
-
def get_model_attributes_modspatialite(hydrofabric: Path):
|
|
117
|
+
def get_model_attributes_modspatialite(hydrofabric: Path) -> pandas.DataFrame:
|
|
122
118
|
# modspatialite is faster than pyproj but can't be added as a pip dependency
|
|
123
119
|
# This incantation took a while
|
|
124
120
|
with GeoPackage(hydrofabric) as conn:
|
|
@@ -149,7 +145,7 @@ def get_model_attributes_modspatialite(hydrofabric: Path):
|
|
|
149
145
|
return divide_conf_df
|
|
150
146
|
|
|
151
147
|
|
|
152
|
-
def get_model_attributes_pyproj(hydrofabric: Path):
|
|
148
|
+
def get_model_attributes_pyproj(hydrofabric: Path) -> pandas.DataFrame:
|
|
153
149
|
# if modspatialite is not available, use pyproj
|
|
154
150
|
with sqlite3.connect(hydrofabric) as conn:
|
|
155
151
|
sql = """
|
|
@@ -182,7 +178,8 @@ def get_model_attributes_pyproj(hydrofabric: Path):
|
|
|
182
178
|
|
|
183
179
|
return divide_conf_df
|
|
184
180
|
|
|
185
|
-
|
|
181
|
+
|
|
182
|
+
def get_model_attributes(hydrofabric: Path) -> pandas.DataFrame:
|
|
186
183
|
try:
|
|
187
184
|
with GeoPackage(hydrofabric) as conn:
|
|
188
185
|
conf_df = pandas.read_sql_query(
|
|
@@ -205,30 +202,31 @@ def get_model_attributes(hydrofabric: Path):
|
|
|
205
202
|
)
|
|
206
203
|
except sqlite3.OperationalError:
|
|
207
204
|
with sqlite3.connect(hydrofabric) as conn:
|
|
208
|
-
conf_df = pandas.read_sql_query(
|
|
205
|
+
conf_df = pandas.read_sql_query(
|
|
206
|
+
"SELECT* FROM 'divide-attributes';",
|
|
207
|
+
conn,
|
|
208
|
+
)
|
|
209
209
|
source_crs = get_table_crs_short(hydrofabric, "divides")
|
|
210
210
|
transformer = Transformer.from_crs(source_crs, "EPSG:4326", always_xy=True)
|
|
211
|
-
lon, lat = transformer.transform(
|
|
212
|
-
conf_df["centroid_x"].values, conf_df["centroid_y"].values
|
|
213
|
-
)
|
|
211
|
+
lon, lat = transformer.transform(conf_df["centroid_x"].values, conf_df["centroid_y"].values)
|
|
214
212
|
conf_df["longitude"] = lon
|
|
215
213
|
conf_df["latitude"] = lat
|
|
216
214
|
|
|
217
215
|
conf_df.drop(columns=["centroid_x", "centroid_y"], axis=1, inplace=True)
|
|
218
216
|
return conf_df
|
|
219
217
|
|
|
218
|
+
|
|
220
219
|
def make_em_config(
|
|
221
220
|
hydrofabric: Path,
|
|
222
221
|
output_dir: Path,
|
|
223
222
|
template_path: Path = file_paths.template_em_config,
|
|
224
223
|
):
|
|
225
|
-
|
|
226
224
|
# test if modspatialite is available
|
|
227
225
|
try:
|
|
228
226
|
divide_conf_df = get_model_attributes_modspatialite(hydrofabric)
|
|
229
227
|
except Exception as e:
|
|
230
228
|
logger.warning(f"mod_spatialite not available, using pyproj instead: {e}")
|
|
231
|
-
logger.warning(
|
|
229
|
+
logger.warning("Install mod_spatialite for improved performance")
|
|
232
230
|
divide_conf_df = get_model_attributes_pyproj(hydrofabric)
|
|
233
231
|
|
|
234
232
|
cat_config_dir = output_dir / "cat_config" / "empirical_model"
|
|
@@ -255,8 +253,7 @@ def make_em_config(
|
|
|
255
253
|
|
|
256
254
|
def configure_troute(
|
|
257
255
|
cat_id: str, config_dir: Path, start_time: datetime, end_time: datetime
|
|
258
|
-
) ->
|
|
259
|
-
|
|
256
|
+
) -> None:
|
|
260
257
|
with open(file_paths.template_troute_config, "r") as file:
|
|
261
258
|
troute_template = file.read()
|
|
262
259
|
time_step_size = 300
|
|
@@ -269,7 +266,7 @@ def configure_troute(
|
|
|
269
266
|
geo_file_path=f"./config/{cat_id}_subset.gpkg",
|
|
270
267
|
start_datetime=start_time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
271
268
|
nts=nts,
|
|
272
|
-
max_loop_size=nts,
|
|
269
|
+
max_loop_size=nts,
|
|
273
270
|
)
|
|
274
271
|
|
|
275
272
|
with open(config_dir / "troute.yaml", "w") as file:
|
|
@@ -301,9 +298,7 @@ def create_em_realization(cat_id: str, start_time: datetime, end_time: datetime)
|
|
|
301
298
|
f.write(em_config)
|
|
302
299
|
|
|
303
300
|
configure_troute(cat_id, paths.config_dir, start_time, end_time)
|
|
304
|
-
make_ngen_realization_json(
|
|
305
|
-
paths.config_dir, template_path, start_time, end_time
|
|
306
|
-
)
|
|
301
|
+
make_ngen_realization_json(paths.config_dir, template_path, start_time, end_time)
|
|
307
302
|
make_em_config(paths.geopackage_path, paths.config_dir)
|
|
308
303
|
# create some partitions for parallelization
|
|
309
304
|
paths.setup_run_folders()
|
|
@@ -315,7 +310,7 @@ def create_realization(
|
|
|
315
310
|
start_time: datetime,
|
|
316
311
|
end_time: datetime,
|
|
317
312
|
use_nwm_gw: bool = False,
|
|
318
|
-
gage_id: str = None,
|
|
313
|
+
gage_id: Optional[str] = None,
|
|
319
314
|
):
|
|
320
315
|
paths = file_paths(cat_id)
|
|
321
316
|
|
|
@@ -324,15 +319,14 @@ def create_realization(
|
|
|
324
319
|
if gage_id is not None:
|
|
325
320
|
# try and download s3:communityhydrofabric/hydrofabrics/community/gage_parameters/gage_id
|
|
326
321
|
# if it doesn't exist, use the default
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
322
|
+
url = f"https://communityhydrofabric.s3.us-east-1.amazonaws.com/hydrofabrics/community/gage_parameters/{gage_id}.json"
|
|
323
|
+
response = requests.get(url)
|
|
324
|
+
if response.status_code == 200:
|
|
330
325
|
new_template = requests.get(url).json()
|
|
331
|
-
template_path = paths.config_dir / "
|
|
326
|
+
template_path = paths.config_dir / "downloaded_params.json"
|
|
332
327
|
with open(template_path, "w") as f:
|
|
333
328
|
json.dump(new_template, f)
|
|
334
|
-
|
|
335
|
-
logger.warning(f"Failed to download gage parameters")
|
|
329
|
+
logger.info(f"downloaded calibrated parameters for {gage_id}")
|
|
336
330
|
|
|
337
331
|
conf_df = get_model_attributes(paths.geopackage_path)
|
|
338
332
|
|
|
@@ -347,21 +341,19 @@ def create_realization(
|
|
|
347
341
|
|
|
348
342
|
configure_troute(cat_id, paths.config_dir, start_time, end_time)
|
|
349
343
|
|
|
350
|
-
make_ngen_realization_json(
|
|
351
|
-
paths.config_dir, template_path, start_time, end_time
|
|
352
|
-
)
|
|
344
|
+
make_ngen_realization_json(paths.config_dir, template_path, start_time, end_time)
|
|
353
345
|
|
|
354
346
|
# create some partitions for parallelization
|
|
355
347
|
paths.setup_run_folders()
|
|
356
348
|
create_partitions(paths)
|
|
357
349
|
|
|
358
350
|
|
|
359
|
-
def create_partitions(paths:
|
|
351
|
+
def create_partitions(paths: file_paths, num_partitions: Optional[int] = None) -> None:
|
|
360
352
|
if num_partitions is None:
|
|
361
353
|
num_partitions = multiprocessing.cpu_count()
|
|
362
354
|
|
|
363
355
|
cat_to_nex_pairs = get_cat_to_nex_flowpairs(hydrofabric=paths.geopackage_path)
|
|
364
|
-
nexus = defaultdict(list)
|
|
356
|
+
# nexus = defaultdict(list)
|
|
365
357
|
|
|
366
358
|
# for cat, nex in cat_to_nex_pairs:
|
|
367
359
|
# nexus[nex].append(cat)
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from dask.distributed import Client
|
|
4
|
+
|
|
5
|
+
logger = logging.getLogger(__name__)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def shutdown_cluster():
|
|
9
|
+
try:
|
|
10
|
+
client = Client.current()
|
|
11
|
+
client.shutdown()
|
|
12
|
+
except ValueError:
|
|
13
|
+
logger.debug("No cluster found to shutdown")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def no_cluster(func):
|
|
17
|
+
"""
|
|
18
|
+
Decorator that ensures the wrapped function runs with no active Dask cluster.
|
|
19
|
+
|
|
20
|
+
This decorator attempts to shut down any existing Dask cluster before
|
|
21
|
+
executing the wrapped function. If no cluster is found, it logs a debug message
|
|
22
|
+
and continues execution.
|
|
23
|
+
|
|
24
|
+
Parameters:
|
|
25
|
+
func: The function to be executed without a Dask cluster
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
wrapper: The wrapped function that will be executed without a Dask cluster
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def wrapper(*args, **kwargs):
|
|
32
|
+
shutdown_cluster()
|
|
33
|
+
result = func(*args, **kwargs)
|
|
34
|
+
return result
|
|
35
|
+
|
|
36
|
+
return wrapper
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def use_cluster(func):
|
|
40
|
+
"""
|
|
41
|
+
Decorator that ensures the wrapped function has access to a Dask cluster.
|
|
42
|
+
|
|
43
|
+
If a Dask cluster is already running, it uses the existing one.
|
|
44
|
+
If no cluster is available, it creates a new one before executing the function.
|
|
45
|
+
The cluster remains active after the function completes.
|
|
46
|
+
|
|
47
|
+
Parameters:
|
|
48
|
+
func: The function to be executed with a Dask cluster
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
wrapper: The wrapped function with access to a Dask cluster
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def wrapper(*args, **kwargs):
|
|
55
|
+
try:
|
|
56
|
+
client = Client.current()
|
|
57
|
+
except ValueError:
|
|
58
|
+
client = Client()
|
|
59
|
+
result = func(*args, **kwargs)
|
|
60
|
+
return result
|
|
61
|
+
|
|
62
|
+
return wrapper
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def temp_cluster(func):
|
|
66
|
+
"""
|
|
67
|
+
Decorator that provides a temporary Dask cluster for the wrapped function.
|
|
68
|
+
|
|
69
|
+
If a Dask cluster is already running, it uses the existing one and leaves it running.
|
|
70
|
+
If no cluster exists, it creates a temporary one and shuts it down after
|
|
71
|
+
the function completes.
|
|
72
|
+
|
|
73
|
+
Parameters:
|
|
74
|
+
func: The function to be executed with a Dask cluster
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
wrapper: The wrapped function with access to a Dask cluster
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
def wrapper(*args, **kwargs):
|
|
81
|
+
cluster_was_running = True
|
|
82
|
+
try:
|
|
83
|
+
client = Client.current()
|
|
84
|
+
except ValueError:
|
|
85
|
+
cluster_was_running = False
|
|
86
|
+
client = Client()
|
|
87
|
+
result = func(*args, **kwargs)
|
|
88
|
+
if not cluster_was_running:
|
|
89
|
+
client.shutdown()
|
|
90
|
+
return result
|
|
91
|
+
|
|
92
|
+
return wrapper
|
data_processing/dataset_utils.py
CHANGED
|
@@ -1,19 +1,22 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
|
+
from datetime import datetime
|
|
3
4
|
from pathlib import Path
|
|
4
|
-
from typing import Tuple, Union
|
|
5
|
+
from typing import List, Literal, Optional, Tuple, Union
|
|
5
6
|
|
|
6
7
|
import geopandas as gpd
|
|
7
8
|
import numpy as np
|
|
8
9
|
import xarray as xr
|
|
9
|
-
from
|
|
10
|
-
import
|
|
10
|
+
from xarray.core.types import InterpOptions
|
|
11
|
+
from dask.distributed import Client, progress, Future
|
|
12
|
+
from data_processing.dask_utils import use_cluster
|
|
11
13
|
|
|
12
14
|
logger = logging.getLogger(__name__)
|
|
13
15
|
|
|
14
16
|
# known ngen variable names
|
|
15
17
|
# https://github.com/CIROH-UA/ngen/blob/4fb5bb68dc397298bca470dfec94db2c1dcb42fe/include/forcing/AorcForcing.hpp#L77
|
|
16
18
|
|
|
19
|
+
|
|
17
20
|
def validate_dataset_format(dataset: xr.Dataset) -> None:
|
|
18
21
|
"""
|
|
19
22
|
Validate the format of the dataset.
|
|
@@ -41,8 +44,9 @@ def validate_dataset_format(dataset: xr.Dataset) -> None:
|
|
|
41
44
|
if "name" not in dataset.attrs:
|
|
42
45
|
raise ValueError("Dataset must have a name attribute to identify it")
|
|
43
46
|
|
|
47
|
+
|
|
44
48
|
def validate_time_range(dataset: xr.Dataset, start_time: str, end_time: str) -> Tuple[str, str]:
|
|
45
|
-
|
|
49
|
+
"""
|
|
46
50
|
Ensure that all selected times are in the passed dataset.
|
|
47
51
|
|
|
48
52
|
Parameters
|
|
@@ -60,7 +64,7 @@ def validate_time_range(dataset: xr.Dataset, start_time: str, end_time: str) ->
|
|
|
60
64
|
start_time, or if not available, earliest available timestep in dataset.
|
|
61
65
|
str
|
|
62
66
|
end_time, or if not available, latest available timestep in dataset.
|
|
63
|
-
|
|
67
|
+
"""
|
|
64
68
|
end_time_in_dataset = dataset.time.isel(time=-1).values
|
|
65
69
|
start_time_in_dataset = dataset.time.isel(time=0).values
|
|
66
70
|
if np.datetime64(start_time) < start_time_in_dataset:
|
|
@@ -77,7 +81,10 @@ def validate_time_range(dataset: xr.Dataset, start_time: str, end_time: str) ->
|
|
|
77
81
|
|
|
78
82
|
|
|
79
83
|
def clip_dataset_to_bounds(
|
|
80
|
-
dataset: xr.Dataset,
|
|
84
|
+
dataset: xr.Dataset,
|
|
85
|
+
bounds: Tuple[float, float, float, float] | np.ndarray[tuple[int], np.dtype[np.float64]],
|
|
86
|
+
start_time: str,
|
|
87
|
+
end_time: str,
|
|
81
88
|
) -> xr.Dataset:
|
|
82
89
|
"""
|
|
83
90
|
Clip the dataset to specified geographical bounds.
|
|
@@ -86,14 +93,14 @@ def clip_dataset_to_bounds(
|
|
|
86
93
|
----------
|
|
87
94
|
dataset : xr.Dataset
|
|
88
95
|
Dataset to be clipped.
|
|
89
|
-
bounds : tuple[float, float, float, float]
|
|
90
|
-
Corners of bounding box. bounds[0] is x_min, bounds[1] is y_min,
|
|
96
|
+
bounds : tuple[float, float, float, float] | np.ndarray[tuple[int], np.dtype[np.float64]]
|
|
97
|
+
Corners of bounding box. bounds[0] is x_min, bounds[1] is y_min,
|
|
91
98
|
bounds[2] is x_max, bounds[3] is y_max.
|
|
92
99
|
start_time : str
|
|
93
100
|
Desired start time in YYYY/MM/DD HH:MM:SS format.
|
|
94
101
|
end_time : str
|
|
95
102
|
Desired end time in YYYY/MM/DD HH:MM:SS format.
|
|
96
|
-
|
|
103
|
+
|
|
97
104
|
Returns
|
|
98
105
|
-------
|
|
99
106
|
xr.Dataset
|
|
@@ -110,33 +117,103 @@ def clip_dataset_to_bounds(
|
|
|
110
117
|
return dataset
|
|
111
118
|
|
|
112
119
|
|
|
113
|
-
def
|
|
114
|
-
|
|
115
|
-
|
|
120
|
+
def interpolate_nan_values(
|
|
121
|
+
dataset: xr.Dataset,
|
|
122
|
+
variables: Optional[List[str]] = None,
|
|
123
|
+
dim: str = "time",
|
|
124
|
+
method: InterpOptions = "nearest",
|
|
125
|
+
fill_value: str = "extrapolate",
|
|
126
|
+
) -> None:
|
|
127
|
+
"""
|
|
128
|
+
Interpolates NaN values in specified (or all numeric time-dependent)
|
|
129
|
+
variables of an xarray.Dataset. Operates inplace on the dataset.
|
|
116
130
|
|
|
117
|
-
|
|
118
|
-
|
|
131
|
+
Parameters
|
|
132
|
+
----------
|
|
133
|
+
dataset : xr.Dataset
|
|
134
|
+
The input dataset.
|
|
135
|
+
variables : Optional[List[str]], optional
|
|
136
|
+
A list of variable names to process. If None (default),
|
|
137
|
+
all numeric variables containing the specified dimension will be processed.
|
|
138
|
+
dim : str, optional
|
|
139
|
+
The dimension along which to interpolate (default is "time").
|
|
140
|
+
method : str, optional
|
|
141
|
+
Interpolation method to use (e.g., "linear", "nearest", "cubic").
|
|
142
|
+
Default is "nearest".
|
|
143
|
+
fill_value : str, optional
|
|
144
|
+
Method for filling NaNs at the start/end of the series after interpolation.
|
|
145
|
+
Set to "extrapolate" to fill with the nearest valid value when using 'nearest' or 'linear'.
|
|
146
|
+
Default is "extrapolate".
|
|
147
|
+
"""
|
|
148
|
+
for name, var in dataset.data_vars.items():
|
|
149
|
+
# if the variable is non-numeric, skip
|
|
150
|
+
if not np.issubdtype(var.dtype, np.number):
|
|
151
|
+
continue
|
|
152
|
+
# if there are no NANs, skip
|
|
153
|
+
if not var.isnull().any().compute():
|
|
154
|
+
continue
|
|
155
|
+
|
|
156
|
+
dataset[name] = var.interpolate_na(
|
|
157
|
+
dim=dim,
|
|
158
|
+
method=method,
|
|
159
|
+
fill_value=fill_value if method in ["nearest", "linear"] else None,
|
|
160
|
+
)
|
|
119
161
|
|
|
120
|
-
# sort of terrible work around for half downloaded files
|
|
121
|
-
temp_path = cached_nc_path.with_suffix(".downloading.nc")
|
|
122
|
-
if os.path.exists(temp_path):
|
|
123
|
-
os.remove(temp_path)
|
|
124
162
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
163
|
+
@use_cluster
|
|
164
|
+
def save_dataset(
|
|
165
|
+
ds_to_save: xr.Dataset,
|
|
166
|
+
target_path: Path,
|
|
167
|
+
engine: Literal["netcdf4", "scipy", "h5netcdf"] = "h5netcdf",
|
|
168
|
+
):
|
|
169
|
+
"""
|
|
170
|
+
Helper function to compute and save an xarray.Dataset to a NetCDF file.
|
|
171
|
+
Uses a temporary file and rename for atomicity.
|
|
172
|
+
"""
|
|
173
|
+
if not target_path.parent.exists():
|
|
174
|
+
target_path.parent.mkdir(parents=True, exist_ok=True)
|
|
175
|
+
|
|
176
|
+
temp_file_path = target_path.with_name(target_path.name + ".saving.nc")
|
|
177
|
+
if temp_file_path.exists():
|
|
178
|
+
os.remove(temp_file_path)
|
|
129
179
|
|
|
130
180
|
client = Client.current()
|
|
131
|
-
future = client.compute(
|
|
132
|
-
|
|
181
|
+
future: Future = client.compute(
|
|
182
|
+
ds_to_save.to_netcdf(temp_file_path, engine=engine, compute=False)
|
|
183
|
+
) # type: ignore
|
|
184
|
+
logger.debug(
|
|
185
|
+
f"NetCDF write task submitted to Dask. Waiting for completion to {temp_file_path}..."
|
|
186
|
+
)
|
|
133
187
|
progress(future)
|
|
134
188
|
future.result()
|
|
189
|
+
os.rename(str(temp_file_path), str(target_path))
|
|
190
|
+
logger.info(f"Successfully saved data to: {target_path}")
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
@use_cluster
|
|
194
|
+
def save_to_cache(
|
|
195
|
+
stores: xr.Dataset, cached_nc_path: Path, interpolate_nans: bool = True
|
|
196
|
+
) -> xr.Dataset:
|
|
197
|
+
"""
|
|
198
|
+
Compute the store and save it to a cached netCDF file. This is not required but will save time and bandwidth.
|
|
199
|
+
"""
|
|
200
|
+
logger.info(f"Processing dataset for caching. Final cache target: {cached_nc_path}")
|
|
135
201
|
|
|
136
|
-
|
|
202
|
+
# lasily cast all numbers to f32
|
|
203
|
+
for name, var in stores.data_vars.items():
|
|
204
|
+
if np.issubdtype(var.dtype, np.number):
|
|
205
|
+
stores[name] = var.astype("float32", casting="same_kind")
|
|
137
206
|
|
|
138
|
-
|
|
139
|
-
|
|
207
|
+
# save dataset locally before manipulating it
|
|
208
|
+
save_dataset(stores, cached_nc_path)
|
|
209
|
+
stores = xr.open_mfdataset(cached_nc_path, parallel=True, engine="h5netcdf")
|
|
210
|
+
|
|
211
|
+
if interpolate_nans:
|
|
212
|
+
interpolate_nan_values(dataset=stores)
|
|
213
|
+
save_dataset(stores, cached_nc_path)
|
|
214
|
+
stores = xr.open_mfdataset(cached_nc_path, parallel=True, engine="h5netcdf")
|
|
215
|
+
|
|
216
|
+
return stores
|
|
140
217
|
|
|
141
218
|
|
|
142
219
|
def check_local_cache(
|
|
@@ -144,9 +221,8 @@ def check_local_cache(
|
|
|
144
221
|
start_time: str,
|
|
145
222
|
end_time: str,
|
|
146
223
|
gdf: gpd.GeoDataFrame,
|
|
147
|
-
remote_dataset: xr.Dataset
|
|
224
|
+
remote_dataset: xr.Dataset,
|
|
148
225
|
) -> Union[xr.Dataset, None]:
|
|
149
|
-
|
|
150
226
|
merged_data = None
|
|
151
227
|
|
|
152
228
|
if not os.path.exists(cached_nc_path):
|
|
@@ -155,9 +231,7 @@ def check_local_cache(
|
|
|
155
231
|
|
|
156
232
|
logger.info("Found cached nc file")
|
|
157
233
|
# open the cached file and check that the time range is correct
|
|
158
|
-
cached_data = xr.open_mfdataset(
|
|
159
|
-
cached_nc_path, parallel=True, engine="h5netcdf"
|
|
160
|
-
)
|
|
234
|
+
cached_data = xr.open_mfdataset(cached_nc_path, parallel=True, engine="h5netcdf")
|
|
161
235
|
|
|
162
236
|
if "name" not in cached_data.attrs or "name" not in remote_dataset.attrs:
|
|
163
237
|
logger.warning("No name attribute found to compare datasets")
|
|
@@ -166,9 +240,9 @@ def check_local_cache(
|
|
|
166
240
|
logger.warning("Cached data from different source, .name attr doesn't match")
|
|
167
241
|
return
|
|
168
242
|
|
|
169
|
-
range_in_cache = cached_data.time[0].values <= np.datetime64(
|
|
170
|
-
|
|
171
|
-
|
|
243
|
+
range_in_cache = cached_data.time[0].values <= np.datetime64(start_time) and cached_data.time[
|
|
244
|
+
-1
|
|
245
|
+
].values >= np.datetime64(end_time)
|
|
172
246
|
|
|
173
247
|
if not range_in_cache:
|
|
174
248
|
# the cache does not contain the desired time range
|
|
@@ -186,10 +260,8 @@ def check_local_cache(
|
|
|
186
260
|
if range_in_cache:
|
|
187
261
|
logger.info("Time range is within cached data")
|
|
188
262
|
logger.debug(f"Opened cached nc file: [{cached_nc_path}]")
|
|
189
|
-
merged_data = clip_dataset_to_bounds(
|
|
190
|
-
|
|
191
|
-
)
|
|
192
|
-
logger.debug("Clipped stores")
|
|
263
|
+
merged_data = clip_dataset_to_bounds(cached_data, gdf.total_bounds, start_time, end_time)
|
|
264
|
+
logger.debug("Clipped stores")
|
|
193
265
|
|
|
194
266
|
return merged_data
|
|
195
267
|
|
|
@@ -197,16 +269,27 @@ def check_local_cache(
|
|
|
197
269
|
def save_and_clip_dataset(
|
|
198
270
|
dataset: xr.Dataset,
|
|
199
271
|
gdf: gpd.GeoDataFrame,
|
|
200
|
-
start_time: datetime
|
|
201
|
-
end_time: datetime
|
|
272
|
+
start_time: datetime,
|
|
273
|
+
end_time: datetime,
|
|
202
274
|
cache_location: Path,
|
|
203
275
|
) -> xr.Dataset:
|
|
204
276
|
"""convenience function clip the remote dataset, and either load from cache or save to cache if it's not present"""
|
|
205
277
|
gdf = gdf.to_crs(dataset.crs)
|
|
206
278
|
|
|
207
|
-
cached_data = check_local_cache(
|
|
279
|
+
cached_data = check_local_cache(
|
|
280
|
+
cache_location,
|
|
281
|
+
start_time, # type: ignore
|
|
282
|
+
end_time, # type: ignore
|
|
283
|
+
gdf,
|
|
284
|
+
dataset,
|
|
285
|
+
)
|
|
208
286
|
|
|
209
287
|
if not cached_data:
|
|
210
|
-
clipped_data = clip_dataset_to_bounds(
|
|
288
|
+
clipped_data = clip_dataset_to_bounds(
|
|
289
|
+
dataset,
|
|
290
|
+
gdf.total_bounds,
|
|
291
|
+
start_time, # type: ignore
|
|
292
|
+
end_time, # type: ignore
|
|
293
|
+
)
|
|
211
294
|
cached_data = save_to_cache(clipped_data, cache_location)
|
|
212
|
-
return cached_data
|
|
295
|
+
return cached_data
|