jinns 0.9.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -0
- jinns/data/_Batchs.py +27 -0
- jinns/data/_DataGenerators.py +904 -1203
- jinns/data/__init__.py +4 -8
- jinns/experimental/__init__.py +0 -2
- jinns/experimental/_diffrax_solver.py +5 -5
- jinns/loss/_DynamicLoss.py +282 -305
- jinns/loss/_DynamicLossAbstract.py +321 -168
- jinns/loss/_LossODE.py +292 -309
- jinns/loss/_LossPDE.py +625 -1010
- jinns/loss/__init__.py +21 -5
- jinns/loss/_boundary_conditions.py +87 -41
- jinns/loss/{_Losses.py → _loss_utils.py} +95 -44
- jinns/loss/_loss_weights.py +59 -0
- jinns/loss/_operators.py +78 -72
- jinns/parameters/__init__.py +6 -0
- jinns/parameters/_derivative_keys.py +94 -0
- jinns/parameters/_params.py +115 -0
- jinns/plot/__init__.py +5 -0
- jinns/{data/_display.py → plot/_plot.py} +98 -75
- jinns/solver/_rar.py +183 -39
- jinns/solver/_solve.py +151 -124
- jinns/utils/__init__.py +3 -9
- jinns/utils/_containers.py +37 -44
- jinns/utils/_hyperpinn.py +224 -119
- jinns/utils/_pinn.py +183 -111
- jinns/utils/_save_load.py +121 -56
- jinns/utils/_spinn.py +113 -86
- jinns/utils/_types.py +64 -0
- jinns/utils/_utils.py +6 -160
- jinns/validation/_validation.py +48 -140
- {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/METADATA +4 -4
- jinns-1.0.0.dist-info/RECORD +38 -0
- {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/WHEEL +1 -1
- jinns/experimental/_sinuspinn.py +0 -135
- jinns/experimental/_spectralpinn.py +0 -87
- jinns/solver/_seq2seq.py +0 -157
- jinns/utils/_optim.py +0 -147
- jinns/utils/_utils_uspinn.py +0 -727
- jinns-0.9.0.dist-info/RECORD +0 -36
- {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/LICENSE +0 -0
- {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/top_level.txt +0 -0
jinns/loss/_DynamicLoss.py
CHANGED
|
@@ -2,12 +2,21 @@
|
|
|
2
2
|
Implements several dynamic losses
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
from __future__ import (
|
|
6
|
+
annotations,
|
|
7
|
+
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING, Dict
|
|
10
|
+
from jaxtyping import Float
|
|
5
11
|
import jax
|
|
6
|
-
from jax import grad
|
|
12
|
+
from jax import grad
|
|
7
13
|
import jax.numpy as jnp
|
|
8
|
-
|
|
14
|
+
import equinox as eqx
|
|
15
|
+
|
|
9
16
|
from jinns.utils._pinn import PINN
|
|
10
17
|
from jinns.utils._spinn import SPINN
|
|
18
|
+
|
|
19
|
+
from jinns.utils._utils import _get_grid
|
|
11
20
|
from jinns.loss._DynamicLossAbstract import ODE, PDEStatio, PDENonStatio
|
|
12
21
|
from jinns.loss._operators import (
|
|
13
22
|
_laplacian_rev,
|
|
@@ -19,52 +28,44 @@ from jinns.loss._operators import (
|
|
|
19
28
|
_u_dot_nabla_times_u_fwd,
|
|
20
29
|
)
|
|
21
30
|
|
|
31
|
+
from jaxtyping import Array, Float
|
|
32
|
+
|
|
33
|
+
if TYPE_CHECKING:
|
|
34
|
+
from jinns.parameters import Params, ParamsDict
|
|
35
|
+
|
|
22
36
|
|
|
23
37
|
class FisherKPP(PDENonStatio):
|
|
24
38
|
r"""
|
|
25
|
-
Return the Fisher KPP dynamic loss term. Dimension of
|
|
39
|
+
Return the Fisher KPP dynamic loss term. Dimension of $x$ can be
|
|
26
40
|
arbitrary
|
|
27
41
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
42
|
+
$$
|
|
43
|
+
\frac{\partial}{\partial t} u(t,x)=D\Delta u(t,x) + u(t,x)(r(x) - \gamma(x)u(t,x))
|
|
44
|
+
$$
|
|
31
45
|
"""
|
|
32
46
|
|
|
33
|
-
def
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
equation
|
|
41
|
-
eq_params_heterogeneity
|
|
42
|
-
Default None. A dict with the keys being the same as in eq_params
|
|
43
|
-
and the value being `time`, `space`, `both` or None which corresponds to
|
|
44
|
-
the heterogeneity of a given parameter. A value can be missing, in
|
|
45
|
-
this case there is no heterogeneity (=None). If
|
|
46
|
-
eq_params_heterogeneity is None this means there is no
|
|
47
|
-
heterogeneity for no parameters.
|
|
48
|
-
"""
|
|
49
|
-
super().__init__(Tmax, eq_params_heterogeneity)
|
|
50
|
-
|
|
51
|
-
@PDENonStatio.evaluate_heterogeneous_parameters
|
|
52
|
-
def evaluate(self, t, x, u, params):
|
|
47
|
+
def equation(
|
|
48
|
+
self,
|
|
49
|
+
t: Float[Array, "1"],
|
|
50
|
+
x: Float[Array, "dim"],
|
|
51
|
+
u: eqx.Module,
|
|
52
|
+
params: Params,
|
|
53
|
+
) -> Float[Array, "1"]:
|
|
53
54
|
r"""
|
|
54
|
-
Evaluate the dynamic loss at
|
|
55
|
+
Evaluate the dynamic loss at $(t,x)$.
|
|
55
56
|
|
|
56
57
|
Parameters
|
|
57
58
|
---------
|
|
58
59
|
t
|
|
59
|
-
A time point
|
|
60
|
+
A time point.
|
|
60
61
|
x
|
|
61
|
-
A point in
|
|
62
|
+
A point in $\Omega$.
|
|
62
63
|
u
|
|
63
64
|
The PINN
|
|
64
65
|
params
|
|
65
66
|
The dictionary of parameters of the model.
|
|
66
67
|
Typically, it is a dictionary of
|
|
67
|
-
dictionaries: `eq_params` and `nn_params
|
|
68
|
+
dictionaries: `eq_params` and `nn_params`, respectively the
|
|
68
69
|
differential equation parameters and the neural network parameter
|
|
69
70
|
"""
|
|
70
71
|
if isinstance(u, PINN):
|
|
@@ -76,12 +77,9 @@ class FisherKPP(PDENonStatio):
|
|
|
76
77
|
lap = _laplacian_rev(t, x, u, params)[..., None]
|
|
77
78
|
|
|
78
79
|
return du_dt + self.Tmax * (
|
|
79
|
-
-params
|
|
80
|
+
-params.eq_params["D"] * lap
|
|
80
81
|
- u(t, x, params)
|
|
81
|
-
* (
|
|
82
|
-
params["eq_params"]["r"]
|
|
83
|
-
- params["eq_params"]["g"] * u(t, x, params)
|
|
84
|
-
)
|
|
82
|
+
* (params.eq_params["r"] - params.eq_params["g"] * u(t, x, params))
|
|
85
83
|
)
|
|
86
84
|
if isinstance(u, SPINN):
|
|
87
85
|
u_tx, du_dt = jax.jvp(
|
|
@@ -91,49 +89,129 @@ class FisherKPP(PDENonStatio):
|
|
|
91
89
|
)
|
|
92
90
|
lap = _laplacian_fwd(t, x, u, params)[..., None]
|
|
93
91
|
return du_dt + self.Tmax * (
|
|
94
|
-
-params
|
|
92
|
+
-params.eq_params["D"] * lap
|
|
95
93
|
- u_tx
|
|
96
|
-
* (
|
|
97
|
-
params["eq_params"]["r"][..., None]
|
|
98
|
-
- params["eq_params"]["g"] * u_tx
|
|
99
|
-
)
|
|
94
|
+
* (params.eq_params["r"][..., None] - params.eq_params["g"] * u_tx)
|
|
100
95
|
)
|
|
101
96
|
raise ValueError("u is not among the recognized types (PINN or SPINN)")
|
|
102
97
|
|
|
103
98
|
|
|
104
|
-
class
|
|
99
|
+
class GeneralizedLotkaVolterra(ODE):
|
|
105
100
|
r"""
|
|
106
|
-
Return
|
|
107
|
-
|
|
108
|
-
.. math::
|
|
109
|
-
\frac{\partial}{\partial t} u(t,x) + u(t,x)\frac{\partial}{\partial x}
|
|
110
|
-
u(t,x) - \theta \frac{\partial^2}{\partial x^2} u(t,x) = 0
|
|
101
|
+
Return a dynamic loss from an equation of a Generalized Lotka Volterra
|
|
102
|
+
system. Say we implement the equation for population $i$
|
|
111
103
|
|
|
104
|
+
$$
|
|
105
|
+
\frac{\partial}{\partial t}u_i(t) = r_iu_i(t) - \sum_{j\neq i}\alpha_{ij}u_j(t)
|
|
106
|
+
-\alpha_{i,i}u_i(t) + c_iu_i(t) + \sum_{j \neq i} c_ju_j(t)
|
|
107
|
+
$$
|
|
108
|
+
with $r_i$ the growth rate parameter, $c_i$ the carrying
|
|
109
|
+
capacities and $\alpha_{ij}$ the interaction terms.
|
|
110
|
+
|
|
111
|
+
Parameters
|
|
112
|
+
----------
|
|
113
|
+
key_main
|
|
114
|
+
The dictionary key (in the dictionaries `u` and `params` that
|
|
115
|
+
are arguments of the `evaluate` function) of the main population
|
|
116
|
+
$i$ of the particular equation of the system implemented
|
|
117
|
+
by this dynamic loss
|
|
118
|
+
keys_other
|
|
119
|
+
The list of dictionary keys (in the dictionaries `u` and `params` that
|
|
120
|
+
are arguments of the `evaluate` function) of the other
|
|
121
|
+
populations that appear in the equation of the system implemented
|
|
122
|
+
by this dynamic loss
|
|
123
|
+
Tmax
|
|
124
|
+
Tmax needs to be given when the PINN time input is normalized in
|
|
125
|
+
$[0, 1]$, ie. we have performed renormalization of the differential
|
|
126
|
+
equation.
|
|
127
|
+
eq_params_heterogeneity
|
|
128
|
+
Default None. A dict with the keys being the same as in eq_params
|
|
129
|
+
and the value being `time`, `space`, `both` or None which corresponds to
|
|
130
|
+
the heterogeneity of a given parameter. A value can be missing, in
|
|
131
|
+
this case there is no heterogeneity (=None). If
|
|
132
|
+
eq_params_heterogeneity is None this means there is no
|
|
133
|
+
heterogeneity for no parameters.
|
|
112
134
|
"""
|
|
113
135
|
|
|
114
|
-
|
|
136
|
+
# they should be static because they are list of strings
|
|
137
|
+
key_main: list[str] = eqx.field(static=True)
|
|
138
|
+
keys_other: list[str] = eqx.field(static=True)
|
|
139
|
+
|
|
140
|
+
def equation(
|
|
115
141
|
self,
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
142
|
+
t: Float[Array, "1"],
|
|
143
|
+
u_dict: Dict[str, eqx.Module],
|
|
144
|
+
params_dict: ParamsDict,
|
|
145
|
+
) -> Float[Array, "1"]:
|
|
119
146
|
"""
|
|
147
|
+
Evaluate the dynamic loss at `t`.
|
|
148
|
+
For stability we implement the dynamic loss in log space.
|
|
149
|
+
|
|
120
150
|
Parameters
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
the heterogeneity of a given parameter. A value can be missing, in
|
|
130
|
-
this case there is no heterogeneity (=None). If
|
|
131
|
-
eq_params_heterogeneity is None this means there is no
|
|
132
|
-
heterogeneity for no parameters.
|
|
151
|
+
---------
|
|
152
|
+
t
|
|
153
|
+
A time point
|
|
154
|
+
u_dict
|
|
155
|
+
A dictionary of PINNS. Must have the same keys as `params_dict`
|
|
156
|
+
params_dict
|
|
157
|
+
The dictionary of dictionaries of parameters of the model. Keys at
|
|
158
|
+
top level are "nn_params" and "eq_params"
|
|
133
159
|
"""
|
|
134
|
-
|
|
160
|
+
params_main = params_dict.extract_params(self.key_main)
|
|
161
|
+
|
|
162
|
+
u = u_dict[self.key_main]
|
|
163
|
+
# need to index with [0] since u output is nec (1,)
|
|
164
|
+
du_dt = grad(lambda t: jnp.log(u(t, params_main)[0]), 0)(t)
|
|
165
|
+
carrying_term = params_main.eq_params["carrying_capacity"] * u(t, params_main)
|
|
166
|
+
# NOTE the following assumes interaction term with oneself is at idx 0
|
|
167
|
+
interaction_terms = params_main.eq_params["interactions"][0] * u(t, params_main)
|
|
168
|
+
|
|
169
|
+
# TODO write this for loop with tree_util functions?
|
|
170
|
+
for i, k in enumerate(self.keys_other):
|
|
171
|
+
params_k = params_dict.extract_params(k)
|
|
172
|
+
carrying_term += params_main.eq_params["carrying_capacity"] * u_dict[k](
|
|
173
|
+
t, params_k
|
|
174
|
+
)
|
|
175
|
+
interaction_terms += params_main.eq_params["interactions"][i + 1] * u_dict[
|
|
176
|
+
k
|
|
177
|
+
](t, params_k)
|
|
178
|
+
|
|
179
|
+
return du_dt + self.Tmax * (
|
|
180
|
+
-params_main.eq_params["growth_rate"] - interaction_terms + carrying_term
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class BurgerEquation(PDENonStatio):
|
|
185
|
+
r"""
|
|
186
|
+
Return the Burger dynamic loss term (in 1 space dimension):
|
|
135
187
|
|
|
136
|
-
|
|
188
|
+
$$
|
|
189
|
+
\frac{\partial}{\partial t} u(t,x) + u(t,x)\frac{\partial}{\partial x}
|
|
190
|
+
u(t,x) - \theta \frac{\partial^2}{\partial x^2} u(t,x) = 0
|
|
191
|
+
$$
|
|
192
|
+
|
|
193
|
+
Parameters
|
|
194
|
+
----------
|
|
195
|
+
Tmax
|
|
196
|
+
Tmax needs to be given when the PINN time input is normalized in
|
|
197
|
+
[0, 1], ie. we have performed renormalization of the differential
|
|
198
|
+
equation
|
|
199
|
+
eq_params_heterogeneity
|
|
200
|
+
Default None. A dict with the keys being the same as in eq_params
|
|
201
|
+
and the value being `time`, `space`, `both` or None which corresponds to
|
|
202
|
+
the heterogeneity of a given parameter. A value can be missing, in
|
|
203
|
+
this case there is no heterogeneity (=None). If
|
|
204
|
+
eq_params_heterogeneity is None this means there is no
|
|
205
|
+
heterogeneity for no parameters.
|
|
206
|
+
"""
|
|
207
|
+
|
|
208
|
+
def equation(
|
|
209
|
+
self,
|
|
210
|
+
t: Float[Array, "1"],
|
|
211
|
+
x: Float[Array, "dim"],
|
|
212
|
+
u: eqx.Module,
|
|
213
|
+
params: Params,
|
|
214
|
+
) -> Float[Array, "1"]:
|
|
137
215
|
r"""
|
|
138
216
|
Evaluate the dynamic loss at :math:`(t,x)`.
|
|
139
217
|
|
|
@@ -142,14 +220,11 @@ class BurgerEquation(PDENonStatio):
|
|
|
142
220
|
t
|
|
143
221
|
A time point
|
|
144
222
|
x
|
|
145
|
-
A point in
|
|
223
|
+
A point in $\Omega$
|
|
146
224
|
u
|
|
147
225
|
The PINN
|
|
148
226
|
params
|
|
149
227
|
The dictionary of parameters of the model.
|
|
150
|
-
Typically, it is a dictionary of
|
|
151
|
-
dictionaries: `eq_params` and `nn_params``, respectively the
|
|
152
|
-
differential equation parameters and the neural network parameter
|
|
153
228
|
"""
|
|
154
229
|
if isinstance(u, PINN):
|
|
155
230
|
# Note that the last dim of u is nec. 1
|
|
@@ -162,8 +237,7 @@ class BurgerEquation(PDENonStatio):
|
|
|
162
237
|
)
|
|
163
238
|
|
|
164
239
|
return du_dt(t, x) + self.Tmax * (
|
|
165
|
-
u(t, x, params) * du_dx(t, x)
|
|
166
|
-
- params["eq_params"]["nu"] * d2u_dx2(t, x)
|
|
240
|
+
u(t, x, params) * du_dx(t, x) - params.eq_params["nu"] * d2u_dx2(t, x)
|
|
167
241
|
)
|
|
168
242
|
|
|
169
243
|
if isinstance(u, SPINN):
|
|
@@ -181,161 +255,68 @@ class BurgerEquation(PDENonStatio):
|
|
|
181
255
|
)[1]
|
|
182
256
|
du_dx, d2u_dx2 = jax.jvp(du_dx_fun, (x,), (jnp.ones_like(x),))
|
|
183
257
|
# Note that ones_like(x) works because x is Bx1 !
|
|
184
|
-
return du_dt + self.Tmax * (
|
|
185
|
-
u_tx * du_dx - params["eq_params"]["nu"] * d2u_dx2
|
|
186
|
-
)
|
|
258
|
+
return du_dt + self.Tmax * (u_tx * du_dx - params.eq_params["nu"] * d2u_dx2)
|
|
187
259
|
raise ValueError("u is not among the recognized types (PINN or SPINN)")
|
|
188
260
|
|
|
189
261
|
|
|
190
|
-
class GeneralizedLotkaVolterra(ODE):
|
|
191
|
-
r"""
|
|
192
|
-
Return a dynamic loss from an equation of a Generalized Lotka Volterra
|
|
193
|
-
system. Say we implement the equation for population :math:`i`:
|
|
194
|
-
|
|
195
|
-
.. math::
|
|
196
|
-
\frac{\partial}{\partial t}u_i(t) = r_iu_i(t) - \sum_{j\neq i}\alpha_{ij}u_j(t)
|
|
197
|
-
-\alpha_{i,i}u_i(t) + c_iu_i(t) + \sum_{j \neq i} c_ju_j(t)
|
|
198
|
-
|
|
199
|
-
with :math:`r_i` the growth rate parameter, :math:`c_i` the carrying
|
|
200
|
-
capacities and :math:`\alpha_{ij}` the interaction terms.
|
|
201
|
-
|
|
202
|
-
"""
|
|
203
|
-
|
|
204
|
-
def __init__(
|
|
205
|
-
self,
|
|
206
|
-
key_main,
|
|
207
|
-
keys_other,
|
|
208
|
-
Tmax=1,
|
|
209
|
-
eq_params_heterogeneity=None,
|
|
210
|
-
):
|
|
211
|
-
"""
|
|
212
|
-
Parameters
|
|
213
|
-
----------
|
|
214
|
-
key_main
|
|
215
|
-
The dictionary key (in the dictionaries ``u`` and ``params`` that
|
|
216
|
-
are arguments of the ``evaluate`` function) of the main population
|
|
217
|
-
:math:`i` of the particular equation of the system implemented
|
|
218
|
-
by this dynamic loss
|
|
219
|
-
keys_other
|
|
220
|
-
The list of dictionary keys (in the dictionaries ``u`` and ``params`` that
|
|
221
|
-
are arguments of the ``evaluate`` function) of the other
|
|
222
|
-
populations that appear in the equation of the system implemented
|
|
223
|
-
by this dynamic loss
|
|
224
|
-
Tmax
|
|
225
|
-
Tmax needs to be given when the PINN time input is normalized in
|
|
226
|
-
[0, 1], ie. we have performed renormalization of the differential
|
|
227
|
-
equation
|
|
228
|
-
eq_params_heterogeneity
|
|
229
|
-
Default None. A dict with the keys being the same as in eq_params
|
|
230
|
-
and the value being `time`, `space`, `both` or None which corresponds to
|
|
231
|
-
the heterogeneity of a given parameter. A value can be missing, in
|
|
232
|
-
this case there is no heterogeneity (=None). If
|
|
233
|
-
eq_params_heterogeneity is None this means there is no
|
|
234
|
-
heterogeneity for no parameters.
|
|
235
|
-
"""
|
|
236
|
-
super().__init__(Tmax, eq_params_heterogeneity)
|
|
237
|
-
self.key_main = key_main
|
|
238
|
-
self.keys_other = keys_other
|
|
239
|
-
|
|
240
|
-
def evaluate(self, t, u_dict, params_dict):
|
|
241
|
-
"""
|
|
242
|
-
Evaluate the dynamic loss at `t`.
|
|
243
|
-
For stability we implement the dynamic loss in log space.
|
|
244
|
-
|
|
245
|
-
Parameters
|
|
246
|
-
---------
|
|
247
|
-
t
|
|
248
|
-
A time point
|
|
249
|
-
u_dict
|
|
250
|
-
A dictionary of PINNS. Must have the same keys as `params_dict`
|
|
251
|
-
params_dict
|
|
252
|
-
The dictionary of dictionaries of parameters of the model. Keys at
|
|
253
|
-
top level are "nn_params" and "eq_params"
|
|
254
|
-
"""
|
|
255
|
-
params_main = _extract_nn_params(params_dict, self.key_main)
|
|
256
|
-
|
|
257
|
-
u = u_dict[self.key_main]
|
|
258
|
-
# need to index with [0] since u output is nec (1,)
|
|
259
|
-
du_dt = grad(lambda t: jnp.log(u(t, params_main)[0]), 0)(t)
|
|
260
|
-
carrying_term = params_main["eq_params"]["carrying_capacity"] * u(
|
|
261
|
-
t, params_main
|
|
262
|
-
)
|
|
263
|
-
# NOTE the following assumes interaction term with oneself is at idx 0
|
|
264
|
-
interaction_terms = params_main["eq_params"]["interactions"][0] * u(
|
|
265
|
-
t, params_main
|
|
266
|
-
)
|
|
267
|
-
|
|
268
|
-
# TODO write this for loop with tree_util functions?
|
|
269
|
-
for i, k in enumerate(self.keys_other):
|
|
270
|
-
params_k = _extract_nn_params(params_dict, k)
|
|
271
|
-
carrying_term += params_main["eq_params"]["carrying_capacity"] * u_dict[k](
|
|
272
|
-
t, params_k
|
|
273
|
-
)
|
|
274
|
-
interaction_terms += params_main["eq_params"]["interactions"][
|
|
275
|
-
i + 1
|
|
276
|
-
] * u_dict[k](t, params_k)
|
|
277
|
-
|
|
278
|
-
return du_dt + self.Tmax * (
|
|
279
|
-
-params_main["eq_params"]["growth_rate"] - interaction_terms + carrying_term
|
|
280
|
-
)
|
|
281
|
-
|
|
282
|
-
|
|
283
262
|
class FPENonStatioLoss2D(PDENonStatio):
|
|
284
263
|
r"""
|
|
285
264
|
Return the dynamic loss for a non-stationary Fokker Planck Equation in two
|
|
286
265
|
dimensions:
|
|
287
266
|
|
|
288
|
-
|
|
267
|
+
$$
|
|
289
268
|
-\sum_{i=1}^2\frac{\partial}{\partial \mathbf{x}}
|
|
290
269
|
\left[\mu(t, \mathbf{x})u(t, \mathbf{x})\right] +
|
|
291
270
|
\sum_{i=1}^2\sum_{j=1}^2\frac{\partial^2}{\partial x_i \partial x_j}
|
|
292
271
|
\left[D(t, \mathbf{x})u(t, \mathbf{x})\right]= \frac{\partial}
|
|
293
272
|
{\partial t}u(t,\mathbf{x})
|
|
294
|
-
|
|
295
|
-
where
|
|
296
|
-
term.
|
|
273
|
+
$$
|
|
274
|
+
where $\mu(t, \mathbf{x})$ is the drift term and $D(t, \mathbf{x})$ is the
|
|
275
|
+
diffusion term.
|
|
297
276
|
|
|
298
277
|
The drift and diffusion terms are not specified here, hence this class
|
|
299
278
|
is `abstract`.
|
|
300
279
|
Other classes inherit from FPENonStatioLoss2D and define the drift and
|
|
301
280
|
diffusion terms, which then defines several other dynamic losses
|
|
302
281
|
(Ornstein-Uhlenbeck, Cox-Ingersoll-Ross, ...)
|
|
303
|
-
"""
|
|
304
282
|
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
heterogeneity for no parameters.
|
|
320
|
-
"""
|
|
321
|
-
super().__init__(Tmax, eq_params_heterogeneity)
|
|
283
|
+
Parameters
|
|
284
|
+
----------
|
|
285
|
+
Tmax
|
|
286
|
+
Tmax needs to be given when the PINN time input is normalized in
|
|
287
|
+
[0, 1], ie. we have performed renormalization of the differential
|
|
288
|
+
equation
|
|
289
|
+
eq_params_heterogeneity
|
|
290
|
+
Default None. A dict with the keys being the same as in eq_params
|
|
291
|
+
and the value being `time`, `space`, `both` or None which corresponds to
|
|
292
|
+
the heterogeneity of a given parameter. A value can be missing, in
|
|
293
|
+
this case there is no heterogeneity (=None). If
|
|
294
|
+
eq_params_heterogeneity is None this means there is no
|
|
295
|
+
heterogeneity for no parameters.
|
|
296
|
+
"""
|
|
322
297
|
|
|
323
|
-
def
|
|
298
|
+
def equation(
|
|
299
|
+
self,
|
|
300
|
+
t: Float[Array, "1"],
|
|
301
|
+
x: Float[Array, "dim"],
|
|
302
|
+
u: eqx.Module,
|
|
303
|
+
params: Params,
|
|
304
|
+
) -> Float[Array, "1"]:
|
|
324
305
|
r"""
|
|
325
|
-
Evaluate the dynamic loss at
|
|
306
|
+
Evaluate the dynamic loss at $(t,\mathbf{x})$.
|
|
326
307
|
|
|
327
308
|
Parameters
|
|
328
309
|
---------
|
|
329
310
|
t
|
|
330
311
|
A time point
|
|
331
312
|
x
|
|
332
|
-
A point in
|
|
313
|
+
A point in $\Omega$
|
|
333
314
|
u
|
|
334
315
|
The PINN
|
|
335
316
|
params
|
|
336
317
|
The dictionary of parameters of the model.
|
|
337
318
|
Typically, it is a dictionary of
|
|
338
|
-
dictionaries: `eq_params` and `nn_params
|
|
319
|
+
dictionaries: `eq_params` and `nn_params`, respectively the
|
|
339
320
|
differential equation parameters and the neural network parameter
|
|
340
321
|
"""
|
|
341
322
|
if isinstance(u, PINN):
|
|
@@ -344,11 +325,13 @@ class FPENonStatioLoss2D(PDENonStatio):
|
|
|
344
325
|
|
|
345
326
|
order_1 = (
|
|
346
327
|
grad(
|
|
347
|
-
lambda t, x: self.drift(t, x, params
|
|
328
|
+
lambda t, x: self.drift(t, x, params.eq_params)[0] * u_(t, x),
|
|
348
329
|
1,
|
|
349
|
-
)(
|
|
330
|
+
)(
|
|
331
|
+
t, x
|
|
332
|
+
)[0:1]
|
|
350
333
|
+ grad(
|
|
351
|
-
lambda t, x: self.drift(t, x, params
|
|
334
|
+
lambda t, x: self.drift(t, x, params.eq_params)[1] * u_(t, x),
|
|
352
335
|
1,
|
|
353
336
|
)(t, x)[1:2]
|
|
354
337
|
)
|
|
@@ -357,7 +340,7 @@ class FPENonStatioLoss2D(PDENonStatio):
|
|
|
357
340
|
grad(
|
|
358
341
|
lambda t, x: grad(
|
|
359
342
|
lambda t, x: u_(t, x)
|
|
360
|
-
* self.diffusion(t, x, params
|
|
343
|
+
* self.diffusion(t, x, params.eq_params)[0, 0],
|
|
361
344
|
1,
|
|
362
345
|
)(t, x)[0],
|
|
363
346
|
1,
|
|
@@ -365,7 +348,7 @@ class FPENonStatioLoss2D(PDENonStatio):
|
|
|
365
348
|
+ grad(
|
|
366
349
|
lambda t, x: grad(
|
|
367
350
|
lambda t, x: u_(t, x)
|
|
368
|
-
* self.diffusion(t, x, params
|
|
351
|
+
* self.diffusion(t, x, params.eq_params)[1, 0],
|
|
369
352
|
1,
|
|
370
353
|
)(t, x)[1],
|
|
371
354
|
1,
|
|
@@ -373,7 +356,7 @@ class FPENonStatioLoss2D(PDENonStatio):
|
|
|
373
356
|
+ grad(
|
|
374
357
|
lambda t, x: grad(
|
|
375
358
|
lambda t, x: u_(t, x)
|
|
376
|
-
* self.diffusion(t, x, params
|
|
359
|
+
* self.diffusion(t, x, params.eq_params)[0, 1],
|
|
377
360
|
1,
|
|
378
361
|
)(t, x)[0],
|
|
379
362
|
1,
|
|
@@ -381,7 +364,7 @@ class FPENonStatioLoss2D(PDENonStatio):
|
|
|
381
364
|
+ grad(
|
|
382
365
|
lambda t, x: grad(
|
|
383
366
|
lambda t, x: u_(t, x)
|
|
384
|
-
* self.diffusion(t, x, params
|
|
367
|
+
* self.diffusion(t, x, params.eq_params)[1, 1],
|
|
385
368
|
1,
|
|
386
369
|
)(t, x)[1],
|
|
387
370
|
1,
|
|
@@ -406,24 +389,20 @@ class FPENonStatioLoss2D(PDENonStatio):
|
|
|
406
389
|
tangent_vec_0 = jnp.repeat(jnp.array([1.0, 0.0])[None], x.shape[0], axis=0)
|
|
407
390
|
tangent_vec_1 = jnp.repeat(jnp.array([0.0, 1.0])[None], x.shape[0], axis=0)
|
|
408
391
|
_, dau_dx1 = jax.jvp(
|
|
409
|
-
lambda x: self.drift(t, _get_grid(x), params
|
|
410
|
-
None, ..., 0:1
|
|
411
|
-
]
|
|
392
|
+
lambda x: self.drift(t, _get_grid(x), params.eq_params)[None, ..., 0:1]
|
|
412
393
|
* u(t, x, params)[..., 0:1],
|
|
413
394
|
(x,),
|
|
414
395
|
(tangent_vec_0,),
|
|
415
396
|
)
|
|
416
397
|
_, dau_dx2 = jax.jvp(
|
|
417
|
-
lambda x: self.drift(t, _get_grid(x), params
|
|
418
|
-
None, ..., 1:2
|
|
419
|
-
]
|
|
398
|
+
lambda x: self.drift(t, _get_grid(x), params.eq_params)[None, ..., 1:2]
|
|
420
399
|
* u(t, x, params)[..., 0:1],
|
|
421
400
|
(x,),
|
|
422
401
|
(tangent_vec_1,),
|
|
423
402
|
)
|
|
424
403
|
|
|
425
404
|
dsu_dx1_fun = lambda x, i, j: jax.jvp(
|
|
426
|
-
lambda x: self.diffusion(t, _get_grid(x), params
|
|
405
|
+
lambda x: self.diffusion(t, _get_grid(x), params.eq_params, i, j)[
|
|
427
406
|
None, None, None, None
|
|
428
407
|
]
|
|
429
408
|
* u(t, x, params)[..., 0:1],
|
|
@@ -431,7 +410,7 @@ class FPENonStatioLoss2D(PDENonStatio):
|
|
|
431
410
|
(tangent_vec_0,),
|
|
432
411
|
)[1]
|
|
433
412
|
dsu_dx2_fun = lambda x, i, j: jax.jvp(
|
|
434
|
-
lambda x: self.diffusion(t, _get_grid(x), params
|
|
413
|
+
lambda x: self.diffusion(t, _get_grid(x), params.eq_params, i, j)[
|
|
435
414
|
None, None, None, None
|
|
436
415
|
]
|
|
437
416
|
* u(t, x, params)[..., 0:1],
|
|
@@ -459,11 +438,11 @@ class FPENonStatioLoss2D(PDENonStatio):
|
|
|
459
438
|
|
|
460
439
|
def drift(self, *args, **kwargs):
|
|
461
440
|
# To be implemented in child classes
|
|
462
|
-
|
|
441
|
+
raise NotImplementedError("Drift function should be implemented")
|
|
463
442
|
|
|
464
443
|
def diffusion(self, *args, **kwargs):
|
|
465
444
|
# To be implemented in child classes
|
|
466
|
-
|
|
445
|
+
raise NotImplementedError("Diffusion function should be implemented")
|
|
467
446
|
|
|
468
447
|
|
|
469
448
|
class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
|
|
@@ -471,34 +450,30 @@ class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
|
|
|
471
450
|
Return the dynamic loss for a stationary Fokker Planck Equation in two
|
|
472
451
|
dimensions:
|
|
473
452
|
|
|
474
|
-
|
|
453
|
+
$$
|
|
475
454
|
-\sum_{i=1}^2\frac{\partial}{\partial \mathbf{x}}
|
|
476
455
|
\left[(\alpha(\mu - \mathbf{x}))u(t,\mathbf{x})\right] +
|
|
477
456
|
\sum_{i=1}^2\sum_{j=1}^2\frac{\partial^2}{\partial x_i \partial x_j}
|
|
478
457
|
\left[\frac{\sigma^2}{2}u(t,\mathbf{x})\right]=
|
|
479
458
|
\frac{\partial}
|
|
480
459
|
{\partial t}u(t,\mathbf{x})
|
|
481
|
-
|
|
460
|
+
$$
|
|
461
|
+
|
|
462
|
+
Parameters
|
|
463
|
+
----------
|
|
464
|
+
Tmax
|
|
465
|
+
Tmax needs to be given when the PINN time input is normalized in
|
|
466
|
+
[0, 1], ie. we have performed renormalization of the differential
|
|
467
|
+
equation
|
|
468
|
+
eq_params_heterogeneity
|
|
469
|
+
Default None. A dict with the keys being the same as in eq_params
|
|
470
|
+
and the value being `time`, `space`, `both` or None which corresponds to
|
|
471
|
+
the heterogeneity of a given parameter. A value can be missing, in
|
|
472
|
+
this case there is no heterogeneity (=None). If
|
|
473
|
+
eq_params_heterogeneity is None this means there is no
|
|
474
|
+
heterogeneity for no parameters.
|
|
482
475
|
"""
|
|
483
476
|
|
|
484
|
-
def __init__(self, Tmax=1, eq_params_heterogeneity=None):
|
|
485
|
-
"""
|
|
486
|
-
Parameters
|
|
487
|
-
----------
|
|
488
|
-
Tmax
|
|
489
|
-
Tmax needs to be given when the PINN time input is normalized in
|
|
490
|
-
[0, 1], ie. we have performed renormalization of the differential
|
|
491
|
-
equation
|
|
492
|
-
eq_params_heterogeneity
|
|
493
|
-
Default None. A dict with the keys being the same as in eq_params
|
|
494
|
-
and the value being `time`, `space`, `both` or None which corresponds to
|
|
495
|
-
the heterogeneity of a given parameter. A value can be missing, in
|
|
496
|
-
this case there is no heterogeneity (=None). If
|
|
497
|
-
eq_params_heterogeneity is None this means there is no
|
|
498
|
-
heterogeneity for no parameters.
|
|
499
|
-
"""
|
|
500
|
-
super().__init__(Tmax, eq_params_heterogeneity)
|
|
501
|
-
|
|
502
477
|
def drift(self, t, x, eq_params):
|
|
503
478
|
r"""
|
|
504
479
|
Return the drift term
|
|
@@ -508,7 +483,7 @@ class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
|
|
|
508
483
|
t
|
|
509
484
|
A time point
|
|
510
485
|
x
|
|
511
|
-
A point in
|
|
486
|
+
A point in $\Omega$
|
|
512
487
|
eq_params
|
|
513
488
|
A dictionary containing the equation parameters
|
|
514
489
|
"""
|
|
@@ -524,7 +499,7 @@ class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
|
|
|
524
499
|
t
|
|
525
500
|
A time point
|
|
526
501
|
x
|
|
527
|
-
A point in
|
|
502
|
+
A point in $\Omega$
|
|
528
503
|
eq_params
|
|
529
504
|
A dictionary containing the equation parameters
|
|
530
505
|
"""
|
|
@@ -541,7 +516,7 @@ class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
|
|
|
541
516
|
t
|
|
542
517
|
A time point
|
|
543
518
|
x
|
|
544
|
-
A point in
|
|
519
|
+
A point in $\Omega$
|
|
545
520
|
eq_params
|
|
546
521
|
A dictionary containing the equation parameters
|
|
547
522
|
"""
|
|
@@ -564,33 +539,36 @@ class MassConservation2DStatio(PDEStatio):
|
|
|
564
539
|
r"""
|
|
565
540
|
Returns the so-called mass conservation equation.
|
|
566
541
|
|
|
567
|
-
|
|
542
|
+
$$
|
|
568
543
|
\nabla \cdot \mathbf{u} = \frac{\partial}{\partial x}u(x,y) +
|
|
569
544
|
\frac{\partial}{\partial y}u(x,y) = 0,
|
|
570
|
-
|
|
571
|
-
where
|
|
572
|
-
|
|
545
|
+
$$
|
|
546
|
+
where $u$ is a stationary function, i.e., it does not depend on
|
|
547
|
+
$t$.
|
|
548
|
+
|
|
549
|
+
Parameters
|
|
550
|
+
----------
|
|
551
|
+
nn_key
|
|
552
|
+
A dictionary key which identifies, in `u_dict` the PINN that
|
|
553
|
+
appears in the mass conservation equation.
|
|
554
|
+
eq_params_heterogeneity
|
|
555
|
+
Default None. A dict with the keys being the same as in eq_params
|
|
556
|
+
and the value being `time`, `space`, `both` or None which corresponds to
|
|
557
|
+
the heterogeneity of a given parameter. A value can be missing, in
|
|
558
|
+
this case there is no heterogeneity (=None). If
|
|
559
|
+
eq_params_heterogeneity is None this means there is no
|
|
560
|
+
heterogeneity for no parameters.
|
|
573
561
|
"""
|
|
574
562
|
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
Parameters
|
|
578
|
-
----------
|
|
579
|
-
nn_key
|
|
580
|
-
A dictionary key which identifies, in `u_dict` the PINN that
|
|
581
|
-
appears in the mass conservation equation.
|
|
582
|
-
eq_params_heterogeneity
|
|
583
|
-
Default None. A dict with the keys being the same as in eq_params
|
|
584
|
-
and the value being `time`, `space`, `both` or None which corresponds to
|
|
585
|
-
the heterogeneity of a given parameter. A value can be missing, in
|
|
586
|
-
this case there is no heterogeneity (=None). If
|
|
587
|
-
eq_params_heterogeneity is None this means there is no
|
|
588
|
-
heterogeneity for no parameters.
|
|
589
|
-
"""
|
|
590
|
-
self.nn_key = nn_key
|
|
591
|
-
super().__init__(eq_params_heterogeneity)
|
|
563
|
+
# an str field should be static (not a valid JAX type)
|
|
564
|
+
nn_key: str = eqx.field(static=True)
|
|
592
565
|
|
|
593
|
-
def
|
|
566
|
+
def equation(
|
|
567
|
+
self,
|
|
568
|
+
x: Float[Array, "dim"],
|
|
569
|
+
u_dict: Dict[str, eqx.Module],
|
|
570
|
+
params_dict: ParamsDict,
|
|
571
|
+
) -> Float[Array, "1"]:
|
|
594
572
|
r"""
|
|
595
573
|
Evaluate the dynamic loss at `\mathbf{x}`.
|
|
596
574
|
For stability we implement the dynamic loss in log space.
|
|
@@ -598,17 +576,17 @@ class MassConservation2DStatio(PDEStatio):
|
|
|
598
576
|
Parameters
|
|
599
577
|
---------
|
|
600
578
|
x
|
|
601
|
-
A point in
|
|
579
|
+
A point in $\Omega\subset\mathbb{R}^2$
|
|
602
580
|
u_dict
|
|
603
581
|
A dictionary of PINNs. Must have the same keys as `params_dict`
|
|
604
582
|
params_dict
|
|
605
583
|
The dictionary of dictionaries of parameters of the model.
|
|
606
584
|
Typically, each sub-dictionary is a dictionary
|
|
607
|
-
with keys: `eq_params` and `nn_params
|
|
585
|
+
with keys: `eq_params` and `nn_params`, respectively the
|
|
608
586
|
differential equation parameters and the neural network parameter.
|
|
609
587
|
Must have the same keys as `u_dict`
|
|
610
588
|
"""
|
|
611
|
-
params =
|
|
589
|
+
params = params_dict.extract_params(self.nn_key)
|
|
612
590
|
|
|
613
591
|
if isinstance(u_dict[self.nn_key], PINN):
|
|
614
592
|
u = u_dict[self.nn_key]
|
|
@@ -627,15 +605,15 @@ class NavierStokes2DStatio(PDEStatio):
|
|
|
627
605
|
Return the dynamic loss for all the components of the stationary Navier Stokes
|
|
628
606
|
equation which is a 2D vectorial PDE.
|
|
629
607
|
|
|
630
|
-
|
|
608
|
+
$$
|
|
631
609
|
(\mathbf{u}\cdot\nabla)\mathbf{u} + \frac{1}{\rho}\nabla p - \theta
|
|
632
610
|
\nabla^2\mathbf{u}=0,
|
|
633
|
-
|
|
611
|
+
$$
|
|
634
612
|
|
|
635
613
|
or, in 2D,
|
|
636
614
|
|
|
637
615
|
|
|
638
|
-
|
|
616
|
+
$$
|
|
639
617
|
\begin{pmatrix}u_x\frac{\partial}{\partial x} u_x + u_y\frac{\partial}{\partial y} u_x \\
|
|
640
618
|
u_x\frac{\partial}{\partial x} u_y + u_y\frac{\partial}{\partial y} u_y \end{pmatrix} +
|
|
641
619
|
\frac{1}{\rho} \begin{pmatrix} \frac{\partial}{\partial x} p \\ \frac{\partial}{\partial y} p \end{pmatrix}
|
|
@@ -644,61 +622,60 @@ class NavierStokes2DStatio(PDEStatio):
|
|
|
644
622
|
\frac{\partial^2}{\partial x^2} u_x + \frac{\partial^2}{\partial y^2} u_x \\
|
|
645
623
|
\frac{\partial^2}{\partial x^2} u_y + \frac{\partial^2}{\partial y^2} u_y
|
|
646
624
|
\end{pmatrix} = 0,
|
|
647
|
-
|
|
625
|
+
$$
|
|
648
626
|
with $\theta$ the viscosity coefficient and $\rho$ the density coefficient.
|
|
649
627
|
|
|
650
|
-
|
|
651
|
-
|
|
628
|
+
Parameters
|
|
629
|
+
----------
|
|
630
|
+
u_key
|
|
631
|
+
A dictionary key which indices the NN u in `u_dict`
|
|
632
|
+
the PINN with the role of the velocity in the equation.
|
|
633
|
+
Its input is bimensional (points in $\Omega\subset\mathbb{R}^2$).
|
|
634
|
+
Its output is bimensional as it represents a velocity vector
|
|
635
|
+
field
|
|
636
|
+
p_key
|
|
637
|
+
A dictionary key which indices the NN p in `u_dict`
|
|
638
|
+
the PINN with the role of the pressure in the equation.
|
|
639
|
+
Its input is bimensional (points in $\Omega\subset\mathbb{R}^2).
|
|
640
|
+
Its output is unidimensional as it represents a pressure scalar
|
|
641
|
+
field
|
|
642
|
+
eq_params_heterogeneity
|
|
643
|
+
Default None. A dict with the keys being the same as in eq_params
|
|
644
|
+
and the value being `time`, `space`, `both` or None which corresponds to
|
|
645
|
+
the heterogeneity of a given parameter. A value can be missing, in
|
|
646
|
+
this case there is no heterogeneity (=None). If
|
|
647
|
+
eq_params_heterogeneity is None this means there is no
|
|
648
|
+
heterogeneity for no parameters.
|
|
652
649
|
"""
|
|
653
650
|
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
Parameters
|
|
657
|
-
----------
|
|
658
|
-
u_key
|
|
659
|
-
A dictionary key which indices in `u_dict`
|
|
660
|
-
the PINN with the role of the velocity in the equation.
|
|
661
|
-
Its input is bimensional (points in :math:`\Omega\subset\mathbb{R}^2`).
|
|
662
|
-
Its output is bimensional as it represents a velocity vector
|
|
663
|
-
field
|
|
664
|
-
p_key
|
|
665
|
-
A dictionary key which indices in `u_dict`
|
|
666
|
-
the PINN with the role of the pressure in the equation.
|
|
667
|
-
Its input is bimensional (points in :math:`\Omega\subset\mathbb{R}^2`).
|
|
668
|
-
Its output is unidimensional as it represents a pressure scalar
|
|
669
|
-
field
|
|
670
|
-
eq_params_heterogeneity
|
|
671
|
-
Default None. A dict with the keys being the same as in eq_params
|
|
672
|
-
and the value being `time`, `space`, `both` or None which corresponds to
|
|
673
|
-
the heterogeneity of a given parameter. A value can be missing, in
|
|
674
|
-
this case there is no heterogeneity (=None). If
|
|
675
|
-
eq_params_heterogeneity is None this means there is no
|
|
676
|
-
heterogeneity for no parameters.
|
|
677
|
-
"""
|
|
678
|
-
self.u_key = u_key
|
|
679
|
-
self.p_key = p_key
|
|
680
|
-
super().__init__(eq_params_heterogeneity)
|
|
651
|
+
u_key: str = eqx.field(static=True)
|
|
652
|
+
p_key: str = eqx.field(static=True)
|
|
681
653
|
|
|
682
|
-
def
|
|
654
|
+
def equation(
|
|
655
|
+
self,
|
|
656
|
+
x: Float[Array, "dim"],
|
|
657
|
+
u_dict: Dict[str, eqx.Module],
|
|
658
|
+
params_dict: ParamsDict,
|
|
659
|
+
) -> Float[Array, "1"]:
|
|
683
660
|
r"""
|
|
684
|
-
Evaluate the dynamic loss at
|
|
661
|
+
Evaluate the dynamic loss at `x`.
|
|
685
662
|
For stability we implement the dynamic loss in log space.
|
|
686
663
|
|
|
687
664
|
Parameters
|
|
688
665
|
---------
|
|
689
666
|
x
|
|
690
|
-
A point in
|
|
667
|
+
A point in $\Omega\subset\mathbb{R}^2$
|
|
691
668
|
u_dict
|
|
692
669
|
A dictionary of PINNs. Must have the same keys as `params_dict`
|
|
693
670
|
params_dict
|
|
694
671
|
The dictionary of dictionaries of parameters of the model.
|
|
695
672
|
Typically, each sub-dictionary is a dictionary
|
|
696
|
-
with keys: `eq_params` and `nn_params
|
|
673
|
+
with keys: `eq_params` and `nn_params`, respectively the
|
|
697
674
|
differential equation parameters and the neural network parameter.
|
|
698
675
|
Must have the same keys as `u_dict`
|
|
699
676
|
"""
|
|
700
|
-
u_params =
|
|
701
|
-
p_params =
|
|
677
|
+
u_params = params_dict.extract_params(self.u_key)
|
|
678
|
+
p_params = params_dict.extract_params(self.p_key)
|
|
702
679
|
|
|
703
680
|
if isinstance(u_dict[self.u_key], PINN):
|
|
704
681
|
u = u_dict[self.u_key]
|
|
@@ -706,22 +683,22 @@ class NavierStokes2DStatio(PDEStatio):
|
|
|
706
683
|
u_dot_nabla_x_u = _u_dot_nabla_times_u_rev(None, x, u, u_params)
|
|
707
684
|
|
|
708
685
|
p = lambda x: u_dict[self.p_key](x, p_params)
|
|
709
|
-
jac_p = jacrev(p, 0)(x) # compute the gradient
|
|
686
|
+
jac_p = jax.jacrev(p, 0)(x) # compute the gradient
|
|
710
687
|
|
|
711
688
|
vec_laplacian_u = _vectorial_laplacian(None, x, u, u_params, u_vec_ndim=2)
|
|
712
689
|
|
|
713
690
|
# dynamic loss on x axis
|
|
714
691
|
result_x = (
|
|
715
692
|
u_dot_nabla_x_u[0]
|
|
716
|
-
+ 1 / params_dict
|
|
717
|
-
- params_dict
|
|
693
|
+
+ 1 / params_dict.eq_params["rho"] * jac_p[0, 0]
|
|
694
|
+
- params_dict.eq_params["nu"] * vec_laplacian_u[0]
|
|
718
695
|
)
|
|
719
696
|
|
|
720
697
|
# dynamic loss on y axis
|
|
721
698
|
result_y = (
|
|
722
699
|
u_dot_nabla_x_u[1]
|
|
723
|
-
+ 1 / params_dict
|
|
724
|
-
- params_dict
|
|
700
|
+
+ 1 / params_dict.eq_params["rho"] * jac_p[0, 1]
|
|
701
|
+
- params_dict.eq_params["nu"] * vec_laplacian_u[1]
|
|
725
702
|
)
|
|
726
703
|
|
|
727
704
|
# output is 2D
|
|
@@ -748,14 +725,14 @@ class NavierStokes2DStatio(PDEStatio):
|
|
|
748
725
|
# dynamic loss on x axis
|
|
749
726
|
result_x = (
|
|
750
727
|
u_dot_nabla_x_u[..., 0]
|
|
751
|
-
+ 1 / params_dict
|
|
752
|
-
- params_dict
|
|
728
|
+
+ 1 / params_dict.eq_params["rho"] * dp_dx.squeeze()
|
|
729
|
+
- params_dict.eq_params["nu"] * vec_laplacian_u[..., 0]
|
|
753
730
|
)
|
|
754
731
|
# dynamic loss on y axis
|
|
755
732
|
result_y = (
|
|
756
733
|
u_dot_nabla_x_u[..., 1]
|
|
757
|
-
+ 1 / params_dict
|
|
758
|
-
- params_dict
|
|
734
|
+
+ 1 / params_dict.eq_params["rho"] * dp_dy.squeeze()
|
|
735
|
+
- params_dict.eq_params["nu"] * vec_laplacian_u[..., 1]
|
|
759
736
|
)
|
|
760
737
|
|
|
761
738
|
# output is 2D
|