ngiab-data-preprocess 4.5.0__py3-none-any.whl → 4.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,21 +1,21 @@
1
1
  import json
2
2
  import logging
3
3
  import multiprocessing
4
+ import os
4
5
  import shutil
5
6
  import sqlite3
6
7
  from datetime import datetime
7
8
  from pathlib import Path
8
9
  from typing import Dict, Optional
9
- import psutil
10
- import os
11
10
 
12
11
  import numpy as np
13
12
  import pandas
13
+ import psutil
14
14
  import requests
15
15
  import s3fs
16
16
  import xarray as xr
17
17
  from data_processing.dask_utils import temp_cluster
18
- from data_processing.file_paths import file_paths
18
+ from data_processing.file_paths import FilePaths
19
19
  from data_processing.gpkg_utils import (
20
20
  get_cat_to_nhd_feature_id,
21
21
  get_table_crs_short,
@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
27
27
 
28
28
 
29
29
  @temp_cluster
30
- def get_approximate_gw_storage(paths: file_paths, start_date: datetime) -> Dict[str, int]:
30
+ def get_approximate_gw_storage(paths: FilePaths, start_date: datetime) -> Dict[str, int]:
31
31
  # get the gw levels from the NWM output on a given start date
32
32
  # this kind of works in place of warmstates for now
33
33
  year = start_date.strftime("%Y")
@@ -50,11 +50,9 @@ def get_approximate_gw_storage(paths: file_paths, start_date: datetime) -> Dict[
50
50
  return water_levels
51
51
 
52
52
 
53
- def make_cfe_config(
54
- divide_conf_df: pandas.DataFrame, files: file_paths, water_levels: dict
55
- ) -> None:
53
+ def make_cfe_config(divide_conf_df: pandas.DataFrame, files: FilePaths, water_levels: dict) -> None:
56
54
  """Parses parameters from NOAHOWP_CFE DataFrame and returns a dictionary of catchment configurations."""
57
- with open(file_paths.template_cfe_config, "r") as f:
55
+ with open(FilePaths.template_cfe_config, "r") as f:
58
56
  cfe_template = f.read()
59
57
  cat_config_dir = files.config_dir / "cat_config" / "CFE"
60
58
  cat_config_dir.mkdir(parents=True, exist_ok=True)
@@ -92,7 +90,7 @@ def make_noahowp_config(
92
90
  ) -> None:
93
91
  start_datetime = start_time.strftime("%Y%m%d%H%M")
94
92
  end_datetime = end_time.strftime("%Y%m%d%H%M")
95
- with open(file_paths.template_noahowp_config, "r") as file:
93
+ with open(FilePaths.template_noahowp_config, "r") as file:
96
94
  template = file.read()
97
95
 
98
96
  cat_config_dir = base_dir / "cat_config" / "NOAH-OWP-M"
@@ -137,7 +135,7 @@ def get_model_attributes(hydrofabric: Path) -> pandas.DataFrame:
137
135
  def make_lstm_config(
138
136
  hydrofabric: Path,
139
137
  output_dir: Path,
140
- template_path: Path = file_paths.template_lstm_config,
138
+ template_path: Path = FilePaths.template_lstm_config,
141
139
  ):
142
140
  # test if modspatialite is available
143
141
 
@@ -169,7 +167,7 @@ def make_lstm_config(
169
167
  lat=row["latitude"],
170
168
  lon=row["longitude"],
171
169
  slope_mean=row["mean_slope_mpkm"],
172
- elevation_mean=row["mean.elevation"] / 1000, # convert mm in hf to m
170
+ elevation_mean=row["mean.elevation"] / 100, # convert cm in hf to m
173
171
  )
174
172
  )
175
173
 
@@ -177,17 +175,19 @@ def make_lstm_config(
177
175
  def configure_troute(
178
176
  cat_id: str, config_dir: Path, start_time: datetime, end_time: datetime
179
177
  ) -> None:
180
- with open(file_paths.template_troute_config, "r") as file:
178
+ with open(FilePaths.template_troute_config, "r") as file:
181
179
  troute_template = file.read()
182
180
  time_step_size = 300
183
- gpkg_file_path=f"{config_dir}/{cat_id}_subset.gpkg"
181
+ gpkg_file_path = f"{config_dir}/{cat_id}_subset.gpkg"
184
182
  nts = (end_time - start_time).total_seconds() / time_step_size
185
183
  with sqlite3.connect(gpkg_file_path) as conn:
186
184
  ncats_df = pandas.read_sql_query("SELECT COUNT(id) FROM 'divides';", conn)
187
- ncats = ncats_df['COUNT(id)'][0]
185
+ ncats = ncats_df["COUNT(id)"][0]
188
186
 
189
- est_bytes_required = nts * ncats * 45 # extremely rough calculation based on about 3 tests :)
190
- local_ram_available = 0.8 * psutil.virtual_memory().available # buffer to not accidentally explode machine
187
+ est_bytes_required = nts * ncats * 45 # extremely rough calculation based on about 3 tests :)
188
+ local_ram_available = (
189
+ 0.8 * psutil.virtual_memory().available
190
+ ) # buffer to not accidentally explode machine
191
191
 
192
192
  if est_bytes_required > local_ram_available:
193
193
  max_loop_size = nts // (est_bytes_required // local_ram_available)
@@ -210,7 +210,7 @@ def configure_troute(
210
210
  start_datetime=start_time.strftime("%Y-%m-%d %H:%M:%S"),
211
211
  nts=nts,
212
212
  max_loop_size=max_loop_size,
213
- binary_nexus_file_folder_comment=binary_nexus_file_folder_comment
213
+ binary_nexus_file_folder_comment=binary_nexus_file_folder_comment,
214
214
  )
215
215
 
216
216
  with open(config_dir / "troute.yaml", "w") as file:
@@ -231,11 +231,26 @@ def make_ngen_realization_json(
231
231
  json.dump(realization, file, indent=4)
232
232
 
233
233
 
234
- def create_lstm_realization(cat_id: str, start_time: datetime, end_time: datetime):
235
- paths = file_paths(cat_id)
236
- template_path = file_paths.template_lstm_realization_config
234
+ def create_lstm_realization(
235
+ cat_id: str, start_time: datetime, end_time: datetime, use_rust: bool = False
236
+ ):
237
+ paths = FilePaths(cat_id)
238
+ realization_path = paths.config_dir / "realization.json"
237
239
  configure_troute(cat_id, paths.config_dir, start_time, end_time)
238
- make_ngen_realization_json(paths.config_dir, template_path, start_time, end_time)
240
+ # python version of the lstm
241
+ python_template_path = FilePaths.template_lstm_realization_config
242
+ make_ngen_realization_json(paths.config_dir, python_template_path, start_time, end_time)
243
+ realization_path.rename(paths.config_dir / "python_lstm_real.json")
244
+ # rust version of the lstm
245
+ rust_template_path = FilePaths.template_lstm_rust_realization_config
246
+ make_ngen_realization_json(paths.config_dir, rust_template_path, start_time, end_time)
247
+ realization_path.rename(paths.config_dir / "rust_lstm_real.json")
248
+
249
+ if use_rust:
250
+ (paths.config_dir / "rust_lstm_real.json").rename(realization_path)
251
+ else:
252
+ (paths.config_dir / "python_lstm_real.json").rename(realization_path)
253
+
239
254
  make_lstm_config(paths.geopackage_path, paths.config_dir)
240
255
  # create some partitions for parallelization
241
256
  paths.setup_run_folders()
@@ -248,7 +263,7 @@ def create_realization(
248
263
  use_nwm_gw: bool = False,
249
264
  gage_id: Optional[str] = None,
250
265
  ):
251
- paths = file_paths(cat_id)
266
+ paths = FilePaths(cat_id)
252
267
 
253
268
  template_path = paths.template_cfe_nowpm_realization_config
254
269
 
@@ -263,6 +278,8 @@ def create_realization(
263
278
  with open(template_path, "w") as f:
264
279
  json.dump(new_template, f)
265
280
  logger.info(f"downloaded calibrated parameters for {gage_id}")
281
+ else:
282
+ logger.warning(f"could not download parameters for {gage_id}, using default template")
266
283
 
267
284
  conf_df = get_model_attributes(paths.geopackage_path)
268
285
 
@@ -2,14 +2,13 @@ import logging
2
2
  import os
3
3
  from datetime import datetime
4
4
  from pathlib import Path
5
- from typing import List, Literal, Optional, Tuple, Union
5
+ from typing import Literal, Tuple, Union
6
6
 
7
7
  import geopandas as gpd
8
8
  import numpy as np
9
9
  import xarray as xr
10
10
  from dask.distributed import Client, Future, progress
11
11
  from data_processing.dask_utils import no_cluster, temp_cluster
12
- from xarray.core.types import InterpOptions
13
12
 
14
13
  logger = logging.getLogger(__name__)
15
14
 
@@ -116,78 +115,6 @@ def clip_dataset_to_bounds(
116
115
  logger.info("Selected time range and clipped to bounds")
117
116
  return dataset
118
117
 
119
-
120
- @no_cluster
121
- def interpolate_nan_values(
122
- dataset: xr.Dataset,
123
- variables: Optional[List[str]] = None,
124
- dim: str = "time",
125
- method: InterpOptions = "nearest",
126
- fill_value: str = "extrapolate",
127
- ) -> bool:
128
- """
129
- Interpolates NaN values in specified (or all numeric time-dependent)
130
- variables of an xarray.Dataset. Operates inplace on the dataset.
131
-
132
- Parameters
133
- ----------
134
- dataset : xr.Dataset
135
- The input dataset.
136
- variables : Optional[List[str]], optional
137
- A list of variable names to process. If None (default),
138
- all numeric variables containing the specified dimension will be processed.
139
- dim : str, optional
140
- The dimension along which to interpolate (default is "time").
141
- method : str, optional
142
- Interpolation method to use (e.g., "linear", "nearest", "cubic").
143
- Default is "nearest".
144
- fill_value : str, optional
145
- Method for filling NaNs at the start/end of the series after interpolation.
146
- Set to "extrapolate" to fill with the nearest valid value when using 'nearest' or 'linear'.
147
- Default is "extrapolate".
148
- """
149
- interpolation_used = False
150
- for name, var in dataset.data_vars.items():
151
- # if the variable is non-numeric, skip
152
- if not np.issubdtype(var.dtype, np.number):
153
- continue
154
- # if there are no NANs, skip
155
- if not var.isnull().any().compute():
156
- continue
157
-
158
- dataset[name] = var.interpolate_na(
159
- dim=dim,
160
- method=method,
161
- fill_value=fill_value if method in ["nearest", "linear"] else None,
162
- )
163
- interpolation_used = True
164
- return interpolation_used
165
-
166
-
167
- @no_cluster
168
- def save_dataset_no_cluster(
169
- ds_to_save: xr.Dataset,
170
- target_path: Path,
171
- engine: Literal["netcdf4", "scipy", "h5netcdf"] = "h5netcdf",
172
- ):
173
- """
174
- This explicitly does not use dask distributed.
175
- Helper function to compute and save an xarray.Dataset to a NetCDF file.
176
- Uses a temporary file and rename for avoid leaving a half written file.
177
- """
178
- if not target_path.parent.exists():
179
- target_path.parent.mkdir(parents=True, exist_ok=True)
180
-
181
- temp_file_path = target_path.with_name(target_path.name + ".saving.nc")
182
- if temp_file_path.exists():
183
- os.remove(temp_file_path)
184
-
185
- ds_to_save.to_netcdf(temp_file_path, engine=engine, compute=True)
186
-
187
- os.rename(str(temp_file_path), str(target_path))
188
- logger.info(f"Successfully saved data to: {target_path}")
189
-
190
-
191
118
  @temp_cluster
192
119
  def save_dataset(
193
120
  ds_to_save: xr.Dataset,
@@ -195,7 +122,8 @@ def save_dataset(
195
122
  engine: Literal["netcdf4", "scipy", "h5netcdf"] = "h5netcdf",
196
123
  ):
197
124
  """
198
- Helper function to compute and save an xarray.Dataset to a NetCDF file.
125
+ Helper function to compute and save an xarray.Dataset (specifically, the raw
126
+ forcing data) to a NetCDF file.
199
127
  Uses a temporary file and rename for atomicity.
200
128
  """
201
129
  if not target_path.parent.exists():
@@ -221,7 +149,7 @@ def save_dataset(
221
149
 
222
150
  @no_cluster
223
151
  def save_to_cache(
224
- stores: xr.Dataset, cached_nc_path: Path, interpolate_nans: bool = True
152
+ stores: xr.Dataset, cached_nc_path: Path
225
153
  ) -> xr.Dataset:
226
154
  """
227
155
  Compute the store and save it to a cached netCDF file. This is not required but will save time and bandwidth.
@@ -236,16 +164,6 @@ def save_to_cache(
236
164
  # save dataset locally before manipulating it
237
165
  save_dataset(stores, cached_nc_path)
238
166
 
239
- if interpolate_nans:
240
- stores = xr.open_mfdataset(
241
- cached_nc_path,
242
- parallel=True,
243
- engine="h5netcdf",
244
- )
245
- was_interpolated = interpolate_nan_values(dataset=stores)
246
- if was_interpolated:
247
- save_dataset_no_cluster(stores, cached_nc_path)
248
-
249
167
  stores = xr.open_mfdataset(cached_nc_path, parallel=True, engine="h5netcdf")
250
168
  return stores
251
169
 
@@ -3,7 +3,7 @@ from pathlib import Path
3
3
  from typing import Optional
4
4
 
5
5
 
6
- class file_paths:
6
+ class FilePaths:
7
7
  """
8
8
  This class contains all of the file paths used in the data processing
9
9
  workflow.
@@ -29,13 +29,14 @@ class file_paths:
29
29
  template_troute_config = data_sources / "ngen-routing-template.yaml"
30
30
  template_cfe_nowpm_realization_config = data_sources / "cfe-nowpm-realization-template.json"
31
31
  template_lstm_realization_config = data_sources / "lstm-realization-template.json"
32
+ template_lstm_rust_realization_config = data_sources / "lstm-rust-realization-template.json"
32
33
  template_noahowp_config = data_sources / "noah-owp-modular-init.namelist.input"
33
34
  template_cfe_config = data_sources / "cfe-template.ini"
34
35
  template_lstm_config = data_sources / "lstm-catchment-template.yml"
35
36
 
36
37
  def __init__(self, folder_name: Optional[str] = None, output_dir: Optional[Path] = None):
37
38
  """
38
- Initialize the file_paths class with a the name of the output subfolder.
39
+ Initialize the FilePaths class with a the name of the output subfolder.
39
40
  OR the path to the output folder you want to use.
40
41
  use one or the other, not both
41
42
 
@@ -49,8 +50,8 @@ class file_paths:
49
50
  self.folder_name = folder_name
50
51
  self.output_dir = self.root_output_dir() / folder_name
51
52
  if output_dir:
52
- self.output_dir = output_dir
53
- self.folder_name = str(output_dir.stem)
53
+ self.output_dir = Path(output_dir)
54
+ self.folder_name = self.output_dir.stem
54
55
 
55
56
  self.cache_dir.mkdir(parents=True, exist_ok=True)
56
57
 
@@ -89,7 +90,13 @@ class file_paths:
89
90
 
90
91
  @property
91
92
  def metadata_dir(self) -> Path:
92
- return self.subset_dir / "metadata"
93
+ meta_dir = self.subset_dir / "metadata"
94
+ meta_dir.mkdir(parents=True, exist_ok=True)
95
+ return meta_dir
96
+
97
+ @property
98
+ def forcing_progress_file(self) -> Path:
99
+ return self.metadata_dir / "forcing_progress.json"
93
100
 
94
101
  @property
95
102
  def geopackage_path(self) -> Path:
@@ -1,3 +1,4 @@
1
+ import json
1
2
  import logging
2
3
  import multiprocessing
3
4
  import os
@@ -16,7 +17,7 @@ import psutil
16
17
  import xarray as xr
17
18
  from data_processing.dask_utils import no_cluster, use_cluster
18
19
  from data_processing.dataset_utils import validate_dataset_format
19
- from data_processing.file_paths import file_paths
20
+ from data_processing.file_paths import FilePaths
20
21
  from exactextract import exact_extract
21
22
  from exactextract.raster import NumPyRasterSource
22
23
  from rich.progress import (
@@ -26,6 +27,7 @@ from rich.progress import (
26
27
  TimeElapsedColumn,
27
28
  TimeRemainingColumn,
28
29
  )
30
+ from xarray.core.types import InterpOptions
29
31
 
30
32
  logger = logging.getLogger(__name__)
31
33
  # Suppress the specific warning from numpy to keep the cli output clean
@@ -63,10 +65,16 @@ def weighted_sum_of_cells(
63
65
  Each element contains the averaged forcing value for the whole catchment
64
66
  over one timestep.
65
67
  """
66
- result = np.zeros(flat_raster.shape[0])
68
+ # early exit for divide by zero
69
+ if np.all(factors == 0):
70
+ return np.zeros(flat_raster.shape[0])
71
+
72
+ selected_cells = flat_raster[:, cell_ids]
73
+ has_nan = np.isnan(selected_cells).any(axis=1)
67
74
  result = np.sum(flat_raster[:, cell_ids] * factors, axis=1)
68
75
  sum_of_weights = np.sum(factors)
69
76
  result /= sum_of_weights
77
+ result[has_nan] = np.nan
70
78
  return result
71
79
 
72
80
 
@@ -305,6 +313,46 @@ def get_units(dataset: xr.Dataset) -> dict:
305
313
  return units
306
314
 
307
315
 
316
+ def interpolate_nan_values(
317
+ dataset: xr.Dataset,
318
+ dim: str = "time",
319
+ method: InterpOptions = "linear",
320
+ fill_value: str = "extrapolate",
321
+ ) -> bool:
322
+ """
323
+ Interpolates NaN values in specified (or all numeric time-dependent)
324
+ variables of an xarray.Dataset. Operates inplace on the dataset.
325
+
326
+ Parameters
327
+ ----------
328
+ dataset : xr.Dataset
329
+ The input dataset.
330
+ dim : str, optional
331
+ The dimension along which to interpolate (default is "time").
332
+ method : str, optional
333
+ Interpolation method to use (e.g., "linear", "nearest", "cubic").
334
+ Default is "linear".
335
+ fill_value : str, optional
336
+ Method for filling NaNs at the start/end of the series after interpolation.
337
+ Set to "extrapolate" to fill with the nearest valid value when using 'nearest' or 'linear'.
338
+ Default is "extrapolate".
339
+ """
340
+ for name, var in dataset.data_vars.items():
341
+ # if the variable is non-numeric, skip
342
+ if not np.issubdtype(var.dtype, np.number):
343
+ continue
344
+ # if there are no NANs, skip
345
+ if not var.isnull().any().compute():
346
+ continue
347
+ logger.info("Interpolating NaN values in %s", name)
348
+
349
+ dataset[name] = var.interpolate_na(
350
+ dim=dim,
351
+ method=method,
352
+ fill_value=fill_value if method in ["nearest", "linear"] else None,
353
+ )
354
+
355
+
308
356
  @no_cluster
309
357
  def compute_zonal_stats(
310
358
  gdf: gpd.GeoDataFrame, gridded_data: xr.Dataset, forcings_dir: Path
@@ -334,6 +382,17 @@ def compute_zonal_stats(
334
382
 
335
383
  cat_chunks: List[pd.DataFrame] = np.array_split(catchments, num_partitions) # type: ignore
336
384
 
385
+ progress_file = FilePaths(output_dir=forcings_dir.parent.stem).forcing_progress_file
386
+ ex_var_name = list(gridded_data.data_vars)[0]
387
+ example_time_chunks = get_index_chunks(gridded_data[ex_var_name])
388
+ all_steps = len(example_time_chunks) * len(gridded_data.data_vars)
389
+ logger.info(
390
+ f"Total steps: {all_steps}, Number of time chunks: {len(example_time_chunks)}, Number of variables: {len(gridded_data.data_vars)}"
391
+ )
392
+ steps_completed = 0
393
+ with open(progress_file, "w") as f:
394
+ json.dump({"total_steps": all_steps, "steps_completed": steps_completed}, f)
395
+
337
396
  progress = Progress(
338
397
  TextColumn("[progress.description]{task.description}"),
339
398
  BarColumn(),
@@ -391,6 +450,9 @@ def compute_zonal_stats(
391
450
  concatenated_da.to_dataset(name=data_var_name).to_netcdf(
392
451
  forcings_dir / "temp" / f"{data_var_name}_timechunk_{i}.nc"
393
452
  )
453
+ steps_completed += 1
454
+ with open(progress_file, "w") as f:
455
+ json.dump({"total_steps": all_steps, "steps_completed": steps_completed}, f)
394
456
  # Merge the chunks back together
395
457
  datasets = [
396
458
  xr.open_dataset(forcings_dir / "temp" / f"{data_var_name}_timechunk_{i}.nc")
@@ -413,6 +475,8 @@ def compute_zonal_stats(
413
475
  f"Forcing generation complete! Zonal stats computed in {time.time() - timer_start:2f} seconds"
414
476
  )
415
477
  write_outputs(forcings_dir, units)
478
+ time.sleep(1) # wait for progress bar to update
479
+ progress_file.unlink()
416
480
 
417
481
 
418
482
  @use_cluster
@@ -455,7 +519,6 @@ def write_outputs(forcings_dir: Path, units: dict) -> None:
455
519
  for var in final_ds.data_vars:
456
520
  final_ds[var] = final_ds[var].astype(np.float32)
457
521
 
458
- logger.info("Saving to disk")
459
522
  # The format for the netcdf is to support a legacy format
460
523
  # which is why it's a little "unorthodox"
461
524
  # There are no coordinates, just dimensions, catchment ids are stored in a 1d data var
@@ -481,7 +544,9 @@ def write_outputs(forcings_dir: Path, units: dict) -> None:
481
544
  final_ds["Time"].attrs["epoch_start"] = (
482
545
  "01/01/1970 00:00:00" # not needed but suppresses the ngen warning
483
546
  )
547
+ interpolate_nan_values(final_ds)
484
548
 
549
+ logger.info("Saving to disk")
485
550
  final_ds.to_netcdf(forcings_dir / "forcings.nc", engine="netcdf4")
486
551
  # close the datasets
487
552
  _ = [result.close() for result in results]
@@ -493,8 +558,8 @@ def write_outputs(forcings_dir: Path, units: dict) -> None:
493
558
  temp_forcings_dir.rmdir()
494
559
 
495
560
 
496
- def setup_directories(cat_id: str) -> file_paths:
497
- forcing_paths = file_paths(cat_id)
561
+ def setup_directories(cat_id: str) -> FilePaths:
562
+ forcing_paths = FilePaths(cat_id)
498
563
  # delete everything in the forcing folder except the cached nc file
499
564
  for file in forcing_paths.forcings_dir.glob("*.*"):
500
565
  if file != forcing_paths.cached_nc_file:
@@ -2,10 +2,10 @@ import logging
2
2
  import sqlite3
3
3
  import struct
4
4
  from pathlib import Path
5
- from typing import List, Tuple, Dict
5
+ from typing import Dict, List, Tuple
6
6
 
7
7
  import pyproj
8
- from data_processing.file_paths import file_paths
8
+ from data_processing.file_paths import FilePaths
9
9
  from shapely.geometry import Point
10
10
  from shapely.geometry.base import BaseGeometry
11
11
  from shapely.ops import transform
@@ -28,7 +28,7 @@ class GeoPackage:
28
28
  self.conn.close()
29
29
 
30
30
 
31
- def verify_indices(gpkg: Path = file_paths.conus_hydrofabric) -> None:
31
+ def verify_indices(gpkg: Path = FilePaths.conus_hydrofabric) -> None:
32
32
  """
33
33
  Verify that the indices in the specified geopackage are correct.
34
34
  If they are not, create the correct indices.
@@ -74,7 +74,7 @@ def create_empty_gpkg(gpkg: Path) -> None:
74
74
  """
75
75
  Create an empty geopackage with the necessary tables and indices.
76
76
  """
77
- with open(file_paths.template_sql) as f:
77
+ with open(FilePaths.template_sql) as f:
78
78
  sql_script = f.read()
79
79
 
80
80
  with sqlite3.connect(gpkg) as conn:
@@ -85,7 +85,7 @@ def add_triggers_to_gpkg(gpkg: Path) -> None:
85
85
  """
86
86
  Adds geopackage triggers required to maintain spatial index integrity
87
87
  """
88
- with open(file_paths.triggers_sql) as f:
88
+ with open(FilePaths.triggers_sql) as f:
89
89
  triggers = f.read()
90
90
  with sqlite3.connect(gpkg) as conn:
91
91
  conn.executescript(triggers)
@@ -93,8 +93,6 @@ def add_triggers_to_gpkg(gpkg: Path) -> None:
93
93
  logger.debug(f"Added triggers to subset gpkg {gpkg}")
94
94
 
95
95
 
96
-
97
-
98
96
  def blob_to_geometry(blob: bytes) -> BaseGeometry | None:
99
97
  """
100
98
  Convert a blob to a geometry.
@@ -178,7 +176,7 @@ def get_catid_from_point(coords: Dict[str, float]) -> str:
178
176
 
179
177
  """
180
178
  logger.info(f"Getting catid for {coords}")
181
- q = file_paths.conus_hydrofabric
179
+ q = FilePaths.conus_hydrofabric
182
180
  point = Point(coords["lng"], coords["lat"])
183
181
  point = convert_to_5070(point)
184
182
  with sqlite3.connect(q) as con:
@@ -261,7 +259,7 @@ def update_geopackage_metadata(gpkg: Path) -> None:
261
259
  Update the contents of the gpkg_contents table in the specified geopackage.
262
260
  """
263
261
  # table_name, data_type, identifier, description, last_change, min_x, min_y, max_x, max_y, srs_id
264
- tables = get_feature_tables(file_paths.conus_hydrofabric)
262
+ tables = get_feature_tables(FilePaths.conus_hydrofabric)
265
263
  con = sqlite3.connect(gpkg)
266
264
  for table in tables:
267
265
  min_x = con.execute(f"SELECT MIN(minx) FROM rtree_{table}_geom").fetchone()[0]
@@ -336,7 +334,7 @@ def subset_table_by_vpu(table: str, vpu: str, hydrofabric: Path, subset_gpkg_nam
336
334
 
337
335
  insert_data(dest_db, table, contents)
338
336
 
339
- if table in get_feature_tables(file_paths.conus_hydrofabric):
337
+ if table in get_feature_tables(FilePaths.conus_hydrofabric):
340
338
  fids = [str(x[0]) for x in contents]
341
339
  copy_rTree_tables(table, fids, source_db, dest_db)
342
340
 
@@ -389,7 +387,7 @@ def subset_table(table: str, ids: List[str], hydrofabric: Path, subset_gpkg_name
389
387
 
390
388
  insert_data(dest_db, table, contents)
391
389
 
392
- if table in get_feature_tables(file_paths.conus_hydrofabric):
390
+ if table in get_feature_tables(FilePaths.conus_hydrofabric):
393
391
  fids = [str(x[0]) for x in contents]
394
392
  copy_rTree_tables(table, fids, source_db, dest_db)
395
393
 
@@ -436,7 +434,7 @@ def get_table_crs(gpkg: str, table: str) -> str:
436
434
  return crs
437
435
 
438
436
 
439
- def get_cat_from_gage_id(gage_id: str, gpkg: Path = file_paths.conus_hydrofabric) -> str:
437
+ def get_cat_from_gage_id(gage_id: str, gpkg: Path = FilePaths.conus_hydrofabric) -> str:
440
438
  """
441
439
  Get the catchment id associated with a gage id.
442
440
 
@@ -476,7 +474,7 @@ def get_cat_from_gage_id(gage_id: str, gpkg: Path = file_paths.conus_hydrofabric
476
474
  return cat_id
477
475
 
478
476
 
479
- def get_cat_to_nex_flowpairs(hydrofabric: Path = file_paths.conus_hydrofabric) -> List[Tuple]:
477
+ def get_cat_to_nex_flowpairs(hydrofabric: Path = FilePaths.conus_hydrofabric) -> List[Tuple]:
480
478
  """
481
479
  Retrieves the from and to IDs from the specified hydrofabric.
482
480
 
@@ -484,7 +482,7 @@ def get_cat_to_nex_flowpairs(hydrofabric: Path = file_paths.conus_hydrofabric) -
484
482
  The true network flows catchment to waterbody to nexus, this bypasses the waterbody and returns catchment to nexus.
485
483
 
486
484
  Args:
487
- hydrofabric (Path, optional): The file path to the hydrofabric. Defaults to file_paths.conus_hydrofabric.
485
+ hydrofabric (Path, optional): The file path to the hydrofabric. Defaults to FilePaths.conus_hydrofabric.
488
486
  Returns:
489
487
  List[tuple]: A list of tuples containing the from and to IDs.
490
488
  """
@@ -518,7 +516,7 @@ def get_available_tables(gpkg: Path) -> List[str]:
518
516
  return tables
519
517
 
520
518
 
521
- def get_cat_to_nhd_feature_id(gpkg: Path = file_paths.conus_hydrofabric) -> Dict[str, int]:
519
+ def get_cat_to_nhd_feature_id(gpkg: Path = FilePaths.conus_hydrofabric) -> Dict[str, int]:
522
520
  available_tables = get_available_tables(gpkg)
523
521
  possible_tables = ["flowpath_edge_list", "network"]
524
522
 
@@ -5,13 +5,13 @@ from pathlib import Path
5
5
  from typing import List, Optional, Set, Union
6
6
 
7
7
  import igraph as ig
8
- from data_processing.file_paths import file_paths
8
+ from data_processing.file_paths import FilePaths
9
9
 
10
10
  logger = logging.getLogger(__name__)
11
11
 
12
12
 
13
13
  def get_from_to_id_pairs(
14
- hydrofabric: Path = file_paths.conus_hydrofabric, ids: Optional[Set | List] = None
14
+ hydrofabric: Path = FilePaths.conus_hydrofabric, ids: Optional[Set | List] = None
15
15
  ) -> List[tuple]:
16
16
  """
17
17
  Retrieves the from and to IDs from the specified hydrofabric.
@@ -19,7 +19,7 @@ def get_from_to_id_pairs(
19
19
  This function reads the from and to IDs from the specified hydrofabric and returns them as a list of tuples.
20
20
 
21
21
  Args:
22
- hydrofabric (Path, optional): The file path to the hydrofabric. Defaults to file_paths.conus_hydrofabric.
22
+ hydrofabric (Path, optional): The file path to the hydrofabric. Defaults to FilePaths.conus_hydrofabric.
23
23
  ids (Set, optional): A set of IDs to filter the results. Defaults to None.
24
24
  Returns:
25
25
  List[tuple]: A list of tuples containing the from and to IDs.
@@ -96,10 +96,10 @@ def get_graph() -> ig.Graph:
96
96
  Returns:
97
97
  ig.Graph: The hydrological network graph.
98
98
  """
99
- pickled_graph_path = file_paths.hydrofabric_graph
99
+ pickled_graph_path = FilePaths.hydrofabric_graph
100
100
  if not pickled_graph_path.exists():
101
101
  logger.debug("Graph pickle does not exist, creating a new graph.")
102
- network_graph = create_graph_from_gpkg(file_paths.conus_hydrofabric)
102
+ network_graph = create_graph_from_gpkg(FilePaths.conus_hydrofabric)
103
103
  network_graph.write_pickle(pickled_graph_path)
104
104
  else:
105
105
  try:
data_processing/subset.py CHANGED
@@ -3,7 +3,7 @@ import os
3
3
  from pathlib import Path
4
4
  from typing import List, Union
5
5
 
6
- from data_processing.file_paths import file_paths
6
+ from data_processing.file_paths import FilePaths
7
7
  from data_processing.gpkg_utils import (
8
8
  add_triggers_to_gpkg,
9
9
  create_empty_gpkg,
@@ -73,7 +73,7 @@ def create_subset_gpkg(
73
73
 
74
74
 
75
75
  def subset_vpu(
76
- vpu_id: str, output_gpkg_path: Path, hydrofabric: Path = file_paths.conus_hydrofabric
76
+ vpu_id: str, output_gpkg_path: Path, hydrofabric: Path = FilePaths.conus_hydrofabric
77
77
  ):
78
78
  if os.path.exists(output_gpkg_path):
79
79
  response = Prompt.ask(
@@ -95,7 +95,7 @@ def subset_vpu(
95
95
 
96
96
  def subset(
97
97
  cat_ids: str | List[str],
98
- hydrofabric: Path = file_paths.conus_hydrofabric,
98
+ hydrofabric: Path = FilePaths.conus_hydrofabric,
99
99
  output_gpkg_path: Path = Path(),
100
100
  include_outlet: bool = True,
101
101
  override_gpkg: bool = True,
@@ -106,7 +106,7 @@ def subset(
106
106
  # if the name isn't provided, use the first upstream id
107
107
  upstream_ids = sorted(upstream_ids)
108
108
  output_folder_name = upstream_ids[0]
109
- paths = file_paths(output_folder_name)
109
+ paths = FilePaths(output_folder_name)
110
110
  output_gpkg_path = paths.geopackage_path
111
111
 
112
112
  create_subset_gpkg(upstream_ids, hydrofabric, output_gpkg_path, override_gpkg=override_gpkg)