hydrodl2 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydrodl2/__init__.py +122 -0
- hydrodl2/_version.py +34 -0
- hydrodl2/api/__init__.py +3 -0
- hydrodl2/api/methods.py +144 -0
- hydrodl2/core/calc/__init__.py +11 -0
- hydrodl2/core/calc/batch_jacobian.pye +501 -0
- hydrodl2/core/calc/fdj.py +92 -0
- hydrodl2/core/calc/uh_routing.py +105 -0
- hydrodl2/core/calc/utils.py +59 -0
- hydrodl2/core/utils/__init__.py +7 -0
- hydrodl2/core/utils/clean_temp.sh +8 -0
- hydrodl2/core/utils/utils.py +63 -0
- hydrodl2/models/hbv/hbv.py +596 -0
- hydrodl2/models/hbv/hbv_1_1p.py +608 -0
- hydrodl2/models/hbv/hbv_2.py +670 -0
- hydrodl2/models/hbv/hbv_2_hourly.py +897 -0
- hydrodl2/models/hbv/hbv_2_mts.py +377 -0
- hydrodl2/models/hbv/hbv_adj.py +712 -0
- hydrodl2/modules/__init__.py +2 -0
- hydrodl2/modules/data_assimilation/variational_prcp_da.py +1 -0
- hydrodl2-1.3.0.dist-info/METADATA +184 -0
- hydrodl2-1.3.0.dist-info/RECORD +24 -0
- hydrodl2-1.3.0.dist-info/WHEEL +4 -0
- hydrodl2-1.3.0.dist-info/licenses/LICENSE +31 -0
|
@@ -0,0 +1,670 @@
|
|
|
1
|
+
from typing import Any, Optional, Union
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from hydrodl2.core.calc import change_param_range, uh_conv, uh_gamma
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Hbv_2(torch.nn.Module):
|
|
9
|
+
"""HBV 2.0.
|
|
10
|
+
|
|
11
|
+
Multi-component, multi-scale, differentiable PyTorch HBV model with rainfall
|
|
12
|
+
runoff simulation on unit basins.
|
|
13
|
+
|
|
14
|
+
Authors
|
|
15
|
+
-------
|
|
16
|
+
- Yalan Song, Leo Lonzarich, Wencong Yang
|
|
17
|
+
- (Original NumPy HBV ver.) Beck et al., 2020 (http://www.gloh2o.org/hbv/).
|
|
18
|
+
- (HBV-light Version 2) Seibert, 2005
|
|
19
|
+
(https://www.geo.uzh.ch/dam/jcr:c8afa73c-ac90-478e-a8c7-929eed7b1b62/HBV_manual_2005.pdf).
|
|
20
|
+
|
|
21
|
+
Publication
|
|
22
|
+
-----------
|
|
23
|
+
- Yalan Song, Tadd Bindas, Chaopeng Shen, et al. "High-resolution
|
|
24
|
+
national-scale water modeling is enhanced by multiscale differentiable
|
|
25
|
+
physics-informed machine learning." Water Resources Research (2025).
|
|
26
|
+
https://doi.org/10.1029/2024WR038928.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
config
|
|
31
|
+
Configuration dictionary.
|
|
32
|
+
device
|
|
33
|
+
Device to run the model on.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
config: Optional[dict[str, Any]] = None,
|
|
39
|
+
device: Optional[torch.device] = None,
|
|
40
|
+
) -> None:
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.name = 'HBV 2.0'
|
|
43
|
+
self.config = config
|
|
44
|
+
self.initialize = False
|
|
45
|
+
self.warm_up = 0
|
|
46
|
+
self.pred_cutoff = 0
|
|
47
|
+
self.warm_up_states = True
|
|
48
|
+
self.dynamic_params = []
|
|
49
|
+
self.dy_drop = 0.0
|
|
50
|
+
self.variables = ['prcp', 'tmean', 'pet']
|
|
51
|
+
self.routing = False
|
|
52
|
+
self.lenF = 15
|
|
53
|
+
self.comprout = False
|
|
54
|
+
self.muwts = None
|
|
55
|
+
self.nearzero = 1e-5
|
|
56
|
+
self.nmul = 1
|
|
57
|
+
self.cache_states = False
|
|
58
|
+
self.device = device
|
|
59
|
+
|
|
60
|
+
self.states, self._state_cache = None, None
|
|
61
|
+
|
|
62
|
+
self.state_names = [
|
|
63
|
+
'SNOWPACK', # Snowpack storage
|
|
64
|
+
'MELTWATER', # Meltwater storage
|
|
65
|
+
'SM', # Soil moisture storage
|
|
66
|
+
'SUZ', # Upper groundwater storage
|
|
67
|
+
'SLZ', # Lower groundwater storage
|
|
68
|
+
]
|
|
69
|
+
self.flux_names = [
|
|
70
|
+
'streamflow', # Routed Streamflow
|
|
71
|
+
'srflow', # Routed surface runoff
|
|
72
|
+
'ssflow', # Routed subsurface flow
|
|
73
|
+
'gwflow', # Routed groundwater flow
|
|
74
|
+
'AET_hydro', # Actual ET
|
|
75
|
+
'PET_hydro', # Potential ET
|
|
76
|
+
'SWE', # Snow water equivalent
|
|
77
|
+
'streamflow_no_rout', # Streamflow
|
|
78
|
+
'srflow_no_rout', # Surface runoff
|
|
79
|
+
'ssflow_no_rout', # Subsurface flow
|
|
80
|
+
'gwflow_no_rout', # Groundwater flow
|
|
81
|
+
'recharge', # Recharge
|
|
82
|
+
'excs', # Excess stored water
|
|
83
|
+
'evapfactor', # Evaporation factor
|
|
84
|
+
'tosoil', # Infiltration
|
|
85
|
+
'percolation', # Percolation
|
|
86
|
+
'capillary', # Capillary rise
|
|
87
|
+
'BFI', # Baseflow index
|
|
88
|
+
]
|
|
89
|
+
|
|
90
|
+
self.parameter_bounds = {
|
|
91
|
+
'parBETA': [1.0, 6.0],
|
|
92
|
+
'parFC': [50, 1000],
|
|
93
|
+
'parK0': [0.05, 0.9],
|
|
94
|
+
'parK1': [0.01, 0.5],
|
|
95
|
+
'parK2': [0.001, 0.2],
|
|
96
|
+
'parLP': [0.2, 1],
|
|
97
|
+
'parPERC': [0, 10],
|
|
98
|
+
'parUZL': [0, 100],
|
|
99
|
+
'parTT': [-2.5, 2.5],
|
|
100
|
+
'parCFMAX': [0.5, 10],
|
|
101
|
+
'parCFR': [0, 0.1],
|
|
102
|
+
'parCWH': [0, 0.2],
|
|
103
|
+
'parBETAET': [0.3, 5],
|
|
104
|
+
'parC': [0, 1],
|
|
105
|
+
'parRT': [0, 20],
|
|
106
|
+
'parAC': [0, 2500],
|
|
107
|
+
}
|
|
108
|
+
self.routing_parameter_bounds = {
|
|
109
|
+
'route_a': [0, 2.9],
|
|
110
|
+
'route_b': [0, 6.5],
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
if not device:
|
|
114
|
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
115
|
+
|
|
116
|
+
if config is not None:
|
|
117
|
+
# Overwrite defaults with config values.
|
|
118
|
+
self.warm_up = config.get('warm_up', self.warm_up)
|
|
119
|
+
self.warm_up_states = config.get('warm_up_states', self.warm_up_states)
|
|
120
|
+
self.dy_drop = config.get('dy_drop', self.dy_drop)
|
|
121
|
+
self.dynamic_params = config['dynamic_params'].get(
|
|
122
|
+
self.__class__.__name__, self.dynamic_params
|
|
123
|
+
)
|
|
124
|
+
self.variables = config.get('variables', self.variables)
|
|
125
|
+
self.routing = config.get('routing', self.routing)
|
|
126
|
+
self.comprout = config.get('comprout', self.comprout)
|
|
127
|
+
self.nearzero = config.get('nearzero', self.nearzero)
|
|
128
|
+
self.nmul = config.get('nmul', self.nmul)
|
|
129
|
+
self.cache_states = config.get('cache_states', self.cache_states)
|
|
130
|
+
self._set_parameters()
|
|
131
|
+
|
|
132
|
+
def _init_states(self, ngrid: int) -> tuple[torch.Tensor]:
|
|
133
|
+
"""Initialize model states to zero."""
|
|
134
|
+
|
|
135
|
+
def make_state():
|
|
136
|
+
return torch.full(
|
|
137
|
+
(ngrid, self.nmul), 0.001, dtype=torch.float32, device=self.device
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
return tuple(make_state() for _ in range(len(self.state_names)))
|
|
141
|
+
|
|
142
|
+
def get_states(self) -> Optional[tuple[torch.Tensor, ...]]:
|
|
143
|
+
"""Return internal model states.
|
|
144
|
+
|
|
145
|
+
Returns
|
|
146
|
+
-------
|
|
147
|
+
tuple[torch.Tensor, ...]
|
|
148
|
+
A tuple containing the states (SNOWPACK, MELTWATER, SM, SUZ, SLZ).
|
|
149
|
+
"""
|
|
150
|
+
return self._state_cache
|
|
151
|
+
|
|
152
|
+
def load_states(
|
|
153
|
+
self,
|
|
154
|
+
states: tuple[torch.Tensor, ...],
|
|
155
|
+
) -> None:
|
|
156
|
+
"""Load internal model states and set to model device and type.
|
|
157
|
+
|
|
158
|
+
Parameters
|
|
159
|
+
----------
|
|
160
|
+
states
|
|
161
|
+
A tuple containing the states (SNOWPACK, MELTWATER, SM, SUZ, SLZ).
|
|
162
|
+
"""
|
|
163
|
+
for state in states:
|
|
164
|
+
if not isinstance(state, torch.Tensor):
|
|
165
|
+
raise ValueError("Each element in `states` must be a tensor.")
|
|
166
|
+
nstates = len(self.state_names)
|
|
167
|
+
if not (isinstance(states, tuple) and len(states) == nstates):
|
|
168
|
+
raise ValueError(f"`states` must be a tuple of {nstates} tensors.")
|
|
169
|
+
|
|
170
|
+
self.states = tuple(
|
|
171
|
+
s.detach().to(self.device, dtype=torch.float32) for s in states
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
def _set_parameters(self) -> None:
|
|
175
|
+
"""Get physical parameters."""
|
|
176
|
+
self.phy_param_names = self.parameter_bounds.keys()
|
|
177
|
+
if self.routing:
|
|
178
|
+
self.routing_param_names = self.routing_parameter_bounds.keys()
|
|
179
|
+
else:
|
|
180
|
+
self.routing_param_names = []
|
|
181
|
+
|
|
182
|
+
self.learnable_param_count1 = len(self.dynamic_params) * self.nmul
|
|
183
|
+
self.learnable_param_count2 = (
|
|
184
|
+
len(self.phy_param_names) - len(self.dynamic_params)
|
|
185
|
+
) * self.nmul + len(self.routing_param_names)
|
|
186
|
+
self.learnable_param_count = (
|
|
187
|
+
self.learnable_param_count1 + self.learnable_param_count2
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
def _unpack_parameters(
|
|
191
|
+
self,
|
|
192
|
+
parameters: torch.Tensor,
|
|
193
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
194
|
+
"""Extract physical model and routing parameters from NN output.
|
|
195
|
+
|
|
196
|
+
Parameters
|
|
197
|
+
----------
|
|
198
|
+
parameters
|
|
199
|
+
Unprocessed, learned parameters from a neural network.
|
|
200
|
+
|
|
201
|
+
Returns
|
|
202
|
+
-------
|
|
203
|
+
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
|
204
|
+
Tuple of physical and routing parameters.
|
|
205
|
+
"""
|
|
206
|
+
phy_param_count = len(self.parameter_bounds)
|
|
207
|
+
dy_param_count = len(self.dynamic_params)
|
|
208
|
+
dif_count = phy_param_count - dy_param_count
|
|
209
|
+
|
|
210
|
+
# Physical dynamic parameters
|
|
211
|
+
phy_dy_params = parameters[0].view(
|
|
212
|
+
parameters[0].shape[0],
|
|
213
|
+
parameters[0].shape[1],
|
|
214
|
+
dy_param_count,
|
|
215
|
+
self.nmul,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# Physical static parameters
|
|
219
|
+
phy_static_params = parameters[1][:, : dif_count * self.nmul].view(
|
|
220
|
+
parameters[1].shape[0],
|
|
221
|
+
dif_count,
|
|
222
|
+
self.nmul,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# Routing parameters
|
|
226
|
+
routing_params = None
|
|
227
|
+
if self.routing:
|
|
228
|
+
routing_params = parameters[1][:, dif_count * self.nmul :]
|
|
229
|
+
|
|
230
|
+
return (phy_dy_params, phy_static_params, routing_params)
|
|
231
|
+
|
|
232
|
+
def _descale_phy_dy_parameters(
|
|
233
|
+
self,
|
|
234
|
+
phy_dy_params: torch.Tensor,
|
|
235
|
+
dy_list: list,
|
|
236
|
+
) -> dict[str, torch.Tensor]:
|
|
237
|
+
"""Descale physical parameters.
|
|
238
|
+
|
|
239
|
+
Parameters
|
|
240
|
+
----------
|
|
241
|
+
phy_params
|
|
242
|
+
Normalized physical parameters.
|
|
243
|
+
dy_list
|
|
244
|
+
List of dynamic parameters.
|
|
245
|
+
|
|
246
|
+
Returns
|
|
247
|
+
-------
|
|
248
|
+
dict
|
|
249
|
+
Dictionary of descaled physical parameters.
|
|
250
|
+
"""
|
|
251
|
+
nsteps = phy_dy_params.shape[0]
|
|
252
|
+
ngrid = phy_dy_params.shape[1]
|
|
253
|
+
|
|
254
|
+
# TODO: Fix; if dynamic parameters are not entered in config as they are
|
|
255
|
+
# in HBV params list, then descaling misamtch will occur.
|
|
256
|
+
param_dict = {}
|
|
257
|
+
pmat = torch.ones([1, ngrid, 1]) * self.dy_drop
|
|
258
|
+
for i, name in enumerate(dy_list):
|
|
259
|
+
staPar = phy_dy_params[-1, :, i, :].unsqueeze(0).repeat([nsteps, 1, 1])
|
|
260
|
+
|
|
261
|
+
dynPar = phy_dy_params[:, :, i, :]
|
|
262
|
+
drmask = torch.bernoulli(pmat).detach_().to(self.device)
|
|
263
|
+
|
|
264
|
+
comPar = dynPar * (1 - drmask) + staPar * drmask
|
|
265
|
+
param_dict[name] = change_param_range(
|
|
266
|
+
param=comPar,
|
|
267
|
+
bounds=self.parameter_bounds[name],
|
|
268
|
+
)
|
|
269
|
+
return param_dict
|
|
270
|
+
|
|
271
|
+
def _descale_phy_stat_parameters(
|
|
272
|
+
self,
|
|
273
|
+
phy_stat_params: torch.Tensor,
|
|
274
|
+
stat_list: list,
|
|
275
|
+
) -> torch.Tensor:
|
|
276
|
+
"""Descale routing parameters.
|
|
277
|
+
|
|
278
|
+
Parameters
|
|
279
|
+
----------
|
|
280
|
+
routing_params
|
|
281
|
+
Normalized routing parameters.
|
|
282
|
+
|
|
283
|
+
Returns
|
|
284
|
+
-------
|
|
285
|
+
dict
|
|
286
|
+
Dictionary of descaled routing parameters.
|
|
287
|
+
"""
|
|
288
|
+
parameter_dict = {}
|
|
289
|
+
for i, name in enumerate(stat_list):
|
|
290
|
+
param = phy_stat_params[:, i, :]
|
|
291
|
+
|
|
292
|
+
parameter_dict[name] = change_param_range(
|
|
293
|
+
param=param,
|
|
294
|
+
bounds=self.parameter_bounds[name],
|
|
295
|
+
)
|
|
296
|
+
return parameter_dict
|
|
297
|
+
|
|
298
|
+
def _descale_route_parameters(
|
|
299
|
+
self,
|
|
300
|
+
routing_params: torch.Tensor,
|
|
301
|
+
) -> torch.Tensor:
|
|
302
|
+
"""Descale routing parameters.
|
|
303
|
+
|
|
304
|
+
Parameters
|
|
305
|
+
----------
|
|
306
|
+
routing_params
|
|
307
|
+
Normalized routing parameters.
|
|
308
|
+
|
|
309
|
+
Returns
|
|
310
|
+
-------
|
|
311
|
+
dict
|
|
312
|
+
Dictionary of descaled routing parameters.
|
|
313
|
+
"""
|
|
314
|
+
parameter_dict = {}
|
|
315
|
+
for i, name in enumerate(self.routing_parameter_bounds.keys()):
|
|
316
|
+
param = routing_params[:, i]
|
|
317
|
+
|
|
318
|
+
parameter_dict[name] = change_param_range(
|
|
319
|
+
param=param,
|
|
320
|
+
bounds=self.routing_parameter_bounds[name],
|
|
321
|
+
)
|
|
322
|
+
return parameter_dict
|
|
323
|
+
|
|
324
|
+
def forward(
|
|
325
|
+
self,
|
|
326
|
+
x_dict: dict[str, torch.Tensor],
|
|
327
|
+
parameters: torch.Tensor,
|
|
328
|
+
) -> tuple[dict[str, torch.Tensor], tuple]:
|
|
329
|
+
"""Forward pass.
|
|
330
|
+
|
|
331
|
+
Parameters
|
|
332
|
+
----------
|
|
333
|
+
x_dict
|
|
334
|
+
Dictionary of input forcing data.
|
|
335
|
+
parameters
|
|
336
|
+
Unprocessed, learned parameters from a neural network.
|
|
337
|
+
|
|
338
|
+
Returns
|
|
339
|
+
-------
|
|
340
|
+
tuple[dict, tuple]
|
|
341
|
+
Tuple or dictionary of model outputs.
|
|
342
|
+
"""
|
|
343
|
+
# Unpack input data.
|
|
344
|
+
x = x_dict['x_phy']
|
|
345
|
+
Ac = x_dict['ac_all'].unsqueeze(-1).repeat(1, self.nmul)
|
|
346
|
+
Elevation = x_dict['elev_all'].unsqueeze(-1).repeat(1, self.nmul)
|
|
347
|
+
self.muwts = x_dict.get('muwts', None)
|
|
348
|
+
ngrid = x.shape[1]
|
|
349
|
+
|
|
350
|
+
# Unpack parameters.
|
|
351
|
+
phy_dy_params, phy_static_params, routing_params = self._unpack_parameters(
|
|
352
|
+
parameters
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
if self.routing:
|
|
356
|
+
self.routing_param_dict = self._descale_route_parameters(routing_params)
|
|
357
|
+
phy_dy_params_dict = self._descale_phy_dy_parameters(
|
|
358
|
+
phy_dy_params,
|
|
359
|
+
dy_list=self.dynamic_params,
|
|
360
|
+
)
|
|
361
|
+
phy_static_params_dict = self._descale_phy_stat_parameters(
|
|
362
|
+
phy_static_params,
|
|
363
|
+
stat_list=[
|
|
364
|
+
param
|
|
365
|
+
for param in self.phy_param_names
|
|
366
|
+
if param not in self.dynamic_params
|
|
367
|
+
],
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
if (not self.states) or (not self.cache_states):
|
|
371
|
+
current_states = self._init_states(ngrid)
|
|
372
|
+
else:
|
|
373
|
+
current_states = self.states
|
|
374
|
+
|
|
375
|
+
fluxes, states = self._PBM(
|
|
376
|
+
x,
|
|
377
|
+
Ac,
|
|
378
|
+
Elevation,
|
|
379
|
+
current_states,
|
|
380
|
+
phy_dy_params_dict,
|
|
381
|
+
phy_static_params_dict,
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
# State caching
|
|
385
|
+
self._state_cache = states
|
|
386
|
+
|
|
387
|
+
if self.cache_states:
|
|
388
|
+
self.states = tuple(s[-1].detach() for s in self._state_cache)
|
|
389
|
+
|
|
390
|
+
return fluxes
|
|
391
|
+
|
|
392
|
+
def _PBM(
|
|
393
|
+
self,
|
|
394
|
+
forcing: torch.Tensor,
|
|
395
|
+
Ac: torch.Tensor,
|
|
396
|
+
Elevation: torch.Tensor,
|
|
397
|
+
states: tuple,
|
|
398
|
+
phy_dy_params_dict: dict,
|
|
399
|
+
phy_static_params_dict: dict,
|
|
400
|
+
) -> Union[tuple, dict[str, torch.Tensor]]:
|
|
401
|
+
"""Run through process-based model (PBM).
|
|
402
|
+
|
|
403
|
+
Flux outputs are in mm/day.
|
|
404
|
+
|
|
405
|
+
Parameters
|
|
406
|
+
----------
|
|
407
|
+
forcing
|
|
408
|
+
Input forcing data.
|
|
409
|
+
states
|
|
410
|
+
Initial model states.
|
|
411
|
+
full_param_dict
|
|
412
|
+
Dictionary of model parameters.
|
|
413
|
+
|
|
414
|
+
Returns
|
|
415
|
+
-------
|
|
416
|
+
Union[tuple, dict]
|
|
417
|
+
Tuple or dictionary of model outputs.
|
|
418
|
+
"""
|
|
419
|
+
SNOWPACK, MELTWATER, SM, SUZ, SLZ = states
|
|
420
|
+
|
|
421
|
+
# Forcings
|
|
422
|
+
P = forcing[:, :, self.variables.index('prcp')] # Precipitation
|
|
423
|
+
T = forcing[:, :, self.variables.index('tmean')] # Mean air temp
|
|
424
|
+
PET = forcing[:, :, self.variables.index('pet')] # Potential ET
|
|
425
|
+
nsteps, ngrid = P.shape
|
|
426
|
+
|
|
427
|
+
# Expand dims to accomodate for nmul models.
|
|
428
|
+
Pm = P.unsqueeze(2).repeat(1, 1, self.nmul)
|
|
429
|
+
Tm = T.unsqueeze(2).repeat(1, 1, self.nmul)
|
|
430
|
+
PETm = PET.unsqueeze(-1).repeat(1, 1, self.nmul)
|
|
431
|
+
|
|
432
|
+
# Apply correction factor to precipitation
|
|
433
|
+
# P = parPCORR.repeat(nsteps, 1) * P
|
|
434
|
+
|
|
435
|
+
# Initialize time series of model variables in shape [time, basins, nmul].
|
|
436
|
+
Qsimmu = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.001
|
|
437
|
+
Q0_sim = (
|
|
438
|
+
torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.0001
|
|
439
|
+
)
|
|
440
|
+
Q1_sim = (
|
|
441
|
+
torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.0001
|
|
442
|
+
)
|
|
443
|
+
Q2_sim = (
|
|
444
|
+
torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.0001
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
AET = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
448
|
+
recharge_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
449
|
+
excs_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
450
|
+
evapfactor_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
451
|
+
tosoil_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
452
|
+
PERC_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
453
|
+
SWE_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
454
|
+
capillary_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
455
|
+
|
|
456
|
+
# NOTE: new for MTS -- Save model states for all time steps.
|
|
457
|
+
SNOWPACK_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
458
|
+
MELTWATER_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
459
|
+
SM_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
460
|
+
SUZ_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
461
|
+
SLZ_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
462
|
+
|
|
463
|
+
param_dict = {}
|
|
464
|
+
for t in range(nsteps):
|
|
465
|
+
# Get dynamic parameter values per timestep.
|
|
466
|
+
for key in phy_dy_params_dict.keys():
|
|
467
|
+
param_dict[key] = phy_dy_params_dict[key][t, :, :]
|
|
468
|
+
for key in phy_static_params_dict.keys():
|
|
469
|
+
param_dict[key] = phy_static_params_dict[key][:, :]
|
|
470
|
+
|
|
471
|
+
# Separate precipitation into liquid and solid components.
|
|
472
|
+
PRECIP = Pm[t, :, :]
|
|
473
|
+
parTT_new = (Elevation >= 2000).type(torch.float32) * 4.0 + (
|
|
474
|
+
Elevation < 2000
|
|
475
|
+
).type(torch.float32) * param_dict['parTT']
|
|
476
|
+
RAIN = torch.mul(PRECIP, (Tm[t, :, :] >= parTT_new).type(torch.float32))
|
|
477
|
+
SNOW = torch.mul(PRECIP, (Tm[t, :, :] < parTT_new).type(torch.float32))
|
|
478
|
+
|
|
479
|
+
# Snow -------------------------------
|
|
480
|
+
SNOWPACK = SNOWPACK + SNOW
|
|
481
|
+
melt = param_dict['parCFMAX'] * (Tm[t, :, :] - parTT_new)
|
|
482
|
+
# melt[melt < 0.0] = 0.0
|
|
483
|
+
melt = torch.clamp(melt, min=0.0)
|
|
484
|
+
# melt[melt > SNOWPACK] = SNOWPACK[melt > SNOWPACK]
|
|
485
|
+
melt = torch.min(melt, SNOWPACK)
|
|
486
|
+
MELTWATER = MELTWATER + melt
|
|
487
|
+
SNOWPACK = SNOWPACK - melt
|
|
488
|
+
refreezing = (
|
|
489
|
+
param_dict['parCFR']
|
|
490
|
+
* param_dict['parCFMAX']
|
|
491
|
+
* (parTT_new - Tm[t, :, :])
|
|
492
|
+
)
|
|
493
|
+
# refreezing[refreezing < 0.0] = 0.0
|
|
494
|
+
# refreezing[refreezing > MELTWATER] = MELTWATER[refreezing > MELTWATER]
|
|
495
|
+
refreezing = torch.clamp(refreezing, min=0.0)
|
|
496
|
+
refreezing = torch.min(refreezing, MELTWATER)
|
|
497
|
+
SNOWPACK = SNOWPACK + refreezing
|
|
498
|
+
MELTWATER = MELTWATER - refreezing
|
|
499
|
+
tosoil = MELTWATER - (param_dict['parCWH'] * SNOWPACK)
|
|
500
|
+
tosoil = torch.clamp(tosoil, min=0.0)
|
|
501
|
+
MELTWATER = MELTWATER - tosoil
|
|
502
|
+
|
|
503
|
+
# Soil and evaporation -------------------------------
|
|
504
|
+
soil_wetness = (SM / param_dict['parFC']) ** param_dict['parBETA']
|
|
505
|
+
# soil_wetness[soil_wetness < 0.0] = 0.0
|
|
506
|
+
# soil_wetness[soil_wetness > 1.0] = 1.0
|
|
507
|
+
soil_wetness = torch.clamp(soil_wetness, min=0.0, max=1.0)
|
|
508
|
+
recharge = (RAIN + tosoil) * soil_wetness
|
|
509
|
+
|
|
510
|
+
SM = SM + RAIN + tosoil - recharge
|
|
511
|
+
|
|
512
|
+
excess = SM - param_dict['parFC']
|
|
513
|
+
excess = torch.clamp(excess, min=0.0)
|
|
514
|
+
SM = SM - excess
|
|
515
|
+
# NOTE: Different from HBV 1.0. Add static/dynamicET shape parameter parBETAET.
|
|
516
|
+
evapfactor = (
|
|
517
|
+
SM / (param_dict['parLP'] * param_dict['parFC'])
|
|
518
|
+
) ** param_dict['parBETAET']
|
|
519
|
+
evapfactor = torch.clamp(evapfactor, min=0.0, max=1.0)
|
|
520
|
+
ETact = PETm[t, :, :] * evapfactor
|
|
521
|
+
ETact = torch.min(SM, ETact)
|
|
522
|
+
SM = torch.clamp(SM - ETact, min=self.nearzero)
|
|
523
|
+
|
|
524
|
+
# Capillary rise (HBV 1.1p mod) -------------------------------
|
|
525
|
+
capillary = torch.min(
|
|
526
|
+
SLZ,
|
|
527
|
+
param_dict['parC']
|
|
528
|
+
* SLZ
|
|
529
|
+
* (1.0 - torch.clamp(SM / param_dict['parFC'], max=1.0)),
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
SM = torch.clamp(SM + capillary, min=self.nearzero)
|
|
533
|
+
SLZ = torch.clamp(SLZ - capillary, min=self.nearzero)
|
|
534
|
+
|
|
535
|
+
# Groundwater boxes -------------------------------
|
|
536
|
+
SUZ = SUZ + recharge + excess
|
|
537
|
+
PERC = torch.min(SUZ, param_dict['parPERC'])
|
|
538
|
+
SUZ = SUZ - PERC
|
|
539
|
+
Q0 = param_dict['parK0'] * torch.clamp(SUZ - param_dict['parUZL'], min=0.0)
|
|
540
|
+
SUZ = SUZ - Q0
|
|
541
|
+
Q1 = param_dict['parK1'] * SUZ
|
|
542
|
+
SUZ = SUZ - Q1
|
|
543
|
+
SLZ = SLZ + PERC
|
|
544
|
+
|
|
545
|
+
LF = torch.clamp(
|
|
546
|
+
(Ac - param_dict['parAC']) / 1000, min=-1, max=1
|
|
547
|
+
) * param_dict['parRT'] * (Ac < 2500) + torch.exp(
|
|
548
|
+
torch.clamp(-(Ac - 2500) / 50, min=-10.0, max=0.0)
|
|
549
|
+
) * param_dict['parRT'] * (Ac >= 2500)
|
|
550
|
+
SLZ = torch.clamp(SLZ + LF, min=0.0)
|
|
551
|
+
|
|
552
|
+
Q2 = param_dict['parK2'] * SLZ
|
|
553
|
+
SLZ = SLZ - Q2
|
|
554
|
+
|
|
555
|
+
# --- Outputs ---
|
|
556
|
+
Qsimmu[t, :, :] = Q0 + Q1 + Q2
|
|
557
|
+
Q0_sim[t, :, :] = Q0
|
|
558
|
+
Q1_sim[t, :, :] = Q1
|
|
559
|
+
Q2_sim[t, :, :] = Q2
|
|
560
|
+
AET[t, :, :] = ETact
|
|
561
|
+
SWE_sim[t, :, :] = SNOWPACK
|
|
562
|
+
capillary_sim[t, :, :] = capillary
|
|
563
|
+
|
|
564
|
+
recharge_sim[t, :, :] = recharge
|
|
565
|
+
excs_sim[t, :, :] = excess
|
|
566
|
+
evapfactor_sim[t, :, :] = evapfactor
|
|
567
|
+
tosoil_sim[t, :, :] = tosoil
|
|
568
|
+
PERC_sim[t, :, :] = PERC
|
|
569
|
+
|
|
570
|
+
# NOTE: new for MTS -- Save model states for all time steps.
|
|
571
|
+
SNOWPACK_sim[t, :, :] = SNOWPACK
|
|
572
|
+
MELTWATER_sim[t, :, :] = MELTWATER
|
|
573
|
+
SM_sim[t, :, :] = SM
|
|
574
|
+
SUZ_sim[t, :, :] = SUZ
|
|
575
|
+
SLZ_sim[t, :, :] = SLZ
|
|
576
|
+
|
|
577
|
+
# Get the average or weighted average using learned weights.
|
|
578
|
+
if self.muwts is None:
|
|
579
|
+
Qsimavg = Qsimmu.mean(-1)
|
|
580
|
+
else:
|
|
581
|
+
Qsimavg = (Qsimmu * self.muwts).sum(-1)
|
|
582
|
+
|
|
583
|
+
# Run routing
|
|
584
|
+
if self.routing:
|
|
585
|
+
# Routing for all components or just the average.
|
|
586
|
+
if self.comprout:
|
|
587
|
+
# All components; reshape to [time, gages * num models]
|
|
588
|
+
Qsim = Qsimmu.view(nsteps, ngrid * self.nmul)
|
|
589
|
+
else:
|
|
590
|
+
# Average, then do routing.
|
|
591
|
+
Qsim = Qsimavg
|
|
592
|
+
|
|
593
|
+
UH = uh_gamma(
|
|
594
|
+
self.routing_param_dict['route_a'].repeat(nsteps, 1).unsqueeze(-1),
|
|
595
|
+
self.routing_param_dict['route_b'].repeat(nsteps, 1).unsqueeze(-1),
|
|
596
|
+
lenF=self.lenF,
|
|
597
|
+
)
|
|
598
|
+
rf = torch.unsqueeze(Qsim, -1).permute([1, 2, 0]) # [gages,vars,time]
|
|
599
|
+
UH = UH.permute([1, 2, 0]) # [gages,vars,time]
|
|
600
|
+
Qsrout = uh_conv(rf, UH).permute([2, 0, 1])
|
|
601
|
+
|
|
602
|
+
# Routing individually for Q0, Q1, and Q2, all w/ dims [gages,vars,time].
|
|
603
|
+
rf_Q0 = Q0_sim.mean(-1, keepdim=True).permute([1, 2, 0])
|
|
604
|
+
Q0_rout = uh_conv(rf_Q0, UH).permute([2, 0, 1])
|
|
605
|
+
rf_Q1 = Q1_sim.mean(-1, keepdim=True).permute([1, 2, 0])
|
|
606
|
+
Q1_rout = uh_conv(rf_Q1, UH).permute([2, 0, 1])
|
|
607
|
+
rf_Q2 = Q2_sim.mean(-1, keepdim=True).permute([1, 2, 0])
|
|
608
|
+
Q2_rout = uh_conv(rf_Q2, UH).permute([2, 0, 1])
|
|
609
|
+
|
|
610
|
+
if self.comprout:
|
|
611
|
+
# Qs is now shape [time, [gages*num models], vars]
|
|
612
|
+
Qstemp = Qsrout.view(nsteps, ngrid, self.nmul)
|
|
613
|
+
if self.muwts is None:
|
|
614
|
+
Qs = Qstemp.mean(-1, keepdim=True)
|
|
615
|
+
else:
|
|
616
|
+
Qs = (Qstemp * self.muwts).sum(-1, keepdim=True)
|
|
617
|
+
else:
|
|
618
|
+
Qs = Qsrout
|
|
619
|
+
|
|
620
|
+
else:
|
|
621
|
+
# No routing, only output the average of all model sims.
|
|
622
|
+
Qsim = Qsimavg
|
|
623
|
+
Qs = torch.unsqueeze(Qsimavg, -1)
|
|
624
|
+
Q0_rout = Q0_sim.mean(-1, keepdim=True)
|
|
625
|
+
Q1_rout = Q1_sim.mean(-1, keepdim=True)
|
|
626
|
+
Q2_rout = Q2_sim.mean(-1, keepdim=True)
|
|
627
|
+
|
|
628
|
+
states = (SNOWPACK_sim, MELTWATER_sim, SM_sim, SUZ_sim, SLZ_sim)
|
|
629
|
+
|
|
630
|
+
if self.initialize:
|
|
631
|
+
# If initialize is True, only return warmed-up storages.
|
|
632
|
+
return {}, states
|
|
633
|
+
else:
|
|
634
|
+
# Baseflow index (BFI) calculation
|
|
635
|
+
BFI_sim = (
|
|
636
|
+
100
|
|
637
|
+
* (torch.sum(Q2_rout, dim=0) / (torch.sum(Qs, dim=0) + self.nearzero))[
|
|
638
|
+
:, 0
|
|
639
|
+
]
|
|
640
|
+
)
|
|
641
|
+
|
|
642
|
+
# Return all sim results.
|
|
643
|
+
flux_dict = {
|
|
644
|
+
'streamflow': Qs, # Routed Streamflow
|
|
645
|
+
'srflow': Q0_rout, # Routed surface runoff
|
|
646
|
+
'ssflow': Q1_rout, # Routed subsurface flow
|
|
647
|
+
'gwflow': Q2_rout, # Routed groundwater flow
|
|
648
|
+
'AET_hydro': AET.mean(-1, keepdim=True), # Actual ET
|
|
649
|
+
'PET_hydro': PETm.mean(-1, keepdim=True), # Potential ET
|
|
650
|
+
'SWE': SWE_sim.mean(-1, keepdim=True), # Snow water equivalent
|
|
651
|
+
'streamflow_no_rout': Qsim.unsqueeze(dim=2), # Streamflow
|
|
652
|
+
'srflow_no_rout': Q0_sim.mean(-1, keepdim=True), # Surface runoff
|
|
653
|
+
'ssflow_no_rout': Q1_sim.mean(-1, keepdim=True), # Subsurface flow
|
|
654
|
+
'gwflow_no_rout': Q2_sim.mean(-1, keepdim=True), # Groundwater flow
|
|
655
|
+
'recharge': recharge_sim.mean(-1, keepdim=True), # Recharge
|
|
656
|
+
'excs': excs_sim.mean(-1, keepdim=True), # Excess stored water
|
|
657
|
+
'evapfactor': evapfactor_sim.mean(
|
|
658
|
+
-1, keepdim=True
|
|
659
|
+
), # Evaporation factor
|
|
660
|
+
'tosoil': tosoil_sim.mean(-1, keepdim=True), # Infiltration
|
|
661
|
+
'percolation': PERC_sim.mean(-1, keepdim=True), # Percolation
|
|
662
|
+
'capillary': capillary_sim.mean(-1, keepdim=True), # Capillary rise
|
|
663
|
+
'BFI': BFI_sim, # Baseflow index
|
|
664
|
+
}
|
|
665
|
+
|
|
666
|
+
if not self.warm_up_states:
|
|
667
|
+
for key in flux_dict.keys():
|
|
668
|
+
if key != 'BFI':
|
|
669
|
+
flux_dict[key] = flux_dict[key][self.pred_cutoff :, :, :]
|
|
670
|
+
return flux_dict, states
|