hydrodl2 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydrodl2/__init__.py +122 -0
- hydrodl2/_version.py +34 -0
- hydrodl2/api/__init__.py +3 -0
- hydrodl2/api/methods.py +144 -0
- hydrodl2/core/calc/__init__.py +11 -0
- hydrodl2/core/calc/batch_jacobian.pye +501 -0
- hydrodl2/core/calc/fdj.py +92 -0
- hydrodl2/core/calc/uh_routing.py +105 -0
- hydrodl2/core/calc/utils.py +59 -0
- hydrodl2/core/utils/__init__.py +7 -0
- hydrodl2/core/utils/clean_temp.sh +8 -0
- hydrodl2/core/utils/utils.py +63 -0
- hydrodl2/models/hbv/hbv.py +596 -0
- hydrodl2/models/hbv/hbv_1_1p.py +608 -0
- hydrodl2/models/hbv/hbv_2.py +670 -0
- hydrodl2/models/hbv/hbv_2_hourly.py +897 -0
- hydrodl2/models/hbv/hbv_2_mts.py +377 -0
- hydrodl2/models/hbv/hbv_adj.py +712 -0
- hydrodl2/modules/__init__.py +2 -0
- hydrodl2/modules/data_assimilation/variational_prcp_da.py +1 -0
- hydrodl2-1.3.0.dist-info/METADATA +184 -0
- hydrodl2-1.3.0.dist-info/RECORD +24 -0
- hydrodl2-1.3.0.dist-info/WHEEL +4 -0
- hydrodl2-1.3.0.dist-info/licenses/LICENSE +31 -0
|
@@ -0,0 +1,596 @@
|
|
|
1
|
+
from typing import Any, Optional, Union
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from hydrodl2.core.calc import change_param_range, uh_conv, uh_gamma
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Hbv(torch.nn.Module):
|
|
9
|
+
"""HBV 1.0 ~.
|
|
10
|
+
|
|
11
|
+
Multi-component, differentiable PyTorch HBV model with option to run without
|
|
12
|
+
internal state warmup.
|
|
13
|
+
|
|
14
|
+
Authors
|
|
15
|
+
-------
|
|
16
|
+
- Farshid Rahmani, Yalan Song, Leo Lonzarich
|
|
17
|
+
- (Original NumPy HBV ver.) Beck et al., 2020 (http://www.gloh2o.org/hbv/).
|
|
18
|
+
- (HBV-light Version 2) Seibert, 2005
|
|
19
|
+
(https://www.geo.uzh.ch/dam/jcr:c8afa73c-ac90-478e-a8c7-929eed7b1b62/HBV_manual_2005.pdf).
|
|
20
|
+
|
|
21
|
+
Publication
|
|
22
|
+
-----------
|
|
23
|
+
- Dapeng Feng, Jiangtao Liu, Kathryn Lawson, Chaopeng Shen. "Differentiable,
|
|
24
|
+
learnable, regionalized process-based models with multiphysical outputs
|
|
25
|
+
can approach state-of-the-art hydrologic prediction accuracy." Water
|
|
26
|
+
Resources Research (2020), 58, e2022WR032404.
|
|
27
|
+
https://doi.org/10.1029/2022WR032404.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
config
|
|
32
|
+
Configuration dictionary.
|
|
33
|
+
device
|
|
34
|
+
Device to run the model on.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
config: Optional[dict[str, Any]] = None,
|
|
40
|
+
device: Optional[torch.device] = None,
|
|
41
|
+
) -> None:
|
|
42
|
+
super().__init__()
|
|
43
|
+
self.name = 'HBV 1.0'
|
|
44
|
+
self.config = config
|
|
45
|
+
self.initialize = False
|
|
46
|
+
self.warm_up = 0
|
|
47
|
+
self.pred_cutoff = 0
|
|
48
|
+
self.warm_up_states = True
|
|
49
|
+
self.dynamic_params = []
|
|
50
|
+
self.dy_drop = 0.0
|
|
51
|
+
self.variables = ['prcp', 'tmean', 'pet']
|
|
52
|
+
self.routing = True
|
|
53
|
+
self.comprout = False
|
|
54
|
+
self.nearzero = 1e-5
|
|
55
|
+
self.nmul = 1
|
|
56
|
+
self.cache_states = False
|
|
57
|
+
self.device = device
|
|
58
|
+
|
|
59
|
+
self.states, self._states_cache = None, None
|
|
60
|
+
|
|
61
|
+
self.state_names = [
|
|
62
|
+
'SNOWPACK', # Snowpack storage
|
|
63
|
+
'MELTWATER', # Meltwater storage
|
|
64
|
+
'SM', # Soil moisture storage
|
|
65
|
+
'SUZ', # Upper groundwater storage
|
|
66
|
+
'SLZ', # Lower groundwater storage
|
|
67
|
+
]
|
|
68
|
+
self.flux_names = [
|
|
69
|
+
'streamflow', # Routed Streamflow
|
|
70
|
+
'srflow', # Routed surface runoff
|
|
71
|
+
'ssflow', # Routed subsurface flow
|
|
72
|
+
'gwflow', # Routed groundwater flow
|
|
73
|
+
'AET_hydro', # Actual ET
|
|
74
|
+
'PET_hydro', # Potential ET
|
|
75
|
+
'SWE', # Snow water equivalent
|
|
76
|
+
'streamflow_no_rout', # Streamflow
|
|
77
|
+
'srflow_no_rout', # Surface runoff
|
|
78
|
+
'ssflow_no_rout', # Subsurface flow
|
|
79
|
+
'gwflow_no_rout', # Groundwater flow
|
|
80
|
+
'recharge', # Recharge
|
|
81
|
+
'excs', # Excess stored water
|
|
82
|
+
'evapfactor', # Evaporation factor
|
|
83
|
+
'tosoil', # Infiltration
|
|
84
|
+
'percolation', # Percolation
|
|
85
|
+
'BFI', # Baseflow index
|
|
86
|
+
]
|
|
87
|
+
|
|
88
|
+
self.parameter_bounds = {
|
|
89
|
+
'parBETA': [1.0, 6.0],
|
|
90
|
+
'parFC': [50, 1000],
|
|
91
|
+
'parK0': [0.05, 0.9],
|
|
92
|
+
'parK1': [0.01, 0.5],
|
|
93
|
+
'parK2': [0.001, 0.2],
|
|
94
|
+
'parLP': [0.2, 1],
|
|
95
|
+
'parPERC': [0, 10],
|
|
96
|
+
'parUZL': [0, 100],
|
|
97
|
+
'parTT': [-2.5, 2.5],
|
|
98
|
+
'parCFMAX': [0.5, 10],
|
|
99
|
+
'parCFR': [0, 0.1],
|
|
100
|
+
'parCWH': [0, 0.2],
|
|
101
|
+
}
|
|
102
|
+
self.routing_parameter_bounds = {
|
|
103
|
+
'route_a': [0, 2.9],
|
|
104
|
+
'route_b': [0, 6.5],
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
if not device:
|
|
108
|
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
109
|
+
|
|
110
|
+
if config is not None:
|
|
111
|
+
# Overwrite defaults with config values.
|
|
112
|
+
self.warm_up = config.get('warm_up', self.warm_up)
|
|
113
|
+
self.warm_up_states = config.get('warm_up_states', self.warm_up_states)
|
|
114
|
+
self.dy_drop = config.get('dy_drop', self.dy_drop)
|
|
115
|
+
self.dynamic_params = config['dynamic_params'].get(
|
|
116
|
+
self.__class__.__name__, self.dynamic_params
|
|
117
|
+
)
|
|
118
|
+
self.variables = config.get('variables', self.variables)
|
|
119
|
+
self.routing = config.get('routing', self.routing)
|
|
120
|
+
self.comprout = config.get('comprout', self.comprout)
|
|
121
|
+
self.nearzero = config.get('nearzero', self.nearzero)
|
|
122
|
+
self.nmul = config.get('nmul', self.nmul)
|
|
123
|
+
self.cache_states = config.get('cache_states', False)
|
|
124
|
+
if 'parBETAET' in self.dynamic_params:
|
|
125
|
+
self.parameter_bounds['parBETAET'] = [0.3, 5]
|
|
126
|
+
self._set_parameters()
|
|
127
|
+
|
|
128
|
+
def _init_states(self, ngrid: int) -> tuple[torch.Tensor]:
|
|
129
|
+
"""Initialize model states to zero."""
|
|
130
|
+
|
|
131
|
+
def make_state():
|
|
132
|
+
return torch.full(
|
|
133
|
+
(ngrid, self.nmul), 0.001, dtype=torch.float32, device=self.device
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
return tuple(make_state() for _ in range(len(self.state_names)))
|
|
137
|
+
|
|
138
|
+
def get_states(self) -> Optional[tuple[torch.Tensor, ...]]:
|
|
139
|
+
"""Return internal model states.
|
|
140
|
+
|
|
141
|
+
Returns
|
|
142
|
+
-------
|
|
143
|
+
tuple[torch.Tensor, ...]
|
|
144
|
+
A tuple containing the states (SNOWPACK, MELTWATER, SM, SUZ, SLZ).
|
|
145
|
+
"""
|
|
146
|
+
return self._states_cache
|
|
147
|
+
|
|
148
|
+
def load_states(
|
|
149
|
+
self,
|
|
150
|
+
states: tuple[torch.Tensor, ...],
|
|
151
|
+
) -> None:
|
|
152
|
+
"""Load internal model states and set to model device and type.
|
|
153
|
+
|
|
154
|
+
Parameters
|
|
155
|
+
----------
|
|
156
|
+
states
|
|
157
|
+
A tuple containing the states (SNOWPACK, MELTWATER, SM, SUZ, SLZ).
|
|
158
|
+
"""
|
|
159
|
+
for state in states:
|
|
160
|
+
if not isinstance(state, torch.Tensor):
|
|
161
|
+
raise ValueError("Each element in `states` must be a tensor.")
|
|
162
|
+
nstates = len(self.state_names)
|
|
163
|
+
if not (isinstance(states, tuple) and len(states) == nstates):
|
|
164
|
+
raise ValueError(f"`states` must be a tuple of {nstates} tensors.")
|
|
165
|
+
|
|
166
|
+
self.states = tuple(
|
|
167
|
+
s.detach().to(self.device, dtype=torch.float32) for s in states
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
def _set_parameters(self) -> None:
|
|
171
|
+
"""Get physical parameters."""
|
|
172
|
+
self.phy_param_names = self.parameter_bounds.keys()
|
|
173
|
+
if self.routing:
|
|
174
|
+
self.routing_param_names = self.routing_parameter_bounds.keys()
|
|
175
|
+
else:
|
|
176
|
+
self.routing_param_names = []
|
|
177
|
+
|
|
178
|
+
self.learnable_param_count = len(self.phy_param_names) * self.nmul + len(
|
|
179
|
+
self.routing_param_names
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
def _unpack_parameters(
|
|
183
|
+
self,
|
|
184
|
+
parameters: torch.Tensor,
|
|
185
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
186
|
+
"""Extract physical model and routing parameters from NN output.
|
|
187
|
+
|
|
188
|
+
Parameters
|
|
189
|
+
----------
|
|
190
|
+
parameters
|
|
191
|
+
Unprocessed, learned parameters from a neural network.
|
|
192
|
+
|
|
193
|
+
Returns
|
|
194
|
+
-------
|
|
195
|
+
tuple[torch.Tensor, torch.Tensor]
|
|
196
|
+
Tuple of physical and routing parameters.
|
|
197
|
+
"""
|
|
198
|
+
phy_param_count = len(self.parameter_bounds)
|
|
199
|
+
|
|
200
|
+
# Physical parameters
|
|
201
|
+
phy_params = torch.sigmoid(
|
|
202
|
+
parameters[:, :, : phy_param_count * self.nmul]
|
|
203
|
+
).view(
|
|
204
|
+
parameters.shape[0],
|
|
205
|
+
parameters.shape[1],
|
|
206
|
+
phy_param_count,
|
|
207
|
+
self.nmul,
|
|
208
|
+
)
|
|
209
|
+
# Routing parameters
|
|
210
|
+
routing_params = None
|
|
211
|
+
if self.routing:
|
|
212
|
+
routing_params = torch.sigmoid(
|
|
213
|
+
parameters[-1, :, phy_param_count * self.nmul :],
|
|
214
|
+
)
|
|
215
|
+
return (phy_params, routing_params)
|
|
216
|
+
|
|
217
|
+
def _descale_phy_parameters(
|
|
218
|
+
self,
|
|
219
|
+
phy_params: torch.Tensor,
|
|
220
|
+
dy_list: list,
|
|
221
|
+
) -> dict[str, torch.Tensor]:
|
|
222
|
+
"""Descale physical parameters.
|
|
223
|
+
|
|
224
|
+
Parameters
|
|
225
|
+
----------
|
|
226
|
+
phy_params
|
|
227
|
+
Normalized physical parameters.
|
|
228
|
+
dy_list
|
|
229
|
+
List of dynamic parameters.
|
|
230
|
+
|
|
231
|
+
Returns
|
|
232
|
+
-------
|
|
233
|
+
dict
|
|
234
|
+
Dictionary of descaled physical parameters.
|
|
235
|
+
"""
|
|
236
|
+
nsteps = phy_params.shape[0]
|
|
237
|
+
ngrid = phy_params.shape[1]
|
|
238
|
+
|
|
239
|
+
param_dict = {}
|
|
240
|
+
pmat = torch.ones([1, ngrid, 1]) * self.dy_drop
|
|
241
|
+
for i, name in enumerate(self.parameter_bounds.keys()):
|
|
242
|
+
staPar = phy_params[-1, :, i, :].unsqueeze(0).repeat([nsteps, 1, 1])
|
|
243
|
+
if name in dy_list:
|
|
244
|
+
dynPar = phy_params[:, :, i, :]
|
|
245
|
+
drmask = torch.bernoulli(pmat).detach_().to(self.device)
|
|
246
|
+
comPar = dynPar * (1 - drmask) + staPar * drmask
|
|
247
|
+
param_dict[name] = change_param_range(
|
|
248
|
+
param=comPar,
|
|
249
|
+
bounds=self.parameter_bounds[name],
|
|
250
|
+
)
|
|
251
|
+
else:
|
|
252
|
+
param_dict[name] = change_param_range(
|
|
253
|
+
param=staPar,
|
|
254
|
+
bounds=self.parameter_bounds[name],
|
|
255
|
+
)
|
|
256
|
+
return param_dict
|
|
257
|
+
|
|
258
|
+
def _descale_route_parameters(
|
|
259
|
+
self,
|
|
260
|
+
routing_params: torch.Tensor,
|
|
261
|
+
) -> dict[str, torch.Tensor]:
|
|
262
|
+
"""Descale routing parameters.
|
|
263
|
+
|
|
264
|
+
Parameters
|
|
265
|
+
----------
|
|
266
|
+
routing_params
|
|
267
|
+
Normalized routing parameters.
|
|
268
|
+
|
|
269
|
+
Returns
|
|
270
|
+
-------
|
|
271
|
+
dict
|
|
272
|
+
Dictionary of descaled routing parameters.
|
|
273
|
+
"""
|
|
274
|
+
parameter_dict = {}
|
|
275
|
+
for i, name in enumerate(self.routing_parameter_bounds.keys()):
|
|
276
|
+
param = routing_params[:, i]
|
|
277
|
+
|
|
278
|
+
parameter_dict[name] = change_param_range(
|
|
279
|
+
param=param,
|
|
280
|
+
bounds=self.routing_parameter_bounds[name],
|
|
281
|
+
)
|
|
282
|
+
return parameter_dict
|
|
283
|
+
|
|
284
|
+
def forward(
|
|
285
|
+
self,
|
|
286
|
+
x_dict: dict[str, torch.Tensor],
|
|
287
|
+
parameters: torch.Tensor,
|
|
288
|
+
) -> Union[tuple, tuple[dict[str, torch.Tensor], tuple]]:
|
|
289
|
+
"""Forward pass.
|
|
290
|
+
|
|
291
|
+
Parameters
|
|
292
|
+
----------
|
|
293
|
+
x_dict
|
|
294
|
+
Dictionary of input forcing data.
|
|
295
|
+
parameters
|
|
296
|
+
Unprocessed, learned parameters from a neural network.
|
|
297
|
+
|
|
298
|
+
Returns
|
|
299
|
+
-------
|
|
300
|
+
Union[tuple, tuple[dict, tuple]]
|
|
301
|
+
Tuple or dictionary of model outputs.
|
|
302
|
+
"""
|
|
303
|
+
# Unpack input data.
|
|
304
|
+
x = x_dict['x_phy']
|
|
305
|
+
self.muwts = x_dict.get('muwts', None)
|
|
306
|
+
ngrid = x.shape[1]
|
|
307
|
+
|
|
308
|
+
# Unpack parameters.
|
|
309
|
+
phy_params, routing_params = self._unpack_parameters(parameters)
|
|
310
|
+
if self.routing:
|
|
311
|
+
self.routing_param_dict = self._descale_route_parameters(routing_params)
|
|
312
|
+
|
|
313
|
+
# Initialization
|
|
314
|
+
if self.warm_up_states:
|
|
315
|
+
warm_up = self.warm_up
|
|
316
|
+
else:
|
|
317
|
+
# No state warm up: run the full model for warm_up days.
|
|
318
|
+
self.pred_cutoff = self.warm_up
|
|
319
|
+
warm_up = 0
|
|
320
|
+
|
|
321
|
+
if (not self.states) or (not self.cache_states):
|
|
322
|
+
current_states = self._init_states(ngrid)
|
|
323
|
+
else:
|
|
324
|
+
current_states = self.states
|
|
325
|
+
|
|
326
|
+
# Warm-up model states - run the model only on warm_up days first.
|
|
327
|
+
if warm_up > 0:
|
|
328
|
+
with torch.no_grad():
|
|
329
|
+
phy_param_warmup_dict = self._descale_phy_parameters(
|
|
330
|
+
phy_params[:warm_up, :, :],
|
|
331
|
+
dy_list=[],
|
|
332
|
+
)
|
|
333
|
+
# a. Save current model settings.
|
|
334
|
+
init_flag, route_flag = self.initialize, self.routing
|
|
335
|
+
|
|
336
|
+
# b. Set temporary model settings for warm-up.
|
|
337
|
+
self.initialize, self.routing = True, False
|
|
338
|
+
|
|
339
|
+
current_states = self._PBM(
|
|
340
|
+
x[:warm_up, :, :],
|
|
341
|
+
current_states,
|
|
342
|
+
phy_param_warmup_dict,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
# c. Restore model settings.
|
|
346
|
+
self.initialize, self.routing = init_flag, route_flag
|
|
347
|
+
|
|
348
|
+
# Run the model for remainder of the simulation period.
|
|
349
|
+
phy_params_dict = self._descale_phy_parameters(
|
|
350
|
+
phy_params[warm_up:, :, :],
|
|
351
|
+
dy_list=self.dynamic_params,
|
|
352
|
+
)
|
|
353
|
+
fluxes, states = self._PBM(x[warm_up:, :, :], current_states, phy_params_dict)
|
|
354
|
+
|
|
355
|
+
# State caching
|
|
356
|
+
self._states_cache = [s.detach() for s in states]
|
|
357
|
+
|
|
358
|
+
if self.cache_states:
|
|
359
|
+
self.states = self._states_cache
|
|
360
|
+
|
|
361
|
+
return fluxes
|
|
362
|
+
|
|
363
|
+
def _PBM(
|
|
364
|
+
self,
|
|
365
|
+
forcing: torch.Tensor,
|
|
366
|
+
states: tuple,
|
|
367
|
+
full_param_dict: dict,
|
|
368
|
+
) -> Union[tuple, dict[str, torch.Tensor]]:
|
|
369
|
+
"""Run through process-based model (PBM).
|
|
370
|
+
|
|
371
|
+
Parameters
|
|
372
|
+
----------
|
|
373
|
+
forcing
|
|
374
|
+
Input forcing data.
|
|
375
|
+
states
|
|
376
|
+
Initial model states.
|
|
377
|
+
full_param_dict
|
|
378
|
+
Dictionary of model parameters.
|
|
379
|
+
|
|
380
|
+
Returns
|
|
381
|
+
-------
|
|
382
|
+
Union[tuple, dict]
|
|
383
|
+
Tuple or dictionary of model outputs.
|
|
384
|
+
"""
|
|
385
|
+
SNOWPACK, MELTWATER, SM, SUZ, SLZ = states
|
|
386
|
+
|
|
387
|
+
# Forcings
|
|
388
|
+
P = forcing[:, :, self.variables.index('prcp')] # Precipitation
|
|
389
|
+
T = forcing[:, :, self.variables.index('tmean')] # Mean air temp
|
|
390
|
+
PET = forcing[:, :, self.variables.index('pet')] # Potential ET
|
|
391
|
+
nsteps, ngrid = P.shape
|
|
392
|
+
|
|
393
|
+
# Expand dims to accomodate for nmul models.
|
|
394
|
+
Pm = P.unsqueeze(2).repeat(1, 1, self.nmul)
|
|
395
|
+
Tm = T.unsqueeze(2).repeat(1, 1, self.nmul)
|
|
396
|
+
PETm = PET.unsqueeze(-1).repeat(1, 1, self.nmul)
|
|
397
|
+
|
|
398
|
+
# Apply correction factor to precipitation
|
|
399
|
+
# P = parPCORR.repeat(nsteps, 1) * P
|
|
400
|
+
|
|
401
|
+
# Initialize time series of model variables in shape [time, basins, nmul].
|
|
402
|
+
Qsimmu = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.001
|
|
403
|
+
Q0_sim = (
|
|
404
|
+
torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.0001
|
|
405
|
+
)
|
|
406
|
+
Q1_sim = (
|
|
407
|
+
torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.0001
|
|
408
|
+
)
|
|
409
|
+
Q2_sim = (
|
|
410
|
+
torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.0001
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
# AET = PET_coef * PET
|
|
414
|
+
AET = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
415
|
+
recharge_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
416
|
+
excs_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
417
|
+
evapfactor_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
418
|
+
tosoil_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
419
|
+
PERC_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
420
|
+
SWE_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
421
|
+
|
|
422
|
+
param_dict = {}
|
|
423
|
+
for t in range(nsteps):
|
|
424
|
+
# Get dynamic parameter values per timestep.
|
|
425
|
+
for key in full_param_dict.keys():
|
|
426
|
+
param_dict[key] = full_param_dict[key][t, :, :]
|
|
427
|
+
|
|
428
|
+
# Separate precipitation into liquid and solid components.
|
|
429
|
+
PRECIP = Pm[t, :, :]
|
|
430
|
+
RAIN = torch.mul(
|
|
431
|
+
PRECIP, (Tm[t, :, :] >= param_dict['parTT']).type(torch.float32)
|
|
432
|
+
)
|
|
433
|
+
SNOW = torch.mul(
|
|
434
|
+
PRECIP, (Tm[t, :, :] < param_dict['parTT']).type(torch.float32)
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
# Snow -------------------------------
|
|
438
|
+
SNOWPACK = SNOWPACK + SNOW
|
|
439
|
+
melt = param_dict['parCFMAX'] * (Tm[t, :, :] - param_dict['parTT'])
|
|
440
|
+
# melt[melt < 0.0] = 0.0
|
|
441
|
+
melt = torch.clamp(melt, min=0.0)
|
|
442
|
+
# melt[melt > SNOWPACK] = SNOWPACK[melt > SNOWPACK]
|
|
443
|
+
melt = torch.min(melt, SNOWPACK)
|
|
444
|
+
MELTWATER = MELTWATER + melt
|
|
445
|
+
SNOWPACK = SNOWPACK - melt
|
|
446
|
+
refreezing = (
|
|
447
|
+
param_dict['parCFR']
|
|
448
|
+
* param_dict['parCFMAX']
|
|
449
|
+
* (param_dict['parTT'] - Tm[t, :, :])
|
|
450
|
+
)
|
|
451
|
+
# refreezing[refreezing < 0.0] = 0.0
|
|
452
|
+
# refreezing[refreezing > MELTWATER] = MELTWATER[refreezing > MELTWATER]
|
|
453
|
+
refreezing = torch.clamp(refreezing, min=0.0)
|
|
454
|
+
refreezing = torch.min(refreezing, MELTWATER)
|
|
455
|
+
SNOWPACK = SNOWPACK + refreezing
|
|
456
|
+
MELTWATER = MELTWATER - refreezing
|
|
457
|
+
tosoil = MELTWATER - (param_dict['parCWH'] * SNOWPACK)
|
|
458
|
+
tosoil = torch.clamp(tosoil, min=0.0)
|
|
459
|
+
MELTWATER = MELTWATER - tosoil
|
|
460
|
+
|
|
461
|
+
# Soil and evaporation -------------------------------
|
|
462
|
+
soil_wetness = (SM / param_dict['parFC']) ** param_dict['parBETA']
|
|
463
|
+
# soil_wetness[soil_wetness < 0.0] = 0.0
|
|
464
|
+
# soil_wetness[soil_wetness > 1.0] = 1.0
|
|
465
|
+
soil_wetness = torch.clamp(soil_wetness, min=0.0, max=1.0)
|
|
466
|
+
recharge = (RAIN + tosoil) * soil_wetness
|
|
467
|
+
|
|
468
|
+
SM = SM + RAIN + tosoil - recharge
|
|
469
|
+
|
|
470
|
+
excess = SM - param_dict['parFC']
|
|
471
|
+
excess = torch.clamp(excess, min=0.0)
|
|
472
|
+
SM = SM - excess
|
|
473
|
+
# parBETAET only has effect when it is a dynamic parameter.
|
|
474
|
+
evapfactor = SM / (param_dict['parLP'] * param_dict['parFC'])
|
|
475
|
+
if 'parBETAET' in param_dict:
|
|
476
|
+
evapfactor = evapfactor ** param_dict['parBETAET']
|
|
477
|
+
evapfactor = torch.clamp(evapfactor, min=0.0, max=1.0)
|
|
478
|
+
ETact = PETm[t, :, :] * evapfactor
|
|
479
|
+
ETact = torch.min(SM, ETact)
|
|
480
|
+
SM = torch.clamp(SM - ETact, min=self.nearzero)
|
|
481
|
+
|
|
482
|
+
# Groundwater boxes -------------------------------
|
|
483
|
+
SUZ = SUZ + recharge + excess
|
|
484
|
+
PERC = torch.min(SUZ, param_dict['parPERC'])
|
|
485
|
+
SUZ = SUZ - PERC
|
|
486
|
+
Q0 = param_dict['parK0'] * torch.clamp(SUZ - param_dict['parUZL'], min=0.0)
|
|
487
|
+
SUZ = SUZ - Q0
|
|
488
|
+
Q1 = param_dict['parK1'] * SUZ
|
|
489
|
+
SUZ = SUZ - Q1
|
|
490
|
+
SLZ = SLZ + PERC
|
|
491
|
+
Q2 = param_dict['parK2'] * SLZ
|
|
492
|
+
SLZ = SLZ - Q2
|
|
493
|
+
|
|
494
|
+
Qsimmu[t, :, :] = Q0 + Q1 + Q2
|
|
495
|
+
Q0_sim[t, :, :] = Q0
|
|
496
|
+
Q1_sim[t, :, :] = Q1
|
|
497
|
+
Q2_sim[t, :, :] = Q2
|
|
498
|
+
AET[t, :, :] = ETact
|
|
499
|
+
SWE_sim[t, :, :] = SNOWPACK
|
|
500
|
+
|
|
501
|
+
recharge_sim[t, :, :] = recharge
|
|
502
|
+
excs_sim[t, :, :] = excess
|
|
503
|
+
evapfactor_sim[t, :, :] = evapfactor
|
|
504
|
+
tosoil_sim[t, :, :] = tosoil
|
|
505
|
+
PERC_sim[t, :, :] = PERC
|
|
506
|
+
|
|
507
|
+
# Get the average or weighted average using learned weights.
|
|
508
|
+
if self.muwts is None:
|
|
509
|
+
Qsimavg = Qsimmu.mean(-1)
|
|
510
|
+
else:
|
|
511
|
+
Qsimavg = (Qsimmu * self.muwts).sum(-1)
|
|
512
|
+
|
|
513
|
+
# Run routing
|
|
514
|
+
if self.routing:
|
|
515
|
+
# Routing for all components or just the average.
|
|
516
|
+
if self.comprout:
|
|
517
|
+
# All components; reshape to [time, gages * num models]
|
|
518
|
+
Qsim = Qsimmu.view(nsteps, ngrid * self.nmul)
|
|
519
|
+
else:
|
|
520
|
+
# Average, then do routing.
|
|
521
|
+
Qsim = Qsimavg
|
|
522
|
+
|
|
523
|
+
UH = uh_gamma(
|
|
524
|
+
self.routing_param_dict['route_a'].repeat(nsteps, 1).unsqueeze(-1),
|
|
525
|
+
self.routing_param_dict['route_b'].repeat(nsteps, 1).unsqueeze(-1),
|
|
526
|
+
lenF=15,
|
|
527
|
+
)
|
|
528
|
+
rf = torch.unsqueeze(Qsim, -1).permute([1, 2, 0]) # [gages,vars,time]
|
|
529
|
+
UH = UH.permute([1, 2, 0]) # [gages,vars,time]
|
|
530
|
+
Qsrout = uh_conv(rf, UH).permute([2, 0, 1])
|
|
531
|
+
|
|
532
|
+
# Routing individually for Q0, Q1, and Q2, all w/ dims [gages,vars,time].
|
|
533
|
+
rf_Q0 = Q0_sim.mean(-1, keepdim=True).permute([1, 2, 0])
|
|
534
|
+
Q0_rout = uh_conv(rf_Q0, UH).permute([2, 0, 1])
|
|
535
|
+
rf_Q1 = Q1_sim.mean(-1, keepdim=True).permute([1, 2, 0])
|
|
536
|
+
Q1_rout = uh_conv(rf_Q1, UH).permute([2, 0, 1])
|
|
537
|
+
rf_Q2 = Q2_sim.mean(-1, keepdim=True).permute([1, 2, 0])
|
|
538
|
+
Q2_rout = uh_conv(rf_Q2, UH).permute([2, 0, 1])
|
|
539
|
+
|
|
540
|
+
if self.comprout:
|
|
541
|
+
# Qs is now shape [time, [gages*num models], vars]
|
|
542
|
+
Qstemp = Qsrout.view(nsteps, ngrid, self.nmul)
|
|
543
|
+
if self.muwts is None:
|
|
544
|
+
Qs = Qstemp.mean(-1, keepdim=True)
|
|
545
|
+
else:
|
|
546
|
+
Qs = (Qstemp * self.muwts).sum(-1, keepdim=True)
|
|
547
|
+
else:
|
|
548
|
+
Qs = Qsrout
|
|
549
|
+
|
|
550
|
+
else:
|
|
551
|
+
# No routing, only output the average of all model sims.
|
|
552
|
+
Qs = torch.unsqueeze(Qsimavg, -1)
|
|
553
|
+
Q0_rout = Q1_rout = Q2_rout = None
|
|
554
|
+
|
|
555
|
+
states = (SNOWPACK, MELTWATER, SM, SUZ, SLZ)
|
|
556
|
+
|
|
557
|
+
if self.initialize:
|
|
558
|
+
# If initialize is True, only return warmed-up storages.
|
|
559
|
+
return states
|
|
560
|
+
else:
|
|
561
|
+
# Baseflow index (BFI) calculation
|
|
562
|
+
BFI_sim = (
|
|
563
|
+
100
|
|
564
|
+
* (torch.sum(Q2_rout, dim=0) / (torch.sum(Qs, dim=0) + self.nearzero))[
|
|
565
|
+
:, 0
|
|
566
|
+
]
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
# Return all sim results.
|
|
570
|
+
flux_dict = {
|
|
571
|
+
'streamflow': Qs, # Routed Streamflow
|
|
572
|
+
'srflow': Q0_rout, # Routed surface runoff
|
|
573
|
+
'ssflow': Q1_rout, # Routed subsurface flow
|
|
574
|
+
'gwflow': Q2_rout, # Routed groundwater flow
|
|
575
|
+
'AET_hydro': AET.mean(-1, keepdim=True), # Actual ET
|
|
576
|
+
'PET_hydro': PETm.mean(-1, keepdim=True), # Potential ET
|
|
577
|
+
'SWE': SWE_sim.mean(-1, keepdim=True), # Snow water equivalent
|
|
578
|
+
'streamflow_no_rout': Qsim.unsqueeze(dim=2), # Streamflow
|
|
579
|
+
'srflow_no_rout': Q0_sim.mean(-1, keepdim=True), # Surface runoff
|
|
580
|
+
'ssflow_no_rout': Q1_sim.mean(-1, keepdim=True), # Subsurface flow
|
|
581
|
+
'gwflow_no_rout': Q2_sim.mean(-1, keepdim=True), # Groundwater flow
|
|
582
|
+
'recharge': recharge_sim.mean(-1, keepdim=True), # Recharge
|
|
583
|
+
'excs': excs_sim.mean(-1, keepdim=True), # Excess stored water
|
|
584
|
+
'evapfactor': evapfactor_sim.mean(
|
|
585
|
+
-1, keepdim=True
|
|
586
|
+
), # Evaporation factor
|
|
587
|
+
'tosoil': tosoil_sim.mean(-1, keepdim=True), # Infiltration
|
|
588
|
+
'percolation': PERC_sim.mean(-1, keepdim=True), # Percolation
|
|
589
|
+
'BFI': BFI_sim, # Baseflow index
|
|
590
|
+
}
|
|
591
|
+
|
|
592
|
+
if not self.warm_up_states:
|
|
593
|
+
for key in flux_dict.keys():
|
|
594
|
+
if key != 'BFI':
|
|
595
|
+
flux_dict[key] = flux_dict[key][self.pred_cutoff :, :, :]
|
|
596
|
+
return flux_dict, states
|