hydrodl2 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,897 @@
1
+ from typing import Any, Optional, Union
2
+
3
+ import torch
4
+
5
+ from hydrodl2.core.calc import change_param_range, uh_conv, uh_gamma
6
+
7
+
8
+ class Hbv_2_hourly(torch.nn.Module):
9
+ """Hourly HBV 2.0.
10
+
11
+ Multi-component, multi-scale, differentiable PyTorch HBV model with rainfall
12
+ runoff simulation on unit basins.
13
+
14
+ Authors
15
+ -------
16
+ - Wencong Yang, Leo Lonzarich, Yalan Song
17
+ - (Original NumPy HBV ver.) Beck et al., 2020 (http://www.gloh2o.org/hbv/).
18
+ - (HBV-light Version 2) Seibert, 2005
19
+ (https://www.geo.uzh.ch/dam/jcr:c8afa73c-ac90-478e-a8c7-929eed7b1b62/HBV_manual_2005.pdf).
20
+
21
+ Parameters
22
+ ----------
23
+ config
24
+ Configuration dictionary.
25
+ device
26
+ Device to run the model on.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ config: Optional[dict[str, Any]] = None,
32
+ device: Optional[torch.device] = None,
33
+ ) -> None:
34
+ super().__init__()
35
+ self.name = 'HBV 2.0 Hourly'
36
+ self.config = config
37
+ self.initialize = False
38
+ self.warm_up = 0
39
+ self.pred_cutoff = 0
40
+ self.warm_up_states = True
41
+ self.dynamic_params = []
42
+ self.dy_drop = 0.0
43
+ self.variables = ['prcp', 'tmean', 'pet']
44
+ self.routing = False
45
+ self.lenF = 72
46
+ self.comprout = False
47
+ self.muwts = None
48
+ self.nearzero = 1e-5
49
+ self.nmul = 1
50
+ self.cache_states = False
51
+ self.device = device
52
+
53
+ self.states, self._states_cache = None, None
54
+
55
+ self._qs_buffer = []
56
+ self._max_history = 100 # Safe buffer size > lenF (72)
57
+
58
+ self.dt = 1.0 / 24
59
+ self.use_distr_routing = True
60
+ self.infiltration = True
61
+ self.lag_uh = True
62
+
63
+ self.state_names = [
64
+ 'SNOWPACK', # Snowpack storage
65
+ 'MELTWATER', # Meltwater storage
66
+ 'SM', # Soil moisture storage
67
+ 'SUZ', # Upper groundwater storage
68
+ 'SLZ', # Lower groundwater storage
69
+ ]
70
+ self.flux_names = [
71
+ 'streamflow', # Routed Streamflow
72
+ 'srflow', # Routed surface runoff
73
+ 'ssflow', # Routed subsurface flow
74
+ 'gwflow', # Routed groundwater flow
75
+ 'AET_hydro', # Actual ET
76
+ 'PET_hydro', # Potential ET
77
+ 'SWE', # Snow water equivalent
78
+ 'streamflow_no_rout', # Streamflow
79
+ 'srflow_no_rout', # Surface runoff
80
+ 'ssflow_no_rout', # Subsurface flow
81
+ 'gwflow_no_rout', # Groundwater flow
82
+ 'recharge', # Recharge
83
+ 'excs', # Excess stored water
84
+ 'evapfactor', # Evaporation factor
85
+ 'tosoil', # Infiltration
86
+ 'percolation', # Percolation
87
+ 'capillary', # Capillary rise
88
+ 'BFI', # Baseflow index
89
+ ]
90
+
91
+ self.parameter_bounds = {
92
+ 'parBETA': [1.0, 6.0],
93
+ 'parFC': [50, 1000],
94
+ 'parK0': [0.05, 0.9],
95
+ 'parK1': [0.01, 0.5],
96
+ 'parK2': [0.001, 0.2],
97
+ 'parLP': [0.2, 1],
98
+ 'parPERC': [0, 10],
99
+ 'parUZL': [0, 100],
100
+ 'parTT': [-2.5, 2.5],
101
+ 'parCFMAX': [0.5, 10],
102
+ 'parCFR': [0, 0.1],
103
+ 'parCWH': [0, 0.2],
104
+ 'parBETAET': [0.3, 5],
105
+ 'parC': [0, 1],
106
+ 'parRT': [0, 20],
107
+ 'parAC': [0, 2500],
108
+ # Infiltration parameters for hourly
109
+ 'parF0': [
110
+ 5.0 / self.dt,
111
+ 120.0 / self.dt,
112
+ ], # dry (max) infiltration capacity, mm/day
113
+ 'parFMIN': [0.0, 1.0], # wet (min) capacity ratio
114
+ 'parALPHA': [0.5, 5.0], # shape of f(s); larger -> more thresholdy
115
+ }
116
+ self.routing_parameter_bounds = {
117
+ 'route_a': [0, 5.0],
118
+ 'route_b': [0, 12.0],
119
+ }
120
+ self.distr_parameter_bounds = {
121
+ 'route_a': [0, 5.0],
122
+ 'route_b': [0, 12.0],
123
+ 'route_tau': [0, 48.0],
124
+ }
125
+
126
+ if not device:
127
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
128
+
129
+ if not self.infiltration:
130
+ self.parameter_bounds.pop('parF0')
131
+ self.parameter_bounds.pop('parFMIN')
132
+ self.parameter_bounds.pop('parALPHA')
133
+ if not self.lag_uh:
134
+ self.distr_parameter_bounds.pop('route_tau')
135
+
136
+ if config is not None:
137
+ # Overwrite defaults with config values.
138
+ self.warm_up = config.get('warm_up', self.warm_up)
139
+ self.warm_up_states = config.get('warm_up_states', self.warm_up_states)
140
+ self.dy_drop = config.get('dy_drop', self.dy_drop)
141
+ self.dynamic_params = config['dynamic_params'].get(
142
+ self.__class__.__name__, self.dynamic_params
143
+ )
144
+ self.variables = config.get('variables', self.variables)
145
+ self.routing = config.get('routing', self.routing)
146
+ self.comprout = config.get('comprout', self.comprout)
147
+ self.nearzero = config.get('nearzero', self.nearzero)
148
+ self.nmul = config.get('nmul', self.nmul)
149
+ self.cache_states = config.get('cache_states', self.cache_states)
150
+ self._set_parameters()
151
+
152
+ def _init_states(self, ngrid: int) -> tuple[torch.Tensor]:
153
+ """Initialize model states to zero."""
154
+
155
+ def make_state():
156
+ return torch.full(
157
+ (ngrid, self.nmul), 0.001, dtype=torch.float32, device=self.device
158
+ )
159
+
160
+ return tuple(make_state() for _ in range(len(self.state_names)))
161
+
162
+ def get_states(self) -> Optional[tuple[torch.Tensor, ...]]:
163
+ """Return internal model states.
164
+
165
+ Returns
166
+ -------
167
+ tuple[torch.Tensor, ...]
168
+ A tuple containing the states (SNOWPACK, MELTWATER, SM, SUZ, SLZ).
169
+ """
170
+ return self._states_cache
171
+
172
+ def load_states(
173
+ self,
174
+ states: tuple[torch.Tensor, ...],
175
+ ) -> None:
176
+ """Load internal model states and set to model device and type.
177
+
178
+ Parameters
179
+ ----------
180
+ states
181
+ A tuple containing the states (SNOWPACK, MELTWATER, SM, SUZ, SLZ).
182
+ """
183
+ for state in states:
184
+ if not isinstance(state, torch.Tensor):
185
+ raise ValueError("Each element in `states` must be a tensor.")
186
+ nstates = len(self.state_names)
187
+ if not (isinstance(states, tuple) and len(states) == nstates):
188
+ raise ValueError(f"`states` must be a tuple of {nstates} tensors.")
189
+
190
+ self.states = tuple(
191
+ s.detach().to(self.device, dtype=torch.float32) for s in states
192
+ )
193
+
194
+ def _set_parameters(self) -> None:
195
+ """Get physical parameters."""
196
+ self.phy_param_names = self.parameter_bounds.keys()
197
+ if self.routing:
198
+ self.routing_param_names = self.routing_parameter_bounds.keys()
199
+ else:
200
+ self.routing_param_names = []
201
+
202
+ self.learnable_param_count1 = len(self.dynamic_params) * self.nmul
203
+ self.learnable_param_count2 = (
204
+ len(self.phy_param_names) - len(self.dynamic_params)
205
+ ) * self.nmul + len(self.routing_param_names)
206
+ self.learnable_param_count3 = len(self.distr_parameter_bounds)
207
+ self.learnable_param_count = (
208
+ self.learnable_param_count1
209
+ + self.learnable_param_count2
210
+ + self.learnable_param_count3
211
+ )
212
+
213
+ def _unpack_parameters(
214
+ self,
215
+ parameters: torch.Tensor,
216
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
217
+ """Extract physical model and routing parameters from NN output.
218
+
219
+ Parameters
220
+ ----------
221
+ parameters
222
+ Unprocessed, learned parameters from a neural network.
223
+
224
+ Returns
225
+ -------
226
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]
227
+ Tuple of physical and routing parameters.
228
+ """
229
+ phy_param_count = len(self.parameter_bounds)
230
+ dy_param_count = len(self.dynamic_params)
231
+ dif_count = phy_param_count - dy_param_count
232
+
233
+ # Physical dynamic parameters
234
+ phy_dy_params = parameters[0].view(
235
+ parameters[0].shape[0],
236
+ parameters[0].shape[1],
237
+ dy_param_count,
238
+ self.nmul,
239
+ )
240
+
241
+ # Physical static parameters
242
+ phy_static_params = parameters[1][:, : dif_count * self.nmul].view(
243
+ parameters[1].shape[0],
244
+ dif_count,
245
+ self.nmul,
246
+ )
247
+
248
+ # Routing parameters
249
+ routing_params = None
250
+ if self.routing:
251
+ routing_params = parameters[1][:, dif_count * self.nmul :]
252
+
253
+ # Distributed routing parameters
254
+ distr_params = parameters[2]
255
+
256
+ return (phy_dy_params, phy_static_params, routing_params, distr_params)
257
+
258
+ def _descale_phy_dy_parameters(
259
+ self,
260
+ phy_dy_params: torch.Tensor,
261
+ dy_list: list,
262
+ ) -> dict[str, torch.Tensor]:
263
+ """Descale physical parameters.
264
+
265
+ Parameters
266
+ ----------
267
+ phy_params
268
+ Normalized physical parameters.
269
+ dy_list
270
+ List of dynamic parameters.
271
+
272
+ Returns
273
+ -------
274
+ dict
275
+ Dictionary of descaled physical parameters.
276
+ """
277
+ nsteps = phy_dy_params.shape[0]
278
+ ngrid = phy_dy_params.shape[1]
279
+
280
+ # TODO: Fix; if dynamic parameters are not entered in config as they are
281
+ # in HBV params list, then descaling misamtch will occur.
282
+ param_dict = {}
283
+ pmat = torch.ones([1, ngrid, 1]) * self.dy_drop
284
+ for i, name in enumerate(dy_list):
285
+ staPar = phy_dy_params[-1, :, i, :].unsqueeze(0).repeat([nsteps, 1, 1])
286
+
287
+ dynPar = phy_dy_params[:, :, i, :]
288
+ drmask = torch.bernoulli(pmat).detach_().to(self.device)
289
+
290
+ comPar = dynPar * (1 - drmask) + staPar * drmask
291
+ param_dict[name] = change_param_range(
292
+ param=comPar,
293
+ bounds=self.parameter_bounds[name],
294
+ )
295
+ return param_dict
296
+
297
+ def _descale_phy_stat_parameters(
298
+ self,
299
+ phy_stat_params: torch.Tensor,
300
+ stat_list: list,
301
+ ) -> torch.Tensor:
302
+ """Descale routing parameters.
303
+
304
+ Parameters
305
+ ----------
306
+ routing_params
307
+ Normalized routing parameters.
308
+
309
+ Returns
310
+ -------
311
+ dict
312
+ Dictionary of descaled routing parameters.
313
+ """
314
+ parameter_dict = {}
315
+ for i, name in enumerate(stat_list):
316
+ param = phy_stat_params[:, i, :]
317
+
318
+ parameter_dict[name] = change_param_range(
319
+ param=param,
320
+ bounds=self.parameter_bounds[name],
321
+ )
322
+ return parameter_dict
323
+
324
+ def _descale_route_parameters(
325
+ self,
326
+ routing_params: torch.Tensor,
327
+ ) -> torch.Tensor:
328
+ """Descale routing parameters.
329
+
330
+ Parameters
331
+ ----------
332
+ routing_params
333
+ Normalized routing parameters.
334
+
335
+ Returns
336
+ -------
337
+ dict
338
+ Dictionary of descaled routing parameters.
339
+ """
340
+ parameter_dict = {}
341
+ for i, name in enumerate(self.routing_parameter_bounds.keys()):
342
+ param = routing_params[:, i]
343
+
344
+ parameter_dict[name] = change_param_range(
345
+ param=param,
346
+ bounds=self.routing_parameter_bounds[name],
347
+ )
348
+ return parameter_dict
349
+
350
+ def _descale_distr_parameters(
351
+ self,
352
+ distr_params: torch.Tensor,
353
+ ) -> dict[str, torch.Tensor]:
354
+ """Descale distributed routing parameters.
355
+
356
+ Parameters
357
+ ----------
358
+ distr_params
359
+ Normalized distributed routing parameters.
360
+
361
+ Returns
362
+ -------
363
+ dict
364
+ Dictionary of descaled distributed routing parameters.
365
+ """
366
+ parameter_dict = {}
367
+ for i, name in enumerate(self.distr_parameter_bounds.keys()):
368
+ param = distr_params[:, i]
369
+
370
+ parameter_dict[name] = change_param_range(
371
+ param=param,
372
+ bounds=self.distr_parameter_bounds[name],
373
+ )
374
+ return parameter_dict
375
+
376
+ def forward(
377
+ self,
378
+ x_dict: dict[str, torch.Tensor],
379
+ parameters: torch.Tensor,
380
+ ) -> tuple[dict[str, torch.Tensor], tuple]:
381
+ """Forward pass.
382
+
383
+ Parameters
384
+ ----------
385
+ x_dict
386
+ Dictionary of input forcing data.
387
+ parameters
388
+ Unprocessed, learned parameters from a neural network.
389
+
390
+ Returns
391
+ -------
392
+ tuple[dict, tuple]
393
+ Tuple or dictionary of model outputs.
394
+ """
395
+ # Unpack input data.
396
+ x = x_dict['x_phy']
397
+ Ac = x_dict['ac_all'].unsqueeze(-1).repeat(1, self.nmul)
398
+ Elevation = x_dict['elev_all'].unsqueeze(-1).repeat(1, self.nmul)
399
+ outlet_topo = x_dict['outlet_topo']
400
+ areas = x_dict['areas']
401
+ self.muwts = x_dict.get('muwts', None)
402
+ ngrid = x.shape[1]
403
+
404
+ # Unpack parameters.
405
+ phy_dy_params, phy_static_params, routing_params, distr_params = (
406
+ self._unpack_parameters(parameters)
407
+ )
408
+
409
+ if self.routing:
410
+ self.routing_param_dict = self._descale_route_parameters(routing_params)
411
+ phy_dy_params_dict = self._descale_phy_dy_parameters(
412
+ phy_dy_params,
413
+ dy_list=self.dynamic_params,
414
+ )
415
+ phy_static_params_dict = self._descale_phy_stat_parameters(
416
+ phy_static_params,
417
+ stat_list=[
418
+ param
419
+ for param in self.phy_param_names
420
+ if param not in self.dynamic_params
421
+ ],
422
+ )
423
+
424
+ if (not self.states) or (not self.cache_states):
425
+ current_states = self._init_states(ngrid)
426
+ else:
427
+ current_states = self.states
428
+
429
+ distr_params_dict = self._descale_distr_parameters(distr_params)
430
+
431
+ fluxes, states = self._PBM(
432
+ x,
433
+ Ac,
434
+ Elevation,
435
+ current_states,
436
+ phy_dy_params_dict,
437
+ phy_static_params_dict,
438
+ outlet_topo,
439
+ areas,
440
+ distr_params_dict,
441
+ )
442
+
443
+ # State caching
444
+ self._state_cache = states
445
+
446
+ if self.cache_states:
447
+ self.states = tuple(s[-1].detach() for s in self._state_cache)
448
+
449
+ return fluxes
450
+
451
+ def _PBM(
452
+ self,
453
+ forcing: torch.Tensor,
454
+ Ac: torch.Tensor,
455
+ Elevation: torch.Tensor,
456
+ states: tuple,
457
+ phy_dy_params_dict: dict,
458
+ phy_static_params_dict: dict,
459
+ outlet_topo: torch.Tensor,
460
+ areas: torch.Tensor,
461
+ distr_params_dict: dict,
462
+ ) -> Union[tuple, dict[str, torch.Tensor]]:
463
+ """Run through process-based model (PBM).
464
+
465
+ Flux outputs are in mm/hour.
466
+
467
+ Parameters
468
+ ----------
469
+ forcing
470
+ Input forcing data.
471
+ states
472
+ Initial model states.
473
+ full_param_dict
474
+ Dictionary of model parameters.
475
+
476
+ Returns
477
+ -------
478
+ Union[tuple, dict]
479
+ Tuple or dictionary of model outputs.
480
+ """
481
+ dt = self.dt
482
+ SNOWPACK, MELTWATER, SM, SUZ, SLZ = states
483
+
484
+ # Forcings
485
+ P = forcing[:, :, self.variables.index('prcp')] / dt # Precipitation
486
+ T = forcing[:, :, self.variables.index('tmean')] # Mean air temp
487
+ PET = forcing[:, :, self.variables.index('pet')] / dt # Potential ET
488
+ nsteps, ngrid = P.shape
489
+
490
+ # Expand dims to accomodate for nmul models.
491
+ Pm = P.unsqueeze(2).repeat(1, 1, self.nmul)
492
+ Tm = T.unsqueeze(2).repeat(1, 1, self.nmul)
493
+ PETm = PET.unsqueeze(-1).repeat(1, 1, self.nmul)
494
+
495
+ # Apply correction factor to precipitation
496
+ # P = parPCORR.repeat(nsteps, 1) * P
497
+
498
+ # Initialize time series of model variables in shape [time, basins, nmul].
499
+ Qsimmu = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.001
500
+ Q0_sim = (
501
+ torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.0001
502
+ )
503
+ Q1_sim = (
504
+ torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.0001
505
+ )
506
+ Q2_sim = (
507
+ torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.0001
508
+ )
509
+
510
+ AET = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
511
+ recharge_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
512
+ excs_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
513
+ evapfactor_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
514
+ tosoil_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
515
+ PERC_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
516
+ SWE_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
517
+ capillary_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
518
+
519
+ # NOTE: new for MTS -- Save model states for all time steps.
520
+ SNOWPACK_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
521
+ MELTWATER_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
522
+ SM_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
523
+ SUZ_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
524
+ SLZ_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
525
+
526
+ param_dict = {}
527
+ for t in range(nsteps):
528
+ # NOTE: new for MTS -- numerical guardrail for long-sequence running
529
+ SNOWPACK = torch.clamp(SNOWPACK, min=0.0)
530
+ MELTWATER = torch.clamp(MELTWATER, min=0.0)
531
+ SM = torch.clamp(SM, min=self.nearzero)
532
+ SUZ = torch.clamp(SUZ, min=self.nearzero)
533
+ SLZ = torch.clamp(SLZ, min=self.nearzero)
534
+ # ------------------------------------------------------------------
535
+
536
+ # Get dynamic parameter values per timestep.
537
+ for key in phy_dy_params_dict.keys():
538
+ param_dict[key] = phy_dy_params_dict[key][t, :, :]
539
+ for key in phy_static_params_dict.keys():
540
+ param_dict[key] = phy_static_params_dict[key][:, :]
541
+
542
+ # Separate precipitation into liquid and solid components.
543
+ PRECIP = Pm[t, :, :]
544
+ parTT_new = (Elevation >= 2000).type(torch.float32) * 4.0 + (
545
+ Elevation < 2000
546
+ ).type(torch.float32) * param_dict['parTT']
547
+ RAIN = torch.mul(PRECIP, (Tm[t, :, :] >= parTT_new).type(torch.float32))
548
+ SNOW = torch.mul(PRECIP, (Tm[t, :, :] < parTT_new).type(torch.float32))
549
+
550
+ # Snow -------------------------------
551
+ SNOWPACK = SNOWPACK + SNOW * dt
552
+ melt = param_dict['parCFMAX'] * (Tm[t, :, :] - parTT_new)
553
+ # melt[melt < 0.0] = 0.0
554
+ melt = torch.clamp(melt, min=0.0)
555
+ # melt[melt > SNOWPACK] = SNOWPACK[melt > SNOWPACK]
556
+ melt = torch.min(melt * dt, SNOWPACK)
557
+ MELTWATER = MELTWATER + melt
558
+ SNOWPACK = SNOWPACK - melt
559
+ refreezing = (
560
+ param_dict['parCFR']
561
+ * param_dict['parCFMAX']
562
+ * (parTT_new - Tm[t, :, :])
563
+ )
564
+ # refreezing[refreezing < 0.0] = 0.0
565
+ # refreezing[refreezing > MELTWATER] = MELTWATER[refreezing > MELTWATER]
566
+ refreezing = torch.clamp(refreezing, min=0.0)
567
+ refreezing = torch.min(refreezing * dt, MELTWATER)
568
+ SNOWPACK = SNOWPACK + refreezing
569
+ MELTWATER = MELTWATER - refreezing
570
+ tosoil = (MELTWATER - (param_dict['parCWH'] * SNOWPACK)) / dt
571
+ tosoil = torch.clamp(tosoil, min=0.0)
572
+ MELTWATER = MELTWATER - tosoil * dt
573
+
574
+ # NOTE: new for MTS -- Hortonian Infiltration Excess
575
+ if self.infiltration:
576
+ # Hortonian infiltration excess: infiltration capacity as a function of wetness
577
+ W = RAIN + tosoil
578
+ s = torch.clamp(
579
+ SM / param_dict['parFC'], 0.0, 1.0 - 0.01
580
+ ) # relative wetness, safe guard for pow and bf/fp16
581
+ parFMIN = param_dict['parFMIN'] * param_dict['parF0']
582
+ with torch.amp.autocast(
583
+ device_type='cuda', enabled=False
584
+ ): # torch.pow not stable with bf/fp16 when base ~ 0
585
+ fcap = parFMIN + (param_dict['parF0'] - parFMIN) * torch.pow(
586
+ 1.0 - s, param_dict['parALPHA']
587
+ )
588
+ infiltration = torch.minimum(W, fcap) # goes into soil
589
+ IE = torch.clamp(W - fcap, min=0.0) # Hortonian excess
590
+
591
+ # Soil and evaporation using Infiltration
592
+ soil_wetness = (SM / param_dict['parFC']) ** param_dict['parBETA']
593
+ soil_wetness = torch.clamp(soil_wetness, 0.0, 1.0)
594
+ recharge = infiltration * soil_wetness
595
+ SM = SM + (infiltration - recharge) * dt
596
+ else:
597
+ soil_wetness = (SM / param_dict['parFC']) ** param_dict['parBETA']
598
+ soil_wetness = torch.clamp(soil_wetness, min=0.0, max=1.0)
599
+ recharge = (RAIN + tosoil) * soil_wetness
600
+ SM = SM + (RAIN + tosoil - recharge) * dt
601
+ # ------------------------------------------------------
602
+
603
+ excess = (SM - param_dict['parFC']) / dt
604
+ excess = torch.clamp(excess, min=0.0)
605
+ SM = SM - excess * dt
606
+ # NOTE: Different from HBV 1.0. Add static/dynamicET shape parameter parBETAET.
607
+ evapfactor = (
608
+ SM / (param_dict['parLP'] * param_dict['parFC'])
609
+ ) ** param_dict['parBETAET']
610
+ evapfactor = torch.clamp(evapfactor, min=0.0, max=1.0)
611
+ ETact = PETm[t, :, :] * evapfactor
612
+ ETact = torch.min(SM, ETact * dt) / dt
613
+ SM = torch.clamp(SM - ETact * dt, min=self.nearzero)
614
+
615
+ # Capillary rise (HBV 1.1p mod) -------------------------------
616
+ capillary = (
617
+ torch.min(
618
+ SLZ,
619
+ param_dict['parC']
620
+ * SLZ
621
+ * (1.0 - torch.clamp(SM / param_dict['parFC'], max=1.0))
622
+ * dt,
623
+ )
624
+ / dt
625
+ )
626
+
627
+ SM = torch.clamp(SM + capillary * dt, min=self.nearzero)
628
+ SLZ = torch.clamp(SLZ - capillary * dt, min=self.nearzero)
629
+
630
+ # Groundwater boxes -------------------------------
631
+ SUZ = SUZ + (recharge + excess) * dt
632
+ PERC = torch.min(SUZ, param_dict['parPERC'] * dt) / dt
633
+ SUZ = SUZ - PERC * dt
634
+ Q0 = param_dict['parK0'] * torch.clamp(SUZ - param_dict['parUZL'], min=0.0)
635
+ SUZ = SUZ - Q0 * dt
636
+ Q1 = param_dict['parK1'] * SUZ
637
+ SUZ = SUZ - Q1 * dt
638
+ SLZ = SLZ + PERC * dt
639
+
640
+ LF = torch.clamp(
641
+ (Ac - param_dict['parAC']) / 1000, min=-1, max=1
642
+ ) * param_dict['parRT'] * (Ac < 2500) + torch.exp(
643
+ torch.clamp(-(Ac - 2500) / 50, min=-10.0, max=0.0)
644
+ ) * param_dict['parRT'] * (Ac >= 2500)
645
+ SLZ = torch.clamp(SLZ + LF * dt, min=0.0)
646
+
647
+ Q2 = param_dict['parK2'] * SLZ
648
+ SLZ = SLZ - Q2 * dt
649
+
650
+ # NOTE: new for MTS -- Add Hortonian Infiltration Excess
651
+ if self.infiltration:
652
+ Qsimmu[t, :, :] = Q0 + Q1 + Q2 + IE
653
+ else:
654
+ Qsimmu[t, :, :] = Q0 + Q1 + Q2
655
+ # ------------------------------------------------------
656
+
657
+ Q0_sim[t, :, :] = Q0
658
+ Q1_sim[t, :, :] = Q1
659
+ Q2_sim[t, :, :] = Q2
660
+ AET[t, :, :] = ETact
661
+ SWE_sim[t, :, :] = SNOWPACK
662
+ capillary_sim[t, :, :] = capillary
663
+
664
+ recharge_sim[t, :, :] = recharge
665
+ excs_sim[t, :, :] = excess
666
+ evapfactor_sim[t, :, :] = evapfactor
667
+ tosoil_sim[t, :, :] = tosoil
668
+ PERC_sim[t, :, :] = PERC
669
+
670
+ # NOTE: new for MTS -- Save model states for all time steps.
671
+ SNOWPACK_sim[t, :, :] = SNOWPACK
672
+ MELTWATER_sim[t, :, :] = MELTWATER
673
+ SM_sim[t, :, :] = SM
674
+ SUZ_sim[t, :, :] = SUZ
675
+ SLZ_sim[t, :, :] = SLZ
676
+
677
+ # Get the average or weighted average using learned weights.
678
+ if self.muwts is None:
679
+ Qsimavg = Qsimmu.mean(-1)
680
+ else:
681
+ Qsimavg = (Qsimmu * self.muwts).sum(-1)
682
+
683
+ # Run routing
684
+ if self.routing:
685
+ # Routing for all components or just the average.
686
+ if self.comprout:
687
+ # All components; reshape to [time, gages * num models]
688
+ Qsim = Qsimmu.view(nsteps, ngrid * self.nmul)
689
+ else:
690
+ # Average, then do routing.
691
+ Qsim = Qsimavg
692
+
693
+ UH = uh_gamma(
694
+ self.routing_param_dict['route_a'].repeat(nsteps, 1).unsqueeze(-1),
695
+ self.routing_param_dict['route_b'].repeat(nsteps, 1).unsqueeze(-1),
696
+ lenF=self.lenF,
697
+ )
698
+ rf = torch.unsqueeze(Qsim, -1).permute([1, 2, 0]) # [gages,vars,time]
699
+ UH = UH.permute([1, 2, 0]) # [gages,vars,time]
700
+ Qsrout = uh_conv(rf, UH).permute([2, 0, 1])
701
+
702
+ # Routing individually for Q0, Q1, and Q2, all w/ dims [gages,vars,time].
703
+ # rf_Q0 = Q0_sim.mean(-1, keepdim=True).permute([1, 2, 0])
704
+ # Q0_rout = uh_conv(rf_Q0, UH).permute([2, 0, 1])
705
+ # rf_Q1 = Q1_sim.mean(-1, keepdim=True).permute([1, 2, 0])
706
+ # Q1_rout = uh_conv(rf_Q1, UH).permute([2, 0, 1])
707
+ # rf_Q2 = Q2_sim.mean(-1, keepdim=True).permute([1, 2, 0])
708
+ # Q2_rout = uh_conv(rf_Q2, UH).permute([2, 0, 1])
709
+
710
+ if self.comprout:
711
+ # Qs is now shape [time, [gages*num models], vars]
712
+ Qstemp = Qsrout.view(nsteps, ngrid, self.nmul)
713
+ if self.muwts is None:
714
+ Qs = Qstemp.mean(-1, keepdim=True)
715
+ else:
716
+ Qs = (Qstemp * self.muwts).sum(-1, keepdim=True)
717
+ else:
718
+ Qs = Qsrout
719
+
720
+ else:
721
+ # No routing, only output the average of all model sims.
722
+ Qs = torch.unsqueeze(Qsimavg, -1)
723
+ # Q0_rout = Q1_rout = Q2_rout = None
724
+
725
+ states = (SNOWPACK_sim, MELTWATER_sim, SM_sim, SUZ_sim, SLZ_sim)
726
+
727
+ if self.initialize:
728
+ # If initialize is True, only return warmed-up storages.
729
+ return {}, states
730
+ else:
731
+ # Baseflow index (BFI) calculation
732
+ # BFI_sim = (
733
+ # 100
734
+ # * (torch.sum(Q2_rout, dim=0) / (torch.sum(Qs, dim=0) + self.nearzero))[
735
+ # :, 0
736
+ # ]
737
+ # )
738
+
739
+ # Return all sim results.
740
+ flux_dict = {
741
+ 'Qs': Qs * dt, # Routed Streamflow for units
742
+ # 'srflow': Q0_rout * dt, # Routed surface runoff
743
+ # 'ssflow': Q1_rout * dt, # Routed subsurface flow
744
+ # 'gwflow': Q2_rout * dt, # Routed groundwater flow
745
+ # 'AET_hydro': AET.mean(-1, keepdim=True) * dt, # Actual ET
746
+ # 'PET_hydro': PETm.mean(-1, keepdim=True) * dt, # Potential ET
747
+ # 'SWE': SWE_sim.mean(-1, keepdim=True), # Snow water equivalent
748
+ # 'streamflow_no_rout': Qsim.unsqueeze(dim=2) * dt, # Streamflow
749
+ # 'srflow_no_rout': Q0_sim.mean(-1, keepdim=True) * dt, # Surface runoff
750
+ # 'ssflow_no_rout': Q1_sim.mean(-1, keepdim=True) * dt, # Subsurface flow
751
+ # 'gwflow_no_rout': Q2_sim.mean(-1, keepdim=True) * dt, # Groundwater flow
752
+ # 'recharge': recharge_sim.mean(-1, keepdim=True) * dt, # Recharge
753
+ # 'excs': excs_sim.mean(-1, keepdim=True) * dt, # Excess stored water
754
+ # 'evapfactor': evapfactor_sim.mean(-1, keepdim=True), # Evaporation factor
755
+ # 'tosoil': tosoil_sim.mean(-1, keepdim=True) * dt, # Infiltration
756
+ # 'percolation': PERC_sim.mean(-1, keepdim=True) * dt, # Percolation
757
+ # 'capillary': capillary_sim.mean(-1, keepdim=True) * dt, # Capillary rise
758
+ # 'BFI': BFI_sim * dt, # Baseflow index
759
+ }
760
+
761
+ if not self.warm_up_states:
762
+ for key in flux_dict.keys():
763
+ if key != 'BFI':
764
+ flux_dict[key] = flux_dict[key][self.pred_cutoff :, :, :]
765
+
766
+ if self.use_distr_routing:
767
+ # 1. Get current Raw Runoff
768
+
769
+ # 2. Manage Buffer
770
+ if self.cache_states:
771
+ self._qs_buffer.append((Qs * dt).detach())
772
+ # Keep buffer reasonable size (at least lenF + tau)
773
+ if len(self._qs_buffer) > self._max_history:
774
+ self._qs_buffer.pop(0)
775
+
776
+ # Create history tensor for convolution
777
+ # [History+1, Units, 1]
778
+ qs_history = torch.cat(self._qs_buffer, dim=0)
779
+ else:
780
+ # If not caching (e.g. training), use what we have
781
+ qs_history = Qs * dt
782
+
783
+ # Distributed routing for streamflow at gages
784
+ distr_out_dict = self.distr_routing(
785
+ Qs=qs_history,
786
+ distr_params_dict=distr_params_dict,
787
+ outlet_topo=outlet_topo,
788
+ areas=areas,
789
+ )
790
+ flux_dict['streamflow'] = distr_out_dict['Qs_rout']
791
+
792
+ if self.cache_states:
793
+ # If we passed in history [T], we get out [T]. We only want index -1.
794
+ flux_dict['streamflow'] = distr_out_dict['Qs_rout'][-1:]
795
+ else:
796
+ flux_dict['streamflow'] = distr_out_dict['Qs_rout']
797
+
798
+ return flux_dict, states
799
+
800
+ def distr_routing(
801
+ self,
802
+ Qs: torch.Tensor,
803
+ distr_params_dict: dict,
804
+ outlet_topo: torch.Tensor,
805
+ areas: torch.Tensor,
806
+ ):
807
+ """
808
+ :param Qs: (nsteps, n_units, 1)
809
+ :param distr_params_dict: dict of (n_pairs, n_params)
810
+ :param outlet_topo: (n_gages, n_units)
811
+ :param areas: (n_units,)
812
+ :return:
813
+ """
814
+ device = areas.device
815
+ nsteps = Qs.size(0)
816
+ max_lag = self.lenF
817
+
818
+ # extract per-pair series
819
+ Qs_weighted = (
820
+ Qs * areas[None, :, None]
821
+ ) # area-weighted runoff, (nsteps, n_units, 1)
822
+ reach_idx = (outlet_topo == 1).nonzero(as_tuple=False)
823
+ pair_rows = reach_idx[:, 0].to(device).long()
824
+ pair_cols = reach_idx[:, 1].to(device).long()
825
+ Qs_pairs = Qs_weighted[:, pair_cols, :] # (nsteps, n_pairs, 1)
826
+
827
+ # routing via convolution
828
+ UH = uh_gamma(
829
+ distr_params_dict['route_a'].repeat(nsteps, 1).unsqueeze(-1),
830
+ distr_params_dict['route_b'].repeat(nsteps, 1).unsqueeze(-1),
831
+ lenF=max_lag,
832
+ )
833
+ if self.lag_uh: # add a lag to the unit hydrograph
834
+ UH = self._frac_shift1d(UH, distr_params_dict['route_tau'])
835
+ rf = Qs_pairs.permute([1, 2, 0]).contiguous() # (n_pairs, 1, nsteps)
836
+ UH = UH.permute([1, 2, 0]).contiguous() # (n_pairs, 1, nsteps)
837
+ Qs_lagged = uh_conv(rf, UH).squeeze(1).contiguous() # (n_pairs, nsteps)
838
+
839
+ # Group-sum: scatter_add_ along rows
840
+ n_gages = int(outlet_topo.shape[0])
841
+ Qs_rout = torch.zeros(
842
+ n_gages, Qs_lagged.shape[1], device=Qs_lagged.device, dtype=Qs_lagged.dtype
843
+ )
844
+ Qs_rout.scatter_add_(
845
+ 0, pair_rows.view(-1, 1).expand(-1, Qs_lagged.shape[1]), Qs_lagged
846
+ ) # (n_gages, nsteps)
847
+
848
+ # Normalize by upstream area
849
+ denom = (outlet_topo * areas[None, :]).sum(dim=1).unsqueeze(1).clamp(min=1e-6)
850
+ Qs_rout = Qs_rout / denom
851
+ Qs_rout = Qs_rout.T.unsqueeze(-1) # (nsteps, n_gages, 1)
852
+
853
+ # output
854
+ output = {'Qs_rout': Qs_rout}
855
+ return output
856
+
857
+ @staticmethod
858
+ def _frac_shift1d(w, tau):
859
+ """
860
+ Differentiable fractional shift: return w(t - tau) by mixing k- and (k+1)-step shifts.
861
+ For tau = k + f (0<=f<1): y[t] = (1-f)*w[t-k] + f*w[t-(k+1)].
862
+ w: [T,B,V].
863
+ tau: [B,V] (>=0 recommended).
864
+ """
865
+ T, B, V = w.shape
866
+ device, dtype = w.device, w.dtype
867
+
868
+ # Decompose tau = k + f
869
+ tau = tau.view(1, B, V).to(dtype)
870
+ k = torch.floor(tau) # [1,B,V]
871
+ f = tau - k # [1,B,V]
872
+
873
+ # Time indices 0..T-1
874
+ t = torch.arange(T, device=device, dtype=dtype).view(T, 1, 1) # [T,1,1]
875
+
876
+ # Target indices for the two integer shifts
877
+ i0 = t - k # corresponds to shift by k
878
+ i1 = t - (k + 1) # corresponds to shift by k+1
879
+
880
+ # Gather with clamp + explicit zeroing (true zero padding)
881
+ i0c = i0.clamp(0, T - 1).long()
882
+ i1c = i1.clamp(0, T - 1).long()
883
+
884
+ w0 = torch.gather(w, 0, i0c)
885
+ w1 = torch.gather(w, 0, i1c)
886
+
887
+ mask0 = (i0 >= 0) & (i0 <= T - 1)
888
+ mask1 = (i1 >= 0) & (i1 <= T - 1)
889
+ w0 = w0 * mask0.to(dtype)
890
+ w1 = w1 * mask1.to(dtype)
891
+
892
+ # Linear blend: (1-f)*k-shift + f*(k+1)-shift
893
+ y = (1.0 - f) * w0 + f * w1
894
+
895
+ # Renormalize to unit mass per (B,V) -> may cause instability
896
+ # y = y / y.sum(0).clamp_min(1e-6)
897
+ return y # [T,B,V]