jinns 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +116 -189
- jinns/loss/_DynamicLossAbstract.py +45 -68
- jinns/loss/_LossODE.py +71 -336
- jinns/loss/_LossPDE.py +176 -513
- jinns/loss/__init__.py +28 -6
- jinns/loss/_abstract_loss.py +15 -0
- jinns/loss/_boundary_conditions.py +22 -21
- jinns/loss/_loss_utils.py +98 -173
- jinns/loss/_loss_weights.py +12 -44
- jinns/loss/_operators.py +84 -76
- jinns/nn/__init__.py +22 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +434 -0
- jinns/nn/_mlp.py +217 -0
- jinns/nn/_pinn.py +204 -0
- jinns/nn/_ppinn.py +239 -0
- jinns/{utils → nn}/_save_load.py +39 -53
- jinns/nn/_spinn.py +123 -0
- jinns/nn/_spinn_mlp.py +202 -0
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +38 -37
- jinns/solver/_rar.py +82 -65
- jinns/solver/_solve.py +111 -71
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -5
- jinns/utils/_containers.py +12 -9
- jinns/utils/_types.py +11 -57
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/METADATA +11 -10
- jinns-1.4.0.dist-info/RECORD +53 -0
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns/utils/_hyperpinn.py +0 -420
- jinns/utils/_pinn.py +0 -324
- jinns/utils/_ppinn.py +0 -227
- jinns/utils/_spinn.py +0 -249
- jinns-1.2.0.dist-info/RECORD +0 -41
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/loss/__init__.py
CHANGED
|
@@ -1,21 +1,18 @@
|
|
|
1
1
|
from ._DynamicLossAbstract import DynamicLoss, ODE, PDEStatio, PDENonStatio
|
|
2
|
-
from ._LossODE import LossODE
|
|
3
|
-
from ._LossPDE import LossPDEStatio, LossPDENonStatio
|
|
2
|
+
from ._LossODE import LossODE
|
|
3
|
+
from ._LossPDE import LossPDEStatio, LossPDENonStatio
|
|
4
4
|
from ._DynamicLoss import (
|
|
5
5
|
GeneralizedLotkaVolterra,
|
|
6
6
|
BurgersEquation,
|
|
7
7
|
FPENonStatioLoss2D,
|
|
8
8
|
OU_FPENonStatioLoss2D,
|
|
9
9
|
FisherKPP,
|
|
10
|
-
|
|
11
|
-
NavierStokes2DStatio,
|
|
10
|
+
NavierStokesMassConservation2DStatio,
|
|
12
11
|
)
|
|
13
12
|
from ._loss_weights import (
|
|
14
13
|
LossWeightsODE,
|
|
15
|
-
LossWeightsODEDict,
|
|
16
14
|
LossWeightsPDENonStatio,
|
|
17
15
|
LossWeightsPDEStatio,
|
|
18
|
-
LossWeightsPDEDict,
|
|
19
16
|
)
|
|
20
17
|
|
|
21
18
|
from ._operators import (
|
|
@@ -26,3 +23,28 @@ from ._operators import (
|
|
|
26
23
|
vectorial_laplacian_fwd,
|
|
27
24
|
vectorial_laplacian_rev,
|
|
28
25
|
)
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"DynamicLoss",
|
|
29
|
+
"ODE",
|
|
30
|
+
"PDEStatio",
|
|
31
|
+
"PDENonStatio",
|
|
32
|
+
"LossODE",
|
|
33
|
+
"LossPDEStatio",
|
|
34
|
+
"LossPDENonStatio",
|
|
35
|
+
"GeneralizedLotkaVolterra",
|
|
36
|
+
"BurgersEquation",
|
|
37
|
+
"FPENonStatioLoss2D",
|
|
38
|
+
"OU_FPENonStatioLoss2D",
|
|
39
|
+
"FisherKPP",
|
|
40
|
+
"NavierStokesMassConservation2DStatio",
|
|
41
|
+
"LossWeightsODE",
|
|
42
|
+
"LossWeightsPDEStatio",
|
|
43
|
+
"LossWeightsPDENonStatio",
|
|
44
|
+
"divergence_fwd",
|
|
45
|
+
"divergence_rev",
|
|
46
|
+
"laplacian_fwd",
|
|
47
|
+
"laplacian_rev",
|
|
48
|
+
"vectorial_laplacian_fwd",
|
|
49
|
+
"vectorial_laplacian_rev",
|
|
50
|
+
]
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from jaxtyping import Array
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class AbstractLoss(eqx.Module):
|
|
7
|
+
"""
|
|
8
|
+
Basically just a way to add a __call__ to an eqx.Module.
|
|
9
|
+
The way to go for correct type hints apparently
|
|
10
|
+
https://github.com/patrick-kidger/equinox/issues/1002 + https://docs.kidger.site/equinox/pattern/
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
@abc.abstractmethod
|
|
14
|
+
def __call__(self, *_, **__) -> Array:
|
|
15
|
+
pass
|
|
@@ -7,31 +7,31 @@ from __future__ import (
|
|
|
7
7
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
8
|
|
|
9
9
|
from typing import TYPE_CHECKING, Callable
|
|
10
|
+
from jaxtyping import Array, Float
|
|
10
11
|
import jax
|
|
11
12
|
import jax.numpy as jnp
|
|
12
13
|
from jax import vmap, grad
|
|
13
|
-
import equinox as eqx
|
|
14
14
|
from jinns.utils._utils import get_grid, _subtract_with_check
|
|
15
|
-
from jinns.data._Batchs import
|
|
16
|
-
from jinns.
|
|
17
|
-
from jinns.
|
|
15
|
+
from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
|
|
16
|
+
from jinns.nn._pinn import PINN
|
|
17
|
+
from jinns.nn._spinn import SPINN
|
|
18
18
|
|
|
19
19
|
if TYPE_CHECKING:
|
|
20
|
-
from jinns.
|
|
20
|
+
from jinns.parameters._params import Params
|
|
21
|
+
from jinns.utils._types import BoundaryConditionFun
|
|
22
|
+
from jinns.nn._abstract_pinn import AbstractPINN
|
|
21
23
|
|
|
22
24
|
|
|
23
25
|
def _compute_boundary_loss(
|
|
24
26
|
boundary_condition_type: str,
|
|
25
|
-
f:
|
|
26
|
-
[Float[Array, "dim"] | Float[Array, "dim + 1"]], Float[Array, "dim_solution"]
|
|
27
|
-
],
|
|
27
|
+
f: BoundaryConditionFun,
|
|
28
28
|
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
29
|
-
u:
|
|
30
|
-
params:
|
|
29
|
+
u: AbstractPINN,
|
|
30
|
+
params: Params[Array],
|
|
31
31
|
facet: int,
|
|
32
32
|
dim_to_apply: slice,
|
|
33
33
|
vmap_in_axes: tuple,
|
|
34
|
-
) ->
|
|
34
|
+
) -> Float[Array, " "]:
|
|
35
35
|
r"""A generic function that will compute the mini-batch MSE of a
|
|
36
36
|
boundary condition in the stationary case, resp. non-stationary, given by:
|
|
37
37
|
|
|
@@ -67,7 +67,7 @@ def _compute_boundary_loss(
|
|
|
67
67
|
u
|
|
68
68
|
a PINN
|
|
69
69
|
params
|
|
70
|
-
Params
|
|
70
|
+
Params
|
|
71
71
|
facet
|
|
72
72
|
An integer which represents the id of the facet which is currently
|
|
73
73
|
considered (in the order provided by the DataGenerator which is fixed)
|
|
@@ -96,15 +96,15 @@ def _compute_boundary_loss(
|
|
|
96
96
|
|
|
97
97
|
def boundary_dirichlet(
|
|
98
98
|
f: Callable[
|
|
99
|
-
[Float[Array, "dim"] | Float[Array, "dim + 1"]], Float[Array, "dim_solution"]
|
|
99
|
+
[Float[Array, " dim"] | Float[Array, " dim + 1"]], Float[Array, " dim_solution"]
|
|
100
100
|
],
|
|
101
101
|
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
102
|
-
u:
|
|
103
|
-
params: Params
|
|
102
|
+
u: AbstractPINN,
|
|
103
|
+
params: Params[Array],
|
|
104
104
|
facet: int,
|
|
105
105
|
dim_to_apply: slice,
|
|
106
106
|
vmap_in_axes: tuple,
|
|
107
|
-
) ->
|
|
107
|
+
) -> Float[Array, " "]:
|
|
108
108
|
r"""
|
|
109
109
|
This omega boundary condition enforces a solution that is equal to `f`
|
|
110
110
|
at `times_batch` x `omega_border` (non stationary case) or at `omega_border`
|
|
@@ -135,6 +135,7 @@ def boundary_dirichlet(
|
|
|
135
135
|
vmap_in_axes
|
|
136
136
|
A tuple object which specifies the in_axes of the vmapping
|
|
137
137
|
"""
|
|
138
|
+
assert batch.border_batch is not None
|
|
138
139
|
batch_array = batch.border_batch
|
|
139
140
|
batch_array = batch_array[..., facet]
|
|
140
141
|
|
|
@@ -168,15 +169,15 @@ def boundary_dirichlet(
|
|
|
168
169
|
|
|
169
170
|
def boundary_neumann(
|
|
170
171
|
f: Callable[
|
|
171
|
-
[Float[Array, "dim"] | Float[Array, "dim + 1"]], Float[Array, "dim_solution"]
|
|
172
|
+
[Float[Array, " dim"] | Float[Array, " dim + 1"]], Float[Array, " dim_solution"]
|
|
172
173
|
],
|
|
173
174
|
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
174
|
-
u:
|
|
175
|
-
params: Params
|
|
175
|
+
u: AbstractPINN,
|
|
176
|
+
params: Params[Array],
|
|
176
177
|
facet: int,
|
|
177
178
|
dim_to_apply: slice,
|
|
178
179
|
vmap_in_axes: tuple,
|
|
179
|
-
) ->
|
|
180
|
+
) -> Float[Array, " "]:
|
|
180
181
|
r"""
|
|
181
182
|
This omega boundary condition enforces a solution where $\nabla u\cdot
|
|
182
183
|
n$ is equal to `f` at the cartesian product of `time_batch` x `omega
|
|
@@ -208,6 +209,7 @@ def boundary_neumann(
|
|
|
208
209
|
vmap_in_axes
|
|
209
210
|
A tuple object which specifies the in_axes of the vmapping
|
|
210
211
|
"""
|
|
212
|
+
assert batch.border_batch is not None
|
|
211
213
|
batch_array = batch.border_batch
|
|
212
214
|
batch_array = batch_array[..., facet]
|
|
213
215
|
|
|
@@ -223,7 +225,6 @@ def boundary_neumann(
|
|
|
223
225
|
n = jnp.array([[-1, 1, 0, 0], [0, 0, -1, 1]])
|
|
224
226
|
|
|
225
227
|
if isinstance(u, PINN):
|
|
226
|
-
|
|
227
228
|
u_ = lambda inputs, params: jnp.squeeze(u(inputs, params)[dim_to_apply])
|
|
228
229
|
|
|
229
230
|
if u.eq_type == "statio_PDE":
|
jinns/loss/_loss_utils.py
CHANGED
|
@@ -6,50 +6,53 @@ from __future__ import (
|
|
|
6
6
|
annotations,
|
|
7
7
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
8
|
|
|
9
|
-
from typing import TYPE_CHECKING, Callable,
|
|
9
|
+
from typing import TYPE_CHECKING, Callable, TypeGuard
|
|
10
|
+
from types import EllipsisType
|
|
10
11
|
import jax
|
|
11
12
|
import jax.numpy as jnp
|
|
12
13
|
from jax import vmap
|
|
13
|
-
import
|
|
14
|
-
from jaxtyping import Float, Array, PyTree
|
|
14
|
+
from jaxtyping import Float, Array
|
|
15
15
|
|
|
16
16
|
from jinns.loss._boundary_conditions import (
|
|
17
17
|
_compute_boundary_loss,
|
|
18
18
|
)
|
|
19
19
|
from jinns.utils._utils import _subtract_with_check, get_grid
|
|
20
|
-
from jinns.data.
|
|
20
|
+
from jinns.data._utils import make_cartesian_product
|
|
21
21
|
from jinns.parameters._params import _get_vmap_in_axes_params
|
|
22
|
-
from jinns.
|
|
23
|
-
from jinns.
|
|
24
|
-
from jinns.
|
|
25
|
-
from jinns.data._Batchs import
|
|
26
|
-
from jinns.parameters._params import Params
|
|
22
|
+
from jinns.nn._pinn import PINN
|
|
23
|
+
from jinns.nn._spinn import SPINN
|
|
24
|
+
from jinns.nn._hyperpinn import HyperPINN
|
|
25
|
+
from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
|
|
26
|
+
from jinns.parameters._params import Params
|
|
27
27
|
|
|
28
28
|
if TYPE_CHECKING:
|
|
29
|
-
from jinns.utils._types import
|
|
29
|
+
from jinns.utils._types import BoundaryConditionFun
|
|
30
|
+
from jinns.nn._abstract_pinn import AbstractPINN
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
def dynamic_loss_apply(
|
|
33
|
-
dyn_loss:
|
|
34
|
-
u:
|
|
34
|
+
dyn_loss: Callable,
|
|
35
|
+
u: AbstractPINN,
|
|
35
36
|
batch: (
|
|
36
|
-
Float[Array, "batch_size 1"]
|
|
37
|
-
| Float[Array, "batch_size dim"]
|
|
38
|
-
| Float[Array, "batch_size 1+dim"]
|
|
37
|
+
Float[Array, " batch_size 1"]
|
|
38
|
+
| Float[Array, " batch_size dim"]
|
|
39
|
+
| Float[Array, " batch_size 1+dim"]
|
|
39
40
|
),
|
|
40
|
-
params: Params
|
|
41
|
-
vmap_axes: tuple[int | None
|
|
42
|
-
loss_weight: float | Float[Array, "dyn_loss_dimension"],
|
|
43
|
-
u_type: PINN |
|
|
44
|
-
) ->
|
|
41
|
+
params: Params[Array],
|
|
42
|
+
vmap_axes: tuple[int, Params[int | None] | None],
|
|
43
|
+
loss_weight: float | Float[Array, " dyn_loss_dimension"],
|
|
44
|
+
u_type: PINN | HyperPINN | None = None,
|
|
45
|
+
) -> Float[Array, " "]:
|
|
45
46
|
"""
|
|
46
47
|
Sometimes when u is a lambda function a or dict we do not have access to
|
|
47
48
|
its type here, hence the last argument
|
|
48
49
|
"""
|
|
49
|
-
if u_type == PINN or u_type ==
|
|
50
|
+
if u_type == PINN or u_type == HyperPINN or isinstance(u, (PINN, HyperPINN)):
|
|
50
51
|
v_dyn_loss = vmap(
|
|
51
52
|
lambda batch, params: dyn_loss(
|
|
52
|
-
batch,
|
|
53
|
+
batch,
|
|
54
|
+
u,
|
|
55
|
+
params, # we must place the params at the end
|
|
53
56
|
),
|
|
54
57
|
vmap_axes,
|
|
55
58
|
0,
|
|
@@ -66,36 +69,38 @@ def dynamic_loss_apply(
|
|
|
66
69
|
|
|
67
70
|
|
|
68
71
|
def normalization_loss_apply(
|
|
69
|
-
u:
|
|
72
|
+
u: AbstractPINN,
|
|
70
73
|
batches: (
|
|
71
|
-
tuple[Float[Array, "nb_norm_samples dim"]]
|
|
74
|
+
tuple[Float[Array, " nb_norm_samples dim"]]
|
|
72
75
|
| tuple[
|
|
73
|
-
Float[Array, "nb_norm_time_slices 1"], Float[Array, "nb_norm_samples dim"]
|
|
76
|
+
Float[Array, " nb_norm_time_slices 1"], Float[Array, " nb_norm_samples dim"]
|
|
74
77
|
]
|
|
75
78
|
),
|
|
76
|
-
params: Params
|
|
77
|
-
vmap_axes_params: tuple[int | None
|
|
78
|
-
|
|
79
|
+
params: Params[Array],
|
|
80
|
+
vmap_axes_params: tuple[Params[int | None] | None],
|
|
81
|
+
norm_weights: Float[Array, " nb_norm_samples"],
|
|
79
82
|
loss_weight: float,
|
|
80
|
-
) ->
|
|
83
|
+
) -> Float[Array, " "]:
|
|
81
84
|
"""
|
|
82
85
|
Note the squeezing on each result. We expect unidimensional *PINN since
|
|
83
86
|
they represent probability distributions
|
|
84
87
|
"""
|
|
85
|
-
if isinstance(u, (PINN,
|
|
88
|
+
if isinstance(u, (PINN, HyperPINN)):
|
|
86
89
|
if len(batches) == 1:
|
|
87
90
|
v_u = vmap(
|
|
88
|
-
lambda b: u(b)[u.slice_solution],
|
|
91
|
+
lambda *b: u(*b)[u.slice_solution],
|
|
89
92
|
(0,) + vmap_axes_params,
|
|
90
93
|
0,
|
|
91
94
|
)
|
|
92
95
|
res = v_u(*batches, params)
|
|
96
|
+
assert res.shape[-1] == 1, "norm loss expects unidimensional *PINN"
|
|
97
|
+
# Monte-Carlo integration using importance sampling
|
|
93
98
|
mse_norm_loss = loss_weight * (
|
|
94
|
-
jnp.abs(jnp.mean(res.squeeze()
|
|
99
|
+
jnp.abs(jnp.mean(res.squeeze() * norm_weights) - 1) ** 2
|
|
95
100
|
)
|
|
96
101
|
else:
|
|
97
102
|
# NOTE this cartesian product is costly
|
|
98
|
-
|
|
103
|
+
batch_cart_prod = make_cartesian_product(
|
|
99
104
|
batches[0],
|
|
100
105
|
batches[1],
|
|
101
106
|
).reshape(batches[0].shape[0], batches[1].shape[0], -1)
|
|
@@ -106,21 +111,24 @@ def normalization_loss_apply(
|
|
|
106
111
|
),
|
|
107
112
|
in_axes=(0,) + vmap_axes_params,
|
|
108
113
|
)
|
|
109
|
-
res = v_u(
|
|
110
|
-
|
|
114
|
+
res = v_u(batch_cart_prod, params)
|
|
115
|
+
assert res.shape[-1] == 1, "norm loss expects unidimensional *PINN"
|
|
116
|
+
# For all times t, we perform an integration. Then we average the
|
|
117
|
+
# losses over times.
|
|
111
118
|
mse_norm_loss = loss_weight * jnp.mean(
|
|
112
|
-
jnp.abs(jnp.mean(res.squeeze(), axis=-1)
|
|
119
|
+
jnp.abs(jnp.mean(res.squeeze() * norm_weights, axis=-1) - 1) ** 2
|
|
113
120
|
)
|
|
114
121
|
elif isinstance(u, SPINN):
|
|
115
122
|
if len(batches) == 1:
|
|
116
123
|
res = u(*batches, params)
|
|
124
|
+
assert res.shape[-1] == 1, "norm loss expects unidimensional *SPINN"
|
|
117
125
|
mse_norm_loss = (
|
|
118
126
|
loss_weight
|
|
119
127
|
* jnp.abs(
|
|
120
128
|
jnp.mean(
|
|
121
129
|
res.squeeze(),
|
|
122
130
|
)
|
|
123
|
-
*
|
|
131
|
+
* norm_weights
|
|
124
132
|
- 1
|
|
125
133
|
)
|
|
126
134
|
** 2
|
|
@@ -134,14 +142,15 @@ def normalization_loss_apply(
|
|
|
134
142
|
),
|
|
135
143
|
params,
|
|
136
144
|
)
|
|
145
|
+
assert res.shape[-1] == 1, "norm loss expects unidimensional *SPINN"
|
|
137
146
|
# the outer mean() below is for the times stamps
|
|
138
147
|
mse_norm_loss = loss_weight * jnp.mean(
|
|
139
148
|
jnp.abs(
|
|
140
149
|
jnp.mean(
|
|
141
150
|
res.squeeze(),
|
|
142
|
-
axis=(d + 1 for d in range(res.ndim - 2)),
|
|
151
|
+
axis=list(d + 1 for d in range(res.ndim - 2)),
|
|
143
152
|
)
|
|
144
|
-
*
|
|
153
|
+
* norm_weights
|
|
145
154
|
- 1
|
|
146
155
|
)
|
|
147
156
|
** 2
|
|
@@ -153,18 +162,34 @@ def normalization_loss_apply(
|
|
|
153
162
|
|
|
154
163
|
|
|
155
164
|
def boundary_condition_apply(
|
|
156
|
-
u:
|
|
165
|
+
u: AbstractPINN,
|
|
157
166
|
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
158
|
-
params: Params
|
|
159
|
-
omega_boundary_fun:
|
|
160
|
-
omega_boundary_condition: str,
|
|
161
|
-
omega_boundary_dim:
|
|
162
|
-
loss_weight: float | Float[Array, "boundary_cond_dim"],
|
|
163
|
-
) ->
|
|
164
|
-
|
|
167
|
+
params: Params[Array],
|
|
168
|
+
omega_boundary_fun: BoundaryConditionFun | dict[str, BoundaryConditionFun],
|
|
169
|
+
omega_boundary_condition: str | dict[str, str],
|
|
170
|
+
omega_boundary_dim: slice | dict[str, slice],
|
|
171
|
+
loss_weight: float | Float[Array, " boundary_cond_dim"],
|
|
172
|
+
) -> Float[Array, " "]:
|
|
173
|
+
assert batch.border_batch is not None
|
|
165
174
|
vmap_in_axes = (0,) + _get_vmap_in_axes_params(batch.param_batch_dict, params)
|
|
166
175
|
|
|
167
|
-
|
|
176
|
+
def _check_tuple_of_dict(
|
|
177
|
+
val,
|
|
178
|
+
) -> TypeGuard[
|
|
179
|
+
tuple[
|
|
180
|
+
dict[str, BoundaryConditionFun],
|
|
181
|
+
dict[str, BoundaryConditionFun],
|
|
182
|
+
dict[str, BoundaryConditionFun],
|
|
183
|
+
]
|
|
184
|
+
]:
|
|
185
|
+
return all(isinstance(x, dict) for x in val)
|
|
186
|
+
|
|
187
|
+
omega_boundary_dicts = (
|
|
188
|
+
omega_boundary_condition,
|
|
189
|
+
omega_boundary_fun,
|
|
190
|
+
omega_boundary_dim,
|
|
191
|
+
)
|
|
192
|
+
if _check_tuple_of_dict(omega_boundary_dicts):
|
|
168
193
|
# We must create the facet tree dictionary as we do not have the
|
|
169
194
|
# enumerate from the for loop to pass the id integer
|
|
170
195
|
if batch.border_batch.shape[-1] == 2:
|
|
@@ -186,10 +211,10 @@ def boundary_condition_apply(
|
|
|
186
211
|
)
|
|
187
212
|
)
|
|
188
213
|
),
|
|
189
|
-
omega_boundary_condition,
|
|
190
|
-
omega_boundary_fun,
|
|
214
|
+
omega_boundary_dicts[0], # omega_boundary_condition,
|
|
215
|
+
omega_boundary_dicts[1], # omega_boundary_fun,
|
|
191
216
|
facet_tree,
|
|
192
|
-
omega_boundary_dim,
|
|
217
|
+
omega_boundary_dicts[2], # omega_boundary_dim,
|
|
193
218
|
is_leaf=lambda x: x is None,
|
|
194
219
|
) # when exploring leaves with None value (no condition) the returned
|
|
195
220
|
# mse is None and we get rid of the None leaves of b_losses_by_facet
|
|
@@ -202,13 +227,13 @@ def boundary_condition_apply(
|
|
|
202
227
|
lambda fa: jnp.mean(
|
|
203
228
|
loss_weight
|
|
204
229
|
* _compute_boundary_loss(
|
|
205
|
-
|
|
206
|
-
|
|
230
|
+
omega_boundary_dicts[0], # type: ignore -> need TypeIs from 3.13
|
|
231
|
+
omega_boundary_dicts[1], # type: ignore -> need TypeIs from 3.13
|
|
207
232
|
batch,
|
|
208
233
|
u,
|
|
209
234
|
params,
|
|
210
235
|
fa,
|
|
211
|
-
|
|
236
|
+
omega_boundary_dicts[2], # type: ignore -> need TypeIs from 3.13
|
|
212
237
|
vmap_in_axes,
|
|
213
238
|
)
|
|
214
239
|
),
|
|
@@ -221,22 +246,21 @@ def boundary_condition_apply(
|
|
|
221
246
|
|
|
222
247
|
|
|
223
248
|
def observations_loss_apply(
|
|
224
|
-
u:
|
|
225
|
-
|
|
226
|
-
params: Params
|
|
227
|
-
vmap_axes: tuple[int | None
|
|
228
|
-
observed_values: Float[Array, "
|
|
229
|
-
loss_weight: float | Float[Array, "observation_dim"],
|
|
230
|
-
obs_slice: slice,
|
|
231
|
-
) ->
|
|
232
|
-
|
|
233
|
-
if isinstance(u, (PINN, HYPERPINN)):
|
|
249
|
+
u: AbstractPINN,
|
|
250
|
+
batch: Float[Array, " obs_batch_size input_dim"],
|
|
251
|
+
params: Params[Array],
|
|
252
|
+
vmap_axes: tuple[int, Params[int | None] | None],
|
|
253
|
+
observed_values: Float[Array, " obs_batch_size observation_dim"],
|
|
254
|
+
loss_weight: float | Float[Array, " observation_dim"],
|
|
255
|
+
obs_slice: EllipsisType | slice | None,
|
|
256
|
+
) -> Float[Array, " "]:
|
|
257
|
+
if isinstance(u, (PINN, HyperPINN)):
|
|
234
258
|
v_u = vmap(
|
|
235
259
|
lambda *args: u(*args)[u.slice_solution],
|
|
236
260
|
vmap_axes,
|
|
237
261
|
0,
|
|
238
262
|
)
|
|
239
|
-
val = v_u(
|
|
263
|
+
val = v_u(batch, params)[:, obs_slice]
|
|
240
264
|
mse_observation_loss = jnp.mean(
|
|
241
265
|
jnp.sum(
|
|
242
266
|
loss_weight
|
|
@@ -255,16 +279,17 @@ def observations_loss_apply(
|
|
|
255
279
|
|
|
256
280
|
|
|
257
281
|
def initial_condition_apply(
|
|
258
|
-
u:
|
|
259
|
-
omega_batch: Float[Array, "dimension"],
|
|
260
|
-
params: Params
|
|
261
|
-
vmap_axes: tuple[int | None
|
|
282
|
+
u: AbstractPINN,
|
|
283
|
+
omega_batch: Float[Array, " dimension"],
|
|
284
|
+
params: Params[Array],
|
|
285
|
+
vmap_axes: tuple[int, Params[int | None] | None],
|
|
262
286
|
initial_condition_fun: Callable,
|
|
263
|
-
|
|
264
|
-
|
|
287
|
+
t0: Float[Array, " 1"],
|
|
288
|
+
loss_weight: float | Float[Array, " initial_condition_dimension"],
|
|
289
|
+
) -> Float[Array, " "]:
|
|
265
290
|
n = omega_batch.shape[0]
|
|
266
|
-
t0_omega_batch = jnp.concatenate([jnp.
|
|
267
|
-
if isinstance(u, (PINN,
|
|
291
|
+
t0_omega_batch = jnp.concatenate([t0 * jnp.ones((n, 1)), omega_batch], axis=1)
|
|
292
|
+
if isinstance(u, (PINN, HyperPINN)):
|
|
268
293
|
v_u_t0 = vmap(
|
|
269
294
|
lambda t0_x, params: _subtract_with_check(
|
|
270
295
|
initial_condition_fun(t0_x[1:]),
|
|
@@ -296,103 +321,3 @@ def initial_condition_apply(
|
|
|
296
321
|
else:
|
|
297
322
|
raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
|
|
298
323
|
return mse_initial_condition
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
def constraints_system_loss_apply(
|
|
302
|
-
u_constraints_dict: Dict,
|
|
303
|
-
batch: ODEBatch | PDEStatioBatch | PDENonStatioBatch,
|
|
304
|
-
params_dict: ParamsDict,
|
|
305
|
-
loss_weights: Dict[str, float | Array],
|
|
306
|
-
loss_weight_struct: PyTree,
|
|
307
|
-
):
|
|
308
|
-
"""
|
|
309
|
-
Same function for systemlossODE and systemlossPDE!
|
|
310
|
-
"""
|
|
311
|
-
# Transpose so we have each u_dict as outer structure and the
|
|
312
|
-
# associated loss_weight as inner structure
|
|
313
|
-
loss_weights_T = jax.tree_util.tree_transpose(
|
|
314
|
-
jax.tree_util.tree_structure(loss_weight_struct),
|
|
315
|
-
jax.tree_util.tree_structure(loss_weights["initial_condition"]),
|
|
316
|
-
loss_weights,
|
|
317
|
-
)
|
|
318
|
-
|
|
319
|
-
if isinstance(params_dict.nn_params, dict):
|
|
320
|
-
|
|
321
|
-
def apply_u_constraint(
|
|
322
|
-
u_constraint, nn_params, eq_params, loss_weights_for_u, obs_batch_u
|
|
323
|
-
):
|
|
324
|
-
res_dict_for_u = u_constraint.evaluate(
|
|
325
|
-
Params(
|
|
326
|
-
nn_params=nn_params,
|
|
327
|
-
eq_params=eq_params,
|
|
328
|
-
),
|
|
329
|
-
append_obs_batch(batch, obs_batch_u),
|
|
330
|
-
)[1]
|
|
331
|
-
res_dict_ponderated = jax.tree_util.tree_map(
|
|
332
|
-
lambda w, l: w * l, res_dict_for_u, loss_weights_for_u
|
|
333
|
-
)
|
|
334
|
-
return res_dict_ponderated
|
|
335
|
-
|
|
336
|
-
# Note in the case of multiple PINNs, batch.obs_batch_dict is a dict
|
|
337
|
-
# with keys corresponding to the PINN and value correspondinf to an
|
|
338
|
-
# original obs_batch_dict. Hence the tree mapping also interates over
|
|
339
|
-
# batch.obs_batch_dict
|
|
340
|
-
res_dict = jax.tree_util.tree_map(
|
|
341
|
-
apply_u_constraint,
|
|
342
|
-
u_constraints_dict,
|
|
343
|
-
params_dict.nn_params,
|
|
344
|
-
(
|
|
345
|
-
params_dict.eq_params
|
|
346
|
-
if params_dict.eq_params.keys() == params_dict.nn_params.keys()
|
|
347
|
-
else {k: params_dict.eq_params for k in params_dict.nn_params.keys()}
|
|
348
|
-
), # this manipulation is needed since we authorize eq_params not to have the same structure as nn_params in ParamsDict
|
|
349
|
-
loss_weights_T,
|
|
350
|
-
batch.obs_batch_dict,
|
|
351
|
-
is_leaf=lambda x: (
|
|
352
|
-
not isinstance(x, dict) # to not traverse more than the first
|
|
353
|
-
# outer dict of the pytrees passed to the function. This will
|
|
354
|
-
# work because u_constraints_dict is a dict of Losses, and it
|
|
355
|
-
# thus stops the traversing of other dict too
|
|
356
|
-
),
|
|
357
|
-
)
|
|
358
|
-
# TODO try to get rid of this condition?
|
|
359
|
-
else:
|
|
360
|
-
|
|
361
|
-
def apply_u_constraint(u_constraint, loss_weights_for_u, obs_batch_u):
|
|
362
|
-
res_dict_for_u = u_constraint.evaluate(
|
|
363
|
-
params_dict,
|
|
364
|
-
append_obs_batch(batch, obs_batch_u),
|
|
365
|
-
)[1]
|
|
366
|
-
res_dict_ponderated = jax.tree_util.tree_map(
|
|
367
|
-
lambda w, l: w * l, res_dict_for_u, loss_weights_for_u
|
|
368
|
-
)
|
|
369
|
-
return res_dict_ponderated
|
|
370
|
-
|
|
371
|
-
res_dict = jax.tree_util.tree_map(
|
|
372
|
-
apply_u_constraint, u_constraints_dict, loss_weights_T, batch.obs_batch_dict
|
|
373
|
-
)
|
|
374
|
-
|
|
375
|
-
# Transpose back so we have mses as outer structures and their values
|
|
376
|
-
# for each u_dict as inner structures. The tree_leaves transforms the
|
|
377
|
-
# inner structure into a list so we can catch is as leaf it the
|
|
378
|
-
# tree_map below
|
|
379
|
-
res_dict = jax.tree_util.tree_transpose(
|
|
380
|
-
jax.tree_util.tree_structure(
|
|
381
|
-
jax.tree_util.tree_leaves(loss_weights["initial_condition"])
|
|
382
|
-
),
|
|
383
|
-
jax.tree_util.tree_structure(loss_weight_struct),
|
|
384
|
-
res_dict,
|
|
385
|
-
)
|
|
386
|
-
# For each mse, sum their values on each u_dict
|
|
387
|
-
res_dict = jax.tree_util.tree_map(
|
|
388
|
-
lambda mse: jax.tree_util.tree_reduce(
|
|
389
|
-
lambda x, y: x + y, jax.tree_util.tree_leaves(mse)
|
|
390
|
-
),
|
|
391
|
-
res_dict,
|
|
392
|
-
is_leaf=lambda x: isinstance(x, list),
|
|
393
|
-
)
|
|
394
|
-
# Total loss
|
|
395
|
-
total_loss = jax.tree_util.tree_reduce(
|
|
396
|
-
lambda x, y: x + y, jax.tree_util.tree_leaves(res_dict)
|
|
397
|
-
)
|
|
398
|
-
return total_loss, res_dict
|
jinns/loss/_loss_weights.py
CHANGED
|
@@ -2,58 +2,26 @@
|
|
|
2
2
|
Formalize the loss weights data structure
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from typing import Dict
|
|
6
5
|
from jaxtyping import Array, Float
|
|
7
6
|
import equinox as eqx
|
|
8
7
|
|
|
9
8
|
|
|
10
9
|
class LossWeightsODE(eqx.Module):
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
observations: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class LossWeightsODEDict(eqx.Module):
|
|
18
|
-
|
|
19
|
-
dyn_loss: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=None)
|
|
20
|
-
initial_condition: Dict[str, Array | Float | None] = eqx.field(
|
|
21
|
-
kw_only=True, default=None
|
|
22
|
-
)
|
|
23
|
-
observations: Dict[str, Array | Float | None] = eqx.field(
|
|
24
|
-
kw_only=True, default=None
|
|
25
|
-
)
|
|
10
|
+
dyn_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
|
|
11
|
+
initial_condition: Array | Float = eqx.field(kw_only=True, default=0.0)
|
|
12
|
+
observations: Array | Float = eqx.field(kw_only=True, default=0.0)
|
|
26
13
|
|
|
27
14
|
|
|
28
15
|
class LossWeightsPDEStatio(eqx.Module):
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
observations: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
16
|
+
dyn_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
|
|
17
|
+
norm_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
|
|
18
|
+
boundary_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
|
|
19
|
+
observations: Array | Float = eqx.field(kw_only=True, default=0.0)
|
|
34
20
|
|
|
35
21
|
|
|
36
22
|
class LossWeightsPDENonStatio(eqx.Module):
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
initial_condition: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
class LossWeightsPDEDict(eqx.Module):
|
|
46
|
-
"""
|
|
47
|
-
Only one type of LossWeights data structure for the SystemLossPDE:
|
|
48
|
-
Include the initial condition always for the code to be more generic
|
|
49
|
-
"""
|
|
50
|
-
|
|
51
|
-
dyn_loss: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=1.0)
|
|
52
|
-
norm_loss: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=1.0)
|
|
53
|
-
boundary_loss: Dict[str, Array | Float | None] = eqx.field(
|
|
54
|
-
kw_only=True, default=1.0
|
|
55
|
-
)
|
|
56
|
-
observations: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=1.0)
|
|
57
|
-
initial_condition: Dict[str, Array | Float | None] = eqx.field(
|
|
58
|
-
kw_only=True, default=1.0
|
|
59
|
-
)
|
|
23
|
+
dyn_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
|
|
24
|
+
norm_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
|
|
25
|
+
boundary_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
|
|
26
|
+
observations: Array | Float = eqx.field(kw_only=True, default=0.0)
|
|
27
|
+
initial_condition: Array | Float = eqx.field(kw_only=True, default=0.0)
|