hydrodl2 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,712 @@
1
+ from typing import Any, Optional, Union
2
+
3
+ # import sourcedefender
4
+ import torch
5
+
6
+ from hydrodl2.core.calc import (
7
+ change_param_range,
8
+ finite_difference_jacobian_p,
9
+ uh_conv,
10
+ uh_gamma,
11
+ )
12
+ from hydrodl2.core.calc.batch_jacobian import batchJacobian
13
+
14
+
15
+ class HbvAdj(torch.nn.Module):
16
+ """
17
+ Multi-component PyTorch HBV model using implicit numerical scheme and
18
+ gradient tracking supported by adjoint method.
19
+
20
+ Author
21
+ ------
22
+ Yalan Song
23
+
24
+ Publication
25
+ -----------
26
+ - Song, Y., Knoben, W. J. M., Clark, M. P., Feng, D., Lawson, K. E., & Shen,
27
+ C. (2024). When ancient numerical demons meet physics-informed machine
28
+ learning: Adjoint-based gradients for implicit differentiable modeling.
29
+ Hydrology and Earth System Sciences Discussions, 1-35.
30
+ https://doi.org/10.5194/hess-2023-258
31
+
32
+ Parameters
33
+ ----------
34
+ config
35
+ Configuration dictionary.
36
+ device
37
+ Device to run the model on.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ config: Optional[dict[str, Any]] = None,
43
+ device: Optional[torch.device] = None,
44
+ ) -> None:
45
+ super().__init__()
46
+ self.name = 'HBV Adjoint'
47
+ self.config = config
48
+ self.initialize = False
49
+ self.warm_up = 0
50
+ self.dynamic_params = []
51
+ self.dy_drop = 0.0
52
+ self.variables = ['prcp', 'tmean', 'pet']
53
+ self.routing = True
54
+ self.comprout = False
55
+ self.nearzero = 1e-5
56
+ self.nmul = 1
57
+ self.ad_efficient = True
58
+ self.device = device
59
+ self.parameter_bounds = {
60
+ 'parBETA': [1.0, 6.0],
61
+ 'parFC': [50, 1000],
62
+ 'parK0': [0.05, 0.9],
63
+ 'parK1': [0.01, 0.5],
64
+ 'parK2': [0.001, 0.2],
65
+ 'parLP': [0.2, 1],
66
+ 'parPERC': [0, 10],
67
+ 'parUZL': [0, 100],
68
+ 'parTT': [-2.5, 2.5],
69
+ 'parCFMAX': [0.5, 10],
70
+ 'parCFR': [0, 0.1],
71
+ 'parCWH': [0, 0.2],
72
+ }
73
+ self.routing_parameter_bounds = {
74
+ 'rout_a': [0, 2.9],
75
+ 'rout_b': [0, 6.5],
76
+ }
77
+
78
+ if not device:
79
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
80
+
81
+ if config is not None:
82
+ # Overwrite defaults with config values.
83
+ self.warm_up = config.get('warm_up', self.warm_up)
84
+ self.dy_drop = config.get('dy_drop', self.dy_drop)
85
+ self.dynamic_params = config['dynamic_params'].get(
86
+ self.__class__.__name__, self.dynamic_params
87
+ )
88
+ self.variables = config.get('variables', self.variables)
89
+ self.routing = config.get('routing', self.routing)
90
+ self.comprout = config.get('comprout', self.comprout)
91
+ self.nearzero = config.get('nearzero', self.nearzero)
92
+ self.nmul = config.get('nmul', self.nmul)
93
+ self.ad_efficient = config.get('ad_efficient', self.ad_efficient)
94
+ if 'parBETAET' in self.dynamic_params:
95
+ self.parameter_bounds['parBETAET'] = [0.3, 5]
96
+
97
+ self.set_parameters()
98
+
99
+ def set_parameters(self) -> None:
100
+ """Get physical parameters."""
101
+ self.phy_param_names = self.parameter_bounds.keys()
102
+ if self.routing:
103
+ self.routing_param_names = self.routing_parameter_bounds.keys()
104
+ else:
105
+ self.routing_param_names = []
106
+
107
+ self.learnable_param_count = len(self.phy_param_names) * self.nmul + len(
108
+ self.routing_param_names
109
+ )
110
+
111
+ def unpack_parameters(
112
+ self,
113
+ parameters: torch.Tensor,
114
+ n_steps: int,
115
+ n_grid: int,
116
+ ) -> dict[str, torch.Tensor]:
117
+ """Extract physical model and routing parameters from NN output.
118
+
119
+ Parameters
120
+ ----------
121
+ parameters
122
+ Unprocessed, learned parameters from a neural network.
123
+ n_steps
124
+ Number of time steps in the input data.
125
+ n_grid
126
+ Number of grid cells in the input data.
127
+
128
+ Returns
129
+ -------
130
+ tuple[torch.Tensor, torch.Tensor]
131
+ Tuple of physical and routing parameters.
132
+ """
133
+ phy_param_count = len(self.parameter_bounds)
134
+
135
+ # Physical parameters
136
+ phy_params = torch.sigmoid(
137
+ parameters[:, :, : phy_param_count * self.nmul]
138
+ ).view(
139
+ parameters.shape[0],
140
+ parameters.shape[1],
141
+ phy_param_count,
142
+ self.nmul,
143
+ )
144
+ ## Merge the multi-components into batch dimension for parallel Jacobian
145
+ phy_params = phy_params.permute([0, 3, 1, 2])
146
+ bsnew = n_grid * self.nmul
147
+ phy_params = phy_params.reshape(n_steps, bsnew, phy_param_count)
148
+
149
+ # Routing parameters
150
+ if self.routing:
151
+ routing_params = torch.sigmoid(
152
+ parameters[-1, :, phy_param_count * self.nmul :],
153
+ )
154
+ return phy_params, routing_params
155
+
156
+ def make_phy_parameters(
157
+ self,
158
+ phy_params: torch.Tensor,
159
+ name_list: list,
160
+ dy_list: list,
161
+ ) -> torch.Tensor:
162
+ """Descale physical parameters.
163
+
164
+ Parameters
165
+ ----------
166
+ phy_params
167
+ Normalized physical parameters.
168
+ name_list
169
+ List of physical parameter names.
170
+ dy_list
171
+ List of dynamic parameters.
172
+
173
+ Returns
174
+ -------
175
+ torch.Tensor
176
+ Tensor of physical parameters.
177
+ """
178
+ n_steps, n_grid, nfea = phy_params.size()
179
+ parstaFull = phy_params[-1, :, :].unsqueeze(0).repeat([n_steps, 1, 1])
180
+ # Precompute probability mask for dynamic parameters
181
+ if dy_list:
182
+ pmat = torch.ones([1, n_grid]) * self.dy_drop
183
+ parhbvFull = torch.clone(parstaFull)
184
+ for i, name in enumerate(name_list):
185
+ if name in dy_list:
186
+ staPar = parstaFull[:, :, i]
187
+ dynPar = phy_params[:, :, i]
188
+ drmask = (
189
+ torch.bernoulli(pmat).detach_().to(self.device)
190
+ ) # to drop some dynamic parameters as static
191
+ comPar = dynPar * (1 - drmask) + staPar * drmask
192
+ parhbvFull[:, :, i] = comPar
193
+ return parhbvFull
194
+
195
+ else:
196
+ return parstaFull
197
+
198
+ def descale_rout_parameters(
199
+ self,
200
+ rout_params: torch.Tensor,
201
+ name_list: list,
202
+ ) -> torch.Tensor:
203
+ """Descale routing parameters.
204
+
205
+ Parameters
206
+ ----------
207
+ rout_params
208
+ Normalized routing parameters.
209
+ name_list
210
+ List of routing parameter names.
211
+
212
+ Returns
213
+ -------
214
+ dict
215
+ Dictionary of descaled routing parameters.
216
+ """
217
+ parameter_dict = {}
218
+ for i, name in enumerate(name_list):
219
+ param = rout_params[:, i]
220
+
221
+ parameter_dict[name] = change_param_range(
222
+ param=param,
223
+ bounds=self.routing_parameter_bounds[name],
224
+ )
225
+ return parameter_dict
226
+
227
+ def forward(
228
+ self,
229
+ x_dict: dict[str, torch.Tensor],
230
+ parameters: torch.Tensor,
231
+ ) -> Union[tuple, dict[str, torch.Tensor]]:
232
+ """Forward pass for HBV Adj.
233
+
234
+ Parameters
235
+ ----------
236
+ x_dict
237
+ Dictionary of input forcing data.
238
+ parameters
239
+ Unprocessed, learned parameters from a neural network.
240
+
241
+ Returns
242
+ -------
243
+ Union[tuple, dict]
244
+ Tuple or dictionary of model outputs.
245
+ """
246
+ # Unpack input data.
247
+ x = x_dict['x_phy']
248
+
249
+ n_steps, bs, _ = x.size()
250
+ bsnew = bs * self.nmul
251
+ phy_params, routing_params = self.unpack_parameters(parameters, n_steps, bs)
252
+
253
+ nS = 5 ## For this version of HBV, we have 5 state varibales
254
+ y_init = torch.zeros((bsnew, nS)).to(self.device)
255
+ nflux = 1 # currently only return streamflow
256
+ delta_t = torch.tensor(1.0).to(device=self.device) ## Daily model
257
+ if self.warm_up > 0:
258
+ phy_params_warmup = self.make_phy_parameters(
259
+ phy_params[: self.warm_up, :, :], self.phy_param_names, []
260
+ )
261
+ x_warmup = x[: self.warm_up, :, :].unsqueeze(1).repeat([1, self.nmul, 1, 1])
262
+ x_warmup = x_warmup.view(x_warmup.shape[0], bsnew, x_warmup.shape[-1])
263
+ f_warm_up = HBV(x_warmup, self.parameter_bounds)
264
+ M_warm_up = MOL(
265
+ f_warm_up,
266
+ nS,
267
+ nflux,
268
+ self.warm_up,
269
+ bsDefault=bsnew,
270
+ mtd=0,
271
+ dtDefault=delta_t,
272
+ ad_efficient=self.ad_efficient,
273
+ )
274
+ y0 = M_warm_up.nsteps_pDyn(phy_params_warmup, y_init)[-1, :, :]
275
+ else:
276
+ y0 = y_init
277
+
278
+ phy_params_run = self.make_phy_parameters(
279
+ phy_params[self.warm_up :, :, :], self.phy_param_names, self.dynamic_params
280
+ )
281
+ routy_params_dict = self.descale_rout_parameters(
282
+ routing_params, self.rout_params_name
283
+ )
284
+
285
+ xTrain = x[self.warm_up :, :, :].unsqueeze(1).repeat([1, self.nmul, 1, 1])
286
+ xTrain = xTrain.view(xTrain.shape[0], bsnew, xTrain.shape[-1])
287
+ # Without warm-up, initialize state variables with zeros.
288
+
289
+ nt = phy_params_run.shape[0]
290
+
291
+ simulation = torch.zeros((nt, bsnew, nflux)).to(self.device)
292
+
293
+ f = HBV(xTrain, self.parameter_bounds)
294
+
295
+ M = MOL(
296
+ f,
297
+ nS,
298
+ nflux,
299
+ nt,
300
+ bsDefault=bsnew,
301
+ dtDefault=delta_t,
302
+ mtd=0,
303
+ ad_efficient=self.ad_efficient,
304
+ )
305
+
306
+ ### Newton iterations with adjoint
307
+ ySolution = M.nsteps_pDyn(phy_params_run, y0)
308
+
309
+ for day in range(0, nt):
310
+ _, flux = f(
311
+ ySolution[day, :, :], phy_params_run[day, :, :], day, expanded_num=[1]
312
+ )
313
+ simulation[day, :, :] = flux * delta_t
314
+
315
+ if self.nmul > 1:
316
+ simulation = simulation.view(nt, self.nmul, bs, nflux)
317
+ simulation = simulation.mean(dim=1)
318
+
319
+ routa = routy_params_dict['rout_a'].unsqueeze(0).repeat(nt, 1).unsqueeze(-1)
320
+ routb = routy_params_dict['rout_b'].unsqueeze(0).repeat(nt, 1).unsqueeze(-1)
321
+
322
+ UH = uh_gamma(routa, routb, lenF=15) # lenF: folter
323
+ rf = simulation.permute([1, 2, 0]) # dim:gage*var*time
324
+ UH = UH.permute([1, 2, 0]) # dim: gage*var*time
325
+ Qsrout = uh_conv(rf, UH).permute([2, 0, 1])
326
+
327
+ # Return all sim results.
328
+ return {
329
+ 'flow_sim': Qsrout,
330
+ }
331
+
332
+
333
+ class HBV(torch.nn.Module):
334
+ """HBV model."""
335
+
336
+ def __init__(self, climate_data, parameter_bounds):
337
+ super().__init__()
338
+ self.climate_data = climate_data
339
+ self.parameter_bounds = parameter_bounds
340
+
341
+ def forward(self, y, theta, t, expanded_num, returnFlux=False, aux=None):
342
+ """Forward method."""
343
+ ##parameters
344
+ Beta = self.parameter_bounds['parBETA'][0] + theta[:, 0] * (
345
+ self.parameter_bounds['parBETA'][1] - self.parameter_bounds['parBETA'][0]
346
+ )
347
+ FC = self.parameter_bounds['parFC'][0] + theta[:, 1] * (
348
+ self.parameter_bounds['parFC'][1] - self.parameter_bounds['parFC'][0]
349
+ )
350
+ K0 = self.parameter_bounds['parK0'][0] + theta[:, 2] * (
351
+ self.parameter_bounds['parK0'][1] - self.parameter_bounds['parK0'][0]
352
+ )
353
+ K1 = self.parameter_bounds['parK1'][0] + theta[:, 3] * (
354
+ self.parameter_bounds['parK1'][1] - self.parameter_bounds['parK1'][0]
355
+ )
356
+ K2 = self.parameter_bounds['parK2'][0] + theta[:, 4] * (
357
+ self.parameter_bounds['parK2'][1] - self.parameter_bounds['parK2'][0]
358
+ )
359
+ LP = self.parameter_bounds['parLP'][0] + theta[:, 5] * (
360
+ self.parameter_bounds['parLP'][1] - self.parameter_bounds['parLP'][0]
361
+ )
362
+ PERC = self.parameter_bounds['parPERC'][0] + theta[:, 6] * (
363
+ self.parameter_bounds['parPERC'][1] - self.parameter_bounds['parPERC'][0]
364
+ )
365
+ UZL = self.parameter_bounds['parUZL'][0] + theta[:, 7] * (
366
+ self.parameter_bounds['parUZL'][1] - self.parameter_bounds['parUZL'][0]
367
+ )
368
+ TT = self.parameter_bounds['parTT'][0] + theta[:, 8] * (
369
+ self.parameter_bounds['parTT'][1] - self.parameter_bounds['parTT'][0]
370
+ )
371
+ CFMAX = self.parameter_bounds['parCFMAX'][0] + theta[:, 9] * (
372
+ self.parameter_bounds['parCFMAX'][1] - self.parameter_bounds['parCFMAX'][0]
373
+ )
374
+ CFR = self.parameter_bounds['parCFR'][0] + theta[:, 10] * (
375
+ self.parameter_bounds['parCFR'][1] - self.parameter_bounds['parCFR'][0]
376
+ )
377
+ CWH = self.parameter_bounds['parCWH'][0] + theta[:, 11] * (
378
+ self.parameter_bounds['parCWH'][1] - self.parameter_bounds['parCWH'][0]
379
+ )
380
+ BETAET = self.parameter_bounds['parBETAET'][0] + theta[:, 12] * (
381
+ self.parameter_bounds['parBETAET'][1]
382
+ - self.parameter_bounds['parBETAET'][0]
383
+ )
384
+
385
+ PRECS = 0
386
+ ##% stores
387
+ SNOWPACK = torch.clamp(y[:, 0], min=PRECS) # SNOWPACK
388
+ MELTWATER = torch.clamp(y[:, 1], min=PRECS) # MELTWATER
389
+ SM = torch.clamp(y[:, 2], min=1e-8) # SM
390
+ SUZ = torch.clamp(y[:, 3], min=PRECS) # SUZ
391
+ SLZ = torch.clamp(y[:, 4], min=PRECS) # SLZ
392
+ dS = torch.zeros(y.shape[0], y.shape[1]).to(y)
393
+ fluxes = torch.zeros((y.shape[0], 1)).to(y)
394
+
395
+ climate_in0 = self.climate_data[int(t), :, :] ##% climate at this step
396
+
397
+ for idx, expanded_num_i in enumerate(expanded_num):
398
+ if idx == 0:
399
+ climate_in = climate_in0.repeat_interleave(expanded_num_i, dim=0)
400
+ else:
401
+ climate_in = torch.cat(
402
+ [climate_in, climate_in0.repeat_interleave(expanded_num_i, dim=0)],
403
+ dim=0,
404
+ )
405
+
406
+ P = climate_in[:, 0]
407
+ Ep = climate_in[:, 2]
408
+ T = climate_in[:, 1]
409
+
410
+ ##% fluxes functions
411
+ flux_sf = self.snowfall(P, T, TT)
412
+ flux_refr = self.refreeze(CFR, CFMAX, T, TT, MELTWATER)
413
+ flux_melt = self.melt(CFMAX, T, TT, SNOWPACK)
414
+ flux_rf = self.rainfall(P, T, TT)
415
+ flux_Isnow = self.Isnow(MELTWATER, CWH, SNOWPACK)
416
+ flux_PEFF = self.Peff(SM, FC, Beta, flux_rf, flux_Isnow)
417
+ flux_ex = self.excess(SM, FC)
418
+ flux_et = self.evap(SM, FC, LP, Ep, BETAET)
419
+ flux_perc = self.percolation(PERC, SUZ)
420
+ flux_q0 = self.interflow(K0, SUZ, UZL)
421
+ flux_q1 = self.baseflow(K1, SUZ)
422
+ flux_q2 = self.baseflow(K2, SLZ)
423
+
424
+ # % stores ODEs
425
+ dS[:, 0] = flux_sf + flux_refr - flux_melt
426
+ dS[:, 1] = flux_melt - flux_refr - flux_Isnow
427
+ dS[:, 2] = flux_Isnow + flux_rf - flux_PEFF - flux_ex - flux_et
428
+ dS[:, 3] = flux_PEFF + flux_ex - flux_perc - flux_q0 - flux_q1
429
+ dS[:, 4] = flux_perc - flux_q2
430
+
431
+ fluxes[:, 0] = flux_q0 + flux_q1 + flux_q2
432
+
433
+ if returnFlux:
434
+ return (
435
+ fluxes,
436
+ flux_q0.unsqueeze(-1),
437
+ flux_q1.unsqueeze(-1),
438
+ flux_q2.unsqueeze(-1),
439
+ flux_et.unsqueeze(-1),
440
+ )
441
+ else:
442
+ return dS, fluxes
443
+
444
+ def snowfall(self, P, T, TT):
445
+ """Snowfall."""
446
+ return torch.mul(P, (T < TT))
447
+
448
+ def refreeze(self, CFR, CFMAX, T, TT, MELTWATER):
449
+ """Refreezing."""
450
+ refreezing = CFR * CFMAX * (TT - T)
451
+ refreezing = torch.clamp(refreezing, min=0.0)
452
+ return torch.min(refreezing, MELTWATER)
453
+
454
+ def melt(self, CFMAX, T, TT, SNOWPACK):
455
+ """Snowmelt."""
456
+ melt = CFMAX * (T - TT)
457
+ melt = torch.clamp(melt, min=0.0)
458
+ return torch.min(melt, SNOWPACK)
459
+
460
+ def rainfall(self, P, T, TT):
461
+ """Rainfall."""
462
+ return torch.mul(P, (T >= TT))
463
+
464
+ def Isnow(self, MELTWATER, CWH, SNOWPACK):
465
+ """Snowmelt to soil water."""
466
+ tosoil = MELTWATER - (CWH * SNOWPACK)
467
+ tosoil = torch.clamp(tosoil, min=0.0)
468
+ return tosoil
469
+
470
+ def Peff(self, SM, FC, Beta, flux_rf, flux_Isnow):
471
+ """Effective precipitation."""
472
+ soil_wetness = (SM / FC) ** Beta
473
+ soil_wetness = torch.clamp(soil_wetness, min=0.0, max=1.0)
474
+ return (flux_rf + flux_Isnow) * soil_wetness
475
+
476
+ def excess(self, SM, FC):
477
+ """Excess water."""
478
+ excess = SM - FC
479
+ return torch.clamp(excess, min=0.0)
480
+
481
+ def evap(self, SM, FC, LP, Ep, BETAET):
482
+ """Evapotranspiration."""
483
+ evapfactor = (SM / (LP * FC)) ** BETAET
484
+ evapfactor = torch.clamp(evapfactor, min=0.0, max=1.0)
485
+ ETact = Ep * evapfactor
486
+ return torch.min(SM, ETact)
487
+
488
+ def interflow(self, K0, SUZ, UZL):
489
+ """Interflow."""
490
+ return K0 * torch.clamp(SUZ - UZL, min=0.0)
491
+
492
+ def percolation(self, PERC, SUZ):
493
+ """Percolation."""
494
+ return torch.min(SUZ, PERC)
495
+
496
+ def baseflow(self, K, S):
497
+ """Baseflow function."""
498
+ return K * S
499
+
500
+
501
+ matrixSolve = torch.linalg.solve
502
+
503
+
504
+ class NewtonSolve(torch.autograd.Function):
505
+ """Newton solver for the adjoint model."""
506
+
507
+ @staticmethod
508
+ def forward(
509
+ ctx, p, p2, t, G, x0=None, auxG=None, batchP=True, eval=False, ad_efficient=True
510
+ ):
511
+ """Forward solver."""
512
+ useAD_jac = True
513
+ if x0 is None and p2 is not None:
514
+ x0 = p2
515
+
516
+ x = x0.clone().detach()
517
+ i = 0
518
+ max_iter = 3
519
+ gtol = 1e-3
520
+
521
+ if useAD_jac:
522
+ torch.set_grad_enabled(True)
523
+
524
+ x.requires_grad = True
525
+
526
+ if p2 is None:
527
+ gg = G(x, p, t, auxG)
528
+ else:
529
+ gg = G(x, p, p2, t, auxG)
530
+ if ad_efficient:
531
+ dGdx = batchJacobian(gg, x, graphed=True)
532
+ # else:
533
+ # dGdx = batchJacobian_AD_slow(gg, x, graphed=True)
534
+ if torch.isnan(dGdx).any() or torch.isinf(dGdx).any():
535
+ raise RuntimeError("Jacobian matrix is NaN")
536
+ x = x.detach()
537
+
538
+ torch.set_grad_enabled(False)
539
+ resnorm = torch.linalg.norm(
540
+ gg, float('inf'), dim=[1]
541
+ ) # calculate norm of the residuals
542
+ resnorm0 = 100 * resnorm
543
+
544
+ while (torch.max(resnorm) > gtol) and i <= max_iter:
545
+ i += 1
546
+ if torch.max(resnorm / resnorm0) > 0.2:
547
+ if useAD_jac:
548
+ torch.set_grad_enabled(True)
549
+
550
+ x.requires_grad = True
551
+
552
+ if p2 is None:
553
+ gg = G(x, p, t, auxG)
554
+ else:
555
+ gg = G(x, p, p2, t, auxG)
556
+ if ad_efficient:
557
+ dGdx = batchJacobian(gg, x, graphed=True)
558
+ # else:
559
+ # dGdx = batchJacobian_AD_slow(gg, x, graphed=True)
560
+ if torch.isnan(dGdx).any() or torch.isinf(dGdx).any():
561
+ raise RuntimeError("Jacobian matrix is NaN")
562
+
563
+ x = x.detach()
564
+
565
+ torch.set_grad_enabled(False)
566
+
567
+ if dGdx.ndim == gg.ndim: # same dimension, must be scalar.
568
+ dx = (gg / dGdx).detach()
569
+ else:
570
+ dx = matrixSolve(dGdx, gg).detach()
571
+ x = x - dx
572
+ if useAD_jac:
573
+ torch.set_grad_enabled(True)
574
+ x.requires_grad = True
575
+ if p2 is None:
576
+ gg = G(x, p, t, auxG)
577
+ else:
578
+ gg = G(x, p, p2, t, auxG)
579
+ torch.set_grad_enabled(False)
580
+ resnorm0 = resnorm ##% old resnorm
581
+ resnorm = torch.linalg.norm(gg, float('inf'), dim=[1])
582
+
583
+ torch.set_grad_enabled(True)
584
+ x = x.detach()
585
+ if not eval:
586
+ if batchP:
587
+ # dGdp is needed only upon convergence.
588
+ if p2 is None:
589
+ if ad_efficient:
590
+ dGdp = batchJacobian(gg, p, graphed=True)
591
+ dGdp2 = None
592
+ else:
593
+ # dGdp = batchJacobian_AD_slow(gg, p, graphed=True);
594
+ dGdp2 = None
595
+ else:
596
+ if ad_efficient:
597
+ dGdp, dGdp2 = batchJacobian(gg, (p, p2), graphed=True)
598
+ else:
599
+ dx = matrixSolve(dGdx, gg)
600
+ x = x - dx
601
+ gg = G(x, p, p2, t, [1], auxG)
602
+ resnorm0 = resnorm ##% old resnorm
603
+ resnorm = torch.linalg.norm(gg, float('inf'), dim=[1])
604
+
605
+ if batchP:
606
+ dGdp, dGdp2 = finite_difference_jacobian_p(G, x, p, p2, t, 1e-6, auxG)
607
+
608
+ else:
609
+ assert "nonbatchp (like NN) pathway not debugged through yet"
610
+ # print("day ", t, "Iterations ", i)
611
+ ctx.save_for_backward(dGdp.float(), dGdp2.float(), dGdx.float())
612
+ # This way, we reduced one forward run. You can also save these two to the CPU if forward run is
613
+ # Alternatively, if memory is a problem, save x and run g during the backward.
614
+ del gg
615
+ return x.float()
616
+
617
+ @staticmethod
618
+ def backward(ctx, dLdx):
619
+ # pydevd.settrace(suspend=False, trace_only_current_thread=True)
620
+ with torch.no_grad():
621
+ dGdp, dGdp2, dGdx = ctx.saved_tensors
622
+ dGdxT = torch.permute(dGdx, (0, 2, 1))
623
+ lambTneg = matrixSolve(dGdxT, dLdx)
624
+ if lambTneg.ndim <= 2:
625
+ lambTneg = torch.unsqueeze(lambTneg, 2)
626
+ dLdp = -torch.bmm(torch.permute(lambTneg, (0, 2, 1)), dGdp)
627
+ dLdp = torch.squeeze(dLdp, 1) # ADHOC!! DON"T KNOW WHY!!
628
+ if dGdp2 is None:
629
+ dLdp2 = None
630
+ else:
631
+ dLdp2 = -torch.bmm(torch.permute(lambTneg, (0, 2, 1)), dGdp2)
632
+ dLdp2 = torch.squeeze(dLdp2, 1) # ADHOC!! DON"T KNOW WHY!!
633
+ return dLdp, dLdp2, None, None, None, None, None, None
634
+
635
+
636
+ class MOL(torch.nn.Module):
637
+ """
638
+ Method of Lines time integrator as a nonlinear equation
639
+ G(x, p, xt, t, auxG)=0.
640
+ RHS is preloaded at construct and is the equation for the right hand side
641
+ of the equation.
642
+ """
643
+
644
+ def __init__(
645
+ self,
646
+ rhsFunc,
647
+ ny,
648
+ nflux,
649
+ rho,
650
+ bsDefault=1,
651
+ mtd=0,
652
+ dtDefault=0,
653
+ solveAdj=NewtonSolve.apply,
654
+ eval=False,
655
+ ad_efficient=True,
656
+ ):
657
+ super().__init__()
658
+ self.mtd = mtd # time discretization method. =0 for backward Euler
659
+ self.rhs = rhsFunc
660
+ self.delta_t = dtDefault
661
+ self.bs = bsDefault
662
+ self.ny = ny
663
+ self.nflux = nflux
664
+ self.rho = rho
665
+ self.solveAdj = solveAdj
666
+ self.eval = eval
667
+ self.ad_efficient = ad_efficient
668
+
669
+ def forward(self, x, p, xt, t, expand_num, auxG): # take one step
670
+ """Forward model."""
671
+ # xt is x^{t}. trying to solve for x^{t+1}
672
+ dt, aux = auxG # expand auxiliary data
673
+
674
+ if self.mtd == 0: # backward Euler
675
+ rhs, _ = self.rhs(
676
+ x, p, t, expand_num, returnFlux=False, aux=aux
677
+ ) # should return [nb,ng]
678
+ gg = (x - xt) / dt - rhs
679
+ elif self.mtd == 1: # Crank Nicholson
680
+ rhs, _ = self.rhs(
681
+ x, p, t, expand_num, returnFlux=False, aux=aux
682
+ ) # should return [nb,ng]
683
+ rhst, _ = self.rhs(
684
+ xt, p, t, expand_num, returnFlux=False, aux=aux
685
+ ) # should return [nb,ng]
686
+ gg = (x - xt) / dt - (rhs + rhst) * 0.5
687
+ return gg
688
+
689
+ def nsteps_pDyn(self, pDyn, x0):
690
+ """Solve adjoint."""
691
+ bs = self.bs
692
+ ny = self.ny
693
+ delta_t = self.delta_t
694
+ rho = self.rho
695
+ ySolution = torch.zeros((rho, bs, ny)).to(pDyn)
696
+ ySolution[0, :, :] = x0
697
+
698
+ xt = x0.clone().requires_grad_()
699
+
700
+ auxG = (delta_t, None)
701
+
702
+ for t in range(rho):
703
+ p = pDyn[t, :, :]
704
+
705
+ x = self.solveAdj(
706
+ p, xt, t, self.forward, None, auxG, True, self.eval, self.ad_efficient
707
+ )
708
+
709
+ ySolution[t, :, :] = x
710
+ xt = x
711
+
712
+ return ySolution
@@ -0,0 +1,2 @@
1
+ # Augmentations to dMG differentiable models will sit in this directory, and
2
+ # will be interfaced with via Model Handler.