jinns 0.8.10__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -0
- jinns/data/_Batchs.py +27 -0
- jinns/data/_DataGenerators.py +953 -1182
- jinns/data/__init__.py +4 -8
- jinns/experimental/__init__.py +0 -2
- jinns/experimental/_diffrax_solver.py +5 -5
- jinns/loss/_DynamicLoss.py +282 -305
- jinns/loss/_DynamicLossAbstract.py +321 -168
- jinns/loss/_LossODE.py +290 -307
- jinns/loss/_LossPDE.py +628 -1040
- jinns/loss/__init__.py +21 -5
- jinns/loss/_boundary_conditions.py +95 -96
- jinns/loss/{_Losses.py → _loss_utils.py} +104 -46
- jinns/loss/_loss_weights.py +59 -0
- jinns/loss/_operators.py +78 -72
- jinns/parameters/__init__.py +6 -0
- jinns/parameters/_derivative_keys.py +94 -0
- jinns/parameters/_params.py +115 -0
- jinns/plot/__init__.py +5 -0
- jinns/{data/_display.py → plot/_plot.py} +98 -75
- jinns/solver/_rar.py +193 -45
- jinns/solver/_solve.py +199 -144
- jinns/utils/__init__.py +3 -9
- jinns/utils/_containers.py +37 -43
- jinns/utils/_hyperpinn.py +226 -127
- jinns/utils/_pinn.py +183 -111
- jinns/utils/_save_load.py +121 -56
- jinns/utils/_spinn.py +117 -84
- jinns/utils/_types.py +64 -0
- jinns/utils/_utils.py +6 -160
- jinns/validation/_validation.py +52 -144
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/METADATA +5 -4
- jinns-1.0.0.dist-info/RECORD +38 -0
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/WHEEL +1 -1
- jinns/experimental/_sinuspinn.py +0 -135
- jinns/experimental/_spectralpinn.py +0 -87
- jinns/solver/_seq2seq.py +0 -157
- jinns/utils/_optim.py +0 -147
- jinns/utils/_utils_uspinn.py +0 -727
- jinns-0.8.10.dist-info/RECORD +0 -36
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/LICENSE +0 -0
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossPDE.py
CHANGED
|
@@ -1,29 +1,52 @@
|
|
|
1
|
+
# pylint: disable=unsubscriptable-object, no-member
|
|
1
2
|
"""
|
|
2
3
|
Main module to implement a PDE loss in jinns
|
|
3
4
|
"""
|
|
5
|
+
from __future__ import (
|
|
6
|
+
annotations,
|
|
7
|
+
) # https://docs.python.org/3/library/typing.html#constant
|
|
4
8
|
|
|
9
|
+
import abc
|
|
10
|
+
from dataclasses import InitVar, fields
|
|
11
|
+
from typing import TYPE_CHECKING, Dict, Callable
|
|
5
12
|
import warnings
|
|
6
13
|
import jax
|
|
7
14
|
import jax.numpy as jnp
|
|
8
|
-
|
|
9
|
-
from
|
|
15
|
+
import equinox as eqx
|
|
16
|
+
from jaxtyping import Float, Array, Key, Int
|
|
17
|
+
from jinns.loss._loss_utils import (
|
|
10
18
|
dynamic_loss_apply,
|
|
11
19
|
boundary_condition_apply,
|
|
12
20
|
normalization_loss_apply,
|
|
13
21
|
observations_loss_apply,
|
|
14
|
-
sobolev_reg_apply,
|
|
15
22
|
initial_condition_apply,
|
|
16
23
|
constraints_system_loss_apply,
|
|
17
24
|
)
|
|
18
|
-
from jinns.data._DataGenerators import
|
|
19
|
-
|
|
25
|
+
from jinns.data._DataGenerators import (
|
|
26
|
+
append_obs_batch,
|
|
27
|
+
)
|
|
28
|
+
from jinns.parameters._params import (
|
|
20
29
|
_get_vmap_in_axes_params,
|
|
21
|
-
_set_derivatives,
|
|
22
30
|
_update_eq_params_dict,
|
|
23
31
|
)
|
|
32
|
+
from jinns.parameters._derivative_keys import (
|
|
33
|
+
_set_derivatives,
|
|
34
|
+
DerivativeKeysPDEStatio,
|
|
35
|
+
DerivativeKeysPDENonStatio,
|
|
36
|
+
)
|
|
37
|
+
from jinns.loss._loss_weights import (
|
|
38
|
+
LossWeightsPDEStatio,
|
|
39
|
+
LossWeightsPDENonStatio,
|
|
40
|
+
LossWeightsPDEDict,
|
|
41
|
+
)
|
|
42
|
+
from jinns.loss._DynamicLossAbstract import PDEStatio, PDENonStatio
|
|
24
43
|
from jinns.utils._pinn import PINN
|
|
25
44
|
from jinns.utils._spinn import SPINN
|
|
26
|
-
from jinns.
|
|
45
|
+
from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
if TYPE_CHECKING:
|
|
49
|
+
from jinns.utils._types import *
|
|
27
50
|
|
|
28
51
|
_IMPLEMENTED_BOUNDARY_CONDITIONS = [
|
|
29
52
|
"dirichlet",
|
|
@@ -31,375 +54,156 @@ _IMPLEMENTED_BOUNDARY_CONDITIONS = [
|
|
|
31
54
|
"vonneumann",
|
|
32
55
|
]
|
|
33
56
|
|
|
34
|
-
_LOSS_WEIGHT_KEYS_PDESTATIO = [
|
|
35
|
-
"sobolev",
|
|
36
|
-
"observations",
|
|
37
|
-
"norm_loss",
|
|
38
|
-
"boundary_loss",
|
|
39
|
-
"dyn_loss",
|
|
40
|
-
]
|
|
41
|
-
|
|
42
|
-
_LOSS_WEIGHT_KEYS_PDENONSTATIO = _LOSS_WEIGHT_KEYS_PDESTATIO + ["initial_condition"]
|
|
43
57
|
|
|
44
|
-
|
|
45
|
-
@register_pytree_node_class
|
|
46
|
-
class LossPDEAbstract:
|
|
58
|
+
class _LossPDEAbstract(eqx.Module):
|
|
47
59
|
"""
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
|
|
63
|
+
loss_weights : LossWeightsPDEStatio | LossWeightsPDENonStatio, default=None
|
|
64
|
+
The loss weights for the differents term : dynamic loss,
|
|
65
|
+
initial condition (if LossWeightsPDENonStatio), boundary conditions if
|
|
66
|
+
any, normalization loss if any and observations if any.
|
|
67
|
+
All fields are set to 1.0 by default.
|
|
68
|
+
derivative_keys : DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio, default=None
|
|
69
|
+
Specify which field of `params` should be differentiated for each
|
|
70
|
+
composant of the total loss. Particularily useful for inverse problems.
|
|
71
|
+
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
72
|
+
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
73
|
+
is `"nn_params"` for each composant of the loss.
|
|
74
|
+
omega_boundary_fun : Callable | Dict[str, Callable], default=None
|
|
75
|
+
The function to be matched in the border condition (can be None) or a
|
|
76
|
+
dictionary of such functions as values and keys as described
|
|
77
|
+
in `omega_boundary_condition`.
|
|
78
|
+
omega_boundary_condition : str | Dict[str, str], default=None
|
|
79
|
+
Either None (no condition, by default), or a string defining
|
|
80
|
+
the boundary condition (Dirichlet or Von Neumann),
|
|
81
|
+
or a dictionary with such strings as values. In this case,
|
|
82
|
+
the keys are the facets and must be in the following order:
|
|
83
|
+
1D -> [“xmin”, “xmax”], 2D -> [“xmin”, “xmax”, “ymin”, “ymax”].
|
|
84
|
+
Note that high order boundaries are currently not implemented.
|
|
85
|
+
A value in the dict can be None, this means we do not enforce
|
|
86
|
+
a particular boundary condition on this facet.
|
|
87
|
+
The facet called “xmin”, resp. “xmax” etc., in 2D,
|
|
88
|
+
refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
|
|
89
|
+
omega_boundary_dim : slice | Dict[str, slice], default=None
|
|
90
|
+
Either None, or a slice object or a dictionary of slice objects as
|
|
91
|
+
values and keys as described in `omega_boundary_condition`.
|
|
92
|
+
`omega_boundary_dim` indicates which dimension(s) of the PINN
|
|
93
|
+
will be forced to match the boundary condition.
|
|
94
|
+
Note that it must be a slice and not an integer
|
|
95
|
+
(but a preprocessing of the user provided argument takes care of it)
|
|
96
|
+
norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
|
|
97
|
+
Fixed sample point in the space over which to compute the
|
|
98
|
+
normalization constant. Default is None.
|
|
99
|
+
norm_int_length : float, default=None
|
|
100
|
+
A float. Must be provided if `norm_samples` is provided. The domain area
|
|
101
|
+
(or interval length in 1D) upon which we perform the numerical
|
|
102
|
+
integration. Default None
|
|
103
|
+
obs_slice : slice, default=None
|
|
104
|
+
slice object specifying the begininning/ending of the PINN output
|
|
105
|
+
that is observed (this is then useful for multidim PINN). Default is None.
|
|
55
106
|
"""
|
|
56
107
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
)
|
|
108
|
+
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
109
|
+
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
110
|
+
# kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
|
|
111
|
+
derivative_keys: DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio | None = (
|
|
112
|
+
eqx.field(kw_only=True, default=None)
|
|
113
|
+
)
|
|
114
|
+
loss_weights: LossWeightsPDEStatio | LossWeightsPDENonStatio | None = eqx.field(
|
|
115
|
+
kw_only=True, default=None
|
|
116
|
+
)
|
|
117
|
+
omega_boundary_fun: Callable | Dict[str, Callable] | None = eqx.field(
|
|
118
|
+
kw_only=True, default=None, static=True
|
|
119
|
+
)
|
|
120
|
+
omega_boundary_condition: str | Dict[str, str] | None = eqx.field(
|
|
121
|
+
kw_only=True, default=None, static=True
|
|
122
|
+
)
|
|
123
|
+
omega_boundary_dim: slice | Dict[str, slice] | None = eqx.field(
|
|
124
|
+
kw_only=True, default=None, static=True
|
|
125
|
+
)
|
|
126
|
+
norm_samples: Float[Array, "nb_norm_samples dimension"] | None = eqx.field(
|
|
127
|
+
kw_only=True, default=None
|
|
128
|
+
)
|
|
129
|
+
norm_int_length: float | None = eqx.field(kw_only=True, default=None)
|
|
130
|
+
obs_slice: slice | None = eqx.field(kw_only=True, default=None, static=True)
|
|
131
|
+
|
|
132
|
+
def __post_init__(self):
|
|
66
133
|
"""
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
u
|
|
70
|
-
the PINN object
|
|
71
|
-
loss_weights
|
|
72
|
-
a dictionary with values used to ponderate each term in the loss
|
|
73
|
-
function. Valid keys are `dyn_loss`, `norm_loss`, `boundary_loss`
|
|
74
|
-
and `observations`
|
|
75
|
-
Note that we can have jnp.arrays with the same dimension of
|
|
76
|
-
`u` which then ponderates each output of `u`
|
|
77
|
-
derivative_keys
|
|
78
|
-
A dict of lists of strings. In the dict, the key must correspond to
|
|
79
|
-
the loss term keywords. Then each of the values must correspond to keys in the parameter
|
|
80
|
-
dictionary (*at top level only of the parameter dictionary*).
|
|
81
|
-
It enables selecting the set of parameters
|
|
82
|
-
with respect to which the gradients of the dynamic
|
|
83
|
-
loss are computed. If nothing is provided, we set ["nn_params"] for all loss term
|
|
84
|
-
keywords, this is what is typically
|
|
85
|
-
done in solving forward problems, when we only estimate the
|
|
86
|
-
equation solution with a PINN. If some loss terms keywords are
|
|
87
|
-
missing we set their value to ["nn_params"] by default for the
|
|
88
|
-
same reason
|
|
89
|
-
norm_key
|
|
90
|
-
Jax random key to draw samples in for the Monte Carlo computation
|
|
91
|
-
of the normalization constant. Default is None
|
|
92
|
-
norm_borders
|
|
93
|
-
tuple of (min, max) of the boundaray values of the space over which
|
|
94
|
-
to integrate in the computation of the normalization constant.
|
|
95
|
-
A list of tuple for higher dimensional problems. Default None.
|
|
96
|
-
norm_samples
|
|
97
|
-
Fixed sample point in the space over which to compute the
|
|
98
|
-
normalization constant. Default is None
|
|
99
|
-
|
|
100
|
-
Raises
|
|
101
|
-
------
|
|
102
|
-
RuntimeError
|
|
103
|
-
When provided an invalid combination of `norm_key`, `norm_borders`
|
|
104
|
-
and `norm_samples`. See note below.
|
|
105
|
-
|
|
106
|
-
**Note:** If `norm_key` and `norm_borders` and `norm_samples` are `None`
|
|
107
|
-
then no normalization loss in enforced.
|
|
108
|
-
If `norm_borders` and `norm_samples` are given while
|
|
109
|
-
`norm_samples` is `None` then samples are drawn at each loss evaluation.
|
|
110
|
-
Otherwise, if `norm_samples` is given, those samples are used.
|
|
134
|
+
Note that neither __init__ or __post_init__ are called when udating a
|
|
135
|
+
Module with eqx.tree_at
|
|
111
136
|
"""
|
|
112
|
-
|
|
113
|
-
self.u = u
|
|
114
|
-
if derivative_keys is None:
|
|
137
|
+
if self.derivative_keys is None:
|
|
115
138
|
# be default we only take gradient wrt nn_params
|
|
116
|
-
derivative_keys =
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
"boundary_loss",
|
|
121
|
-
"norm_loss",
|
|
122
|
-
"initial_condition",
|
|
123
|
-
"observations",
|
|
124
|
-
"sobolev",
|
|
125
|
-
]
|
|
126
|
-
}
|
|
127
|
-
if isinstance(derivative_keys, list):
|
|
128
|
-
# if the user only provided a list, this defines the gradient taken
|
|
129
|
-
# for all the loss entries
|
|
130
|
-
derivative_keys = {
|
|
131
|
-
k: derivative_keys
|
|
132
|
-
for k in [
|
|
133
|
-
"dyn_loss",
|
|
134
|
-
"boundary_loss",
|
|
135
|
-
"norm_loss",
|
|
136
|
-
"initial_condition",
|
|
137
|
-
"observations",
|
|
138
|
-
"sobolev",
|
|
139
|
-
]
|
|
140
|
-
}
|
|
141
|
-
|
|
142
|
-
self.derivative_keys = derivative_keys
|
|
143
|
-
self.loss_weights = loss_weights
|
|
144
|
-
self.norm_borders = norm_borders
|
|
145
|
-
self.norm_key = norm_key
|
|
146
|
-
self.norm_samples = norm_samples
|
|
147
|
-
|
|
148
|
-
if norm_key is None and norm_borders is None and norm_samples is None:
|
|
149
|
-
# if there is None of the 3 above, that means we don't consider
|
|
150
|
-
# normalization loss
|
|
151
|
-
self.normalization_loss = None
|
|
152
|
-
elif (
|
|
153
|
-
norm_key is not None and norm_borders is not None and norm_samples is None
|
|
154
|
-
): # this ordering so that by default priority is to given mc_samples
|
|
155
|
-
self.norm_sample_method = "generate"
|
|
156
|
-
if not isinstance(self.norm_borders[0], tuple):
|
|
157
|
-
self.norm_borders = (self.norm_borders,)
|
|
158
|
-
self.norm_xmin, self.norm_xmax = [], []
|
|
159
|
-
for i, _ in enumerate(self.norm_borders):
|
|
160
|
-
self.norm_xmin.append(self.norm_borders[i][0])
|
|
161
|
-
self.norm_xmax.append(self.norm_borders[i][1])
|
|
162
|
-
self.int_length = jnp.prod(
|
|
163
|
-
jnp.array(
|
|
164
|
-
[
|
|
165
|
-
self.norm_xmax[i] - self.norm_xmin[i]
|
|
166
|
-
for i in range(len(self.norm_borders))
|
|
167
|
-
]
|
|
168
|
-
)
|
|
139
|
+
self.derivative_keys = (
|
|
140
|
+
DerivativeKeysPDENonStatio()
|
|
141
|
+
if isinstance(self, LossPDENonStatio)
|
|
142
|
+
else DerivativeKeysPDEStatio()
|
|
169
143
|
)
|
|
170
|
-
self.normalization_loss = True
|
|
171
|
-
elif norm_samples is None:
|
|
172
|
-
raise RuntimeError(
|
|
173
|
-
"norm_borders should always provided then either"
|
|
174
|
-
" norm_samples (fixed norm_samples) or norm_key (random norm_samples)"
|
|
175
|
-
" is required."
|
|
176
|
-
)
|
|
177
|
-
else:
|
|
178
|
-
# ok, we are sure we have norm_samples given by the user
|
|
179
|
-
self.norm_sample_method = "user"
|
|
180
|
-
if not isinstance(self.norm_borders[0], tuple):
|
|
181
|
-
self.norm_borders = (self.norm_borders,)
|
|
182
|
-
self.norm_xmin, self.norm_xmax = [], []
|
|
183
|
-
for i, _ in enumerate(self.norm_borders):
|
|
184
|
-
self.norm_xmin.append(self.norm_borders[i][0])
|
|
185
|
-
self.norm_xmax.append(self.norm_borders[i][1])
|
|
186
|
-
self.int_length = jnp.prod(
|
|
187
|
-
jnp.array(
|
|
188
|
-
[
|
|
189
|
-
self.norm_xmax[i] - self.norm_xmin[i]
|
|
190
|
-
for i in range(len(self.norm_borders))
|
|
191
|
-
]
|
|
192
|
-
)
|
|
193
|
-
)
|
|
194
|
-
self.normalization_loss = True
|
|
195
144
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
if self.norm_sample_method == "user":
|
|
203
|
-
return self.norm_samples
|
|
204
|
-
if self.norm_sample_method == "generate":
|
|
205
|
-
## NOTE TODO CHECK the performances of this for loop
|
|
206
|
-
norm_samples = []
|
|
207
|
-
for d in range(len(self.norm_borders)):
|
|
208
|
-
self.norm_key, subkey = jax.random.split(self.norm_key)
|
|
209
|
-
norm_samples.append(
|
|
210
|
-
jax.random.uniform(
|
|
211
|
-
subkey,
|
|
212
|
-
shape=(1000, 1),
|
|
213
|
-
minval=self.norm_xmin[d],
|
|
214
|
-
maxval=self.norm_xmax[d],
|
|
215
|
-
)
|
|
216
|
-
)
|
|
217
|
-
self.norm_samples = jnp.concatenate(norm_samples, axis=-1)
|
|
218
|
-
return self.norm_samples
|
|
219
|
-
raise RuntimeError("Problem with the value of self.norm_sample_method")
|
|
220
|
-
|
|
221
|
-
def tree_flatten(self):
|
|
222
|
-
children = (self.norm_key, self.norm_samples, self.loss_weights)
|
|
223
|
-
aux_data = {
|
|
224
|
-
"norm_borders": self.norm_borders,
|
|
225
|
-
"derivative_keys": self.derivative_keys,
|
|
226
|
-
"u": self.u,
|
|
227
|
-
}
|
|
228
|
-
return (children, aux_data)
|
|
229
|
-
|
|
230
|
-
@classmethod
|
|
231
|
-
def tree_unflatten(self, aux_data, children):
|
|
232
|
-
(norm_key, norm_samples, loss_weights) = children
|
|
233
|
-
pls = self(
|
|
234
|
-
aux_data["u"],
|
|
235
|
-
loss_weights,
|
|
236
|
-
aux_data["derivative_keys"],
|
|
237
|
-
norm_key,
|
|
238
|
-
aux_data["norm_borders"],
|
|
239
|
-
norm_samples,
|
|
240
|
-
)
|
|
241
|
-
return pls
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
@register_pytree_node_class
|
|
245
|
-
class LossPDEStatio(LossPDEAbstract):
|
|
246
|
-
r"""Loss object for a stationary partial differential equation
|
|
247
|
-
|
|
248
|
-
.. math::
|
|
249
|
-
\mathcal{N}[u](x) = 0, \forall x \in \Omega
|
|
250
|
-
|
|
251
|
-
where :math:`\mathcal{N}[\cdot]` is a differential operator and the
|
|
252
|
-
boundary condition is :math:`u(x)=u_b(x)` The additional condition of
|
|
253
|
-
integrating to 1 can be included, i.e. :math:`\int u(x)\mathrm{d}x=1`.
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
**Note:** LossPDEStatio is jittable. Hence it implements the tree_flatten() and
|
|
257
|
-
tree_unflatten methods.
|
|
258
|
-
"""
|
|
145
|
+
if self.loss_weights is None:
|
|
146
|
+
self.loss_weights = (
|
|
147
|
+
LossWeightsPDENonStatio()
|
|
148
|
+
if isinstance(self, LossPDENonStatio)
|
|
149
|
+
else LossWeightsPDEStatio()
|
|
150
|
+
)
|
|
259
151
|
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
u,
|
|
263
|
-
loss_weights,
|
|
264
|
-
dynamic_loss,
|
|
265
|
-
derivative_keys=None,
|
|
266
|
-
omega_boundary_fun=None,
|
|
267
|
-
omega_boundary_condition=None,
|
|
268
|
-
omega_boundary_dim=None,
|
|
269
|
-
norm_key=None,
|
|
270
|
-
norm_borders=None,
|
|
271
|
-
norm_samples=None,
|
|
272
|
-
sobolev_m=None,
|
|
273
|
-
obs_slice=None,
|
|
274
|
-
):
|
|
275
|
-
r"""
|
|
276
|
-
Parameters
|
|
277
|
-
----------
|
|
278
|
-
u
|
|
279
|
-
the PINN object
|
|
280
|
-
loss_weights
|
|
281
|
-
a dictionary with values used to ponderate each term in the loss
|
|
282
|
-
function. Valid keys are `dyn_loss`, `norm_loss`, `boundary_loss`,
|
|
283
|
-
`observations` and `sobolev`.
|
|
284
|
-
Note that we can have jnp.arrays with the same dimension of
|
|
285
|
-
`u` which then ponderates each output of `u`
|
|
286
|
-
dynamic_loss
|
|
287
|
-
the stationary PDE dynamic part of the loss, basically the differential
|
|
288
|
-
operator :math:` \mathcal{N}[u](t)`. Should implement a method
|
|
289
|
-
`dynamic_loss.evaluate(t, u, params)`.
|
|
290
|
-
Can be None in order to access only some part of the evaluate call
|
|
291
|
-
results.
|
|
292
|
-
derivative_keys
|
|
293
|
-
A dict of lists of strings. In the dict, the key must correspond to
|
|
294
|
-
the loss term keywords. Then each of the values must correspond to keys in the parameter
|
|
295
|
-
dictionary (*at top level only of the parameter dictionary*).
|
|
296
|
-
It enables selecting the set of parameters
|
|
297
|
-
with respect to which the gradients of the dynamic
|
|
298
|
-
loss are computed. If nothing is provided, we set ["nn_params"] for all loss term
|
|
299
|
-
keywords, this is what is typically
|
|
300
|
-
done in solving forward problems, when we only estimate the
|
|
301
|
-
equation solution with a PINN. If some loss terms keywords are
|
|
302
|
-
missing we set their value to ["nn_params"] by default for the same
|
|
303
|
-
reason
|
|
304
|
-
omega_boundary_fun
|
|
305
|
-
The function to be matched in the border condition (can be None)
|
|
306
|
-
or a dictionary of such function. In this case, the keys are the
|
|
307
|
-
facets and the values are the functions. The keys must be in the
|
|
308
|
-
following order: 1D -> ["xmin", "xmax"], 2D -> ["xmin", "xmax",
|
|
309
|
-
"ymin", "ymax"]. Note that high order boundaries are currently not
|
|
310
|
-
implemented. A value in the dict can be None, this means we do not
|
|
311
|
-
enforce a particular boundary condition on this facet.
|
|
312
|
-
The facet called "xmin", resp. "xmax" etc., in 2D,
|
|
313
|
-
refers to the set of 2D points with fixed "xmin", resp. "xmax", etc.
|
|
314
|
-
omega_boundary_condition
|
|
315
|
-
Either None (no condition), or a string defining the boundary
|
|
316
|
-
condition e.g. Dirichlet or Von Neumann, or a dictionary of such
|
|
317
|
-
strings. In this case, the keys are the
|
|
318
|
-
facets and the values are the strings. The keys must be in the
|
|
319
|
-
following order: 1D -> ["xmin", "xmax"], 2D -> ["xmin", "xmax",
|
|
320
|
-
"ymin", "ymax"]. Note that high order boundaries are currently not
|
|
321
|
-
implemented. A value in the dict can be None, this means we do not
|
|
322
|
-
enforce a particular boundary condition on this facet.
|
|
323
|
-
The facet called "xmin", resp. "xmax" etc., in 2D,
|
|
324
|
-
refers to the set of 2D points with fixed "xmin", resp. "xmax", etc.
|
|
325
|
-
omega_boundary_dim
|
|
326
|
-
Either None, or a jnp.s\_ or a dict of jnp.s\_ with keys following
|
|
327
|
-
the logic of omega_boundary_fun. It indicates which dimension(s) of
|
|
328
|
-
the PINN will be forced to match the boundary condition
|
|
329
|
-
Note that it must be a slice and not an integer (a preprocessing of the
|
|
330
|
-
user provided argument takes care of it)
|
|
331
|
-
norm_key
|
|
332
|
-
Jax random key to draw samples in for the Monte Carlo computation
|
|
333
|
-
of the normalization constant. Default is None
|
|
334
|
-
norm_borders
|
|
335
|
-
tuple of (min, max) of the boundaray values of the space over which
|
|
336
|
-
to integrate in the computation of the normalization constant.
|
|
337
|
-
A list of tuple for higher dimensional problems. Default None.
|
|
338
|
-
norm_samples
|
|
339
|
-
Fixed sample point in the space over which to compute the
|
|
340
|
-
normalization constant. Default is None
|
|
341
|
-
sobolev_m
|
|
342
|
-
An integer. Default is None.
|
|
343
|
-
It corresponds to the Sobolev regularization order as proposed in
|
|
344
|
-
*Convergence and error analysis of PINNs*,
|
|
345
|
-
Doumeche et al., 2023, https://arxiv.org/pdf/2305.01240.pdf
|
|
346
|
-
obs_slice
|
|
347
|
-
slice object specifying the begininning/ending
|
|
348
|
-
slice of u output(s) that is observed (this is then useful for
|
|
349
|
-
multidim PINN). Default is None.
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
Raises
|
|
353
|
-
------
|
|
354
|
-
ValueError
|
|
355
|
-
If conditions on omega_boundary_condition and omega_boundary_fun
|
|
356
|
-
are not respected
|
|
357
|
-
"""
|
|
152
|
+
if self.obs_slice is None:
|
|
153
|
+
self.obs_slice = jnp.s_[...]
|
|
358
154
|
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
155
|
+
if (
|
|
156
|
+
isinstance(self.omega_boundary_fun, dict)
|
|
157
|
+
and not isinstance(self.omega_boundary_condition, dict)
|
|
158
|
+
) or (
|
|
159
|
+
not isinstance(self.omega_boundary_fun, dict)
|
|
160
|
+
and isinstance(self.omega_boundary_condition, dict)
|
|
161
|
+
):
|
|
162
|
+
raise ValueError(
|
|
163
|
+
"if one of self.omega_boundary_fun or "
|
|
164
|
+
"self.omega_boundary_condition is dict, the other should be too."
|
|
165
|
+
)
|
|
362
166
|
|
|
363
|
-
if omega_boundary_condition is None or omega_boundary_fun is None:
|
|
167
|
+
if self.omega_boundary_condition is None or self.omega_boundary_fun is None:
|
|
364
168
|
warnings.warn(
|
|
365
169
|
"Missing boundary function or no boundary condition."
|
|
366
170
|
"Boundary function is thus ignored."
|
|
367
171
|
)
|
|
368
172
|
else:
|
|
369
|
-
if isinstance(omega_boundary_condition, dict):
|
|
370
|
-
for _, v in omega_boundary_condition.items():
|
|
173
|
+
if isinstance(self.omega_boundary_condition, dict):
|
|
174
|
+
for _, v in self.omega_boundary_condition.items():
|
|
371
175
|
if v is not None and not any(
|
|
372
176
|
v.lower() in s for s in _IMPLEMENTED_BOUNDARY_CONDITIONS
|
|
373
177
|
):
|
|
374
178
|
raise NotImplementedError(
|
|
375
|
-
f"The boundary condition {omega_boundary_condition} is not"
|
|
179
|
+
f"The boundary condition {self.omega_boundary_condition} is not"
|
|
376
180
|
f"implemented yet. Try one of :"
|
|
377
181
|
f"{_IMPLEMENTED_BOUNDARY_CONDITIONS}."
|
|
378
182
|
)
|
|
379
183
|
else:
|
|
380
184
|
if not any(
|
|
381
|
-
omega_boundary_condition.lower() in s
|
|
185
|
+
self.omega_boundary_condition.lower() in s
|
|
382
186
|
for s in _IMPLEMENTED_BOUNDARY_CONDITIONS
|
|
383
187
|
):
|
|
384
188
|
raise NotImplementedError(
|
|
385
|
-
f"The boundary condition {omega_boundary_condition} is not"
|
|
189
|
+
f"The boundary condition {self.omega_boundary_condition} is not"
|
|
386
190
|
f"implemented yet. Try one of :"
|
|
387
191
|
f"{_IMPLEMENTED_BOUNDARY_CONDITIONS}."
|
|
388
192
|
)
|
|
389
|
-
if isinstance(omega_boundary_fun, dict) and isinstance(
|
|
390
|
-
omega_boundary_condition, dict
|
|
193
|
+
if isinstance(self.omega_boundary_fun, dict) and isinstance(
|
|
194
|
+
self.omega_boundary_condition, dict
|
|
391
195
|
):
|
|
392
196
|
if (
|
|
393
197
|
not (
|
|
394
|
-
list(omega_boundary_fun.keys()) == ["xmin", "xmax"]
|
|
395
|
-
and list(omega_boundary_condition.keys())
|
|
198
|
+
list(self.omega_boundary_fun.keys()) == ["xmin", "xmax"]
|
|
199
|
+
and list(self.omega_boundary_condition.keys())
|
|
396
200
|
== ["xmin", "xmax"]
|
|
397
201
|
)
|
|
398
202
|
) or (
|
|
399
203
|
not (
|
|
400
|
-
list(omega_boundary_fun.keys())
|
|
204
|
+
list(self.omega_boundary_fun.keys())
|
|
401
205
|
== ["xmin", "xmax", "ymin", "ymax"]
|
|
402
|
-
and list(omega_boundary_condition.keys())
|
|
206
|
+
and list(self.omega_boundary_condition.keys())
|
|
403
207
|
== ["xmin", "xmax", "ymin", "ymax"]
|
|
404
208
|
)
|
|
405
209
|
):
|
|
@@ -408,10 +212,6 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
408
212
|
"boundary condition dictionaries is incorrect"
|
|
409
213
|
)
|
|
410
214
|
|
|
411
|
-
self.omega_boundary_fun = omega_boundary_fun
|
|
412
|
-
self.omega_boundary_condition = omega_boundary_condition
|
|
413
|
-
|
|
414
|
-
self.omega_boundary_dim = omega_boundary_dim
|
|
415
215
|
if isinstance(self.omega_boundary_fun, dict):
|
|
416
216
|
if self.omega_boundary_dim is None:
|
|
417
217
|
self.omega_boundary_dim = {
|
|
@@ -440,44 +240,139 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
440
240
|
self.omega_boundary_dim : self.omega_boundary_dim + 1
|
|
441
241
|
]
|
|
442
242
|
if not isinstance(self.omega_boundary_dim, slice):
|
|
443
|
-
raise ValueError("self.omega_boundary_dim must be a jnp.s_
|
|
243
|
+
raise ValueError("self.omega_boundary_dim must be a jnp.s_ object")
|
|
444
244
|
|
|
445
|
-
self.
|
|
245
|
+
if self.norm_samples is not None and self.norm_int_length is None:
|
|
246
|
+
raise ValueError("self.norm_samples and norm_int_length must be provided")
|
|
446
247
|
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
else:
|
|
455
|
-
self.sobolev_reg = None
|
|
248
|
+
@abc.abstractmethod
|
|
249
|
+
def evaluate(
|
|
250
|
+
self: eqx.Module,
|
|
251
|
+
params: Params,
|
|
252
|
+
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
253
|
+
) -> tuple[Float, dict]:
|
|
254
|
+
raise NotImplementedError
|
|
456
255
|
|
|
457
|
-
for k in _LOSS_WEIGHT_KEYS_PDESTATIO:
|
|
458
|
-
if k not in self.loss_weights.keys():
|
|
459
|
-
self.loss_weights[k] = 0
|
|
460
256
|
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
and not isinstance(self.omega_boundary_condition, dict)
|
|
464
|
-
) or (
|
|
465
|
-
not isinstance(self.omega_boundary_fun, dict)
|
|
466
|
-
and isinstance(self.omega_boundary_condition, dict)
|
|
467
|
-
):
|
|
468
|
-
raise ValueError(
|
|
469
|
-
"if one of self.omega_boundary_fun or "
|
|
470
|
-
"self.omega_boundary_condition is dict, the other should be too."
|
|
471
|
-
)
|
|
257
|
+
class LossPDEStatio(_LossPDEAbstract):
|
|
258
|
+
r"""Loss object for a stationary partial differential equation
|
|
472
259
|
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
260
|
+
$$
|
|
261
|
+
\mathcal{N}[u](x) = 0, \forall x \in \Omega
|
|
262
|
+
$$
|
|
263
|
+
|
|
264
|
+
where $\mathcal{N}[\cdot]$ is a differential operator and the
|
|
265
|
+
boundary condition is $u(x)=u_b(x)$ The additional condition of
|
|
266
|
+
integrating to 1 can be included, i.e. $\int u(x)\mathrm{d}x=1$.
|
|
267
|
+
|
|
268
|
+
Parameters
|
|
269
|
+
----------
|
|
270
|
+
u : eqx.Module
|
|
271
|
+
the PINN
|
|
272
|
+
dynamic_loss : DynamicLoss
|
|
273
|
+
the stationary PDE dynamic part of the loss, basically the differential
|
|
274
|
+
operator $\mathcal{N}[u](x)$. Should implement a method
|
|
275
|
+
`dynamic_loss.evaluate(x, u, params)`.
|
|
276
|
+
Can be None in order to access only some part of the evaluate call
|
|
277
|
+
results.
|
|
278
|
+
key : Key
|
|
279
|
+
A JAX PRNG Key for the loss class treated as an attribute. Default is
|
|
280
|
+
None. This field is provided for future developments and additional
|
|
281
|
+
losses that might need some randomness. Note that special care must be
|
|
282
|
+
taken when splitting the key because in-place updates are forbidden in
|
|
283
|
+
eqx.Modules.
|
|
284
|
+
loss_weights : LossWeightsPDEStatio, default=None
|
|
285
|
+
The loss weights for the differents term : dynamic loss,
|
|
286
|
+
boundary conditions if any, normalization loss if any and
|
|
287
|
+
observations if any.
|
|
288
|
+
All fields are set to 1.0 by default.
|
|
289
|
+
derivative_keys : DerivativeKeysPDEStatio, default=None
|
|
290
|
+
Specify which field of `params` should be differentiated for each
|
|
291
|
+
composant of the total loss. Particularily useful for inverse problems.
|
|
292
|
+
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
293
|
+
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
294
|
+
is `"nn_params"` for each composant of the loss.
|
|
295
|
+
omega_boundary_fun : Callable | Dict[str, Callable], default=None
|
|
296
|
+
The function to be matched in the border condition (can be None) or a
|
|
297
|
+
dictionary of such functions as values and keys as described
|
|
298
|
+
in `omega_boundary_condition`.
|
|
299
|
+
omega_boundary_condition : str | Dict[str, str], default=None
|
|
300
|
+
Either None (no condition, by default), or a string defining
|
|
301
|
+
the boundary condition (Dirichlet or Von Neumann),
|
|
302
|
+
or a dictionary with such strings as values. In this case,
|
|
303
|
+
the keys are the facets and must be in the following order:
|
|
304
|
+
1D -> [“xmin”, “xmax”], 2D -> [“xmin”, “xmax”, “ymin”, “ymax”].
|
|
305
|
+
Note that high order boundaries are currently not implemented.
|
|
306
|
+
A value in the dict can be None, this means we do not enforce
|
|
307
|
+
a particular boundary condition on this facet.
|
|
308
|
+
The facet called “xmin”, resp. “xmax” etc., in 2D,
|
|
309
|
+
refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
|
|
310
|
+
omega_boundary_dim : slice | Dict[str, slice], default=None
|
|
311
|
+
Either None, or a slice object or a dictionary of slice objects as
|
|
312
|
+
values and keys as described in `omega_boundary_condition`.
|
|
313
|
+
`omega_boundary_dim` indicates which dimension(s) of the PINN
|
|
314
|
+
will be forced to match the boundary condition.
|
|
315
|
+
Note that it must be a slice and not an integer
|
|
316
|
+
(but a preprocessing of the user provided argument takes care of it)
|
|
317
|
+
norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
|
|
318
|
+
Fixed sample point in the space over which to compute the
|
|
319
|
+
normalization constant. Default is None.
|
|
320
|
+
norm_int_length : float, default=None
|
|
321
|
+
A float. Must be provided if `norm_samples` is provided. The domain area
|
|
322
|
+
(or interval length in 1D) upon which we perform the numerical
|
|
323
|
+
integration. Default None
|
|
324
|
+
obs_slice : slice, default=None
|
|
325
|
+
slice object specifying the begininning/ending of the PINN output
|
|
326
|
+
that is observed (this is then useful for multidim PINN). Default is None.
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
Raises
|
|
330
|
+
------
|
|
331
|
+
ValueError
|
|
332
|
+
If conditions on omega_boundary_condition and omega_boundary_fun
|
|
333
|
+
are not respected
|
|
334
|
+
"""
|
|
335
|
+
|
|
336
|
+
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
337
|
+
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
338
|
+
|
|
339
|
+
u: eqx.Module
|
|
340
|
+
dynamic_loss: DynamicLoss | None
|
|
341
|
+
key: Key | None = eqx.field(kw_only=True, default=None)
|
|
342
|
+
|
|
343
|
+
vmap_in_axes: tuple[Int] = eqx.field(init=False, static=True)
|
|
344
|
+
|
|
345
|
+
def __post_init__(self):
|
|
346
|
+
"""
|
|
347
|
+
Note that neither __init__ or __post_init__ are called when udating a
|
|
348
|
+
Module with eqx.tree_at!
|
|
349
|
+
"""
|
|
350
|
+
super().__post_init__() # because __init__ or __post_init__ of Base
|
|
351
|
+
# class is not automatically called
|
|
352
|
+
|
|
353
|
+
self.vmap_in_axes = (0,) # for x only here
|
|
354
|
+
|
|
355
|
+
def _get_dynamic_loss_batch(
|
|
356
|
+
self, batch: PDEStatioBatch
|
|
357
|
+
) -> tuple[Float[Array, "batch_size dimension"]]:
|
|
358
|
+
return (batch.inside_batch,)
|
|
359
|
+
|
|
360
|
+
def _get_normalization_loss_batch(
|
|
361
|
+
self, _
|
|
362
|
+
) -> Float[Array, "nb_norm_samples dimension"]:
|
|
363
|
+
return (self.norm_samples,)
|
|
364
|
+
|
|
365
|
+
def _get_observations_loss_batch(
|
|
366
|
+
self, batch: PDEStatioBatch
|
|
367
|
+
) -> Float[Array, "batch_size obs_dim"]:
|
|
368
|
+
return (batch.obs_batch_dict["pinn_in"],)
|
|
476
369
|
|
|
477
370
|
def __call__(self, *args, **kwargs):
|
|
478
371
|
return self.evaluate(*args, **kwargs)
|
|
479
372
|
|
|
480
|
-
def evaluate(
|
|
373
|
+
def evaluate(
|
|
374
|
+
self, params: Params, batch: PDEStatioBatch
|
|
375
|
+
) -> tuple[Float[Array, "1"], dict[str, float]]:
|
|
481
376
|
"""
|
|
482
377
|
Evaluate the loss function at a batch of points for given parameters.
|
|
483
378
|
|
|
@@ -485,22 +380,14 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
485
380
|
Parameters
|
|
486
381
|
---------
|
|
487
382
|
params
|
|
488
|
-
|
|
489
|
-
Typically, it is a dictionary of
|
|
490
|
-
dictionaries: `eq_params` and `nn_params``, respectively the
|
|
491
|
-
differential equation parameters and the neural network parameter
|
|
383
|
+
Parameters at which the loss is evaluated
|
|
492
384
|
batch
|
|
493
|
-
|
|
494
|
-
Such a named tuple is composed of a batch of points in the
|
|
385
|
+
Composed of a batch of points in the
|
|
495
386
|
domain, a batch of points in the domain
|
|
496
387
|
border and an optional additional batch of parameters (eg. for
|
|
497
388
|
metamodeling) and an optional additional batch of observed
|
|
498
389
|
inputs/outputs/parameters
|
|
499
390
|
"""
|
|
500
|
-
omega_batch, _ = batch.inside_batch, batch.border_batch
|
|
501
|
-
|
|
502
|
-
vmap_in_axes_x = (0,)
|
|
503
|
-
|
|
504
391
|
# Retrieve the optional eq_params_batch
|
|
505
392
|
# and update eq_params with the latter
|
|
506
393
|
# and update vmap_in_axes
|
|
@@ -511,44 +398,41 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
511
398
|
vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
|
|
512
399
|
|
|
513
400
|
# dynamic part
|
|
514
|
-
params_ = _set_derivatives(params, "dyn_loss", self.derivative_keys)
|
|
515
401
|
if self.dynamic_loss is not None:
|
|
516
402
|
mse_dyn_loss = dynamic_loss_apply(
|
|
517
403
|
self.dynamic_loss.evaluate,
|
|
518
404
|
self.u,
|
|
519
|
-
(
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
self.loss_weights
|
|
405
|
+
self._get_dynamic_loss_batch(batch),
|
|
406
|
+
_set_derivatives(params, self.derivative_keys.dyn_loss),
|
|
407
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
408
|
+
self.loss_weights.dyn_loss,
|
|
523
409
|
)
|
|
524
410
|
else:
|
|
525
411
|
mse_dyn_loss = jnp.array(0.0)
|
|
526
412
|
|
|
527
413
|
# normalization part
|
|
528
|
-
|
|
529
|
-
if self.normalization_loss is not None:
|
|
414
|
+
if self.norm_samples is not None:
|
|
530
415
|
mse_norm_loss = normalization_loss_apply(
|
|
531
416
|
self.u,
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
self.
|
|
536
|
-
self.loss_weights
|
|
417
|
+
self._get_normalization_loss_batch(batch),
|
|
418
|
+
_set_derivatives(params, self.derivative_keys.norm_loss),
|
|
419
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
420
|
+
self.norm_int_length,
|
|
421
|
+
self.loss_weights.norm_loss,
|
|
537
422
|
)
|
|
538
423
|
else:
|
|
539
424
|
mse_norm_loss = jnp.array(0.0)
|
|
540
425
|
|
|
541
426
|
# boundary part
|
|
542
|
-
params_ = _set_derivatives(params, "boundary_loss", self.derivative_keys)
|
|
543
427
|
if self.omega_boundary_condition is not None:
|
|
544
428
|
mse_boundary_loss = boundary_condition_apply(
|
|
545
429
|
self.u,
|
|
546
430
|
batch,
|
|
547
|
-
|
|
431
|
+
_set_derivatives(params, self.derivative_keys.boundary_loss),
|
|
548
432
|
self.omega_boundary_fun,
|
|
549
433
|
self.omega_boundary_condition,
|
|
550
434
|
self.omega_boundary_dim,
|
|
551
|
-
self.loss_weights
|
|
435
|
+
self.loss_weights.boundary_loss,
|
|
552
436
|
)
|
|
553
437
|
else:
|
|
554
438
|
mse_boundary_loss = jnp.array(0.0)
|
|
@@ -558,40 +442,21 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
558
442
|
# update params with the batches of observed params
|
|
559
443
|
params = _update_eq_params_dict(params, batch.obs_batch_dict["eq_params"])
|
|
560
444
|
|
|
561
|
-
params_ = _set_derivatives(params, "observations", self.derivative_keys)
|
|
562
445
|
mse_observation_loss = observations_loss_apply(
|
|
563
446
|
self.u,
|
|
564
|
-
(batch
|
|
565
|
-
|
|
566
|
-
|
|
447
|
+
self._get_observations_loss_batch(batch),
|
|
448
|
+
_set_derivatives(params, self.derivative_keys.observations),
|
|
449
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
567
450
|
batch.obs_batch_dict["val"],
|
|
568
|
-
self.loss_weights
|
|
451
|
+
self.loss_weights.observations,
|
|
569
452
|
self.obs_slice,
|
|
570
453
|
)
|
|
571
454
|
else:
|
|
572
455
|
mse_observation_loss = jnp.array(0.0)
|
|
573
456
|
|
|
574
|
-
# Sobolev regularization
|
|
575
|
-
params_ = _set_derivatives(params, "sobolev", self.derivative_keys)
|
|
576
|
-
if self.sobolev_reg is not None:
|
|
577
|
-
mse_sobolev_loss = sobolev_reg_apply(
|
|
578
|
-
self.u,
|
|
579
|
-
(omega_batch,),
|
|
580
|
-
params_,
|
|
581
|
-
vmap_in_axes_x + vmap_in_axes_params,
|
|
582
|
-
self.sobolev_reg,
|
|
583
|
-
self.loss_weights["sobolev"],
|
|
584
|
-
)
|
|
585
|
-
else:
|
|
586
|
-
mse_sobolev_loss = jnp.array(0.0)
|
|
587
|
-
|
|
588
457
|
# total loss
|
|
589
458
|
total_loss = (
|
|
590
|
-
mse_dyn_loss
|
|
591
|
-
+ mse_norm_loss
|
|
592
|
-
+ mse_boundary_loss
|
|
593
|
-
+ mse_observation_loss
|
|
594
|
-
+ mse_sobolev_loss
|
|
459
|
+
mse_dyn_loss + mse_norm_loss + mse_boundary_loss + mse_observation_loss
|
|
595
460
|
)
|
|
596
461
|
return total_loss, (
|
|
597
462
|
{
|
|
@@ -599,205 +464,143 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
599
464
|
"norm_loss": mse_norm_loss,
|
|
600
465
|
"boundary_loss": mse_boundary_loss,
|
|
601
466
|
"observations": mse_observation_loss,
|
|
602
|
-
"sobolev": mse_sobolev_loss,
|
|
603
467
|
"initial_condition": jnp.array(0.0), # for compatibility in the
|
|
604
468
|
# tree_map of SystemLoss
|
|
605
469
|
}
|
|
606
470
|
)
|
|
607
471
|
|
|
608
|
-
def tree_flatten(self):
|
|
609
|
-
children = (self.norm_key, self.norm_samples, self.loss_weights)
|
|
610
|
-
aux_data = {
|
|
611
|
-
"u": self.u,
|
|
612
|
-
"dynamic_loss": self.dynamic_loss,
|
|
613
|
-
"derivative_keys": self.derivative_keys,
|
|
614
|
-
"omega_boundary_fun": self.omega_boundary_fun,
|
|
615
|
-
"omega_boundary_condition": self.omega_boundary_condition,
|
|
616
|
-
"omega_boundary_dim": self.omega_boundary_dim,
|
|
617
|
-
"norm_borders": self.norm_borders,
|
|
618
|
-
"sobolev_m": self.sobolev_m,
|
|
619
|
-
"obs_slice": self.obs_slice,
|
|
620
|
-
}
|
|
621
|
-
return (children, aux_data)
|
|
622
|
-
|
|
623
|
-
@classmethod
|
|
624
|
-
def tree_unflatten(cls, aux_data, children):
|
|
625
|
-
(norm_key, norm_samples, loss_weights) = children
|
|
626
|
-
pls = cls(
|
|
627
|
-
aux_data["u"],
|
|
628
|
-
loss_weights,
|
|
629
|
-
aux_data["dynamic_loss"],
|
|
630
|
-
aux_data["derivative_keys"],
|
|
631
|
-
aux_data["omega_boundary_fun"],
|
|
632
|
-
aux_data["omega_boundary_condition"],
|
|
633
|
-
aux_data["omega_boundary_dim"],
|
|
634
|
-
norm_key,
|
|
635
|
-
aux_data["norm_borders"],
|
|
636
|
-
norm_samples,
|
|
637
|
-
aux_data["sobolev_m"],
|
|
638
|
-
aux_data["obs_slice"],
|
|
639
|
-
)
|
|
640
|
-
return pls
|
|
641
472
|
|
|
642
|
-
|
|
643
|
-
@register_pytree_node_class
|
|
644
473
|
class LossPDENonStatio(LossPDEStatio):
|
|
645
474
|
r"""Loss object for a stationary partial differential equation
|
|
646
475
|
|
|
647
|
-
|
|
476
|
+
$$
|
|
648
477
|
\mathcal{N}[u](t, x) = 0, \forall t \in I, \forall x \in \Omega
|
|
478
|
+
$$
|
|
649
479
|
|
|
650
|
-
where
|
|
651
|
-
The boundary condition is
|
|
652
|
-
x\in\delta\Omega, \forall t
|
|
653
|
-
The initial condition is
|
|
480
|
+
where $\mathcal{N}[\cdot]$ is a differential operator.
|
|
481
|
+
The boundary condition is $u(t, x)=u_b(t, x),\forall
|
|
482
|
+
x\in\delta\Omega, \forall t$.
|
|
483
|
+
The initial condition is $u(0, x)=u_0(x), \forall x\in\Omega$
|
|
654
484
|
The additional condition of
|
|
655
|
-
integrating to 1 can be included, i.e.,
|
|
656
|
-
|
|
485
|
+
integrating to 1 can be included, i.e., $\int u(t, x)\mathrm{d}x=1$.
|
|
486
|
+
|
|
487
|
+
Parameters
|
|
488
|
+
----------
|
|
489
|
+
u : eqx.Module
|
|
490
|
+
the PINN
|
|
491
|
+
dynamic_loss : DynamicLoss
|
|
492
|
+
the non stationary PDE dynamic part of the loss, basically the differential
|
|
493
|
+
operator $\mathcal{N}[u](t, x)$. Should implement a method
|
|
494
|
+
`dynamic_loss.evaluate(t, x, u, params)`.
|
|
495
|
+
Can be None in order to access only some part of the evaluate call
|
|
496
|
+
results.
|
|
497
|
+
key : Key
|
|
498
|
+
A JAX PRNG Key for the loss class treated as an attribute. Default is
|
|
499
|
+
None. This field is provided for future developments and additional
|
|
500
|
+
losses that might need some randomness. Note that special care must be
|
|
501
|
+
taken when splitting the key because in-place updates are forbidden in
|
|
502
|
+
eqx.Modules.
|
|
503
|
+
reason
|
|
504
|
+
loss_weights : LossWeightsPDENonStatio, default=None
|
|
505
|
+
The loss weights for the differents term : dynamic loss,
|
|
506
|
+
boundary conditions if any, initial condition, normalization loss if any and
|
|
507
|
+
observations if any.
|
|
508
|
+
All fields are set to 1.0 by default.
|
|
509
|
+
derivative_keys : DerivativeKeysPDENonStatio, default=None
|
|
510
|
+
Specify which field of `params` should be differentiated for each
|
|
511
|
+
composant of the total loss. Particularily useful for inverse problems.
|
|
512
|
+
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
513
|
+
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
514
|
+
is `"nn_params"` for each composant of the loss.
|
|
515
|
+
omega_boundary_fun : Callable | Dict[str, Callable], default=None
|
|
516
|
+
The function to be matched in the border condition (can be None) or a
|
|
517
|
+
dictionary of such functions as values and keys as described
|
|
518
|
+
in `omega_boundary_condition`.
|
|
519
|
+
omega_boundary_condition : str | Dict[str, str], default=None
|
|
520
|
+
Either None (no condition, by default), or a string defining
|
|
521
|
+
the boundary condition (Dirichlet or Von Neumann),
|
|
522
|
+
or a dictionary with such strings as values. In this case,
|
|
523
|
+
the keys are the facets and must be in the following order:
|
|
524
|
+
1D -> [“xmin”, “xmax”], 2D -> [“xmin”, “xmax”, “ymin”, “ymax”].
|
|
525
|
+
Note that high order boundaries are currently not implemented.
|
|
526
|
+
A value in the dict can be None, this means we do not enforce
|
|
527
|
+
a particular boundary condition on this facet.
|
|
528
|
+
The facet called “xmin”, resp. “xmax” etc., in 2D,
|
|
529
|
+
refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
|
|
530
|
+
omega_boundary_dim : slice | Dict[str, slice], default=None
|
|
531
|
+
Either None, or a slice object or a dictionary of slice objects as
|
|
532
|
+
values and keys as described in `omega_boundary_condition`.
|
|
533
|
+
`omega_boundary_dim` indicates which dimension(s) of the PINN
|
|
534
|
+
will be forced to match the boundary condition.
|
|
535
|
+
Note that it must be a slice and not an integer
|
|
536
|
+
(but a preprocessing of the user provided argument takes care of it)
|
|
537
|
+
norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
|
|
538
|
+
Fixed sample point in the space over which to compute the
|
|
539
|
+
normalization constant. Default is None.
|
|
540
|
+
norm_int_length : float, default=None
|
|
541
|
+
A float. Must be provided if `norm_samples` is provided. The domain area
|
|
542
|
+
(or interval length in 1D) upon which we perform the numerical
|
|
543
|
+
integration. Default None
|
|
544
|
+
obs_slice : slice, default=None
|
|
545
|
+
slice object specifying the begininning/ending of the PINN output
|
|
546
|
+
that is observed (this is then useful for multidim PINN). Default is None.
|
|
547
|
+
initial_condition_fun : Callable, default=None
|
|
548
|
+
A function representing the temporal initial condition. If None
|
|
549
|
+
(default) then no initial condition is applied
|
|
657
550
|
|
|
658
|
-
**Note:** LossPDENonStatio is jittable. Hence it implements the tree_flatten() and
|
|
659
|
-
tree_unflatten methods.
|
|
660
551
|
"""
|
|
661
552
|
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
derivative_keys=None,
|
|
668
|
-
omega_boundary_fun=None,
|
|
669
|
-
omega_boundary_condition=None,
|
|
670
|
-
omega_boundary_dim=None,
|
|
671
|
-
initial_condition_fun=None,
|
|
672
|
-
norm_key=None,
|
|
673
|
-
norm_borders=None,
|
|
674
|
-
norm_samples=None,
|
|
675
|
-
sobolev_m=None,
|
|
676
|
-
obs_slice=None,
|
|
677
|
-
):
|
|
678
|
-
r"""
|
|
679
|
-
Parameters
|
|
680
|
-
----------
|
|
681
|
-
u
|
|
682
|
-
the PINN object
|
|
683
|
-
loss_weights
|
|
684
|
-
dictionary of values for loss term ponderation
|
|
685
|
-
Note that we can have jnp.arrays with the same dimension of
|
|
686
|
-
`u` which then ponderates each output of `u`
|
|
687
|
-
dynamic_loss
|
|
688
|
-
A Dynamic loss object whose evaluate method corresponds to the
|
|
689
|
-
dynamic term in the loss
|
|
690
|
-
Can be None in order to access only some part of the evaluate call
|
|
691
|
-
results.
|
|
692
|
-
derivative_keys
|
|
693
|
-
A dict of lists of strings. In the dict, the key must correspond to
|
|
694
|
-
the loss term keywords. Then each of the values must correspond to keys in the parameter
|
|
695
|
-
dictionary (*at top level only of the parameter dictionary*).
|
|
696
|
-
It enables selecting the set of parameters
|
|
697
|
-
with respect to which the gradients of the dynamic
|
|
698
|
-
loss are computed. If nothing is provided, we set ["nn_params"] for all loss term
|
|
699
|
-
keywords, this is what is typically
|
|
700
|
-
done in solving forward problems, when we only estimate the
|
|
701
|
-
equation solution with a PINN. If some loss terms keywords are
|
|
702
|
-
missing we set their value to ["nn_params"] by default for the same
|
|
703
|
-
reason
|
|
704
|
-
omega_boundary_fun
|
|
705
|
-
The function to be matched in the border condition (can be None)
|
|
706
|
-
or a dictionary of such function. In this case, the keys are the
|
|
707
|
-
facets and the values are the functions. The keys must be in the
|
|
708
|
-
following order: 1D -> ["xmin", "xmax"], 2D -> ["xmin", "xmax",
|
|
709
|
-
"ymin", "ymax"]. Note that high order boundaries are currently not
|
|
710
|
-
implemented. A value in the dict can be None, this means we do not
|
|
711
|
-
enforce a particular boundary condition on this facet.
|
|
712
|
-
The facet called "xmin", resp. "xmax" etc., in 2D,
|
|
713
|
-
refers to the set of 2D points with fixed "xmin", resp. "xmax", etc.
|
|
714
|
-
omega_boundary_condition
|
|
715
|
-
Either None (no condition), or a string defining the boundary
|
|
716
|
-
condition e.g. Dirichlet or Von Neumann, or a dictionary of such
|
|
717
|
-
strings. In this case, the keys are the
|
|
718
|
-
facets and the values are the strings. The keys must be in the
|
|
719
|
-
following order: 1D -> ["xmin", "xmax"], 2D -> ["xmin", "xmax",
|
|
720
|
-
"ymin", "ymax"]. Note that high order boundaries are currently not
|
|
721
|
-
implemented. A value in the dict can be None, this means we do not
|
|
722
|
-
enforce a particular boundary condition on this facet.
|
|
723
|
-
The facet called "xmin", resp. "xmax" etc., in 2D,
|
|
724
|
-
refers to the set of 2D points with fixed "xmin", resp. "xmax", etc.
|
|
725
|
-
omega_boundary_dim
|
|
726
|
-
Either None, or a jnp.s\_ or a dict of jnp.s\_ with keys following
|
|
727
|
-
the logic of omega_boundary_fun. It indicates which dimension(s) of
|
|
728
|
-
the PINN will be forced to match the boundary condition
|
|
729
|
-
Note that it must be a slice and not an integer (a preprocessing of the
|
|
730
|
-
user provided argument takes care of it)
|
|
731
|
-
initial_condition_fun
|
|
732
|
-
A function representing the temporal initial condition. If None
|
|
733
|
-
(default) then no initial condition is applied
|
|
734
|
-
norm_key
|
|
735
|
-
Jax random key to draw samples in for the Monte Carlo computation
|
|
736
|
-
of the normalization constant. Default is None
|
|
737
|
-
norm_borders
|
|
738
|
-
tuple of (min, max) of the boundaray values of the space over which
|
|
739
|
-
to integrate in the computation of the normalization constant.
|
|
740
|
-
A list of tuple for higher dimensional problems. Default None.
|
|
741
|
-
norm_samples
|
|
742
|
-
Fixed sample point in the space over which to compute the
|
|
743
|
-
normalization constant. Default is None
|
|
744
|
-
sobolev_m
|
|
745
|
-
An integer. Default is None.
|
|
746
|
-
It corresponds to the Sobolev regularization order as proposed in
|
|
747
|
-
*Convergence and error analysis of PINNs*,
|
|
748
|
-
Doumeche et al., 2023, https://arxiv.org/pdf/2305.01240.pdf
|
|
749
|
-
obs_slice
|
|
750
|
-
slice object specifying the begininning/ending
|
|
751
|
-
slice of u output(s) that is observed (this is then useful for
|
|
752
|
-
multidim PINN). Default is None.
|
|
753
|
-
|
|
553
|
+
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
554
|
+
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
555
|
+
initial_condition_fun: Callable | None = eqx.field(
|
|
556
|
+
kw_only=True, default=None, static=True
|
|
557
|
+
)
|
|
754
558
|
|
|
559
|
+
def __post_init__(self):
|
|
560
|
+
"""
|
|
561
|
+
Note that neither __init__ or __post_init__ are called when udating a
|
|
562
|
+
Module with eqx.tree_at!
|
|
755
563
|
"""
|
|
564
|
+
super().__post_init__() # because __init__ or __post_init__ of Base
|
|
565
|
+
# class is not automatically called
|
|
756
566
|
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
dynamic_loss,
|
|
761
|
-
derivative_keys,
|
|
762
|
-
omega_boundary_fun,
|
|
763
|
-
omega_boundary_condition,
|
|
764
|
-
omega_boundary_dim,
|
|
765
|
-
norm_key,
|
|
766
|
-
norm_borders,
|
|
767
|
-
norm_samples,
|
|
768
|
-
sobolev_m=sobolev_m,
|
|
769
|
-
obs_slice=obs_slice,
|
|
770
|
-
)
|
|
771
|
-
if initial_condition_fun is None:
|
|
567
|
+
self.vmap_in_axes = (0, 0) # for t and x
|
|
568
|
+
|
|
569
|
+
if self.initial_condition_fun is None:
|
|
772
570
|
warnings.warn(
|
|
773
571
|
"Initial condition wasn't provided. Be sure to cover for that"
|
|
774
572
|
"case (e.g by. hardcoding it into the PINN output)."
|
|
775
573
|
)
|
|
776
|
-
self.initial_condition_fun = initial_condition_fun
|
|
777
|
-
|
|
778
|
-
self.sobolev_m = sobolev_m
|
|
779
|
-
if self.sobolev_m is not None:
|
|
780
|
-
# This overwrite the wrongly initialized self.sobolev_reg with
|
|
781
|
-
# statio=True in the LossPDEStatio init
|
|
782
|
-
self.sobolev_reg = _sobolev(self.u, self.sobolev_m, statio=False)
|
|
783
|
-
# we return a function, that way
|
|
784
|
-
# the order of sobolev_m is static and the conditional in the recursive
|
|
785
|
-
# function is properly set
|
|
786
|
-
else:
|
|
787
|
-
self.sobolev_reg = None
|
|
788
574
|
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
575
|
+
def _get_dynamic_loss_batch(
|
|
576
|
+
self, batch: PDENonStatioBatch
|
|
577
|
+
) -> tuple[Float[Array, "batch_size 1"], Float[Array, "batch_size dimension"]]:
|
|
578
|
+
times_batch = batch.times_x_inside_batch[:, 0:1]
|
|
579
|
+
omega_batch = batch.times_x_inside_batch[:, 1:]
|
|
580
|
+
return (times_batch, omega_batch)
|
|
581
|
+
|
|
582
|
+
def _get_normalization_loss_batch(
|
|
583
|
+
self, batch: PDENonStatioBatch
|
|
584
|
+
) -> tuple[Float[Array, "batch_size 1"], Float[Array, "nb_norm_samples dimension"]]:
|
|
585
|
+
return (
|
|
586
|
+
batch.times_x_inside_batch[:, 0:1],
|
|
587
|
+
self.norm_samples,
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
def _get_observations_loss_batch(
|
|
591
|
+
self, batch: PDENonStatioBatch
|
|
592
|
+
) -> tuple[Float[Array, "batch_size 1"], Float[Array, "batch_size dimension"]]:
|
|
593
|
+
return (
|
|
594
|
+
batch.obs_batch_dict["pinn_in"][:, 0:1],
|
|
595
|
+
batch.obs_batch_dict["pinn_in"][:, 1:],
|
|
596
|
+
)
|
|
792
597
|
|
|
793
598
|
def __call__(self, *args, **kwargs):
|
|
794
599
|
return self.evaluate(*args, **kwargs)
|
|
795
600
|
|
|
796
601
|
def evaluate(
|
|
797
|
-
self,
|
|
798
|
-
|
|
799
|
-
batch,
|
|
800
|
-
):
|
|
602
|
+
self, params: Params, batch: PDENonStatioBatch
|
|
603
|
+
) -> tuple[Float[Array, "1"], dict[str, float]]:
|
|
801
604
|
"""
|
|
802
605
|
Evaluate the loss function at a batch of points for given parameters.
|
|
803
606
|
|
|
@@ -805,203 +608,55 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
805
608
|
Parameters
|
|
806
609
|
---------
|
|
807
610
|
params
|
|
808
|
-
|
|
809
|
-
Typically, it is a dictionary of
|
|
810
|
-
dictionaries: `eq_params` and `nn_params`, respectively the
|
|
811
|
-
differential equation parameters and the neural network parameter
|
|
611
|
+
Parameters at which the loss is evaluated
|
|
812
612
|
batch
|
|
813
|
-
|
|
814
|
-
Such a named tuple is composed of a batch of points in
|
|
613
|
+
Composed of a batch of points in
|
|
815
614
|
the domain, a batch of points in the domain
|
|
816
615
|
border, a batch of time points and an optional additional batch
|
|
817
616
|
of parameters (eg. for metamodeling) and an optional additional batch of observed
|
|
818
617
|
inputs/outputs/parameters
|
|
819
618
|
"""
|
|
820
619
|
|
|
821
|
-
omega_batch
|
|
822
|
-
batch.inside_batch,
|
|
823
|
-
batch.border_batch,
|
|
824
|
-
batch.temporal_batch,
|
|
825
|
-
)
|
|
826
|
-
n = omega_batch.shape[0]
|
|
827
|
-
nt = times_batch.shape[0]
|
|
828
|
-
times_batch = times_batch.reshape(nt, 1)
|
|
829
|
-
|
|
830
|
-
def rep_times(k):
|
|
831
|
-
return jnp.repeat(times_batch, k, axis=0)
|
|
832
|
-
|
|
833
|
-
vmap_in_axes_x_t = (0, 0)
|
|
620
|
+
omega_batch = batch.times_x_inside_batch[:, 1:]
|
|
834
621
|
|
|
835
622
|
# Retrieve the optional eq_params_batch
|
|
836
623
|
# and update eq_params with the latter
|
|
837
624
|
# and update vmap_in_axes
|
|
838
625
|
if batch.param_batch_dict is not None:
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
# feed the eq_params with the batch
|
|
842
|
-
for k in eq_params_batch_dict.keys():
|
|
843
|
-
params["eq_params"][k] = eq_params_batch_dict[k]
|
|
626
|
+
# update eq_params with the batches of generated params
|
|
627
|
+
params = _update_eq_params_dict(params, batch.param_batch_dict)
|
|
844
628
|
|
|
845
629
|
vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
|
|
846
630
|
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
# dynamic part
|
|
852
|
-
params_ = _set_derivatives(params, "dyn_loss", self.derivative_keys)
|
|
853
|
-
if self.dynamic_loss is not None:
|
|
854
|
-
mse_dyn_loss = dynamic_loss_apply(
|
|
855
|
-
self.dynamic_loss.evaluate,
|
|
856
|
-
self.u,
|
|
857
|
-
(times_batch, omega_batch),
|
|
858
|
-
params_,
|
|
859
|
-
vmap_in_axes_x_t + vmap_in_axes_params,
|
|
860
|
-
self.loss_weights["dyn_loss"],
|
|
861
|
-
)
|
|
862
|
-
else:
|
|
863
|
-
mse_dyn_loss = jnp.array(0.0)
|
|
864
|
-
|
|
865
|
-
# normalization part
|
|
866
|
-
params_ = _set_derivatives(params, "norm_loss", self.derivative_keys)
|
|
867
|
-
if self.normalization_loss is not None:
|
|
868
|
-
mse_norm_loss = normalization_loss_apply(
|
|
869
|
-
self.u,
|
|
870
|
-
(times_batch, self.get_norm_samples()),
|
|
871
|
-
params_,
|
|
872
|
-
vmap_in_axes_x_t + vmap_in_axes_params,
|
|
873
|
-
self.int_length,
|
|
874
|
-
self.loss_weights["norm_loss"],
|
|
875
|
-
)
|
|
876
|
-
else:
|
|
877
|
-
mse_norm_loss = jnp.array(0.0)
|
|
878
|
-
|
|
879
|
-
# boundary part
|
|
880
|
-
params_ = _set_derivatives(params, "boundary_loss", self.derivative_keys)
|
|
881
|
-
if self.omega_boundary_fun is not None:
|
|
882
|
-
mse_boundary_loss = boundary_condition_apply(
|
|
883
|
-
self.u,
|
|
884
|
-
batch,
|
|
885
|
-
params_,
|
|
886
|
-
self.omega_boundary_fun,
|
|
887
|
-
self.omega_boundary_condition,
|
|
888
|
-
self.omega_boundary_dim,
|
|
889
|
-
self.loss_weights["boundary_loss"],
|
|
890
|
-
)
|
|
891
|
-
else:
|
|
892
|
-
mse_boundary_loss = jnp.array(0.0)
|
|
631
|
+
# For mse_dyn_loss, mse_norm_loss, mse_boundary_loss,
|
|
632
|
+
# mse_observation_loss we use the evaluate from parent class
|
|
633
|
+
partial_mse, partial_mse_terms = super().evaluate(params, batch)
|
|
893
634
|
|
|
894
635
|
# initial condition
|
|
895
|
-
params_ = _set_derivatives(params, "initial_condition", self.derivative_keys)
|
|
896
636
|
if self.initial_condition_fun is not None:
|
|
897
637
|
mse_initial_condition = initial_condition_apply(
|
|
898
638
|
self.u,
|
|
899
639
|
omega_batch,
|
|
900
|
-
|
|
640
|
+
_set_derivatives(params, self.derivative_keys.initial_condition),
|
|
901
641
|
(0,) + vmap_in_axes_params,
|
|
902
642
|
self.initial_condition_fun,
|
|
903
|
-
|
|
904
|
-
self.loss_weights
|
|
643
|
+
omega_batch.shape[0],
|
|
644
|
+
self.loss_weights.initial_condition,
|
|
905
645
|
)
|
|
906
646
|
else:
|
|
907
647
|
mse_initial_condition = jnp.array(0.0)
|
|
908
648
|
|
|
909
|
-
# Observation mse
|
|
910
|
-
if batch.obs_batch_dict is not None:
|
|
911
|
-
# update params with the batches of observed params
|
|
912
|
-
params = _update_eq_params_dict(params, batch.obs_batch_dict["eq_params"])
|
|
913
|
-
|
|
914
|
-
params_ = _set_derivatives(params, "observations", self.derivative_keys)
|
|
915
|
-
mse_observation_loss = observations_loss_apply(
|
|
916
|
-
self.u,
|
|
917
|
-
(
|
|
918
|
-
batch.obs_batch_dict["pinn_in"][:, 0:1],
|
|
919
|
-
batch.obs_batch_dict["pinn_in"][:, 1:],
|
|
920
|
-
),
|
|
921
|
-
params_,
|
|
922
|
-
vmap_in_axes_x_t + vmap_in_axes_params,
|
|
923
|
-
batch.obs_batch_dict["val"],
|
|
924
|
-
self.loss_weights["observations"],
|
|
925
|
-
self.obs_slice,
|
|
926
|
-
)
|
|
927
|
-
else:
|
|
928
|
-
mse_observation_loss = jnp.array(0.0)
|
|
929
|
-
|
|
930
|
-
# Sobolev regularization
|
|
931
|
-
params_ = _set_derivatives(params, "sobolev", self.derivative_keys)
|
|
932
|
-
if self.sobolev_reg is not None:
|
|
933
|
-
mse_sobolev_loss = sobolev_reg_apply(
|
|
934
|
-
self.u,
|
|
935
|
-
(omega_batch, times_batch),
|
|
936
|
-
params_,
|
|
937
|
-
vmap_in_axes_x_t + vmap_in_axes_params,
|
|
938
|
-
self.sobolev_reg,
|
|
939
|
-
self.loss_weights["sobolev"],
|
|
940
|
-
)
|
|
941
|
-
else:
|
|
942
|
-
mse_sobolev_loss = jnp.array(0.0)
|
|
943
|
-
|
|
944
649
|
# total loss
|
|
945
|
-
total_loss =
|
|
946
|
-
mse_dyn_loss
|
|
947
|
-
+ mse_norm_loss
|
|
948
|
-
+ mse_boundary_loss
|
|
949
|
-
+ mse_initial_condition
|
|
950
|
-
+ mse_observation_loss
|
|
951
|
-
+ mse_sobolev_loss
|
|
952
|
-
)
|
|
953
|
-
|
|
954
|
-
return total_loss, (
|
|
955
|
-
{
|
|
956
|
-
"dyn_loss": mse_dyn_loss,
|
|
957
|
-
"norm_loss": mse_norm_loss,
|
|
958
|
-
"boundary_loss": mse_boundary_loss,
|
|
959
|
-
"initial_condition": mse_initial_condition,
|
|
960
|
-
"observations": mse_observation_loss,
|
|
961
|
-
"sobolev": mse_sobolev_loss,
|
|
962
|
-
}
|
|
963
|
-
)
|
|
650
|
+
total_loss = partial_mse + mse_initial_condition
|
|
964
651
|
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
"u": self.u,
|
|
969
|
-
"dynamic_loss": self.dynamic_loss,
|
|
970
|
-
"derivative_keys": self.derivative_keys,
|
|
971
|
-
"omega_boundary_fun": self.omega_boundary_fun,
|
|
972
|
-
"omega_boundary_condition": self.omega_boundary_condition,
|
|
973
|
-
"omega_boundary_dim": self.omega_boundary_dim,
|
|
974
|
-
"initial_condition_fun": self.initial_condition_fun,
|
|
975
|
-
"norm_borders": self.norm_borders,
|
|
976
|
-
"sobolev_m": self.sobolev_m,
|
|
977
|
-
"obs_slice": self.obs_slice,
|
|
652
|
+
return total_loss, {
|
|
653
|
+
**partial_mse_terms,
|
|
654
|
+
"initial_condition": mse_initial_condition,
|
|
978
655
|
}
|
|
979
|
-
return (children, aux_data)
|
|
980
|
-
|
|
981
|
-
@classmethod
|
|
982
|
-
def tree_unflatten(cls, aux_data, children):
|
|
983
|
-
(norm_key, norm_samples, loss_weights) = children
|
|
984
|
-
pls = cls(
|
|
985
|
-
aux_data["u"],
|
|
986
|
-
loss_weights,
|
|
987
|
-
aux_data["dynamic_loss"],
|
|
988
|
-
aux_data["derivative_keys"],
|
|
989
|
-
aux_data["omega_boundary_fun"],
|
|
990
|
-
aux_data["omega_boundary_condition"],
|
|
991
|
-
aux_data["omega_boundary_dim"],
|
|
992
|
-
aux_data["initial_condition_fun"],
|
|
993
|
-
norm_key,
|
|
994
|
-
aux_data["norm_borders"],
|
|
995
|
-
norm_samples,
|
|
996
|
-
aux_data["sobolev_m"],
|
|
997
|
-
aux_data["obs_slice"],
|
|
998
|
-
)
|
|
999
|
-
return pls
|
|
1000
656
|
|
|
1001
657
|
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
"""
|
|
658
|
+
class SystemLossPDE(eqx.Module):
|
|
659
|
+
r"""
|
|
1005
660
|
Class to implement a system of PDEs.
|
|
1006
661
|
The goal is to give maximum freedom to the user. The class is created with
|
|
1007
662
|
a dict of dynamic loss, and dictionaries of all the objects that are used
|
|
@@ -1015,190 +670,186 @@ class SystemLossPDE:
|
|
|
1015
670
|
Indeed, these dictionaries (except `dynamic_loss_dict`) are tied to one
|
|
1016
671
|
solution.
|
|
1017
672
|
|
|
1018
|
-
|
|
1019
|
-
|
|
673
|
+
Parameters
|
|
674
|
+
----------
|
|
675
|
+
u_dict : Dict[str, eqx.Module]
|
|
676
|
+
dict of PINNs
|
|
677
|
+
loss_weights : LossWeightsPDEDict
|
|
678
|
+
A dictionary of LossWeightsODE
|
|
679
|
+
derivative_keys_dict : Dict[str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio], default=None
|
|
680
|
+
A dictionnary of DerivativeKeysPDEStatio or DerivativeKeysPDENonStatio
|
|
681
|
+
specifying what field of `params`
|
|
682
|
+
should be used during gradient computations for each of the terms of
|
|
683
|
+
the total loss, for each of the loss in the system. Default is
|
|
684
|
+
`"nn_params`" everywhere.
|
|
685
|
+
dynamic_loss_dict : Dict[str, PDEStatio | PDENonStatio]
|
|
686
|
+
A dict of dynamic part of the loss, basically the differential
|
|
687
|
+
operator $\mathcal{N}[u](t, x)$ or $\mathcal{N}[u](x)$.
|
|
688
|
+
key_dict : Dict[str, Key], default=None
|
|
689
|
+
A dictionary of JAX PRNG keys. The dictionary keys of key_dict must
|
|
690
|
+
match that of u_dict. See LossPDEStatio or LossPDENonStatio for
|
|
691
|
+
more details.
|
|
692
|
+
omega_boundary_fun_dict : Dict[str, Callable | Dict[str, Callable] | None], default=None
|
|
693
|
+
A dict of of function or of dict of functions or of None
|
|
694
|
+
(see doc for `omega_boundary_fun` in
|
|
695
|
+
LossPDEStatio or LossPDENonStatio). Default is None.
|
|
696
|
+
Must share the keys of `u_dict`.
|
|
697
|
+
omega_boundary_condition_dict : Dict[str, str | Dict[str, str] | None], default=None
|
|
698
|
+
A dict of strings or of dict of strings or of None
|
|
699
|
+
(see doc for `omega_boundary_condition_dict` in
|
|
700
|
+
LossPDEStatio or LossPDENonStatio). Default is None.
|
|
701
|
+
Must share the keys of `u_dict`
|
|
702
|
+
omega_boundary_dim_dict : Dict[str, slice | Dict[str, slice] | None], default=None
|
|
703
|
+
A dict of slices or of dict of slices or of None
|
|
704
|
+
(see doc for `omega_boundary_dim` in
|
|
705
|
+
LossPDEStatio or LossPDENonStatio). Default is None.
|
|
706
|
+
Must share the keys of `u_dict`
|
|
707
|
+
initial_condition_fun_dict : Dict[str, Callable | None], default=None
|
|
708
|
+
A dict of functions representing the temporal initial condition (None
|
|
709
|
+
value is possible). If None
|
|
710
|
+
(default) then no temporal boundary condition is applied
|
|
711
|
+
Must share the keys of `u_dict`
|
|
712
|
+
norm_samples_dict : Dict[str, Float[Array, "nb_norm_samples dimension"] | None, default=None
|
|
713
|
+
A dict of fixed sample point in the space over which to compute the
|
|
714
|
+
normalization constant. Default is None
|
|
715
|
+
Must share the keys of `u_dict`
|
|
716
|
+
norm_int_length_dict : Dict[str, float | None] | None, default=None
|
|
717
|
+
A dict of Float. The domain area
|
|
718
|
+
(or interval length in 1D) upon which we perform the numerical
|
|
719
|
+
integration for each element of u_dict.
|
|
720
|
+
Default is None
|
|
721
|
+
Must share the keys of `u_dict`
|
|
722
|
+
obs_slice_dict : Dict[str, slice | None] | None, default=None
|
|
723
|
+
dict of obs_slice, with keys from `u_dict` to designate the
|
|
724
|
+
output(s) channels that are forced to observed values, for each
|
|
725
|
+
PINNs. Default is None. But if a value is given, all the entries of
|
|
726
|
+
`u_dict` must be represented here with default value `jnp.s_[...]`
|
|
727
|
+
if no particular slice is to be given
|
|
728
|
+
|
|
1020
729
|
"""
|
|
1021
730
|
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
)
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
`omega_boundary_condition_dict` in
|
|
1074
|
-
LossPDEStatio or LossPDENonStatio). Default is None.
|
|
1075
|
-
Must share the keys of `u_dict`
|
|
1076
|
-
omega_boundary_dim_dict
|
|
1077
|
-
A dict of dict of slices (see doc for `omega_boundary_dim` in
|
|
1078
|
-
LossPDEStatio or LossPDENonStatio). Default is None.
|
|
1079
|
-
Must share the keys of `u_dict`
|
|
1080
|
-
initial_condition_fun_dict
|
|
1081
|
-
A dict of functions representing the temporal initial condition. If None
|
|
1082
|
-
(default) then no temporal boundary condition is applied
|
|
1083
|
-
Must share the keys of `u_dict`
|
|
1084
|
-
norm_key_dict
|
|
1085
|
-
A dict of Jax random keys to draw samples in for the Monte Carlo computation
|
|
1086
|
-
of the normalization constant. Default is None
|
|
1087
|
-
Must share the keys of `u_dict`
|
|
1088
|
-
norm_borders_dict
|
|
1089
|
-
A dict of tuples of (min, max) of the boundaray values of the space over which
|
|
1090
|
-
to integrate in the computation of the normalization constant.
|
|
1091
|
-
A list of tuple for higher dimensional problems. Default None.
|
|
1092
|
-
Must share the keys of `u_dict`
|
|
1093
|
-
norm_samples_dict
|
|
1094
|
-
A dict of fixed sample point in the space over which to compute the
|
|
1095
|
-
normalization constant. Default is None
|
|
1096
|
-
Must share the keys of `u_dict`
|
|
1097
|
-
sobolev_m
|
|
1098
|
-
Default is None. A dictionary of integers, one per key which must
|
|
1099
|
-
match `u_dict`.
|
|
1100
|
-
It corresponds to the Sobolev regularization order as proposed in
|
|
1101
|
-
*Convergence and error analysis of PINNs*,
|
|
1102
|
-
Doumeche et al., 2023, https://arxiv.org/pdf/2305.01240.pdf
|
|
1103
|
-
obs_slice_dict
|
|
1104
|
-
dict of obs_slice, with keys from `u_dict` to designate the
|
|
1105
|
-
output(s) channels that are forced to observed values, for each
|
|
1106
|
-
PINNs. Default is None. But if a value is given, all the entries of
|
|
1107
|
-
`u_dict` must be represented here with default value `jnp.s_[...]`
|
|
1108
|
-
if no particular slice is to be given
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
Raises
|
|
1112
|
-
------
|
|
1113
|
-
ValueError
|
|
1114
|
-
if initial condition is not a dict of tuple
|
|
1115
|
-
ValueError
|
|
1116
|
-
if the dictionaries that should share the keys of u_dict do not
|
|
1117
|
-
"""
|
|
731
|
+
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
732
|
+
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
733
|
+
u_dict: Dict[str, eqx.Module]
|
|
734
|
+
dynamic_loss_dict: Dict[str, PDEStatio | PDENonStatio]
|
|
735
|
+
key_dict: Dict[str, Key] | None = eqx.field(kw_only=True, default=None)
|
|
736
|
+
derivative_keys_dict: Dict[
|
|
737
|
+
str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio | None
|
|
738
|
+
] = eqx.field(kw_only=True, default=None)
|
|
739
|
+
omega_boundary_fun_dict: Dict[str, Callable | Dict[str, Callable] | None] | None = (
|
|
740
|
+
eqx.field(kw_only=True, default=None, static=True)
|
|
741
|
+
)
|
|
742
|
+
omega_boundary_condition_dict: Dict[str, str | Dict[str, str] | None] | None = (
|
|
743
|
+
eqx.field(kw_only=True, default=None, static=True)
|
|
744
|
+
)
|
|
745
|
+
omega_boundary_dim_dict: Dict[str, slice | Dict[str, slice] | None] | None = (
|
|
746
|
+
eqx.field(kw_only=True, default=None, static=True)
|
|
747
|
+
)
|
|
748
|
+
initial_condition_fun_dict: Dict[str, Callable | None] | None = eqx.field(
|
|
749
|
+
kw_only=True, default=None, static=True
|
|
750
|
+
)
|
|
751
|
+
norm_samples_dict: Dict[str, Float[Array, "nb_norm_samples dimension"]] | None = (
|
|
752
|
+
eqx.field(kw_only=True, default=None)
|
|
753
|
+
)
|
|
754
|
+
norm_int_length_dict: Dict[str, float | None] | None = eqx.field(
|
|
755
|
+
kw_only=True, default=None
|
|
756
|
+
)
|
|
757
|
+
obs_slice_dict: Dict[str, slice | None] | None = eqx.field(
|
|
758
|
+
kw_only=True, default=None, static=True
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
# For the user loss_weights are passed as a LossWeightsPDEDict (with internal
|
|
762
|
+
# dictionary having keys in u_dict and / or dynamic_loss_dict)
|
|
763
|
+
loss_weights: InitVar[LossWeightsPDEDict | None] = eqx.field(
|
|
764
|
+
kw_only=True, default=None
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
# following have init=False and are set in the __post_init__
|
|
768
|
+
u_constraints_dict: Dict[str, LossPDEStatio | LossPDENonStatio] = eqx.field(
|
|
769
|
+
init=False
|
|
770
|
+
)
|
|
771
|
+
derivative_keys_u_dict: Dict[
|
|
772
|
+
str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio
|
|
773
|
+
] = eqx.field(init=False)
|
|
774
|
+
derivative_keys_dyn_loss_dict: Dict[
|
|
775
|
+
str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio
|
|
776
|
+
] = eqx.field(init=False)
|
|
777
|
+
u_dict_with_none: Dict[str, None] = eqx.field(init=False)
|
|
778
|
+
# internally the loss weights are handled with a dictionary
|
|
779
|
+
_loss_weights: Dict[str, dict] = eqx.field(init=False)
|
|
780
|
+
|
|
781
|
+
def __post_init__(self, loss_weights):
|
|
1118
782
|
# a dictionary that will be useful at different places
|
|
1119
|
-
self.u_dict_with_none = {k: None for k in u_dict.keys()}
|
|
783
|
+
self.u_dict_with_none = {k: None for k in self.u_dict.keys()}
|
|
1120
784
|
# First, for all the optional dict,
|
|
1121
785
|
# if the user did not provide at all this optional argument,
|
|
1122
786
|
# we make sure there is a null ponderating loss_weight and we
|
|
1123
787
|
# create a dummy dict with the required keys and all the values to
|
|
1124
788
|
# None
|
|
1125
|
-
if
|
|
789
|
+
if self.key_dict is None:
|
|
790
|
+
self.key_dict = self.u_dict_with_none
|
|
791
|
+
if self.omega_boundary_fun_dict is None:
|
|
1126
792
|
self.omega_boundary_fun_dict = self.u_dict_with_none
|
|
1127
|
-
|
|
1128
|
-
self.omega_boundary_fun_dict = omega_boundary_fun_dict
|
|
1129
|
-
if omega_boundary_condition_dict is None:
|
|
793
|
+
if self.omega_boundary_condition_dict is None:
|
|
1130
794
|
self.omega_boundary_condition_dict = self.u_dict_with_none
|
|
1131
|
-
|
|
1132
|
-
self.omega_boundary_condition_dict = omega_boundary_condition_dict
|
|
1133
|
-
if omega_boundary_dim_dict is None:
|
|
795
|
+
if self.omega_boundary_dim_dict is None:
|
|
1134
796
|
self.omega_boundary_dim_dict = self.u_dict_with_none
|
|
1135
|
-
|
|
1136
|
-
self.omega_boundary_dim_dict = omega_boundary_dim_dict
|
|
1137
|
-
if initial_condition_fun_dict is None:
|
|
797
|
+
if self.initial_condition_fun_dict is None:
|
|
1138
798
|
self.initial_condition_fun_dict = self.u_dict_with_none
|
|
1139
|
-
|
|
1140
|
-
self.initial_condition_fun_dict = initial_condition_fun_dict
|
|
1141
|
-
if norm_key_dict is None:
|
|
1142
|
-
self.norm_key_dict = self.u_dict_with_none
|
|
1143
|
-
else:
|
|
1144
|
-
self.norm_key_dict = norm_key_dict
|
|
1145
|
-
if norm_borders_dict is None:
|
|
1146
|
-
self.norm_borders_dict = self.u_dict_with_none
|
|
1147
|
-
else:
|
|
1148
|
-
self.norm_borders_dict = norm_borders_dict
|
|
1149
|
-
if norm_samples_dict is None:
|
|
799
|
+
if self.norm_samples_dict is None:
|
|
1150
800
|
self.norm_samples_dict = self.u_dict_with_none
|
|
1151
|
-
|
|
1152
|
-
self.
|
|
1153
|
-
if
|
|
1154
|
-
self.
|
|
1155
|
-
|
|
1156
|
-
self.sobolev_m_dict = sobolev_m_dict
|
|
1157
|
-
if obs_slice_dict is None:
|
|
1158
|
-
self.obs_slice_dict = {k: jnp.s_[...] for k in u_dict.keys()}
|
|
1159
|
-
else:
|
|
1160
|
-
self.obs_slice_dict = obs_slice_dict
|
|
1161
|
-
if u_dict.keys() != obs_slice_dict.keys():
|
|
801
|
+
if self.norm_int_length_dict is None:
|
|
802
|
+
self.norm_int_length_dict = self.u_dict_with_none
|
|
803
|
+
if self.obs_slice_dict is None:
|
|
804
|
+
self.obs_slice_dict = {k: jnp.s_[...] for k in self.u_dict.keys()}
|
|
805
|
+
if self.u_dict.keys() != self.obs_slice_dict.keys():
|
|
1162
806
|
raise ValueError("obs_slice_dict should have same keys as u_dict")
|
|
1163
|
-
if derivative_keys_dict is None:
|
|
807
|
+
if self.derivative_keys_dict is None:
|
|
1164
808
|
self.derivative_keys_dict = {
|
|
1165
809
|
k: None
|
|
1166
|
-
for k in set(
|
|
810
|
+
for k in set(
|
|
811
|
+
list(self.dynamic_loss_dict.keys()) + list(self.u_dict.keys())
|
|
812
|
+
)
|
|
1167
813
|
}
|
|
1168
814
|
# set() because we can have duplicate entries and in this case we
|
|
1169
815
|
# say it corresponds to the same derivative_keys_dict entry
|
|
1170
|
-
|
|
1171
|
-
|
|
816
|
+
# we need both because the constraints (all but dyn_loss) will be
|
|
817
|
+
# done by iterating on u_dict while the dyn_loss will be by
|
|
818
|
+
# iterating on dynamic_loss_dict. So each time we will require dome
|
|
819
|
+
# derivative_keys_dict
|
|
820
|
+
|
|
1172
821
|
# but then if the user did not provide anything, we must at least have
|
|
1173
822
|
# a default value for the dynamic_loss_dict keys entries in
|
|
1174
823
|
# self.derivative_keys_dict since the computation of dynamic losses is
|
|
1175
|
-
# made without create a
|
|
824
|
+
# made without create a loss object that would provide the
|
|
1176
825
|
# default values
|
|
1177
|
-
for k in dynamic_loss_dict.keys():
|
|
826
|
+
for k in self.dynamic_loss_dict.keys():
|
|
1178
827
|
if self.derivative_keys_dict[k] is None:
|
|
1179
|
-
|
|
828
|
+
try:
|
|
829
|
+
if self.u_dict[k].eq_type == "statio_PDE":
|
|
830
|
+
self.derivative_keys_dict[k] = DerivativeKeysPDEStatio()
|
|
831
|
+
else:
|
|
832
|
+
self.derivative_keys_dict[k] = DerivativeKeysPDENonStatio()
|
|
833
|
+
except KeyError: # We are in a key that is not in u_dict but in
|
|
834
|
+
# dynamic_loss_dict
|
|
835
|
+
if isinstance(self.dynamic_loss_dict[k], PDEStatio):
|
|
836
|
+
self.derivative_keys_dict[k] = DerivativeKeysPDEStatio()
|
|
837
|
+
else:
|
|
838
|
+
self.derivative_keys_dict[k] = DerivativeKeysPDENonStatio()
|
|
1180
839
|
|
|
1181
840
|
# Second we make sure that all the dicts (except dynamic_loss_dict) have the same keys
|
|
1182
841
|
if (
|
|
1183
|
-
u_dict.keys() !=
|
|
1184
|
-
or u_dict.keys() != self.omega_boundary_fun_dict.keys()
|
|
1185
|
-
or u_dict.keys() != self.omega_boundary_condition_dict.keys()
|
|
1186
|
-
or u_dict.keys() != self.omega_boundary_dim_dict.keys()
|
|
1187
|
-
or u_dict.keys() != self.initial_condition_fun_dict.keys()
|
|
1188
|
-
or u_dict.keys() != self.
|
|
1189
|
-
or u_dict.keys() != self.
|
|
1190
|
-
or u_dict.keys() != self.norm_samples_dict.keys()
|
|
1191
|
-
or u_dict.keys() != self.sobolev_m_dict.keys()
|
|
842
|
+
self.u_dict.keys() != self.key_dict.keys()
|
|
843
|
+
or self.u_dict.keys() != self.omega_boundary_fun_dict.keys()
|
|
844
|
+
or self.u_dict.keys() != self.omega_boundary_condition_dict.keys()
|
|
845
|
+
or self.u_dict.keys() != self.omega_boundary_dim_dict.keys()
|
|
846
|
+
or self.u_dict.keys() != self.initial_condition_fun_dict.keys()
|
|
847
|
+
or self.u_dict.keys() != self.norm_samples_dict.keys()
|
|
848
|
+
or self.u_dict.keys() != self.norm_int_length_dict.keys()
|
|
1192
849
|
):
|
|
1193
850
|
raise ValueError("All the dicts concerning the PINNs should have same keys")
|
|
1194
851
|
|
|
1195
|
-
self.
|
|
1196
|
-
self.u_dict = u_dict
|
|
1197
|
-
# TODO nn_type should become a class attribute now that we have PINN
|
|
1198
|
-
# class and SPINNs class
|
|
1199
|
-
self.nn_type_dict = nn_type_dict
|
|
1200
|
-
|
|
1201
|
-
self.loss_weights = loss_weights # This calls the setter
|
|
852
|
+
self._loss_weights = self.set_loss_weights(loss_weights)
|
|
1202
853
|
|
|
1203
854
|
# Third, in order not to benefit from LossPDEStatio and
|
|
1204
855
|
# LossPDENonStatio and in order to factorize code, we create internally
|
|
@@ -1206,52 +857,51 @@ class SystemLossPDE:
|
|
|
1206
857
|
# We will not use the dynamic loss term
|
|
1207
858
|
self.u_constraints_dict = {}
|
|
1208
859
|
for i in self.u_dict.keys():
|
|
1209
|
-
if self.
|
|
860
|
+
if self.u_dict[i].eq_type == "statio_PDE":
|
|
1210
861
|
self.u_constraints_dict[i] = LossPDEStatio(
|
|
1211
|
-
u=u_dict[i],
|
|
1212
|
-
loss_weights=
|
|
1213
|
-
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
862
|
+
u=self.u_dict[i],
|
|
863
|
+
loss_weights=LossWeightsPDENonStatio(
|
|
864
|
+
dyn_loss=0.0,
|
|
865
|
+
norm_loss=1.0,
|
|
866
|
+
boundary_loss=1.0,
|
|
867
|
+
observations=1.0,
|
|
868
|
+
initial_condition=1.0,
|
|
869
|
+
),
|
|
1219
870
|
dynamic_loss=None,
|
|
871
|
+
key=self.key_dict[i],
|
|
1220
872
|
derivative_keys=self.derivative_keys_dict[i],
|
|
1221
873
|
omega_boundary_fun=self.omega_boundary_fun_dict[i],
|
|
1222
874
|
omega_boundary_condition=self.omega_boundary_condition_dict[i],
|
|
1223
875
|
omega_boundary_dim=self.omega_boundary_dim_dict[i],
|
|
1224
|
-
norm_key=self.norm_key_dict[i],
|
|
1225
|
-
norm_borders=self.norm_borders_dict[i],
|
|
1226
876
|
norm_samples=self.norm_samples_dict[i],
|
|
1227
|
-
|
|
877
|
+
norm_int_length=self.norm_int_length_dict[i],
|
|
1228
878
|
obs_slice=self.obs_slice_dict[i],
|
|
1229
879
|
)
|
|
1230
|
-
elif self.
|
|
880
|
+
elif self.u_dict[i].eq_type == "nonstatio_PDE":
|
|
1231
881
|
self.u_constraints_dict[i] = LossPDENonStatio(
|
|
1232
|
-
u=u_dict[i],
|
|
1233
|
-
loss_weights=
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
},
|
|
882
|
+
u=self.u_dict[i],
|
|
883
|
+
loss_weights=LossWeightsPDENonStatio(
|
|
884
|
+
dyn_loss=0.0,
|
|
885
|
+
norm_loss=1.0,
|
|
886
|
+
boundary_loss=1.0,
|
|
887
|
+
observations=1.0,
|
|
888
|
+
initial_condition=1.0,
|
|
889
|
+
),
|
|
1241
890
|
dynamic_loss=None,
|
|
891
|
+
key=self.key_dict[i],
|
|
1242
892
|
derivative_keys=self.derivative_keys_dict[i],
|
|
1243
893
|
omega_boundary_fun=self.omega_boundary_fun_dict[i],
|
|
1244
894
|
omega_boundary_condition=self.omega_boundary_condition_dict[i],
|
|
1245
895
|
omega_boundary_dim=self.omega_boundary_dim_dict[i],
|
|
1246
896
|
initial_condition_fun=self.initial_condition_fun_dict[i],
|
|
1247
|
-
norm_key=self.norm_key_dict[i],
|
|
1248
|
-
norm_borders=self.norm_borders_dict[i],
|
|
1249
897
|
norm_samples=self.norm_samples_dict[i],
|
|
1250
|
-
|
|
898
|
+
norm_int_length=self.norm_int_length_dict[i],
|
|
899
|
+
obs_slice=self.obs_slice_dict[i],
|
|
1251
900
|
)
|
|
1252
901
|
else:
|
|
1253
902
|
raise ValueError(
|
|
1254
|
-
|
|
903
|
+
"Wrong value for self.u_dict[i].eq_type[i], "
|
|
904
|
+
f"got {self.u_dict[i].eq_type[i]}"
|
|
1255
905
|
)
|
|
1256
906
|
|
|
1257
907
|
# for convenience in the tree_map of evaluate,
|
|
@@ -1267,34 +917,38 @@ class SystemLossPDE:
|
|
|
1267
917
|
|
|
1268
918
|
# also make sure we only have PINNs or SPINNs
|
|
1269
919
|
if not (
|
|
1270
|
-
all(isinstance(value, PINN) for value in u_dict.values())
|
|
1271
|
-
or all(isinstance(value, SPINN) for value in u_dict.values())
|
|
920
|
+
all(isinstance(value, PINN) for value in self.u_dict.values())
|
|
921
|
+
or all(isinstance(value, SPINN) for value in self.u_dict.values())
|
|
1272
922
|
):
|
|
1273
923
|
raise ValueError(
|
|
1274
924
|
"We only accept dictionary of PINNs or dictionary of SPINNs"
|
|
1275
925
|
)
|
|
1276
926
|
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
927
|
+
def set_loss_weights(
|
|
928
|
+
self, loss_weights_init: LossWeightsPDEDict
|
|
929
|
+
) -> dict[str, dict]:
|
|
930
|
+
"""
|
|
931
|
+
This rather complex function enables the user to specify a simple
|
|
932
|
+
loss_weights=LossWeightsPDEDict(dyn_loss=1., initial_condition=Tmax)
|
|
933
|
+
for ponderating values being applied to all the equations of the
|
|
934
|
+
system... So all the transformations are handled here
|
|
935
|
+
"""
|
|
936
|
+
_loss_weights = {}
|
|
937
|
+
for k in fields(loss_weights_init):
|
|
938
|
+
v = getattr(loss_weights_init, k.name)
|
|
1285
939
|
if isinstance(v, dict):
|
|
1286
|
-
for
|
|
940
|
+
for vv in v.keys():
|
|
1287
941
|
if not isinstance(vv, (int, float)) and not (
|
|
1288
|
-
isinstance(vv,
|
|
942
|
+
isinstance(vv, Array)
|
|
1289
943
|
and ((vv.shape == (1,) or len(vv.shape) == 0))
|
|
1290
944
|
):
|
|
1291
945
|
# TODO improve that
|
|
1292
946
|
raise ValueError(
|
|
1293
947
|
f"loss values cannot be vectorial here, got {vv}"
|
|
1294
948
|
)
|
|
1295
|
-
if k == "dyn_loss":
|
|
949
|
+
if k.name == "dyn_loss":
|
|
1296
950
|
if v.keys() == self.dynamic_loss_dict.keys():
|
|
1297
|
-
|
|
951
|
+
_loss_weights[k.name] = v
|
|
1298
952
|
else:
|
|
1299
953
|
raise ValueError(
|
|
1300
954
|
"Keys in nested dictionary of loss_weights"
|
|
@@ -1302,51 +956,36 @@ class SystemLossPDE:
|
|
|
1302
956
|
)
|
|
1303
957
|
else:
|
|
1304
958
|
if v.keys() == self.u_dict.keys():
|
|
1305
|
-
|
|
959
|
+
_loss_weights[k.name] = v
|
|
1306
960
|
else:
|
|
1307
961
|
raise ValueError(
|
|
1308
962
|
"Keys in nested dictionary of loss_weights"
|
|
1309
963
|
" do not match u_dict keys"
|
|
1310
964
|
)
|
|
965
|
+
if v is None:
|
|
966
|
+
_loss_weights[k.name] = {kk: 0 for kk in self.u_dict.keys()}
|
|
1311
967
|
else:
|
|
1312
968
|
if not isinstance(v, (int, float)) and not (
|
|
1313
|
-
isinstance(v,
|
|
1314
|
-
and ((v.shape == (1,) or len(v.shape) == 0))
|
|
969
|
+
isinstance(v, Array) and ((v.shape == (1,) or len(v.shape) == 0))
|
|
1315
970
|
):
|
|
1316
971
|
# TODO improve that
|
|
1317
972
|
raise ValueError(f"loss values cannot be vectorial here, got {v}")
|
|
1318
|
-
if k == "dyn_loss":
|
|
1319
|
-
|
|
973
|
+
if k.name == "dyn_loss":
|
|
974
|
+
_loss_weights[k.name] = {
|
|
1320
975
|
kk: v for kk in self.dynamic_loss_dict.keys()
|
|
1321
976
|
}
|
|
1322
977
|
else:
|
|
1323
|
-
|
|
1324
|
-
|
|
1325
|
-
if all(v is None for k, v in self.sobolev_m_dict.items()):
|
|
1326
|
-
self._loss_weights["sobolev"] = {k: 0 for k in self.u_dict.keys()}
|
|
1327
|
-
if "observations" not in value.keys():
|
|
1328
|
-
self._loss_weights["observations"] = {k: 0 for k in self.u_dict.keys()}
|
|
1329
|
-
if all(v is None for k, v in self.omega_boundary_fun_dict.items()) or all(
|
|
1330
|
-
v is None for k, v in self.omega_boundary_condition_dict.items()
|
|
1331
|
-
):
|
|
1332
|
-
self._loss_weights["boundary_loss"] = {k: 0 for k in self.u_dict.keys()}
|
|
1333
|
-
if (
|
|
1334
|
-
all(v is None for k, v in self.norm_key_dict.items())
|
|
1335
|
-
or all(v is None for k, v in self.norm_borders_dict.items())
|
|
1336
|
-
or all(v is None for k, v in self.norm_samples_dict.items())
|
|
1337
|
-
):
|
|
1338
|
-
self._loss_weights["norm_loss"] = {k: 0 for k in self.u_dict.keys()}
|
|
1339
|
-
if all(v is None for k, v in self.initial_condition_fun_dict.items()):
|
|
1340
|
-
self._loss_weights["initial_condition"] = {k: 0 for k in self.u_dict.keys()}
|
|
978
|
+
_loss_weights[k.name] = {kk: v for kk in self.u_dict.keys()}
|
|
979
|
+
return _loss_weights
|
|
1341
980
|
|
|
1342
981
|
def __call__(self, *args, **kwargs):
|
|
1343
982
|
return self.evaluate(*args, **kwargs)
|
|
1344
983
|
|
|
1345
984
|
def evaluate(
|
|
1346
985
|
self,
|
|
1347
|
-
params_dict,
|
|
1348
|
-
batch,
|
|
1349
|
-
):
|
|
986
|
+
params_dict: ParamsDict,
|
|
987
|
+
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
988
|
+
) -> tuple[Float[Array, "1"], dict[str, float]]:
|
|
1350
989
|
"""
|
|
1351
990
|
Evaluate the loss function at a batch of points for given parameters.
|
|
1352
991
|
|
|
@@ -1354,12 +993,8 @@ class SystemLossPDE:
|
|
|
1354
993
|
Parameters
|
|
1355
994
|
---------
|
|
1356
995
|
params_dict
|
|
1357
|
-
|
|
1358
|
-
Typically, it is a dictionary of dictionaries of
|
|
1359
|
-
dictionaries: `eq_params` and `nn_params``, respectively the
|
|
1360
|
-
differential equation parameters and the neural network parameter
|
|
996
|
+
Parameters at which the losses of the system are evaluated
|
|
1361
997
|
batch
|
|
1362
|
-
A PDEStatioBatch or PDENonStatioBatch object.
|
|
1363
998
|
Such named tuples are composed of batch of points in the
|
|
1364
999
|
domain, a batch of points in the domain
|
|
1365
1000
|
border, (a batch of time points a for PDENonStatioBatch) and an
|
|
@@ -1367,32 +1002,17 @@ class SystemLossPDE:
|
|
|
1367
1002
|
and an optional additional batch of observed
|
|
1368
1003
|
inputs/outputs/parameters
|
|
1369
1004
|
"""
|
|
1370
|
-
if self.u_dict.keys() != params_dict
|
|
1005
|
+
if self.u_dict.keys() != params_dict.nn_params.keys():
|
|
1371
1006
|
raise ValueError("u_dict and params_dict[nn_params] should have same keys ")
|
|
1372
1007
|
|
|
1373
1008
|
if isinstance(batch, PDEStatioBatch):
|
|
1374
1009
|
omega_batch, _ = batch.inside_batch, batch.border_batch
|
|
1375
|
-
n = omega_batch.shape[0]
|
|
1376
1010
|
vmap_in_axes_x_or_x_t = (0,)
|
|
1377
1011
|
|
|
1378
1012
|
batches = (omega_batch,)
|
|
1379
1013
|
elif isinstance(batch, PDENonStatioBatch):
|
|
1380
|
-
|
|
1381
|
-
|
|
1382
|
-
batch.border_batch,
|
|
1383
|
-
batch.temporal_batch,
|
|
1384
|
-
)
|
|
1385
|
-
n = omega_batch.shape[0]
|
|
1386
|
-
nt = times_batch.shape[0]
|
|
1387
|
-
times_batch = times_batch.reshape(nt, 1)
|
|
1388
|
-
|
|
1389
|
-
def rep_times(k):
|
|
1390
|
-
return jnp.repeat(times_batch, k, axis=0)
|
|
1391
|
-
|
|
1392
|
-
# Moreover...
|
|
1393
|
-
if isinstance(list(self.u_dict.values())[0], PINN):
|
|
1394
|
-
omega_batch = jnp.tile(omega_batch, reps=(nt, 1)) # it is tiled
|
|
1395
|
-
times_batch = rep_times(n) # it is repeated
|
|
1014
|
+
times_batch = batch.times_x_inside_batch[:, 0:1]
|
|
1015
|
+
omega_batch = batch.times_x_inside_batch[:, 1:]
|
|
1396
1016
|
|
|
1397
1017
|
batches = (omega_batch, times_batch)
|
|
1398
1018
|
vmap_in_axes_x_or_x_t = (0, 0)
|
|
@@ -1405,9 +1025,10 @@ class SystemLossPDE:
|
|
|
1405
1025
|
if batch.param_batch_dict is not None:
|
|
1406
1026
|
eq_params_batch_dict = batch.param_batch_dict
|
|
1407
1027
|
|
|
1028
|
+
# TODO
|
|
1408
1029
|
# feed the eq_params with the batch
|
|
1409
1030
|
for k in eq_params_batch_dict.keys():
|
|
1410
|
-
params_dict
|
|
1031
|
+
params_dict.eq_params[k] = eq_params_batch_dict[k]
|
|
1411
1032
|
|
|
1412
1033
|
vmap_in_axes_params = _get_vmap_in_axes_params(
|
|
1413
1034
|
batch.param_batch_dict, params_dict
|
|
@@ -1415,12 +1036,11 @@ class SystemLossPDE:
|
|
|
1415
1036
|
|
|
1416
1037
|
def dyn_loss_for_one_key(dyn_loss, derivative_key, loss_weight):
|
|
1417
1038
|
"""The function used in tree_map"""
|
|
1418
|
-
params_dict_ = _set_derivatives(params_dict, "dyn_loss", derivative_key)
|
|
1419
1039
|
return dynamic_loss_apply(
|
|
1420
1040
|
dyn_loss.evaluate,
|
|
1421
1041
|
self.u_dict,
|
|
1422
1042
|
batches,
|
|
1423
|
-
|
|
1043
|
+
_set_derivatives(params_dict, derivative_key.dyn_loss),
|
|
1424
1044
|
vmap_in_axes_x_or_x_t + vmap_in_axes_params,
|
|
1425
1045
|
loss_weight,
|
|
1426
1046
|
u_type=type(list(self.u_dict.values())[0]),
|
|
@@ -1431,6 +1051,13 @@ class SystemLossPDE:
|
|
|
1431
1051
|
self.dynamic_loss_dict,
|
|
1432
1052
|
self.derivative_keys_dyn_loss_dict,
|
|
1433
1053
|
self._loss_weights["dyn_loss"],
|
|
1054
|
+
is_leaf=lambda x: isinstance(
|
|
1055
|
+
x, (PDEStatio, PDENonStatio)
|
|
1056
|
+
), # before when dynamic losses
|
|
1057
|
+
# where plain (unregister pytree) node classes, we could not traverse
|
|
1058
|
+
# this level. Now that dynamic losses are eqx.Module they can be
|
|
1059
|
+
# traversed by tree map recursion. Hence we need to specify to that
|
|
1060
|
+
# we want to stop at this level
|
|
1434
1061
|
)
|
|
1435
1062
|
mse_dyn_loss = jax.tree_util.tree_reduce(
|
|
1436
1063
|
lambda x, y: x + y, jax.tree_util.tree_leaves(dyn_loss_mse_dict)
|
|
@@ -1445,11 +1072,10 @@ class SystemLossPDE:
|
|
|
1445
1072
|
"boundary_loss": "*",
|
|
1446
1073
|
"observations": "*",
|
|
1447
1074
|
"initial_condition": "*",
|
|
1448
|
-
"sobolev": "*",
|
|
1449
1075
|
}
|
|
1450
1076
|
# we need to do the following for the tree_mapping to work
|
|
1451
1077
|
if batch.obs_batch_dict is None:
|
|
1452
|
-
batch = batch
|
|
1078
|
+
batch = append_obs_batch(batch, self.u_dict_with_none)
|
|
1453
1079
|
total_loss, res_dict = constraints_system_loss_apply(
|
|
1454
1080
|
self.u_constraints_dict,
|
|
1455
1081
|
batch,
|
|
@@ -1462,41 +1088,3 @@ class SystemLossPDE:
|
|
|
1462
1088
|
total_loss += mse_dyn_loss
|
|
1463
1089
|
res_dict["dyn_loss"] += mse_dyn_loss
|
|
1464
1090
|
return total_loss, res_dict
|
|
1465
|
-
|
|
1466
|
-
def tree_flatten(self):
|
|
1467
|
-
children = (
|
|
1468
|
-
self.norm_key_dict,
|
|
1469
|
-
self.norm_samples_dict,
|
|
1470
|
-
self.initial_condition_fun_dict,
|
|
1471
|
-
self._loss_weights,
|
|
1472
|
-
)
|
|
1473
|
-
aux_data = {
|
|
1474
|
-
"u_dict": self.u_dict,
|
|
1475
|
-
"dynamic_loss_dict": self.dynamic_loss_dict,
|
|
1476
|
-
"norm_borders_dict": self.norm_borders_dict,
|
|
1477
|
-
"omega_boundary_fun_dict": self.omega_boundary_fun_dict,
|
|
1478
|
-
"omega_boundary_condition_dict": self.omega_boundary_condition_dict,
|
|
1479
|
-
"nn_type_dict": self.nn_type_dict,
|
|
1480
|
-
"sobolev_m_dict": self.sobolev_m_dict,
|
|
1481
|
-
"derivative_keys_dict": self.derivative_keys_dict,
|
|
1482
|
-
"obs_slice_dict": self.obs_slice_dict,
|
|
1483
|
-
}
|
|
1484
|
-
return (children, aux_data)
|
|
1485
|
-
|
|
1486
|
-
@classmethod
|
|
1487
|
-
def tree_unflatten(cls, aux_data, children):
|
|
1488
|
-
(
|
|
1489
|
-
norm_key_dict,
|
|
1490
|
-
norm_samples_dict,
|
|
1491
|
-
initial_condition_fun_dict,
|
|
1492
|
-
loss_weights,
|
|
1493
|
-
) = children
|
|
1494
|
-
loss_ode = cls(
|
|
1495
|
-
loss_weights=loss_weights,
|
|
1496
|
-
norm_key_dict=norm_key_dict,
|
|
1497
|
-
norm_samples_dict=norm_samples_dict,
|
|
1498
|
-
initial_condition_fun_dict=initial_condition_fun_dict,
|
|
1499
|
-
**aux_data,
|
|
1500
|
-
)
|
|
1501
|
-
|
|
1502
|
-
return loss_ode
|