jinns 1.5.1__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_AbstractDataGenerator.py +1 -1
- jinns/data/_Batchs.py +47 -13
- jinns/data/_CubicMeshPDENonStatio.py +55 -34
- jinns/data/_CubicMeshPDEStatio.py +63 -35
- jinns/data/_DataGeneratorODE.py +48 -22
- jinns/data/_DataGeneratorObservations.py +75 -32
- jinns/data/_DataGeneratorParameter.py +152 -101
- jinns/data/__init__.py +2 -1
- jinns/data/_utils.py +22 -10
- jinns/loss/_DynamicLoss.py +21 -20
- jinns/loss/_DynamicLossAbstract.py +51 -36
- jinns/loss/_LossODE.py +139 -184
- jinns/loss/_LossPDE.py +440 -358
- jinns/loss/_abstract_loss.py +60 -25
- jinns/loss/_loss_components.py +4 -25
- jinns/loss/_loss_weight_updates.py +6 -7
- jinns/loss/_loss_weights.py +34 -35
- jinns/nn/_abstract_pinn.py +0 -2
- jinns/nn/_hyperpinn.py +34 -23
- jinns/nn/_mlp.py +5 -4
- jinns/nn/_pinn.py +1 -16
- jinns/nn/_ppinn.py +5 -16
- jinns/nn/_save_load.py +11 -4
- jinns/nn/_spinn.py +1 -16
- jinns/nn/_spinn_mlp.py +5 -5
- jinns/nn/_utils.py +33 -38
- jinns/parameters/__init__.py +3 -1
- jinns/parameters/_derivative_keys.py +99 -41
- jinns/parameters/_params.py +50 -25
- jinns/solver/_solve.py +3 -3
- jinns/utils/_DictToModuleMeta.py +66 -0
- jinns/utils/_ItemizableModule.py +19 -0
- jinns/utils/__init__.py +2 -1
- jinns/utils/_types.py +25 -15
- {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/METADATA +2 -2
- jinns-1.6.0.dist-info/RECORD +57 -0
- jinns-1.5.1.dist-info/RECORD +0 -55
- {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/WHEEL +0 -0
- {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossPDE.py
CHANGED
|
@@ -8,13 +8,13 @@ from __future__ import (
|
|
|
8
8
|
|
|
9
9
|
import abc
|
|
10
10
|
from dataclasses import InitVar
|
|
11
|
-
from typing import TYPE_CHECKING, Callable,
|
|
11
|
+
from typing import TYPE_CHECKING, Callable, cast, Any, TypeVar, Generic
|
|
12
12
|
from types import EllipsisType
|
|
13
13
|
import warnings
|
|
14
14
|
import jax
|
|
15
15
|
import jax.numpy as jnp
|
|
16
16
|
import equinox as eqx
|
|
17
|
-
from jaxtyping import Float, Array
|
|
17
|
+
from jaxtyping import PRNGKeyArray, Float, Array
|
|
18
18
|
from jinns.loss._loss_utils import (
|
|
19
19
|
dynamic_loss_apply,
|
|
20
20
|
boundary_condition_apply,
|
|
@@ -25,7 +25,7 @@ from jinns.loss._loss_utils import (
|
|
|
25
25
|
)
|
|
26
26
|
from jinns.parameters._params import (
|
|
27
27
|
_get_vmap_in_axes_params,
|
|
28
|
-
|
|
28
|
+
update_eq_params,
|
|
29
29
|
)
|
|
30
30
|
from jinns.parameters._derivative_keys import (
|
|
31
31
|
_set_derivatives,
|
|
@@ -40,24 +40,14 @@ from jinns.loss._loss_weights import (
|
|
|
40
40
|
)
|
|
41
41
|
from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
|
|
42
42
|
from jinns.parameters._params import Params
|
|
43
|
+
from jinns.loss import PDENonStatio, PDEStatio
|
|
43
44
|
|
|
44
45
|
|
|
45
46
|
if TYPE_CHECKING:
|
|
46
47
|
# imports for type hints only
|
|
47
48
|
from jinns.nn._abstract_pinn import AbstractPINN
|
|
48
|
-
from jinns.loss import PDENonStatio, PDEStatio
|
|
49
49
|
from jinns.utils._types import BoundaryConditionFun
|
|
50
50
|
|
|
51
|
-
class LossDictPDEStatio(TypedDict):
|
|
52
|
-
dyn_loss: Float[Array, " "]
|
|
53
|
-
norm_loss: Float[Array, " "]
|
|
54
|
-
boundary_loss: Float[Array, " "]
|
|
55
|
-
observations: Float[Array, " "]
|
|
56
|
-
|
|
57
|
-
class LossDictPDENonStatio(LossDictPDEStatio):
|
|
58
|
-
initial_condition: Float[Array, " "]
|
|
59
|
-
|
|
60
|
-
|
|
61
51
|
_IMPLEMENTED_BOUNDARY_CONDITIONS = [
|
|
62
52
|
"dirichlet",
|
|
63
53
|
"von neumann",
|
|
@@ -65,7 +55,24 @@ _IMPLEMENTED_BOUNDARY_CONDITIONS = [
|
|
|
65
55
|
]
|
|
66
56
|
|
|
67
57
|
|
|
68
|
-
|
|
58
|
+
# For the same reason that we have the TypeVar in _abstract_loss.py, we have them
|
|
59
|
+
# here, because _LossPDEAbstract is abtract and we cannot decide for several
|
|
60
|
+
# types between their statio and non-statio version.
|
|
61
|
+
# Assigning the type where it can be decide seems a better practice than
|
|
62
|
+
# assigning a type at a higher level depending on a child class type. This is
|
|
63
|
+
# why we now assign LossWeights and DerivativeKeys in the child class where
|
|
64
|
+
# they really can be decided.
|
|
65
|
+
|
|
66
|
+
L = TypeVar("L", bound=LossWeightsPDEStatio | LossWeightsPDENonStatio)
|
|
67
|
+
B = TypeVar("B", bound=PDEStatioBatch | PDENonStatioBatch)
|
|
68
|
+
C = TypeVar(
|
|
69
|
+
"C", bound=PDEStatioComponents[Array | None] | PDENonStatioComponents[Array | None]
|
|
70
|
+
)
|
|
71
|
+
D = TypeVar("D", bound=DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio)
|
|
72
|
+
Y = TypeVar("Y", bound=PDEStatio | PDENonStatio | None)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class _LossPDEAbstract(AbstractLoss[L, B, C], Generic[L, B, C, D, Y]):
|
|
69
76
|
r"""
|
|
70
77
|
Parameters
|
|
71
78
|
----------
|
|
@@ -78,17 +85,11 @@ class _LossPDEAbstract(AbstractLoss):
|
|
|
78
85
|
`update_weight_method`
|
|
79
86
|
update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
|
|
80
87
|
Default is None meaning no update for loss weights. Otherwise a string
|
|
81
|
-
derivative_keys : DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio, default=None
|
|
82
|
-
Specify which field of `params` should be differentiated for each
|
|
83
|
-
composant of the total loss. Particularily useful for inverse problems.
|
|
84
|
-
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
85
|
-
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
86
|
-
is `"nn_params"` for each composant of the loss.
|
|
87
88
|
omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
|
|
88
89
|
The function to be matched in the border condition (can be None) or a
|
|
89
90
|
dictionary of such functions as values and keys as described
|
|
90
91
|
in `omega_boundary_condition`.
|
|
91
|
-
omega_boundary_condition : str | dict[str, str], default=None
|
|
92
|
+
omega_boundary_condition : str | dict[str, str | None], default=None
|
|
92
93
|
Either None (no condition, by default), or a string defining
|
|
93
94
|
the boundary condition (Dirichlet or Von Neumann),
|
|
94
95
|
or a dictionary with such strings as values. In this case,
|
|
@@ -99,7 +100,7 @@ class _LossPDEAbstract(AbstractLoss):
|
|
|
99
100
|
a particular boundary condition on this facet.
|
|
100
101
|
The facet called “xmin”, resp. “xmax” etc., in 2D,
|
|
101
102
|
refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
|
|
102
|
-
omega_boundary_dim : slice | dict[str, slice], default=None
|
|
103
|
+
omega_boundary_dim : int | slice | dict[str, slice], default=None
|
|
103
104
|
Either None, or a slice object or a dictionary of slice objects as
|
|
104
105
|
values and keys as described in `omega_boundary_condition`.
|
|
105
106
|
`omega_boundary_dim` indicates which dimension(s) of the PINN
|
|
@@ -121,197 +122,286 @@ class _LossPDEAbstract(AbstractLoss):
|
|
|
121
122
|
obs_slice : EllipsisType | slice, default=None
|
|
122
123
|
slice object specifying the begininning/ending of the PINN output
|
|
123
124
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
125
|
+
key : Key | None
|
|
126
|
+
A JAX PRNG Key for the loss class treated as an attribute. Default is
|
|
127
|
+
None. This field is provided for future developments and additional
|
|
128
|
+
losses that might need some randomness. Note that special care must be
|
|
129
|
+
taken when splitting the key because in-place updates are forbidden in
|
|
130
|
+
eqx.Modules.
|
|
127
131
|
"""
|
|
128
132
|
|
|
129
133
|
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
130
134
|
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
eqx.field(kw_only=True, default=None)
|
|
134
|
-
)
|
|
135
|
-
loss_weights: LossWeightsPDEStatio | LossWeightsPDENonStatio | None = eqx.field(
|
|
136
|
-
kw_only=True, default=None
|
|
137
|
-
)
|
|
135
|
+
u: eqx.AbstractVar[AbstractPINN]
|
|
136
|
+
dynamic_loss: eqx.AbstractVar[Y]
|
|
138
137
|
omega_boundary_fun: (
|
|
139
138
|
BoundaryConditionFun | dict[str, BoundaryConditionFun] | None
|
|
140
|
-
) = eqx.field(
|
|
141
|
-
omega_boundary_condition: str | dict[str, str] | None = eqx.field(
|
|
142
|
-
|
|
139
|
+
) = eqx.field(static=True)
|
|
140
|
+
omega_boundary_condition: str | dict[str, str | None] | None = eqx.field(
|
|
141
|
+
static=True
|
|
143
142
|
)
|
|
144
|
-
omega_boundary_dim: slice | dict[str, slice]
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
self.derivative_keys = (
|
|
168
|
-
DerivativeKeysPDENonStatio(params=params)
|
|
169
|
-
if isinstance(self, LossPDENonStatio)
|
|
170
|
-
else DerivativeKeysPDEStatio(params=params)
|
|
171
|
-
)
|
|
172
|
-
except ValueError as exc:
|
|
173
|
-
raise ValueError(
|
|
174
|
-
"Problem at self.derivative_keys initialization "
|
|
175
|
-
f"received {self.derivative_keys=} and {params=}"
|
|
176
|
-
) from exc
|
|
177
|
-
|
|
178
|
-
if self.loss_weights is None:
|
|
179
|
-
self.loss_weights = (
|
|
180
|
-
LossWeightsPDENonStatio()
|
|
181
|
-
if isinstance(self, LossPDENonStatio)
|
|
182
|
-
else LossWeightsPDEStatio()
|
|
183
|
-
)
|
|
184
|
-
|
|
185
|
-
if self.obs_slice is None:
|
|
143
|
+
omega_boundary_dim: slice | dict[str, slice] = eqx.field(static=True)
|
|
144
|
+
norm_samples: Float[Array, " nb_norm_samples dimension"] | None
|
|
145
|
+
norm_weights: Float[Array, " nb_norm_samples"] | None
|
|
146
|
+
obs_slice: EllipsisType | slice = eqx.field(static=True)
|
|
147
|
+
key: PRNGKeyArray | None
|
|
148
|
+
|
|
149
|
+
def __init__(
|
|
150
|
+
self,
|
|
151
|
+
*,
|
|
152
|
+
omega_boundary_fun: BoundaryConditionFun
|
|
153
|
+
| dict[str, BoundaryConditionFun]
|
|
154
|
+
| None = None,
|
|
155
|
+
omega_boundary_condition: str | dict[str, str | None] | None = None,
|
|
156
|
+
omega_boundary_dim: int | slice | dict[str, int | slice] | None = None,
|
|
157
|
+
norm_samples: Float[Array, " nb_norm_samples dimension"] | None = None,
|
|
158
|
+
norm_weights: Float[Array, " nb_norm_samples"] | float | int | None = None,
|
|
159
|
+
obs_slice: EllipsisType | slice | None = None,
|
|
160
|
+
key: PRNGKeyArray | None = None,
|
|
161
|
+
**kwargs: Any, # for arguments for super()
|
|
162
|
+
):
|
|
163
|
+
super().__init__(loss_weights=self.loss_weights, **kwargs)
|
|
164
|
+
|
|
165
|
+
if obs_slice is None:
|
|
186
166
|
self.obs_slice = jnp.s_[...]
|
|
167
|
+
else:
|
|
168
|
+
self.obs_slice = obs_slice
|
|
187
169
|
|
|
188
170
|
if (
|
|
189
|
-
isinstance(
|
|
190
|
-
and not isinstance(
|
|
171
|
+
isinstance(omega_boundary_fun, dict)
|
|
172
|
+
and not isinstance(omega_boundary_condition, dict)
|
|
191
173
|
) or (
|
|
192
|
-
not isinstance(
|
|
193
|
-
and isinstance(
|
|
174
|
+
not isinstance(omega_boundary_fun, dict)
|
|
175
|
+
and isinstance(omega_boundary_condition, dict)
|
|
194
176
|
):
|
|
195
177
|
raise ValueError(
|
|
196
|
-
"if one of
|
|
197
|
-
"
|
|
178
|
+
"if one of omega_boundary_fun or "
|
|
179
|
+
"omega_boundary_condition is dict, the other should be too."
|
|
198
180
|
)
|
|
199
181
|
|
|
200
|
-
if
|
|
182
|
+
if omega_boundary_condition is None or omega_boundary_fun is None:
|
|
201
183
|
warnings.warn(
|
|
202
184
|
"Missing boundary function or no boundary condition."
|
|
203
185
|
"Boundary function is thus ignored."
|
|
204
186
|
)
|
|
205
187
|
else:
|
|
206
|
-
if isinstance(
|
|
207
|
-
for _, v in
|
|
188
|
+
if isinstance(omega_boundary_condition, dict):
|
|
189
|
+
for _, v in omega_boundary_condition.items():
|
|
208
190
|
if v is not None and not any(
|
|
209
191
|
v.lower() in s for s in _IMPLEMENTED_BOUNDARY_CONDITIONS
|
|
210
192
|
):
|
|
211
193
|
raise NotImplementedError(
|
|
212
|
-
f"The boundary condition {
|
|
194
|
+
f"The boundary condition {omega_boundary_condition} is not"
|
|
213
195
|
f"implemented yet. Try one of :"
|
|
214
196
|
f"{_IMPLEMENTED_BOUNDARY_CONDITIONS}."
|
|
215
197
|
)
|
|
216
198
|
else:
|
|
217
199
|
if not any(
|
|
218
|
-
|
|
200
|
+
omega_boundary_condition.lower() in s
|
|
219
201
|
for s in _IMPLEMENTED_BOUNDARY_CONDITIONS
|
|
220
202
|
):
|
|
221
203
|
raise NotImplementedError(
|
|
222
|
-
f"The boundary condition {
|
|
204
|
+
f"The boundary condition {omega_boundary_condition} is not"
|
|
223
205
|
f"implemented yet. Try one of :"
|
|
224
206
|
f"{_IMPLEMENTED_BOUNDARY_CONDITIONS}."
|
|
225
207
|
)
|
|
226
|
-
|
|
227
|
-
|
|
208
|
+
if isinstance(omega_boundary_fun, dict) and isinstance(
|
|
209
|
+
omega_boundary_condition, dict
|
|
210
|
+
):
|
|
211
|
+
keys_omega_boundary_fun = cast(str, omega_boundary_fun.keys())
|
|
212
|
+
if (
|
|
213
|
+
not (
|
|
214
|
+
list(keys_omega_boundary_fun) == ["xmin", "xmax"]
|
|
215
|
+
and list(omega_boundary_condition.keys()) == ["xmin", "xmax"]
|
|
216
|
+
)
|
|
217
|
+
) and (
|
|
218
|
+
not (
|
|
219
|
+
list(keys_omega_boundary_fun)
|
|
220
|
+
== ["xmin", "xmax", "ymin", "ymax"]
|
|
221
|
+
and list(omega_boundary_condition.keys())
|
|
222
|
+
== ["xmin", "xmax", "ymin", "ymax"]
|
|
223
|
+
)
|
|
228
224
|
):
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
== ["xmin", "xmax"]
|
|
234
|
-
)
|
|
235
|
-
) or (
|
|
236
|
-
not (
|
|
237
|
-
list(self.omega_boundary_fun.keys())
|
|
238
|
-
== ["xmin", "xmax", "ymin", "ymax"]
|
|
239
|
-
and list(self.omega_boundary_condition.keys())
|
|
240
|
-
== ["xmin", "xmax", "ymin", "ymax"]
|
|
241
|
-
)
|
|
242
|
-
):
|
|
243
|
-
raise ValueError(
|
|
244
|
-
"The key order (facet order) in the "
|
|
245
|
-
"boundary condition dictionaries is incorrect"
|
|
246
|
-
)
|
|
225
|
+
raise ValueError(
|
|
226
|
+
"The key order (facet order) in the "
|
|
227
|
+
"boundary condition dictionaries is incorrect"
|
|
228
|
+
)
|
|
247
229
|
|
|
248
|
-
|
|
249
|
-
|
|
230
|
+
self.omega_boundary_fun = omega_boundary_fun
|
|
231
|
+
self.omega_boundary_condition = omega_boundary_condition
|
|
232
|
+
|
|
233
|
+
if isinstance(omega_boundary_fun, dict):
|
|
234
|
+
keys_omega_boundary_fun: str = cast(str, omega_boundary_fun.keys())
|
|
235
|
+
if omega_boundary_dim is None:
|
|
236
|
+
self.omega_boundary_dim = {
|
|
237
|
+
k: jnp.s_[::] for k in keys_omega_boundary_fun
|
|
238
|
+
}
|
|
239
|
+
if not isinstance(omega_boundary_dim, dict):
|
|
250
240
|
raise ValueError(
|
|
251
241
|
"If omega_boundary_fun is a dict then"
|
|
252
242
|
" omega_boundary_dim should also be a dict"
|
|
253
243
|
)
|
|
254
|
-
if
|
|
255
|
-
self.omega_boundary_dim = {
|
|
256
|
-
k: jnp.s_[::] for k in self.omega_boundary_fun.keys()
|
|
257
|
-
}
|
|
258
|
-
if list(self.omega_boundary_dim.keys()) != list(
|
|
259
|
-
self.omega_boundary_fun.keys()
|
|
260
|
-
):
|
|
244
|
+
if list(omega_boundary_dim.keys()) != list(keys_omega_boundary_fun):
|
|
261
245
|
raise ValueError(
|
|
262
246
|
"If omega_boundary_fun is a dict,"
|
|
263
247
|
" omega_boundary_dim should be a dict with the same keys"
|
|
264
248
|
)
|
|
265
|
-
|
|
249
|
+
self.omega_boundary_dim = {}
|
|
250
|
+
for k, v in omega_boundary_dim.items():
|
|
266
251
|
if isinstance(v, int):
|
|
267
252
|
# rewrite it as a slice to ensure that axis does not disappear when
|
|
268
253
|
# indexing
|
|
269
254
|
self.omega_boundary_dim[k] = jnp.s_[v : v + 1]
|
|
255
|
+
else:
|
|
256
|
+
self.omega_boundary_dim[k] = v
|
|
270
257
|
|
|
271
258
|
else:
|
|
272
|
-
|
|
259
|
+
assert not isinstance(omega_boundary_dim, dict)
|
|
260
|
+
if omega_boundary_dim is None:
|
|
273
261
|
self.omega_boundary_dim = jnp.s_[::]
|
|
274
|
-
|
|
262
|
+
elif isinstance(omega_boundary_dim, int):
|
|
275
263
|
# rewrite it as a slice to ensure that axis does not disappear when
|
|
276
264
|
# indexing
|
|
277
265
|
self.omega_boundary_dim = jnp.s_[
|
|
278
|
-
|
|
266
|
+
omega_boundary_dim : omega_boundary_dim + 1
|
|
279
267
|
]
|
|
280
|
-
|
|
281
|
-
|
|
268
|
+
else:
|
|
269
|
+
assert isinstance(omega_boundary_dim, slice)
|
|
270
|
+
self.omega_boundary_dim = omega_boundary_dim
|
|
282
271
|
|
|
283
|
-
if
|
|
284
|
-
|
|
272
|
+
if norm_samples is not None:
|
|
273
|
+
self.norm_samples = norm_samples
|
|
274
|
+
if norm_weights is None:
|
|
285
275
|
raise ValueError(
|
|
286
276
|
"`norm_weights` must be provided when `norm_samples` is used!"
|
|
287
277
|
)
|
|
288
|
-
if isinstance(
|
|
289
|
-
self.norm_weights =
|
|
278
|
+
if isinstance(norm_weights, (int, float)):
|
|
279
|
+
self.norm_weights = norm_weights * jnp.ones(
|
|
290
280
|
(self.norm_samples.shape[0],)
|
|
291
281
|
)
|
|
292
|
-
|
|
293
|
-
|
|
282
|
+
else:
|
|
283
|
+
assert isinstance(norm_weights, Array)
|
|
284
|
+
if not (norm_weights.shape[0] == norm_samples.shape[0]):
|
|
294
285
|
raise ValueError(
|
|
295
|
-
"
|
|
296
|
-
"
|
|
286
|
+
"norm_weights and "
|
|
287
|
+
"norm_samples must have the same leading dimension"
|
|
297
288
|
)
|
|
298
|
-
|
|
299
|
-
|
|
289
|
+
self.norm_weights = norm_weights
|
|
290
|
+
else:
|
|
291
|
+
self.norm_samples = norm_samples
|
|
292
|
+
self.norm_weights = None
|
|
293
|
+
|
|
294
|
+
self.key = key
|
|
300
295
|
|
|
301
296
|
@abc.abstractmethod
|
|
302
|
-
def
|
|
297
|
+
def _get_dynamic_loss_batch(self, batch: B) -> Array:
|
|
303
298
|
pass
|
|
304
299
|
|
|
305
300
|
@abc.abstractmethod
|
|
306
|
-
def
|
|
307
|
-
|
|
301
|
+
def _get_normalization_loss_batch(self, batch: B) -> tuple[Array | None, ...]:
|
|
302
|
+
pass
|
|
303
|
+
|
|
304
|
+
def _get_dyn_loss_fun(
|
|
305
|
+
self, batch: B, vmap_in_axes_params: tuple[Params[int | None] | None]
|
|
306
|
+
) -> Callable[[Params[Array]], Array] | None:
|
|
307
|
+
if self.dynamic_loss is not None:
|
|
308
|
+
dyn_loss_eval = self.dynamic_loss.evaluate
|
|
309
|
+
dyn_loss_fun: Callable[[Params[Array]], Array] | None = (
|
|
310
|
+
lambda p: dynamic_loss_apply(
|
|
311
|
+
dyn_loss_eval,
|
|
312
|
+
self.u,
|
|
313
|
+
self._get_dynamic_loss_batch(batch),
|
|
314
|
+
_set_derivatives(p, self.derivative_keys.dyn_loss),
|
|
315
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
316
|
+
)
|
|
317
|
+
)
|
|
318
|
+
else:
|
|
319
|
+
dyn_loss_fun = None
|
|
320
|
+
|
|
321
|
+
return dyn_loss_fun
|
|
322
|
+
|
|
323
|
+
def _get_norm_loss_fun(
|
|
324
|
+
self, batch: B, vmap_in_axes_params: tuple[Params[int | None] | None]
|
|
325
|
+
) -> Callable[[Params[Array]], Array] | None:
|
|
326
|
+
if self.norm_samples is not None:
|
|
327
|
+
norm_loss_fun: Callable[[Params[Array]], Array] | None = (
|
|
328
|
+
lambda p: normalization_loss_apply(
|
|
329
|
+
self.u,
|
|
330
|
+
cast(
|
|
331
|
+
tuple[Array, Array], self._get_normalization_loss_batch(batch)
|
|
332
|
+
),
|
|
333
|
+
_set_derivatives(p, self.derivative_keys.norm_loss),
|
|
334
|
+
vmap_in_axes_params,
|
|
335
|
+
self.norm_weights, # type: ignore -> can't get the __post_init__ narrowing here
|
|
336
|
+
)
|
|
337
|
+
)
|
|
338
|
+
else:
|
|
339
|
+
norm_loss_fun = None
|
|
340
|
+
return norm_loss_fun
|
|
341
|
+
|
|
342
|
+
def _get_boundary_loss_fun(
|
|
343
|
+
self, batch: B
|
|
344
|
+
) -> Callable[[Params[Array]], Array] | None:
|
|
345
|
+
if (
|
|
346
|
+
self.omega_boundary_condition is not None
|
|
347
|
+
and self.omega_boundary_fun is not None
|
|
348
|
+
):
|
|
349
|
+
boundary_loss_fun: Callable[[Params[Array]], Array] | None = (
|
|
350
|
+
lambda p: boundary_condition_apply(
|
|
351
|
+
self.u,
|
|
352
|
+
batch,
|
|
353
|
+
_set_derivatives(p, self.derivative_keys.boundary_loss),
|
|
354
|
+
self.omega_boundary_fun, # type: ignore (we are in lambda)
|
|
355
|
+
self.omega_boundary_condition, # type: ignore
|
|
356
|
+
self.omega_boundary_dim, # type: ignore
|
|
357
|
+
)
|
|
358
|
+
)
|
|
359
|
+
else:
|
|
360
|
+
boundary_loss_fun = None
|
|
361
|
+
|
|
362
|
+
return boundary_loss_fun
|
|
363
|
+
|
|
364
|
+
def _get_obs_params_and_obs_loss_fun(
|
|
365
|
+
self,
|
|
366
|
+
batch: B,
|
|
367
|
+
vmap_in_axes_params: tuple[Params[int | None] | None],
|
|
308
368
|
params: Params[Array],
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
369
|
+
) -> tuple[Params[Array] | None, Callable[[Params[Array]], Array] | None]:
|
|
370
|
+
if batch.obs_batch_dict is not None:
|
|
371
|
+
# update params with the batches of observed params
|
|
372
|
+
params_obs = update_eq_params(params, batch.obs_batch_dict["eq_params"])
|
|
373
|
+
|
|
374
|
+
pinn_in, val = (
|
|
375
|
+
batch.obs_batch_dict["pinn_in"],
|
|
376
|
+
batch.obs_batch_dict["val"],
|
|
377
|
+
)
|
|
312
378
|
|
|
379
|
+
obs_loss_fun: Callable[[Params[Array]], Array] | None = (
|
|
380
|
+
lambda po: observations_loss_apply(
|
|
381
|
+
self.u,
|
|
382
|
+
pinn_in,
|
|
383
|
+
_set_derivatives(po, self.derivative_keys.observations),
|
|
384
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
385
|
+
val,
|
|
386
|
+
self.obs_slice,
|
|
387
|
+
)
|
|
388
|
+
)
|
|
389
|
+
else:
|
|
390
|
+
params_obs = None
|
|
391
|
+
obs_loss_fun = None
|
|
392
|
+
|
|
393
|
+
return params_obs, obs_loss_fun
|
|
313
394
|
|
|
314
|
-
|
|
395
|
+
|
|
396
|
+
class LossPDEStatio(
|
|
397
|
+
_LossPDEAbstract[
|
|
398
|
+
LossWeightsPDEStatio,
|
|
399
|
+
PDEStatioBatch,
|
|
400
|
+
PDEStatioComponents[Array | None],
|
|
401
|
+
DerivativeKeysPDEStatio,
|
|
402
|
+
PDEStatio | None,
|
|
403
|
+
]
|
|
404
|
+
):
|
|
315
405
|
r"""Loss object for a stationary partial differential equation
|
|
316
406
|
|
|
317
407
|
$$
|
|
@@ -326,13 +416,13 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
326
416
|
----------
|
|
327
417
|
u : AbstractPINN
|
|
328
418
|
the PINN
|
|
329
|
-
dynamic_loss : PDEStatio
|
|
419
|
+
dynamic_loss : PDEStatio | None
|
|
330
420
|
the stationary PDE dynamic part of the loss, basically the differential
|
|
331
421
|
operator $\mathcal{N}[u](x)$. Should implement a method
|
|
332
422
|
`dynamic_loss.evaluate(x, u, params)`.
|
|
333
423
|
Can be None in order to access only some part of the evaluate call
|
|
334
424
|
results.
|
|
335
|
-
key :
|
|
425
|
+
key : PRNGKeyArray
|
|
336
426
|
A JAX PRNG Key for the loss class treated as an attribute. Default is
|
|
337
427
|
None. This field is provided for future developments and additional
|
|
338
428
|
losses that might need some randomness. Note that special care must be
|
|
@@ -344,14 +434,18 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
344
434
|
observations if any.
|
|
345
435
|
Can be updated according to a specific algorithm. See
|
|
346
436
|
`update_weight_method`
|
|
347
|
-
update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
|
|
348
|
-
Default is None meaning no update for loss weights. Otherwise a string
|
|
349
437
|
derivative_keys : DerivativeKeysPDEStatio, default=None
|
|
350
438
|
Specify which field of `params` should be differentiated for each
|
|
351
439
|
composant of the total loss. Particularily useful for inverse problems.
|
|
352
440
|
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
353
441
|
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
354
442
|
is `"nn_params"` for each composant of the loss.
|
|
443
|
+
params : InitVar[Params[Array]], default=None
|
|
444
|
+
The main Params object of the problem needed to instanciate the
|
|
445
|
+
DerivativeKeysODE if the latter is not specified.
|
|
446
|
+
|
|
447
|
+
update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
|
|
448
|
+
Default is None meaning no update for loss weights. Otherwise a string
|
|
355
449
|
omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
|
|
356
450
|
The function to be matched in the border condition (can be None) or a
|
|
357
451
|
dictionary of such functions as values and keys as described
|
|
@@ -367,7 +461,7 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
367
461
|
a particular boundary condition on this facet.
|
|
368
462
|
The facet called “xmin”, resp. “xmax” etc., in 2D,
|
|
369
463
|
refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
|
|
370
|
-
omega_boundary_dim : slice | dict[str, slice], default=None
|
|
464
|
+
omega_boundary_dim : int | slice | dict[str, slice], default=None
|
|
371
465
|
Either None, or a slice object or a dictionary of slice objects as
|
|
372
466
|
values and keys as described in `omega_boundary_condition`.
|
|
373
467
|
`omega_boundary_dim` indicates which dimension(s) of the PINN
|
|
@@ -388,10 +482,6 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
388
482
|
obs_slice : slice, default=None
|
|
389
483
|
slice object specifying the begininning/ending of the PINN output
|
|
390
484
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
391
|
-
params : InitVar[Params[Array]], default=None
|
|
392
|
-
The main Params object of the problem needed to instanciate the
|
|
393
|
-
DerivativeKeysODE if the latter is not specified.
|
|
394
|
-
|
|
395
485
|
|
|
396
486
|
Raises
|
|
397
487
|
------
|
|
@@ -405,21 +495,46 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
405
495
|
|
|
406
496
|
u: AbstractPINN
|
|
407
497
|
dynamic_loss: PDEStatio | None
|
|
408
|
-
|
|
498
|
+
loss_weights: LossWeightsPDEStatio
|
|
499
|
+
derivative_keys: DerivativeKeysPDEStatio
|
|
500
|
+
vmap_in_axes: tuple[int] = eqx.field(static=True)
|
|
501
|
+
|
|
502
|
+
params: InitVar[Params[Array] | None]
|
|
503
|
+
|
|
504
|
+
def __init__(
|
|
505
|
+
self,
|
|
506
|
+
*,
|
|
507
|
+
u: AbstractPINN,
|
|
508
|
+
dynamic_loss: PDEStatio | None,
|
|
509
|
+
loss_weights: LossWeightsPDEStatio | None = None,
|
|
510
|
+
derivative_keys: DerivativeKeysPDEStatio | None = None,
|
|
511
|
+
params: Params[Array] | None = None,
|
|
512
|
+
**kwargs: Any,
|
|
513
|
+
):
|
|
514
|
+
self.u = u
|
|
515
|
+
if loss_weights is None:
|
|
516
|
+
self.loss_weights = LossWeightsPDEStatio()
|
|
517
|
+
else:
|
|
518
|
+
self.loss_weights = loss_weights
|
|
519
|
+
self.dynamic_loss = dynamic_loss
|
|
409
520
|
|
|
410
|
-
|
|
521
|
+
super().__init__(
|
|
522
|
+
**kwargs,
|
|
523
|
+
)
|
|
411
524
|
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
525
|
+
if derivative_keys is None:
|
|
526
|
+
# be default we only take gradient wrt nn_params
|
|
527
|
+
try:
|
|
528
|
+
self.derivative_keys = DerivativeKeysPDEStatio(params=params)
|
|
529
|
+
except ValueError as exc:
|
|
530
|
+
raise ValueError(
|
|
531
|
+
"Problem at derivative_keys initialization "
|
|
532
|
+
f"received {derivative_keys=} and {params=}"
|
|
533
|
+
) from exc
|
|
534
|
+
else:
|
|
535
|
+
self.derivative_keys = derivative_keys
|
|
421
536
|
|
|
422
|
-
self.vmap_in_axes = (0,)
|
|
537
|
+
self.vmap_in_axes = (0,)
|
|
423
538
|
|
|
424
539
|
def _get_dynamic_loss_batch(
|
|
425
540
|
self, batch: PDEStatioBatch
|
|
@@ -427,23 +542,18 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
427
542
|
return batch.domain_batch
|
|
428
543
|
|
|
429
544
|
def _get_normalization_loss_batch(
|
|
430
|
-
self, _
|
|
431
|
-
) -> tuple[Float[Array, " nb_norm_samples dimension"]]:
|
|
432
|
-
return (self.norm_samples,) # type: ignore -> cannot narrow a class attr
|
|
433
|
-
|
|
434
|
-
# we could have used typing.cast though
|
|
435
|
-
|
|
436
|
-
def _get_observations_loss_batch(
|
|
437
545
|
self, batch: PDEStatioBatch
|
|
438
|
-
) -> Float[Array, "
|
|
439
|
-
return
|
|
546
|
+
) -> tuple[Float[Array, " nb_norm_samples dimension"] | None,]:
|
|
547
|
+
return (self.norm_samples,)
|
|
440
548
|
|
|
441
|
-
|
|
442
|
-
return self.evaluate(*args, **kwargs)
|
|
549
|
+
# we could have used typing.cast though
|
|
443
550
|
|
|
444
551
|
def evaluate_by_terms(
|
|
445
552
|
self, params: Params[Array], batch: PDEStatioBatch
|
|
446
|
-
) -> tuple[
|
|
553
|
+
) -> tuple[
|
|
554
|
+
PDEStatioComponents[Float[Array, ""] | None],
|
|
555
|
+
PDEStatioComponents[Float[Array, ""] | None],
|
|
556
|
+
]:
|
|
447
557
|
"""
|
|
448
558
|
Evaluate the loss function at a batch of points for given parameters.
|
|
449
559
|
|
|
@@ -465,69 +575,23 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
465
575
|
# and update vmap_in_axes
|
|
466
576
|
if batch.param_batch_dict is not None:
|
|
467
577
|
# update eq_params with the batches of generated params
|
|
468
|
-
params =
|
|
578
|
+
params = update_eq_params(params, batch.param_batch_dict)
|
|
469
579
|
|
|
470
580
|
vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
|
|
471
581
|
|
|
472
582
|
# dynamic part
|
|
473
|
-
|
|
474
|
-
dyn_loss_fun = lambda p: dynamic_loss_apply(
|
|
475
|
-
self.dynamic_loss.evaluate, # type: ignore
|
|
476
|
-
self.u,
|
|
477
|
-
self._get_dynamic_loss_batch(batch),
|
|
478
|
-
_set_derivatives(p, self.derivative_keys.dyn_loss), # type: ignore
|
|
479
|
-
self.vmap_in_axes + vmap_in_axes_params,
|
|
480
|
-
)
|
|
481
|
-
else:
|
|
482
|
-
dyn_loss_fun = None
|
|
583
|
+
dyn_loss_fun = self._get_dyn_loss_fun(batch, vmap_in_axes_params)
|
|
483
584
|
|
|
484
585
|
# normalization part
|
|
485
|
-
|
|
486
|
-
norm_loss_fun = lambda p: normalization_loss_apply(
|
|
487
|
-
self.u,
|
|
488
|
-
self._get_normalization_loss_batch(batch),
|
|
489
|
-
_set_derivatives(p, self.derivative_keys.norm_loss), # type: ignore
|
|
490
|
-
vmap_in_axes_params,
|
|
491
|
-
self.norm_weights, # type: ignore -> can't get the __post_init__ narrowing here
|
|
492
|
-
)
|
|
493
|
-
else:
|
|
494
|
-
norm_loss_fun = None
|
|
586
|
+
norm_loss_fun = self._get_norm_loss_fun(batch, vmap_in_axes_params)
|
|
495
587
|
|
|
496
588
|
# boundary part
|
|
497
|
-
|
|
498
|
-
self.omega_boundary_condition is not None
|
|
499
|
-
and self.omega_boundary_dim is not None
|
|
500
|
-
and self.omega_boundary_fun is not None
|
|
501
|
-
): # pyright cannot narrow down the three None otherwise as it is class attribute
|
|
502
|
-
boundary_loss_fun = lambda p: boundary_condition_apply(
|
|
503
|
-
self.u,
|
|
504
|
-
batch,
|
|
505
|
-
_set_derivatives(p, self.derivative_keys.boundary_loss), # type: ignore
|
|
506
|
-
self.omega_boundary_fun, # type: ignore
|
|
507
|
-
self.omega_boundary_condition, # type: ignore
|
|
508
|
-
self.omega_boundary_dim, # type: ignore
|
|
509
|
-
)
|
|
510
|
-
else:
|
|
511
|
-
boundary_loss_fun = None
|
|
589
|
+
boundary_loss_fun = self._get_boundary_loss_fun(batch)
|
|
512
590
|
|
|
513
591
|
# Observation mse
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
params, batch.obs_batch_dict["eq_params"]
|
|
518
|
-
)
|
|
519
|
-
|
|
520
|
-
obs_loss_fun = lambda po: observations_loss_apply(
|
|
521
|
-
self.u,
|
|
522
|
-
self._get_observations_loss_batch(batch),
|
|
523
|
-
_set_derivatives(po, self.derivative_keys.observations), # type: ignore
|
|
524
|
-
self.vmap_in_axes + vmap_in_axes_params,
|
|
525
|
-
batch.obs_batch_dict["val"],
|
|
526
|
-
self.obs_slice,
|
|
527
|
-
)
|
|
528
|
-
else:
|
|
529
|
-
params_obs = None
|
|
530
|
-
obs_loss_fun = None
|
|
592
|
+
params_obs, obs_loss_fun = self._get_obs_params_and_obs_loss_fun(
|
|
593
|
+
batch, vmap_in_axes_params, params
|
|
594
|
+
)
|
|
531
595
|
|
|
532
596
|
# get the unweighted mses for each loss term as well as the gradients
|
|
533
597
|
all_funs: PDEStatioComponents[Callable[[Params[Array]], Array] | None] = (
|
|
@@ -539,47 +603,34 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
539
603
|
params, params, params, params_obs
|
|
540
604
|
)
|
|
541
605
|
mses_grads = jax.tree.map(
|
|
542
|
-
|
|
606
|
+
self.get_gradients,
|
|
543
607
|
all_funs,
|
|
544
608
|
all_params,
|
|
545
609
|
is_leaf=lambda x: x is None,
|
|
546
610
|
)
|
|
547
611
|
mses = jax.tree.map(
|
|
548
|
-
lambda leaf: leaf[0],
|
|
612
|
+
lambda leaf: leaf[0], # type: ignore
|
|
613
|
+
mses_grads,
|
|
614
|
+
is_leaf=lambda x: isinstance(x, tuple),
|
|
549
615
|
)
|
|
550
616
|
grads = jax.tree.map(
|
|
551
|
-
lambda leaf: leaf[1],
|
|
617
|
+
lambda leaf: leaf[1], # type: ignore
|
|
618
|
+
mses_grads,
|
|
619
|
+
is_leaf=lambda x: isinstance(x, tuple),
|
|
552
620
|
)
|
|
553
621
|
|
|
554
622
|
return mses, grads
|
|
555
623
|
|
|
556
|
-
def evaluate(
|
|
557
|
-
self, params: Params[Array], batch: PDEStatioBatch
|
|
558
|
-
) -> tuple[Float[Array, " "], PDEStatioComponents[Float[Array, " "] | None]]:
|
|
559
|
-
"""
|
|
560
|
-
Evaluate the loss function at a batch of points for given parameters.
|
|
561
|
-
|
|
562
|
-
We retrieve the total value itself and a PyTree with loss values for each term
|
|
563
|
-
|
|
564
|
-
Parameters
|
|
565
|
-
---------
|
|
566
|
-
params
|
|
567
|
-
Parameters at which the loss is evaluated
|
|
568
|
-
batch
|
|
569
|
-
Composed of a batch of points in the
|
|
570
|
-
domain, a batch of points in the domain
|
|
571
|
-
border and an optional additional batch of parameters (eg. for
|
|
572
|
-
metamodeling) and an optional additional batch of observed
|
|
573
|
-
inputs/outputs/parameters
|
|
574
|
-
"""
|
|
575
|
-
loss_terms, _ = self.evaluate_by_terms(params, batch)
|
|
576
|
-
|
|
577
|
-
loss_val = self.ponderate_and_sum_loss(loss_terms)
|
|
578
|
-
|
|
579
|
-
return loss_val, loss_terms
|
|
580
624
|
|
|
581
|
-
|
|
582
|
-
|
|
625
|
+
class LossPDENonStatio(
|
|
626
|
+
_LossPDEAbstract[
|
|
627
|
+
LossWeightsPDENonStatio,
|
|
628
|
+
PDENonStatioBatch,
|
|
629
|
+
PDENonStatioComponents[Array | None],
|
|
630
|
+
DerivativeKeysPDENonStatio,
|
|
631
|
+
PDENonStatio | None,
|
|
632
|
+
]
|
|
633
|
+
):
|
|
583
634
|
r"""Loss object for a stationary partial differential equation
|
|
584
635
|
|
|
585
636
|
$$
|
|
@@ -603,7 +654,7 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
603
654
|
`dynamic_loss.evaluate(t, x, u, params)`.
|
|
604
655
|
Can be None in order to access only some part of the evaluate call
|
|
605
656
|
results.
|
|
606
|
-
key :
|
|
657
|
+
key : PRNGKeyArray
|
|
607
658
|
A JAX PRNG Key for the loss class treated as an attribute. Default is
|
|
608
659
|
None. This field is provided for future developments and additional
|
|
609
660
|
losses that might need some randomness. Note that special care must be
|
|
@@ -616,14 +667,31 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
616
667
|
observations if any.
|
|
617
668
|
Can be updated according to a specific algorithm. See
|
|
618
669
|
`update_weight_method`
|
|
619
|
-
update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
|
|
620
|
-
Default is None meaning no update for loss weights. Otherwise a string
|
|
621
670
|
derivative_keys : DerivativeKeysPDENonStatio, default=None
|
|
622
671
|
Specify which field of `params` should be differentiated for each
|
|
623
672
|
composant of the total loss. Particularily useful for inverse problems.
|
|
624
673
|
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
625
674
|
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
626
675
|
is `"nn_params"` for each composant of the loss.
|
|
676
|
+
initial_condition_fun : Callable, default=None
|
|
677
|
+
A function representing the initial condition at `t0`. If None
|
|
678
|
+
(default) then no initial condition is applied.
|
|
679
|
+
t0 : float | Float[Array, " 1"], default=None
|
|
680
|
+
The time at which to apply the initial condition. If None, the time
|
|
681
|
+
is set to `0` by default.
|
|
682
|
+
max_norm_time_slices : int, default=100
|
|
683
|
+
The maximum number of time points in the Cartesian product with the
|
|
684
|
+
omega points to create the set of collocation points upon which the
|
|
685
|
+
normalization constant is computed.
|
|
686
|
+
max_norm_samples_omega : int, default=1000
|
|
687
|
+
The maximum number of omega points in the Cartesian product with the
|
|
688
|
+
time points to create the set of collocation points upon which the
|
|
689
|
+
normalization constant is computed.
|
|
690
|
+
params : InitVar[Params[Array]], default=None
|
|
691
|
+
The main `Params` object of the problem needed to instanciate the
|
|
692
|
+
`DerivativeKeysODE` if the latter is not specified.
|
|
693
|
+
update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
|
|
694
|
+
Default is None meaning no update for loss weights. Otherwise a string
|
|
627
695
|
omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
|
|
628
696
|
The function to be matched in the border condition (can be None) or a
|
|
629
697
|
dictionary of such functions as values and keys as described
|
|
@@ -660,58 +728,81 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
660
728
|
obs_slice : slice, default=None
|
|
661
729
|
slice object specifying the begininning/ending of the PINN output
|
|
662
730
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
663
|
-
t0 : float | Float[Array, " 1"], default=None
|
|
664
|
-
The time at which to apply the initial condition. If None, the time
|
|
665
|
-
is set to `0` by default.
|
|
666
|
-
initial_condition_fun : Callable, default=None
|
|
667
|
-
A function representing the initial condition at `t0`. If None
|
|
668
|
-
(default) then no initial condition is applied.
|
|
669
|
-
params : InitVar[Params[Array]], default=None
|
|
670
|
-
The main `Params` object of the problem needed to instanciate the
|
|
671
|
-
`DerivativeKeysODE` if the latter is not specified.
|
|
672
731
|
|
|
673
732
|
"""
|
|
674
733
|
|
|
734
|
+
u: AbstractPINN
|
|
675
735
|
dynamic_loss: PDENonStatio | None
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
736
|
+
loss_weights: LossWeightsPDENonStatio
|
|
737
|
+
derivative_keys: DerivativeKeysPDENonStatio
|
|
738
|
+
params: InitVar[Params[Array] | None]
|
|
739
|
+
t0: Float[Array, " "]
|
|
740
|
+
initial_condition_fun: Callable[[Float[Array, " dimension"]], Array] | None = (
|
|
741
|
+
eqx.field(static=True)
|
|
680
742
|
)
|
|
681
|
-
|
|
743
|
+
vmap_in_axes: tuple[int] = eqx.field(static=True)
|
|
744
|
+
max_norm_samples_omega: int = eqx.field(static=True)
|
|
745
|
+
max_norm_time_slices: int = eqx.field(static=True)
|
|
746
|
+
|
|
747
|
+
params: InitVar[Params[Array] | None]
|
|
748
|
+
|
|
749
|
+
def __init__(
|
|
750
|
+
self,
|
|
751
|
+
*,
|
|
752
|
+
u: AbstractPINN,
|
|
753
|
+
dynamic_loss: PDENonStatio | None,
|
|
754
|
+
loss_weights: LossWeightsPDENonStatio | None = None,
|
|
755
|
+
derivative_keys: DerivativeKeysPDENonStatio | None = None,
|
|
756
|
+
initial_condition_fun: Callable[[Float[Array, " dimension"]], Array]
|
|
757
|
+
| None = None,
|
|
758
|
+
t0: int | float | Float[Array, " "] | None = None,
|
|
759
|
+
max_norm_time_slices: int = 100,
|
|
760
|
+
max_norm_samples_omega: int = 1000,
|
|
761
|
+
params: Params[Array] | None = None,
|
|
762
|
+
**kwargs: Any,
|
|
763
|
+
):
|
|
764
|
+
self.u = u
|
|
765
|
+
if loss_weights is None:
|
|
766
|
+
self.loss_weights = LossWeightsPDENonStatio()
|
|
767
|
+
else:
|
|
768
|
+
self.loss_weights = loss_weights
|
|
769
|
+
self.dynamic_loss = dynamic_loss
|
|
682
770
|
|
|
683
|
-
|
|
684
|
-
|
|
771
|
+
super().__init__(
|
|
772
|
+
**kwargs,
|
|
773
|
+
)
|
|
685
774
|
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
775
|
+
if derivative_keys is None:
|
|
776
|
+
# be default we only take gradient wrt nn_params
|
|
777
|
+
try:
|
|
778
|
+
self.derivative_keys = DerivativeKeysPDENonStatio(params=params)
|
|
779
|
+
except ValueError as exc:
|
|
780
|
+
raise ValueError(
|
|
781
|
+
"Problem at derivative_keys initialization "
|
|
782
|
+
f"received {derivative_keys=} and {params=}"
|
|
783
|
+
) from exc
|
|
784
|
+
else:
|
|
785
|
+
self.derivative_keys = derivative_keys
|
|
695
786
|
|
|
696
787
|
self.vmap_in_axes = (0,) # for t_x
|
|
697
788
|
|
|
698
|
-
if
|
|
789
|
+
if initial_condition_fun is None:
|
|
699
790
|
warnings.warn(
|
|
700
791
|
"Initial condition wasn't provided. Be sure to cover for that"
|
|
701
792
|
"case (e.g by. hardcoding it into the PINN output)."
|
|
702
793
|
)
|
|
703
794
|
# some checks for t0
|
|
704
|
-
t0 = self.t0
|
|
705
795
|
if t0 is None:
|
|
706
|
-
t0 = jnp.array([0])
|
|
796
|
+
self.t0 = jnp.array([0])
|
|
707
797
|
else:
|
|
708
|
-
t0 = initial_condition_check(t0, dim_size=1)
|
|
709
|
-
|
|
798
|
+
self.t0 = initial_condition_check(t0, dim_size=1)
|
|
799
|
+
|
|
800
|
+
self.initial_condition_fun = initial_condition_fun
|
|
710
801
|
|
|
711
|
-
#
|
|
802
|
+
# with the variables below we avoid memory overflow since a cartesian
|
|
712
803
|
# product is taken
|
|
713
|
-
self.
|
|
714
|
-
self.
|
|
804
|
+
self.max_norm_time_slices = max_norm_time_slices
|
|
805
|
+
self.max_norm_samples_omega = max_norm_samples_omega
|
|
715
806
|
|
|
716
807
|
def _get_dynamic_loss_batch(
|
|
717
808
|
self, batch: PDENonStatioBatch
|
|
@@ -724,18 +815,10 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
724
815
|
Float[Array, " nb_norm_time_slices 1"], Float[Array, " nb_norm_samples dim"]
|
|
725
816
|
]:
|
|
726
817
|
return (
|
|
727
|
-
batch.domain_batch[: self.
|
|
728
|
-
self.norm_samples[: self.
|
|
818
|
+
batch.domain_batch[: self.max_norm_time_slices, 0:1],
|
|
819
|
+
self.norm_samples[: self.max_norm_samples_omega], # type: ignore -> cannot narrow a class attr
|
|
729
820
|
)
|
|
730
821
|
|
|
731
|
-
def _get_observations_loss_batch(
|
|
732
|
-
self, batch: PDENonStatioBatch
|
|
733
|
-
) -> Float[Array, " batch_size 1+dim"]:
|
|
734
|
-
return batch.obs_batch_dict["pinn_in"]
|
|
735
|
-
|
|
736
|
-
def __call__(self, *args, **kwargs):
|
|
737
|
-
return self.evaluate(*args, **kwargs)
|
|
738
|
-
|
|
739
822
|
def evaluate_by_terms(
|
|
740
823
|
self, params: Params[Array], batch: PDENonStatioBatch
|
|
741
824
|
) -> tuple[
|
|
@@ -757,76 +840,75 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
757
840
|
metamodeling) and an optional additional batch of observed
|
|
758
841
|
inputs/outputs/parameters
|
|
759
842
|
"""
|
|
760
|
-
|
|
761
|
-
assert
|
|
843
|
+
omega_initial_batch = batch.initial_batch
|
|
844
|
+
assert omega_initial_batch is not None
|
|
762
845
|
|
|
763
846
|
# Retrieve the optional eq_params_batch
|
|
764
847
|
# and update eq_params with the latter
|
|
765
848
|
# and update vmap_in_axes
|
|
766
849
|
if batch.param_batch_dict is not None:
|
|
767
850
|
# update eq_params with the batches of generated params
|
|
768
|
-
params =
|
|
851
|
+
params = update_eq_params(params, batch.param_batch_dict)
|
|
769
852
|
|
|
770
853
|
vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
|
|
771
854
|
|
|
772
|
-
#
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
855
|
+
# dynamic part
|
|
856
|
+
dyn_loss_fun = self._get_dyn_loss_fun(batch, vmap_in_axes_params)
|
|
857
|
+
|
|
858
|
+
# normalization part
|
|
859
|
+
norm_loss_fun = self._get_norm_loss_fun(batch, vmap_in_axes_params)
|
|
860
|
+
|
|
861
|
+
# boundary part
|
|
862
|
+
boundary_loss_fun = self._get_boundary_loss_fun(batch)
|
|
863
|
+
|
|
864
|
+
# Observation mse
|
|
865
|
+
params_obs, obs_loss_fun = self._get_obs_params_and_obs_loss_fun(
|
|
866
|
+
batch, vmap_in_axes_params, params
|
|
867
|
+
)
|
|
777
868
|
|
|
778
869
|
# initial condition
|
|
779
870
|
if self.initial_condition_fun is not None:
|
|
780
|
-
mse_initial_condition_fun
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
mse_initial_condition_fun, params
|
|
871
|
+
mse_initial_condition_fun: Callable[[Params[Array]], Array] | None = (
|
|
872
|
+
lambda p: initial_condition_apply(
|
|
873
|
+
self.u,
|
|
874
|
+
omega_initial_batch,
|
|
875
|
+
_set_derivatives(p, self.derivative_keys.initial_condition),
|
|
876
|
+
(0,) + vmap_in_axes_params,
|
|
877
|
+
self.initial_condition_fun, # type: ignore
|
|
878
|
+
self.t0,
|
|
879
|
+
)
|
|
790
880
|
)
|
|
791
881
|
else:
|
|
792
|
-
|
|
793
|
-
grad_initial_condition = None
|
|
794
|
-
|
|
795
|
-
mses = PDENonStatioComponents(
|
|
796
|
-
partial_mses.dyn_loss,
|
|
797
|
-
partial_mses.norm_loss,
|
|
798
|
-
partial_mses.boundary_loss,
|
|
799
|
-
partial_mses.observations,
|
|
800
|
-
mse_initial_condition,
|
|
801
|
-
)
|
|
882
|
+
mse_initial_condition_fun = None
|
|
802
883
|
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
884
|
+
# get the unweighted mses for each loss term as well as the gradients
|
|
885
|
+
all_funs: PDENonStatioComponents[Callable[[Params[Array]], Array] | None] = (
|
|
886
|
+
PDENonStatioComponents(
|
|
887
|
+
dyn_loss_fun,
|
|
888
|
+
norm_loss_fun,
|
|
889
|
+
boundary_loss_fun,
|
|
890
|
+
obs_loss_fun,
|
|
891
|
+
mse_initial_condition_fun,
|
|
892
|
+
)
|
|
893
|
+
)
|
|
894
|
+
all_params: PDENonStatioComponents[Params[Array] | None] = (
|
|
895
|
+
PDENonStatioComponents(params, params, params, params_obs, params)
|
|
896
|
+
)
|
|
897
|
+
mses_grads = jax.tree.map(
|
|
898
|
+
self.get_gradients,
|
|
899
|
+
all_funs,
|
|
900
|
+
all_params,
|
|
901
|
+
is_leaf=lambda x: x is None,
|
|
902
|
+
)
|
|
903
|
+
mses = jax.tree.map(
|
|
904
|
+
lambda leaf: leaf[0], # type: ignore
|
|
905
|
+
mses_grads,
|
|
906
|
+
is_leaf=lambda x: isinstance(x, tuple),
|
|
907
|
+
)
|
|
908
|
+
grads = jax.tree.map(
|
|
909
|
+
lambda leaf: leaf[1], # type: ignore
|
|
910
|
+
mses_grads,
|
|
911
|
+
is_leaf=lambda x: isinstance(x, tuple),
|
|
809
912
|
)
|
|
810
913
|
|
|
811
914
|
return mses, grads
|
|
812
|
-
|
|
813
|
-
def evaluate(
|
|
814
|
-
self, params: Params[Array], batch: PDENonStatioBatch
|
|
815
|
-
) -> tuple[Float[Array, " "], PDENonStatioComponents[Float[Array, " "] | None]]:
|
|
816
|
-
"""
|
|
817
|
-
Evaluate the loss function at a batch of points for given parameters.
|
|
818
|
-
We retrieve the total value itself and a PyTree with loss values for each term
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
Parameters
|
|
822
|
-
---------
|
|
823
|
-
params
|
|
824
|
-
Parameters at which the loss is evaluated
|
|
825
|
-
batch
|
|
826
|
-
Composed of a batch of points in
|
|
827
|
-
the domain, a batch of points in the domain
|
|
828
|
-
border, a batch of time points and an optional additional batch
|
|
829
|
-
of parameters (eg. for metamodeling) and an optional additional batch of observed
|
|
830
|
-
inputs/outputs/parameters
|
|
831
|
-
"""
|
|
832
|
-
return super().evaluate(params, batch) # type: ignore
|