hydrodl2 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,596 @@
1
+ from typing import Any, Optional, Union
2
+
3
+ import torch
4
+
5
+ from hydrodl2.core.calc import change_param_range, uh_conv, uh_gamma
6
+
7
+
8
+ class Hbv(torch.nn.Module):
9
+ """HBV 1.0 ~.
10
+
11
+ Multi-component, differentiable PyTorch HBV model with option to run without
12
+ internal state warmup.
13
+
14
+ Authors
15
+ -------
16
+ - Farshid Rahmani, Yalan Song, Leo Lonzarich
17
+ - (Original NumPy HBV ver.) Beck et al., 2020 (http://www.gloh2o.org/hbv/).
18
+ - (HBV-light Version 2) Seibert, 2005
19
+ (https://www.geo.uzh.ch/dam/jcr:c8afa73c-ac90-478e-a8c7-929eed7b1b62/HBV_manual_2005.pdf).
20
+
21
+ Publication
22
+ -----------
23
+ - Dapeng Feng, Jiangtao Liu, Kathryn Lawson, Chaopeng Shen. "Differentiable,
24
+ learnable, regionalized process-based models with multiphysical outputs
25
+ can approach state-of-the-art hydrologic prediction accuracy." Water
26
+ Resources Research (2020), 58, e2022WR032404.
27
+ https://doi.org/10.1029/2022WR032404.
28
+
29
+ Parameters
30
+ ----------
31
+ config
32
+ Configuration dictionary.
33
+ device
34
+ Device to run the model on.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ config: Optional[dict[str, Any]] = None,
40
+ device: Optional[torch.device] = None,
41
+ ) -> None:
42
+ super().__init__()
43
+ self.name = 'HBV 1.0'
44
+ self.config = config
45
+ self.initialize = False
46
+ self.warm_up = 0
47
+ self.pred_cutoff = 0
48
+ self.warm_up_states = True
49
+ self.dynamic_params = []
50
+ self.dy_drop = 0.0
51
+ self.variables = ['prcp', 'tmean', 'pet']
52
+ self.routing = True
53
+ self.comprout = False
54
+ self.nearzero = 1e-5
55
+ self.nmul = 1
56
+ self.cache_states = False
57
+ self.device = device
58
+
59
+ self.states, self._states_cache = None, None
60
+
61
+ self.state_names = [
62
+ 'SNOWPACK', # Snowpack storage
63
+ 'MELTWATER', # Meltwater storage
64
+ 'SM', # Soil moisture storage
65
+ 'SUZ', # Upper groundwater storage
66
+ 'SLZ', # Lower groundwater storage
67
+ ]
68
+ self.flux_names = [
69
+ 'streamflow', # Routed Streamflow
70
+ 'srflow', # Routed surface runoff
71
+ 'ssflow', # Routed subsurface flow
72
+ 'gwflow', # Routed groundwater flow
73
+ 'AET_hydro', # Actual ET
74
+ 'PET_hydro', # Potential ET
75
+ 'SWE', # Snow water equivalent
76
+ 'streamflow_no_rout', # Streamflow
77
+ 'srflow_no_rout', # Surface runoff
78
+ 'ssflow_no_rout', # Subsurface flow
79
+ 'gwflow_no_rout', # Groundwater flow
80
+ 'recharge', # Recharge
81
+ 'excs', # Excess stored water
82
+ 'evapfactor', # Evaporation factor
83
+ 'tosoil', # Infiltration
84
+ 'percolation', # Percolation
85
+ 'BFI', # Baseflow index
86
+ ]
87
+
88
+ self.parameter_bounds = {
89
+ 'parBETA': [1.0, 6.0],
90
+ 'parFC': [50, 1000],
91
+ 'parK0': [0.05, 0.9],
92
+ 'parK1': [0.01, 0.5],
93
+ 'parK2': [0.001, 0.2],
94
+ 'parLP': [0.2, 1],
95
+ 'parPERC': [0, 10],
96
+ 'parUZL': [0, 100],
97
+ 'parTT': [-2.5, 2.5],
98
+ 'parCFMAX': [0.5, 10],
99
+ 'parCFR': [0, 0.1],
100
+ 'parCWH': [0, 0.2],
101
+ }
102
+ self.routing_parameter_bounds = {
103
+ 'route_a': [0, 2.9],
104
+ 'route_b': [0, 6.5],
105
+ }
106
+
107
+ if not device:
108
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
109
+
110
+ if config is not None:
111
+ # Overwrite defaults with config values.
112
+ self.warm_up = config.get('warm_up', self.warm_up)
113
+ self.warm_up_states = config.get('warm_up_states', self.warm_up_states)
114
+ self.dy_drop = config.get('dy_drop', self.dy_drop)
115
+ self.dynamic_params = config['dynamic_params'].get(
116
+ self.__class__.__name__, self.dynamic_params
117
+ )
118
+ self.variables = config.get('variables', self.variables)
119
+ self.routing = config.get('routing', self.routing)
120
+ self.comprout = config.get('comprout', self.comprout)
121
+ self.nearzero = config.get('nearzero', self.nearzero)
122
+ self.nmul = config.get('nmul', self.nmul)
123
+ self.cache_states = config.get('cache_states', False)
124
+ if 'parBETAET' in self.dynamic_params:
125
+ self.parameter_bounds['parBETAET'] = [0.3, 5]
126
+ self._set_parameters()
127
+
128
+ def _init_states(self, ngrid: int) -> tuple[torch.Tensor]:
129
+ """Initialize model states to zero."""
130
+
131
+ def make_state():
132
+ return torch.full(
133
+ (ngrid, self.nmul), 0.001, dtype=torch.float32, device=self.device
134
+ )
135
+
136
+ return tuple(make_state() for _ in range(len(self.state_names)))
137
+
138
+ def get_states(self) -> Optional[tuple[torch.Tensor, ...]]:
139
+ """Return internal model states.
140
+
141
+ Returns
142
+ -------
143
+ tuple[torch.Tensor, ...]
144
+ A tuple containing the states (SNOWPACK, MELTWATER, SM, SUZ, SLZ).
145
+ """
146
+ return self._states_cache
147
+
148
+ def load_states(
149
+ self,
150
+ states: tuple[torch.Tensor, ...],
151
+ ) -> None:
152
+ """Load internal model states and set to model device and type.
153
+
154
+ Parameters
155
+ ----------
156
+ states
157
+ A tuple containing the states (SNOWPACK, MELTWATER, SM, SUZ, SLZ).
158
+ """
159
+ for state in states:
160
+ if not isinstance(state, torch.Tensor):
161
+ raise ValueError("Each element in `states` must be a tensor.")
162
+ nstates = len(self.state_names)
163
+ if not (isinstance(states, tuple) and len(states) == nstates):
164
+ raise ValueError(f"`states` must be a tuple of {nstates} tensors.")
165
+
166
+ self.states = tuple(
167
+ s.detach().to(self.device, dtype=torch.float32) for s in states
168
+ )
169
+
170
+ def _set_parameters(self) -> None:
171
+ """Get physical parameters."""
172
+ self.phy_param_names = self.parameter_bounds.keys()
173
+ if self.routing:
174
+ self.routing_param_names = self.routing_parameter_bounds.keys()
175
+ else:
176
+ self.routing_param_names = []
177
+
178
+ self.learnable_param_count = len(self.phy_param_names) * self.nmul + len(
179
+ self.routing_param_names
180
+ )
181
+
182
+ def _unpack_parameters(
183
+ self,
184
+ parameters: torch.Tensor,
185
+ ) -> tuple[torch.Tensor, torch.Tensor]:
186
+ """Extract physical model and routing parameters from NN output.
187
+
188
+ Parameters
189
+ ----------
190
+ parameters
191
+ Unprocessed, learned parameters from a neural network.
192
+
193
+ Returns
194
+ -------
195
+ tuple[torch.Tensor, torch.Tensor]
196
+ Tuple of physical and routing parameters.
197
+ """
198
+ phy_param_count = len(self.parameter_bounds)
199
+
200
+ # Physical parameters
201
+ phy_params = torch.sigmoid(
202
+ parameters[:, :, : phy_param_count * self.nmul]
203
+ ).view(
204
+ parameters.shape[0],
205
+ parameters.shape[1],
206
+ phy_param_count,
207
+ self.nmul,
208
+ )
209
+ # Routing parameters
210
+ routing_params = None
211
+ if self.routing:
212
+ routing_params = torch.sigmoid(
213
+ parameters[-1, :, phy_param_count * self.nmul :],
214
+ )
215
+ return (phy_params, routing_params)
216
+
217
+ def _descale_phy_parameters(
218
+ self,
219
+ phy_params: torch.Tensor,
220
+ dy_list: list,
221
+ ) -> dict[str, torch.Tensor]:
222
+ """Descale physical parameters.
223
+
224
+ Parameters
225
+ ----------
226
+ phy_params
227
+ Normalized physical parameters.
228
+ dy_list
229
+ List of dynamic parameters.
230
+
231
+ Returns
232
+ -------
233
+ dict
234
+ Dictionary of descaled physical parameters.
235
+ """
236
+ nsteps = phy_params.shape[0]
237
+ ngrid = phy_params.shape[1]
238
+
239
+ param_dict = {}
240
+ pmat = torch.ones([1, ngrid, 1]) * self.dy_drop
241
+ for i, name in enumerate(self.parameter_bounds.keys()):
242
+ staPar = phy_params[-1, :, i, :].unsqueeze(0).repeat([nsteps, 1, 1])
243
+ if name in dy_list:
244
+ dynPar = phy_params[:, :, i, :]
245
+ drmask = torch.bernoulli(pmat).detach_().to(self.device)
246
+ comPar = dynPar * (1 - drmask) + staPar * drmask
247
+ param_dict[name] = change_param_range(
248
+ param=comPar,
249
+ bounds=self.parameter_bounds[name],
250
+ )
251
+ else:
252
+ param_dict[name] = change_param_range(
253
+ param=staPar,
254
+ bounds=self.parameter_bounds[name],
255
+ )
256
+ return param_dict
257
+
258
+ def _descale_route_parameters(
259
+ self,
260
+ routing_params: torch.Tensor,
261
+ ) -> dict[str, torch.Tensor]:
262
+ """Descale routing parameters.
263
+
264
+ Parameters
265
+ ----------
266
+ routing_params
267
+ Normalized routing parameters.
268
+
269
+ Returns
270
+ -------
271
+ dict
272
+ Dictionary of descaled routing parameters.
273
+ """
274
+ parameter_dict = {}
275
+ for i, name in enumerate(self.routing_parameter_bounds.keys()):
276
+ param = routing_params[:, i]
277
+
278
+ parameter_dict[name] = change_param_range(
279
+ param=param,
280
+ bounds=self.routing_parameter_bounds[name],
281
+ )
282
+ return parameter_dict
283
+
284
+ def forward(
285
+ self,
286
+ x_dict: dict[str, torch.Tensor],
287
+ parameters: torch.Tensor,
288
+ ) -> Union[tuple, tuple[dict[str, torch.Tensor], tuple]]:
289
+ """Forward pass.
290
+
291
+ Parameters
292
+ ----------
293
+ x_dict
294
+ Dictionary of input forcing data.
295
+ parameters
296
+ Unprocessed, learned parameters from a neural network.
297
+
298
+ Returns
299
+ -------
300
+ Union[tuple, tuple[dict, tuple]]
301
+ Tuple or dictionary of model outputs.
302
+ """
303
+ # Unpack input data.
304
+ x = x_dict['x_phy']
305
+ self.muwts = x_dict.get('muwts', None)
306
+ ngrid = x.shape[1]
307
+
308
+ # Unpack parameters.
309
+ phy_params, routing_params = self._unpack_parameters(parameters)
310
+ if self.routing:
311
+ self.routing_param_dict = self._descale_route_parameters(routing_params)
312
+
313
+ # Initialization
314
+ if self.warm_up_states:
315
+ warm_up = self.warm_up
316
+ else:
317
+ # No state warm up: run the full model for warm_up days.
318
+ self.pred_cutoff = self.warm_up
319
+ warm_up = 0
320
+
321
+ if (not self.states) or (not self.cache_states):
322
+ current_states = self._init_states(ngrid)
323
+ else:
324
+ current_states = self.states
325
+
326
+ # Warm-up model states - run the model only on warm_up days first.
327
+ if warm_up > 0:
328
+ with torch.no_grad():
329
+ phy_param_warmup_dict = self._descale_phy_parameters(
330
+ phy_params[:warm_up, :, :],
331
+ dy_list=[],
332
+ )
333
+ # a. Save current model settings.
334
+ init_flag, route_flag = self.initialize, self.routing
335
+
336
+ # b. Set temporary model settings for warm-up.
337
+ self.initialize, self.routing = True, False
338
+
339
+ current_states = self._PBM(
340
+ x[:warm_up, :, :],
341
+ current_states,
342
+ phy_param_warmup_dict,
343
+ )
344
+
345
+ # c. Restore model settings.
346
+ self.initialize, self.routing = init_flag, route_flag
347
+
348
+ # Run the model for remainder of the simulation period.
349
+ phy_params_dict = self._descale_phy_parameters(
350
+ phy_params[warm_up:, :, :],
351
+ dy_list=self.dynamic_params,
352
+ )
353
+ fluxes, states = self._PBM(x[warm_up:, :, :], current_states, phy_params_dict)
354
+
355
+ # State caching
356
+ self._states_cache = [s.detach() for s in states]
357
+
358
+ if self.cache_states:
359
+ self.states = self._states_cache
360
+
361
+ return fluxes
362
+
363
+ def _PBM(
364
+ self,
365
+ forcing: torch.Tensor,
366
+ states: tuple,
367
+ full_param_dict: dict,
368
+ ) -> Union[tuple, dict[str, torch.Tensor]]:
369
+ """Run through process-based model (PBM).
370
+
371
+ Parameters
372
+ ----------
373
+ forcing
374
+ Input forcing data.
375
+ states
376
+ Initial model states.
377
+ full_param_dict
378
+ Dictionary of model parameters.
379
+
380
+ Returns
381
+ -------
382
+ Union[tuple, dict]
383
+ Tuple or dictionary of model outputs.
384
+ """
385
+ SNOWPACK, MELTWATER, SM, SUZ, SLZ = states
386
+
387
+ # Forcings
388
+ P = forcing[:, :, self.variables.index('prcp')] # Precipitation
389
+ T = forcing[:, :, self.variables.index('tmean')] # Mean air temp
390
+ PET = forcing[:, :, self.variables.index('pet')] # Potential ET
391
+ nsteps, ngrid = P.shape
392
+
393
+ # Expand dims to accomodate for nmul models.
394
+ Pm = P.unsqueeze(2).repeat(1, 1, self.nmul)
395
+ Tm = T.unsqueeze(2).repeat(1, 1, self.nmul)
396
+ PETm = PET.unsqueeze(-1).repeat(1, 1, self.nmul)
397
+
398
+ # Apply correction factor to precipitation
399
+ # P = parPCORR.repeat(nsteps, 1) * P
400
+
401
+ # Initialize time series of model variables in shape [time, basins, nmul].
402
+ Qsimmu = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.001
403
+ Q0_sim = (
404
+ torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.0001
405
+ )
406
+ Q1_sim = (
407
+ torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.0001
408
+ )
409
+ Q2_sim = (
410
+ torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.0001
411
+ )
412
+
413
+ # AET = PET_coef * PET
414
+ AET = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
415
+ recharge_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
416
+ excs_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
417
+ evapfactor_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
418
+ tosoil_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
419
+ PERC_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
420
+ SWE_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
421
+
422
+ param_dict = {}
423
+ for t in range(nsteps):
424
+ # Get dynamic parameter values per timestep.
425
+ for key in full_param_dict.keys():
426
+ param_dict[key] = full_param_dict[key][t, :, :]
427
+
428
+ # Separate precipitation into liquid and solid components.
429
+ PRECIP = Pm[t, :, :]
430
+ RAIN = torch.mul(
431
+ PRECIP, (Tm[t, :, :] >= param_dict['parTT']).type(torch.float32)
432
+ )
433
+ SNOW = torch.mul(
434
+ PRECIP, (Tm[t, :, :] < param_dict['parTT']).type(torch.float32)
435
+ )
436
+
437
+ # Snow -------------------------------
438
+ SNOWPACK = SNOWPACK + SNOW
439
+ melt = param_dict['parCFMAX'] * (Tm[t, :, :] - param_dict['parTT'])
440
+ # melt[melt < 0.0] = 0.0
441
+ melt = torch.clamp(melt, min=0.0)
442
+ # melt[melt > SNOWPACK] = SNOWPACK[melt > SNOWPACK]
443
+ melt = torch.min(melt, SNOWPACK)
444
+ MELTWATER = MELTWATER + melt
445
+ SNOWPACK = SNOWPACK - melt
446
+ refreezing = (
447
+ param_dict['parCFR']
448
+ * param_dict['parCFMAX']
449
+ * (param_dict['parTT'] - Tm[t, :, :])
450
+ )
451
+ # refreezing[refreezing < 0.0] = 0.0
452
+ # refreezing[refreezing > MELTWATER] = MELTWATER[refreezing > MELTWATER]
453
+ refreezing = torch.clamp(refreezing, min=0.0)
454
+ refreezing = torch.min(refreezing, MELTWATER)
455
+ SNOWPACK = SNOWPACK + refreezing
456
+ MELTWATER = MELTWATER - refreezing
457
+ tosoil = MELTWATER - (param_dict['parCWH'] * SNOWPACK)
458
+ tosoil = torch.clamp(tosoil, min=0.0)
459
+ MELTWATER = MELTWATER - tosoil
460
+
461
+ # Soil and evaporation -------------------------------
462
+ soil_wetness = (SM / param_dict['parFC']) ** param_dict['parBETA']
463
+ # soil_wetness[soil_wetness < 0.0] = 0.0
464
+ # soil_wetness[soil_wetness > 1.0] = 1.0
465
+ soil_wetness = torch.clamp(soil_wetness, min=0.0, max=1.0)
466
+ recharge = (RAIN + tosoil) * soil_wetness
467
+
468
+ SM = SM + RAIN + tosoil - recharge
469
+
470
+ excess = SM - param_dict['parFC']
471
+ excess = torch.clamp(excess, min=0.0)
472
+ SM = SM - excess
473
+ # parBETAET only has effect when it is a dynamic parameter.
474
+ evapfactor = SM / (param_dict['parLP'] * param_dict['parFC'])
475
+ if 'parBETAET' in param_dict:
476
+ evapfactor = evapfactor ** param_dict['parBETAET']
477
+ evapfactor = torch.clamp(evapfactor, min=0.0, max=1.0)
478
+ ETact = PETm[t, :, :] * evapfactor
479
+ ETact = torch.min(SM, ETact)
480
+ SM = torch.clamp(SM - ETact, min=self.nearzero)
481
+
482
+ # Groundwater boxes -------------------------------
483
+ SUZ = SUZ + recharge + excess
484
+ PERC = torch.min(SUZ, param_dict['parPERC'])
485
+ SUZ = SUZ - PERC
486
+ Q0 = param_dict['parK0'] * torch.clamp(SUZ - param_dict['parUZL'], min=0.0)
487
+ SUZ = SUZ - Q0
488
+ Q1 = param_dict['parK1'] * SUZ
489
+ SUZ = SUZ - Q1
490
+ SLZ = SLZ + PERC
491
+ Q2 = param_dict['parK2'] * SLZ
492
+ SLZ = SLZ - Q2
493
+
494
+ Qsimmu[t, :, :] = Q0 + Q1 + Q2
495
+ Q0_sim[t, :, :] = Q0
496
+ Q1_sim[t, :, :] = Q1
497
+ Q2_sim[t, :, :] = Q2
498
+ AET[t, :, :] = ETact
499
+ SWE_sim[t, :, :] = SNOWPACK
500
+
501
+ recharge_sim[t, :, :] = recharge
502
+ excs_sim[t, :, :] = excess
503
+ evapfactor_sim[t, :, :] = evapfactor
504
+ tosoil_sim[t, :, :] = tosoil
505
+ PERC_sim[t, :, :] = PERC
506
+
507
+ # Get the average or weighted average using learned weights.
508
+ if self.muwts is None:
509
+ Qsimavg = Qsimmu.mean(-1)
510
+ else:
511
+ Qsimavg = (Qsimmu * self.muwts).sum(-1)
512
+
513
+ # Run routing
514
+ if self.routing:
515
+ # Routing for all components or just the average.
516
+ if self.comprout:
517
+ # All components; reshape to [time, gages * num models]
518
+ Qsim = Qsimmu.view(nsteps, ngrid * self.nmul)
519
+ else:
520
+ # Average, then do routing.
521
+ Qsim = Qsimavg
522
+
523
+ UH = uh_gamma(
524
+ self.routing_param_dict['route_a'].repeat(nsteps, 1).unsqueeze(-1),
525
+ self.routing_param_dict['route_b'].repeat(nsteps, 1).unsqueeze(-1),
526
+ lenF=15,
527
+ )
528
+ rf = torch.unsqueeze(Qsim, -1).permute([1, 2, 0]) # [gages,vars,time]
529
+ UH = UH.permute([1, 2, 0]) # [gages,vars,time]
530
+ Qsrout = uh_conv(rf, UH).permute([2, 0, 1])
531
+
532
+ # Routing individually for Q0, Q1, and Q2, all w/ dims [gages,vars,time].
533
+ rf_Q0 = Q0_sim.mean(-1, keepdim=True).permute([1, 2, 0])
534
+ Q0_rout = uh_conv(rf_Q0, UH).permute([2, 0, 1])
535
+ rf_Q1 = Q1_sim.mean(-1, keepdim=True).permute([1, 2, 0])
536
+ Q1_rout = uh_conv(rf_Q1, UH).permute([2, 0, 1])
537
+ rf_Q2 = Q2_sim.mean(-1, keepdim=True).permute([1, 2, 0])
538
+ Q2_rout = uh_conv(rf_Q2, UH).permute([2, 0, 1])
539
+
540
+ if self.comprout:
541
+ # Qs is now shape [time, [gages*num models], vars]
542
+ Qstemp = Qsrout.view(nsteps, ngrid, self.nmul)
543
+ if self.muwts is None:
544
+ Qs = Qstemp.mean(-1, keepdim=True)
545
+ else:
546
+ Qs = (Qstemp * self.muwts).sum(-1, keepdim=True)
547
+ else:
548
+ Qs = Qsrout
549
+
550
+ else:
551
+ # No routing, only output the average of all model sims.
552
+ Qs = torch.unsqueeze(Qsimavg, -1)
553
+ Q0_rout = Q1_rout = Q2_rout = None
554
+
555
+ states = (SNOWPACK, MELTWATER, SM, SUZ, SLZ)
556
+
557
+ if self.initialize:
558
+ # If initialize is True, only return warmed-up storages.
559
+ return states
560
+ else:
561
+ # Baseflow index (BFI) calculation
562
+ BFI_sim = (
563
+ 100
564
+ * (torch.sum(Q2_rout, dim=0) / (torch.sum(Qs, dim=0) + self.nearzero))[
565
+ :, 0
566
+ ]
567
+ )
568
+
569
+ # Return all sim results.
570
+ flux_dict = {
571
+ 'streamflow': Qs, # Routed Streamflow
572
+ 'srflow': Q0_rout, # Routed surface runoff
573
+ 'ssflow': Q1_rout, # Routed subsurface flow
574
+ 'gwflow': Q2_rout, # Routed groundwater flow
575
+ 'AET_hydro': AET.mean(-1, keepdim=True), # Actual ET
576
+ 'PET_hydro': PETm.mean(-1, keepdim=True), # Potential ET
577
+ 'SWE': SWE_sim.mean(-1, keepdim=True), # Snow water equivalent
578
+ 'streamflow_no_rout': Qsim.unsqueeze(dim=2), # Streamflow
579
+ 'srflow_no_rout': Q0_sim.mean(-1, keepdim=True), # Surface runoff
580
+ 'ssflow_no_rout': Q1_sim.mean(-1, keepdim=True), # Subsurface flow
581
+ 'gwflow_no_rout': Q2_sim.mean(-1, keepdim=True), # Groundwater flow
582
+ 'recharge': recharge_sim.mean(-1, keepdim=True), # Recharge
583
+ 'excs': excs_sim.mean(-1, keepdim=True), # Excess stored water
584
+ 'evapfactor': evapfactor_sim.mean(
585
+ -1, keepdim=True
586
+ ), # Evaporation factor
587
+ 'tosoil': tosoil_sim.mean(-1, keepdim=True), # Infiltration
588
+ 'percolation': PERC_sim.mean(-1, keepdim=True), # Percolation
589
+ 'BFI': BFI_sim, # Baseflow index
590
+ }
591
+
592
+ if not self.warm_up_states:
593
+ for key in flux_dict.keys():
594
+ if key != 'BFI':
595
+ flux_dict[key] = flux_dict[key][self.pred_cutoff :, :, :]
596
+ return flux_dict, states