hydrodl2 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,377 @@
1
+ from typing import Any, Optional
2
+
3
+ import torch
4
+ from tqdm import tqdm
5
+
6
+ from hydrodl2.models.hbv.hbv_2 import Hbv_2
7
+ from hydrodl2.models.hbv.hbv_2_hourly import Hbv_2_hourly
8
+
9
+
10
+ class Hbv_2_mts(torch.nn.Module):
11
+ """HBV 2.0, multi timescale, distributed UH.
12
+
13
+ Multi-component, multi-scale, differentiable PyTorch HBV model with rainfall
14
+ runoff simulation on unit basins.
15
+
16
+ Authors
17
+ -------
18
+ - Wencong Yang
19
+ - (Original NumPy HBV ver.) Beck et al., 2020 (http://www.gloh2o.org/hbv/).
20
+ - (HBV-light Version 2) Seibert, 2005
21
+ (https://www.geo.uzh.ch/dam/jcr:c8afa73c-ac90-478e-a8c7-929eed7b1b62/HBV_manual_2005.pdf).
22
+
23
+ Parameters
24
+ ----------
25
+ config
26
+ Configuration dictionary.
27
+ device
28
+ Device to run the model on.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ low_freq_config: Optional[dict[str, Any]] = None,
34
+ high_freq_config: Optional[dict[str, Any]] = None,
35
+ device: Optional[torch.device] = None,
36
+ ) -> None:
37
+ super().__init__()
38
+ self.device = device if device is not None else torch.device('cpu')
39
+ self.dtype = torch.float32
40
+ self.low_freq_model = Hbv_2(low_freq_config, device=device)
41
+ self.low_freq_model.initialize = True
42
+ self.high_freq_model = Hbv_2_hourly(high_freq_config, device=device)
43
+ self._state_cache = [None, None]
44
+ self.states = (None, None)
45
+ self.load_from_cache = False
46
+ self.use_from_cache = False
47
+
48
+ # # learnable transfer
49
+ # self.state_transfer_model = torch.nn.ModuleDict(
50
+ # {
51
+ # name: torch.nn.Sequential(
52
+ # torch.nn.Linear(
53
+ # self.low_freq_model.nmul, self.high_freq_model.nmul
54
+ # ),
55
+ # torch.nn.ReLU(),
56
+ # )
57
+ # for name in self.high_freq_model.state_names
58
+ # }
59
+ # )
60
+ # Identity state transfer
61
+ self.state_transfer_model = torch.nn.ModuleDict(
62
+ {name: torch.nn.Identity() for name in self.high_freq_model.state_names}
63
+ )
64
+
65
+ self.train_spatial_chunk_size = high_freq_config['train_spatial_chunk_size']
66
+ self.simulate_spatial_chunk_size = high_freq_config[
67
+ 'simulate_spatial_chunk_size'
68
+ ]
69
+ self.simulate_temporal_chunk_size = high_freq_config[
70
+ 'simulate_temporal_chunk_size'
71
+ ]
72
+ self.spatial_chunk_size = self.train_spatial_chunk_size
73
+ self.simulate_mode = False
74
+
75
+ # warmup steps for routing during training.
76
+ self.train_warmup = high_freq_config['train_warmup']
77
+
78
+ def get_states(self) -> Optional[tuple[torch.Tensor, ...]]:
79
+ """Return internal states for high and low frequency models."""
80
+ lof_states = self.low_freq_model.get_states()
81
+ hif_states = self.high_freq_model.get_states()
82
+ return (lof_states, hif_states)
83
+
84
+ def load_states(
85
+ self,
86
+ state_tuple: tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]],
87
+ ) -> None:
88
+ """Load internal model states sideload low frequency states."""
89
+ if not isinstance(state_tuple, tuple) or len(state_tuple) != 2:
90
+ raise ValueError("`states` must be a tuple of two tuples of tensors.")
91
+ self._state_cache = tuple(
92
+ tuple(s[-1].detach().to(self.device, dtype=self.dtype) for s in states)
93
+ for states in state_tuple
94
+ )
95
+
96
+ if self.load_from_cache:
97
+ # Only sideload low-frequency states.
98
+ self.low_freq_model.load_states(state_tuple[0])
99
+
100
+ def _forward(
101
+ self,
102
+ x_dict: dict[str, torch.Tensor],
103
+ parameters: tuple[list[torch.Tensor], list[torch.Tensor]],
104
+ ) -> dict[str, torch.Tensor]:
105
+ """Base forward."""
106
+ # 1. Transfer states
107
+ low_freq_parameters, high_freq_parameters = parameters
108
+
109
+ if self.use_from_cache and (self._state_cache[1] is not None):
110
+ states = self.states[1]
111
+ else:
112
+ low_freq_x_dict = {
113
+ 'x_phy': x_dict['x_phy_low_freq'],
114
+ 'ac_all': x_dict['ac_all'],
115
+ 'elev_all': x_dict['elev_all'],
116
+ 'muwts': x_dict.get('muwts', None),
117
+ }
118
+
119
+ self.low_freq_model.states = None
120
+ self.low_freq_model(
121
+ low_freq_x_dict,
122
+ low_freq_parameters,
123
+ )
124
+
125
+ # Low-frequency states at last timestep
126
+ self._state_cache[0] = self.low_freq_model.states
127
+ states = self.state_transfer(self.low_freq_model.states)
128
+
129
+ # 2. Transfer parameters
130
+ phy_dy_params_dict, phy_static_params_dict, distr_params_dict = (
131
+ self.param_transfer(
132
+ low_freq_parameters,
133
+ high_freq_parameters,
134
+ )
135
+ )
136
+
137
+ # Run the model
138
+ x = x_dict['x_phy_high_freq']
139
+
140
+ Ac = x_dict['ac_all'].unsqueeze(-1).expand(-1, self.high_freq_model.nmul)
141
+ Elevation = (
142
+ x_dict['elev_all'].unsqueeze(-1).expand(-1, self.high_freq_model.nmul)
143
+ )
144
+ outlet_topo = x_dict['outlet_topo']
145
+ areas = x_dict['areas']
146
+
147
+ predictions, hif_states = self.high_freq_model._PBM(
148
+ forcing=x,
149
+ Ac=Ac,
150
+ Elevation=Elevation,
151
+ states=tuple(states),
152
+ phy_dy_params_dict=phy_dy_params_dict,
153
+ phy_static_params_dict=phy_static_params_dict,
154
+ outlet_topo=outlet_topo,
155
+ areas=areas,
156
+ distr_params_dict=distr_params_dict,
157
+ )
158
+
159
+ # State caching
160
+ self._state_cache[1] = tuple(s.detach() for s in hif_states)
161
+ if self.load_from_cache:
162
+ new_states = []
163
+
164
+ # low-frequency states remain the same
165
+ new_states.append(self._state_cache[0])
166
+
167
+ # high-frequency states updated
168
+ new_states.append(tuple(s[-1] for s in hif_states))
169
+ self.states = tuple(new_states)
170
+
171
+ # Temp: save initial states
172
+ # torch.save(tuple(tuple(s.detach().cpu() for s in states) for states in self._state_cache), "/projects/mhpi/leoglonz/ciroh-ua/dhbv2_mts/ngen_resources/data/dhbv2_mts/models/hfv2.2_15yr/initial_states_2009.pt")
173
+
174
+ return predictions
175
+
176
+ def forward(
177
+ self,
178
+ x_dict: dict[str, torch.Tensor],
179
+ parameters: tuple[list[torch.Tensor], list[torch.Tensor]],
180
+ ) -> dict[str, torch.Tensor]:
181
+ """Foward supports spatial and temporal chunking.
182
+
183
+ x_dict and parameters can be in cpu for simulation mode to save GPU
184
+ memory.
185
+ """
186
+ device = self.device
187
+ n_units = x_dict['areas'].shape[0]
188
+ spatial_chunk_size = self.spatial_chunk_size
189
+ temporal_chunk_size = self.simulate_temporal_chunk_size
190
+ train_warmup = self.train_warmup
191
+
192
+ if (not self.simulate_mode) and (n_units <= spatial_chunk_size):
193
+ self.high_freq_model.use_distr_routing = False
194
+ return self._forward(x_dict, parameters)
195
+
196
+ # Chunked runoff generation for simulation mode or large training batches
197
+ self.high_freq_model.use_distr_routing = False
198
+ preds_list = []
199
+ prog_bar = tqdm(
200
+ range(0, n_units, spatial_chunk_size),
201
+ desc="Spatial runoff chunks",
202
+ )
203
+
204
+ for i in prog_bar:
205
+ end_idx = min(i + spatial_chunk_size, n_units)
206
+ reach_idx = (x_dict['outlet_topo'] == 1).nonzero(as_tuple=False)
207
+ idxs_in_chunk = (reach_idx[:, 1] >= i) & (reach_idx[:, 1] < end_idx)
208
+
209
+ chunk_x_dict = {
210
+ 'x_phy_low_freq': x_dict['x_phy_low_freq'][:, i:end_idx].to(device),
211
+ 'x_phy_high_freq': x_dict['x_phy_high_freq'][:, i:end_idx].to(device),
212
+ 'ac_all': x_dict['ac_all'][i:end_idx].to(device),
213
+ 'elev_all': x_dict['elev_all'][i:end_idx].to(device),
214
+ 'areas': x_dict['areas'][i:end_idx].to(device),
215
+ 'outlet_topo': x_dict['outlet_topo'][:, i:end_idx].to(device),
216
+ }
217
+ chunk_parameters = (
218
+ [
219
+ parameters[0][0][:, i:end_idx].to(
220
+ device
221
+ ), # low-freq dynamic phy params
222
+ parameters[0][1][i:end_idx].to(
223
+ device
224
+ ), # low-freq static phy params
225
+ ],
226
+ [
227
+ parameters[1][0][:, i:end_idx].to(
228
+ device
229
+ ), # high-freq dynamic phy params
230
+ parameters[1][1][i:end_idx].to(
231
+ device
232
+ ), # high-freq static phy params
233
+ parameters[1][2][idxs_in_chunk].to(
234
+ device
235
+ ), # high-freq distributed params
236
+ ],
237
+ )
238
+ chunk_predictions = self._forward(chunk_x_dict, chunk_parameters)
239
+ preds_list.append(chunk_predictions)
240
+
241
+ predictions = self.concat_spatial_chunks(preds_list)
242
+ runoff = predictions['Qs']
243
+ high_freq_length = runoff.shape[0]
244
+
245
+ # Chunked routing
246
+ _, _, _, distr_params = self.high_freq_model.unpack_parameters(parameters[1])
247
+ distr_params_dict = self.high_freq_model._descale_distr_parameters(distr_params)
248
+ distr_params_dict = {
249
+ key: value.to(device) for key, value in distr_params_dict.items()
250
+ }
251
+ outlet_topo = x_dict['outlet_topo'].to(device)
252
+ areas = x_dict['areas'].to(device)
253
+
254
+ preds_list = []
255
+ prog_bar = tqdm(
256
+ range(train_warmup, high_freq_length, temporal_chunk_size),
257
+ desc="Temporal routing chunks",
258
+ )
259
+
260
+ for t in prog_bar:
261
+ end_t = min(t + temporal_chunk_size, high_freq_length)
262
+ chunk_runoff = runoff[t - train_warmup : end_t]
263
+ chunk_predictions = self.high_freq_model.distr_routing(
264
+ Qs=chunk_runoff,
265
+ distr_params_dict=distr_params_dict,
266
+ outlet_topo=outlet_topo,
267
+ areas=areas,
268
+ )
269
+
270
+ # Remove routing warmup for all but first chunk
271
+ if t > train_warmup:
272
+ chunk_predictions = {
273
+ key: value[train_warmup:]
274
+ for key, value in chunk_predictions.items()
275
+ }
276
+ preds_list.append(chunk_predictions)
277
+
278
+ routing_predictions = self.concat_temporal_chunks(preds_list)
279
+ predictions['streamflow'] = routing_predictions['Qs_rout']
280
+
281
+ return predictions
282
+
283
+ def set_mode(self, is_simulate: bool):
284
+ """Set simulate mode."""
285
+ if is_simulate:
286
+ self.spatial_chunk_size = self.simulate_spatial_chunk_size
287
+ self.simulate_mode = True
288
+ else:
289
+ self.spatial_chunk_size = self.train_spatial_chunk_size
290
+ self.simulate_mode = False
291
+
292
+ def param_transfer(
293
+ self,
294
+ low_freq_parameters: list[torch.Tensor],
295
+ high_freq_parameters: list[torch.Tensor],
296
+ ):
297
+ """Map low-frequency parameters to high-frequency parameters."""
298
+ warmup_phy_dy_params, warmup_phy_static_params, warmup_routing_params = (
299
+ self.low_freq_model._unpack_parameters(low_freq_parameters)
300
+ )
301
+
302
+ phy_dy_params, phy_static_params, routing_params, distr_params = (
303
+ self.high_freq_model._unpack_parameters(high_freq_parameters)
304
+ )
305
+ # New dynamic params
306
+ phy_dy_params_dict = self.high_freq_model._descale_phy_dy_parameters(
307
+ phy_dy_params, dy_list=self.high_freq_model.dynamic_params
308
+ )
309
+
310
+ # Keep warmup static params, add high-freq specific static params
311
+ static_param_names = [
312
+ param
313
+ for param in self.high_freq_model.phy_param_names
314
+ if param not in self.high_freq_model.dynamic_params
315
+ ]
316
+ warmup_static_param_names = [
317
+ param
318
+ for param in self.low_freq_model.phy_param_names
319
+ if param not in self.low_freq_model.dynamic_params
320
+ ]
321
+ var_indexes = [
322
+ i
323
+ for i, param in enumerate(static_param_names)
324
+ if param not in warmup_static_param_names
325
+ ]
326
+ phy_static_params_dict = self.high_freq_model._descale_phy_stat_parameters(
327
+ torch.concat(
328
+ [warmup_phy_static_params, phy_static_params[:, var_indexes]], dim=1
329
+ ),
330
+ stat_list=static_param_names,
331
+ )
332
+ # New distributed params
333
+ distr_params_dict = self.high_freq_model._descale_distr_parameters(distr_params)
334
+
335
+ # New routing params
336
+ if self.high_freq_model.routing:
337
+ self.high_freq_model.routing_param_dict = (
338
+ self.high_freq_model._descale_rout_parameters(routing_params)
339
+ )
340
+
341
+ return phy_dy_params_dict, phy_static_params_dict, distr_params_dict
342
+
343
+ def state_transfer(self, states: list[torch.Tensor]):
344
+ """Map low-frequency states to high-frequency states."""
345
+ states_dict = dict(zip(self.high_freq_model.state_names, states))
346
+ return [
347
+ self.state_transfer_model[key](states_dict[key])
348
+ for key in states_dict.keys()
349
+ ]
350
+
351
+ @staticmethod
352
+ def concat_spatial_chunks(pred_list: list[dict[str, torch.Tensor]]):
353
+ """Concatenate spatial chunk pedictions."""
354
+ output = {}
355
+ for key in pred_list[0].keys():
356
+ if pred_list[0][key].ndim == 3:
357
+ output[key] = torch.cat(
358
+ [preds[key] for preds in pred_list], dim=1
359
+ ) # (window_size, n_units, nmul)
360
+ else:
361
+ output[key] = torch.cat(
362
+ [preds[key] for preds in pred_list], dim=0
363
+ ) # (n_units, nmul) or (n_units,)
364
+ return output
365
+
366
+ @staticmethod
367
+ def concat_temporal_chunks(pred_list: list[dict[str, torch.Tensor]]):
368
+ """Concatenate temporal chunk predictions."""
369
+ output = {}
370
+ for key in pred_list[0].keys():
371
+ if pred_list[0][key].ndim == 3:
372
+ output[key] = torch.cat(
373
+ [preds[key] for preds in pred_list], dim=0
374
+ ) # (window_size, n, nmul)
375
+ else:
376
+ output[key] = pred_list[0][key] # (n_units, nmul) or (n_units,)
377
+ return output