hydrodl2 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydrodl2/__init__.py +122 -0
- hydrodl2/_version.py +34 -0
- hydrodl2/api/__init__.py +3 -0
- hydrodl2/api/methods.py +144 -0
- hydrodl2/core/calc/__init__.py +11 -0
- hydrodl2/core/calc/batch_jacobian.pye +501 -0
- hydrodl2/core/calc/fdj.py +92 -0
- hydrodl2/core/calc/uh_routing.py +105 -0
- hydrodl2/core/calc/utils.py +59 -0
- hydrodl2/core/utils/__init__.py +7 -0
- hydrodl2/core/utils/clean_temp.sh +8 -0
- hydrodl2/core/utils/utils.py +63 -0
- hydrodl2/models/hbv/hbv.py +596 -0
- hydrodl2/models/hbv/hbv_1_1p.py +608 -0
- hydrodl2/models/hbv/hbv_2.py +670 -0
- hydrodl2/models/hbv/hbv_2_hourly.py +897 -0
- hydrodl2/models/hbv/hbv_2_mts.py +377 -0
- hydrodl2/models/hbv/hbv_adj.py +712 -0
- hydrodl2/modules/__init__.py +2 -0
- hydrodl2/modules/data_assimilation/variational_prcp_da.py +1 -0
- hydrodl2-1.3.0.dist-info/METADATA +184 -0
- hydrodl2-1.3.0.dist-info/RECORD +24 -0
- hydrodl2-1.3.0.dist-info/WHEEL +4 -0
- hydrodl2-1.3.0.dist-info/licenses/LICENSE +31 -0
|
@@ -0,0 +1,897 @@
|
|
|
1
|
+
from typing import Any, Optional, Union
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from hydrodl2.core.calc import change_param_range, uh_conv, uh_gamma
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Hbv_2_hourly(torch.nn.Module):
|
|
9
|
+
"""Hourly HBV 2.0.
|
|
10
|
+
|
|
11
|
+
Multi-component, multi-scale, differentiable PyTorch HBV model with rainfall
|
|
12
|
+
runoff simulation on unit basins.
|
|
13
|
+
|
|
14
|
+
Authors
|
|
15
|
+
-------
|
|
16
|
+
- Wencong Yang, Leo Lonzarich, Yalan Song
|
|
17
|
+
- (Original NumPy HBV ver.) Beck et al., 2020 (http://www.gloh2o.org/hbv/).
|
|
18
|
+
- (HBV-light Version 2) Seibert, 2005
|
|
19
|
+
(https://www.geo.uzh.ch/dam/jcr:c8afa73c-ac90-478e-a8c7-929eed7b1b62/HBV_manual_2005.pdf).
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
config
|
|
24
|
+
Configuration dictionary.
|
|
25
|
+
device
|
|
26
|
+
Device to run the model on.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
config: Optional[dict[str, Any]] = None,
|
|
32
|
+
device: Optional[torch.device] = None,
|
|
33
|
+
) -> None:
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.name = 'HBV 2.0 Hourly'
|
|
36
|
+
self.config = config
|
|
37
|
+
self.initialize = False
|
|
38
|
+
self.warm_up = 0
|
|
39
|
+
self.pred_cutoff = 0
|
|
40
|
+
self.warm_up_states = True
|
|
41
|
+
self.dynamic_params = []
|
|
42
|
+
self.dy_drop = 0.0
|
|
43
|
+
self.variables = ['prcp', 'tmean', 'pet']
|
|
44
|
+
self.routing = False
|
|
45
|
+
self.lenF = 72
|
|
46
|
+
self.comprout = False
|
|
47
|
+
self.muwts = None
|
|
48
|
+
self.nearzero = 1e-5
|
|
49
|
+
self.nmul = 1
|
|
50
|
+
self.cache_states = False
|
|
51
|
+
self.device = device
|
|
52
|
+
|
|
53
|
+
self.states, self._states_cache = None, None
|
|
54
|
+
|
|
55
|
+
self._qs_buffer = []
|
|
56
|
+
self._max_history = 100 # Safe buffer size > lenF (72)
|
|
57
|
+
|
|
58
|
+
self.dt = 1.0 / 24
|
|
59
|
+
self.use_distr_routing = True
|
|
60
|
+
self.infiltration = True
|
|
61
|
+
self.lag_uh = True
|
|
62
|
+
|
|
63
|
+
self.state_names = [
|
|
64
|
+
'SNOWPACK', # Snowpack storage
|
|
65
|
+
'MELTWATER', # Meltwater storage
|
|
66
|
+
'SM', # Soil moisture storage
|
|
67
|
+
'SUZ', # Upper groundwater storage
|
|
68
|
+
'SLZ', # Lower groundwater storage
|
|
69
|
+
]
|
|
70
|
+
self.flux_names = [
|
|
71
|
+
'streamflow', # Routed Streamflow
|
|
72
|
+
'srflow', # Routed surface runoff
|
|
73
|
+
'ssflow', # Routed subsurface flow
|
|
74
|
+
'gwflow', # Routed groundwater flow
|
|
75
|
+
'AET_hydro', # Actual ET
|
|
76
|
+
'PET_hydro', # Potential ET
|
|
77
|
+
'SWE', # Snow water equivalent
|
|
78
|
+
'streamflow_no_rout', # Streamflow
|
|
79
|
+
'srflow_no_rout', # Surface runoff
|
|
80
|
+
'ssflow_no_rout', # Subsurface flow
|
|
81
|
+
'gwflow_no_rout', # Groundwater flow
|
|
82
|
+
'recharge', # Recharge
|
|
83
|
+
'excs', # Excess stored water
|
|
84
|
+
'evapfactor', # Evaporation factor
|
|
85
|
+
'tosoil', # Infiltration
|
|
86
|
+
'percolation', # Percolation
|
|
87
|
+
'capillary', # Capillary rise
|
|
88
|
+
'BFI', # Baseflow index
|
|
89
|
+
]
|
|
90
|
+
|
|
91
|
+
self.parameter_bounds = {
|
|
92
|
+
'parBETA': [1.0, 6.0],
|
|
93
|
+
'parFC': [50, 1000],
|
|
94
|
+
'parK0': [0.05, 0.9],
|
|
95
|
+
'parK1': [0.01, 0.5],
|
|
96
|
+
'parK2': [0.001, 0.2],
|
|
97
|
+
'parLP': [0.2, 1],
|
|
98
|
+
'parPERC': [0, 10],
|
|
99
|
+
'parUZL': [0, 100],
|
|
100
|
+
'parTT': [-2.5, 2.5],
|
|
101
|
+
'parCFMAX': [0.5, 10],
|
|
102
|
+
'parCFR': [0, 0.1],
|
|
103
|
+
'parCWH': [0, 0.2],
|
|
104
|
+
'parBETAET': [0.3, 5],
|
|
105
|
+
'parC': [0, 1],
|
|
106
|
+
'parRT': [0, 20],
|
|
107
|
+
'parAC': [0, 2500],
|
|
108
|
+
# Infiltration parameters for hourly
|
|
109
|
+
'parF0': [
|
|
110
|
+
5.0 / self.dt,
|
|
111
|
+
120.0 / self.dt,
|
|
112
|
+
], # dry (max) infiltration capacity, mm/day
|
|
113
|
+
'parFMIN': [0.0, 1.0], # wet (min) capacity ratio
|
|
114
|
+
'parALPHA': [0.5, 5.0], # shape of f(s); larger -> more thresholdy
|
|
115
|
+
}
|
|
116
|
+
self.routing_parameter_bounds = {
|
|
117
|
+
'route_a': [0, 5.0],
|
|
118
|
+
'route_b': [0, 12.0],
|
|
119
|
+
}
|
|
120
|
+
self.distr_parameter_bounds = {
|
|
121
|
+
'route_a': [0, 5.0],
|
|
122
|
+
'route_b': [0, 12.0],
|
|
123
|
+
'route_tau': [0, 48.0],
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
if not device:
|
|
127
|
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
128
|
+
|
|
129
|
+
if not self.infiltration:
|
|
130
|
+
self.parameter_bounds.pop('parF0')
|
|
131
|
+
self.parameter_bounds.pop('parFMIN')
|
|
132
|
+
self.parameter_bounds.pop('parALPHA')
|
|
133
|
+
if not self.lag_uh:
|
|
134
|
+
self.distr_parameter_bounds.pop('route_tau')
|
|
135
|
+
|
|
136
|
+
if config is not None:
|
|
137
|
+
# Overwrite defaults with config values.
|
|
138
|
+
self.warm_up = config.get('warm_up', self.warm_up)
|
|
139
|
+
self.warm_up_states = config.get('warm_up_states', self.warm_up_states)
|
|
140
|
+
self.dy_drop = config.get('dy_drop', self.dy_drop)
|
|
141
|
+
self.dynamic_params = config['dynamic_params'].get(
|
|
142
|
+
self.__class__.__name__, self.dynamic_params
|
|
143
|
+
)
|
|
144
|
+
self.variables = config.get('variables', self.variables)
|
|
145
|
+
self.routing = config.get('routing', self.routing)
|
|
146
|
+
self.comprout = config.get('comprout', self.comprout)
|
|
147
|
+
self.nearzero = config.get('nearzero', self.nearzero)
|
|
148
|
+
self.nmul = config.get('nmul', self.nmul)
|
|
149
|
+
self.cache_states = config.get('cache_states', self.cache_states)
|
|
150
|
+
self._set_parameters()
|
|
151
|
+
|
|
152
|
+
def _init_states(self, ngrid: int) -> tuple[torch.Tensor]:
|
|
153
|
+
"""Initialize model states to zero."""
|
|
154
|
+
|
|
155
|
+
def make_state():
|
|
156
|
+
return torch.full(
|
|
157
|
+
(ngrid, self.nmul), 0.001, dtype=torch.float32, device=self.device
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
return tuple(make_state() for _ in range(len(self.state_names)))
|
|
161
|
+
|
|
162
|
+
def get_states(self) -> Optional[tuple[torch.Tensor, ...]]:
|
|
163
|
+
"""Return internal model states.
|
|
164
|
+
|
|
165
|
+
Returns
|
|
166
|
+
-------
|
|
167
|
+
tuple[torch.Tensor, ...]
|
|
168
|
+
A tuple containing the states (SNOWPACK, MELTWATER, SM, SUZ, SLZ).
|
|
169
|
+
"""
|
|
170
|
+
return self._states_cache
|
|
171
|
+
|
|
172
|
+
def load_states(
|
|
173
|
+
self,
|
|
174
|
+
states: tuple[torch.Tensor, ...],
|
|
175
|
+
) -> None:
|
|
176
|
+
"""Load internal model states and set to model device and type.
|
|
177
|
+
|
|
178
|
+
Parameters
|
|
179
|
+
----------
|
|
180
|
+
states
|
|
181
|
+
A tuple containing the states (SNOWPACK, MELTWATER, SM, SUZ, SLZ).
|
|
182
|
+
"""
|
|
183
|
+
for state in states:
|
|
184
|
+
if not isinstance(state, torch.Tensor):
|
|
185
|
+
raise ValueError("Each element in `states` must be a tensor.")
|
|
186
|
+
nstates = len(self.state_names)
|
|
187
|
+
if not (isinstance(states, tuple) and len(states) == nstates):
|
|
188
|
+
raise ValueError(f"`states` must be a tuple of {nstates} tensors.")
|
|
189
|
+
|
|
190
|
+
self.states = tuple(
|
|
191
|
+
s.detach().to(self.device, dtype=torch.float32) for s in states
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
def _set_parameters(self) -> None:
|
|
195
|
+
"""Get physical parameters."""
|
|
196
|
+
self.phy_param_names = self.parameter_bounds.keys()
|
|
197
|
+
if self.routing:
|
|
198
|
+
self.routing_param_names = self.routing_parameter_bounds.keys()
|
|
199
|
+
else:
|
|
200
|
+
self.routing_param_names = []
|
|
201
|
+
|
|
202
|
+
self.learnable_param_count1 = len(self.dynamic_params) * self.nmul
|
|
203
|
+
self.learnable_param_count2 = (
|
|
204
|
+
len(self.phy_param_names) - len(self.dynamic_params)
|
|
205
|
+
) * self.nmul + len(self.routing_param_names)
|
|
206
|
+
self.learnable_param_count3 = len(self.distr_parameter_bounds)
|
|
207
|
+
self.learnable_param_count = (
|
|
208
|
+
self.learnable_param_count1
|
|
209
|
+
+ self.learnable_param_count2
|
|
210
|
+
+ self.learnable_param_count3
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
def _unpack_parameters(
|
|
214
|
+
self,
|
|
215
|
+
parameters: torch.Tensor,
|
|
216
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
217
|
+
"""Extract physical model and routing parameters from NN output.
|
|
218
|
+
|
|
219
|
+
Parameters
|
|
220
|
+
----------
|
|
221
|
+
parameters
|
|
222
|
+
Unprocessed, learned parameters from a neural network.
|
|
223
|
+
|
|
224
|
+
Returns
|
|
225
|
+
-------
|
|
226
|
+
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
|
227
|
+
Tuple of physical and routing parameters.
|
|
228
|
+
"""
|
|
229
|
+
phy_param_count = len(self.parameter_bounds)
|
|
230
|
+
dy_param_count = len(self.dynamic_params)
|
|
231
|
+
dif_count = phy_param_count - dy_param_count
|
|
232
|
+
|
|
233
|
+
# Physical dynamic parameters
|
|
234
|
+
phy_dy_params = parameters[0].view(
|
|
235
|
+
parameters[0].shape[0],
|
|
236
|
+
parameters[0].shape[1],
|
|
237
|
+
dy_param_count,
|
|
238
|
+
self.nmul,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
# Physical static parameters
|
|
242
|
+
phy_static_params = parameters[1][:, : dif_count * self.nmul].view(
|
|
243
|
+
parameters[1].shape[0],
|
|
244
|
+
dif_count,
|
|
245
|
+
self.nmul,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# Routing parameters
|
|
249
|
+
routing_params = None
|
|
250
|
+
if self.routing:
|
|
251
|
+
routing_params = parameters[1][:, dif_count * self.nmul :]
|
|
252
|
+
|
|
253
|
+
# Distributed routing parameters
|
|
254
|
+
distr_params = parameters[2]
|
|
255
|
+
|
|
256
|
+
return (phy_dy_params, phy_static_params, routing_params, distr_params)
|
|
257
|
+
|
|
258
|
+
def _descale_phy_dy_parameters(
|
|
259
|
+
self,
|
|
260
|
+
phy_dy_params: torch.Tensor,
|
|
261
|
+
dy_list: list,
|
|
262
|
+
) -> dict[str, torch.Tensor]:
|
|
263
|
+
"""Descale physical parameters.
|
|
264
|
+
|
|
265
|
+
Parameters
|
|
266
|
+
----------
|
|
267
|
+
phy_params
|
|
268
|
+
Normalized physical parameters.
|
|
269
|
+
dy_list
|
|
270
|
+
List of dynamic parameters.
|
|
271
|
+
|
|
272
|
+
Returns
|
|
273
|
+
-------
|
|
274
|
+
dict
|
|
275
|
+
Dictionary of descaled physical parameters.
|
|
276
|
+
"""
|
|
277
|
+
nsteps = phy_dy_params.shape[0]
|
|
278
|
+
ngrid = phy_dy_params.shape[1]
|
|
279
|
+
|
|
280
|
+
# TODO: Fix; if dynamic parameters are not entered in config as they are
|
|
281
|
+
# in HBV params list, then descaling misamtch will occur.
|
|
282
|
+
param_dict = {}
|
|
283
|
+
pmat = torch.ones([1, ngrid, 1]) * self.dy_drop
|
|
284
|
+
for i, name in enumerate(dy_list):
|
|
285
|
+
staPar = phy_dy_params[-1, :, i, :].unsqueeze(0).repeat([nsteps, 1, 1])
|
|
286
|
+
|
|
287
|
+
dynPar = phy_dy_params[:, :, i, :]
|
|
288
|
+
drmask = torch.bernoulli(pmat).detach_().to(self.device)
|
|
289
|
+
|
|
290
|
+
comPar = dynPar * (1 - drmask) + staPar * drmask
|
|
291
|
+
param_dict[name] = change_param_range(
|
|
292
|
+
param=comPar,
|
|
293
|
+
bounds=self.parameter_bounds[name],
|
|
294
|
+
)
|
|
295
|
+
return param_dict
|
|
296
|
+
|
|
297
|
+
def _descale_phy_stat_parameters(
|
|
298
|
+
self,
|
|
299
|
+
phy_stat_params: torch.Tensor,
|
|
300
|
+
stat_list: list,
|
|
301
|
+
) -> torch.Tensor:
|
|
302
|
+
"""Descale routing parameters.
|
|
303
|
+
|
|
304
|
+
Parameters
|
|
305
|
+
----------
|
|
306
|
+
routing_params
|
|
307
|
+
Normalized routing parameters.
|
|
308
|
+
|
|
309
|
+
Returns
|
|
310
|
+
-------
|
|
311
|
+
dict
|
|
312
|
+
Dictionary of descaled routing parameters.
|
|
313
|
+
"""
|
|
314
|
+
parameter_dict = {}
|
|
315
|
+
for i, name in enumerate(stat_list):
|
|
316
|
+
param = phy_stat_params[:, i, :]
|
|
317
|
+
|
|
318
|
+
parameter_dict[name] = change_param_range(
|
|
319
|
+
param=param,
|
|
320
|
+
bounds=self.parameter_bounds[name],
|
|
321
|
+
)
|
|
322
|
+
return parameter_dict
|
|
323
|
+
|
|
324
|
+
def _descale_route_parameters(
|
|
325
|
+
self,
|
|
326
|
+
routing_params: torch.Tensor,
|
|
327
|
+
) -> torch.Tensor:
|
|
328
|
+
"""Descale routing parameters.
|
|
329
|
+
|
|
330
|
+
Parameters
|
|
331
|
+
----------
|
|
332
|
+
routing_params
|
|
333
|
+
Normalized routing parameters.
|
|
334
|
+
|
|
335
|
+
Returns
|
|
336
|
+
-------
|
|
337
|
+
dict
|
|
338
|
+
Dictionary of descaled routing parameters.
|
|
339
|
+
"""
|
|
340
|
+
parameter_dict = {}
|
|
341
|
+
for i, name in enumerate(self.routing_parameter_bounds.keys()):
|
|
342
|
+
param = routing_params[:, i]
|
|
343
|
+
|
|
344
|
+
parameter_dict[name] = change_param_range(
|
|
345
|
+
param=param,
|
|
346
|
+
bounds=self.routing_parameter_bounds[name],
|
|
347
|
+
)
|
|
348
|
+
return parameter_dict
|
|
349
|
+
|
|
350
|
+
def _descale_distr_parameters(
|
|
351
|
+
self,
|
|
352
|
+
distr_params: torch.Tensor,
|
|
353
|
+
) -> dict[str, torch.Tensor]:
|
|
354
|
+
"""Descale distributed routing parameters.
|
|
355
|
+
|
|
356
|
+
Parameters
|
|
357
|
+
----------
|
|
358
|
+
distr_params
|
|
359
|
+
Normalized distributed routing parameters.
|
|
360
|
+
|
|
361
|
+
Returns
|
|
362
|
+
-------
|
|
363
|
+
dict
|
|
364
|
+
Dictionary of descaled distributed routing parameters.
|
|
365
|
+
"""
|
|
366
|
+
parameter_dict = {}
|
|
367
|
+
for i, name in enumerate(self.distr_parameter_bounds.keys()):
|
|
368
|
+
param = distr_params[:, i]
|
|
369
|
+
|
|
370
|
+
parameter_dict[name] = change_param_range(
|
|
371
|
+
param=param,
|
|
372
|
+
bounds=self.distr_parameter_bounds[name],
|
|
373
|
+
)
|
|
374
|
+
return parameter_dict
|
|
375
|
+
|
|
376
|
+
def forward(
|
|
377
|
+
self,
|
|
378
|
+
x_dict: dict[str, torch.Tensor],
|
|
379
|
+
parameters: torch.Tensor,
|
|
380
|
+
) -> tuple[dict[str, torch.Tensor], tuple]:
|
|
381
|
+
"""Forward pass.
|
|
382
|
+
|
|
383
|
+
Parameters
|
|
384
|
+
----------
|
|
385
|
+
x_dict
|
|
386
|
+
Dictionary of input forcing data.
|
|
387
|
+
parameters
|
|
388
|
+
Unprocessed, learned parameters from a neural network.
|
|
389
|
+
|
|
390
|
+
Returns
|
|
391
|
+
-------
|
|
392
|
+
tuple[dict, tuple]
|
|
393
|
+
Tuple or dictionary of model outputs.
|
|
394
|
+
"""
|
|
395
|
+
# Unpack input data.
|
|
396
|
+
x = x_dict['x_phy']
|
|
397
|
+
Ac = x_dict['ac_all'].unsqueeze(-1).repeat(1, self.nmul)
|
|
398
|
+
Elevation = x_dict['elev_all'].unsqueeze(-1).repeat(1, self.nmul)
|
|
399
|
+
outlet_topo = x_dict['outlet_topo']
|
|
400
|
+
areas = x_dict['areas']
|
|
401
|
+
self.muwts = x_dict.get('muwts', None)
|
|
402
|
+
ngrid = x.shape[1]
|
|
403
|
+
|
|
404
|
+
# Unpack parameters.
|
|
405
|
+
phy_dy_params, phy_static_params, routing_params, distr_params = (
|
|
406
|
+
self._unpack_parameters(parameters)
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
if self.routing:
|
|
410
|
+
self.routing_param_dict = self._descale_route_parameters(routing_params)
|
|
411
|
+
phy_dy_params_dict = self._descale_phy_dy_parameters(
|
|
412
|
+
phy_dy_params,
|
|
413
|
+
dy_list=self.dynamic_params,
|
|
414
|
+
)
|
|
415
|
+
phy_static_params_dict = self._descale_phy_stat_parameters(
|
|
416
|
+
phy_static_params,
|
|
417
|
+
stat_list=[
|
|
418
|
+
param
|
|
419
|
+
for param in self.phy_param_names
|
|
420
|
+
if param not in self.dynamic_params
|
|
421
|
+
],
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
if (not self.states) or (not self.cache_states):
|
|
425
|
+
current_states = self._init_states(ngrid)
|
|
426
|
+
else:
|
|
427
|
+
current_states = self.states
|
|
428
|
+
|
|
429
|
+
distr_params_dict = self._descale_distr_parameters(distr_params)
|
|
430
|
+
|
|
431
|
+
fluxes, states = self._PBM(
|
|
432
|
+
x,
|
|
433
|
+
Ac,
|
|
434
|
+
Elevation,
|
|
435
|
+
current_states,
|
|
436
|
+
phy_dy_params_dict,
|
|
437
|
+
phy_static_params_dict,
|
|
438
|
+
outlet_topo,
|
|
439
|
+
areas,
|
|
440
|
+
distr_params_dict,
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
# State caching
|
|
444
|
+
self._state_cache = states
|
|
445
|
+
|
|
446
|
+
if self.cache_states:
|
|
447
|
+
self.states = tuple(s[-1].detach() for s in self._state_cache)
|
|
448
|
+
|
|
449
|
+
return fluxes
|
|
450
|
+
|
|
451
|
+
def _PBM(
|
|
452
|
+
self,
|
|
453
|
+
forcing: torch.Tensor,
|
|
454
|
+
Ac: torch.Tensor,
|
|
455
|
+
Elevation: torch.Tensor,
|
|
456
|
+
states: tuple,
|
|
457
|
+
phy_dy_params_dict: dict,
|
|
458
|
+
phy_static_params_dict: dict,
|
|
459
|
+
outlet_topo: torch.Tensor,
|
|
460
|
+
areas: torch.Tensor,
|
|
461
|
+
distr_params_dict: dict,
|
|
462
|
+
) -> Union[tuple, dict[str, torch.Tensor]]:
|
|
463
|
+
"""Run through process-based model (PBM).
|
|
464
|
+
|
|
465
|
+
Flux outputs are in mm/hour.
|
|
466
|
+
|
|
467
|
+
Parameters
|
|
468
|
+
----------
|
|
469
|
+
forcing
|
|
470
|
+
Input forcing data.
|
|
471
|
+
states
|
|
472
|
+
Initial model states.
|
|
473
|
+
full_param_dict
|
|
474
|
+
Dictionary of model parameters.
|
|
475
|
+
|
|
476
|
+
Returns
|
|
477
|
+
-------
|
|
478
|
+
Union[tuple, dict]
|
|
479
|
+
Tuple or dictionary of model outputs.
|
|
480
|
+
"""
|
|
481
|
+
dt = self.dt
|
|
482
|
+
SNOWPACK, MELTWATER, SM, SUZ, SLZ = states
|
|
483
|
+
|
|
484
|
+
# Forcings
|
|
485
|
+
P = forcing[:, :, self.variables.index('prcp')] / dt # Precipitation
|
|
486
|
+
T = forcing[:, :, self.variables.index('tmean')] # Mean air temp
|
|
487
|
+
PET = forcing[:, :, self.variables.index('pet')] / dt # Potential ET
|
|
488
|
+
nsteps, ngrid = P.shape
|
|
489
|
+
|
|
490
|
+
# Expand dims to accomodate for nmul models.
|
|
491
|
+
Pm = P.unsqueeze(2).repeat(1, 1, self.nmul)
|
|
492
|
+
Tm = T.unsqueeze(2).repeat(1, 1, self.nmul)
|
|
493
|
+
PETm = PET.unsqueeze(-1).repeat(1, 1, self.nmul)
|
|
494
|
+
|
|
495
|
+
# Apply correction factor to precipitation
|
|
496
|
+
# P = parPCORR.repeat(nsteps, 1) * P
|
|
497
|
+
|
|
498
|
+
# Initialize time series of model variables in shape [time, basins, nmul].
|
|
499
|
+
Qsimmu = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.001
|
|
500
|
+
Q0_sim = (
|
|
501
|
+
torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.0001
|
|
502
|
+
)
|
|
503
|
+
Q1_sim = (
|
|
504
|
+
torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.0001
|
|
505
|
+
)
|
|
506
|
+
Q2_sim = (
|
|
507
|
+
torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.0001
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
AET = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
511
|
+
recharge_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
512
|
+
excs_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
513
|
+
evapfactor_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
514
|
+
tosoil_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
515
|
+
PERC_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
516
|
+
SWE_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
517
|
+
capillary_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
518
|
+
|
|
519
|
+
# NOTE: new for MTS -- Save model states for all time steps.
|
|
520
|
+
SNOWPACK_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
521
|
+
MELTWATER_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
522
|
+
SM_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
523
|
+
SUZ_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
524
|
+
SLZ_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
|
|
525
|
+
|
|
526
|
+
param_dict = {}
|
|
527
|
+
for t in range(nsteps):
|
|
528
|
+
# NOTE: new for MTS -- numerical guardrail for long-sequence running
|
|
529
|
+
SNOWPACK = torch.clamp(SNOWPACK, min=0.0)
|
|
530
|
+
MELTWATER = torch.clamp(MELTWATER, min=0.0)
|
|
531
|
+
SM = torch.clamp(SM, min=self.nearzero)
|
|
532
|
+
SUZ = torch.clamp(SUZ, min=self.nearzero)
|
|
533
|
+
SLZ = torch.clamp(SLZ, min=self.nearzero)
|
|
534
|
+
# ------------------------------------------------------------------
|
|
535
|
+
|
|
536
|
+
# Get dynamic parameter values per timestep.
|
|
537
|
+
for key in phy_dy_params_dict.keys():
|
|
538
|
+
param_dict[key] = phy_dy_params_dict[key][t, :, :]
|
|
539
|
+
for key in phy_static_params_dict.keys():
|
|
540
|
+
param_dict[key] = phy_static_params_dict[key][:, :]
|
|
541
|
+
|
|
542
|
+
# Separate precipitation into liquid and solid components.
|
|
543
|
+
PRECIP = Pm[t, :, :]
|
|
544
|
+
parTT_new = (Elevation >= 2000).type(torch.float32) * 4.0 + (
|
|
545
|
+
Elevation < 2000
|
|
546
|
+
).type(torch.float32) * param_dict['parTT']
|
|
547
|
+
RAIN = torch.mul(PRECIP, (Tm[t, :, :] >= parTT_new).type(torch.float32))
|
|
548
|
+
SNOW = torch.mul(PRECIP, (Tm[t, :, :] < parTT_new).type(torch.float32))
|
|
549
|
+
|
|
550
|
+
# Snow -------------------------------
|
|
551
|
+
SNOWPACK = SNOWPACK + SNOW * dt
|
|
552
|
+
melt = param_dict['parCFMAX'] * (Tm[t, :, :] - parTT_new)
|
|
553
|
+
# melt[melt < 0.0] = 0.0
|
|
554
|
+
melt = torch.clamp(melt, min=0.0)
|
|
555
|
+
# melt[melt > SNOWPACK] = SNOWPACK[melt > SNOWPACK]
|
|
556
|
+
melt = torch.min(melt * dt, SNOWPACK)
|
|
557
|
+
MELTWATER = MELTWATER + melt
|
|
558
|
+
SNOWPACK = SNOWPACK - melt
|
|
559
|
+
refreezing = (
|
|
560
|
+
param_dict['parCFR']
|
|
561
|
+
* param_dict['parCFMAX']
|
|
562
|
+
* (parTT_new - Tm[t, :, :])
|
|
563
|
+
)
|
|
564
|
+
# refreezing[refreezing < 0.0] = 0.0
|
|
565
|
+
# refreezing[refreezing > MELTWATER] = MELTWATER[refreezing > MELTWATER]
|
|
566
|
+
refreezing = torch.clamp(refreezing, min=0.0)
|
|
567
|
+
refreezing = torch.min(refreezing * dt, MELTWATER)
|
|
568
|
+
SNOWPACK = SNOWPACK + refreezing
|
|
569
|
+
MELTWATER = MELTWATER - refreezing
|
|
570
|
+
tosoil = (MELTWATER - (param_dict['parCWH'] * SNOWPACK)) / dt
|
|
571
|
+
tosoil = torch.clamp(tosoil, min=0.0)
|
|
572
|
+
MELTWATER = MELTWATER - tosoil * dt
|
|
573
|
+
|
|
574
|
+
# NOTE: new for MTS -- Hortonian Infiltration Excess
|
|
575
|
+
if self.infiltration:
|
|
576
|
+
# Hortonian infiltration excess: infiltration capacity as a function of wetness
|
|
577
|
+
W = RAIN + tosoil
|
|
578
|
+
s = torch.clamp(
|
|
579
|
+
SM / param_dict['parFC'], 0.0, 1.0 - 0.01
|
|
580
|
+
) # relative wetness, safe guard for pow and bf/fp16
|
|
581
|
+
parFMIN = param_dict['parFMIN'] * param_dict['parF0']
|
|
582
|
+
with torch.amp.autocast(
|
|
583
|
+
device_type='cuda', enabled=False
|
|
584
|
+
): # torch.pow not stable with bf/fp16 when base ~ 0
|
|
585
|
+
fcap = parFMIN + (param_dict['parF0'] - parFMIN) * torch.pow(
|
|
586
|
+
1.0 - s, param_dict['parALPHA']
|
|
587
|
+
)
|
|
588
|
+
infiltration = torch.minimum(W, fcap) # goes into soil
|
|
589
|
+
IE = torch.clamp(W - fcap, min=0.0) # Hortonian excess
|
|
590
|
+
|
|
591
|
+
# Soil and evaporation using Infiltration
|
|
592
|
+
soil_wetness = (SM / param_dict['parFC']) ** param_dict['parBETA']
|
|
593
|
+
soil_wetness = torch.clamp(soil_wetness, 0.0, 1.0)
|
|
594
|
+
recharge = infiltration * soil_wetness
|
|
595
|
+
SM = SM + (infiltration - recharge) * dt
|
|
596
|
+
else:
|
|
597
|
+
soil_wetness = (SM / param_dict['parFC']) ** param_dict['parBETA']
|
|
598
|
+
soil_wetness = torch.clamp(soil_wetness, min=0.0, max=1.0)
|
|
599
|
+
recharge = (RAIN + tosoil) * soil_wetness
|
|
600
|
+
SM = SM + (RAIN + tosoil - recharge) * dt
|
|
601
|
+
# ------------------------------------------------------
|
|
602
|
+
|
|
603
|
+
excess = (SM - param_dict['parFC']) / dt
|
|
604
|
+
excess = torch.clamp(excess, min=0.0)
|
|
605
|
+
SM = SM - excess * dt
|
|
606
|
+
# NOTE: Different from HBV 1.0. Add static/dynamicET shape parameter parBETAET.
|
|
607
|
+
evapfactor = (
|
|
608
|
+
SM / (param_dict['parLP'] * param_dict['parFC'])
|
|
609
|
+
) ** param_dict['parBETAET']
|
|
610
|
+
evapfactor = torch.clamp(evapfactor, min=0.0, max=1.0)
|
|
611
|
+
ETact = PETm[t, :, :] * evapfactor
|
|
612
|
+
ETact = torch.min(SM, ETact * dt) / dt
|
|
613
|
+
SM = torch.clamp(SM - ETact * dt, min=self.nearzero)
|
|
614
|
+
|
|
615
|
+
# Capillary rise (HBV 1.1p mod) -------------------------------
|
|
616
|
+
capillary = (
|
|
617
|
+
torch.min(
|
|
618
|
+
SLZ,
|
|
619
|
+
param_dict['parC']
|
|
620
|
+
* SLZ
|
|
621
|
+
* (1.0 - torch.clamp(SM / param_dict['parFC'], max=1.0))
|
|
622
|
+
* dt,
|
|
623
|
+
)
|
|
624
|
+
/ dt
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
SM = torch.clamp(SM + capillary * dt, min=self.nearzero)
|
|
628
|
+
SLZ = torch.clamp(SLZ - capillary * dt, min=self.nearzero)
|
|
629
|
+
|
|
630
|
+
# Groundwater boxes -------------------------------
|
|
631
|
+
SUZ = SUZ + (recharge + excess) * dt
|
|
632
|
+
PERC = torch.min(SUZ, param_dict['parPERC'] * dt) / dt
|
|
633
|
+
SUZ = SUZ - PERC * dt
|
|
634
|
+
Q0 = param_dict['parK0'] * torch.clamp(SUZ - param_dict['parUZL'], min=0.0)
|
|
635
|
+
SUZ = SUZ - Q0 * dt
|
|
636
|
+
Q1 = param_dict['parK1'] * SUZ
|
|
637
|
+
SUZ = SUZ - Q1 * dt
|
|
638
|
+
SLZ = SLZ + PERC * dt
|
|
639
|
+
|
|
640
|
+
LF = torch.clamp(
|
|
641
|
+
(Ac - param_dict['parAC']) / 1000, min=-1, max=1
|
|
642
|
+
) * param_dict['parRT'] * (Ac < 2500) + torch.exp(
|
|
643
|
+
torch.clamp(-(Ac - 2500) / 50, min=-10.0, max=0.0)
|
|
644
|
+
) * param_dict['parRT'] * (Ac >= 2500)
|
|
645
|
+
SLZ = torch.clamp(SLZ + LF * dt, min=0.0)
|
|
646
|
+
|
|
647
|
+
Q2 = param_dict['parK2'] * SLZ
|
|
648
|
+
SLZ = SLZ - Q2 * dt
|
|
649
|
+
|
|
650
|
+
# NOTE: new for MTS -- Add Hortonian Infiltration Excess
|
|
651
|
+
if self.infiltration:
|
|
652
|
+
Qsimmu[t, :, :] = Q0 + Q1 + Q2 + IE
|
|
653
|
+
else:
|
|
654
|
+
Qsimmu[t, :, :] = Q0 + Q1 + Q2
|
|
655
|
+
# ------------------------------------------------------
|
|
656
|
+
|
|
657
|
+
Q0_sim[t, :, :] = Q0
|
|
658
|
+
Q1_sim[t, :, :] = Q1
|
|
659
|
+
Q2_sim[t, :, :] = Q2
|
|
660
|
+
AET[t, :, :] = ETact
|
|
661
|
+
SWE_sim[t, :, :] = SNOWPACK
|
|
662
|
+
capillary_sim[t, :, :] = capillary
|
|
663
|
+
|
|
664
|
+
recharge_sim[t, :, :] = recharge
|
|
665
|
+
excs_sim[t, :, :] = excess
|
|
666
|
+
evapfactor_sim[t, :, :] = evapfactor
|
|
667
|
+
tosoil_sim[t, :, :] = tosoil
|
|
668
|
+
PERC_sim[t, :, :] = PERC
|
|
669
|
+
|
|
670
|
+
# NOTE: new for MTS -- Save model states for all time steps.
|
|
671
|
+
SNOWPACK_sim[t, :, :] = SNOWPACK
|
|
672
|
+
MELTWATER_sim[t, :, :] = MELTWATER
|
|
673
|
+
SM_sim[t, :, :] = SM
|
|
674
|
+
SUZ_sim[t, :, :] = SUZ
|
|
675
|
+
SLZ_sim[t, :, :] = SLZ
|
|
676
|
+
|
|
677
|
+
# Get the average or weighted average using learned weights.
|
|
678
|
+
if self.muwts is None:
|
|
679
|
+
Qsimavg = Qsimmu.mean(-1)
|
|
680
|
+
else:
|
|
681
|
+
Qsimavg = (Qsimmu * self.muwts).sum(-1)
|
|
682
|
+
|
|
683
|
+
# Run routing
|
|
684
|
+
if self.routing:
|
|
685
|
+
# Routing for all components or just the average.
|
|
686
|
+
if self.comprout:
|
|
687
|
+
# All components; reshape to [time, gages * num models]
|
|
688
|
+
Qsim = Qsimmu.view(nsteps, ngrid * self.nmul)
|
|
689
|
+
else:
|
|
690
|
+
# Average, then do routing.
|
|
691
|
+
Qsim = Qsimavg
|
|
692
|
+
|
|
693
|
+
UH = uh_gamma(
|
|
694
|
+
self.routing_param_dict['route_a'].repeat(nsteps, 1).unsqueeze(-1),
|
|
695
|
+
self.routing_param_dict['route_b'].repeat(nsteps, 1).unsqueeze(-1),
|
|
696
|
+
lenF=self.lenF,
|
|
697
|
+
)
|
|
698
|
+
rf = torch.unsqueeze(Qsim, -1).permute([1, 2, 0]) # [gages,vars,time]
|
|
699
|
+
UH = UH.permute([1, 2, 0]) # [gages,vars,time]
|
|
700
|
+
Qsrout = uh_conv(rf, UH).permute([2, 0, 1])
|
|
701
|
+
|
|
702
|
+
# Routing individually for Q0, Q1, and Q2, all w/ dims [gages,vars,time].
|
|
703
|
+
# rf_Q0 = Q0_sim.mean(-1, keepdim=True).permute([1, 2, 0])
|
|
704
|
+
# Q0_rout = uh_conv(rf_Q0, UH).permute([2, 0, 1])
|
|
705
|
+
# rf_Q1 = Q1_sim.mean(-1, keepdim=True).permute([1, 2, 0])
|
|
706
|
+
# Q1_rout = uh_conv(rf_Q1, UH).permute([2, 0, 1])
|
|
707
|
+
# rf_Q2 = Q2_sim.mean(-1, keepdim=True).permute([1, 2, 0])
|
|
708
|
+
# Q2_rout = uh_conv(rf_Q2, UH).permute([2, 0, 1])
|
|
709
|
+
|
|
710
|
+
if self.comprout:
|
|
711
|
+
# Qs is now shape [time, [gages*num models], vars]
|
|
712
|
+
Qstemp = Qsrout.view(nsteps, ngrid, self.nmul)
|
|
713
|
+
if self.muwts is None:
|
|
714
|
+
Qs = Qstemp.mean(-1, keepdim=True)
|
|
715
|
+
else:
|
|
716
|
+
Qs = (Qstemp * self.muwts).sum(-1, keepdim=True)
|
|
717
|
+
else:
|
|
718
|
+
Qs = Qsrout
|
|
719
|
+
|
|
720
|
+
else:
|
|
721
|
+
# No routing, only output the average of all model sims.
|
|
722
|
+
Qs = torch.unsqueeze(Qsimavg, -1)
|
|
723
|
+
# Q0_rout = Q1_rout = Q2_rout = None
|
|
724
|
+
|
|
725
|
+
states = (SNOWPACK_sim, MELTWATER_sim, SM_sim, SUZ_sim, SLZ_sim)
|
|
726
|
+
|
|
727
|
+
if self.initialize:
|
|
728
|
+
# If initialize is True, only return warmed-up storages.
|
|
729
|
+
return {}, states
|
|
730
|
+
else:
|
|
731
|
+
# Baseflow index (BFI) calculation
|
|
732
|
+
# BFI_sim = (
|
|
733
|
+
# 100
|
|
734
|
+
# * (torch.sum(Q2_rout, dim=0) / (torch.sum(Qs, dim=0) + self.nearzero))[
|
|
735
|
+
# :, 0
|
|
736
|
+
# ]
|
|
737
|
+
# )
|
|
738
|
+
|
|
739
|
+
# Return all sim results.
|
|
740
|
+
flux_dict = {
|
|
741
|
+
'Qs': Qs * dt, # Routed Streamflow for units
|
|
742
|
+
# 'srflow': Q0_rout * dt, # Routed surface runoff
|
|
743
|
+
# 'ssflow': Q1_rout * dt, # Routed subsurface flow
|
|
744
|
+
# 'gwflow': Q2_rout * dt, # Routed groundwater flow
|
|
745
|
+
# 'AET_hydro': AET.mean(-1, keepdim=True) * dt, # Actual ET
|
|
746
|
+
# 'PET_hydro': PETm.mean(-1, keepdim=True) * dt, # Potential ET
|
|
747
|
+
# 'SWE': SWE_sim.mean(-1, keepdim=True), # Snow water equivalent
|
|
748
|
+
# 'streamflow_no_rout': Qsim.unsqueeze(dim=2) * dt, # Streamflow
|
|
749
|
+
# 'srflow_no_rout': Q0_sim.mean(-1, keepdim=True) * dt, # Surface runoff
|
|
750
|
+
# 'ssflow_no_rout': Q1_sim.mean(-1, keepdim=True) * dt, # Subsurface flow
|
|
751
|
+
# 'gwflow_no_rout': Q2_sim.mean(-1, keepdim=True) * dt, # Groundwater flow
|
|
752
|
+
# 'recharge': recharge_sim.mean(-1, keepdim=True) * dt, # Recharge
|
|
753
|
+
# 'excs': excs_sim.mean(-1, keepdim=True) * dt, # Excess stored water
|
|
754
|
+
# 'evapfactor': evapfactor_sim.mean(-1, keepdim=True), # Evaporation factor
|
|
755
|
+
# 'tosoil': tosoil_sim.mean(-1, keepdim=True) * dt, # Infiltration
|
|
756
|
+
# 'percolation': PERC_sim.mean(-1, keepdim=True) * dt, # Percolation
|
|
757
|
+
# 'capillary': capillary_sim.mean(-1, keepdim=True) * dt, # Capillary rise
|
|
758
|
+
# 'BFI': BFI_sim * dt, # Baseflow index
|
|
759
|
+
}
|
|
760
|
+
|
|
761
|
+
if not self.warm_up_states:
|
|
762
|
+
for key in flux_dict.keys():
|
|
763
|
+
if key != 'BFI':
|
|
764
|
+
flux_dict[key] = flux_dict[key][self.pred_cutoff :, :, :]
|
|
765
|
+
|
|
766
|
+
if self.use_distr_routing:
|
|
767
|
+
# 1. Get current Raw Runoff
|
|
768
|
+
|
|
769
|
+
# 2. Manage Buffer
|
|
770
|
+
if self.cache_states:
|
|
771
|
+
self._qs_buffer.append((Qs * dt).detach())
|
|
772
|
+
# Keep buffer reasonable size (at least lenF + tau)
|
|
773
|
+
if len(self._qs_buffer) > self._max_history:
|
|
774
|
+
self._qs_buffer.pop(0)
|
|
775
|
+
|
|
776
|
+
# Create history tensor for convolution
|
|
777
|
+
# [History+1, Units, 1]
|
|
778
|
+
qs_history = torch.cat(self._qs_buffer, dim=0)
|
|
779
|
+
else:
|
|
780
|
+
# If not caching (e.g. training), use what we have
|
|
781
|
+
qs_history = Qs * dt
|
|
782
|
+
|
|
783
|
+
# Distributed routing for streamflow at gages
|
|
784
|
+
distr_out_dict = self.distr_routing(
|
|
785
|
+
Qs=qs_history,
|
|
786
|
+
distr_params_dict=distr_params_dict,
|
|
787
|
+
outlet_topo=outlet_topo,
|
|
788
|
+
areas=areas,
|
|
789
|
+
)
|
|
790
|
+
flux_dict['streamflow'] = distr_out_dict['Qs_rout']
|
|
791
|
+
|
|
792
|
+
if self.cache_states:
|
|
793
|
+
# If we passed in history [T], we get out [T]. We only want index -1.
|
|
794
|
+
flux_dict['streamflow'] = distr_out_dict['Qs_rout'][-1:]
|
|
795
|
+
else:
|
|
796
|
+
flux_dict['streamflow'] = distr_out_dict['Qs_rout']
|
|
797
|
+
|
|
798
|
+
return flux_dict, states
|
|
799
|
+
|
|
800
|
+
def distr_routing(
|
|
801
|
+
self,
|
|
802
|
+
Qs: torch.Tensor,
|
|
803
|
+
distr_params_dict: dict,
|
|
804
|
+
outlet_topo: torch.Tensor,
|
|
805
|
+
areas: torch.Tensor,
|
|
806
|
+
):
|
|
807
|
+
"""
|
|
808
|
+
:param Qs: (nsteps, n_units, 1)
|
|
809
|
+
:param distr_params_dict: dict of (n_pairs, n_params)
|
|
810
|
+
:param outlet_topo: (n_gages, n_units)
|
|
811
|
+
:param areas: (n_units,)
|
|
812
|
+
:return:
|
|
813
|
+
"""
|
|
814
|
+
device = areas.device
|
|
815
|
+
nsteps = Qs.size(0)
|
|
816
|
+
max_lag = self.lenF
|
|
817
|
+
|
|
818
|
+
# extract per-pair series
|
|
819
|
+
Qs_weighted = (
|
|
820
|
+
Qs * areas[None, :, None]
|
|
821
|
+
) # area-weighted runoff, (nsteps, n_units, 1)
|
|
822
|
+
reach_idx = (outlet_topo == 1).nonzero(as_tuple=False)
|
|
823
|
+
pair_rows = reach_idx[:, 0].to(device).long()
|
|
824
|
+
pair_cols = reach_idx[:, 1].to(device).long()
|
|
825
|
+
Qs_pairs = Qs_weighted[:, pair_cols, :] # (nsteps, n_pairs, 1)
|
|
826
|
+
|
|
827
|
+
# routing via convolution
|
|
828
|
+
UH = uh_gamma(
|
|
829
|
+
distr_params_dict['route_a'].repeat(nsteps, 1).unsqueeze(-1),
|
|
830
|
+
distr_params_dict['route_b'].repeat(nsteps, 1).unsqueeze(-1),
|
|
831
|
+
lenF=max_lag,
|
|
832
|
+
)
|
|
833
|
+
if self.lag_uh: # add a lag to the unit hydrograph
|
|
834
|
+
UH = self._frac_shift1d(UH, distr_params_dict['route_tau'])
|
|
835
|
+
rf = Qs_pairs.permute([1, 2, 0]).contiguous() # (n_pairs, 1, nsteps)
|
|
836
|
+
UH = UH.permute([1, 2, 0]).contiguous() # (n_pairs, 1, nsteps)
|
|
837
|
+
Qs_lagged = uh_conv(rf, UH).squeeze(1).contiguous() # (n_pairs, nsteps)
|
|
838
|
+
|
|
839
|
+
# Group-sum: scatter_add_ along rows
|
|
840
|
+
n_gages = int(outlet_topo.shape[0])
|
|
841
|
+
Qs_rout = torch.zeros(
|
|
842
|
+
n_gages, Qs_lagged.shape[1], device=Qs_lagged.device, dtype=Qs_lagged.dtype
|
|
843
|
+
)
|
|
844
|
+
Qs_rout.scatter_add_(
|
|
845
|
+
0, pair_rows.view(-1, 1).expand(-1, Qs_lagged.shape[1]), Qs_lagged
|
|
846
|
+
) # (n_gages, nsteps)
|
|
847
|
+
|
|
848
|
+
# Normalize by upstream area
|
|
849
|
+
denom = (outlet_topo * areas[None, :]).sum(dim=1).unsqueeze(1).clamp(min=1e-6)
|
|
850
|
+
Qs_rout = Qs_rout / denom
|
|
851
|
+
Qs_rout = Qs_rout.T.unsqueeze(-1) # (nsteps, n_gages, 1)
|
|
852
|
+
|
|
853
|
+
# output
|
|
854
|
+
output = {'Qs_rout': Qs_rout}
|
|
855
|
+
return output
|
|
856
|
+
|
|
857
|
+
@staticmethod
|
|
858
|
+
def _frac_shift1d(w, tau):
|
|
859
|
+
"""
|
|
860
|
+
Differentiable fractional shift: return w(t - tau) by mixing k- and (k+1)-step shifts.
|
|
861
|
+
For tau = k + f (0<=f<1): y[t] = (1-f)*w[t-k] + f*w[t-(k+1)].
|
|
862
|
+
w: [T,B,V].
|
|
863
|
+
tau: [B,V] (>=0 recommended).
|
|
864
|
+
"""
|
|
865
|
+
T, B, V = w.shape
|
|
866
|
+
device, dtype = w.device, w.dtype
|
|
867
|
+
|
|
868
|
+
# Decompose tau = k + f
|
|
869
|
+
tau = tau.view(1, B, V).to(dtype)
|
|
870
|
+
k = torch.floor(tau) # [1,B,V]
|
|
871
|
+
f = tau - k # [1,B,V]
|
|
872
|
+
|
|
873
|
+
# Time indices 0..T-1
|
|
874
|
+
t = torch.arange(T, device=device, dtype=dtype).view(T, 1, 1) # [T,1,1]
|
|
875
|
+
|
|
876
|
+
# Target indices for the two integer shifts
|
|
877
|
+
i0 = t - k # corresponds to shift by k
|
|
878
|
+
i1 = t - (k + 1) # corresponds to shift by k+1
|
|
879
|
+
|
|
880
|
+
# Gather with clamp + explicit zeroing (true zero padding)
|
|
881
|
+
i0c = i0.clamp(0, T - 1).long()
|
|
882
|
+
i1c = i1.clamp(0, T - 1).long()
|
|
883
|
+
|
|
884
|
+
w0 = torch.gather(w, 0, i0c)
|
|
885
|
+
w1 = torch.gather(w, 0, i1c)
|
|
886
|
+
|
|
887
|
+
mask0 = (i0 >= 0) & (i0 <= T - 1)
|
|
888
|
+
mask1 = (i1 >= 0) & (i1 <= T - 1)
|
|
889
|
+
w0 = w0 * mask0.to(dtype)
|
|
890
|
+
w1 = w1 * mask1.to(dtype)
|
|
891
|
+
|
|
892
|
+
# Linear blend: (1-f)*k-shift + f*(k+1)-shift
|
|
893
|
+
y = (1.0 - f) * w0 + f * w1
|
|
894
|
+
|
|
895
|
+
# Renormalize to unit mass per (B,V) -> may cause instability
|
|
896
|
+
# y = y / y.sum(0).clamp_min(1e-6)
|
|
897
|
+
return y # [T,B,V]
|