hydrodl2 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydrodl2/__init__.py +122 -0
- hydrodl2/_version.py +34 -0
- hydrodl2/api/__init__.py +3 -0
- hydrodl2/api/methods.py +144 -0
- hydrodl2/core/calc/__init__.py +11 -0
- hydrodl2/core/calc/batch_jacobian.pye +501 -0
- hydrodl2/core/calc/fdj.py +92 -0
- hydrodl2/core/calc/uh_routing.py +105 -0
- hydrodl2/core/calc/utils.py +59 -0
- hydrodl2/core/utils/__init__.py +7 -0
- hydrodl2/core/utils/clean_temp.sh +8 -0
- hydrodl2/core/utils/utils.py +63 -0
- hydrodl2/models/hbv/hbv.py +596 -0
- hydrodl2/models/hbv/hbv_1_1p.py +608 -0
- hydrodl2/models/hbv/hbv_2.py +670 -0
- hydrodl2/models/hbv/hbv_2_hourly.py +897 -0
- hydrodl2/models/hbv/hbv_2_mts.py +377 -0
- hydrodl2/models/hbv/hbv_adj.py +712 -0
- hydrodl2/modules/__init__.py +2 -0
- hydrodl2/modules/data_assimilation/variational_prcp_da.py +1 -0
- hydrodl2-1.3.0.dist-info/METADATA +184 -0
- hydrodl2-1.3.0.dist-info/RECORD +24 -0
- hydrodl2-1.3.0.dist-info/WHEEL +4 -0
- hydrodl2-1.3.0.dist-info/licenses/LICENSE +31 -0
|
@@ -0,0 +1,377 @@
|
|
|
1
|
+
from typing import Any, Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from tqdm import tqdm
|
|
5
|
+
|
|
6
|
+
from hydrodl2.models.hbv.hbv_2 import Hbv_2
|
|
7
|
+
from hydrodl2.models.hbv.hbv_2_hourly import Hbv_2_hourly
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Hbv_2_mts(torch.nn.Module):
|
|
11
|
+
"""HBV 2.0, multi timescale, distributed UH.
|
|
12
|
+
|
|
13
|
+
Multi-component, multi-scale, differentiable PyTorch HBV model with rainfall
|
|
14
|
+
runoff simulation on unit basins.
|
|
15
|
+
|
|
16
|
+
Authors
|
|
17
|
+
-------
|
|
18
|
+
- Wencong Yang
|
|
19
|
+
- (Original NumPy HBV ver.) Beck et al., 2020 (http://www.gloh2o.org/hbv/).
|
|
20
|
+
- (HBV-light Version 2) Seibert, 2005
|
|
21
|
+
(https://www.geo.uzh.ch/dam/jcr:c8afa73c-ac90-478e-a8c7-929eed7b1b62/HBV_manual_2005.pdf).
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
config
|
|
26
|
+
Configuration dictionary.
|
|
27
|
+
device
|
|
28
|
+
Device to run the model on.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
low_freq_config: Optional[dict[str, Any]] = None,
|
|
34
|
+
high_freq_config: Optional[dict[str, Any]] = None,
|
|
35
|
+
device: Optional[torch.device] = None,
|
|
36
|
+
) -> None:
|
|
37
|
+
super().__init__()
|
|
38
|
+
self.device = device if device is not None else torch.device('cpu')
|
|
39
|
+
self.dtype = torch.float32
|
|
40
|
+
self.low_freq_model = Hbv_2(low_freq_config, device=device)
|
|
41
|
+
self.low_freq_model.initialize = True
|
|
42
|
+
self.high_freq_model = Hbv_2_hourly(high_freq_config, device=device)
|
|
43
|
+
self._state_cache = [None, None]
|
|
44
|
+
self.states = (None, None)
|
|
45
|
+
self.load_from_cache = False
|
|
46
|
+
self.use_from_cache = False
|
|
47
|
+
|
|
48
|
+
# # learnable transfer
|
|
49
|
+
# self.state_transfer_model = torch.nn.ModuleDict(
|
|
50
|
+
# {
|
|
51
|
+
# name: torch.nn.Sequential(
|
|
52
|
+
# torch.nn.Linear(
|
|
53
|
+
# self.low_freq_model.nmul, self.high_freq_model.nmul
|
|
54
|
+
# ),
|
|
55
|
+
# torch.nn.ReLU(),
|
|
56
|
+
# )
|
|
57
|
+
# for name in self.high_freq_model.state_names
|
|
58
|
+
# }
|
|
59
|
+
# )
|
|
60
|
+
# Identity state transfer
|
|
61
|
+
self.state_transfer_model = torch.nn.ModuleDict(
|
|
62
|
+
{name: torch.nn.Identity() for name in self.high_freq_model.state_names}
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
self.train_spatial_chunk_size = high_freq_config['train_spatial_chunk_size']
|
|
66
|
+
self.simulate_spatial_chunk_size = high_freq_config[
|
|
67
|
+
'simulate_spatial_chunk_size'
|
|
68
|
+
]
|
|
69
|
+
self.simulate_temporal_chunk_size = high_freq_config[
|
|
70
|
+
'simulate_temporal_chunk_size'
|
|
71
|
+
]
|
|
72
|
+
self.spatial_chunk_size = self.train_spatial_chunk_size
|
|
73
|
+
self.simulate_mode = False
|
|
74
|
+
|
|
75
|
+
# warmup steps for routing during training.
|
|
76
|
+
self.train_warmup = high_freq_config['train_warmup']
|
|
77
|
+
|
|
78
|
+
def get_states(self) -> Optional[tuple[torch.Tensor, ...]]:
|
|
79
|
+
"""Return internal states for high and low frequency models."""
|
|
80
|
+
lof_states = self.low_freq_model.get_states()
|
|
81
|
+
hif_states = self.high_freq_model.get_states()
|
|
82
|
+
return (lof_states, hif_states)
|
|
83
|
+
|
|
84
|
+
def load_states(
|
|
85
|
+
self,
|
|
86
|
+
state_tuple: tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]],
|
|
87
|
+
) -> None:
|
|
88
|
+
"""Load internal model states sideload low frequency states."""
|
|
89
|
+
if not isinstance(state_tuple, tuple) or len(state_tuple) != 2:
|
|
90
|
+
raise ValueError("`states` must be a tuple of two tuples of tensors.")
|
|
91
|
+
self._state_cache = tuple(
|
|
92
|
+
tuple(s[-1].detach().to(self.device, dtype=self.dtype) for s in states)
|
|
93
|
+
for states in state_tuple
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
if self.load_from_cache:
|
|
97
|
+
# Only sideload low-frequency states.
|
|
98
|
+
self.low_freq_model.load_states(state_tuple[0])
|
|
99
|
+
|
|
100
|
+
def _forward(
|
|
101
|
+
self,
|
|
102
|
+
x_dict: dict[str, torch.Tensor],
|
|
103
|
+
parameters: tuple[list[torch.Tensor], list[torch.Tensor]],
|
|
104
|
+
) -> dict[str, torch.Tensor]:
|
|
105
|
+
"""Base forward."""
|
|
106
|
+
# 1. Transfer states
|
|
107
|
+
low_freq_parameters, high_freq_parameters = parameters
|
|
108
|
+
|
|
109
|
+
if self.use_from_cache and (self._state_cache[1] is not None):
|
|
110
|
+
states = self.states[1]
|
|
111
|
+
else:
|
|
112
|
+
low_freq_x_dict = {
|
|
113
|
+
'x_phy': x_dict['x_phy_low_freq'],
|
|
114
|
+
'ac_all': x_dict['ac_all'],
|
|
115
|
+
'elev_all': x_dict['elev_all'],
|
|
116
|
+
'muwts': x_dict.get('muwts', None),
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
self.low_freq_model.states = None
|
|
120
|
+
self.low_freq_model(
|
|
121
|
+
low_freq_x_dict,
|
|
122
|
+
low_freq_parameters,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# Low-frequency states at last timestep
|
|
126
|
+
self._state_cache[0] = self.low_freq_model.states
|
|
127
|
+
states = self.state_transfer(self.low_freq_model.states)
|
|
128
|
+
|
|
129
|
+
# 2. Transfer parameters
|
|
130
|
+
phy_dy_params_dict, phy_static_params_dict, distr_params_dict = (
|
|
131
|
+
self.param_transfer(
|
|
132
|
+
low_freq_parameters,
|
|
133
|
+
high_freq_parameters,
|
|
134
|
+
)
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
# Run the model
|
|
138
|
+
x = x_dict['x_phy_high_freq']
|
|
139
|
+
|
|
140
|
+
Ac = x_dict['ac_all'].unsqueeze(-1).expand(-1, self.high_freq_model.nmul)
|
|
141
|
+
Elevation = (
|
|
142
|
+
x_dict['elev_all'].unsqueeze(-1).expand(-1, self.high_freq_model.nmul)
|
|
143
|
+
)
|
|
144
|
+
outlet_topo = x_dict['outlet_topo']
|
|
145
|
+
areas = x_dict['areas']
|
|
146
|
+
|
|
147
|
+
predictions, hif_states = self.high_freq_model._PBM(
|
|
148
|
+
forcing=x,
|
|
149
|
+
Ac=Ac,
|
|
150
|
+
Elevation=Elevation,
|
|
151
|
+
states=tuple(states),
|
|
152
|
+
phy_dy_params_dict=phy_dy_params_dict,
|
|
153
|
+
phy_static_params_dict=phy_static_params_dict,
|
|
154
|
+
outlet_topo=outlet_topo,
|
|
155
|
+
areas=areas,
|
|
156
|
+
distr_params_dict=distr_params_dict,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
# State caching
|
|
160
|
+
self._state_cache[1] = tuple(s.detach() for s in hif_states)
|
|
161
|
+
if self.load_from_cache:
|
|
162
|
+
new_states = []
|
|
163
|
+
|
|
164
|
+
# low-frequency states remain the same
|
|
165
|
+
new_states.append(self._state_cache[0])
|
|
166
|
+
|
|
167
|
+
# high-frequency states updated
|
|
168
|
+
new_states.append(tuple(s[-1] for s in hif_states))
|
|
169
|
+
self.states = tuple(new_states)
|
|
170
|
+
|
|
171
|
+
# Temp: save initial states
|
|
172
|
+
# torch.save(tuple(tuple(s.detach().cpu() for s in states) for states in self._state_cache), "/projects/mhpi/leoglonz/ciroh-ua/dhbv2_mts/ngen_resources/data/dhbv2_mts/models/hfv2.2_15yr/initial_states_2009.pt")
|
|
173
|
+
|
|
174
|
+
return predictions
|
|
175
|
+
|
|
176
|
+
def forward(
|
|
177
|
+
self,
|
|
178
|
+
x_dict: dict[str, torch.Tensor],
|
|
179
|
+
parameters: tuple[list[torch.Tensor], list[torch.Tensor]],
|
|
180
|
+
) -> dict[str, torch.Tensor]:
|
|
181
|
+
"""Foward supports spatial and temporal chunking.
|
|
182
|
+
|
|
183
|
+
x_dict and parameters can be in cpu for simulation mode to save GPU
|
|
184
|
+
memory.
|
|
185
|
+
"""
|
|
186
|
+
device = self.device
|
|
187
|
+
n_units = x_dict['areas'].shape[0]
|
|
188
|
+
spatial_chunk_size = self.spatial_chunk_size
|
|
189
|
+
temporal_chunk_size = self.simulate_temporal_chunk_size
|
|
190
|
+
train_warmup = self.train_warmup
|
|
191
|
+
|
|
192
|
+
if (not self.simulate_mode) and (n_units <= spatial_chunk_size):
|
|
193
|
+
self.high_freq_model.use_distr_routing = False
|
|
194
|
+
return self._forward(x_dict, parameters)
|
|
195
|
+
|
|
196
|
+
# Chunked runoff generation for simulation mode or large training batches
|
|
197
|
+
self.high_freq_model.use_distr_routing = False
|
|
198
|
+
preds_list = []
|
|
199
|
+
prog_bar = tqdm(
|
|
200
|
+
range(0, n_units, spatial_chunk_size),
|
|
201
|
+
desc="Spatial runoff chunks",
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
for i in prog_bar:
|
|
205
|
+
end_idx = min(i + spatial_chunk_size, n_units)
|
|
206
|
+
reach_idx = (x_dict['outlet_topo'] == 1).nonzero(as_tuple=False)
|
|
207
|
+
idxs_in_chunk = (reach_idx[:, 1] >= i) & (reach_idx[:, 1] < end_idx)
|
|
208
|
+
|
|
209
|
+
chunk_x_dict = {
|
|
210
|
+
'x_phy_low_freq': x_dict['x_phy_low_freq'][:, i:end_idx].to(device),
|
|
211
|
+
'x_phy_high_freq': x_dict['x_phy_high_freq'][:, i:end_idx].to(device),
|
|
212
|
+
'ac_all': x_dict['ac_all'][i:end_idx].to(device),
|
|
213
|
+
'elev_all': x_dict['elev_all'][i:end_idx].to(device),
|
|
214
|
+
'areas': x_dict['areas'][i:end_idx].to(device),
|
|
215
|
+
'outlet_topo': x_dict['outlet_topo'][:, i:end_idx].to(device),
|
|
216
|
+
}
|
|
217
|
+
chunk_parameters = (
|
|
218
|
+
[
|
|
219
|
+
parameters[0][0][:, i:end_idx].to(
|
|
220
|
+
device
|
|
221
|
+
), # low-freq dynamic phy params
|
|
222
|
+
parameters[0][1][i:end_idx].to(
|
|
223
|
+
device
|
|
224
|
+
), # low-freq static phy params
|
|
225
|
+
],
|
|
226
|
+
[
|
|
227
|
+
parameters[1][0][:, i:end_idx].to(
|
|
228
|
+
device
|
|
229
|
+
), # high-freq dynamic phy params
|
|
230
|
+
parameters[1][1][i:end_idx].to(
|
|
231
|
+
device
|
|
232
|
+
), # high-freq static phy params
|
|
233
|
+
parameters[1][2][idxs_in_chunk].to(
|
|
234
|
+
device
|
|
235
|
+
), # high-freq distributed params
|
|
236
|
+
],
|
|
237
|
+
)
|
|
238
|
+
chunk_predictions = self._forward(chunk_x_dict, chunk_parameters)
|
|
239
|
+
preds_list.append(chunk_predictions)
|
|
240
|
+
|
|
241
|
+
predictions = self.concat_spatial_chunks(preds_list)
|
|
242
|
+
runoff = predictions['Qs']
|
|
243
|
+
high_freq_length = runoff.shape[0]
|
|
244
|
+
|
|
245
|
+
# Chunked routing
|
|
246
|
+
_, _, _, distr_params = self.high_freq_model.unpack_parameters(parameters[1])
|
|
247
|
+
distr_params_dict = self.high_freq_model._descale_distr_parameters(distr_params)
|
|
248
|
+
distr_params_dict = {
|
|
249
|
+
key: value.to(device) for key, value in distr_params_dict.items()
|
|
250
|
+
}
|
|
251
|
+
outlet_topo = x_dict['outlet_topo'].to(device)
|
|
252
|
+
areas = x_dict['areas'].to(device)
|
|
253
|
+
|
|
254
|
+
preds_list = []
|
|
255
|
+
prog_bar = tqdm(
|
|
256
|
+
range(train_warmup, high_freq_length, temporal_chunk_size),
|
|
257
|
+
desc="Temporal routing chunks",
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
for t in prog_bar:
|
|
261
|
+
end_t = min(t + temporal_chunk_size, high_freq_length)
|
|
262
|
+
chunk_runoff = runoff[t - train_warmup : end_t]
|
|
263
|
+
chunk_predictions = self.high_freq_model.distr_routing(
|
|
264
|
+
Qs=chunk_runoff,
|
|
265
|
+
distr_params_dict=distr_params_dict,
|
|
266
|
+
outlet_topo=outlet_topo,
|
|
267
|
+
areas=areas,
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# Remove routing warmup for all but first chunk
|
|
271
|
+
if t > train_warmup:
|
|
272
|
+
chunk_predictions = {
|
|
273
|
+
key: value[train_warmup:]
|
|
274
|
+
for key, value in chunk_predictions.items()
|
|
275
|
+
}
|
|
276
|
+
preds_list.append(chunk_predictions)
|
|
277
|
+
|
|
278
|
+
routing_predictions = self.concat_temporal_chunks(preds_list)
|
|
279
|
+
predictions['streamflow'] = routing_predictions['Qs_rout']
|
|
280
|
+
|
|
281
|
+
return predictions
|
|
282
|
+
|
|
283
|
+
def set_mode(self, is_simulate: bool):
|
|
284
|
+
"""Set simulate mode."""
|
|
285
|
+
if is_simulate:
|
|
286
|
+
self.spatial_chunk_size = self.simulate_spatial_chunk_size
|
|
287
|
+
self.simulate_mode = True
|
|
288
|
+
else:
|
|
289
|
+
self.spatial_chunk_size = self.train_spatial_chunk_size
|
|
290
|
+
self.simulate_mode = False
|
|
291
|
+
|
|
292
|
+
def param_transfer(
|
|
293
|
+
self,
|
|
294
|
+
low_freq_parameters: list[torch.Tensor],
|
|
295
|
+
high_freq_parameters: list[torch.Tensor],
|
|
296
|
+
):
|
|
297
|
+
"""Map low-frequency parameters to high-frequency parameters."""
|
|
298
|
+
warmup_phy_dy_params, warmup_phy_static_params, warmup_routing_params = (
|
|
299
|
+
self.low_freq_model._unpack_parameters(low_freq_parameters)
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
phy_dy_params, phy_static_params, routing_params, distr_params = (
|
|
303
|
+
self.high_freq_model._unpack_parameters(high_freq_parameters)
|
|
304
|
+
)
|
|
305
|
+
# New dynamic params
|
|
306
|
+
phy_dy_params_dict = self.high_freq_model._descale_phy_dy_parameters(
|
|
307
|
+
phy_dy_params, dy_list=self.high_freq_model.dynamic_params
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
# Keep warmup static params, add high-freq specific static params
|
|
311
|
+
static_param_names = [
|
|
312
|
+
param
|
|
313
|
+
for param in self.high_freq_model.phy_param_names
|
|
314
|
+
if param not in self.high_freq_model.dynamic_params
|
|
315
|
+
]
|
|
316
|
+
warmup_static_param_names = [
|
|
317
|
+
param
|
|
318
|
+
for param in self.low_freq_model.phy_param_names
|
|
319
|
+
if param not in self.low_freq_model.dynamic_params
|
|
320
|
+
]
|
|
321
|
+
var_indexes = [
|
|
322
|
+
i
|
|
323
|
+
for i, param in enumerate(static_param_names)
|
|
324
|
+
if param not in warmup_static_param_names
|
|
325
|
+
]
|
|
326
|
+
phy_static_params_dict = self.high_freq_model._descale_phy_stat_parameters(
|
|
327
|
+
torch.concat(
|
|
328
|
+
[warmup_phy_static_params, phy_static_params[:, var_indexes]], dim=1
|
|
329
|
+
),
|
|
330
|
+
stat_list=static_param_names,
|
|
331
|
+
)
|
|
332
|
+
# New distributed params
|
|
333
|
+
distr_params_dict = self.high_freq_model._descale_distr_parameters(distr_params)
|
|
334
|
+
|
|
335
|
+
# New routing params
|
|
336
|
+
if self.high_freq_model.routing:
|
|
337
|
+
self.high_freq_model.routing_param_dict = (
|
|
338
|
+
self.high_freq_model._descale_rout_parameters(routing_params)
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
return phy_dy_params_dict, phy_static_params_dict, distr_params_dict
|
|
342
|
+
|
|
343
|
+
def state_transfer(self, states: list[torch.Tensor]):
|
|
344
|
+
"""Map low-frequency states to high-frequency states."""
|
|
345
|
+
states_dict = dict(zip(self.high_freq_model.state_names, states))
|
|
346
|
+
return [
|
|
347
|
+
self.state_transfer_model[key](states_dict[key])
|
|
348
|
+
for key in states_dict.keys()
|
|
349
|
+
]
|
|
350
|
+
|
|
351
|
+
@staticmethod
|
|
352
|
+
def concat_spatial_chunks(pred_list: list[dict[str, torch.Tensor]]):
|
|
353
|
+
"""Concatenate spatial chunk pedictions."""
|
|
354
|
+
output = {}
|
|
355
|
+
for key in pred_list[0].keys():
|
|
356
|
+
if pred_list[0][key].ndim == 3:
|
|
357
|
+
output[key] = torch.cat(
|
|
358
|
+
[preds[key] for preds in pred_list], dim=1
|
|
359
|
+
) # (window_size, n_units, nmul)
|
|
360
|
+
else:
|
|
361
|
+
output[key] = torch.cat(
|
|
362
|
+
[preds[key] for preds in pred_list], dim=0
|
|
363
|
+
) # (n_units, nmul) or (n_units,)
|
|
364
|
+
return output
|
|
365
|
+
|
|
366
|
+
@staticmethod
|
|
367
|
+
def concat_temporal_chunks(pred_list: list[dict[str, torch.Tensor]]):
|
|
368
|
+
"""Concatenate temporal chunk predictions."""
|
|
369
|
+
output = {}
|
|
370
|
+
for key in pred_list[0].keys():
|
|
371
|
+
if pred_list[0][key].ndim == 3:
|
|
372
|
+
output[key] = torch.cat(
|
|
373
|
+
[preds[key] for preds in pred_list], dim=0
|
|
374
|
+
) # (window_size, n, nmul)
|
|
375
|
+
else:
|
|
376
|
+
output[key] = pred_list[0][key] # (n_units, nmul) or (n_units,)
|
|
377
|
+
return output
|