jinns 1.1.0__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_Batchs.py +4 -8
- jinns/data/_DataGenerators.py +534 -343
- jinns/loss/_DynamicLoss.py +152 -175
- jinns/loss/_DynamicLossAbstract.py +25 -73
- jinns/loss/_LossODE.py +4 -4
- jinns/loss/_LossPDE.py +102 -74
- jinns/loss/__init__.py +7 -6
- jinns/loss/_boundary_conditions.py +150 -281
- jinns/loss/_loss_utils.py +95 -67
- jinns/loss/_operators.py +441 -186
- jinns/nn/__init__.py +7 -0
- jinns/nn/_hyperpinn.py +397 -0
- jinns/nn/_mlp.py +192 -0
- jinns/nn/_pinn.py +190 -0
- jinns/nn/_ppinn.py +203 -0
- jinns/{utils → nn}/_save_load.py +47 -31
- jinns/nn/_spinn.py +106 -0
- jinns/nn/_spinn_mlp.py +196 -0
- jinns/plot/_plot.py +113 -100
- jinns/solver/_rar.py +104 -409
- jinns/solver/_solve.py +87 -38
- jinns/solver/_utils.py +122 -0
- jinns/utils/__init__.py +1 -4
- jinns/utils/_containers.py +3 -1
- jinns/utils/_types.py +5 -4
- jinns/utils/_utils.py +40 -12
- jinns-1.3.0.dist-info/METADATA +127 -0
- jinns-1.3.0.dist-info/RECORD +44 -0
- {jinns-1.1.0.dist-info → jinns-1.3.0.dist-info}/WHEEL +1 -1
- jinns/utils/_hyperpinn.py +0 -410
- jinns/utils/_pinn.py +0 -334
- jinns/utils/_spinn.py +0 -268
- jinns-1.1.0.dist-info/METADATA +0 -85
- jinns-1.1.0.dist-info/RECORD +0 -39
- {jinns-1.1.0.dist-info → jinns-1.3.0.dist-info}/AUTHORS +0 -0
- {jinns-1.1.0.dist-info → jinns-1.3.0.dist-info}/LICENSE +0 -0
- {jinns-1.1.0.dist-info → jinns-1.3.0.dist-info}/top_level.txt +0 -0
jinns/loss/_DynamicLoss.py
CHANGED
|
@@ -13,17 +13,18 @@ from jax import grad
|
|
|
13
13
|
import jax.numpy as jnp
|
|
14
14
|
import equinox as eqx
|
|
15
15
|
|
|
16
|
-
from jinns.
|
|
17
|
-
from jinns.
|
|
16
|
+
from jinns.nn._pinn import PINN
|
|
17
|
+
from jinns.nn._spinn_mlp import SPINN
|
|
18
18
|
|
|
19
|
-
from jinns.utils._utils import
|
|
19
|
+
from jinns.utils._utils import get_grid
|
|
20
20
|
from jinns.loss._DynamicLossAbstract import ODE, PDEStatio, PDENonStatio
|
|
21
21
|
from jinns.loss._operators import (
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
22
|
+
laplacian_rev,
|
|
23
|
+
laplacian_fwd,
|
|
24
|
+
divergence_rev,
|
|
25
|
+
divergence_fwd,
|
|
26
|
+
vectorial_laplacian_rev,
|
|
27
|
+
vectorial_laplacian_fwd,
|
|
27
28
|
_u_dot_nabla_times_u_rev,
|
|
28
29
|
_u_dot_nabla_times_u_fwd,
|
|
29
30
|
)
|
|
@@ -42,24 +43,29 @@ class FisherKPP(PDENonStatio):
|
|
|
42
43
|
$$
|
|
43
44
|
\frac{\partial}{\partial t} u(t,x)=D\Delta u(t,x) + u(t,x)(r(x) - \gamma(x)u(t,x))
|
|
44
45
|
$$
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
dim_x : int, default=1
|
|
50
|
+
The dimension of x, the space domain. Default is 1.
|
|
45
51
|
"""
|
|
46
52
|
|
|
53
|
+
dim_x: int = eqx.field(default=1, static=True)
|
|
54
|
+
|
|
47
55
|
def equation(
|
|
48
56
|
self,
|
|
49
|
-
|
|
50
|
-
x: Float[Array, "dim"],
|
|
57
|
+
t_x: Float[Array, "1+dim"],
|
|
51
58
|
u: eqx.Module,
|
|
52
59
|
params: Params,
|
|
53
60
|
) -> Float[Array, "1"]:
|
|
54
61
|
r"""
|
|
55
|
-
Evaluate the dynamic loss at $(t,x)$.
|
|
62
|
+
Evaluate the dynamic loss at $(t, x)$.
|
|
56
63
|
|
|
57
64
|
Parameters
|
|
58
65
|
---------
|
|
59
|
-
|
|
60
|
-
A time point
|
|
61
|
-
|
|
62
|
-
A point in $\Omega$.
|
|
66
|
+
t_x
|
|
67
|
+
A jnp array containing the concatenation of a time point
|
|
68
|
+
and a point in $\Omega$
|
|
63
69
|
u
|
|
64
70
|
The PINN
|
|
65
71
|
params
|
|
@@ -70,28 +76,31 @@ class FisherKPP(PDENonStatio):
|
|
|
70
76
|
"""
|
|
71
77
|
if isinstance(u, PINN):
|
|
72
78
|
# Note that the last dim of u is nec. 1
|
|
73
|
-
u_ = lambda
|
|
79
|
+
u_ = lambda t_x: u(t_x, params)[0]
|
|
74
80
|
|
|
75
|
-
du_dt = grad(u_
|
|
81
|
+
du_dt = grad(u_)(t_x)[0]
|
|
76
82
|
|
|
77
|
-
lap =
|
|
83
|
+
lap = laplacian_rev(t_x, u, params, eq_type=u.eq_type)[..., None]
|
|
78
84
|
|
|
79
85
|
return du_dt + self.Tmax * (
|
|
80
86
|
-params.eq_params["D"] * lap
|
|
81
|
-
- u(
|
|
82
|
-
* (params.eq_params["r"] - params.eq_params["g"] * u(
|
|
87
|
+
- u(t_x, params)
|
|
88
|
+
* (params.eq_params["r"] - params.eq_params["g"] * u(t_x, params))
|
|
83
89
|
)
|
|
84
90
|
if isinstance(u, SPINN):
|
|
91
|
+
s = jnp.zeros((1, self.dim_x + 1))
|
|
92
|
+
s = s.at[0].set(1.0)
|
|
93
|
+
v0 = jnp.repeat(s, t_x.shape[0], axis=0)
|
|
85
94
|
u_tx, du_dt = jax.jvp(
|
|
86
|
-
lambda
|
|
87
|
-
(
|
|
88
|
-
(
|
|
95
|
+
lambda t_x: u(t_x, params),
|
|
96
|
+
(t_x,),
|
|
97
|
+
(v0,),
|
|
89
98
|
)
|
|
90
|
-
lap =
|
|
99
|
+
lap = laplacian_fwd(t_x, u, params, eq_type=u.eq_type)
|
|
100
|
+
|
|
91
101
|
return du_dt + self.Tmax * (
|
|
92
102
|
-params.eq_params["D"] * lap
|
|
93
|
-
- u_tx
|
|
94
|
-
* (params.eq_params["r"][..., None] - params.eq_params["g"] * u_tx)
|
|
103
|
+
- u_tx * (params.eq_params["r"] - params.eq_params["g"] * u_tx)
|
|
95
104
|
)
|
|
96
105
|
raise ValueError("u is not among the recognized types (PINN or SPINN)")
|
|
97
106
|
|
|
@@ -181,9 +190,9 @@ class GeneralizedLotkaVolterra(ODE):
|
|
|
181
190
|
)
|
|
182
191
|
|
|
183
192
|
|
|
184
|
-
class
|
|
193
|
+
class BurgersEquation(PDENonStatio):
|
|
185
194
|
r"""
|
|
186
|
-
Return the
|
|
195
|
+
Return the Burgers dynamic loss term (in 1 space dimension):
|
|
187
196
|
|
|
188
197
|
$$
|
|
189
198
|
\frac{\partial}{\partial t} u(t,x) + u(t,x)\frac{\partial}{\partial x}
|
|
@@ -207,8 +216,7 @@ class BurgerEquation(PDENonStatio):
|
|
|
207
216
|
|
|
208
217
|
def equation(
|
|
209
218
|
self,
|
|
210
|
-
|
|
211
|
-
x: Float[Array, "dim"],
|
|
219
|
+
t_x: Float[Array, "1+dim"],
|
|
212
220
|
u: eqx.Module,
|
|
213
221
|
params: Params,
|
|
214
222
|
) -> Float[Array, "1"]:
|
|
@@ -216,44 +224,51 @@ class BurgerEquation(PDENonStatio):
|
|
|
216
224
|
Evaluate the dynamic loss at :math:`(t,x)`.
|
|
217
225
|
|
|
218
226
|
Parameters
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
A time point
|
|
222
|
-
|
|
223
|
-
A point in $\Omega$
|
|
227
|
+
----------
|
|
228
|
+
t_x
|
|
229
|
+
A jnp array containing the concatenation of a time point
|
|
230
|
+
and a point in $\Omega$
|
|
224
231
|
u
|
|
225
232
|
The PINN
|
|
226
233
|
params
|
|
227
234
|
The dictionary of parameters of the model.
|
|
228
235
|
"""
|
|
229
236
|
if isinstance(u, PINN):
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
1
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
return du_dt(t, x) + self.Tmax * (
|
|
240
|
-
u(t, x, params) * du_dx(t, x) - params.eq_params["nu"] * d2u_dx2(t, x)
|
|
237
|
+
u_ = lambda t_x: jnp.squeeze(u(t_x, params)[u.slice_solution])
|
|
238
|
+
du_dtx = grad(u_)
|
|
239
|
+
d2u_dx_dtx = grad(lambda t_x: du_dtx(t_x)[1])
|
|
240
|
+
du_dtx_values = du_dtx(t_x)
|
|
241
|
+
|
|
242
|
+
return du_dtx_values[0:1] + self.Tmax * (
|
|
243
|
+
u_(t_x) * du_dtx_values[1:2]
|
|
244
|
+
- params.eq_params["nu"] * d2u_dx_dtx(t_x)[1:2]
|
|
241
245
|
)
|
|
242
246
|
|
|
243
247
|
if isinstance(u, SPINN):
|
|
244
248
|
# d=2 JVP calls are expected since we have time and x
|
|
245
249
|
# then with a batch of size B, we then have Bd JVP calls
|
|
250
|
+
v0 = jnp.repeat(jnp.array([[1.0, 0.0]]), t_x.shape[0], axis=0)
|
|
251
|
+
v1 = jnp.repeat(jnp.array([[0.0, 1.0]]), t_x.shape[0], axis=0)
|
|
246
252
|
u_tx, du_dt = jax.jvp(
|
|
247
|
-
lambda
|
|
248
|
-
(
|
|
249
|
-
(
|
|
253
|
+
lambda t_x: u(t_x, params),
|
|
254
|
+
(t_x,),
|
|
255
|
+
(v0,),
|
|
256
|
+
)
|
|
257
|
+
_, du_dx = jax.jvp(
|
|
258
|
+
lambda t_x: u(t_x, params),
|
|
259
|
+
(t_x,),
|
|
260
|
+
(v1,),
|
|
250
261
|
)
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
262
|
+
# both calls above could be condensed into the one jacfwd below
|
|
263
|
+
# u_ = lambda t_x: u(t_x, params)
|
|
264
|
+
# J = jax.jacfwd(u_)(t_x)
|
|
265
|
+
|
|
266
|
+
du_dx_fun = lambda t_x: jax.jvp(
|
|
267
|
+
lambda t_x: u(t_x, params),
|
|
268
|
+
(t_x,),
|
|
269
|
+
(v1,),
|
|
255
270
|
)[1]
|
|
256
|
-
|
|
271
|
+
_, d2u_dx2 = jax.jvp(du_dx_fun, (t_x,), (v1,))
|
|
257
272
|
# Note that ones_like(x) works because x is Bx1 !
|
|
258
273
|
return du_dt + self.Tmax * (u_tx * du_dx - params.eq_params["nu"] * d2u_dx2)
|
|
259
274
|
raise ValueError("u is not among the recognized types (PINN or SPINN)")
|
|
@@ -297,8 +312,7 @@ class FPENonStatioLoss2D(PDENonStatio):
|
|
|
297
312
|
|
|
298
313
|
def equation(
|
|
299
314
|
self,
|
|
300
|
-
|
|
301
|
-
x: Float[Array, "dim"],
|
|
315
|
+
t_x: Float[Array, "1+dim"],
|
|
302
316
|
u: eqx.Module,
|
|
303
317
|
params: Params,
|
|
304
318
|
) -> Float[Array, "1"]:
|
|
@@ -307,10 +321,8 @@ class FPENonStatioLoss2D(PDENonStatio):
|
|
|
307
321
|
|
|
308
322
|
Parameters
|
|
309
323
|
---------
|
|
310
|
-
|
|
311
|
-
A
|
|
312
|
-
x
|
|
313
|
-
A point in $\Omega$
|
|
324
|
+
t_x
|
|
325
|
+
A collocation point in $I\times\Omega$
|
|
314
326
|
u
|
|
315
327
|
The PINN
|
|
316
328
|
params
|
|
@@ -321,114 +333,87 @@ class FPENonStatioLoss2D(PDENonStatio):
|
|
|
321
333
|
"""
|
|
322
334
|
if isinstance(u, PINN):
|
|
323
335
|
# Note that the last dim of u is nec. 1
|
|
324
|
-
u_ = lambda
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
lambda t, x: self.drift(t, x, params.eq_params)[0] * u_(t, x),
|
|
329
|
-
1,
|
|
330
|
-
)(
|
|
331
|
-
t, x
|
|
332
|
-
)[0:1]
|
|
333
|
-
+ grad(
|
|
334
|
-
lambda t, x: self.drift(t, x, params.eq_params)[1] * u_(t, x),
|
|
335
|
-
1,
|
|
336
|
-
)(t, x)[1:2]
|
|
337
|
-
)
|
|
336
|
+
u_ = lambda t_x: u(t_x, params)[0]
|
|
337
|
+
|
|
338
|
+
order_1_fun = lambda t_x: self.drift(t_x[1:], params.eq_params) * u_(t_x)
|
|
339
|
+
grad_order_1 = jnp.trace(jax.jacrev(order_1_fun)(t_x)[..., 1:])[None]
|
|
338
340
|
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
lambda t, x: u_(t, x)
|
|
351
|
-
* self.diffusion(t, x, params.eq_params)[1, 0],
|
|
352
|
-
1,
|
|
353
|
-
)(t, x)[1],
|
|
354
|
-
1,
|
|
355
|
-
)(t, x)[0:1]
|
|
356
|
-
+ grad(
|
|
357
|
-
lambda t, x: grad(
|
|
358
|
-
lambda t, x: u_(t, x)
|
|
359
|
-
* self.diffusion(t, x, params.eq_params)[0, 1],
|
|
360
|
-
1,
|
|
361
|
-
)(t, x)[0],
|
|
362
|
-
1,
|
|
363
|
-
)(t, x)[1:2]
|
|
364
|
-
+ grad(
|
|
365
|
-
lambda t, x: grad(
|
|
366
|
-
lambda t, x: u_(t, x)
|
|
367
|
-
* self.diffusion(t, x, params.eq_params)[1, 1],
|
|
368
|
-
1,
|
|
369
|
-
)(t, x)[1],
|
|
370
|
-
1,
|
|
371
|
-
)(t, x)[1:2]
|
|
341
|
+
order_2_fun = lambda t_x: self.diffusion(t_x[1:], params.eq_params) * u_(
|
|
342
|
+
t_x
|
|
343
|
+
)
|
|
344
|
+
grad_order_2_fun = lambda t_x: jax.jacrev(order_2_fun)(t_x)[..., 1:]
|
|
345
|
+
grad_grad_order_2 = (
|
|
346
|
+
jnp.trace(
|
|
347
|
+
jax.jacrev(lambda t_x: grad_order_2_fun(t_x)[0, :, 0])(t_x)[..., 1:]
|
|
348
|
+
)[None]
|
|
349
|
+
+ jnp.trace(
|
|
350
|
+
jax.jacrev(lambda t_x: grad_order_2_fun(t_x)[1, :, 1])(t_x)[..., 1:]
|
|
351
|
+
)[None]
|
|
372
352
|
)
|
|
353
|
+
# This is be a condensed form of the explicit which is less efficient
|
|
354
|
+
# since 4 jacrev are called (as compared to 2)
|
|
355
|
+
# grad_order_2_fun = lambda t_x, i, j: jax.jacrev(order_2_fun)(t_x)[i, j, 1:]
|
|
356
|
+
# grad_grad_order_2 = (
|
|
357
|
+
# jax.jacrev(lambda t_x: grad_order_2_fun(t_x, 0, 0))(t_x)[0, 1] +
|
|
358
|
+
# jax.jacrev(lambda t_x: grad_order_2_fun(t_x, 1, 0))(t_x)[1, 1] +
|
|
359
|
+
# jax.jacrev(lambda t_x: grad_order_2_fun(t_x, 0, 1))(t_x)[0, 2] +
|
|
360
|
+
# jax.jacrev(lambda t_x: grad_order_2_fun(t_x, 1, 1))(t_x)[1, 2]
|
|
361
|
+
# )[None]
|
|
373
362
|
|
|
374
|
-
du_dt = grad(u_
|
|
363
|
+
du_dt = grad(u_)(t_x)[0:1]
|
|
375
364
|
|
|
376
|
-
return -du_dt + self.Tmax * (-
|
|
365
|
+
return -du_dt + self.Tmax * (-grad_order_1 + grad_grad_order_2)
|
|
377
366
|
|
|
378
367
|
if isinstance(u, SPINN):
|
|
379
|
-
|
|
368
|
+
v0 = jnp.repeat(jnp.array([[1.0, 0.0, 0.0]]), t_x.shape[0], axis=0)
|
|
380
369
|
_, du_dt = jax.jvp(
|
|
381
|
-
lambda
|
|
382
|
-
(
|
|
383
|
-
(
|
|
370
|
+
lambda t_x: u(t_x, params),
|
|
371
|
+
(t_x,),
|
|
372
|
+
(v0,),
|
|
384
373
|
)
|
|
385
374
|
|
|
386
375
|
# in forward AD we do not have the results for all the input
|
|
387
376
|
# dimension at once (as it is the case with grad), we then write
|
|
388
377
|
# two jvp calls
|
|
389
|
-
|
|
390
|
-
|
|
378
|
+
v1 = jnp.repeat(jnp.array([[0.0, 1.0, 0.0]]), t_x.shape[0], axis=0)
|
|
379
|
+
v2 = jnp.repeat(jnp.array([[0.0, 0.0, 1.0]]), t_x.shape[0], axis=0)
|
|
391
380
|
_, dau_dx1 = jax.jvp(
|
|
392
|
-
lambda
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
(
|
|
381
|
+
lambda t_x: self.drift(get_grid(t_x[:, 1:]), params.eq_params)[
|
|
382
|
+
None, ..., 0:1
|
|
383
|
+
]
|
|
384
|
+
* u(t_x, params),
|
|
385
|
+
(t_x,),
|
|
386
|
+
(v1,),
|
|
396
387
|
)
|
|
397
388
|
_, dau_dx2 = jax.jvp(
|
|
398
|
-
lambda
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
(
|
|
389
|
+
lambda t_x: self.drift(get_grid(t_x[:, 1:]), params.eq_params)[
|
|
390
|
+
None, ..., 1:2
|
|
391
|
+
]
|
|
392
|
+
* u(t_x, params),
|
|
393
|
+
(t_x,),
|
|
394
|
+
(v2,),
|
|
402
395
|
)
|
|
403
396
|
|
|
404
|
-
dsu_dx1_fun = lambda
|
|
405
|
-
lambda
|
|
406
|
-
|
|
407
|
-
]
|
|
408
|
-
* u(
|
|
409
|
-
(
|
|
410
|
-
(
|
|
397
|
+
dsu_dx1_fun = lambda t_x, i, j: jax.jvp(
|
|
398
|
+
lambda t_x: self.diffusion(
|
|
399
|
+
get_grid(t_x[:, 1:]), params.eq_params, i, j
|
|
400
|
+
)[None, None, None, None]
|
|
401
|
+
* u(t_x, params),
|
|
402
|
+
(t_x,),
|
|
403
|
+
(v1,),
|
|
411
404
|
)[1]
|
|
412
|
-
dsu_dx2_fun = lambda
|
|
413
|
-
lambda
|
|
414
|
-
|
|
415
|
-
]
|
|
416
|
-
* u(
|
|
417
|
-
(
|
|
418
|
-
(
|
|
405
|
+
dsu_dx2_fun = lambda t_x, i, j: jax.jvp(
|
|
406
|
+
lambda t_x: self.diffusion(
|
|
407
|
+
get_grid(t_x[:, 1:]), params.eq_params, i, j
|
|
408
|
+
)[None, None, None, None]
|
|
409
|
+
* u(t_x, params),
|
|
410
|
+
(t_x,),
|
|
411
|
+
(v2,),
|
|
419
412
|
)[1]
|
|
420
|
-
_, d2su_dx12 = jax.jvp(
|
|
421
|
-
|
|
422
|
-
)
|
|
423
|
-
_,
|
|
424
|
-
lambda x: dsu_dx1_fun(x, 0, 1), (x,), (tangent_vec_1,)
|
|
425
|
-
)
|
|
426
|
-
_, d2su_dx22 = jax.jvp(
|
|
427
|
-
lambda x: dsu_dx2_fun(x, 1, 1), (x,), (tangent_vec_1,)
|
|
428
|
-
)
|
|
429
|
-
_, d2su_dx2dx1 = jax.jvp(
|
|
430
|
-
lambda x: dsu_dx2_fun(x, 1, 0), (x,), (tangent_vec_0,)
|
|
431
|
-
)
|
|
413
|
+
_, d2su_dx12 = jax.jvp(lambda t_x: dsu_dx1_fun(t_x, 0, 0), (t_x,), (v1,))
|
|
414
|
+
_, d2su_dx1dx2 = jax.jvp(lambda t_x: dsu_dx1_fun(t_x, 0, 1), (t_x,), (v2,))
|
|
415
|
+
_, d2su_dx22 = jax.jvp(lambda t_x: dsu_dx2_fun(t_x, 1, 1), (t_x,), (v2,))
|
|
416
|
+
_, d2su_dx2dx1 = jax.jvp(lambda t_x: dsu_dx2_fun(t_x, 1, 0), (t_x,), (v1,))
|
|
432
417
|
|
|
433
418
|
return -du_dt + self.Tmax * (
|
|
434
419
|
-(dau_dx1 + dau_dx2)
|
|
@@ -474,14 +459,12 @@ class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
|
|
|
474
459
|
heterogeneity for no parameters.
|
|
475
460
|
"""
|
|
476
461
|
|
|
477
|
-
def drift(self,
|
|
462
|
+
def drift(self, x, eq_params):
|
|
478
463
|
r"""
|
|
479
464
|
Return the drift term
|
|
480
465
|
|
|
481
466
|
Parameters
|
|
482
467
|
----------
|
|
483
|
-
t
|
|
484
|
-
A time point
|
|
485
468
|
x
|
|
486
469
|
A point in $\Omega$
|
|
487
470
|
eq_params
|
|
@@ -489,15 +472,13 @@ class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
|
|
|
489
472
|
"""
|
|
490
473
|
return eq_params["alpha"] * (eq_params["mu"] - x)
|
|
491
474
|
|
|
492
|
-
def sigma_mat(self,
|
|
475
|
+
def sigma_mat(self, x, eq_params):
|
|
493
476
|
r"""
|
|
494
477
|
Return the square root of the diffusion tensor in the sense of the outer
|
|
495
478
|
product used to create the diffusion tensor
|
|
496
479
|
|
|
497
480
|
Parameters
|
|
498
481
|
----------
|
|
499
|
-
t
|
|
500
|
-
A time point
|
|
501
482
|
x
|
|
502
483
|
A point in $\Omega$
|
|
503
484
|
eq_params
|
|
@@ -506,15 +487,13 @@ class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
|
|
|
506
487
|
|
|
507
488
|
return jnp.diag(eq_params["sigma"])
|
|
508
489
|
|
|
509
|
-
def diffusion(self,
|
|
490
|
+
def diffusion(self, x, eq_params, i=None, j=None):
|
|
510
491
|
r"""
|
|
511
492
|
Return the computation of the diffusion tensor term in 2D (or
|
|
512
493
|
higher)
|
|
513
494
|
|
|
514
495
|
Parameters
|
|
515
496
|
----------
|
|
516
|
-
t
|
|
517
|
-
A time point
|
|
518
497
|
x
|
|
519
498
|
A point in $\Omega$
|
|
520
499
|
eq_params
|
|
@@ -523,14 +502,14 @@ class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
|
|
|
523
502
|
if i is None or j is None:
|
|
524
503
|
return 0.5 * (
|
|
525
504
|
jnp.matmul(
|
|
526
|
-
self.sigma_mat(
|
|
527
|
-
jnp.transpose(self.sigma_mat(
|
|
505
|
+
self.sigma_mat(x, eq_params),
|
|
506
|
+
jnp.transpose(self.sigma_mat(x, eq_params)),
|
|
528
507
|
)
|
|
529
508
|
)
|
|
530
509
|
return 0.5 * (
|
|
531
510
|
jnp.matmul(
|
|
532
|
-
self.sigma_mat(
|
|
533
|
-
jnp.transpose(self.sigma_mat(
|
|
511
|
+
self.sigma_mat(x, eq_params),
|
|
512
|
+
jnp.transpose(self.sigma_mat(x, eq_params)),
|
|
534
513
|
)[i, j]
|
|
535
514
|
)
|
|
536
515
|
|
|
@@ -591,12 +570,12 @@ class MassConservation2DStatio(PDEStatio):
|
|
|
591
570
|
if isinstance(u_dict[self.nn_key], PINN):
|
|
592
571
|
u = u_dict[self.nn_key]
|
|
593
572
|
|
|
594
|
-
return
|
|
573
|
+
return divergence_rev(x, u, params)[..., None]
|
|
595
574
|
|
|
596
575
|
if isinstance(u_dict[self.nn_key], SPINN):
|
|
597
576
|
u = u_dict[self.nn_key]
|
|
598
577
|
|
|
599
|
-
return
|
|
578
|
+
return divergence_fwd(x, u, params)[..., None]
|
|
600
579
|
raise ValueError("u is not among the recognized types (PINN or SPINN)")
|
|
601
580
|
|
|
602
581
|
|
|
@@ -614,12 +593,14 @@ class NavierStokes2DStatio(PDEStatio):
|
|
|
614
593
|
|
|
615
594
|
|
|
616
595
|
$$
|
|
617
|
-
\begin{pmatrix}u_x\frac{\partial}{\partial x} u_x +
|
|
596
|
+
\begin{pmatrix}u_x\frac{\partial}{\partial x} u_x +
|
|
597
|
+
u_y\frac{\partial}{\partial y} u_x, \\
|
|
618
598
|
u_x\frac{\partial}{\partial x} u_y + u_y\frac{\partial}{\partial y} u_y \end{pmatrix} +
|
|
619
|
-
\frac{1}{\rho} \begin{pmatrix} \frac{\partial}{\partial x} p \\ \frac{\partial}{\partial y} p \end{pmatrix}
|
|
599
|
+
\frac{1}{\rho} \begin{pmatrix} \frac{\partial}{\partial x} p, \\ \frac{\partial}{\partial y} p \end{pmatrix}
|
|
620
600
|
- \theta
|
|
621
601
|
\begin{pmatrix}
|
|
622
|
-
\frac{\partial^2}{\partial x^2} u_x + \frac{\partial^2}{\partial y^2}
|
|
602
|
+
\frac{\partial^2}{\partial x^2} u_x + \frac{\partial^2}{\partial y^2}
|
|
603
|
+
u_x, \\
|
|
623
604
|
\frac{\partial^2}{\partial x^2} u_y + \frac{\partial^2}{\partial y^2} u_y
|
|
624
605
|
\end{pmatrix} = 0,
|
|
625
606
|
$$
|
|
@@ -680,12 +661,12 @@ class NavierStokes2DStatio(PDEStatio):
|
|
|
680
661
|
if isinstance(u_dict[self.u_key], PINN):
|
|
681
662
|
u = u_dict[self.u_key]
|
|
682
663
|
|
|
683
|
-
u_dot_nabla_x_u = _u_dot_nabla_times_u_rev(
|
|
664
|
+
u_dot_nabla_x_u = _u_dot_nabla_times_u_rev(x, u, u_params)
|
|
684
665
|
|
|
685
666
|
p = lambda x: u_dict[self.p_key](x, p_params)
|
|
686
667
|
jac_p = jax.jacrev(p, 0)(x) # compute the gradient
|
|
687
668
|
|
|
688
|
-
vec_laplacian_u =
|
|
669
|
+
vec_laplacian_u = vectorial_laplacian_rev(x, u, u_params, dim_out=2)
|
|
689
670
|
|
|
690
671
|
# dynamic loss on x axis
|
|
691
672
|
result_x = (
|
|
@@ -707,7 +688,7 @@ class NavierStokes2DStatio(PDEStatio):
|
|
|
707
688
|
if isinstance(u_dict[self.u_key], SPINN):
|
|
708
689
|
u = u_dict[self.u_key]
|
|
709
690
|
|
|
710
|
-
u_dot_nabla_x_u = _u_dot_nabla_times_u_fwd(
|
|
691
|
+
u_dot_nabla_x_u = _u_dot_nabla_times_u_fwd(x, u, u_params)
|
|
711
692
|
|
|
712
693
|
p = lambda x: u_dict[self.p_key](x, p_params)
|
|
713
694
|
|
|
@@ -716,11 +697,7 @@ class NavierStokes2DStatio(PDEStatio):
|
|
|
716
697
|
tangent_vec_1 = jnp.repeat(jnp.array([0.0, 1.0])[None], x.shape[0], axis=0)
|
|
717
698
|
_, dp_dy = jax.jvp(p, (x,), (tangent_vec_1,))
|
|
718
699
|
|
|
719
|
-
vec_laplacian_u =
|
|
720
|
-
_vectorial_laplacian(None, x, u, u_params, u_vec_ndim=2),
|
|
721
|
-
source=0,
|
|
722
|
-
destination=-1,
|
|
723
|
-
)
|
|
700
|
+
vec_laplacian_u = vectorial_laplacian_fwd(x, u, u_params, dim_out=2)
|
|
724
701
|
|
|
725
702
|
# dynamic loss on x axis
|
|
726
703
|
result_x = (
|