ngiab-data-preprocess 4.2.2__py3-none-any.whl → 4.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_processing/create_realization.py +39 -22
- data_processing/dask_utils.py +92 -0
- data_processing/dataset_utils.py +161 -44
- data_processing/datasets.py +18 -29
- data_processing/file_paths.py +7 -7
- data_processing/forcings.py +40 -38
- data_processing/gpkg_utils.py +13 -13
- data_processing/graph_utils.py +4 -6
- data_processing/s3fs_utils.py +1 -1
- data_processing/subset.py +39 -8
- data_sources/ngen-routing-template.yaml +1 -1
- data_sources/source_validation.py +72 -34
- map_app/__main__.py +3 -2
- map_app/static/css/main.css +14 -3
- map_app/static/js/data_processing.js +31 -55
- map_app/static/js/main.js +224 -106
- map_app/templates/index.html +10 -1
- map_app/views.py +17 -3
- ngiab_data_cli/__main__.py +32 -29
- ngiab_data_cli/arguments.py +0 -1
- ngiab_data_cli/forcing_cli.py +10 -19
- ngiab_data_preprocess-4.4.0.dist-info/METADATA +308 -0
- ngiab_data_preprocess-4.4.0.dist-info/RECORD +43 -0
- {ngiab_data_preprocess-4.2.2.dist-info → ngiab_data_preprocess-4.4.0.dist-info}/WHEEL +1 -1
- ngiab_data_preprocess-4.2.2.dist-info/METADATA +0 -258
- ngiab_data_preprocess-4.2.2.dist-info/RECORD +0 -42
- {ngiab_data_preprocess-4.2.2.dist-info → ngiab_data_preprocess-4.4.0.dist-info}/entry_points.txt +0 -0
- {ngiab_data_preprocess-4.2.2.dist-info → ngiab_data_preprocess-4.4.0.dist-info}/licenses/LICENSE +0 -0
- {ngiab_data_preprocess-4.2.2.dist-info → ngiab_data_preprocess-4.4.0.dist-info}/top_level.txt +0 -0
|
@@ -3,15 +3,17 @@ import logging
|
|
|
3
3
|
import multiprocessing
|
|
4
4
|
import shutil
|
|
5
5
|
import sqlite3
|
|
6
|
-
from collections import defaultdict
|
|
7
6
|
from datetime import datetime
|
|
8
7
|
from pathlib import Path
|
|
8
|
+
from typing import Dict, Optional
|
|
9
|
+
import psutil
|
|
10
|
+
import os
|
|
9
11
|
|
|
10
12
|
import pandas
|
|
11
13
|
import requests
|
|
12
14
|
import s3fs
|
|
13
15
|
import xarray as xr
|
|
14
|
-
from
|
|
16
|
+
from data_processing.dask_utils import temp_cluster
|
|
15
17
|
from data_processing.file_paths import file_paths
|
|
16
18
|
from data_processing.gpkg_utils import (
|
|
17
19
|
GeoPackage,
|
|
@@ -25,7 +27,8 @@ from tqdm.rich import tqdm
|
|
|
25
27
|
logger = logging.getLogger(__name__)
|
|
26
28
|
|
|
27
29
|
|
|
28
|
-
|
|
30
|
+
@temp_cluster
|
|
31
|
+
def get_approximate_gw_storage(paths: file_paths, start_date: datetime) -> Dict[str, int]:
|
|
29
32
|
# get the gw levels from the NWM output on a given start date
|
|
30
33
|
# this kind of works in place of warmstates for now
|
|
31
34
|
year = start_date.strftime("%Y")
|
|
@@ -35,17 +38,10 @@ def get_approximate_gw_storage(paths: file_paths, start_date: datetime):
|
|
|
35
38
|
fs = s3fs.S3FileSystem(anon=True)
|
|
36
39
|
nc_url = f"s3://noaa-nwm-retrospective-3-0-pds/CONUS/netcdf/GWOUT/{year}/{formatted_dt}.GWOUT_DOMAIN1"
|
|
37
40
|
|
|
38
|
-
# make sure there's a dask cluster running
|
|
39
|
-
try:
|
|
40
|
-
client = Client.current()
|
|
41
|
-
except ValueError:
|
|
42
|
-
cluster = LocalCluster()
|
|
43
|
-
client = Client(cluster)
|
|
44
|
-
|
|
45
41
|
with fs.open(nc_url) as file_obj:
|
|
46
|
-
ds = xr.open_dataset(file_obj)
|
|
42
|
+
ds = xr.open_dataset(file_obj) # type: ignore
|
|
47
43
|
|
|
48
|
-
water_levels = dict()
|
|
44
|
+
water_levels: Dict[str, int] = dict()
|
|
49
45
|
for cat, feature in tqdm(cat_to_feature.items()):
|
|
50
46
|
# this value is in CM, we need meters to match max_gw_depth
|
|
51
47
|
# xarray says it's in mm, with 0.1 scale factor. calling .values doesn't apply the scale
|
|
@@ -114,13 +110,13 @@ def make_noahowp_config(
|
|
|
114
110
|
lon=divide_conf_df.loc[divide, "longitude"],
|
|
115
111
|
terrain_slope=divide_conf_df.loc[divide, "mean.slope_1km"],
|
|
116
112
|
azimuth=divide_conf_df.loc[divide, "circ_mean.aspect"],
|
|
117
|
-
ISLTYP=int(divide_conf_df.loc[divide, "mode.ISLTYP"]),
|
|
118
|
-
IVGTYP=int(divide_conf_df.loc[divide, "mode.IVGTYP"]),
|
|
113
|
+
ISLTYP=int(divide_conf_df.loc[divide, "mode.ISLTYP"]), # type: ignore
|
|
114
|
+
IVGTYP=int(divide_conf_df.loc[divide, "mode.IVGTYP"]), # type: ignore
|
|
119
115
|
)
|
|
120
116
|
)
|
|
121
117
|
|
|
122
118
|
|
|
123
|
-
def get_model_attributes_modspatialite(hydrofabric: Path):
|
|
119
|
+
def get_model_attributes_modspatialite(hydrofabric: Path) -> pandas.DataFrame:
|
|
124
120
|
# modspatialite is faster than pyproj but can't be added as a pip dependency
|
|
125
121
|
# This incantation took a while
|
|
126
122
|
with GeoPackage(hydrofabric) as conn:
|
|
@@ -151,7 +147,7 @@ def get_model_attributes_modspatialite(hydrofabric: Path):
|
|
|
151
147
|
return divide_conf_df
|
|
152
148
|
|
|
153
149
|
|
|
154
|
-
def get_model_attributes_pyproj(hydrofabric: Path):
|
|
150
|
+
def get_model_attributes_pyproj(hydrofabric: Path) -> pandas.DataFrame:
|
|
155
151
|
# if modspatialite is not available, use pyproj
|
|
156
152
|
with sqlite3.connect(hydrofabric) as conn:
|
|
157
153
|
sql = """
|
|
@@ -185,7 +181,7 @@ def get_model_attributes_pyproj(hydrofabric: Path):
|
|
|
185
181
|
return divide_conf_df
|
|
186
182
|
|
|
187
183
|
|
|
188
|
-
def get_model_attributes(hydrofabric: Path):
|
|
184
|
+
def get_model_attributes(hydrofabric: Path) -> pandas.DataFrame:
|
|
189
185
|
try:
|
|
190
186
|
with GeoPackage(hydrofabric) as conn:
|
|
191
187
|
conf_df = pandas.read_sql_query(
|
|
@@ -259,11 +255,31 @@ def make_em_config(
|
|
|
259
255
|
|
|
260
256
|
def configure_troute(
|
|
261
257
|
cat_id: str, config_dir: Path, start_time: datetime, end_time: datetime
|
|
262
|
-
) ->
|
|
258
|
+
) -> None:
|
|
263
259
|
with open(file_paths.template_troute_config, "r") as file:
|
|
264
260
|
troute_template = file.read()
|
|
265
261
|
time_step_size = 300
|
|
262
|
+
gpkg_file_path=f"{config_dir}/{cat_id}_subset.gpkg"
|
|
266
263
|
nts = (end_time - start_time).total_seconds() / time_step_size
|
|
264
|
+
with sqlite3.connect(gpkg_file_path) as conn:
|
|
265
|
+
ncats_df = pandas.read_sql_query("SELECT COUNT(id) FROM 'divides';", conn)
|
|
266
|
+
ncats = ncats_df['COUNT(id)'][0]
|
|
267
|
+
|
|
268
|
+
est_bytes_required = nts * ncats * 45 # extremely rough calculation based on about 3 tests :)
|
|
269
|
+
local_ram_available = 0.8 * psutil.virtual_memory().available # buffer to not accidentally explode machine
|
|
270
|
+
|
|
271
|
+
if est_bytes_required > local_ram_available:
|
|
272
|
+
max_loop_size = nts // (est_bytes_required // local_ram_available)
|
|
273
|
+
binary_nexus_file_folder_comment = ""
|
|
274
|
+
parent_dir = config_dir.parent
|
|
275
|
+
output_parquet_path = Path(f"{parent_dir}/outputs/parquet/")
|
|
276
|
+
|
|
277
|
+
if not output_parquet_path.exists():
|
|
278
|
+
os.makedirs(output_parquet_path)
|
|
279
|
+
else:
|
|
280
|
+
max_loop_size = nts
|
|
281
|
+
binary_nexus_file_folder_comment = "#"
|
|
282
|
+
|
|
267
283
|
filled_template = troute_template.format(
|
|
268
284
|
# hard coded to 5 minutes
|
|
269
285
|
time_step_size=time_step_size,
|
|
@@ -272,7 +288,8 @@ def configure_troute(
|
|
|
272
288
|
geo_file_path=f"./config/{cat_id}_subset.gpkg",
|
|
273
289
|
start_datetime=start_time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
274
290
|
nts=nts,
|
|
275
|
-
max_loop_size=
|
|
291
|
+
max_loop_size=max_loop_size,
|
|
292
|
+
binary_nexus_file_folder_comment=binary_nexus_file_folder_comment
|
|
276
293
|
)
|
|
277
294
|
|
|
278
295
|
with open(config_dir / "troute.yaml", "w") as file:
|
|
@@ -316,7 +333,7 @@ def create_realization(
|
|
|
316
333
|
start_time: datetime,
|
|
317
334
|
end_time: datetime,
|
|
318
335
|
use_nwm_gw: bool = False,
|
|
319
|
-
gage_id: str = None,
|
|
336
|
+
gage_id: Optional[str] = None,
|
|
320
337
|
):
|
|
321
338
|
paths = file_paths(cat_id)
|
|
322
339
|
|
|
@@ -354,12 +371,12 @@ def create_realization(
|
|
|
354
371
|
create_partitions(paths)
|
|
355
372
|
|
|
356
373
|
|
|
357
|
-
def create_partitions(paths:
|
|
374
|
+
def create_partitions(paths: file_paths, num_partitions: Optional[int] = None) -> None:
|
|
358
375
|
if num_partitions is None:
|
|
359
376
|
num_partitions = multiprocessing.cpu_count()
|
|
360
377
|
|
|
361
378
|
cat_to_nex_pairs = get_cat_to_nex_flowpairs(hydrofabric=paths.geopackage_path)
|
|
362
|
-
nexus = defaultdict(list)
|
|
379
|
+
# nexus = defaultdict(list)
|
|
363
380
|
|
|
364
381
|
# for cat, nex in cat_to_nex_pairs:
|
|
365
382
|
# nexus[nex].append(cat)
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from dask.distributed import Client
|
|
4
|
+
|
|
5
|
+
logger = logging.getLogger(__name__)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def shutdown_cluster():
|
|
9
|
+
try:
|
|
10
|
+
client = Client.current()
|
|
11
|
+
client.shutdown()
|
|
12
|
+
except ValueError:
|
|
13
|
+
logger.debug("No cluster found to shutdown")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def no_cluster(func):
|
|
17
|
+
"""
|
|
18
|
+
Decorator that ensures the wrapped function runs with no active Dask cluster.
|
|
19
|
+
|
|
20
|
+
This decorator attempts to shut down any existing Dask cluster before
|
|
21
|
+
executing the wrapped function. If no cluster is found, it logs a debug message
|
|
22
|
+
and continues execution.
|
|
23
|
+
|
|
24
|
+
Parameters:
|
|
25
|
+
func: The function to be executed without a Dask cluster
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
wrapper: The wrapped function that will be executed without a Dask cluster
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def wrapper(*args, **kwargs):
|
|
32
|
+
shutdown_cluster()
|
|
33
|
+
result = func(*args, **kwargs)
|
|
34
|
+
return result
|
|
35
|
+
|
|
36
|
+
return wrapper
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def use_cluster(func):
|
|
40
|
+
"""
|
|
41
|
+
Decorator that ensures the wrapped function has access to a Dask cluster.
|
|
42
|
+
|
|
43
|
+
If a Dask cluster is already running, it uses the existing one.
|
|
44
|
+
If no cluster is available, it creates a new one before executing the function.
|
|
45
|
+
The cluster remains active after the function completes.
|
|
46
|
+
|
|
47
|
+
Parameters:
|
|
48
|
+
func: The function to be executed with a Dask cluster
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
wrapper: The wrapped function with access to a Dask cluster
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def wrapper(*args, **kwargs):
|
|
55
|
+
try:
|
|
56
|
+
client = Client.current()
|
|
57
|
+
except ValueError:
|
|
58
|
+
client = Client()
|
|
59
|
+
result = func(*args, **kwargs)
|
|
60
|
+
return result
|
|
61
|
+
|
|
62
|
+
return wrapper
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def temp_cluster(func):
|
|
66
|
+
"""
|
|
67
|
+
Decorator that provides a temporary Dask cluster for the wrapped function.
|
|
68
|
+
|
|
69
|
+
If a Dask cluster is already running, it uses the existing one and leaves it running.
|
|
70
|
+
If no cluster exists, it creates a temporary one and shuts it down after
|
|
71
|
+
the function completes.
|
|
72
|
+
|
|
73
|
+
Parameters:
|
|
74
|
+
func: The function to be executed with a Dask cluster
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
wrapper: The wrapped function with access to a Dask cluster
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
def wrapper(*args, **kwargs):
|
|
81
|
+
cluster_was_running = True
|
|
82
|
+
try:
|
|
83
|
+
client = Client.current()
|
|
84
|
+
except ValueError:
|
|
85
|
+
cluster_was_running = False
|
|
86
|
+
client = Client()
|
|
87
|
+
result = func(*args, **kwargs)
|
|
88
|
+
if not cluster_was_running:
|
|
89
|
+
client.shutdown()
|
|
90
|
+
return result
|
|
91
|
+
|
|
92
|
+
return wrapper
|
data_processing/dataset_utils.py
CHANGED
|
@@ -1,19 +1,22 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
|
+
from datetime import datetime
|
|
3
4
|
from pathlib import Path
|
|
4
|
-
from typing import Tuple, Union
|
|
5
|
+
from typing import List, Literal, Optional, Tuple, Union
|
|
5
6
|
|
|
6
7
|
import geopandas as gpd
|
|
7
8
|
import numpy as np
|
|
8
9
|
import xarray as xr
|
|
9
|
-
from dask.distributed import Client,
|
|
10
|
-
import
|
|
10
|
+
from dask.distributed import Client, Future, progress
|
|
11
|
+
from data_processing.dask_utils import no_cluster, temp_cluster
|
|
12
|
+
from xarray.core.types import InterpOptions
|
|
11
13
|
|
|
12
14
|
logger = logging.getLogger(__name__)
|
|
13
15
|
|
|
14
16
|
# known ngen variable names
|
|
15
17
|
# https://github.com/CIROH-UA/ngen/blob/4fb5bb68dc397298bca470dfec94db2c1dcb42fe/include/forcing/AorcForcing.hpp#L77
|
|
16
18
|
|
|
19
|
+
|
|
17
20
|
def validate_dataset_format(dataset: xr.Dataset) -> None:
|
|
18
21
|
"""
|
|
19
22
|
Validate the format of the dataset.
|
|
@@ -41,8 +44,9 @@ def validate_dataset_format(dataset: xr.Dataset) -> None:
|
|
|
41
44
|
if "name" not in dataset.attrs:
|
|
42
45
|
raise ValueError("Dataset must have a name attribute to identify it")
|
|
43
46
|
|
|
47
|
+
|
|
44
48
|
def validate_time_range(dataset: xr.Dataset, start_time: str, end_time: str) -> Tuple[str, str]:
|
|
45
|
-
|
|
49
|
+
"""
|
|
46
50
|
Ensure that all selected times are in the passed dataset.
|
|
47
51
|
|
|
48
52
|
Parameters
|
|
@@ -60,7 +64,7 @@ def validate_time_range(dataset: xr.Dataset, start_time: str, end_time: str) ->
|
|
|
60
64
|
start_time, or if not available, earliest available timestep in dataset.
|
|
61
65
|
str
|
|
62
66
|
end_time, or if not available, latest available timestep in dataset.
|
|
63
|
-
|
|
67
|
+
"""
|
|
64
68
|
end_time_in_dataset = dataset.time.isel(time=-1).values
|
|
65
69
|
start_time_in_dataset = dataset.time.isel(time=0).values
|
|
66
70
|
if np.datetime64(start_time) < start_time_in_dataset:
|
|
@@ -77,7 +81,10 @@ def validate_time_range(dataset: xr.Dataset, start_time: str, end_time: str) ->
|
|
|
77
81
|
|
|
78
82
|
|
|
79
83
|
def clip_dataset_to_bounds(
|
|
80
|
-
dataset: xr.Dataset,
|
|
84
|
+
dataset: xr.Dataset,
|
|
85
|
+
bounds: Tuple[float, float, float, float] | np.ndarray[tuple[int], np.dtype[np.float64]],
|
|
86
|
+
start_time: str,
|
|
87
|
+
end_time: str,
|
|
81
88
|
) -> xr.Dataset:
|
|
82
89
|
"""
|
|
83
90
|
Clip the dataset to specified geographical bounds.
|
|
@@ -86,14 +93,14 @@ def clip_dataset_to_bounds(
|
|
|
86
93
|
----------
|
|
87
94
|
dataset : xr.Dataset
|
|
88
95
|
Dataset to be clipped.
|
|
89
|
-
bounds : tuple[float, float, float, float]
|
|
90
|
-
Corners of bounding box. bounds[0] is x_min, bounds[1] is y_min,
|
|
96
|
+
bounds : tuple[float, float, float, float] | np.ndarray[tuple[int], np.dtype[np.float64]]
|
|
97
|
+
Corners of bounding box. bounds[0] is x_min, bounds[1] is y_min,
|
|
91
98
|
bounds[2] is x_max, bounds[3] is y_max.
|
|
92
99
|
start_time : str
|
|
93
100
|
Desired start time in YYYY/MM/DD HH:MM:SS format.
|
|
94
101
|
end_time : str
|
|
95
102
|
Desired end time in YYYY/MM/DD HH:MM:SS format.
|
|
96
|
-
|
|
103
|
+
|
|
97
104
|
Returns
|
|
98
105
|
-------
|
|
99
106
|
xr.Dataset
|
|
@@ -110,33 +117,137 @@ def clip_dataset_to_bounds(
|
|
|
110
117
|
return dataset
|
|
111
118
|
|
|
112
119
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
120
|
+
@no_cluster
|
|
121
|
+
def interpolate_nan_values(
|
|
122
|
+
dataset: xr.Dataset,
|
|
123
|
+
variables: Optional[List[str]] = None,
|
|
124
|
+
dim: str = "time",
|
|
125
|
+
method: InterpOptions = "nearest",
|
|
126
|
+
fill_value: str = "extrapolate",
|
|
127
|
+
) -> bool:
|
|
128
|
+
"""
|
|
129
|
+
Interpolates NaN values in specified (or all numeric time-dependent)
|
|
130
|
+
variables of an xarray.Dataset. Operates inplace on the dataset.
|
|
131
|
+
|
|
132
|
+
Parameters
|
|
133
|
+
----------
|
|
134
|
+
dataset : xr.Dataset
|
|
135
|
+
The input dataset.
|
|
136
|
+
variables : Optional[List[str]], optional
|
|
137
|
+
A list of variable names to process. If None (default),
|
|
138
|
+
all numeric variables containing the specified dimension will be processed.
|
|
139
|
+
dim : str, optional
|
|
140
|
+
The dimension along which to interpolate (default is "time").
|
|
141
|
+
method : str, optional
|
|
142
|
+
Interpolation method to use (e.g., "linear", "nearest", "cubic").
|
|
143
|
+
Default is "nearest".
|
|
144
|
+
fill_value : str, optional
|
|
145
|
+
Method for filling NaNs at the start/end of the series after interpolation.
|
|
146
|
+
Set to "extrapolate" to fill with the nearest valid value when using 'nearest' or 'linear'.
|
|
147
|
+
Default is "extrapolate".
|
|
148
|
+
"""
|
|
149
|
+
interpolation_used = False
|
|
150
|
+
for name, var in dataset.data_vars.items():
|
|
151
|
+
# if the variable is non-numeric, skip
|
|
152
|
+
if not np.issubdtype(var.dtype, np.number):
|
|
153
|
+
continue
|
|
154
|
+
# if there are no NANs, skip
|
|
155
|
+
if not var.isnull().any().compute():
|
|
156
|
+
continue
|
|
157
|
+
|
|
158
|
+
dataset[name] = var.interpolate_na(
|
|
159
|
+
dim=dim,
|
|
160
|
+
method=method,
|
|
161
|
+
fill_value=fill_value if method in ["nearest", "linear"] else None,
|
|
162
|
+
)
|
|
163
|
+
interpolation_used = True
|
|
164
|
+
return interpolation_used
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
@no_cluster
|
|
168
|
+
def save_dataset_no_cluster(
|
|
169
|
+
ds_to_save: xr.Dataset,
|
|
170
|
+
target_path: Path,
|
|
171
|
+
engine: Literal["netcdf4", "scipy", "h5netcdf"] = "h5netcdf",
|
|
172
|
+
):
|
|
173
|
+
"""
|
|
174
|
+
This explicitly does not use dask distributed.
|
|
175
|
+
Helper function to compute and save an xarray.Dataset to a NetCDF file.
|
|
176
|
+
Uses a temporary file and rename for avoid leaving a half written file.
|
|
177
|
+
"""
|
|
178
|
+
if not target_path.parent.exists():
|
|
179
|
+
target_path.parent.mkdir(parents=True, exist_ok=True)
|
|
116
180
|
|
|
117
|
-
|
|
118
|
-
|
|
181
|
+
temp_file_path = target_path.with_name(target_path.name + ".saving.nc")
|
|
182
|
+
if temp_file_path.exists():
|
|
183
|
+
os.remove(temp_file_path)
|
|
119
184
|
|
|
120
|
-
|
|
121
|
-
temp_path = cached_nc_path.with_suffix(".downloading.nc")
|
|
122
|
-
if os.path.exists(temp_path):
|
|
123
|
-
os.remove(temp_path)
|
|
185
|
+
ds_to_save.to_netcdf(temp_file_path, engine=engine, compute=True)
|
|
124
186
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
187
|
+
os.rename(str(temp_file_path), str(target_path))
|
|
188
|
+
logger.info(f"Successfully saved data to: {target_path}")
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
@temp_cluster
|
|
192
|
+
def save_dataset(
|
|
193
|
+
ds_to_save: xr.Dataset,
|
|
194
|
+
target_path: Path,
|
|
195
|
+
engine: Literal["netcdf4", "scipy", "h5netcdf"] = "h5netcdf",
|
|
196
|
+
):
|
|
197
|
+
"""
|
|
198
|
+
Helper function to compute and save an xarray.Dataset to a NetCDF file.
|
|
199
|
+
Uses a temporary file and rename for atomicity.
|
|
200
|
+
"""
|
|
201
|
+
if not target_path.parent.exists():
|
|
202
|
+
target_path.parent.mkdir(parents=True, exist_ok=True)
|
|
203
|
+
|
|
204
|
+
temp_file_path = target_path.with_name(target_path.name + ".saving.nc")
|
|
205
|
+
if temp_file_path.exists():
|
|
206
|
+
os.remove(temp_file_path)
|
|
129
207
|
|
|
130
208
|
client = Client.current()
|
|
131
|
-
future = client.compute(
|
|
132
|
-
|
|
209
|
+
future: Future = client.compute(
|
|
210
|
+
ds_to_save.to_netcdf(temp_file_path, engine=engine, compute=False)
|
|
211
|
+
) # type: ignore
|
|
212
|
+
logger.debug(
|
|
213
|
+
f"NetCDF write task submitted to Dask. Waiting for completion to {temp_file_path}..."
|
|
214
|
+
)
|
|
215
|
+
logger.info("For more detailed progress, see the Dask dashboard http://localhost:8787/status")
|
|
133
216
|
progress(future)
|
|
134
217
|
future.result()
|
|
218
|
+
os.rename(str(temp_file_path), str(target_path))
|
|
219
|
+
logger.info(f"Successfully saved data to: {target_path}")
|
|
220
|
+
|
|
135
221
|
|
|
136
|
-
|
|
222
|
+
@no_cluster
|
|
223
|
+
def save_to_cache(
|
|
224
|
+
stores: xr.Dataset, cached_nc_path: Path, interpolate_nans: bool = True
|
|
225
|
+
) -> xr.Dataset:
|
|
226
|
+
"""
|
|
227
|
+
Compute the store and save it to a cached netCDF file. This is not required but will save time and bandwidth.
|
|
228
|
+
"""
|
|
229
|
+
logger.debug(f"Processing dataset for caching. Final cache target: {cached_nc_path}")
|
|
137
230
|
|
|
138
|
-
|
|
139
|
-
|
|
231
|
+
# lasily cast all numbers to f32
|
|
232
|
+
for name, var in stores.data_vars.items():
|
|
233
|
+
if np.issubdtype(var.dtype, np.number):
|
|
234
|
+
stores[name] = var.astype("float32", casting="same_kind")
|
|
235
|
+
|
|
236
|
+
# save dataset locally before manipulating it
|
|
237
|
+
save_dataset(stores, cached_nc_path)
|
|
238
|
+
|
|
239
|
+
if interpolate_nans:
|
|
240
|
+
stores = xr.open_mfdataset(
|
|
241
|
+
cached_nc_path,
|
|
242
|
+
parallel=True,
|
|
243
|
+
engine="h5netcdf",
|
|
244
|
+
)
|
|
245
|
+
was_interpolated = interpolate_nan_values(dataset=stores)
|
|
246
|
+
if was_interpolated:
|
|
247
|
+
save_dataset_no_cluster(stores, cached_nc_path)
|
|
248
|
+
|
|
249
|
+
stores = xr.open_mfdataset(cached_nc_path, parallel=True, engine="h5netcdf")
|
|
250
|
+
return stores
|
|
140
251
|
|
|
141
252
|
|
|
142
253
|
def check_local_cache(
|
|
@@ -144,9 +255,8 @@ def check_local_cache(
|
|
|
144
255
|
start_time: str,
|
|
145
256
|
end_time: str,
|
|
146
257
|
gdf: gpd.GeoDataFrame,
|
|
147
|
-
remote_dataset: xr.Dataset
|
|
258
|
+
remote_dataset: xr.Dataset,
|
|
148
259
|
) -> Union[xr.Dataset, None]:
|
|
149
|
-
|
|
150
260
|
merged_data = None
|
|
151
261
|
|
|
152
262
|
if not os.path.exists(cached_nc_path):
|
|
@@ -155,9 +265,7 @@ def check_local_cache(
|
|
|
155
265
|
|
|
156
266
|
logger.info("Found cached nc file")
|
|
157
267
|
# open the cached file and check that the time range is correct
|
|
158
|
-
cached_data = xr.open_mfdataset(
|
|
159
|
-
cached_nc_path, parallel=True, engine="h5netcdf"
|
|
160
|
-
)
|
|
268
|
+
cached_data = xr.open_mfdataset(cached_nc_path, parallel=True, engine="h5netcdf")
|
|
161
269
|
|
|
162
270
|
if "name" not in cached_data.attrs or "name" not in remote_dataset.attrs:
|
|
163
271
|
logger.warning("No name attribute found to compare datasets")
|
|
@@ -166,9 +274,9 @@ def check_local_cache(
|
|
|
166
274
|
logger.warning("Cached data from different source, .name attr doesn't match")
|
|
167
275
|
return
|
|
168
276
|
|
|
169
|
-
range_in_cache = cached_data.time[0].values <= np.datetime64(
|
|
170
|
-
|
|
171
|
-
|
|
277
|
+
range_in_cache = cached_data.time[0].values <= np.datetime64(start_time) and cached_data.time[
|
|
278
|
+
-1
|
|
279
|
+
].values >= np.datetime64(end_time)
|
|
172
280
|
|
|
173
281
|
if not range_in_cache:
|
|
174
282
|
# the cache does not contain the desired time range
|
|
@@ -186,10 +294,8 @@ def check_local_cache(
|
|
|
186
294
|
if range_in_cache:
|
|
187
295
|
logger.info("Time range is within cached data")
|
|
188
296
|
logger.debug(f"Opened cached nc file: [{cached_nc_path}]")
|
|
189
|
-
merged_data = clip_dataset_to_bounds(
|
|
190
|
-
|
|
191
|
-
)
|
|
192
|
-
logger.debug("Clipped stores")
|
|
297
|
+
merged_data = clip_dataset_to_bounds(cached_data, gdf.total_bounds, start_time, end_time)
|
|
298
|
+
logger.debug("Clipped stores")
|
|
193
299
|
|
|
194
300
|
return merged_data
|
|
195
301
|
|
|
@@ -197,16 +303,27 @@ def check_local_cache(
|
|
|
197
303
|
def save_and_clip_dataset(
|
|
198
304
|
dataset: xr.Dataset,
|
|
199
305
|
gdf: gpd.GeoDataFrame,
|
|
200
|
-
start_time: datetime
|
|
201
|
-
end_time: datetime
|
|
306
|
+
start_time: datetime,
|
|
307
|
+
end_time: datetime,
|
|
202
308
|
cache_location: Path,
|
|
203
309
|
) -> xr.Dataset:
|
|
204
310
|
"""convenience function clip the remote dataset, and either load from cache or save to cache if it's not present"""
|
|
205
311
|
gdf = gdf.to_crs(dataset.crs)
|
|
206
312
|
|
|
207
|
-
cached_data = check_local_cache(
|
|
313
|
+
cached_data = check_local_cache(
|
|
314
|
+
cache_location,
|
|
315
|
+
start_time, # type: ignore
|
|
316
|
+
end_time, # type: ignore
|
|
317
|
+
gdf,
|
|
318
|
+
dataset,
|
|
319
|
+
)
|
|
208
320
|
|
|
209
321
|
if not cached_data:
|
|
210
|
-
clipped_data = clip_dataset_to_bounds(
|
|
322
|
+
clipped_data = clip_dataset_to_bounds(
|
|
323
|
+
dataset,
|
|
324
|
+
gdf.total_bounds,
|
|
325
|
+
start_time, # type: ignore
|
|
326
|
+
end_time, # type: ignore
|
|
327
|
+
)
|
|
211
328
|
cached_data = save_to_cache(clipped_data, cache_location)
|
|
212
|
-
return cached_data
|
|
329
|
+
return cached_data
|
data_processing/datasets.py
CHANGED
|
@@ -1,35 +1,31 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
from typing import Optional
|
|
2
3
|
|
|
3
4
|
import s3fs
|
|
4
|
-
from data_processing.s3fs_utils import S3ParallelFileSystem
|
|
5
5
|
import xarray as xr
|
|
6
|
-
from
|
|
6
|
+
from data_processing.dask_utils import use_cluster
|
|
7
7
|
from data_processing.dataset_utils import validate_dataset_format
|
|
8
|
-
|
|
8
|
+
from data_processing.s3fs_utils import S3ParallelFileSystem
|
|
9
9
|
|
|
10
10
|
logger = logging.getLogger(__name__)
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
|
|
13
|
+
@use_cluster
|
|
14
|
+
def load_v3_retrospective_zarr(forcing_vars: Optional[list[str]] = None) -> xr.Dataset:
|
|
14
15
|
"""Load zarr datasets from S3 within the specified time range."""
|
|
15
16
|
# if a LocalCluster is not already running, start one
|
|
16
17
|
if not forcing_vars:
|
|
17
18
|
forcing_vars = ["lwdown", "precip", "psfc", "q2d", "swdown", "t2d", "u2d", "v2d"]
|
|
18
|
-
|
|
19
|
-
client = Client.current()
|
|
20
|
-
except ValueError:
|
|
21
|
-
cluster = LocalCluster()
|
|
22
|
-
client = Client(cluster)
|
|
19
|
+
|
|
23
20
|
s3_urls = [
|
|
24
|
-
f"s3://noaa-nwm-retrospective-3-0-pds/CONUS/zarr/forcing/{var}.zarr"
|
|
25
|
-
for var in forcing_vars
|
|
21
|
+
f"s3://noaa-nwm-retrospective-3-0-pds/CONUS/zarr/forcing/{var}.zarr" for var in forcing_vars
|
|
26
22
|
]
|
|
27
23
|
# default cache is readahead which is detrimental to performance in this case
|
|
28
24
|
fs = S3ParallelFileSystem(anon=True, default_cache_type="none") # default_block_size
|
|
29
25
|
s3_stores = [s3fs.S3Map(url, s3=fs) for url in s3_urls]
|
|
30
26
|
# the cache option here just holds accessed data in memory to prevent s3 being queried multiple times
|
|
31
27
|
# most of the data is read once and written to disk but some of the coordinate data is read multiple times
|
|
32
|
-
dataset = xr.open_mfdataset(s3_stores, parallel=True, engine="zarr", cache=True)
|
|
28
|
+
dataset = xr.open_mfdataset(s3_stores, parallel=True, engine="zarr", cache=True) # type: ignore
|
|
33
29
|
|
|
34
30
|
# set the crs attribute to conform with the format
|
|
35
31
|
esri_pe_string = dataset.crs.esri_pe_string
|
|
@@ -54,7 +50,8 @@ def load_v3_retrospective_zarr(forcing_vars: list[str] = None) -> xr.Dataset:
|
|
|
54
50
|
return dataset
|
|
55
51
|
|
|
56
52
|
|
|
57
|
-
|
|
53
|
+
@use_cluster
|
|
54
|
+
def load_aorc_zarr(start_year: Optional[int] = None, end_year: Optional[int] = None) -> xr.Dataset:
|
|
58
55
|
"""Load the aorc zarr dataset from S3."""
|
|
59
56
|
if not start_year or not end_year:
|
|
60
57
|
logger.warning("No start or end year provided, defaulting to 1979-2023")
|
|
@@ -63,11 +60,6 @@ def load_aorc_zarr(start_year: int = None, end_year: int = None) -> xr.Dataset:
|
|
|
63
60
|
start_year = 1979
|
|
64
61
|
if not end_year:
|
|
65
62
|
end_year = 2023
|
|
66
|
-
try:
|
|
67
|
-
client = Client.current()
|
|
68
|
-
except ValueError:
|
|
69
|
-
cluster = LocalCluster()
|
|
70
|
-
client = Client(cluster)
|
|
71
63
|
|
|
72
64
|
logger.info(f"Loading AORC zarr datasets from {start_year} to {end_year}")
|
|
73
65
|
estimated_time_s = ((end_year - start_year) * 2.5) + 3.5
|
|
@@ -75,9 +67,9 @@ def load_aorc_zarr(start_year: int = None, end_year: int = None) -> xr.Dataset:
|
|
|
75
67
|
logger.info(f"This should take roughly {estimated_time_s} seconds")
|
|
76
68
|
fs = S3ParallelFileSystem(anon=True, default_cache_type="none")
|
|
77
69
|
s3_url = "s3://noaa-nws-aorc-v1-1-1km/"
|
|
78
|
-
urls = [f"{s3_url}{i}.zarr" for i in range(start_year, end_year+1)]
|
|
70
|
+
urls = [f"{s3_url}{i}.zarr" for i in range(start_year, end_year + 1)]
|
|
79
71
|
filestores = [s3fs.S3Map(url, s3=fs) for url in urls]
|
|
80
|
-
dataset = xr.open_mfdataset(filestores, parallel=True, engine="zarr", cache=True)
|
|
72
|
+
dataset = xr.open_mfdataset(filestores, parallel=True, engine="zarr", cache=True) # type: ignore
|
|
81
73
|
dataset.attrs["crs"] = "+proj=longlat +datum=WGS84 +no_defs"
|
|
82
74
|
dataset.attrs["name"] = "aorc_1km_zarr"
|
|
83
75
|
# rename latitude and longitude to x and y
|
|
@@ -87,32 +79,29 @@ def load_aorc_zarr(start_year: int = None, end_year: int = None) -> xr.Dataset:
|
|
|
87
79
|
return dataset
|
|
88
80
|
|
|
89
81
|
|
|
82
|
+
@use_cluster
|
|
90
83
|
def load_swe_zarr() -> xr.Dataset:
|
|
91
84
|
"""Load the swe zarr dataset from S3."""
|
|
92
|
-
s3_urls = [
|
|
93
|
-
f"s3://noaa-nwm-retrospective-3-0-pds/CONUS/zarr/ldasout.zarr"
|
|
94
|
-
]
|
|
85
|
+
s3_urls = ["s3://noaa-nwm-retrospective-3-0-pds/CONUS/zarr/ldasout.zarr"]
|
|
95
86
|
# default cache is readahead which is detrimental to performance in this case
|
|
96
87
|
fs = S3ParallelFileSystem(anon=True, default_cache_type="none") # default_block_size
|
|
97
88
|
s3_stores = [s3fs.S3Map(url, s3=fs) for url in s3_urls]
|
|
98
89
|
# the cache option here just holds accessed data in memory to prevent s3 being queried multiple times
|
|
99
90
|
# most of the data is read once and written to disk but some of the coordinate data is read multiple times
|
|
100
|
-
dataset = xr.open_mfdataset(s3_stores, parallel=True, engine="zarr", cache=True)
|
|
101
|
-
|
|
91
|
+
dataset = xr.open_mfdataset(s3_stores, parallel=True, engine="zarr", cache=True) # type: ignore
|
|
92
|
+
|
|
102
93
|
# set the crs attribute to conform with the format
|
|
103
94
|
esri_pe_string = dataset.crs.esri_pe_string
|
|
104
95
|
dataset = dataset.drop_vars(["crs"])
|
|
105
96
|
dataset.attrs["crs"] = esri_pe_string
|
|
106
97
|
# drop everything except SNEQV
|
|
107
98
|
vars_to_drop = list(dataset.data_vars)
|
|
108
|
-
vars_to_drop.remove(
|
|
99
|
+
vars_to_drop.remove("SNEQV")
|
|
109
100
|
dataset = dataset.drop_vars(vars_to_drop)
|
|
110
101
|
dataset.attrs["name"] = "v3_swe_zarr"
|
|
111
102
|
|
|
112
103
|
# rename the data vars to work with ngen
|
|
113
|
-
variables = {
|
|
114
|
-
"SNEQV": "swe"
|
|
115
|
-
}
|
|
104
|
+
variables = {"SNEQV": "swe"}
|
|
116
105
|
dataset = dataset.rename_vars(variables)
|
|
117
106
|
|
|
118
107
|
validate_dataset_format(dataset)
|