ngiab-data-preprocess 4.3.0__py3-none-any.whl → 4.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,7 +6,10 @@ import sqlite3
6
6
  from datetime import datetime
7
7
  from pathlib import Path
8
8
  from typing import Dict, Optional
9
+ import psutil
10
+ import os
9
11
 
12
+ import numpy as np
10
13
  import pandas
11
14
  import requests
12
15
  import s3fs
@@ -14,8 +17,6 @@ import xarray as xr
14
17
  from data_processing.dask_utils import temp_cluster
15
18
  from data_processing.file_paths import file_paths
16
19
  from data_processing.gpkg_utils import (
17
- GeoPackage,
18
- get_cat_to_nex_flowpairs,
19
20
  get_cat_to_nhd_feature_id,
20
21
  get_table_crs_short,
21
22
  )
@@ -89,7 +90,6 @@ def make_cfe_config(
89
90
  def make_noahowp_config(
90
91
  base_dir: Path, divide_conf_df: pandas.DataFrame, start_time: datetime, end_time: datetime
91
92
  ) -> None:
92
- divide_conf_df.set_index("divide_id", inplace=True)
93
93
  start_datetime = start_time.strftime("%Y%m%d%H%M")
94
94
  end_datetime = end_time.strftime("%Y%m%d%H%M")
95
95
  with open(file_paths.template_noahowp_config, "r") as file:
@@ -98,155 +98,78 @@ def make_noahowp_config(
98
98
  cat_config_dir = base_dir / "cat_config" / "NOAH-OWP-M"
99
99
  cat_config_dir.mkdir(parents=True, exist_ok=True)
100
100
 
101
- for divide in divide_conf_df.index:
102
- with open(cat_config_dir / f"{divide}.input", "w") as file:
101
+ for _, row in divide_conf_df.iterrows():
102
+ with open(cat_config_dir / f"{row['divide_id']}.input", "w") as file:
103
103
  file.write(
104
104
  template.format(
105
105
  start_datetime=start_datetime,
106
106
  end_datetime=end_datetime,
107
- lat=divide_conf_df.loc[divide, "latitude"],
108
- lon=divide_conf_df.loc[divide, "longitude"],
109
- terrain_slope=divide_conf_df.loc[divide, "mean.slope_1km"],
110
- azimuth=divide_conf_df.loc[divide, "circ_mean.aspect"],
111
- ISLTYP=int(divide_conf_df.loc[divide, "mode.ISLTYP"]), # type: ignore
112
- IVGTYP=int(divide_conf_df.loc[divide, "mode.IVGTYP"]), # type: ignore
107
+ lat=row["latitude"],
108
+ lon=row["longitude"],
109
+ terrain_slope=row["mean.slope_1km"],
110
+ azimuth=row["circ_mean.aspect"],
111
+ ISLTYP=int(row["mode.ISLTYP"]), # type: ignore
112
+ IVGTYP=int(row["mode.IVGTYP"]), # type: ignore
113
113
  )
114
114
  )
115
115
 
116
116
 
117
- def get_model_attributes_modspatialite(hydrofabric: Path) -> pandas.DataFrame:
118
- # modspatialite is faster than pyproj but can't be added as a pip dependency
119
- # This incantation took a while
120
- with GeoPackage(hydrofabric) as conn:
121
- sql = """WITH source_crs AS (
122
- SELECT organization || ':' || organization_coordsys_id AS crs_string
123
- FROM gpkg_spatial_ref_sys
124
- WHERE srs_id = (
125
- SELECT srs_id
126
- FROM gpkg_geometry_columns
127
- WHERE table_name = 'divides'
128
- )
129
- )
130
- SELECT
131
- d.divide_id,
132
- d.areasqkm,
133
- da."mean.slope",
134
- da."mean.slope_1km",
135
- da."mean.elevation",
136
- ST_X(Transform(MakePoint(da.centroid_x, da.centroid_y), 4326, NULL,
137
- (SELECT crs_string FROM source_crs), 'EPSG:4326')) AS longitude,
138
- ST_Y(Transform(MakePoint(da.centroid_x, da.centroid_y), 4326, NULL,
139
- (SELECT crs_string FROM source_crs), 'EPSG:4326')) AS latitude
140
- FROM divides AS d
141
- JOIN 'divide-attributes' AS da ON d.divide_id = da.divide_id
142
- """
143
- divide_conf_df = pandas.read_sql_query(sql, conn)
144
- divide_conf_df.set_index("divide_id", inplace=True)
145
- return divide_conf_df
146
-
147
-
148
- def get_model_attributes_pyproj(hydrofabric: Path) -> pandas.DataFrame:
149
- # if modspatialite is not available, use pyproj
117
+ def get_model_attributes(hydrofabric: Path) -> pandas.DataFrame:
150
118
  with sqlite3.connect(hydrofabric) as conn:
151
- sql = """
152
- SELECT
153
- d.divide_id,
154
- d.areasqkm,
155
- da."mean.slope",
156
- da."mean.slope_1km",
157
- da."mean.elevation",
158
- da.centroid_x,
159
- da.centroid_y
160
- FROM divides AS d
161
- JOIN 'divide-attributes' AS da ON d.divide_id = da.divide_id
162
- """
163
- divide_conf_df = pandas.read_sql_query(sql, conn)
164
-
119
+ conf_df = pandas.read_sql_query(
120
+ """
121
+ SELECT
122
+ d.areasqkm,
123
+ da.*
124
+ FROM divides AS d
125
+ JOIN 'divide-attributes' AS da ON d.divide_id = da.divide_id
126
+ """,
127
+ conn,
128
+ )
165
129
  source_crs = get_table_crs_short(hydrofabric, "divides")
166
-
167
130
  transformer = Transformer.from_crs(source_crs, "EPSG:4326", always_xy=True)
168
-
169
- lon, lat = transformer.transform(
170
- divide_conf_df["centroid_x"].values, divide_conf_df["centroid_y"].values
171
- )
172
-
173
- divide_conf_df["longitude"] = lon
174
- divide_conf_df["latitude"] = lat
175
-
176
- divide_conf_df.drop(columns=["centroid_x", "centroid_y"], axis=1, inplace=True)
177
- divide_conf_df.set_index("divide_id", inplace=True)
178
-
179
- return divide_conf_df
180
-
181
-
182
- def get_model_attributes(hydrofabric: Path) -> pandas.DataFrame:
183
- try:
184
- with GeoPackage(hydrofabric) as conn:
185
- conf_df = pandas.read_sql_query(
186
- """WITH source_crs AS (
187
- SELECT organization || ':' || organization_coordsys_id AS crs_string
188
- FROM gpkg_spatial_ref_sys
189
- WHERE srs_id = (
190
- SELECT srs_id
191
- FROM gpkg_geometry_columns
192
- WHERE table_name = 'divides'
193
- )
194
- )
195
- SELECT
196
- *,
197
- ST_X(Transform(MakePoint(centroid_x, centroid_y), 4326, NULL,
198
- (SELECT crs_string FROM source_crs), 'EPSG:4326')) AS longitude,
199
- ST_Y(Transform(MakePoint(centroid_x, centroid_y), 4326, NULL,
200
- (SELECT crs_string FROM source_crs), 'EPSG:4326')) AS latitude FROM 'divide-attributes';""",
201
- conn,
202
- )
203
- except sqlite3.OperationalError:
204
- with sqlite3.connect(hydrofabric) as conn:
205
- conf_df = pandas.read_sql_query(
206
- "SELECT* FROM 'divide-attributes';",
207
- conn,
208
- )
209
- source_crs = get_table_crs_short(hydrofabric, "divides")
210
- transformer = Transformer.from_crs(source_crs, "EPSG:4326", always_xy=True)
211
- lon, lat = transformer.transform(conf_df["centroid_x"].values, conf_df["centroid_y"].values)
212
- conf_df["longitude"] = lon
213
- conf_df["latitude"] = lat
214
-
215
- conf_df.drop(columns=["centroid_x", "centroid_y"], axis=1, inplace=True)
131
+ lon, lat = transformer.transform(conf_df["centroid_x"].values, conf_df["centroid_y"].values)
132
+ conf_df["longitude"] = lon
133
+ conf_df["latitude"] = lat
216
134
  return conf_df
217
135
 
218
136
 
219
- def make_em_config(
137
+ def make_lstm_config(
220
138
  hydrofabric: Path,
221
139
  output_dir: Path,
222
- template_path: Path = file_paths.template_em_config,
140
+ template_path: Path = file_paths.template_lstm_config,
223
141
  ):
224
142
  # test if modspatialite is available
225
- try:
226
- divide_conf_df = get_model_attributes_modspatialite(hydrofabric)
227
- except Exception as e:
228
- logger.warning(f"mod_spatialite not available, using pyproj instead: {e}")
229
- logger.warning("Install mod_spatialite for improved performance")
230
- divide_conf_df = get_model_attributes_pyproj(hydrofabric)
231
-
232
- cat_config_dir = output_dir / "cat_config" / "empirical_model"
143
+
144
+ divide_conf_df = get_model_attributes(hydrofabric)
145
+
146
+ cat_config_dir = output_dir / "cat_config" / "lstm"
233
147
  if cat_config_dir.exists():
234
148
  shutil.rmtree(cat_config_dir)
235
149
  cat_config_dir.mkdir(parents=True, exist_ok=True)
236
150
 
151
+ # convert the mean.slope from degrees 0-90 where 90 is flat and 0 is vertical to m/km
152
+ # flip 0 and 90 degree values
153
+ divide_conf_df["flipped_mean_slope"] = abs(divide_conf_df["mean.slope"] - 90)
154
+ # Convert degrees to meters per kmmeter
155
+ divide_conf_df["mean_slope_mpkm"] = (
156
+ np.tan(np.radians(divide_conf_df["flipped_mean_slope"])) * 1000
157
+ )
158
+
237
159
  with open(template_path, "r") as file:
238
160
  template = file.read()
239
161
 
240
- for divide in divide_conf_df.index:
162
+ for _, row in divide_conf_df.iterrows():
163
+ divide = row["divide_id"]
241
164
  with open(cat_config_dir / f"{divide}.yml", "w") as file:
242
165
  file.write(
243
166
  template.format(
244
- area_sqkm=divide_conf_df.loc[divide, "areasqkm"],
167
+ area_sqkm=row["areasqkm"],
245
168
  divide_id=divide,
246
- lat=divide_conf_df.loc[divide, "latitude"],
247
- lon=divide_conf_df.loc[divide, "longitude"],
248
- slope_mean=divide_conf_df.loc[divide, "mean.slope"],
249
- elevation_mean=divide_conf_df.loc[divide, "mean.slope"],
169
+ lat=row["latitude"],
170
+ lon=row["longitude"],
171
+ slope_mean=row["mean_slope_mpkm"],
172
+ elevation_mean=row["mean.elevation"] / 1000, # convert mm in hf to m
250
173
  )
251
174
  )
252
175
 
@@ -257,7 +180,27 @@ def configure_troute(
257
180
  with open(file_paths.template_troute_config, "r") as file:
258
181
  troute_template = file.read()
259
182
  time_step_size = 300
183
+ gpkg_file_path=f"{config_dir}/{cat_id}_subset.gpkg"
260
184
  nts = (end_time - start_time).total_seconds() / time_step_size
185
+ with sqlite3.connect(gpkg_file_path) as conn:
186
+ ncats_df = pandas.read_sql_query("SELECT COUNT(id) FROM 'divides';", conn)
187
+ ncats = ncats_df['COUNT(id)'][0]
188
+
189
+ est_bytes_required = nts * ncats * 45 # extremely rough calculation based on about 3 tests :)
190
+ local_ram_available = 0.8 * psutil.virtual_memory().available # buffer to not accidentally explode machine
191
+
192
+ if est_bytes_required > local_ram_available:
193
+ max_loop_size = nts // (est_bytes_required // local_ram_available)
194
+ binary_nexus_file_folder_comment = ""
195
+ parent_dir = config_dir.parent
196
+ output_parquet_path = Path(f"{parent_dir}/outputs/parquet/")
197
+
198
+ if not output_parquet_path.exists():
199
+ os.makedirs(output_parquet_path)
200
+ else:
201
+ max_loop_size = nts
202
+ binary_nexus_file_folder_comment = "#"
203
+
261
204
  filled_template = troute_template.format(
262
205
  # hard coded to 5 minutes
263
206
  time_step_size=time_step_size,
@@ -266,7 +209,8 @@ def configure_troute(
266
209
  geo_file_path=f"./config/{cat_id}_subset.gpkg",
267
210
  start_datetime=start_time.strftime("%Y-%m-%d %H:%M:%S"),
268
211
  nts=nts,
269
- max_loop_size=nts,
212
+ max_loop_size=max_loop_size,
213
+ binary_nexus_file_folder_comment=binary_nexus_file_folder_comment
270
214
  )
271
215
 
272
216
  with open(config_dir / "troute.yaml", "w") as file:
@@ -287,22 +231,14 @@ def make_ngen_realization_json(
287
231
  json.dump(realization, file, indent=4)
288
232
 
289
233
 
290
- def create_em_realization(cat_id: str, start_time: datetime, end_time: datetime):
234
+ def create_lstm_realization(cat_id: str, start_time: datetime, end_time: datetime):
291
235
  paths = file_paths(cat_id)
292
- template_path = file_paths.template_em_realization_config
293
- em_config = file_paths.template_em_model_config
294
- # move em_config to paths.config_dir
295
- with open(em_config, "r") as f:
296
- em_config = f.read()
297
- with open(paths.config_dir / "em-config.yml", "w") as f:
298
- f.write(em_config)
299
-
236
+ template_path = file_paths.template_lstm_realization_config
300
237
  configure_troute(cat_id, paths.config_dir, start_time, end_time)
301
238
  make_ngen_realization_json(paths.config_dir, template_path, start_time, end_time)
302
- make_em_config(paths.geopackage_path, paths.config_dir)
239
+ make_lstm_config(paths.geopackage_path, paths.config_dir)
303
240
  # create some partitions for parallelization
304
241
  paths.setup_run_folders()
305
- create_partitions(paths)
306
242
 
307
243
 
308
244
  def create_realization(
@@ -345,48 +281,3 @@ def create_realization(
345
281
 
346
282
  # create some partitions for parallelization
347
283
  paths.setup_run_folders()
348
- create_partitions(paths)
349
-
350
-
351
- def create_partitions(paths: file_paths, num_partitions: Optional[int] = None) -> None:
352
- if num_partitions is None:
353
- num_partitions = multiprocessing.cpu_count()
354
-
355
- cat_to_nex_pairs = get_cat_to_nex_flowpairs(hydrofabric=paths.geopackage_path)
356
- # nexus = defaultdict(list)
357
-
358
- # for cat, nex in cat_to_nex_pairs:
359
- # nexus[nex].append(cat)
360
-
361
- num_partitions = min(num_partitions, len(cat_to_nex_pairs))
362
- # partition_size = ceil(len(nexus) / num_partitions)
363
- # num_nexus = len(nexus)
364
- # nexus = list(nexus.items())
365
- # partitions = []
366
- # for i in range(0, num_nexus, partition_size):
367
- # part = {}
368
- # part["id"] = i // partition_size
369
- # part["cat-ids"] = []
370
- # part["nex-ids"] = []
371
- # part["remote-connections"] = []
372
- # for j in range(i, i + partition_size):
373
- # if j < num_nexus:
374
- # part["cat-ids"].extend(nexus[j][1])
375
- # part["nex-ids"].append(nexus[j][0])
376
- # partitions.append(part)
377
-
378
- # with open(paths.subset_dir / f"partitions_{num_partitions}.json", "w") as f:
379
- # f.write(json.dumps({"partitions": partitions}, indent=4))
380
-
381
- # write this to a metadata file to save on repeated file io to recalculate
382
- with open(paths.metadata_dir / "num_partitions", "w") as f:
383
- f.write(str(num_partitions))
384
-
385
-
386
- if __name__ == "__main__":
387
- cat_id = "cat-1643991"
388
- start_time = datetime(2010, 1, 1, 0, 0, 0)
389
- end_time = datetime(2010, 1, 2, 0, 0, 0)
390
- # output_interval = 3600
391
- # nts = 2592
392
- create_realization(cat_id, start_time, end_time)
@@ -7,9 +7,9 @@ from typing import List, Literal, Optional, Tuple, Union
7
7
  import geopandas as gpd
8
8
  import numpy as np
9
9
  import xarray as xr
10
+ from dask.distributed import Client, Future, progress
11
+ from data_processing.dask_utils import no_cluster, temp_cluster
10
12
  from xarray.core.types import InterpOptions
11
- from dask.distributed import Client, progress, Future
12
- from data_processing.dask_utils import use_cluster
13
13
 
14
14
  logger = logging.getLogger(__name__)
15
15
 
@@ -117,13 +117,14 @@ def clip_dataset_to_bounds(
117
117
  return dataset
118
118
 
119
119
 
120
+ @no_cluster
120
121
  def interpolate_nan_values(
121
122
  dataset: xr.Dataset,
122
123
  variables: Optional[List[str]] = None,
123
124
  dim: str = "time",
124
125
  method: InterpOptions = "nearest",
125
126
  fill_value: str = "extrapolate",
126
- ) -> None:
127
+ ) -> bool:
127
128
  """
128
129
  Interpolates NaN values in specified (or all numeric time-dependent)
129
130
  variables of an xarray.Dataset. Operates inplace on the dataset.
@@ -145,6 +146,7 @@ def interpolate_nan_values(
145
146
  Set to "extrapolate" to fill with the nearest valid value when using 'nearest' or 'linear'.
146
147
  Default is "extrapolate".
147
148
  """
149
+ interpolation_used = False
148
150
  for name, var in dataset.data_vars.items():
149
151
  # if the variable is non-numeric, skip
150
152
  if not np.issubdtype(var.dtype, np.number):
@@ -158,9 +160,35 @@ def interpolate_nan_values(
158
160
  method=method,
159
161
  fill_value=fill_value if method in ["nearest", "linear"] else None,
160
162
  )
163
+ interpolation_used = True
164
+ return interpolation_used
161
165
 
162
166
 
163
- @use_cluster
167
+ @no_cluster
168
+ def save_dataset_no_cluster(
169
+ ds_to_save: xr.Dataset,
170
+ target_path: Path,
171
+ engine: Literal["netcdf4", "scipy", "h5netcdf"] = "h5netcdf",
172
+ ):
173
+ """
174
+ This explicitly does not use dask distributed.
175
+ Helper function to compute and save an xarray.Dataset to a NetCDF file.
176
+ Uses a temporary file and rename for avoid leaving a half written file.
177
+ """
178
+ if not target_path.parent.exists():
179
+ target_path.parent.mkdir(parents=True, exist_ok=True)
180
+
181
+ temp_file_path = target_path.with_name(target_path.name + ".saving.nc")
182
+ if temp_file_path.exists():
183
+ os.remove(temp_file_path)
184
+
185
+ ds_to_save.to_netcdf(temp_file_path, engine=engine, compute=True)
186
+
187
+ os.rename(str(temp_file_path), str(target_path))
188
+ logger.info(f"Successfully saved data to: {target_path}")
189
+
190
+
191
+ @temp_cluster
164
192
  def save_dataset(
165
193
  ds_to_save: xr.Dataset,
166
194
  target_path: Path,
@@ -184,20 +212,21 @@ def save_dataset(
184
212
  logger.debug(
185
213
  f"NetCDF write task submitted to Dask. Waiting for completion to {temp_file_path}..."
186
214
  )
215
+ logger.info("For more detailed progress, see the Dask dashboard http://localhost:8787/status")
187
216
  progress(future)
188
217
  future.result()
189
218
  os.rename(str(temp_file_path), str(target_path))
190
219
  logger.info(f"Successfully saved data to: {target_path}")
191
220
 
192
221
 
193
- @use_cluster
222
+ @no_cluster
194
223
  def save_to_cache(
195
224
  stores: xr.Dataset, cached_nc_path: Path, interpolate_nans: bool = True
196
225
  ) -> xr.Dataset:
197
226
  """
198
227
  Compute the store and save it to a cached netCDF file. This is not required but will save time and bandwidth.
199
228
  """
200
- logger.info(f"Processing dataset for caching. Final cache target: {cached_nc_path}")
229
+ logger.debug(f"Processing dataset for caching. Final cache target: {cached_nc_path}")
201
230
 
202
231
  # lasily cast all numbers to f32
203
232
  for name, var in stores.data_vars.items():
@@ -206,13 +235,18 @@ def save_to_cache(
206
235
 
207
236
  # save dataset locally before manipulating it
208
237
  save_dataset(stores, cached_nc_path)
209
- stores = xr.open_mfdataset(cached_nc_path, parallel=True, engine="h5netcdf")
210
238
 
211
239
  if interpolate_nans:
212
- interpolate_nan_values(dataset=stores)
213
- save_dataset(stores, cached_nc_path)
214
- stores = xr.open_mfdataset(cached_nc_path, parallel=True, engine="h5netcdf")
240
+ stores = xr.open_mfdataset(
241
+ cached_nc_path,
242
+ parallel=True,
243
+ engine="h5netcdf",
244
+ )
245
+ was_interpolated = interpolate_nan_values(dataset=stores)
246
+ if was_interpolated:
247
+ save_dataset_no_cluster(stores, cached_nc_path)
215
248
 
249
+ stores = xr.open_mfdataset(cached_nc_path, parallel=True, engine="h5netcdf")
216
250
  return stores
217
251
 
218
252
 
@@ -1,6 +1,7 @@
1
+ from datetime import datetime
1
2
  from pathlib import Path
2
3
  from typing import Optional
3
- from datetime import datetime
4
+
4
5
 
5
6
  class file_paths:
6
7
  """
@@ -27,11 +28,10 @@ class file_paths:
27
28
  dev_file = Path(__file__).parent.parent.parent / ".dev"
28
29
  template_troute_config = data_sources / "ngen-routing-template.yaml"
29
30
  template_cfe_nowpm_realization_config = data_sources / "cfe-nowpm-realization-template.json"
30
- template_em_realization_config = data_sources / "em-realization-template.json"
31
+ template_lstm_realization_config = data_sources / "lstm-realization-template.json"
31
32
  template_noahowp_config = data_sources / "noah-owp-modular-init.namelist.input"
32
33
  template_cfe_config = data_sources / "cfe-template.ini"
33
- template_em_config = data_sources / "em-catchment-template.yml"
34
- template_em_model_config = data_sources / "em-config.yml"
34
+ template_lstm_config = data_sources / "lstm-catchment-template.yml"
35
35
 
36
36
  def __init__(self, folder_name: Optional[str] = None, output_dir: Optional[Path] = None):
37
37
  """
@@ -169,7 +169,6 @@ def get_upstream_cats(names: Union[str, List[str]]) -> Set[str]:
169
169
  node_index = graph.vs.find(cat=name).index
170
170
  else:
171
171
  node_index = graph.vs.find(name=name).index
172
- node_index = graph.vs.find(cat=name).index
173
172
  upstream_nodes = graph.subcomponent(node_index, mode="IN")
174
173
  for node in upstream_nodes:
175
174
  parent_ids.add(graph.vs[node]["name"])
@@ -178,7 +177,6 @@ def get_upstream_cats(names: Union[str, List[str]]) -> Set[str]:
178
177
  logger.critical(f"Catchment {name} not found in the hydrofabric graph.")
179
178
  except ValueError:
180
179
  logger.critical(f"Catchment {name} not found in the hydrofabric graph.")
181
-
182
180
  # sometimes returns None, which isn't helpful
183
181
  if None in cat_ids:
184
182
  cat_ids.remove(None)
data_processing/subset.py CHANGED
@@ -12,9 +12,11 @@ from data_processing.gpkg_utils import (
12
12
  update_geopackage_metadata,
13
13
  )
14
14
  from data_processing.graph_utils import get_upstream_ids
15
+ from rich.console import Console
16
+ from rich.prompt import Prompt
15
17
 
16
18
  logger = logging.getLogger(__name__)
17
-
19
+ console = Console()
18
20
  subset_tables = [
19
21
  "divides",
20
22
  "divide-attributes", # requires divides
@@ -30,15 +32,33 @@ subset_tables = [
30
32
 
31
33
 
32
34
  def create_subset_gpkg(
33
- ids: Union[List[str], str], hydrofabric: Path, output_gpkg_path: Path, is_vpu: bool = False
35
+ ids: Union[List[str], str],
36
+ hydrofabric: Path,
37
+ output_gpkg_path: Path,
38
+ is_vpu: bool = False,
39
+ override_gpkg: bool = True,
34
40
  ):
35
41
  # ids is a list of nexus and wb ids, or a single vpu id
36
42
  if not isinstance(ids, list):
37
43
  ids = [ids]
38
44
  output_gpkg_path.parent.mkdir(parents=True, exist_ok=True)
39
45
 
40
- if os.path.exists(output_gpkg_path):
41
- os.remove(output_gpkg_path)
46
+ if not override_gpkg:
47
+ if os.path.exists(output_gpkg_path):
48
+ response = Prompt.ask(
49
+ f"Subset geopackage at {output_gpkg_path} already exists. Are you sure you want to overwrite it?",
50
+ default="n",
51
+ choices=["y", "n"],
52
+ )
53
+ if response == "y":
54
+ console.print(f"Removing {output_gpkg_path}...", style="yellow")
55
+ os.remove(output_gpkg_path)
56
+ else:
57
+ console.print("Exiting...", style="bold red")
58
+ exit()
59
+ else:
60
+ if os.path.exists(output_gpkg_path):
61
+ os.remove(output_gpkg_path)
42
62
 
43
63
  create_empty_gpkg(output_gpkg_path)
44
64
  logger.info(f"Subsetting tables: {subset_tables}")
@@ -55,8 +75,18 @@ def create_subset_gpkg(
55
75
  def subset_vpu(
56
76
  vpu_id: str, output_gpkg_path: Path, hydrofabric: Path = file_paths.conus_hydrofabric
57
77
  ):
58
- if output_gpkg_path.exists():
59
- os.remove(output_gpkg_path)
78
+ if os.path.exists(output_gpkg_path):
79
+ response = Prompt.ask(
80
+ f"Subset geopackage at {output_gpkg_path} already exists. Are you sure you want to overwrite it?",
81
+ default="n",
82
+ choices=["y", "n"],
83
+ )
84
+ if response == "y":
85
+ console.print(f"Removing {output_gpkg_path}...", style="yellow")
86
+ os.remove(output_gpkg_path)
87
+ else:
88
+ console.print("Exiting...", style="bold red")
89
+ exit()
60
90
 
61
91
  create_subset_gpkg(vpu_id, hydrofabric, output_gpkg_path=output_gpkg_path, is_vpu=True)
62
92
  logger.info(f"Subset complete for VPU {vpu_id}")
@@ -68,6 +98,7 @@ def subset(
68
98
  hydrofabric: Path = file_paths.conus_hydrofabric,
69
99
  output_gpkg_path: Path = Path(),
70
100
  include_outlet: bool = True,
101
+ override_gpkg: bool = True,
71
102
  ):
72
103
  upstream_ids = list(get_upstream_ids(cat_ids, include_outlet))
73
104
 
@@ -78,6 +109,6 @@ def subset(
78
109
  paths = file_paths(output_folder_name)
79
110
  output_gpkg_path = paths.geopackage_path
80
111
 
81
- create_subset_gpkg(upstream_ids, hydrofabric, output_gpkg_path)
112
+ create_subset_gpkg(upstream_ids, hydrofabric, output_gpkg_path, override_gpkg=override_gpkg)
82
113
  logger.info(f"Subset complete for {len(upstream_ids)} features (catchments + nexuses)")
83
114
  logger.debug(f"Subset complete for {upstream_ids} catchments")
@@ -0,0 +1,17 @@
1
+ time_step: "1 hour"
2
+ area_sqkm: {area_sqkm} # areasqkm
3
+ basin_id: {divide_id}
4
+ basin_name: {divide_id}
5
+ elev_mean: {elevation_mean} # mean.elevation
6
+ initial_state: zero
7
+ lat: {lat} # needs calulating
8
+ lon: {lon} # needs calulating
9
+ slope_mean: {slope_mean} # mean.slope
10
+ train_cfg_file:
11
+ - /ngen/ngen/extern/lstm/trained_neuralhydrology_models/nh_AORC_hourly_25yr_1210_112435_7/config.yml
12
+ - /ngen/ngen/extern/lstm/trained_neuralhydrology_models/nh_AORC_hourly_25yr_1210_112435_8/config.yml
13
+ - /ngen/ngen/extern/lstm/trained_neuralhydrology_models/nh_AORC_hourly_25yr_1210_112435_9/config.yml
14
+ - /ngen/ngen/extern/lstm/trained_neuralhydrology_models/nh_AORC_hourly_25yr_seq999_seed101_0701_143442/config.yml
15
+ - /ngen/ngen/extern/lstm/trained_neuralhydrology_models/nh_AORC_hourly_25yr_seq999_seed103_2701_171540/config.yml
16
+ - /ngen/ngen/extern/lstm/trained_neuralhydrology_models/nh_AORC_hourly_slope_elev_precip_temp_seq999_seed101_2801_191806/config.yml
17
+ verbose: 0
@@ -5,25 +5,22 @@
5
5
  "name": "bmi_multi",
6
6
  "params": {
7
7
  "name": "bmi_multi",
8
- "model_type_name": "empirical_model",
8
+ "model_type_name": "lstm",
9
9
  "forcing_file": "",
10
10
  "init_config": "",
11
11
  "allow_exceed_end_time": true,
12
12
  "main_output_variable": "land_surface_water__runoff_depth",
13
- "modules": [
13
+ "modules": [
14
14
  {
15
15
  "name": "bmi_python",
16
16
  "params": {
17
17
  "name": "bmi_python",
18
18
  "python_type": "lstm.bmi_lstm.bmi_LSTM",
19
- "model_type_name": "bmi_empirical_model",
20
- "init_config": "./config/cat_config/empirical_model/{{id}}.yml",
19
+ "model_type_name": "bmi_lstm",
20
+ "init_config": "./config/cat_config/lstm/{{id}}.yml",
21
21
  "allow_exceed_end_time": true,
22
22
  "main_output_variable": "land_surface_water__runoff_depth",
23
- "uses_forcing_file": false,
24
- "variables_names_map": {
25
- "atmosphere_water__liquid_equivalent_precipitation_rate": "APCP_surface"
26
- }
23
+ "uses_forcing_file": false
27
24
  }
28
25
  }
29
26
  ]
@@ -62,7 +62,7 @@ compute_parameters:
62
62
  qlat_input_folder: ./outputs/ngen/
63
63
  qlat_file_pattern_filter: "nex-*"
64
64
 
65
- #binary_nexus_file_folder: ./outputs/parquet/ # if nexus_file_pattern_filter="nex-*" and you want it to reformat them as parquet, you need this
65
+ {binary_nexus_file_folder_comment}binary_nexus_file_folder: ./outputs/parquet/ # if nexus_file_pattern_filter="nex-*" and you want it to reformat them as parquet, you need this
66
66
  #coastal_boundary_input_file : channel_forcing/schout_1.nc
67
67
  nts: {nts} #288 for 1day
68
68
  max_loop_size: {max_loop_size} # [number of timesteps]