ngiab-data-preprocess 4.2.1__py3-none-any.whl → 4.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,35 +1,31 @@
1
1
  import logging
2
+ from typing import Optional
2
3
 
3
4
  import s3fs
4
- from data_processing.s3fs_utils import S3ParallelFileSystem
5
5
  import xarray as xr
6
- from dask.distributed import Client, LocalCluster
6
+ from data_processing.dask_utils import use_cluster
7
7
  from data_processing.dataset_utils import validate_dataset_format
8
-
8
+ from data_processing.s3fs_utils import S3ParallelFileSystem
9
9
 
10
10
  logger = logging.getLogger(__name__)
11
11
 
12
12
 
13
- def load_v3_retrospective_zarr(forcing_vars: list[str] = None) -> xr.Dataset:
13
+ @use_cluster
14
+ def load_v3_retrospective_zarr(forcing_vars: Optional[list[str]] = None) -> xr.Dataset:
14
15
  """Load zarr datasets from S3 within the specified time range."""
15
16
  # if a LocalCluster is not already running, start one
16
17
  if not forcing_vars:
17
18
  forcing_vars = ["lwdown", "precip", "psfc", "q2d", "swdown", "t2d", "u2d", "v2d"]
18
- try:
19
- client = Client.current()
20
- except ValueError:
21
- cluster = LocalCluster()
22
- client = Client(cluster)
19
+
23
20
  s3_urls = [
24
- f"s3://noaa-nwm-retrospective-3-0-pds/CONUS/zarr/forcing/{var}.zarr"
25
- for var in forcing_vars
21
+ f"s3://noaa-nwm-retrospective-3-0-pds/CONUS/zarr/forcing/{var}.zarr" for var in forcing_vars
26
22
  ]
27
23
  # default cache is readahead which is detrimental to performance in this case
28
24
  fs = S3ParallelFileSystem(anon=True, default_cache_type="none") # default_block_size
29
25
  s3_stores = [s3fs.S3Map(url, s3=fs) for url in s3_urls]
30
26
  # the cache option here just holds accessed data in memory to prevent s3 being queried multiple times
31
27
  # most of the data is read once and written to disk but some of the coordinate data is read multiple times
32
- dataset = xr.open_mfdataset(s3_stores, parallel=True, engine="zarr", cache=True)
28
+ dataset = xr.open_mfdataset(s3_stores, parallel=True, engine="zarr", cache=True) # type: ignore
33
29
 
34
30
  # set the crs attribute to conform with the format
35
31
  esri_pe_string = dataset.crs.esri_pe_string
@@ -54,7 +50,8 @@ def load_v3_retrospective_zarr(forcing_vars: list[str] = None) -> xr.Dataset:
54
50
  return dataset
55
51
 
56
52
 
57
- def load_aorc_zarr(start_year: int = None, end_year: int = None) -> xr.Dataset:
53
+ @use_cluster
54
+ def load_aorc_zarr(start_year: Optional[int] = None, end_year: Optional[int] = None) -> xr.Dataset:
58
55
  """Load the aorc zarr dataset from S3."""
59
56
  if not start_year or not end_year:
60
57
  logger.warning("No start or end year provided, defaulting to 1979-2023")
@@ -63,11 +60,6 @@ def load_aorc_zarr(start_year: int = None, end_year: int = None) -> xr.Dataset:
63
60
  start_year = 1979
64
61
  if not end_year:
65
62
  end_year = 2023
66
- try:
67
- client = Client.current()
68
- except ValueError:
69
- cluster = LocalCluster()
70
- client = Client(cluster)
71
63
 
72
64
  logger.info(f"Loading AORC zarr datasets from {start_year} to {end_year}")
73
65
  estimated_time_s = ((end_year - start_year) * 2.5) + 3.5
@@ -75,9 +67,9 @@ def load_aorc_zarr(start_year: int = None, end_year: int = None) -> xr.Dataset:
75
67
  logger.info(f"This should take roughly {estimated_time_s} seconds")
76
68
  fs = S3ParallelFileSystem(anon=True, default_cache_type="none")
77
69
  s3_url = "s3://noaa-nws-aorc-v1-1-1km/"
78
- urls = [f"{s3_url}{i}.zarr" for i in range(start_year, end_year+1)]
70
+ urls = [f"{s3_url}{i}.zarr" for i in range(start_year, end_year + 1)]
79
71
  filestores = [s3fs.S3Map(url, s3=fs) for url in urls]
80
- dataset = xr.open_mfdataset(filestores, parallel=True, engine="zarr", cache=True)
72
+ dataset = xr.open_mfdataset(filestores, parallel=True, engine="zarr", cache=True) # type: ignore
81
73
  dataset.attrs["crs"] = "+proj=longlat +datum=WGS84 +no_defs"
82
74
  dataset.attrs["name"] = "aorc_1km_zarr"
83
75
  # rename latitude and longitude to x and y
@@ -87,32 +79,29 @@ def load_aorc_zarr(start_year: int = None, end_year: int = None) -> xr.Dataset:
87
79
  return dataset
88
80
 
89
81
 
82
+ @use_cluster
90
83
  def load_swe_zarr() -> xr.Dataset:
91
84
  """Load the swe zarr dataset from S3."""
92
- s3_urls = [
93
- f"s3://noaa-nwm-retrospective-3-0-pds/CONUS/zarr/ldasout.zarr"
94
- ]
85
+ s3_urls = ["s3://noaa-nwm-retrospective-3-0-pds/CONUS/zarr/ldasout.zarr"]
95
86
  # default cache is readahead which is detrimental to performance in this case
96
87
  fs = S3ParallelFileSystem(anon=True, default_cache_type="none") # default_block_size
97
88
  s3_stores = [s3fs.S3Map(url, s3=fs) for url in s3_urls]
98
89
  # the cache option here just holds accessed data in memory to prevent s3 being queried multiple times
99
90
  # most of the data is read once and written to disk but some of the coordinate data is read multiple times
100
- dataset = xr.open_mfdataset(s3_stores, parallel=True, engine="zarr", cache=True)
101
-
91
+ dataset = xr.open_mfdataset(s3_stores, parallel=True, engine="zarr", cache=True) # type: ignore
92
+
102
93
  # set the crs attribute to conform with the format
103
94
  esri_pe_string = dataset.crs.esri_pe_string
104
95
  dataset = dataset.drop_vars(["crs"])
105
96
  dataset.attrs["crs"] = esri_pe_string
106
97
  # drop everything except SNEQV
107
98
  vars_to_drop = list(dataset.data_vars)
108
- vars_to_drop.remove('SNEQV')
99
+ vars_to_drop.remove("SNEQV")
109
100
  dataset = dataset.drop_vars(vars_to_drop)
110
101
  dataset.attrs["name"] = "v3_swe_zarr"
111
102
 
112
103
  # rename the data vars to work with ngen
113
- variables = {
114
- "SNEQV": "swe"
115
- }
104
+ variables = {"SNEQV": "swe"}
116
105
  dataset = dataset.rename_vars(variables)
117
106
 
118
107
  validate_dataset_format(dataset)
@@ -1,11 +1,13 @@
1
1
  from pathlib import Path
2
-
2
+ from typing import Optional
3
+ from datetime import datetime
3
4
 
4
5
  class file_paths:
5
6
  """
6
7
  This class contains all of the file paths used in the data processing
7
8
  workflow.
8
9
  """
10
+
9
11
  config_file = Path("~/.ngiab/preprocessor").expanduser()
10
12
  hydrofabric_dir = Path("~/.ngiab/hydrofabric/v2.2").expanduser()
11
13
  hydrofabric_download_log = Path("~/.ngiab/hydrofabric/v2.2/download_log.json").expanduser()
@@ -31,7 +33,7 @@ class file_paths:
31
33
  template_em_config = data_sources / "em-catchment-template.yml"
32
34
  template_em_model_config = data_sources / "em-config.yml"
33
35
 
34
- def __init__(self, folder_name: str = None, output_dir: Path = None):
36
+ def __init__(self, folder_name: Optional[str] = None, output_dir: Optional[Path] = None):
35
37
  """
36
38
  Initialize the file_paths class with a the name of the output subfolder.
37
39
  OR the path to the output folder you want to use.
@@ -53,7 +55,7 @@ class file_paths:
53
55
  self.cache_dir.mkdir(parents=True, exist_ok=True)
54
56
 
55
57
  @classmethod
56
- def get_working_dir(cls) -> Path:
58
+ def get_working_dir(cls) -> Path | None:
57
59
  try:
58
60
  with open(cls.config_file, "r") as f:
59
61
  return Path(f.readline().strip()).expanduser()
@@ -67,9 +69,7 @@ class file_paths:
67
69
 
68
70
  @classmethod
69
71
  def root_output_dir(cls) -> Path:
70
- if cls.get_working_dir() is not None:
71
- return cls.get_working_dir()
72
- return Path(__file__).parent.parent.parent / "output"
72
+ return cls.get_working_dir() or Path(__file__).parent.parent.parent / "output"
73
73
 
74
74
  @property
75
75
  def subset_dir(self) -> Path:
@@ -102,7 +102,7 @@ class file_paths:
102
102
  def append_cli_command(self, command: list[str]) -> None:
103
103
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
104
104
  command_string = " ".join(command)
105
- history_file = self.metadata_dir / "cli_commands_history.txt"
105
+ history_file = self.metadata_dir / "cli_commands_history.txt"
106
106
  if not history_file.parent.exists():
107
107
  history_file.parent.mkdir(parents=True, exist_ok=True)
108
108
  with open(self.metadata_dir / "cli_commands_history.txt", "a") as f:
@@ -3,32 +3,29 @@ import multiprocessing
3
3
  import os
4
4
  import time
5
5
  import warnings
6
- from datetime import datetime
7
6
  from functools import partial
8
7
  from math import ceil
9
8
  from multiprocessing import shared_memory
10
9
  from pathlib import Path
11
-
12
- from dask.distributed import Client, LocalCluster
10
+ from typing import List, Tuple
13
11
 
14
12
  import geopandas as gpd
15
13
  import numpy as np
16
14
  import pandas as pd
17
15
  import psutil
18
16
  import xarray as xr
19
- from data_processing.file_paths import file_paths
17
+ from data_processing.dask_utils import no_cluster, use_cluster
20
18
  from data_processing.dataset_utils import validate_dataset_format
19
+ from data_processing.file_paths import file_paths
21
20
  from exactextract import exact_extract
22
21
  from exactextract.raster import NumPyRasterSource
23
22
  from rich.progress import (
24
- Progress,
25
23
  BarColumn,
24
+ Progress,
26
25
  TextColumn,
27
26
  TimeElapsedColumn,
28
27
  TimeRemainingColumn,
29
28
  )
30
- from typing import Tuple
31
-
32
29
 
33
30
  logger = logging.getLogger(__name__)
34
31
  # Suppress the specific warning from numpy to keep the cli output clean
@@ -40,13 +37,13 @@ warnings.filterwarnings(
40
37
  )
41
38
 
42
39
 
43
- def weighted_sum_of_cells(flat_raster: np.ndarray,
44
- cell_ids: np.ndarray,
45
- factors: np.ndarray) -> np.ndarray:
46
- '''
40
+ def weighted_sum_of_cells(
41
+ flat_raster: np.ndarray, cell_ids: np.ndarray, factors: np.ndarray
42
+ ) -> np.ndarray:
43
+ """
47
44
  Take an average of each forcing variable in a catchment. Create an output
48
- array initialized with zeros, and then sum up the forcing variable and
49
- divide by the sum of the cell weights to get an averaged forcing variable
45
+ array initialized with zeros, and then sum up the forcing variable and
46
+ divide by the sum of the cell weights to get an averaged forcing variable
50
47
  for the entire catchment.
51
48
 
52
49
  Parameters
@@ -65,7 +62,7 @@ def weighted_sum_of_cells(flat_raster: np.ndarray,
65
62
  An one-dimensional array, where each element corresponds to a timestep.
66
63
  Each element contains the averaged forcing value for the whole catchment
67
64
  over one timestep.
68
- '''
65
+ """
69
66
  result = np.zeros(flat_raster.shape[0])
70
67
  result = np.sum(flat_raster[:, cell_ids] * factors, axis=1)
71
68
  sum_of_weights = np.sum(factors)
@@ -73,12 +70,10 @@ def weighted_sum_of_cells(flat_raster: np.ndarray,
73
70
  return result
74
71
 
75
72
 
76
- def get_cell_weights(raster: xr.Dataset,
77
- gdf: gpd.GeoDataFrame,
78
- wkt: str) -> pd.DataFrame:
79
- '''
80
- Get the cell weights (coverage) for each cell in a divide. Coverage is
81
- defined as the fraction (a float in [0,1]) of a raster cell that overlaps
73
+ def get_cell_weights(raster: xr.Dataset, gdf: gpd.GeoDataFrame, wkt: str) -> pd.DataFrame:
74
+ """
75
+ Get the cell weights (coverage) for each cell in a divide. Coverage is
76
+ defined as the fraction (a float in [0,1]) of a raster cell that overlaps
82
77
  with the polygon in the passed gdf.
83
78
 
84
79
  Parameters
@@ -96,35 +91,37 @@ def get_cell_weights(raster: xr.Dataset,
96
91
  pd.DataFrame
97
92
  DataFrame indexed by divide_id that contains information about coverage
98
93
  for each raster cell in gridded forcing file.
99
- '''
100
- xmin = raster.x[0]
101
- xmax = raster.x[-1]
102
- ymin = raster.y[0]
103
- ymax = raster.y[-1]
94
+ """
95
+ xmin = min(raster.x)
96
+ xmax = max(raster.x)
97
+ ymin = min(raster.y)
98
+ ymax = max(raster.y)
104
99
  data_vars = list(raster.data_vars)
105
100
  rastersource = NumPyRasterSource(
106
101
  raster[data_vars[0]], srs_wkt=wkt, xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax
107
102
  )
108
- output = exact_extract(
103
+ output: pd.DataFrame = exact_extract(
109
104
  rastersource,
110
105
  gdf,
111
106
  ["cell_id", "coverage"],
112
107
  include_cols=["divide_id"],
113
108
  output="pandas",
114
- )
109
+ ) # type: ignore
115
110
  return output.set_index("divide_id")
116
111
 
117
112
 
118
113
  def add_APCP_SURFACE_to_dataset(dataset: xr.Dataset) -> xr.Dataset:
119
- '''Convert precipitation value to correct units.'''
114
+ """Convert precipitation value to correct units."""
120
115
  # precip_rate is mm/s
121
116
  # cfe says input atmosphere_water__liquid_equivalent_precipitation_rate is mm/h
122
117
  # nom says prcpnonc input is mm/s
123
118
  # technically should be kg/m^2/s at 1kg = 1l it equates to mm/s
124
119
  # nom says qinsur output is m/s, hopefully qinsur is converted to mm/h by ngen
125
120
  dataset["APCP_surface"] = dataset["precip_rate"] * 3600
126
- dataset["APCP_surface"].attrs["units"] = "mm h^-1" # ^-1 notation copied from source data
127
- dataset["APCP_surface"].attrs["source_note"] = "This is just the precip_rate variable converted to mm/h by multiplying by 3600"
121
+ dataset["APCP_surface"].attrs["units"] = "mm h^-1" # ^-1 notation copied from source data
122
+ dataset["APCP_surface"].attrs["source_note"] = (
123
+ "This is just the precip_rate variable converted to mm/h by multiplying by 3600"
124
+ )
128
125
  return dataset
129
126
 
130
127
 
@@ -132,14 +129,14 @@ def add_precip_rate_to_dataset(dataset: xr.Dataset) -> xr.Dataset:
132
129
  # the inverse of the function above
133
130
  dataset["precip_rate"] = dataset["APCP_surface"] / 3600
134
131
  dataset["precip_rate"].attrs["units"] = "mm s^-1"
135
- dataset["precip_rate"].attrs[
136
- "source_note"
137
- ] = "This is just the APCP_surface variable converted to mm/s by dividing by 3600"
132
+ dataset["precip_rate"].attrs["source_note"] = (
133
+ "This is just the APCP_surface variable converted to mm/s by dividing by 3600"
134
+ )
138
135
  return dataset
139
136
 
140
137
 
141
138
  def get_index_chunks(data: xr.DataArray) -> list[tuple[int, int]]:
142
- '''
139
+ """
143
140
  Take a DataArray and calculate the start and end index for each chunk based
144
141
  on the available memory.
145
142
 
@@ -153,7 +150,7 @@ def get_index_chunks(data: xr.DataArray) -> list[tuple[int, int]]:
153
150
  list[Tuple[int, int]]
154
151
  Each element in the list represents a chunk of data. The tuple within
155
152
  the chunk indicates the start index and end index of the chunk.
156
- '''
153
+ """
157
154
  array_memory_usage = data.nbytes
158
155
  free_memory = psutil.virtual_memory().available * 0.8 # 80% of available memory
159
156
  # limit the chunk to 20gb, makes things more stable
@@ -166,15 +163,13 @@ def get_index_chunks(data: xr.DataArray) -> list[tuple[int, int]]:
166
163
  return index_chunks
167
164
 
168
165
 
169
- def create_shared_memory(lazy_array: xr.Dataset) -> Tuple[
170
- shared_memory.SharedMemory,
171
- np.dtype,
172
- np.dtype
173
- ]:
174
- '''
175
- Create a shared memory object so that multiple processes can access loaded
166
+ def create_shared_memory(
167
+ lazy_array: xr.DataArray,
168
+ ) -> Tuple[shared_memory.SharedMemory, Tuple[int, ...], np.dtype]:
169
+ """
170
+ Create a shared memory object so that multiple processes can access loaded
176
171
  data.
177
-
172
+
178
173
  Parameters
179
174
  ----------
180
175
  lazy_array : xr.Dataset
@@ -183,22 +178,22 @@ def create_shared_memory(lazy_array: xr.Dataset) -> Tuple[
183
178
  Returns
184
179
  -------
185
180
  shared_memory.SharedMemory
186
- A specific block of memory allocated by the OS of the size of
181
+ A specific block of memory allocated by the OS of the size of
187
182
  lazy_array.
188
- np.dtype.shape
183
+ Tuple[int, ...]
189
184
  A shape object with dimensions (# timesteps, # of raster cells) in
190
185
  reference to lazy_array.
191
186
  np.dtype
192
187
  Data type of objects in lazy_array.
193
- '''
194
- logger.debug(f"Creating shared memory size {lazy_array.nbytes/ 10**6} Mb.")
188
+ """
189
+ logger.debug(f"Creating shared memory size {lazy_array.nbytes / 10**6} Mb.")
195
190
  shm = shared_memory.SharedMemory(create=True, size=lazy_array.nbytes)
196
191
  shared_array = np.ndarray(lazy_array.shape, dtype=np.float32, buffer=shm.buf)
197
192
  # if your data is not float32, xarray will do an automatic conversion here
198
193
  # which consumes a lot more memory, forcings downloaded with this tool will work
199
194
  for start, end in get_index_chunks(lazy_array):
200
- # copy data from lazy to shared memory one chunk at a time
201
- shared_array[start:end] = lazy_array[start:end]
195
+ # copy data from lazy to shared memory one chunk at a time
196
+ shared_array[start:end] = lazy_array[start:end]
202
197
 
203
198
  time, x, y = shared_array.shape
204
199
  shared_array = shared_array.reshape(time, -1)
@@ -206,14 +201,16 @@ def create_shared_memory(lazy_array: xr.Dataset) -> Tuple[
206
201
  return shm, shared_array.shape, shared_array.dtype
207
202
 
208
203
 
209
- def process_chunk_shared(variable: str,
210
- times: np.ndarray,
211
- shm_name: str,
212
- shape: np.dtype.shape,
213
- dtype: np.dtype,
214
- chunk: gpd.GeoDataFrame) -> xr.DataArray:
215
- '''
216
- Process the gridded forcings chunk loaded into a SharedMemory block.
204
+ def process_chunk_shared(
205
+ variable: str,
206
+ times: np.ndarray,
207
+ shm_name: str,
208
+ shape: Tuple[int, ...],
209
+ dtype: np.dtype,
210
+ chunk: pd.DataFrame,
211
+ ) -> xr.DataArray:
212
+ """
213
+ Process the gridded forcings chunk loaded into a SharedMemory block.
217
214
 
218
215
  Parameters
219
216
  ----------
@@ -235,7 +232,7 @@ def process_chunk_shared(variable: str,
235
232
  -------
236
233
  xr.DataArray
237
234
  Averaged forcings data for each timestep for each catchment.
238
- '''
235
+ """
239
236
  existing_shm = shared_memory.SharedMemory(name=shm_name)
240
237
  raster = np.ndarray(shape, dtype=dtype, buffer=existing_shm.buf)
241
238
  results = []
@@ -256,10 +253,10 @@ def process_chunk_shared(variable: str,
256
253
  return xr.concat(results, dim="catchment")
257
254
 
258
255
 
259
- def get_cell_weights_parallel(gdf: gpd.GeoDataFrame,
260
- input_forcings: xr.Dataset,
261
- num_partitions: int) -> pd.DataFrame:
262
- '''
256
+ def get_cell_weights_parallel(
257
+ gdf: gpd.GeoDataFrame, input_forcings: xr.Dataset, num_partitions: int
258
+ ) -> pd.DataFrame:
259
+ """
263
260
  Execute get_cell_weights with multiprocessing, with chunking for the passed
264
261
  GeoDataFrame to conserve memory usage.
265
262
 
@@ -277,29 +274,30 @@ def get_cell_weights_parallel(gdf: gpd.GeoDataFrame,
277
274
  pd.DataFrame
278
275
  DataFrame indexed by divide_id that contains information about coverage
279
276
  for each raster cell and each timestep in gridded forcing file.
280
- '''
277
+ """
281
278
  gdf_chunks = np.array_split(gdf, num_partitions)
282
- wkt = gdf.crs.to_wkt()
279
+ wkt = gdf.crs.to_wkt() # type: ignore
283
280
  one_timestep = input_forcings.isel(time=0).compute()
284
281
  with multiprocessing.Pool() as pool:
285
282
  args = [(one_timestep, gdf_chunk, wkt) for gdf_chunk in gdf_chunks]
286
283
  catchments = pool.starmap(get_cell_weights, args)
287
284
  return pd.concat(catchments)
288
285
 
286
+
289
287
  def get_units(dataset: xr.Dataset) -> dict:
290
- '''
288
+ """
291
289
  Return dictionary of units for each variable in dataset.
292
-
290
+
293
291
  Parameters
294
292
  ----------
295
293
  dataset : xr.Dataset
296
294
  Dataset with variables and units.
297
-
295
+
298
296
  Returns
299
297
  -------
300
- dict
298
+ dict
301
299
  {variable name: unit}
302
- '''
300
+ """
303
301
  units = {}
304
302
  for var in dataset.data_vars:
305
303
  if dataset[var].attrs["units"]:
@@ -307,12 +305,13 @@ def get_units(dataset: xr.Dataset) -> dict:
307
305
  return units
308
306
 
309
307
 
308
+ @no_cluster
310
309
  def compute_zonal_stats(
311
310
  gdf: gpd.GeoDataFrame, gridded_data: xr.Dataset, forcings_dir: Path
312
311
  ) -> None:
313
- '''
314
- Compute zonal statistics in parallel for all timesteps over all desired
315
- catchments. Create chunks of catchments and within those, chunks of
312
+ """
313
+ Compute zonal statistics in parallel for all timesteps over all desired
314
+ catchments. Create chunks of catchments and within those, chunks of
316
315
  timesteps for memory management.
317
316
 
318
317
  Parameters
@@ -323,7 +322,7 @@ def compute_zonal_stats(
323
322
  Gridded forcing data that intersects with desired catchments.
324
323
  forcings_dir : Path
325
324
  Path to directory where outputs are to be stored.
326
- '''
325
+ """
327
326
  logger.info("Computing zonal stats in parallel for all timesteps")
328
327
  timer_start = time.time()
329
328
  num_partitions = multiprocessing.cpu_count() - 1
@@ -333,7 +332,7 @@ def compute_zonal_stats(
333
332
  catchments = get_cell_weights_parallel(gdf, gridded_data, num_partitions)
334
333
  units = get_units(gridded_data)
335
334
 
336
- cat_chunks = np.array_split(catchments, num_partitions)
335
+ cat_chunks: List[pd.DataFrame] = np.array_split(catchments, num_partitions) # type: ignore
337
336
 
338
337
  progress = Progress(
339
338
  TextColumn("[progress.description]{task.description}"),
@@ -352,25 +351,28 @@ def compute_zonal_stats(
352
351
  "[cyan]Processing variables...", total=len(gridded_data.data_vars), elapsed=0
353
352
  )
354
353
  progress.start()
355
- for variable in list(gridded_data.data_vars):
354
+ for data_var_name in list(gridded_data.data_vars):
355
+ data_var_name: str
356
356
  progress.update(variable_task, advance=1)
357
- progress.update(variable_task, description=f"Processing {variable}")
357
+ progress.update(variable_task, description=f"Processing {data_var_name}")
358
358
 
359
359
  # to make sure this fits in memory, we need to chunk the data
360
- time_chunks = get_index_chunks(gridded_data[variable])
360
+ time_chunks = get_index_chunks(gridded_data[data_var_name])
361
361
  chunk_task = progress.add_task("[purple] processing chunks", total=len(time_chunks))
362
362
  for i, times in enumerate(time_chunks):
363
363
  progress.update(chunk_task, advance=1)
364
364
  start, end = times
365
365
  # select the chunk of time we want to process
366
- data_chunk = gridded_data[variable].isel(time=slice(start, end))
366
+ data_chunk = gridded_data[data_var_name].isel(time=slice(start, end))
367
367
  # put it in shared memory
368
368
  shm, shape, dtype = create_shared_memory(data_chunk)
369
369
  times = data_chunk.time.values
370
370
  # create a partial function to pass to the multiprocessing pool
371
- partial_process_chunk = partial(process_chunk_shared,variable,times,shm.name,shape,dtype)
371
+ partial_process_chunk = partial(
372
+ process_chunk_shared, data_var_name, times, shm.name, shape, dtype
373
+ )
372
374
 
373
- logger.debug(f"Processing variable: {variable}")
375
+ logger.debug(f"Processing variable: {data_var_name}")
374
376
  # process the chunks of catchments in parallel
375
377
  with multiprocessing.Pool(num_partitions) as pool:
376
378
  variable_data = pool.map(partial_process_chunk, cat_chunks)
@@ -378,24 +380,24 @@ def compute_zonal_stats(
378
380
  # clean up the shared memory
379
381
  shm.close()
380
382
  shm.unlink()
381
- logger.debug(f"Processed variable: {variable}")
383
+ logger.debug(f"Processed variable: {data_var_name}")
382
384
  concatenated_da = xr.concat(variable_data, dim="catchment")
383
385
  # delete the data to free up memory
384
386
  del variable_data
385
- logger.debug(f"Concatenated variable: {variable}")
387
+ logger.debug(f"Concatenated variable: {data_var_name}")
386
388
  # write this to disk now to save memory
387
389
  # xarray will monitor memory usage, but it doesn't account for the shared memory used to store the raster
388
390
  # This reduces memory usage by about 60%
389
- concatenated_da.to_dataset(name=variable).to_netcdf(
390
- forcings_dir / "temp" / f"{variable}_timechunk_{i}.nc"
391
+ concatenated_da.to_dataset(name=data_var_name).to_netcdf(
392
+ forcings_dir / "temp" / f"{data_var_name}_timechunk_{i}.nc"
391
393
  )
392
394
  # Merge the chunks back together
393
395
  datasets = [
394
- xr.open_dataset(forcings_dir / "temp" / f"{variable}_timechunk_{i}.nc")
396
+ xr.open_dataset(forcings_dir / "temp" / f"{data_var_name}_timechunk_{i}.nc")
395
397
  for i in range(len(time_chunks))
396
398
  ]
397
399
  result = xr.concat(datasets, dim="time")
398
- result.to_netcdf(forcings_dir / "temp" / f"{variable}.nc")
400
+ result.to_netcdf(forcings_dir / "temp" / f"{data_var_name}.nc")
399
401
  # close the datasets
400
402
  result.close()
401
403
  _ = [dataset.close() for dataset in datasets]
@@ -413,8 +415,9 @@ def compute_zonal_stats(
413
415
  write_outputs(forcings_dir, units)
414
416
 
415
417
 
418
+ @use_cluster
416
419
  def write_outputs(forcings_dir: Path, units: dict) -> None:
417
- '''
420
+ """
418
421
  Write outputs to disk in the form of a NetCDF file, using dask clusters to
419
422
  facilitate parallel computing.
420
423
 
@@ -423,20 +426,13 @@ def write_outputs(forcings_dir: Path, units: dict) -> None:
423
426
  forcings_dir : Path
424
427
  Path to directory where outputs are to be stored.
425
428
  variables : dict
426
- Preset dictionary where the keys are forcing variable names and the
429
+ Preset dictionary where the keys are forcing variable names and the
427
430
  values are units.
428
431
  units : dict
429
- Dictionary where the keys are forcing variable names and the values are
432
+ Dictionary where the keys are forcing variable names and the values are
430
433
  units. Differs from variables, as this dictionary depends on the gridded
431
434
  forcing dataset.
432
- '''
433
-
434
- # start a dask cluster if there isn't one already running
435
- try:
436
- client = Client.current()
437
- except ValueError:
438
- cluster = LocalCluster()
439
- client = Client(cluster)
435
+ """
440
436
  temp_forcings_dir = forcings_dir / "temp"
441
437
  # Combine all variables into a single dataset using dask
442
438
  results = [xr.open_dataset(file, chunks="auto") for file in temp_forcings_dir.glob("*.nc")]
@@ -473,14 +469,18 @@ def write_outputs(forcings_dir: Path, units: dict) -> None:
473
469
  time_array = (
474
470
  final_ds.time.astype("datetime64[s]").astype(np.int64).values // 10**9
475
471
  ) ## convert from ns to s
476
- time_array = time_array.astype(np.int32) ## convert to int32 to save space
477
- final_ds = final_ds.drop_vars(["catchment", "time"]) ## drop the original time and catchment vars
478
- final_ds = final_ds.rename_dims({"catchment": "catchment-id"}) # rename the catchment dimension
472
+ time_array = time_array.astype(np.int32) ## convert to int32 to save space
473
+ final_ds = final_ds.drop_vars(
474
+ ["catchment", "time"]
475
+ ) ## drop the original time and catchment vars
476
+ final_ds = final_ds.rename_dims({"catchment": "catchment-id"}) # rename the catchment dimension
479
477
  # add the time as a 2d data var, yes this is wasting disk space.
480
478
  final_ds["Time"] = (("catchment-id", "time"), [time_array for _ in range(len(final_ds["ids"]))])
481
479
  # set the time unit
482
480
  final_ds["Time"].attrs["units"] = "s"
483
- final_ds["Time"].attrs["epoch_start"] = "01/01/1970 00:00:00" # not needed but suppresses the ngen warning
481
+ final_ds["Time"].attrs["epoch_start"] = (
482
+ "01/01/1970 00:00:00" # not needed but suppresses the ngen warning
483
+ )
484
484
 
485
485
  final_ds.to_netcdf(forcings_dir / "forcings.nc", engine="netcdf4")
486
486
  # close the datasets
@@ -508,7 +508,7 @@ def setup_directories(cat_id: str) -> file_paths:
508
508
  def create_forcings(dataset: xr.Dataset, output_folder_name: str) -> None:
509
509
  validate_dataset_format(dataset)
510
510
  forcing_paths = setup_directories(output_folder_name)
511
- print(f"forcing path {output_folder_name} {forcing_paths.forcings_dir}")
511
+ logger.debug(f"forcing path {output_folder_name} {forcing_paths.forcings_dir}")
512
512
  gdf = gpd.read_file(forcing_paths.geopackage_path, layer="divides")
513
513
  logger.debug(f"gdf bounds: {gdf.total_bounds}")
514
514
  gdf = gdf.to_crs(dataset.crs)