jinns 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +114 -187
- jinns/loss/_DynamicLossAbstract.py +74 -69
- jinns/loss/_LossODE.py +132 -348
- jinns/loss/_LossPDE.py +262 -549
- jinns/loss/__init__.py +32 -6
- jinns/loss/_abstract_loss.py +128 -0
- jinns/loss/_boundary_conditions.py +20 -19
- jinns/loss/_loss_components.py +43 -0
- jinns/loss/_loss_utils.py +85 -179
- jinns/loss/_loss_weight_updates.py +202 -0
- jinns/loss/_loss_weights.py +64 -40
- jinns/loss/_operators.py +84 -74
- jinns/nn/__init__.py +15 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +94 -57
- jinns/nn/_mlp.py +50 -25
- jinns/nn/_pinn.py +33 -19
- jinns/nn/_ppinn.py +70 -34
- jinns/nn/_save_load.py +21 -51
- jinns/nn/_spinn.py +33 -16
- jinns/nn/_spinn_mlp.py +28 -22
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +35 -34
- jinns/solver/_rar.py +80 -63
- jinns/solver/_solve.py +207 -92
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +16 -10
- jinns/utils/_types.py +20 -54
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/METADATA +8 -4
- jinns-1.5.0.dist-info/RECORD +55 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns-1.3.0.dist-info/RECORD +0 -44
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/top_level.txt +0 -0
jinns/loss/__init__.py
CHANGED
|
@@ -1,22 +1,20 @@
|
|
|
1
1
|
from ._DynamicLossAbstract import DynamicLoss, ODE, PDEStatio, PDENonStatio
|
|
2
|
-
from ._LossODE import LossODE
|
|
3
|
-
from ._LossPDE import LossPDEStatio, LossPDENonStatio
|
|
2
|
+
from ._LossODE import LossODE
|
|
3
|
+
from ._LossPDE import LossPDEStatio, LossPDENonStatio
|
|
4
4
|
from ._DynamicLoss import (
|
|
5
5
|
GeneralizedLotkaVolterra,
|
|
6
6
|
BurgersEquation,
|
|
7
7
|
FPENonStatioLoss2D,
|
|
8
8
|
OU_FPENonStatioLoss2D,
|
|
9
9
|
FisherKPP,
|
|
10
|
-
|
|
11
|
-
NavierStokes2DStatio,
|
|
10
|
+
NavierStokesMassConservation2DStatio,
|
|
12
11
|
)
|
|
13
12
|
from ._loss_weights import (
|
|
14
13
|
LossWeightsODE,
|
|
15
|
-
LossWeightsODEDict,
|
|
16
14
|
LossWeightsPDENonStatio,
|
|
17
15
|
LossWeightsPDEStatio,
|
|
18
|
-
LossWeightsPDEDict,
|
|
19
16
|
)
|
|
17
|
+
from ._loss_weight_updates import soft_adapt, lr_annealing, ReLoBRaLo
|
|
20
18
|
|
|
21
19
|
from ._operators import (
|
|
22
20
|
divergence_fwd,
|
|
@@ -26,3 +24,31 @@ from ._operators import (
|
|
|
26
24
|
vectorial_laplacian_fwd,
|
|
27
25
|
vectorial_laplacian_rev,
|
|
28
26
|
)
|
|
27
|
+
|
|
28
|
+
__all__ = [
|
|
29
|
+
"DynamicLoss",
|
|
30
|
+
"ODE",
|
|
31
|
+
"PDEStatio",
|
|
32
|
+
"PDENonStatio",
|
|
33
|
+
"LossODE",
|
|
34
|
+
"LossPDEStatio",
|
|
35
|
+
"LossPDENonStatio",
|
|
36
|
+
"GeneralizedLotkaVolterra",
|
|
37
|
+
"BurgersEquation",
|
|
38
|
+
"FPENonStatioLoss2D",
|
|
39
|
+
"OU_FPENonStatioLoss2D",
|
|
40
|
+
"FisherKPP",
|
|
41
|
+
"NavierStokesMassConservation2DStatio",
|
|
42
|
+
"LossWeightsODE",
|
|
43
|
+
"LossWeightsPDEStatio",
|
|
44
|
+
"LossWeightsPDENonStatio",
|
|
45
|
+
"divergence_fwd",
|
|
46
|
+
"divergence_rev",
|
|
47
|
+
"laplacian_fwd",
|
|
48
|
+
"laplacian_rev",
|
|
49
|
+
"vectorial_laplacian_fwd",
|
|
50
|
+
"vectorial_laplacian_rev",
|
|
51
|
+
"soft_adapt",
|
|
52
|
+
"lr_annealing",
|
|
53
|
+
"ReLoBRaLo",
|
|
54
|
+
]
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import TYPE_CHECKING, Self, Literal, Callable
|
|
5
|
+
from jaxtyping import Array, PyTree, Key
|
|
6
|
+
import equinox as eqx
|
|
7
|
+
import jax
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
import optax
|
|
10
|
+
from jinns.loss._loss_weights import AbstractLossWeights
|
|
11
|
+
from jinns.parameters._params import Params
|
|
12
|
+
from jinns.loss._loss_weight_updates import soft_adapt, lr_annealing, ReLoBRaLo
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from jinns.utils._types import AnyLossComponents, AnyBatch
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AbstractLoss(eqx.Module):
|
|
19
|
+
"""
|
|
20
|
+
About the call:
|
|
21
|
+
https://github.com/patrick-kidger/equinox/issues/1002 + https://docs.kidger.site/equinox/pattern/
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
loss_weights: AbstractLossWeights
|
|
25
|
+
update_weight_method: Literal["soft_adapt", "lr_annealing", "ReLoBRaLo"] | None = (
|
|
26
|
+
eqx.field(kw_only=True, default=None, static=True)
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
@abc.abstractmethod
|
|
30
|
+
def __call__(self, *_, **__) -> Array:
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
@abc.abstractmethod
|
|
34
|
+
def evaluate_by_terms(
|
|
35
|
+
self, params: Params[Array], batch: AnyBatch
|
|
36
|
+
) -> tuple[AnyLossComponents, AnyLossComponents]:
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
def get_gradients(
|
|
40
|
+
self, fun: Callable[[Params[Array]], Array], params: Params[Array]
|
|
41
|
+
) -> tuple[Array, Array]:
|
|
42
|
+
"""
|
|
43
|
+
params already filtered with derivative keys here
|
|
44
|
+
"""
|
|
45
|
+
if fun is None:
|
|
46
|
+
return None, None
|
|
47
|
+
value_grad_loss = jax.value_and_grad(fun)
|
|
48
|
+
loss_val, grads = value_grad_loss(params)
|
|
49
|
+
return loss_val, grads
|
|
50
|
+
|
|
51
|
+
def ponderate_and_sum_loss(self, terms):
|
|
52
|
+
"""
|
|
53
|
+
Get total loss from individual loss terms and weights
|
|
54
|
+
|
|
55
|
+
tree.leaves is needed to get rid of None from non used loss terms
|
|
56
|
+
"""
|
|
57
|
+
weights = jax.tree.leaves(
|
|
58
|
+
self.loss_weights,
|
|
59
|
+
is_leaf=lambda x: eqx.is_inexact_array(x) and x is not None,
|
|
60
|
+
)
|
|
61
|
+
terms = jax.tree.leaves(
|
|
62
|
+
terms, is_leaf=lambda x: eqx.is_inexact_array(x) and x is not None
|
|
63
|
+
)
|
|
64
|
+
if len(weights) == len(terms):
|
|
65
|
+
return jnp.sum(jnp.array(weights) * jnp.array(terms))
|
|
66
|
+
else:
|
|
67
|
+
raise ValueError(
|
|
68
|
+
"The numbers of declared loss weights and "
|
|
69
|
+
"declared loss terms do not concord "
|
|
70
|
+
f" got {len(weights)} and {len(terms)}"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def ponderate_and_sum_gradient(self, terms):
|
|
74
|
+
"""
|
|
75
|
+
Get total gradients from individual loss gradients and weights
|
|
76
|
+
for each parameter
|
|
77
|
+
|
|
78
|
+
tree.leaves is needed to get rid of None from non used loss terms
|
|
79
|
+
"""
|
|
80
|
+
weights = jax.tree.leaves(
|
|
81
|
+
self.loss_weights,
|
|
82
|
+
is_leaf=lambda x: eqx.is_inexact_array(x) and x is not None,
|
|
83
|
+
)
|
|
84
|
+
grads = jax.tree.leaves(terms, is_leaf=lambda x: isinstance(x, Params))
|
|
85
|
+
# gradient terms for each individual loss for each parameter (several
|
|
86
|
+
# Params structures)
|
|
87
|
+
weights_pytree = jax.tree.map(
|
|
88
|
+
lambda w: optax.tree_utils.tree_full_like(grads[0], w), weights
|
|
89
|
+
) # We need several Params structures full of the weight scalar
|
|
90
|
+
weighted_grads = jax.tree.map(
|
|
91
|
+
lambda w, p: w * p, weights_pytree, grads, is_leaf=eqx.is_inexact_array
|
|
92
|
+
) # Now we can multiply
|
|
93
|
+
return jax.tree.map(
|
|
94
|
+
lambda *grads: jnp.sum(jnp.array(grads), axis=0),
|
|
95
|
+
*weighted_grads,
|
|
96
|
+
is_leaf=eqx.is_inexact_array,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def update_weights(
|
|
100
|
+
self: Self,
|
|
101
|
+
iteration_nb: int,
|
|
102
|
+
loss_terms: PyTree,
|
|
103
|
+
stored_loss_terms: PyTree,
|
|
104
|
+
grad_terms: PyTree,
|
|
105
|
+
key: Key,
|
|
106
|
+
) -> Self:
|
|
107
|
+
"""
|
|
108
|
+
Update the loss weights according to a predefined scheme
|
|
109
|
+
"""
|
|
110
|
+
if self.update_weight_method == "soft_adapt":
|
|
111
|
+
new_weights = soft_adapt(
|
|
112
|
+
self.loss_weights, iteration_nb, loss_terms, stored_loss_terms
|
|
113
|
+
)
|
|
114
|
+
elif self.update_weight_method == "lr_annealing":
|
|
115
|
+
new_weights = lr_annealing(self.loss_weights, grad_terms)
|
|
116
|
+
elif self.update_weight_method == "ReLoBRaLo":
|
|
117
|
+
new_weights = ReLoBRaLo(
|
|
118
|
+
self.loss_weights, iteration_nb, loss_terms, stored_loss_terms, key
|
|
119
|
+
)
|
|
120
|
+
else:
|
|
121
|
+
raise ValueError("update_weight_method for loss weights not implemented")
|
|
122
|
+
|
|
123
|
+
# Below we update the non None entry in the PyTree self.loss_weights
|
|
124
|
+
# we directly get the non None entries because None is not treated as a
|
|
125
|
+
# leaf
|
|
126
|
+
return eqx.tree_at(
|
|
127
|
+
lambda pt: jax.tree.leaves(pt.loss_weights), self, new_weights
|
|
128
|
+
)
|
|
@@ -7,31 +7,31 @@ from __future__ import (
|
|
|
7
7
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
8
|
|
|
9
9
|
from typing import TYPE_CHECKING, Callable
|
|
10
|
+
from jaxtyping import Array, Float
|
|
10
11
|
import jax
|
|
11
12
|
import jax.numpy as jnp
|
|
12
13
|
from jax import vmap, grad
|
|
13
|
-
import equinox as eqx
|
|
14
14
|
from jinns.utils._utils import get_grid, _subtract_with_check
|
|
15
|
-
from jinns.data._Batchs import
|
|
15
|
+
from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
|
|
16
16
|
from jinns.nn._pinn import PINN
|
|
17
17
|
from jinns.nn._spinn import SPINN
|
|
18
18
|
|
|
19
19
|
if TYPE_CHECKING:
|
|
20
|
-
from jinns.
|
|
20
|
+
from jinns.parameters._params import Params
|
|
21
|
+
from jinns.utils._types import BoundaryConditionFun
|
|
22
|
+
from jinns.nn._abstract_pinn import AbstractPINN
|
|
21
23
|
|
|
22
24
|
|
|
23
25
|
def _compute_boundary_loss(
|
|
24
26
|
boundary_condition_type: str,
|
|
25
|
-
f:
|
|
26
|
-
[Float[Array, "dim"] | Float[Array, "dim + 1"]], Float[Array, "dim_solution"]
|
|
27
|
-
],
|
|
27
|
+
f: BoundaryConditionFun,
|
|
28
28
|
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
29
|
-
u:
|
|
30
|
-
params:
|
|
29
|
+
u: AbstractPINN,
|
|
30
|
+
params: Params[Array],
|
|
31
31
|
facet: int,
|
|
32
32
|
dim_to_apply: slice,
|
|
33
33
|
vmap_in_axes: tuple,
|
|
34
|
-
) ->
|
|
34
|
+
) -> Float[Array, " "]:
|
|
35
35
|
r"""A generic function that will compute the mini-batch MSE of a
|
|
36
36
|
boundary condition in the stationary case, resp. non-stationary, given by:
|
|
37
37
|
|
|
@@ -67,7 +67,7 @@ def _compute_boundary_loss(
|
|
|
67
67
|
u
|
|
68
68
|
a PINN
|
|
69
69
|
params
|
|
70
|
-
Params
|
|
70
|
+
Params
|
|
71
71
|
facet
|
|
72
72
|
An integer which represents the id of the facet which is currently
|
|
73
73
|
considered (in the order provided by the DataGenerator which is fixed)
|
|
@@ -96,15 +96,15 @@ def _compute_boundary_loss(
|
|
|
96
96
|
|
|
97
97
|
def boundary_dirichlet(
|
|
98
98
|
f: Callable[
|
|
99
|
-
[Float[Array, "dim"] | Float[Array, "dim + 1"]], Float[Array, "dim_solution"]
|
|
99
|
+
[Float[Array, " dim"] | Float[Array, " dim + 1"]], Float[Array, " dim_solution"]
|
|
100
100
|
],
|
|
101
101
|
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
102
|
-
u:
|
|
103
|
-
params: Params
|
|
102
|
+
u: AbstractPINN,
|
|
103
|
+
params: Params[Array],
|
|
104
104
|
facet: int,
|
|
105
105
|
dim_to_apply: slice,
|
|
106
106
|
vmap_in_axes: tuple,
|
|
107
|
-
) ->
|
|
107
|
+
) -> Float[Array, " "]:
|
|
108
108
|
r"""
|
|
109
109
|
This omega boundary condition enforces a solution that is equal to `f`
|
|
110
110
|
at `times_batch` x `omega_border` (non stationary case) or at `omega_border`
|
|
@@ -135,6 +135,7 @@ def boundary_dirichlet(
|
|
|
135
135
|
vmap_in_axes
|
|
136
136
|
A tuple object which specifies the in_axes of the vmapping
|
|
137
137
|
"""
|
|
138
|
+
assert batch.border_batch is not None
|
|
138
139
|
batch_array = batch.border_batch
|
|
139
140
|
batch_array = batch_array[..., facet]
|
|
140
141
|
|
|
@@ -168,15 +169,15 @@ def boundary_dirichlet(
|
|
|
168
169
|
|
|
169
170
|
def boundary_neumann(
|
|
170
171
|
f: Callable[
|
|
171
|
-
[Float[Array, "dim"] | Float[Array, "dim + 1"]], Float[Array, "dim_solution"]
|
|
172
|
+
[Float[Array, " dim"] | Float[Array, " dim + 1"]], Float[Array, " dim_solution"]
|
|
172
173
|
],
|
|
173
174
|
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
174
|
-
u:
|
|
175
|
-
params: Params
|
|
175
|
+
u: AbstractPINN,
|
|
176
|
+
params: Params[Array],
|
|
176
177
|
facet: int,
|
|
177
178
|
dim_to_apply: slice,
|
|
178
179
|
vmap_in_axes: tuple,
|
|
179
|
-
) ->
|
|
180
|
+
) -> Float[Array, " "]:
|
|
180
181
|
r"""
|
|
181
182
|
This omega boundary condition enforces a solution where $\nabla u\cdot
|
|
182
183
|
n$ is equal to `f` at the cartesian product of `time_batch` x `omega
|
|
@@ -208,6 +209,7 @@ def boundary_neumann(
|
|
|
208
209
|
vmap_in_axes
|
|
209
210
|
A tuple object which specifies the in_axes of the vmapping
|
|
210
211
|
"""
|
|
212
|
+
assert batch.border_batch is not None
|
|
211
213
|
batch_array = batch.border_batch
|
|
212
214
|
batch_array = batch_array[..., facet]
|
|
213
215
|
|
|
@@ -223,7 +225,6 @@ def boundary_neumann(
|
|
|
223
225
|
n = jnp.array([[-1, 1, 0, 0], [0, 0, -1, 1]])
|
|
224
226
|
|
|
225
227
|
if isinstance(u, PINN):
|
|
226
|
-
|
|
227
228
|
u_ = lambda inputs, params: jnp.squeeze(u(inputs, params)[dim_to_apply])
|
|
228
229
|
|
|
229
230
|
if u.eq_type == "statio_PDE":
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from typing import TypeVar, Generic
|
|
2
|
+
from dataclasses import fields
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
|
|
5
|
+
T = TypeVar("T")
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class XDEComponentsAbstract(eqx.Module, Generic[T]):
|
|
9
|
+
"""
|
|
10
|
+
Provides a template for ODE components with generic types.
|
|
11
|
+
One can inherit to specialize and add methods and attributes
|
|
12
|
+
We do not enforce keyword only to avoid being to verbose (this then can
|
|
13
|
+
work like a tuple)
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def items(self):
|
|
17
|
+
"""
|
|
18
|
+
For the dataclass to be iterated like a dictionary.
|
|
19
|
+
Practical and retrocompatible with old code when loss components were
|
|
20
|
+
dictionaries
|
|
21
|
+
"""
|
|
22
|
+
return {
|
|
23
|
+
field.name: getattr(self, field.name)
|
|
24
|
+
for field in fields(self)
|
|
25
|
+
if getattr(self, field.name) is not None
|
|
26
|
+
}.items()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ODEComponents(XDEComponentsAbstract[T]):
|
|
30
|
+
dyn_loss: T
|
|
31
|
+
initial_condition: T
|
|
32
|
+
observations: T
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class PDEStatioComponents(XDEComponentsAbstract[T]):
|
|
36
|
+
dyn_loss: T
|
|
37
|
+
norm_loss: T
|
|
38
|
+
boundary_loss: T
|
|
39
|
+
observations: T
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class PDENonStatioComponents(PDEStatioComponents[T]):
|
|
43
|
+
initial_condition: T
|