jinns 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +114 -187
- jinns/loss/_DynamicLossAbstract.py +45 -68
- jinns/loss/_LossODE.py +71 -336
- jinns/loss/_LossPDE.py +146 -520
- jinns/loss/__init__.py +28 -6
- jinns/loss/_abstract_loss.py +15 -0
- jinns/loss/_boundary_conditions.py +20 -19
- jinns/loss/_loss_utils.py +78 -159
- jinns/loss/_loss_weights.py +12 -44
- jinns/loss/_operators.py +84 -74
- jinns/nn/__init__.py +15 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +94 -57
- jinns/nn/_mlp.py +50 -25
- jinns/nn/_pinn.py +33 -19
- jinns/nn/_ppinn.py +70 -34
- jinns/nn/_save_load.py +21 -51
- jinns/nn/_spinn.py +33 -16
- jinns/nn/_spinn_mlp.py +28 -22
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +35 -34
- jinns/solver/_rar.py +80 -63
- jinns/solver/_solve.py +89 -63
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +12 -9
- jinns/utils/_types.py +11 -57
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/METADATA +4 -3
- jinns-1.4.0.dist-info/RECORD +53 -0
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns-1.3.0.dist-info/RECORD +0 -44
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/loss/__init__.py
CHANGED
|
@@ -1,21 +1,18 @@
|
|
|
1
1
|
from ._DynamicLossAbstract import DynamicLoss, ODE, PDEStatio, PDENonStatio
|
|
2
|
-
from ._LossODE import LossODE
|
|
3
|
-
from ._LossPDE import LossPDEStatio, LossPDENonStatio
|
|
2
|
+
from ._LossODE import LossODE
|
|
3
|
+
from ._LossPDE import LossPDEStatio, LossPDENonStatio
|
|
4
4
|
from ._DynamicLoss import (
|
|
5
5
|
GeneralizedLotkaVolterra,
|
|
6
6
|
BurgersEquation,
|
|
7
7
|
FPENonStatioLoss2D,
|
|
8
8
|
OU_FPENonStatioLoss2D,
|
|
9
9
|
FisherKPP,
|
|
10
|
-
|
|
11
|
-
NavierStokes2DStatio,
|
|
10
|
+
NavierStokesMassConservation2DStatio,
|
|
12
11
|
)
|
|
13
12
|
from ._loss_weights import (
|
|
14
13
|
LossWeightsODE,
|
|
15
|
-
LossWeightsODEDict,
|
|
16
14
|
LossWeightsPDENonStatio,
|
|
17
15
|
LossWeightsPDEStatio,
|
|
18
|
-
LossWeightsPDEDict,
|
|
19
16
|
)
|
|
20
17
|
|
|
21
18
|
from ._operators import (
|
|
@@ -26,3 +23,28 @@ from ._operators import (
|
|
|
26
23
|
vectorial_laplacian_fwd,
|
|
27
24
|
vectorial_laplacian_rev,
|
|
28
25
|
)
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"DynamicLoss",
|
|
29
|
+
"ODE",
|
|
30
|
+
"PDEStatio",
|
|
31
|
+
"PDENonStatio",
|
|
32
|
+
"LossODE",
|
|
33
|
+
"LossPDEStatio",
|
|
34
|
+
"LossPDENonStatio",
|
|
35
|
+
"GeneralizedLotkaVolterra",
|
|
36
|
+
"BurgersEquation",
|
|
37
|
+
"FPENonStatioLoss2D",
|
|
38
|
+
"OU_FPENonStatioLoss2D",
|
|
39
|
+
"FisherKPP",
|
|
40
|
+
"NavierStokesMassConservation2DStatio",
|
|
41
|
+
"LossWeightsODE",
|
|
42
|
+
"LossWeightsPDEStatio",
|
|
43
|
+
"LossWeightsPDENonStatio",
|
|
44
|
+
"divergence_fwd",
|
|
45
|
+
"divergence_rev",
|
|
46
|
+
"laplacian_fwd",
|
|
47
|
+
"laplacian_rev",
|
|
48
|
+
"vectorial_laplacian_fwd",
|
|
49
|
+
"vectorial_laplacian_rev",
|
|
50
|
+
]
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from jaxtyping import Array
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class AbstractLoss(eqx.Module):
|
|
7
|
+
"""
|
|
8
|
+
Basically just a way to add a __call__ to an eqx.Module.
|
|
9
|
+
The way to go for correct type hints apparently
|
|
10
|
+
https://github.com/patrick-kidger/equinox/issues/1002 + https://docs.kidger.site/equinox/pattern/
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
@abc.abstractmethod
|
|
14
|
+
def __call__(self, *_, **__) -> Array:
|
|
15
|
+
pass
|
|
@@ -7,31 +7,31 @@ from __future__ import (
|
|
|
7
7
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
8
|
|
|
9
9
|
from typing import TYPE_CHECKING, Callable
|
|
10
|
+
from jaxtyping import Array, Float
|
|
10
11
|
import jax
|
|
11
12
|
import jax.numpy as jnp
|
|
12
13
|
from jax import vmap, grad
|
|
13
|
-
import equinox as eqx
|
|
14
14
|
from jinns.utils._utils import get_grid, _subtract_with_check
|
|
15
|
-
from jinns.data._Batchs import
|
|
15
|
+
from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
|
|
16
16
|
from jinns.nn._pinn import PINN
|
|
17
17
|
from jinns.nn._spinn import SPINN
|
|
18
18
|
|
|
19
19
|
if TYPE_CHECKING:
|
|
20
|
-
from jinns.
|
|
20
|
+
from jinns.parameters._params import Params
|
|
21
|
+
from jinns.utils._types import BoundaryConditionFun
|
|
22
|
+
from jinns.nn._abstract_pinn import AbstractPINN
|
|
21
23
|
|
|
22
24
|
|
|
23
25
|
def _compute_boundary_loss(
|
|
24
26
|
boundary_condition_type: str,
|
|
25
|
-
f:
|
|
26
|
-
[Float[Array, "dim"] | Float[Array, "dim + 1"]], Float[Array, "dim_solution"]
|
|
27
|
-
],
|
|
27
|
+
f: BoundaryConditionFun,
|
|
28
28
|
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
29
|
-
u:
|
|
30
|
-
params:
|
|
29
|
+
u: AbstractPINN,
|
|
30
|
+
params: Params[Array],
|
|
31
31
|
facet: int,
|
|
32
32
|
dim_to_apply: slice,
|
|
33
33
|
vmap_in_axes: tuple,
|
|
34
|
-
) ->
|
|
34
|
+
) -> Float[Array, " "]:
|
|
35
35
|
r"""A generic function that will compute the mini-batch MSE of a
|
|
36
36
|
boundary condition in the stationary case, resp. non-stationary, given by:
|
|
37
37
|
|
|
@@ -67,7 +67,7 @@ def _compute_boundary_loss(
|
|
|
67
67
|
u
|
|
68
68
|
a PINN
|
|
69
69
|
params
|
|
70
|
-
Params
|
|
70
|
+
Params
|
|
71
71
|
facet
|
|
72
72
|
An integer which represents the id of the facet which is currently
|
|
73
73
|
considered (in the order provided by the DataGenerator which is fixed)
|
|
@@ -96,15 +96,15 @@ def _compute_boundary_loss(
|
|
|
96
96
|
|
|
97
97
|
def boundary_dirichlet(
|
|
98
98
|
f: Callable[
|
|
99
|
-
[Float[Array, "dim"] | Float[Array, "dim + 1"]], Float[Array, "dim_solution"]
|
|
99
|
+
[Float[Array, " dim"] | Float[Array, " dim + 1"]], Float[Array, " dim_solution"]
|
|
100
100
|
],
|
|
101
101
|
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
102
|
-
u:
|
|
103
|
-
params: Params
|
|
102
|
+
u: AbstractPINN,
|
|
103
|
+
params: Params[Array],
|
|
104
104
|
facet: int,
|
|
105
105
|
dim_to_apply: slice,
|
|
106
106
|
vmap_in_axes: tuple,
|
|
107
|
-
) ->
|
|
107
|
+
) -> Float[Array, " "]:
|
|
108
108
|
r"""
|
|
109
109
|
This omega boundary condition enforces a solution that is equal to `f`
|
|
110
110
|
at `times_batch` x `omega_border` (non stationary case) or at `omega_border`
|
|
@@ -135,6 +135,7 @@ def boundary_dirichlet(
|
|
|
135
135
|
vmap_in_axes
|
|
136
136
|
A tuple object which specifies the in_axes of the vmapping
|
|
137
137
|
"""
|
|
138
|
+
assert batch.border_batch is not None
|
|
138
139
|
batch_array = batch.border_batch
|
|
139
140
|
batch_array = batch_array[..., facet]
|
|
140
141
|
|
|
@@ -168,15 +169,15 @@ def boundary_dirichlet(
|
|
|
168
169
|
|
|
169
170
|
def boundary_neumann(
|
|
170
171
|
f: Callable[
|
|
171
|
-
[Float[Array, "dim"] | Float[Array, "dim + 1"]], Float[Array, "dim_solution"]
|
|
172
|
+
[Float[Array, " dim"] | Float[Array, " dim + 1"]], Float[Array, " dim_solution"]
|
|
172
173
|
],
|
|
173
174
|
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
174
|
-
u:
|
|
175
|
-
params: Params
|
|
175
|
+
u: AbstractPINN,
|
|
176
|
+
params: Params[Array],
|
|
176
177
|
facet: int,
|
|
177
178
|
dim_to_apply: slice,
|
|
178
179
|
vmap_in_axes: tuple,
|
|
179
|
-
) ->
|
|
180
|
+
) -> Float[Array, " "]:
|
|
180
181
|
r"""
|
|
181
182
|
This omega boundary condition enforces a solution where $\nabla u\cdot
|
|
182
183
|
n$ is equal to `f` at the cartesian product of `time_batch` x `omega
|
|
@@ -208,6 +209,7 @@ def boundary_neumann(
|
|
|
208
209
|
vmap_in_axes
|
|
209
210
|
A tuple object which specifies the in_axes of the vmapping
|
|
210
211
|
"""
|
|
212
|
+
assert batch.border_batch is not None
|
|
211
213
|
batch_array = batch.border_batch
|
|
212
214
|
batch_array = batch_array[..., facet]
|
|
213
215
|
|
|
@@ -223,7 +225,6 @@ def boundary_neumann(
|
|
|
223
225
|
n = jnp.array([[-1, 1, 0, 0], [0, 0, -1, 1]])
|
|
224
226
|
|
|
225
227
|
if isinstance(u, PINN):
|
|
226
|
-
|
|
227
228
|
u_ = lambda inputs, params: jnp.squeeze(u(inputs, params)[dim_to_apply])
|
|
228
229
|
|
|
229
230
|
if u.eq_type == "statio_PDE":
|
jinns/loss/_loss_utils.py
CHANGED
|
@@ -6,42 +6,43 @@ from __future__ import (
|
|
|
6
6
|
annotations,
|
|
7
7
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
8
|
|
|
9
|
-
from typing import TYPE_CHECKING, Callable,
|
|
9
|
+
from typing import TYPE_CHECKING, Callable, TypeGuard
|
|
10
|
+
from types import EllipsisType
|
|
10
11
|
import jax
|
|
11
12
|
import jax.numpy as jnp
|
|
12
13
|
from jax import vmap
|
|
13
|
-
import
|
|
14
|
-
from jaxtyping import Float, Array, PyTree
|
|
14
|
+
from jaxtyping import Float, Array
|
|
15
15
|
|
|
16
16
|
from jinns.loss._boundary_conditions import (
|
|
17
17
|
_compute_boundary_loss,
|
|
18
18
|
)
|
|
19
19
|
from jinns.utils._utils import _subtract_with_check, get_grid
|
|
20
|
-
from jinns.data.
|
|
20
|
+
from jinns.data._utils import make_cartesian_product
|
|
21
21
|
from jinns.parameters._params import _get_vmap_in_axes_params
|
|
22
22
|
from jinns.nn._pinn import PINN
|
|
23
23
|
from jinns.nn._spinn import SPINN
|
|
24
24
|
from jinns.nn._hyperpinn import HyperPINN
|
|
25
|
-
from jinns.data._Batchs import
|
|
26
|
-
from jinns.parameters._params import Params
|
|
25
|
+
from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
|
|
26
|
+
from jinns.parameters._params import Params
|
|
27
27
|
|
|
28
28
|
if TYPE_CHECKING:
|
|
29
|
-
from jinns.utils._types import
|
|
29
|
+
from jinns.utils._types import BoundaryConditionFun
|
|
30
|
+
from jinns.nn._abstract_pinn import AbstractPINN
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
def dynamic_loss_apply(
|
|
33
|
-
dyn_loss:
|
|
34
|
-
u:
|
|
34
|
+
dyn_loss: Callable,
|
|
35
|
+
u: AbstractPINN,
|
|
35
36
|
batch: (
|
|
36
|
-
Float[Array, "batch_size 1"]
|
|
37
|
-
| Float[Array, "batch_size dim"]
|
|
38
|
-
| Float[Array, "batch_size 1+dim"]
|
|
37
|
+
Float[Array, " batch_size 1"]
|
|
38
|
+
| Float[Array, " batch_size dim"]
|
|
39
|
+
| Float[Array, " batch_size 1+dim"]
|
|
39
40
|
),
|
|
40
|
-
params: Params
|
|
41
|
-
vmap_axes: tuple[int | None
|
|
42
|
-
loss_weight: float | Float[Array, "dyn_loss_dimension"],
|
|
41
|
+
params: Params[Array],
|
|
42
|
+
vmap_axes: tuple[int, Params[int | None] | None],
|
|
43
|
+
loss_weight: float | Float[Array, " dyn_loss_dimension"],
|
|
43
44
|
u_type: PINN | HyperPINN | None = None,
|
|
44
|
-
) ->
|
|
45
|
+
) -> Float[Array, " "]:
|
|
45
46
|
"""
|
|
46
47
|
Sometimes when u is a lambda function a or dict we do not have access to
|
|
47
48
|
its type here, hence the last argument
|
|
@@ -49,7 +50,9 @@ def dynamic_loss_apply(
|
|
|
49
50
|
if u_type == PINN or u_type == HyperPINN or isinstance(u, (PINN, HyperPINN)):
|
|
50
51
|
v_dyn_loss = vmap(
|
|
51
52
|
lambda batch, params: dyn_loss(
|
|
52
|
-
batch,
|
|
53
|
+
batch,
|
|
54
|
+
u,
|
|
55
|
+
params, # we must place the params at the end
|
|
53
56
|
),
|
|
54
57
|
vmap_axes,
|
|
55
58
|
0,
|
|
@@ -66,18 +69,18 @@ def dynamic_loss_apply(
|
|
|
66
69
|
|
|
67
70
|
|
|
68
71
|
def normalization_loss_apply(
|
|
69
|
-
u:
|
|
72
|
+
u: AbstractPINN,
|
|
70
73
|
batches: (
|
|
71
|
-
tuple[Float[Array, "nb_norm_samples dim"]]
|
|
74
|
+
tuple[Float[Array, " nb_norm_samples dim"]]
|
|
72
75
|
| tuple[
|
|
73
|
-
Float[Array, "nb_norm_time_slices 1"], Float[Array, "nb_norm_samples dim"]
|
|
76
|
+
Float[Array, " nb_norm_time_slices 1"], Float[Array, " nb_norm_samples dim"]
|
|
74
77
|
]
|
|
75
78
|
),
|
|
76
|
-
params: Params
|
|
77
|
-
vmap_axes_params: tuple[int | None
|
|
78
|
-
norm_weights: Float[Array, "nb_norm_samples"],
|
|
79
|
+
params: Params[Array],
|
|
80
|
+
vmap_axes_params: tuple[Params[int | None] | None],
|
|
81
|
+
norm_weights: Float[Array, " nb_norm_samples"],
|
|
79
82
|
loss_weight: float,
|
|
80
|
-
) ->
|
|
83
|
+
) -> Float[Array, " "]:
|
|
81
84
|
"""
|
|
82
85
|
Note the squeezing on each result. We expect unidimensional *PINN since
|
|
83
86
|
they represent probability distributions
|
|
@@ -97,7 +100,7 @@ def normalization_loss_apply(
|
|
|
97
100
|
)
|
|
98
101
|
else:
|
|
99
102
|
# NOTE this cartesian product is costly
|
|
100
|
-
|
|
103
|
+
batch_cart_prod = make_cartesian_product(
|
|
101
104
|
batches[0],
|
|
102
105
|
batches[1],
|
|
103
106
|
).reshape(batches[0].shape[0], batches[1].shape[0], -1)
|
|
@@ -108,7 +111,7 @@ def normalization_loss_apply(
|
|
|
108
111
|
),
|
|
109
112
|
in_axes=(0,) + vmap_axes_params,
|
|
110
113
|
)
|
|
111
|
-
res = v_u(
|
|
114
|
+
res = v_u(batch_cart_prod, params)
|
|
112
115
|
assert res.shape[-1] == 1, "norm loss expects unidimensional *PINN"
|
|
113
116
|
# For all times t, we perform an integration. Then we average the
|
|
114
117
|
# losses over times.
|
|
@@ -145,7 +148,7 @@ def normalization_loss_apply(
|
|
|
145
148
|
jnp.abs(
|
|
146
149
|
jnp.mean(
|
|
147
150
|
res.squeeze(),
|
|
148
|
-
axis=(d + 1 for d in range(res.ndim - 2)),
|
|
151
|
+
axis=list(d + 1 for d in range(res.ndim - 2)),
|
|
149
152
|
)
|
|
150
153
|
* norm_weights
|
|
151
154
|
- 1
|
|
@@ -159,18 +162,34 @@ def normalization_loss_apply(
|
|
|
159
162
|
|
|
160
163
|
|
|
161
164
|
def boundary_condition_apply(
|
|
162
|
-
u:
|
|
165
|
+
u: AbstractPINN,
|
|
163
166
|
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
164
|
-
params: Params
|
|
165
|
-
omega_boundary_fun:
|
|
166
|
-
omega_boundary_condition: str,
|
|
167
|
-
omega_boundary_dim:
|
|
168
|
-
loss_weight: float | Float[Array, "boundary_cond_dim"],
|
|
169
|
-
) ->
|
|
170
|
-
|
|
167
|
+
params: Params[Array],
|
|
168
|
+
omega_boundary_fun: BoundaryConditionFun | dict[str, BoundaryConditionFun],
|
|
169
|
+
omega_boundary_condition: str | dict[str, str],
|
|
170
|
+
omega_boundary_dim: slice | dict[str, slice],
|
|
171
|
+
loss_weight: float | Float[Array, " boundary_cond_dim"],
|
|
172
|
+
) -> Float[Array, " "]:
|
|
173
|
+
assert batch.border_batch is not None
|
|
171
174
|
vmap_in_axes = (0,) + _get_vmap_in_axes_params(batch.param_batch_dict, params)
|
|
172
175
|
|
|
173
|
-
|
|
176
|
+
def _check_tuple_of_dict(
|
|
177
|
+
val,
|
|
178
|
+
) -> TypeGuard[
|
|
179
|
+
tuple[
|
|
180
|
+
dict[str, BoundaryConditionFun],
|
|
181
|
+
dict[str, BoundaryConditionFun],
|
|
182
|
+
dict[str, BoundaryConditionFun],
|
|
183
|
+
]
|
|
184
|
+
]:
|
|
185
|
+
return all(isinstance(x, dict) for x in val)
|
|
186
|
+
|
|
187
|
+
omega_boundary_dicts = (
|
|
188
|
+
omega_boundary_condition,
|
|
189
|
+
omega_boundary_fun,
|
|
190
|
+
omega_boundary_dim,
|
|
191
|
+
)
|
|
192
|
+
if _check_tuple_of_dict(omega_boundary_dicts):
|
|
174
193
|
# We must create the facet tree dictionary as we do not have the
|
|
175
194
|
# enumerate from the for loop to pass the id integer
|
|
176
195
|
if batch.border_batch.shape[-1] == 2:
|
|
@@ -192,10 +211,10 @@ def boundary_condition_apply(
|
|
|
192
211
|
)
|
|
193
212
|
)
|
|
194
213
|
),
|
|
195
|
-
omega_boundary_condition,
|
|
196
|
-
omega_boundary_fun,
|
|
214
|
+
omega_boundary_dicts[0], # omega_boundary_condition,
|
|
215
|
+
omega_boundary_dicts[1], # omega_boundary_fun,
|
|
197
216
|
facet_tree,
|
|
198
|
-
omega_boundary_dim,
|
|
217
|
+
omega_boundary_dicts[2], # omega_boundary_dim,
|
|
199
218
|
is_leaf=lambda x: x is None,
|
|
200
219
|
) # when exploring leaves with None value (no condition) the returned
|
|
201
220
|
# mse is None and we get rid of the None leaves of b_losses_by_facet
|
|
@@ -208,13 +227,13 @@ def boundary_condition_apply(
|
|
|
208
227
|
lambda fa: jnp.mean(
|
|
209
228
|
loss_weight
|
|
210
229
|
* _compute_boundary_loss(
|
|
211
|
-
|
|
212
|
-
|
|
230
|
+
omega_boundary_dicts[0], # type: ignore -> need TypeIs from 3.13
|
|
231
|
+
omega_boundary_dicts[1], # type: ignore -> need TypeIs from 3.13
|
|
213
232
|
batch,
|
|
214
233
|
u,
|
|
215
234
|
params,
|
|
216
235
|
fa,
|
|
217
|
-
|
|
236
|
+
omega_boundary_dicts[2], # type: ignore -> need TypeIs from 3.13
|
|
218
237
|
vmap_in_axes,
|
|
219
238
|
)
|
|
220
239
|
),
|
|
@@ -227,22 +246,21 @@ def boundary_condition_apply(
|
|
|
227
246
|
|
|
228
247
|
|
|
229
248
|
def observations_loss_apply(
|
|
230
|
-
u:
|
|
231
|
-
|
|
232
|
-
params: Params
|
|
233
|
-
vmap_axes: tuple[int | None
|
|
234
|
-
observed_values: Float[Array, "
|
|
235
|
-
loss_weight: float | Float[Array, "observation_dim"],
|
|
236
|
-
obs_slice: slice,
|
|
237
|
-
) ->
|
|
238
|
-
# TODO implement for SPINN
|
|
249
|
+
u: AbstractPINN,
|
|
250
|
+
batch: Float[Array, " obs_batch_size input_dim"],
|
|
251
|
+
params: Params[Array],
|
|
252
|
+
vmap_axes: tuple[int, Params[int | None] | None],
|
|
253
|
+
observed_values: Float[Array, " obs_batch_size observation_dim"],
|
|
254
|
+
loss_weight: float | Float[Array, " observation_dim"],
|
|
255
|
+
obs_slice: EllipsisType | slice | None,
|
|
256
|
+
) -> Float[Array, " "]:
|
|
239
257
|
if isinstance(u, (PINN, HyperPINN)):
|
|
240
258
|
v_u = vmap(
|
|
241
259
|
lambda *args: u(*args)[u.slice_solution],
|
|
242
260
|
vmap_axes,
|
|
243
261
|
0,
|
|
244
262
|
)
|
|
245
|
-
val = v_u(
|
|
263
|
+
val = v_u(batch, params)[:, obs_slice]
|
|
246
264
|
mse_observation_loss = jnp.mean(
|
|
247
265
|
jnp.sum(
|
|
248
266
|
loss_weight
|
|
@@ -261,15 +279,16 @@ def observations_loss_apply(
|
|
|
261
279
|
|
|
262
280
|
|
|
263
281
|
def initial_condition_apply(
|
|
264
|
-
u:
|
|
265
|
-
omega_batch: Float[Array, "dimension"],
|
|
266
|
-
params: Params
|
|
267
|
-
vmap_axes: tuple[int | None
|
|
282
|
+
u: AbstractPINN,
|
|
283
|
+
omega_batch: Float[Array, " dimension"],
|
|
284
|
+
params: Params[Array],
|
|
285
|
+
vmap_axes: tuple[int, Params[int | None] | None],
|
|
268
286
|
initial_condition_fun: Callable,
|
|
269
|
-
|
|
270
|
-
|
|
287
|
+
t0: Float[Array, " 1"],
|
|
288
|
+
loss_weight: float | Float[Array, " initial_condition_dimension"],
|
|
289
|
+
) -> Float[Array, " "]:
|
|
271
290
|
n = omega_batch.shape[0]
|
|
272
|
-
t0_omega_batch = jnp.concatenate([jnp.
|
|
291
|
+
t0_omega_batch = jnp.concatenate([t0 * jnp.ones((n, 1)), omega_batch], axis=1)
|
|
273
292
|
if isinstance(u, (PINN, HyperPINN)):
|
|
274
293
|
v_u_t0 = vmap(
|
|
275
294
|
lambda t0_x, params: _subtract_with_check(
|
|
@@ -302,103 +321,3 @@ def initial_condition_apply(
|
|
|
302
321
|
else:
|
|
303
322
|
raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
|
|
304
323
|
return mse_initial_condition
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
def constraints_system_loss_apply(
|
|
308
|
-
u_constraints_dict: Dict,
|
|
309
|
-
batch: ODEBatch | PDEStatioBatch | PDENonStatioBatch,
|
|
310
|
-
params_dict: ParamsDict,
|
|
311
|
-
loss_weights: Dict[str, float | Array],
|
|
312
|
-
loss_weight_struct: PyTree,
|
|
313
|
-
):
|
|
314
|
-
"""
|
|
315
|
-
Same function for systemlossODE and systemlossPDE!
|
|
316
|
-
"""
|
|
317
|
-
# Transpose so we have each u_dict as outer structure and the
|
|
318
|
-
# associated loss_weight as inner structure
|
|
319
|
-
loss_weights_T = jax.tree_util.tree_transpose(
|
|
320
|
-
jax.tree_util.tree_structure(loss_weight_struct),
|
|
321
|
-
jax.tree_util.tree_structure(loss_weights["initial_condition"]),
|
|
322
|
-
loss_weights,
|
|
323
|
-
)
|
|
324
|
-
|
|
325
|
-
if isinstance(params_dict.nn_params, dict):
|
|
326
|
-
|
|
327
|
-
def apply_u_constraint(
|
|
328
|
-
u_constraint, nn_params, eq_params, loss_weights_for_u, obs_batch_u
|
|
329
|
-
):
|
|
330
|
-
res_dict_for_u = u_constraint.evaluate(
|
|
331
|
-
Params(
|
|
332
|
-
nn_params=nn_params,
|
|
333
|
-
eq_params=eq_params,
|
|
334
|
-
),
|
|
335
|
-
append_obs_batch(batch, obs_batch_u),
|
|
336
|
-
)[1]
|
|
337
|
-
res_dict_ponderated = jax.tree_util.tree_map(
|
|
338
|
-
lambda w, l: w * l, res_dict_for_u, loss_weights_for_u
|
|
339
|
-
)
|
|
340
|
-
return res_dict_ponderated
|
|
341
|
-
|
|
342
|
-
# Note in the case of multiple PINNs, batch.obs_batch_dict is a dict
|
|
343
|
-
# with keys corresponding to the PINN and value correspondinf to an
|
|
344
|
-
# original obs_batch_dict. Hence the tree mapping also interates over
|
|
345
|
-
# batch.obs_batch_dict
|
|
346
|
-
res_dict = jax.tree_util.tree_map(
|
|
347
|
-
apply_u_constraint,
|
|
348
|
-
u_constraints_dict,
|
|
349
|
-
params_dict.nn_params,
|
|
350
|
-
(
|
|
351
|
-
params_dict.eq_params
|
|
352
|
-
if params_dict.eq_params.keys() == params_dict.nn_params.keys()
|
|
353
|
-
else {k: params_dict.eq_params for k in params_dict.nn_params.keys()}
|
|
354
|
-
), # this manipulation is needed since we authorize eq_params not to have the same structure as nn_params in ParamsDict
|
|
355
|
-
loss_weights_T,
|
|
356
|
-
batch.obs_batch_dict,
|
|
357
|
-
is_leaf=lambda x: (
|
|
358
|
-
not isinstance(x, dict) # to not traverse more than the first
|
|
359
|
-
# outer dict of the pytrees passed to the function. This will
|
|
360
|
-
# work because u_constraints_dict is a dict of Losses, and it
|
|
361
|
-
# thus stops the traversing of other dict too
|
|
362
|
-
),
|
|
363
|
-
)
|
|
364
|
-
# TODO try to get rid of this condition?
|
|
365
|
-
else:
|
|
366
|
-
|
|
367
|
-
def apply_u_constraint(u_constraint, loss_weights_for_u, obs_batch_u):
|
|
368
|
-
res_dict_for_u = u_constraint.evaluate(
|
|
369
|
-
params_dict,
|
|
370
|
-
append_obs_batch(batch, obs_batch_u),
|
|
371
|
-
)[1]
|
|
372
|
-
res_dict_ponderated = jax.tree_util.tree_map(
|
|
373
|
-
lambda w, l: w * l, res_dict_for_u, loss_weights_for_u
|
|
374
|
-
)
|
|
375
|
-
return res_dict_ponderated
|
|
376
|
-
|
|
377
|
-
res_dict = jax.tree_util.tree_map(
|
|
378
|
-
apply_u_constraint, u_constraints_dict, loss_weights_T, batch.obs_batch_dict
|
|
379
|
-
)
|
|
380
|
-
|
|
381
|
-
# Transpose back so we have mses as outer structures and their values
|
|
382
|
-
# for each u_dict as inner structures. The tree_leaves transforms the
|
|
383
|
-
# inner structure into a list so we can catch is as leaf it the
|
|
384
|
-
# tree_map below
|
|
385
|
-
res_dict = jax.tree_util.tree_transpose(
|
|
386
|
-
jax.tree_util.tree_structure(
|
|
387
|
-
jax.tree_util.tree_leaves(loss_weights["initial_condition"])
|
|
388
|
-
),
|
|
389
|
-
jax.tree_util.tree_structure(loss_weight_struct),
|
|
390
|
-
res_dict,
|
|
391
|
-
)
|
|
392
|
-
# For each mse, sum their values on each u_dict
|
|
393
|
-
res_dict = jax.tree_util.tree_map(
|
|
394
|
-
lambda mse: jax.tree_util.tree_reduce(
|
|
395
|
-
lambda x, y: x + y, jax.tree_util.tree_leaves(mse)
|
|
396
|
-
),
|
|
397
|
-
res_dict,
|
|
398
|
-
is_leaf=lambda x: isinstance(x, list),
|
|
399
|
-
)
|
|
400
|
-
# Total loss
|
|
401
|
-
total_loss = jax.tree_util.tree_reduce(
|
|
402
|
-
lambda x, y: x + y, jax.tree_util.tree_leaves(res_dict)
|
|
403
|
-
)
|
|
404
|
-
return total_loss, res_dict
|
jinns/loss/_loss_weights.py
CHANGED
|
@@ -2,58 +2,26 @@
|
|
|
2
2
|
Formalize the loss weights data structure
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from typing import Dict
|
|
6
5
|
from jaxtyping import Array, Float
|
|
7
6
|
import equinox as eqx
|
|
8
7
|
|
|
9
8
|
|
|
10
9
|
class LossWeightsODE(eqx.Module):
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
observations: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class LossWeightsODEDict(eqx.Module):
|
|
18
|
-
|
|
19
|
-
dyn_loss: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=None)
|
|
20
|
-
initial_condition: Dict[str, Array | Float | None] = eqx.field(
|
|
21
|
-
kw_only=True, default=None
|
|
22
|
-
)
|
|
23
|
-
observations: Dict[str, Array | Float | None] = eqx.field(
|
|
24
|
-
kw_only=True, default=None
|
|
25
|
-
)
|
|
10
|
+
dyn_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
|
|
11
|
+
initial_condition: Array | Float = eqx.field(kw_only=True, default=0.0)
|
|
12
|
+
observations: Array | Float = eqx.field(kw_only=True, default=0.0)
|
|
26
13
|
|
|
27
14
|
|
|
28
15
|
class LossWeightsPDEStatio(eqx.Module):
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
observations: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
16
|
+
dyn_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
|
|
17
|
+
norm_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
|
|
18
|
+
boundary_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
|
|
19
|
+
observations: Array | Float = eqx.field(kw_only=True, default=0.0)
|
|
34
20
|
|
|
35
21
|
|
|
36
22
|
class LossWeightsPDENonStatio(eqx.Module):
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
initial_condition: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
class LossWeightsPDEDict(eqx.Module):
|
|
46
|
-
"""
|
|
47
|
-
Only one type of LossWeights data structure for the SystemLossPDE:
|
|
48
|
-
Include the initial condition always for the code to be more generic
|
|
49
|
-
"""
|
|
50
|
-
|
|
51
|
-
dyn_loss: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=1.0)
|
|
52
|
-
norm_loss: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=1.0)
|
|
53
|
-
boundary_loss: Dict[str, Array | Float | None] = eqx.field(
|
|
54
|
-
kw_only=True, default=1.0
|
|
55
|
-
)
|
|
56
|
-
observations: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=1.0)
|
|
57
|
-
initial_condition: Dict[str, Array | Float | None] = eqx.field(
|
|
58
|
-
kw_only=True, default=1.0
|
|
59
|
-
)
|
|
23
|
+
dyn_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
|
|
24
|
+
norm_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
|
|
25
|
+
boundary_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
|
|
26
|
+
observations: Array | Float = eqx.field(kw_only=True, default=0.0)
|
|
27
|
+
initial_condition: Array | Float = eqx.field(kw_only=True, default=0.0)
|