jinns 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +114 -187
- jinns/loss/_DynamicLossAbstract.py +45 -68
- jinns/loss/_LossODE.py +71 -336
- jinns/loss/_LossPDE.py +146 -520
- jinns/loss/__init__.py +28 -6
- jinns/loss/_abstract_loss.py +15 -0
- jinns/loss/_boundary_conditions.py +20 -19
- jinns/loss/_loss_utils.py +78 -159
- jinns/loss/_loss_weights.py +12 -44
- jinns/loss/_operators.py +84 -74
- jinns/nn/__init__.py +15 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +94 -57
- jinns/nn/_mlp.py +50 -25
- jinns/nn/_pinn.py +33 -19
- jinns/nn/_ppinn.py +70 -34
- jinns/nn/_save_load.py +21 -51
- jinns/nn/_spinn.py +33 -16
- jinns/nn/_spinn_mlp.py +28 -22
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +35 -34
- jinns/solver/_rar.py +80 -63
- jinns/solver/_solve.py +89 -63
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +12 -9
- jinns/utils/_types.py +11 -57
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/METADATA +4 -3
- jinns-1.4.0.dist-info/RECORD +53 -0
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns-1.3.0.dist-info/RECORD +0 -44
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/loss/_operators.py
CHANGED
|
@@ -2,22 +2,42 @@
|
|
|
2
2
|
Implements diverse operators for dynamic losses
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from __future__ import (
|
|
6
|
+
annotations,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
from typing import Literal, cast, Callable
|
|
6
10
|
|
|
7
11
|
import jax
|
|
8
12
|
import jax.numpy as jnp
|
|
9
13
|
from jax import grad
|
|
10
|
-
import equinox as eqx
|
|
11
14
|
from jaxtyping import Float, Array
|
|
12
15
|
from jinns.parameters._params import Params
|
|
16
|
+
from jinns.nn._abstract_pinn import AbstractPINN
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _get_eq_type(
|
|
20
|
+
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
21
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None,
|
|
22
|
+
) -> Literal["nonstatio_PDE", "statio_PDE"]:
|
|
23
|
+
"""
|
|
24
|
+
But we filter out ODE from eq_type because we only have operators that does
|
|
25
|
+
not work with ODEs so far
|
|
26
|
+
"""
|
|
27
|
+
if isinstance(u, AbstractPINN):
|
|
28
|
+
assert u.eq_type != "ODE", "Cannot compute the operator for ODE PINNs"
|
|
29
|
+
return u.eq_type
|
|
30
|
+
if eq_type is None:
|
|
31
|
+
raise ValueError("eq_type could not be set!")
|
|
32
|
+
return eq_type
|
|
13
33
|
|
|
14
34
|
|
|
15
35
|
def divergence_rev(
|
|
16
|
-
inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
|
|
17
|
-
u:
|
|
18
|
-
params: Params,
|
|
19
|
-
eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
|
|
20
|
-
) ->
|
|
36
|
+
inputs: Float[Array, " dim"] | Float[Array, " 1+dim"],
|
|
37
|
+
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
38
|
+
params: Params[Array],
|
|
39
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
|
|
40
|
+
) -> Float[Array, " "]:
|
|
21
41
|
r"""
|
|
22
42
|
Compute the divergence of a vector field $\mathbf{u}$, i.e.,
|
|
23
43
|
$\nabla_\mathbf{x} \cdot \mathbf{u}(\mathrm{inputs})$ with $\mathbf{u}$ a vector
|
|
@@ -41,13 +61,7 @@ def divergence_rev(
|
|
|
41
61
|
can know that by inspecting the `u` argument (PINN object). But if `u` is
|
|
42
62
|
a function, we must set this attribute.
|
|
43
63
|
"""
|
|
44
|
-
|
|
45
|
-
try:
|
|
46
|
-
eq_type = u.eq_type
|
|
47
|
-
except AttributeError:
|
|
48
|
-
pass # use the value passed as argument
|
|
49
|
-
if eq_type is None:
|
|
50
|
-
raise ValueError("eq_type could not be set!")
|
|
64
|
+
eq_type = _get_eq_type(u, eq_type)
|
|
51
65
|
|
|
52
66
|
def scan_fun(_, i):
|
|
53
67
|
if eq_type == "nonstatio_PDE":
|
|
@@ -70,11 +84,11 @@ def divergence_rev(
|
|
|
70
84
|
|
|
71
85
|
|
|
72
86
|
def divergence_fwd(
|
|
73
|
-
inputs: Float[Array, "batch_size dim"] | Float[Array, "batch_size 1+dim"],
|
|
74
|
-
u:
|
|
75
|
-
params: Params,
|
|
76
|
-
eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
|
|
77
|
-
) -> Float[Array, "batch_size * (1+dim) 1"] | Float[Array, "batch_size * (dim) 1"]:
|
|
87
|
+
inputs: Float[Array, " batch_size dim"] | Float[Array, " batch_size 1+dim"],
|
|
88
|
+
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
89
|
+
params: Params[Array],
|
|
90
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
|
|
91
|
+
) -> Float[Array, " batch_size * (1+dim) 1"] | Float[Array, " batch_size * (dim) 1"]:
|
|
78
92
|
r"""
|
|
79
93
|
Compute the divergence of a **batched** vector field $\mathbf{u}$, i.e.,
|
|
80
94
|
$\nabla_\mathbf{x} \cdot \mathbf{u}(\mathbf{x})$ with $\mathbf{u}$ a vector
|
|
@@ -103,13 +117,7 @@ def divergence_fwd(
|
|
|
103
117
|
can know that by inspecting the `u` argument (PINN object). But if `u` is
|
|
104
118
|
a function, we must set this attribute.
|
|
105
119
|
"""
|
|
106
|
-
|
|
107
|
-
try:
|
|
108
|
-
eq_type = u.eq_type
|
|
109
|
-
except AttributeError:
|
|
110
|
-
pass # use the value passed as argument
|
|
111
|
-
if eq_type is None:
|
|
112
|
-
raise ValueError("eq_type could not be set!")
|
|
120
|
+
eq_type = _get_eq_type(u, eq_type)
|
|
113
121
|
|
|
114
122
|
def scan_fun(_, i):
|
|
115
123
|
if eq_type == "nonstatio_PDE":
|
|
@@ -142,12 +150,12 @@ def divergence_fwd(
|
|
|
142
150
|
|
|
143
151
|
|
|
144
152
|
def laplacian_rev(
|
|
145
|
-
inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
|
|
146
|
-
u:
|
|
147
|
-
params: Params,
|
|
153
|
+
inputs: Float[Array, " dim"] | Float[Array, " 1+dim"],
|
|
154
|
+
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
155
|
+
params: Params[Array],
|
|
148
156
|
method: Literal["trace_hessian_x", "trace_hessian_t_x", "loop"] = "trace_hessian_x",
|
|
149
|
-
eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
|
|
150
|
-
) ->
|
|
157
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
|
|
158
|
+
) -> Float[Array, " "]:
|
|
151
159
|
r"""
|
|
152
160
|
Compute the Laplacian of a scalar field $u$ from $\mathbb{R}^d$
|
|
153
161
|
to $\mathbb{R}$ or from $\mathbb{R}^{1+d}$ to $\mathbb{R}$, i.e., this
|
|
@@ -180,13 +188,7 @@ def laplacian_rev(
|
|
|
180
188
|
can know that by inspecting the `u` argument (PINN object). But if `u` is
|
|
181
189
|
a function, we must set this attribute.
|
|
182
190
|
"""
|
|
183
|
-
|
|
184
|
-
try:
|
|
185
|
-
eq_type = u.eq_type
|
|
186
|
-
except AttributeError:
|
|
187
|
-
pass # use the value passed as argument
|
|
188
|
-
if eq_type is None:
|
|
189
|
-
raise ValueError("eq_type could not be set!")
|
|
191
|
+
eq_type = _get_eq_type(u, eq_type)
|
|
190
192
|
|
|
191
193
|
if method == "trace_hessian_x":
|
|
192
194
|
# NOTE we afford a concatenate here to avoid computing Hessian elements for
|
|
@@ -226,16 +228,12 @@ def laplacian_rev(
|
|
|
226
228
|
if eq_type == "nonstatio_PDE":
|
|
227
229
|
d2u_dxi2 = grad(
|
|
228
230
|
lambda inputs: grad(u_)(inputs)[1 + i],
|
|
229
|
-
)(
|
|
230
|
-
inputs
|
|
231
|
-
)[1 + i]
|
|
231
|
+
)(inputs)[1 + i]
|
|
232
232
|
else:
|
|
233
233
|
d2u_dxi2 = grad(
|
|
234
234
|
lambda inputs: grad(u_, 0)(inputs)[i],
|
|
235
235
|
0,
|
|
236
|
-
)(
|
|
237
|
-
inputs
|
|
238
|
-
)[i]
|
|
236
|
+
)(inputs)[i]
|
|
239
237
|
return _, d2u_dxi2
|
|
240
238
|
|
|
241
239
|
if eq_type == "nonstatio_PDE":
|
|
@@ -251,12 +249,12 @@ def laplacian_rev(
|
|
|
251
249
|
|
|
252
250
|
|
|
253
251
|
def laplacian_fwd(
|
|
254
|
-
inputs: Float[Array, "batch_size 1+dim"] | Float[Array, "batch_size dim"],
|
|
255
|
-
u:
|
|
256
|
-
params: Params,
|
|
252
|
+
inputs: Float[Array, " batch_size 1+dim"] | Float[Array, " batch_size dim"],
|
|
253
|
+
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
254
|
+
params: Params[Array],
|
|
257
255
|
method: Literal["trace_hessian_t_x", "trace_hessian_x", "loop"] = "loop",
|
|
258
|
-
eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
|
|
259
|
-
) -> Float[Array, "batch_size * (1+dim) 1"] | Float[Array, "batch_size * (dim) 1"]:
|
|
256
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
|
|
257
|
+
) -> Float[Array, " batch_size * (1+dim) 1"] | Float[Array, " batch_size * (dim) 1"]:
|
|
260
258
|
r"""
|
|
261
259
|
Compute the Laplacian of a **batched** scalar field $u$
|
|
262
260
|
from $\mathbb{R}^{b\times d}$ to $\mathbb{R}^{b\times b}$ or
|
|
@@ -299,13 +297,7 @@ def laplacian_fwd(
|
|
|
299
297
|
can know that by inspecting the `u` argument (PINN object). But if `u` is
|
|
300
298
|
a function, we must set this attribute.
|
|
301
299
|
"""
|
|
302
|
-
|
|
303
|
-
try:
|
|
304
|
-
eq_type = u.eq_type
|
|
305
|
-
except AttributeError:
|
|
306
|
-
pass # use the value passed as argument
|
|
307
|
-
if eq_type is None:
|
|
308
|
-
raise ValueError("eq_type could not be set!")
|
|
300
|
+
eq_type = _get_eq_type(u, eq_type)
|
|
309
301
|
|
|
310
302
|
if method == "loop":
|
|
311
303
|
|
|
@@ -398,11 +390,12 @@ def laplacian_fwd(
|
|
|
398
390
|
|
|
399
391
|
|
|
400
392
|
def vectorial_laplacian_rev(
|
|
401
|
-
inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
|
|
402
|
-
u:
|
|
403
|
-
params: Params,
|
|
404
|
-
dim_out: int = None,
|
|
405
|
-
|
|
393
|
+
inputs: Float[Array, " dim"] | Float[Array, " 1+dim"],
|
|
394
|
+
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
395
|
+
params: Params[Array],
|
|
396
|
+
dim_out: int | None = None,
|
|
397
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
|
|
398
|
+
) -> Float[Array, " dim_out"]:
|
|
406
399
|
r"""
|
|
407
400
|
Compute the vectorial Laplacian of a vector field $\mathbf{u}$ from
|
|
408
401
|
$\mathbb{R}^d$ to $\mathbb{R}^n$ or from $\mathbb{R}^{1+d}$ to
|
|
@@ -426,7 +419,12 @@ def vectorial_laplacian_rev(
|
|
|
426
419
|
dim_out
|
|
427
420
|
Dimension of the vector $\mathbf{u}(\mathrm{inputs})$. This needs to be
|
|
428
421
|
provided if it is different than that of $\mathrm{inputs}$.
|
|
422
|
+
eq_type
|
|
423
|
+
whether we consider a stationary or non stationary PINN. Most often we
|
|
424
|
+
can know that by inspecting the `u` argument (PINN object). But if `u` is
|
|
425
|
+
a function, we must set this attribute.
|
|
429
426
|
"""
|
|
427
|
+
eq_type = _get_eq_type(u, eq_type)
|
|
430
428
|
if dim_out is None:
|
|
431
429
|
dim_out = inputs.shape[0]
|
|
432
430
|
|
|
@@ -435,7 +433,9 @@ def vectorial_laplacian_rev(
|
|
|
435
433
|
# each of these components
|
|
436
434
|
# Note the jnp.expand_dims call
|
|
437
435
|
uj = lambda inputs, params: jnp.expand_dims(u(inputs, params)[j], axis=-1)
|
|
438
|
-
lap_on_j = laplacian_rev(
|
|
436
|
+
lap_on_j = laplacian_rev(
|
|
437
|
+
inputs, cast(AbstractPINN, uj), params, eq_type=eq_type
|
|
438
|
+
)
|
|
439
439
|
|
|
440
440
|
return _, lap_on_j
|
|
441
441
|
|
|
@@ -444,11 +444,12 @@ def vectorial_laplacian_rev(
|
|
|
444
444
|
|
|
445
445
|
|
|
446
446
|
def vectorial_laplacian_fwd(
|
|
447
|
-
inputs: Float[Array, "batch_size dim"] | Float[Array, "batch_size 1+dim"],
|
|
448
|
-
u:
|
|
449
|
-
params: Params,
|
|
450
|
-
dim_out: int = None,
|
|
451
|
-
|
|
447
|
+
inputs: Float[Array, " batch_size dim"] | Float[Array, " batch_size 1+dim"],
|
|
448
|
+
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
449
|
+
params: Params[Array],
|
|
450
|
+
dim_out: int | None = None,
|
|
451
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
|
|
452
|
+
) -> Float[Array, " batch_size * (1+dim) n"] | Float[Array, " batch_size * (dim) n"]:
|
|
452
453
|
r"""
|
|
453
454
|
Compute the vectorial Laplacian of a vector field $\mathbf{u}$ when
|
|
454
455
|
`u` is a SPINN, in this case, it corresponds to a vector
|
|
@@ -474,7 +475,12 @@ def vectorial_laplacian_fwd(
|
|
|
474
475
|
dim_out
|
|
475
476
|
the value of the output dimension ($n$ in the formula above). Must be
|
|
476
477
|
set if different from $d$.
|
|
478
|
+
eq_type
|
|
479
|
+
whether we consider a stationary or non stationary PINN. Most often we
|
|
480
|
+
can know that by inspecting the `u` argument (PINN object). But if `u` is
|
|
481
|
+
a function, we must set this attribute.
|
|
477
482
|
"""
|
|
483
|
+
eq_type = _get_eq_type(u, eq_type)
|
|
478
484
|
if dim_out is None:
|
|
479
485
|
dim_out = inputs.shape[0]
|
|
480
486
|
|
|
@@ -483,7 +489,9 @@ def vectorial_laplacian_fwd(
|
|
|
483
489
|
# each of these components
|
|
484
490
|
# Note the expand_dims
|
|
485
491
|
uj = lambda inputs, params: jnp.expand_dims(u(inputs, params)[..., j], axis=-1)
|
|
486
|
-
lap_on_j = laplacian_fwd(
|
|
492
|
+
lap_on_j = laplacian_fwd(
|
|
493
|
+
inputs, cast(AbstractPINN, uj), params, eq_type=eq_type
|
|
494
|
+
)
|
|
487
495
|
|
|
488
496
|
return _, lap_on_j
|
|
489
497
|
|
|
@@ -492,8 +500,10 @@ def vectorial_laplacian_fwd(
|
|
|
492
500
|
|
|
493
501
|
|
|
494
502
|
def _u_dot_nabla_times_u_rev(
|
|
495
|
-
x: Float[Array, "2"],
|
|
496
|
-
|
|
503
|
+
x: Float[Array, " 2"],
|
|
504
|
+
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
505
|
+
params: Params[Array],
|
|
506
|
+
) -> Float[Array, " 2"]:
|
|
497
507
|
r"""
|
|
498
508
|
Implement $((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})$ for
|
|
499
509
|
$\mathbf{x}$ of arbitrary
|
|
@@ -522,10 +532,10 @@ def _u_dot_nabla_times_u_rev(
|
|
|
522
532
|
|
|
523
533
|
|
|
524
534
|
def _u_dot_nabla_times_u_fwd(
|
|
525
|
-
x: Float[Array, "batch_size 2"],
|
|
526
|
-
u:
|
|
527
|
-
params: Params,
|
|
528
|
-
) -> Float[Array, "batch_size batch_size 2"]:
|
|
535
|
+
x: Float[Array, " batch_size 2"],
|
|
536
|
+
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
537
|
+
params: Params[Array],
|
|
538
|
+
) -> Float[Array, " batch_size batch_size 2"]:
|
|
529
539
|
r"""
|
|
530
540
|
Implement :math:`((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})` for
|
|
531
541
|
:math:`\mathbf{x}` of arbitrary dimension **with a batch dimension**.
|
jinns/nn/__init__.py
CHANGED
|
@@ -1,7 +1,22 @@
|
|
|
1
1
|
from ._save_load import save_pinn, load_pinn
|
|
2
|
+
from ._abstract_pinn import AbstractPINN
|
|
2
3
|
from ._pinn import PINN
|
|
3
4
|
from ._spinn import SPINN
|
|
4
5
|
from ._mlp import PINN_MLP, MLP
|
|
5
6
|
from ._spinn_mlp import SPINN_MLP, SMLP
|
|
6
7
|
from ._hyperpinn import HyperPINN
|
|
7
8
|
from ._ppinn import PPINN_MLP
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"save_pinn",
|
|
12
|
+
"load_pinn",
|
|
13
|
+
"AbstractPINN",
|
|
14
|
+
"PINN",
|
|
15
|
+
"SPINN",
|
|
16
|
+
"PINN_MLP",
|
|
17
|
+
"MLP",
|
|
18
|
+
"SPINN_MLP",
|
|
19
|
+
"SMLP",
|
|
20
|
+
"HyperPINN",
|
|
21
|
+
"PPINN_MLP",
|
|
22
|
+
]
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from typing import Literal, Any
|
|
3
|
+
from jaxtyping import Array
|
|
4
|
+
import equinox as eqx
|
|
5
|
+
|
|
6
|
+
from jinns.nn._utils import _PyTree_to_Params
|
|
7
|
+
from jinns.parameters._params import Params
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AbstractPINN(eqx.Module):
|
|
11
|
+
"""
|
|
12
|
+
Basically just a way to add a __call__ to an eqx.Module.
|
|
13
|
+
The way to go for correct type hints apparently
|
|
14
|
+
https://github.com/patrick-kidger/equinox/issues/1002 + https://docs.kidger.site/equinox/pattern/
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
eq_type: eqx.AbstractVar[Literal["ODE", "statio_PDE", "nonstatio_PDE"]]
|
|
18
|
+
|
|
19
|
+
@abc.abstractmethod
|
|
20
|
+
@_PyTree_to_Params
|
|
21
|
+
def __call__(self, inputs: Any, params: Params[Array], *args, **kwargs) -> Any:
|
|
22
|
+
pass
|
jinns/nn/_hyperpinn.py
CHANGED
|
@@ -3,9 +3,11 @@ Implements utility function to create HyperPINNs
|
|
|
3
3
|
https://arxiv.org/pdf/2111.01008.pdf
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
6
8
|
import warnings
|
|
7
9
|
from dataclasses import InitVar
|
|
8
|
-
from typing import Callable, Literal, Self, Union, Any
|
|
10
|
+
from typing import Callable, Literal, Self, Union, Any, cast, overload
|
|
9
11
|
from math import prod
|
|
10
12
|
import jax
|
|
11
13
|
import jax.numpy as jnp
|
|
@@ -15,12 +17,13 @@ import numpy as onp
|
|
|
15
17
|
|
|
16
18
|
from jinns.nn._pinn import PINN
|
|
17
19
|
from jinns.nn._mlp import MLP
|
|
18
|
-
from jinns.parameters._params import Params
|
|
20
|
+
from jinns.parameters._params import Params
|
|
21
|
+
from jinns.nn._utils import _PyTree_to_Params
|
|
19
22
|
|
|
20
23
|
|
|
21
24
|
def _get_param_nb(
|
|
22
|
-
params:
|
|
23
|
-
) -> tuple[int, list]:
|
|
25
|
+
params: PyTree[Array],
|
|
26
|
+
) -> tuple[int, list[int]]:
|
|
24
27
|
"""Returns the number of parameters in a Params object and also
|
|
25
28
|
the cumulative sum when parsing the object.
|
|
26
29
|
|
|
@@ -48,7 +51,7 @@ class HyperPINN(PINN):
|
|
|
48
51
|
|
|
49
52
|
Parameters
|
|
50
53
|
----------
|
|
51
|
-
hyperparams: list = eqx.field(static=True)
|
|
54
|
+
hyperparams: list[str] = eqx.field(static=True)
|
|
52
55
|
A list of keys from Params.eq_params that will be considered as
|
|
53
56
|
hyperparameters for metamodeling.
|
|
54
57
|
hypernet_input_size: int
|
|
@@ -72,12 +75,12 @@ class HyperPINN(PINN):
|
|
|
72
75
|
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
73
76
|
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
74
77
|
after the `input_transform` function
|
|
75
|
-
input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
|
|
78
|
+
input_transform : Callable[[Float[Array, " input_dim"], Params[Array]], Float[Array, " output_dim"]]
|
|
76
79
|
A function that will be called before entering the PINN. Its output(s)
|
|
77
80
|
must match the PINN inputs (except for the parameters).
|
|
78
81
|
Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
|
|
79
82
|
and the parameters. Default is no operation.
|
|
80
|
-
output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
|
|
83
|
+
output_transform : Callable[[Float[Array, " input_dim"], Float[Array, " output_dim"], Params[Array]], Float[Array, " output_dim"]]
|
|
81
84
|
A function with arguments begin the same input as the PINN, the PINN
|
|
82
85
|
output and the parameter. This function will be called after exiting the PINN.
|
|
83
86
|
Default is no operation.
|
|
@@ -100,10 +103,10 @@ class HyperPINN(PINN):
|
|
|
100
103
|
eqx_hyper_network: InitVar[eqx.Module] = eqx.field(kw_only=True)
|
|
101
104
|
|
|
102
105
|
pinn_params_sum: int = eqx.field(init=False, static=True)
|
|
103
|
-
pinn_params_cumsum: list = eqx.field(init=False, static=True)
|
|
106
|
+
pinn_params_cumsum: list[int] = eqx.field(init=False, static=True)
|
|
104
107
|
|
|
105
|
-
init_params_hyper:
|
|
106
|
-
static_hyper:
|
|
108
|
+
init_params_hyper: HyperPINN = eqx.field(init=False)
|
|
109
|
+
static_hyper: HyperPINN = eqx.field(init=False, static=True)
|
|
107
110
|
|
|
108
111
|
def __post_init__(self, eqx_network, eqx_hyper_network):
|
|
109
112
|
super().__post_init__(
|
|
@@ -115,7 +118,7 @@ class HyperPINN(PINN):
|
|
|
115
118
|
)
|
|
116
119
|
self.pinn_params_sum, self.pinn_params_cumsum = _get_param_nb(self.init_params)
|
|
117
120
|
|
|
118
|
-
def _hyper_to_pinn(self, hyper_output: Float[Array, "output_dim"]) ->
|
|
121
|
+
def _hyper_to_pinn(self, hyper_output: Float[Array, " output_dim"]) -> PINN:
|
|
119
122
|
"""
|
|
120
123
|
From the output of the hypernetwork, transform to a well formed
|
|
121
124
|
parameters for the pinn network (i.e. with the same PyTree structure as
|
|
@@ -142,15 +145,29 @@ class HyperPINN(PINN):
|
|
|
142
145
|
is_leaf=lambda x: isinstance(x, jnp.ndarray),
|
|
143
146
|
)
|
|
144
147
|
|
|
148
|
+
@overload
|
|
149
|
+
@_PyTree_to_Params
|
|
145
150
|
def __call__(
|
|
146
151
|
self,
|
|
147
|
-
inputs: Float[Array, "input_dim"],
|
|
148
|
-
params:
|
|
152
|
+
inputs: Float[Array, " input_dim"],
|
|
153
|
+
params: PyTree,
|
|
149
154
|
*args,
|
|
150
155
|
**kwargs,
|
|
151
|
-
) -> Float[Array, "output_dim"]:
|
|
156
|
+
) -> Float[Array, " output_dim"]: ...
|
|
157
|
+
|
|
158
|
+
@_PyTree_to_Params
|
|
159
|
+
def __call__(
|
|
160
|
+
self,
|
|
161
|
+
inputs: Float[Array, " input_dim"],
|
|
162
|
+
params: Params[Array],
|
|
163
|
+
*args,
|
|
164
|
+
**kwargs,
|
|
165
|
+
) -> Float[Array, " output_dim"]:
|
|
152
166
|
"""
|
|
153
167
|
Evaluate the HyperPINN on some inputs with some params.
|
|
168
|
+
|
|
169
|
+
Note that that thanks to the decorator, params can also directly be the
|
|
170
|
+
PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
|
|
154
171
|
"""
|
|
155
172
|
if len(inputs.shape) == 0:
|
|
156
173
|
# This can happen often when the user directly provides some
|
|
@@ -158,16 +175,17 @@ class HyperPINN(PINN):
|
|
|
158
175
|
# DataGenerators)
|
|
159
176
|
inputs = inputs[None]
|
|
160
177
|
|
|
161
|
-
try:
|
|
162
|
-
|
|
163
|
-
except (KeyError, AttributeError, TypeError) as e: # give more flexibility
|
|
164
|
-
|
|
178
|
+
# try:
|
|
179
|
+
hyper = eqx.combine(params.nn_params, self.static_hyper)
|
|
180
|
+
# except (KeyError, AttributeError, TypeError) as e: # give more flexibility
|
|
181
|
+
# hyper = eqx.combine(params, self.static_hyper)
|
|
165
182
|
|
|
166
183
|
eq_params_batch = jnp.concatenate(
|
|
167
|
-
[params.eq_params[k].flatten() for k in self.hyperparams],
|
|
184
|
+
[params.eq_params[k].flatten() for k in self.hyperparams],
|
|
185
|
+
axis=0,
|
|
168
186
|
)
|
|
169
187
|
|
|
170
|
-
hyper_output = hyper(eq_params_batch)
|
|
188
|
+
hyper_output = hyper(eq_params_batch) # type: ignore
|
|
171
189
|
|
|
172
190
|
pinn_params = self._hyper_to_pinn(hyper_output)
|
|
173
191
|
|
|
@@ -187,21 +205,34 @@ class HyperPINN(PINN):
|
|
|
187
205
|
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
|
|
188
206
|
hyperparams: list[str],
|
|
189
207
|
hypernet_input_size: int,
|
|
190
|
-
eqx_network: eqx.nn.MLP = None,
|
|
191
|
-
eqx_hyper_network: eqx.nn.MLP = None,
|
|
208
|
+
eqx_network: eqx.nn.MLP | MLP | None = None,
|
|
209
|
+
eqx_hyper_network: eqx.nn.MLP | MLP | None = None,
|
|
192
210
|
key: Key = None,
|
|
193
|
-
eqx_list: tuple[tuple[Callable, int, int] | Callable, ...] = None,
|
|
194
|
-
eqx_list_hyper:
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
211
|
+
eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...] | None = None,
|
|
212
|
+
eqx_list_hyper: (
|
|
213
|
+
tuple[tuple[Callable, int, int] | tuple[Callable], ...] | None
|
|
214
|
+
) = None,
|
|
215
|
+
input_transform: (
|
|
216
|
+
Callable[
|
|
217
|
+
[Float[Array, " input_dim"], Params[Array]],
|
|
218
|
+
Float[Array, " output_dim"],
|
|
219
|
+
]
|
|
220
|
+
| None
|
|
221
|
+
) = None,
|
|
222
|
+
output_transform: (
|
|
223
|
+
Callable[
|
|
224
|
+
[
|
|
225
|
+
Float[Array, " input_dim"],
|
|
226
|
+
Float[Array, " output_dim"],
|
|
227
|
+
Params[Array],
|
|
228
|
+
],
|
|
229
|
+
Float[Array, " output_dim"],
|
|
230
|
+
]
|
|
231
|
+
| None
|
|
232
|
+
) = None,
|
|
233
|
+
slice_solution: slice | None = None,
|
|
203
234
|
filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = None,
|
|
204
|
-
) -> tuple[Self,
|
|
235
|
+
) -> tuple[Self, HyperPINN]:
|
|
205
236
|
r"""
|
|
206
237
|
Utility function to create a standard PINN neural network with the equinox
|
|
207
238
|
library.
|
|
@@ -250,11 +281,11 @@ class HyperPINN(PINN):
|
|
|
250
281
|
The `key` argument need not be given.
|
|
251
282
|
Thus typical example is `eqx_list=
|
|
252
283
|
((eqx.nn.Linear, 2, 20),
|
|
253
|
-
jax.nn.tanh,
|
|
284
|
+
(jax.nn.tanh,),
|
|
254
285
|
(eqx.nn.Linear, 20, 20),
|
|
255
|
-
jax.nn.tanh,
|
|
286
|
+
(jax.nn.tanh,),
|
|
256
287
|
(eqx.nn.Linear, 20, 20),
|
|
257
|
-
jax.nn.tanh,
|
|
288
|
+
(jax.nn.tanh,),
|
|
258
289
|
(eqx.nn.Linear, 20, 1)
|
|
259
290
|
)`.
|
|
260
291
|
eqx_list_hyper
|
|
@@ -268,11 +299,11 @@ class HyperPINN(PINN):
|
|
|
268
299
|
The `key` argument need not be given.
|
|
269
300
|
Thus typical example is `eqx_list=
|
|
270
301
|
((eqx.nn.Linear, 2, 20),
|
|
271
|
-
jax.nn.tanh,
|
|
302
|
+
(jax.nn.tanh,),
|
|
272
303
|
(eqx.nn.Linear, 20, 20),
|
|
273
|
-
jax.nn.tanh,
|
|
304
|
+
(jax.nn.tanh,),
|
|
274
305
|
(eqx.nn.Linear, 20, 20),
|
|
275
|
-
jax.nn.tanh,
|
|
306
|
+
(jax.nn.tanh,),
|
|
276
307
|
(eqx.nn.Linear, 20, 1)
|
|
277
308
|
)`.
|
|
278
309
|
input_transform
|
|
@@ -343,10 +374,13 @@ class HyperPINN(PINN):
|
|
|
343
374
|
(eqx_list_hyper[-1][:2] + (pinn_params_sum,)),
|
|
344
375
|
)
|
|
345
376
|
else:
|
|
346
|
-
eqx_list_hyper = (
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
377
|
+
eqx_list_hyper = cast(
|
|
378
|
+
tuple[tuple[Callable, int, int] | tuple[Callable], ...],
|
|
379
|
+
(
|
|
380
|
+
eqx_list_hyper[:-2]
|
|
381
|
+
+ ((eqx_list_hyper[-2][:2] + (pinn_params_sum,)),)
|
|
382
|
+
+ eqx_list_hyper[-1]
|
|
383
|
+
),
|
|
350
384
|
)
|
|
351
385
|
if len(eqx_list_hyper[0]) > 1:
|
|
352
386
|
eqx_list_hyper = (
|
|
@@ -357,21 +391,24 @@ class HyperPINN(PINN):
|
|
|
357
391
|
),
|
|
358
392
|
) + eqx_list_hyper[1:]
|
|
359
393
|
else:
|
|
360
|
-
eqx_list_hyper = (
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
394
|
+
eqx_list_hyper = cast(
|
|
395
|
+
tuple[tuple[Callable, int, int] | tuple[Callable], ...],
|
|
396
|
+
(
|
|
397
|
+
eqx_list_hyper[0]
|
|
398
|
+
+ (
|
|
399
|
+
(
|
|
400
|
+
(eqx_list_hyper[1][0],)
|
|
401
|
+
+ (hypernet_input_size,)
|
|
402
|
+
+ (eqx_list_hyper[1][2],) # type: ignore because we suppose that the second element of tuple is nec.of length > 1 since we expect smth like eqx.nn.Linear
|
|
403
|
+
),
|
|
404
|
+
)
|
|
405
|
+
+ eqx_list_hyper[2:]
|
|
406
|
+
),
|
|
370
407
|
)
|
|
371
408
|
key, subkey = jax.random.split(key, 2)
|
|
372
409
|
# with warnings.catch_warnings():
|
|
373
410
|
# warnings.filterwarnings("ignore", message="A JAX array is being set as static!")
|
|
374
|
-
eqx_hyper_network = MLP(key=subkey, eqx_list=eqx_list_hyper)
|
|
411
|
+
eqx_hyper_network = cast(MLP, MLP(key=subkey, eqx_list=eqx_list_hyper))
|
|
375
412
|
|
|
376
413
|
### End of finetuning the hypernetwork architecture
|
|
377
414
|
|
|
@@ -386,10 +423,10 @@ class HyperPINN(PINN):
|
|
|
386
423
|
hyperpinn = cls(
|
|
387
424
|
eqx_network=eqx_network,
|
|
388
425
|
eqx_hyper_network=eqx_hyper_network,
|
|
389
|
-
slice_solution=slice_solution,
|
|
426
|
+
slice_solution=slice_solution, # type: ignore
|
|
390
427
|
eq_type=eq_type,
|
|
391
|
-
input_transform=input_transform,
|
|
392
|
-
output_transform=output_transform,
|
|
428
|
+
input_transform=input_transform, # type: ignore
|
|
429
|
+
output_transform=output_transform, # type: ignore
|
|
393
430
|
hyperparams=hyperparams,
|
|
394
431
|
hypernet_input_size=hypernet_input_size,
|
|
395
432
|
filter_spec=filter_spec,
|