ngiab-data-preprocess 4.2.2__py3-none-any.whl → 4.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,15 +3,15 @@ import logging
3
3
  import multiprocessing
4
4
  import shutil
5
5
  import sqlite3
6
- from collections import defaultdict
7
6
  from datetime import datetime
8
7
  from pathlib import Path
8
+ from typing import Dict, Optional
9
9
 
10
10
  import pandas
11
11
  import requests
12
12
  import s3fs
13
13
  import xarray as xr
14
- from dask.distributed import Client, LocalCluster
14
+ from data_processing.dask_utils import temp_cluster
15
15
  from data_processing.file_paths import file_paths
16
16
  from data_processing.gpkg_utils import (
17
17
  GeoPackage,
@@ -25,7 +25,8 @@ from tqdm.rich import tqdm
25
25
  logger = logging.getLogger(__name__)
26
26
 
27
27
 
28
- def get_approximate_gw_storage(paths: file_paths, start_date: datetime):
28
+ @temp_cluster
29
+ def get_approximate_gw_storage(paths: file_paths, start_date: datetime) -> Dict[str, int]:
29
30
  # get the gw levels from the NWM output on a given start date
30
31
  # this kind of works in place of warmstates for now
31
32
  year = start_date.strftime("%Y")
@@ -35,17 +36,10 @@ def get_approximate_gw_storage(paths: file_paths, start_date: datetime):
35
36
  fs = s3fs.S3FileSystem(anon=True)
36
37
  nc_url = f"s3://noaa-nwm-retrospective-3-0-pds/CONUS/netcdf/GWOUT/{year}/{formatted_dt}.GWOUT_DOMAIN1"
37
38
 
38
- # make sure there's a dask cluster running
39
- try:
40
- client = Client.current()
41
- except ValueError:
42
- cluster = LocalCluster()
43
- client = Client(cluster)
44
-
45
39
  with fs.open(nc_url) as file_obj:
46
- ds = xr.open_dataset(file_obj)
40
+ ds = xr.open_dataset(file_obj) # type: ignore
47
41
 
48
- water_levels = dict()
42
+ water_levels: Dict[str, int] = dict()
49
43
  for cat, feature in tqdm(cat_to_feature.items()):
50
44
  # this value is in CM, we need meters to match max_gw_depth
51
45
  # xarray says it's in mm, with 0.1 scale factor. calling .values doesn't apply the scale
@@ -114,13 +108,13 @@ def make_noahowp_config(
114
108
  lon=divide_conf_df.loc[divide, "longitude"],
115
109
  terrain_slope=divide_conf_df.loc[divide, "mean.slope_1km"],
116
110
  azimuth=divide_conf_df.loc[divide, "circ_mean.aspect"],
117
- ISLTYP=int(divide_conf_df.loc[divide, "mode.ISLTYP"]),
118
- IVGTYP=int(divide_conf_df.loc[divide, "mode.IVGTYP"]),
111
+ ISLTYP=int(divide_conf_df.loc[divide, "mode.ISLTYP"]), # type: ignore
112
+ IVGTYP=int(divide_conf_df.loc[divide, "mode.IVGTYP"]), # type: ignore
119
113
  )
120
114
  )
121
115
 
122
116
 
123
- def get_model_attributes_modspatialite(hydrofabric: Path):
117
+ def get_model_attributes_modspatialite(hydrofabric: Path) -> pandas.DataFrame:
124
118
  # modspatialite is faster than pyproj but can't be added as a pip dependency
125
119
  # This incantation took a while
126
120
  with GeoPackage(hydrofabric) as conn:
@@ -151,7 +145,7 @@ def get_model_attributes_modspatialite(hydrofabric: Path):
151
145
  return divide_conf_df
152
146
 
153
147
 
154
- def get_model_attributes_pyproj(hydrofabric: Path):
148
+ def get_model_attributes_pyproj(hydrofabric: Path) -> pandas.DataFrame:
155
149
  # if modspatialite is not available, use pyproj
156
150
  with sqlite3.connect(hydrofabric) as conn:
157
151
  sql = """
@@ -185,7 +179,7 @@ def get_model_attributes_pyproj(hydrofabric: Path):
185
179
  return divide_conf_df
186
180
 
187
181
 
188
- def get_model_attributes(hydrofabric: Path):
182
+ def get_model_attributes(hydrofabric: Path) -> pandas.DataFrame:
189
183
  try:
190
184
  with GeoPackage(hydrofabric) as conn:
191
185
  conf_df = pandas.read_sql_query(
@@ -259,7 +253,7 @@ def make_em_config(
259
253
 
260
254
  def configure_troute(
261
255
  cat_id: str, config_dir: Path, start_time: datetime, end_time: datetime
262
- ) -> int:
256
+ ) -> None:
263
257
  with open(file_paths.template_troute_config, "r") as file:
264
258
  troute_template = file.read()
265
259
  time_step_size = 300
@@ -316,7 +310,7 @@ def create_realization(
316
310
  start_time: datetime,
317
311
  end_time: datetime,
318
312
  use_nwm_gw: bool = False,
319
- gage_id: str = None,
313
+ gage_id: Optional[str] = None,
320
314
  ):
321
315
  paths = file_paths(cat_id)
322
316
 
@@ -354,12 +348,12 @@ def create_realization(
354
348
  create_partitions(paths)
355
349
 
356
350
 
357
- def create_partitions(paths: Path, num_partitions: int = None) -> None:
351
+ def create_partitions(paths: file_paths, num_partitions: Optional[int] = None) -> None:
358
352
  if num_partitions is None:
359
353
  num_partitions = multiprocessing.cpu_count()
360
354
 
361
355
  cat_to_nex_pairs = get_cat_to_nex_flowpairs(hydrofabric=paths.geopackage_path)
362
- nexus = defaultdict(list)
356
+ # nexus = defaultdict(list)
363
357
 
364
358
  # for cat, nex in cat_to_nex_pairs:
365
359
  # nexus[nex].append(cat)
@@ -0,0 +1,92 @@
1
+ import logging
2
+
3
+ from dask.distributed import Client
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ def shutdown_cluster():
9
+ try:
10
+ client = Client.current()
11
+ client.shutdown()
12
+ except ValueError:
13
+ logger.debug("No cluster found to shutdown")
14
+
15
+
16
+ def no_cluster(func):
17
+ """
18
+ Decorator that ensures the wrapped function runs with no active Dask cluster.
19
+
20
+ This decorator attempts to shut down any existing Dask cluster before
21
+ executing the wrapped function. If no cluster is found, it logs a debug message
22
+ and continues execution.
23
+
24
+ Parameters:
25
+ func: The function to be executed without a Dask cluster
26
+
27
+ Returns:
28
+ wrapper: The wrapped function that will be executed without a Dask cluster
29
+ """
30
+
31
+ def wrapper(*args, **kwargs):
32
+ shutdown_cluster()
33
+ result = func(*args, **kwargs)
34
+ return result
35
+
36
+ return wrapper
37
+
38
+
39
+ def use_cluster(func):
40
+ """
41
+ Decorator that ensures the wrapped function has access to a Dask cluster.
42
+
43
+ If a Dask cluster is already running, it uses the existing one.
44
+ If no cluster is available, it creates a new one before executing the function.
45
+ The cluster remains active after the function completes.
46
+
47
+ Parameters:
48
+ func: The function to be executed with a Dask cluster
49
+
50
+ Returns:
51
+ wrapper: The wrapped function with access to a Dask cluster
52
+ """
53
+
54
+ def wrapper(*args, **kwargs):
55
+ try:
56
+ client = Client.current()
57
+ except ValueError:
58
+ client = Client()
59
+ result = func(*args, **kwargs)
60
+ return result
61
+
62
+ return wrapper
63
+
64
+
65
+ def temp_cluster(func):
66
+ """
67
+ Decorator that provides a temporary Dask cluster for the wrapped function.
68
+
69
+ If a Dask cluster is already running, it uses the existing one and leaves it running.
70
+ If no cluster exists, it creates a temporary one and shuts it down after
71
+ the function completes.
72
+
73
+ Parameters:
74
+ func: The function to be executed with a Dask cluster
75
+
76
+ Returns:
77
+ wrapper: The wrapped function with access to a Dask cluster
78
+ """
79
+
80
+ def wrapper(*args, **kwargs):
81
+ cluster_was_running = True
82
+ try:
83
+ client = Client.current()
84
+ except ValueError:
85
+ cluster_was_running = False
86
+ client = Client()
87
+ result = func(*args, **kwargs)
88
+ if not cluster_was_running:
89
+ client.shutdown()
90
+ return result
91
+
92
+ return wrapper
@@ -1,19 +1,22 @@
1
1
  import logging
2
2
  import os
3
+ from datetime import datetime
3
4
  from pathlib import Path
4
- from typing import Tuple, Union
5
+ from typing import List, Literal, Optional, Tuple, Union
5
6
 
6
7
  import geopandas as gpd
7
8
  import numpy as np
8
9
  import xarray as xr
9
- from dask.distributed import Client, progress
10
- import datetime
10
+ from xarray.core.types import InterpOptions
11
+ from dask.distributed import Client, progress, Future
12
+ from data_processing.dask_utils import use_cluster
11
13
 
12
14
  logger = logging.getLogger(__name__)
13
15
 
14
16
  # known ngen variable names
15
17
  # https://github.com/CIROH-UA/ngen/blob/4fb5bb68dc397298bca470dfec94db2c1dcb42fe/include/forcing/AorcForcing.hpp#L77
16
18
 
19
+
17
20
  def validate_dataset_format(dataset: xr.Dataset) -> None:
18
21
  """
19
22
  Validate the format of the dataset.
@@ -41,8 +44,9 @@ def validate_dataset_format(dataset: xr.Dataset) -> None:
41
44
  if "name" not in dataset.attrs:
42
45
  raise ValueError("Dataset must have a name attribute to identify it")
43
46
 
47
+
44
48
  def validate_time_range(dataset: xr.Dataset, start_time: str, end_time: str) -> Tuple[str, str]:
45
- '''
49
+ """
46
50
  Ensure that all selected times are in the passed dataset.
47
51
 
48
52
  Parameters
@@ -60,7 +64,7 @@ def validate_time_range(dataset: xr.Dataset, start_time: str, end_time: str) ->
60
64
  start_time, or if not available, earliest available timestep in dataset.
61
65
  str
62
66
  end_time, or if not available, latest available timestep in dataset.
63
- '''
67
+ """
64
68
  end_time_in_dataset = dataset.time.isel(time=-1).values
65
69
  start_time_in_dataset = dataset.time.isel(time=0).values
66
70
  if np.datetime64(start_time) < start_time_in_dataset:
@@ -77,7 +81,10 @@ def validate_time_range(dataset: xr.Dataset, start_time: str, end_time: str) ->
77
81
 
78
82
 
79
83
  def clip_dataset_to_bounds(
80
- dataset: xr.Dataset, bounds: Tuple[float, float, float, float], start_time: str, end_time: str
84
+ dataset: xr.Dataset,
85
+ bounds: Tuple[float, float, float, float] | np.ndarray[tuple[int], np.dtype[np.float64]],
86
+ start_time: str,
87
+ end_time: str,
81
88
  ) -> xr.Dataset:
82
89
  """
83
90
  Clip the dataset to specified geographical bounds.
@@ -86,14 +93,14 @@ def clip_dataset_to_bounds(
86
93
  ----------
87
94
  dataset : xr.Dataset
88
95
  Dataset to be clipped.
89
- bounds : tuple[float, float, float, float]
90
- Corners of bounding box. bounds[0] is x_min, bounds[1] is y_min,
96
+ bounds : tuple[float, float, float, float] | np.ndarray[tuple[int], np.dtype[np.float64]]
97
+ Corners of bounding box. bounds[0] is x_min, bounds[1] is y_min,
91
98
  bounds[2] is x_max, bounds[3] is y_max.
92
99
  start_time : str
93
100
  Desired start time in YYYY/MM/DD HH:MM:SS format.
94
101
  end_time : str
95
102
  Desired end time in YYYY/MM/DD HH:MM:SS format.
96
-
103
+
97
104
  Returns
98
105
  -------
99
106
  xr.Dataset
@@ -110,33 +117,103 @@ def clip_dataset_to_bounds(
110
117
  return dataset
111
118
 
112
119
 
113
- def save_to_cache(stores: xr.Dataset, cached_nc_path: Path) -> xr.Dataset:
114
- """Compute the store and save it to a cached netCDF file. This is not required but will save time and bandwidth."""
115
- logger.info("Downloading and caching forcing data, this may take a while")
120
+ def interpolate_nan_values(
121
+ dataset: xr.Dataset,
122
+ variables: Optional[List[str]] = None,
123
+ dim: str = "time",
124
+ method: InterpOptions = "nearest",
125
+ fill_value: str = "extrapolate",
126
+ ) -> None:
127
+ """
128
+ Interpolates NaN values in specified (or all numeric time-dependent)
129
+ variables of an xarray.Dataset. Operates inplace on the dataset.
116
130
 
117
- if not cached_nc_path.parent.exists():
118
- cached_nc_path.parent.mkdir(parents=True)
131
+ Parameters
132
+ ----------
133
+ dataset : xr.Dataset
134
+ The input dataset.
135
+ variables : Optional[List[str]], optional
136
+ A list of variable names to process. If None (default),
137
+ all numeric variables containing the specified dimension will be processed.
138
+ dim : str, optional
139
+ The dimension along which to interpolate (default is "time").
140
+ method : str, optional
141
+ Interpolation method to use (e.g., "linear", "nearest", "cubic").
142
+ Default is "nearest".
143
+ fill_value : str, optional
144
+ Method for filling NaNs at the start/end of the series after interpolation.
145
+ Set to "extrapolate" to fill with the nearest valid value when using 'nearest' or 'linear'.
146
+ Default is "extrapolate".
147
+ """
148
+ for name, var in dataset.data_vars.items():
149
+ # if the variable is non-numeric, skip
150
+ if not np.issubdtype(var.dtype, np.number):
151
+ continue
152
+ # if there are no NANs, skip
153
+ if not var.isnull().any().compute():
154
+ continue
155
+
156
+ dataset[name] = var.interpolate_na(
157
+ dim=dim,
158
+ method=method,
159
+ fill_value=fill_value if method in ["nearest", "linear"] else None,
160
+ )
119
161
 
120
- # sort of terrible work around for half downloaded files
121
- temp_path = cached_nc_path.with_suffix(".downloading.nc")
122
- if os.path.exists(temp_path):
123
- os.remove(temp_path)
124
162
 
125
- ## Cast every single variable to float32 to save space to save a lot of memory issues later
126
- ## easier to do it now in this slow download step than later in the steps without dask
127
- for var in stores.data_vars:
128
- stores[var] = stores[var].astype("float32")
163
+ @use_cluster
164
+ def save_dataset(
165
+ ds_to_save: xr.Dataset,
166
+ target_path: Path,
167
+ engine: Literal["netcdf4", "scipy", "h5netcdf"] = "h5netcdf",
168
+ ):
169
+ """
170
+ Helper function to compute and save an xarray.Dataset to a NetCDF file.
171
+ Uses a temporary file and rename for atomicity.
172
+ """
173
+ if not target_path.parent.exists():
174
+ target_path.parent.mkdir(parents=True, exist_ok=True)
175
+
176
+ temp_file_path = target_path.with_name(target_path.name + ".saving.nc")
177
+ if temp_file_path.exists():
178
+ os.remove(temp_file_path)
129
179
 
130
180
  client = Client.current()
131
- future = client.compute(stores.to_netcdf(temp_path, compute=False))
132
- # Display progress bar
181
+ future: Future = client.compute(
182
+ ds_to_save.to_netcdf(temp_file_path, engine=engine, compute=False)
183
+ ) # type: ignore
184
+ logger.debug(
185
+ f"NetCDF write task submitted to Dask. Waiting for completion to {temp_file_path}..."
186
+ )
133
187
  progress(future)
134
188
  future.result()
189
+ os.rename(str(temp_file_path), str(target_path))
190
+ logger.info(f"Successfully saved data to: {target_path}")
191
+
192
+
193
+ @use_cluster
194
+ def save_to_cache(
195
+ stores: xr.Dataset, cached_nc_path: Path, interpolate_nans: bool = True
196
+ ) -> xr.Dataset:
197
+ """
198
+ Compute the store and save it to a cached netCDF file. This is not required but will save time and bandwidth.
199
+ """
200
+ logger.info(f"Processing dataset for caching. Final cache target: {cached_nc_path}")
135
201
 
136
- os.rename(temp_path, cached_nc_path)
202
+ # lasily cast all numbers to f32
203
+ for name, var in stores.data_vars.items():
204
+ if np.issubdtype(var.dtype, np.number):
205
+ stores[name] = var.astype("float32", casting="same_kind")
137
206
 
138
- data = xr.open_mfdataset(cached_nc_path, parallel=True, engine="h5netcdf")
139
- return data
207
+ # save dataset locally before manipulating it
208
+ save_dataset(stores, cached_nc_path)
209
+ stores = xr.open_mfdataset(cached_nc_path, parallel=True, engine="h5netcdf")
210
+
211
+ if interpolate_nans:
212
+ interpolate_nan_values(dataset=stores)
213
+ save_dataset(stores, cached_nc_path)
214
+ stores = xr.open_mfdataset(cached_nc_path, parallel=True, engine="h5netcdf")
215
+
216
+ return stores
140
217
 
141
218
 
142
219
  def check_local_cache(
@@ -144,9 +221,8 @@ def check_local_cache(
144
221
  start_time: str,
145
222
  end_time: str,
146
223
  gdf: gpd.GeoDataFrame,
147
- remote_dataset: xr.Dataset
224
+ remote_dataset: xr.Dataset,
148
225
  ) -> Union[xr.Dataset, None]:
149
-
150
226
  merged_data = None
151
227
 
152
228
  if not os.path.exists(cached_nc_path):
@@ -155,9 +231,7 @@ def check_local_cache(
155
231
 
156
232
  logger.info("Found cached nc file")
157
233
  # open the cached file and check that the time range is correct
158
- cached_data = xr.open_mfdataset(
159
- cached_nc_path, parallel=True, engine="h5netcdf"
160
- )
234
+ cached_data = xr.open_mfdataset(cached_nc_path, parallel=True, engine="h5netcdf")
161
235
 
162
236
  if "name" not in cached_data.attrs or "name" not in remote_dataset.attrs:
163
237
  logger.warning("No name attribute found to compare datasets")
@@ -166,9 +240,9 @@ def check_local_cache(
166
240
  logger.warning("Cached data from different source, .name attr doesn't match")
167
241
  return
168
242
 
169
- range_in_cache = cached_data.time[0].values <= np.datetime64(
170
- start_time
171
- ) and cached_data.time[-1].values >= np.datetime64(end_time)
243
+ range_in_cache = cached_data.time[0].values <= np.datetime64(start_time) and cached_data.time[
244
+ -1
245
+ ].values >= np.datetime64(end_time)
172
246
 
173
247
  if not range_in_cache:
174
248
  # the cache does not contain the desired time range
@@ -186,10 +260,8 @@ def check_local_cache(
186
260
  if range_in_cache:
187
261
  logger.info("Time range is within cached data")
188
262
  logger.debug(f"Opened cached nc file: [{cached_nc_path}]")
189
- merged_data = clip_dataset_to_bounds(
190
- cached_data, gdf.total_bounds, start_time, end_time
191
- )
192
- logger.debug("Clipped stores")
263
+ merged_data = clip_dataset_to_bounds(cached_data, gdf.total_bounds, start_time, end_time)
264
+ logger.debug("Clipped stores")
193
265
 
194
266
  return merged_data
195
267
 
@@ -197,16 +269,27 @@ def check_local_cache(
197
269
  def save_and_clip_dataset(
198
270
  dataset: xr.Dataset,
199
271
  gdf: gpd.GeoDataFrame,
200
- start_time: datetime.datetime,
201
- end_time: datetime.datetime,
272
+ start_time: datetime,
273
+ end_time: datetime,
202
274
  cache_location: Path,
203
275
  ) -> xr.Dataset:
204
276
  """convenience function clip the remote dataset, and either load from cache or save to cache if it's not present"""
205
277
  gdf = gdf.to_crs(dataset.crs)
206
278
 
207
- cached_data = check_local_cache(cache_location, start_time, end_time, gdf, dataset)
279
+ cached_data = check_local_cache(
280
+ cache_location,
281
+ start_time, # type: ignore
282
+ end_time, # type: ignore
283
+ gdf,
284
+ dataset,
285
+ )
208
286
 
209
287
  if not cached_data:
210
- clipped_data = clip_dataset_to_bounds(dataset, gdf.total_bounds, start_time, end_time)
288
+ clipped_data = clip_dataset_to_bounds(
289
+ dataset,
290
+ gdf.total_bounds,
291
+ start_time, # type: ignore
292
+ end_time, # type: ignore
293
+ )
211
294
  cached_data = save_to_cache(clipped_data, cache_location)
212
- return cached_data
295
+ return cached_data
@@ -1,35 +1,31 @@
1
1
  import logging
2
+ from typing import Optional
2
3
 
3
4
  import s3fs
4
- from data_processing.s3fs_utils import S3ParallelFileSystem
5
5
  import xarray as xr
6
- from dask.distributed import Client, LocalCluster
6
+ from data_processing.dask_utils import use_cluster
7
7
  from data_processing.dataset_utils import validate_dataset_format
8
-
8
+ from data_processing.s3fs_utils import S3ParallelFileSystem
9
9
 
10
10
  logger = logging.getLogger(__name__)
11
11
 
12
12
 
13
- def load_v3_retrospective_zarr(forcing_vars: list[str] = None) -> xr.Dataset:
13
+ @use_cluster
14
+ def load_v3_retrospective_zarr(forcing_vars: Optional[list[str]] = None) -> xr.Dataset:
14
15
  """Load zarr datasets from S3 within the specified time range."""
15
16
  # if a LocalCluster is not already running, start one
16
17
  if not forcing_vars:
17
18
  forcing_vars = ["lwdown", "precip", "psfc", "q2d", "swdown", "t2d", "u2d", "v2d"]
18
- try:
19
- client = Client.current()
20
- except ValueError:
21
- cluster = LocalCluster()
22
- client = Client(cluster)
19
+
23
20
  s3_urls = [
24
- f"s3://noaa-nwm-retrospective-3-0-pds/CONUS/zarr/forcing/{var}.zarr"
25
- for var in forcing_vars
21
+ f"s3://noaa-nwm-retrospective-3-0-pds/CONUS/zarr/forcing/{var}.zarr" for var in forcing_vars
26
22
  ]
27
23
  # default cache is readahead which is detrimental to performance in this case
28
24
  fs = S3ParallelFileSystem(anon=True, default_cache_type="none") # default_block_size
29
25
  s3_stores = [s3fs.S3Map(url, s3=fs) for url in s3_urls]
30
26
  # the cache option here just holds accessed data in memory to prevent s3 being queried multiple times
31
27
  # most of the data is read once and written to disk but some of the coordinate data is read multiple times
32
- dataset = xr.open_mfdataset(s3_stores, parallel=True, engine="zarr", cache=True)
28
+ dataset = xr.open_mfdataset(s3_stores, parallel=True, engine="zarr", cache=True) # type: ignore
33
29
 
34
30
  # set the crs attribute to conform with the format
35
31
  esri_pe_string = dataset.crs.esri_pe_string
@@ -54,7 +50,8 @@ def load_v3_retrospective_zarr(forcing_vars: list[str] = None) -> xr.Dataset:
54
50
  return dataset
55
51
 
56
52
 
57
- def load_aorc_zarr(start_year: int = None, end_year: int = None) -> xr.Dataset:
53
+ @use_cluster
54
+ def load_aorc_zarr(start_year: Optional[int] = None, end_year: Optional[int] = None) -> xr.Dataset:
58
55
  """Load the aorc zarr dataset from S3."""
59
56
  if not start_year or not end_year:
60
57
  logger.warning("No start or end year provided, defaulting to 1979-2023")
@@ -63,11 +60,6 @@ def load_aorc_zarr(start_year: int = None, end_year: int = None) -> xr.Dataset:
63
60
  start_year = 1979
64
61
  if not end_year:
65
62
  end_year = 2023
66
- try:
67
- client = Client.current()
68
- except ValueError:
69
- cluster = LocalCluster()
70
- client = Client(cluster)
71
63
 
72
64
  logger.info(f"Loading AORC zarr datasets from {start_year} to {end_year}")
73
65
  estimated_time_s = ((end_year - start_year) * 2.5) + 3.5
@@ -75,9 +67,9 @@ def load_aorc_zarr(start_year: int = None, end_year: int = None) -> xr.Dataset:
75
67
  logger.info(f"This should take roughly {estimated_time_s} seconds")
76
68
  fs = S3ParallelFileSystem(anon=True, default_cache_type="none")
77
69
  s3_url = "s3://noaa-nws-aorc-v1-1-1km/"
78
- urls = [f"{s3_url}{i}.zarr" for i in range(start_year, end_year+1)]
70
+ urls = [f"{s3_url}{i}.zarr" for i in range(start_year, end_year + 1)]
79
71
  filestores = [s3fs.S3Map(url, s3=fs) for url in urls]
80
- dataset = xr.open_mfdataset(filestores, parallel=True, engine="zarr", cache=True)
72
+ dataset = xr.open_mfdataset(filestores, parallel=True, engine="zarr", cache=True) # type: ignore
81
73
  dataset.attrs["crs"] = "+proj=longlat +datum=WGS84 +no_defs"
82
74
  dataset.attrs["name"] = "aorc_1km_zarr"
83
75
  # rename latitude and longitude to x and y
@@ -87,32 +79,29 @@ def load_aorc_zarr(start_year: int = None, end_year: int = None) -> xr.Dataset:
87
79
  return dataset
88
80
 
89
81
 
82
+ @use_cluster
90
83
  def load_swe_zarr() -> xr.Dataset:
91
84
  """Load the swe zarr dataset from S3."""
92
- s3_urls = [
93
- f"s3://noaa-nwm-retrospective-3-0-pds/CONUS/zarr/ldasout.zarr"
94
- ]
85
+ s3_urls = ["s3://noaa-nwm-retrospective-3-0-pds/CONUS/zarr/ldasout.zarr"]
95
86
  # default cache is readahead which is detrimental to performance in this case
96
87
  fs = S3ParallelFileSystem(anon=True, default_cache_type="none") # default_block_size
97
88
  s3_stores = [s3fs.S3Map(url, s3=fs) for url in s3_urls]
98
89
  # the cache option here just holds accessed data in memory to prevent s3 being queried multiple times
99
90
  # most of the data is read once and written to disk but some of the coordinate data is read multiple times
100
- dataset = xr.open_mfdataset(s3_stores, parallel=True, engine="zarr", cache=True)
101
-
91
+ dataset = xr.open_mfdataset(s3_stores, parallel=True, engine="zarr", cache=True) # type: ignore
92
+
102
93
  # set the crs attribute to conform with the format
103
94
  esri_pe_string = dataset.crs.esri_pe_string
104
95
  dataset = dataset.drop_vars(["crs"])
105
96
  dataset.attrs["crs"] = esri_pe_string
106
97
  # drop everything except SNEQV
107
98
  vars_to_drop = list(dataset.data_vars)
108
- vars_to_drop.remove('SNEQV')
99
+ vars_to_drop.remove("SNEQV")
109
100
  dataset = dataset.drop_vars(vars_to_drop)
110
101
  dataset.attrs["name"] = "v3_swe_zarr"
111
102
 
112
103
  # rename the data vars to work with ngen
113
- variables = {
114
- "SNEQV": "swe"
115
- }
104
+ variables = {"SNEQV": "swe"}
116
105
  dataset = dataset.rename_vars(variables)
117
106
 
118
107
  validate_dataset_format(dataset)
@@ -1,11 +1,13 @@
1
1
  from pathlib import Path
2
-
2
+ from typing import Optional
3
+ from datetime import datetime
3
4
 
4
5
  class file_paths:
5
6
  """
6
7
  This class contains all of the file paths used in the data processing
7
8
  workflow.
8
9
  """
10
+
9
11
  config_file = Path("~/.ngiab/preprocessor").expanduser()
10
12
  hydrofabric_dir = Path("~/.ngiab/hydrofabric/v2.2").expanduser()
11
13
  hydrofabric_download_log = Path("~/.ngiab/hydrofabric/v2.2/download_log.json").expanduser()
@@ -31,7 +33,7 @@ class file_paths:
31
33
  template_em_config = data_sources / "em-catchment-template.yml"
32
34
  template_em_model_config = data_sources / "em-config.yml"
33
35
 
34
- def __init__(self, folder_name: str = None, output_dir: Path = None):
36
+ def __init__(self, folder_name: Optional[str] = None, output_dir: Optional[Path] = None):
35
37
  """
36
38
  Initialize the file_paths class with a the name of the output subfolder.
37
39
  OR the path to the output folder you want to use.
@@ -53,7 +55,7 @@ class file_paths:
53
55
  self.cache_dir.mkdir(parents=True, exist_ok=True)
54
56
 
55
57
  @classmethod
56
- def get_working_dir(cls) -> Path:
58
+ def get_working_dir(cls) -> Path | None:
57
59
  try:
58
60
  with open(cls.config_file, "r") as f:
59
61
  return Path(f.readline().strip()).expanduser()
@@ -67,9 +69,7 @@ class file_paths:
67
69
 
68
70
  @classmethod
69
71
  def root_output_dir(cls) -> Path:
70
- if cls.get_working_dir() is not None:
71
- return cls.get_working_dir()
72
- return Path(__file__).parent.parent.parent / "output"
72
+ return cls.get_working_dir() or Path(__file__).parent.parent.parent / "output"
73
73
 
74
74
  @property
75
75
  def subset_dir(self) -> Path:
@@ -102,7 +102,7 @@ class file_paths:
102
102
  def append_cli_command(self, command: list[str]) -> None:
103
103
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
104
104
  command_string = " ".join(command)
105
- history_file = self.metadata_dir / "cli_commands_history.txt"
105
+ history_file = self.metadata_dir / "cli_commands_history.txt"
106
106
  if not history_file.parent.exists():
107
107
  history_file.parent.mkdir(parents=True, exist_ok=True)
108
108
  with open(self.metadata_dir / "cli_commands_history.txt", "a") as f: