jinns 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +114 -187
- jinns/loss/_DynamicLossAbstract.py +74 -69
- jinns/loss/_LossODE.py +132 -348
- jinns/loss/_LossPDE.py +262 -549
- jinns/loss/__init__.py +32 -6
- jinns/loss/_abstract_loss.py +128 -0
- jinns/loss/_boundary_conditions.py +20 -19
- jinns/loss/_loss_components.py +43 -0
- jinns/loss/_loss_utils.py +85 -179
- jinns/loss/_loss_weight_updates.py +202 -0
- jinns/loss/_loss_weights.py +64 -40
- jinns/loss/_operators.py +84 -74
- jinns/nn/__init__.py +15 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +94 -57
- jinns/nn/_mlp.py +50 -25
- jinns/nn/_pinn.py +33 -19
- jinns/nn/_ppinn.py +70 -34
- jinns/nn/_save_load.py +21 -51
- jinns/nn/_spinn.py +33 -16
- jinns/nn/_spinn_mlp.py +28 -22
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +35 -34
- jinns/solver/_rar.py +80 -63
- jinns/solver/_solve.py +207 -92
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +16 -10
- jinns/utils/_types.py +20 -54
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/METADATA +8 -4
- jinns-1.5.0.dist-info/RECORD +55 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns-1.3.0.dist-info/RECORD +0 -44
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/top_level.txt +0 -0
jinns/loss/_loss_weights.py
CHANGED
|
@@ -2,58 +2,82 @@
|
|
|
2
2
|
Formalize the loss weights data structure
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from
|
|
6
|
-
from
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
from dataclasses import fields
|
|
7
|
+
|
|
8
|
+
from jaxtyping import Array
|
|
9
|
+
import jax.numpy as jnp
|
|
7
10
|
import equinox as eqx
|
|
8
11
|
|
|
9
12
|
|
|
10
|
-
|
|
13
|
+
def lw_converter(x):
|
|
14
|
+
if x is None:
|
|
15
|
+
return x
|
|
16
|
+
else:
|
|
17
|
+
return jnp.asarray(x)
|
|
11
18
|
|
|
12
|
-
dyn_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
13
|
-
initial_condition: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
14
|
-
observations: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
15
19
|
|
|
20
|
+
class AbstractLossWeights(eqx.Module):
|
|
21
|
+
"""
|
|
22
|
+
An abstract class, currently only useful for type hints
|
|
16
23
|
|
|
17
|
-
|
|
24
|
+
TODO in the future maybe loss weights could be subclasses of
|
|
25
|
+
XDEComponentsAbstract?
|
|
26
|
+
"""
|
|
18
27
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
28
|
+
def items(self):
|
|
29
|
+
"""
|
|
30
|
+
For the dataclass to be iterated like a dictionary.
|
|
31
|
+
Practical and retrocompatible with old code when loss components were
|
|
32
|
+
dictionaries
|
|
33
|
+
"""
|
|
34
|
+
return {
|
|
35
|
+
field.name: getattr(self, field.name)
|
|
36
|
+
for field in fields(self)
|
|
37
|
+
if getattr(self, field.name) is not None
|
|
38
|
+
}.items()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class LossWeightsODE(AbstractLossWeights):
|
|
42
|
+
dyn_loss: Array | float | None = eqx.field(
|
|
43
|
+
kw_only=True, default=None, converter=lw_converter
|
|
22
44
|
)
|
|
23
|
-
|
|
24
|
-
kw_only=True, default=None
|
|
45
|
+
initial_condition: Array | float | None = eqx.field(
|
|
46
|
+
kw_only=True, default=None, converter=lw_converter
|
|
47
|
+
)
|
|
48
|
+
observations: Array | float | None = eqx.field(
|
|
49
|
+
kw_only=True, default=None, converter=lw_converter
|
|
25
50
|
)
|
|
26
51
|
|
|
27
52
|
|
|
28
|
-
class LossWeightsPDEStatio(
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
observations: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
42
|
-
initial_condition: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
43
|
-
|
|
53
|
+
class LossWeightsPDEStatio(AbstractLossWeights):
|
|
54
|
+
dyn_loss: Array | float | None = eqx.field(
|
|
55
|
+
kw_only=True, default=None, converter=lw_converter
|
|
56
|
+
)
|
|
57
|
+
norm_loss: Array | float | None = eqx.field(
|
|
58
|
+
kw_only=True, default=None, converter=lw_converter
|
|
59
|
+
)
|
|
60
|
+
boundary_loss: Array | float | None = eqx.field(
|
|
61
|
+
kw_only=True, default=None, converter=lw_converter
|
|
62
|
+
)
|
|
63
|
+
observations: Array | float | None = eqx.field(
|
|
64
|
+
kw_only=True, default=None, converter=lw_converter
|
|
65
|
+
)
|
|
44
66
|
|
|
45
|
-
class LossWeightsPDEDict(eqx.Module):
|
|
46
|
-
"""
|
|
47
|
-
Only one type of LossWeights data structure for the SystemLossPDE:
|
|
48
|
-
Include the initial condition always for the code to be more generic
|
|
49
|
-
"""
|
|
50
67
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
68
|
+
class LossWeightsPDENonStatio(AbstractLossWeights):
|
|
69
|
+
dyn_loss: Array | float | None = eqx.field(
|
|
70
|
+
kw_only=True, default=None, converter=lw_converter
|
|
71
|
+
)
|
|
72
|
+
norm_loss: Array | float | None = eqx.field(
|
|
73
|
+
kw_only=True, default=None, converter=lw_converter
|
|
74
|
+
)
|
|
75
|
+
boundary_loss: Array | float | None = eqx.field(
|
|
76
|
+
kw_only=True, default=None, converter=lw_converter
|
|
77
|
+
)
|
|
78
|
+
observations: Array | float | None = eqx.field(
|
|
79
|
+
kw_only=True, default=None, converter=lw_converter
|
|
55
80
|
)
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
kw_only=True, default=1.0
|
|
81
|
+
initial_condition: Array | float | None = eqx.field(
|
|
82
|
+
kw_only=True, default=None, converter=lw_converter
|
|
59
83
|
)
|
jinns/loss/_operators.py
CHANGED
|
@@ -2,22 +2,42 @@
|
|
|
2
2
|
Implements diverse operators for dynamic losses
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from __future__ import (
|
|
6
|
+
annotations,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
from typing import Literal, cast, Callable
|
|
6
10
|
|
|
7
11
|
import jax
|
|
8
12
|
import jax.numpy as jnp
|
|
9
13
|
from jax import grad
|
|
10
|
-
import equinox as eqx
|
|
11
14
|
from jaxtyping import Float, Array
|
|
12
15
|
from jinns.parameters._params import Params
|
|
16
|
+
from jinns.nn._abstract_pinn import AbstractPINN
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _get_eq_type(
|
|
20
|
+
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
21
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None,
|
|
22
|
+
) -> Literal["nonstatio_PDE", "statio_PDE"]:
|
|
23
|
+
"""
|
|
24
|
+
But we filter out ODE from eq_type because we only have operators that does
|
|
25
|
+
not work with ODEs so far
|
|
26
|
+
"""
|
|
27
|
+
if isinstance(u, AbstractPINN):
|
|
28
|
+
assert u.eq_type != "ODE", "Cannot compute the operator for ODE PINNs"
|
|
29
|
+
return u.eq_type
|
|
30
|
+
if eq_type is None:
|
|
31
|
+
raise ValueError("eq_type could not be set!")
|
|
32
|
+
return eq_type
|
|
13
33
|
|
|
14
34
|
|
|
15
35
|
def divergence_rev(
|
|
16
|
-
inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
|
|
17
|
-
u:
|
|
18
|
-
params: Params,
|
|
19
|
-
eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
|
|
20
|
-
) ->
|
|
36
|
+
inputs: Float[Array, " dim"] | Float[Array, " 1+dim"],
|
|
37
|
+
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
38
|
+
params: Params[Array],
|
|
39
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
|
|
40
|
+
) -> Float[Array, " "]:
|
|
21
41
|
r"""
|
|
22
42
|
Compute the divergence of a vector field $\mathbf{u}$, i.e.,
|
|
23
43
|
$\nabla_\mathbf{x} \cdot \mathbf{u}(\mathrm{inputs})$ with $\mathbf{u}$ a vector
|
|
@@ -41,13 +61,7 @@ def divergence_rev(
|
|
|
41
61
|
can know that by inspecting the `u` argument (PINN object). But if `u` is
|
|
42
62
|
a function, we must set this attribute.
|
|
43
63
|
"""
|
|
44
|
-
|
|
45
|
-
try:
|
|
46
|
-
eq_type = u.eq_type
|
|
47
|
-
except AttributeError:
|
|
48
|
-
pass # use the value passed as argument
|
|
49
|
-
if eq_type is None:
|
|
50
|
-
raise ValueError("eq_type could not be set!")
|
|
64
|
+
eq_type = _get_eq_type(u, eq_type)
|
|
51
65
|
|
|
52
66
|
def scan_fun(_, i):
|
|
53
67
|
if eq_type == "nonstatio_PDE":
|
|
@@ -70,11 +84,11 @@ def divergence_rev(
|
|
|
70
84
|
|
|
71
85
|
|
|
72
86
|
def divergence_fwd(
|
|
73
|
-
inputs: Float[Array, "batch_size dim"] | Float[Array, "batch_size 1+dim"],
|
|
74
|
-
u:
|
|
75
|
-
params: Params,
|
|
76
|
-
eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
|
|
77
|
-
) -> Float[Array, "batch_size * (1+dim) 1"] | Float[Array, "batch_size * (dim) 1"]:
|
|
87
|
+
inputs: Float[Array, " batch_size dim"] | Float[Array, " batch_size 1+dim"],
|
|
88
|
+
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
89
|
+
params: Params[Array],
|
|
90
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
|
|
91
|
+
) -> Float[Array, " batch_size * (1+dim) 1"] | Float[Array, " batch_size * (dim) 1"]:
|
|
78
92
|
r"""
|
|
79
93
|
Compute the divergence of a **batched** vector field $\mathbf{u}$, i.e.,
|
|
80
94
|
$\nabla_\mathbf{x} \cdot \mathbf{u}(\mathbf{x})$ with $\mathbf{u}$ a vector
|
|
@@ -103,13 +117,7 @@ def divergence_fwd(
|
|
|
103
117
|
can know that by inspecting the `u` argument (PINN object). But if `u` is
|
|
104
118
|
a function, we must set this attribute.
|
|
105
119
|
"""
|
|
106
|
-
|
|
107
|
-
try:
|
|
108
|
-
eq_type = u.eq_type
|
|
109
|
-
except AttributeError:
|
|
110
|
-
pass # use the value passed as argument
|
|
111
|
-
if eq_type is None:
|
|
112
|
-
raise ValueError("eq_type could not be set!")
|
|
120
|
+
eq_type = _get_eq_type(u, eq_type)
|
|
113
121
|
|
|
114
122
|
def scan_fun(_, i):
|
|
115
123
|
if eq_type == "nonstatio_PDE":
|
|
@@ -142,12 +150,12 @@ def divergence_fwd(
|
|
|
142
150
|
|
|
143
151
|
|
|
144
152
|
def laplacian_rev(
|
|
145
|
-
inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
|
|
146
|
-
u:
|
|
147
|
-
params: Params,
|
|
153
|
+
inputs: Float[Array, " dim"] | Float[Array, " 1+dim"],
|
|
154
|
+
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
155
|
+
params: Params[Array],
|
|
148
156
|
method: Literal["trace_hessian_x", "trace_hessian_t_x", "loop"] = "trace_hessian_x",
|
|
149
|
-
eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
|
|
150
|
-
) ->
|
|
157
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
|
|
158
|
+
) -> Float[Array, " "]:
|
|
151
159
|
r"""
|
|
152
160
|
Compute the Laplacian of a scalar field $u$ from $\mathbb{R}^d$
|
|
153
161
|
to $\mathbb{R}$ or from $\mathbb{R}^{1+d}$ to $\mathbb{R}$, i.e., this
|
|
@@ -180,13 +188,7 @@ def laplacian_rev(
|
|
|
180
188
|
can know that by inspecting the `u` argument (PINN object). But if `u` is
|
|
181
189
|
a function, we must set this attribute.
|
|
182
190
|
"""
|
|
183
|
-
|
|
184
|
-
try:
|
|
185
|
-
eq_type = u.eq_type
|
|
186
|
-
except AttributeError:
|
|
187
|
-
pass # use the value passed as argument
|
|
188
|
-
if eq_type is None:
|
|
189
|
-
raise ValueError("eq_type could not be set!")
|
|
191
|
+
eq_type = _get_eq_type(u, eq_type)
|
|
190
192
|
|
|
191
193
|
if method == "trace_hessian_x":
|
|
192
194
|
# NOTE we afford a concatenate here to avoid computing Hessian elements for
|
|
@@ -226,16 +228,12 @@ def laplacian_rev(
|
|
|
226
228
|
if eq_type == "nonstatio_PDE":
|
|
227
229
|
d2u_dxi2 = grad(
|
|
228
230
|
lambda inputs: grad(u_)(inputs)[1 + i],
|
|
229
|
-
)(
|
|
230
|
-
inputs
|
|
231
|
-
)[1 + i]
|
|
231
|
+
)(inputs)[1 + i]
|
|
232
232
|
else:
|
|
233
233
|
d2u_dxi2 = grad(
|
|
234
234
|
lambda inputs: grad(u_, 0)(inputs)[i],
|
|
235
235
|
0,
|
|
236
|
-
)(
|
|
237
|
-
inputs
|
|
238
|
-
)[i]
|
|
236
|
+
)(inputs)[i]
|
|
239
237
|
return _, d2u_dxi2
|
|
240
238
|
|
|
241
239
|
if eq_type == "nonstatio_PDE":
|
|
@@ -251,12 +249,12 @@ def laplacian_rev(
|
|
|
251
249
|
|
|
252
250
|
|
|
253
251
|
def laplacian_fwd(
|
|
254
|
-
inputs: Float[Array, "batch_size 1+dim"] | Float[Array, "batch_size dim"],
|
|
255
|
-
u:
|
|
256
|
-
params: Params,
|
|
252
|
+
inputs: Float[Array, " batch_size 1+dim"] | Float[Array, " batch_size dim"],
|
|
253
|
+
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
254
|
+
params: Params[Array],
|
|
257
255
|
method: Literal["trace_hessian_t_x", "trace_hessian_x", "loop"] = "loop",
|
|
258
|
-
eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
|
|
259
|
-
) -> Float[Array, "batch_size * (1+dim) 1"] | Float[Array, "batch_size * (dim) 1"]:
|
|
256
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
|
|
257
|
+
) -> Float[Array, " batch_size * (1+dim) 1"] | Float[Array, " batch_size * (dim) 1"]:
|
|
260
258
|
r"""
|
|
261
259
|
Compute the Laplacian of a **batched** scalar field $u$
|
|
262
260
|
from $\mathbb{R}^{b\times d}$ to $\mathbb{R}^{b\times b}$ or
|
|
@@ -299,13 +297,7 @@ def laplacian_fwd(
|
|
|
299
297
|
can know that by inspecting the `u` argument (PINN object). But if `u` is
|
|
300
298
|
a function, we must set this attribute.
|
|
301
299
|
"""
|
|
302
|
-
|
|
303
|
-
try:
|
|
304
|
-
eq_type = u.eq_type
|
|
305
|
-
except AttributeError:
|
|
306
|
-
pass # use the value passed as argument
|
|
307
|
-
if eq_type is None:
|
|
308
|
-
raise ValueError("eq_type could not be set!")
|
|
300
|
+
eq_type = _get_eq_type(u, eq_type)
|
|
309
301
|
|
|
310
302
|
if method == "loop":
|
|
311
303
|
|
|
@@ -398,11 +390,12 @@ def laplacian_fwd(
|
|
|
398
390
|
|
|
399
391
|
|
|
400
392
|
def vectorial_laplacian_rev(
|
|
401
|
-
inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
|
|
402
|
-
u:
|
|
403
|
-
params: Params,
|
|
404
|
-
dim_out: int = None,
|
|
405
|
-
|
|
393
|
+
inputs: Float[Array, " dim"] | Float[Array, " 1+dim"],
|
|
394
|
+
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
395
|
+
params: Params[Array],
|
|
396
|
+
dim_out: int | None = None,
|
|
397
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
|
|
398
|
+
) -> Float[Array, " dim_out"]:
|
|
406
399
|
r"""
|
|
407
400
|
Compute the vectorial Laplacian of a vector field $\mathbf{u}$ from
|
|
408
401
|
$\mathbb{R}^d$ to $\mathbb{R}^n$ or from $\mathbb{R}^{1+d}$ to
|
|
@@ -426,7 +419,12 @@ def vectorial_laplacian_rev(
|
|
|
426
419
|
dim_out
|
|
427
420
|
Dimension of the vector $\mathbf{u}(\mathrm{inputs})$. This needs to be
|
|
428
421
|
provided if it is different than that of $\mathrm{inputs}$.
|
|
422
|
+
eq_type
|
|
423
|
+
whether we consider a stationary or non stationary PINN. Most often we
|
|
424
|
+
can know that by inspecting the `u` argument (PINN object). But if `u` is
|
|
425
|
+
a function, we must set this attribute.
|
|
429
426
|
"""
|
|
427
|
+
eq_type = _get_eq_type(u, eq_type)
|
|
430
428
|
if dim_out is None:
|
|
431
429
|
dim_out = inputs.shape[0]
|
|
432
430
|
|
|
@@ -435,7 +433,9 @@ def vectorial_laplacian_rev(
|
|
|
435
433
|
# each of these components
|
|
436
434
|
# Note the jnp.expand_dims call
|
|
437
435
|
uj = lambda inputs, params: jnp.expand_dims(u(inputs, params)[j], axis=-1)
|
|
438
|
-
lap_on_j = laplacian_rev(
|
|
436
|
+
lap_on_j = laplacian_rev(
|
|
437
|
+
inputs, cast(AbstractPINN, uj), params, eq_type=eq_type
|
|
438
|
+
)
|
|
439
439
|
|
|
440
440
|
return _, lap_on_j
|
|
441
441
|
|
|
@@ -444,11 +444,12 @@ def vectorial_laplacian_rev(
|
|
|
444
444
|
|
|
445
445
|
|
|
446
446
|
def vectorial_laplacian_fwd(
|
|
447
|
-
inputs: Float[Array, "batch_size dim"] | Float[Array, "batch_size 1+dim"],
|
|
448
|
-
u:
|
|
449
|
-
params: Params,
|
|
450
|
-
dim_out: int = None,
|
|
451
|
-
|
|
447
|
+
inputs: Float[Array, " batch_size dim"] | Float[Array, " batch_size 1+dim"],
|
|
448
|
+
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
449
|
+
params: Params[Array],
|
|
450
|
+
dim_out: int | None = None,
|
|
451
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
|
|
452
|
+
) -> Float[Array, " batch_size * (1+dim) n"] | Float[Array, " batch_size * (dim) n"]:
|
|
452
453
|
r"""
|
|
453
454
|
Compute the vectorial Laplacian of a vector field $\mathbf{u}$ when
|
|
454
455
|
`u` is a SPINN, in this case, it corresponds to a vector
|
|
@@ -474,7 +475,12 @@ def vectorial_laplacian_fwd(
|
|
|
474
475
|
dim_out
|
|
475
476
|
the value of the output dimension ($n$ in the formula above). Must be
|
|
476
477
|
set if different from $d$.
|
|
478
|
+
eq_type
|
|
479
|
+
whether we consider a stationary or non stationary PINN. Most often we
|
|
480
|
+
can know that by inspecting the `u` argument (PINN object). But if `u` is
|
|
481
|
+
a function, we must set this attribute.
|
|
477
482
|
"""
|
|
483
|
+
eq_type = _get_eq_type(u, eq_type)
|
|
478
484
|
if dim_out is None:
|
|
479
485
|
dim_out = inputs.shape[0]
|
|
480
486
|
|
|
@@ -483,7 +489,9 @@ def vectorial_laplacian_fwd(
|
|
|
483
489
|
# each of these components
|
|
484
490
|
# Note the expand_dims
|
|
485
491
|
uj = lambda inputs, params: jnp.expand_dims(u(inputs, params)[..., j], axis=-1)
|
|
486
|
-
lap_on_j = laplacian_fwd(
|
|
492
|
+
lap_on_j = laplacian_fwd(
|
|
493
|
+
inputs, cast(AbstractPINN, uj), params, eq_type=eq_type
|
|
494
|
+
)
|
|
487
495
|
|
|
488
496
|
return _, lap_on_j
|
|
489
497
|
|
|
@@ -492,8 +500,10 @@ def vectorial_laplacian_fwd(
|
|
|
492
500
|
|
|
493
501
|
|
|
494
502
|
def _u_dot_nabla_times_u_rev(
|
|
495
|
-
x: Float[Array, "2"],
|
|
496
|
-
|
|
503
|
+
x: Float[Array, " 2"],
|
|
504
|
+
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
505
|
+
params: Params[Array],
|
|
506
|
+
) -> Float[Array, " 2"]:
|
|
497
507
|
r"""
|
|
498
508
|
Implement $((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})$ for
|
|
499
509
|
$\mathbf{x}$ of arbitrary
|
|
@@ -522,10 +532,10 @@ def _u_dot_nabla_times_u_rev(
|
|
|
522
532
|
|
|
523
533
|
|
|
524
534
|
def _u_dot_nabla_times_u_fwd(
|
|
525
|
-
x: Float[Array, "batch_size 2"],
|
|
526
|
-
u:
|
|
527
|
-
params: Params,
|
|
528
|
-
) -> Float[Array, "batch_size batch_size 2"]:
|
|
535
|
+
x: Float[Array, " batch_size 2"],
|
|
536
|
+
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
537
|
+
params: Params[Array],
|
|
538
|
+
) -> Float[Array, " batch_size batch_size 2"]:
|
|
529
539
|
r"""
|
|
530
540
|
Implement :math:`((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})` for
|
|
531
541
|
:math:`\mathbf{x}` of arbitrary dimension **with a batch dimension**.
|
jinns/nn/__init__.py
CHANGED
|
@@ -1,7 +1,22 @@
|
|
|
1
1
|
from ._save_load import save_pinn, load_pinn
|
|
2
|
+
from ._abstract_pinn import AbstractPINN
|
|
2
3
|
from ._pinn import PINN
|
|
3
4
|
from ._spinn import SPINN
|
|
4
5
|
from ._mlp import PINN_MLP, MLP
|
|
5
6
|
from ._spinn_mlp import SPINN_MLP, SMLP
|
|
6
7
|
from ._hyperpinn import HyperPINN
|
|
7
8
|
from ._ppinn import PPINN_MLP
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"save_pinn",
|
|
12
|
+
"load_pinn",
|
|
13
|
+
"AbstractPINN",
|
|
14
|
+
"PINN",
|
|
15
|
+
"SPINN",
|
|
16
|
+
"PINN_MLP",
|
|
17
|
+
"MLP",
|
|
18
|
+
"SPINN_MLP",
|
|
19
|
+
"SMLP",
|
|
20
|
+
"HyperPINN",
|
|
21
|
+
"PPINN_MLP",
|
|
22
|
+
]
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from typing import Literal, Any
|
|
3
|
+
from jaxtyping import Array
|
|
4
|
+
import equinox as eqx
|
|
5
|
+
|
|
6
|
+
from jinns.nn._utils import _PyTree_to_Params
|
|
7
|
+
from jinns.parameters._params import Params
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AbstractPINN(eqx.Module):
|
|
11
|
+
"""
|
|
12
|
+
Basically just a way to add a __call__ to an eqx.Module.
|
|
13
|
+
The way to go for correct type hints apparently
|
|
14
|
+
https://github.com/patrick-kidger/equinox/issues/1002 + https://docs.kidger.site/equinox/pattern/
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
eq_type: eqx.AbstractVar[Literal["ODE", "statio_PDE", "nonstatio_PDE"]]
|
|
18
|
+
|
|
19
|
+
@abc.abstractmethod
|
|
20
|
+
@_PyTree_to_Params
|
|
21
|
+
def __call__(self, inputs: Any, params: Params[Array], *args, **kwargs) -> Any:
|
|
22
|
+
pass
|