hydrodl2 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,670 @@
1
+ from typing import Any, Optional, Union
2
+
3
+ import torch
4
+
5
+ from hydrodl2.core.calc import change_param_range, uh_conv, uh_gamma
6
+
7
+
8
+ class Hbv_2(torch.nn.Module):
9
+ """HBV 2.0.
10
+
11
+ Multi-component, multi-scale, differentiable PyTorch HBV model with rainfall
12
+ runoff simulation on unit basins.
13
+
14
+ Authors
15
+ -------
16
+ - Yalan Song, Leo Lonzarich, Wencong Yang
17
+ - (Original NumPy HBV ver.) Beck et al., 2020 (http://www.gloh2o.org/hbv/).
18
+ - (HBV-light Version 2) Seibert, 2005
19
+ (https://www.geo.uzh.ch/dam/jcr:c8afa73c-ac90-478e-a8c7-929eed7b1b62/HBV_manual_2005.pdf).
20
+
21
+ Publication
22
+ -----------
23
+ - Yalan Song, Tadd Bindas, Chaopeng Shen, et al. "High-resolution
24
+ national-scale water modeling is enhanced by multiscale differentiable
25
+ physics-informed machine learning." Water Resources Research (2025).
26
+ https://doi.org/10.1029/2024WR038928.
27
+
28
+ Parameters
29
+ ----------
30
+ config
31
+ Configuration dictionary.
32
+ device
33
+ Device to run the model on.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ config: Optional[dict[str, Any]] = None,
39
+ device: Optional[torch.device] = None,
40
+ ) -> None:
41
+ super().__init__()
42
+ self.name = 'HBV 2.0'
43
+ self.config = config
44
+ self.initialize = False
45
+ self.warm_up = 0
46
+ self.pred_cutoff = 0
47
+ self.warm_up_states = True
48
+ self.dynamic_params = []
49
+ self.dy_drop = 0.0
50
+ self.variables = ['prcp', 'tmean', 'pet']
51
+ self.routing = False
52
+ self.lenF = 15
53
+ self.comprout = False
54
+ self.muwts = None
55
+ self.nearzero = 1e-5
56
+ self.nmul = 1
57
+ self.cache_states = False
58
+ self.device = device
59
+
60
+ self.states, self._state_cache = None, None
61
+
62
+ self.state_names = [
63
+ 'SNOWPACK', # Snowpack storage
64
+ 'MELTWATER', # Meltwater storage
65
+ 'SM', # Soil moisture storage
66
+ 'SUZ', # Upper groundwater storage
67
+ 'SLZ', # Lower groundwater storage
68
+ ]
69
+ self.flux_names = [
70
+ 'streamflow', # Routed Streamflow
71
+ 'srflow', # Routed surface runoff
72
+ 'ssflow', # Routed subsurface flow
73
+ 'gwflow', # Routed groundwater flow
74
+ 'AET_hydro', # Actual ET
75
+ 'PET_hydro', # Potential ET
76
+ 'SWE', # Snow water equivalent
77
+ 'streamflow_no_rout', # Streamflow
78
+ 'srflow_no_rout', # Surface runoff
79
+ 'ssflow_no_rout', # Subsurface flow
80
+ 'gwflow_no_rout', # Groundwater flow
81
+ 'recharge', # Recharge
82
+ 'excs', # Excess stored water
83
+ 'evapfactor', # Evaporation factor
84
+ 'tosoil', # Infiltration
85
+ 'percolation', # Percolation
86
+ 'capillary', # Capillary rise
87
+ 'BFI', # Baseflow index
88
+ ]
89
+
90
+ self.parameter_bounds = {
91
+ 'parBETA': [1.0, 6.0],
92
+ 'parFC': [50, 1000],
93
+ 'parK0': [0.05, 0.9],
94
+ 'parK1': [0.01, 0.5],
95
+ 'parK2': [0.001, 0.2],
96
+ 'parLP': [0.2, 1],
97
+ 'parPERC': [0, 10],
98
+ 'parUZL': [0, 100],
99
+ 'parTT': [-2.5, 2.5],
100
+ 'parCFMAX': [0.5, 10],
101
+ 'parCFR': [0, 0.1],
102
+ 'parCWH': [0, 0.2],
103
+ 'parBETAET': [0.3, 5],
104
+ 'parC': [0, 1],
105
+ 'parRT': [0, 20],
106
+ 'parAC': [0, 2500],
107
+ }
108
+ self.routing_parameter_bounds = {
109
+ 'route_a': [0, 2.9],
110
+ 'route_b': [0, 6.5],
111
+ }
112
+
113
+ if not device:
114
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
115
+
116
+ if config is not None:
117
+ # Overwrite defaults with config values.
118
+ self.warm_up = config.get('warm_up', self.warm_up)
119
+ self.warm_up_states = config.get('warm_up_states', self.warm_up_states)
120
+ self.dy_drop = config.get('dy_drop', self.dy_drop)
121
+ self.dynamic_params = config['dynamic_params'].get(
122
+ self.__class__.__name__, self.dynamic_params
123
+ )
124
+ self.variables = config.get('variables', self.variables)
125
+ self.routing = config.get('routing', self.routing)
126
+ self.comprout = config.get('comprout', self.comprout)
127
+ self.nearzero = config.get('nearzero', self.nearzero)
128
+ self.nmul = config.get('nmul', self.nmul)
129
+ self.cache_states = config.get('cache_states', self.cache_states)
130
+ self._set_parameters()
131
+
132
+ def _init_states(self, ngrid: int) -> tuple[torch.Tensor]:
133
+ """Initialize model states to zero."""
134
+
135
+ def make_state():
136
+ return torch.full(
137
+ (ngrid, self.nmul), 0.001, dtype=torch.float32, device=self.device
138
+ )
139
+
140
+ return tuple(make_state() for _ in range(len(self.state_names)))
141
+
142
+ def get_states(self) -> Optional[tuple[torch.Tensor, ...]]:
143
+ """Return internal model states.
144
+
145
+ Returns
146
+ -------
147
+ tuple[torch.Tensor, ...]
148
+ A tuple containing the states (SNOWPACK, MELTWATER, SM, SUZ, SLZ).
149
+ """
150
+ return self._state_cache
151
+
152
+ def load_states(
153
+ self,
154
+ states: tuple[torch.Tensor, ...],
155
+ ) -> None:
156
+ """Load internal model states and set to model device and type.
157
+
158
+ Parameters
159
+ ----------
160
+ states
161
+ A tuple containing the states (SNOWPACK, MELTWATER, SM, SUZ, SLZ).
162
+ """
163
+ for state in states:
164
+ if not isinstance(state, torch.Tensor):
165
+ raise ValueError("Each element in `states` must be a tensor.")
166
+ nstates = len(self.state_names)
167
+ if not (isinstance(states, tuple) and len(states) == nstates):
168
+ raise ValueError(f"`states` must be a tuple of {nstates} tensors.")
169
+
170
+ self.states = tuple(
171
+ s.detach().to(self.device, dtype=torch.float32) for s in states
172
+ )
173
+
174
+ def _set_parameters(self) -> None:
175
+ """Get physical parameters."""
176
+ self.phy_param_names = self.parameter_bounds.keys()
177
+ if self.routing:
178
+ self.routing_param_names = self.routing_parameter_bounds.keys()
179
+ else:
180
+ self.routing_param_names = []
181
+
182
+ self.learnable_param_count1 = len(self.dynamic_params) * self.nmul
183
+ self.learnable_param_count2 = (
184
+ len(self.phy_param_names) - len(self.dynamic_params)
185
+ ) * self.nmul + len(self.routing_param_names)
186
+ self.learnable_param_count = (
187
+ self.learnable_param_count1 + self.learnable_param_count2
188
+ )
189
+
190
+ def _unpack_parameters(
191
+ self,
192
+ parameters: torch.Tensor,
193
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
194
+ """Extract physical model and routing parameters from NN output.
195
+
196
+ Parameters
197
+ ----------
198
+ parameters
199
+ Unprocessed, learned parameters from a neural network.
200
+
201
+ Returns
202
+ -------
203
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]
204
+ Tuple of physical and routing parameters.
205
+ """
206
+ phy_param_count = len(self.parameter_bounds)
207
+ dy_param_count = len(self.dynamic_params)
208
+ dif_count = phy_param_count - dy_param_count
209
+
210
+ # Physical dynamic parameters
211
+ phy_dy_params = parameters[0].view(
212
+ parameters[0].shape[0],
213
+ parameters[0].shape[1],
214
+ dy_param_count,
215
+ self.nmul,
216
+ )
217
+
218
+ # Physical static parameters
219
+ phy_static_params = parameters[1][:, : dif_count * self.nmul].view(
220
+ parameters[1].shape[0],
221
+ dif_count,
222
+ self.nmul,
223
+ )
224
+
225
+ # Routing parameters
226
+ routing_params = None
227
+ if self.routing:
228
+ routing_params = parameters[1][:, dif_count * self.nmul :]
229
+
230
+ return (phy_dy_params, phy_static_params, routing_params)
231
+
232
+ def _descale_phy_dy_parameters(
233
+ self,
234
+ phy_dy_params: torch.Tensor,
235
+ dy_list: list,
236
+ ) -> dict[str, torch.Tensor]:
237
+ """Descale physical parameters.
238
+
239
+ Parameters
240
+ ----------
241
+ phy_params
242
+ Normalized physical parameters.
243
+ dy_list
244
+ List of dynamic parameters.
245
+
246
+ Returns
247
+ -------
248
+ dict
249
+ Dictionary of descaled physical parameters.
250
+ """
251
+ nsteps = phy_dy_params.shape[0]
252
+ ngrid = phy_dy_params.shape[1]
253
+
254
+ # TODO: Fix; if dynamic parameters are not entered in config as they are
255
+ # in HBV params list, then descaling misamtch will occur.
256
+ param_dict = {}
257
+ pmat = torch.ones([1, ngrid, 1]) * self.dy_drop
258
+ for i, name in enumerate(dy_list):
259
+ staPar = phy_dy_params[-1, :, i, :].unsqueeze(0).repeat([nsteps, 1, 1])
260
+
261
+ dynPar = phy_dy_params[:, :, i, :]
262
+ drmask = torch.bernoulli(pmat).detach_().to(self.device)
263
+
264
+ comPar = dynPar * (1 - drmask) + staPar * drmask
265
+ param_dict[name] = change_param_range(
266
+ param=comPar,
267
+ bounds=self.parameter_bounds[name],
268
+ )
269
+ return param_dict
270
+
271
+ def _descale_phy_stat_parameters(
272
+ self,
273
+ phy_stat_params: torch.Tensor,
274
+ stat_list: list,
275
+ ) -> torch.Tensor:
276
+ """Descale routing parameters.
277
+
278
+ Parameters
279
+ ----------
280
+ routing_params
281
+ Normalized routing parameters.
282
+
283
+ Returns
284
+ -------
285
+ dict
286
+ Dictionary of descaled routing parameters.
287
+ """
288
+ parameter_dict = {}
289
+ for i, name in enumerate(stat_list):
290
+ param = phy_stat_params[:, i, :]
291
+
292
+ parameter_dict[name] = change_param_range(
293
+ param=param,
294
+ bounds=self.parameter_bounds[name],
295
+ )
296
+ return parameter_dict
297
+
298
+ def _descale_route_parameters(
299
+ self,
300
+ routing_params: torch.Tensor,
301
+ ) -> torch.Tensor:
302
+ """Descale routing parameters.
303
+
304
+ Parameters
305
+ ----------
306
+ routing_params
307
+ Normalized routing parameters.
308
+
309
+ Returns
310
+ -------
311
+ dict
312
+ Dictionary of descaled routing parameters.
313
+ """
314
+ parameter_dict = {}
315
+ for i, name in enumerate(self.routing_parameter_bounds.keys()):
316
+ param = routing_params[:, i]
317
+
318
+ parameter_dict[name] = change_param_range(
319
+ param=param,
320
+ bounds=self.routing_parameter_bounds[name],
321
+ )
322
+ return parameter_dict
323
+
324
+ def forward(
325
+ self,
326
+ x_dict: dict[str, torch.Tensor],
327
+ parameters: torch.Tensor,
328
+ ) -> tuple[dict[str, torch.Tensor], tuple]:
329
+ """Forward pass.
330
+
331
+ Parameters
332
+ ----------
333
+ x_dict
334
+ Dictionary of input forcing data.
335
+ parameters
336
+ Unprocessed, learned parameters from a neural network.
337
+
338
+ Returns
339
+ -------
340
+ tuple[dict, tuple]
341
+ Tuple or dictionary of model outputs.
342
+ """
343
+ # Unpack input data.
344
+ x = x_dict['x_phy']
345
+ Ac = x_dict['ac_all'].unsqueeze(-1).repeat(1, self.nmul)
346
+ Elevation = x_dict['elev_all'].unsqueeze(-1).repeat(1, self.nmul)
347
+ self.muwts = x_dict.get('muwts', None)
348
+ ngrid = x.shape[1]
349
+
350
+ # Unpack parameters.
351
+ phy_dy_params, phy_static_params, routing_params = self._unpack_parameters(
352
+ parameters
353
+ )
354
+
355
+ if self.routing:
356
+ self.routing_param_dict = self._descale_route_parameters(routing_params)
357
+ phy_dy_params_dict = self._descale_phy_dy_parameters(
358
+ phy_dy_params,
359
+ dy_list=self.dynamic_params,
360
+ )
361
+ phy_static_params_dict = self._descale_phy_stat_parameters(
362
+ phy_static_params,
363
+ stat_list=[
364
+ param
365
+ for param in self.phy_param_names
366
+ if param not in self.dynamic_params
367
+ ],
368
+ )
369
+
370
+ if (not self.states) or (not self.cache_states):
371
+ current_states = self._init_states(ngrid)
372
+ else:
373
+ current_states = self.states
374
+
375
+ fluxes, states = self._PBM(
376
+ x,
377
+ Ac,
378
+ Elevation,
379
+ current_states,
380
+ phy_dy_params_dict,
381
+ phy_static_params_dict,
382
+ )
383
+
384
+ # State caching
385
+ self._state_cache = states
386
+
387
+ if self.cache_states:
388
+ self.states = tuple(s[-1].detach() for s in self._state_cache)
389
+
390
+ return fluxes
391
+
392
+ def _PBM(
393
+ self,
394
+ forcing: torch.Tensor,
395
+ Ac: torch.Tensor,
396
+ Elevation: torch.Tensor,
397
+ states: tuple,
398
+ phy_dy_params_dict: dict,
399
+ phy_static_params_dict: dict,
400
+ ) -> Union[tuple, dict[str, torch.Tensor]]:
401
+ """Run through process-based model (PBM).
402
+
403
+ Flux outputs are in mm/day.
404
+
405
+ Parameters
406
+ ----------
407
+ forcing
408
+ Input forcing data.
409
+ states
410
+ Initial model states.
411
+ full_param_dict
412
+ Dictionary of model parameters.
413
+
414
+ Returns
415
+ -------
416
+ Union[tuple, dict]
417
+ Tuple or dictionary of model outputs.
418
+ """
419
+ SNOWPACK, MELTWATER, SM, SUZ, SLZ = states
420
+
421
+ # Forcings
422
+ P = forcing[:, :, self.variables.index('prcp')] # Precipitation
423
+ T = forcing[:, :, self.variables.index('tmean')] # Mean air temp
424
+ PET = forcing[:, :, self.variables.index('pet')] # Potential ET
425
+ nsteps, ngrid = P.shape
426
+
427
+ # Expand dims to accomodate for nmul models.
428
+ Pm = P.unsqueeze(2).repeat(1, 1, self.nmul)
429
+ Tm = T.unsqueeze(2).repeat(1, 1, self.nmul)
430
+ PETm = PET.unsqueeze(-1).repeat(1, 1, self.nmul)
431
+
432
+ # Apply correction factor to precipitation
433
+ # P = parPCORR.repeat(nsteps, 1) * P
434
+
435
+ # Initialize time series of model variables in shape [time, basins, nmul].
436
+ Qsimmu = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.001
437
+ Q0_sim = (
438
+ torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.0001
439
+ )
440
+ Q1_sim = (
441
+ torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.0001
442
+ )
443
+ Q2_sim = (
444
+ torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.0001
445
+ )
446
+
447
+ AET = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
448
+ recharge_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
449
+ excs_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
450
+ evapfactor_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
451
+ tosoil_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
452
+ PERC_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
453
+ SWE_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
454
+ capillary_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
455
+
456
+ # NOTE: new for MTS -- Save model states for all time steps.
457
+ SNOWPACK_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
458
+ MELTWATER_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
459
+ SM_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
460
+ SUZ_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
461
+ SLZ_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
462
+
463
+ param_dict = {}
464
+ for t in range(nsteps):
465
+ # Get dynamic parameter values per timestep.
466
+ for key in phy_dy_params_dict.keys():
467
+ param_dict[key] = phy_dy_params_dict[key][t, :, :]
468
+ for key in phy_static_params_dict.keys():
469
+ param_dict[key] = phy_static_params_dict[key][:, :]
470
+
471
+ # Separate precipitation into liquid and solid components.
472
+ PRECIP = Pm[t, :, :]
473
+ parTT_new = (Elevation >= 2000).type(torch.float32) * 4.0 + (
474
+ Elevation < 2000
475
+ ).type(torch.float32) * param_dict['parTT']
476
+ RAIN = torch.mul(PRECIP, (Tm[t, :, :] >= parTT_new).type(torch.float32))
477
+ SNOW = torch.mul(PRECIP, (Tm[t, :, :] < parTT_new).type(torch.float32))
478
+
479
+ # Snow -------------------------------
480
+ SNOWPACK = SNOWPACK + SNOW
481
+ melt = param_dict['parCFMAX'] * (Tm[t, :, :] - parTT_new)
482
+ # melt[melt < 0.0] = 0.0
483
+ melt = torch.clamp(melt, min=0.0)
484
+ # melt[melt > SNOWPACK] = SNOWPACK[melt > SNOWPACK]
485
+ melt = torch.min(melt, SNOWPACK)
486
+ MELTWATER = MELTWATER + melt
487
+ SNOWPACK = SNOWPACK - melt
488
+ refreezing = (
489
+ param_dict['parCFR']
490
+ * param_dict['parCFMAX']
491
+ * (parTT_new - Tm[t, :, :])
492
+ )
493
+ # refreezing[refreezing < 0.0] = 0.0
494
+ # refreezing[refreezing > MELTWATER] = MELTWATER[refreezing > MELTWATER]
495
+ refreezing = torch.clamp(refreezing, min=0.0)
496
+ refreezing = torch.min(refreezing, MELTWATER)
497
+ SNOWPACK = SNOWPACK + refreezing
498
+ MELTWATER = MELTWATER - refreezing
499
+ tosoil = MELTWATER - (param_dict['parCWH'] * SNOWPACK)
500
+ tosoil = torch.clamp(tosoil, min=0.0)
501
+ MELTWATER = MELTWATER - tosoil
502
+
503
+ # Soil and evaporation -------------------------------
504
+ soil_wetness = (SM / param_dict['parFC']) ** param_dict['parBETA']
505
+ # soil_wetness[soil_wetness < 0.0] = 0.0
506
+ # soil_wetness[soil_wetness > 1.0] = 1.0
507
+ soil_wetness = torch.clamp(soil_wetness, min=0.0, max=1.0)
508
+ recharge = (RAIN + tosoil) * soil_wetness
509
+
510
+ SM = SM + RAIN + tosoil - recharge
511
+
512
+ excess = SM - param_dict['parFC']
513
+ excess = torch.clamp(excess, min=0.0)
514
+ SM = SM - excess
515
+ # NOTE: Different from HBV 1.0. Add static/dynamicET shape parameter parBETAET.
516
+ evapfactor = (
517
+ SM / (param_dict['parLP'] * param_dict['parFC'])
518
+ ) ** param_dict['parBETAET']
519
+ evapfactor = torch.clamp(evapfactor, min=0.0, max=1.0)
520
+ ETact = PETm[t, :, :] * evapfactor
521
+ ETact = torch.min(SM, ETact)
522
+ SM = torch.clamp(SM - ETact, min=self.nearzero)
523
+
524
+ # Capillary rise (HBV 1.1p mod) -------------------------------
525
+ capillary = torch.min(
526
+ SLZ,
527
+ param_dict['parC']
528
+ * SLZ
529
+ * (1.0 - torch.clamp(SM / param_dict['parFC'], max=1.0)),
530
+ )
531
+
532
+ SM = torch.clamp(SM + capillary, min=self.nearzero)
533
+ SLZ = torch.clamp(SLZ - capillary, min=self.nearzero)
534
+
535
+ # Groundwater boxes -------------------------------
536
+ SUZ = SUZ + recharge + excess
537
+ PERC = torch.min(SUZ, param_dict['parPERC'])
538
+ SUZ = SUZ - PERC
539
+ Q0 = param_dict['parK0'] * torch.clamp(SUZ - param_dict['parUZL'], min=0.0)
540
+ SUZ = SUZ - Q0
541
+ Q1 = param_dict['parK1'] * SUZ
542
+ SUZ = SUZ - Q1
543
+ SLZ = SLZ + PERC
544
+
545
+ LF = torch.clamp(
546
+ (Ac - param_dict['parAC']) / 1000, min=-1, max=1
547
+ ) * param_dict['parRT'] * (Ac < 2500) + torch.exp(
548
+ torch.clamp(-(Ac - 2500) / 50, min=-10.0, max=0.0)
549
+ ) * param_dict['parRT'] * (Ac >= 2500)
550
+ SLZ = torch.clamp(SLZ + LF, min=0.0)
551
+
552
+ Q2 = param_dict['parK2'] * SLZ
553
+ SLZ = SLZ - Q2
554
+
555
+ # --- Outputs ---
556
+ Qsimmu[t, :, :] = Q0 + Q1 + Q2
557
+ Q0_sim[t, :, :] = Q0
558
+ Q1_sim[t, :, :] = Q1
559
+ Q2_sim[t, :, :] = Q2
560
+ AET[t, :, :] = ETact
561
+ SWE_sim[t, :, :] = SNOWPACK
562
+ capillary_sim[t, :, :] = capillary
563
+
564
+ recharge_sim[t, :, :] = recharge
565
+ excs_sim[t, :, :] = excess
566
+ evapfactor_sim[t, :, :] = evapfactor
567
+ tosoil_sim[t, :, :] = tosoil
568
+ PERC_sim[t, :, :] = PERC
569
+
570
+ # NOTE: new for MTS -- Save model states for all time steps.
571
+ SNOWPACK_sim[t, :, :] = SNOWPACK
572
+ MELTWATER_sim[t, :, :] = MELTWATER
573
+ SM_sim[t, :, :] = SM
574
+ SUZ_sim[t, :, :] = SUZ
575
+ SLZ_sim[t, :, :] = SLZ
576
+
577
+ # Get the average or weighted average using learned weights.
578
+ if self.muwts is None:
579
+ Qsimavg = Qsimmu.mean(-1)
580
+ else:
581
+ Qsimavg = (Qsimmu * self.muwts).sum(-1)
582
+
583
+ # Run routing
584
+ if self.routing:
585
+ # Routing for all components or just the average.
586
+ if self.comprout:
587
+ # All components; reshape to [time, gages * num models]
588
+ Qsim = Qsimmu.view(nsteps, ngrid * self.nmul)
589
+ else:
590
+ # Average, then do routing.
591
+ Qsim = Qsimavg
592
+
593
+ UH = uh_gamma(
594
+ self.routing_param_dict['route_a'].repeat(nsteps, 1).unsqueeze(-1),
595
+ self.routing_param_dict['route_b'].repeat(nsteps, 1).unsqueeze(-1),
596
+ lenF=self.lenF,
597
+ )
598
+ rf = torch.unsqueeze(Qsim, -1).permute([1, 2, 0]) # [gages,vars,time]
599
+ UH = UH.permute([1, 2, 0]) # [gages,vars,time]
600
+ Qsrout = uh_conv(rf, UH).permute([2, 0, 1])
601
+
602
+ # Routing individually for Q0, Q1, and Q2, all w/ dims [gages,vars,time].
603
+ rf_Q0 = Q0_sim.mean(-1, keepdim=True).permute([1, 2, 0])
604
+ Q0_rout = uh_conv(rf_Q0, UH).permute([2, 0, 1])
605
+ rf_Q1 = Q1_sim.mean(-1, keepdim=True).permute([1, 2, 0])
606
+ Q1_rout = uh_conv(rf_Q1, UH).permute([2, 0, 1])
607
+ rf_Q2 = Q2_sim.mean(-1, keepdim=True).permute([1, 2, 0])
608
+ Q2_rout = uh_conv(rf_Q2, UH).permute([2, 0, 1])
609
+
610
+ if self.comprout:
611
+ # Qs is now shape [time, [gages*num models], vars]
612
+ Qstemp = Qsrout.view(nsteps, ngrid, self.nmul)
613
+ if self.muwts is None:
614
+ Qs = Qstemp.mean(-1, keepdim=True)
615
+ else:
616
+ Qs = (Qstemp * self.muwts).sum(-1, keepdim=True)
617
+ else:
618
+ Qs = Qsrout
619
+
620
+ else:
621
+ # No routing, only output the average of all model sims.
622
+ Qsim = Qsimavg
623
+ Qs = torch.unsqueeze(Qsimavg, -1)
624
+ Q0_rout = Q0_sim.mean(-1, keepdim=True)
625
+ Q1_rout = Q1_sim.mean(-1, keepdim=True)
626
+ Q2_rout = Q2_sim.mean(-1, keepdim=True)
627
+
628
+ states = (SNOWPACK_sim, MELTWATER_sim, SM_sim, SUZ_sim, SLZ_sim)
629
+
630
+ if self.initialize:
631
+ # If initialize is True, only return warmed-up storages.
632
+ return {}, states
633
+ else:
634
+ # Baseflow index (BFI) calculation
635
+ BFI_sim = (
636
+ 100
637
+ * (torch.sum(Q2_rout, dim=0) / (torch.sum(Qs, dim=0) + self.nearzero))[
638
+ :, 0
639
+ ]
640
+ )
641
+
642
+ # Return all sim results.
643
+ flux_dict = {
644
+ 'streamflow': Qs, # Routed Streamflow
645
+ 'srflow': Q0_rout, # Routed surface runoff
646
+ 'ssflow': Q1_rout, # Routed subsurface flow
647
+ 'gwflow': Q2_rout, # Routed groundwater flow
648
+ 'AET_hydro': AET.mean(-1, keepdim=True), # Actual ET
649
+ 'PET_hydro': PETm.mean(-1, keepdim=True), # Potential ET
650
+ 'SWE': SWE_sim.mean(-1, keepdim=True), # Snow water equivalent
651
+ 'streamflow_no_rout': Qsim.unsqueeze(dim=2), # Streamflow
652
+ 'srflow_no_rout': Q0_sim.mean(-1, keepdim=True), # Surface runoff
653
+ 'ssflow_no_rout': Q1_sim.mean(-1, keepdim=True), # Subsurface flow
654
+ 'gwflow_no_rout': Q2_sim.mean(-1, keepdim=True), # Groundwater flow
655
+ 'recharge': recharge_sim.mean(-1, keepdim=True), # Recharge
656
+ 'excs': excs_sim.mean(-1, keepdim=True), # Excess stored water
657
+ 'evapfactor': evapfactor_sim.mean(
658
+ -1, keepdim=True
659
+ ), # Evaporation factor
660
+ 'tosoil': tosoil_sim.mean(-1, keepdim=True), # Infiltration
661
+ 'percolation': PERC_sim.mean(-1, keepdim=True), # Percolation
662
+ 'capillary': capillary_sim.mean(-1, keepdim=True), # Capillary rise
663
+ 'BFI': BFI_sim, # Baseflow index
664
+ }
665
+
666
+ if not self.warm_up_states:
667
+ for key in flux_dict.keys():
668
+ if key != 'BFI':
669
+ flux_dict[key] = flux_dict[key][self.pred_cutoff :, :, :]
670
+ return flux_dict, states