jinns 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -0
- jinns/data/_Batchs.py +27 -0
- jinns/data/_DataGenerators.py +904 -1203
- jinns/data/__init__.py +4 -8
- jinns/experimental/__init__.py +0 -2
- jinns/experimental/_diffrax_solver.py +5 -5
- jinns/loss/_DynamicLoss.py +282 -305
- jinns/loss/_DynamicLossAbstract.py +322 -167
- jinns/loss/_LossODE.py +324 -322
- jinns/loss/_LossPDE.py +652 -1027
- jinns/loss/__init__.py +21 -5
- jinns/loss/_boundary_conditions.py +87 -41
- jinns/loss/{_Losses.py → _loss_utils.py} +101 -45
- jinns/loss/_loss_weights.py +59 -0
- jinns/loss/_operators.py +78 -72
- jinns/parameters/__init__.py +6 -0
- jinns/parameters/_derivative_keys.py +521 -0
- jinns/parameters/_params.py +115 -0
- jinns/plot/__init__.py +5 -0
- jinns/{data/_display.py → plot/_plot.py} +98 -75
- jinns/solver/_rar.py +183 -39
- jinns/solver/_solve.py +151 -124
- jinns/utils/__init__.py +3 -9
- jinns/utils/_containers.py +37 -44
- jinns/utils/_hyperpinn.py +224 -119
- jinns/utils/_pinn.py +183 -111
- jinns/utils/_save_load.py +121 -56
- jinns/utils/_spinn.py +113 -86
- jinns/utils/_types.py +64 -0
- jinns/utils/_utils.py +6 -160
- jinns/validation/_validation.py +48 -140
- jinns-1.1.0.dist-info/AUTHORS +2 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/METADATA +5 -4
- jinns-1.1.0.dist-info/RECORD +39 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/WHEEL +1 -1
- jinns/experimental/_sinuspinn.py +0 -135
- jinns/experimental/_spectralpinn.py +0 -87
- jinns/solver/_seq2seq.py +0 -157
- jinns/utils/_optim.py +0 -147
- jinns/utils/_utils_uspinn.py +0 -727
- jinns-0.9.0.dist-info/RECORD +0 -36
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/LICENSE +0 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossPDE.py
CHANGED
|
@@ -1,29 +1,52 @@
|
|
|
1
|
+
# pylint: disable=unsubscriptable-object, no-member
|
|
1
2
|
"""
|
|
2
3
|
Main module to implement a PDE loss in jinns
|
|
3
4
|
"""
|
|
5
|
+
from __future__ import (
|
|
6
|
+
annotations,
|
|
7
|
+
) # https://docs.python.org/3/library/typing.html#constant
|
|
4
8
|
|
|
9
|
+
import abc
|
|
10
|
+
from dataclasses import InitVar, fields
|
|
11
|
+
from typing import TYPE_CHECKING, Dict, Callable
|
|
5
12
|
import warnings
|
|
6
13
|
import jax
|
|
7
14
|
import jax.numpy as jnp
|
|
8
|
-
|
|
9
|
-
from
|
|
15
|
+
import equinox as eqx
|
|
16
|
+
from jaxtyping import Float, Array, Key, Int
|
|
17
|
+
from jinns.loss._loss_utils import (
|
|
10
18
|
dynamic_loss_apply,
|
|
11
19
|
boundary_condition_apply,
|
|
12
20
|
normalization_loss_apply,
|
|
13
21
|
observations_loss_apply,
|
|
14
|
-
sobolev_reg_apply,
|
|
15
22
|
initial_condition_apply,
|
|
16
23
|
constraints_system_loss_apply,
|
|
17
24
|
)
|
|
18
|
-
from jinns.data._DataGenerators import
|
|
19
|
-
|
|
25
|
+
from jinns.data._DataGenerators import (
|
|
26
|
+
append_obs_batch,
|
|
27
|
+
)
|
|
28
|
+
from jinns.parameters._params import (
|
|
20
29
|
_get_vmap_in_axes_params,
|
|
21
|
-
_set_derivatives,
|
|
22
30
|
_update_eq_params_dict,
|
|
23
31
|
)
|
|
32
|
+
from jinns.parameters._derivative_keys import (
|
|
33
|
+
_set_derivatives,
|
|
34
|
+
DerivativeKeysPDEStatio,
|
|
35
|
+
DerivativeKeysPDENonStatio,
|
|
36
|
+
)
|
|
37
|
+
from jinns.loss._loss_weights import (
|
|
38
|
+
LossWeightsPDEStatio,
|
|
39
|
+
LossWeightsPDENonStatio,
|
|
40
|
+
LossWeightsPDEDict,
|
|
41
|
+
)
|
|
42
|
+
from jinns.loss._DynamicLossAbstract import PDEStatio, PDENonStatio
|
|
24
43
|
from jinns.utils._pinn import PINN
|
|
25
44
|
from jinns.utils._spinn import SPINN
|
|
26
|
-
from jinns.
|
|
45
|
+
from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
if TYPE_CHECKING:
|
|
49
|
+
from jinns.utils._types import *
|
|
27
50
|
|
|
28
51
|
_IMPLEMENTED_BOUNDARY_CONDITIONS = [
|
|
29
52
|
"dirichlet",
|
|
@@ -31,375 +54,167 @@ _IMPLEMENTED_BOUNDARY_CONDITIONS = [
|
|
|
31
54
|
"vonneumann",
|
|
32
55
|
]
|
|
33
56
|
|
|
34
|
-
_LOSS_WEIGHT_KEYS_PDESTATIO = [
|
|
35
|
-
"sobolev",
|
|
36
|
-
"observations",
|
|
37
|
-
"norm_loss",
|
|
38
|
-
"boundary_loss",
|
|
39
|
-
"dyn_loss",
|
|
40
|
-
]
|
|
41
|
-
|
|
42
|
-
_LOSS_WEIGHT_KEYS_PDENONSTATIO = _LOSS_WEIGHT_KEYS_PDESTATIO + ["initial_condition"]
|
|
43
57
|
|
|
44
|
-
|
|
45
|
-
@register_pytree_node_class
|
|
46
|
-
class LossPDEAbstract:
|
|
58
|
+
class _LossPDEAbstract(eqx.Module):
|
|
47
59
|
"""
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
|
|
63
|
+
loss_weights : LossWeightsPDEStatio | LossWeightsPDENonStatio, default=None
|
|
64
|
+
The loss weights for the differents term : dynamic loss,
|
|
65
|
+
initial condition (if LossWeightsPDENonStatio), boundary conditions if
|
|
66
|
+
any, normalization loss if any and observations if any.
|
|
67
|
+
All fields are set to 1.0 by default.
|
|
68
|
+
derivative_keys : DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio, default=None
|
|
69
|
+
Specify which field of `params` should be differentiated for each
|
|
70
|
+
composant of the total loss. Particularily useful for inverse problems.
|
|
71
|
+
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
72
|
+
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
73
|
+
is `"nn_params"` for each composant of the loss.
|
|
74
|
+
omega_boundary_fun : Callable | Dict[str, Callable], default=None
|
|
75
|
+
The function to be matched in the border condition (can be None) or a
|
|
76
|
+
dictionary of such functions as values and keys as described
|
|
77
|
+
in `omega_boundary_condition`.
|
|
78
|
+
omega_boundary_condition : str | Dict[str, str], default=None
|
|
79
|
+
Either None (no condition, by default), or a string defining
|
|
80
|
+
the boundary condition (Dirichlet or Von Neumann),
|
|
81
|
+
or a dictionary with such strings as values. In this case,
|
|
82
|
+
the keys are the facets and must be in the following order:
|
|
83
|
+
1D -> [“xmin”, “xmax”], 2D -> [“xmin”, “xmax”, “ymin”, “ymax”].
|
|
84
|
+
Note that high order boundaries are currently not implemented.
|
|
85
|
+
A value in the dict can be None, this means we do not enforce
|
|
86
|
+
a particular boundary condition on this facet.
|
|
87
|
+
The facet called “xmin”, resp. “xmax” etc., in 2D,
|
|
88
|
+
refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
|
|
89
|
+
omega_boundary_dim : slice | Dict[str, slice], default=None
|
|
90
|
+
Either None, or a slice object or a dictionary of slice objects as
|
|
91
|
+
values and keys as described in `omega_boundary_condition`.
|
|
92
|
+
`omega_boundary_dim` indicates which dimension(s) of the PINN
|
|
93
|
+
will be forced to match the boundary condition.
|
|
94
|
+
Note that it must be a slice and not an integer
|
|
95
|
+
(but a preprocessing of the user provided argument takes care of it)
|
|
96
|
+
norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
|
|
97
|
+
Fixed sample point in the space over which to compute the
|
|
98
|
+
normalization constant. Default is None.
|
|
99
|
+
norm_int_length : float, default=None
|
|
100
|
+
A float. Must be provided if `norm_samples` is provided. The domain area
|
|
101
|
+
(or interval length in 1D) upon which we perform the numerical
|
|
102
|
+
integration. Default None
|
|
103
|
+
obs_slice : slice, default=None
|
|
104
|
+
slice object specifying the begininning/ending of the PINN output
|
|
105
|
+
that is observed (this is then useful for multidim PINN). Default is None.
|
|
106
|
+
params : InitVar[Params], default=None
|
|
107
|
+
The main Params object of the problem needed to instanciate the
|
|
108
|
+
DerivativeKeysODE if the latter is not specified.
|
|
55
109
|
"""
|
|
56
110
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
)
|
|
111
|
+
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
112
|
+
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
113
|
+
# kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
|
|
114
|
+
derivative_keys: DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio | None = (
|
|
115
|
+
eqx.field(kw_only=True, default=None)
|
|
116
|
+
)
|
|
117
|
+
loss_weights: LossWeightsPDEStatio | LossWeightsPDENonStatio | None = eqx.field(
|
|
118
|
+
kw_only=True, default=None
|
|
119
|
+
)
|
|
120
|
+
omega_boundary_fun: Callable | Dict[str, Callable] | None = eqx.field(
|
|
121
|
+
kw_only=True, default=None, static=True
|
|
122
|
+
)
|
|
123
|
+
omega_boundary_condition: str | Dict[str, str] | None = eqx.field(
|
|
124
|
+
kw_only=True, default=None, static=True
|
|
125
|
+
)
|
|
126
|
+
omega_boundary_dim: slice | Dict[str, slice] | None = eqx.field(
|
|
127
|
+
kw_only=True, default=None, static=True
|
|
128
|
+
)
|
|
129
|
+
norm_samples: Float[Array, "nb_norm_samples dimension"] | None = eqx.field(
|
|
130
|
+
kw_only=True, default=None
|
|
131
|
+
)
|
|
132
|
+
norm_int_length: float | None = eqx.field(kw_only=True, default=None)
|
|
133
|
+
obs_slice: slice | None = eqx.field(kw_only=True, default=None, static=True)
|
|
134
|
+
|
|
135
|
+
params: InitVar[Params] = eqx.field(kw_only=True, default=None)
|
|
136
|
+
|
|
137
|
+
def __post_init__(self, params=None):
|
|
66
138
|
"""
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
u
|
|
70
|
-
the PINN object
|
|
71
|
-
loss_weights
|
|
72
|
-
a dictionary with values used to ponderate each term in the loss
|
|
73
|
-
function. Valid keys are `dyn_loss`, `norm_loss`, `boundary_loss`
|
|
74
|
-
and `observations`
|
|
75
|
-
Note that we can have jnp.arrays with the same dimension of
|
|
76
|
-
`u` which then ponderates each output of `u`
|
|
77
|
-
derivative_keys
|
|
78
|
-
A dict of lists of strings. In the dict, the key must correspond to
|
|
79
|
-
the loss term keywords. Then each of the values must correspond to keys in the parameter
|
|
80
|
-
dictionary (*at top level only of the parameter dictionary*).
|
|
81
|
-
It enables selecting the set of parameters
|
|
82
|
-
with respect to which the gradients of the dynamic
|
|
83
|
-
loss are computed. If nothing is provided, we set ["nn_params"] for all loss term
|
|
84
|
-
keywords, this is what is typically
|
|
85
|
-
done in solving forward problems, when we only estimate the
|
|
86
|
-
equation solution with a PINN. If some loss terms keywords are
|
|
87
|
-
missing we set their value to ["nn_params"] by default for the
|
|
88
|
-
same reason
|
|
89
|
-
norm_key
|
|
90
|
-
Jax random key to draw samples in for the Monte Carlo computation
|
|
91
|
-
of the normalization constant. Default is None
|
|
92
|
-
norm_borders
|
|
93
|
-
tuple of (min, max) of the boundaray values of the space over which
|
|
94
|
-
to integrate in the computation of the normalization constant.
|
|
95
|
-
A list of tuple for higher dimensional problems. Default None.
|
|
96
|
-
norm_samples
|
|
97
|
-
Fixed sample point in the space over which to compute the
|
|
98
|
-
normalization constant. Default is None
|
|
99
|
-
|
|
100
|
-
Raises
|
|
101
|
-
------
|
|
102
|
-
RuntimeError
|
|
103
|
-
When provided an invalid combination of `norm_key`, `norm_borders`
|
|
104
|
-
and `norm_samples`. See note below.
|
|
105
|
-
|
|
106
|
-
**Note:** If `norm_key` and `norm_borders` and `norm_samples` are `None`
|
|
107
|
-
then no normalization loss in enforced.
|
|
108
|
-
If `norm_borders` and `norm_samples` are given while
|
|
109
|
-
`norm_samples` is `None` then samples are drawn at each loss evaluation.
|
|
110
|
-
Otherwise, if `norm_samples` is given, those samples are used.
|
|
139
|
+
Note that neither __init__ or __post_init__ are called when udating a
|
|
140
|
+
Module with eqx.tree_at
|
|
111
141
|
"""
|
|
112
|
-
|
|
113
|
-
self.u = u
|
|
114
|
-
if derivative_keys is None:
|
|
142
|
+
if self.derivative_keys is None:
|
|
115
143
|
# be default we only take gradient wrt nn_params
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
"norm_loss",
|
|
122
|
-
"initial_condition",
|
|
123
|
-
"observations",
|
|
124
|
-
"sobolev",
|
|
125
|
-
]
|
|
126
|
-
}
|
|
127
|
-
if isinstance(derivative_keys, list):
|
|
128
|
-
# if the user only provided a list, this defines the gradient taken
|
|
129
|
-
# for all the loss entries
|
|
130
|
-
derivative_keys = {
|
|
131
|
-
k: derivative_keys
|
|
132
|
-
for k in [
|
|
133
|
-
"dyn_loss",
|
|
134
|
-
"boundary_loss",
|
|
135
|
-
"norm_loss",
|
|
136
|
-
"initial_condition",
|
|
137
|
-
"observations",
|
|
138
|
-
"sobolev",
|
|
139
|
-
]
|
|
140
|
-
}
|
|
141
|
-
|
|
142
|
-
self.derivative_keys = derivative_keys
|
|
143
|
-
self.loss_weights = loss_weights
|
|
144
|
-
self.norm_borders = norm_borders
|
|
145
|
-
self.norm_key = norm_key
|
|
146
|
-
self.norm_samples = norm_samples
|
|
147
|
-
|
|
148
|
-
if norm_key is None and norm_borders is None and norm_samples is None:
|
|
149
|
-
# if there is None of the 3 above, that means we don't consider
|
|
150
|
-
# normalization loss
|
|
151
|
-
self.normalization_loss = None
|
|
152
|
-
elif (
|
|
153
|
-
norm_key is not None and norm_borders is not None and norm_samples is None
|
|
154
|
-
): # this ordering so that by default priority is to given mc_samples
|
|
155
|
-
self.norm_sample_method = "generate"
|
|
156
|
-
if not isinstance(self.norm_borders[0], tuple):
|
|
157
|
-
self.norm_borders = (self.norm_borders,)
|
|
158
|
-
self.norm_xmin, self.norm_xmax = [], []
|
|
159
|
-
for i, _ in enumerate(self.norm_borders):
|
|
160
|
-
self.norm_xmin.append(self.norm_borders[i][0])
|
|
161
|
-
self.norm_xmax.append(self.norm_borders[i][1])
|
|
162
|
-
self.int_length = jnp.prod(
|
|
163
|
-
jnp.array(
|
|
164
|
-
[
|
|
165
|
-
self.norm_xmax[i] - self.norm_xmin[i]
|
|
166
|
-
for i in range(len(self.norm_borders))
|
|
167
|
-
]
|
|
168
|
-
)
|
|
169
|
-
)
|
|
170
|
-
self.normalization_loss = True
|
|
171
|
-
elif norm_samples is None:
|
|
172
|
-
raise RuntimeError(
|
|
173
|
-
"norm_borders should always provided then either"
|
|
174
|
-
" norm_samples (fixed norm_samples) or norm_key (random norm_samples)"
|
|
175
|
-
" is required."
|
|
176
|
-
)
|
|
177
|
-
else:
|
|
178
|
-
# ok, we are sure we have norm_samples given by the user
|
|
179
|
-
self.norm_sample_method = "user"
|
|
180
|
-
if not isinstance(self.norm_borders[0], tuple):
|
|
181
|
-
self.norm_borders = (self.norm_borders,)
|
|
182
|
-
self.norm_xmin, self.norm_xmax = [], []
|
|
183
|
-
for i, _ in enumerate(self.norm_borders):
|
|
184
|
-
self.norm_xmin.append(self.norm_borders[i][0])
|
|
185
|
-
self.norm_xmax.append(self.norm_borders[i][1])
|
|
186
|
-
self.int_length = jnp.prod(
|
|
187
|
-
jnp.array(
|
|
188
|
-
[
|
|
189
|
-
self.norm_xmax[i] - self.norm_xmin[i]
|
|
190
|
-
for i in range(len(self.norm_borders))
|
|
191
|
-
]
|
|
144
|
+
try:
|
|
145
|
+
self.derivative_keys = (
|
|
146
|
+
DerivativeKeysPDENonStatio(params=params)
|
|
147
|
+
if isinstance(self, LossPDENonStatio)
|
|
148
|
+
else DerivativeKeysPDEStatio(params=params)
|
|
192
149
|
)
|
|
150
|
+
except ValueError as exc:
|
|
151
|
+
raise ValueError(
|
|
152
|
+
"Problem at self.derivative_keys initialization "
|
|
153
|
+
f"received {self.derivative_keys=} and {params=}"
|
|
154
|
+
) from exc
|
|
155
|
+
|
|
156
|
+
if self.loss_weights is None:
|
|
157
|
+
self.loss_weights = (
|
|
158
|
+
LossWeightsPDENonStatio()
|
|
159
|
+
if isinstance(self, LossPDENonStatio)
|
|
160
|
+
else LossWeightsPDEStatio()
|
|
193
161
|
)
|
|
194
|
-
self.normalization_loss = True
|
|
195
162
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
Returns a batch of points in the domain for integration when the
|
|
199
|
-
normalization constraint is enforced. The batch of points is either
|
|
200
|
-
fixed (provided by the user) or regenerated at each iteration.
|
|
201
|
-
"""
|
|
202
|
-
if self.norm_sample_method == "user":
|
|
203
|
-
return self.norm_samples
|
|
204
|
-
if self.norm_sample_method == "generate":
|
|
205
|
-
## NOTE TODO CHECK the performances of this for loop
|
|
206
|
-
norm_samples = []
|
|
207
|
-
for d in range(len(self.norm_borders)):
|
|
208
|
-
self.norm_key, subkey = jax.random.split(self.norm_key)
|
|
209
|
-
norm_samples.append(
|
|
210
|
-
jax.random.uniform(
|
|
211
|
-
subkey,
|
|
212
|
-
shape=(1000, 1),
|
|
213
|
-
minval=self.norm_xmin[d],
|
|
214
|
-
maxval=self.norm_xmax[d],
|
|
215
|
-
)
|
|
216
|
-
)
|
|
217
|
-
self.norm_samples = jnp.concatenate(norm_samples, axis=-1)
|
|
218
|
-
return self.norm_samples
|
|
219
|
-
raise RuntimeError("Problem with the value of self.norm_sample_method")
|
|
220
|
-
|
|
221
|
-
def tree_flatten(self):
|
|
222
|
-
children = (self.norm_key, self.norm_samples, self.loss_weights)
|
|
223
|
-
aux_data = {
|
|
224
|
-
"norm_borders": self.norm_borders,
|
|
225
|
-
"derivative_keys": self.derivative_keys,
|
|
226
|
-
"u": self.u,
|
|
227
|
-
}
|
|
228
|
-
return (children, aux_data)
|
|
229
|
-
|
|
230
|
-
@classmethod
|
|
231
|
-
def tree_unflatten(self, aux_data, children):
|
|
232
|
-
(norm_key, norm_samples, loss_weights) = children
|
|
233
|
-
pls = self(
|
|
234
|
-
aux_data["u"],
|
|
235
|
-
loss_weights,
|
|
236
|
-
aux_data["derivative_keys"],
|
|
237
|
-
norm_key,
|
|
238
|
-
aux_data["norm_borders"],
|
|
239
|
-
norm_samples,
|
|
240
|
-
)
|
|
241
|
-
return pls
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
@register_pytree_node_class
|
|
245
|
-
class LossPDEStatio(LossPDEAbstract):
|
|
246
|
-
r"""Loss object for a stationary partial differential equation
|
|
247
|
-
|
|
248
|
-
.. math::
|
|
249
|
-
\mathcal{N}[u](x) = 0, \forall x \in \Omega
|
|
250
|
-
|
|
251
|
-
where :math:`\mathcal{N}[\cdot]` is a differential operator and the
|
|
252
|
-
boundary condition is :math:`u(x)=u_b(x)` The additional condition of
|
|
253
|
-
integrating to 1 can be included, i.e. :math:`\int u(x)\mathrm{d}x=1`.
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
**Note:** LossPDEStatio is jittable. Hence it implements the tree_flatten() and
|
|
257
|
-
tree_unflatten methods.
|
|
258
|
-
"""
|
|
259
|
-
|
|
260
|
-
def __init__(
|
|
261
|
-
self,
|
|
262
|
-
u,
|
|
263
|
-
loss_weights,
|
|
264
|
-
dynamic_loss,
|
|
265
|
-
derivative_keys=None,
|
|
266
|
-
omega_boundary_fun=None,
|
|
267
|
-
omega_boundary_condition=None,
|
|
268
|
-
omega_boundary_dim=None,
|
|
269
|
-
norm_key=None,
|
|
270
|
-
norm_borders=None,
|
|
271
|
-
norm_samples=None,
|
|
272
|
-
sobolev_m=None,
|
|
273
|
-
obs_slice=None,
|
|
274
|
-
):
|
|
275
|
-
r"""
|
|
276
|
-
Parameters
|
|
277
|
-
----------
|
|
278
|
-
u
|
|
279
|
-
the PINN object
|
|
280
|
-
loss_weights
|
|
281
|
-
a dictionary with values used to ponderate each term in the loss
|
|
282
|
-
function. Valid keys are `dyn_loss`, `norm_loss`, `boundary_loss`,
|
|
283
|
-
`observations` and `sobolev`.
|
|
284
|
-
Note that we can have jnp.arrays with the same dimension of
|
|
285
|
-
`u` which then ponderates each output of `u`
|
|
286
|
-
dynamic_loss
|
|
287
|
-
the stationary PDE dynamic part of the loss, basically the differential
|
|
288
|
-
operator :math:` \mathcal{N}[u](t)`. Should implement a method
|
|
289
|
-
`dynamic_loss.evaluate(t, u, params)`.
|
|
290
|
-
Can be None in order to access only some part of the evaluate call
|
|
291
|
-
results.
|
|
292
|
-
derivative_keys
|
|
293
|
-
A dict of lists of strings. In the dict, the key must correspond to
|
|
294
|
-
the loss term keywords. Then each of the values must correspond to keys in the parameter
|
|
295
|
-
dictionary (*at top level only of the parameter dictionary*).
|
|
296
|
-
It enables selecting the set of parameters
|
|
297
|
-
with respect to which the gradients of the dynamic
|
|
298
|
-
loss are computed. If nothing is provided, we set ["nn_params"] for all loss term
|
|
299
|
-
keywords, this is what is typically
|
|
300
|
-
done in solving forward problems, when we only estimate the
|
|
301
|
-
equation solution with a PINN. If some loss terms keywords are
|
|
302
|
-
missing we set their value to ["nn_params"] by default for the same
|
|
303
|
-
reason
|
|
304
|
-
omega_boundary_fun
|
|
305
|
-
The function to be matched in the border condition (can be None)
|
|
306
|
-
or a dictionary of such function. In this case, the keys are the
|
|
307
|
-
facets and the values are the functions. The keys must be in the
|
|
308
|
-
following order: 1D -> ["xmin", "xmax"], 2D -> ["xmin", "xmax",
|
|
309
|
-
"ymin", "ymax"]. Note that high order boundaries are currently not
|
|
310
|
-
implemented. A value in the dict can be None, this means we do not
|
|
311
|
-
enforce a particular boundary condition on this facet.
|
|
312
|
-
The facet called "xmin", resp. "xmax" etc., in 2D,
|
|
313
|
-
refers to the set of 2D points with fixed "xmin", resp. "xmax", etc.
|
|
314
|
-
omega_boundary_condition
|
|
315
|
-
Either None (no condition), or a string defining the boundary
|
|
316
|
-
condition e.g. Dirichlet or Von Neumann, or a dictionary of such
|
|
317
|
-
strings. In this case, the keys are the
|
|
318
|
-
facets and the values are the strings. The keys must be in the
|
|
319
|
-
following order: 1D -> ["xmin", "xmax"], 2D -> ["xmin", "xmax",
|
|
320
|
-
"ymin", "ymax"]. Note that high order boundaries are currently not
|
|
321
|
-
implemented. A value in the dict can be None, this means we do not
|
|
322
|
-
enforce a particular boundary condition on this facet.
|
|
323
|
-
The facet called "xmin", resp. "xmax" etc., in 2D,
|
|
324
|
-
refers to the set of 2D points with fixed "xmin", resp. "xmax", etc.
|
|
325
|
-
omega_boundary_dim
|
|
326
|
-
Either None, or a jnp.s\_ or a dict of jnp.s\_ with keys following
|
|
327
|
-
the logic of omega_boundary_fun. It indicates which dimension(s) of
|
|
328
|
-
the PINN will be forced to match the boundary condition
|
|
329
|
-
Note that it must be a slice and not an integer (a preprocessing of the
|
|
330
|
-
user provided argument takes care of it)
|
|
331
|
-
norm_key
|
|
332
|
-
Jax random key to draw samples in for the Monte Carlo computation
|
|
333
|
-
of the normalization constant. Default is None
|
|
334
|
-
norm_borders
|
|
335
|
-
tuple of (min, max) of the boundaray values of the space over which
|
|
336
|
-
to integrate in the computation of the normalization constant.
|
|
337
|
-
A list of tuple for higher dimensional problems. Default None.
|
|
338
|
-
norm_samples
|
|
339
|
-
Fixed sample point in the space over which to compute the
|
|
340
|
-
normalization constant. Default is None
|
|
341
|
-
sobolev_m
|
|
342
|
-
An integer. Default is None.
|
|
343
|
-
It corresponds to the Sobolev regularization order as proposed in
|
|
344
|
-
*Convergence and error analysis of PINNs*,
|
|
345
|
-
Doumeche et al., 2023, https://arxiv.org/pdf/2305.01240.pdf
|
|
346
|
-
obs_slice
|
|
347
|
-
slice object specifying the begininning/ending
|
|
348
|
-
slice of u output(s) that is observed (this is then useful for
|
|
349
|
-
multidim PINN). Default is None.
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
Raises
|
|
353
|
-
------
|
|
354
|
-
ValueError
|
|
355
|
-
If conditions on omega_boundary_condition and omega_boundary_fun
|
|
356
|
-
are not respected
|
|
357
|
-
"""
|
|
163
|
+
if self.obs_slice is None:
|
|
164
|
+
self.obs_slice = jnp.s_[...]
|
|
358
165
|
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
166
|
+
if (
|
|
167
|
+
isinstance(self.omega_boundary_fun, dict)
|
|
168
|
+
and not isinstance(self.omega_boundary_condition, dict)
|
|
169
|
+
) or (
|
|
170
|
+
not isinstance(self.omega_boundary_fun, dict)
|
|
171
|
+
and isinstance(self.omega_boundary_condition, dict)
|
|
172
|
+
):
|
|
173
|
+
raise ValueError(
|
|
174
|
+
"if one of self.omega_boundary_fun or "
|
|
175
|
+
"self.omega_boundary_condition is dict, the other should be too."
|
|
176
|
+
)
|
|
362
177
|
|
|
363
|
-
if omega_boundary_condition is None or omega_boundary_fun is None:
|
|
178
|
+
if self.omega_boundary_condition is None or self.omega_boundary_fun is None:
|
|
364
179
|
warnings.warn(
|
|
365
180
|
"Missing boundary function or no boundary condition."
|
|
366
181
|
"Boundary function is thus ignored."
|
|
367
182
|
)
|
|
368
183
|
else:
|
|
369
|
-
if isinstance(omega_boundary_condition, dict):
|
|
370
|
-
for _, v in omega_boundary_condition.items():
|
|
184
|
+
if isinstance(self.omega_boundary_condition, dict):
|
|
185
|
+
for _, v in self.omega_boundary_condition.items():
|
|
371
186
|
if v is not None and not any(
|
|
372
187
|
v.lower() in s for s in _IMPLEMENTED_BOUNDARY_CONDITIONS
|
|
373
188
|
):
|
|
374
189
|
raise NotImplementedError(
|
|
375
|
-
f"The boundary condition {omega_boundary_condition} is not"
|
|
190
|
+
f"The boundary condition {self.omega_boundary_condition} is not"
|
|
376
191
|
f"implemented yet. Try one of :"
|
|
377
192
|
f"{_IMPLEMENTED_BOUNDARY_CONDITIONS}."
|
|
378
193
|
)
|
|
379
194
|
else:
|
|
380
195
|
if not any(
|
|
381
|
-
omega_boundary_condition.lower() in s
|
|
196
|
+
self.omega_boundary_condition.lower() in s
|
|
382
197
|
for s in _IMPLEMENTED_BOUNDARY_CONDITIONS
|
|
383
198
|
):
|
|
384
199
|
raise NotImplementedError(
|
|
385
|
-
f"The boundary condition {omega_boundary_condition} is not"
|
|
200
|
+
f"The boundary condition {self.omega_boundary_condition} is not"
|
|
386
201
|
f"implemented yet. Try one of :"
|
|
387
202
|
f"{_IMPLEMENTED_BOUNDARY_CONDITIONS}."
|
|
388
203
|
)
|
|
389
|
-
if isinstance(omega_boundary_fun, dict) and isinstance(
|
|
390
|
-
omega_boundary_condition, dict
|
|
204
|
+
if isinstance(self.omega_boundary_fun, dict) and isinstance(
|
|
205
|
+
self.omega_boundary_condition, dict
|
|
391
206
|
):
|
|
392
207
|
if (
|
|
393
208
|
not (
|
|
394
|
-
list(omega_boundary_fun.keys()) == ["xmin", "xmax"]
|
|
395
|
-
and list(omega_boundary_condition.keys())
|
|
209
|
+
list(self.omega_boundary_fun.keys()) == ["xmin", "xmax"]
|
|
210
|
+
and list(self.omega_boundary_condition.keys())
|
|
396
211
|
== ["xmin", "xmax"]
|
|
397
212
|
)
|
|
398
213
|
) or (
|
|
399
214
|
not (
|
|
400
|
-
list(omega_boundary_fun.keys())
|
|
215
|
+
list(self.omega_boundary_fun.keys())
|
|
401
216
|
== ["xmin", "xmax", "ymin", "ymax"]
|
|
402
|
-
and list(omega_boundary_condition.keys())
|
|
217
|
+
and list(self.omega_boundary_condition.keys())
|
|
403
218
|
== ["xmin", "xmax", "ymin", "ymax"]
|
|
404
219
|
)
|
|
405
220
|
):
|
|
@@ -408,10 +223,6 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
408
223
|
"boundary condition dictionaries is incorrect"
|
|
409
224
|
)
|
|
410
225
|
|
|
411
|
-
self.omega_boundary_fun = omega_boundary_fun
|
|
412
|
-
self.omega_boundary_condition = omega_boundary_condition
|
|
413
|
-
|
|
414
|
-
self.omega_boundary_dim = omega_boundary_dim
|
|
415
226
|
if isinstance(self.omega_boundary_fun, dict):
|
|
416
227
|
if self.omega_boundary_dim is None:
|
|
417
228
|
self.omega_boundary_dim = {
|
|
@@ -440,44 +251,144 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
440
251
|
self.omega_boundary_dim : self.omega_boundary_dim + 1
|
|
441
252
|
]
|
|
442
253
|
if not isinstance(self.omega_boundary_dim, slice):
|
|
443
|
-
raise ValueError("self.omega_boundary_dim must be a jnp.s_
|
|
254
|
+
raise ValueError("self.omega_boundary_dim must be a jnp.s_ object")
|
|
444
255
|
|
|
445
|
-
self.
|
|
256
|
+
if self.norm_samples is not None and self.norm_int_length is None:
|
|
257
|
+
raise ValueError("self.norm_samples and norm_int_length must be provided")
|
|
446
258
|
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
else:
|
|
455
|
-
self.sobolev_reg = None
|
|
259
|
+
@abc.abstractmethod
|
|
260
|
+
def evaluate(
|
|
261
|
+
self: eqx.Module,
|
|
262
|
+
params: Params,
|
|
263
|
+
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
264
|
+
) -> tuple[Float, dict]:
|
|
265
|
+
raise NotImplementedError
|
|
456
266
|
|
|
457
|
-
for k in _LOSS_WEIGHT_KEYS_PDESTATIO:
|
|
458
|
-
if k not in self.loss_weights.keys():
|
|
459
|
-
self.loss_weights[k] = 0
|
|
460
267
|
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
and not isinstance(self.omega_boundary_condition, dict)
|
|
464
|
-
) or (
|
|
465
|
-
not isinstance(self.omega_boundary_fun, dict)
|
|
466
|
-
and isinstance(self.omega_boundary_condition, dict)
|
|
467
|
-
):
|
|
468
|
-
raise ValueError(
|
|
469
|
-
"if one of self.omega_boundary_fun or "
|
|
470
|
-
"self.omega_boundary_condition is dict, the other should be too."
|
|
471
|
-
)
|
|
268
|
+
class LossPDEStatio(_LossPDEAbstract):
|
|
269
|
+
r"""Loss object for a stationary partial differential equation
|
|
472
270
|
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
271
|
+
$$
|
|
272
|
+
\mathcal{N}[u](x) = 0, \forall x \in \Omega
|
|
273
|
+
$$
|
|
274
|
+
|
|
275
|
+
where $\mathcal{N}[\cdot]$ is a differential operator and the
|
|
276
|
+
boundary condition is $u(x)=u_b(x)$ The additional condition of
|
|
277
|
+
integrating to 1 can be included, i.e. $\int u(x)\mathrm{d}x=1$.
|
|
278
|
+
|
|
279
|
+
Parameters
|
|
280
|
+
----------
|
|
281
|
+
u : eqx.Module
|
|
282
|
+
the PINN
|
|
283
|
+
dynamic_loss : DynamicLoss
|
|
284
|
+
the stationary PDE dynamic part of the loss, basically the differential
|
|
285
|
+
operator $\mathcal{N}[u](x)$. Should implement a method
|
|
286
|
+
`dynamic_loss.evaluate(x, u, params)`.
|
|
287
|
+
Can be None in order to access only some part of the evaluate call
|
|
288
|
+
results.
|
|
289
|
+
key : Key
|
|
290
|
+
A JAX PRNG Key for the loss class treated as an attribute. Default is
|
|
291
|
+
None. This field is provided for future developments and additional
|
|
292
|
+
losses that might need some randomness. Note that special care must be
|
|
293
|
+
taken when splitting the key because in-place updates are forbidden in
|
|
294
|
+
eqx.Modules.
|
|
295
|
+
loss_weights : LossWeightsPDEStatio, default=None
|
|
296
|
+
The loss weights for the differents term : dynamic loss,
|
|
297
|
+
boundary conditions if any, normalization loss if any and
|
|
298
|
+
observations if any.
|
|
299
|
+
All fields are set to 1.0 by default.
|
|
300
|
+
derivative_keys : DerivativeKeysPDEStatio, default=None
|
|
301
|
+
Specify which field of `params` should be differentiated for each
|
|
302
|
+
composant of the total loss. Particularily useful for inverse problems.
|
|
303
|
+
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
304
|
+
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
305
|
+
is `"nn_params"` for each composant of the loss.
|
|
306
|
+
omega_boundary_fun : Callable | Dict[str, Callable], default=None
|
|
307
|
+
The function to be matched in the border condition (can be None) or a
|
|
308
|
+
dictionary of such functions as values and keys as described
|
|
309
|
+
in `omega_boundary_condition`.
|
|
310
|
+
omega_boundary_condition : str | Dict[str, str], default=None
|
|
311
|
+
Either None (no condition, by default), or a string defining
|
|
312
|
+
the boundary condition (Dirichlet or Von Neumann),
|
|
313
|
+
or a dictionary with such strings as values. In this case,
|
|
314
|
+
the keys are the facets and must be in the following order:
|
|
315
|
+
1D -> [“xmin”, “xmax”], 2D -> [“xmin”, “xmax”, “ymin”, “ymax”].
|
|
316
|
+
Note that high order boundaries are currently not implemented.
|
|
317
|
+
A value in the dict can be None, this means we do not enforce
|
|
318
|
+
a particular boundary condition on this facet.
|
|
319
|
+
The facet called “xmin”, resp. “xmax” etc., in 2D,
|
|
320
|
+
refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
|
|
321
|
+
omega_boundary_dim : slice | Dict[str, slice], default=None
|
|
322
|
+
Either None, or a slice object or a dictionary of slice objects as
|
|
323
|
+
values and keys as described in `omega_boundary_condition`.
|
|
324
|
+
`omega_boundary_dim` indicates which dimension(s) of the PINN
|
|
325
|
+
will be forced to match the boundary condition.
|
|
326
|
+
Note that it must be a slice and not an integer
|
|
327
|
+
(but a preprocessing of the user provided argument takes care of it)
|
|
328
|
+
norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
|
|
329
|
+
Fixed sample point in the space over which to compute the
|
|
330
|
+
normalization constant. Default is None.
|
|
331
|
+
norm_int_length : float, default=None
|
|
332
|
+
A float. Must be provided if `norm_samples` is provided. The domain area
|
|
333
|
+
(or interval length in 1D) upon which we perform the numerical
|
|
334
|
+
integration. Default None
|
|
335
|
+
obs_slice : slice, default=None
|
|
336
|
+
slice object specifying the begininning/ending of the PINN output
|
|
337
|
+
that is observed (this is then useful for multidim PINN). Default is None.
|
|
338
|
+
params : InitVar[Params], default=None
|
|
339
|
+
The main Params object of the problem needed to instanciate the
|
|
340
|
+
DerivativeKeysODE if the latter is not specified.
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
Raises
|
|
344
|
+
------
|
|
345
|
+
ValueError
|
|
346
|
+
If conditions on omega_boundary_condition and omega_boundary_fun
|
|
347
|
+
are not respected
|
|
348
|
+
"""
|
|
349
|
+
|
|
350
|
+
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
351
|
+
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
352
|
+
|
|
353
|
+
u: eqx.Module
|
|
354
|
+
dynamic_loss: DynamicLoss | None
|
|
355
|
+
key: Key | None = eqx.field(kw_only=True, default=None)
|
|
356
|
+
|
|
357
|
+
vmap_in_axes: tuple[Int] = eqx.field(init=False, static=True)
|
|
358
|
+
|
|
359
|
+
def __post_init__(self, params=None):
|
|
360
|
+
"""
|
|
361
|
+
Note that neither __init__ or __post_init__ are called when udating a
|
|
362
|
+
Module with eqx.tree_at!
|
|
363
|
+
"""
|
|
364
|
+
super().__post_init__(
|
|
365
|
+
params=params
|
|
366
|
+
) # because __init__ or __post_init__ of Base
|
|
367
|
+
# class is not automatically called
|
|
368
|
+
|
|
369
|
+
self.vmap_in_axes = (0,) # for x only here
|
|
370
|
+
|
|
371
|
+
def _get_dynamic_loss_batch(
|
|
372
|
+
self, batch: PDEStatioBatch
|
|
373
|
+
) -> tuple[Float[Array, "batch_size dimension"]]:
|
|
374
|
+
return (batch.inside_batch,)
|
|
375
|
+
|
|
376
|
+
def _get_normalization_loss_batch(
|
|
377
|
+
self, _
|
|
378
|
+
) -> Float[Array, "nb_norm_samples dimension"]:
|
|
379
|
+
return (self.norm_samples,)
|
|
380
|
+
|
|
381
|
+
def _get_observations_loss_batch(
|
|
382
|
+
self, batch: PDEStatioBatch
|
|
383
|
+
) -> Float[Array, "batch_size obs_dim"]:
|
|
384
|
+
return (batch.obs_batch_dict["pinn_in"],)
|
|
476
385
|
|
|
477
386
|
def __call__(self, *args, **kwargs):
|
|
478
387
|
return self.evaluate(*args, **kwargs)
|
|
479
388
|
|
|
480
|
-
def evaluate(
|
|
389
|
+
def evaluate(
|
|
390
|
+
self, params: Params, batch: PDEStatioBatch
|
|
391
|
+
) -> tuple[Float[Array, "1"], dict[str, float]]:
|
|
481
392
|
"""
|
|
482
393
|
Evaluate the loss function at a batch of points for given parameters.
|
|
483
394
|
|
|
@@ -485,22 +396,14 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
485
396
|
Parameters
|
|
486
397
|
---------
|
|
487
398
|
params
|
|
488
|
-
|
|
489
|
-
Typically, it is a dictionary of
|
|
490
|
-
dictionaries: `eq_params` and `nn_params``, respectively the
|
|
491
|
-
differential equation parameters and the neural network parameter
|
|
399
|
+
Parameters at which the loss is evaluated
|
|
492
400
|
batch
|
|
493
|
-
|
|
494
|
-
Such a named tuple is composed of a batch of points in the
|
|
401
|
+
Composed of a batch of points in the
|
|
495
402
|
domain, a batch of points in the domain
|
|
496
403
|
border and an optional additional batch of parameters (eg. for
|
|
497
404
|
metamodeling) and an optional additional batch of observed
|
|
498
405
|
inputs/outputs/parameters
|
|
499
406
|
"""
|
|
500
|
-
omega_batch, _ = batch.inside_batch, batch.border_batch
|
|
501
|
-
|
|
502
|
-
vmap_in_axes_x = (0,)
|
|
503
|
-
|
|
504
407
|
# Retrieve the optional eq_params_batch
|
|
505
408
|
# and update eq_params with the latter
|
|
506
409
|
# and update vmap_in_axes
|
|
@@ -511,44 +414,41 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
511
414
|
vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
|
|
512
415
|
|
|
513
416
|
# dynamic part
|
|
514
|
-
params_ = _set_derivatives(params, "dyn_loss", self.derivative_keys)
|
|
515
417
|
if self.dynamic_loss is not None:
|
|
516
418
|
mse_dyn_loss = dynamic_loss_apply(
|
|
517
419
|
self.dynamic_loss.evaluate,
|
|
518
420
|
self.u,
|
|
519
|
-
(
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
self.loss_weights
|
|
421
|
+
self._get_dynamic_loss_batch(batch),
|
|
422
|
+
_set_derivatives(params, self.derivative_keys.dyn_loss),
|
|
423
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
424
|
+
self.loss_weights.dyn_loss,
|
|
523
425
|
)
|
|
524
426
|
else:
|
|
525
427
|
mse_dyn_loss = jnp.array(0.0)
|
|
526
428
|
|
|
527
429
|
# normalization part
|
|
528
|
-
|
|
529
|
-
if self.normalization_loss is not None:
|
|
430
|
+
if self.norm_samples is not None:
|
|
530
431
|
mse_norm_loss = normalization_loss_apply(
|
|
531
432
|
self.u,
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
self.
|
|
536
|
-
self.loss_weights
|
|
433
|
+
self._get_normalization_loss_batch(batch),
|
|
434
|
+
_set_derivatives(params, self.derivative_keys.norm_loss),
|
|
435
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
436
|
+
self.norm_int_length,
|
|
437
|
+
self.loss_weights.norm_loss,
|
|
537
438
|
)
|
|
538
439
|
else:
|
|
539
440
|
mse_norm_loss = jnp.array(0.0)
|
|
540
441
|
|
|
541
442
|
# boundary part
|
|
542
|
-
params_ = _set_derivatives(params, "boundary_loss", self.derivative_keys)
|
|
543
443
|
if self.omega_boundary_condition is not None:
|
|
544
444
|
mse_boundary_loss = boundary_condition_apply(
|
|
545
445
|
self.u,
|
|
546
446
|
batch,
|
|
547
|
-
|
|
447
|
+
_set_derivatives(params, self.derivative_keys.boundary_loss),
|
|
548
448
|
self.omega_boundary_fun,
|
|
549
449
|
self.omega_boundary_condition,
|
|
550
450
|
self.omega_boundary_dim,
|
|
551
|
-
self.loss_weights
|
|
451
|
+
self.loss_weights.boundary_loss,
|
|
552
452
|
)
|
|
553
453
|
else:
|
|
554
454
|
mse_boundary_loss = jnp.array(0.0)
|
|
@@ -558,40 +458,21 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
558
458
|
# update params with the batches of observed params
|
|
559
459
|
params = _update_eq_params_dict(params, batch.obs_batch_dict["eq_params"])
|
|
560
460
|
|
|
561
|
-
params_ = _set_derivatives(params, "observations", self.derivative_keys)
|
|
562
461
|
mse_observation_loss = observations_loss_apply(
|
|
563
462
|
self.u,
|
|
564
|
-
(batch
|
|
565
|
-
|
|
566
|
-
|
|
463
|
+
self._get_observations_loss_batch(batch),
|
|
464
|
+
_set_derivatives(params, self.derivative_keys.observations),
|
|
465
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
567
466
|
batch.obs_batch_dict["val"],
|
|
568
|
-
self.loss_weights
|
|
467
|
+
self.loss_weights.observations,
|
|
569
468
|
self.obs_slice,
|
|
570
469
|
)
|
|
571
470
|
else:
|
|
572
471
|
mse_observation_loss = jnp.array(0.0)
|
|
573
472
|
|
|
574
|
-
# Sobolev regularization
|
|
575
|
-
params_ = _set_derivatives(params, "sobolev", self.derivative_keys)
|
|
576
|
-
if self.sobolev_reg is not None:
|
|
577
|
-
mse_sobolev_loss = sobolev_reg_apply(
|
|
578
|
-
self.u,
|
|
579
|
-
(omega_batch,),
|
|
580
|
-
params_,
|
|
581
|
-
vmap_in_axes_x + vmap_in_axes_params,
|
|
582
|
-
self.sobolev_reg,
|
|
583
|
-
self.loss_weights["sobolev"],
|
|
584
|
-
)
|
|
585
|
-
else:
|
|
586
|
-
mse_sobolev_loss = jnp.array(0.0)
|
|
587
|
-
|
|
588
473
|
# total loss
|
|
589
474
|
total_loss = (
|
|
590
|
-
mse_dyn_loss
|
|
591
|
-
+ mse_norm_loss
|
|
592
|
-
+ mse_boundary_loss
|
|
593
|
-
+ mse_observation_loss
|
|
594
|
-
+ mse_sobolev_loss
|
|
475
|
+
mse_dyn_loss + mse_norm_loss + mse_boundary_loss + mse_observation_loss
|
|
595
476
|
)
|
|
596
477
|
return total_loss, (
|
|
597
478
|
{
|
|
@@ -599,205 +480,148 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
599
480
|
"norm_loss": mse_norm_loss,
|
|
600
481
|
"boundary_loss": mse_boundary_loss,
|
|
601
482
|
"observations": mse_observation_loss,
|
|
602
|
-
"sobolev": mse_sobolev_loss,
|
|
603
483
|
"initial_condition": jnp.array(0.0), # for compatibility in the
|
|
604
484
|
# tree_map of SystemLoss
|
|
605
485
|
}
|
|
606
486
|
)
|
|
607
487
|
|
|
608
|
-
def tree_flatten(self):
|
|
609
|
-
children = (self.norm_key, self.norm_samples, self.loss_weights)
|
|
610
|
-
aux_data = {
|
|
611
|
-
"u": self.u,
|
|
612
|
-
"dynamic_loss": self.dynamic_loss,
|
|
613
|
-
"derivative_keys": self.derivative_keys,
|
|
614
|
-
"omega_boundary_fun": self.omega_boundary_fun,
|
|
615
|
-
"omega_boundary_condition": self.omega_boundary_condition,
|
|
616
|
-
"omega_boundary_dim": self.omega_boundary_dim,
|
|
617
|
-
"norm_borders": self.norm_borders,
|
|
618
|
-
"sobolev_m": self.sobolev_m,
|
|
619
|
-
"obs_slice": self.obs_slice,
|
|
620
|
-
}
|
|
621
|
-
return (children, aux_data)
|
|
622
|
-
|
|
623
|
-
@classmethod
|
|
624
|
-
def tree_unflatten(cls, aux_data, children):
|
|
625
|
-
(norm_key, norm_samples, loss_weights) = children
|
|
626
|
-
pls = cls(
|
|
627
|
-
aux_data["u"],
|
|
628
|
-
loss_weights,
|
|
629
|
-
aux_data["dynamic_loss"],
|
|
630
|
-
aux_data["derivative_keys"],
|
|
631
|
-
aux_data["omega_boundary_fun"],
|
|
632
|
-
aux_data["omega_boundary_condition"],
|
|
633
|
-
aux_data["omega_boundary_dim"],
|
|
634
|
-
norm_key,
|
|
635
|
-
aux_data["norm_borders"],
|
|
636
|
-
norm_samples,
|
|
637
|
-
aux_data["sobolev_m"],
|
|
638
|
-
aux_data["obs_slice"],
|
|
639
|
-
)
|
|
640
|
-
return pls
|
|
641
488
|
|
|
642
|
-
|
|
643
|
-
@register_pytree_node_class
|
|
644
489
|
class LossPDENonStatio(LossPDEStatio):
|
|
645
490
|
r"""Loss object for a stationary partial differential equation
|
|
646
491
|
|
|
647
|
-
|
|
492
|
+
$$
|
|
648
493
|
\mathcal{N}[u](t, x) = 0, \forall t \in I, \forall x \in \Omega
|
|
494
|
+
$$
|
|
649
495
|
|
|
650
|
-
where
|
|
651
|
-
The boundary condition is
|
|
652
|
-
x\in\delta\Omega, \forall t
|
|
653
|
-
The initial condition is
|
|
496
|
+
where $\mathcal{N}[\cdot]$ is a differential operator.
|
|
497
|
+
The boundary condition is $u(t, x)=u_b(t, x),\forall
|
|
498
|
+
x\in\delta\Omega, \forall t$.
|
|
499
|
+
The initial condition is $u(0, x)=u_0(x), \forall x\in\Omega$
|
|
654
500
|
The additional condition of
|
|
655
|
-
integrating to 1 can be included, i.e.,
|
|
656
|
-
|
|
501
|
+
integrating to 1 can be included, i.e., $\int u(t, x)\mathrm{d}x=1$.
|
|
502
|
+
|
|
503
|
+
Parameters
|
|
504
|
+
----------
|
|
505
|
+
u : eqx.Module
|
|
506
|
+
the PINN
|
|
507
|
+
dynamic_loss : DynamicLoss
|
|
508
|
+
the non stationary PDE dynamic part of the loss, basically the differential
|
|
509
|
+
operator $\mathcal{N}[u](t, x)$. Should implement a method
|
|
510
|
+
`dynamic_loss.evaluate(t, x, u, params)`.
|
|
511
|
+
Can be None in order to access only some part of the evaluate call
|
|
512
|
+
results.
|
|
513
|
+
key : Key
|
|
514
|
+
A JAX PRNG Key for the loss class treated as an attribute. Default is
|
|
515
|
+
None. This field is provided for future developments and additional
|
|
516
|
+
losses that might need some randomness. Note that special care must be
|
|
517
|
+
taken when splitting the key because in-place updates are forbidden in
|
|
518
|
+
eqx.Modules.
|
|
519
|
+
reason
|
|
520
|
+
loss_weights : LossWeightsPDENonStatio, default=None
|
|
521
|
+
The loss weights for the differents term : dynamic loss,
|
|
522
|
+
boundary conditions if any, initial condition, normalization loss if any and
|
|
523
|
+
observations if any.
|
|
524
|
+
All fields are set to 1.0 by default.
|
|
525
|
+
derivative_keys : DerivativeKeysPDENonStatio, default=None
|
|
526
|
+
Specify which field of `params` should be differentiated for each
|
|
527
|
+
composant of the total loss. Particularily useful for inverse problems.
|
|
528
|
+
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
529
|
+
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
530
|
+
is `"nn_params"` for each composant of the loss.
|
|
531
|
+
omega_boundary_fun : Callable | Dict[str, Callable], default=None
|
|
532
|
+
The function to be matched in the border condition (can be None) or a
|
|
533
|
+
dictionary of such functions as values and keys as described
|
|
534
|
+
in `omega_boundary_condition`.
|
|
535
|
+
omega_boundary_condition : str | Dict[str, str], default=None
|
|
536
|
+
Either None (no condition, by default), or a string defining
|
|
537
|
+
the boundary condition (Dirichlet or Von Neumann),
|
|
538
|
+
or a dictionary with such strings as values. In this case,
|
|
539
|
+
the keys are the facets and must be in the following order:
|
|
540
|
+
1D -> [“xmin”, “xmax”], 2D -> [“xmin”, “xmax”, “ymin”, “ymax”].
|
|
541
|
+
Note that high order boundaries are currently not implemented.
|
|
542
|
+
A value in the dict can be None, this means we do not enforce
|
|
543
|
+
a particular boundary condition on this facet.
|
|
544
|
+
The facet called “xmin”, resp. “xmax” etc., in 2D,
|
|
545
|
+
refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
|
|
546
|
+
omega_boundary_dim : slice | Dict[str, slice], default=None
|
|
547
|
+
Either None, or a slice object or a dictionary of slice objects as
|
|
548
|
+
values and keys as described in `omega_boundary_condition`.
|
|
549
|
+
`omega_boundary_dim` indicates which dimension(s) of the PINN
|
|
550
|
+
will be forced to match the boundary condition.
|
|
551
|
+
Note that it must be a slice and not an integer
|
|
552
|
+
(but a preprocessing of the user provided argument takes care of it)
|
|
553
|
+
norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
|
|
554
|
+
Fixed sample point in the space over which to compute the
|
|
555
|
+
normalization constant. Default is None.
|
|
556
|
+
norm_int_length : float, default=None
|
|
557
|
+
A float. Must be provided if `norm_samples` is provided. The domain area
|
|
558
|
+
(or interval length in 1D) upon which we perform the numerical
|
|
559
|
+
integration. Default None
|
|
560
|
+
obs_slice : slice, default=None
|
|
561
|
+
slice object specifying the begininning/ending of the PINN output
|
|
562
|
+
that is observed (this is then useful for multidim PINN). Default is None.
|
|
563
|
+
initial_condition_fun : Callable, default=None
|
|
564
|
+
A function representing the temporal initial condition. If None
|
|
565
|
+
(default) then no initial condition is applied
|
|
566
|
+
params : InitVar[Params], default=None
|
|
567
|
+
The main Params object of the problem needed to instanciate the
|
|
568
|
+
DerivativeKeysODE if the latter is not specified.
|
|
657
569
|
|
|
658
|
-
**Note:** LossPDENonStatio is jittable. Hence it implements the tree_flatten() and
|
|
659
|
-
tree_unflatten methods.
|
|
660
570
|
"""
|
|
661
571
|
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
derivative_keys=None,
|
|
668
|
-
omega_boundary_fun=None,
|
|
669
|
-
omega_boundary_condition=None,
|
|
670
|
-
omega_boundary_dim=None,
|
|
671
|
-
initial_condition_fun=None,
|
|
672
|
-
norm_key=None,
|
|
673
|
-
norm_borders=None,
|
|
674
|
-
norm_samples=None,
|
|
675
|
-
sobolev_m=None,
|
|
676
|
-
obs_slice=None,
|
|
677
|
-
):
|
|
678
|
-
r"""
|
|
679
|
-
Parameters
|
|
680
|
-
----------
|
|
681
|
-
u
|
|
682
|
-
the PINN object
|
|
683
|
-
loss_weights
|
|
684
|
-
dictionary of values for loss term ponderation
|
|
685
|
-
Note that we can have jnp.arrays with the same dimension of
|
|
686
|
-
`u` which then ponderates each output of `u`
|
|
687
|
-
dynamic_loss
|
|
688
|
-
A Dynamic loss object whose evaluate method corresponds to the
|
|
689
|
-
dynamic term in the loss
|
|
690
|
-
Can be None in order to access only some part of the evaluate call
|
|
691
|
-
results.
|
|
692
|
-
derivative_keys
|
|
693
|
-
A dict of lists of strings. In the dict, the key must correspond to
|
|
694
|
-
the loss term keywords. Then each of the values must correspond to keys in the parameter
|
|
695
|
-
dictionary (*at top level only of the parameter dictionary*).
|
|
696
|
-
It enables selecting the set of parameters
|
|
697
|
-
with respect to which the gradients of the dynamic
|
|
698
|
-
loss are computed. If nothing is provided, we set ["nn_params"] for all loss term
|
|
699
|
-
keywords, this is what is typically
|
|
700
|
-
done in solving forward problems, when we only estimate the
|
|
701
|
-
equation solution with a PINN. If some loss terms keywords are
|
|
702
|
-
missing we set their value to ["nn_params"] by default for the same
|
|
703
|
-
reason
|
|
704
|
-
omega_boundary_fun
|
|
705
|
-
The function to be matched in the border condition (can be None)
|
|
706
|
-
or a dictionary of such function. In this case, the keys are the
|
|
707
|
-
facets and the values are the functions. The keys must be in the
|
|
708
|
-
following order: 1D -> ["xmin", "xmax"], 2D -> ["xmin", "xmax",
|
|
709
|
-
"ymin", "ymax"]. Note that high order boundaries are currently not
|
|
710
|
-
implemented. A value in the dict can be None, this means we do not
|
|
711
|
-
enforce a particular boundary condition on this facet.
|
|
712
|
-
The facet called "xmin", resp. "xmax" etc., in 2D,
|
|
713
|
-
refers to the set of 2D points with fixed "xmin", resp. "xmax", etc.
|
|
714
|
-
omega_boundary_condition
|
|
715
|
-
Either None (no condition), or a string defining the boundary
|
|
716
|
-
condition e.g. Dirichlet or Von Neumann, or a dictionary of such
|
|
717
|
-
strings. In this case, the keys are the
|
|
718
|
-
facets and the values are the strings. The keys must be in the
|
|
719
|
-
following order: 1D -> ["xmin", "xmax"], 2D -> ["xmin", "xmax",
|
|
720
|
-
"ymin", "ymax"]. Note that high order boundaries are currently not
|
|
721
|
-
implemented. A value in the dict can be None, this means we do not
|
|
722
|
-
enforce a particular boundary condition on this facet.
|
|
723
|
-
The facet called "xmin", resp. "xmax" etc., in 2D,
|
|
724
|
-
refers to the set of 2D points with fixed "xmin", resp. "xmax", etc.
|
|
725
|
-
omega_boundary_dim
|
|
726
|
-
Either None, or a jnp.s\_ or a dict of jnp.s\_ with keys following
|
|
727
|
-
the logic of omega_boundary_fun. It indicates which dimension(s) of
|
|
728
|
-
the PINN will be forced to match the boundary condition
|
|
729
|
-
Note that it must be a slice and not an integer (a preprocessing of the
|
|
730
|
-
user provided argument takes care of it)
|
|
731
|
-
initial_condition_fun
|
|
732
|
-
A function representing the temporal initial condition. If None
|
|
733
|
-
(default) then no initial condition is applied
|
|
734
|
-
norm_key
|
|
735
|
-
Jax random key to draw samples in for the Monte Carlo computation
|
|
736
|
-
of the normalization constant. Default is None
|
|
737
|
-
norm_borders
|
|
738
|
-
tuple of (min, max) of the boundaray values of the space over which
|
|
739
|
-
to integrate in the computation of the normalization constant.
|
|
740
|
-
A list of tuple for higher dimensional problems. Default None.
|
|
741
|
-
norm_samples
|
|
742
|
-
Fixed sample point in the space over which to compute the
|
|
743
|
-
normalization constant. Default is None
|
|
744
|
-
sobolev_m
|
|
745
|
-
An integer. Default is None.
|
|
746
|
-
It corresponds to the Sobolev regularization order as proposed in
|
|
747
|
-
*Convergence and error analysis of PINNs*,
|
|
748
|
-
Doumeche et al., 2023, https://arxiv.org/pdf/2305.01240.pdf
|
|
749
|
-
obs_slice
|
|
750
|
-
slice object specifying the begininning/ending
|
|
751
|
-
slice of u output(s) that is observed (this is then useful for
|
|
752
|
-
multidim PINN). Default is None.
|
|
753
|
-
|
|
572
|
+
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
573
|
+
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
574
|
+
initial_condition_fun: Callable | None = eqx.field(
|
|
575
|
+
kw_only=True, default=None, static=True
|
|
576
|
+
)
|
|
754
577
|
|
|
578
|
+
def __post_init__(self, params=None):
|
|
579
|
+
"""
|
|
580
|
+
Note that neither __init__ or __post_init__ are called when udating a
|
|
581
|
+
Module with eqx.tree_at!
|
|
755
582
|
"""
|
|
583
|
+
super().__post_init__(
|
|
584
|
+
params=params
|
|
585
|
+
) # because __init__ or __post_init__ of Base
|
|
586
|
+
# class is not automatically called
|
|
756
587
|
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
dynamic_loss,
|
|
761
|
-
derivative_keys,
|
|
762
|
-
omega_boundary_fun,
|
|
763
|
-
omega_boundary_condition,
|
|
764
|
-
omega_boundary_dim,
|
|
765
|
-
norm_key,
|
|
766
|
-
norm_borders,
|
|
767
|
-
norm_samples,
|
|
768
|
-
sobolev_m=sobolev_m,
|
|
769
|
-
obs_slice=obs_slice,
|
|
770
|
-
)
|
|
771
|
-
if initial_condition_fun is None:
|
|
588
|
+
self.vmap_in_axes = (0, 0) # for t and x
|
|
589
|
+
|
|
590
|
+
if self.initial_condition_fun is None:
|
|
772
591
|
warnings.warn(
|
|
773
592
|
"Initial condition wasn't provided. Be sure to cover for that"
|
|
774
593
|
"case (e.g by. hardcoding it into the PINN output)."
|
|
775
594
|
)
|
|
776
|
-
self.initial_condition_fun = initial_condition_fun
|
|
777
|
-
|
|
778
|
-
self.sobolev_m = sobolev_m
|
|
779
|
-
if self.sobolev_m is not None:
|
|
780
|
-
# This overwrite the wrongly initialized self.sobolev_reg with
|
|
781
|
-
# statio=True in the LossPDEStatio init
|
|
782
|
-
self.sobolev_reg = _sobolev(self.u, self.sobolev_m, statio=False)
|
|
783
|
-
# we return a function, that way
|
|
784
|
-
# the order of sobolev_m is static and the conditional in the recursive
|
|
785
|
-
# function is properly set
|
|
786
|
-
else:
|
|
787
|
-
self.sobolev_reg = None
|
|
788
595
|
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
596
|
+
def _get_dynamic_loss_batch(
|
|
597
|
+
self, batch: PDENonStatioBatch
|
|
598
|
+
) -> tuple[Float[Array, "batch_size 1"], Float[Array, "batch_size dimension"]]:
|
|
599
|
+
times_batch = batch.times_x_inside_batch[:, 0:1]
|
|
600
|
+
omega_batch = batch.times_x_inside_batch[:, 1:]
|
|
601
|
+
return (times_batch, omega_batch)
|
|
602
|
+
|
|
603
|
+
def _get_normalization_loss_batch(
|
|
604
|
+
self, batch: PDENonStatioBatch
|
|
605
|
+
) -> tuple[Float[Array, "batch_size 1"], Float[Array, "nb_norm_samples dimension"]]:
|
|
606
|
+
return (
|
|
607
|
+
batch.times_x_inside_batch[:, 0:1],
|
|
608
|
+
self.norm_samples,
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
def _get_observations_loss_batch(
|
|
612
|
+
self, batch: PDENonStatioBatch
|
|
613
|
+
) -> tuple[Float[Array, "batch_size 1"], Float[Array, "batch_size dimension"]]:
|
|
614
|
+
return (
|
|
615
|
+
batch.obs_batch_dict["pinn_in"][:, 0:1],
|
|
616
|
+
batch.obs_batch_dict["pinn_in"][:, 1:],
|
|
617
|
+
)
|
|
792
618
|
|
|
793
619
|
def __call__(self, *args, **kwargs):
|
|
794
620
|
return self.evaluate(*args, **kwargs)
|
|
795
621
|
|
|
796
622
|
def evaluate(
|
|
797
|
-
self,
|
|
798
|
-
|
|
799
|
-
batch,
|
|
800
|
-
):
|
|
623
|
+
self, params: Params, batch: PDENonStatioBatch
|
|
624
|
+
) -> tuple[Float[Array, "1"], dict[str, float]]:
|
|
801
625
|
"""
|
|
802
626
|
Evaluate the loss function at a batch of points for given parameters.
|
|
803
627
|
|
|
@@ -805,191 +629,54 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
805
629
|
Parameters
|
|
806
630
|
---------
|
|
807
631
|
params
|
|
808
|
-
|
|
809
|
-
Typically, it is a dictionary of
|
|
810
|
-
dictionaries: `eq_params` and `nn_params`, respectively the
|
|
811
|
-
differential equation parameters and the neural network parameter
|
|
632
|
+
Parameters at which the loss is evaluated
|
|
812
633
|
batch
|
|
813
|
-
|
|
814
|
-
Such a named tuple is composed of a batch of points in
|
|
634
|
+
Composed of a batch of points in
|
|
815
635
|
the domain, a batch of points in the domain
|
|
816
636
|
border, a batch of time points and an optional additional batch
|
|
817
637
|
of parameters (eg. for metamodeling) and an optional additional batch of observed
|
|
818
638
|
inputs/outputs/parameters
|
|
819
639
|
"""
|
|
820
|
-
|
|
821
|
-
times_batch = batch.times_x_inside_batch[:, 0:1]
|
|
822
640
|
omega_batch = batch.times_x_inside_batch[:, 1:]
|
|
823
|
-
n = omega_batch.shape[0]
|
|
824
|
-
|
|
825
|
-
vmap_in_axes_x_t = (0, 0)
|
|
826
641
|
|
|
827
642
|
# Retrieve the optional eq_params_batch
|
|
828
643
|
# and update eq_params with the latter
|
|
829
644
|
# and update vmap_in_axes
|
|
830
645
|
if batch.param_batch_dict is not None:
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
# feed the eq_params with the batch
|
|
834
|
-
for k in eq_params_batch_dict.keys():
|
|
835
|
-
params["eq_params"][k] = eq_params_batch_dict[k]
|
|
646
|
+
# update eq_params with the batches of generated params
|
|
647
|
+
params = _update_eq_params_dict(params, batch.param_batch_dict)
|
|
836
648
|
|
|
837
649
|
vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
|
|
838
650
|
|
|
839
|
-
#
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
mse_dyn_loss = dynamic_loss_apply(
|
|
843
|
-
self.dynamic_loss.evaluate,
|
|
844
|
-
self.u,
|
|
845
|
-
(times_batch, omega_batch),
|
|
846
|
-
params_,
|
|
847
|
-
vmap_in_axes_x_t + vmap_in_axes_params,
|
|
848
|
-
self.loss_weights["dyn_loss"],
|
|
849
|
-
)
|
|
850
|
-
else:
|
|
851
|
-
mse_dyn_loss = jnp.array(0.0)
|
|
852
|
-
|
|
853
|
-
# normalization part
|
|
854
|
-
params_ = _set_derivatives(params, "norm_loss", self.derivative_keys)
|
|
855
|
-
if self.normalization_loss is not None:
|
|
856
|
-
mse_norm_loss = normalization_loss_apply(
|
|
857
|
-
self.u,
|
|
858
|
-
(times_batch, self.get_norm_samples()),
|
|
859
|
-
params_,
|
|
860
|
-
vmap_in_axes_x_t + vmap_in_axes_params,
|
|
861
|
-
self.int_length,
|
|
862
|
-
self.loss_weights["norm_loss"],
|
|
863
|
-
)
|
|
864
|
-
else:
|
|
865
|
-
mse_norm_loss = jnp.array(0.0)
|
|
866
|
-
|
|
867
|
-
# boundary part
|
|
868
|
-
params_ = _set_derivatives(params, "boundary_loss", self.derivative_keys)
|
|
869
|
-
if self.omega_boundary_fun is not None:
|
|
870
|
-
mse_boundary_loss = boundary_condition_apply(
|
|
871
|
-
self.u,
|
|
872
|
-
batch,
|
|
873
|
-
params_,
|
|
874
|
-
self.omega_boundary_fun,
|
|
875
|
-
self.omega_boundary_condition,
|
|
876
|
-
self.omega_boundary_dim,
|
|
877
|
-
self.loss_weights["boundary_loss"],
|
|
878
|
-
)
|
|
879
|
-
else:
|
|
880
|
-
mse_boundary_loss = jnp.array(0.0)
|
|
651
|
+
# For mse_dyn_loss, mse_norm_loss, mse_boundary_loss,
|
|
652
|
+
# mse_observation_loss we use the evaluate from parent class
|
|
653
|
+
partial_mse, partial_mse_terms = super().evaluate(params, batch)
|
|
881
654
|
|
|
882
655
|
# initial condition
|
|
883
|
-
params_ = _set_derivatives(params, "initial_condition", self.derivative_keys)
|
|
884
656
|
if self.initial_condition_fun is not None:
|
|
885
657
|
mse_initial_condition = initial_condition_apply(
|
|
886
658
|
self.u,
|
|
887
659
|
omega_batch,
|
|
888
|
-
|
|
660
|
+
_set_derivatives(params, self.derivative_keys.initial_condition),
|
|
889
661
|
(0,) + vmap_in_axes_params,
|
|
890
662
|
self.initial_condition_fun,
|
|
891
|
-
|
|
892
|
-
self.loss_weights
|
|
663
|
+
omega_batch.shape[0],
|
|
664
|
+
self.loss_weights.initial_condition,
|
|
893
665
|
)
|
|
894
666
|
else:
|
|
895
667
|
mse_initial_condition = jnp.array(0.0)
|
|
896
668
|
|
|
897
|
-
# Observation mse
|
|
898
|
-
if batch.obs_batch_dict is not None:
|
|
899
|
-
# update params with the batches of observed params
|
|
900
|
-
params = _update_eq_params_dict(params, batch.obs_batch_dict["eq_params"])
|
|
901
|
-
|
|
902
|
-
params_ = _set_derivatives(params, "observations", self.derivative_keys)
|
|
903
|
-
mse_observation_loss = observations_loss_apply(
|
|
904
|
-
self.u,
|
|
905
|
-
(
|
|
906
|
-
batch.obs_batch_dict["pinn_in"][:, 0:1],
|
|
907
|
-
batch.obs_batch_dict["pinn_in"][:, 1:],
|
|
908
|
-
),
|
|
909
|
-
params_,
|
|
910
|
-
vmap_in_axes_x_t + vmap_in_axes_params,
|
|
911
|
-
batch.obs_batch_dict["val"],
|
|
912
|
-
self.loss_weights["observations"],
|
|
913
|
-
self.obs_slice,
|
|
914
|
-
)
|
|
915
|
-
else:
|
|
916
|
-
mse_observation_loss = jnp.array(0.0)
|
|
917
|
-
|
|
918
|
-
# Sobolev regularization
|
|
919
|
-
params_ = _set_derivatives(params, "sobolev", self.derivative_keys)
|
|
920
|
-
if self.sobolev_reg is not None:
|
|
921
|
-
mse_sobolev_loss = sobolev_reg_apply(
|
|
922
|
-
self.u,
|
|
923
|
-
(omega_batch, times_batch),
|
|
924
|
-
params_,
|
|
925
|
-
vmap_in_axes_x_t + vmap_in_axes_params,
|
|
926
|
-
self.sobolev_reg,
|
|
927
|
-
self.loss_weights["sobolev"],
|
|
928
|
-
)
|
|
929
|
-
else:
|
|
930
|
-
mse_sobolev_loss = jnp.array(0.0)
|
|
931
|
-
|
|
932
669
|
# total loss
|
|
933
|
-
total_loss =
|
|
934
|
-
mse_dyn_loss
|
|
935
|
-
+ mse_norm_loss
|
|
936
|
-
+ mse_boundary_loss
|
|
937
|
-
+ mse_initial_condition
|
|
938
|
-
+ mse_observation_loss
|
|
939
|
-
+ mse_sobolev_loss
|
|
940
|
-
)
|
|
941
|
-
|
|
942
|
-
return total_loss, (
|
|
943
|
-
{
|
|
944
|
-
"dyn_loss": mse_dyn_loss,
|
|
945
|
-
"norm_loss": mse_norm_loss,
|
|
946
|
-
"boundary_loss": mse_boundary_loss,
|
|
947
|
-
"initial_condition": mse_initial_condition,
|
|
948
|
-
"observations": mse_observation_loss,
|
|
949
|
-
"sobolev": mse_sobolev_loss,
|
|
950
|
-
}
|
|
951
|
-
)
|
|
670
|
+
total_loss = partial_mse + mse_initial_condition
|
|
952
671
|
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
"u": self.u,
|
|
957
|
-
"dynamic_loss": self.dynamic_loss,
|
|
958
|
-
"derivative_keys": self.derivative_keys,
|
|
959
|
-
"omega_boundary_fun": self.omega_boundary_fun,
|
|
960
|
-
"omega_boundary_condition": self.omega_boundary_condition,
|
|
961
|
-
"omega_boundary_dim": self.omega_boundary_dim,
|
|
962
|
-
"initial_condition_fun": self.initial_condition_fun,
|
|
963
|
-
"norm_borders": self.norm_borders,
|
|
964
|
-
"sobolev_m": self.sobolev_m,
|
|
965
|
-
"obs_slice": self.obs_slice,
|
|
672
|
+
return total_loss, {
|
|
673
|
+
**partial_mse_terms,
|
|
674
|
+
"initial_condition": mse_initial_condition,
|
|
966
675
|
}
|
|
967
|
-
return (children, aux_data)
|
|
968
|
-
|
|
969
|
-
@classmethod
|
|
970
|
-
def tree_unflatten(cls, aux_data, children):
|
|
971
|
-
(norm_key, norm_samples, loss_weights) = children
|
|
972
|
-
pls = cls(
|
|
973
|
-
aux_data["u"],
|
|
974
|
-
loss_weights,
|
|
975
|
-
aux_data["dynamic_loss"],
|
|
976
|
-
aux_data["derivative_keys"],
|
|
977
|
-
aux_data["omega_boundary_fun"],
|
|
978
|
-
aux_data["omega_boundary_condition"],
|
|
979
|
-
aux_data["omega_boundary_dim"],
|
|
980
|
-
aux_data["initial_condition_fun"],
|
|
981
|
-
norm_key,
|
|
982
|
-
aux_data["norm_borders"],
|
|
983
|
-
norm_samples,
|
|
984
|
-
aux_data["sobolev_m"],
|
|
985
|
-
aux_data["obs_slice"],
|
|
986
|
-
)
|
|
987
|
-
return pls
|
|
988
676
|
|
|
989
677
|
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
"""
|
|
678
|
+
class SystemLossPDE(eqx.Module):
|
|
679
|
+
r"""
|
|
993
680
|
Class to implement a system of PDEs.
|
|
994
681
|
The goal is to give maximum freedom to the user. The class is created with
|
|
995
682
|
a dict of dynamic loss, and dictionaries of all the objects that are used
|
|
@@ -1003,190 +690,182 @@ class SystemLossPDE:
|
|
|
1003
690
|
Indeed, these dictionaries (except `dynamic_loss_dict`) are tied to one
|
|
1004
691
|
solution.
|
|
1005
692
|
|
|
1006
|
-
|
|
1007
|
-
|
|
693
|
+
Parameters
|
|
694
|
+
----------
|
|
695
|
+
u_dict : Dict[str, eqx.Module]
|
|
696
|
+
dict of PINNs
|
|
697
|
+
loss_weights : LossWeightsPDEDict
|
|
698
|
+
A dictionary of LossWeightsODE
|
|
699
|
+
derivative_keys_dict : Dict[str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio], default=None
|
|
700
|
+
A dictionnary of DerivativeKeysPDEStatio or DerivativeKeysPDENonStatio
|
|
701
|
+
specifying what field of `params`
|
|
702
|
+
should be used during gradient computations for each of the terms of
|
|
703
|
+
the total loss, for each of the loss in the system. Default is
|
|
704
|
+
`"nn_params`" everywhere.
|
|
705
|
+
dynamic_loss_dict : Dict[str, PDEStatio | PDENonStatio]
|
|
706
|
+
A dict of dynamic part of the loss, basically the differential
|
|
707
|
+
operator $\mathcal{N}[u](t, x)$ or $\mathcal{N}[u](x)$.
|
|
708
|
+
key_dict : Dict[str, Key], default=None
|
|
709
|
+
A dictionary of JAX PRNG keys. The dictionary keys of key_dict must
|
|
710
|
+
match that of u_dict. See LossPDEStatio or LossPDENonStatio for
|
|
711
|
+
more details.
|
|
712
|
+
omega_boundary_fun_dict : Dict[str, Callable | Dict[str, Callable] | None], default=None
|
|
713
|
+
A dict of of function or of dict of functions or of None
|
|
714
|
+
(see doc for `omega_boundary_fun` in
|
|
715
|
+
LossPDEStatio or LossPDENonStatio). Default is None.
|
|
716
|
+
Must share the keys of `u_dict`.
|
|
717
|
+
omega_boundary_condition_dict : Dict[str, str | Dict[str, str] | None], default=None
|
|
718
|
+
A dict of strings or of dict of strings or of None
|
|
719
|
+
(see doc for `omega_boundary_condition_dict` in
|
|
720
|
+
LossPDEStatio or LossPDENonStatio). Default is None.
|
|
721
|
+
Must share the keys of `u_dict`
|
|
722
|
+
omega_boundary_dim_dict : Dict[str, slice | Dict[str, slice] | None], default=None
|
|
723
|
+
A dict of slices or of dict of slices or of None
|
|
724
|
+
(see doc for `omega_boundary_dim` in
|
|
725
|
+
LossPDEStatio or LossPDENonStatio). Default is None.
|
|
726
|
+
Must share the keys of `u_dict`
|
|
727
|
+
initial_condition_fun_dict : Dict[str, Callable | None], default=None
|
|
728
|
+
A dict of functions representing the temporal initial condition (None
|
|
729
|
+
value is possible). If None
|
|
730
|
+
(default) then no temporal boundary condition is applied
|
|
731
|
+
Must share the keys of `u_dict`
|
|
732
|
+
norm_samples_dict : Dict[str, Float[Array, "nb_norm_samples dimension"] | None, default=None
|
|
733
|
+
A dict of fixed sample point in the space over which to compute the
|
|
734
|
+
normalization constant. Default is None
|
|
735
|
+
Must share the keys of `u_dict`
|
|
736
|
+
norm_int_length_dict : Dict[str, float | None] | None, default=None
|
|
737
|
+
A dict of Float. The domain area
|
|
738
|
+
(or interval length in 1D) upon which we perform the numerical
|
|
739
|
+
integration for each element of u_dict.
|
|
740
|
+
Default is None
|
|
741
|
+
Must share the keys of `u_dict`
|
|
742
|
+
obs_slice_dict : Dict[str, slice | None] | None, default=None
|
|
743
|
+
dict of obs_slice, with keys from `u_dict` to designate the
|
|
744
|
+
output(s) channels that are forced to observed values, for each
|
|
745
|
+
PINNs. Default is None. But if a value is given, all the entries of
|
|
746
|
+
`u_dict` must be represented here with default value `jnp.s_[...]`
|
|
747
|
+
if no particular slice is to be given
|
|
748
|
+
params : InitVar[ParamsDict], default=None
|
|
749
|
+
The main Params object of the problem needed to instanciate the
|
|
750
|
+
DerivativeKeysODE if the latter is not specified.
|
|
751
|
+
|
|
1008
752
|
"""
|
|
1009
753
|
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
)
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
omega_boundary_condition_dict
|
|
1060
|
-
A dict of dict of strings (see doc for
|
|
1061
|
-
`omega_boundary_condition_dict` in
|
|
1062
|
-
LossPDEStatio or LossPDENonStatio). Default is None.
|
|
1063
|
-
Must share the keys of `u_dict`
|
|
1064
|
-
omega_boundary_dim_dict
|
|
1065
|
-
A dict of dict of slices (see doc for `omega_boundary_dim` in
|
|
1066
|
-
LossPDEStatio or LossPDENonStatio). Default is None.
|
|
1067
|
-
Must share the keys of `u_dict`
|
|
1068
|
-
initial_condition_fun_dict
|
|
1069
|
-
A dict of functions representing the temporal initial condition. If None
|
|
1070
|
-
(default) then no temporal boundary condition is applied
|
|
1071
|
-
Must share the keys of `u_dict`
|
|
1072
|
-
norm_key_dict
|
|
1073
|
-
A dict of Jax random keys to draw samples in for the Monte Carlo computation
|
|
1074
|
-
of the normalization constant. Default is None
|
|
1075
|
-
Must share the keys of `u_dict`
|
|
1076
|
-
norm_borders_dict
|
|
1077
|
-
A dict of tuples of (min, max) of the boundaray values of the space over which
|
|
1078
|
-
to integrate in the computation of the normalization constant.
|
|
1079
|
-
A list of tuple for higher dimensional problems. Default None.
|
|
1080
|
-
Must share the keys of `u_dict`
|
|
1081
|
-
norm_samples_dict
|
|
1082
|
-
A dict of fixed sample point in the space over which to compute the
|
|
1083
|
-
normalization constant. Default is None
|
|
1084
|
-
Must share the keys of `u_dict`
|
|
1085
|
-
sobolev_m
|
|
1086
|
-
Default is None. A dictionary of integers, one per key which must
|
|
1087
|
-
match `u_dict`.
|
|
1088
|
-
It corresponds to the Sobolev regularization order as proposed in
|
|
1089
|
-
*Convergence and error analysis of PINNs*,
|
|
1090
|
-
Doumeche et al., 2023, https://arxiv.org/pdf/2305.01240.pdf
|
|
1091
|
-
obs_slice_dict
|
|
1092
|
-
dict of obs_slice, with keys from `u_dict` to designate the
|
|
1093
|
-
output(s) channels that are forced to observed values, for each
|
|
1094
|
-
PINNs. Default is None. But if a value is given, all the entries of
|
|
1095
|
-
`u_dict` must be represented here with default value `jnp.s_[...]`
|
|
1096
|
-
if no particular slice is to be given
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
Raises
|
|
1100
|
-
------
|
|
1101
|
-
ValueError
|
|
1102
|
-
if initial condition is not a dict of tuple
|
|
1103
|
-
ValueError
|
|
1104
|
-
if the dictionaries that should share the keys of u_dict do not
|
|
1105
|
-
"""
|
|
754
|
+
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
755
|
+
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
756
|
+
u_dict: Dict[str, eqx.Module]
|
|
757
|
+
dynamic_loss_dict: Dict[str, PDEStatio | PDENonStatio]
|
|
758
|
+
key_dict: Dict[str, Key] | None = eqx.field(kw_only=True, default=None)
|
|
759
|
+
derivative_keys_dict: Dict[
|
|
760
|
+
str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio | None
|
|
761
|
+
] = eqx.field(kw_only=True, default=None)
|
|
762
|
+
omega_boundary_fun_dict: Dict[str, Callable | Dict[str, Callable] | None] | None = (
|
|
763
|
+
eqx.field(kw_only=True, default=None, static=True)
|
|
764
|
+
)
|
|
765
|
+
omega_boundary_condition_dict: Dict[str, str | Dict[str, str] | None] | None = (
|
|
766
|
+
eqx.field(kw_only=True, default=None, static=True)
|
|
767
|
+
)
|
|
768
|
+
omega_boundary_dim_dict: Dict[str, slice | Dict[str, slice] | None] | None = (
|
|
769
|
+
eqx.field(kw_only=True, default=None, static=True)
|
|
770
|
+
)
|
|
771
|
+
initial_condition_fun_dict: Dict[str, Callable | None] | None = eqx.field(
|
|
772
|
+
kw_only=True, default=None, static=True
|
|
773
|
+
)
|
|
774
|
+
norm_samples_dict: Dict[str, Float[Array, "nb_norm_samples dimension"]] | None = (
|
|
775
|
+
eqx.field(kw_only=True, default=None)
|
|
776
|
+
)
|
|
777
|
+
norm_int_length_dict: Dict[str, float | None] | None = eqx.field(
|
|
778
|
+
kw_only=True, default=None
|
|
779
|
+
)
|
|
780
|
+
obs_slice_dict: Dict[str, slice | None] | None = eqx.field(
|
|
781
|
+
kw_only=True, default=None, static=True
|
|
782
|
+
)
|
|
783
|
+
|
|
784
|
+
# For the user loss_weights are passed as a LossWeightsPDEDict (with internal
|
|
785
|
+
# dictionary having keys in u_dict and / or dynamic_loss_dict)
|
|
786
|
+
loss_weights: InitVar[LossWeightsPDEDict | None] = eqx.field(
|
|
787
|
+
kw_only=True, default=None
|
|
788
|
+
)
|
|
789
|
+
params_dict: InitVar[ParamsDict] = eqx.field(kw_only=True, default=None)
|
|
790
|
+
|
|
791
|
+
# following have init=False and are set in the __post_init__
|
|
792
|
+
u_constraints_dict: Dict[str, LossPDEStatio | LossPDENonStatio] = eqx.field(
|
|
793
|
+
init=False
|
|
794
|
+
)
|
|
795
|
+
derivative_keys_dyn_loss: DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio = (
|
|
796
|
+
eqx.field(init=False)
|
|
797
|
+
)
|
|
798
|
+
u_dict_with_none: Dict[str, None] = eqx.field(init=False)
|
|
799
|
+
# internally the loss weights are handled with a dictionary
|
|
800
|
+
_loss_weights: Dict[str, dict] = eqx.field(init=False)
|
|
801
|
+
|
|
802
|
+
def __post_init__(self, loss_weights=None, params_dict=None):
|
|
1106
803
|
# a dictionary that will be useful at different places
|
|
1107
|
-
self.u_dict_with_none = {k: None for k in u_dict.keys()}
|
|
804
|
+
self.u_dict_with_none = {k: None for k in self.u_dict.keys()}
|
|
1108
805
|
# First, for all the optional dict,
|
|
1109
806
|
# if the user did not provide at all this optional argument,
|
|
1110
807
|
# we make sure there is a null ponderating loss_weight and we
|
|
1111
808
|
# create a dummy dict with the required keys and all the values to
|
|
1112
809
|
# None
|
|
1113
|
-
if
|
|
810
|
+
if self.key_dict is None:
|
|
811
|
+
self.key_dict = self.u_dict_with_none
|
|
812
|
+
if self.omega_boundary_fun_dict is None:
|
|
1114
813
|
self.omega_boundary_fun_dict = self.u_dict_with_none
|
|
1115
|
-
|
|
1116
|
-
self.omega_boundary_fun_dict = omega_boundary_fun_dict
|
|
1117
|
-
if omega_boundary_condition_dict is None:
|
|
814
|
+
if self.omega_boundary_condition_dict is None:
|
|
1118
815
|
self.omega_boundary_condition_dict = self.u_dict_with_none
|
|
1119
|
-
|
|
1120
|
-
self.omega_boundary_condition_dict = omega_boundary_condition_dict
|
|
1121
|
-
if omega_boundary_dim_dict is None:
|
|
816
|
+
if self.omega_boundary_dim_dict is None:
|
|
1122
817
|
self.omega_boundary_dim_dict = self.u_dict_with_none
|
|
1123
|
-
|
|
1124
|
-
self.omega_boundary_dim_dict = omega_boundary_dim_dict
|
|
1125
|
-
if initial_condition_fun_dict is None:
|
|
818
|
+
if self.initial_condition_fun_dict is None:
|
|
1126
819
|
self.initial_condition_fun_dict = self.u_dict_with_none
|
|
1127
|
-
|
|
1128
|
-
self.initial_condition_fun_dict = initial_condition_fun_dict
|
|
1129
|
-
if norm_key_dict is None:
|
|
1130
|
-
self.norm_key_dict = self.u_dict_with_none
|
|
1131
|
-
else:
|
|
1132
|
-
self.norm_key_dict = norm_key_dict
|
|
1133
|
-
if norm_borders_dict is None:
|
|
1134
|
-
self.norm_borders_dict = self.u_dict_with_none
|
|
1135
|
-
else:
|
|
1136
|
-
self.norm_borders_dict = norm_borders_dict
|
|
1137
|
-
if norm_samples_dict is None:
|
|
820
|
+
if self.norm_samples_dict is None:
|
|
1138
821
|
self.norm_samples_dict = self.u_dict_with_none
|
|
1139
|
-
|
|
1140
|
-
self.
|
|
1141
|
-
if
|
|
1142
|
-
self.
|
|
1143
|
-
|
|
1144
|
-
self.sobolev_m_dict = sobolev_m_dict
|
|
1145
|
-
if obs_slice_dict is None:
|
|
1146
|
-
self.obs_slice_dict = {k: jnp.s_[...] for k in u_dict.keys()}
|
|
1147
|
-
else:
|
|
1148
|
-
self.obs_slice_dict = obs_slice_dict
|
|
1149
|
-
if u_dict.keys() != obs_slice_dict.keys():
|
|
822
|
+
if self.norm_int_length_dict is None:
|
|
823
|
+
self.norm_int_length_dict = self.u_dict_with_none
|
|
824
|
+
if self.obs_slice_dict is None:
|
|
825
|
+
self.obs_slice_dict = {k: jnp.s_[...] for k in self.u_dict.keys()}
|
|
826
|
+
if self.u_dict.keys() != self.obs_slice_dict.keys():
|
|
1150
827
|
raise ValueError("obs_slice_dict should have same keys as u_dict")
|
|
1151
|
-
if derivative_keys_dict is None:
|
|
828
|
+
if self.derivative_keys_dict is None:
|
|
1152
829
|
self.derivative_keys_dict = {
|
|
1153
830
|
k: None
|
|
1154
|
-
for k in set(
|
|
831
|
+
for k in set(
|
|
832
|
+
list(self.dynamic_loss_dict.keys()) + list(self.u_dict.keys())
|
|
833
|
+
)
|
|
1155
834
|
}
|
|
1156
835
|
# set() because we can have duplicate entries and in this case we
|
|
1157
836
|
# say it corresponds to the same derivative_keys_dict entry
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
#
|
|
1164
|
-
#
|
|
1165
|
-
|
|
837
|
+
# we need both because the constraints (all but dyn_loss) will be
|
|
838
|
+
# done by iterating on u_dict while the dyn_loss will be by
|
|
839
|
+
# iterating on dynamic_loss_dict. So each time we will require dome
|
|
840
|
+
# derivative_keys_dict
|
|
841
|
+
|
|
842
|
+
# derivative keys for the u_constraints. Note that we create missing
|
|
843
|
+
# DerivativeKeysODE around a Params object and not ParamsDict
|
|
844
|
+
# this works because u_dict.keys == params_dict.nn_params.keys()
|
|
845
|
+
for k in self.u_dict.keys():
|
|
1166
846
|
if self.derivative_keys_dict[k] is None:
|
|
1167
|
-
self.
|
|
847
|
+
if self.u_dict[k].eq_type == "statio_PDE":
|
|
848
|
+
self.derivative_keys_dict[k] = DerivativeKeysPDEStatio(
|
|
849
|
+
params=params_dict.extract_params(k)
|
|
850
|
+
)
|
|
851
|
+
else:
|
|
852
|
+
self.derivative_keys_dict[k] = DerivativeKeysPDENonStatio(
|
|
853
|
+
params=params_dict.extract_params(k)
|
|
854
|
+
)
|
|
1168
855
|
|
|
1169
856
|
# Second we make sure that all the dicts (except dynamic_loss_dict) have the same keys
|
|
1170
857
|
if (
|
|
1171
|
-
u_dict.keys() !=
|
|
1172
|
-
or u_dict.keys() != self.omega_boundary_fun_dict.keys()
|
|
1173
|
-
or u_dict.keys() != self.omega_boundary_condition_dict.keys()
|
|
1174
|
-
or u_dict.keys() != self.omega_boundary_dim_dict.keys()
|
|
1175
|
-
or u_dict.keys() != self.initial_condition_fun_dict.keys()
|
|
1176
|
-
or u_dict.keys() != self.
|
|
1177
|
-
or u_dict.keys() != self.
|
|
1178
|
-
or u_dict.keys() != self.norm_samples_dict.keys()
|
|
1179
|
-
or u_dict.keys() != self.sobolev_m_dict.keys()
|
|
858
|
+
self.u_dict.keys() != self.key_dict.keys()
|
|
859
|
+
or self.u_dict.keys() != self.omega_boundary_fun_dict.keys()
|
|
860
|
+
or self.u_dict.keys() != self.omega_boundary_condition_dict.keys()
|
|
861
|
+
or self.u_dict.keys() != self.omega_boundary_dim_dict.keys()
|
|
862
|
+
or self.u_dict.keys() != self.initial_condition_fun_dict.keys()
|
|
863
|
+
or self.u_dict.keys() != self.norm_samples_dict.keys()
|
|
864
|
+
or self.u_dict.keys() != self.norm_int_length_dict.keys()
|
|
1180
865
|
):
|
|
1181
866
|
raise ValueError("All the dicts concerning the PINNs should have same keys")
|
|
1182
867
|
|
|
1183
|
-
self.
|
|
1184
|
-
self.u_dict = u_dict
|
|
1185
|
-
# TODO nn_type should become a class attribute now that we have PINN
|
|
1186
|
-
# class and SPINNs class
|
|
1187
|
-
self.nn_type_dict = nn_type_dict
|
|
1188
|
-
|
|
1189
|
-
self.loss_weights = loss_weights # This calls the setter
|
|
868
|
+
self._loss_weights = self.set_loss_weights(loss_weights)
|
|
1190
869
|
|
|
1191
870
|
# Third, in order not to benefit from LossPDEStatio and
|
|
1192
871
|
# LossPDENonStatio and in order to factorize code, we create internally
|
|
@@ -1194,95 +873,93 @@ class SystemLossPDE:
|
|
|
1194
873
|
# We will not use the dynamic loss term
|
|
1195
874
|
self.u_constraints_dict = {}
|
|
1196
875
|
for i in self.u_dict.keys():
|
|
1197
|
-
if self.
|
|
876
|
+
if self.u_dict[i].eq_type == "statio_PDE":
|
|
1198
877
|
self.u_constraints_dict[i] = LossPDEStatio(
|
|
1199
|
-
u=u_dict[i],
|
|
1200
|
-
loss_weights=
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
|
|
878
|
+
u=self.u_dict[i],
|
|
879
|
+
loss_weights=LossWeightsPDENonStatio(
|
|
880
|
+
dyn_loss=0.0,
|
|
881
|
+
norm_loss=1.0,
|
|
882
|
+
boundary_loss=1.0,
|
|
883
|
+
observations=1.0,
|
|
884
|
+
initial_condition=1.0,
|
|
885
|
+
),
|
|
1207
886
|
dynamic_loss=None,
|
|
887
|
+
key=self.key_dict[i],
|
|
1208
888
|
derivative_keys=self.derivative_keys_dict[i],
|
|
1209
889
|
omega_boundary_fun=self.omega_boundary_fun_dict[i],
|
|
1210
890
|
omega_boundary_condition=self.omega_boundary_condition_dict[i],
|
|
1211
891
|
omega_boundary_dim=self.omega_boundary_dim_dict[i],
|
|
1212
|
-
norm_key=self.norm_key_dict[i],
|
|
1213
|
-
norm_borders=self.norm_borders_dict[i],
|
|
1214
892
|
norm_samples=self.norm_samples_dict[i],
|
|
1215
|
-
|
|
893
|
+
norm_int_length=self.norm_int_length_dict[i],
|
|
1216
894
|
obs_slice=self.obs_slice_dict[i],
|
|
1217
895
|
)
|
|
1218
|
-
elif self.
|
|
896
|
+
elif self.u_dict[i].eq_type == "nonstatio_PDE":
|
|
1219
897
|
self.u_constraints_dict[i] = LossPDENonStatio(
|
|
1220
|
-
u=u_dict[i],
|
|
1221
|
-
loss_weights=
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
},
|
|
898
|
+
u=self.u_dict[i],
|
|
899
|
+
loss_weights=LossWeightsPDENonStatio(
|
|
900
|
+
dyn_loss=0.0,
|
|
901
|
+
norm_loss=1.0,
|
|
902
|
+
boundary_loss=1.0,
|
|
903
|
+
observations=1.0,
|
|
904
|
+
initial_condition=1.0,
|
|
905
|
+
),
|
|
1229
906
|
dynamic_loss=None,
|
|
907
|
+
key=self.key_dict[i],
|
|
1230
908
|
derivative_keys=self.derivative_keys_dict[i],
|
|
1231
909
|
omega_boundary_fun=self.omega_boundary_fun_dict[i],
|
|
1232
910
|
omega_boundary_condition=self.omega_boundary_condition_dict[i],
|
|
1233
911
|
omega_boundary_dim=self.omega_boundary_dim_dict[i],
|
|
1234
912
|
initial_condition_fun=self.initial_condition_fun_dict[i],
|
|
1235
|
-
norm_key=self.norm_key_dict[i],
|
|
1236
|
-
norm_borders=self.norm_borders_dict[i],
|
|
1237
913
|
norm_samples=self.norm_samples_dict[i],
|
|
1238
|
-
|
|
914
|
+
norm_int_length=self.norm_int_length_dict[i],
|
|
915
|
+
obs_slice=self.obs_slice_dict[i],
|
|
1239
916
|
)
|
|
1240
917
|
else:
|
|
1241
918
|
raise ValueError(
|
|
1242
|
-
|
|
919
|
+
"Wrong value for self.u_dict[i].eq_type[i], "
|
|
920
|
+
f"got {self.u_dict[i].eq_type[i]}"
|
|
1243
921
|
)
|
|
1244
922
|
|
|
1245
|
-
#
|
|
1246
|
-
#
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
}
|
|
1251
|
-
self.derivative_keys_u_dict = {
|
|
1252
|
-
k: self.derivative_keys_dict[k]
|
|
1253
|
-
for k in self.u_dict.keys() & self.derivative_keys_dict.keys()
|
|
1254
|
-
}
|
|
923
|
+
# derivative keys for the dynamic loss. Note that we create a
|
|
924
|
+
# DerivativeKeysODE around a ParamsDict object because a whole
|
|
925
|
+
# params_dict is feed to DynamicLoss.evaluate functions (extract_params
|
|
926
|
+
# happen inside it)
|
|
927
|
+
self.derivative_keys_dyn_loss = DerivativeKeysPDENonStatio(params=params_dict)
|
|
1255
928
|
|
|
1256
929
|
# also make sure we only have PINNs or SPINNs
|
|
1257
930
|
if not (
|
|
1258
|
-
all(isinstance(value, PINN) for value in u_dict.values())
|
|
1259
|
-
or all(isinstance(value, SPINN) for value in u_dict.values())
|
|
931
|
+
all(isinstance(value, PINN) for value in self.u_dict.values())
|
|
932
|
+
or all(isinstance(value, SPINN) for value in self.u_dict.values())
|
|
1260
933
|
):
|
|
1261
934
|
raise ValueError(
|
|
1262
935
|
"We only accept dictionary of PINNs or dictionary of SPINNs"
|
|
1263
936
|
)
|
|
1264
937
|
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
|
-
|
|
1270
|
-
|
|
1271
|
-
|
|
1272
|
-
|
|
938
|
+
def set_loss_weights(
|
|
939
|
+
self, loss_weights_init: LossWeightsPDEDict
|
|
940
|
+
) -> dict[str, dict]:
|
|
941
|
+
"""
|
|
942
|
+
This rather complex function enables the user to specify a simple
|
|
943
|
+
loss_weights=LossWeightsPDEDict(dyn_loss=1., initial_condition=Tmax)
|
|
944
|
+
for ponderating values being applied to all the equations of the
|
|
945
|
+
system... So all the transformations are handled here
|
|
946
|
+
"""
|
|
947
|
+
_loss_weights = {}
|
|
948
|
+
for k in fields(loss_weights_init):
|
|
949
|
+
v = getattr(loss_weights_init, k.name)
|
|
1273
950
|
if isinstance(v, dict):
|
|
1274
|
-
for
|
|
951
|
+
for vv in v.keys():
|
|
1275
952
|
if not isinstance(vv, (int, float)) and not (
|
|
1276
|
-
isinstance(vv,
|
|
953
|
+
isinstance(vv, Array)
|
|
1277
954
|
and ((vv.shape == (1,) or len(vv.shape) == 0))
|
|
1278
955
|
):
|
|
1279
956
|
# TODO improve that
|
|
1280
957
|
raise ValueError(
|
|
1281
958
|
f"loss values cannot be vectorial here, got {vv}"
|
|
1282
959
|
)
|
|
1283
|
-
if k == "dyn_loss":
|
|
960
|
+
if k.name == "dyn_loss":
|
|
1284
961
|
if v.keys() == self.dynamic_loss_dict.keys():
|
|
1285
|
-
|
|
962
|
+
_loss_weights[k.name] = v
|
|
1286
963
|
else:
|
|
1287
964
|
raise ValueError(
|
|
1288
965
|
"Keys in nested dictionary of loss_weights"
|
|
@@ -1290,51 +967,36 @@ class SystemLossPDE:
|
|
|
1290
967
|
)
|
|
1291
968
|
else:
|
|
1292
969
|
if v.keys() == self.u_dict.keys():
|
|
1293
|
-
|
|
970
|
+
_loss_weights[k.name] = v
|
|
1294
971
|
else:
|
|
1295
972
|
raise ValueError(
|
|
1296
973
|
"Keys in nested dictionary of loss_weights"
|
|
1297
974
|
" do not match u_dict keys"
|
|
1298
975
|
)
|
|
976
|
+
if v is None:
|
|
977
|
+
_loss_weights[k.name] = {kk: 0 for kk in self.u_dict.keys()}
|
|
1299
978
|
else:
|
|
1300
979
|
if not isinstance(v, (int, float)) and not (
|
|
1301
|
-
isinstance(v,
|
|
1302
|
-
and ((v.shape == (1,) or len(v.shape) == 0))
|
|
980
|
+
isinstance(v, Array) and ((v.shape == (1,) or len(v.shape) == 0))
|
|
1303
981
|
):
|
|
1304
982
|
# TODO improve that
|
|
1305
983
|
raise ValueError(f"loss values cannot be vectorial here, got {v}")
|
|
1306
|
-
if k == "dyn_loss":
|
|
1307
|
-
|
|
984
|
+
if k.name == "dyn_loss":
|
|
985
|
+
_loss_weights[k.name] = {
|
|
1308
986
|
kk: v for kk in self.dynamic_loss_dict.keys()
|
|
1309
987
|
}
|
|
1310
988
|
else:
|
|
1311
|
-
|
|
1312
|
-
|
|
1313
|
-
if all(v is None for k, v in self.sobolev_m_dict.items()):
|
|
1314
|
-
self._loss_weights["sobolev"] = {k: 0 for k in self.u_dict.keys()}
|
|
1315
|
-
if "observations" not in value.keys():
|
|
1316
|
-
self._loss_weights["observations"] = {k: 0 for k in self.u_dict.keys()}
|
|
1317
|
-
if all(v is None for k, v in self.omega_boundary_fun_dict.items()) or all(
|
|
1318
|
-
v is None for k, v in self.omega_boundary_condition_dict.items()
|
|
1319
|
-
):
|
|
1320
|
-
self._loss_weights["boundary_loss"] = {k: 0 for k in self.u_dict.keys()}
|
|
1321
|
-
if (
|
|
1322
|
-
all(v is None for k, v in self.norm_key_dict.items())
|
|
1323
|
-
or all(v is None for k, v in self.norm_borders_dict.items())
|
|
1324
|
-
or all(v is None for k, v in self.norm_samples_dict.items())
|
|
1325
|
-
):
|
|
1326
|
-
self._loss_weights["norm_loss"] = {k: 0 for k in self.u_dict.keys()}
|
|
1327
|
-
if all(v is None for k, v in self.initial_condition_fun_dict.items()):
|
|
1328
|
-
self._loss_weights["initial_condition"] = {k: 0 for k in self.u_dict.keys()}
|
|
989
|
+
_loss_weights[k.name] = {kk: v for kk in self.u_dict.keys()}
|
|
990
|
+
return _loss_weights
|
|
1329
991
|
|
|
1330
992
|
def __call__(self, *args, **kwargs):
|
|
1331
993
|
return self.evaluate(*args, **kwargs)
|
|
1332
994
|
|
|
1333
995
|
def evaluate(
|
|
1334
996
|
self,
|
|
1335
|
-
params_dict,
|
|
1336
|
-
batch,
|
|
1337
|
-
):
|
|
997
|
+
params_dict: ParamsDict,
|
|
998
|
+
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
999
|
+
) -> tuple[Float[Array, "1"], dict[str, float]]:
|
|
1338
1000
|
"""
|
|
1339
1001
|
Evaluate the loss function at a batch of points for given parameters.
|
|
1340
1002
|
|
|
@@ -1342,12 +1004,8 @@ class SystemLossPDE:
|
|
|
1342
1004
|
Parameters
|
|
1343
1005
|
---------
|
|
1344
1006
|
params_dict
|
|
1345
|
-
|
|
1346
|
-
Typically, it is a dictionary of dictionaries of
|
|
1347
|
-
dictionaries: `eq_params` and `nn_params``, respectively the
|
|
1348
|
-
differential equation parameters and the neural network parameter
|
|
1007
|
+
Parameters at which the losses of the system are evaluated
|
|
1349
1008
|
batch
|
|
1350
|
-
A PDEStatioBatch or PDENonStatioBatch object.
|
|
1351
1009
|
Such named tuples are composed of batch of points in the
|
|
1352
1010
|
domain, a batch of points in the domain
|
|
1353
1011
|
border, (a batch of time points a for PDENonStatioBatch) and an
|
|
@@ -1355,7 +1013,7 @@ class SystemLossPDE:
|
|
|
1355
1013
|
and an optional additional batch of observed
|
|
1356
1014
|
inputs/outputs/parameters
|
|
1357
1015
|
"""
|
|
1358
|
-
if self.u_dict.keys() != params_dict
|
|
1016
|
+
if self.u_dict.keys() != params_dict.nn_params.keys():
|
|
1359
1017
|
raise ValueError("u_dict and params_dict[nn_params] should have same keys ")
|
|
1360
1018
|
|
|
1361
1019
|
if isinstance(batch, PDEStatioBatch):
|
|
@@ -1378,22 +1036,22 @@ class SystemLossPDE:
|
|
|
1378
1036
|
if batch.param_batch_dict is not None:
|
|
1379
1037
|
eq_params_batch_dict = batch.param_batch_dict
|
|
1380
1038
|
|
|
1039
|
+
# TODO
|
|
1381
1040
|
# feed the eq_params with the batch
|
|
1382
1041
|
for k in eq_params_batch_dict.keys():
|
|
1383
|
-
params_dict
|
|
1042
|
+
params_dict.eq_params[k] = eq_params_batch_dict[k]
|
|
1384
1043
|
|
|
1385
1044
|
vmap_in_axes_params = _get_vmap_in_axes_params(
|
|
1386
1045
|
batch.param_batch_dict, params_dict
|
|
1387
1046
|
)
|
|
1388
1047
|
|
|
1389
|
-
def dyn_loss_for_one_key(dyn_loss,
|
|
1048
|
+
def dyn_loss_for_one_key(dyn_loss, loss_weight):
|
|
1390
1049
|
"""The function used in tree_map"""
|
|
1391
|
-
params_dict_ = _set_derivatives(params_dict, "dyn_loss", derivative_key)
|
|
1392
1050
|
return dynamic_loss_apply(
|
|
1393
1051
|
dyn_loss.evaluate,
|
|
1394
1052
|
self.u_dict,
|
|
1395
1053
|
batches,
|
|
1396
|
-
|
|
1054
|
+
_set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
|
|
1397
1055
|
vmap_in_axes_x_or_x_t + vmap_in_axes_params,
|
|
1398
1056
|
loss_weight,
|
|
1399
1057
|
u_type=type(list(self.u_dict.values())[0]),
|
|
@@ -1402,8 +1060,14 @@ class SystemLossPDE:
|
|
|
1402
1060
|
dyn_loss_mse_dict = jax.tree_util.tree_map(
|
|
1403
1061
|
dyn_loss_for_one_key,
|
|
1404
1062
|
self.dynamic_loss_dict,
|
|
1405
|
-
self.derivative_keys_dyn_loss_dict,
|
|
1406
1063
|
self._loss_weights["dyn_loss"],
|
|
1064
|
+
is_leaf=lambda x: isinstance(
|
|
1065
|
+
x, (PDEStatio, PDENonStatio)
|
|
1066
|
+
), # before when dynamic losses
|
|
1067
|
+
# where plain (unregister pytree) node classes, we could not traverse
|
|
1068
|
+
# this level. Now that dynamic losses are eqx.Module they can be
|
|
1069
|
+
# traversed by tree map recursion. Hence we need to specify to that
|
|
1070
|
+
# we want to stop at this level
|
|
1407
1071
|
)
|
|
1408
1072
|
mse_dyn_loss = jax.tree_util.tree_reduce(
|
|
1409
1073
|
lambda x, y: x + y, jax.tree_util.tree_leaves(dyn_loss_mse_dict)
|
|
@@ -1418,11 +1082,10 @@ class SystemLossPDE:
|
|
|
1418
1082
|
"boundary_loss": "*",
|
|
1419
1083
|
"observations": "*",
|
|
1420
1084
|
"initial_condition": "*",
|
|
1421
|
-
"sobolev": "*",
|
|
1422
1085
|
}
|
|
1423
1086
|
# we need to do the following for the tree_mapping to work
|
|
1424
1087
|
if batch.obs_batch_dict is None:
|
|
1425
|
-
batch = batch
|
|
1088
|
+
batch = append_obs_batch(batch, self.u_dict_with_none)
|
|
1426
1089
|
total_loss, res_dict = constraints_system_loss_apply(
|
|
1427
1090
|
self.u_constraints_dict,
|
|
1428
1091
|
batch,
|
|
@@ -1435,41 +1098,3 @@ class SystemLossPDE:
|
|
|
1435
1098
|
total_loss += mse_dyn_loss
|
|
1436
1099
|
res_dict["dyn_loss"] += mse_dyn_loss
|
|
1437
1100
|
return total_loss, res_dict
|
|
1438
|
-
|
|
1439
|
-
def tree_flatten(self):
|
|
1440
|
-
children = (
|
|
1441
|
-
self.norm_key_dict,
|
|
1442
|
-
self.norm_samples_dict,
|
|
1443
|
-
self.initial_condition_fun_dict,
|
|
1444
|
-
self._loss_weights,
|
|
1445
|
-
)
|
|
1446
|
-
aux_data = {
|
|
1447
|
-
"u_dict": self.u_dict,
|
|
1448
|
-
"dynamic_loss_dict": self.dynamic_loss_dict,
|
|
1449
|
-
"norm_borders_dict": self.norm_borders_dict,
|
|
1450
|
-
"omega_boundary_fun_dict": self.omega_boundary_fun_dict,
|
|
1451
|
-
"omega_boundary_condition_dict": self.omega_boundary_condition_dict,
|
|
1452
|
-
"nn_type_dict": self.nn_type_dict,
|
|
1453
|
-
"sobolev_m_dict": self.sobolev_m_dict,
|
|
1454
|
-
"derivative_keys_dict": self.derivative_keys_dict,
|
|
1455
|
-
"obs_slice_dict": self.obs_slice_dict,
|
|
1456
|
-
}
|
|
1457
|
-
return (children, aux_data)
|
|
1458
|
-
|
|
1459
|
-
@classmethod
|
|
1460
|
-
def tree_unflatten(cls, aux_data, children):
|
|
1461
|
-
(
|
|
1462
|
-
norm_key_dict,
|
|
1463
|
-
norm_samples_dict,
|
|
1464
|
-
initial_condition_fun_dict,
|
|
1465
|
-
loss_weights,
|
|
1466
|
-
) = children
|
|
1467
|
-
loss_ode = cls(
|
|
1468
|
-
loss_weights=loss_weights,
|
|
1469
|
-
norm_key_dict=norm_key_dict,
|
|
1470
|
-
norm_samples_dict=norm_samples_dict,
|
|
1471
|
-
initial_condition_fun_dict=initial_condition_fun_dict,
|
|
1472
|
-
**aux_data,
|
|
1473
|
-
)
|
|
1474
|
-
|
|
1475
|
-
return loss_ode
|