jinns 1.6.1__py3-none-any.whl → 1.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -1
- jinns/data/_Batchs.py +4 -4
- jinns/data/_DataGeneratorODE.py +1 -1
- jinns/data/_DataGeneratorObservations.py +498 -90
- jinns/loss/_DynamicLossAbstract.py +3 -1
- jinns/loss/_LossODE.py +138 -73
- jinns/loss/_LossPDE.py +208 -104
- jinns/loss/_abstract_loss.py +97 -14
- jinns/loss/_boundary_conditions.py +6 -6
- jinns/loss/_loss_utils.py +2 -2
- jinns/loss/_loss_weight_updates.py +30 -0
- jinns/loss/_loss_weights.py +4 -0
- jinns/loss/_operators.py +27 -27
- jinns/nn/_abstract_pinn.py +1 -1
- jinns/nn/_hyperpinn.py +6 -6
- jinns/nn/_mlp.py +3 -3
- jinns/nn/_pinn.py +7 -7
- jinns/nn/_ppinn.py +6 -6
- jinns/nn/_spinn.py +4 -4
- jinns/nn/_spinn_mlp.py +7 -7
- jinns/parameters/_derivative_keys.py +13 -6
- jinns/parameters/_params.py +10 -0
- jinns/solver/_rar.py +19 -9
- jinns/solver/_solve.py +102 -367
- jinns/solver/_solve_alternate.py +885 -0
- jinns/solver/_utils.py +520 -11
- jinns/utils/_DictToModuleMeta.py +3 -1
- jinns/utils/_containers.py +8 -4
- jinns/utils/_types.py +42 -1
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/METADATA +26 -14
- jinns-1.7.1.dist-info/RECORD +58 -0
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/WHEEL +1 -1
- jinns-1.6.1.dist-info/RECORD +0 -57
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/top_level.txt +0 -0
jinns/loss/_abstract_loss.py
CHANGED
|
@@ -1,15 +1,27 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import abc
|
|
4
|
-
|
|
5
|
-
from
|
|
4
|
+
import warnings
|
|
5
|
+
from typing import Self, Literal, Callable, TypeVar, Generic, Any, get_args
|
|
6
|
+
from dataclasses import InitVar
|
|
7
|
+
from jaxtyping import Array, PyTree, Float, PRNGKeyArray
|
|
6
8
|
import equinox as eqx
|
|
7
9
|
import jax
|
|
8
10
|
import jax.numpy as jnp
|
|
9
11
|
import optax
|
|
10
12
|
from jinns.parameters._params import Params
|
|
11
|
-
from jinns.loss._loss_weight_updates import
|
|
12
|
-
|
|
13
|
+
from jinns.loss._loss_weight_updates import (
|
|
14
|
+
soft_adapt,
|
|
15
|
+
lr_annealing,
|
|
16
|
+
ReLoBRaLo,
|
|
17
|
+
prior_loss,
|
|
18
|
+
)
|
|
19
|
+
from jinns.utils._types import (
|
|
20
|
+
AnyLossComponents,
|
|
21
|
+
AnyBatch,
|
|
22
|
+
AnyLossWeights,
|
|
23
|
+
AnyDerivativeKeys,
|
|
24
|
+
)
|
|
13
25
|
|
|
14
26
|
L = TypeVar(
|
|
15
27
|
"L", bound=AnyLossWeights
|
|
@@ -25,31 +37,85 @@ C = TypeVar(
|
|
|
25
37
|
"C", bound=AnyLossComponents[Array | None]
|
|
26
38
|
) # The above comment also works with Unions (https://docs.python.org/3/library/typing.html#typing.TypeVar)
|
|
27
39
|
|
|
40
|
+
DK = TypeVar("DK", bound=AnyDerivativeKeys)
|
|
41
|
+
|
|
28
42
|
# In the cases above, without the bound, we could not have covariance on
|
|
29
43
|
# the type because it would break LSP. Note that covariance on the return type
|
|
30
44
|
# is authorized in LSP hence we do not need the same TypeVar instruction for
|
|
31
45
|
# the return types of evaluate_by_terms for example!
|
|
32
46
|
|
|
33
47
|
|
|
34
|
-
|
|
48
|
+
AvailableUpdateWeightMethods = Literal[
|
|
49
|
+
"softadapt", "soft_adapt", "prior_loss", "lr_annealing", "ReLoBRaLo"
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class AbstractLoss(eqx.Module, Generic[L, B, C, DK]):
|
|
35
54
|
"""
|
|
36
55
|
About the call:
|
|
37
56
|
https://github.com/patrick-kidger/equinox/issues/1002 + https://docs.kidger.site/equinox/pattern/
|
|
38
57
|
"""
|
|
39
58
|
|
|
59
|
+
derivative_keys: eqx.AbstractVar[DK]
|
|
40
60
|
loss_weights: eqx.AbstractVar[L]
|
|
41
|
-
|
|
42
|
-
|
|
61
|
+
loss_weight_scales: L = eqx.field(init=False)
|
|
62
|
+
update_weight_method: AvailableUpdateWeightMethods | None = eqx.field(
|
|
63
|
+
kw_only=True, default=None, static=True
|
|
64
|
+
)
|
|
65
|
+
vmap_in_axes: tuple[int] = eqx.field(static=True)
|
|
66
|
+
keep_initial_loss_weight_scales: InitVar[bool] = eqx.field(
|
|
67
|
+
default=True, kw_only=True
|
|
43
68
|
)
|
|
44
69
|
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
*,
|
|
73
|
+
loss_weights,
|
|
74
|
+
derivative_keys,
|
|
75
|
+
vmap_in_axes,
|
|
76
|
+
update_weight_method=None,
|
|
77
|
+
keep_initial_loss_weight_scales: bool = False,
|
|
78
|
+
):
|
|
79
|
+
if update_weight_method is not None and update_weight_method not in get_args(
|
|
80
|
+
AvailableUpdateWeightMethods
|
|
81
|
+
):
|
|
82
|
+
raise ValueError(f"{update_weight_method=} is not a valid method")
|
|
83
|
+
self.update_weight_method = update_weight_method
|
|
84
|
+
self.loss_weights = loss_weights
|
|
85
|
+
self.derivative_keys = derivative_keys
|
|
86
|
+
self.vmap_in_axes = vmap_in_axes
|
|
87
|
+
if keep_initial_loss_weight_scales:
|
|
88
|
+
self.loss_weight_scales = self.loss_weights
|
|
89
|
+
if self.update_weight_method is not None:
|
|
90
|
+
warnings.warn(
|
|
91
|
+
"Loss weights out from update_weight_method will still be"
|
|
92
|
+
" multiplied by the initial input loss_weights"
|
|
93
|
+
)
|
|
94
|
+
else:
|
|
95
|
+
self.loss_weight_scales = optax.tree_utils.tree_ones_like(self.loss_weights)
|
|
96
|
+
# self.loss_weight_scales will contain None where self.loss_weights
|
|
97
|
+
# has None
|
|
98
|
+
|
|
45
99
|
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
46
100
|
return self.evaluate(*args, **kwargs)
|
|
47
101
|
|
|
48
102
|
@abc.abstractmethod
|
|
49
|
-
def evaluate_by_terms(
|
|
103
|
+
def evaluate_by_terms(
|
|
104
|
+
self,
|
|
105
|
+
opt_params: Params[Array],
|
|
106
|
+
batch: B,
|
|
107
|
+
*,
|
|
108
|
+
non_opt_params: Params[Array] | None = None,
|
|
109
|
+
) -> tuple[C, C]:
|
|
50
110
|
pass
|
|
51
111
|
|
|
52
|
-
def evaluate(
|
|
112
|
+
def evaluate(
|
|
113
|
+
self,
|
|
114
|
+
opt_params: Params[Array],
|
|
115
|
+
batch: B,
|
|
116
|
+
*,
|
|
117
|
+
non_opt_params: Params[Array] | None = None,
|
|
118
|
+
) -> tuple[Float[Array, " "], C]:
|
|
53
119
|
"""
|
|
54
120
|
Evaluate the loss function at a batch of points for given parameters.
|
|
55
121
|
|
|
@@ -57,16 +123,20 @@ class AbstractLoss(eqx.Module, Generic[L, B, C]):
|
|
|
57
123
|
|
|
58
124
|
Parameters
|
|
59
125
|
---------
|
|
60
|
-
|
|
61
|
-
Parameters at which the loss is evaluated
|
|
126
|
+
opt_params
|
|
127
|
+
Parameters, which are optimized, at which the loss is evaluated
|
|
62
128
|
batch
|
|
63
129
|
Composed of a batch of points in the
|
|
64
130
|
domain, a batch of points in the domain
|
|
65
131
|
border and an optional additional batch of parameters (eg. for
|
|
66
132
|
metamodeling) and an optional additional batch of observed
|
|
67
133
|
inputs/outputs/parameters
|
|
134
|
+
non_opt_params
|
|
135
|
+
Parameters, which are non optimized, at which the loss is evaluated
|
|
68
136
|
"""
|
|
69
|
-
loss_terms, _ = self.evaluate_by_terms(
|
|
137
|
+
loss_terms, _ = self.evaluate_by_terms(
|
|
138
|
+
opt_params, batch, non_opt_params=non_opt_params
|
|
139
|
+
)
|
|
70
140
|
|
|
71
141
|
loss_val = self.ponderate_and_sum_loss(loss_terms)
|
|
72
142
|
|
|
@@ -102,10 +172,14 @@ class AbstractLoss(eqx.Module, Generic[L, B, C]):
|
|
|
102
172
|
raise ValueError(
|
|
103
173
|
"The numbers of declared loss weights and "
|
|
104
174
|
"declared loss terms do not concord "
|
|
105
|
-
f" got {len(weights)} and {len(terms_list)}"
|
|
175
|
+
f" got {len(weights)} and {len(terms_list)}. "
|
|
176
|
+
"If you passed tuple of dyn_loss, make sure to pass "
|
|
177
|
+
"tuple of loss weights at LossWeights.dyn_loss."
|
|
178
|
+
"If you passed tuple of obs datasets, make sure to pass "
|
|
179
|
+
"tuple of loss weights at LossWeights.observations."
|
|
106
180
|
)
|
|
107
181
|
|
|
108
|
-
def ponderate_and_sum_gradient(self, terms: C) ->
|
|
182
|
+
def ponderate_and_sum_gradient(self, terms: C) -> Params[Array | None]:
|
|
109
183
|
"""
|
|
110
184
|
Get total gradients from individual loss gradients and weights
|
|
111
185
|
for each parameter
|
|
@@ -146,6 +220,8 @@ class AbstractLoss(eqx.Module, Generic[L, B, C]):
|
|
|
146
220
|
new_weights = soft_adapt(
|
|
147
221
|
self.loss_weights, iteration_nb, loss_terms, stored_loss_terms
|
|
148
222
|
)
|
|
223
|
+
elif self.update_weight_method == "prior_loss":
|
|
224
|
+
new_weights = prior_loss(self.loss_weights, iteration_nb, stored_loss_terms)
|
|
149
225
|
elif self.update_weight_method == "lr_annealing":
|
|
150
226
|
new_weights = lr_annealing(self.loss_weights, grad_terms)
|
|
151
227
|
elif self.update_weight_method == "ReLoBRaLo":
|
|
@@ -158,6 +234,13 @@ class AbstractLoss(eqx.Module, Generic[L, B, C]):
|
|
|
158
234
|
# Below we update the non None entry in the PyTree self.loss_weights
|
|
159
235
|
# we directly get the non None entries because None is not treated as a
|
|
160
236
|
# leaf
|
|
237
|
+
|
|
238
|
+
new_weights = jax.lax.cond(
|
|
239
|
+
iteration_nb == 0,
|
|
240
|
+
lambda nw: nw,
|
|
241
|
+
lambda nw: jnp.array(jax.tree.leaves(self.loss_weight_scales)) * nw,
|
|
242
|
+
new_weights,
|
|
243
|
+
)
|
|
161
244
|
return eqx.tree_at(
|
|
162
245
|
lambda pt: jax.tree.leaves(pt.loss_weights), self, new_weights
|
|
163
246
|
)
|
|
@@ -227,7 +227,7 @@ def boundary_neumann(
|
|
|
227
227
|
if isinstance(u, PINN):
|
|
228
228
|
u_ = lambda inputs, params: jnp.squeeze(u(inputs, params)[dim_to_apply])
|
|
229
229
|
|
|
230
|
-
if u.eq_type == "
|
|
230
|
+
if u.eq_type == "PDEStatio":
|
|
231
231
|
v_neumann = vmap(
|
|
232
232
|
lambda inputs, params: _subtract_with_check(
|
|
233
233
|
f(inputs),
|
|
@@ -240,7 +240,7 @@ def boundary_neumann(
|
|
|
240
240
|
vmap_in_axes,
|
|
241
241
|
0,
|
|
242
242
|
)
|
|
243
|
-
elif u.eq_type == "
|
|
243
|
+
elif u.eq_type == "PDENonStatio":
|
|
244
244
|
v_neumann = vmap(
|
|
245
245
|
lambda inputs, params: _subtract_with_check(
|
|
246
246
|
f(inputs),
|
|
@@ -274,14 +274,14 @@ def boundary_neumann(
|
|
|
274
274
|
if (batch_array.shape[0] == 1 and isinstance(batch, PDEStatioBatch)) or (
|
|
275
275
|
batch_array.shape[-1] == 2 and isinstance(batch, PDENonStatioBatch)
|
|
276
276
|
):
|
|
277
|
-
if u.eq_type == "
|
|
277
|
+
if u.eq_type == "PDEStatio":
|
|
278
278
|
_, du_dx = jax.jvp(
|
|
279
279
|
lambda inputs: u(inputs, params)[..., dim_to_apply],
|
|
280
280
|
(batch_array,),
|
|
281
281
|
(jnp.ones_like(batch_array),),
|
|
282
282
|
)
|
|
283
283
|
values = du_dx * n[facet]
|
|
284
|
-
if u.eq_type == "
|
|
284
|
+
if u.eq_type == "PDENonStatio":
|
|
285
285
|
_, du_dx = jax.jvp(
|
|
286
286
|
lambda inputs: u(inputs, params)[..., dim_to_apply],
|
|
287
287
|
(batch_array,),
|
|
@@ -291,7 +291,7 @@ def boundary_neumann(
|
|
|
291
291
|
elif (batch_array.shape[-1] == 2 and isinstance(batch, PDEStatioBatch)) or (
|
|
292
292
|
batch_array.shape[-1] == 3 and isinstance(batch, PDENonStatioBatch)
|
|
293
293
|
):
|
|
294
|
-
if u.eq_type == "
|
|
294
|
+
if u.eq_type == "PDEStatio":
|
|
295
295
|
tangent_vec_0 = jnp.repeat(
|
|
296
296
|
jnp.array([1.0, 0.0])[None], batch_array.shape[0], axis=0
|
|
297
297
|
)
|
|
@@ -309,7 +309,7 @@ def boundary_neumann(
|
|
|
309
309
|
(tangent_vec_1,),
|
|
310
310
|
)
|
|
311
311
|
values = du_dx1 * n[0, facet] + du_dx2 * n[1, facet] # dot product
|
|
312
|
-
if u.eq_type == "
|
|
312
|
+
if u.eq_type == "PDENonStatio":
|
|
313
313
|
tangent_vec_0 = jnp.repeat(
|
|
314
314
|
jnp.array([0.0, 1.0, 0.0])[None], batch_array.shape[0], axis=0
|
|
315
315
|
)
|
jinns/loss/_loss_utils.py
CHANGED
|
@@ -26,12 +26,12 @@ from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
|
|
|
26
26
|
from jinns.parameters._params import Params
|
|
27
27
|
|
|
28
28
|
if TYPE_CHECKING:
|
|
29
|
-
from jinns.utils._types import BoundaryConditionFun
|
|
29
|
+
from jinns.utils._types import BoundaryConditionFun, AnyBatch
|
|
30
30
|
from jinns.nn._abstract_pinn import AbstractPINN
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
def dynamic_loss_apply(
|
|
34
|
-
dyn_loss: Callable,
|
|
34
|
+
dyn_loss: Callable[[AnyBatch, AbstractPINN, Params[Array]], Array],
|
|
35
35
|
u: AbstractPINN,
|
|
36
36
|
batch: (
|
|
37
37
|
Float[Array, " batch_size 1"]
|
|
@@ -13,6 +13,36 @@ if TYPE_CHECKING:
|
|
|
13
13
|
from jinns.utils._types import AnyLossComponents, AnyLossWeights
|
|
14
14
|
|
|
15
15
|
|
|
16
|
+
def prior_loss(
|
|
17
|
+
loss_weights: AnyLossWeights,
|
|
18
|
+
iteration_nb: int,
|
|
19
|
+
stored_loss_terms: AnyLossComponents,
|
|
20
|
+
) -> Array:
|
|
21
|
+
"""
|
|
22
|
+
Simple adaptative weights according to the prior loss idea:
|
|
23
|
+
the ponderation in front of a loss term is given by the inverse of the
|
|
24
|
+
value of that loss term at the previous iteration
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def do_nothing(loss_weights, _):
|
|
28
|
+
return jnp.array(
|
|
29
|
+
jax.tree.leaves(loss_weights, is_leaf=eqx.is_inexact_array), dtype=float
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
def _prior_loss(_, stored_loss_terms):
|
|
33
|
+
new_weights = jax.tree.map(
|
|
34
|
+
lambda slt: 1 / (slt[iteration_nb - 1] + 1e-6), stored_loss_terms
|
|
35
|
+
)
|
|
36
|
+
return jnp.array(jax.tree.leaves(new_weights), dtype=float)
|
|
37
|
+
|
|
38
|
+
return jax.lax.cond(
|
|
39
|
+
iteration_nb == 0,
|
|
40
|
+
lambda op: do_nothing(*op),
|
|
41
|
+
lambda op: _prior_loss(*op),
|
|
42
|
+
(loss_weights, stored_loss_terms),
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
16
46
|
def soft_adapt(
|
|
17
47
|
loss_weights: AnyLossWeights,
|
|
18
48
|
iteration_nb: int,
|
jinns/loss/_loss_weights.py
CHANGED
|
@@ -18,6 +18,10 @@ from jinns.loss._loss_components import (
|
|
|
18
18
|
def lw_converter(x: Array | None) -> Array | None:
|
|
19
19
|
if x is None:
|
|
20
20
|
return x
|
|
21
|
+
elif isinstance(x, tuple):
|
|
22
|
+
# user might input tuple of scalar loss weights to account for cases
|
|
23
|
+
# when dyn loss is also a tuple of (possibly 1D) dyn_loss
|
|
24
|
+
return tuple(jnp.asarray(x_) for x_ in x)
|
|
21
25
|
else:
|
|
22
26
|
return jnp.asarray(x)
|
|
23
27
|
|
jinns/loss/_operators.py
CHANGED
|
@@ -18,8 +18,8 @@ from jinns.nn._abstract_pinn import AbstractPINN
|
|
|
18
18
|
|
|
19
19
|
def _get_eq_type(
|
|
20
20
|
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
21
|
-
eq_type: Literal["
|
|
22
|
-
) -> Literal["
|
|
21
|
+
eq_type: Literal["PDENonStatio", "PDEStatio"] | None,
|
|
22
|
+
) -> Literal["PDENonStatio", "PDEStatio"]:
|
|
23
23
|
"""
|
|
24
24
|
But we filter out ODE from eq_type because we only have operators that does
|
|
25
25
|
not work with ODEs so far
|
|
@@ -36,7 +36,7 @@ def divergence_rev(
|
|
|
36
36
|
inputs: Float[Array, " dim"] | Float[Array, " 1+dim"],
|
|
37
37
|
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
38
38
|
params: Params[Array],
|
|
39
|
-
eq_type: Literal["
|
|
39
|
+
eq_type: Literal["PDENonStatio", "PDEStatio"] | None = None,
|
|
40
40
|
) -> Float[Array, " "]:
|
|
41
41
|
r"""
|
|
42
42
|
Compute the divergence of a vector field $\mathbf{u}$, i.e.,
|
|
@@ -64,7 +64,7 @@ def divergence_rev(
|
|
|
64
64
|
eq_type = _get_eq_type(u, eq_type)
|
|
65
65
|
|
|
66
66
|
def scan_fun(_, i):
|
|
67
|
-
if eq_type == "
|
|
67
|
+
if eq_type == "PDENonStatio":
|
|
68
68
|
du_dxi = grad(lambda inputs, params: u(inputs, params)[1 + i])(
|
|
69
69
|
inputs, params
|
|
70
70
|
)[1 + i]
|
|
@@ -74,9 +74,9 @@ def divergence_rev(
|
|
|
74
74
|
]
|
|
75
75
|
return _, du_dxi
|
|
76
76
|
|
|
77
|
-
if eq_type == "
|
|
77
|
+
if eq_type == "PDENonStatio":
|
|
78
78
|
_, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[0] - 1))
|
|
79
|
-
elif eq_type == "
|
|
79
|
+
elif eq_type == "PDEStatio":
|
|
80
80
|
_, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[0]))
|
|
81
81
|
else:
|
|
82
82
|
raise ValueError("Unexpected u.eq_type!")
|
|
@@ -87,7 +87,7 @@ def divergence_fwd(
|
|
|
87
87
|
inputs: Float[Array, " batch_size dim"] | Float[Array, " batch_size 1+dim"],
|
|
88
88
|
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
89
89
|
params: Params[Array],
|
|
90
|
-
eq_type: Literal["
|
|
90
|
+
eq_type: Literal["PDENonStatio", "PDEStatio"] | None = None,
|
|
91
91
|
) -> Float[Array, " batch_size * (1+dim) 1"] | Float[Array, " batch_size * (dim) 1"]:
|
|
92
92
|
r"""
|
|
93
93
|
Compute the divergence of a **batched** vector field $\mathbf{u}$, i.e.,
|
|
@@ -120,7 +120,7 @@ def divergence_fwd(
|
|
|
120
120
|
eq_type = _get_eq_type(u, eq_type)
|
|
121
121
|
|
|
122
122
|
def scan_fun(_, i):
|
|
123
|
-
if eq_type == "
|
|
123
|
+
if eq_type == "PDENonStatio":
|
|
124
124
|
tangent_vec = jnp.repeat(
|
|
125
125
|
jax.nn.one_hot(i + 1, inputs.shape[-1])[None],
|
|
126
126
|
inputs.shape[0],
|
|
@@ -140,9 +140,9 @@ def divergence_fwd(
|
|
|
140
140
|
)
|
|
141
141
|
return _, du_dxi
|
|
142
142
|
|
|
143
|
-
if eq_type == "
|
|
143
|
+
if eq_type == "PDENonStatio":
|
|
144
144
|
_, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[1] - 1))
|
|
145
|
-
elif eq_type == "
|
|
145
|
+
elif eq_type == "PDEStatio":
|
|
146
146
|
_, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[1]))
|
|
147
147
|
else:
|
|
148
148
|
raise ValueError("Unexpected u.eq_type!")
|
|
@@ -154,7 +154,7 @@ def laplacian_rev(
|
|
|
154
154
|
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
155
155
|
params: Params[Array],
|
|
156
156
|
method: Literal["trace_hessian_x", "trace_hessian_t_x", "loop"] = "trace_hessian_x",
|
|
157
|
-
eq_type: Literal["
|
|
157
|
+
eq_type: Literal["PDENonStatio", "PDEStatio"] | None = None,
|
|
158
158
|
) -> Float[Array, " "]:
|
|
159
159
|
r"""
|
|
160
160
|
Compute the Laplacian of a scalar field $u$ from $\mathbb{R}^d$
|
|
@@ -196,22 +196,22 @@ def laplacian_rev(
|
|
|
196
196
|
# computation and then discarding elements but for higher order derivatives
|
|
197
197
|
# it might not be worth it. See other options below for computating the
|
|
198
198
|
# Laplacian
|
|
199
|
-
if eq_type == "
|
|
199
|
+
if eq_type == "PDENonStatio":
|
|
200
200
|
u_ = lambda x: jnp.squeeze(
|
|
201
201
|
u(jnp.concatenate([inputs[:1], x], axis=0), params)
|
|
202
202
|
)
|
|
203
203
|
return jnp.sum(jnp.diag(jax.hessian(u_)(inputs[1:])))
|
|
204
|
-
if eq_type == "
|
|
204
|
+
if eq_type == "PDEStatio":
|
|
205
205
|
u_ = lambda inputs: jnp.squeeze(u(inputs, params))
|
|
206
206
|
return jnp.sum(jnp.diag(jax.hessian(u_)(inputs)))
|
|
207
207
|
raise ValueError("Unexpected eq_type!")
|
|
208
208
|
if method == "trace_hessian_t_x":
|
|
209
209
|
# NOTE that it is unclear whether it is better to vectorially compute the
|
|
210
210
|
# Hessian (despite a useless time dimension) as below
|
|
211
|
-
if eq_type == "
|
|
211
|
+
if eq_type == "PDENonStatio":
|
|
212
212
|
u_ = lambda inputs: jnp.squeeze(u(inputs, params))
|
|
213
213
|
return jnp.sum(jnp.diag(jax.hessian(u_)(inputs))[1:])
|
|
214
|
-
if eq_type == "
|
|
214
|
+
if eq_type == "PDEStatio":
|
|
215
215
|
u_ = lambda inputs: jnp.squeeze(u(inputs, params))
|
|
216
216
|
return jnp.sum(jnp.diag(jax.hessian(u_)(inputs)))
|
|
217
217
|
raise ValueError("Unexpected eq_type!")
|
|
@@ -225,7 +225,7 @@ def laplacian_rev(
|
|
|
225
225
|
u_ = lambda inputs: u(inputs, params).squeeze()
|
|
226
226
|
|
|
227
227
|
def scan_fun(_, i):
|
|
228
|
-
if eq_type == "
|
|
228
|
+
if eq_type == "PDENonStatio":
|
|
229
229
|
d2u_dxi2 = grad(
|
|
230
230
|
lambda inputs: grad(u_)(inputs)[1 + i],
|
|
231
231
|
)(inputs)[1 + i]
|
|
@@ -236,11 +236,11 @@ def laplacian_rev(
|
|
|
236
236
|
)(inputs)[i]
|
|
237
237
|
return _, d2u_dxi2
|
|
238
238
|
|
|
239
|
-
if eq_type == "
|
|
239
|
+
if eq_type == "PDENonStatio":
|
|
240
240
|
_, trace_hessian = jax.lax.scan(
|
|
241
241
|
scan_fun, {}, jnp.arange(inputs.shape[0] - 1)
|
|
242
242
|
)
|
|
243
|
-
elif eq_type == "
|
|
243
|
+
elif eq_type == "PDEStatio":
|
|
244
244
|
_, trace_hessian = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[0]))
|
|
245
245
|
else:
|
|
246
246
|
raise ValueError("Unexpected eq_type!")
|
|
@@ -253,7 +253,7 @@ def laplacian_fwd(
|
|
|
253
253
|
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
254
254
|
params: Params[Array],
|
|
255
255
|
method: Literal["trace_hessian_t_x", "trace_hessian_x", "loop"] = "loop",
|
|
256
|
-
eq_type: Literal["
|
|
256
|
+
eq_type: Literal["PDENonStatio", "PDEStatio"] | None = None,
|
|
257
257
|
) -> Float[Array, " batch_size * (1+dim) 1"] | Float[Array, " batch_size * (dim) 1"]:
|
|
258
258
|
r"""
|
|
259
259
|
Compute the Laplacian of a **batched** scalar field $u$
|
|
@@ -302,7 +302,7 @@ def laplacian_fwd(
|
|
|
302
302
|
if method == "loop":
|
|
303
303
|
|
|
304
304
|
def scan_fun(_, i):
|
|
305
|
-
if eq_type == "
|
|
305
|
+
if eq_type == "PDENonStatio":
|
|
306
306
|
tangent_vec = jnp.repeat(
|
|
307
307
|
jax.nn.one_hot(i + 1, inputs.shape[-1])[None],
|
|
308
308
|
inputs.shape[0],
|
|
@@ -323,17 +323,17 @@ def laplacian_fwd(
|
|
|
323
323
|
__, d2u_dxi2 = jax.jvp(du_dxi_fun, (inputs,), (tangent_vec,))
|
|
324
324
|
return _, d2u_dxi2
|
|
325
325
|
|
|
326
|
-
if eq_type == "
|
|
326
|
+
if eq_type == "PDENonStatio":
|
|
327
327
|
_, trace_hessian = jax.lax.scan(
|
|
328
328
|
scan_fun, {}, jnp.arange(inputs.shape[-1] - 1)
|
|
329
329
|
)
|
|
330
|
-
elif eq_type == "
|
|
330
|
+
elif eq_type == "PDEStatio":
|
|
331
331
|
_, trace_hessian = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[-1]))
|
|
332
332
|
else:
|
|
333
333
|
raise ValueError("Unexpected eq_type!")
|
|
334
334
|
return jnp.sum(trace_hessian, axis=0)
|
|
335
335
|
if method == "trace_hessian_t_x":
|
|
336
|
-
if eq_type == "
|
|
336
|
+
if eq_type == "PDENonStatio":
|
|
337
337
|
# compute the Hessian including the batch dimension, get rid of the
|
|
338
338
|
# (..,1,..) axis that is here because of the scalar output
|
|
339
339
|
# if inputs.shape==(10,3) (1 for time, 2 for x_dim)
|
|
@@ -351,7 +351,7 @@ def laplacian_fwd(
|
|
|
351
351
|
res_dims = "".join([f"{chr(97 + d)}" for d in range(inputs.shape[-1])])
|
|
352
352
|
lap = jnp.einsum(res_dims + "ii->" + res_dims, r)
|
|
353
353
|
return lap[..., None]
|
|
354
|
-
if eq_type == "
|
|
354
|
+
if eq_type == "PDEStatio":
|
|
355
355
|
# compute the Hessian including the batch dimension, get rid of the
|
|
356
356
|
# (..,1,..) axis that is here because of the scalar output
|
|
357
357
|
# if inputs.shape==(10,2), r.shape=(10,10,1,10,2,10,2)
|
|
@@ -369,7 +369,7 @@ def laplacian_fwd(
|
|
|
369
369
|
return lap[..., None]
|
|
370
370
|
raise ValueError("Unexpected eq_type!")
|
|
371
371
|
if method == "trace_hessian_x":
|
|
372
|
-
if eq_type == "
|
|
372
|
+
if eq_type == "PDEStatio":
|
|
373
373
|
# compute the Hessian including the batch dimension, get rid of the
|
|
374
374
|
# (..,1,..) axis that is here because of the scalar output
|
|
375
375
|
# if inputs.shape==(10,2), r.shape=(10,10,1,10,2,10,2)
|
|
@@ -394,7 +394,7 @@ def vectorial_laplacian_rev(
|
|
|
394
394
|
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
395
395
|
params: Params[Array],
|
|
396
396
|
dim_out: int | None = None,
|
|
397
|
-
eq_type: Literal["
|
|
397
|
+
eq_type: Literal["PDENonStatio", "PDEStatio"] | None = None,
|
|
398
398
|
) -> Float[Array, " dim_out"]:
|
|
399
399
|
r"""
|
|
400
400
|
Compute the vectorial Laplacian of a vector field $\mathbf{u}$ from
|
|
@@ -448,7 +448,7 @@ def vectorial_laplacian_fwd(
|
|
|
448
448
|
u: AbstractPINN | Callable[[Array, Params[Array]], Array],
|
|
449
449
|
params: Params[Array],
|
|
450
450
|
dim_out: int | None = None,
|
|
451
|
-
eq_type: Literal["
|
|
451
|
+
eq_type: Literal["PDENonStatio", "PDEStatio"] | None = None,
|
|
452
452
|
) -> Float[Array, " batch_size * (1+dim) n"] | Float[Array, " batch_size * (dim) n"]:
|
|
453
453
|
r"""
|
|
454
454
|
Compute the vectorial Laplacian of a vector field $\mathbf{u}$ when
|
jinns/nn/_abstract_pinn.py
CHANGED
|
@@ -13,7 +13,7 @@ class AbstractPINN(eqx.Module):
|
|
|
13
13
|
https://github.com/patrick-kidger/equinox/issues/1002 + https://docs.kidger.site/equinox/pattern/
|
|
14
14
|
"""
|
|
15
15
|
|
|
16
|
-
eq_type: eqx.AbstractVar[Literal["ODE", "
|
|
16
|
+
eq_type: eqx.AbstractVar[Literal["ODE", "PDEStatio", "PDENonStatio"]]
|
|
17
17
|
|
|
18
18
|
@abc.abstractmethod
|
|
19
19
|
def __call__(self, inputs: Any, params: Params[Array], *args, **kwargs) -> Any:
|
jinns/nn/_hyperpinn.py
CHANGED
|
@@ -67,9 +67,9 @@ class HyperPINN(PINN):
|
|
|
67
67
|
eq_type : str
|
|
68
68
|
A string with three possibilities.
|
|
69
69
|
"ODE": the HyperPINN is called with one input `t`.
|
|
70
|
-
"
|
|
70
|
+
"PDEStatio": the HyperPINN is called with one input `x`, `x`
|
|
71
71
|
can be high dimensional.
|
|
72
|
-
"
|
|
72
|
+
"PDENonStatio": the HyperPINN is called with two inputs `t` and `x`, `x`
|
|
73
73
|
can be high dimensional.
|
|
74
74
|
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
75
75
|
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
@@ -192,7 +192,7 @@ class HyperPINN(PINN):
|
|
|
192
192
|
hyper = eqx.combine(params.nn_params, self.static_hyper)
|
|
193
193
|
|
|
194
194
|
eq_params_batch = jnp.concatenate(
|
|
195
|
-
[getattr(params.eq_params, k).flatten() for k in self.hyperparams],
|
|
195
|
+
[getattr(params.eq_params, k).flatten() for k in self.hyperparams], # pylint: disable=E1133
|
|
196
196
|
axis=0,
|
|
197
197
|
)
|
|
198
198
|
|
|
@@ -214,7 +214,7 @@ class HyperPINN(PINN):
|
|
|
214
214
|
def create(
|
|
215
215
|
cls,
|
|
216
216
|
*,
|
|
217
|
-
eq_type: Literal["ODE", "
|
|
217
|
+
eq_type: Literal["ODE", "PDEStatio", "PDENonStatio"],
|
|
218
218
|
hyperparams: list[str],
|
|
219
219
|
hypernet_input_size: int,
|
|
220
220
|
key: PRNGKeyArray | None = None,
|
|
@@ -257,9 +257,9 @@ class HyperPINN(PINN):
|
|
|
257
257
|
eq_type
|
|
258
258
|
A string with three possibilities.
|
|
259
259
|
"ODE": the HyperPINN is called with one input `t`.
|
|
260
|
-
"
|
|
260
|
+
"PDEStatio": the HyperPINN is called with one input `x`, `x`
|
|
261
261
|
can be high dimensional.
|
|
262
|
-
"
|
|
262
|
+
"PDENonStatio": the HyperPINN is called with two inputs `t` and `x`, `x`
|
|
263
263
|
can be high dimensional.
|
|
264
264
|
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
265
265
|
of the dimension of `t` + the dimension of `x` or the output dimension
|
jinns/nn/_mlp.py
CHANGED
|
@@ -95,7 +95,7 @@ class PINN_MLP(PINN):
|
|
|
95
95
|
def create(
|
|
96
96
|
cls,
|
|
97
97
|
*,
|
|
98
|
-
eq_type: Literal["ODE", "
|
|
98
|
+
eq_type: Literal["ODE", "PDEStatio", "PDENonStatio"],
|
|
99
99
|
key: PRNGKeyArray | None = None,
|
|
100
100
|
eqx_network: eqx.nn.MLP | MLP | None = None,
|
|
101
101
|
eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...] | None = None,
|
|
@@ -130,9 +130,9 @@ class PINN_MLP(PINN):
|
|
|
130
130
|
eq_type
|
|
131
131
|
A string with three possibilities.
|
|
132
132
|
"ODE": the MLP is called with one input `t`.
|
|
133
|
-
"
|
|
133
|
+
"PDEStatio": the MLP is called with one input `x`, `x`
|
|
134
134
|
can be high dimensional.
|
|
135
|
-
"
|
|
135
|
+
"PDENonStatio": the MLP is called with two inputs `t` and `x`, `x`
|
|
136
136
|
can be high dimensional.
|
|
137
137
|
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
138
138
|
of the dimension of `t` + the dimension of `x` or the output dimension
|
jinns/nn/_pinn.py
CHANGED
|
@@ -50,12 +50,12 @@ class PINN(AbstractPINN):
|
|
|
50
50
|
when the PINN is also used to output equation parameters for example
|
|
51
51
|
Note that it must be a slice and not an integer (a preprocessing of the
|
|
52
52
|
user provided argument takes care of it).
|
|
53
|
-
eq_type : Literal["ODE", "
|
|
53
|
+
eq_type : Literal["ODE", "PDEStatio", "PDENonStatio"]
|
|
54
54
|
A string with three possibilities.
|
|
55
55
|
"ODE": the PINN is called with one input `t`.
|
|
56
|
-
"
|
|
56
|
+
"PDEStatio": the PINN is called with one input `x`, `x`
|
|
57
57
|
can be high dimensional.
|
|
58
|
-
"
|
|
58
|
+
"PDENonStatio": the PINN is called with two inputs `t` and `x`, `x`
|
|
59
59
|
can be high dimensional.
|
|
60
60
|
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
61
61
|
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
@@ -83,11 +83,11 @@ class PINN(AbstractPINN):
|
|
|
83
83
|
Raises
|
|
84
84
|
------
|
|
85
85
|
RuntimeError
|
|
86
|
-
If the parameter value for eq_type is not in `["ODE", "
|
|
87
|
-
"
|
|
86
|
+
If the parameter value for eq_type is not in `["ODE", "PDEStatio",
|
|
87
|
+
"PDENonStatio"]`
|
|
88
88
|
"""
|
|
89
89
|
|
|
90
|
-
eq_type: Literal["ODE", "
|
|
90
|
+
eq_type: Literal["ODE", "PDEStatio", "PDENonStatio"] = eqx.field(
|
|
91
91
|
static=True, kw_only=True
|
|
92
92
|
)
|
|
93
93
|
slice_solution: slice = eqx.field(static=True, kw_only=True, default=None)
|
|
@@ -108,7 +108,7 @@ class PINN(AbstractPINN):
|
|
|
108
108
|
static: PINN = eqx.field(init=False, static=True)
|
|
109
109
|
|
|
110
110
|
def __post_init__(self, eqx_network):
|
|
111
|
-
if self.eq_type not in ["ODE", "
|
|
111
|
+
if self.eq_type not in ["ODE", "PDEStatio", "PDENonStatio"]:
|
|
112
112
|
raise RuntimeError("Wrong parameter value for eq_type")
|
|
113
113
|
# saving the static part of the model and initial parameters
|
|
114
114
|
|
jinns/nn/_ppinn.py
CHANGED
|
@@ -31,12 +31,12 @@ class PPINN_MLP(PINN):
|
|
|
31
31
|
when the PINN is also used to output equation parameters for example
|
|
32
32
|
Note that it must be a slice and not an integer (a preprocessing of the
|
|
33
33
|
user provided argument takes care of it).
|
|
34
|
-
eq_type : Literal["ODE", "
|
|
34
|
+
eq_type : Literal["ODE", "PDEStatio", "PDENonStatio"]
|
|
35
35
|
A string with three possibilities.
|
|
36
36
|
"ODE": the PPINN is called with one input `t`.
|
|
37
|
-
"
|
|
37
|
+
"PDEStatio": the PPINN is called with one input `x`, `x`
|
|
38
38
|
can be high dimensional.
|
|
39
|
-
"
|
|
39
|
+
"PDENonStatio": the PPINN is called with two inputs `t` and `x`, `x`
|
|
40
40
|
can be high dimensional.
|
|
41
41
|
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
42
42
|
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
@@ -125,7 +125,7 @@ class PPINN_MLP(PINN):
|
|
|
125
125
|
cls,
|
|
126
126
|
*,
|
|
127
127
|
key: PRNGKeyArray | None = None,
|
|
128
|
-
eq_type: Literal["ODE", "
|
|
128
|
+
eq_type: Literal["ODE", "PDEStatio", "PDENonStatio"],
|
|
129
129
|
eqx_network_list: list[eqx.nn.MLP | MLP] | None = None,
|
|
130
130
|
eqx_list_list: (
|
|
131
131
|
list[tuple[tuple[Callable, int, int] | tuple[Callable], ...]] | None
|
|
@@ -158,9 +158,9 @@ class PPINN_MLP(PINN):
|
|
|
158
158
|
eq_type
|
|
159
159
|
A string with three possibilities.
|
|
160
160
|
"ODE": the PPINN MLP is called with one input `t`.
|
|
161
|
-
"
|
|
161
|
+
"PDEStatio": the PPINN MLP is called with one input `x`, `x`
|
|
162
162
|
can be high dimensional.
|
|
163
|
-
"
|
|
163
|
+
"PDENonStatio": the PPINN MLP is called with two inputs `t` and `x`, `x`
|
|
164
164
|
can be high dimensional.
|
|
165
165
|
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
166
166
|
of the dimension of `t` + the dimension of `x` or the output dimension
|