ngiab-data-preprocess 4.2.1__py3-none-any.whl → 4.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,30 +1,32 @@
1
1
  import json
2
+ import logging
2
3
  import multiprocessing
4
+ import shutil
3
5
  import sqlite3
4
6
  from datetime import datetime
5
7
  from pathlib import Path
6
- import shutil
7
- import requests
8
+ from typing import Dict, Optional
8
9
 
9
10
  import pandas
11
+ import requests
10
12
  import s3fs
11
13
  import xarray as xr
12
- import logging
13
- from collections import defaultdict
14
- from dask.distributed import Client, LocalCluster
14
+ from data_processing.dask_utils import temp_cluster
15
15
  from data_processing.file_paths import file_paths
16
16
  from data_processing.gpkg_utils import (
17
17
  GeoPackage,
18
+ get_cat_to_nex_flowpairs,
18
19
  get_cat_to_nhd_feature_id,
19
20
  get_table_crs_short,
20
- get_cat_to_nex_flowpairs,
21
21
  )
22
- from tqdm.rich import tqdm
23
22
  from pyproj import Transformer
23
+ from tqdm.rich import tqdm
24
24
 
25
25
  logger = logging.getLogger(__name__)
26
26
 
27
- def get_approximate_gw_storage(paths: file_paths, start_date: datetime):
27
+
28
+ @temp_cluster
29
+ def get_approximate_gw_storage(paths: file_paths, start_date: datetime) -> Dict[str, int]:
28
30
  # get the gw levels from the NWM output on a given start date
29
31
  # this kind of works in place of warmstates for now
30
32
  year = start_date.strftime("%Y")
@@ -34,17 +36,10 @@ def get_approximate_gw_storage(paths: file_paths, start_date: datetime):
34
36
  fs = s3fs.S3FileSystem(anon=True)
35
37
  nc_url = f"s3://noaa-nwm-retrospective-3-0-pds/CONUS/netcdf/GWOUT/{year}/{formatted_dt}.GWOUT_DOMAIN1"
36
38
 
37
- # make sure there's a dask cluster running
38
- try:
39
- client = Client.current()
40
- except ValueError:
41
- cluster = LocalCluster()
42
- client = Client(cluster)
43
-
44
39
  with fs.open(nc_url) as file_obj:
45
- ds = xr.open_dataset(file_obj)
40
+ ds = xr.open_dataset(file_obj) # type: ignore
46
41
 
47
- water_levels = dict()
42
+ water_levels: Dict[str, int] = dict()
48
43
  for cat, feature in tqdm(cat_to_feature.items()):
49
44
  # this value is in CM, we need meters to match max_gw_depth
50
45
  # xarray says it's in mm, with 0.1 scale factor. calling .values doesn't apply the scale
@@ -78,7 +73,9 @@ def make_cfe_config(
78
73
  slope=row["mean.slope_1km"],
79
74
  smcmax=row["mean.smcmax_soil_layers_stag=2"],
80
75
  smcwlt=row["mean.smcwlt_soil_layers_stag=2"],
81
- max_gw_storage=row["mean.Zmax"]/1000 if row["mean.Zmax"] is not None else "0.011[m]", # mean.Zmax is in mm!
76
+ max_gw_storage=row["mean.Zmax"] / 1000
77
+ if row["mean.Zmax"] is not None
78
+ else "0.011[m]", # mean.Zmax is in mm!
82
79
  gw_Coeff=row["mean.Coeff"] if row["mean.Coeff"] is not None else "0.0018[m h-1]",
83
80
  gw_Expon=row["mode.Expon"],
84
81
  gw_storage="{:.5}".format(gw_storage_ratio),
@@ -92,7 +89,6 @@ def make_cfe_config(
92
89
  def make_noahowp_config(
93
90
  base_dir: Path, divide_conf_df: pandas.DataFrame, start_time: datetime, end_time: datetime
94
91
  ) -> None:
95
-
96
92
  divide_conf_df.set_index("divide_id", inplace=True)
97
93
  start_datetime = start_time.strftime("%Y%m%d%H%M")
98
94
  end_datetime = end_time.strftime("%Y%m%d%H%M")
@@ -110,15 +106,15 @@ def make_noahowp_config(
110
106
  end_datetime=end_datetime,
111
107
  lat=divide_conf_df.loc[divide, "latitude"],
112
108
  lon=divide_conf_df.loc[divide, "longitude"],
113
- terrain_slope= divide_conf_df.loc[divide, "mean.slope_1km"],
114
- azimuth= divide_conf_df.loc[divide, "circ_mean.aspect"],
115
- ISLTYP=int(divide_conf_df.loc[divide, "mode.ISLTYP"]),
116
- IVGTYP=int(divide_conf_df.loc[divide, "mode.IVGTYP"]),
109
+ terrain_slope=divide_conf_df.loc[divide, "mean.slope_1km"],
110
+ azimuth=divide_conf_df.loc[divide, "circ_mean.aspect"],
111
+ ISLTYP=int(divide_conf_df.loc[divide, "mode.ISLTYP"]), # type: ignore
112
+ IVGTYP=int(divide_conf_df.loc[divide, "mode.IVGTYP"]), # type: ignore
117
113
  )
118
114
  )
119
115
 
120
116
 
121
- def get_model_attributes_modspatialite(hydrofabric: Path):
117
+ def get_model_attributes_modspatialite(hydrofabric: Path) -> pandas.DataFrame:
122
118
  # modspatialite is faster than pyproj but can't be added as a pip dependency
123
119
  # This incantation took a while
124
120
  with GeoPackage(hydrofabric) as conn:
@@ -149,7 +145,7 @@ def get_model_attributes_modspatialite(hydrofabric: Path):
149
145
  return divide_conf_df
150
146
 
151
147
 
152
- def get_model_attributes_pyproj(hydrofabric: Path):
148
+ def get_model_attributes_pyproj(hydrofabric: Path) -> pandas.DataFrame:
153
149
  # if modspatialite is not available, use pyproj
154
150
  with sqlite3.connect(hydrofabric) as conn:
155
151
  sql = """
@@ -182,7 +178,8 @@ def get_model_attributes_pyproj(hydrofabric: Path):
182
178
 
183
179
  return divide_conf_df
184
180
 
185
- def get_model_attributes(hydrofabric: Path):
181
+
182
+ def get_model_attributes(hydrofabric: Path) -> pandas.DataFrame:
186
183
  try:
187
184
  with GeoPackage(hydrofabric) as conn:
188
185
  conf_df = pandas.read_sql_query(
@@ -205,30 +202,31 @@ def get_model_attributes(hydrofabric: Path):
205
202
  )
206
203
  except sqlite3.OperationalError:
207
204
  with sqlite3.connect(hydrofabric) as conn:
208
- conf_df = pandas.read_sql_query("SELECT* FROM 'divide-attributes';", conn,)
205
+ conf_df = pandas.read_sql_query(
206
+ "SELECT* FROM 'divide-attributes';",
207
+ conn,
208
+ )
209
209
  source_crs = get_table_crs_short(hydrofabric, "divides")
210
210
  transformer = Transformer.from_crs(source_crs, "EPSG:4326", always_xy=True)
211
- lon, lat = transformer.transform(
212
- conf_df["centroid_x"].values, conf_df["centroid_y"].values
213
- )
211
+ lon, lat = transformer.transform(conf_df["centroid_x"].values, conf_df["centroid_y"].values)
214
212
  conf_df["longitude"] = lon
215
213
  conf_df["latitude"] = lat
216
214
 
217
215
  conf_df.drop(columns=["centroid_x", "centroid_y"], axis=1, inplace=True)
218
216
  return conf_df
219
217
 
218
+
220
219
  def make_em_config(
221
220
  hydrofabric: Path,
222
221
  output_dir: Path,
223
222
  template_path: Path = file_paths.template_em_config,
224
223
  ):
225
-
226
224
  # test if modspatialite is available
227
225
  try:
228
226
  divide_conf_df = get_model_attributes_modspatialite(hydrofabric)
229
227
  except Exception as e:
230
228
  logger.warning(f"mod_spatialite not available, using pyproj instead: {e}")
231
- logger.warning(f"Install mod_spatialite for improved performance")
229
+ logger.warning("Install mod_spatialite for improved performance")
232
230
  divide_conf_df = get_model_attributes_pyproj(hydrofabric)
233
231
 
234
232
  cat_config_dir = output_dir / "cat_config" / "empirical_model"
@@ -255,8 +253,7 @@ def make_em_config(
255
253
 
256
254
  def configure_troute(
257
255
  cat_id: str, config_dir: Path, start_time: datetime, end_time: datetime
258
- ) -> int:
259
-
256
+ ) -> None:
260
257
  with open(file_paths.template_troute_config, "r") as file:
261
258
  troute_template = file.read()
262
259
  time_step_size = 300
@@ -269,7 +266,7 @@ def configure_troute(
269
266
  geo_file_path=f"./config/{cat_id}_subset.gpkg",
270
267
  start_datetime=start_time.strftime("%Y-%m-%d %H:%M:%S"),
271
268
  nts=nts,
272
- max_loop_size=nts,
269
+ max_loop_size=nts,
273
270
  )
274
271
 
275
272
  with open(config_dir / "troute.yaml", "w") as file:
@@ -301,9 +298,7 @@ def create_em_realization(cat_id: str, start_time: datetime, end_time: datetime)
301
298
  f.write(em_config)
302
299
 
303
300
  configure_troute(cat_id, paths.config_dir, start_time, end_time)
304
- make_ngen_realization_json(
305
- paths.config_dir, template_path, start_time, end_time
306
- )
301
+ make_ngen_realization_json(paths.config_dir, template_path, start_time, end_time)
307
302
  make_em_config(paths.geopackage_path, paths.config_dir)
308
303
  # create some partitions for parallelization
309
304
  paths.setup_run_folders()
@@ -315,7 +310,7 @@ def create_realization(
315
310
  start_time: datetime,
316
311
  end_time: datetime,
317
312
  use_nwm_gw: bool = False,
318
- gage_id: str = None,
313
+ gage_id: Optional[str] = None,
319
314
  ):
320
315
  paths = file_paths(cat_id)
321
316
 
@@ -324,15 +319,14 @@ def create_realization(
324
319
  if gage_id is not None:
325
320
  # try and download s3:communityhydrofabric/hydrofabrics/community/gage_parameters/gage_id
326
321
  # if it doesn't exist, use the default
327
- try:
328
- url = f"https://communityhydrofabric.s3.us-east-1.amazonaws.com/hydrofabrics/community/gage_parameters/{gage_id}.json"
329
-
322
+ url = f"https://communityhydrofabric.s3.us-east-1.amazonaws.com/hydrofabrics/community/gage_parameters/{gage_id}.json"
323
+ response = requests.get(url)
324
+ if response.status_code == 200:
330
325
  new_template = requests.get(url).json()
331
- template_path = paths.config_dir / "calibrated_params.json"
326
+ template_path = paths.config_dir / "downloaded_params.json"
332
327
  with open(template_path, "w") as f:
333
328
  json.dump(new_template, f)
334
- except Exception as e:
335
- logger.warning(f"Failed to download gage parameters")
329
+ logger.info(f"downloaded calibrated parameters for {gage_id}")
336
330
 
337
331
  conf_df = get_model_attributes(paths.geopackage_path)
338
332
 
@@ -347,21 +341,19 @@ def create_realization(
347
341
 
348
342
  configure_troute(cat_id, paths.config_dir, start_time, end_time)
349
343
 
350
- make_ngen_realization_json(
351
- paths.config_dir, template_path, start_time, end_time
352
- )
344
+ make_ngen_realization_json(paths.config_dir, template_path, start_time, end_time)
353
345
 
354
346
  # create some partitions for parallelization
355
347
  paths.setup_run_folders()
356
348
  create_partitions(paths)
357
349
 
358
350
 
359
- def create_partitions(paths: Path, num_partitions: int = None) -> None:
351
+ def create_partitions(paths: file_paths, num_partitions: Optional[int] = None) -> None:
360
352
  if num_partitions is None:
361
353
  num_partitions = multiprocessing.cpu_count()
362
354
 
363
355
  cat_to_nex_pairs = get_cat_to_nex_flowpairs(hydrofabric=paths.geopackage_path)
364
- nexus = defaultdict(list)
356
+ # nexus = defaultdict(list)
365
357
 
366
358
  # for cat, nex in cat_to_nex_pairs:
367
359
  # nexus[nex].append(cat)
@@ -0,0 +1,92 @@
1
+ import logging
2
+
3
+ from dask.distributed import Client
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ def shutdown_cluster():
9
+ try:
10
+ client = Client.current()
11
+ client.shutdown()
12
+ except ValueError:
13
+ logger.debug("No cluster found to shutdown")
14
+
15
+
16
+ def no_cluster(func):
17
+ """
18
+ Decorator that ensures the wrapped function runs with no active Dask cluster.
19
+
20
+ This decorator attempts to shut down any existing Dask cluster before
21
+ executing the wrapped function. If no cluster is found, it logs a debug message
22
+ and continues execution.
23
+
24
+ Parameters:
25
+ func: The function to be executed without a Dask cluster
26
+
27
+ Returns:
28
+ wrapper: The wrapped function that will be executed without a Dask cluster
29
+ """
30
+
31
+ def wrapper(*args, **kwargs):
32
+ shutdown_cluster()
33
+ result = func(*args, **kwargs)
34
+ return result
35
+
36
+ return wrapper
37
+
38
+
39
+ def use_cluster(func):
40
+ """
41
+ Decorator that ensures the wrapped function has access to a Dask cluster.
42
+
43
+ If a Dask cluster is already running, it uses the existing one.
44
+ If no cluster is available, it creates a new one before executing the function.
45
+ The cluster remains active after the function completes.
46
+
47
+ Parameters:
48
+ func: The function to be executed with a Dask cluster
49
+
50
+ Returns:
51
+ wrapper: The wrapped function with access to a Dask cluster
52
+ """
53
+
54
+ def wrapper(*args, **kwargs):
55
+ try:
56
+ client = Client.current()
57
+ except ValueError:
58
+ client = Client()
59
+ result = func(*args, **kwargs)
60
+ return result
61
+
62
+ return wrapper
63
+
64
+
65
+ def temp_cluster(func):
66
+ """
67
+ Decorator that provides a temporary Dask cluster for the wrapped function.
68
+
69
+ If a Dask cluster is already running, it uses the existing one and leaves it running.
70
+ If no cluster exists, it creates a temporary one and shuts it down after
71
+ the function completes.
72
+
73
+ Parameters:
74
+ func: The function to be executed with a Dask cluster
75
+
76
+ Returns:
77
+ wrapper: The wrapped function with access to a Dask cluster
78
+ """
79
+
80
+ def wrapper(*args, **kwargs):
81
+ cluster_was_running = True
82
+ try:
83
+ client = Client.current()
84
+ except ValueError:
85
+ cluster_was_running = False
86
+ client = Client()
87
+ result = func(*args, **kwargs)
88
+ if not cluster_was_running:
89
+ client.shutdown()
90
+ return result
91
+
92
+ return wrapper
@@ -1,19 +1,22 @@
1
1
  import logging
2
2
  import os
3
+ from datetime import datetime
3
4
  from pathlib import Path
4
- from typing import Tuple, Union
5
+ from typing import List, Literal, Optional, Tuple, Union
5
6
 
6
7
  import geopandas as gpd
7
8
  import numpy as np
8
9
  import xarray as xr
9
- from dask.distributed import Client, progress
10
- import datetime
10
+ from xarray.core.types import InterpOptions
11
+ from dask.distributed import Client, progress, Future
12
+ from data_processing.dask_utils import use_cluster
11
13
 
12
14
  logger = logging.getLogger(__name__)
13
15
 
14
16
  # known ngen variable names
15
17
  # https://github.com/CIROH-UA/ngen/blob/4fb5bb68dc397298bca470dfec94db2c1dcb42fe/include/forcing/AorcForcing.hpp#L77
16
18
 
19
+
17
20
  def validate_dataset_format(dataset: xr.Dataset) -> None:
18
21
  """
19
22
  Validate the format of the dataset.
@@ -41,8 +44,9 @@ def validate_dataset_format(dataset: xr.Dataset) -> None:
41
44
  if "name" not in dataset.attrs:
42
45
  raise ValueError("Dataset must have a name attribute to identify it")
43
46
 
47
+
44
48
  def validate_time_range(dataset: xr.Dataset, start_time: str, end_time: str) -> Tuple[str, str]:
45
- '''
49
+ """
46
50
  Ensure that all selected times are in the passed dataset.
47
51
 
48
52
  Parameters
@@ -60,7 +64,7 @@ def validate_time_range(dataset: xr.Dataset, start_time: str, end_time: str) ->
60
64
  start_time, or if not available, earliest available timestep in dataset.
61
65
  str
62
66
  end_time, or if not available, latest available timestep in dataset.
63
- '''
67
+ """
64
68
  end_time_in_dataset = dataset.time.isel(time=-1).values
65
69
  start_time_in_dataset = dataset.time.isel(time=0).values
66
70
  if np.datetime64(start_time) < start_time_in_dataset:
@@ -77,7 +81,10 @@ def validate_time_range(dataset: xr.Dataset, start_time: str, end_time: str) ->
77
81
 
78
82
 
79
83
  def clip_dataset_to_bounds(
80
- dataset: xr.Dataset, bounds: Tuple[float, float, float, float], start_time: str, end_time: str
84
+ dataset: xr.Dataset,
85
+ bounds: Tuple[float, float, float, float] | np.ndarray[tuple[int], np.dtype[np.float64]],
86
+ start_time: str,
87
+ end_time: str,
81
88
  ) -> xr.Dataset:
82
89
  """
83
90
  Clip the dataset to specified geographical bounds.
@@ -86,14 +93,14 @@ def clip_dataset_to_bounds(
86
93
  ----------
87
94
  dataset : xr.Dataset
88
95
  Dataset to be clipped.
89
- bounds : tuple[float, float, float, float]
90
- Corners of bounding box. bounds[0] is x_min, bounds[1] is y_min,
96
+ bounds : tuple[float, float, float, float] | np.ndarray[tuple[int], np.dtype[np.float64]]
97
+ Corners of bounding box. bounds[0] is x_min, bounds[1] is y_min,
91
98
  bounds[2] is x_max, bounds[3] is y_max.
92
99
  start_time : str
93
100
  Desired start time in YYYY/MM/DD HH:MM:SS format.
94
101
  end_time : str
95
102
  Desired end time in YYYY/MM/DD HH:MM:SS format.
96
-
103
+
97
104
  Returns
98
105
  -------
99
106
  xr.Dataset
@@ -110,33 +117,103 @@ def clip_dataset_to_bounds(
110
117
  return dataset
111
118
 
112
119
 
113
- def save_to_cache(stores: xr.Dataset, cached_nc_path: Path) -> xr.Dataset:
114
- """Compute the store and save it to a cached netCDF file. This is not required but will save time and bandwidth."""
115
- logger.info("Downloading and caching forcing data, this may take a while")
120
+ def interpolate_nan_values(
121
+ dataset: xr.Dataset,
122
+ variables: Optional[List[str]] = None,
123
+ dim: str = "time",
124
+ method: InterpOptions = "nearest",
125
+ fill_value: str = "extrapolate",
126
+ ) -> None:
127
+ """
128
+ Interpolates NaN values in specified (or all numeric time-dependent)
129
+ variables of an xarray.Dataset. Operates inplace on the dataset.
116
130
 
117
- if not cached_nc_path.parent.exists():
118
- cached_nc_path.parent.mkdir(parents=True)
131
+ Parameters
132
+ ----------
133
+ dataset : xr.Dataset
134
+ The input dataset.
135
+ variables : Optional[List[str]], optional
136
+ A list of variable names to process. If None (default),
137
+ all numeric variables containing the specified dimension will be processed.
138
+ dim : str, optional
139
+ The dimension along which to interpolate (default is "time").
140
+ method : str, optional
141
+ Interpolation method to use (e.g., "linear", "nearest", "cubic").
142
+ Default is "nearest".
143
+ fill_value : str, optional
144
+ Method for filling NaNs at the start/end of the series after interpolation.
145
+ Set to "extrapolate" to fill with the nearest valid value when using 'nearest' or 'linear'.
146
+ Default is "extrapolate".
147
+ """
148
+ for name, var in dataset.data_vars.items():
149
+ # if the variable is non-numeric, skip
150
+ if not np.issubdtype(var.dtype, np.number):
151
+ continue
152
+ # if there are no NANs, skip
153
+ if not var.isnull().any().compute():
154
+ continue
155
+
156
+ dataset[name] = var.interpolate_na(
157
+ dim=dim,
158
+ method=method,
159
+ fill_value=fill_value if method in ["nearest", "linear"] else None,
160
+ )
119
161
 
120
- # sort of terrible work around for half downloaded files
121
- temp_path = cached_nc_path.with_suffix(".downloading.nc")
122
- if os.path.exists(temp_path):
123
- os.remove(temp_path)
124
162
 
125
- ## Cast every single variable to float32 to save space to save a lot of memory issues later
126
- ## easier to do it now in this slow download step than later in the steps without dask
127
- for var in stores.data_vars:
128
- stores[var] = stores[var].astype("float32")
163
+ @use_cluster
164
+ def save_dataset(
165
+ ds_to_save: xr.Dataset,
166
+ target_path: Path,
167
+ engine: Literal["netcdf4", "scipy", "h5netcdf"] = "h5netcdf",
168
+ ):
169
+ """
170
+ Helper function to compute and save an xarray.Dataset to a NetCDF file.
171
+ Uses a temporary file and rename for atomicity.
172
+ """
173
+ if not target_path.parent.exists():
174
+ target_path.parent.mkdir(parents=True, exist_ok=True)
175
+
176
+ temp_file_path = target_path.with_name(target_path.name + ".saving.nc")
177
+ if temp_file_path.exists():
178
+ os.remove(temp_file_path)
129
179
 
130
180
  client = Client.current()
131
- future = client.compute(stores.to_netcdf(temp_path, compute=False))
132
- # Display progress bar
181
+ future: Future = client.compute(
182
+ ds_to_save.to_netcdf(temp_file_path, engine=engine, compute=False)
183
+ ) # type: ignore
184
+ logger.debug(
185
+ f"NetCDF write task submitted to Dask. Waiting for completion to {temp_file_path}..."
186
+ )
133
187
  progress(future)
134
188
  future.result()
189
+ os.rename(str(temp_file_path), str(target_path))
190
+ logger.info(f"Successfully saved data to: {target_path}")
191
+
192
+
193
+ @use_cluster
194
+ def save_to_cache(
195
+ stores: xr.Dataset, cached_nc_path: Path, interpolate_nans: bool = True
196
+ ) -> xr.Dataset:
197
+ """
198
+ Compute the store and save it to a cached netCDF file. This is not required but will save time and bandwidth.
199
+ """
200
+ logger.info(f"Processing dataset for caching. Final cache target: {cached_nc_path}")
135
201
 
136
- os.rename(temp_path, cached_nc_path)
202
+ # lasily cast all numbers to f32
203
+ for name, var in stores.data_vars.items():
204
+ if np.issubdtype(var.dtype, np.number):
205
+ stores[name] = var.astype("float32", casting="same_kind")
137
206
 
138
- data = xr.open_mfdataset(cached_nc_path, parallel=True, engine="h5netcdf")
139
- return data
207
+ # save dataset locally before manipulating it
208
+ save_dataset(stores, cached_nc_path)
209
+ stores = xr.open_mfdataset(cached_nc_path, parallel=True, engine="h5netcdf")
210
+
211
+ if interpolate_nans:
212
+ interpolate_nan_values(dataset=stores)
213
+ save_dataset(stores, cached_nc_path)
214
+ stores = xr.open_mfdataset(cached_nc_path, parallel=True, engine="h5netcdf")
215
+
216
+ return stores
140
217
 
141
218
 
142
219
  def check_local_cache(
@@ -144,9 +221,8 @@ def check_local_cache(
144
221
  start_time: str,
145
222
  end_time: str,
146
223
  gdf: gpd.GeoDataFrame,
147
- remote_dataset: xr.Dataset
224
+ remote_dataset: xr.Dataset,
148
225
  ) -> Union[xr.Dataset, None]:
149
-
150
226
  merged_data = None
151
227
 
152
228
  if not os.path.exists(cached_nc_path):
@@ -155,9 +231,7 @@ def check_local_cache(
155
231
 
156
232
  logger.info("Found cached nc file")
157
233
  # open the cached file and check that the time range is correct
158
- cached_data = xr.open_mfdataset(
159
- cached_nc_path, parallel=True, engine="h5netcdf"
160
- )
234
+ cached_data = xr.open_mfdataset(cached_nc_path, parallel=True, engine="h5netcdf")
161
235
 
162
236
  if "name" not in cached_data.attrs or "name" not in remote_dataset.attrs:
163
237
  logger.warning("No name attribute found to compare datasets")
@@ -166,9 +240,9 @@ def check_local_cache(
166
240
  logger.warning("Cached data from different source, .name attr doesn't match")
167
241
  return
168
242
 
169
- range_in_cache = cached_data.time[0].values <= np.datetime64(
170
- start_time
171
- ) and cached_data.time[-1].values >= np.datetime64(end_time)
243
+ range_in_cache = cached_data.time[0].values <= np.datetime64(start_time) and cached_data.time[
244
+ -1
245
+ ].values >= np.datetime64(end_time)
172
246
 
173
247
  if not range_in_cache:
174
248
  # the cache does not contain the desired time range
@@ -186,10 +260,8 @@ def check_local_cache(
186
260
  if range_in_cache:
187
261
  logger.info("Time range is within cached data")
188
262
  logger.debug(f"Opened cached nc file: [{cached_nc_path}]")
189
- merged_data = clip_dataset_to_bounds(
190
- cached_data, gdf.total_bounds, start_time, end_time
191
- )
192
- logger.debug("Clipped stores")
263
+ merged_data = clip_dataset_to_bounds(cached_data, gdf.total_bounds, start_time, end_time)
264
+ logger.debug("Clipped stores")
193
265
 
194
266
  return merged_data
195
267
 
@@ -197,16 +269,27 @@ def check_local_cache(
197
269
  def save_and_clip_dataset(
198
270
  dataset: xr.Dataset,
199
271
  gdf: gpd.GeoDataFrame,
200
- start_time: datetime.datetime,
201
- end_time: datetime.datetime,
272
+ start_time: datetime,
273
+ end_time: datetime,
202
274
  cache_location: Path,
203
275
  ) -> xr.Dataset:
204
276
  """convenience function clip the remote dataset, and either load from cache or save to cache if it's not present"""
205
277
  gdf = gdf.to_crs(dataset.crs)
206
278
 
207
- cached_data = check_local_cache(cache_location, start_time, end_time, gdf, dataset)
279
+ cached_data = check_local_cache(
280
+ cache_location,
281
+ start_time, # type: ignore
282
+ end_time, # type: ignore
283
+ gdf,
284
+ dataset,
285
+ )
208
286
 
209
287
  if not cached_data:
210
- clipped_data = clip_dataset_to_bounds(dataset, gdf.total_bounds, start_time, end_time)
288
+ clipped_data = clip_dataset_to_bounds(
289
+ dataset,
290
+ gdf.total_bounds,
291
+ start_time, # type: ignore
292
+ end_time, # type: ignore
293
+ )
211
294
  cached_data = save_to_cache(clipped_data, cache_location)
212
- return cached_data
295
+ return cached_data