hydrodl2 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,608 @@
1
+ from typing import Any, Optional, Union
2
+
3
+ import torch
4
+
5
+ from hydrodl2.core.calc import change_param_range, uh_conv, uh_gamma
6
+
7
+
8
+ class Hbv_1_1p(torch.nn.Module):
9
+ """HBV 1.1p ~
10
+ Multi-component, differentiable Pytorch HBV model with a capillary rise
11
+ modification and option to run without internal state warmup.
12
+
13
+ Authors
14
+ -------
15
+ - Yalan Song, Farshid Rahmani, Leo Lonzarich
16
+ - (Original NumPy HBV ver.) Beck et al., 2020 (http://www.gloh2o.org/hbv/).
17
+ - (HBV-light Version 2) Seibert, 2005
18
+ (https://www.geo.uzh.ch/dam/jcr:c8afa73c-ac90-478e-a8c7-929eed7b1b62/HBV_manual_2005.pdf).
19
+
20
+ Publication
21
+ -----------
22
+ - Yalan Song, Kamlesh Sawadekar, Jonathan Frame, et al. "Physics-informed,
23
+ Differentiable Hydrologic Models for Capturing Unseen Extreme Events."
24
+ ESS Open Archive (2025).
25
+ https://doi.org/10.22541/essoar.172304428.82707157/v2.
26
+
27
+ Parameters
28
+ ----------
29
+ config
30
+ Configuration dictionary.
31
+ device
32
+ Device to run the model on.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ config: Optional[dict[str, Any]] = None,
38
+ device: Optional[torch.device] = None,
39
+ ) -> None:
40
+ super().__init__()
41
+ self.name = 'HBV 1.1p'
42
+ self.config = config
43
+ self.initialize = False
44
+ self.warm_up = 0
45
+ self.pred_cutoff = 0
46
+ self.warm_up_states = True
47
+ self.dynamic_params = []
48
+ self.dy_drop = 0.0
49
+ self.variables = ['prcp', 'tmean', 'pet']
50
+ self.routing = True
51
+ self.comprout = False
52
+ self.nearzero = 1e-5
53
+ self.nmul = 1
54
+ self.cache_states = False
55
+ self.device = device
56
+
57
+ self.states, self._states_cache = None, None
58
+
59
+ self.state_names = [
60
+ 'SNOWPACK', # Snowpack storage
61
+ 'MELTWATER', # Meltwater storage
62
+ 'SM', # Soil moisture storage
63
+ 'SUZ', # Upper groundwater storage
64
+ 'SLZ', # Lower groundwater storage
65
+ ]
66
+ self.flux_names = [
67
+ 'streamflow', # Routed Streamflow
68
+ 'srflow', # Routed surface runoff
69
+ 'ssflow', # Routed subsurface flow
70
+ 'gwflow', # Routed groundwater flow
71
+ 'AET_hydro', # Actual ET
72
+ 'PET_hydro', # Potential ET
73
+ 'SWE', # Snow water equivalent
74
+ 'streamflow_no_rout', # Streamflow
75
+ 'srflow_no_rout', # Surface runoff
76
+ 'ssflow_no_rout', # Subsurface flow
77
+ 'gwflow_no_rout', # Groundwater flow
78
+ 'recharge', # Recharge
79
+ 'excs', # Excess stored water
80
+ 'evapfactor', # Evaporation factor
81
+ 'tosoil', # Infiltration
82
+ 'percolation', # Percolation
83
+ 'capillary', # Capillary rise
84
+ 'BFI', # Baseflow index
85
+ ]
86
+
87
+ self.parameter_bounds = {
88
+ 'parBETA': [1.0, 6.0],
89
+ 'parFC': [50, 1000],
90
+ 'parK0': [0.05, 0.9],
91
+ 'parK1': [0.01, 0.5],
92
+ 'parK2': [0.001, 0.2],
93
+ 'parLP': [0.2, 1],
94
+ 'parPERC': [0, 10],
95
+ 'parUZL': [0, 100],
96
+ 'parTT': [-2.5, 2.5],
97
+ 'parCFMAX': [0.5, 10],
98
+ 'parCFR': [0, 0.1],
99
+ 'parCWH': [0, 0.2],
100
+ 'parBETAET': [0.3, 5],
101
+ 'parC': [0, 1],
102
+ }
103
+ self.routing_parameter_bounds = {
104
+ 'route_a': [0, 2.9],
105
+ 'route_b': [0, 6.5],
106
+ }
107
+
108
+ if not device:
109
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
110
+
111
+ if config is not None:
112
+ # Overwrite defaults with config values.
113
+ self.warm_up = config.get('warm_up', self.warm_up)
114
+ self.warm_up_states = config.get('warm_up_states', self.warm_up_states)
115
+ self.dy_drop = config.get('dy_drop', self.dy_drop)
116
+ self.dynamic_params = config['dynamic_params'].get(
117
+ self.__class__.__name__, self.dynamic_params
118
+ )
119
+ self.variables = config.get('variables', self.variables)
120
+ self.routing = config.get('routing', self.routing)
121
+ self.comprout = config.get('comprout', self.comprout)
122
+ self.nearzero = config.get('nearzero', self.nearzero)
123
+ self.nmul = config.get('nmul', self.nmul)
124
+ self.cache_states = config.get('cache_states', False)
125
+ self._set_parameters()
126
+
127
+ def _init_states(self, ngrid: int) -> tuple[torch.Tensor]:
128
+ """Initialize model states to zero."""
129
+
130
+ def make_state():
131
+ return torch.full(
132
+ (ngrid, self.nmul), 0.001, dtype=torch.float32, device=self.device
133
+ )
134
+
135
+ return tuple(make_state() for _ in range(len(self.state_names)))
136
+
137
+ def get_states(self) -> Optional[tuple[torch.Tensor, ...]]:
138
+ """Return internal model states.
139
+
140
+ Returns
141
+ -------
142
+ tuple[torch.Tensor, ...]
143
+ A tuple containing the states (SNOWPACK, MELTWATER, SM, SUZ, SLZ).
144
+ """
145
+ return self._states_cache
146
+
147
+ def load_states(
148
+ self,
149
+ states: tuple[torch.Tensor, ...],
150
+ ) -> None:
151
+ """Load internal model states and set to model device and type.
152
+
153
+ Parameters
154
+ ----------
155
+ states
156
+ A tuple containing the states (SNOWPACK, MELTWATER, SM, SUZ, SLZ).
157
+ """
158
+ for state in states:
159
+ if not isinstance(state, torch.Tensor):
160
+ raise ValueError("Each element in `states` must be a tensor.")
161
+ nstates = len(self.state_names)
162
+ if not (isinstance(states, tuple) and len(states) == nstates):
163
+ raise ValueError(f"`states` must be a tuple of {nstates} tensors.")
164
+
165
+ self.states = tuple(
166
+ s.detach().to(self.device, dtype=torch.float32) for s in states
167
+ )
168
+
169
+ def _set_parameters(self) -> None:
170
+ """Get physical parameters."""
171
+ self.phy_param_names = self.parameter_bounds.keys()
172
+ if self.routing:
173
+ self.routing_param_names = self.routing_parameter_bounds.keys()
174
+ else:
175
+ self.routing_param_names = []
176
+
177
+ self.learnable_param_count = len(self.phy_param_names) * self.nmul + len(
178
+ self.routing_param_names
179
+ )
180
+
181
+ def _unpack_parameters(
182
+ self,
183
+ parameters: torch.Tensor,
184
+ ) -> tuple[torch.Tensor, torch.Tensor]:
185
+ """Extract physical model and routing parameters from NN output.
186
+
187
+ Parameters
188
+ ----------
189
+ parameters
190
+ Unprocessed, learned parameters from a neural network.
191
+
192
+ Returns
193
+ -------
194
+ tuple[torch.Tensor, torch.Tensor]
195
+ Tuple of physical and routing parameters.
196
+ """
197
+ phy_param_count = len(self.parameter_bounds)
198
+
199
+ # Physical parameters
200
+ phy_params = torch.sigmoid(
201
+ parameters[:, :, : phy_param_count * self.nmul]
202
+ ).view(
203
+ parameters.shape[0],
204
+ parameters.shape[1],
205
+ phy_param_count,
206
+ self.nmul,
207
+ )
208
+ # Routing parameters
209
+ routing_params = None
210
+ if self.routing:
211
+ routing_params = torch.sigmoid(
212
+ parameters[-1, :, phy_param_count * self.nmul :],
213
+ )
214
+ return (phy_params, routing_params)
215
+
216
+ def _descale_phy_parameters(
217
+ self,
218
+ phy_params: torch.Tensor,
219
+ dy_list: list,
220
+ ) -> dict[str, torch.Tensor]:
221
+ """Descale physical parameters.
222
+
223
+ Parameters
224
+ ----------
225
+ phy_params
226
+ Normalized physical parameters.
227
+ dy_list
228
+ List of dynamic parameters.
229
+
230
+ Returns
231
+ -------
232
+ dict
233
+ Dictionary of descaled physical parameters.
234
+ """
235
+ nsteps = phy_params.shape[0]
236
+ ngrid = phy_params.shape[1]
237
+
238
+ param_dict = {}
239
+ pmat = torch.ones([1, ngrid, 1]) * self.dy_drop
240
+ for i, name in enumerate(self.parameter_bounds.keys()):
241
+ staPar = phy_params[-1, :, i, :].unsqueeze(0).repeat([nsteps, 1, 1])
242
+ if name in dy_list:
243
+ dynPar = phy_params[:, :, i, :]
244
+ drmask = torch.bernoulli(pmat).detach_().to(self.device)
245
+ comPar = dynPar * (1 - drmask) + staPar * drmask
246
+ param_dict[name] = change_param_range(
247
+ param=comPar,
248
+ bounds=self.parameter_bounds[name],
249
+ )
250
+ else:
251
+ param_dict[name] = change_param_range(
252
+ param=staPar,
253
+ bounds=self.parameter_bounds[name],
254
+ )
255
+ return param_dict
256
+
257
+ def _descale_route_parameters(
258
+ self,
259
+ routing_params: torch.Tensor,
260
+ ) -> dict[str, torch.Tensor]:
261
+ """Descale routing parameters.
262
+
263
+ Parameters
264
+ ----------
265
+ routing_params
266
+ Normalized routing parameters.
267
+
268
+ Returns
269
+ -------
270
+ dict
271
+ Dictionary of descaled routing parameters.
272
+ """
273
+ parameter_dict = {}
274
+ for i, name in enumerate(self.routing_parameter_bounds.keys()):
275
+ param = routing_params[:, i]
276
+
277
+ parameter_dict[name] = change_param_range(
278
+ param=param,
279
+ bounds=self.routing_parameter_bounds[name],
280
+ )
281
+ return parameter_dict
282
+
283
+ def forward(
284
+ self,
285
+ x_dict: dict[str, torch.Tensor],
286
+ parameters: torch.Tensor,
287
+ ) -> Union[tuple, tuple[dict[str, torch.Tensor], tuple]]:
288
+ """Forward pass.
289
+
290
+ Parameters
291
+ ----------
292
+ x_dict
293
+ Dictionary of input forcing data.
294
+ parameters
295
+ Unprocessed, learned parameters from a neural network.
296
+
297
+ Returns
298
+ -------
299
+ Union[tuple, tuple[dict, tuple]]
300
+ Tuple or dictionary of model outputs.
301
+ """
302
+ # Unpack input data.
303
+ x = x_dict['x_phy']
304
+ self.muwts = x_dict.get('muwts', None)
305
+ ngrid = x.shape[1]
306
+
307
+ # Unpack parameters.
308
+ phy_params, routing_params = self._unpack_parameters(parameters)
309
+ if self.routing:
310
+ self.routing_param_dict = self._descale_route_parameters(routing_params)
311
+
312
+ # Initialization
313
+ if self.warm_up_states:
314
+ warm_up = self.warm_up
315
+ else:
316
+ # No state warm up: run the full model for warm_up days.
317
+ self.pred_cutoff = self.warm_up
318
+ warm_up = 0
319
+
320
+ if (not self.states) or (not self.cache_states):
321
+ current_states = self._init_states(ngrid)
322
+ else:
323
+ current_states = self.states
324
+
325
+ # Warm-up model states - run the model only on warm_up days first.
326
+ if warm_up > 0:
327
+ with torch.no_grad():
328
+ phy_param_warmup_dict = self._descale_phy_parameters(
329
+ phy_params[:warm_up, :, :],
330
+ dy_list=[],
331
+ )
332
+ # a. Save current model settings.
333
+ init_flag, route_flag = self.initialize, self.routing
334
+
335
+ # b. Set temporary model settings for warm-up.
336
+ self.initialize, self.routing = True, False
337
+
338
+ current_states = self._PBM(
339
+ x[:warm_up, :, :],
340
+ current_states,
341
+ phy_param_warmup_dict,
342
+ )
343
+
344
+ # c. Restore model settings.
345
+ self.initialize, self.routing = init_flag, route_flag
346
+
347
+ # Run the model for remainder of the simulation period.
348
+ phy_params_dict = self._descale_phy_parameters(
349
+ phy_params[warm_up:, :, :],
350
+ dy_list=self.dynamic_params,
351
+ )
352
+ fluxes, states = self._PBM(x[warm_up:, :, :], current_states, phy_params_dict)
353
+
354
+ # State caching
355
+ self._states_cache = [s.detach() for s in states]
356
+
357
+ if self.cache_states:
358
+ self.states = self._states_cache
359
+
360
+ return fluxes
361
+
362
+ def _PBM(
363
+ self,
364
+ forcing: torch.Tensor,
365
+ states: tuple,
366
+ full_param_dict: dict,
367
+ ) -> Union[tuple, dict[str, torch.Tensor]]:
368
+ """Run through process-based model (PBM).
369
+
370
+ Parameters
371
+ ----------
372
+ forcing
373
+ Input forcing data.
374
+ states
375
+ Initial model states.
376
+ full_param_dict
377
+ Dictionary of model parameters.
378
+
379
+ Returns
380
+ -------
381
+ Union[tuple, dict]
382
+ Tuple or dictionary of model outputs.
383
+ """
384
+ SNOWPACK, MELTWATER, SM, SUZ, SLZ = states
385
+
386
+ # Forcings
387
+ P = forcing[:, :, self.variables.index('prcp')] # Precipitation
388
+ T = forcing[:, :, self.variables.index('tmean')] # Mean air temp
389
+ PET = forcing[:, :, self.variables.index('pet')] # Potential ET
390
+ nsteps, ngrid = P.shape
391
+
392
+ # Expand dims to accomodate for nmul models.
393
+ Pm = P.unsqueeze(2).repeat(1, 1, self.nmul)
394
+ Tm = T.unsqueeze(2).repeat(1, 1, self.nmul)
395
+ PETm = PET.unsqueeze(-1).repeat(1, 1, self.nmul)
396
+
397
+ # Apply correction factor to precipitation
398
+ # P = parPCORR.repeat(nsteps, 1) * P
399
+
400
+ # Initialize time series of model variables in shape [time, basins, nmul].
401
+ Qsimmu = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.001
402
+ Q0_sim = (
403
+ torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.0001
404
+ )
405
+ Q1_sim = (
406
+ torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.0001
407
+ )
408
+ Q2_sim = (
409
+ torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.0001
410
+ )
411
+
412
+ AET = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
413
+ recharge_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
414
+ excs_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
415
+ evapfactor_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
416
+ tosoil_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
417
+ PERC_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
418
+ SWE_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
419
+ capillary_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
420
+
421
+ param_dict = {}
422
+ for t in range(nsteps):
423
+ # Get dynamic parameter values per timestep.
424
+ for key in full_param_dict.keys():
425
+ param_dict[key] = full_param_dict[key][t, :, :]
426
+
427
+ # Separate precipitation into liquid and solid components.
428
+ PRECIP = Pm[t, :, :]
429
+ RAIN = torch.mul(
430
+ PRECIP, (Tm[t, :, :] >= param_dict['parTT']).type(torch.float32)
431
+ )
432
+ SNOW = torch.mul(
433
+ PRECIP, (Tm[t, :, :] < param_dict['parTT']).type(torch.float32)
434
+ )
435
+
436
+ # Snow -------------------------------
437
+ SNOWPACK = SNOWPACK + SNOW
438
+ melt = param_dict['parCFMAX'] * (Tm[t, :, :] - param_dict['parTT'])
439
+ # melt[melt < 0.0] = 0.0
440
+ melt = torch.clamp(melt, min=0.0)
441
+ # melt[melt > SNOWPACK] = SNOWPACK[melt > SNOWPACK]
442
+ melt = torch.min(melt, SNOWPACK)
443
+ MELTWATER = MELTWATER + melt
444
+ SNOWPACK = SNOWPACK - melt
445
+ refreezing = (
446
+ param_dict['parCFR']
447
+ * param_dict['parCFMAX']
448
+ * (param_dict['parTT'] - Tm[t, :, :])
449
+ )
450
+ # refreezing[refreezing < 0.0] = 0.0
451
+ # refreezing[refreezing > MELTWATER] = MELTWATER[refreezing > MELTWATER]
452
+ refreezing = torch.clamp(refreezing, min=0.0)
453
+ refreezing = torch.min(refreezing, MELTWATER)
454
+ SNOWPACK = SNOWPACK + refreezing
455
+ MELTWATER = MELTWATER - refreezing
456
+ tosoil = MELTWATER - (param_dict['parCWH'] * SNOWPACK)
457
+ tosoil = torch.clamp(tosoil, min=0.0)
458
+ MELTWATER = MELTWATER - tosoil
459
+
460
+ # Soil and evaporation -------------------------------
461
+ soil_wetness = (SM / param_dict['parFC']) ** param_dict['parBETA']
462
+ # soil_wetness[soil_wetness < 0.0] = 0.0
463
+ # soil_wetness[soil_wetness > 1.0] = 1.0
464
+ soil_wetness = torch.clamp(soil_wetness, min=0.0, max=1.0)
465
+ recharge = (RAIN + tosoil) * soil_wetness
466
+
467
+ SM = SM + RAIN + tosoil - recharge
468
+
469
+ excess = SM - param_dict['parFC']
470
+ excess = torch.clamp(excess, min=0.0)
471
+ SM = SM - excess
472
+ # NOTE: Different from HBV 1.0. Add static/dynamicET shape parameter parBETAET.
473
+ evapfactor = (
474
+ SM / (param_dict['parLP'] * param_dict['parFC'])
475
+ ) ** param_dict['parBETAET']
476
+ evapfactor = torch.clamp(evapfactor, min=0.0, max=1.0)
477
+ ETact = PETm[t, :, :] * evapfactor
478
+ ETact = torch.min(SM, ETact)
479
+ SM = torch.clamp(SM - ETact, min=self.nearzero)
480
+
481
+ # Capillary rise (HBV 1.1p mod) -------------------------------
482
+ capillary = torch.min(
483
+ SLZ,
484
+ param_dict['parC']
485
+ * SLZ
486
+ * (1.0 - torch.clamp(SM / param_dict['parFC'], max=1.0)),
487
+ )
488
+
489
+ SM = torch.clamp(SM + capillary, min=self.nearzero)
490
+ SLZ = torch.clamp(SLZ - capillary, min=self.nearzero)
491
+
492
+ # Groundwater boxes -------------------------------
493
+ SUZ = SUZ + recharge + excess
494
+ PERC = torch.min(SUZ, param_dict['parPERC'])
495
+ SUZ = SUZ - PERC
496
+ Q0 = param_dict['parK0'] * torch.clamp(SUZ - param_dict['parUZL'], min=0.0)
497
+ SUZ = SUZ - Q0
498
+ Q1 = param_dict['parK1'] * SUZ
499
+ SUZ = SUZ - Q1
500
+ SLZ = SLZ + PERC
501
+ Q2 = param_dict['parK2'] * SLZ
502
+ SLZ = SLZ - Q2
503
+
504
+ Qsimmu[t, :, :] = Q0 + Q1 + Q2
505
+ Q0_sim[t, :, :] = Q0
506
+ Q1_sim[t, :, :] = Q1
507
+ Q2_sim[t, :, :] = Q2
508
+ AET[t, :, :] = ETact
509
+ SWE_sim[t, :, :] = SNOWPACK
510
+ capillary_sim[t, :, :] = capillary
511
+
512
+ recharge_sim[t, :, :] = recharge
513
+ excs_sim[t, :, :] = excess
514
+ evapfactor_sim[t, :, :] = evapfactor
515
+ tosoil_sim[t, :, :] = tosoil
516
+ PERC_sim[t, :, :] = PERC
517
+
518
+ # Get the average or weighted average using learned weights.
519
+ if self.muwts is None:
520
+ Qsimavg = Qsimmu.mean(-1)
521
+ else:
522
+ Qsimavg = (Qsimmu * self.muwts).sum(-1)
523
+
524
+ # Run routing
525
+ if self.routing:
526
+ # Routing for all components or just the average.
527
+ if self.comprout:
528
+ # All components; reshape to [time, gages * num models]
529
+ Qsim = Qsimmu.view(nsteps, ngrid * self.nmul)
530
+ else:
531
+ # Average, then do routing.
532
+ Qsim = Qsimavg
533
+
534
+ UH = uh_gamma(
535
+ self.routing_param_dict['route_a'].repeat(nsteps, 1).unsqueeze(-1),
536
+ self.routing_param_dict['route_b'].repeat(nsteps, 1).unsqueeze(-1),
537
+ lenF=15,
538
+ )
539
+ rf = torch.unsqueeze(Qsim, -1).permute([1, 2, 0]) # [gages,vars,time]
540
+ UH = UH.permute([1, 2, 0]) # [gages,vars,time]
541
+ Qsrout = uh_conv(rf, UH).permute([2, 0, 1])
542
+
543
+ # Routing individually for Q0, Q1, and Q2, all w/ dims [gages,vars,time].
544
+ rf_Q0 = Q0_sim.mean(-1, keepdim=True).permute([1, 2, 0])
545
+ Q0_rout = uh_conv(rf_Q0, UH).permute([2, 0, 1])
546
+ rf_Q1 = Q1_sim.mean(-1, keepdim=True).permute([1, 2, 0])
547
+ Q1_rout = uh_conv(rf_Q1, UH).permute([2, 0, 1])
548
+ rf_Q2 = Q2_sim.mean(-1, keepdim=True).permute([1, 2, 0])
549
+ Q2_rout = uh_conv(rf_Q2, UH).permute([2, 0, 1])
550
+
551
+ if self.comprout:
552
+ # Qs is now shape [time, [gages*num models], vars]
553
+ Qstemp = Qsrout.view(nsteps, ngrid, self.nmul)
554
+ if self.muwts is None:
555
+ Qs = Qstemp.mean(-1, keepdim=True)
556
+ else:
557
+ Qs = (Qstemp * self.muwts).sum(-1, keepdim=True)
558
+ else:
559
+ Qs = Qsrout
560
+
561
+ else:
562
+ # No routing, only output the average of all model sims.
563
+ Qs = torch.unsqueeze(Qsimavg, -1)
564
+ Q0_rout = Q1_rout = Q2_rout = None
565
+
566
+ states = (SNOWPACK, MELTWATER, SM, SUZ, SLZ)
567
+
568
+ if self.initialize:
569
+ # If initialize is True, only return warmed-up storages.
570
+ return states
571
+ else:
572
+ # Baseflow index (BFI) calculation
573
+ BFI_sim = (
574
+ 100
575
+ * (torch.sum(Q2_rout, dim=0) / (torch.sum(Qs, dim=0) + self.nearzero))[
576
+ :, 0
577
+ ]
578
+ )
579
+
580
+ # Return all sim results.
581
+ flux_dict = {
582
+ 'streamflow': Qs, # Routed Streamflow
583
+ 'srflow': Q0_rout, # Routed surface runoff
584
+ 'ssflow': Q1_rout, # Routed subsurface flow
585
+ 'gwflow': Q2_rout, # Routed groundwater flow
586
+ 'AET_hydro': AET.mean(-1, keepdim=True), # Actual ET
587
+ 'PET_hydro': PETm.mean(-1, keepdim=True), # Potential ET
588
+ 'SWE': SWE_sim.mean(-1, keepdim=True), # Snow water equivalent
589
+ 'streamflow_no_rout': Qsim.unsqueeze(dim=2), # Streamflow
590
+ 'srflow_no_rout': Q0_sim.mean(-1, keepdim=True), # Surface runoff
591
+ 'ssflow_no_rout': Q1_sim.mean(-1, keepdim=True), # Subsurface flow
592
+ 'gwflow_no_rout': Q2_sim.mean(-1, keepdim=True), # Groundwater flow
593
+ 'recharge': recharge_sim.mean(-1, keepdim=True), # Recharge
594
+ 'excs': excs_sim.mean(-1, keepdim=True), # Excess stored water
595
+ 'evapfactor': evapfactor_sim.mean(
596
+ -1, keepdim=True
597
+ ), # Evaporation factor
598
+ 'tosoil': tosoil_sim.mean(-1, keepdim=True), # Infiltration
599
+ 'percolation': PERC_sim.mean(-1, keepdim=True), # Percolation
600
+ 'capillary': capillary_sim.mean(-1, keepdim=True), # Capillary rise
601
+ 'BFI': BFI_sim, # Baseflow index
602
+ }
603
+
604
+ if not self.warm_up_states:
605
+ for key in flux_dict.keys():
606
+ if key != 'BFI':
607
+ flux_dict[key] = flux_dict[key][self.pred_cutoff :, :, :]
608
+ return flux_dict, states