jinns 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -0
- jinns/data/_Batchs.py +27 -0
- jinns/data/_DataGenerators.py +904 -1203
- jinns/data/__init__.py +4 -8
- jinns/experimental/__init__.py +0 -2
- jinns/experimental/_diffrax_solver.py +5 -5
- jinns/loss/_DynamicLoss.py +282 -305
- jinns/loss/_DynamicLossAbstract.py +322 -167
- jinns/loss/_LossODE.py +324 -322
- jinns/loss/_LossPDE.py +652 -1027
- jinns/loss/__init__.py +21 -5
- jinns/loss/_boundary_conditions.py +87 -41
- jinns/loss/{_Losses.py → _loss_utils.py} +101 -45
- jinns/loss/_loss_weights.py +59 -0
- jinns/loss/_operators.py +78 -72
- jinns/parameters/__init__.py +6 -0
- jinns/parameters/_derivative_keys.py +521 -0
- jinns/parameters/_params.py +115 -0
- jinns/plot/__init__.py +5 -0
- jinns/{data/_display.py → plot/_plot.py} +98 -75
- jinns/solver/_rar.py +183 -39
- jinns/solver/_solve.py +151 -124
- jinns/utils/__init__.py +3 -9
- jinns/utils/_containers.py +37 -44
- jinns/utils/_hyperpinn.py +224 -119
- jinns/utils/_pinn.py +183 -111
- jinns/utils/_save_load.py +121 -56
- jinns/utils/_spinn.py +113 -86
- jinns/utils/_types.py +64 -0
- jinns/utils/_utils.py +6 -160
- jinns/validation/_validation.py +48 -140
- jinns-1.1.0.dist-info/AUTHORS +2 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/METADATA +5 -4
- jinns-1.1.0.dist-info/RECORD +39 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/WHEEL +1 -1
- jinns/experimental/_sinuspinn.py +0 -135
- jinns/experimental/_spectralpinn.py +0 -87
- jinns/solver/_seq2seq.py +0 -157
- jinns/utils/_optim.py +0 -147
- jinns/utils/_utils_uspinn.py +0 -727
- jinns-0.9.0.dist-info/RECORD +0 -36
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/LICENSE +0 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/top_level.txt +0 -0
jinns/loss/__init__.py
CHANGED
|
@@ -1,11 +1,27 @@
|
|
|
1
|
-
from ._DynamicLossAbstract import ODE, PDEStatio, PDENonStatio
|
|
1
|
+
from ._DynamicLossAbstract import DynamicLoss, ODE, PDEStatio, PDENonStatio
|
|
2
|
+
from ._LossODE import LossODE, SystemLossODE
|
|
3
|
+
from ._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
|
|
2
4
|
from ._DynamicLoss import (
|
|
3
|
-
FisherKPP,
|
|
4
|
-
BurgerEquation,
|
|
5
5
|
GeneralizedLotkaVolterra,
|
|
6
|
+
BurgerEquation,
|
|
7
|
+
FPENonStatioLoss2D,
|
|
6
8
|
OU_FPENonStatioLoss2D,
|
|
9
|
+
FisherKPP,
|
|
7
10
|
MassConservation2DStatio,
|
|
8
11
|
NavierStokes2DStatio,
|
|
9
12
|
)
|
|
10
|
-
from .
|
|
11
|
-
|
|
13
|
+
from ._loss_weights import (
|
|
14
|
+
LossWeightsODE,
|
|
15
|
+
LossWeightsODEDict,
|
|
16
|
+
LossWeightsPDENonStatio,
|
|
17
|
+
LossWeightsPDEStatio,
|
|
18
|
+
LossWeightsPDEDict,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
from ._operators import (
|
|
22
|
+
_div_fwd,
|
|
23
|
+
_div_rev,
|
|
24
|
+
_laplacian_fwd,
|
|
25
|
+
_laplacian_rev,
|
|
26
|
+
_vectorial_laplacian,
|
|
27
|
+
)
|
|
@@ -2,38 +2,53 @@
|
|
|
2
2
|
Implements the main boundary conditions for all kinds of losses in jinns
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
from __future__ import (
|
|
6
|
+
annotations,
|
|
7
|
+
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING, Callable
|
|
5
10
|
import jax
|
|
6
11
|
import jax.numpy as jnp
|
|
7
12
|
from jax import vmap, grad
|
|
13
|
+
import equinox as eqx
|
|
8
14
|
from jinns.utils._utils import (
|
|
9
15
|
_get_grid,
|
|
10
16
|
_check_user_func_return,
|
|
11
|
-
_get_vmap_in_axes_params,
|
|
12
17
|
)
|
|
13
|
-
from jinns.
|
|
18
|
+
from jinns.parameters._params import _get_vmap_in_axes_params
|
|
19
|
+
from jinns.data._Batchs import *
|
|
14
20
|
from jinns.utils._pinn import PINN
|
|
15
21
|
from jinns.utils._spinn import SPINN
|
|
16
22
|
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from jinns.utils._types import *
|
|
25
|
+
|
|
17
26
|
|
|
18
27
|
def _compute_boundary_loss(
|
|
19
|
-
boundary_condition_type
|
|
20
|
-
|
|
28
|
+
boundary_condition_type: str,
|
|
29
|
+
f: Callable,
|
|
30
|
+
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
31
|
+
u: eqx.Module,
|
|
32
|
+
params: Params | ParamsDict,
|
|
33
|
+
facet: int,
|
|
34
|
+
dim_to_apply: slice,
|
|
35
|
+
) -> float:
|
|
21
36
|
r"""A generic function that will compute the mini-batch MSE of a
|
|
22
37
|
boundary condition in the stationary case, resp. non-stationary, given by:
|
|
23
38
|
|
|
24
|
-
|
|
39
|
+
$$
|
|
25
40
|
D[u](\partial x) = f(\partial x), \forall \partial x \in \partial \Omega
|
|
26
|
-
|
|
41
|
+
$$
|
|
27
42
|
resp.,
|
|
28
43
|
|
|
29
|
-
|
|
44
|
+
$$
|
|
30
45
|
D[u](t, \partial x) = f(\partial x), \forall t \in I, \forall \partial
|
|
31
46
|
x \in \partial \Omega
|
|
47
|
+
$$
|
|
32
48
|
|
|
49
|
+
Where $D[\cdot]$ is a differential operator, possibly identity.
|
|
33
50
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
__Note__: if using a batch.param_batch_dict, we need to resolve the
|
|
51
|
+
**Note**: if using a batch.param_batch_dict, we need to resolve the
|
|
37
52
|
vmapping axes in the boundary functions, however params["eq_params"]
|
|
38
53
|
has already been fed with the batch in the `evaluate()` of `LossPDEStatio`,
|
|
39
54
|
resp. `LossPDENonStatio`.
|
|
@@ -41,27 +56,24 @@ def _compute_boundary_loss(
|
|
|
41
56
|
Parameters
|
|
42
57
|
----------
|
|
43
58
|
boundary_condition_type
|
|
44
|
-
a string defining the differential operator
|
|
45
|
-
Currently implements one of "Dirichlet" (
|
|
46
|
-
Neuman (
|
|
47
|
-
unitary outgoing vector normal to
|
|
59
|
+
a string defining the differential operator $D[\cdot]$.
|
|
60
|
+
Currently implements one of "Dirichlet" ($D = Id$) and Von
|
|
61
|
+
Neuman ($D[u] = \nabla u \cdot n$) where $n$ is the
|
|
62
|
+
unitary outgoing vector normal to $\partial\Omega$
|
|
48
63
|
f
|
|
49
64
|
the function to be matched in the boundary condition. It should have
|
|
50
|
-
one
|
|
65
|
+
one or two arguments only (other are ignored).
|
|
51
66
|
batch
|
|
52
|
-
a PDEStatioBatch
|
|
67
|
+
a PDEStatioBatch or PDENonStatioBatch
|
|
53
68
|
u
|
|
54
69
|
a PINN
|
|
55
70
|
params
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
dictionaries: `eq_params` and `nn_params``, respectively the
|
|
59
|
-
differential equation parameters and the neural network parameter
|
|
60
|
-
facet:
|
|
71
|
+
Params or ParamsDict
|
|
72
|
+
facet
|
|
61
73
|
An integer which represents the id of the facet which is currently
|
|
62
74
|
considered (in the order provided by the DataGenerator which is fixed)
|
|
63
75
|
dim_to_apply
|
|
64
|
-
A jnp.
|
|
76
|
+
A `jnp.s_` object which indicates which dimension(s) of u will be forced
|
|
65
77
|
to match the boundary condition
|
|
66
78
|
|
|
67
79
|
Returns
|
|
@@ -91,7 +103,14 @@ def _compute_boundary_loss(
|
|
|
91
103
|
return mse
|
|
92
104
|
|
|
93
105
|
|
|
94
|
-
def boundary_dirichlet_statio(
|
|
106
|
+
def boundary_dirichlet_statio(
|
|
107
|
+
f: Callable,
|
|
108
|
+
batch: PDEStatioBatch,
|
|
109
|
+
u: eqx.Module,
|
|
110
|
+
params: Params | ParamsDict,
|
|
111
|
+
facet: int,
|
|
112
|
+
dim_to_apply: slice,
|
|
113
|
+
) -> float:
|
|
95
114
|
r"""
|
|
96
115
|
This omega boundary condition enforces a solution that is equal to f on
|
|
97
116
|
border batch.
|
|
@@ -102,17 +121,14 @@ def boundary_dirichlet_statio(f, batch, u, params, facet, dim_to_apply):
|
|
|
102
121
|
|
|
103
122
|
Parameters
|
|
104
123
|
----------
|
|
105
|
-
f
|
|
124
|
+
f
|
|
106
125
|
the constraint function
|
|
107
126
|
batch
|
|
108
127
|
A PDEStatioBatch object.
|
|
109
128
|
u
|
|
110
129
|
The PINN
|
|
111
130
|
params
|
|
112
|
-
|
|
113
|
-
Typically, it is a dictionary of
|
|
114
|
-
dictionaries: `eq_params` and `nn_params``, respectively the
|
|
115
|
-
differential equation parameters and the neural network parameter
|
|
131
|
+
Params or ParamsDict
|
|
116
132
|
dim_to_apply
|
|
117
133
|
A jnp.s\_ object. The dimension of u on which to apply the boundary condition
|
|
118
134
|
"""
|
|
@@ -139,14 +155,23 @@ def boundary_dirichlet_statio(f, batch, u, params, facet, dim_to_apply):
|
|
|
139
155
|
res**2,
|
|
140
156
|
axis=-1,
|
|
141
157
|
)
|
|
158
|
+
else:
|
|
159
|
+
raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
|
|
142
160
|
return mse_u_boundary
|
|
143
161
|
|
|
144
162
|
|
|
145
|
-
def boundary_neumann_statio(
|
|
163
|
+
def boundary_neumann_statio(
|
|
164
|
+
f: Callable,
|
|
165
|
+
batch: PDEStatioBatch,
|
|
166
|
+
u: eqx.Module,
|
|
167
|
+
params: Params | ParamsDict,
|
|
168
|
+
facet: int,
|
|
169
|
+
dim_to_apply: slice,
|
|
170
|
+
) -> float:
|
|
146
171
|
r"""
|
|
147
|
-
This omega boundary condition enforces a solution where
|
|
148
|
-
n
|
|
149
|
-
outgoing vector normal at border
|
|
172
|
+
This omega boundary condition enforces a solution where $\nabla u\cdot
|
|
173
|
+
n$ is equal to `f` on omega borders. $n$ is the unitary
|
|
174
|
+
outgoing vector normal at border $\partial\Omega$.
|
|
150
175
|
|
|
151
176
|
__Note__: if using a batch.param_batch_dict, we need to resolve the
|
|
152
177
|
vmapping axes here however params["eq_params"] has already been fed with
|
|
@@ -165,7 +190,7 @@ def boundary_neumann_statio(f, batch, u, params, facet, dim_to_apply):
|
|
|
165
190
|
Typically, it is a dictionary of
|
|
166
191
|
dictionaries: `eq_params` and `nn_params``, respectively the
|
|
167
192
|
differential equation parameters and the neural network parameter
|
|
168
|
-
facet
|
|
193
|
+
facet
|
|
169
194
|
An integer which represents the id of the facet which is currently
|
|
170
195
|
considered (in the order provided wy the DataGenerator which is fixed)
|
|
171
196
|
dim_to_apply
|
|
@@ -248,13 +273,22 @@ def boundary_neumann_statio(f, batch, u, params, facet, dim_to_apply):
|
|
|
248
273
|
boundaries = _check_user_func_return(f(x_grid), values.shape)
|
|
249
274
|
res = values - boundaries
|
|
250
275
|
mse_u_boundary = jnp.sum(res**2, axis=-1)
|
|
276
|
+
else:
|
|
277
|
+
raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
|
|
251
278
|
return mse_u_boundary
|
|
252
279
|
|
|
253
280
|
|
|
254
|
-
def boundary_dirichlet_nonstatio(
|
|
281
|
+
def boundary_dirichlet_nonstatio(
|
|
282
|
+
f: Callable,
|
|
283
|
+
batch: PDENonStatioBatch,
|
|
284
|
+
u: eqx.Module,
|
|
285
|
+
params: Params | ParamsDict,
|
|
286
|
+
facet: int,
|
|
287
|
+
dim_to_apply: slice,
|
|
288
|
+
) -> float:
|
|
255
289
|
r"""
|
|
256
|
-
This omega boundary condition enforces a solution that is equal to f
|
|
257
|
-
at times_batch x omega borders
|
|
290
|
+
This omega boundary condition enforces a solution that is equal to `f`
|
|
291
|
+
at `times_batch` x `omega borders`
|
|
258
292
|
|
|
259
293
|
__Note__: if using a batch.param_batch_dict, we need to resolve the
|
|
260
294
|
vmapping axes here however params["eq_params"] has already been fed with
|
|
@@ -271,7 +305,7 @@ def boundary_dirichlet_nonstatio(f, batch, u, params, facet, dim_to_apply):
|
|
|
271
305
|
params
|
|
272
306
|
The dictionary of parameters of the model.
|
|
273
307
|
Typically, it is a dictionary of
|
|
274
|
-
dictionaries: `eq_params` and `nn_params
|
|
308
|
+
dictionaries: `eq_params` and `nn_params`, respectively the
|
|
275
309
|
differential equation parameters and the neural network parameter
|
|
276
310
|
facet:
|
|
277
311
|
An integer which represents the id of the facet which is currently
|
|
@@ -309,14 +343,24 @@ def boundary_dirichlet_nonstatio(f, batch, u, params, facet, dim_to_apply):
|
|
|
309
343
|
)
|
|
310
344
|
res = values - boundaries
|
|
311
345
|
mse_u_boundary = jnp.sum(res**2, axis=-1)
|
|
346
|
+
else:
|
|
347
|
+
raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
|
|
312
348
|
return mse_u_boundary
|
|
313
349
|
|
|
314
350
|
|
|
315
|
-
def boundary_neumann_nonstatio(
|
|
351
|
+
def boundary_neumann_nonstatio(
|
|
352
|
+
f: Callable,
|
|
353
|
+
batch: PDENonStatioBatch,
|
|
354
|
+
u: eqx.Module,
|
|
355
|
+
params: Params | ParamsDict,
|
|
356
|
+
facet: int,
|
|
357
|
+
dim_to_apply: slice,
|
|
358
|
+
) -> float:
|
|
316
359
|
r"""
|
|
317
|
-
This omega boundary condition enforces a solution where
|
|
318
|
-
n
|
|
319
|
-
outgoing vector normal at border
|
|
360
|
+
This omega boundary condition enforces a solution where $\nabla u\cdot
|
|
361
|
+
n$ is equal to `f` at the cartesian product of `time_batch` x `omega
|
|
362
|
+
borders`. $n$ is the unitary outgoing vector normal at border
|
|
363
|
+
$\partial\Omega$.
|
|
320
364
|
|
|
321
365
|
__Note__: if using a batch.param_batch_dict, we need to resolve the
|
|
322
366
|
vmapping axes here however params["eq_params"] has already been fed with
|
|
@@ -424,4 +468,6 @@ def boundary_neumann_nonstatio(f, batch, u, params, facet, dim_to_apply):
|
|
|
424
468
|
res**2,
|
|
425
469
|
axis=-1,
|
|
426
470
|
)
|
|
471
|
+
else:
|
|
472
|
+
raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
|
|
427
473
|
return mse_u_boundary
|
|
@@ -2,22 +2,43 @@
|
|
|
2
2
|
Interface for diverse loss functions to factorize code
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
from __future__ import (
|
|
6
|
+
annotations,
|
|
7
|
+
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING, Callable, Dict
|
|
5
10
|
import jax
|
|
6
11
|
import jax.numpy as jnp
|
|
7
12
|
from jax import vmap
|
|
13
|
+
import equinox as eqx
|
|
14
|
+
from jaxtyping import Float, Array, PyTree
|
|
8
15
|
|
|
9
|
-
from jinns.utils._pinn import PINN
|
|
10
|
-
from jinns.utils._spinn import SPINN
|
|
11
|
-
from jinns.utils._hyperpinn import HYPERPINN
|
|
12
16
|
from jinns.loss._boundary_conditions import (
|
|
13
17
|
_compute_boundary_loss,
|
|
14
18
|
)
|
|
15
19
|
from jinns.utils._utils import _check_user_func_return, _get_grid
|
|
20
|
+
from jinns.data._DataGenerators import (
|
|
21
|
+
append_obs_batch,
|
|
22
|
+
)
|
|
23
|
+
from jinns.utils._pinn import PINN
|
|
24
|
+
from jinns.utils._spinn import SPINN
|
|
25
|
+
from jinns.utils._hyperpinn import HYPERPINN
|
|
26
|
+
from jinns.data._Batchs import *
|
|
27
|
+
from jinns.parameters._params import Params, ParamsDict
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from jinns.utils._types import *
|
|
16
31
|
|
|
17
32
|
|
|
18
33
|
def dynamic_loss_apply(
|
|
19
|
-
dyn_loss
|
|
20
|
-
|
|
34
|
+
dyn_loss: DynamicLoss,
|
|
35
|
+
u: eqx.Module,
|
|
36
|
+
batches: ODEBatch | PDEStatioBatch | PDENonStatioBatch,
|
|
37
|
+
params: Params | ParamsDict,
|
|
38
|
+
vmap_axes: tuple[int | None, ...],
|
|
39
|
+
loss_weight: float | Float[Array, "dyn_loss_dimension"],
|
|
40
|
+
u_type: PINN | HYPERPINN | None = None,
|
|
41
|
+
) -> float:
|
|
21
42
|
"""
|
|
22
43
|
Sometimes when u is a lambda function a or dict we do not have access to
|
|
23
44
|
its type here, hence the last argument
|
|
@@ -35,11 +56,20 @@ def dynamic_loss_apply(
|
|
|
35
56
|
elif u_type == SPINN or isinstance(u, SPINN):
|
|
36
57
|
residuals = dyn_loss(*batches, u, params)
|
|
37
58
|
mse_dyn_loss = jnp.mean(jnp.sum(loss_weight * residuals**2, axis=-1))
|
|
59
|
+
else:
|
|
60
|
+
raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
|
|
38
61
|
|
|
39
62
|
return mse_dyn_loss
|
|
40
63
|
|
|
41
64
|
|
|
42
|
-
def normalization_loss_apply(
|
|
65
|
+
def normalization_loss_apply(
|
|
66
|
+
u: eqx.Module,
|
|
67
|
+
batches: ODEBatch | PDEStatioBatch | PDENonStatioBatch,
|
|
68
|
+
params: Params | ParamsDict,
|
|
69
|
+
vmap_axes: tuple[int | None, ...],
|
|
70
|
+
int_length: int,
|
|
71
|
+
loss_weight: float,
|
|
72
|
+
) -> float:
|
|
43
73
|
# TODO merge stationary and non stationary cases
|
|
44
74
|
if isinstance(u, (PINN, HYPERPINN)):
|
|
45
75
|
if len(batches) == 1:
|
|
@@ -95,26 +125,38 @@ def normalization_loss_apply(u, batches, params, vmap_axes, int_length, loss_wei
|
|
|
95
125
|
)
|
|
96
126
|
** 2
|
|
97
127
|
)
|
|
128
|
+
else:
|
|
129
|
+
raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
|
|
98
130
|
|
|
99
131
|
return mse_norm_loss
|
|
100
132
|
|
|
101
133
|
|
|
102
134
|
def boundary_condition_apply(
|
|
103
|
-
u,
|
|
104
|
-
batch,
|
|
105
|
-
params,
|
|
106
|
-
omega_boundary_fun,
|
|
107
|
-
omega_boundary_condition,
|
|
108
|
-
omega_boundary_dim,
|
|
109
|
-
loss_weight,
|
|
110
|
-
):
|
|
135
|
+
u: eqx.Module,
|
|
136
|
+
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
137
|
+
params: Params | ParamsDict,
|
|
138
|
+
omega_boundary_fun: Callable,
|
|
139
|
+
omega_boundary_condition: str,
|
|
140
|
+
omega_boundary_dim: int,
|
|
141
|
+
loss_weight: float | Float[Array, "boundary_cond_dim"],
|
|
142
|
+
) -> float:
|
|
111
143
|
if isinstance(omega_boundary_fun, dict):
|
|
112
144
|
# We must create the facet tree dictionary as we do not have the
|
|
113
145
|
# enumerate from the for loop to pass the id integer
|
|
114
|
-
if
|
|
146
|
+
if (
|
|
147
|
+
isinstance(batch, PDEStatioBatch) and batch.border_batch.shape[-1] == 2
|
|
148
|
+
) or (
|
|
149
|
+
isinstance(batch, PDENonStatioBatch)
|
|
150
|
+
and batch.times_x_border_batch.shape[-1] == 2
|
|
151
|
+
):
|
|
115
152
|
# 1D
|
|
116
153
|
facet_tree = {"xmin": 0, "xmax": 1}
|
|
117
|
-
elif
|
|
154
|
+
elif (
|
|
155
|
+
isinstance(batch, PDEStatioBatch) and batch.border_batch.shape[-1] == 4
|
|
156
|
+
) or (
|
|
157
|
+
isinstance(batch, PDENonStatioBatch)
|
|
158
|
+
and batch.times_x_border_batch.shape[-1] == 4
|
|
159
|
+
):
|
|
118
160
|
# 2D
|
|
119
161
|
facet_tree = {"xmin": 0, "xmax": 1, "ymin": 2, "ymax": 3}
|
|
120
162
|
else:
|
|
@@ -138,7 +180,10 @@ def boundary_condition_apply(
|
|
|
138
180
|
# Note that to keep the behaviour given in the comment above we neede
|
|
139
181
|
# to specify is_leaf according to the note in the release of 0.4.29
|
|
140
182
|
else:
|
|
141
|
-
|
|
183
|
+
if isinstance(batch, PDEStatioBatch):
|
|
184
|
+
facet_tuple = tuple(f for f in range(batch.border_batch.shape[-1]))
|
|
185
|
+
else:
|
|
186
|
+
facet_tuple = tuple(f for f in range(batch.times_x_border_batch.shape[-1]))
|
|
142
187
|
b_losses_by_facet = jax.tree_util.tree_map(
|
|
143
188
|
lambda fa: jnp.mean(
|
|
144
189
|
loss_weight
|
|
@@ -161,8 +206,14 @@ def boundary_condition_apply(
|
|
|
161
206
|
|
|
162
207
|
|
|
163
208
|
def observations_loss_apply(
|
|
164
|
-
u
|
|
165
|
-
|
|
209
|
+
u: eqx.Module,
|
|
210
|
+
batches: ODEBatch | PDEStatioBatch | PDENonStatioBatch,
|
|
211
|
+
params: Params | ParamsDict,
|
|
212
|
+
vmap_axes: tuple[int | None, ...],
|
|
213
|
+
observed_values: Float[Array, "batch_size observation_dim"],
|
|
214
|
+
loss_weight: float | Float[Array, "observation_dim"],
|
|
215
|
+
obs_slice: slice,
|
|
216
|
+
) -> float:
|
|
166
217
|
# TODO implement for SPINN
|
|
167
218
|
if isinstance(u, (PINN, HYPERPINN)):
|
|
168
219
|
v_u = vmap(
|
|
@@ -181,12 +232,20 @@ def observations_loss_apply(
|
|
|
181
232
|
)
|
|
182
233
|
elif isinstance(u, SPINN):
|
|
183
234
|
raise RuntimeError("observation loss term not yet implemented for SPINNs")
|
|
235
|
+
else:
|
|
236
|
+
raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
|
|
184
237
|
return mse_observation_loss
|
|
185
238
|
|
|
186
239
|
|
|
187
240
|
def initial_condition_apply(
|
|
188
|
-
u
|
|
189
|
-
|
|
241
|
+
u: eqx.Module,
|
|
242
|
+
omega_batch: Float[Array, "dimension"],
|
|
243
|
+
params: Params | ParamsDict,
|
|
244
|
+
vmap_axes: tuple[int | None, ...],
|
|
245
|
+
initial_condition_fun: Callable,
|
|
246
|
+
n: int,
|
|
247
|
+
loss_weight: float | Float[Array, "initial_condition_dimension"],
|
|
248
|
+
) -> float:
|
|
190
249
|
if isinstance(u, (PINN, HYPERPINN)):
|
|
191
250
|
v_u_t0 = vmap(
|
|
192
251
|
lambda x, params: initial_condition_fun(x) - u(jnp.zeros((1,)), x, params),
|
|
@@ -212,25 +271,17 @@ def initial_condition_apply(
|
|
|
212
271
|
)
|
|
213
272
|
res = ini - v_ini
|
|
214
273
|
mse_initial_condition = jnp.mean(jnp.sum(loss_weight * res**2, axis=-1))
|
|
274
|
+
else:
|
|
275
|
+
raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
|
|
215
276
|
return mse_initial_condition
|
|
216
277
|
|
|
217
278
|
|
|
218
|
-
def sobolev_reg_apply(u, batches, params, vmap_axes, sobolev_reg, loss_weight):
|
|
219
|
-
# TODO implement for SPINN
|
|
220
|
-
if isinstance(u, (PINN, HYPERPINN)):
|
|
221
|
-
v_sob_reg = vmap(
|
|
222
|
-
lambda *args: sobolev_reg(*args), # pylint: disable=E1121
|
|
223
|
-
vmap_axes,
|
|
224
|
-
0,
|
|
225
|
-
)
|
|
226
|
-
mse_sobolev_loss = loss_weight * jnp.mean(v_sob_reg(*batches, params))
|
|
227
|
-
elif isinstance(u, SPINN):
|
|
228
|
-
raise RuntimeError("Sobolev loss term not yet implemented for SPINNs")
|
|
229
|
-
return mse_sobolev_loss
|
|
230
|
-
|
|
231
|
-
|
|
232
279
|
def constraints_system_loss_apply(
|
|
233
|
-
u_constraints_dict
|
|
280
|
+
u_constraints_dict: Dict,
|
|
281
|
+
batch: ODEBatch | PDEStatioBatch | PDENonStatioBatch,
|
|
282
|
+
params_dict: ParamsDict,
|
|
283
|
+
loss_weights: Dict[str, float | Array],
|
|
284
|
+
loss_weight_struct: PyTree,
|
|
234
285
|
):
|
|
235
286
|
"""
|
|
236
287
|
Same function for systemlossODE and systemlossPDE!
|
|
@@ -243,17 +294,17 @@ def constraints_system_loss_apply(
|
|
|
243
294
|
loss_weights,
|
|
244
295
|
)
|
|
245
296
|
|
|
246
|
-
if isinstance(params_dict
|
|
297
|
+
if isinstance(params_dict.nn_params, dict):
|
|
247
298
|
|
|
248
299
|
def apply_u_constraint(
|
|
249
|
-
u_constraint, nn_params, loss_weights_for_u, obs_batch_u
|
|
300
|
+
u_constraint, nn_params, eq_params, loss_weights_for_u, obs_batch_u
|
|
250
301
|
):
|
|
251
302
|
res_dict_for_u = u_constraint.evaluate(
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
batch
|
|
303
|
+
Params(
|
|
304
|
+
nn_params=nn_params,
|
|
305
|
+
eq_params=eq_params,
|
|
306
|
+
),
|
|
307
|
+
append_obs_batch(batch, obs_batch_u),
|
|
257
308
|
)[1]
|
|
258
309
|
res_dict_ponderated = jax.tree_util.tree_map(
|
|
259
310
|
lambda w, l: w * l, res_dict_for_u, loss_weights_for_u
|
|
@@ -267,7 +318,12 @@ def constraints_system_loss_apply(
|
|
|
267
318
|
res_dict = jax.tree_util.tree_map(
|
|
268
319
|
apply_u_constraint,
|
|
269
320
|
u_constraints_dict,
|
|
270
|
-
params_dict
|
|
321
|
+
params_dict.nn_params,
|
|
322
|
+
(
|
|
323
|
+
params_dict.eq_params
|
|
324
|
+
if params_dict.eq_params.keys() == params_dict.nn_params.keys()
|
|
325
|
+
else {k: params_dict.eq_params for k in params_dict.nn_params.keys()}
|
|
326
|
+
), # this manipulation is needed since we authorize eq_params not to have the same structure as nn_params in ParamsDict
|
|
271
327
|
loss_weights_T,
|
|
272
328
|
batch.obs_batch_dict,
|
|
273
329
|
is_leaf=lambda x: (
|
|
@@ -283,7 +339,7 @@ def constraints_system_loss_apply(
|
|
|
283
339
|
def apply_u_constraint(u_constraint, loss_weights_for_u, obs_batch_u):
|
|
284
340
|
res_dict_for_u = u_constraint.evaluate(
|
|
285
341
|
params_dict,
|
|
286
|
-
batch
|
|
342
|
+
append_obs_batch(batch, obs_batch_u),
|
|
287
343
|
)[1]
|
|
288
344
|
res_dict_ponderated = jax.tree_util.tree_map(
|
|
289
345
|
lambda w, l: w * l, res_dict_for_u, loss_weights_for_u
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Formalize the loss weights data structure
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Dict
|
|
6
|
+
from jaxtyping import Array, Float
|
|
7
|
+
import equinox as eqx
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class LossWeightsODE(eqx.Module):
|
|
11
|
+
|
|
12
|
+
dyn_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
13
|
+
initial_condition: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
14
|
+
observations: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class LossWeightsODEDict(eqx.Module):
|
|
18
|
+
|
|
19
|
+
dyn_loss: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=None)
|
|
20
|
+
initial_condition: Dict[str, Array | Float | None] = eqx.field(
|
|
21
|
+
kw_only=True, default=None
|
|
22
|
+
)
|
|
23
|
+
observations: Dict[str, Array | Float | None] = eqx.field(
|
|
24
|
+
kw_only=True, default=None
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class LossWeightsPDEStatio(eqx.Module):
|
|
29
|
+
|
|
30
|
+
dyn_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
31
|
+
norm_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
32
|
+
boundary_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
33
|
+
observations: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class LossWeightsPDENonStatio(eqx.Module):
|
|
37
|
+
|
|
38
|
+
dyn_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
39
|
+
norm_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
40
|
+
boundary_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
41
|
+
observations: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
42
|
+
initial_condition: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class LossWeightsPDEDict(eqx.Module):
|
|
46
|
+
"""
|
|
47
|
+
Only one type of LossWeights data structure for the SystemLossPDE:
|
|
48
|
+
Include the initial condition always for the code to be more generic
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
dyn_loss: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=1.0)
|
|
52
|
+
norm_loss: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=1.0)
|
|
53
|
+
boundary_loss: Dict[str, Array | Float | None] = eqx.field(
|
|
54
|
+
kw_only=True, default=1.0
|
|
55
|
+
)
|
|
56
|
+
observations: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=1.0)
|
|
57
|
+
initial_condition: Dict[str, Array | Float | None] = eqx.field(
|
|
58
|
+
kw_only=True, default=1.0
|
|
59
|
+
)
|