jinns 1.5.0__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +7 -7
- jinns/data/_AbstractDataGenerator.py +1 -1
- jinns/data/_Batchs.py +47 -13
- jinns/data/_CubicMeshPDENonStatio.py +203 -54
- jinns/data/_CubicMeshPDEStatio.py +190 -54
- jinns/data/_DataGeneratorODE.py +48 -22
- jinns/data/_DataGeneratorObservations.py +75 -32
- jinns/data/_DataGeneratorParameter.py +152 -101
- jinns/data/__init__.py +2 -1
- jinns/data/_utils.py +22 -10
- jinns/loss/_DynamicLoss.py +21 -20
- jinns/loss/_DynamicLossAbstract.py +51 -36
- jinns/loss/_LossODE.py +210 -191
- jinns/loss/_LossPDE.py +441 -368
- jinns/loss/_abstract_loss.py +60 -25
- jinns/loss/_loss_components.py +4 -25
- jinns/loss/_loss_utils.py +23 -0
- jinns/loss/_loss_weight_updates.py +6 -7
- jinns/loss/_loss_weights.py +34 -35
- jinns/nn/_abstract_pinn.py +0 -2
- jinns/nn/_hyperpinn.py +34 -23
- jinns/nn/_mlp.py +5 -4
- jinns/nn/_pinn.py +1 -16
- jinns/nn/_ppinn.py +5 -16
- jinns/nn/_save_load.py +11 -4
- jinns/nn/_spinn.py +1 -16
- jinns/nn/_spinn_mlp.py +5 -5
- jinns/nn/_utils.py +33 -38
- jinns/parameters/__init__.py +3 -1
- jinns/parameters/_derivative_keys.py +99 -41
- jinns/parameters/_params.py +58 -25
- jinns/solver/_solve.py +14 -8
- jinns/utils/_DictToModuleMeta.py +66 -0
- jinns/utils/_ItemizableModule.py +19 -0
- jinns/utils/__init__.py +2 -1
- jinns/utils/_types.py +25 -15
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/METADATA +2 -2
- jinns-1.6.0.dist-info/RECORD +57 -0
- jinns-1.5.0.dist-info/RECORD +0 -55
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/WHEEL +0 -0
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossPDE.py
CHANGED
|
@@ -8,23 +8,24 @@ from __future__ import (
|
|
|
8
8
|
|
|
9
9
|
import abc
|
|
10
10
|
from dataclasses import InitVar
|
|
11
|
-
from typing import TYPE_CHECKING, Callable,
|
|
11
|
+
from typing import TYPE_CHECKING, Callable, cast, Any, TypeVar, Generic
|
|
12
12
|
from types import EllipsisType
|
|
13
13
|
import warnings
|
|
14
14
|
import jax
|
|
15
15
|
import jax.numpy as jnp
|
|
16
16
|
import equinox as eqx
|
|
17
|
-
from jaxtyping import Float, Array
|
|
17
|
+
from jaxtyping import PRNGKeyArray, Float, Array
|
|
18
18
|
from jinns.loss._loss_utils import (
|
|
19
19
|
dynamic_loss_apply,
|
|
20
20
|
boundary_condition_apply,
|
|
21
21
|
normalization_loss_apply,
|
|
22
22
|
observations_loss_apply,
|
|
23
23
|
initial_condition_apply,
|
|
24
|
+
initial_condition_check,
|
|
24
25
|
)
|
|
25
26
|
from jinns.parameters._params import (
|
|
26
27
|
_get_vmap_in_axes_params,
|
|
27
|
-
|
|
28
|
+
update_eq_params,
|
|
28
29
|
)
|
|
29
30
|
from jinns.parameters._derivative_keys import (
|
|
30
31
|
_set_derivatives,
|
|
@@ -39,24 +40,14 @@ from jinns.loss._loss_weights import (
|
|
|
39
40
|
)
|
|
40
41
|
from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
|
|
41
42
|
from jinns.parameters._params import Params
|
|
43
|
+
from jinns.loss import PDENonStatio, PDEStatio
|
|
42
44
|
|
|
43
45
|
|
|
44
46
|
if TYPE_CHECKING:
|
|
45
47
|
# imports for type hints only
|
|
46
48
|
from jinns.nn._abstract_pinn import AbstractPINN
|
|
47
|
-
from jinns.loss import PDENonStatio, PDEStatio
|
|
48
49
|
from jinns.utils._types import BoundaryConditionFun
|
|
49
50
|
|
|
50
|
-
class LossDictPDEStatio(TypedDict):
|
|
51
|
-
dyn_loss: Float[Array, " "]
|
|
52
|
-
norm_loss: Float[Array, " "]
|
|
53
|
-
boundary_loss: Float[Array, " "]
|
|
54
|
-
observations: Float[Array, " "]
|
|
55
|
-
|
|
56
|
-
class LossDictPDENonStatio(LossDictPDEStatio):
|
|
57
|
-
initial_condition: Float[Array, " "]
|
|
58
|
-
|
|
59
|
-
|
|
60
51
|
_IMPLEMENTED_BOUNDARY_CONDITIONS = [
|
|
61
52
|
"dirichlet",
|
|
62
53
|
"von neumann",
|
|
@@ -64,7 +55,24 @@ _IMPLEMENTED_BOUNDARY_CONDITIONS = [
|
|
|
64
55
|
]
|
|
65
56
|
|
|
66
57
|
|
|
67
|
-
|
|
58
|
+
# For the same reason that we have the TypeVar in _abstract_loss.py, we have them
|
|
59
|
+
# here, because _LossPDEAbstract is abtract and we cannot decide for several
|
|
60
|
+
# types between their statio and non-statio version.
|
|
61
|
+
# Assigning the type where it can be decide seems a better practice than
|
|
62
|
+
# assigning a type at a higher level depending on a child class type. This is
|
|
63
|
+
# why we now assign LossWeights and DerivativeKeys in the child class where
|
|
64
|
+
# they really can be decided.
|
|
65
|
+
|
|
66
|
+
L = TypeVar("L", bound=LossWeightsPDEStatio | LossWeightsPDENonStatio)
|
|
67
|
+
B = TypeVar("B", bound=PDEStatioBatch | PDENonStatioBatch)
|
|
68
|
+
C = TypeVar(
|
|
69
|
+
"C", bound=PDEStatioComponents[Array | None] | PDENonStatioComponents[Array | None]
|
|
70
|
+
)
|
|
71
|
+
D = TypeVar("D", bound=DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio)
|
|
72
|
+
Y = TypeVar("Y", bound=PDEStatio | PDENonStatio | None)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class _LossPDEAbstract(AbstractLoss[L, B, C], Generic[L, B, C, D, Y]):
|
|
68
76
|
r"""
|
|
69
77
|
Parameters
|
|
70
78
|
----------
|
|
@@ -77,17 +85,11 @@ class _LossPDEAbstract(AbstractLoss):
|
|
|
77
85
|
`update_weight_method`
|
|
78
86
|
update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
|
|
79
87
|
Default is None meaning no update for loss weights. Otherwise a string
|
|
80
|
-
derivative_keys : DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio, default=None
|
|
81
|
-
Specify which field of `params` should be differentiated for each
|
|
82
|
-
composant of the total loss. Particularily useful for inverse problems.
|
|
83
|
-
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
84
|
-
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
85
|
-
is `"nn_params"` for each composant of the loss.
|
|
86
88
|
omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
|
|
87
89
|
The function to be matched in the border condition (can be None) or a
|
|
88
90
|
dictionary of such functions as values and keys as described
|
|
89
91
|
in `omega_boundary_condition`.
|
|
90
|
-
omega_boundary_condition : str | dict[str, str], default=None
|
|
92
|
+
omega_boundary_condition : str | dict[str, str | None], default=None
|
|
91
93
|
Either None (no condition, by default), or a string defining
|
|
92
94
|
the boundary condition (Dirichlet or Von Neumann),
|
|
93
95
|
or a dictionary with such strings as values. In this case,
|
|
@@ -98,7 +100,7 @@ class _LossPDEAbstract(AbstractLoss):
|
|
|
98
100
|
a particular boundary condition on this facet.
|
|
99
101
|
The facet called “xmin”, resp. “xmax” etc., in 2D,
|
|
100
102
|
refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
|
|
101
|
-
omega_boundary_dim : slice | dict[str, slice], default=None
|
|
103
|
+
omega_boundary_dim : int | slice | dict[str, slice], default=None
|
|
102
104
|
Either None, or a slice object or a dictionary of slice objects as
|
|
103
105
|
values and keys as described in `omega_boundary_condition`.
|
|
104
106
|
`omega_boundary_dim` indicates which dimension(s) of the PINN
|
|
@@ -120,197 +122,286 @@ class _LossPDEAbstract(AbstractLoss):
|
|
|
120
122
|
obs_slice : EllipsisType | slice, default=None
|
|
121
123
|
slice object specifying the begininning/ending of the PINN output
|
|
122
124
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
125
|
+
key : Key | None
|
|
126
|
+
A JAX PRNG Key for the loss class treated as an attribute. Default is
|
|
127
|
+
None. This field is provided for future developments and additional
|
|
128
|
+
losses that might need some randomness. Note that special care must be
|
|
129
|
+
taken when splitting the key because in-place updates are forbidden in
|
|
130
|
+
eqx.Modules.
|
|
126
131
|
"""
|
|
127
132
|
|
|
128
133
|
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
129
134
|
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
eqx.field(kw_only=True, default=None)
|
|
133
|
-
)
|
|
134
|
-
loss_weights: LossWeightsPDEStatio | LossWeightsPDENonStatio | None = eqx.field(
|
|
135
|
-
kw_only=True, default=None
|
|
136
|
-
)
|
|
135
|
+
u: eqx.AbstractVar[AbstractPINN]
|
|
136
|
+
dynamic_loss: eqx.AbstractVar[Y]
|
|
137
137
|
omega_boundary_fun: (
|
|
138
138
|
BoundaryConditionFun | dict[str, BoundaryConditionFun] | None
|
|
139
|
-
) = eqx.field(
|
|
140
|
-
omega_boundary_condition: str | dict[str, str] | None = eqx.field(
|
|
141
|
-
|
|
139
|
+
) = eqx.field(static=True)
|
|
140
|
+
omega_boundary_condition: str | dict[str, str | None] | None = eqx.field(
|
|
141
|
+
static=True
|
|
142
142
|
)
|
|
143
|
-
omega_boundary_dim: slice | dict[str, slice]
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
self.derivative_keys = (
|
|
167
|
-
DerivativeKeysPDENonStatio(params=params)
|
|
168
|
-
if isinstance(self, LossPDENonStatio)
|
|
169
|
-
else DerivativeKeysPDEStatio(params=params)
|
|
170
|
-
)
|
|
171
|
-
except ValueError as exc:
|
|
172
|
-
raise ValueError(
|
|
173
|
-
"Problem at self.derivative_keys initialization "
|
|
174
|
-
f"received {self.derivative_keys=} and {params=}"
|
|
175
|
-
) from exc
|
|
176
|
-
|
|
177
|
-
if self.loss_weights is None:
|
|
178
|
-
self.loss_weights = (
|
|
179
|
-
LossWeightsPDENonStatio()
|
|
180
|
-
if isinstance(self, LossPDENonStatio)
|
|
181
|
-
else LossWeightsPDEStatio()
|
|
182
|
-
)
|
|
183
|
-
|
|
184
|
-
if self.obs_slice is None:
|
|
143
|
+
omega_boundary_dim: slice | dict[str, slice] = eqx.field(static=True)
|
|
144
|
+
norm_samples: Float[Array, " nb_norm_samples dimension"] | None
|
|
145
|
+
norm_weights: Float[Array, " nb_norm_samples"] | None
|
|
146
|
+
obs_slice: EllipsisType | slice = eqx.field(static=True)
|
|
147
|
+
key: PRNGKeyArray | None
|
|
148
|
+
|
|
149
|
+
def __init__(
|
|
150
|
+
self,
|
|
151
|
+
*,
|
|
152
|
+
omega_boundary_fun: BoundaryConditionFun
|
|
153
|
+
| dict[str, BoundaryConditionFun]
|
|
154
|
+
| None = None,
|
|
155
|
+
omega_boundary_condition: str | dict[str, str | None] | None = None,
|
|
156
|
+
omega_boundary_dim: int | slice | dict[str, int | slice] | None = None,
|
|
157
|
+
norm_samples: Float[Array, " nb_norm_samples dimension"] | None = None,
|
|
158
|
+
norm_weights: Float[Array, " nb_norm_samples"] | float | int | None = None,
|
|
159
|
+
obs_slice: EllipsisType | slice | None = None,
|
|
160
|
+
key: PRNGKeyArray | None = None,
|
|
161
|
+
**kwargs: Any, # for arguments for super()
|
|
162
|
+
):
|
|
163
|
+
super().__init__(loss_weights=self.loss_weights, **kwargs)
|
|
164
|
+
|
|
165
|
+
if obs_slice is None:
|
|
185
166
|
self.obs_slice = jnp.s_[...]
|
|
167
|
+
else:
|
|
168
|
+
self.obs_slice = obs_slice
|
|
186
169
|
|
|
187
170
|
if (
|
|
188
|
-
isinstance(
|
|
189
|
-
and not isinstance(
|
|
171
|
+
isinstance(omega_boundary_fun, dict)
|
|
172
|
+
and not isinstance(omega_boundary_condition, dict)
|
|
190
173
|
) or (
|
|
191
|
-
not isinstance(
|
|
192
|
-
and isinstance(
|
|
174
|
+
not isinstance(omega_boundary_fun, dict)
|
|
175
|
+
and isinstance(omega_boundary_condition, dict)
|
|
193
176
|
):
|
|
194
177
|
raise ValueError(
|
|
195
|
-
"if one of
|
|
196
|
-
"
|
|
178
|
+
"if one of omega_boundary_fun or "
|
|
179
|
+
"omega_boundary_condition is dict, the other should be too."
|
|
197
180
|
)
|
|
198
181
|
|
|
199
|
-
if
|
|
182
|
+
if omega_boundary_condition is None or omega_boundary_fun is None:
|
|
200
183
|
warnings.warn(
|
|
201
184
|
"Missing boundary function or no boundary condition."
|
|
202
185
|
"Boundary function is thus ignored."
|
|
203
186
|
)
|
|
204
187
|
else:
|
|
205
|
-
if isinstance(
|
|
206
|
-
for _, v in
|
|
188
|
+
if isinstance(omega_boundary_condition, dict):
|
|
189
|
+
for _, v in omega_boundary_condition.items():
|
|
207
190
|
if v is not None and not any(
|
|
208
191
|
v.lower() in s for s in _IMPLEMENTED_BOUNDARY_CONDITIONS
|
|
209
192
|
):
|
|
210
193
|
raise NotImplementedError(
|
|
211
|
-
f"The boundary condition {
|
|
194
|
+
f"The boundary condition {omega_boundary_condition} is not"
|
|
212
195
|
f"implemented yet. Try one of :"
|
|
213
196
|
f"{_IMPLEMENTED_BOUNDARY_CONDITIONS}."
|
|
214
197
|
)
|
|
215
198
|
else:
|
|
216
199
|
if not any(
|
|
217
|
-
|
|
200
|
+
omega_boundary_condition.lower() in s
|
|
218
201
|
for s in _IMPLEMENTED_BOUNDARY_CONDITIONS
|
|
219
202
|
):
|
|
220
203
|
raise NotImplementedError(
|
|
221
|
-
f"The boundary condition {
|
|
204
|
+
f"The boundary condition {omega_boundary_condition} is not"
|
|
222
205
|
f"implemented yet. Try one of :"
|
|
223
206
|
f"{_IMPLEMENTED_BOUNDARY_CONDITIONS}."
|
|
224
207
|
)
|
|
225
|
-
|
|
226
|
-
|
|
208
|
+
if isinstance(omega_boundary_fun, dict) and isinstance(
|
|
209
|
+
omega_boundary_condition, dict
|
|
210
|
+
):
|
|
211
|
+
keys_omega_boundary_fun = cast(str, omega_boundary_fun.keys())
|
|
212
|
+
if (
|
|
213
|
+
not (
|
|
214
|
+
list(keys_omega_boundary_fun) == ["xmin", "xmax"]
|
|
215
|
+
and list(omega_boundary_condition.keys()) == ["xmin", "xmax"]
|
|
216
|
+
)
|
|
217
|
+
) and (
|
|
218
|
+
not (
|
|
219
|
+
list(keys_omega_boundary_fun)
|
|
220
|
+
== ["xmin", "xmax", "ymin", "ymax"]
|
|
221
|
+
and list(omega_boundary_condition.keys())
|
|
222
|
+
== ["xmin", "xmax", "ymin", "ymax"]
|
|
223
|
+
)
|
|
227
224
|
):
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
not (
|
|
236
|
-
list(self.omega_boundary_fun.keys())
|
|
237
|
-
== ["xmin", "xmax", "ymin", "ymax"]
|
|
238
|
-
and list(self.omega_boundary_condition.keys())
|
|
239
|
-
== ["xmin", "xmax", "ymin", "ymax"]
|
|
240
|
-
)
|
|
241
|
-
):
|
|
242
|
-
raise ValueError(
|
|
243
|
-
"The key order (facet order) in the "
|
|
244
|
-
"boundary condition dictionaries is incorrect"
|
|
245
|
-
)
|
|
225
|
+
raise ValueError(
|
|
226
|
+
"The key order (facet order) in the "
|
|
227
|
+
"boundary condition dictionaries is incorrect"
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
self.omega_boundary_fun = omega_boundary_fun
|
|
231
|
+
self.omega_boundary_condition = omega_boundary_condition
|
|
246
232
|
|
|
247
|
-
if isinstance(
|
|
248
|
-
|
|
233
|
+
if isinstance(omega_boundary_fun, dict):
|
|
234
|
+
keys_omega_boundary_fun: str = cast(str, omega_boundary_fun.keys())
|
|
235
|
+
if omega_boundary_dim is None:
|
|
236
|
+
self.omega_boundary_dim = {
|
|
237
|
+
k: jnp.s_[::] for k in keys_omega_boundary_fun
|
|
238
|
+
}
|
|
239
|
+
if not isinstance(omega_boundary_dim, dict):
|
|
249
240
|
raise ValueError(
|
|
250
241
|
"If omega_boundary_fun is a dict then"
|
|
251
242
|
" omega_boundary_dim should also be a dict"
|
|
252
243
|
)
|
|
253
|
-
if
|
|
254
|
-
self.omega_boundary_dim = {
|
|
255
|
-
k: jnp.s_[::] for k in self.omega_boundary_fun.keys()
|
|
256
|
-
}
|
|
257
|
-
if list(self.omega_boundary_dim.keys()) != list(
|
|
258
|
-
self.omega_boundary_fun.keys()
|
|
259
|
-
):
|
|
244
|
+
if list(omega_boundary_dim.keys()) != list(keys_omega_boundary_fun):
|
|
260
245
|
raise ValueError(
|
|
261
246
|
"If omega_boundary_fun is a dict,"
|
|
262
247
|
" omega_boundary_dim should be a dict with the same keys"
|
|
263
248
|
)
|
|
264
|
-
|
|
249
|
+
self.omega_boundary_dim = {}
|
|
250
|
+
for k, v in omega_boundary_dim.items():
|
|
265
251
|
if isinstance(v, int):
|
|
266
252
|
# rewrite it as a slice to ensure that axis does not disappear when
|
|
267
253
|
# indexing
|
|
268
254
|
self.omega_boundary_dim[k] = jnp.s_[v : v + 1]
|
|
255
|
+
else:
|
|
256
|
+
self.omega_boundary_dim[k] = v
|
|
269
257
|
|
|
270
258
|
else:
|
|
271
|
-
|
|
259
|
+
assert not isinstance(omega_boundary_dim, dict)
|
|
260
|
+
if omega_boundary_dim is None:
|
|
272
261
|
self.omega_boundary_dim = jnp.s_[::]
|
|
273
|
-
|
|
262
|
+
elif isinstance(omega_boundary_dim, int):
|
|
274
263
|
# rewrite it as a slice to ensure that axis does not disappear when
|
|
275
264
|
# indexing
|
|
276
265
|
self.omega_boundary_dim = jnp.s_[
|
|
277
|
-
|
|
266
|
+
omega_boundary_dim : omega_boundary_dim + 1
|
|
278
267
|
]
|
|
279
|
-
|
|
280
|
-
|
|
268
|
+
else:
|
|
269
|
+
assert isinstance(omega_boundary_dim, slice)
|
|
270
|
+
self.omega_boundary_dim = omega_boundary_dim
|
|
281
271
|
|
|
282
|
-
if
|
|
283
|
-
|
|
272
|
+
if norm_samples is not None:
|
|
273
|
+
self.norm_samples = norm_samples
|
|
274
|
+
if norm_weights is None:
|
|
284
275
|
raise ValueError(
|
|
285
276
|
"`norm_weights` must be provided when `norm_samples` is used!"
|
|
286
277
|
)
|
|
287
|
-
if isinstance(
|
|
288
|
-
self.norm_weights =
|
|
278
|
+
if isinstance(norm_weights, (int, float)):
|
|
279
|
+
self.norm_weights = norm_weights * jnp.ones(
|
|
289
280
|
(self.norm_samples.shape[0],)
|
|
290
281
|
)
|
|
291
|
-
|
|
292
|
-
|
|
282
|
+
else:
|
|
283
|
+
assert isinstance(norm_weights, Array)
|
|
284
|
+
if not (norm_weights.shape[0] == norm_samples.shape[0]):
|
|
293
285
|
raise ValueError(
|
|
294
|
-
"
|
|
295
|
-
"
|
|
286
|
+
"norm_weights and "
|
|
287
|
+
"norm_samples must have the same leading dimension"
|
|
296
288
|
)
|
|
297
|
-
|
|
298
|
-
|
|
289
|
+
self.norm_weights = norm_weights
|
|
290
|
+
else:
|
|
291
|
+
self.norm_samples = norm_samples
|
|
292
|
+
self.norm_weights = None
|
|
293
|
+
|
|
294
|
+
self.key = key
|
|
299
295
|
|
|
300
296
|
@abc.abstractmethod
|
|
301
|
-
def
|
|
297
|
+
def _get_dynamic_loss_batch(self, batch: B) -> Array:
|
|
302
298
|
pass
|
|
303
299
|
|
|
304
300
|
@abc.abstractmethod
|
|
305
|
-
def
|
|
306
|
-
|
|
301
|
+
def _get_normalization_loss_batch(self, batch: B) -> tuple[Array | None, ...]:
|
|
302
|
+
pass
|
|
303
|
+
|
|
304
|
+
def _get_dyn_loss_fun(
|
|
305
|
+
self, batch: B, vmap_in_axes_params: tuple[Params[int | None] | None]
|
|
306
|
+
) -> Callable[[Params[Array]], Array] | None:
|
|
307
|
+
if self.dynamic_loss is not None:
|
|
308
|
+
dyn_loss_eval = self.dynamic_loss.evaluate
|
|
309
|
+
dyn_loss_fun: Callable[[Params[Array]], Array] | None = (
|
|
310
|
+
lambda p: dynamic_loss_apply(
|
|
311
|
+
dyn_loss_eval,
|
|
312
|
+
self.u,
|
|
313
|
+
self._get_dynamic_loss_batch(batch),
|
|
314
|
+
_set_derivatives(p, self.derivative_keys.dyn_loss),
|
|
315
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
316
|
+
)
|
|
317
|
+
)
|
|
318
|
+
else:
|
|
319
|
+
dyn_loss_fun = None
|
|
320
|
+
|
|
321
|
+
return dyn_loss_fun
|
|
322
|
+
|
|
323
|
+
def _get_norm_loss_fun(
|
|
324
|
+
self, batch: B, vmap_in_axes_params: tuple[Params[int | None] | None]
|
|
325
|
+
) -> Callable[[Params[Array]], Array] | None:
|
|
326
|
+
if self.norm_samples is not None:
|
|
327
|
+
norm_loss_fun: Callable[[Params[Array]], Array] | None = (
|
|
328
|
+
lambda p: normalization_loss_apply(
|
|
329
|
+
self.u,
|
|
330
|
+
cast(
|
|
331
|
+
tuple[Array, Array], self._get_normalization_loss_batch(batch)
|
|
332
|
+
),
|
|
333
|
+
_set_derivatives(p, self.derivative_keys.norm_loss),
|
|
334
|
+
vmap_in_axes_params,
|
|
335
|
+
self.norm_weights, # type: ignore -> can't get the __post_init__ narrowing here
|
|
336
|
+
)
|
|
337
|
+
)
|
|
338
|
+
else:
|
|
339
|
+
norm_loss_fun = None
|
|
340
|
+
return norm_loss_fun
|
|
341
|
+
|
|
342
|
+
def _get_boundary_loss_fun(
|
|
343
|
+
self, batch: B
|
|
344
|
+
) -> Callable[[Params[Array]], Array] | None:
|
|
345
|
+
if (
|
|
346
|
+
self.omega_boundary_condition is not None
|
|
347
|
+
and self.omega_boundary_fun is not None
|
|
348
|
+
):
|
|
349
|
+
boundary_loss_fun: Callable[[Params[Array]], Array] | None = (
|
|
350
|
+
lambda p: boundary_condition_apply(
|
|
351
|
+
self.u,
|
|
352
|
+
batch,
|
|
353
|
+
_set_derivatives(p, self.derivative_keys.boundary_loss),
|
|
354
|
+
self.omega_boundary_fun, # type: ignore (we are in lambda)
|
|
355
|
+
self.omega_boundary_condition, # type: ignore
|
|
356
|
+
self.omega_boundary_dim, # type: ignore
|
|
357
|
+
)
|
|
358
|
+
)
|
|
359
|
+
else:
|
|
360
|
+
boundary_loss_fun = None
|
|
361
|
+
|
|
362
|
+
return boundary_loss_fun
|
|
363
|
+
|
|
364
|
+
def _get_obs_params_and_obs_loss_fun(
|
|
365
|
+
self,
|
|
366
|
+
batch: B,
|
|
367
|
+
vmap_in_axes_params: tuple[Params[int | None] | None],
|
|
307
368
|
params: Params[Array],
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
369
|
+
) -> tuple[Params[Array] | None, Callable[[Params[Array]], Array] | None]:
|
|
370
|
+
if batch.obs_batch_dict is not None:
|
|
371
|
+
# update params with the batches of observed params
|
|
372
|
+
params_obs = update_eq_params(params, batch.obs_batch_dict["eq_params"])
|
|
373
|
+
|
|
374
|
+
pinn_in, val = (
|
|
375
|
+
batch.obs_batch_dict["pinn_in"],
|
|
376
|
+
batch.obs_batch_dict["val"],
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
obs_loss_fun: Callable[[Params[Array]], Array] | None = (
|
|
380
|
+
lambda po: observations_loss_apply(
|
|
381
|
+
self.u,
|
|
382
|
+
pinn_in,
|
|
383
|
+
_set_derivatives(po, self.derivative_keys.observations),
|
|
384
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
385
|
+
val,
|
|
386
|
+
self.obs_slice,
|
|
387
|
+
)
|
|
388
|
+
)
|
|
389
|
+
else:
|
|
390
|
+
params_obs = None
|
|
391
|
+
obs_loss_fun = None
|
|
392
|
+
|
|
393
|
+
return params_obs, obs_loss_fun
|
|
311
394
|
|
|
312
395
|
|
|
313
|
-
class LossPDEStatio(
|
|
396
|
+
class LossPDEStatio(
|
|
397
|
+
_LossPDEAbstract[
|
|
398
|
+
LossWeightsPDEStatio,
|
|
399
|
+
PDEStatioBatch,
|
|
400
|
+
PDEStatioComponents[Array | None],
|
|
401
|
+
DerivativeKeysPDEStatio,
|
|
402
|
+
PDEStatio | None,
|
|
403
|
+
]
|
|
404
|
+
):
|
|
314
405
|
r"""Loss object for a stationary partial differential equation
|
|
315
406
|
|
|
316
407
|
$$
|
|
@@ -325,13 +416,13 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
325
416
|
----------
|
|
326
417
|
u : AbstractPINN
|
|
327
418
|
the PINN
|
|
328
|
-
dynamic_loss : PDEStatio
|
|
419
|
+
dynamic_loss : PDEStatio | None
|
|
329
420
|
the stationary PDE dynamic part of the loss, basically the differential
|
|
330
421
|
operator $\mathcal{N}[u](x)$. Should implement a method
|
|
331
422
|
`dynamic_loss.evaluate(x, u, params)`.
|
|
332
423
|
Can be None in order to access only some part of the evaluate call
|
|
333
424
|
results.
|
|
334
|
-
key :
|
|
425
|
+
key : PRNGKeyArray
|
|
335
426
|
A JAX PRNG Key for the loss class treated as an attribute. Default is
|
|
336
427
|
None. This field is provided for future developments and additional
|
|
337
428
|
losses that might need some randomness. Note that special care must be
|
|
@@ -343,14 +434,18 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
343
434
|
observations if any.
|
|
344
435
|
Can be updated according to a specific algorithm. See
|
|
345
436
|
`update_weight_method`
|
|
346
|
-
update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
|
|
347
|
-
Default is None meaning no update for loss weights. Otherwise a string
|
|
348
437
|
derivative_keys : DerivativeKeysPDEStatio, default=None
|
|
349
438
|
Specify which field of `params` should be differentiated for each
|
|
350
439
|
composant of the total loss. Particularily useful for inverse problems.
|
|
351
440
|
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
352
441
|
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
353
442
|
is `"nn_params"` for each composant of the loss.
|
|
443
|
+
params : InitVar[Params[Array]], default=None
|
|
444
|
+
The main Params object of the problem needed to instanciate the
|
|
445
|
+
DerivativeKeysODE if the latter is not specified.
|
|
446
|
+
|
|
447
|
+
update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
|
|
448
|
+
Default is None meaning no update for loss weights. Otherwise a string
|
|
354
449
|
omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
|
|
355
450
|
The function to be matched in the border condition (can be None) or a
|
|
356
451
|
dictionary of such functions as values and keys as described
|
|
@@ -366,7 +461,7 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
366
461
|
a particular boundary condition on this facet.
|
|
367
462
|
The facet called “xmin”, resp. “xmax” etc., in 2D,
|
|
368
463
|
refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
|
|
369
|
-
omega_boundary_dim : slice | dict[str, slice], default=None
|
|
464
|
+
omega_boundary_dim : int | slice | dict[str, slice], default=None
|
|
370
465
|
Either None, or a slice object or a dictionary of slice objects as
|
|
371
466
|
values and keys as described in `omega_boundary_condition`.
|
|
372
467
|
`omega_boundary_dim` indicates which dimension(s) of the PINN
|
|
@@ -387,10 +482,6 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
387
482
|
obs_slice : slice, default=None
|
|
388
483
|
slice object specifying the begininning/ending of the PINN output
|
|
389
484
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
390
|
-
params : InitVar[Params[Array]], default=None
|
|
391
|
-
The main Params object of the problem needed to instanciate the
|
|
392
|
-
DerivativeKeysODE if the latter is not specified.
|
|
393
|
-
|
|
394
485
|
|
|
395
486
|
Raises
|
|
396
487
|
------
|
|
@@ -404,21 +495,46 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
404
495
|
|
|
405
496
|
u: AbstractPINN
|
|
406
497
|
dynamic_loss: PDEStatio | None
|
|
407
|
-
|
|
498
|
+
loss_weights: LossWeightsPDEStatio
|
|
499
|
+
derivative_keys: DerivativeKeysPDEStatio
|
|
500
|
+
vmap_in_axes: tuple[int] = eqx.field(static=True)
|
|
501
|
+
|
|
502
|
+
params: InitVar[Params[Array] | None]
|
|
503
|
+
|
|
504
|
+
def __init__(
|
|
505
|
+
self,
|
|
506
|
+
*,
|
|
507
|
+
u: AbstractPINN,
|
|
508
|
+
dynamic_loss: PDEStatio | None,
|
|
509
|
+
loss_weights: LossWeightsPDEStatio | None = None,
|
|
510
|
+
derivative_keys: DerivativeKeysPDEStatio | None = None,
|
|
511
|
+
params: Params[Array] | None = None,
|
|
512
|
+
**kwargs: Any,
|
|
513
|
+
):
|
|
514
|
+
self.u = u
|
|
515
|
+
if loss_weights is None:
|
|
516
|
+
self.loss_weights = LossWeightsPDEStatio()
|
|
517
|
+
else:
|
|
518
|
+
self.loss_weights = loss_weights
|
|
519
|
+
self.dynamic_loss = dynamic_loss
|
|
408
520
|
|
|
409
|
-
|
|
521
|
+
super().__init__(
|
|
522
|
+
**kwargs,
|
|
523
|
+
)
|
|
410
524
|
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
525
|
+
if derivative_keys is None:
|
|
526
|
+
# be default we only take gradient wrt nn_params
|
|
527
|
+
try:
|
|
528
|
+
self.derivative_keys = DerivativeKeysPDEStatio(params=params)
|
|
529
|
+
except ValueError as exc:
|
|
530
|
+
raise ValueError(
|
|
531
|
+
"Problem at derivative_keys initialization "
|
|
532
|
+
f"received {derivative_keys=} and {params=}"
|
|
533
|
+
) from exc
|
|
534
|
+
else:
|
|
535
|
+
self.derivative_keys = derivative_keys
|
|
420
536
|
|
|
421
|
-
self.vmap_in_axes = (0,)
|
|
537
|
+
self.vmap_in_axes = (0,)
|
|
422
538
|
|
|
423
539
|
def _get_dynamic_loss_batch(
|
|
424
540
|
self, batch: PDEStatioBatch
|
|
@@ -426,23 +542,18 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
426
542
|
return batch.domain_batch
|
|
427
543
|
|
|
428
544
|
def _get_normalization_loss_batch(
|
|
429
|
-
self, _
|
|
430
|
-
) -> tuple[Float[Array, " nb_norm_samples dimension"]]:
|
|
431
|
-
return (self.norm_samples,) # type: ignore -> cannot narrow a class attr
|
|
432
|
-
|
|
433
|
-
# we could have used typing.cast though
|
|
434
|
-
|
|
435
|
-
def _get_observations_loss_batch(
|
|
436
545
|
self, batch: PDEStatioBatch
|
|
437
|
-
) -> Float[Array, "
|
|
438
|
-
return
|
|
546
|
+
) -> tuple[Float[Array, " nb_norm_samples dimension"] | None,]:
|
|
547
|
+
return (self.norm_samples,)
|
|
439
548
|
|
|
440
|
-
|
|
441
|
-
return self.evaluate(*args, **kwargs)
|
|
549
|
+
# we could have used typing.cast though
|
|
442
550
|
|
|
443
551
|
def evaluate_by_terms(
|
|
444
552
|
self, params: Params[Array], batch: PDEStatioBatch
|
|
445
|
-
) -> tuple[
|
|
553
|
+
) -> tuple[
|
|
554
|
+
PDEStatioComponents[Float[Array, ""] | None],
|
|
555
|
+
PDEStatioComponents[Float[Array, ""] | None],
|
|
556
|
+
]:
|
|
446
557
|
"""
|
|
447
558
|
Evaluate the loss function at a batch of points for given parameters.
|
|
448
559
|
|
|
@@ -464,69 +575,23 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
464
575
|
# and update vmap_in_axes
|
|
465
576
|
if batch.param_batch_dict is not None:
|
|
466
577
|
# update eq_params with the batches of generated params
|
|
467
|
-
params =
|
|
578
|
+
params = update_eq_params(params, batch.param_batch_dict)
|
|
468
579
|
|
|
469
580
|
vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
|
|
470
581
|
|
|
471
582
|
# dynamic part
|
|
472
|
-
|
|
473
|
-
dyn_loss_fun = lambda p: dynamic_loss_apply(
|
|
474
|
-
self.dynamic_loss.evaluate, # type: ignore
|
|
475
|
-
self.u,
|
|
476
|
-
self._get_dynamic_loss_batch(batch),
|
|
477
|
-
_set_derivatives(p, self.derivative_keys.dyn_loss), # type: ignore
|
|
478
|
-
self.vmap_in_axes + vmap_in_axes_params,
|
|
479
|
-
)
|
|
480
|
-
else:
|
|
481
|
-
dyn_loss_fun = None
|
|
583
|
+
dyn_loss_fun = self._get_dyn_loss_fun(batch, vmap_in_axes_params)
|
|
482
584
|
|
|
483
585
|
# normalization part
|
|
484
|
-
|
|
485
|
-
norm_loss_fun = lambda p: normalization_loss_apply(
|
|
486
|
-
self.u,
|
|
487
|
-
self._get_normalization_loss_batch(batch),
|
|
488
|
-
_set_derivatives(p, self.derivative_keys.norm_loss), # type: ignore
|
|
489
|
-
vmap_in_axes_params,
|
|
490
|
-
self.norm_weights, # type: ignore -> can't get the __post_init__ narrowing here
|
|
491
|
-
)
|
|
492
|
-
else:
|
|
493
|
-
norm_loss_fun = None
|
|
586
|
+
norm_loss_fun = self._get_norm_loss_fun(batch, vmap_in_axes_params)
|
|
494
587
|
|
|
495
588
|
# boundary part
|
|
496
|
-
|
|
497
|
-
self.omega_boundary_condition is not None
|
|
498
|
-
and self.omega_boundary_dim is not None
|
|
499
|
-
and self.omega_boundary_fun is not None
|
|
500
|
-
): # pyright cannot narrow down the three None otherwise as it is class attribute
|
|
501
|
-
boundary_loss_fun = lambda p: boundary_condition_apply(
|
|
502
|
-
self.u,
|
|
503
|
-
batch,
|
|
504
|
-
_set_derivatives(p, self.derivative_keys.boundary_loss), # type: ignore
|
|
505
|
-
self.omega_boundary_fun, # type: ignore
|
|
506
|
-
self.omega_boundary_condition, # type: ignore
|
|
507
|
-
self.omega_boundary_dim, # type: ignore
|
|
508
|
-
)
|
|
509
|
-
else:
|
|
510
|
-
boundary_loss_fun = None
|
|
589
|
+
boundary_loss_fun = self._get_boundary_loss_fun(batch)
|
|
511
590
|
|
|
512
591
|
# Observation mse
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
params, batch.obs_batch_dict["eq_params"]
|
|
517
|
-
)
|
|
518
|
-
|
|
519
|
-
obs_loss_fun = lambda po: observations_loss_apply(
|
|
520
|
-
self.u,
|
|
521
|
-
self._get_observations_loss_batch(batch),
|
|
522
|
-
_set_derivatives(po, self.derivative_keys.observations), # type: ignore
|
|
523
|
-
self.vmap_in_axes + vmap_in_axes_params,
|
|
524
|
-
batch.obs_batch_dict["val"],
|
|
525
|
-
self.obs_slice,
|
|
526
|
-
)
|
|
527
|
-
else:
|
|
528
|
-
params_obs = None
|
|
529
|
-
obs_loss_fun = None
|
|
592
|
+
params_obs, obs_loss_fun = self._get_obs_params_and_obs_loss_fun(
|
|
593
|
+
batch, vmap_in_axes_params, params
|
|
594
|
+
)
|
|
530
595
|
|
|
531
596
|
# get the unweighted mses for each loss term as well as the gradients
|
|
532
597
|
all_funs: PDEStatioComponents[Callable[[Params[Array]], Array] | None] = (
|
|
@@ -538,47 +603,34 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
538
603
|
params, params, params, params_obs
|
|
539
604
|
)
|
|
540
605
|
mses_grads = jax.tree.map(
|
|
541
|
-
|
|
606
|
+
self.get_gradients,
|
|
542
607
|
all_funs,
|
|
543
608
|
all_params,
|
|
544
609
|
is_leaf=lambda x: x is None,
|
|
545
610
|
)
|
|
546
611
|
mses = jax.tree.map(
|
|
547
|
-
lambda leaf: leaf[0],
|
|
612
|
+
lambda leaf: leaf[0], # type: ignore
|
|
613
|
+
mses_grads,
|
|
614
|
+
is_leaf=lambda x: isinstance(x, tuple),
|
|
548
615
|
)
|
|
549
616
|
grads = jax.tree.map(
|
|
550
|
-
lambda leaf: leaf[1],
|
|
617
|
+
lambda leaf: leaf[1], # type: ignore
|
|
618
|
+
mses_grads,
|
|
619
|
+
is_leaf=lambda x: isinstance(x, tuple),
|
|
551
620
|
)
|
|
552
621
|
|
|
553
622
|
return mses, grads
|
|
554
623
|
|
|
555
|
-
def evaluate(
|
|
556
|
-
self, params: Params[Array], batch: PDEStatioBatch
|
|
557
|
-
) -> tuple[Float[Array, " "], PDEStatioComponents[Float[Array, " "] | None]]:
|
|
558
|
-
"""
|
|
559
|
-
Evaluate the loss function at a batch of points for given parameters.
|
|
560
|
-
|
|
561
|
-
We retrieve the total value itself and a PyTree with loss values for each term
|
|
562
|
-
|
|
563
|
-
Parameters
|
|
564
|
-
---------
|
|
565
|
-
params
|
|
566
|
-
Parameters at which the loss is evaluated
|
|
567
|
-
batch
|
|
568
|
-
Composed of a batch of points in the
|
|
569
|
-
domain, a batch of points in the domain
|
|
570
|
-
border and an optional additional batch of parameters (eg. for
|
|
571
|
-
metamodeling) and an optional additional batch of observed
|
|
572
|
-
inputs/outputs/parameters
|
|
573
|
-
"""
|
|
574
|
-
loss_terms, _ = self.evaluate_by_terms(params, batch)
|
|
575
|
-
|
|
576
|
-
loss_val = self.ponderate_and_sum_loss(loss_terms)
|
|
577
|
-
|
|
578
|
-
return loss_val, loss_terms
|
|
579
|
-
|
|
580
624
|
|
|
581
|
-
class LossPDENonStatio(
|
|
625
|
+
class LossPDENonStatio(
|
|
626
|
+
_LossPDEAbstract[
|
|
627
|
+
LossWeightsPDENonStatio,
|
|
628
|
+
PDENonStatioBatch,
|
|
629
|
+
PDENonStatioComponents[Array | None],
|
|
630
|
+
DerivativeKeysPDENonStatio,
|
|
631
|
+
PDENonStatio | None,
|
|
632
|
+
]
|
|
633
|
+
):
|
|
582
634
|
r"""Loss object for a stationary partial differential equation
|
|
583
635
|
|
|
584
636
|
$$
|
|
@@ -602,7 +654,7 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
602
654
|
`dynamic_loss.evaluate(t, x, u, params)`.
|
|
603
655
|
Can be None in order to access only some part of the evaluate call
|
|
604
656
|
results.
|
|
605
|
-
key :
|
|
657
|
+
key : PRNGKeyArray
|
|
606
658
|
A JAX PRNG Key for the loss class treated as an attribute. Default is
|
|
607
659
|
None. This field is provided for future developments and additional
|
|
608
660
|
losses that might need some randomness. Note that special care must be
|
|
@@ -615,14 +667,31 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
615
667
|
observations if any.
|
|
616
668
|
Can be updated according to a specific algorithm. See
|
|
617
669
|
`update_weight_method`
|
|
618
|
-
update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
|
|
619
|
-
Default is None meaning no update for loss weights. Otherwise a string
|
|
620
670
|
derivative_keys : DerivativeKeysPDENonStatio, default=None
|
|
621
671
|
Specify which field of `params` should be differentiated for each
|
|
622
672
|
composant of the total loss. Particularily useful for inverse problems.
|
|
623
673
|
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
624
674
|
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
625
675
|
is `"nn_params"` for each composant of the loss.
|
|
676
|
+
initial_condition_fun : Callable, default=None
|
|
677
|
+
A function representing the initial condition at `t0`. If None
|
|
678
|
+
(default) then no initial condition is applied.
|
|
679
|
+
t0 : float | Float[Array, " 1"], default=None
|
|
680
|
+
The time at which to apply the initial condition. If None, the time
|
|
681
|
+
is set to `0` by default.
|
|
682
|
+
max_norm_time_slices : int, default=100
|
|
683
|
+
The maximum number of time points in the Cartesian product with the
|
|
684
|
+
omega points to create the set of collocation points upon which the
|
|
685
|
+
normalization constant is computed.
|
|
686
|
+
max_norm_samples_omega : int, default=1000
|
|
687
|
+
The maximum number of omega points in the Cartesian product with the
|
|
688
|
+
time points to create the set of collocation points upon which the
|
|
689
|
+
normalization constant is computed.
|
|
690
|
+
params : InitVar[Params[Array]], default=None
|
|
691
|
+
The main `Params` object of the problem needed to instanciate the
|
|
692
|
+
`DerivativeKeysODE` if the latter is not specified.
|
|
693
|
+
update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
|
|
694
|
+
Default is None meaning no update for loss weights. Otherwise a string
|
|
626
695
|
omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
|
|
627
696
|
The function to be matched in the border condition (can be None) or a
|
|
628
697
|
dictionary of such functions as values and keys as described
|
|
@@ -659,68 +728,81 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
659
728
|
obs_slice : slice, default=None
|
|
660
729
|
slice object specifying the begininning/ending of the PINN output
|
|
661
730
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
662
|
-
t0 : float | Float[Array, " 1"], default=None
|
|
663
|
-
The time at which to apply the initial condition. If None, the time
|
|
664
|
-
is set to `0` by default.
|
|
665
|
-
initial_condition_fun : Callable, default=None
|
|
666
|
-
A function representing the initial condition at `t0`. If None
|
|
667
|
-
(default) then no initial condition is applied.
|
|
668
|
-
params : InitVar[Params[Array]], default=None
|
|
669
|
-
The main `Params` object of the problem needed to instanciate the
|
|
670
|
-
`DerivativeKeysODE` if the latter is not specified.
|
|
671
731
|
|
|
672
732
|
"""
|
|
673
733
|
|
|
734
|
+
u: AbstractPINN
|
|
674
735
|
dynamic_loss: PDENonStatio | None
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
736
|
+
loss_weights: LossWeightsPDENonStatio
|
|
737
|
+
derivative_keys: DerivativeKeysPDENonStatio
|
|
738
|
+
params: InitVar[Params[Array] | None]
|
|
739
|
+
t0: Float[Array, " "]
|
|
740
|
+
initial_condition_fun: Callable[[Float[Array, " dimension"]], Array] | None = (
|
|
741
|
+
eqx.field(static=True)
|
|
679
742
|
)
|
|
680
|
-
|
|
743
|
+
vmap_in_axes: tuple[int] = eqx.field(static=True)
|
|
744
|
+
max_norm_samples_omega: int = eqx.field(static=True)
|
|
745
|
+
max_norm_time_slices: int = eqx.field(static=True)
|
|
746
|
+
|
|
747
|
+
params: InitVar[Params[Array] | None]
|
|
748
|
+
|
|
749
|
+
def __init__(
|
|
750
|
+
self,
|
|
751
|
+
*,
|
|
752
|
+
u: AbstractPINN,
|
|
753
|
+
dynamic_loss: PDENonStatio | None,
|
|
754
|
+
loss_weights: LossWeightsPDENonStatio | None = None,
|
|
755
|
+
derivative_keys: DerivativeKeysPDENonStatio | None = None,
|
|
756
|
+
initial_condition_fun: Callable[[Float[Array, " dimension"]], Array]
|
|
757
|
+
| None = None,
|
|
758
|
+
t0: int | float | Float[Array, " "] | None = None,
|
|
759
|
+
max_norm_time_slices: int = 100,
|
|
760
|
+
max_norm_samples_omega: int = 1000,
|
|
761
|
+
params: Params[Array] | None = None,
|
|
762
|
+
**kwargs: Any,
|
|
763
|
+
):
|
|
764
|
+
self.u = u
|
|
765
|
+
if loss_weights is None:
|
|
766
|
+
self.loss_weights = LossWeightsPDENonStatio()
|
|
767
|
+
else:
|
|
768
|
+
self.loss_weights = loss_weights
|
|
769
|
+
self.dynamic_loss = dynamic_loss
|
|
681
770
|
|
|
682
|
-
|
|
683
|
-
|
|
771
|
+
super().__init__(
|
|
772
|
+
**kwargs,
|
|
773
|
+
)
|
|
684
774
|
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
775
|
+
if derivative_keys is None:
|
|
776
|
+
# be default we only take gradient wrt nn_params
|
|
777
|
+
try:
|
|
778
|
+
self.derivative_keys = DerivativeKeysPDENonStatio(params=params)
|
|
779
|
+
except ValueError as exc:
|
|
780
|
+
raise ValueError(
|
|
781
|
+
"Problem at derivative_keys initialization "
|
|
782
|
+
f"received {derivative_keys=} and {params=}"
|
|
783
|
+
) from exc
|
|
784
|
+
else:
|
|
785
|
+
self.derivative_keys = derivative_keys
|
|
694
786
|
|
|
695
787
|
self.vmap_in_axes = (0,) # for t_x
|
|
696
788
|
|
|
697
|
-
if
|
|
789
|
+
if initial_condition_fun is None:
|
|
698
790
|
warnings.warn(
|
|
699
791
|
"Initial condition wasn't provided. Be sure to cover for that"
|
|
700
792
|
"case (e.g by. hardcoding it into the PINN output)."
|
|
701
793
|
)
|
|
702
794
|
# some checks for t0
|
|
703
|
-
if
|
|
704
|
-
if not self.t0.shape: # e.g. user input: jnp.array(0.)
|
|
705
|
-
self.t0 = jnp.array([self.t0])
|
|
706
|
-
elif self.t0.shape != (1,):
|
|
707
|
-
raise ValueError(
|
|
708
|
-
f"Wrong self.t0 input. It should be"
|
|
709
|
-
f"a float or an array of shape (1,). Got shape: {self.t0.shape}"
|
|
710
|
-
)
|
|
711
|
-
elif isinstance(self.t0, float): # e.g. user input: 0.
|
|
712
|
-
self.t0 = jnp.array([self.t0])
|
|
713
|
-
elif isinstance(self.t0, int): # e.g. user input: 0
|
|
714
|
-
self.t0 = jnp.array([float(self.t0)])
|
|
715
|
-
elif self.t0 is None:
|
|
795
|
+
if t0 is None:
|
|
716
796
|
self.t0 = jnp.array([0])
|
|
717
797
|
else:
|
|
718
|
-
|
|
798
|
+
self.t0 = initial_condition_check(t0, dim_size=1)
|
|
719
799
|
|
|
720
|
-
|
|
800
|
+
self.initial_condition_fun = initial_condition_fun
|
|
801
|
+
|
|
802
|
+
# with the variables below we avoid memory overflow since a cartesian
|
|
721
803
|
# product is taken
|
|
722
|
-
self.
|
|
723
|
-
self.
|
|
804
|
+
self.max_norm_time_slices = max_norm_time_slices
|
|
805
|
+
self.max_norm_samples_omega = max_norm_samples_omega
|
|
724
806
|
|
|
725
807
|
def _get_dynamic_loss_batch(
|
|
726
808
|
self, batch: PDENonStatioBatch
|
|
@@ -733,18 +815,10 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
733
815
|
Float[Array, " nb_norm_time_slices 1"], Float[Array, " nb_norm_samples dim"]
|
|
734
816
|
]:
|
|
735
817
|
return (
|
|
736
|
-
batch.domain_batch[: self.
|
|
737
|
-
self.norm_samples[: self.
|
|
818
|
+
batch.domain_batch[: self.max_norm_time_slices, 0:1],
|
|
819
|
+
self.norm_samples[: self.max_norm_samples_omega], # type: ignore -> cannot narrow a class attr
|
|
738
820
|
)
|
|
739
821
|
|
|
740
|
-
def _get_observations_loss_batch(
|
|
741
|
-
self, batch: PDENonStatioBatch
|
|
742
|
-
) -> Float[Array, " batch_size 1+dim"]:
|
|
743
|
-
return batch.obs_batch_dict["pinn_in"]
|
|
744
|
-
|
|
745
|
-
def __call__(self, *args, **kwargs):
|
|
746
|
-
return self.evaluate(*args, **kwargs)
|
|
747
|
-
|
|
748
822
|
def evaluate_by_terms(
|
|
749
823
|
self, params: Params[Array], batch: PDENonStatioBatch
|
|
750
824
|
) -> tuple[
|
|
@@ -766,76 +840,75 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
766
840
|
metamodeling) and an optional additional batch of observed
|
|
767
841
|
inputs/outputs/parameters
|
|
768
842
|
"""
|
|
769
|
-
|
|
770
|
-
assert
|
|
843
|
+
omega_initial_batch = batch.initial_batch
|
|
844
|
+
assert omega_initial_batch is not None
|
|
771
845
|
|
|
772
846
|
# Retrieve the optional eq_params_batch
|
|
773
847
|
# and update eq_params with the latter
|
|
774
848
|
# and update vmap_in_axes
|
|
775
849
|
if batch.param_batch_dict is not None:
|
|
776
850
|
# update eq_params with the batches of generated params
|
|
777
|
-
params =
|
|
851
|
+
params = update_eq_params(params, batch.param_batch_dict)
|
|
778
852
|
|
|
779
853
|
vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
|
|
780
854
|
|
|
781
|
-
#
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
855
|
+
# dynamic part
|
|
856
|
+
dyn_loss_fun = self._get_dyn_loss_fun(batch, vmap_in_axes_params)
|
|
857
|
+
|
|
858
|
+
# normalization part
|
|
859
|
+
norm_loss_fun = self._get_norm_loss_fun(batch, vmap_in_axes_params)
|
|
860
|
+
|
|
861
|
+
# boundary part
|
|
862
|
+
boundary_loss_fun = self._get_boundary_loss_fun(batch)
|
|
863
|
+
|
|
864
|
+
# Observation mse
|
|
865
|
+
params_obs, obs_loss_fun = self._get_obs_params_and_obs_loss_fun(
|
|
866
|
+
batch, vmap_in_axes_params, params
|
|
867
|
+
)
|
|
786
868
|
|
|
787
869
|
# initial condition
|
|
788
870
|
if self.initial_condition_fun is not None:
|
|
789
|
-
mse_initial_condition_fun
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
mse_initial_condition_fun, params
|
|
871
|
+
mse_initial_condition_fun: Callable[[Params[Array]], Array] | None = (
|
|
872
|
+
lambda p: initial_condition_apply(
|
|
873
|
+
self.u,
|
|
874
|
+
omega_initial_batch,
|
|
875
|
+
_set_derivatives(p, self.derivative_keys.initial_condition),
|
|
876
|
+
(0,) + vmap_in_axes_params,
|
|
877
|
+
self.initial_condition_fun, # type: ignore
|
|
878
|
+
self.t0,
|
|
879
|
+
)
|
|
799
880
|
)
|
|
800
881
|
else:
|
|
801
|
-
|
|
802
|
-
grad_initial_condition = None
|
|
803
|
-
|
|
804
|
-
mses = PDENonStatioComponents(
|
|
805
|
-
partial_mses.dyn_loss,
|
|
806
|
-
partial_mses.norm_loss,
|
|
807
|
-
partial_mses.boundary_loss,
|
|
808
|
-
partial_mses.observations,
|
|
809
|
-
mse_initial_condition,
|
|
810
|
-
)
|
|
882
|
+
mse_initial_condition_fun = None
|
|
811
883
|
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
884
|
+
# get the unweighted mses for each loss term as well as the gradients
|
|
885
|
+
all_funs: PDENonStatioComponents[Callable[[Params[Array]], Array] | None] = (
|
|
886
|
+
PDENonStatioComponents(
|
|
887
|
+
dyn_loss_fun,
|
|
888
|
+
norm_loss_fun,
|
|
889
|
+
boundary_loss_fun,
|
|
890
|
+
obs_loss_fun,
|
|
891
|
+
mse_initial_condition_fun,
|
|
892
|
+
)
|
|
893
|
+
)
|
|
894
|
+
all_params: PDENonStatioComponents[Params[Array] | None] = (
|
|
895
|
+
PDENonStatioComponents(params, params, params, params_obs, params)
|
|
896
|
+
)
|
|
897
|
+
mses_grads = jax.tree.map(
|
|
898
|
+
self.get_gradients,
|
|
899
|
+
all_funs,
|
|
900
|
+
all_params,
|
|
901
|
+
is_leaf=lambda x: x is None,
|
|
902
|
+
)
|
|
903
|
+
mses = jax.tree.map(
|
|
904
|
+
lambda leaf: leaf[0], # type: ignore
|
|
905
|
+
mses_grads,
|
|
906
|
+
is_leaf=lambda x: isinstance(x, tuple),
|
|
907
|
+
)
|
|
908
|
+
grads = jax.tree.map(
|
|
909
|
+
lambda leaf: leaf[1], # type: ignore
|
|
910
|
+
mses_grads,
|
|
911
|
+
is_leaf=lambda x: isinstance(x, tuple),
|
|
818
912
|
)
|
|
819
913
|
|
|
820
914
|
return mses, grads
|
|
821
|
-
|
|
822
|
-
def evaluate(
|
|
823
|
-
self, params: Params[Array], batch: PDENonStatioBatch
|
|
824
|
-
) -> tuple[Float[Array, " "], PDENonStatioComponents[Float[Array, " "] | None]]:
|
|
825
|
-
"""
|
|
826
|
-
Evaluate the loss function at a batch of points for given parameters.
|
|
827
|
-
We retrieve the total value itself and a PyTree with loss values for each term
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
Parameters
|
|
831
|
-
---------
|
|
832
|
-
params
|
|
833
|
-
Parameters at which the loss is evaluated
|
|
834
|
-
batch
|
|
835
|
-
Composed of a batch of points in
|
|
836
|
-
the domain, a batch of points in the domain
|
|
837
|
-
border, a batch of time points and an optional additional batch
|
|
838
|
-
of parameters (eg. for metamodeling) and an optional additional batch of observed
|
|
839
|
-
inputs/outputs/parameters
|
|
840
|
-
"""
|
|
841
|
-
return super().evaluate(params, batch) # type: ignore
|