jinns 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,12 +1,18 @@
1
1
  import jax
2
- from jax import jit, grad, jacrev
2
+ from jax import jit, grad, jacrev, jacfwd
3
3
  import jax.numpy as jnp
4
+ from jinns.utils._utils import _get_grid
5
+ from jinns.utils._pinn import PINN
6
+ from jinns.utils._spinn import SPINN
4
7
  from jinns.loss._DynamicLossAbstract import ODE, PDEStatio, PDENonStatio
5
8
  from jinns.loss._operators import (
6
- _laplacian,
7
- _div,
9
+ _laplacian_rev,
10
+ _laplacian_fwd,
11
+ _div_rev,
12
+ _div_fwd,
8
13
  _vectorial_laplacian,
9
- _u_dot_nabla_times_u,
14
+ _u_dot_nabla_times_u_rev,
15
+ _u_dot_nabla_times_u_fwd,
10
16
  )
11
17
 
12
18
 
@@ -48,145 +54,6 @@ class FisherKPP(PDENonStatio):
48
54
  """
49
55
  Evaluate the dynamic loss at :math:`(t,x)`.
50
56
 
51
- **Note:** In practice this `u` is vectorized and `t` and `x` have a
52
- batch dimension.
53
-
54
- Parameters
55
- ---------
56
- t
57
- A time point
58
- x
59
- A point in :math:`\Omega`
60
- u
61
- The PINN
62
- params
63
- The dictionary of parameters of the model.
64
- Typically, it is a dictionary of
65
- dictionaries: `eq_params` and `nn_params``, respectively the
66
- differential equation parameters and the neural network parameter
67
- """
68
- nn_params, eq_params = self.set_stop_gradient(params)
69
-
70
- eq_params = self._eval_heterogeneous_parameters(
71
- eq_params, t, x, self.eq_params_heterogeneity
72
- )
73
-
74
- du_dt = grad(u, 0)(t, x, nn_params)[0]
75
-
76
- lap = _laplacian(u, nn_params, eq_params, x, t)
77
-
78
- return du_dt + self.Tmax * (
79
- -eq_params["D"] * lap
80
- - u(t, x, nn_params, eq_params)
81
- * (eq_params["r"] - eq_params["g"] * u(t, x, nn_params, eq_params))
82
- )
83
-
84
-
85
- class Malthus(ODE):
86
- r"""
87
- Return a Malthus dynamic loss term following the PINN logic:
88
-
89
- .. math::
90
- \frac{\partial}{\partial t} u(t)=ru(t)
91
-
92
- """
93
-
94
- def __init__(self, Tmax=1, derivatives="nn_params", eq_params_heterogeneity=None):
95
- """
96
- Parameters
97
- ----------
98
- Tmax
99
- Tmax needs to be given when the PINN time input is normalized in
100
- [0, 1], ie. we have performed renormalization of the differential
101
- equation
102
- derivatives
103
- A string. Either ``nn_params``, ``eq_params``, ``both``. Determines
104
- with respect to which set of parameters gradients of the dynamic
105
- loss are computed. Default "nn_params", this is what is typically
106
- done in solving forward problems, when we only estimate the
107
- equation solution with as PINN.
108
- eq_params_heterogeneity
109
- Default None. A dict with the keys being the same as in eq_params
110
- and the value being `time`, `space`, `both` or None which corresponds to
111
- the heterogeneity of a given parameter. A value can be missing, in
112
- this case there is no heterogeneity (=None). If
113
- eq_params_heterogeneity is None this means there is no
114
- heterogeneity for no parameters.
115
- """
116
- super().__init__(Tmax, derivatives, eq_params_heterogeneity)
117
-
118
- def evaluate(self, t, u, params):
119
- """
120
- Evaluate the dynamic loss at `t`.
121
- For stability we implement the dynamic loss in log space.
122
-
123
- **Note:** In practice this `u` is vectorized and `t` has a
124
- batch dimension.
125
-
126
- Parameters
127
- ---------
128
- t
129
- A time point
130
- u
131
- The PINN
132
- params
133
- The dictionary of parameters of the model.
134
- Typically, it is a dictionary of
135
- dictionaries: `eq_params` and `nn_params``, respectively the
136
- differential equation parameters and the neural network parameter
137
- """
138
- nn_params, eq_params = self.set_stop_gradient(params)
139
-
140
- eq_params = self._eval_heterogeneous_parameters(
141
- eq_params, t, x, self.eq_params_heterogeneity
142
- )
143
-
144
- # NOTE the log formulation of the loss for stability
145
- du_dt = grad(lambda t: jnp.log(u(t, nn_params, eq_params)), 0)(t)
146
- return du_dt - eq_params["growth_rate"]
147
-
148
-
149
- class BurgerEquation(PDENonStatio):
150
- r"""
151
- Return the Burger dynamic loss term (in 1 space dimension):
152
-
153
- .. math::
154
- \frac{\partial}{\partial t} u(t,x) + u(t,x)\frac{\partial}{\partial x}
155
- u(t,x) - \theta \frac{\partial^2}{\partial x^2} u(t,x) = 0
156
-
157
- """
158
-
159
- def __init__(self, Tmax=1, derivatives="nn_params", eq_params_heterogeneity=None):
160
- """
161
- Parameters
162
- ----------
163
- Tmax
164
- Tmax needs to be given when the PINN time input is normalized in
165
- [0, 1], ie. we have performed renormalization of the differential
166
- equation
167
- derivatives
168
- A string. Either ``nn_params``, ``eq_params``, ``both``. Determines
169
- with respect to which set of parameters gradients of the dynamic
170
- loss are computed. Default "nn_params", this is what is typically
171
- done in solving forward problems, when we only estimate the
172
- equation solution with as PINN.
173
- eq_params_heterogeneity
174
- Default None. A dict with the keys being the same as in eq_params
175
- and the value being `time`, `space`, `both` or None which corresponds to
176
- the heterogeneity of a given parameter. A value can be missing, in
177
- this case there is no heterogeneity (=None). If
178
- eq_params_heterogeneity is None this means there is no
179
- heterogeneity for no parameters.
180
- """
181
- super().__init__(Tmax, derivatives, eq_params_heterogeneity)
182
-
183
- def evaluate(self, t, x, u, params):
184
- """
185
- Evaluate the dynamic loss at :math:`(t,x)`.
186
-
187
- **Note:** In practice this `u` is vectorized and `t` and `x` have a
188
- batch dimension.
189
-
190
57
  Parameters
191
58
  ---------
192
59
  t
@@ -201,548 +68,49 @@ class BurgerEquation(PDENonStatio):
201
68
  dictionaries: `eq_params` and `nn_params``, respectively the
202
69
  differential equation parameters and the neural network parameter
203
70
  """
204
- nn_params, eq_params = self.set_stop_gradient(params)
205
-
206
- eq_params = self._eval_heterogeneous_parameters(
207
- eq_params, t, x, self.eq_params_heterogeneity
208
- )
209
-
210
- du_dt = grad(u, 0)
211
- du_dx = grad(u, 1)
212
- du2_dx2 = grad(
213
- lambda t, x, nn_params, eq_params: du_dx(t, x, nn_params, eq_params)[0],
214
- 1,
215
- )
216
-
217
- return du_dt(t, x, nn_params, eq_params)[0] + self.Tmax * (
218
- u(t, x, nn_params, eq_params) * du_dx(t, x, nn_params, eq_params)[0]
219
- - eq_params["nu"] * du2_dx2(t, x, nn_params, eq_params)[0]
220
- )
221
-
222
-
223
- class GeneralizedLotkaVolterra(ODE):
224
- r"""
225
- Return a dynamic loss from an equation of a Generalized Lotka Volterra
226
- system. Say we implement the equation for population :math:`i`:
227
-
228
- .. math::
229
- \frac{\partial}{\partial t}u_i(t) = r_iu_i(t) - \sum_{j\neq i}\alpha_{ij}u_j(t)
230
- -\alpha_{i,i}u_i(t) + c_iu_i(t) + \sum_{j \neq i} c_ju_j(t)
231
-
232
- with :math:`r_i` the growth rate parameter, :math:`c_i` the carrying
233
- capacities and :math:`\alpha_{ij}` the interaction terms.
234
-
235
- """
236
-
237
- def __init__(
238
- self,
239
- key_main,
240
- keys_other,
241
- Tmax=1,
242
- derivatives="nn_params",
243
- eq_params_heterogeneity=None,
244
- ):
245
- """
246
- Parameters
247
- ----------
248
- key_main
249
- The dictionary key (in the dictionaries ``u`` and ``params`` that
250
- are arguments of the ``evaluate`` function) of the main population
251
- :math:`i` of the particular equation of the system implemented
252
- by this dynamic loss
253
- keys_other
254
- The list of dictionary keys (in the dictionaries ``u`` and ``params`` that
255
- are arguments of the ``evaluate`` function) of the other
256
- populations that appear in the equation of the system implemented
257
- by this dynamic loss
258
- Tmax
259
- Tmax needs to be given when the PINN time input is normalized in
260
- [0, 1], ie. we have performed renormalization of the differential
261
- equation
262
- derivatives
263
- A string. Either ``nn_params``, ``eq_params``, ``both``. Determines
264
- with respect to which set of parameters gradients of the dynamic
265
- loss are computed. Default "nn_params", this is what is typically
266
- done in solving forward problems, when we only estimate the
267
- equation solution with as PINN.
268
- eq_params_heterogeneity
269
- Default None. A dict with the keys being the same as in eq_params
270
- and the value being `time`, `space`, `both` or None which corresponds to
271
- the heterogeneity of a given parameter. A value can be missing, in
272
- this case there is no heterogeneity (=None). If
273
- eq_params_heterogeneity is None this means there is no
274
- heterogeneity for no parameters.
275
- """
276
- super().__init__(Tmax, derivatives, eq_params_heterogeneity)
277
- self.key_main = key_main
278
- self.keys_other = keys_other
279
-
280
- def evaluate(self, t, u_dict, params_dict):
281
- """
282
- Evaluate the dynamic loss at `t`.
283
- For stability we implement the dynamic loss in log space.
284
-
285
- **Note:** In practice each `u` from `u_dict` is vectorized and `t` has a
286
- batch dimension.
287
-
288
- Parameters
289
- ---------
290
- t
291
- A time point
292
- u_dict
293
- A dictionary of PINNS. Must have the same keys as `params_dict`
294
- params_dict
295
- The dictionary of dictionaries of parameters of the model.
296
- Typically, each sub-dictionary is a dictionary
297
- with keys: `eq_params` and `nn_params``, respectively the
298
- differential equation parameters and the neural network parameter.
299
- Must have the same keys as `u_dict`
300
- """
301
- nn_params, eq_params = self.set_stop_gradient(params_dict)
302
-
303
- u_nn_params = nn_params[self.key_main]
304
- u_eq_params = eq_params[self.key_main]
305
-
306
- u = u_dict[self.key_main]
307
- du_dt = grad(lambda t: jnp.log(u(t, u_nn_params, u_eq_params)), 0)(t)
308
- carrying_term = u_eq_params["carrying_capacity"] * u(
309
- t, u_nn_params, u_eq_params
310
- )
311
- for k in self.keys_other:
312
- carrying_term += u_eq_params["carrying_capacity"] * u_dict[k](
313
- t, nn_params[k], eq_params[k]
314
- )
315
- # NOTE the following assumes interaction term with oneself is at idx 0
316
- interaction_terms = u_eq_params["interactions"][0] * u(
317
- t, u_nn_params, u_eq_params
318
- )
319
- for i, k in enumerate(self.keys_other):
320
- interaction_terms += u_eq_params["interactions"][i + 1] * u_dict[k](
321
- t, nn_params[k], eq_params[k]
71
+ if isinstance(u, PINN):
72
+ nn_params, eq_params = self.set_stop_gradient(params)
73
+ eq_params = self._eval_heterogeneous_parameters(
74
+ eq_params, t, x, self.eq_params_heterogeneity
322
75
  )
323
76
 
324
- return du_dt + self.Tmax * (
325
- -u_eq_params["growth_rate"] - interaction_terms + carrying_term
326
- )
327
-
328
-
329
- class FPEStatioLoss1D(PDEStatio):
330
- r"""
331
- Return the dynamic loss for a stationary Fokker Planck Equation in one
332
- dimension:
333
-
334
- .. math::
335
- -\frac{\partial}{\partial x}\left[\mu(x)u(x)\right] +
336
- \frac{\partial^2}{\partial x^2}\left[D(x)u(x)\right]=0
337
-
338
- where :math:`\mu(x)` is the drift term and :math:`D(x)` is the diffusion
339
- term.
340
-
341
- The drift and diffusion terms are not specified here, hence this class
342
- is `abstract`.
343
- Other classes inherit from FPEStatioLoss1D and define the drift and
344
- diffusion terms, which then defines several other dynamic losses
345
- (Ornstein-Uhlenbeck, Cox-Ingersoll-Ross, ...)
346
- """
347
-
348
- def __init__(self, derivatives="nn_params", eq_params_heterogeneity=None):
349
- """
350
- Parameters
351
- ----------
352
- derivatives
353
- A string. Either ``nn_params``, ``eq_params``, ``both``. Determines
354
- with respect to which set of parameters gradients of the dynamic
355
- loss are computed. Default "nn_params", this is what is typically
356
- done in solving forward problems, when we only estimate the
357
- equation solution with as PINN.
358
- eq_params_heterogeneity
359
- Default None. A dict with the keys being the same as in eq_params
360
- and the value being `time`, `space`, `both` or None which corresponds to
361
- the heterogeneity of a given parameter. A value can be missing, in
362
- this case there is no heterogeneity (=None). If
363
- eq_params_heterogeneity is None this means there is no
364
- heterogeneity for no parameters.
365
- """
366
- super().__init__(derivatives, eq_params_heterogeneity)
367
-
368
- def evaluate(self, x, u, params):
369
- """
370
- Evaluate the dynamic loss at `x`.
371
-
372
- **Note:** In practice this `u` is vectorized and `x` has a
373
- batch dimension.
374
-
375
- Parameters
376
- ---------
377
- x
378
- A point in :math:`\Omega`
379
- u
380
- The PINN
381
- params
382
- The dictionary of parameters of the model.
383
- Typically, it is a dictionary of
384
- dictionaries: `eq_params` and `nn_params``, respectively the
385
- differential equation parameters and the neural network parameter
386
- """
387
- nn_params, eq_params = self.set_stop_gradient(params)
388
-
389
- # (drift * u)'
390
- order_1 = grad(
391
- lambda x: (self.drift(x, eq_params) * u(x, nn_params, eq_params))[0],
392
- 0,
393
- )(x)
394
-
395
- # (diffusion * u)''
396
- order_2 = grad(
397
- lambda x: grad(
398
- lambda x: (self.diffusion(x, eq_params) * u(x, nn_params, eq_params))[
399
- 0
400
- ],
401
- 0,
402
- )(x)[0],
403
- 0,
404
- )(x)
405
-
406
- return -order_1 + order_2
407
-
408
-
409
- class OU_FPEStatioLoss1D(FPEStatioLoss1D):
410
- r"""
411
- Return the dynamic loss whose solution is the probability density
412
- function of a stationary Ornstein-Uhlenbeck process in one dimension:
413
-
414
- .. math::
415
- -\frac{\partial}{\partial x}\left[(\alpha(\mu - x))u(x)\right] +
416
- \frac{\partial^2}{\partial x^2}\left[\frac{\sigma^2}{2}u(x)\right]=0
417
-
418
- """
419
-
420
- def __init__(self, derivatives="nn_params", eq_params_heterogeneity=None):
421
- """
422
- Parameters
423
- ----------
424
- derivatives
425
- A string. Either ``nn_params``, ``eq_params``, ``both``. Determines
426
- with respect to which set of parameters gradients of the dynamic
427
- loss are computed. Default "nn_params", this is what is typically
428
- done in solving forward problems, when we only estimate the
429
- equation solution with as PINN.
430
- eq_params_heterogeneity
431
- Default None. A dict with the keys being the same as in eq_params
432
- and the value being `time`, `space`, `both` or None which corresponds to
433
- the heterogeneity of a given parameter. A value can be missing, in
434
- this case there is no heterogeneity (=None). If
435
- eq_params_heterogeneity is None this means there is no
436
- heterogeneity for no parameters.
437
- """
438
- super().__init__(derivatives, eq_params_heterogeneity)
439
-
440
- def drift(self, x, eq_params):
441
- r"""
442
- Return the drift term
443
-
444
- Parameters
445
- ----------
446
- x
447
- A point in :math:`\Omega`
448
- eq_params
449
- A dictionary containing the equation parameters
450
- """
451
- return eq_params["alpha"] * (eq_params["mu"] - x)
452
-
453
- def diffusion(self, x, eq_params):
454
- r"""
455
- Return the computation of the diffusion tensor term in 1D
456
-
457
- Parameters
458
- ----------
459
- x
460
- A point in :math:`\Omega`
461
- eq_params
462
- A dictionary containing the equation parameters
463
- """
464
- return 0.5 * eq_params["sigma"] ** 2
465
-
466
-
467
- class CIR_FPEStatioLoss1D(FPEStatioLoss1D):
468
- r"""
469
- Return the dynamic loss whose solution is the probability density
470
- function of a stationary Cox-Ingersoll-Ross process in one dimension:
471
-
472
- .. math::
473
- -\frac{\partial}{\partial x}\left[(\mu - \alpha x)u(x)\right] +
474
- \frac{\partial^2}{\partial x^2}\left[\frac{\sigma^2}{2}xu(x)\right]=0
475
-
476
- """
477
-
478
- def __init__(self, derivatives="nn_params", eq_params_heterogeneity=None):
479
- """
480
- Parameters
481
- ----------
482
- derivatives
483
- A string. Either ``nn_params``, ``eq_params``, ``both``. Determines
484
- with respect to which set of parameters gradients of the dynamic
485
- loss are computed. Default "nn_params", this is what is typically
486
- done in solving forward problems, when we only estimate the
487
- equation solution with as PINN.
488
- eq_params_heterogeneity
489
- Default None. A dict with the keys being the same as in eq_params
490
- and the value being `time`, `space`, `both` or None which corresponds to
491
- the heterogeneity of a given parameter. A value can be missing, in
492
- this case there is no heterogeneity (=None). If
493
- eq_params_heterogeneity is None this means there is no
494
- heterogeneity for no parameters.
495
- """
496
- super().__init__(derivatives, eq_params_heterogeneity)
497
-
498
- def drift(self, x, eq_params):
499
- r"""
500
- Return the drift term
501
-
502
- Parameters
503
- ----------
504
- x
505
- A point in :math:`\Omega`
506
- eq_params
507
- A dictionary containing the equation parameters
508
- """
509
- return eq_params["mu"] - eq_params["alpha"] * x
510
-
511
- def diffusion(self, x, eq_params):
512
- r"""
513
- Return the computation of the diffusion tensor term in 1D
514
-
515
- Parameters
516
- ----------
517
- x
518
- A point in :math:`\Omega`
519
- eq_params
520
- A dictionary containing the equation parameters
521
- """
522
- return 0.5 * (eq_params["sigma"] ** 2) * x
523
-
524
-
525
- class FPENonStatioLoss1D(PDENonStatio):
526
- r"""
527
- Return the dynamic loss for a non stationary Fokker Planck Equation in one
528
- dimension:
529
-
530
- .. math::
531
- -\frac{\partial}{\partial x}\left[\mu(t, x)u(t, x)\right] +
532
- \frac{\partial^2}{\partial x^2}\left[D(t, x)u(t, x)\right] =
533
- \frac{\partial}{\partial t}u(t,x)
534
-
535
- where :math:`\mu(t, x)` is the drift term and :math:`D(t, x)` is the diffusion
536
- term.
537
-
538
- The drift and diffusion terms are not specified here, hence this class
539
- is `abstract`.
540
- Other classes inherit from FPENonStatioLoss1D and define the drift and
541
- diffusion terms, which then defines several other dynamic losses
542
- (Ornstein-Uhlenbeck, Cox-Ingersoll-Ross, ...)
543
- """
544
-
545
- def __init__(self, Tmax=1, derivatives="nn_params", eq_params_heterogeneity=None):
546
- """
547
- Parameters
548
- ----------
549
- Tmax
550
- Tmax needs to be given when the PINN time input is normalized in
551
- [0, 1], ie. we have performed renormalization of the differential
552
- equation
553
- eq_params_heterogeneity
554
- Default None. A dict with the keys being the same as in eq_params
555
- and the value being `time`, `space`, `both` or None which corresponds to
556
- the heterogeneity of a given parameter. A value can be missing, in
557
- this case there is no heterogeneity (=None). If
558
- eq_params_heterogeneity is None this means there is no
559
- heterogeneity for no parameters.
560
- """
561
- super().__init__(Tmax, derivatives, eq_params_heterogeneity)
562
-
563
- def evaluate(self, t, x, u, params):
564
- """
565
- Evaluate the dynamic loss at :math:`(t,x)`.
566
-
567
- **Note:** In practice this `u` is vectorized and `t` and `x` have a
568
- batch dimension.
569
-
570
- Parameters
571
- ---------
572
- t
573
- A time point
574
- x
575
- A point in :math:`\Omega`
576
- u
577
- The PINN
578
- params
579
- The dictionary of parameters of the model.
580
- Typically, it is a dictionary of
581
- dictionaries: `eq_params` and `nn_params``, respectively the
582
- differential equation parameters and the neural network parameter
583
- """
584
- nn_params, eq_params = self.set_stop_gradient(params)
585
- # (drift * u)'
586
-
587
- order_1 = grad(
588
- lambda t, x: (self.drift(t, x, eq_params) * u(t, x, nn_params, eq_params))[
589
- 0
590
- ],
591
- 1,
592
- )(t, x)
593
-
594
- # (diffusion * u)''
595
- order_2 = grad(
596
- lambda t, x: grad(
597
- lambda t, x: (
598
- self.diffusion(t, x, eq_params) * u(t, x, nn_params, eq_params)
599
- )[0],
600
- 1,
601
- )(t, x)[0],
602
- 1,
603
- )(t, x)
604
-
605
- du_dt = grad(u, 0)(t, x, nn_params, eq_params)
606
-
607
- return -du_dt + self.Tmax * (-order_1 + order_2)
608
-
609
-
610
- class OU_FPENonStatioLoss1D(FPENonStatioLoss1D):
611
- r"""
612
- Return the dynamic loss whose solution is the probability density
613
- function of a non-stationary Ornstein-Uhlenbeck process in one dimension:
614
-
615
- .. math::
616
- -\frac{\partial}{\partial x}\left[(\alpha(\mu - x))u(t,x)\right] +
617
- \frac{\partial^2}{\partial x^2}\left[\frac{\sigma^2}{2}u(t,x)\right] =
618
- \frac{\partial}{\partial t}u(t,x)
619
-
620
- """
621
-
622
- def __init__(self, Tmax=1, derivatives="nn_params", eq_params_heterogeneity=None):
623
- """
624
- Parameters
625
- ----------
626
- Tmax
627
- Tmax needs to be given when the PINN time input is normalized in
628
- [0, 1], ie. we have performed renormalization of the differential
629
- equation
630
- derivatives
631
- A string. Either ``nn_params``, ``eq_params``, ``both``. Determines
632
- with respect to which set of parameters gradients of the dynamic
633
- loss are computed. Default "nn_params", this is what is typically
634
- done in solving forward problems, when we only estimate the
635
- equation solution with as PINN.
636
- eq_params_heterogeneity
637
- Default None. A dict with the keys being the same as in eq_params
638
- and the value being `time`, `space`, `both` or None which corresponds to
639
- the heterogeneity of a given parameter. A value can be missing, in
640
- this case there is no heterogeneity (=None). If
641
- eq_params_heterogeneity is None this means there is no
642
- heterogeneity for no parameters.
643
- """
644
- super().__init__(Tmax, derivatives, eq_params_heterogeneity)
645
-
646
- def drift(self, t, x, eq_params):
647
- r"""
648
- Return the drift term
649
-
650
- Parameters
651
- ----------
652
- t
653
- A time point
654
- x
655
- A point in :math:`\Omega`
656
- eq_params
657
- A dictionary containing the equation parameters
658
- """
659
- return eq_params["alpha"] * (eq_params["mu"] - x)
660
-
661
- def diffusion(self, t, x, eq_params):
662
- r"""
663
- Return the computation of the diffusion tensor term in 1D
664
-
665
- Parameters
666
- ----------
667
- t
668
- A time point
669
- x
670
- A point in :math:`\Omega`
671
- eq_params
672
- A dictionary containing the equation parameters
673
- """
674
- return 0.5 * eq_params["sigma"] ** 2
675
-
676
-
677
- class CIR_FPENonStatioLoss1D(FPENonStatioLoss1D):
678
- r"""
679
- Return the dynamic loss whose solution is the probability density
680
- function of a stationary Cox-Ingersoll-Ross process in one dimension:
681
-
682
- .. math::
683
- -\frac{\partial}{\partial x}\left[(\mu - \alpha x)u(x)\right] +
684
- \frac{\partial^2}{\partial x^2}\left[\frac{\sigma^2}{2}xu(x)\right] =
685
- \frac{\partial}{\partial t}u(t,x)
686
-
687
- """
688
-
689
- def __init__(self, Tmax=1, derivatives="nn_params"):
690
- """
691
- Parameters
692
- ----------
693
- Tmax
694
- Tmax needs to be given when the PINN time input is normalized in
695
- [0, 1], ie. we have performed renormalization of the differential
696
- equation
697
- derivatives
698
- A string. Either ``nn_params``, ``eq_params``, ``both``. Determines
699
- with respect to which set of parameters gradients of the dynamic
700
- loss are computed. Default "nn_params", this is what is typically
701
- done in solving forward problems, when we only estimate the
702
- equation solution with as PINN.
703
- """
704
- super().__init__(Tmax, derivatives, eq_params_heterogeneity)
705
-
706
- def drift(self, t, x, eq_params):
707
- r"""
708
- Return the drift term
709
-
710
- Parameters
711
- ----------
712
- t
713
- A time point
714
- x
715
- A point in :math:`\Omega`
716
- eq_params
717
- A dictionary containing the equation parameters
718
- """
719
- return eq_params["mu"] - eq_params["alpha"] * x
720
-
721
- def diffusion(self, t, x, eq_params):
722
- r"""
723
- Return the computation of the diffusion tensor term in 1D
724
-
725
- Parameters
726
- ----------
727
- t
728
- A time point
729
- x
730
- A point in :math:`\Omega`
731
- eq_params
732
- A dictionary containing the equation parameters
733
- """
734
- return 0.5 * (eq_params["sigma"] ** 2) * x
77
+ # Note that the last dim of u is nec. 1
78
+ u_ = lambda t, x: u(t, x, nn_params, eq_params)[0]
735
79
 
80
+ du_dt = grad(u_, 0)(t, x)
736
81
 
737
- class Sinus_FPENonStatioLoss1D(FPENonStatioLoss1D):
82
+ lap = _laplacian_rev(u, nn_params, eq_params, x, t)[..., None]
83
+
84
+ return du_dt + self.Tmax * (
85
+ -eq_params["D"] * lap
86
+ - u(t, x, nn_params, eq_params)
87
+ * (eq_params["r"] - eq_params["g"] * u(t, x, nn_params, eq_params))
88
+ )
89
+ elif isinstance(u, SPINN):
90
+ nn_params, eq_params = self.set_stop_gradient(params)
91
+ x_grid = _get_grid(x)
92
+ eq_params = self._eval_heterogeneous_parameters(
93
+ eq_params, t, x_grid, self.eq_params_heterogeneity
94
+ )
95
+
96
+ u_tx, du_dt = jax.jvp(
97
+ lambda t: u(t, x, nn_params, eq_params),
98
+ (t,),
99
+ (jnp.ones_like(t),),
100
+ )
101
+ lap = _laplacian_fwd(u, nn_params, eq_params, x, t)[..., None]
102
+ return du_dt + self.Tmax * (
103
+ -eq_params["D"] * lap
104
+ - u_tx * (eq_params["r"][..., None] - eq_params["g"] * u_tx)
105
+ )
106
+
107
+
108
+ class Malthus(ODE):
738
109
  r"""
739
- Return the dynamic loss whose solution is the probability density
740
- function of a non-stationary Ornstein-Uhlenbeck process in one dimension:
110
+ Return a Malthus dynamic loss term following the PINN logic:
741
111
 
742
112
  .. math::
743
- -\frac{\partial}{\partial x}\left[\sin(x)u(x)\right] +
744
- \frac{\partial^2}{\partial x^2}\left[\frac{1}{2}u(x)\right] =
745
- \frac{\partial}{\partial t}u(t,x)
113
+ \frac{\partial}{\partial t} u(t)=ru(t)
746
114
 
747
115
  """
748
116
 
@@ -770,62 +138,57 @@ class Sinus_FPENonStatioLoss1D(FPENonStatioLoss1D):
770
138
  """
771
139
  super().__init__(Tmax, derivatives, eq_params_heterogeneity)
772
140
 
773
- def drift(self, t, x, eq_params):
774
- r"""
775
- Return the drift term
141
+ def evaluate(self, t, u, params):
142
+ """
143
+ Evaluate the dynamic loss at `t`.
144
+ For stability we implement the dynamic loss in log space.
776
145
 
777
146
  Parameters
778
- ----------
147
+ ---------
779
148
  t
780
149
  A time point
781
- x
782
- A point in :math:`\Omega`
783
- eq_params
784
- A dictionary containing the equation parameters
150
+ u
151
+ The PINN
152
+ params
153
+ The dictionary of parameters of the model.
154
+ Typically, it is a dictionary of
155
+ dictionaries: `eq_params` and `nn_params``, respectively the
156
+ differential equation parameters and the neural network parameter
785
157
  """
786
- return jnp.sin(x)
158
+ nn_params, eq_params = self.set_stop_gradient(params)
787
159
 
788
- def diffusion(self, t, x, eq_params):
789
- r"""
790
- Return the computation of the diffusion tensor term in 1D
160
+ eq_params = self._eval_heterogeneous_parameters(
161
+ eq_params, t, x, self.eq_params_heterogeneity
162
+ )
791
163
 
792
- Parameters
793
- ----------
794
- t
795
- A time point
796
- x
797
- A point in :math:`\Omega`
798
- eq_params
799
- A dictionary containing the equation parameters
800
- """
801
- return 0.5 * jnp.ones((1))
164
+ # NOTE the log formulation of the loss for stability
165
+ du_dt = grad(lambda t: jnp.log(u(t, nn_params, eq_params)), 0)(t)
166
+ return du_dt - eq_params["growth_rate"]
802
167
 
803
168
 
804
- class FPEStatioLoss2D(PDEStatio):
169
+ class BurgerEquation(PDENonStatio):
805
170
  r"""
806
- Return the dynamic loss for a stationary Fokker Planck Equation in two
807
- dimensions:
171
+ Return the Burger dynamic loss term (in 1 space dimension):
808
172
 
809
173
  .. math::
810
- -\sum_{i=1}^2\frac{\partial}{\partial \mathbf{x}}
811
- \left[\mu(\mathbf{x})u(\mathbf{x})\right] +
812
- \sum_{i=1}^2\sum_{j=1}^2\frac{\partial^2}{\partial x_i \partial x_j}
813
- \left[D(\mathbf{x})u(\mathbf{x})\right]=0
814
-
815
- where :math:`\mu(\mathbf{x})` is the drift term and :math:`D(\mathbf{x})` is the diffusion
816
- term.
174
+ \frac{\partial}{\partial t} u(t,x) + u(t,x)\frac{\partial}{\partial x}
175
+ u(t,x) - \theta \frac{\partial^2}{\partial x^2} u(t,x) = 0
817
176
 
818
- The drift and diffusion terms are not specified here, hence this class
819
- is `abstract`.
820
- Other classes inherit from FPEStatioLoss2D and define the drift and
821
- diffusion terms, which then defines several other dynamic losses
822
- (Ornstein-Uhlenbeck, Cox-Ingersoll-Ross, ...)
823
177
  """
824
178
 
825
- def __init__(self, derivatives="nn_params", eq_params_heterogeneity=None):
179
+ def __init__(
180
+ self,
181
+ Tmax=1,
182
+ derivatives="nn_params",
183
+ eq_params_heterogeneity=None,
184
+ ):
826
185
  """
827
186
  Parameters
828
187
  ----------
188
+ Tmax
189
+ Tmax needs to be given when the PINN time input is normalized in
190
+ [0, 1], ie. we have performed renormalization of the differential
191
+ equation
829
192
  derivatives
830
193
  A string. Either ``nn_params``, ``eq_params``, ``both``. Determines
831
194
  with respect to which set of parameters gradients of the dynamic
@@ -840,20 +203,16 @@ class FPEStatioLoss2D(PDEStatio):
840
203
  eq_params_heterogeneity is None this means there is no
841
204
  heterogeneity for no parameters.
842
205
  """
843
- super().__init__(derivatives, eq_params_heterogeneity)
206
+ super().__init__(Tmax, derivatives, eq_params_heterogeneity)
844
207
 
845
- def evaluate(self, x, u, params):
208
+ def evaluate(self, t, x, u, params):
846
209
  """
847
- Evaluate the dynamic loss at :math:`\mathbf{x}`.
848
-
849
- **Note:** For computational purpose we use compositions of calls to
850
- `jax.grad` instead of a call to `jax.hessian`
851
-
852
- **Note:** In practice this `u` is vectorized and :math:`\mathbf{x}` has a
853
- batch dimension.
210
+ Evaluate the dynamic loss at :math:`(t,x)`.
854
211
 
855
212
  Parameters
856
213
  ---------
214
+ t
215
+ A time point
857
216
  x
858
217
  A point in :math:`\Omega`
859
218
  u
@@ -864,73 +223,88 @@ class FPEStatioLoss2D(PDEStatio):
864
223
  dictionaries: `eq_params` and `nn_params``, respectively the
865
224
  differential equation parameters and the neural network parameter
866
225
  """
867
- nn_params, eq_params = self.set_stop_gradient(params)
226
+ if isinstance(u, PINN):
227
+ nn_params, eq_params = self.set_stop_gradient(params)
228
+ eq_params = self._eval_heterogeneous_parameters(
229
+ eq_params, t, x, self.eq_params_heterogeneity
230
+ )
868
231
 
869
- order_1 = (
870
- grad(
871
- lambda x: self.drift(x, eq_params)[0] * u(x, nn_params, eq_params),
872
- 0,
873
- )(x)[0]
874
- + grad(
875
- lambda x: self.drift(x, eq_params)[1] * u(x, nn_params, eq_params),
876
- 0,
877
- )(x)[1]
878
- )
232
+ # Note that the last dim of u is nec. 1
233
+ u_ = lambda t, x: u(t, x, nn_params, eq_params)[0]
234
+ du_dt = grad(u_, 0)
235
+ du_dx = grad(u_, 1)
236
+ d2u_dx2 = grad(
237
+ lambda t, x: du_dx(t, x)[0],
238
+ 1,
239
+ )
879
240
 
880
- order_2 = (
881
- grad(
882
- lambda x: grad(
883
- lambda x: u(x, nn_params, eq_params)
884
- * self.diffusion(x, eq_params)[0, 0],
885
- 0,
886
- )(x)[0],
887
- 0,
888
- )(x)[0]
889
- + grad(
890
- lambda x: grad(
891
- lambda x: u(x, nn_params, eq_params)
892
- * self.diffusion(x, eq_params)[1, 0],
893
- 0,
894
- )(x)[1],
895
- 0,
896
- )(x)[0]
897
- + grad(
898
- lambda x: grad(
899
- lambda x: u(x, nn_params, eq_params)
900
- * self.diffusion(x, eq_params)[0, 1],
901
- 0,
902
- )(x)[0],
903
- 0,
904
- )(x)[1]
905
- + grad(
906
- lambda x: grad(
907
- lambda x: u(x, nn_params, eq_params)
908
- * self.diffusion(x, eq_params)[1, 1],
909
- 0,
910
- )(x)[1],
911
- 0,
912
- )(x)[1]
913
- )
914
- return -order_1 + order_2
241
+ return du_dt(t, x) + self.Tmax * (
242
+ u(t, x, nn_params, eq_params) * du_dx(t, x)
243
+ - eq_params["nu"] * d2u_dx2(t, x)
244
+ )
915
245
 
246
+ elif isinstance(u, SPINN):
247
+ nn_params, eq_params = self.set_stop_gradient(params)
248
+ x_grid = _get_grid(x)
249
+ eq_params = self._eval_heterogeneous_parameters(
250
+ eq_params, t, x_grid, self.eq_params_heterogeneity
251
+ )
252
+ # d=2 JVP calls are expected since we have time and x
253
+ # then with a batch of size B, we then have Bd JVP calls
254
+ u_tx, du_dt = jax.jvp(
255
+ lambda t: u(t, x, nn_params, eq_params),
256
+ (t,),
257
+ (jnp.ones_like(t),),
258
+ )
259
+ du_dx_fun = lambda x: jax.jvp(
260
+ lambda x: u(t, x, nn_params, eq_params),
261
+ (x,),
262
+ (jnp.ones_like(x),),
263
+ )[1]
264
+ du_dx, d2u_dx2 = jax.jvp(du_dx_fun, (x,), (jnp.ones_like(x),))
265
+ # Note that ones_like(x) works because x is Bx1 !
266
+ return du_dt + self.Tmax * (u_tx * du_dx - eq_params["nu"] * d2u_dx2)
916
267
 
917
- class OU_FPEStatioLoss2D(FPEStatioLoss2D):
268
+
269
+ class GeneralizedLotkaVolterra(ODE):
918
270
  r"""
919
- Return the dynamic loss for a stationary Fokker Planck Equation in two
920
- dimensions:
271
+ Return a dynamic loss from an equation of a Generalized Lotka Volterra
272
+ system. Say we implement the equation for population :math:`i`:
921
273
 
922
274
  .. math::
923
- -\sum_{i=1}^2\frac{\partial}{\partial \mathbf{x}}
924
- \left[(\alpha(\mu - \mathbf{x}))u(\mathbf{x})\right] +
925
- \sum_{i=1}^2\sum_{j=1}^2\frac{\partial^2}{\partial x_i \partial x_j}
926
- \left[\frac{\sigma^2}{2}u(\mathbf{x})\right]=0
275
+ \frac{\partial}{\partial t}u_i(t) = r_iu_i(t) - \sum_{j\neq i}\alpha_{ij}u_j(t)
276
+ -\alpha_{i,i}u_i(t) + c_iu_i(t) + \sum_{j \neq i} c_ju_j(t)
277
+
278
+ with :math:`r_i` the growth rate parameter, :math:`c_i` the carrying
279
+ capacities and :math:`\alpha_{ij}` the interaction terms.
927
280
 
928
281
  """
929
282
 
930
- def __init__(self, derivatives="nn_params", eq_params_heterogeneity=None):
283
+ def __init__(
284
+ self,
285
+ key_main,
286
+ keys_other,
287
+ Tmax=1,
288
+ derivatives="nn_params",
289
+ eq_params_heterogeneity=None,
290
+ ):
931
291
  """
932
292
  Parameters
933
293
  ----------
294
+ key_main
295
+ The dictionary key (in the dictionaries ``u`` and ``params`` that
296
+ are arguments of the ``evaluate`` function) of the main population
297
+ :math:`i` of the particular equation of the system implemented
298
+ by this dynamic loss
299
+ keys_other
300
+ The list of dictionary keys (in the dictionaries ``u`` and ``params`` that
301
+ are arguments of the ``evaluate`` function) of the other
302
+ populations that appear in the equation of the system implemented
303
+ by this dynamic loss
304
+ Tmax
305
+ Tmax needs to be given when the PINN time input is normalized in
306
+ [0, 1], ie. we have performed renormalization of the differential
307
+ equation
934
308
  derivatives
935
309
  A string. Either ``nn_params``, ``eq_params``, ``both``. Determines
936
310
  with respect to which set of parameters gradients of the dynamic
@@ -945,52 +319,53 @@ class OU_FPEStatioLoss2D(FPEStatioLoss2D):
945
319
  eq_params_heterogeneity is None this means there is no
946
320
  heterogeneity for no parameters.
947
321
  """
948
- super().__init__(derivatives, eq_params_heterogeneity)
949
-
950
- def drift(self, x, eq_params):
951
- r"""
952
- Return the drift term
322
+ super().__init__(Tmax, derivatives, eq_params_heterogeneity)
323
+ self.key_main = key_main
324
+ self.keys_other = keys_other
953
325
 
954
- Parameters
955
- ----------
956
- x
957
- A point in :math:`\Omega`
958
- eq_params
959
- A dictionary containing the equation parameters
326
+ def evaluate(self, t, u_dict, params_dict):
960
327
  """
961
- return eq_params["alpha"] * (eq_params["mu"] - x)
962
-
963
- def sigma_mat(self, x, eq_params):
964
- r"""
965
- Return the square root of the diffusion tensor in the sense of the outer
966
- product used to create the diffusion tensor
328
+ Evaluate the dynamic loss at `t`.
329
+ For stability we implement the dynamic loss in log space.
967
330
 
968
331
  Parameters
969
- ----------
970
- x
971
- A point in :math:`\Omega`
972
- eq_params
973
- A dictionary containing the equation parameters
332
+ ---------
333
+ t
334
+ A time point
335
+ u_dict
336
+ A dictionary of PINNS. Must have the same keys as `params_dict`
337
+ params_dict
338
+ The dictionary of dictionaries of parameters of the model.
339
+ Typically, each sub-dictionary is a dictionary
340
+ with keys: `eq_params` and `nn_params``, respectively the
341
+ differential equation parameters and the neural network parameter.
342
+ Must have the same keys as `u_dict`
974
343
  """
975
- return jnp.diag(eq_params["sigma"])
344
+ nn_params, eq_params = self.set_stop_gradient(params_dict)
976
345
 
977
- def diffusion(self, x, eq_params):
978
- r"""
979
- Return the computation of the diffusion tensor term in 2D (or
980
- higher)
346
+ u_nn_params = nn_params[self.key_main]
347
+ u_eq_params = eq_params[self.key_main]
981
348
 
982
- Parameters
983
- ----------
984
- x
985
- A point in :math:`\Omega`
986
- eq_params
987
- A dictionary containing the equation parameters
988
- """
989
- return 0.5 * (
990
- jnp.matmul(
991
- self.sigma_mat(x, eq_params),
992
- jnp.transpose(self.sigma_mat(x, eq_params)),
349
+ u = u_dict[self.key_main]
350
+ du_dt = grad(lambda t: jnp.log(u(t, u_nn_params, u_eq_params)), 0)(t)
351
+ carrying_term = u_eq_params["carrying_capacity"] * u(
352
+ t, u_nn_params, u_eq_params
353
+ )
354
+ for k in self.keys_other:
355
+ carrying_term += u_eq_params["carrying_capacity"] * u_dict[k](
356
+ t, nn_params[k], eq_params[k]
357
+ )
358
+ # NOTE the following assumes interaction term with oneself is at idx 0
359
+ interaction_terms = u_eq_params["interactions"][0] * u(
360
+ t, u_nn_params, u_eq_params
361
+ )
362
+ for i, k in enumerate(self.keys_other):
363
+ interaction_terms += u_eq_params["interactions"][i + 1] * u_dict[k](
364
+ t, nn_params[k], eq_params[k]
993
365
  )
366
+
367
+ return du_dt + self.Tmax * (
368
+ -u_eq_params["growth_rate"] - interaction_terms + carrying_term
994
369
  )
995
370
 
996
371
 
@@ -1044,9 +419,6 @@ class FPENonStatioLoss2D(PDENonStatio):
1044
419
  """
1045
420
  Evaluate the dynamic loss at :math:`(t,\mathbf{x})`.
1046
421
 
1047
- **Note:** In practice this `u` is vectorized and `t` and
1048
- :math:`\mathbf{x}` have a batch dimension.
1049
-
1050
422
  Parameters
1051
423
  ---------
1052
424
  t
@@ -1061,59 +433,129 @@ class FPENonStatioLoss2D(PDENonStatio):
1061
433
  dictionaries: `eq_params` and `nn_params``, respectively the
1062
434
  differential equation parameters and the neural network parameter
1063
435
  """
1064
- nn_params, eq_params = self.set_stop_gradient(params)
436
+ if isinstance(u, PINN):
437
+ nn_params, eq_params = self.set_stop_gradient(params)
438
+ eq_params = self._eval_heterogeneous_parameters(
439
+ eq_params, t, x, self.eq_params_heterogeneity
440
+ )
1065
441
 
1066
- order_1 = (
1067
- grad(
1068
- lambda t, x: self.drift(t, x, eq_params)[0]
1069
- * u(t, x, nn_params, eq_params),
1070
- 1,
1071
- )(t, x)[0]
1072
- + grad(
1073
- lambda t, x: self.drift(t, x, eq_params)[1]
1074
- * u(t, x, nn_params, eq_params),
1075
- 1,
1076
- )(t, x)[1]
1077
- )
442
+ # Note that the last dim of u is nec. 1
443
+ u_ = lambda t, x: u(t, x, nn_params, eq_params)[0]
1078
444
 
1079
- order_2 = (
1080
- grad(
1081
- lambda t, x: grad(
1082
- lambda t, x: u(t, x, nn_params, eq_params)
1083
- * self.diffusion(t, x, eq_params)[0, 0],
445
+ order_1 = (
446
+ grad(
447
+ lambda t, x: self.drift(t, x, eq_params)[0] * u_(t, x),
1084
448
  1,
1085
- )(t, x)[0],
1086
- 1,
1087
- )(t, x)[0]
1088
- + grad(
1089
- lambda t, x: grad(
1090
- lambda t, x: u(t, x, nn_params, eq_params)
1091
- * self.diffusion(t, x, eq_params)[1, 0],
449
+ )(
450
+ t, x
451
+ )[0:1]
452
+ + grad(
453
+ lambda t, x: self.drift(t, x, eq_params)[1] * u_(t, x),
1092
454
  1,
1093
- )(t, x)[1],
1094
- 1,
1095
- )(t, x)[0]
1096
- + grad(
1097
- lambda t, x: grad(
1098
- lambda t, x: u(t, x, nn_params, eq_params)
1099
- * self.diffusion(t, x, eq_params)[0, 1],
455
+ )(
456
+ t, x
457
+ )[1:2]
458
+ )
459
+
460
+ order_2 = (
461
+ grad(
462
+ lambda t, x: grad(
463
+ lambda t, x: u_(t, x) * self.diffusion(t, x, eq_params)[0, 0],
464
+ 1,
465
+ )(t, x)[0],
1100
466
  1,
1101
- )(t, x)[0],
1102
- 1,
1103
- )(t, x)[1]
1104
- + grad(
1105
- lambda t, x: grad(
1106
- lambda t, x: u(t, x, nn_params, eq_params)
1107
- * self.diffusion(t, x, eq_params)[1, 1],
467
+ )(t, x)[0:1]
468
+ + grad(
469
+ lambda t, x: grad(
470
+ lambda t, x: u_(t, x) * self.diffusion(t, x, eq_params)[1, 0],
471
+ 1,
472
+ )(t, x)[1],
1108
473
  1,
1109
- )(t, x)[1],
1110
- 1,
1111
- )(t, x)[1]
1112
- )
474
+ )(t, x)[0:1]
475
+ + grad(
476
+ lambda t, x: grad(
477
+ lambda t, x: u_(t, x) * self.diffusion(t, x, eq_params)[0, 1],
478
+ 1,
479
+ )(t, x)[0],
480
+ 1,
481
+ )(t, x)[1:2]
482
+ + grad(
483
+ lambda t, x: grad(
484
+ lambda t, x: u_(t, x) * self.diffusion(t, x, eq_params)[1, 1],
485
+ 1,
486
+ )(t, x)[1],
487
+ 1,
488
+ )(t, x)[1:2]
489
+ )
490
+
491
+ du_dt = grad(u_, 0)(t, x)
492
+
493
+ return -du_dt + self.Tmax * (-order_1 + order_2)
494
+
495
+ elif isinstance(u, SPINN):
496
+ nn_params, eq_params = self.set_stop_gradient(params)
497
+ x_grid = _get_grid(x)
498
+ eq_params = self._eval_heterogeneous_parameters(
499
+ eq_params, t, x_grid, self.eq_params_heterogeneity
500
+ )
501
+
502
+ _, du_dt = jax.jvp(
503
+ lambda t: u(t, x, nn_params, eq_params),
504
+ (t,),
505
+ (jnp.ones_like(t),),
506
+ )
507
+
508
+ # in forward AD we do not have the results for all the input
509
+ # dimension at once (as it is the case with grad), we then write
510
+ # two jvp calls
511
+ tangent_vec_0 = jnp.repeat(jnp.array([1.0, 0.0])[None], x.shape[0], axis=0)
512
+ tangent_vec_1 = jnp.repeat(jnp.array([0.0, 1.0])[None], x.shape[0], axis=0)
513
+ _, dau_dx1 = jax.jvp(
514
+ lambda x: self.drift(t, _get_grid(x), eq_params)[None, ..., 0:1]
515
+ * u(t, x, nn_params, eq_params)[..., 0:1],
516
+ (x,),
517
+ (tangent_vec_0,),
518
+ )
519
+ _, dau_dx2 = jax.jvp(
520
+ lambda x: self.drift(t, _get_grid(x), eq_params)[None, ..., 1:2]
521
+ * u(t, x, nn_params, eq_params)[..., 0:1],
522
+ (x,),
523
+ (tangent_vec_1,),
524
+ )
1113
525
 
1114
- du_dt = grad(u, 0)(t, x, nn_params, eq_params)
526
+ dsu_dx1_fun = lambda x, i, j: jax.jvp(
527
+ lambda x: self.diffusion(t, _get_grid(x), eq_params, i, j)[
528
+ None, None, None, None
529
+ ]
530
+ * u(t, x, nn_params, eq_params)[..., 0:1],
531
+ (x,),
532
+ (tangent_vec_0,),
533
+ )[1]
534
+ dsu_dx2_fun = lambda x, i, j: jax.jvp(
535
+ lambda x: self.diffusion(t, _get_grid(x), eq_params, i, j)[
536
+ None, None, None, None
537
+ ]
538
+ * u(t, x, nn_params, eq_params)[..., 0:1],
539
+ (x,),
540
+ (tangent_vec_1,),
541
+ )[1]
542
+ _, d2su_dx12 = jax.jvp(
543
+ lambda x: dsu_dx1_fun(x, 0, 0), (x,), (tangent_vec_0,)
544
+ )
545
+ _, d2su_dx1dx2 = jax.jvp(
546
+ lambda x: dsu_dx1_fun(x, 0, 1), (x,), (tangent_vec_1,)
547
+ )
548
+ _, d2su_dx22 = jax.jvp(
549
+ lambda x: dsu_dx2_fun(x, 1, 1), (x,), (tangent_vec_1,)
550
+ )
551
+ _, d2su_dx2dx1 = jax.jvp(
552
+ lambda x: dsu_dx2_fun(x, 1, 0), (x,), (tangent_vec_0,)
553
+ )
1115
554
 
1116
- return -du_dt + self.Tmax * (-order_1 + order_2)
555
+ return -du_dt + self.Tmax * (
556
+ -(dau_dx1 + dau_dx2)
557
+ + (d2su_dx12 + d2su_dx22 + d2su_dx1dx2 + d2su_dx2dx1)
558
+ )
1117
559
 
1118
560
 
1119
561
  class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
@@ -1184,9 +626,10 @@ class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
1184
626
  eq_params
1185
627
  A dictionary containing the equation parameters
1186
628
  """
629
+
1187
630
  return jnp.diag(eq_params["sigma"])
1188
631
 
1189
- def diffusion(self, t, x, eq_params):
632
+ def diffusion(self, t, x, eq_params, i=None, j=None):
1190
633
  r"""
1191
634
  Return the computation of the diffusion tensor term in 2D (or
1192
635
  higher)
@@ -1200,12 +643,20 @@ class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
1200
643
  eq_params
1201
644
  A dictionary containing the equation parameters
1202
645
  """
1203
- return 0.5 * (
1204
- jnp.matmul(
1205
- self.sigma_mat(t, x, eq_params),
1206
- jnp.transpose(self.sigma_mat(t, x, eq_params)),
646
+ if i is None or j is None:
647
+ return 0.5 * (
648
+ jnp.matmul(
649
+ self.sigma_mat(t, x, eq_params),
650
+ jnp.transpose(self.sigma_mat(t, x, eq_params)),
651
+ )
652
+ )
653
+ else:
654
+ return 0.5 * (
655
+ jnp.matmul(
656
+ self.sigma_mat(t, x, eq_params),
657
+ jnp.transpose(self.sigma_mat(t, x, eq_params)),
658
+ )[i, j]
1207
659
  )
1208
- )
1209
660
 
1210
661
 
1211
662
  class ConvectionDiffusionNonStatio(FPENonStatioLoss2D):
@@ -1326,9 +777,6 @@ class MassConservation2DStatio(PDEStatio):
1326
777
  Evaluate the dynamic loss at `\mathbf{x}`.
1327
778
  For stability we implement the dynamic loss in log space.
1328
779
 
1329
- **Note:** In practice each `u` from `u_dict` is vectorized and
1330
- `\mathbf{x}` has a batch dimension.
1331
-
1332
780
  Parameters
1333
781
  ---------
1334
782
  x
@@ -1342,14 +790,25 @@ class MassConservation2DStatio(PDEStatio):
1342
790
  differential equation parameters and the neural network parameter.
1343
791
  Must have the same keys as `u_dict`
1344
792
  """
1345
- nn_params, eq_params = self.set_stop_gradient(params_dict)
793
+ if isinstance(u_dict[self.nn_key], PINN):
794
+ nn_params, eq_params = self.set_stop_gradient(params_dict)
795
+
796
+ nn_params = nn_params[self.nn_key]
797
+ eq_params = eq_params
1346
798
 
1347
- nn_params = nn_params[self.nn_key]
1348
- eq_params = eq_params
799
+ u = u_dict[self.nn_key]
1349
800
 
1350
- u = u_dict[self.nn_key]
801
+ return _div_rev(u, nn_params, eq_params, x)[..., None]
1351
802
 
1352
- return _div(u, nn_params, eq_params, x)
803
+ elif isinstance(u_dict[self.nn_key], SPINN):
804
+ nn_params, eq_params = self.set_stop_gradient(params_dict)
805
+
806
+ nn_params = nn_params[self.nn_key]
807
+ eq_params = eq_params
808
+
809
+ u = u_dict[self.nn_key]
810
+
811
+ return _div_fwd(u, nn_params, eq_params, x)[..., None]
1353
812
 
1354
813
 
1355
814
  class NavierStokes2DStatio(PDEStatio):
@@ -1422,9 +881,6 @@ class NavierStokes2DStatio(PDEStatio):
1422
881
  Evaluate the dynamic loss at `\mathbf{x}`.
1423
882
  For stability we implement the dynamic loss in log space.
1424
883
 
1425
- **Note:** In practice each `u` from `u_dict` is vectorized and
1426
- `\mathbf{x}` has a batch dimension.
1427
-
1428
884
  Parameters
1429
885
  ---------
1430
886
  x
@@ -1438,35 +894,77 @@ class NavierStokes2DStatio(PDEStatio):
1438
894
  differential equation parameters and the neural network parameter.
1439
895
  Must have the same keys as `u_dict`
1440
896
  """
1441
- nn_params, eq_params = self.set_stop_gradient(params_dict)
897
+ if isinstance(u_dict[self.u_key], PINN):
898
+ nn_params, eq_params = self.set_stop_gradient(params_dict)
1442
899
 
1443
- u_nn_params = nn_params[self.u_key]
1444
- p_nn_params = nn_params[self.p_key]
1445
- eq_params = eq_params
900
+ u_nn_params = nn_params[self.u_key]
901
+ p_nn_params = nn_params[self.p_key]
902
+ eq_params = eq_params
1446
903
 
1447
- u = u_dict[self.u_key]
904
+ u = u_dict[self.u_key]
1448
905
 
1449
- u_dot_nabla_x_u = _u_dot_nabla_times_u(u, u_nn_params, eq_params, x)
906
+ u_dot_nabla_x_u = _u_dot_nabla_times_u_rev(u, u_nn_params, eq_params, x)
1450
907
 
1451
- p = lambda x: u_dict[self.p_key](x, p_nn_params, eq_params)
1452
- jac_p = jacrev(p, 0)(x) # compute the gradient
908
+ p = lambda x: u_dict[self.p_key](x, p_nn_params, eq_params)
909
+ jac_p = jacrev(p, 0)(x) # compute the gradient
1453
910
 
1454
- vec_laplacian_u = _vectorial_laplacian(
1455
- u, u_nn_params, eq_params, x, u_vec_ndim=2
1456
- )
911
+ vec_laplacian_u = _vectorial_laplacian(
912
+ u, u_nn_params, eq_params, x, u_vec_ndim=2
913
+ )
1457
914
 
1458
- # dynamic loss on x axis
1459
- result_x = (
1460
- u_dot_nabla_x_u[0]
1461
- + 1 / eq_params["rho"] * jac_p[0]
1462
- - eq_params["nu"] * vec_laplacian_u[0]
1463
- )
1464
- # dynamic loss on y axis
1465
- result_y = (
1466
- u_dot_nabla_x_u[1]
1467
- + 1 / eq_params["rho"] * jac_p[1]
1468
- - eq_params["nu"] * vec_laplacian_u[1]
1469
- )
915
+ # dynamic loss on x axis
916
+ result_x = (
917
+ u_dot_nabla_x_u[0]
918
+ + 1 / eq_params["rho"] * jac_p[0, 0]
919
+ - eq_params["nu"] * vec_laplacian_u[0]
920
+ )
921
+
922
+ # dynamic loss on y axis
923
+ result_y = (
924
+ u_dot_nabla_x_u[1]
925
+ + 1 / eq_params["rho"] * jac_p[0, 1]
926
+ - eq_params["nu"] * vec_laplacian_u[1]
927
+ )
928
+
929
+ # output is 2D
930
+ return jnp.stack([result_x, result_y], axis=-1)
931
+
932
+ elif isinstance(u_dict[self.u_key], SPINN):
933
+ nn_params, eq_params = self.set_stop_gradient(params_dict)
934
+
935
+ u_nn_params = nn_params[self.u_key]
936
+ p_nn_params = nn_params[self.p_key]
937
+ eq_params = eq_params
938
+
939
+ u = u_dict[self.u_key]
940
+
941
+ u_dot_nabla_x_u = _u_dot_nabla_times_u_fwd(u, u_nn_params, eq_params, x)
942
+
943
+ p = lambda x: u_dict[self.p_key](x, p_nn_params, eq_params)
944
+
945
+ tangent_vec_0 = jnp.repeat(jnp.array([1.0, 0.0])[None], x.shape[0], axis=0)
946
+ _, dp_dx = jax.jvp(p, (x,), (tangent_vec_0,))
947
+ tangent_vec_1 = jnp.repeat(jnp.array([0.0, 1.0])[None], x.shape[0], axis=0)
948
+ _, dp_dy = jax.jvp(p, (x,), (tangent_vec_1,))
949
+
950
+ vec_laplacian_u = jnp.moveaxis(
951
+ _vectorial_laplacian(u, u_nn_params, eq_params, x, u_vec_ndim=2),
952
+ source=0,
953
+ destination=-1,
954
+ )
955
+
956
+ # dynamic loss on x axis
957
+ result_x = (
958
+ u_dot_nabla_x_u[..., 0]
959
+ + 1 / eq_params["rho"] * dp_dx.squeeze()
960
+ - eq_params["nu"] * vec_laplacian_u[..., 0]
961
+ )
962
+ # dynamic loss on y axis
963
+ result_y = (
964
+ u_dot_nabla_x_u[..., 1]
965
+ + 1 / eq_params["rho"] * dp_dy.squeeze()
966
+ - eq_params["nu"] * vec_laplacian_u[..., 1]
967
+ )
1470
968
 
1471
- # output is 2D
1472
- return jnp.stack([result_x, result_y], axis=-1)
969
+ # output is 2D
970
+ return jnp.stack([result_x, result_y], axis=-1)