jinns 0.9.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -0
- jinns/data/_Batchs.py +27 -0
- jinns/data/_DataGenerators.py +904 -1203
- jinns/data/__init__.py +4 -8
- jinns/experimental/__init__.py +0 -2
- jinns/experimental/_diffrax_solver.py +5 -5
- jinns/loss/_DynamicLoss.py +282 -305
- jinns/loss/_DynamicLossAbstract.py +321 -168
- jinns/loss/_LossODE.py +292 -309
- jinns/loss/_LossPDE.py +625 -1010
- jinns/loss/__init__.py +21 -5
- jinns/loss/_boundary_conditions.py +87 -41
- jinns/loss/{_Losses.py → _loss_utils.py} +95 -44
- jinns/loss/_loss_weights.py +59 -0
- jinns/loss/_operators.py +78 -72
- jinns/parameters/__init__.py +6 -0
- jinns/parameters/_derivative_keys.py +94 -0
- jinns/parameters/_params.py +115 -0
- jinns/plot/__init__.py +5 -0
- jinns/{data/_display.py → plot/_plot.py} +98 -75
- jinns/solver/_rar.py +183 -39
- jinns/solver/_solve.py +151 -124
- jinns/utils/__init__.py +3 -9
- jinns/utils/_containers.py +37 -44
- jinns/utils/_hyperpinn.py +224 -119
- jinns/utils/_pinn.py +183 -111
- jinns/utils/_save_load.py +121 -56
- jinns/utils/_spinn.py +113 -86
- jinns/utils/_types.py +64 -0
- jinns/utils/_utils.py +6 -160
- jinns/validation/_validation.py +48 -140
- {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/METADATA +4 -4
- jinns-1.0.0.dist-info/RECORD +38 -0
- {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/WHEEL +1 -1
- jinns/experimental/_sinuspinn.py +0 -135
- jinns/experimental/_spectralpinn.py +0 -87
- jinns/solver/_seq2seq.py +0 -157
- jinns/utils/_optim.py +0 -147
- jinns/utils/_utils_uspinn.py +0 -727
- jinns-0.9.0.dist-info/RECORD +0 -36
- {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/LICENSE +0 -0
- {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossPDE.py
CHANGED
|
@@ -1,29 +1,52 @@
|
|
|
1
|
+
# pylint: disable=unsubscriptable-object, no-member
|
|
1
2
|
"""
|
|
2
3
|
Main module to implement a PDE loss in jinns
|
|
3
4
|
"""
|
|
5
|
+
from __future__ import (
|
|
6
|
+
annotations,
|
|
7
|
+
) # https://docs.python.org/3/library/typing.html#constant
|
|
4
8
|
|
|
9
|
+
import abc
|
|
10
|
+
from dataclasses import InitVar, fields
|
|
11
|
+
from typing import TYPE_CHECKING, Dict, Callable
|
|
5
12
|
import warnings
|
|
6
13
|
import jax
|
|
7
14
|
import jax.numpy as jnp
|
|
8
|
-
|
|
9
|
-
from
|
|
15
|
+
import equinox as eqx
|
|
16
|
+
from jaxtyping import Float, Array, Key, Int
|
|
17
|
+
from jinns.loss._loss_utils import (
|
|
10
18
|
dynamic_loss_apply,
|
|
11
19
|
boundary_condition_apply,
|
|
12
20
|
normalization_loss_apply,
|
|
13
21
|
observations_loss_apply,
|
|
14
|
-
sobolev_reg_apply,
|
|
15
22
|
initial_condition_apply,
|
|
16
23
|
constraints_system_loss_apply,
|
|
17
24
|
)
|
|
18
|
-
from jinns.data._DataGenerators import
|
|
19
|
-
|
|
25
|
+
from jinns.data._DataGenerators import (
|
|
26
|
+
append_obs_batch,
|
|
27
|
+
)
|
|
28
|
+
from jinns.parameters._params import (
|
|
20
29
|
_get_vmap_in_axes_params,
|
|
21
|
-
_set_derivatives,
|
|
22
30
|
_update_eq_params_dict,
|
|
23
31
|
)
|
|
32
|
+
from jinns.parameters._derivative_keys import (
|
|
33
|
+
_set_derivatives,
|
|
34
|
+
DerivativeKeysPDEStatio,
|
|
35
|
+
DerivativeKeysPDENonStatio,
|
|
36
|
+
)
|
|
37
|
+
from jinns.loss._loss_weights import (
|
|
38
|
+
LossWeightsPDEStatio,
|
|
39
|
+
LossWeightsPDENonStatio,
|
|
40
|
+
LossWeightsPDEDict,
|
|
41
|
+
)
|
|
42
|
+
from jinns.loss._DynamicLossAbstract import PDEStatio, PDENonStatio
|
|
24
43
|
from jinns.utils._pinn import PINN
|
|
25
44
|
from jinns.utils._spinn import SPINN
|
|
26
|
-
from jinns.
|
|
45
|
+
from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
if TYPE_CHECKING:
|
|
49
|
+
from jinns.utils._types import *
|
|
27
50
|
|
|
28
51
|
_IMPLEMENTED_BOUNDARY_CONDITIONS = [
|
|
29
52
|
"dirichlet",
|
|
@@ -31,375 +54,156 @@ _IMPLEMENTED_BOUNDARY_CONDITIONS = [
|
|
|
31
54
|
"vonneumann",
|
|
32
55
|
]
|
|
33
56
|
|
|
34
|
-
_LOSS_WEIGHT_KEYS_PDESTATIO = [
|
|
35
|
-
"sobolev",
|
|
36
|
-
"observations",
|
|
37
|
-
"norm_loss",
|
|
38
|
-
"boundary_loss",
|
|
39
|
-
"dyn_loss",
|
|
40
|
-
]
|
|
41
|
-
|
|
42
|
-
_LOSS_WEIGHT_KEYS_PDENONSTATIO = _LOSS_WEIGHT_KEYS_PDESTATIO + ["initial_condition"]
|
|
43
|
-
|
|
44
57
|
|
|
45
|
-
|
|
46
|
-
class LossPDEAbstract:
|
|
58
|
+
class _LossPDEAbstract(eqx.Module):
|
|
47
59
|
"""
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
|
|
63
|
+
loss_weights : LossWeightsPDEStatio | LossWeightsPDENonStatio, default=None
|
|
64
|
+
The loss weights for the differents term : dynamic loss,
|
|
65
|
+
initial condition (if LossWeightsPDENonStatio), boundary conditions if
|
|
66
|
+
any, normalization loss if any and observations if any.
|
|
67
|
+
All fields are set to 1.0 by default.
|
|
68
|
+
derivative_keys : DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio, default=None
|
|
69
|
+
Specify which field of `params` should be differentiated for each
|
|
70
|
+
composant of the total loss. Particularily useful for inverse problems.
|
|
71
|
+
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
72
|
+
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
73
|
+
is `"nn_params"` for each composant of the loss.
|
|
74
|
+
omega_boundary_fun : Callable | Dict[str, Callable], default=None
|
|
75
|
+
The function to be matched in the border condition (can be None) or a
|
|
76
|
+
dictionary of such functions as values and keys as described
|
|
77
|
+
in `omega_boundary_condition`.
|
|
78
|
+
omega_boundary_condition : str | Dict[str, str], default=None
|
|
79
|
+
Either None (no condition, by default), or a string defining
|
|
80
|
+
the boundary condition (Dirichlet or Von Neumann),
|
|
81
|
+
or a dictionary with such strings as values. In this case,
|
|
82
|
+
the keys are the facets and must be in the following order:
|
|
83
|
+
1D -> [“xmin”, “xmax”], 2D -> [“xmin”, “xmax”, “ymin”, “ymax”].
|
|
84
|
+
Note that high order boundaries are currently not implemented.
|
|
85
|
+
A value in the dict can be None, this means we do not enforce
|
|
86
|
+
a particular boundary condition on this facet.
|
|
87
|
+
The facet called “xmin”, resp. “xmax” etc., in 2D,
|
|
88
|
+
refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
|
|
89
|
+
omega_boundary_dim : slice | Dict[str, slice], default=None
|
|
90
|
+
Either None, or a slice object or a dictionary of slice objects as
|
|
91
|
+
values and keys as described in `omega_boundary_condition`.
|
|
92
|
+
`omega_boundary_dim` indicates which dimension(s) of the PINN
|
|
93
|
+
will be forced to match the boundary condition.
|
|
94
|
+
Note that it must be a slice and not an integer
|
|
95
|
+
(but a preprocessing of the user provided argument takes care of it)
|
|
96
|
+
norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
|
|
97
|
+
Fixed sample point in the space over which to compute the
|
|
98
|
+
normalization constant. Default is None.
|
|
99
|
+
norm_int_length : float, default=None
|
|
100
|
+
A float. Must be provided if `norm_samples` is provided. The domain area
|
|
101
|
+
(or interval length in 1D) upon which we perform the numerical
|
|
102
|
+
integration. Default None
|
|
103
|
+
obs_slice : slice, default=None
|
|
104
|
+
slice object specifying the begininning/ending of the PINN output
|
|
105
|
+
that is observed (this is then useful for multidim PINN). Default is None.
|
|
55
106
|
"""
|
|
56
107
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
)
|
|
108
|
+
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
109
|
+
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
110
|
+
# kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
|
|
111
|
+
derivative_keys: DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio | None = (
|
|
112
|
+
eqx.field(kw_only=True, default=None)
|
|
113
|
+
)
|
|
114
|
+
loss_weights: LossWeightsPDEStatio | LossWeightsPDENonStatio | None = eqx.field(
|
|
115
|
+
kw_only=True, default=None
|
|
116
|
+
)
|
|
117
|
+
omega_boundary_fun: Callable | Dict[str, Callable] | None = eqx.field(
|
|
118
|
+
kw_only=True, default=None, static=True
|
|
119
|
+
)
|
|
120
|
+
omega_boundary_condition: str | Dict[str, str] | None = eqx.field(
|
|
121
|
+
kw_only=True, default=None, static=True
|
|
122
|
+
)
|
|
123
|
+
omega_boundary_dim: slice | Dict[str, slice] | None = eqx.field(
|
|
124
|
+
kw_only=True, default=None, static=True
|
|
125
|
+
)
|
|
126
|
+
norm_samples: Float[Array, "nb_norm_samples dimension"] | None = eqx.field(
|
|
127
|
+
kw_only=True, default=None
|
|
128
|
+
)
|
|
129
|
+
norm_int_length: float | None = eqx.field(kw_only=True, default=None)
|
|
130
|
+
obs_slice: slice | None = eqx.field(kw_only=True, default=None, static=True)
|
|
131
|
+
|
|
132
|
+
def __post_init__(self):
|
|
66
133
|
"""
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
u
|
|
70
|
-
the PINN object
|
|
71
|
-
loss_weights
|
|
72
|
-
a dictionary with values used to ponderate each term in the loss
|
|
73
|
-
function. Valid keys are `dyn_loss`, `norm_loss`, `boundary_loss`
|
|
74
|
-
and `observations`
|
|
75
|
-
Note that we can have jnp.arrays with the same dimension of
|
|
76
|
-
`u` which then ponderates each output of `u`
|
|
77
|
-
derivative_keys
|
|
78
|
-
A dict of lists of strings. In the dict, the key must correspond to
|
|
79
|
-
the loss term keywords. Then each of the values must correspond to keys in the parameter
|
|
80
|
-
dictionary (*at top level only of the parameter dictionary*).
|
|
81
|
-
It enables selecting the set of parameters
|
|
82
|
-
with respect to which the gradients of the dynamic
|
|
83
|
-
loss are computed. If nothing is provided, we set ["nn_params"] for all loss term
|
|
84
|
-
keywords, this is what is typically
|
|
85
|
-
done in solving forward problems, when we only estimate the
|
|
86
|
-
equation solution with a PINN. If some loss terms keywords are
|
|
87
|
-
missing we set their value to ["nn_params"] by default for the
|
|
88
|
-
same reason
|
|
89
|
-
norm_key
|
|
90
|
-
Jax random key to draw samples in for the Monte Carlo computation
|
|
91
|
-
of the normalization constant. Default is None
|
|
92
|
-
norm_borders
|
|
93
|
-
tuple of (min, max) of the boundaray values of the space over which
|
|
94
|
-
to integrate in the computation of the normalization constant.
|
|
95
|
-
A list of tuple for higher dimensional problems. Default None.
|
|
96
|
-
norm_samples
|
|
97
|
-
Fixed sample point in the space over which to compute the
|
|
98
|
-
normalization constant. Default is None
|
|
99
|
-
|
|
100
|
-
Raises
|
|
101
|
-
------
|
|
102
|
-
RuntimeError
|
|
103
|
-
When provided an invalid combination of `norm_key`, `norm_borders`
|
|
104
|
-
and `norm_samples`. See note below.
|
|
105
|
-
|
|
106
|
-
**Note:** If `norm_key` and `norm_borders` and `norm_samples` are `None`
|
|
107
|
-
then no normalization loss in enforced.
|
|
108
|
-
If `norm_borders` and `norm_samples` are given while
|
|
109
|
-
`norm_samples` is `None` then samples are drawn at each loss evaluation.
|
|
110
|
-
Otherwise, if `norm_samples` is given, those samples are used.
|
|
134
|
+
Note that neither __init__ or __post_init__ are called when udating a
|
|
135
|
+
Module with eqx.tree_at
|
|
111
136
|
"""
|
|
112
|
-
|
|
113
|
-
self.u = u
|
|
114
|
-
if derivative_keys is None:
|
|
137
|
+
if self.derivative_keys is None:
|
|
115
138
|
# be default we only take gradient wrt nn_params
|
|
116
|
-
derivative_keys =
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
"boundary_loss",
|
|
121
|
-
"norm_loss",
|
|
122
|
-
"initial_condition",
|
|
123
|
-
"observations",
|
|
124
|
-
"sobolev",
|
|
125
|
-
]
|
|
126
|
-
}
|
|
127
|
-
if isinstance(derivative_keys, list):
|
|
128
|
-
# if the user only provided a list, this defines the gradient taken
|
|
129
|
-
# for all the loss entries
|
|
130
|
-
derivative_keys = {
|
|
131
|
-
k: derivative_keys
|
|
132
|
-
for k in [
|
|
133
|
-
"dyn_loss",
|
|
134
|
-
"boundary_loss",
|
|
135
|
-
"norm_loss",
|
|
136
|
-
"initial_condition",
|
|
137
|
-
"observations",
|
|
138
|
-
"sobolev",
|
|
139
|
-
]
|
|
140
|
-
}
|
|
141
|
-
|
|
142
|
-
self.derivative_keys = derivative_keys
|
|
143
|
-
self.loss_weights = loss_weights
|
|
144
|
-
self.norm_borders = norm_borders
|
|
145
|
-
self.norm_key = norm_key
|
|
146
|
-
self.norm_samples = norm_samples
|
|
147
|
-
|
|
148
|
-
if norm_key is None and norm_borders is None and norm_samples is None:
|
|
149
|
-
# if there is None of the 3 above, that means we don't consider
|
|
150
|
-
# normalization loss
|
|
151
|
-
self.normalization_loss = None
|
|
152
|
-
elif (
|
|
153
|
-
norm_key is not None and norm_borders is not None and norm_samples is None
|
|
154
|
-
): # this ordering so that by default priority is to given mc_samples
|
|
155
|
-
self.norm_sample_method = "generate"
|
|
156
|
-
if not isinstance(self.norm_borders[0], tuple):
|
|
157
|
-
self.norm_borders = (self.norm_borders,)
|
|
158
|
-
self.norm_xmin, self.norm_xmax = [], []
|
|
159
|
-
for i, _ in enumerate(self.norm_borders):
|
|
160
|
-
self.norm_xmin.append(self.norm_borders[i][0])
|
|
161
|
-
self.norm_xmax.append(self.norm_borders[i][1])
|
|
162
|
-
self.int_length = jnp.prod(
|
|
163
|
-
jnp.array(
|
|
164
|
-
[
|
|
165
|
-
self.norm_xmax[i] - self.norm_xmin[i]
|
|
166
|
-
for i in range(len(self.norm_borders))
|
|
167
|
-
]
|
|
168
|
-
)
|
|
169
|
-
)
|
|
170
|
-
self.normalization_loss = True
|
|
171
|
-
elif norm_samples is None:
|
|
172
|
-
raise RuntimeError(
|
|
173
|
-
"norm_borders should always provided then either"
|
|
174
|
-
" norm_samples (fixed norm_samples) or norm_key (random norm_samples)"
|
|
175
|
-
" is required."
|
|
176
|
-
)
|
|
177
|
-
else:
|
|
178
|
-
# ok, we are sure we have norm_samples given by the user
|
|
179
|
-
self.norm_sample_method = "user"
|
|
180
|
-
if not isinstance(self.norm_borders[0], tuple):
|
|
181
|
-
self.norm_borders = (self.norm_borders,)
|
|
182
|
-
self.norm_xmin, self.norm_xmax = [], []
|
|
183
|
-
for i, _ in enumerate(self.norm_borders):
|
|
184
|
-
self.norm_xmin.append(self.norm_borders[i][0])
|
|
185
|
-
self.norm_xmax.append(self.norm_borders[i][1])
|
|
186
|
-
self.int_length = jnp.prod(
|
|
187
|
-
jnp.array(
|
|
188
|
-
[
|
|
189
|
-
self.norm_xmax[i] - self.norm_xmin[i]
|
|
190
|
-
for i in range(len(self.norm_borders))
|
|
191
|
-
]
|
|
192
|
-
)
|
|
139
|
+
self.derivative_keys = (
|
|
140
|
+
DerivativeKeysPDENonStatio()
|
|
141
|
+
if isinstance(self, LossPDENonStatio)
|
|
142
|
+
else DerivativeKeysPDEStatio()
|
|
193
143
|
)
|
|
194
|
-
self.normalization_loss = True
|
|
195
|
-
|
|
196
|
-
def get_norm_samples(self):
|
|
197
|
-
"""
|
|
198
|
-
Returns a batch of points in the domain for integration when the
|
|
199
|
-
normalization constraint is enforced. The batch of points is either
|
|
200
|
-
fixed (provided by the user) or regenerated at each iteration.
|
|
201
|
-
"""
|
|
202
|
-
if self.norm_sample_method == "user":
|
|
203
|
-
return self.norm_samples
|
|
204
|
-
if self.norm_sample_method == "generate":
|
|
205
|
-
## NOTE TODO CHECK the performances of this for loop
|
|
206
|
-
norm_samples = []
|
|
207
|
-
for d in range(len(self.norm_borders)):
|
|
208
|
-
self.norm_key, subkey = jax.random.split(self.norm_key)
|
|
209
|
-
norm_samples.append(
|
|
210
|
-
jax.random.uniform(
|
|
211
|
-
subkey,
|
|
212
|
-
shape=(1000, 1),
|
|
213
|
-
minval=self.norm_xmin[d],
|
|
214
|
-
maxval=self.norm_xmax[d],
|
|
215
|
-
)
|
|
216
|
-
)
|
|
217
|
-
self.norm_samples = jnp.concatenate(norm_samples, axis=-1)
|
|
218
|
-
return self.norm_samples
|
|
219
|
-
raise RuntimeError("Problem with the value of self.norm_sample_method")
|
|
220
|
-
|
|
221
|
-
def tree_flatten(self):
|
|
222
|
-
children = (self.norm_key, self.norm_samples, self.loss_weights)
|
|
223
|
-
aux_data = {
|
|
224
|
-
"norm_borders": self.norm_borders,
|
|
225
|
-
"derivative_keys": self.derivative_keys,
|
|
226
|
-
"u": self.u,
|
|
227
|
-
}
|
|
228
|
-
return (children, aux_data)
|
|
229
|
-
|
|
230
|
-
@classmethod
|
|
231
|
-
def tree_unflatten(self, aux_data, children):
|
|
232
|
-
(norm_key, norm_samples, loss_weights) = children
|
|
233
|
-
pls = self(
|
|
234
|
-
aux_data["u"],
|
|
235
|
-
loss_weights,
|
|
236
|
-
aux_data["derivative_keys"],
|
|
237
|
-
norm_key,
|
|
238
|
-
aux_data["norm_borders"],
|
|
239
|
-
norm_samples,
|
|
240
|
-
)
|
|
241
|
-
return pls
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
@register_pytree_node_class
|
|
245
|
-
class LossPDEStatio(LossPDEAbstract):
|
|
246
|
-
r"""Loss object for a stationary partial differential equation
|
|
247
|
-
|
|
248
|
-
.. math::
|
|
249
|
-
\mathcal{N}[u](x) = 0, \forall x \in \Omega
|
|
250
|
-
|
|
251
|
-
where :math:`\mathcal{N}[\cdot]` is a differential operator and the
|
|
252
|
-
boundary condition is :math:`u(x)=u_b(x)` The additional condition of
|
|
253
|
-
integrating to 1 can be included, i.e. :math:`\int u(x)\mathrm{d}x=1`.
|
|
254
|
-
|
|
255
144
|
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
145
|
+
if self.loss_weights is None:
|
|
146
|
+
self.loss_weights = (
|
|
147
|
+
LossWeightsPDENonStatio()
|
|
148
|
+
if isinstance(self, LossPDENonStatio)
|
|
149
|
+
else LossWeightsPDEStatio()
|
|
150
|
+
)
|
|
259
151
|
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
u,
|
|
263
|
-
loss_weights,
|
|
264
|
-
dynamic_loss,
|
|
265
|
-
derivative_keys=None,
|
|
266
|
-
omega_boundary_fun=None,
|
|
267
|
-
omega_boundary_condition=None,
|
|
268
|
-
omega_boundary_dim=None,
|
|
269
|
-
norm_key=None,
|
|
270
|
-
norm_borders=None,
|
|
271
|
-
norm_samples=None,
|
|
272
|
-
sobolev_m=None,
|
|
273
|
-
obs_slice=None,
|
|
274
|
-
):
|
|
275
|
-
r"""
|
|
276
|
-
Parameters
|
|
277
|
-
----------
|
|
278
|
-
u
|
|
279
|
-
the PINN object
|
|
280
|
-
loss_weights
|
|
281
|
-
a dictionary with values used to ponderate each term in the loss
|
|
282
|
-
function. Valid keys are `dyn_loss`, `norm_loss`, `boundary_loss`,
|
|
283
|
-
`observations` and `sobolev`.
|
|
284
|
-
Note that we can have jnp.arrays with the same dimension of
|
|
285
|
-
`u` which then ponderates each output of `u`
|
|
286
|
-
dynamic_loss
|
|
287
|
-
the stationary PDE dynamic part of the loss, basically the differential
|
|
288
|
-
operator :math:` \mathcal{N}[u](t)`. Should implement a method
|
|
289
|
-
`dynamic_loss.evaluate(t, u, params)`.
|
|
290
|
-
Can be None in order to access only some part of the evaluate call
|
|
291
|
-
results.
|
|
292
|
-
derivative_keys
|
|
293
|
-
A dict of lists of strings. In the dict, the key must correspond to
|
|
294
|
-
the loss term keywords. Then each of the values must correspond to keys in the parameter
|
|
295
|
-
dictionary (*at top level only of the parameter dictionary*).
|
|
296
|
-
It enables selecting the set of parameters
|
|
297
|
-
with respect to which the gradients of the dynamic
|
|
298
|
-
loss are computed. If nothing is provided, we set ["nn_params"] for all loss term
|
|
299
|
-
keywords, this is what is typically
|
|
300
|
-
done in solving forward problems, when we only estimate the
|
|
301
|
-
equation solution with a PINN. If some loss terms keywords are
|
|
302
|
-
missing we set their value to ["nn_params"] by default for the same
|
|
303
|
-
reason
|
|
304
|
-
omega_boundary_fun
|
|
305
|
-
The function to be matched in the border condition (can be None)
|
|
306
|
-
or a dictionary of such function. In this case, the keys are the
|
|
307
|
-
facets and the values are the functions. The keys must be in the
|
|
308
|
-
following order: 1D -> ["xmin", "xmax"], 2D -> ["xmin", "xmax",
|
|
309
|
-
"ymin", "ymax"]. Note that high order boundaries are currently not
|
|
310
|
-
implemented. A value in the dict can be None, this means we do not
|
|
311
|
-
enforce a particular boundary condition on this facet.
|
|
312
|
-
The facet called "xmin", resp. "xmax" etc., in 2D,
|
|
313
|
-
refers to the set of 2D points with fixed "xmin", resp. "xmax", etc.
|
|
314
|
-
omega_boundary_condition
|
|
315
|
-
Either None (no condition), or a string defining the boundary
|
|
316
|
-
condition e.g. Dirichlet or Von Neumann, or a dictionary of such
|
|
317
|
-
strings. In this case, the keys are the
|
|
318
|
-
facets and the values are the strings. The keys must be in the
|
|
319
|
-
following order: 1D -> ["xmin", "xmax"], 2D -> ["xmin", "xmax",
|
|
320
|
-
"ymin", "ymax"]. Note that high order boundaries are currently not
|
|
321
|
-
implemented. A value in the dict can be None, this means we do not
|
|
322
|
-
enforce a particular boundary condition on this facet.
|
|
323
|
-
The facet called "xmin", resp. "xmax" etc., in 2D,
|
|
324
|
-
refers to the set of 2D points with fixed "xmin", resp. "xmax", etc.
|
|
325
|
-
omega_boundary_dim
|
|
326
|
-
Either None, or a jnp.s\_ or a dict of jnp.s\_ with keys following
|
|
327
|
-
the logic of omega_boundary_fun. It indicates which dimension(s) of
|
|
328
|
-
the PINN will be forced to match the boundary condition
|
|
329
|
-
Note that it must be a slice and not an integer (a preprocessing of the
|
|
330
|
-
user provided argument takes care of it)
|
|
331
|
-
norm_key
|
|
332
|
-
Jax random key to draw samples in for the Monte Carlo computation
|
|
333
|
-
of the normalization constant. Default is None
|
|
334
|
-
norm_borders
|
|
335
|
-
tuple of (min, max) of the boundaray values of the space over which
|
|
336
|
-
to integrate in the computation of the normalization constant.
|
|
337
|
-
A list of tuple for higher dimensional problems. Default None.
|
|
338
|
-
norm_samples
|
|
339
|
-
Fixed sample point in the space over which to compute the
|
|
340
|
-
normalization constant. Default is None
|
|
341
|
-
sobolev_m
|
|
342
|
-
An integer. Default is None.
|
|
343
|
-
It corresponds to the Sobolev regularization order as proposed in
|
|
344
|
-
*Convergence and error analysis of PINNs*,
|
|
345
|
-
Doumeche et al., 2023, https://arxiv.org/pdf/2305.01240.pdf
|
|
346
|
-
obs_slice
|
|
347
|
-
slice object specifying the begininning/ending
|
|
348
|
-
slice of u output(s) that is observed (this is then useful for
|
|
349
|
-
multidim PINN). Default is None.
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
Raises
|
|
353
|
-
------
|
|
354
|
-
ValueError
|
|
355
|
-
If conditions on omega_boundary_condition and omega_boundary_fun
|
|
356
|
-
are not respected
|
|
357
|
-
"""
|
|
152
|
+
if self.obs_slice is None:
|
|
153
|
+
self.obs_slice = jnp.s_[...]
|
|
358
154
|
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
155
|
+
if (
|
|
156
|
+
isinstance(self.omega_boundary_fun, dict)
|
|
157
|
+
and not isinstance(self.omega_boundary_condition, dict)
|
|
158
|
+
) or (
|
|
159
|
+
not isinstance(self.omega_boundary_fun, dict)
|
|
160
|
+
and isinstance(self.omega_boundary_condition, dict)
|
|
161
|
+
):
|
|
162
|
+
raise ValueError(
|
|
163
|
+
"if one of self.omega_boundary_fun or "
|
|
164
|
+
"self.omega_boundary_condition is dict, the other should be too."
|
|
165
|
+
)
|
|
362
166
|
|
|
363
|
-
if omega_boundary_condition is None or omega_boundary_fun is None:
|
|
167
|
+
if self.omega_boundary_condition is None or self.omega_boundary_fun is None:
|
|
364
168
|
warnings.warn(
|
|
365
169
|
"Missing boundary function or no boundary condition."
|
|
366
170
|
"Boundary function is thus ignored."
|
|
367
171
|
)
|
|
368
172
|
else:
|
|
369
|
-
if isinstance(omega_boundary_condition, dict):
|
|
370
|
-
for _, v in omega_boundary_condition.items():
|
|
173
|
+
if isinstance(self.omega_boundary_condition, dict):
|
|
174
|
+
for _, v in self.omega_boundary_condition.items():
|
|
371
175
|
if v is not None and not any(
|
|
372
176
|
v.lower() in s for s in _IMPLEMENTED_BOUNDARY_CONDITIONS
|
|
373
177
|
):
|
|
374
178
|
raise NotImplementedError(
|
|
375
|
-
f"The boundary condition {omega_boundary_condition} is not"
|
|
179
|
+
f"The boundary condition {self.omega_boundary_condition} is not"
|
|
376
180
|
f"implemented yet. Try one of :"
|
|
377
181
|
f"{_IMPLEMENTED_BOUNDARY_CONDITIONS}."
|
|
378
182
|
)
|
|
379
183
|
else:
|
|
380
184
|
if not any(
|
|
381
|
-
omega_boundary_condition.lower() in s
|
|
185
|
+
self.omega_boundary_condition.lower() in s
|
|
382
186
|
for s in _IMPLEMENTED_BOUNDARY_CONDITIONS
|
|
383
187
|
):
|
|
384
188
|
raise NotImplementedError(
|
|
385
|
-
f"The boundary condition {omega_boundary_condition} is not"
|
|
189
|
+
f"The boundary condition {self.omega_boundary_condition} is not"
|
|
386
190
|
f"implemented yet. Try one of :"
|
|
387
191
|
f"{_IMPLEMENTED_BOUNDARY_CONDITIONS}."
|
|
388
192
|
)
|
|
389
|
-
if isinstance(omega_boundary_fun, dict) and isinstance(
|
|
390
|
-
omega_boundary_condition, dict
|
|
193
|
+
if isinstance(self.omega_boundary_fun, dict) and isinstance(
|
|
194
|
+
self.omega_boundary_condition, dict
|
|
391
195
|
):
|
|
392
196
|
if (
|
|
393
197
|
not (
|
|
394
|
-
list(omega_boundary_fun.keys()) == ["xmin", "xmax"]
|
|
395
|
-
and list(omega_boundary_condition.keys())
|
|
198
|
+
list(self.omega_boundary_fun.keys()) == ["xmin", "xmax"]
|
|
199
|
+
and list(self.omega_boundary_condition.keys())
|
|
396
200
|
== ["xmin", "xmax"]
|
|
397
201
|
)
|
|
398
202
|
) or (
|
|
399
203
|
not (
|
|
400
|
-
list(omega_boundary_fun.keys())
|
|
204
|
+
list(self.omega_boundary_fun.keys())
|
|
401
205
|
== ["xmin", "xmax", "ymin", "ymax"]
|
|
402
|
-
and list(omega_boundary_condition.keys())
|
|
206
|
+
and list(self.omega_boundary_condition.keys())
|
|
403
207
|
== ["xmin", "xmax", "ymin", "ymax"]
|
|
404
208
|
)
|
|
405
209
|
):
|
|
@@ -408,10 +212,6 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
408
212
|
"boundary condition dictionaries is incorrect"
|
|
409
213
|
)
|
|
410
214
|
|
|
411
|
-
self.omega_boundary_fun = omega_boundary_fun
|
|
412
|
-
self.omega_boundary_condition = omega_boundary_condition
|
|
413
|
-
|
|
414
|
-
self.omega_boundary_dim = omega_boundary_dim
|
|
415
215
|
if isinstance(self.omega_boundary_fun, dict):
|
|
416
216
|
if self.omega_boundary_dim is None:
|
|
417
217
|
self.omega_boundary_dim = {
|
|
@@ -440,44 +240,139 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
440
240
|
self.omega_boundary_dim : self.omega_boundary_dim + 1
|
|
441
241
|
]
|
|
442
242
|
if not isinstance(self.omega_boundary_dim, slice):
|
|
443
|
-
raise ValueError("self.omega_boundary_dim must be a jnp.s_
|
|
243
|
+
raise ValueError("self.omega_boundary_dim must be a jnp.s_ object")
|
|
444
244
|
|
|
445
|
-
self.
|
|
245
|
+
if self.norm_samples is not None and self.norm_int_length is None:
|
|
246
|
+
raise ValueError("self.norm_samples and norm_int_length must be provided")
|
|
446
247
|
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
else:
|
|
455
|
-
self.sobolev_reg = None
|
|
248
|
+
@abc.abstractmethod
|
|
249
|
+
def evaluate(
|
|
250
|
+
self: eqx.Module,
|
|
251
|
+
params: Params,
|
|
252
|
+
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
253
|
+
) -> tuple[Float, dict]:
|
|
254
|
+
raise NotImplementedError
|
|
456
255
|
|
|
457
|
-
for k in _LOSS_WEIGHT_KEYS_PDESTATIO:
|
|
458
|
-
if k not in self.loss_weights.keys():
|
|
459
|
-
self.loss_weights[k] = 0
|
|
460
256
|
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
and not isinstance(self.omega_boundary_condition, dict)
|
|
464
|
-
) or (
|
|
465
|
-
not isinstance(self.omega_boundary_fun, dict)
|
|
466
|
-
and isinstance(self.omega_boundary_condition, dict)
|
|
467
|
-
):
|
|
468
|
-
raise ValueError(
|
|
469
|
-
"if one of self.omega_boundary_fun or "
|
|
470
|
-
"self.omega_boundary_condition is dict, the other should be too."
|
|
471
|
-
)
|
|
257
|
+
class LossPDEStatio(_LossPDEAbstract):
|
|
258
|
+
r"""Loss object for a stationary partial differential equation
|
|
472
259
|
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
260
|
+
$$
|
|
261
|
+
\mathcal{N}[u](x) = 0, \forall x \in \Omega
|
|
262
|
+
$$
|
|
263
|
+
|
|
264
|
+
where $\mathcal{N}[\cdot]$ is a differential operator and the
|
|
265
|
+
boundary condition is $u(x)=u_b(x)$ The additional condition of
|
|
266
|
+
integrating to 1 can be included, i.e. $\int u(x)\mathrm{d}x=1$.
|
|
267
|
+
|
|
268
|
+
Parameters
|
|
269
|
+
----------
|
|
270
|
+
u : eqx.Module
|
|
271
|
+
the PINN
|
|
272
|
+
dynamic_loss : DynamicLoss
|
|
273
|
+
the stationary PDE dynamic part of the loss, basically the differential
|
|
274
|
+
operator $\mathcal{N}[u](x)$. Should implement a method
|
|
275
|
+
`dynamic_loss.evaluate(x, u, params)`.
|
|
276
|
+
Can be None in order to access only some part of the evaluate call
|
|
277
|
+
results.
|
|
278
|
+
key : Key
|
|
279
|
+
A JAX PRNG Key for the loss class treated as an attribute. Default is
|
|
280
|
+
None. This field is provided for future developments and additional
|
|
281
|
+
losses that might need some randomness. Note that special care must be
|
|
282
|
+
taken when splitting the key because in-place updates are forbidden in
|
|
283
|
+
eqx.Modules.
|
|
284
|
+
loss_weights : LossWeightsPDEStatio, default=None
|
|
285
|
+
The loss weights for the differents term : dynamic loss,
|
|
286
|
+
boundary conditions if any, normalization loss if any and
|
|
287
|
+
observations if any.
|
|
288
|
+
All fields are set to 1.0 by default.
|
|
289
|
+
derivative_keys : DerivativeKeysPDEStatio, default=None
|
|
290
|
+
Specify which field of `params` should be differentiated for each
|
|
291
|
+
composant of the total loss. Particularily useful for inverse problems.
|
|
292
|
+
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
293
|
+
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
294
|
+
is `"nn_params"` for each composant of the loss.
|
|
295
|
+
omega_boundary_fun : Callable | Dict[str, Callable], default=None
|
|
296
|
+
The function to be matched in the border condition (can be None) or a
|
|
297
|
+
dictionary of such functions as values and keys as described
|
|
298
|
+
in `omega_boundary_condition`.
|
|
299
|
+
omega_boundary_condition : str | Dict[str, str], default=None
|
|
300
|
+
Either None (no condition, by default), or a string defining
|
|
301
|
+
the boundary condition (Dirichlet or Von Neumann),
|
|
302
|
+
or a dictionary with such strings as values. In this case,
|
|
303
|
+
the keys are the facets and must be in the following order:
|
|
304
|
+
1D -> [“xmin”, “xmax”], 2D -> [“xmin”, “xmax”, “ymin”, “ymax”].
|
|
305
|
+
Note that high order boundaries are currently not implemented.
|
|
306
|
+
A value in the dict can be None, this means we do not enforce
|
|
307
|
+
a particular boundary condition on this facet.
|
|
308
|
+
The facet called “xmin”, resp. “xmax” etc., in 2D,
|
|
309
|
+
refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
|
|
310
|
+
omega_boundary_dim : slice | Dict[str, slice], default=None
|
|
311
|
+
Either None, or a slice object or a dictionary of slice objects as
|
|
312
|
+
values and keys as described in `omega_boundary_condition`.
|
|
313
|
+
`omega_boundary_dim` indicates which dimension(s) of the PINN
|
|
314
|
+
will be forced to match the boundary condition.
|
|
315
|
+
Note that it must be a slice and not an integer
|
|
316
|
+
(but a preprocessing of the user provided argument takes care of it)
|
|
317
|
+
norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
|
|
318
|
+
Fixed sample point in the space over which to compute the
|
|
319
|
+
normalization constant. Default is None.
|
|
320
|
+
norm_int_length : float, default=None
|
|
321
|
+
A float. Must be provided if `norm_samples` is provided. The domain area
|
|
322
|
+
(or interval length in 1D) upon which we perform the numerical
|
|
323
|
+
integration. Default None
|
|
324
|
+
obs_slice : slice, default=None
|
|
325
|
+
slice object specifying the begininning/ending of the PINN output
|
|
326
|
+
that is observed (this is then useful for multidim PINN). Default is None.
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
Raises
|
|
330
|
+
------
|
|
331
|
+
ValueError
|
|
332
|
+
If conditions on omega_boundary_condition and omega_boundary_fun
|
|
333
|
+
are not respected
|
|
334
|
+
"""
|
|
335
|
+
|
|
336
|
+
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
337
|
+
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
338
|
+
|
|
339
|
+
u: eqx.Module
|
|
340
|
+
dynamic_loss: DynamicLoss | None
|
|
341
|
+
key: Key | None = eqx.field(kw_only=True, default=None)
|
|
342
|
+
|
|
343
|
+
vmap_in_axes: tuple[Int] = eqx.field(init=False, static=True)
|
|
344
|
+
|
|
345
|
+
def __post_init__(self):
|
|
346
|
+
"""
|
|
347
|
+
Note that neither __init__ or __post_init__ are called when udating a
|
|
348
|
+
Module with eqx.tree_at!
|
|
349
|
+
"""
|
|
350
|
+
super().__post_init__() # because __init__ or __post_init__ of Base
|
|
351
|
+
# class is not automatically called
|
|
352
|
+
|
|
353
|
+
self.vmap_in_axes = (0,) # for x only here
|
|
354
|
+
|
|
355
|
+
def _get_dynamic_loss_batch(
|
|
356
|
+
self, batch: PDEStatioBatch
|
|
357
|
+
) -> tuple[Float[Array, "batch_size dimension"]]:
|
|
358
|
+
return (batch.inside_batch,)
|
|
359
|
+
|
|
360
|
+
def _get_normalization_loss_batch(
|
|
361
|
+
self, _
|
|
362
|
+
) -> Float[Array, "nb_norm_samples dimension"]:
|
|
363
|
+
return (self.norm_samples,)
|
|
364
|
+
|
|
365
|
+
def _get_observations_loss_batch(
|
|
366
|
+
self, batch: PDEStatioBatch
|
|
367
|
+
) -> Float[Array, "batch_size obs_dim"]:
|
|
368
|
+
return (batch.obs_batch_dict["pinn_in"],)
|
|
476
369
|
|
|
477
370
|
def __call__(self, *args, **kwargs):
|
|
478
371
|
return self.evaluate(*args, **kwargs)
|
|
479
372
|
|
|
480
|
-
def evaluate(
|
|
373
|
+
def evaluate(
|
|
374
|
+
self, params: Params, batch: PDEStatioBatch
|
|
375
|
+
) -> tuple[Float[Array, "1"], dict[str, float]]:
|
|
481
376
|
"""
|
|
482
377
|
Evaluate the loss function at a batch of points for given parameters.
|
|
483
378
|
|
|
@@ -485,22 +380,14 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
485
380
|
Parameters
|
|
486
381
|
---------
|
|
487
382
|
params
|
|
488
|
-
|
|
489
|
-
Typically, it is a dictionary of
|
|
490
|
-
dictionaries: `eq_params` and `nn_params``, respectively the
|
|
491
|
-
differential equation parameters and the neural network parameter
|
|
383
|
+
Parameters at which the loss is evaluated
|
|
492
384
|
batch
|
|
493
|
-
|
|
494
|
-
Such a named tuple is composed of a batch of points in the
|
|
385
|
+
Composed of a batch of points in the
|
|
495
386
|
domain, a batch of points in the domain
|
|
496
387
|
border and an optional additional batch of parameters (eg. for
|
|
497
388
|
metamodeling) and an optional additional batch of observed
|
|
498
389
|
inputs/outputs/parameters
|
|
499
390
|
"""
|
|
500
|
-
omega_batch, _ = batch.inside_batch, batch.border_batch
|
|
501
|
-
|
|
502
|
-
vmap_in_axes_x = (0,)
|
|
503
|
-
|
|
504
391
|
# Retrieve the optional eq_params_batch
|
|
505
392
|
# and update eq_params with the latter
|
|
506
393
|
# and update vmap_in_axes
|
|
@@ -511,44 +398,41 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
511
398
|
vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
|
|
512
399
|
|
|
513
400
|
# dynamic part
|
|
514
|
-
params_ = _set_derivatives(params, "dyn_loss", self.derivative_keys)
|
|
515
401
|
if self.dynamic_loss is not None:
|
|
516
402
|
mse_dyn_loss = dynamic_loss_apply(
|
|
517
403
|
self.dynamic_loss.evaluate,
|
|
518
404
|
self.u,
|
|
519
|
-
(
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
self.loss_weights
|
|
405
|
+
self._get_dynamic_loss_batch(batch),
|
|
406
|
+
_set_derivatives(params, self.derivative_keys.dyn_loss),
|
|
407
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
408
|
+
self.loss_weights.dyn_loss,
|
|
523
409
|
)
|
|
524
410
|
else:
|
|
525
411
|
mse_dyn_loss = jnp.array(0.0)
|
|
526
412
|
|
|
527
413
|
# normalization part
|
|
528
|
-
|
|
529
|
-
if self.normalization_loss is not None:
|
|
414
|
+
if self.norm_samples is not None:
|
|
530
415
|
mse_norm_loss = normalization_loss_apply(
|
|
531
416
|
self.u,
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
self.
|
|
536
|
-
self.loss_weights
|
|
417
|
+
self._get_normalization_loss_batch(batch),
|
|
418
|
+
_set_derivatives(params, self.derivative_keys.norm_loss),
|
|
419
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
420
|
+
self.norm_int_length,
|
|
421
|
+
self.loss_weights.norm_loss,
|
|
537
422
|
)
|
|
538
423
|
else:
|
|
539
424
|
mse_norm_loss = jnp.array(0.0)
|
|
540
425
|
|
|
541
426
|
# boundary part
|
|
542
|
-
params_ = _set_derivatives(params, "boundary_loss", self.derivative_keys)
|
|
543
427
|
if self.omega_boundary_condition is not None:
|
|
544
428
|
mse_boundary_loss = boundary_condition_apply(
|
|
545
429
|
self.u,
|
|
546
430
|
batch,
|
|
547
|
-
|
|
431
|
+
_set_derivatives(params, self.derivative_keys.boundary_loss),
|
|
548
432
|
self.omega_boundary_fun,
|
|
549
433
|
self.omega_boundary_condition,
|
|
550
434
|
self.omega_boundary_dim,
|
|
551
|
-
self.loss_weights
|
|
435
|
+
self.loss_weights.boundary_loss,
|
|
552
436
|
)
|
|
553
437
|
else:
|
|
554
438
|
mse_boundary_loss = jnp.array(0.0)
|
|
@@ -558,40 +442,21 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
558
442
|
# update params with the batches of observed params
|
|
559
443
|
params = _update_eq_params_dict(params, batch.obs_batch_dict["eq_params"])
|
|
560
444
|
|
|
561
|
-
params_ = _set_derivatives(params, "observations", self.derivative_keys)
|
|
562
445
|
mse_observation_loss = observations_loss_apply(
|
|
563
446
|
self.u,
|
|
564
|
-
(batch
|
|
565
|
-
|
|
566
|
-
|
|
447
|
+
self._get_observations_loss_batch(batch),
|
|
448
|
+
_set_derivatives(params, self.derivative_keys.observations),
|
|
449
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
567
450
|
batch.obs_batch_dict["val"],
|
|
568
|
-
self.loss_weights
|
|
451
|
+
self.loss_weights.observations,
|
|
569
452
|
self.obs_slice,
|
|
570
453
|
)
|
|
571
454
|
else:
|
|
572
455
|
mse_observation_loss = jnp.array(0.0)
|
|
573
456
|
|
|
574
|
-
# Sobolev regularization
|
|
575
|
-
params_ = _set_derivatives(params, "sobolev", self.derivative_keys)
|
|
576
|
-
if self.sobolev_reg is not None:
|
|
577
|
-
mse_sobolev_loss = sobolev_reg_apply(
|
|
578
|
-
self.u,
|
|
579
|
-
(omega_batch,),
|
|
580
|
-
params_,
|
|
581
|
-
vmap_in_axes_x + vmap_in_axes_params,
|
|
582
|
-
self.sobolev_reg,
|
|
583
|
-
self.loss_weights["sobolev"],
|
|
584
|
-
)
|
|
585
|
-
else:
|
|
586
|
-
mse_sobolev_loss = jnp.array(0.0)
|
|
587
|
-
|
|
588
457
|
# total loss
|
|
589
458
|
total_loss = (
|
|
590
|
-
mse_dyn_loss
|
|
591
|
-
+ mse_norm_loss
|
|
592
|
-
+ mse_boundary_loss
|
|
593
|
-
+ mse_observation_loss
|
|
594
|
-
+ mse_sobolev_loss
|
|
459
|
+
mse_dyn_loss + mse_norm_loss + mse_boundary_loss + mse_observation_loss
|
|
595
460
|
)
|
|
596
461
|
return total_loss, (
|
|
597
462
|
{
|
|
@@ -599,205 +464,143 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
599
464
|
"norm_loss": mse_norm_loss,
|
|
600
465
|
"boundary_loss": mse_boundary_loss,
|
|
601
466
|
"observations": mse_observation_loss,
|
|
602
|
-
"sobolev": mse_sobolev_loss,
|
|
603
467
|
"initial_condition": jnp.array(0.0), # for compatibility in the
|
|
604
468
|
# tree_map of SystemLoss
|
|
605
469
|
}
|
|
606
470
|
)
|
|
607
471
|
|
|
608
|
-
def tree_flatten(self):
|
|
609
|
-
children = (self.norm_key, self.norm_samples, self.loss_weights)
|
|
610
|
-
aux_data = {
|
|
611
|
-
"u": self.u,
|
|
612
|
-
"dynamic_loss": self.dynamic_loss,
|
|
613
|
-
"derivative_keys": self.derivative_keys,
|
|
614
|
-
"omega_boundary_fun": self.omega_boundary_fun,
|
|
615
|
-
"omega_boundary_condition": self.omega_boundary_condition,
|
|
616
|
-
"omega_boundary_dim": self.omega_boundary_dim,
|
|
617
|
-
"norm_borders": self.norm_borders,
|
|
618
|
-
"sobolev_m": self.sobolev_m,
|
|
619
|
-
"obs_slice": self.obs_slice,
|
|
620
|
-
}
|
|
621
|
-
return (children, aux_data)
|
|
622
|
-
|
|
623
|
-
@classmethod
|
|
624
|
-
def tree_unflatten(cls, aux_data, children):
|
|
625
|
-
(norm_key, norm_samples, loss_weights) = children
|
|
626
|
-
pls = cls(
|
|
627
|
-
aux_data["u"],
|
|
628
|
-
loss_weights,
|
|
629
|
-
aux_data["dynamic_loss"],
|
|
630
|
-
aux_data["derivative_keys"],
|
|
631
|
-
aux_data["omega_boundary_fun"],
|
|
632
|
-
aux_data["omega_boundary_condition"],
|
|
633
|
-
aux_data["omega_boundary_dim"],
|
|
634
|
-
norm_key,
|
|
635
|
-
aux_data["norm_borders"],
|
|
636
|
-
norm_samples,
|
|
637
|
-
aux_data["sobolev_m"],
|
|
638
|
-
aux_data["obs_slice"],
|
|
639
|
-
)
|
|
640
|
-
return pls
|
|
641
|
-
|
|
642
472
|
|
|
643
|
-
@register_pytree_node_class
|
|
644
473
|
class LossPDENonStatio(LossPDEStatio):
|
|
645
474
|
r"""Loss object for a stationary partial differential equation
|
|
646
475
|
|
|
647
|
-
|
|
476
|
+
$$
|
|
648
477
|
\mathcal{N}[u](t, x) = 0, \forall t \in I, \forall x \in \Omega
|
|
478
|
+
$$
|
|
649
479
|
|
|
650
|
-
where
|
|
651
|
-
The boundary condition is
|
|
652
|
-
x\in\delta\Omega, \forall t
|
|
653
|
-
The initial condition is
|
|
480
|
+
where $\mathcal{N}[\cdot]$ is a differential operator.
|
|
481
|
+
The boundary condition is $u(t, x)=u_b(t, x),\forall
|
|
482
|
+
x\in\delta\Omega, \forall t$.
|
|
483
|
+
The initial condition is $u(0, x)=u_0(x), \forall x\in\Omega$
|
|
654
484
|
The additional condition of
|
|
655
|
-
integrating to 1 can be included, i.e.,
|
|
656
|
-
|
|
485
|
+
integrating to 1 can be included, i.e., $\int u(t, x)\mathrm{d}x=1$.
|
|
486
|
+
|
|
487
|
+
Parameters
|
|
488
|
+
----------
|
|
489
|
+
u : eqx.Module
|
|
490
|
+
the PINN
|
|
491
|
+
dynamic_loss : DynamicLoss
|
|
492
|
+
the non stationary PDE dynamic part of the loss, basically the differential
|
|
493
|
+
operator $\mathcal{N}[u](t, x)$. Should implement a method
|
|
494
|
+
`dynamic_loss.evaluate(t, x, u, params)`.
|
|
495
|
+
Can be None in order to access only some part of the evaluate call
|
|
496
|
+
results.
|
|
497
|
+
key : Key
|
|
498
|
+
A JAX PRNG Key for the loss class treated as an attribute. Default is
|
|
499
|
+
None. This field is provided for future developments and additional
|
|
500
|
+
losses that might need some randomness. Note that special care must be
|
|
501
|
+
taken when splitting the key because in-place updates are forbidden in
|
|
502
|
+
eqx.Modules.
|
|
503
|
+
reason
|
|
504
|
+
loss_weights : LossWeightsPDENonStatio, default=None
|
|
505
|
+
The loss weights for the differents term : dynamic loss,
|
|
506
|
+
boundary conditions if any, initial condition, normalization loss if any and
|
|
507
|
+
observations if any.
|
|
508
|
+
All fields are set to 1.0 by default.
|
|
509
|
+
derivative_keys : DerivativeKeysPDENonStatio, default=None
|
|
510
|
+
Specify which field of `params` should be differentiated for each
|
|
511
|
+
composant of the total loss. Particularily useful for inverse problems.
|
|
512
|
+
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
513
|
+
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
514
|
+
is `"nn_params"` for each composant of the loss.
|
|
515
|
+
omega_boundary_fun : Callable | Dict[str, Callable], default=None
|
|
516
|
+
The function to be matched in the border condition (can be None) or a
|
|
517
|
+
dictionary of such functions as values and keys as described
|
|
518
|
+
in `omega_boundary_condition`.
|
|
519
|
+
omega_boundary_condition : str | Dict[str, str], default=None
|
|
520
|
+
Either None (no condition, by default), or a string defining
|
|
521
|
+
the boundary condition (Dirichlet or Von Neumann),
|
|
522
|
+
or a dictionary with such strings as values. In this case,
|
|
523
|
+
the keys are the facets and must be in the following order:
|
|
524
|
+
1D -> [“xmin”, “xmax”], 2D -> [“xmin”, “xmax”, “ymin”, “ymax”].
|
|
525
|
+
Note that high order boundaries are currently not implemented.
|
|
526
|
+
A value in the dict can be None, this means we do not enforce
|
|
527
|
+
a particular boundary condition on this facet.
|
|
528
|
+
The facet called “xmin”, resp. “xmax” etc., in 2D,
|
|
529
|
+
refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
|
|
530
|
+
omega_boundary_dim : slice | Dict[str, slice], default=None
|
|
531
|
+
Either None, or a slice object or a dictionary of slice objects as
|
|
532
|
+
values and keys as described in `omega_boundary_condition`.
|
|
533
|
+
`omega_boundary_dim` indicates which dimension(s) of the PINN
|
|
534
|
+
will be forced to match the boundary condition.
|
|
535
|
+
Note that it must be a slice and not an integer
|
|
536
|
+
(but a preprocessing of the user provided argument takes care of it)
|
|
537
|
+
norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
|
|
538
|
+
Fixed sample point in the space over which to compute the
|
|
539
|
+
normalization constant. Default is None.
|
|
540
|
+
norm_int_length : float, default=None
|
|
541
|
+
A float. Must be provided if `norm_samples` is provided. The domain area
|
|
542
|
+
(or interval length in 1D) upon which we perform the numerical
|
|
543
|
+
integration. Default None
|
|
544
|
+
obs_slice : slice, default=None
|
|
545
|
+
slice object specifying the begininning/ending of the PINN output
|
|
546
|
+
that is observed (this is then useful for multidim PINN). Default is None.
|
|
547
|
+
initial_condition_fun : Callable, default=None
|
|
548
|
+
A function representing the temporal initial condition. If None
|
|
549
|
+
(default) then no initial condition is applied
|
|
657
550
|
|
|
658
|
-
**Note:** LossPDENonStatio is jittable. Hence it implements the tree_flatten() and
|
|
659
|
-
tree_unflatten methods.
|
|
660
551
|
"""
|
|
661
552
|
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
derivative_keys=None,
|
|
668
|
-
omega_boundary_fun=None,
|
|
669
|
-
omega_boundary_condition=None,
|
|
670
|
-
omega_boundary_dim=None,
|
|
671
|
-
initial_condition_fun=None,
|
|
672
|
-
norm_key=None,
|
|
673
|
-
norm_borders=None,
|
|
674
|
-
norm_samples=None,
|
|
675
|
-
sobolev_m=None,
|
|
676
|
-
obs_slice=None,
|
|
677
|
-
):
|
|
678
|
-
r"""
|
|
679
|
-
Parameters
|
|
680
|
-
----------
|
|
681
|
-
u
|
|
682
|
-
the PINN object
|
|
683
|
-
loss_weights
|
|
684
|
-
dictionary of values for loss term ponderation
|
|
685
|
-
Note that we can have jnp.arrays with the same dimension of
|
|
686
|
-
`u` which then ponderates each output of `u`
|
|
687
|
-
dynamic_loss
|
|
688
|
-
A Dynamic loss object whose evaluate method corresponds to the
|
|
689
|
-
dynamic term in the loss
|
|
690
|
-
Can be None in order to access only some part of the evaluate call
|
|
691
|
-
results.
|
|
692
|
-
derivative_keys
|
|
693
|
-
A dict of lists of strings. In the dict, the key must correspond to
|
|
694
|
-
the loss term keywords. Then each of the values must correspond to keys in the parameter
|
|
695
|
-
dictionary (*at top level only of the parameter dictionary*).
|
|
696
|
-
It enables selecting the set of parameters
|
|
697
|
-
with respect to which the gradients of the dynamic
|
|
698
|
-
loss are computed. If nothing is provided, we set ["nn_params"] for all loss term
|
|
699
|
-
keywords, this is what is typically
|
|
700
|
-
done in solving forward problems, when we only estimate the
|
|
701
|
-
equation solution with a PINN. If some loss terms keywords are
|
|
702
|
-
missing we set their value to ["nn_params"] by default for the same
|
|
703
|
-
reason
|
|
704
|
-
omega_boundary_fun
|
|
705
|
-
The function to be matched in the border condition (can be None)
|
|
706
|
-
or a dictionary of such function. In this case, the keys are the
|
|
707
|
-
facets and the values are the functions. The keys must be in the
|
|
708
|
-
following order: 1D -> ["xmin", "xmax"], 2D -> ["xmin", "xmax",
|
|
709
|
-
"ymin", "ymax"]. Note that high order boundaries are currently not
|
|
710
|
-
implemented. A value in the dict can be None, this means we do not
|
|
711
|
-
enforce a particular boundary condition on this facet.
|
|
712
|
-
The facet called "xmin", resp. "xmax" etc., in 2D,
|
|
713
|
-
refers to the set of 2D points with fixed "xmin", resp. "xmax", etc.
|
|
714
|
-
omega_boundary_condition
|
|
715
|
-
Either None (no condition), or a string defining the boundary
|
|
716
|
-
condition e.g. Dirichlet or Von Neumann, or a dictionary of such
|
|
717
|
-
strings. In this case, the keys are the
|
|
718
|
-
facets and the values are the strings. The keys must be in the
|
|
719
|
-
following order: 1D -> ["xmin", "xmax"], 2D -> ["xmin", "xmax",
|
|
720
|
-
"ymin", "ymax"]. Note that high order boundaries are currently not
|
|
721
|
-
implemented. A value in the dict can be None, this means we do not
|
|
722
|
-
enforce a particular boundary condition on this facet.
|
|
723
|
-
The facet called "xmin", resp. "xmax" etc., in 2D,
|
|
724
|
-
refers to the set of 2D points with fixed "xmin", resp. "xmax", etc.
|
|
725
|
-
omega_boundary_dim
|
|
726
|
-
Either None, or a jnp.s\_ or a dict of jnp.s\_ with keys following
|
|
727
|
-
the logic of omega_boundary_fun. It indicates which dimension(s) of
|
|
728
|
-
the PINN will be forced to match the boundary condition
|
|
729
|
-
Note that it must be a slice and not an integer (a preprocessing of the
|
|
730
|
-
user provided argument takes care of it)
|
|
731
|
-
initial_condition_fun
|
|
732
|
-
A function representing the temporal initial condition. If None
|
|
733
|
-
(default) then no initial condition is applied
|
|
734
|
-
norm_key
|
|
735
|
-
Jax random key to draw samples in for the Monte Carlo computation
|
|
736
|
-
of the normalization constant. Default is None
|
|
737
|
-
norm_borders
|
|
738
|
-
tuple of (min, max) of the boundaray values of the space over which
|
|
739
|
-
to integrate in the computation of the normalization constant.
|
|
740
|
-
A list of tuple for higher dimensional problems. Default None.
|
|
741
|
-
norm_samples
|
|
742
|
-
Fixed sample point in the space over which to compute the
|
|
743
|
-
normalization constant. Default is None
|
|
744
|
-
sobolev_m
|
|
745
|
-
An integer. Default is None.
|
|
746
|
-
It corresponds to the Sobolev regularization order as proposed in
|
|
747
|
-
*Convergence and error analysis of PINNs*,
|
|
748
|
-
Doumeche et al., 2023, https://arxiv.org/pdf/2305.01240.pdf
|
|
749
|
-
obs_slice
|
|
750
|
-
slice object specifying the begininning/ending
|
|
751
|
-
slice of u output(s) that is observed (this is then useful for
|
|
752
|
-
multidim PINN). Default is None.
|
|
753
|
-
|
|
553
|
+
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
554
|
+
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
555
|
+
initial_condition_fun: Callable | None = eqx.field(
|
|
556
|
+
kw_only=True, default=None, static=True
|
|
557
|
+
)
|
|
754
558
|
|
|
559
|
+
def __post_init__(self):
|
|
755
560
|
"""
|
|
561
|
+
Note that neither __init__ or __post_init__ are called when udating a
|
|
562
|
+
Module with eqx.tree_at!
|
|
563
|
+
"""
|
|
564
|
+
super().__post_init__() # because __init__ or __post_init__ of Base
|
|
565
|
+
# class is not automatically called
|
|
756
566
|
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
dynamic_loss,
|
|
761
|
-
derivative_keys,
|
|
762
|
-
omega_boundary_fun,
|
|
763
|
-
omega_boundary_condition,
|
|
764
|
-
omega_boundary_dim,
|
|
765
|
-
norm_key,
|
|
766
|
-
norm_borders,
|
|
767
|
-
norm_samples,
|
|
768
|
-
sobolev_m=sobolev_m,
|
|
769
|
-
obs_slice=obs_slice,
|
|
770
|
-
)
|
|
771
|
-
if initial_condition_fun is None:
|
|
567
|
+
self.vmap_in_axes = (0, 0) # for t and x
|
|
568
|
+
|
|
569
|
+
if self.initial_condition_fun is None:
|
|
772
570
|
warnings.warn(
|
|
773
571
|
"Initial condition wasn't provided. Be sure to cover for that"
|
|
774
572
|
"case (e.g by. hardcoding it into the PINN output)."
|
|
775
573
|
)
|
|
776
|
-
self.initial_condition_fun = initial_condition_fun
|
|
777
|
-
|
|
778
|
-
self.sobolev_m = sobolev_m
|
|
779
|
-
if self.sobolev_m is not None:
|
|
780
|
-
# This overwrite the wrongly initialized self.sobolev_reg with
|
|
781
|
-
# statio=True in the LossPDEStatio init
|
|
782
|
-
self.sobolev_reg = _sobolev(self.u, self.sobolev_m, statio=False)
|
|
783
|
-
# we return a function, that way
|
|
784
|
-
# the order of sobolev_m is static and the conditional in the recursive
|
|
785
|
-
# function is properly set
|
|
786
|
-
else:
|
|
787
|
-
self.sobolev_reg = None
|
|
788
574
|
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
575
|
+
def _get_dynamic_loss_batch(
|
|
576
|
+
self, batch: PDENonStatioBatch
|
|
577
|
+
) -> tuple[Float[Array, "batch_size 1"], Float[Array, "batch_size dimension"]]:
|
|
578
|
+
times_batch = batch.times_x_inside_batch[:, 0:1]
|
|
579
|
+
omega_batch = batch.times_x_inside_batch[:, 1:]
|
|
580
|
+
return (times_batch, omega_batch)
|
|
581
|
+
|
|
582
|
+
def _get_normalization_loss_batch(
|
|
583
|
+
self, batch: PDENonStatioBatch
|
|
584
|
+
) -> tuple[Float[Array, "batch_size 1"], Float[Array, "nb_norm_samples dimension"]]:
|
|
585
|
+
return (
|
|
586
|
+
batch.times_x_inside_batch[:, 0:1],
|
|
587
|
+
self.norm_samples,
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
def _get_observations_loss_batch(
|
|
591
|
+
self, batch: PDENonStatioBatch
|
|
592
|
+
) -> tuple[Float[Array, "batch_size 1"], Float[Array, "batch_size dimension"]]:
|
|
593
|
+
return (
|
|
594
|
+
batch.obs_batch_dict["pinn_in"][:, 0:1],
|
|
595
|
+
batch.obs_batch_dict["pinn_in"][:, 1:],
|
|
596
|
+
)
|
|
792
597
|
|
|
793
598
|
def __call__(self, *args, **kwargs):
|
|
794
599
|
return self.evaluate(*args, **kwargs)
|
|
795
600
|
|
|
796
601
|
def evaluate(
|
|
797
|
-
self,
|
|
798
|
-
|
|
799
|
-
batch,
|
|
800
|
-
):
|
|
602
|
+
self, params: Params, batch: PDENonStatioBatch
|
|
603
|
+
) -> tuple[Float[Array, "1"], dict[str, float]]:
|
|
801
604
|
"""
|
|
802
605
|
Evaluate the loss function at a batch of points for given parameters.
|
|
803
606
|
|
|
@@ -805,191 +608,55 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
805
608
|
Parameters
|
|
806
609
|
---------
|
|
807
610
|
params
|
|
808
|
-
|
|
809
|
-
Typically, it is a dictionary of
|
|
810
|
-
dictionaries: `eq_params` and `nn_params`, respectively the
|
|
811
|
-
differential equation parameters and the neural network parameter
|
|
611
|
+
Parameters at which the loss is evaluated
|
|
812
612
|
batch
|
|
813
|
-
|
|
814
|
-
Such a named tuple is composed of a batch of points in
|
|
613
|
+
Composed of a batch of points in
|
|
815
614
|
the domain, a batch of points in the domain
|
|
816
615
|
border, a batch of time points and an optional additional batch
|
|
817
616
|
of parameters (eg. for metamodeling) and an optional additional batch of observed
|
|
818
617
|
inputs/outputs/parameters
|
|
819
618
|
"""
|
|
820
619
|
|
|
821
|
-
times_batch = batch.times_x_inside_batch[:, 0:1]
|
|
822
620
|
omega_batch = batch.times_x_inside_batch[:, 1:]
|
|
823
|
-
n = omega_batch.shape[0]
|
|
824
|
-
|
|
825
|
-
vmap_in_axes_x_t = (0, 0)
|
|
826
621
|
|
|
827
622
|
# Retrieve the optional eq_params_batch
|
|
828
623
|
# and update eq_params with the latter
|
|
829
624
|
# and update vmap_in_axes
|
|
830
625
|
if batch.param_batch_dict is not None:
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
# feed the eq_params with the batch
|
|
834
|
-
for k in eq_params_batch_dict.keys():
|
|
835
|
-
params["eq_params"][k] = eq_params_batch_dict[k]
|
|
626
|
+
# update eq_params with the batches of generated params
|
|
627
|
+
params = _update_eq_params_dict(params, batch.param_batch_dict)
|
|
836
628
|
|
|
837
629
|
vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
|
|
838
630
|
|
|
839
|
-
#
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
mse_dyn_loss = dynamic_loss_apply(
|
|
843
|
-
self.dynamic_loss.evaluate,
|
|
844
|
-
self.u,
|
|
845
|
-
(times_batch, omega_batch),
|
|
846
|
-
params_,
|
|
847
|
-
vmap_in_axes_x_t + vmap_in_axes_params,
|
|
848
|
-
self.loss_weights["dyn_loss"],
|
|
849
|
-
)
|
|
850
|
-
else:
|
|
851
|
-
mse_dyn_loss = jnp.array(0.0)
|
|
852
|
-
|
|
853
|
-
# normalization part
|
|
854
|
-
params_ = _set_derivatives(params, "norm_loss", self.derivative_keys)
|
|
855
|
-
if self.normalization_loss is not None:
|
|
856
|
-
mse_norm_loss = normalization_loss_apply(
|
|
857
|
-
self.u,
|
|
858
|
-
(times_batch, self.get_norm_samples()),
|
|
859
|
-
params_,
|
|
860
|
-
vmap_in_axes_x_t + vmap_in_axes_params,
|
|
861
|
-
self.int_length,
|
|
862
|
-
self.loss_weights["norm_loss"],
|
|
863
|
-
)
|
|
864
|
-
else:
|
|
865
|
-
mse_norm_loss = jnp.array(0.0)
|
|
866
|
-
|
|
867
|
-
# boundary part
|
|
868
|
-
params_ = _set_derivatives(params, "boundary_loss", self.derivative_keys)
|
|
869
|
-
if self.omega_boundary_fun is not None:
|
|
870
|
-
mse_boundary_loss = boundary_condition_apply(
|
|
871
|
-
self.u,
|
|
872
|
-
batch,
|
|
873
|
-
params_,
|
|
874
|
-
self.omega_boundary_fun,
|
|
875
|
-
self.omega_boundary_condition,
|
|
876
|
-
self.omega_boundary_dim,
|
|
877
|
-
self.loss_weights["boundary_loss"],
|
|
878
|
-
)
|
|
879
|
-
else:
|
|
880
|
-
mse_boundary_loss = jnp.array(0.0)
|
|
631
|
+
# For mse_dyn_loss, mse_norm_loss, mse_boundary_loss,
|
|
632
|
+
# mse_observation_loss we use the evaluate from parent class
|
|
633
|
+
partial_mse, partial_mse_terms = super().evaluate(params, batch)
|
|
881
634
|
|
|
882
635
|
# initial condition
|
|
883
|
-
params_ = _set_derivatives(params, "initial_condition", self.derivative_keys)
|
|
884
636
|
if self.initial_condition_fun is not None:
|
|
885
637
|
mse_initial_condition = initial_condition_apply(
|
|
886
638
|
self.u,
|
|
887
639
|
omega_batch,
|
|
888
|
-
|
|
640
|
+
_set_derivatives(params, self.derivative_keys.initial_condition),
|
|
889
641
|
(0,) + vmap_in_axes_params,
|
|
890
642
|
self.initial_condition_fun,
|
|
891
|
-
|
|
892
|
-
self.loss_weights
|
|
643
|
+
omega_batch.shape[0],
|
|
644
|
+
self.loss_weights.initial_condition,
|
|
893
645
|
)
|
|
894
646
|
else:
|
|
895
647
|
mse_initial_condition = jnp.array(0.0)
|
|
896
648
|
|
|
897
|
-
# Observation mse
|
|
898
|
-
if batch.obs_batch_dict is not None:
|
|
899
|
-
# update params with the batches of observed params
|
|
900
|
-
params = _update_eq_params_dict(params, batch.obs_batch_dict["eq_params"])
|
|
901
|
-
|
|
902
|
-
params_ = _set_derivatives(params, "observations", self.derivative_keys)
|
|
903
|
-
mse_observation_loss = observations_loss_apply(
|
|
904
|
-
self.u,
|
|
905
|
-
(
|
|
906
|
-
batch.obs_batch_dict["pinn_in"][:, 0:1],
|
|
907
|
-
batch.obs_batch_dict["pinn_in"][:, 1:],
|
|
908
|
-
),
|
|
909
|
-
params_,
|
|
910
|
-
vmap_in_axes_x_t + vmap_in_axes_params,
|
|
911
|
-
batch.obs_batch_dict["val"],
|
|
912
|
-
self.loss_weights["observations"],
|
|
913
|
-
self.obs_slice,
|
|
914
|
-
)
|
|
915
|
-
else:
|
|
916
|
-
mse_observation_loss = jnp.array(0.0)
|
|
917
|
-
|
|
918
|
-
# Sobolev regularization
|
|
919
|
-
params_ = _set_derivatives(params, "sobolev", self.derivative_keys)
|
|
920
|
-
if self.sobolev_reg is not None:
|
|
921
|
-
mse_sobolev_loss = sobolev_reg_apply(
|
|
922
|
-
self.u,
|
|
923
|
-
(omega_batch, times_batch),
|
|
924
|
-
params_,
|
|
925
|
-
vmap_in_axes_x_t + vmap_in_axes_params,
|
|
926
|
-
self.sobolev_reg,
|
|
927
|
-
self.loss_weights["sobolev"],
|
|
928
|
-
)
|
|
929
|
-
else:
|
|
930
|
-
mse_sobolev_loss = jnp.array(0.0)
|
|
931
|
-
|
|
932
649
|
# total loss
|
|
933
|
-
total_loss =
|
|
934
|
-
mse_dyn_loss
|
|
935
|
-
+ mse_norm_loss
|
|
936
|
-
+ mse_boundary_loss
|
|
937
|
-
+ mse_initial_condition
|
|
938
|
-
+ mse_observation_loss
|
|
939
|
-
+ mse_sobolev_loss
|
|
940
|
-
)
|
|
941
|
-
|
|
942
|
-
return total_loss, (
|
|
943
|
-
{
|
|
944
|
-
"dyn_loss": mse_dyn_loss,
|
|
945
|
-
"norm_loss": mse_norm_loss,
|
|
946
|
-
"boundary_loss": mse_boundary_loss,
|
|
947
|
-
"initial_condition": mse_initial_condition,
|
|
948
|
-
"observations": mse_observation_loss,
|
|
949
|
-
"sobolev": mse_sobolev_loss,
|
|
950
|
-
}
|
|
951
|
-
)
|
|
650
|
+
total_loss = partial_mse + mse_initial_condition
|
|
952
651
|
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
"u": self.u,
|
|
957
|
-
"dynamic_loss": self.dynamic_loss,
|
|
958
|
-
"derivative_keys": self.derivative_keys,
|
|
959
|
-
"omega_boundary_fun": self.omega_boundary_fun,
|
|
960
|
-
"omega_boundary_condition": self.omega_boundary_condition,
|
|
961
|
-
"omega_boundary_dim": self.omega_boundary_dim,
|
|
962
|
-
"initial_condition_fun": self.initial_condition_fun,
|
|
963
|
-
"norm_borders": self.norm_borders,
|
|
964
|
-
"sobolev_m": self.sobolev_m,
|
|
965
|
-
"obs_slice": self.obs_slice,
|
|
652
|
+
return total_loss, {
|
|
653
|
+
**partial_mse_terms,
|
|
654
|
+
"initial_condition": mse_initial_condition,
|
|
966
655
|
}
|
|
967
|
-
return (children, aux_data)
|
|
968
|
-
|
|
969
|
-
@classmethod
|
|
970
|
-
def tree_unflatten(cls, aux_data, children):
|
|
971
|
-
(norm_key, norm_samples, loss_weights) = children
|
|
972
|
-
pls = cls(
|
|
973
|
-
aux_data["u"],
|
|
974
|
-
loss_weights,
|
|
975
|
-
aux_data["dynamic_loss"],
|
|
976
|
-
aux_data["derivative_keys"],
|
|
977
|
-
aux_data["omega_boundary_fun"],
|
|
978
|
-
aux_data["omega_boundary_condition"],
|
|
979
|
-
aux_data["omega_boundary_dim"],
|
|
980
|
-
aux_data["initial_condition_fun"],
|
|
981
|
-
norm_key,
|
|
982
|
-
aux_data["norm_borders"],
|
|
983
|
-
norm_samples,
|
|
984
|
-
aux_data["sobolev_m"],
|
|
985
|
-
aux_data["obs_slice"],
|
|
986
|
-
)
|
|
987
|
-
return pls
|
|
988
656
|
|
|
989
657
|
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
"""
|
|
658
|
+
class SystemLossPDE(eqx.Module):
|
|
659
|
+
r"""
|
|
993
660
|
Class to implement a system of PDEs.
|
|
994
661
|
The goal is to give maximum freedom to the user. The class is created with
|
|
995
662
|
a dict of dynamic loss, and dictionaries of all the objects that are used
|
|
@@ -1003,190 +670,186 @@ class SystemLossPDE:
|
|
|
1003
670
|
Indeed, these dictionaries (except `dynamic_loss_dict`) are tied to one
|
|
1004
671
|
solution.
|
|
1005
672
|
|
|
1006
|
-
|
|
1007
|
-
|
|
673
|
+
Parameters
|
|
674
|
+
----------
|
|
675
|
+
u_dict : Dict[str, eqx.Module]
|
|
676
|
+
dict of PINNs
|
|
677
|
+
loss_weights : LossWeightsPDEDict
|
|
678
|
+
A dictionary of LossWeightsODE
|
|
679
|
+
derivative_keys_dict : Dict[str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio], default=None
|
|
680
|
+
A dictionnary of DerivativeKeysPDEStatio or DerivativeKeysPDENonStatio
|
|
681
|
+
specifying what field of `params`
|
|
682
|
+
should be used during gradient computations for each of the terms of
|
|
683
|
+
the total loss, for each of the loss in the system. Default is
|
|
684
|
+
`"nn_params`" everywhere.
|
|
685
|
+
dynamic_loss_dict : Dict[str, PDEStatio | PDENonStatio]
|
|
686
|
+
A dict of dynamic part of the loss, basically the differential
|
|
687
|
+
operator $\mathcal{N}[u](t, x)$ or $\mathcal{N}[u](x)$.
|
|
688
|
+
key_dict : Dict[str, Key], default=None
|
|
689
|
+
A dictionary of JAX PRNG keys. The dictionary keys of key_dict must
|
|
690
|
+
match that of u_dict. See LossPDEStatio or LossPDENonStatio for
|
|
691
|
+
more details.
|
|
692
|
+
omega_boundary_fun_dict : Dict[str, Callable | Dict[str, Callable] | None], default=None
|
|
693
|
+
A dict of of function or of dict of functions or of None
|
|
694
|
+
(see doc for `omega_boundary_fun` in
|
|
695
|
+
LossPDEStatio or LossPDENonStatio). Default is None.
|
|
696
|
+
Must share the keys of `u_dict`.
|
|
697
|
+
omega_boundary_condition_dict : Dict[str, str | Dict[str, str] | None], default=None
|
|
698
|
+
A dict of strings or of dict of strings or of None
|
|
699
|
+
(see doc for `omega_boundary_condition_dict` in
|
|
700
|
+
LossPDEStatio or LossPDENonStatio). Default is None.
|
|
701
|
+
Must share the keys of `u_dict`
|
|
702
|
+
omega_boundary_dim_dict : Dict[str, slice | Dict[str, slice] | None], default=None
|
|
703
|
+
A dict of slices or of dict of slices or of None
|
|
704
|
+
(see doc for `omega_boundary_dim` in
|
|
705
|
+
LossPDEStatio or LossPDENonStatio). Default is None.
|
|
706
|
+
Must share the keys of `u_dict`
|
|
707
|
+
initial_condition_fun_dict : Dict[str, Callable | None], default=None
|
|
708
|
+
A dict of functions representing the temporal initial condition (None
|
|
709
|
+
value is possible). If None
|
|
710
|
+
(default) then no temporal boundary condition is applied
|
|
711
|
+
Must share the keys of `u_dict`
|
|
712
|
+
norm_samples_dict : Dict[str, Float[Array, "nb_norm_samples dimension"] | None, default=None
|
|
713
|
+
A dict of fixed sample point in the space over which to compute the
|
|
714
|
+
normalization constant. Default is None
|
|
715
|
+
Must share the keys of `u_dict`
|
|
716
|
+
norm_int_length_dict : Dict[str, float | None] | None, default=None
|
|
717
|
+
A dict of Float. The domain area
|
|
718
|
+
(or interval length in 1D) upon which we perform the numerical
|
|
719
|
+
integration for each element of u_dict.
|
|
720
|
+
Default is None
|
|
721
|
+
Must share the keys of `u_dict`
|
|
722
|
+
obs_slice_dict : Dict[str, slice | None] | None, default=None
|
|
723
|
+
dict of obs_slice, with keys from `u_dict` to designate the
|
|
724
|
+
output(s) channels that are forced to observed values, for each
|
|
725
|
+
PINNs. Default is None. But if a value is given, all the entries of
|
|
726
|
+
`u_dict` must be represented here with default value `jnp.s_[...]`
|
|
727
|
+
if no particular slice is to be given
|
|
728
|
+
|
|
1008
729
|
"""
|
|
1009
730
|
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
)
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
`omega_boundary_condition_dict` in
|
|
1062
|
-
LossPDEStatio or LossPDENonStatio). Default is None.
|
|
1063
|
-
Must share the keys of `u_dict`
|
|
1064
|
-
omega_boundary_dim_dict
|
|
1065
|
-
A dict of dict of slices (see doc for `omega_boundary_dim` in
|
|
1066
|
-
LossPDEStatio or LossPDENonStatio). Default is None.
|
|
1067
|
-
Must share the keys of `u_dict`
|
|
1068
|
-
initial_condition_fun_dict
|
|
1069
|
-
A dict of functions representing the temporal initial condition. If None
|
|
1070
|
-
(default) then no temporal boundary condition is applied
|
|
1071
|
-
Must share the keys of `u_dict`
|
|
1072
|
-
norm_key_dict
|
|
1073
|
-
A dict of Jax random keys to draw samples in for the Monte Carlo computation
|
|
1074
|
-
of the normalization constant. Default is None
|
|
1075
|
-
Must share the keys of `u_dict`
|
|
1076
|
-
norm_borders_dict
|
|
1077
|
-
A dict of tuples of (min, max) of the boundaray values of the space over which
|
|
1078
|
-
to integrate in the computation of the normalization constant.
|
|
1079
|
-
A list of tuple for higher dimensional problems. Default None.
|
|
1080
|
-
Must share the keys of `u_dict`
|
|
1081
|
-
norm_samples_dict
|
|
1082
|
-
A dict of fixed sample point in the space over which to compute the
|
|
1083
|
-
normalization constant. Default is None
|
|
1084
|
-
Must share the keys of `u_dict`
|
|
1085
|
-
sobolev_m
|
|
1086
|
-
Default is None. A dictionary of integers, one per key which must
|
|
1087
|
-
match `u_dict`.
|
|
1088
|
-
It corresponds to the Sobolev regularization order as proposed in
|
|
1089
|
-
*Convergence and error analysis of PINNs*,
|
|
1090
|
-
Doumeche et al., 2023, https://arxiv.org/pdf/2305.01240.pdf
|
|
1091
|
-
obs_slice_dict
|
|
1092
|
-
dict of obs_slice, with keys from `u_dict` to designate the
|
|
1093
|
-
output(s) channels that are forced to observed values, for each
|
|
1094
|
-
PINNs. Default is None. But if a value is given, all the entries of
|
|
1095
|
-
`u_dict` must be represented here with default value `jnp.s_[...]`
|
|
1096
|
-
if no particular slice is to be given
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
Raises
|
|
1100
|
-
------
|
|
1101
|
-
ValueError
|
|
1102
|
-
if initial condition is not a dict of tuple
|
|
1103
|
-
ValueError
|
|
1104
|
-
if the dictionaries that should share the keys of u_dict do not
|
|
1105
|
-
"""
|
|
731
|
+
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
732
|
+
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
733
|
+
u_dict: Dict[str, eqx.Module]
|
|
734
|
+
dynamic_loss_dict: Dict[str, PDEStatio | PDENonStatio]
|
|
735
|
+
key_dict: Dict[str, Key] | None = eqx.field(kw_only=True, default=None)
|
|
736
|
+
derivative_keys_dict: Dict[
|
|
737
|
+
str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio | None
|
|
738
|
+
] = eqx.field(kw_only=True, default=None)
|
|
739
|
+
omega_boundary_fun_dict: Dict[str, Callable | Dict[str, Callable] | None] | None = (
|
|
740
|
+
eqx.field(kw_only=True, default=None, static=True)
|
|
741
|
+
)
|
|
742
|
+
omega_boundary_condition_dict: Dict[str, str | Dict[str, str] | None] | None = (
|
|
743
|
+
eqx.field(kw_only=True, default=None, static=True)
|
|
744
|
+
)
|
|
745
|
+
omega_boundary_dim_dict: Dict[str, slice | Dict[str, slice] | None] | None = (
|
|
746
|
+
eqx.field(kw_only=True, default=None, static=True)
|
|
747
|
+
)
|
|
748
|
+
initial_condition_fun_dict: Dict[str, Callable | None] | None = eqx.field(
|
|
749
|
+
kw_only=True, default=None, static=True
|
|
750
|
+
)
|
|
751
|
+
norm_samples_dict: Dict[str, Float[Array, "nb_norm_samples dimension"]] | None = (
|
|
752
|
+
eqx.field(kw_only=True, default=None)
|
|
753
|
+
)
|
|
754
|
+
norm_int_length_dict: Dict[str, float | None] | None = eqx.field(
|
|
755
|
+
kw_only=True, default=None
|
|
756
|
+
)
|
|
757
|
+
obs_slice_dict: Dict[str, slice | None] | None = eqx.field(
|
|
758
|
+
kw_only=True, default=None, static=True
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
# For the user loss_weights are passed as a LossWeightsPDEDict (with internal
|
|
762
|
+
# dictionary having keys in u_dict and / or dynamic_loss_dict)
|
|
763
|
+
loss_weights: InitVar[LossWeightsPDEDict | None] = eqx.field(
|
|
764
|
+
kw_only=True, default=None
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
# following have init=False and are set in the __post_init__
|
|
768
|
+
u_constraints_dict: Dict[str, LossPDEStatio | LossPDENonStatio] = eqx.field(
|
|
769
|
+
init=False
|
|
770
|
+
)
|
|
771
|
+
derivative_keys_u_dict: Dict[
|
|
772
|
+
str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio
|
|
773
|
+
] = eqx.field(init=False)
|
|
774
|
+
derivative_keys_dyn_loss_dict: Dict[
|
|
775
|
+
str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio
|
|
776
|
+
] = eqx.field(init=False)
|
|
777
|
+
u_dict_with_none: Dict[str, None] = eqx.field(init=False)
|
|
778
|
+
# internally the loss weights are handled with a dictionary
|
|
779
|
+
_loss_weights: Dict[str, dict] = eqx.field(init=False)
|
|
780
|
+
|
|
781
|
+
def __post_init__(self, loss_weights):
|
|
1106
782
|
# a dictionary that will be useful at different places
|
|
1107
|
-
self.u_dict_with_none = {k: None for k in u_dict.keys()}
|
|
783
|
+
self.u_dict_with_none = {k: None for k in self.u_dict.keys()}
|
|
1108
784
|
# First, for all the optional dict,
|
|
1109
785
|
# if the user did not provide at all this optional argument,
|
|
1110
786
|
# we make sure there is a null ponderating loss_weight and we
|
|
1111
787
|
# create a dummy dict with the required keys and all the values to
|
|
1112
788
|
# None
|
|
1113
|
-
if
|
|
789
|
+
if self.key_dict is None:
|
|
790
|
+
self.key_dict = self.u_dict_with_none
|
|
791
|
+
if self.omega_boundary_fun_dict is None:
|
|
1114
792
|
self.omega_boundary_fun_dict = self.u_dict_with_none
|
|
1115
|
-
|
|
1116
|
-
self.omega_boundary_fun_dict = omega_boundary_fun_dict
|
|
1117
|
-
if omega_boundary_condition_dict is None:
|
|
793
|
+
if self.omega_boundary_condition_dict is None:
|
|
1118
794
|
self.omega_boundary_condition_dict = self.u_dict_with_none
|
|
1119
|
-
|
|
1120
|
-
self.omega_boundary_condition_dict = omega_boundary_condition_dict
|
|
1121
|
-
if omega_boundary_dim_dict is None:
|
|
795
|
+
if self.omega_boundary_dim_dict is None:
|
|
1122
796
|
self.omega_boundary_dim_dict = self.u_dict_with_none
|
|
1123
|
-
|
|
1124
|
-
self.omega_boundary_dim_dict = omega_boundary_dim_dict
|
|
1125
|
-
if initial_condition_fun_dict is None:
|
|
797
|
+
if self.initial_condition_fun_dict is None:
|
|
1126
798
|
self.initial_condition_fun_dict = self.u_dict_with_none
|
|
1127
|
-
|
|
1128
|
-
self.initial_condition_fun_dict = initial_condition_fun_dict
|
|
1129
|
-
if norm_key_dict is None:
|
|
1130
|
-
self.norm_key_dict = self.u_dict_with_none
|
|
1131
|
-
else:
|
|
1132
|
-
self.norm_key_dict = norm_key_dict
|
|
1133
|
-
if norm_borders_dict is None:
|
|
1134
|
-
self.norm_borders_dict = self.u_dict_with_none
|
|
1135
|
-
else:
|
|
1136
|
-
self.norm_borders_dict = norm_borders_dict
|
|
1137
|
-
if norm_samples_dict is None:
|
|
799
|
+
if self.norm_samples_dict is None:
|
|
1138
800
|
self.norm_samples_dict = self.u_dict_with_none
|
|
1139
|
-
|
|
1140
|
-
self.
|
|
1141
|
-
if
|
|
1142
|
-
self.
|
|
1143
|
-
|
|
1144
|
-
self.sobolev_m_dict = sobolev_m_dict
|
|
1145
|
-
if obs_slice_dict is None:
|
|
1146
|
-
self.obs_slice_dict = {k: jnp.s_[...] for k in u_dict.keys()}
|
|
1147
|
-
else:
|
|
1148
|
-
self.obs_slice_dict = obs_slice_dict
|
|
1149
|
-
if u_dict.keys() != obs_slice_dict.keys():
|
|
801
|
+
if self.norm_int_length_dict is None:
|
|
802
|
+
self.norm_int_length_dict = self.u_dict_with_none
|
|
803
|
+
if self.obs_slice_dict is None:
|
|
804
|
+
self.obs_slice_dict = {k: jnp.s_[...] for k in self.u_dict.keys()}
|
|
805
|
+
if self.u_dict.keys() != self.obs_slice_dict.keys():
|
|
1150
806
|
raise ValueError("obs_slice_dict should have same keys as u_dict")
|
|
1151
|
-
if derivative_keys_dict is None:
|
|
807
|
+
if self.derivative_keys_dict is None:
|
|
1152
808
|
self.derivative_keys_dict = {
|
|
1153
809
|
k: None
|
|
1154
|
-
for k in set(
|
|
810
|
+
for k in set(
|
|
811
|
+
list(self.dynamic_loss_dict.keys()) + list(self.u_dict.keys())
|
|
812
|
+
)
|
|
1155
813
|
}
|
|
1156
814
|
# set() because we can have duplicate entries and in this case we
|
|
1157
815
|
# say it corresponds to the same derivative_keys_dict entry
|
|
1158
|
-
|
|
1159
|
-
|
|
816
|
+
# we need both because the constraints (all but dyn_loss) will be
|
|
817
|
+
# done by iterating on u_dict while the dyn_loss will be by
|
|
818
|
+
# iterating on dynamic_loss_dict. So each time we will require dome
|
|
819
|
+
# derivative_keys_dict
|
|
820
|
+
|
|
1160
821
|
# but then if the user did not provide anything, we must at least have
|
|
1161
822
|
# a default value for the dynamic_loss_dict keys entries in
|
|
1162
823
|
# self.derivative_keys_dict since the computation of dynamic losses is
|
|
1163
|
-
# made without create a
|
|
824
|
+
# made without create a loss object that would provide the
|
|
1164
825
|
# default values
|
|
1165
|
-
for k in dynamic_loss_dict.keys():
|
|
826
|
+
for k in self.dynamic_loss_dict.keys():
|
|
1166
827
|
if self.derivative_keys_dict[k] is None:
|
|
1167
|
-
|
|
828
|
+
try:
|
|
829
|
+
if self.u_dict[k].eq_type == "statio_PDE":
|
|
830
|
+
self.derivative_keys_dict[k] = DerivativeKeysPDEStatio()
|
|
831
|
+
else:
|
|
832
|
+
self.derivative_keys_dict[k] = DerivativeKeysPDENonStatio()
|
|
833
|
+
except KeyError: # We are in a key that is not in u_dict but in
|
|
834
|
+
# dynamic_loss_dict
|
|
835
|
+
if isinstance(self.dynamic_loss_dict[k], PDEStatio):
|
|
836
|
+
self.derivative_keys_dict[k] = DerivativeKeysPDEStatio()
|
|
837
|
+
else:
|
|
838
|
+
self.derivative_keys_dict[k] = DerivativeKeysPDENonStatio()
|
|
1168
839
|
|
|
1169
840
|
# Second we make sure that all the dicts (except dynamic_loss_dict) have the same keys
|
|
1170
841
|
if (
|
|
1171
|
-
u_dict.keys() !=
|
|
1172
|
-
or u_dict.keys() != self.omega_boundary_fun_dict.keys()
|
|
1173
|
-
or u_dict.keys() != self.omega_boundary_condition_dict.keys()
|
|
1174
|
-
or u_dict.keys() != self.omega_boundary_dim_dict.keys()
|
|
1175
|
-
or u_dict.keys() != self.initial_condition_fun_dict.keys()
|
|
1176
|
-
or u_dict.keys() != self.
|
|
1177
|
-
or u_dict.keys() != self.
|
|
1178
|
-
or u_dict.keys() != self.norm_samples_dict.keys()
|
|
1179
|
-
or u_dict.keys() != self.sobolev_m_dict.keys()
|
|
842
|
+
self.u_dict.keys() != self.key_dict.keys()
|
|
843
|
+
or self.u_dict.keys() != self.omega_boundary_fun_dict.keys()
|
|
844
|
+
or self.u_dict.keys() != self.omega_boundary_condition_dict.keys()
|
|
845
|
+
or self.u_dict.keys() != self.omega_boundary_dim_dict.keys()
|
|
846
|
+
or self.u_dict.keys() != self.initial_condition_fun_dict.keys()
|
|
847
|
+
or self.u_dict.keys() != self.norm_samples_dict.keys()
|
|
848
|
+
or self.u_dict.keys() != self.norm_int_length_dict.keys()
|
|
1180
849
|
):
|
|
1181
850
|
raise ValueError("All the dicts concerning the PINNs should have same keys")
|
|
1182
851
|
|
|
1183
|
-
self.
|
|
1184
|
-
self.u_dict = u_dict
|
|
1185
|
-
# TODO nn_type should become a class attribute now that we have PINN
|
|
1186
|
-
# class and SPINNs class
|
|
1187
|
-
self.nn_type_dict = nn_type_dict
|
|
1188
|
-
|
|
1189
|
-
self.loss_weights = loss_weights # This calls the setter
|
|
852
|
+
self._loss_weights = self.set_loss_weights(loss_weights)
|
|
1190
853
|
|
|
1191
854
|
# Third, in order not to benefit from LossPDEStatio and
|
|
1192
855
|
# LossPDENonStatio and in order to factorize code, we create internally
|
|
@@ -1194,52 +857,51 @@ class SystemLossPDE:
|
|
|
1194
857
|
# We will not use the dynamic loss term
|
|
1195
858
|
self.u_constraints_dict = {}
|
|
1196
859
|
for i in self.u_dict.keys():
|
|
1197
|
-
if self.
|
|
860
|
+
if self.u_dict[i].eq_type == "statio_PDE":
|
|
1198
861
|
self.u_constraints_dict[i] = LossPDEStatio(
|
|
1199
|
-
u=u_dict[i],
|
|
1200
|
-
loss_weights=
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
|
|
862
|
+
u=self.u_dict[i],
|
|
863
|
+
loss_weights=LossWeightsPDENonStatio(
|
|
864
|
+
dyn_loss=0.0,
|
|
865
|
+
norm_loss=1.0,
|
|
866
|
+
boundary_loss=1.0,
|
|
867
|
+
observations=1.0,
|
|
868
|
+
initial_condition=1.0,
|
|
869
|
+
),
|
|
1207
870
|
dynamic_loss=None,
|
|
871
|
+
key=self.key_dict[i],
|
|
1208
872
|
derivative_keys=self.derivative_keys_dict[i],
|
|
1209
873
|
omega_boundary_fun=self.omega_boundary_fun_dict[i],
|
|
1210
874
|
omega_boundary_condition=self.omega_boundary_condition_dict[i],
|
|
1211
875
|
omega_boundary_dim=self.omega_boundary_dim_dict[i],
|
|
1212
|
-
norm_key=self.norm_key_dict[i],
|
|
1213
|
-
norm_borders=self.norm_borders_dict[i],
|
|
1214
876
|
norm_samples=self.norm_samples_dict[i],
|
|
1215
|
-
|
|
877
|
+
norm_int_length=self.norm_int_length_dict[i],
|
|
1216
878
|
obs_slice=self.obs_slice_dict[i],
|
|
1217
879
|
)
|
|
1218
|
-
elif self.
|
|
880
|
+
elif self.u_dict[i].eq_type == "nonstatio_PDE":
|
|
1219
881
|
self.u_constraints_dict[i] = LossPDENonStatio(
|
|
1220
|
-
u=u_dict[i],
|
|
1221
|
-
loss_weights=
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
},
|
|
882
|
+
u=self.u_dict[i],
|
|
883
|
+
loss_weights=LossWeightsPDENonStatio(
|
|
884
|
+
dyn_loss=0.0,
|
|
885
|
+
norm_loss=1.0,
|
|
886
|
+
boundary_loss=1.0,
|
|
887
|
+
observations=1.0,
|
|
888
|
+
initial_condition=1.0,
|
|
889
|
+
),
|
|
1229
890
|
dynamic_loss=None,
|
|
891
|
+
key=self.key_dict[i],
|
|
1230
892
|
derivative_keys=self.derivative_keys_dict[i],
|
|
1231
893
|
omega_boundary_fun=self.omega_boundary_fun_dict[i],
|
|
1232
894
|
omega_boundary_condition=self.omega_boundary_condition_dict[i],
|
|
1233
895
|
omega_boundary_dim=self.omega_boundary_dim_dict[i],
|
|
1234
896
|
initial_condition_fun=self.initial_condition_fun_dict[i],
|
|
1235
|
-
norm_key=self.norm_key_dict[i],
|
|
1236
|
-
norm_borders=self.norm_borders_dict[i],
|
|
1237
897
|
norm_samples=self.norm_samples_dict[i],
|
|
1238
|
-
|
|
898
|
+
norm_int_length=self.norm_int_length_dict[i],
|
|
899
|
+
obs_slice=self.obs_slice_dict[i],
|
|
1239
900
|
)
|
|
1240
901
|
else:
|
|
1241
902
|
raise ValueError(
|
|
1242
|
-
|
|
903
|
+
"Wrong value for self.u_dict[i].eq_type[i], "
|
|
904
|
+
f"got {self.u_dict[i].eq_type[i]}"
|
|
1243
905
|
)
|
|
1244
906
|
|
|
1245
907
|
# for convenience in the tree_map of evaluate,
|
|
@@ -1255,34 +917,38 @@ class SystemLossPDE:
|
|
|
1255
917
|
|
|
1256
918
|
# also make sure we only have PINNs or SPINNs
|
|
1257
919
|
if not (
|
|
1258
|
-
all(isinstance(value, PINN) for value in u_dict.values())
|
|
1259
|
-
or all(isinstance(value, SPINN) for value in u_dict.values())
|
|
920
|
+
all(isinstance(value, PINN) for value in self.u_dict.values())
|
|
921
|
+
or all(isinstance(value, SPINN) for value in self.u_dict.values())
|
|
1260
922
|
):
|
|
1261
923
|
raise ValueError(
|
|
1262
924
|
"We only accept dictionary of PINNs or dictionary of SPINNs"
|
|
1263
925
|
)
|
|
1264
926
|
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
|
-
|
|
1270
|
-
|
|
1271
|
-
|
|
1272
|
-
|
|
927
|
+
def set_loss_weights(
|
|
928
|
+
self, loss_weights_init: LossWeightsPDEDict
|
|
929
|
+
) -> dict[str, dict]:
|
|
930
|
+
"""
|
|
931
|
+
This rather complex function enables the user to specify a simple
|
|
932
|
+
loss_weights=LossWeightsPDEDict(dyn_loss=1., initial_condition=Tmax)
|
|
933
|
+
for ponderating values being applied to all the equations of the
|
|
934
|
+
system... So all the transformations are handled here
|
|
935
|
+
"""
|
|
936
|
+
_loss_weights = {}
|
|
937
|
+
for k in fields(loss_weights_init):
|
|
938
|
+
v = getattr(loss_weights_init, k.name)
|
|
1273
939
|
if isinstance(v, dict):
|
|
1274
|
-
for
|
|
940
|
+
for vv in v.keys():
|
|
1275
941
|
if not isinstance(vv, (int, float)) and not (
|
|
1276
|
-
isinstance(vv,
|
|
942
|
+
isinstance(vv, Array)
|
|
1277
943
|
and ((vv.shape == (1,) or len(vv.shape) == 0))
|
|
1278
944
|
):
|
|
1279
945
|
# TODO improve that
|
|
1280
946
|
raise ValueError(
|
|
1281
947
|
f"loss values cannot be vectorial here, got {vv}"
|
|
1282
948
|
)
|
|
1283
|
-
if k == "dyn_loss":
|
|
949
|
+
if k.name == "dyn_loss":
|
|
1284
950
|
if v.keys() == self.dynamic_loss_dict.keys():
|
|
1285
|
-
|
|
951
|
+
_loss_weights[k.name] = v
|
|
1286
952
|
else:
|
|
1287
953
|
raise ValueError(
|
|
1288
954
|
"Keys in nested dictionary of loss_weights"
|
|
@@ -1290,51 +956,36 @@ class SystemLossPDE:
|
|
|
1290
956
|
)
|
|
1291
957
|
else:
|
|
1292
958
|
if v.keys() == self.u_dict.keys():
|
|
1293
|
-
|
|
959
|
+
_loss_weights[k.name] = v
|
|
1294
960
|
else:
|
|
1295
961
|
raise ValueError(
|
|
1296
962
|
"Keys in nested dictionary of loss_weights"
|
|
1297
963
|
" do not match u_dict keys"
|
|
1298
964
|
)
|
|
965
|
+
if v is None:
|
|
966
|
+
_loss_weights[k.name] = {kk: 0 for kk in self.u_dict.keys()}
|
|
1299
967
|
else:
|
|
1300
968
|
if not isinstance(v, (int, float)) and not (
|
|
1301
|
-
isinstance(v,
|
|
1302
|
-
and ((v.shape == (1,) or len(v.shape) == 0))
|
|
969
|
+
isinstance(v, Array) and ((v.shape == (1,) or len(v.shape) == 0))
|
|
1303
970
|
):
|
|
1304
971
|
# TODO improve that
|
|
1305
972
|
raise ValueError(f"loss values cannot be vectorial here, got {v}")
|
|
1306
|
-
if k == "dyn_loss":
|
|
1307
|
-
|
|
973
|
+
if k.name == "dyn_loss":
|
|
974
|
+
_loss_weights[k.name] = {
|
|
1308
975
|
kk: v for kk in self.dynamic_loss_dict.keys()
|
|
1309
976
|
}
|
|
1310
977
|
else:
|
|
1311
|
-
|
|
1312
|
-
|
|
1313
|
-
if all(v is None for k, v in self.sobolev_m_dict.items()):
|
|
1314
|
-
self._loss_weights["sobolev"] = {k: 0 for k in self.u_dict.keys()}
|
|
1315
|
-
if "observations" not in value.keys():
|
|
1316
|
-
self._loss_weights["observations"] = {k: 0 for k in self.u_dict.keys()}
|
|
1317
|
-
if all(v is None for k, v in self.omega_boundary_fun_dict.items()) or all(
|
|
1318
|
-
v is None for k, v in self.omega_boundary_condition_dict.items()
|
|
1319
|
-
):
|
|
1320
|
-
self._loss_weights["boundary_loss"] = {k: 0 for k in self.u_dict.keys()}
|
|
1321
|
-
if (
|
|
1322
|
-
all(v is None for k, v in self.norm_key_dict.items())
|
|
1323
|
-
or all(v is None for k, v in self.norm_borders_dict.items())
|
|
1324
|
-
or all(v is None for k, v in self.norm_samples_dict.items())
|
|
1325
|
-
):
|
|
1326
|
-
self._loss_weights["norm_loss"] = {k: 0 for k in self.u_dict.keys()}
|
|
1327
|
-
if all(v is None for k, v in self.initial_condition_fun_dict.items()):
|
|
1328
|
-
self._loss_weights["initial_condition"] = {k: 0 for k in self.u_dict.keys()}
|
|
978
|
+
_loss_weights[k.name] = {kk: v for kk in self.u_dict.keys()}
|
|
979
|
+
return _loss_weights
|
|
1329
980
|
|
|
1330
981
|
def __call__(self, *args, **kwargs):
|
|
1331
982
|
return self.evaluate(*args, **kwargs)
|
|
1332
983
|
|
|
1333
984
|
def evaluate(
|
|
1334
985
|
self,
|
|
1335
|
-
params_dict,
|
|
1336
|
-
batch,
|
|
1337
|
-
):
|
|
986
|
+
params_dict: ParamsDict,
|
|
987
|
+
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
988
|
+
) -> tuple[Float[Array, "1"], dict[str, float]]:
|
|
1338
989
|
"""
|
|
1339
990
|
Evaluate the loss function at a batch of points for given parameters.
|
|
1340
991
|
|
|
@@ -1342,12 +993,8 @@ class SystemLossPDE:
|
|
|
1342
993
|
Parameters
|
|
1343
994
|
---------
|
|
1344
995
|
params_dict
|
|
1345
|
-
|
|
1346
|
-
Typically, it is a dictionary of dictionaries of
|
|
1347
|
-
dictionaries: `eq_params` and `nn_params``, respectively the
|
|
1348
|
-
differential equation parameters and the neural network parameter
|
|
996
|
+
Parameters at which the losses of the system are evaluated
|
|
1349
997
|
batch
|
|
1350
|
-
A PDEStatioBatch or PDENonStatioBatch object.
|
|
1351
998
|
Such named tuples are composed of batch of points in the
|
|
1352
999
|
domain, a batch of points in the domain
|
|
1353
1000
|
border, (a batch of time points a for PDENonStatioBatch) and an
|
|
@@ -1355,7 +1002,7 @@ class SystemLossPDE:
|
|
|
1355
1002
|
and an optional additional batch of observed
|
|
1356
1003
|
inputs/outputs/parameters
|
|
1357
1004
|
"""
|
|
1358
|
-
if self.u_dict.keys() != params_dict
|
|
1005
|
+
if self.u_dict.keys() != params_dict.nn_params.keys():
|
|
1359
1006
|
raise ValueError("u_dict and params_dict[nn_params] should have same keys ")
|
|
1360
1007
|
|
|
1361
1008
|
if isinstance(batch, PDEStatioBatch):
|
|
@@ -1378,9 +1025,10 @@ class SystemLossPDE:
|
|
|
1378
1025
|
if batch.param_batch_dict is not None:
|
|
1379
1026
|
eq_params_batch_dict = batch.param_batch_dict
|
|
1380
1027
|
|
|
1028
|
+
# TODO
|
|
1381
1029
|
# feed the eq_params with the batch
|
|
1382
1030
|
for k in eq_params_batch_dict.keys():
|
|
1383
|
-
params_dict
|
|
1031
|
+
params_dict.eq_params[k] = eq_params_batch_dict[k]
|
|
1384
1032
|
|
|
1385
1033
|
vmap_in_axes_params = _get_vmap_in_axes_params(
|
|
1386
1034
|
batch.param_batch_dict, params_dict
|
|
@@ -1388,12 +1036,11 @@ class SystemLossPDE:
|
|
|
1388
1036
|
|
|
1389
1037
|
def dyn_loss_for_one_key(dyn_loss, derivative_key, loss_weight):
|
|
1390
1038
|
"""The function used in tree_map"""
|
|
1391
|
-
params_dict_ = _set_derivatives(params_dict, "dyn_loss", derivative_key)
|
|
1392
1039
|
return dynamic_loss_apply(
|
|
1393
1040
|
dyn_loss.evaluate,
|
|
1394
1041
|
self.u_dict,
|
|
1395
1042
|
batches,
|
|
1396
|
-
|
|
1043
|
+
_set_derivatives(params_dict, derivative_key.dyn_loss),
|
|
1397
1044
|
vmap_in_axes_x_or_x_t + vmap_in_axes_params,
|
|
1398
1045
|
loss_weight,
|
|
1399
1046
|
u_type=type(list(self.u_dict.values())[0]),
|
|
@@ -1404,6 +1051,13 @@ class SystemLossPDE:
|
|
|
1404
1051
|
self.dynamic_loss_dict,
|
|
1405
1052
|
self.derivative_keys_dyn_loss_dict,
|
|
1406
1053
|
self._loss_weights["dyn_loss"],
|
|
1054
|
+
is_leaf=lambda x: isinstance(
|
|
1055
|
+
x, (PDEStatio, PDENonStatio)
|
|
1056
|
+
), # before when dynamic losses
|
|
1057
|
+
# where plain (unregister pytree) node classes, we could not traverse
|
|
1058
|
+
# this level. Now that dynamic losses are eqx.Module they can be
|
|
1059
|
+
# traversed by tree map recursion. Hence we need to specify to that
|
|
1060
|
+
# we want to stop at this level
|
|
1407
1061
|
)
|
|
1408
1062
|
mse_dyn_loss = jax.tree_util.tree_reduce(
|
|
1409
1063
|
lambda x, y: x + y, jax.tree_util.tree_leaves(dyn_loss_mse_dict)
|
|
@@ -1418,11 +1072,10 @@ class SystemLossPDE:
|
|
|
1418
1072
|
"boundary_loss": "*",
|
|
1419
1073
|
"observations": "*",
|
|
1420
1074
|
"initial_condition": "*",
|
|
1421
|
-
"sobolev": "*",
|
|
1422
1075
|
}
|
|
1423
1076
|
# we need to do the following for the tree_mapping to work
|
|
1424
1077
|
if batch.obs_batch_dict is None:
|
|
1425
|
-
batch = batch
|
|
1078
|
+
batch = append_obs_batch(batch, self.u_dict_with_none)
|
|
1426
1079
|
total_loss, res_dict = constraints_system_loss_apply(
|
|
1427
1080
|
self.u_constraints_dict,
|
|
1428
1081
|
batch,
|
|
@@ -1435,41 +1088,3 @@ class SystemLossPDE:
|
|
|
1435
1088
|
total_loss += mse_dyn_loss
|
|
1436
1089
|
res_dict["dyn_loss"] += mse_dyn_loss
|
|
1437
1090
|
return total_loss, res_dict
|
|
1438
|
-
|
|
1439
|
-
def tree_flatten(self):
|
|
1440
|
-
children = (
|
|
1441
|
-
self.norm_key_dict,
|
|
1442
|
-
self.norm_samples_dict,
|
|
1443
|
-
self.initial_condition_fun_dict,
|
|
1444
|
-
self._loss_weights,
|
|
1445
|
-
)
|
|
1446
|
-
aux_data = {
|
|
1447
|
-
"u_dict": self.u_dict,
|
|
1448
|
-
"dynamic_loss_dict": self.dynamic_loss_dict,
|
|
1449
|
-
"norm_borders_dict": self.norm_borders_dict,
|
|
1450
|
-
"omega_boundary_fun_dict": self.omega_boundary_fun_dict,
|
|
1451
|
-
"omega_boundary_condition_dict": self.omega_boundary_condition_dict,
|
|
1452
|
-
"nn_type_dict": self.nn_type_dict,
|
|
1453
|
-
"sobolev_m_dict": self.sobolev_m_dict,
|
|
1454
|
-
"derivative_keys_dict": self.derivative_keys_dict,
|
|
1455
|
-
"obs_slice_dict": self.obs_slice_dict,
|
|
1456
|
-
}
|
|
1457
|
-
return (children, aux_data)
|
|
1458
|
-
|
|
1459
|
-
@classmethod
|
|
1460
|
-
def tree_unflatten(cls, aux_data, children):
|
|
1461
|
-
(
|
|
1462
|
-
norm_key_dict,
|
|
1463
|
-
norm_samples_dict,
|
|
1464
|
-
initial_condition_fun_dict,
|
|
1465
|
-
loss_weights,
|
|
1466
|
-
) = children
|
|
1467
|
-
loss_ode = cls(
|
|
1468
|
-
loss_weights=loss_weights,
|
|
1469
|
-
norm_key_dict=norm_key_dict,
|
|
1470
|
-
norm_samples_dict=norm_samples_dict,
|
|
1471
|
-
initial_condition_fun_dict=initial_condition_fun_dict,
|
|
1472
|
-
**aux_data,
|
|
1473
|
-
)
|
|
1474
|
-
|
|
1475
|
-
return loss_ode
|