hydrodl2 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydrodl2/__init__.py +122 -0
- hydrodl2/_version.py +34 -0
- hydrodl2/api/__init__.py +3 -0
- hydrodl2/api/methods.py +144 -0
- hydrodl2/core/calc/__init__.py +11 -0
- hydrodl2/core/calc/batch_jacobian.pye +501 -0
- hydrodl2/core/calc/fdj.py +92 -0
- hydrodl2/core/calc/uh_routing.py +105 -0
- hydrodl2/core/calc/utils.py +59 -0
- hydrodl2/core/utils/__init__.py +7 -0
- hydrodl2/core/utils/clean_temp.sh +8 -0
- hydrodl2/core/utils/utils.py +63 -0
- hydrodl2/models/hbv/hbv.py +596 -0
- hydrodl2/models/hbv/hbv_1_1p.py +608 -0
- hydrodl2/models/hbv/hbv_2.py +670 -0
- hydrodl2/models/hbv/hbv_2_hourly.py +897 -0
- hydrodl2/models/hbv/hbv_2_mts.py +377 -0
- hydrodl2/models/hbv/hbv_adj.py +712 -0
- hydrodl2/modules/__init__.py +2 -0
- hydrodl2/modules/data_assimilation/variational_prcp_da.py +1 -0
- hydrodl2-1.3.0.dist-info/METADATA +184 -0
- hydrodl2-1.3.0.dist-info/RECORD +24 -0
- hydrodl2-1.3.0.dist-info/WHEEL +4 -0
- hydrodl2-1.3.0.dist-info/licenses/LICENSE +31 -0
|
@@ -0,0 +1,712 @@
|
|
|
1
|
+
from typing import Any, Optional, Union
|
|
2
|
+
|
|
3
|
+
# import sourcedefender
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from hydrodl2.core.calc import (
|
|
7
|
+
change_param_range,
|
|
8
|
+
finite_difference_jacobian_p,
|
|
9
|
+
uh_conv,
|
|
10
|
+
uh_gamma,
|
|
11
|
+
)
|
|
12
|
+
from hydrodl2.core.calc.batch_jacobian import batchJacobian
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class HbvAdj(torch.nn.Module):
|
|
16
|
+
"""
|
|
17
|
+
Multi-component PyTorch HBV model using implicit numerical scheme and
|
|
18
|
+
gradient tracking supported by adjoint method.
|
|
19
|
+
|
|
20
|
+
Author
|
|
21
|
+
------
|
|
22
|
+
Yalan Song
|
|
23
|
+
|
|
24
|
+
Publication
|
|
25
|
+
-----------
|
|
26
|
+
- Song, Y., Knoben, W. J. M., Clark, M. P., Feng, D., Lawson, K. E., & Shen,
|
|
27
|
+
C. (2024). When ancient numerical demons meet physics-informed machine
|
|
28
|
+
learning: Adjoint-based gradients for implicit differentiable modeling.
|
|
29
|
+
Hydrology and Earth System Sciences Discussions, 1-35.
|
|
30
|
+
https://doi.org/10.5194/hess-2023-258
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
config
|
|
35
|
+
Configuration dictionary.
|
|
36
|
+
device
|
|
37
|
+
Device to run the model on.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
config: Optional[dict[str, Any]] = None,
|
|
43
|
+
device: Optional[torch.device] = None,
|
|
44
|
+
) -> None:
|
|
45
|
+
super().__init__()
|
|
46
|
+
self.name = 'HBV Adjoint'
|
|
47
|
+
self.config = config
|
|
48
|
+
self.initialize = False
|
|
49
|
+
self.warm_up = 0
|
|
50
|
+
self.dynamic_params = []
|
|
51
|
+
self.dy_drop = 0.0
|
|
52
|
+
self.variables = ['prcp', 'tmean', 'pet']
|
|
53
|
+
self.routing = True
|
|
54
|
+
self.comprout = False
|
|
55
|
+
self.nearzero = 1e-5
|
|
56
|
+
self.nmul = 1
|
|
57
|
+
self.ad_efficient = True
|
|
58
|
+
self.device = device
|
|
59
|
+
self.parameter_bounds = {
|
|
60
|
+
'parBETA': [1.0, 6.0],
|
|
61
|
+
'parFC': [50, 1000],
|
|
62
|
+
'parK0': [0.05, 0.9],
|
|
63
|
+
'parK1': [0.01, 0.5],
|
|
64
|
+
'parK2': [0.001, 0.2],
|
|
65
|
+
'parLP': [0.2, 1],
|
|
66
|
+
'parPERC': [0, 10],
|
|
67
|
+
'parUZL': [0, 100],
|
|
68
|
+
'parTT': [-2.5, 2.5],
|
|
69
|
+
'parCFMAX': [0.5, 10],
|
|
70
|
+
'parCFR': [0, 0.1],
|
|
71
|
+
'parCWH': [0, 0.2],
|
|
72
|
+
}
|
|
73
|
+
self.routing_parameter_bounds = {
|
|
74
|
+
'rout_a': [0, 2.9],
|
|
75
|
+
'rout_b': [0, 6.5],
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
if not device:
|
|
79
|
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
80
|
+
|
|
81
|
+
if config is not None:
|
|
82
|
+
# Overwrite defaults with config values.
|
|
83
|
+
self.warm_up = config.get('warm_up', self.warm_up)
|
|
84
|
+
self.dy_drop = config.get('dy_drop', self.dy_drop)
|
|
85
|
+
self.dynamic_params = config['dynamic_params'].get(
|
|
86
|
+
self.__class__.__name__, self.dynamic_params
|
|
87
|
+
)
|
|
88
|
+
self.variables = config.get('variables', self.variables)
|
|
89
|
+
self.routing = config.get('routing', self.routing)
|
|
90
|
+
self.comprout = config.get('comprout', self.comprout)
|
|
91
|
+
self.nearzero = config.get('nearzero', self.nearzero)
|
|
92
|
+
self.nmul = config.get('nmul', self.nmul)
|
|
93
|
+
self.ad_efficient = config.get('ad_efficient', self.ad_efficient)
|
|
94
|
+
if 'parBETAET' in self.dynamic_params:
|
|
95
|
+
self.parameter_bounds['parBETAET'] = [0.3, 5]
|
|
96
|
+
|
|
97
|
+
self.set_parameters()
|
|
98
|
+
|
|
99
|
+
def set_parameters(self) -> None:
|
|
100
|
+
"""Get physical parameters."""
|
|
101
|
+
self.phy_param_names = self.parameter_bounds.keys()
|
|
102
|
+
if self.routing:
|
|
103
|
+
self.routing_param_names = self.routing_parameter_bounds.keys()
|
|
104
|
+
else:
|
|
105
|
+
self.routing_param_names = []
|
|
106
|
+
|
|
107
|
+
self.learnable_param_count = len(self.phy_param_names) * self.nmul + len(
|
|
108
|
+
self.routing_param_names
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
def unpack_parameters(
|
|
112
|
+
self,
|
|
113
|
+
parameters: torch.Tensor,
|
|
114
|
+
n_steps: int,
|
|
115
|
+
n_grid: int,
|
|
116
|
+
) -> dict[str, torch.Tensor]:
|
|
117
|
+
"""Extract physical model and routing parameters from NN output.
|
|
118
|
+
|
|
119
|
+
Parameters
|
|
120
|
+
----------
|
|
121
|
+
parameters
|
|
122
|
+
Unprocessed, learned parameters from a neural network.
|
|
123
|
+
n_steps
|
|
124
|
+
Number of time steps in the input data.
|
|
125
|
+
n_grid
|
|
126
|
+
Number of grid cells in the input data.
|
|
127
|
+
|
|
128
|
+
Returns
|
|
129
|
+
-------
|
|
130
|
+
tuple[torch.Tensor, torch.Tensor]
|
|
131
|
+
Tuple of physical and routing parameters.
|
|
132
|
+
"""
|
|
133
|
+
phy_param_count = len(self.parameter_bounds)
|
|
134
|
+
|
|
135
|
+
# Physical parameters
|
|
136
|
+
phy_params = torch.sigmoid(
|
|
137
|
+
parameters[:, :, : phy_param_count * self.nmul]
|
|
138
|
+
).view(
|
|
139
|
+
parameters.shape[0],
|
|
140
|
+
parameters.shape[1],
|
|
141
|
+
phy_param_count,
|
|
142
|
+
self.nmul,
|
|
143
|
+
)
|
|
144
|
+
## Merge the multi-components into batch dimension for parallel Jacobian
|
|
145
|
+
phy_params = phy_params.permute([0, 3, 1, 2])
|
|
146
|
+
bsnew = n_grid * self.nmul
|
|
147
|
+
phy_params = phy_params.reshape(n_steps, bsnew, phy_param_count)
|
|
148
|
+
|
|
149
|
+
# Routing parameters
|
|
150
|
+
if self.routing:
|
|
151
|
+
routing_params = torch.sigmoid(
|
|
152
|
+
parameters[-1, :, phy_param_count * self.nmul :],
|
|
153
|
+
)
|
|
154
|
+
return phy_params, routing_params
|
|
155
|
+
|
|
156
|
+
def make_phy_parameters(
|
|
157
|
+
self,
|
|
158
|
+
phy_params: torch.Tensor,
|
|
159
|
+
name_list: list,
|
|
160
|
+
dy_list: list,
|
|
161
|
+
) -> torch.Tensor:
|
|
162
|
+
"""Descale physical parameters.
|
|
163
|
+
|
|
164
|
+
Parameters
|
|
165
|
+
----------
|
|
166
|
+
phy_params
|
|
167
|
+
Normalized physical parameters.
|
|
168
|
+
name_list
|
|
169
|
+
List of physical parameter names.
|
|
170
|
+
dy_list
|
|
171
|
+
List of dynamic parameters.
|
|
172
|
+
|
|
173
|
+
Returns
|
|
174
|
+
-------
|
|
175
|
+
torch.Tensor
|
|
176
|
+
Tensor of physical parameters.
|
|
177
|
+
"""
|
|
178
|
+
n_steps, n_grid, nfea = phy_params.size()
|
|
179
|
+
parstaFull = phy_params[-1, :, :].unsqueeze(0).repeat([n_steps, 1, 1])
|
|
180
|
+
# Precompute probability mask for dynamic parameters
|
|
181
|
+
if dy_list:
|
|
182
|
+
pmat = torch.ones([1, n_grid]) * self.dy_drop
|
|
183
|
+
parhbvFull = torch.clone(parstaFull)
|
|
184
|
+
for i, name in enumerate(name_list):
|
|
185
|
+
if name in dy_list:
|
|
186
|
+
staPar = parstaFull[:, :, i]
|
|
187
|
+
dynPar = phy_params[:, :, i]
|
|
188
|
+
drmask = (
|
|
189
|
+
torch.bernoulli(pmat).detach_().to(self.device)
|
|
190
|
+
) # to drop some dynamic parameters as static
|
|
191
|
+
comPar = dynPar * (1 - drmask) + staPar * drmask
|
|
192
|
+
parhbvFull[:, :, i] = comPar
|
|
193
|
+
return parhbvFull
|
|
194
|
+
|
|
195
|
+
else:
|
|
196
|
+
return parstaFull
|
|
197
|
+
|
|
198
|
+
def descale_rout_parameters(
|
|
199
|
+
self,
|
|
200
|
+
rout_params: torch.Tensor,
|
|
201
|
+
name_list: list,
|
|
202
|
+
) -> torch.Tensor:
|
|
203
|
+
"""Descale routing parameters.
|
|
204
|
+
|
|
205
|
+
Parameters
|
|
206
|
+
----------
|
|
207
|
+
rout_params
|
|
208
|
+
Normalized routing parameters.
|
|
209
|
+
name_list
|
|
210
|
+
List of routing parameter names.
|
|
211
|
+
|
|
212
|
+
Returns
|
|
213
|
+
-------
|
|
214
|
+
dict
|
|
215
|
+
Dictionary of descaled routing parameters.
|
|
216
|
+
"""
|
|
217
|
+
parameter_dict = {}
|
|
218
|
+
for i, name in enumerate(name_list):
|
|
219
|
+
param = rout_params[:, i]
|
|
220
|
+
|
|
221
|
+
parameter_dict[name] = change_param_range(
|
|
222
|
+
param=param,
|
|
223
|
+
bounds=self.routing_parameter_bounds[name],
|
|
224
|
+
)
|
|
225
|
+
return parameter_dict
|
|
226
|
+
|
|
227
|
+
def forward(
|
|
228
|
+
self,
|
|
229
|
+
x_dict: dict[str, torch.Tensor],
|
|
230
|
+
parameters: torch.Tensor,
|
|
231
|
+
) -> Union[tuple, dict[str, torch.Tensor]]:
|
|
232
|
+
"""Forward pass for HBV Adj.
|
|
233
|
+
|
|
234
|
+
Parameters
|
|
235
|
+
----------
|
|
236
|
+
x_dict
|
|
237
|
+
Dictionary of input forcing data.
|
|
238
|
+
parameters
|
|
239
|
+
Unprocessed, learned parameters from a neural network.
|
|
240
|
+
|
|
241
|
+
Returns
|
|
242
|
+
-------
|
|
243
|
+
Union[tuple, dict]
|
|
244
|
+
Tuple or dictionary of model outputs.
|
|
245
|
+
"""
|
|
246
|
+
# Unpack input data.
|
|
247
|
+
x = x_dict['x_phy']
|
|
248
|
+
|
|
249
|
+
n_steps, bs, _ = x.size()
|
|
250
|
+
bsnew = bs * self.nmul
|
|
251
|
+
phy_params, routing_params = self.unpack_parameters(parameters, n_steps, bs)
|
|
252
|
+
|
|
253
|
+
nS = 5 ## For this version of HBV, we have 5 state varibales
|
|
254
|
+
y_init = torch.zeros((bsnew, nS)).to(self.device)
|
|
255
|
+
nflux = 1 # currently only return streamflow
|
|
256
|
+
delta_t = torch.tensor(1.0).to(device=self.device) ## Daily model
|
|
257
|
+
if self.warm_up > 0:
|
|
258
|
+
phy_params_warmup = self.make_phy_parameters(
|
|
259
|
+
phy_params[: self.warm_up, :, :], self.phy_param_names, []
|
|
260
|
+
)
|
|
261
|
+
x_warmup = x[: self.warm_up, :, :].unsqueeze(1).repeat([1, self.nmul, 1, 1])
|
|
262
|
+
x_warmup = x_warmup.view(x_warmup.shape[0], bsnew, x_warmup.shape[-1])
|
|
263
|
+
f_warm_up = HBV(x_warmup, self.parameter_bounds)
|
|
264
|
+
M_warm_up = MOL(
|
|
265
|
+
f_warm_up,
|
|
266
|
+
nS,
|
|
267
|
+
nflux,
|
|
268
|
+
self.warm_up,
|
|
269
|
+
bsDefault=bsnew,
|
|
270
|
+
mtd=0,
|
|
271
|
+
dtDefault=delta_t,
|
|
272
|
+
ad_efficient=self.ad_efficient,
|
|
273
|
+
)
|
|
274
|
+
y0 = M_warm_up.nsteps_pDyn(phy_params_warmup, y_init)[-1, :, :]
|
|
275
|
+
else:
|
|
276
|
+
y0 = y_init
|
|
277
|
+
|
|
278
|
+
phy_params_run = self.make_phy_parameters(
|
|
279
|
+
phy_params[self.warm_up :, :, :], self.phy_param_names, self.dynamic_params
|
|
280
|
+
)
|
|
281
|
+
routy_params_dict = self.descale_rout_parameters(
|
|
282
|
+
routing_params, self.rout_params_name
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
xTrain = x[self.warm_up :, :, :].unsqueeze(1).repeat([1, self.nmul, 1, 1])
|
|
286
|
+
xTrain = xTrain.view(xTrain.shape[0], bsnew, xTrain.shape[-1])
|
|
287
|
+
# Without warm-up, initialize state variables with zeros.
|
|
288
|
+
|
|
289
|
+
nt = phy_params_run.shape[0]
|
|
290
|
+
|
|
291
|
+
simulation = torch.zeros((nt, bsnew, nflux)).to(self.device)
|
|
292
|
+
|
|
293
|
+
f = HBV(xTrain, self.parameter_bounds)
|
|
294
|
+
|
|
295
|
+
M = MOL(
|
|
296
|
+
f,
|
|
297
|
+
nS,
|
|
298
|
+
nflux,
|
|
299
|
+
nt,
|
|
300
|
+
bsDefault=bsnew,
|
|
301
|
+
dtDefault=delta_t,
|
|
302
|
+
mtd=0,
|
|
303
|
+
ad_efficient=self.ad_efficient,
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
### Newton iterations with adjoint
|
|
307
|
+
ySolution = M.nsteps_pDyn(phy_params_run, y0)
|
|
308
|
+
|
|
309
|
+
for day in range(0, nt):
|
|
310
|
+
_, flux = f(
|
|
311
|
+
ySolution[day, :, :], phy_params_run[day, :, :], day, expanded_num=[1]
|
|
312
|
+
)
|
|
313
|
+
simulation[day, :, :] = flux * delta_t
|
|
314
|
+
|
|
315
|
+
if self.nmul > 1:
|
|
316
|
+
simulation = simulation.view(nt, self.nmul, bs, nflux)
|
|
317
|
+
simulation = simulation.mean(dim=1)
|
|
318
|
+
|
|
319
|
+
routa = routy_params_dict['rout_a'].unsqueeze(0).repeat(nt, 1).unsqueeze(-1)
|
|
320
|
+
routb = routy_params_dict['rout_b'].unsqueeze(0).repeat(nt, 1).unsqueeze(-1)
|
|
321
|
+
|
|
322
|
+
UH = uh_gamma(routa, routb, lenF=15) # lenF: folter
|
|
323
|
+
rf = simulation.permute([1, 2, 0]) # dim:gage*var*time
|
|
324
|
+
UH = UH.permute([1, 2, 0]) # dim: gage*var*time
|
|
325
|
+
Qsrout = uh_conv(rf, UH).permute([2, 0, 1])
|
|
326
|
+
|
|
327
|
+
# Return all sim results.
|
|
328
|
+
return {
|
|
329
|
+
'flow_sim': Qsrout,
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
class HBV(torch.nn.Module):
|
|
334
|
+
"""HBV model."""
|
|
335
|
+
|
|
336
|
+
def __init__(self, climate_data, parameter_bounds):
|
|
337
|
+
super().__init__()
|
|
338
|
+
self.climate_data = climate_data
|
|
339
|
+
self.parameter_bounds = parameter_bounds
|
|
340
|
+
|
|
341
|
+
def forward(self, y, theta, t, expanded_num, returnFlux=False, aux=None):
|
|
342
|
+
"""Forward method."""
|
|
343
|
+
##parameters
|
|
344
|
+
Beta = self.parameter_bounds['parBETA'][0] + theta[:, 0] * (
|
|
345
|
+
self.parameter_bounds['parBETA'][1] - self.parameter_bounds['parBETA'][0]
|
|
346
|
+
)
|
|
347
|
+
FC = self.parameter_bounds['parFC'][0] + theta[:, 1] * (
|
|
348
|
+
self.parameter_bounds['parFC'][1] - self.parameter_bounds['parFC'][0]
|
|
349
|
+
)
|
|
350
|
+
K0 = self.parameter_bounds['parK0'][0] + theta[:, 2] * (
|
|
351
|
+
self.parameter_bounds['parK0'][1] - self.parameter_bounds['parK0'][0]
|
|
352
|
+
)
|
|
353
|
+
K1 = self.parameter_bounds['parK1'][0] + theta[:, 3] * (
|
|
354
|
+
self.parameter_bounds['parK1'][1] - self.parameter_bounds['parK1'][0]
|
|
355
|
+
)
|
|
356
|
+
K2 = self.parameter_bounds['parK2'][0] + theta[:, 4] * (
|
|
357
|
+
self.parameter_bounds['parK2'][1] - self.parameter_bounds['parK2'][0]
|
|
358
|
+
)
|
|
359
|
+
LP = self.parameter_bounds['parLP'][0] + theta[:, 5] * (
|
|
360
|
+
self.parameter_bounds['parLP'][1] - self.parameter_bounds['parLP'][0]
|
|
361
|
+
)
|
|
362
|
+
PERC = self.parameter_bounds['parPERC'][0] + theta[:, 6] * (
|
|
363
|
+
self.parameter_bounds['parPERC'][1] - self.parameter_bounds['parPERC'][0]
|
|
364
|
+
)
|
|
365
|
+
UZL = self.parameter_bounds['parUZL'][0] + theta[:, 7] * (
|
|
366
|
+
self.parameter_bounds['parUZL'][1] - self.parameter_bounds['parUZL'][0]
|
|
367
|
+
)
|
|
368
|
+
TT = self.parameter_bounds['parTT'][0] + theta[:, 8] * (
|
|
369
|
+
self.parameter_bounds['parTT'][1] - self.parameter_bounds['parTT'][0]
|
|
370
|
+
)
|
|
371
|
+
CFMAX = self.parameter_bounds['parCFMAX'][0] + theta[:, 9] * (
|
|
372
|
+
self.parameter_bounds['parCFMAX'][1] - self.parameter_bounds['parCFMAX'][0]
|
|
373
|
+
)
|
|
374
|
+
CFR = self.parameter_bounds['parCFR'][0] + theta[:, 10] * (
|
|
375
|
+
self.parameter_bounds['parCFR'][1] - self.parameter_bounds['parCFR'][0]
|
|
376
|
+
)
|
|
377
|
+
CWH = self.parameter_bounds['parCWH'][0] + theta[:, 11] * (
|
|
378
|
+
self.parameter_bounds['parCWH'][1] - self.parameter_bounds['parCWH'][0]
|
|
379
|
+
)
|
|
380
|
+
BETAET = self.parameter_bounds['parBETAET'][0] + theta[:, 12] * (
|
|
381
|
+
self.parameter_bounds['parBETAET'][1]
|
|
382
|
+
- self.parameter_bounds['parBETAET'][0]
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
PRECS = 0
|
|
386
|
+
##% stores
|
|
387
|
+
SNOWPACK = torch.clamp(y[:, 0], min=PRECS) # SNOWPACK
|
|
388
|
+
MELTWATER = torch.clamp(y[:, 1], min=PRECS) # MELTWATER
|
|
389
|
+
SM = torch.clamp(y[:, 2], min=1e-8) # SM
|
|
390
|
+
SUZ = torch.clamp(y[:, 3], min=PRECS) # SUZ
|
|
391
|
+
SLZ = torch.clamp(y[:, 4], min=PRECS) # SLZ
|
|
392
|
+
dS = torch.zeros(y.shape[0], y.shape[1]).to(y)
|
|
393
|
+
fluxes = torch.zeros((y.shape[0], 1)).to(y)
|
|
394
|
+
|
|
395
|
+
climate_in0 = self.climate_data[int(t), :, :] ##% climate at this step
|
|
396
|
+
|
|
397
|
+
for idx, expanded_num_i in enumerate(expanded_num):
|
|
398
|
+
if idx == 0:
|
|
399
|
+
climate_in = climate_in0.repeat_interleave(expanded_num_i, dim=0)
|
|
400
|
+
else:
|
|
401
|
+
climate_in = torch.cat(
|
|
402
|
+
[climate_in, climate_in0.repeat_interleave(expanded_num_i, dim=0)],
|
|
403
|
+
dim=0,
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
P = climate_in[:, 0]
|
|
407
|
+
Ep = climate_in[:, 2]
|
|
408
|
+
T = climate_in[:, 1]
|
|
409
|
+
|
|
410
|
+
##% fluxes functions
|
|
411
|
+
flux_sf = self.snowfall(P, T, TT)
|
|
412
|
+
flux_refr = self.refreeze(CFR, CFMAX, T, TT, MELTWATER)
|
|
413
|
+
flux_melt = self.melt(CFMAX, T, TT, SNOWPACK)
|
|
414
|
+
flux_rf = self.rainfall(P, T, TT)
|
|
415
|
+
flux_Isnow = self.Isnow(MELTWATER, CWH, SNOWPACK)
|
|
416
|
+
flux_PEFF = self.Peff(SM, FC, Beta, flux_rf, flux_Isnow)
|
|
417
|
+
flux_ex = self.excess(SM, FC)
|
|
418
|
+
flux_et = self.evap(SM, FC, LP, Ep, BETAET)
|
|
419
|
+
flux_perc = self.percolation(PERC, SUZ)
|
|
420
|
+
flux_q0 = self.interflow(K0, SUZ, UZL)
|
|
421
|
+
flux_q1 = self.baseflow(K1, SUZ)
|
|
422
|
+
flux_q2 = self.baseflow(K2, SLZ)
|
|
423
|
+
|
|
424
|
+
# % stores ODEs
|
|
425
|
+
dS[:, 0] = flux_sf + flux_refr - flux_melt
|
|
426
|
+
dS[:, 1] = flux_melt - flux_refr - flux_Isnow
|
|
427
|
+
dS[:, 2] = flux_Isnow + flux_rf - flux_PEFF - flux_ex - flux_et
|
|
428
|
+
dS[:, 3] = flux_PEFF + flux_ex - flux_perc - flux_q0 - flux_q1
|
|
429
|
+
dS[:, 4] = flux_perc - flux_q2
|
|
430
|
+
|
|
431
|
+
fluxes[:, 0] = flux_q0 + flux_q1 + flux_q2
|
|
432
|
+
|
|
433
|
+
if returnFlux:
|
|
434
|
+
return (
|
|
435
|
+
fluxes,
|
|
436
|
+
flux_q0.unsqueeze(-1),
|
|
437
|
+
flux_q1.unsqueeze(-1),
|
|
438
|
+
flux_q2.unsqueeze(-1),
|
|
439
|
+
flux_et.unsqueeze(-1),
|
|
440
|
+
)
|
|
441
|
+
else:
|
|
442
|
+
return dS, fluxes
|
|
443
|
+
|
|
444
|
+
def snowfall(self, P, T, TT):
|
|
445
|
+
"""Snowfall."""
|
|
446
|
+
return torch.mul(P, (T < TT))
|
|
447
|
+
|
|
448
|
+
def refreeze(self, CFR, CFMAX, T, TT, MELTWATER):
|
|
449
|
+
"""Refreezing."""
|
|
450
|
+
refreezing = CFR * CFMAX * (TT - T)
|
|
451
|
+
refreezing = torch.clamp(refreezing, min=0.0)
|
|
452
|
+
return torch.min(refreezing, MELTWATER)
|
|
453
|
+
|
|
454
|
+
def melt(self, CFMAX, T, TT, SNOWPACK):
|
|
455
|
+
"""Snowmelt."""
|
|
456
|
+
melt = CFMAX * (T - TT)
|
|
457
|
+
melt = torch.clamp(melt, min=0.0)
|
|
458
|
+
return torch.min(melt, SNOWPACK)
|
|
459
|
+
|
|
460
|
+
def rainfall(self, P, T, TT):
|
|
461
|
+
"""Rainfall."""
|
|
462
|
+
return torch.mul(P, (T >= TT))
|
|
463
|
+
|
|
464
|
+
def Isnow(self, MELTWATER, CWH, SNOWPACK):
|
|
465
|
+
"""Snowmelt to soil water."""
|
|
466
|
+
tosoil = MELTWATER - (CWH * SNOWPACK)
|
|
467
|
+
tosoil = torch.clamp(tosoil, min=0.0)
|
|
468
|
+
return tosoil
|
|
469
|
+
|
|
470
|
+
def Peff(self, SM, FC, Beta, flux_rf, flux_Isnow):
|
|
471
|
+
"""Effective precipitation."""
|
|
472
|
+
soil_wetness = (SM / FC) ** Beta
|
|
473
|
+
soil_wetness = torch.clamp(soil_wetness, min=0.0, max=1.0)
|
|
474
|
+
return (flux_rf + flux_Isnow) * soil_wetness
|
|
475
|
+
|
|
476
|
+
def excess(self, SM, FC):
|
|
477
|
+
"""Excess water."""
|
|
478
|
+
excess = SM - FC
|
|
479
|
+
return torch.clamp(excess, min=0.0)
|
|
480
|
+
|
|
481
|
+
def evap(self, SM, FC, LP, Ep, BETAET):
|
|
482
|
+
"""Evapotranspiration."""
|
|
483
|
+
evapfactor = (SM / (LP * FC)) ** BETAET
|
|
484
|
+
evapfactor = torch.clamp(evapfactor, min=0.0, max=1.0)
|
|
485
|
+
ETact = Ep * evapfactor
|
|
486
|
+
return torch.min(SM, ETact)
|
|
487
|
+
|
|
488
|
+
def interflow(self, K0, SUZ, UZL):
|
|
489
|
+
"""Interflow."""
|
|
490
|
+
return K0 * torch.clamp(SUZ - UZL, min=0.0)
|
|
491
|
+
|
|
492
|
+
def percolation(self, PERC, SUZ):
|
|
493
|
+
"""Percolation."""
|
|
494
|
+
return torch.min(SUZ, PERC)
|
|
495
|
+
|
|
496
|
+
def baseflow(self, K, S):
|
|
497
|
+
"""Baseflow function."""
|
|
498
|
+
return K * S
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
matrixSolve = torch.linalg.solve
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
class NewtonSolve(torch.autograd.Function):
|
|
505
|
+
"""Newton solver for the adjoint model."""
|
|
506
|
+
|
|
507
|
+
@staticmethod
|
|
508
|
+
def forward(
|
|
509
|
+
ctx, p, p2, t, G, x0=None, auxG=None, batchP=True, eval=False, ad_efficient=True
|
|
510
|
+
):
|
|
511
|
+
"""Forward solver."""
|
|
512
|
+
useAD_jac = True
|
|
513
|
+
if x0 is None and p2 is not None:
|
|
514
|
+
x0 = p2
|
|
515
|
+
|
|
516
|
+
x = x0.clone().detach()
|
|
517
|
+
i = 0
|
|
518
|
+
max_iter = 3
|
|
519
|
+
gtol = 1e-3
|
|
520
|
+
|
|
521
|
+
if useAD_jac:
|
|
522
|
+
torch.set_grad_enabled(True)
|
|
523
|
+
|
|
524
|
+
x.requires_grad = True
|
|
525
|
+
|
|
526
|
+
if p2 is None:
|
|
527
|
+
gg = G(x, p, t, auxG)
|
|
528
|
+
else:
|
|
529
|
+
gg = G(x, p, p2, t, auxG)
|
|
530
|
+
if ad_efficient:
|
|
531
|
+
dGdx = batchJacobian(gg, x, graphed=True)
|
|
532
|
+
# else:
|
|
533
|
+
# dGdx = batchJacobian_AD_slow(gg, x, graphed=True)
|
|
534
|
+
if torch.isnan(dGdx).any() or torch.isinf(dGdx).any():
|
|
535
|
+
raise RuntimeError("Jacobian matrix is NaN")
|
|
536
|
+
x = x.detach()
|
|
537
|
+
|
|
538
|
+
torch.set_grad_enabled(False)
|
|
539
|
+
resnorm = torch.linalg.norm(
|
|
540
|
+
gg, float('inf'), dim=[1]
|
|
541
|
+
) # calculate norm of the residuals
|
|
542
|
+
resnorm0 = 100 * resnorm
|
|
543
|
+
|
|
544
|
+
while (torch.max(resnorm) > gtol) and i <= max_iter:
|
|
545
|
+
i += 1
|
|
546
|
+
if torch.max(resnorm / resnorm0) > 0.2:
|
|
547
|
+
if useAD_jac:
|
|
548
|
+
torch.set_grad_enabled(True)
|
|
549
|
+
|
|
550
|
+
x.requires_grad = True
|
|
551
|
+
|
|
552
|
+
if p2 is None:
|
|
553
|
+
gg = G(x, p, t, auxG)
|
|
554
|
+
else:
|
|
555
|
+
gg = G(x, p, p2, t, auxG)
|
|
556
|
+
if ad_efficient:
|
|
557
|
+
dGdx = batchJacobian(gg, x, graphed=True)
|
|
558
|
+
# else:
|
|
559
|
+
# dGdx = batchJacobian_AD_slow(gg, x, graphed=True)
|
|
560
|
+
if torch.isnan(dGdx).any() or torch.isinf(dGdx).any():
|
|
561
|
+
raise RuntimeError("Jacobian matrix is NaN")
|
|
562
|
+
|
|
563
|
+
x = x.detach()
|
|
564
|
+
|
|
565
|
+
torch.set_grad_enabled(False)
|
|
566
|
+
|
|
567
|
+
if dGdx.ndim == gg.ndim: # same dimension, must be scalar.
|
|
568
|
+
dx = (gg / dGdx).detach()
|
|
569
|
+
else:
|
|
570
|
+
dx = matrixSolve(dGdx, gg).detach()
|
|
571
|
+
x = x - dx
|
|
572
|
+
if useAD_jac:
|
|
573
|
+
torch.set_grad_enabled(True)
|
|
574
|
+
x.requires_grad = True
|
|
575
|
+
if p2 is None:
|
|
576
|
+
gg = G(x, p, t, auxG)
|
|
577
|
+
else:
|
|
578
|
+
gg = G(x, p, p2, t, auxG)
|
|
579
|
+
torch.set_grad_enabled(False)
|
|
580
|
+
resnorm0 = resnorm ##% old resnorm
|
|
581
|
+
resnorm = torch.linalg.norm(gg, float('inf'), dim=[1])
|
|
582
|
+
|
|
583
|
+
torch.set_grad_enabled(True)
|
|
584
|
+
x = x.detach()
|
|
585
|
+
if not eval:
|
|
586
|
+
if batchP:
|
|
587
|
+
# dGdp is needed only upon convergence.
|
|
588
|
+
if p2 is None:
|
|
589
|
+
if ad_efficient:
|
|
590
|
+
dGdp = batchJacobian(gg, p, graphed=True)
|
|
591
|
+
dGdp2 = None
|
|
592
|
+
else:
|
|
593
|
+
# dGdp = batchJacobian_AD_slow(gg, p, graphed=True);
|
|
594
|
+
dGdp2 = None
|
|
595
|
+
else:
|
|
596
|
+
if ad_efficient:
|
|
597
|
+
dGdp, dGdp2 = batchJacobian(gg, (p, p2), graphed=True)
|
|
598
|
+
else:
|
|
599
|
+
dx = matrixSolve(dGdx, gg)
|
|
600
|
+
x = x - dx
|
|
601
|
+
gg = G(x, p, p2, t, [1], auxG)
|
|
602
|
+
resnorm0 = resnorm ##% old resnorm
|
|
603
|
+
resnorm = torch.linalg.norm(gg, float('inf'), dim=[1])
|
|
604
|
+
|
|
605
|
+
if batchP:
|
|
606
|
+
dGdp, dGdp2 = finite_difference_jacobian_p(G, x, p, p2, t, 1e-6, auxG)
|
|
607
|
+
|
|
608
|
+
else:
|
|
609
|
+
assert "nonbatchp (like NN) pathway not debugged through yet"
|
|
610
|
+
# print("day ", t, "Iterations ", i)
|
|
611
|
+
ctx.save_for_backward(dGdp.float(), dGdp2.float(), dGdx.float())
|
|
612
|
+
# This way, we reduced one forward run. You can also save these two to the CPU if forward run is
|
|
613
|
+
# Alternatively, if memory is a problem, save x and run g during the backward.
|
|
614
|
+
del gg
|
|
615
|
+
return x.float()
|
|
616
|
+
|
|
617
|
+
@staticmethod
|
|
618
|
+
def backward(ctx, dLdx):
|
|
619
|
+
# pydevd.settrace(suspend=False, trace_only_current_thread=True)
|
|
620
|
+
with torch.no_grad():
|
|
621
|
+
dGdp, dGdp2, dGdx = ctx.saved_tensors
|
|
622
|
+
dGdxT = torch.permute(dGdx, (0, 2, 1))
|
|
623
|
+
lambTneg = matrixSolve(dGdxT, dLdx)
|
|
624
|
+
if lambTneg.ndim <= 2:
|
|
625
|
+
lambTneg = torch.unsqueeze(lambTneg, 2)
|
|
626
|
+
dLdp = -torch.bmm(torch.permute(lambTneg, (0, 2, 1)), dGdp)
|
|
627
|
+
dLdp = torch.squeeze(dLdp, 1) # ADHOC!! DON"T KNOW WHY!!
|
|
628
|
+
if dGdp2 is None:
|
|
629
|
+
dLdp2 = None
|
|
630
|
+
else:
|
|
631
|
+
dLdp2 = -torch.bmm(torch.permute(lambTneg, (0, 2, 1)), dGdp2)
|
|
632
|
+
dLdp2 = torch.squeeze(dLdp2, 1) # ADHOC!! DON"T KNOW WHY!!
|
|
633
|
+
return dLdp, dLdp2, None, None, None, None, None, None
|
|
634
|
+
|
|
635
|
+
|
|
636
|
+
class MOL(torch.nn.Module):
|
|
637
|
+
"""
|
|
638
|
+
Method of Lines time integrator as a nonlinear equation
|
|
639
|
+
G(x, p, xt, t, auxG)=0.
|
|
640
|
+
RHS is preloaded at construct and is the equation for the right hand side
|
|
641
|
+
of the equation.
|
|
642
|
+
"""
|
|
643
|
+
|
|
644
|
+
def __init__(
|
|
645
|
+
self,
|
|
646
|
+
rhsFunc,
|
|
647
|
+
ny,
|
|
648
|
+
nflux,
|
|
649
|
+
rho,
|
|
650
|
+
bsDefault=1,
|
|
651
|
+
mtd=0,
|
|
652
|
+
dtDefault=0,
|
|
653
|
+
solveAdj=NewtonSolve.apply,
|
|
654
|
+
eval=False,
|
|
655
|
+
ad_efficient=True,
|
|
656
|
+
):
|
|
657
|
+
super().__init__()
|
|
658
|
+
self.mtd = mtd # time discretization method. =0 for backward Euler
|
|
659
|
+
self.rhs = rhsFunc
|
|
660
|
+
self.delta_t = dtDefault
|
|
661
|
+
self.bs = bsDefault
|
|
662
|
+
self.ny = ny
|
|
663
|
+
self.nflux = nflux
|
|
664
|
+
self.rho = rho
|
|
665
|
+
self.solveAdj = solveAdj
|
|
666
|
+
self.eval = eval
|
|
667
|
+
self.ad_efficient = ad_efficient
|
|
668
|
+
|
|
669
|
+
def forward(self, x, p, xt, t, expand_num, auxG): # take one step
|
|
670
|
+
"""Forward model."""
|
|
671
|
+
# xt is x^{t}. trying to solve for x^{t+1}
|
|
672
|
+
dt, aux = auxG # expand auxiliary data
|
|
673
|
+
|
|
674
|
+
if self.mtd == 0: # backward Euler
|
|
675
|
+
rhs, _ = self.rhs(
|
|
676
|
+
x, p, t, expand_num, returnFlux=False, aux=aux
|
|
677
|
+
) # should return [nb,ng]
|
|
678
|
+
gg = (x - xt) / dt - rhs
|
|
679
|
+
elif self.mtd == 1: # Crank Nicholson
|
|
680
|
+
rhs, _ = self.rhs(
|
|
681
|
+
x, p, t, expand_num, returnFlux=False, aux=aux
|
|
682
|
+
) # should return [nb,ng]
|
|
683
|
+
rhst, _ = self.rhs(
|
|
684
|
+
xt, p, t, expand_num, returnFlux=False, aux=aux
|
|
685
|
+
) # should return [nb,ng]
|
|
686
|
+
gg = (x - xt) / dt - (rhs + rhst) * 0.5
|
|
687
|
+
return gg
|
|
688
|
+
|
|
689
|
+
def nsteps_pDyn(self, pDyn, x0):
|
|
690
|
+
"""Solve adjoint."""
|
|
691
|
+
bs = self.bs
|
|
692
|
+
ny = self.ny
|
|
693
|
+
delta_t = self.delta_t
|
|
694
|
+
rho = self.rho
|
|
695
|
+
ySolution = torch.zeros((rho, bs, ny)).to(pDyn)
|
|
696
|
+
ySolution[0, :, :] = x0
|
|
697
|
+
|
|
698
|
+
xt = x0.clone().requires_grad_()
|
|
699
|
+
|
|
700
|
+
auxG = (delta_t, None)
|
|
701
|
+
|
|
702
|
+
for t in range(rho):
|
|
703
|
+
p = pDyn[t, :, :]
|
|
704
|
+
|
|
705
|
+
x = self.solveAdj(
|
|
706
|
+
p, xt, t, self.forward, None, auxG, True, self.eval, self.ad_efficient
|
|
707
|
+
)
|
|
708
|
+
|
|
709
|
+
ySolution[t, :, :] = x
|
|
710
|
+
xt = x
|
|
711
|
+
|
|
712
|
+
return ySolution
|