ngiab-data-preprocess 4.2.2__py3-none-any.whl → 4.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,15 +3,17 @@ import logging
3
3
  import multiprocessing
4
4
  import shutil
5
5
  import sqlite3
6
- from collections import defaultdict
7
6
  from datetime import datetime
8
7
  from pathlib import Path
8
+ from typing import Dict, Optional
9
+ import psutil
10
+ import os
9
11
 
10
12
  import pandas
11
13
  import requests
12
14
  import s3fs
13
15
  import xarray as xr
14
- from dask.distributed import Client, LocalCluster
16
+ from data_processing.dask_utils import temp_cluster
15
17
  from data_processing.file_paths import file_paths
16
18
  from data_processing.gpkg_utils import (
17
19
  GeoPackage,
@@ -25,7 +27,8 @@ from tqdm.rich import tqdm
25
27
  logger = logging.getLogger(__name__)
26
28
 
27
29
 
28
- def get_approximate_gw_storage(paths: file_paths, start_date: datetime):
30
+ @temp_cluster
31
+ def get_approximate_gw_storage(paths: file_paths, start_date: datetime) -> Dict[str, int]:
29
32
  # get the gw levels from the NWM output on a given start date
30
33
  # this kind of works in place of warmstates for now
31
34
  year = start_date.strftime("%Y")
@@ -35,17 +38,10 @@ def get_approximate_gw_storage(paths: file_paths, start_date: datetime):
35
38
  fs = s3fs.S3FileSystem(anon=True)
36
39
  nc_url = f"s3://noaa-nwm-retrospective-3-0-pds/CONUS/netcdf/GWOUT/{year}/{formatted_dt}.GWOUT_DOMAIN1"
37
40
 
38
- # make sure there's a dask cluster running
39
- try:
40
- client = Client.current()
41
- except ValueError:
42
- cluster = LocalCluster()
43
- client = Client(cluster)
44
-
45
41
  with fs.open(nc_url) as file_obj:
46
- ds = xr.open_dataset(file_obj)
42
+ ds = xr.open_dataset(file_obj) # type: ignore
47
43
 
48
- water_levels = dict()
44
+ water_levels: Dict[str, int] = dict()
49
45
  for cat, feature in tqdm(cat_to_feature.items()):
50
46
  # this value is in CM, we need meters to match max_gw_depth
51
47
  # xarray says it's in mm, with 0.1 scale factor. calling .values doesn't apply the scale
@@ -114,13 +110,13 @@ def make_noahowp_config(
114
110
  lon=divide_conf_df.loc[divide, "longitude"],
115
111
  terrain_slope=divide_conf_df.loc[divide, "mean.slope_1km"],
116
112
  azimuth=divide_conf_df.loc[divide, "circ_mean.aspect"],
117
- ISLTYP=int(divide_conf_df.loc[divide, "mode.ISLTYP"]),
118
- IVGTYP=int(divide_conf_df.loc[divide, "mode.IVGTYP"]),
113
+ ISLTYP=int(divide_conf_df.loc[divide, "mode.ISLTYP"]), # type: ignore
114
+ IVGTYP=int(divide_conf_df.loc[divide, "mode.IVGTYP"]), # type: ignore
119
115
  )
120
116
  )
121
117
 
122
118
 
123
- def get_model_attributes_modspatialite(hydrofabric: Path):
119
+ def get_model_attributes_modspatialite(hydrofabric: Path) -> pandas.DataFrame:
124
120
  # modspatialite is faster than pyproj but can't be added as a pip dependency
125
121
  # This incantation took a while
126
122
  with GeoPackage(hydrofabric) as conn:
@@ -151,7 +147,7 @@ def get_model_attributes_modspatialite(hydrofabric: Path):
151
147
  return divide_conf_df
152
148
 
153
149
 
154
- def get_model_attributes_pyproj(hydrofabric: Path):
150
+ def get_model_attributes_pyproj(hydrofabric: Path) -> pandas.DataFrame:
155
151
  # if modspatialite is not available, use pyproj
156
152
  with sqlite3.connect(hydrofabric) as conn:
157
153
  sql = """
@@ -185,7 +181,7 @@ def get_model_attributes_pyproj(hydrofabric: Path):
185
181
  return divide_conf_df
186
182
 
187
183
 
188
- def get_model_attributes(hydrofabric: Path):
184
+ def get_model_attributes(hydrofabric: Path) -> pandas.DataFrame:
189
185
  try:
190
186
  with GeoPackage(hydrofabric) as conn:
191
187
  conf_df = pandas.read_sql_query(
@@ -259,11 +255,31 @@ def make_em_config(
259
255
 
260
256
  def configure_troute(
261
257
  cat_id: str, config_dir: Path, start_time: datetime, end_time: datetime
262
- ) -> int:
258
+ ) -> None:
263
259
  with open(file_paths.template_troute_config, "r") as file:
264
260
  troute_template = file.read()
265
261
  time_step_size = 300
262
+ gpkg_file_path=f"{config_dir}/{cat_id}_subset.gpkg"
266
263
  nts = (end_time - start_time).total_seconds() / time_step_size
264
+ with sqlite3.connect(gpkg_file_path) as conn:
265
+ ncats_df = pandas.read_sql_query("SELECT COUNT(id) FROM 'divides';", conn)
266
+ ncats = ncats_df['COUNT(id)'][0]
267
+
268
+ est_bytes_required = nts * ncats * 45 # extremely rough calculation based on about 3 tests :)
269
+ local_ram_available = 0.8 * psutil.virtual_memory().available # buffer to not accidentally explode machine
270
+
271
+ if est_bytes_required > local_ram_available:
272
+ max_loop_size = nts // (est_bytes_required // local_ram_available)
273
+ binary_nexus_file_folder_comment = ""
274
+ parent_dir = config_dir.parent
275
+ output_parquet_path = Path(f"{parent_dir}/outputs/parquet/")
276
+
277
+ if not output_parquet_path.exists():
278
+ os.makedirs(output_parquet_path)
279
+ else:
280
+ max_loop_size = nts
281
+ binary_nexus_file_folder_comment = "#"
282
+
267
283
  filled_template = troute_template.format(
268
284
  # hard coded to 5 minutes
269
285
  time_step_size=time_step_size,
@@ -272,7 +288,8 @@ def configure_troute(
272
288
  geo_file_path=f"./config/{cat_id}_subset.gpkg",
273
289
  start_datetime=start_time.strftime("%Y-%m-%d %H:%M:%S"),
274
290
  nts=nts,
275
- max_loop_size=nts,
291
+ max_loop_size=max_loop_size,
292
+ binary_nexus_file_folder_comment=binary_nexus_file_folder_comment
276
293
  )
277
294
 
278
295
  with open(config_dir / "troute.yaml", "w") as file:
@@ -316,7 +333,7 @@ def create_realization(
316
333
  start_time: datetime,
317
334
  end_time: datetime,
318
335
  use_nwm_gw: bool = False,
319
- gage_id: str = None,
336
+ gage_id: Optional[str] = None,
320
337
  ):
321
338
  paths = file_paths(cat_id)
322
339
 
@@ -354,12 +371,12 @@ def create_realization(
354
371
  create_partitions(paths)
355
372
 
356
373
 
357
- def create_partitions(paths: Path, num_partitions: int = None) -> None:
374
+ def create_partitions(paths: file_paths, num_partitions: Optional[int] = None) -> None:
358
375
  if num_partitions is None:
359
376
  num_partitions = multiprocessing.cpu_count()
360
377
 
361
378
  cat_to_nex_pairs = get_cat_to_nex_flowpairs(hydrofabric=paths.geopackage_path)
362
- nexus = defaultdict(list)
379
+ # nexus = defaultdict(list)
363
380
 
364
381
  # for cat, nex in cat_to_nex_pairs:
365
382
  # nexus[nex].append(cat)
@@ -0,0 +1,92 @@
1
+ import logging
2
+
3
+ from dask.distributed import Client
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ def shutdown_cluster():
9
+ try:
10
+ client = Client.current()
11
+ client.shutdown()
12
+ except ValueError:
13
+ logger.debug("No cluster found to shutdown")
14
+
15
+
16
+ def no_cluster(func):
17
+ """
18
+ Decorator that ensures the wrapped function runs with no active Dask cluster.
19
+
20
+ This decorator attempts to shut down any existing Dask cluster before
21
+ executing the wrapped function. If no cluster is found, it logs a debug message
22
+ and continues execution.
23
+
24
+ Parameters:
25
+ func: The function to be executed without a Dask cluster
26
+
27
+ Returns:
28
+ wrapper: The wrapped function that will be executed without a Dask cluster
29
+ """
30
+
31
+ def wrapper(*args, **kwargs):
32
+ shutdown_cluster()
33
+ result = func(*args, **kwargs)
34
+ return result
35
+
36
+ return wrapper
37
+
38
+
39
+ def use_cluster(func):
40
+ """
41
+ Decorator that ensures the wrapped function has access to a Dask cluster.
42
+
43
+ If a Dask cluster is already running, it uses the existing one.
44
+ If no cluster is available, it creates a new one before executing the function.
45
+ The cluster remains active after the function completes.
46
+
47
+ Parameters:
48
+ func: The function to be executed with a Dask cluster
49
+
50
+ Returns:
51
+ wrapper: The wrapped function with access to a Dask cluster
52
+ """
53
+
54
+ def wrapper(*args, **kwargs):
55
+ try:
56
+ client = Client.current()
57
+ except ValueError:
58
+ client = Client()
59
+ result = func(*args, **kwargs)
60
+ return result
61
+
62
+ return wrapper
63
+
64
+
65
+ def temp_cluster(func):
66
+ """
67
+ Decorator that provides a temporary Dask cluster for the wrapped function.
68
+
69
+ If a Dask cluster is already running, it uses the existing one and leaves it running.
70
+ If no cluster exists, it creates a temporary one and shuts it down after
71
+ the function completes.
72
+
73
+ Parameters:
74
+ func: The function to be executed with a Dask cluster
75
+
76
+ Returns:
77
+ wrapper: The wrapped function with access to a Dask cluster
78
+ """
79
+
80
+ def wrapper(*args, **kwargs):
81
+ cluster_was_running = True
82
+ try:
83
+ client = Client.current()
84
+ except ValueError:
85
+ cluster_was_running = False
86
+ client = Client()
87
+ result = func(*args, **kwargs)
88
+ if not cluster_was_running:
89
+ client.shutdown()
90
+ return result
91
+
92
+ return wrapper
@@ -1,19 +1,22 @@
1
1
  import logging
2
2
  import os
3
+ from datetime import datetime
3
4
  from pathlib import Path
4
- from typing import Tuple, Union
5
+ from typing import List, Literal, Optional, Tuple, Union
5
6
 
6
7
  import geopandas as gpd
7
8
  import numpy as np
8
9
  import xarray as xr
9
- from dask.distributed import Client, progress
10
- import datetime
10
+ from dask.distributed import Client, Future, progress
11
+ from data_processing.dask_utils import no_cluster, temp_cluster
12
+ from xarray.core.types import InterpOptions
11
13
 
12
14
  logger = logging.getLogger(__name__)
13
15
 
14
16
  # known ngen variable names
15
17
  # https://github.com/CIROH-UA/ngen/blob/4fb5bb68dc397298bca470dfec94db2c1dcb42fe/include/forcing/AorcForcing.hpp#L77
16
18
 
19
+
17
20
  def validate_dataset_format(dataset: xr.Dataset) -> None:
18
21
  """
19
22
  Validate the format of the dataset.
@@ -41,8 +44,9 @@ def validate_dataset_format(dataset: xr.Dataset) -> None:
41
44
  if "name" not in dataset.attrs:
42
45
  raise ValueError("Dataset must have a name attribute to identify it")
43
46
 
47
+
44
48
  def validate_time_range(dataset: xr.Dataset, start_time: str, end_time: str) -> Tuple[str, str]:
45
- '''
49
+ """
46
50
  Ensure that all selected times are in the passed dataset.
47
51
 
48
52
  Parameters
@@ -60,7 +64,7 @@ def validate_time_range(dataset: xr.Dataset, start_time: str, end_time: str) ->
60
64
  start_time, or if not available, earliest available timestep in dataset.
61
65
  str
62
66
  end_time, or if not available, latest available timestep in dataset.
63
- '''
67
+ """
64
68
  end_time_in_dataset = dataset.time.isel(time=-1).values
65
69
  start_time_in_dataset = dataset.time.isel(time=0).values
66
70
  if np.datetime64(start_time) < start_time_in_dataset:
@@ -77,7 +81,10 @@ def validate_time_range(dataset: xr.Dataset, start_time: str, end_time: str) ->
77
81
 
78
82
 
79
83
  def clip_dataset_to_bounds(
80
- dataset: xr.Dataset, bounds: Tuple[float, float, float, float], start_time: str, end_time: str
84
+ dataset: xr.Dataset,
85
+ bounds: Tuple[float, float, float, float] | np.ndarray[tuple[int], np.dtype[np.float64]],
86
+ start_time: str,
87
+ end_time: str,
81
88
  ) -> xr.Dataset:
82
89
  """
83
90
  Clip the dataset to specified geographical bounds.
@@ -86,14 +93,14 @@ def clip_dataset_to_bounds(
86
93
  ----------
87
94
  dataset : xr.Dataset
88
95
  Dataset to be clipped.
89
- bounds : tuple[float, float, float, float]
90
- Corners of bounding box. bounds[0] is x_min, bounds[1] is y_min,
96
+ bounds : tuple[float, float, float, float] | np.ndarray[tuple[int], np.dtype[np.float64]]
97
+ Corners of bounding box. bounds[0] is x_min, bounds[1] is y_min,
91
98
  bounds[2] is x_max, bounds[3] is y_max.
92
99
  start_time : str
93
100
  Desired start time in YYYY/MM/DD HH:MM:SS format.
94
101
  end_time : str
95
102
  Desired end time in YYYY/MM/DD HH:MM:SS format.
96
-
103
+
97
104
  Returns
98
105
  -------
99
106
  xr.Dataset
@@ -110,33 +117,137 @@ def clip_dataset_to_bounds(
110
117
  return dataset
111
118
 
112
119
 
113
- def save_to_cache(stores: xr.Dataset, cached_nc_path: Path) -> xr.Dataset:
114
- """Compute the store and save it to a cached netCDF file. This is not required but will save time and bandwidth."""
115
- logger.info("Downloading and caching forcing data, this may take a while")
120
+ @no_cluster
121
+ def interpolate_nan_values(
122
+ dataset: xr.Dataset,
123
+ variables: Optional[List[str]] = None,
124
+ dim: str = "time",
125
+ method: InterpOptions = "nearest",
126
+ fill_value: str = "extrapolate",
127
+ ) -> bool:
128
+ """
129
+ Interpolates NaN values in specified (or all numeric time-dependent)
130
+ variables of an xarray.Dataset. Operates inplace on the dataset.
131
+
132
+ Parameters
133
+ ----------
134
+ dataset : xr.Dataset
135
+ The input dataset.
136
+ variables : Optional[List[str]], optional
137
+ A list of variable names to process. If None (default),
138
+ all numeric variables containing the specified dimension will be processed.
139
+ dim : str, optional
140
+ The dimension along which to interpolate (default is "time").
141
+ method : str, optional
142
+ Interpolation method to use (e.g., "linear", "nearest", "cubic").
143
+ Default is "nearest".
144
+ fill_value : str, optional
145
+ Method for filling NaNs at the start/end of the series after interpolation.
146
+ Set to "extrapolate" to fill with the nearest valid value when using 'nearest' or 'linear'.
147
+ Default is "extrapolate".
148
+ """
149
+ interpolation_used = False
150
+ for name, var in dataset.data_vars.items():
151
+ # if the variable is non-numeric, skip
152
+ if not np.issubdtype(var.dtype, np.number):
153
+ continue
154
+ # if there are no NANs, skip
155
+ if not var.isnull().any().compute():
156
+ continue
157
+
158
+ dataset[name] = var.interpolate_na(
159
+ dim=dim,
160
+ method=method,
161
+ fill_value=fill_value if method in ["nearest", "linear"] else None,
162
+ )
163
+ interpolation_used = True
164
+ return interpolation_used
165
+
166
+
167
+ @no_cluster
168
+ def save_dataset_no_cluster(
169
+ ds_to_save: xr.Dataset,
170
+ target_path: Path,
171
+ engine: Literal["netcdf4", "scipy", "h5netcdf"] = "h5netcdf",
172
+ ):
173
+ """
174
+ This explicitly does not use dask distributed.
175
+ Helper function to compute and save an xarray.Dataset to a NetCDF file.
176
+ Uses a temporary file and rename for avoid leaving a half written file.
177
+ """
178
+ if not target_path.parent.exists():
179
+ target_path.parent.mkdir(parents=True, exist_ok=True)
116
180
 
117
- if not cached_nc_path.parent.exists():
118
- cached_nc_path.parent.mkdir(parents=True)
181
+ temp_file_path = target_path.with_name(target_path.name + ".saving.nc")
182
+ if temp_file_path.exists():
183
+ os.remove(temp_file_path)
119
184
 
120
- # sort of terrible work around for half downloaded files
121
- temp_path = cached_nc_path.with_suffix(".downloading.nc")
122
- if os.path.exists(temp_path):
123
- os.remove(temp_path)
185
+ ds_to_save.to_netcdf(temp_file_path, engine=engine, compute=True)
124
186
 
125
- ## Cast every single variable to float32 to save space to save a lot of memory issues later
126
- ## easier to do it now in this slow download step than later in the steps without dask
127
- for var in stores.data_vars:
128
- stores[var] = stores[var].astype("float32")
187
+ os.rename(str(temp_file_path), str(target_path))
188
+ logger.info(f"Successfully saved data to: {target_path}")
189
+
190
+
191
+ @temp_cluster
192
+ def save_dataset(
193
+ ds_to_save: xr.Dataset,
194
+ target_path: Path,
195
+ engine: Literal["netcdf4", "scipy", "h5netcdf"] = "h5netcdf",
196
+ ):
197
+ """
198
+ Helper function to compute and save an xarray.Dataset to a NetCDF file.
199
+ Uses a temporary file and rename for atomicity.
200
+ """
201
+ if not target_path.parent.exists():
202
+ target_path.parent.mkdir(parents=True, exist_ok=True)
203
+
204
+ temp_file_path = target_path.with_name(target_path.name + ".saving.nc")
205
+ if temp_file_path.exists():
206
+ os.remove(temp_file_path)
129
207
 
130
208
  client = Client.current()
131
- future = client.compute(stores.to_netcdf(temp_path, compute=False))
132
- # Display progress bar
209
+ future: Future = client.compute(
210
+ ds_to_save.to_netcdf(temp_file_path, engine=engine, compute=False)
211
+ ) # type: ignore
212
+ logger.debug(
213
+ f"NetCDF write task submitted to Dask. Waiting for completion to {temp_file_path}..."
214
+ )
215
+ logger.info("For more detailed progress, see the Dask dashboard http://localhost:8787/status")
133
216
  progress(future)
134
217
  future.result()
218
+ os.rename(str(temp_file_path), str(target_path))
219
+ logger.info(f"Successfully saved data to: {target_path}")
220
+
135
221
 
136
- os.rename(temp_path, cached_nc_path)
222
+ @no_cluster
223
+ def save_to_cache(
224
+ stores: xr.Dataset, cached_nc_path: Path, interpolate_nans: bool = True
225
+ ) -> xr.Dataset:
226
+ """
227
+ Compute the store and save it to a cached netCDF file. This is not required but will save time and bandwidth.
228
+ """
229
+ logger.debug(f"Processing dataset for caching. Final cache target: {cached_nc_path}")
137
230
 
138
- data = xr.open_mfdataset(cached_nc_path, parallel=True, engine="h5netcdf")
139
- return data
231
+ # lasily cast all numbers to f32
232
+ for name, var in stores.data_vars.items():
233
+ if np.issubdtype(var.dtype, np.number):
234
+ stores[name] = var.astype("float32", casting="same_kind")
235
+
236
+ # save dataset locally before manipulating it
237
+ save_dataset(stores, cached_nc_path)
238
+
239
+ if interpolate_nans:
240
+ stores = xr.open_mfdataset(
241
+ cached_nc_path,
242
+ parallel=True,
243
+ engine="h5netcdf",
244
+ )
245
+ was_interpolated = interpolate_nan_values(dataset=stores)
246
+ if was_interpolated:
247
+ save_dataset_no_cluster(stores, cached_nc_path)
248
+
249
+ stores = xr.open_mfdataset(cached_nc_path, parallel=True, engine="h5netcdf")
250
+ return stores
140
251
 
141
252
 
142
253
  def check_local_cache(
@@ -144,9 +255,8 @@ def check_local_cache(
144
255
  start_time: str,
145
256
  end_time: str,
146
257
  gdf: gpd.GeoDataFrame,
147
- remote_dataset: xr.Dataset
258
+ remote_dataset: xr.Dataset,
148
259
  ) -> Union[xr.Dataset, None]:
149
-
150
260
  merged_data = None
151
261
 
152
262
  if not os.path.exists(cached_nc_path):
@@ -155,9 +265,7 @@ def check_local_cache(
155
265
 
156
266
  logger.info("Found cached nc file")
157
267
  # open the cached file and check that the time range is correct
158
- cached_data = xr.open_mfdataset(
159
- cached_nc_path, parallel=True, engine="h5netcdf"
160
- )
268
+ cached_data = xr.open_mfdataset(cached_nc_path, parallel=True, engine="h5netcdf")
161
269
 
162
270
  if "name" not in cached_data.attrs or "name" not in remote_dataset.attrs:
163
271
  logger.warning("No name attribute found to compare datasets")
@@ -166,9 +274,9 @@ def check_local_cache(
166
274
  logger.warning("Cached data from different source, .name attr doesn't match")
167
275
  return
168
276
 
169
- range_in_cache = cached_data.time[0].values <= np.datetime64(
170
- start_time
171
- ) and cached_data.time[-1].values >= np.datetime64(end_time)
277
+ range_in_cache = cached_data.time[0].values <= np.datetime64(start_time) and cached_data.time[
278
+ -1
279
+ ].values >= np.datetime64(end_time)
172
280
 
173
281
  if not range_in_cache:
174
282
  # the cache does not contain the desired time range
@@ -186,10 +294,8 @@ def check_local_cache(
186
294
  if range_in_cache:
187
295
  logger.info("Time range is within cached data")
188
296
  logger.debug(f"Opened cached nc file: [{cached_nc_path}]")
189
- merged_data = clip_dataset_to_bounds(
190
- cached_data, gdf.total_bounds, start_time, end_time
191
- )
192
- logger.debug("Clipped stores")
297
+ merged_data = clip_dataset_to_bounds(cached_data, gdf.total_bounds, start_time, end_time)
298
+ logger.debug("Clipped stores")
193
299
 
194
300
  return merged_data
195
301
 
@@ -197,16 +303,27 @@ def check_local_cache(
197
303
  def save_and_clip_dataset(
198
304
  dataset: xr.Dataset,
199
305
  gdf: gpd.GeoDataFrame,
200
- start_time: datetime.datetime,
201
- end_time: datetime.datetime,
306
+ start_time: datetime,
307
+ end_time: datetime,
202
308
  cache_location: Path,
203
309
  ) -> xr.Dataset:
204
310
  """convenience function clip the remote dataset, and either load from cache or save to cache if it's not present"""
205
311
  gdf = gdf.to_crs(dataset.crs)
206
312
 
207
- cached_data = check_local_cache(cache_location, start_time, end_time, gdf, dataset)
313
+ cached_data = check_local_cache(
314
+ cache_location,
315
+ start_time, # type: ignore
316
+ end_time, # type: ignore
317
+ gdf,
318
+ dataset,
319
+ )
208
320
 
209
321
  if not cached_data:
210
- clipped_data = clip_dataset_to_bounds(dataset, gdf.total_bounds, start_time, end_time)
322
+ clipped_data = clip_dataset_to_bounds(
323
+ dataset,
324
+ gdf.total_bounds,
325
+ start_time, # type: ignore
326
+ end_time, # type: ignore
327
+ )
211
328
  cached_data = save_to_cache(clipped_data, cache_location)
212
- return cached_data
329
+ return cached_data
@@ -1,35 +1,31 @@
1
1
  import logging
2
+ from typing import Optional
2
3
 
3
4
  import s3fs
4
- from data_processing.s3fs_utils import S3ParallelFileSystem
5
5
  import xarray as xr
6
- from dask.distributed import Client, LocalCluster
6
+ from data_processing.dask_utils import use_cluster
7
7
  from data_processing.dataset_utils import validate_dataset_format
8
-
8
+ from data_processing.s3fs_utils import S3ParallelFileSystem
9
9
 
10
10
  logger = logging.getLogger(__name__)
11
11
 
12
12
 
13
- def load_v3_retrospective_zarr(forcing_vars: list[str] = None) -> xr.Dataset:
13
+ @use_cluster
14
+ def load_v3_retrospective_zarr(forcing_vars: Optional[list[str]] = None) -> xr.Dataset:
14
15
  """Load zarr datasets from S3 within the specified time range."""
15
16
  # if a LocalCluster is not already running, start one
16
17
  if not forcing_vars:
17
18
  forcing_vars = ["lwdown", "precip", "psfc", "q2d", "swdown", "t2d", "u2d", "v2d"]
18
- try:
19
- client = Client.current()
20
- except ValueError:
21
- cluster = LocalCluster()
22
- client = Client(cluster)
19
+
23
20
  s3_urls = [
24
- f"s3://noaa-nwm-retrospective-3-0-pds/CONUS/zarr/forcing/{var}.zarr"
25
- for var in forcing_vars
21
+ f"s3://noaa-nwm-retrospective-3-0-pds/CONUS/zarr/forcing/{var}.zarr" for var in forcing_vars
26
22
  ]
27
23
  # default cache is readahead which is detrimental to performance in this case
28
24
  fs = S3ParallelFileSystem(anon=True, default_cache_type="none") # default_block_size
29
25
  s3_stores = [s3fs.S3Map(url, s3=fs) for url in s3_urls]
30
26
  # the cache option here just holds accessed data in memory to prevent s3 being queried multiple times
31
27
  # most of the data is read once and written to disk but some of the coordinate data is read multiple times
32
- dataset = xr.open_mfdataset(s3_stores, parallel=True, engine="zarr", cache=True)
28
+ dataset = xr.open_mfdataset(s3_stores, parallel=True, engine="zarr", cache=True) # type: ignore
33
29
 
34
30
  # set the crs attribute to conform with the format
35
31
  esri_pe_string = dataset.crs.esri_pe_string
@@ -54,7 +50,8 @@ def load_v3_retrospective_zarr(forcing_vars: list[str] = None) -> xr.Dataset:
54
50
  return dataset
55
51
 
56
52
 
57
- def load_aorc_zarr(start_year: int = None, end_year: int = None) -> xr.Dataset:
53
+ @use_cluster
54
+ def load_aorc_zarr(start_year: Optional[int] = None, end_year: Optional[int] = None) -> xr.Dataset:
58
55
  """Load the aorc zarr dataset from S3."""
59
56
  if not start_year or not end_year:
60
57
  logger.warning("No start or end year provided, defaulting to 1979-2023")
@@ -63,11 +60,6 @@ def load_aorc_zarr(start_year: int = None, end_year: int = None) -> xr.Dataset:
63
60
  start_year = 1979
64
61
  if not end_year:
65
62
  end_year = 2023
66
- try:
67
- client = Client.current()
68
- except ValueError:
69
- cluster = LocalCluster()
70
- client = Client(cluster)
71
63
 
72
64
  logger.info(f"Loading AORC zarr datasets from {start_year} to {end_year}")
73
65
  estimated_time_s = ((end_year - start_year) * 2.5) + 3.5
@@ -75,9 +67,9 @@ def load_aorc_zarr(start_year: int = None, end_year: int = None) -> xr.Dataset:
75
67
  logger.info(f"This should take roughly {estimated_time_s} seconds")
76
68
  fs = S3ParallelFileSystem(anon=True, default_cache_type="none")
77
69
  s3_url = "s3://noaa-nws-aorc-v1-1-1km/"
78
- urls = [f"{s3_url}{i}.zarr" for i in range(start_year, end_year+1)]
70
+ urls = [f"{s3_url}{i}.zarr" for i in range(start_year, end_year + 1)]
79
71
  filestores = [s3fs.S3Map(url, s3=fs) for url in urls]
80
- dataset = xr.open_mfdataset(filestores, parallel=True, engine="zarr", cache=True)
72
+ dataset = xr.open_mfdataset(filestores, parallel=True, engine="zarr", cache=True) # type: ignore
81
73
  dataset.attrs["crs"] = "+proj=longlat +datum=WGS84 +no_defs"
82
74
  dataset.attrs["name"] = "aorc_1km_zarr"
83
75
  # rename latitude and longitude to x and y
@@ -87,32 +79,29 @@ def load_aorc_zarr(start_year: int = None, end_year: int = None) -> xr.Dataset:
87
79
  return dataset
88
80
 
89
81
 
82
+ @use_cluster
90
83
  def load_swe_zarr() -> xr.Dataset:
91
84
  """Load the swe zarr dataset from S3."""
92
- s3_urls = [
93
- f"s3://noaa-nwm-retrospective-3-0-pds/CONUS/zarr/ldasout.zarr"
94
- ]
85
+ s3_urls = ["s3://noaa-nwm-retrospective-3-0-pds/CONUS/zarr/ldasout.zarr"]
95
86
  # default cache is readahead which is detrimental to performance in this case
96
87
  fs = S3ParallelFileSystem(anon=True, default_cache_type="none") # default_block_size
97
88
  s3_stores = [s3fs.S3Map(url, s3=fs) for url in s3_urls]
98
89
  # the cache option here just holds accessed data in memory to prevent s3 being queried multiple times
99
90
  # most of the data is read once and written to disk but some of the coordinate data is read multiple times
100
- dataset = xr.open_mfdataset(s3_stores, parallel=True, engine="zarr", cache=True)
101
-
91
+ dataset = xr.open_mfdataset(s3_stores, parallel=True, engine="zarr", cache=True) # type: ignore
92
+
102
93
  # set the crs attribute to conform with the format
103
94
  esri_pe_string = dataset.crs.esri_pe_string
104
95
  dataset = dataset.drop_vars(["crs"])
105
96
  dataset.attrs["crs"] = esri_pe_string
106
97
  # drop everything except SNEQV
107
98
  vars_to_drop = list(dataset.data_vars)
108
- vars_to_drop.remove('SNEQV')
99
+ vars_to_drop.remove("SNEQV")
109
100
  dataset = dataset.drop_vars(vars_to_drop)
110
101
  dataset.attrs["name"] = "v3_swe_zarr"
111
102
 
112
103
  # rename the data vars to work with ngen
113
- variables = {
114
- "SNEQV": "swe"
115
- }
104
+ variables = {"SNEQV": "swe"}
116
105
  dataset = dataset.rename_vars(variables)
117
106
 
118
107
  validate_dataset_format(dataset)