jinns 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +116 -189
- jinns/loss/_DynamicLossAbstract.py +45 -68
- jinns/loss/_LossODE.py +71 -336
- jinns/loss/_LossPDE.py +176 -513
- jinns/loss/__init__.py +28 -6
- jinns/loss/_abstract_loss.py +15 -0
- jinns/loss/_boundary_conditions.py +22 -21
- jinns/loss/_loss_utils.py +98 -173
- jinns/loss/_loss_weights.py +12 -44
- jinns/loss/_operators.py +84 -76
- jinns/nn/__init__.py +22 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +434 -0
- jinns/nn/_mlp.py +217 -0
- jinns/nn/_pinn.py +204 -0
- jinns/nn/_ppinn.py +239 -0
- jinns/{utils → nn}/_save_load.py +39 -53
- jinns/nn/_spinn.py +123 -0
- jinns/nn/_spinn_mlp.py +202 -0
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +38 -37
- jinns/solver/_rar.py +82 -65
- jinns/solver/_solve.py +111 -71
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -5
- jinns/utils/_containers.py +12 -9
- jinns/utils/_types.py +11 -57
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/METADATA +11 -10
- jinns-1.4.0.dist-info/RECORD +53 -0
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns/utils/_hyperpinn.py +0 -420
- jinns/utils/_pinn.py +0 -324
- jinns/utils/_ppinn.py +0 -227
- jinns/utils/_spinn.py +0 -249
- jinns-1.2.0.dist-info/RECORD +0 -41
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/loss/_operators.py
CHANGED
|
@@ -2,24 +2,42 @@
|
|
|
2
2
|
Implements diverse operators for dynamic losses
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from __future__ import (
|
|
6
|
+
annotations,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
from typing import Literal, cast, Callable
|
|
6
10
|
|
|
7
11
|
import jax
|
|
8
12
|
import jax.numpy as jnp
|
|
9
13
|
from jax import grad
|
|
10
|
-
import equinox as eqx
|
|
11
14
|
from jaxtyping import Float, Array
|
|
12
|
-
from jinns.utils._pinn import PINN
|
|
13
|
-
from jinns.utils._spinn import SPINN
|
|
14
15
|
from jinns.parameters._params import Params
|
|
16
|
+
from jinns.nn._abstract_pinn import AbstractPINN
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _get_eq_type(
|
|
20
|
+
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
21
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None,
|
|
22
|
+
) -> Literal["nonstatio_PDE", "statio_PDE"]:
|
|
23
|
+
"""
|
|
24
|
+
But we filter out ODE from eq_type because we only have operators that does
|
|
25
|
+
not work with ODEs so far
|
|
26
|
+
"""
|
|
27
|
+
if isinstance(u, AbstractPINN):
|
|
28
|
+
assert u.eq_type != "ODE", "Cannot compute the operator for ODE PINNs"
|
|
29
|
+
return u.eq_type
|
|
30
|
+
if eq_type is None:
|
|
31
|
+
raise ValueError("eq_type could not be set!")
|
|
32
|
+
return eq_type
|
|
15
33
|
|
|
16
34
|
|
|
17
35
|
def divergence_rev(
|
|
18
|
-
inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
|
|
19
|
-
u:
|
|
20
|
-
params: Params,
|
|
21
|
-
eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
|
|
22
|
-
) ->
|
|
36
|
+
inputs: Float[Array, " dim"] | Float[Array, " 1+dim"],
|
|
37
|
+
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
38
|
+
params: Params[Array],
|
|
39
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
|
|
40
|
+
) -> Float[Array, " "]:
|
|
23
41
|
r"""
|
|
24
42
|
Compute the divergence of a vector field $\mathbf{u}$, i.e.,
|
|
25
43
|
$\nabla_\mathbf{x} \cdot \mathbf{u}(\mathrm{inputs})$ with $\mathbf{u}$ a vector
|
|
@@ -43,13 +61,7 @@ def divergence_rev(
|
|
|
43
61
|
can know that by inspecting the `u` argument (PINN object). But if `u` is
|
|
44
62
|
a function, we must set this attribute.
|
|
45
63
|
"""
|
|
46
|
-
|
|
47
|
-
try:
|
|
48
|
-
eq_type = u.eq_type
|
|
49
|
-
except AttributeError:
|
|
50
|
-
pass # use the value passed as argument
|
|
51
|
-
if eq_type is None:
|
|
52
|
-
raise ValueError("eq_type could not be set!")
|
|
64
|
+
eq_type = _get_eq_type(u, eq_type)
|
|
53
65
|
|
|
54
66
|
def scan_fun(_, i):
|
|
55
67
|
if eq_type == "nonstatio_PDE":
|
|
@@ -72,11 +84,11 @@ def divergence_rev(
|
|
|
72
84
|
|
|
73
85
|
|
|
74
86
|
def divergence_fwd(
|
|
75
|
-
inputs: Float[Array, "batch_size dim"] | Float[Array, "batch_size 1+dim"],
|
|
76
|
-
u:
|
|
77
|
-
params: Params,
|
|
78
|
-
eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
|
|
79
|
-
) -> Float[Array, "batch_size * (1+dim) 1"] | Float[Array, "batch_size * (dim) 1"]:
|
|
87
|
+
inputs: Float[Array, " batch_size dim"] | Float[Array, " batch_size 1+dim"],
|
|
88
|
+
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
89
|
+
params: Params[Array],
|
|
90
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
|
|
91
|
+
) -> Float[Array, " batch_size * (1+dim) 1"] | Float[Array, " batch_size * (dim) 1"]:
|
|
80
92
|
r"""
|
|
81
93
|
Compute the divergence of a **batched** vector field $\mathbf{u}$, i.e.,
|
|
82
94
|
$\nabla_\mathbf{x} \cdot \mathbf{u}(\mathbf{x})$ with $\mathbf{u}$ a vector
|
|
@@ -105,13 +117,7 @@ def divergence_fwd(
|
|
|
105
117
|
can know that by inspecting the `u` argument (PINN object). But if `u` is
|
|
106
118
|
a function, we must set this attribute.
|
|
107
119
|
"""
|
|
108
|
-
|
|
109
|
-
try:
|
|
110
|
-
eq_type = u.eq_type
|
|
111
|
-
except AttributeError:
|
|
112
|
-
pass # use the value passed as argument
|
|
113
|
-
if eq_type is None:
|
|
114
|
-
raise ValueError("eq_type could not be set!")
|
|
120
|
+
eq_type = _get_eq_type(u, eq_type)
|
|
115
121
|
|
|
116
122
|
def scan_fun(_, i):
|
|
117
123
|
if eq_type == "nonstatio_PDE":
|
|
@@ -144,12 +150,12 @@ def divergence_fwd(
|
|
|
144
150
|
|
|
145
151
|
|
|
146
152
|
def laplacian_rev(
|
|
147
|
-
inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
|
|
148
|
-
u:
|
|
149
|
-
params: Params,
|
|
153
|
+
inputs: Float[Array, " dim"] | Float[Array, " 1+dim"],
|
|
154
|
+
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
155
|
+
params: Params[Array],
|
|
150
156
|
method: Literal["trace_hessian_x", "trace_hessian_t_x", "loop"] = "trace_hessian_x",
|
|
151
|
-
eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
|
|
152
|
-
) ->
|
|
157
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
|
|
158
|
+
) -> Float[Array, " "]:
|
|
153
159
|
r"""
|
|
154
160
|
Compute the Laplacian of a scalar field $u$ from $\mathbb{R}^d$
|
|
155
161
|
to $\mathbb{R}$ or from $\mathbb{R}^{1+d}$ to $\mathbb{R}$, i.e., this
|
|
@@ -182,13 +188,7 @@ def laplacian_rev(
|
|
|
182
188
|
can know that by inspecting the `u` argument (PINN object). But if `u` is
|
|
183
189
|
a function, we must set this attribute.
|
|
184
190
|
"""
|
|
185
|
-
|
|
186
|
-
try:
|
|
187
|
-
eq_type = u.eq_type
|
|
188
|
-
except AttributeError:
|
|
189
|
-
pass # use the value passed as argument
|
|
190
|
-
if eq_type is None:
|
|
191
|
-
raise ValueError("eq_type could not be set!")
|
|
191
|
+
eq_type = _get_eq_type(u, eq_type)
|
|
192
192
|
|
|
193
193
|
if method == "trace_hessian_x":
|
|
194
194
|
# NOTE we afford a concatenate here to avoid computing Hessian elements for
|
|
@@ -228,16 +228,12 @@ def laplacian_rev(
|
|
|
228
228
|
if eq_type == "nonstatio_PDE":
|
|
229
229
|
d2u_dxi2 = grad(
|
|
230
230
|
lambda inputs: grad(u_)(inputs)[1 + i],
|
|
231
|
-
)(
|
|
232
|
-
inputs
|
|
233
|
-
)[1 + i]
|
|
231
|
+
)(inputs)[1 + i]
|
|
234
232
|
else:
|
|
235
233
|
d2u_dxi2 = grad(
|
|
236
234
|
lambda inputs: grad(u_, 0)(inputs)[i],
|
|
237
235
|
0,
|
|
238
|
-
)(
|
|
239
|
-
inputs
|
|
240
|
-
)[i]
|
|
236
|
+
)(inputs)[i]
|
|
241
237
|
return _, d2u_dxi2
|
|
242
238
|
|
|
243
239
|
if eq_type == "nonstatio_PDE":
|
|
@@ -253,12 +249,12 @@ def laplacian_rev(
|
|
|
253
249
|
|
|
254
250
|
|
|
255
251
|
def laplacian_fwd(
|
|
256
|
-
inputs: Float[Array, "batch_size 1+dim"] | Float[Array, "batch_size dim"],
|
|
257
|
-
u:
|
|
258
|
-
params: Params,
|
|
252
|
+
inputs: Float[Array, " batch_size 1+dim"] | Float[Array, " batch_size dim"],
|
|
253
|
+
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
254
|
+
params: Params[Array],
|
|
259
255
|
method: Literal["trace_hessian_t_x", "trace_hessian_x", "loop"] = "loop",
|
|
260
|
-
eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
|
|
261
|
-
) -> Float[Array, "batch_size * (1+dim) 1"] | Float[Array, "batch_size * (dim) 1"]:
|
|
256
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
|
|
257
|
+
) -> Float[Array, " batch_size * (1+dim) 1"] | Float[Array, " batch_size * (dim) 1"]:
|
|
262
258
|
r"""
|
|
263
259
|
Compute the Laplacian of a **batched** scalar field $u$
|
|
264
260
|
from $\mathbb{R}^{b\times d}$ to $\mathbb{R}^{b\times b}$ or
|
|
@@ -301,13 +297,7 @@ def laplacian_fwd(
|
|
|
301
297
|
can know that by inspecting the `u` argument (PINN object). But if `u` is
|
|
302
298
|
a function, we must set this attribute.
|
|
303
299
|
"""
|
|
304
|
-
|
|
305
|
-
try:
|
|
306
|
-
eq_type = u.eq_type
|
|
307
|
-
except AttributeError:
|
|
308
|
-
pass # use the value passed as argument
|
|
309
|
-
if eq_type is None:
|
|
310
|
-
raise ValueError("eq_type could not be set!")
|
|
300
|
+
eq_type = _get_eq_type(u, eq_type)
|
|
311
301
|
|
|
312
302
|
if method == "loop":
|
|
313
303
|
|
|
@@ -400,11 +390,12 @@ def laplacian_fwd(
|
|
|
400
390
|
|
|
401
391
|
|
|
402
392
|
def vectorial_laplacian_rev(
|
|
403
|
-
inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
|
|
404
|
-
u:
|
|
405
|
-
params: Params,
|
|
406
|
-
dim_out: int = None,
|
|
407
|
-
|
|
393
|
+
inputs: Float[Array, " dim"] | Float[Array, " 1+dim"],
|
|
394
|
+
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
395
|
+
params: Params[Array],
|
|
396
|
+
dim_out: int | None = None,
|
|
397
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
|
|
398
|
+
) -> Float[Array, " dim_out"]:
|
|
408
399
|
r"""
|
|
409
400
|
Compute the vectorial Laplacian of a vector field $\mathbf{u}$ from
|
|
410
401
|
$\mathbb{R}^d$ to $\mathbb{R}^n$ or from $\mathbb{R}^{1+d}$ to
|
|
@@ -428,7 +419,12 @@ def vectorial_laplacian_rev(
|
|
|
428
419
|
dim_out
|
|
429
420
|
Dimension of the vector $\mathbf{u}(\mathrm{inputs})$. This needs to be
|
|
430
421
|
provided if it is different than that of $\mathrm{inputs}$.
|
|
422
|
+
eq_type
|
|
423
|
+
whether we consider a stationary or non stationary PINN. Most often we
|
|
424
|
+
can know that by inspecting the `u` argument (PINN object). But if `u` is
|
|
425
|
+
a function, we must set this attribute.
|
|
431
426
|
"""
|
|
427
|
+
eq_type = _get_eq_type(u, eq_type)
|
|
432
428
|
if dim_out is None:
|
|
433
429
|
dim_out = inputs.shape[0]
|
|
434
430
|
|
|
@@ -437,7 +433,9 @@ def vectorial_laplacian_rev(
|
|
|
437
433
|
# each of these components
|
|
438
434
|
# Note the jnp.expand_dims call
|
|
439
435
|
uj = lambda inputs, params: jnp.expand_dims(u(inputs, params)[j], axis=-1)
|
|
440
|
-
lap_on_j = laplacian_rev(
|
|
436
|
+
lap_on_j = laplacian_rev(
|
|
437
|
+
inputs, cast(AbstractPINN, uj), params, eq_type=eq_type
|
|
438
|
+
)
|
|
441
439
|
|
|
442
440
|
return _, lap_on_j
|
|
443
441
|
|
|
@@ -446,11 +444,12 @@ def vectorial_laplacian_rev(
|
|
|
446
444
|
|
|
447
445
|
|
|
448
446
|
def vectorial_laplacian_fwd(
|
|
449
|
-
inputs: Float[Array, "batch_size dim"] | Float[Array, "batch_size 1+dim"],
|
|
450
|
-
u:
|
|
451
|
-
params: Params,
|
|
452
|
-
dim_out: int = None,
|
|
453
|
-
|
|
447
|
+
inputs: Float[Array, " batch_size dim"] | Float[Array, " batch_size 1+dim"],
|
|
448
|
+
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
449
|
+
params: Params[Array],
|
|
450
|
+
dim_out: int | None = None,
|
|
451
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
|
|
452
|
+
) -> Float[Array, " batch_size * (1+dim) n"] | Float[Array, " batch_size * (dim) n"]:
|
|
454
453
|
r"""
|
|
455
454
|
Compute the vectorial Laplacian of a vector field $\mathbf{u}$ when
|
|
456
455
|
`u` is a SPINN, in this case, it corresponds to a vector
|
|
@@ -476,7 +475,12 @@ def vectorial_laplacian_fwd(
|
|
|
476
475
|
dim_out
|
|
477
476
|
the value of the output dimension ($n$ in the formula above). Must be
|
|
478
477
|
set if different from $d$.
|
|
478
|
+
eq_type
|
|
479
|
+
whether we consider a stationary or non stationary PINN. Most often we
|
|
480
|
+
can know that by inspecting the `u` argument (PINN object). But if `u` is
|
|
481
|
+
a function, we must set this attribute.
|
|
479
482
|
"""
|
|
483
|
+
eq_type = _get_eq_type(u, eq_type)
|
|
480
484
|
if dim_out is None:
|
|
481
485
|
dim_out = inputs.shape[0]
|
|
482
486
|
|
|
@@ -485,7 +489,9 @@ def vectorial_laplacian_fwd(
|
|
|
485
489
|
# each of these components
|
|
486
490
|
# Note the expand_dims
|
|
487
491
|
uj = lambda inputs, params: jnp.expand_dims(u(inputs, params)[..., j], axis=-1)
|
|
488
|
-
lap_on_j = laplacian_fwd(
|
|
492
|
+
lap_on_j = laplacian_fwd(
|
|
493
|
+
inputs, cast(AbstractPINN, uj), params, eq_type=eq_type
|
|
494
|
+
)
|
|
489
495
|
|
|
490
496
|
return _, lap_on_j
|
|
491
497
|
|
|
@@ -494,8 +500,10 @@ def vectorial_laplacian_fwd(
|
|
|
494
500
|
|
|
495
501
|
|
|
496
502
|
def _u_dot_nabla_times_u_rev(
|
|
497
|
-
x: Float[Array, "2"],
|
|
498
|
-
|
|
503
|
+
x: Float[Array, " 2"],
|
|
504
|
+
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
505
|
+
params: Params[Array],
|
|
506
|
+
) -> Float[Array, " 2"]:
|
|
499
507
|
r"""
|
|
500
508
|
Implement $((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})$ for
|
|
501
509
|
$\mathbf{x}$ of arbitrary
|
|
@@ -524,10 +532,10 @@ def _u_dot_nabla_times_u_rev(
|
|
|
524
532
|
|
|
525
533
|
|
|
526
534
|
def _u_dot_nabla_times_u_fwd(
|
|
527
|
-
x: Float[Array, "batch_size 2"],
|
|
528
|
-
u:
|
|
529
|
-
params: Params,
|
|
530
|
-
) -> Float[Array, "batch_size batch_size 2"]:
|
|
535
|
+
x: Float[Array, " batch_size 2"],
|
|
536
|
+
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
537
|
+
params: Params[Array],
|
|
538
|
+
) -> Float[Array, " batch_size batch_size 2"]:
|
|
531
539
|
r"""
|
|
532
540
|
Implement :math:`((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})` for
|
|
533
541
|
:math:`\mathbf{x}` of arbitrary dimension **with a batch dimension**.
|
jinns/nn/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from ._save_load import save_pinn, load_pinn
|
|
2
|
+
from ._abstract_pinn import AbstractPINN
|
|
3
|
+
from ._pinn import PINN
|
|
4
|
+
from ._spinn import SPINN
|
|
5
|
+
from ._mlp import PINN_MLP, MLP
|
|
6
|
+
from ._spinn_mlp import SPINN_MLP, SMLP
|
|
7
|
+
from ._hyperpinn import HyperPINN
|
|
8
|
+
from ._ppinn import PPINN_MLP
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"save_pinn",
|
|
12
|
+
"load_pinn",
|
|
13
|
+
"AbstractPINN",
|
|
14
|
+
"PINN",
|
|
15
|
+
"SPINN",
|
|
16
|
+
"PINN_MLP",
|
|
17
|
+
"MLP",
|
|
18
|
+
"SPINN_MLP",
|
|
19
|
+
"SMLP",
|
|
20
|
+
"HyperPINN",
|
|
21
|
+
"PPINN_MLP",
|
|
22
|
+
]
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from typing import Literal, Any
|
|
3
|
+
from jaxtyping import Array
|
|
4
|
+
import equinox as eqx
|
|
5
|
+
|
|
6
|
+
from jinns.nn._utils import _PyTree_to_Params
|
|
7
|
+
from jinns.parameters._params import Params
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AbstractPINN(eqx.Module):
|
|
11
|
+
"""
|
|
12
|
+
Basically just a way to add a __call__ to an eqx.Module.
|
|
13
|
+
The way to go for correct type hints apparently
|
|
14
|
+
https://github.com/patrick-kidger/equinox/issues/1002 + https://docs.kidger.site/equinox/pattern/
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
eq_type: eqx.AbstractVar[Literal["ODE", "statio_PDE", "nonstatio_PDE"]]
|
|
18
|
+
|
|
19
|
+
@abc.abstractmethod
|
|
20
|
+
@_PyTree_to_Params
|
|
21
|
+
def __call__(self, inputs: Any, params: Params[Array], *args, **kwargs) -> Any:
|
|
22
|
+
pass
|