jinns 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +114 -187
- jinns/loss/_DynamicLossAbstract.py +74 -69
- jinns/loss/_LossODE.py +132 -348
- jinns/loss/_LossPDE.py +262 -549
- jinns/loss/__init__.py +32 -6
- jinns/loss/_abstract_loss.py +128 -0
- jinns/loss/_boundary_conditions.py +20 -19
- jinns/loss/_loss_components.py +43 -0
- jinns/loss/_loss_utils.py +85 -179
- jinns/loss/_loss_weight_updates.py +202 -0
- jinns/loss/_loss_weights.py +64 -40
- jinns/loss/_operators.py +84 -74
- jinns/nn/__init__.py +15 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +94 -57
- jinns/nn/_mlp.py +50 -25
- jinns/nn/_pinn.py +33 -19
- jinns/nn/_ppinn.py +70 -34
- jinns/nn/_save_load.py +21 -51
- jinns/nn/_spinn.py +33 -16
- jinns/nn/_spinn_mlp.py +28 -22
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +35 -34
- jinns/solver/_rar.py +80 -63
- jinns/solver/_solve.py +207 -92
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +16 -10
- jinns/utils/_types.py +20 -54
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/METADATA +8 -4
- jinns-1.5.0.dist-info/RECORD +55 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns-1.3.0.dist-info/RECORD +0 -44
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/top_level.txt +0 -0
jinns/loss/_DynamicLoss.py
CHANGED
|
@@ -6,7 +6,7 @@ from __future__ import (
|
|
|
6
6
|
annotations,
|
|
7
7
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
8
|
|
|
9
|
-
from typing import TYPE_CHECKING
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
10
10
|
from jaxtyping import Float
|
|
11
11
|
import jax
|
|
12
12
|
from jax import grad
|
|
@@ -29,10 +29,11 @@ from jinns.loss._operators import (
|
|
|
29
29
|
_u_dot_nabla_times_u_fwd,
|
|
30
30
|
)
|
|
31
31
|
|
|
32
|
-
from jaxtyping import Array
|
|
32
|
+
from jaxtyping import Array
|
|
33
33
|
|
|
34
34
|
if TYPE_CHECKING:
|
|
35
|
-
from jinns.parameters import Params
|
|
35
|
+
from jinns.parameters import Params
|
|
36
|
+
from jinns.nn._abstract_pinn import AbstractPINN
|
|
36
37
|
|
|
37
38
|
|
|
38
39
|
class FisherKPP(PDENonStatio):
|
|
@@ -54,10 +55,10 @@ class FisherKPP(PDENonStatio):
|
|
|
54
55
|
|
|
55
56
|
def equation(
|
|
56
57
|
self,
|
|
57
|
-
t_x: Float[Array, "1+dim"],
|
|
58
|
-
u:
|
|
59
|
-
params: Params,
|
|
60
|
-
) -> Float[Array, "1"]:
|
|
58
|
+
t_x: Float[Array, " 1+dim"],
|
|
59
|
+
u: AbstractPINN,
|
|
60
|
+
params: Params[Array],
|
|
61
|
+
) -> Float[Array, " 1"]:
|
|
61
62
|
r"""
|
|
62
63
|
Evaluate the dynamic loss at $(t, x)$.
|
|
63
64
|
|
|
@@ -74,13 +75,14 @@ class FisherKPP(PDENonStatio):
|
|
|
74
75
|
dictionaries: `eq_params` and `nn_params`, respectively the
|
|
75
76
|
differential equation parameters and the neural network parameter
|
|
76
77
|
"""
|
|
78
|
+
assert u.eq_type != "ODE", "Cannot compute the loss for ODE PINNs"
|
|
77
79
|
if isinstance(u, PINN):
|
|
78
80
|
# Note that the last dim of u is nec. 1
|
|
79
81
|
u_ = lambda t_x: u(t_x, params)[0]
|
|
80
82
|
|
|
81
83
|
du_dt = grad(u_)(t_x)[0]
|
|
82
84
|
|
|
83
|
-
lap = laplacian_rev(t_x, u, params
|
|
85
|
+
lap = laplacian_rev(t_x, u, params)[..., None]
|
|
84
86
|
|
|
85
87
|
return du_dt + self.Tmax * (
|
|
86
88
|
-params.eq_params["D"] * lap
|
|
@@ -96,7 +98,7 @@ class FisherKPP(PDENonStatio):
|
|
|
96
98
|
(t_x,),
|
|
97
99
|
(v0,),
|
|
98
100
|
)
|
|
99
|
-
lap = laplacian_fwd(t_x, u, params
|
|
101
|
+
lap = laplacian_fwd(t_x, u, params)
|
|
100
102
|
|
|
101
103
|
return du_dt + self.Tmax * (
|
|
102
104
|
-params.eq_params["D"] * lap
|
|
@@ -108,14 +110,14 @@ class FisherKPP(PDENonStatio):
|
|
|
108
110
|
class GeneralizedLotkaVolterra(ODE):
|
|
109
111
|
r"""
|
|
110
112
|
Return a dynamic loss from an equation of a Generalized Lotka Volterra
|
|
111
|
-
system. Say we implement the
|
|
113
|
+
system. Say we implement the system of equations, for several populations $i$
|
|
112
114
|
|
|
113
115
|
$$
|
|
114
|
-
|
|
115
|
-
|
|
116
|
+
\frac{\partial}{\partial t}u_i(t) = \alpha_iu_i(t)
|
|
117
|
+
-\sum_j\gamma_{j,i}u_j(t) - \beta_i\sum_{i'}u_{i'}(t), i\in\{1, 2, 3\}
|
|
116
118
|
$$
|
|
117
|
-
|
|
118
|
-
|
|
119
|
+
|
|
120
|
+
where $\alpha$ are the growth rates, $\gamma$ are the interactions terms and $\beta$ and the capacity terms.
|
|
119
121
|
|
|
120
122
|
Parameters
|
|
121
123
|
----------
|
|
@@ -142,16 +144,12 @@ class GeneralizedLotkaVolterra(ODE):
|
|
|
142
144
|
heterogeneity for no parameters.
|
|
143
145
|
"""
|
|
144
146
|
|
|
145
|
-
# they should be static because they are list of strings
|
|
146
|
-
key_main: list[str] = eqx.field(static=True)
|
|
147
|
-
keys_other: list[str] = eqx.field(static=True)
|
|
148
|
-
|
|
149
147
|
def equation(
|
|
150
148
|
self,
|
|
151
|
-
t: Float[Array, "1"],
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
) -> Float[Array, "1"]:
|
|
149
|
+
t: Float[Array, " 1"],
|
|
150
|
+
u: AbstractPINN,
|
|
151
|
+
params: Params[Array],
|
|
152
|
+
) -> Float[Array, " 1"]:
|
|
155
153
|
"""
|
|
156
154
|
Evaluate the dynamic loss at `t`.
|
|
157
155
|
For stability we implement the dynamic loss in log space.
|
|
@@ -160,33 +158,23 @@ class GeneralizedLotkaVolterra(ODE):
|
|
|
160
158
|
---------
|
|
161
159
|
t
|
|
162
160
|
A time point
|
|
163
|
-
|
|
164
|
-
A
|
|
165
|
-
|
|
166
|
-
The
|
|
167
|
-
top level are "nn_params" and "eq_params"
|
|
161
|
+
u
|
|
162
|
+
A vectorial PINN with as many outputs as there are populations
|
|
163
|
+
params
|
|
164
|
+
The parameters in a Params object
|
|
168
165
|
"""
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
carrying_term += params_main.eq_params["carrying_capacity"] * u_dict[k](
|
|
182
|
-
t, params_k
|
|
183
|
-
)
|
|
184
|
-
interaction_terms += params_main.eq_params["interactions"][i + 1] * u_dict[
|
|
185
|
-
k
|
|
186
|
-
](t, params_k)
|
|
187
|
-
|
|
188
|
-
return du_dt + self.Tmax * (
|
|
189
|
-
-params_main.eq_params["growth_rate"] - interaction_terms + carrying_term
|
|
166
|
+
du_dt = jax.jacrev(lambda t: jnp.log(u(t, params)))(t)
|
|
167
|
+
carrying_term = params.eq_params["carrying_capacities"] * jnp.sum(u(t, params))
|
|
168
|
+
interactions_terms = jax.tree.map(
|
|
169
|
+
lambda interactions_for_i: jnp.sum(
|
|
170
|
+
interactions_for_i * u(t, params).squeeze()
|
|
171
|
+
),
|
|
172
|
+
params.eq_params["interactions"],
|
|
173
|
+
is_leaf=eqx.is_array,
|
|
174
|
+
)
|
|
175
|
+
interactions_terms = jnp.array([*(interactions_terms)])
|
|
176
|
+
return du_dt.squeeze() + self.Tmax * (
|
|
177
|
+
-params.eq_params["growth_rates"] + interactions_terms + carrying_term
|
|
190
178
|
)
|
|
191
179
|
|
|
192
180
|
|
|
@@ -216,10 +204,10 @@ class BurgersEquation(PDENonStatio):
|
|
|
216
204
|
|
|
217
205
|
def equation(
|
|
218
206
|
self,
|
|
219
|
-
t_x: Float[Array, "1+dim"],
|
|
220
|
-
u:
|
|
221
|
-
params: Params,
|
|
222
|
-
) -> Float[Array, "1"]:
|
|
207
|
+
t_x: Float[Array, " 1+dim"],
|
|
208
|
+
u: AbstractPINN,
|
|
209
|
+
params: Params[Array],
|
|
210
|
+
) -> Float[Array, " 1"]:
|
|
223
211
|
r"""
|
|
224
212
|
Evaluate the dynamic loss at :math:`(t,x)`.
|
|
225
213
|
|
|
@@ -312,10 +300,10 @@ class FPENonStatioLoss2D(PDENonStatio):
|
|
|
312
300
|
|
|
313
301
|
def equation(
|
|
314
302
|
self,
|
|
315
|
-
t_x: Float[Array, "1+dim"],
|
|
316
|
-
u:
|
|
317
|
-
params: Params,
|
|
318
|
-
) -> Float[Array, "1"]:
|
|
303
|
+
t_x: Float[Array, " 1+dim"],
|
|
304
|
+
u: AbstractPINN,
|
|
305
|
+
params: Params[Array],
|
|
306
|
+
) -> Float[Array, " 1"]:
|
|
319
307
|
r"""
|
|
320
308
|
Evaluate the dynamic loss at $(t,\mathbf{x})$.
|
|
321
309
|
|
|
@@ -506,94 +494,36 @@ class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
|
|
|
506
494
|
jnp.transpose(self.sigma_mat(x, eq_params)),
|
|
507
495
|
)
|
|
508
496
|
)
|
|
509
|
-
return
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
jnp.
|
|
513
|
-
|
|
497
|
+
return (
|
|
498
|
+
0.5
|
|
499
|
+
* (
|
|
500
|
+
jnp.matmul(
|
|
501
|
+
self.sigma_mat(x, eq_params),
|
|
502
|
+
jnp.transpose(self.sigma_mat(x, eq_params)),
|
|
503
|
+
)[i, j]
|
|
504
|
+
)
|
|
514
505
|
)
|
|
515
506
|
|
|
516
507
|
|
|
517
|
-
class
|
|
508
|
+
class NavierStokesMassConservation2DStatio(PDEStatio):
|
|
518
509
|
r"""
|
|
519
|
-
|
|
510
|
+
Dynamic loss for the system of equations (stationary Navier Stokes in 2D,
|
|
511
|
+
mass conservation). For a 2D velocity field $\mathbf{u}$ and a 1D pressure
|
|
512
|
+
field $p$, this reads:
|
|
520
513
|
|
|
521
514
|
$$
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
Parameters
|
|
529
|
-
----------
|
|
530
|
-
nn_key
|
|
531
|
-
A dictionary key which identifies, in `u_dict` the PINN that
|
|
532
|
-
appears in the mass conservation equation.
|
|
533
|
-
eq_params_heterogeneity
|
|
534
|
-
Default None. A dict with the keys being the same as in eq_params
|
|
535
|
-
and the value being `time`, `space`, `both` or None which corresponds to
|
|
536
|
-
the heterogeneity of a given parameter. A value can be missing, in
|
|
537
|
-
this case there is no heterogeneity (=None). If
|
|
538
|
-
eq_params_heterogeneity is None this means there is no
|
|
539
|
-
heterogeneity for no parameters.
|
|
540
|
-
"""
|
|
541
|
-
|
|
542
|
-
# an str field should be static (not a valid JAX type)
|
|
543
|
-
nn_key: str = eqx.field(static=True)
|
|
544
|
-
|
|
545
|
-
def equation(
|
|
546
|
-
self,
|
|
547
|
-
x: Float[Array, "dim"],
|
|
548
|
-
u_dict: Dict[str, eqx.Module],
|
|
549
|
-
params_dict: ParamsDict,
|
|
550
|
-
) -> Float[Array, "1"]:
|
|
551
|
-
r"""
|
|
552
|
-
Evaluate the dynamic loss at `\mathbf{x}`.
|
|
553
|
-
For stability we implement the dynamic loss in log space.
|
|
554
|
-
|
|
555
|
-
Parameters
|
|
556
|
-
---------
|
|
557
|
-
x
|
|
558
|
-
A point in $\Omega\subset\mathbb{R}^2$
|
|
559
|
-
u_dict
|
|
560
|
-
A dictionary of PINNs. Must have the same keys as `params_dict`
|
|
561
|
-
params_dict
|
|
562
|
-
The dictionary of dictionaries of parameters of the model.
|
|
563
|
-
Typically, each sub-dictionary is a dictionary
|
|
564
|
-
with keys: `eq_params` and `nn_params`, respectively the
|
|
565
|
-
differential equation parameters and the neural network parameter.
|
|
566
|
-
Must have the same keys as `u_dict`
|
|
567
|
-
"""
|
|
568
|
-
params = params_dict.extract_params(self.nn_key)
|
|
569
|
-
|
|
570
|
-
if isinstance(u_dict[self.nn_key], PINN):
|
|
571
|
-
u = u_dict[self.nn_key]
|
|
572
|
-
|
|
573
|
-
return divergence_rev(x, u, params)[..., None]
|
|
574
|
-
|
|
575
|
-
if isinstance(u_dict[self.nn_key], SPINN):
|
|
576
|
-
u = u_dict[self.nn_key]
|
|
577
|
-
|
|
578
|
-
return divergence_fwd(x, u, params)[..., None]
|
|
579
|
-
raise ValueError("u is not among the recognized types (PINN or SPINN)")
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
class NavierStokes2DStatio(PDEStatio):
|
|
583
|
-
r"""
|
|
584
|
-
Return the dynamic loss for all the components of the stationary Navier Stokes
|
|
585
|
-
equation which is a 2D vectorial PDE.
|
|
586
|
-
|
|
587
|
-
$$
|
|
588
|
-
(\mathbf{u}\cdot\nabla)\mathbf{u} + \frac{1}{\rho}\nabla p - \theta
|
|
589
|
-
\nabla^2\mathbf{u}=0,
|
|
515
|
+
\begin{cases}
|
|
516
|
+
&(\mathbf{u}\cdot\nabla)\mathbf{u} + \frac{1}{\rho}\nabla p - \theta
|
|
517
|
+
\nabla^2\mathbf{u}=0,\\
|
|
518
|
+
&\nabla\cdot\mathbf{u}=0,
|
|
519
|
+
\end{cases}
|
|
590
520
|
$$
|
|
591
521
|
|
|
592
522
|
or, in 2D,
|
|
593
523
|
|
|
594
|
-
|
|
595
524
|
$$
|
|
596
|
-
|
|
525
|
+
\begin{cases}
|
|
526
|
+
&\begin{pmatrix}u_x\frac{\partial}{\partial x} u_x +
|
|
597
527
|
u_y\frac{\partial}{\partial y} u_x, \\
|
|
598
528
|
u_x\frac{\partial}{\partial x} u_y + u_y\frac{\partial}{\partial y} u_y \end{pmatrix} +
|
|
599
529
|
\frac{1}{\rho} \begin{pmatrix} \frac{\partial}{\partial x} p, \\ \frac{\partial}{\partial y} p \end{pmatrix}
|
|
@@ -602,24 +532,15 @@ class NavierStokes2DStatio(PDEStatio):
|
|
|
602
532
|
\frac{\partial^2}{\partial x^2} u_x + \frac{\partial^2}{\partial y^2}
|
|
603
533
|
u_x, \\
|
|
604
534
|
\frac{\partial^2}{\partial x^2} u_y + \frac{\partial^2}{\partial y^2} u_y
|
|
605
|
-
\end{pmatrix} = 0
|
|
535
|
+
\end{pmatrix} = 0,\\
|
|
536
|
+
&\frac{\partial}{\partial x}u(x,y) +
|
|
537
|
+
\frac{\partial}{\partial y}u(x,y) = 0,
|
|
538
|
+
\end{cases}
|
|
606
539
|
$$
|
|
607
540
|
with $\theta$ the viscosity coefficient and $\rho$ the density coefficient.
|
|
608
541
|
|
|
609
542
|
Parameters
|
|
610
543
|
----------
|
|
611
|
-
u_key
|
|
612
|
-
A dictionary key which indices the NN u in `u_dict`
|
|
613
|
-
the PINN with the role of the velocity in the equation.
|
|
614
|
-
Its input is bimensional (points in $\Omega\subset\mathbb{R}^2$).
|
|
615
|
-
Its output is bimensional as it represents a velocity vector
|
|
616
|
-
field
|
|
617
|
-
p_key
|
|
618
|
-
A dictionary key which indices the NN p in `u_dict`
|
|
619
|
-
the PINN with the role of the pressure in the equation.
|
|
620
|
-
Its input is bimensional (points in $\Omega\subset\mathbb{R}^2).
|
|
621
|
-
Its output is unidimensional as it represents a pressure scalar
|
|
622
|
-
field
|
|
623
544
|
eq_params_heterogeneity
|
|
624
545
|
Default None. A dict with the keys being the same as in eq_params
|
|
625
546
|
and the value being `time`, `space`, `both` or None which corresponds to
|
|
@@ -629,15 +550,12 @@ class NavierStokes2DStatio(PDEStatio):
|
|
|
629
550
|
heterogeneity for no parameters.
|
|
630
551
|
"""
|
|
631
552
|
|
|
632
|
-
u_key: str = eqx.field(static=True)
|
|
633
|
-
p_key: str = eqx.field(static=True)
|
|
634
|
-
|
|
635
553
|
def equation(
|
|
636
554
|
self,
|
|
637
|
-
x: Float[Array, "dim"],
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
) -> Float[Array, "
|
|
555
|
+
x: Float[Array, " dim"],
|
|
556
|
+
u_p: AbstractPINN,
|
|
557
|
+
params: Params[Array],
|
|
558
|
+
) -> Float[Array, " 3"]:
|
|
641
559
|
r"""
|
|
642
560
|
Evaluate the dynamic loss at `x`.
|
|
643
561
|
For stability we implement the dynamic loss in log space.
|
|
@@ -646,72 +564,81 @@ class NavierStokes2DStatio(PDEStatio):
|
|
|
646
564
|
---------
|
|
647
565
|
x
|
|
648
566
|
A point in $\Omega\subset\mathbb{R}^2$
|
|
649
|
-
|
|
650
|
-
A
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
with keys: `eq_params` and `nn_params`, respectively the
|
|
655
|
-
differential equation parameters and the neural network parameter.
|
|
656
|
-
Must have the same keys as `u_dict`
|
|
567
|
+
u_p
|
|
568
|
+
A PINN with 3 outputs of the 2D velocity field and the 1D pressure
|
|
569
|
+
field
|
|
570
|
+
params
|
|
571
|
+
The parameters in a Params object
|
|
657
572
|
"""
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
u = u_dict[self.u_key]
|
|
573
|
+
assert u_p.eq_type != "ODE", "Cannot compute the loss for ODE PINNs"
|
|
574
|
+
if isinstance(u_p, PINN):
|
|
575
|
+
u = lambda x, params: u_p(x, params)[0:2]
|
|
576
|
+
p = lambda x, params: u_p(x, params)[2:3]
|
|
663
577
|
|
|
664
|
-
|
|
578
|
+
# NAVIER STOKES
|
|
579
|
+
u_dot_nabla_x_u = _u_dot_nabla_times_u_rev(x, u, params)
|
|
665
580
|
|
|
666
|
-
|
|
667
|
-
jac_p = jax.jacrev(p, 0)(x) # compute the gradient
|
|
581
|
+
jac_p = jax.jacrev(p, 0)(x, params) # compute the gradient
|
|
668
582
|
|
|
669
|
-
vec_laplacian_u = vectorial_laplacian_rev(
|
|
583
|
+
vec_laplacian_u = vectorial_laplacian_rev(
|
|
584
|
+
x, u, params, dim_out=2, eq_type=u_p.eq_type
|
|
585
|
+
)
|
|
670
586
|
|
|
671
587
|
# dynamic loss on x axis
|
|
672
588
|
result_x = (
|
|
673
589
|
u_dot_nabla_x_u[0]
|
|
674
|
-
+ 1 /
|
|
675
|
-
-
|
|
590
|
+
+ 1 / params.eq_params["rho"] * jac_p[0, 0]
|
|
591
|
+
- params.eq_params["nu"] * vec_laplacian_u[0]
|
|
676
592
|
)
|
|
677
593
|
|
|
678
594
|
# dynamic loss on y axis
|
|
679
595
|
result_y = (
|
|
680
596
|
u_dot_nabla_x_u[1]
|
|
681
|
-
+ 1 /
|
|
682
|
-
-
|
|
597
|
+
+ 1 / params.eq_params["rho"] * jac_p[0, 1]
|
|
598
|
+
- params.eq_params["nu"] * vec_laplacian_u[1]
|
|
683
599
|
)
|
|
684
600
|
|
|
685
|
-
#
|
|
686
|
-
return jnp.stack([result_x, result_y], axis=-1)
|
|
601
|
+
# MASS CONVERVATION
|
|
687
602
|
|
|
688
|
-
|
|
689
|
-
u = u_dict[self.u_key]
|
|
603
|
+
mc = divergence_rev(x, u, params, eq_type=u_p.eq_type)
|
|
690
604
|
|
|
691
|
-
|
|
605
|
+
# output is 3D
|
|
606
|
+
if mc.ndim == 0 and not result_x.ndim == 0:
|
|
607
|
+
mc = mc[None]
|
|
608
|
+
return jnp.stack([result_x, result_y, mc], axis=-1)
|
|
692
609
|
|
|
693
|
-
|
|
610
|
+
if isinstance(u_p, SPINN):
|
|
611
|
+
u = lambda x, params: u_p(x, params)[..., 0:2]
|
|
612
|
+
p = lambda x: u_p(x, params)[..., 2:3]
|
|
613
|
+
|
|
614
|
+
# NAVIER STOKES
|
|
615
|
+
u_dot_nabla_x_u = _u_dot_nabla_times_u_fwd(x, u, params)
|
|
694
616
|
|
|
695
617
|
tangent_vec_0 = jnp.repeat(jnp.array([1.0, 0.0])[None], x.shape[0], axis=0)
|
|
696
618
|
_, dp_dx = jax.jvp(p, (x,), (tangent_vec_0,))
|
|
697
619
|
tangent_vec_1 = jnp.repeat(jnp.array([0.0, 1.0])[None], x.shape[0], axis=0)
|
|
698
620
|
_, dp_dy = jax.jvp(p, (x,), (tangent_vec_1,))
|
|
699
621
|
|
|
700
|
-
vec_laplacian_u = vectorial_laplacian_fwd(
|
|
622
|
+
vec_laplacian_u = vectorial_laplacian_fwd(
|
|
623
|
+
x, u, params, dim_out=2, eq_type=u_p.eq_type
|
|
624
|
+
)
|
|
701
625
|
|
|
702
626
|
# dynamic loss on x axis
|
|
703
627
|
result_x = (
|
|
704
628
|
u_dot_nabla_x_u[..., 0]
|
|
705
|
-
+ 1 /
|
|
706
|
-
-
|
|
629
|
+
+ 1 / params.eq_params["rho"] * dp_dx.squeeze()
|
|
630
|
+
- params.eq_params["nu"] * vec_laplacian_u[..., 0]
|
|
707
631
|
)
|
|
708
632
|
# dynamic loss on y axis
|
|
709
633
|
result_y = (
|
|
710
634
|
u_dot_nabla_x_u[..., 1]
|
|
711
|
-
+ 1 /
|
|
712
|
-
-
|
|
635
|
+
+ 1 / params.eq_params["rho"] * dp_dy.squeeze()
|
|
636
|
+
- params.eq_params["nu"] * vec_laplacian_u[..., 1]
|
|
713
637
|
)
|
|
714
638
|
|
|
715
|
-
#
|
|
716
|
-
|
|
639
|
+
# MASS CONVERVATION
|
|
640
|
+
mc = divergence_fwd(x, u, params, eq_type=u_p.eq_type)[..., None]
|
|
641
|
+
|
|
642
|
+
# output is (..., 3)
|
|
643
|
+
return jnp.stack([result_x[..., None], result_y[..., None], mc], axis=-1)
|
|
717
644
|
raise ValueError("u is not among the recognized types (PINN or SPINN)")
|