jinns 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +114 -187
- jinns/loss/_DynamicLossAbstract.py +45 -68
- jinns/loss/_LossODE.py +71 -336
- jinns/loss/_LossPDE.py +146 -520
- jinns/loss/__init__.py +28 -6
- jinns/loss/_abstract_loss.py +15 -0
- jinns/loss/_boundary_conditions.py +20 -19
- jinns/loss/_loss_utils.py +78 -159
- jinns/loss/_loss_weights.py +12 -44
- jinns/loss/_operators.py +84 -74
- jinns/nn/__init__.py +15 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +94 -57
- jinns/nn/_mlp.py +50 -25
- jinns/nn/_pinn.py +33 -19
- jinns/nn/_ppinn.py +70 -34
- jinns/nn/_save_load.py +21 -51
- jinns/nn/_spinn.py +33 -16
- jinns/nn/_spinn_mlp.py +28 -22
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +35 -34
- jinns/solver/_rar.py +80 -63
- jinns/solver/_solve.py +89 -63
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +12 -9
- jinns/utils/_types.py +11 -57
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/METADATA +4 -3
- jinns-1.4.0.dist-info/RECORD +53 -0
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns-1.3.0.dist-info/RECORD +0 -44
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossPDE.py
CHANGED
|
@@ -1,16 +1,16 @@
|
|
|
1
|
-
# pylint: disable=unsubscriptable-object, no-member
|
|
2
1
|
"""
|
|
3
2
|
Main module to implement a PDE loss in jinns
|
|
4
3
|
"""
|
|
4
|
+
|
|
5
5
|
from __future__ import (
|
|
6
6
|
annotations,
|
|
7
7
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
8
|
|
|
9
9
|
import abc
|
|
10
|
-
from dataclasses import InitVar
|
|
11
|
-
from typing import TYPE_CHECKING,
|
|
10
|
+
from dataclasses import InitVar
|
|
11
|
+
from typing import TYPE_CHECKING, Callable, TypedDict
|
|
12
|
+
from types import EllipsisType
|
|
12
13
|
import warnings
|
|
13
|
-
import jax
|
|
14
14
|
import jax.numpy as jnp
|
|
15
15
|
import equinox as eqx
|
|
16
16
|
from jaxtyping import Float, Array, Key, Int
|
|
@@ -20,9 +20,7 @@ from jinns.loss._loss_utils import (
|
|
|
20
20
|
normalization_loss_apply,
|
|
21
21
|
observations_loss_apply,
|
|
22
22
|
initial_condition_apply,
|
|
23
|
-
constraints_system_loss_apply,
|
|
24
23
|
)
|
|
25
|
-
from jinns.data._DataGenerators import append_obs_batch
|
|
26
24
|
from jinns.parameters._params import (
|
|
27
25
|
_get_vmap_in_axes_params,
|
|
28
26
|
_update_eq_params_dict,
|
|
@@ -32,19 +30,30 @@ from jinns.parameters._derivative_keys import (
|
|
|
32
30
|
DerivativeKeysPDEStatio,
|
|
33
31
|
DerivativeKeysPDENonStatio,
|
|
34
32
|
)
|
|
33
|
+
from jinns.loss._abstract_loss import AbstractLoss
|
|
35
34
|
from jinns.loss._loss_weights import (
|
|
36
35
|
LossWeightsPDEStatio,
|
|
37
36
|
LossWeightsPDENonStatio,
|
|
38
|
-
LossWeightsPDEDict,
|
|
39
37
|
)
|
|
40
|
-
from jinns.loss._DynamicLossAbstract import PDEStatio, PDENonStatio
|
|
41
|
-
from jinns.nn._pinn import PINN
|
|
42
|
-
from jinns.nn._spinn import SPINN
|
|
43
38
|
from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
|
|
44
39
|
|
|
45
40
|
|
|
46
41
|
if TYPE_CHECKING:
|
|
47
|
-
|
|
42
|
+
# imports for type hints only
|
|
43
|
+
from jinns.parameters._params import Params
|
|
44
|
+
from jinns.nn._abstract_pinn import AbstractPINN
|
|
45
|
+
from jinns.loss import PDENonStatio, PDEStatio
|
|
46
|
+
from jinns.utils._types import BoundaryConditionFun
|
|
47
|
+
|
|
48
|
+
class LossDictPDEStatio(TypedDict):
|
|
49
|
+
dyn_loss: Float[Array, " "]
|
|
50
|
+
norm_loss: Float[Array, " "]
|
|
51
|
+
boundary_loss: Float[Array, " "]
|
|
52
|
+
observations: Float[Array, " "]
|
|
53
|
+
|
|
54
|
+
class LossDictPDENonStatio(LossDictPDEStatio):
|
|
55
|
+
initial_condition: Float[Array, " "]
|
|
56
|
+
|
|
48
57
|
|
|
49
58
|
_IMPLEMENTED_BOUNDARY_CONDITIONS = [
|
|
50
59
|
"dirichlet",
|
|
@@ -53,8 +62,8 @@ _IMPLEMENTED_BOUNDARY_CONDITIONS = [
|
|
|
53
62
|
]
|
|
54
63
|
|
|
55
64
|
|
|
56
|
-
class _LossPDEAbstract(
|
|
57
|
-
"""
|
|
65
|
+
class _LossPDEAbstract(AbstractLoss):
|
|
66
|
+
r"""
|
|
58
67
|
Parameters
|
|
59
68
|
----------
|
|
60
69
|
|
|
@@ -69,11 +78,11 @@ class _LossPDEAbstract(eqx.Module):
|
|
|
69
78
|
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
70
79
|
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
71
80
|
is `"nn_params"` for each composant of the loss.
|
|
72
|
-
omega_boundary_fun :
|
|
81
|
+
omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
|
|
73
82
|
The function to be matched in the border condition (can be None) or a
|
|
74
83
|
dictionary of such functions as values and keys as described
|
|
75
84
|
in `omega_boundary_condition`.
|
|
76
|
-
omega_boundary_condition : str |
|
|
85
|
+
omega_boundary_condition : str | dict[str, str], default=None
|
|
77
86
|
Either None (no condition, by default), or a string defining
|
|
78
87
|
the boundary condition (Dirichlet or Von Neumann),
|
|
79
88
|
or a dictionary with such strings as values. In this case,
|
|
@@ -84,28 +93,29 @@ class _LossPDEAbstract(eqx.Module):
|
|
|
84
93
|
a particular boundary condition on this facet.
|
|
85
94
|
The facet called “xmin”, resp. “xmax” etc., in 2D,
|
|
86
95
|
refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
|
|
87
|
-
omega_boundary_dim : slice |
|
|
96
|
+
omega_boundary_dim : slice | dict[str, slice], default=None
|
|
88
97
|
Either None, or a slice object or a dictionary of slice objects as
|
|
89
98
|
values and keys as described in `omega_boundary_condition`.
|
|
90
99
|
`omega_boundary_dim` indicates which dimension(s) of the PINN
|
|
91
100
|
will be forced to match the boundary condition.
|
|
92
101
|
Note that it must be a slice and not an integer
|
|
93
102
|
(but a preprocessing of the user provided argument takes care of it)
|
|
94
|
-
norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
|
|
103
|
+
norm_samples : Float[Array, " nb_norm_samples dimension"], default=None
|
|
95
104
|
Monte-Carlo sample points for computing the
|
|
96
105
|
normalization constant. Default is None.
|
|
97
|
-
norm_weights : Float[Array, "nb_norm_samples"] | float | int, default=None
|
|
106
|
+
norm_weights : Float[Array, " nb_norm_samples"] | float | int, default=None
|
|
98
107
|
The importance sampling weights for Monte-Carlo integration of the
|
|
99
108
|
normalization constant. Must be provided if `norm_samples` is provided.
|
|
100
|
-
`norm_weights` should
|
|
109
|
+
`norm_weights` should be broadcastble to
|
|
101
110
|
`norm_samples`.
|
|
102
|
-
Alternatively, the user can pass a float or an integer
|
|
111
|
+
Alternatively, the user can pass a float or an integer that will be
|
|
112
|
+
made broadcastable to `norm_samples`.
|
|
103
113
|
These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
|
|
104
114
|
$q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
|
|
105
|
-
obs_slice : slice, default=None
|
|
115
|
+
obs_slice : EllipsisType | slice, default=None
|
|
106
116
|
slice object specifying the begininning/ending of the PINN output
|
|
107
117
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
108
|
-
params : InitVar[Params], default=None
|
|
118
|
+
params : InitVar[Params[Array]], default=None
|
|
109
119
|
The main Params object of the problem needed to instanciate the
|
|
110
120
|
DerivativeKeysODE if the latter is not specified.
|
|
111
121
|
"""
|
|
@@ -119,26 +129,28 @@ class _LossPDEAbstract(eqx.Module):
|
|
|
119
129
|
loss_weights: LossWeightsPDEStatio | LossWeightsPDENonStatio | None = eqx.field(
|
|
120
130
|
kw_only=True, default=None
|
|
121
131
|
)
|
|
122
|
-
omega_boundary_fun:
|
|
123
|
-
|
|
124
|
-
)
|
|
125
|
-
omega_boundary_condition: str |
|
|
132
|
+
omega_boundary_fun: (
|
|
133
|
+
BoundaryConditionFun | dict[str, BoundaryConditionFun] | None
|
|
134
|
+
) = eqx.field(kw_only=True, default=None, static=True)
|
|
135
|
+
omega_boundary_condition: str | dict[str, str] | None = eqx.field(
|
|
126
136
|
kw_only=True, default=None, static=True
|
|
127
137
|
)
|
|
128
|
-
omega_boundary_dim: slice |
|
|
138
|
+
omega_boundary_dim: slice | dict[str, slice] | None = eqx.field(
|
|
129
139
|
kw_only=True, default=None, static=True
|
|
130
140
|
)
|
|
131
|
-
norm_samples: Float[Array, "nb_norm_samples dimension"] | None = eqx.field(
|
|
141
|
+
norm_samples: Float[Array, " nb_norm_samples dimension"] | None = eqx.field(
|
|
132
142
|
kw_only=True, default=None
|
|
133
143
|
)
|
|
134
|
-
norm_weights: Float[Array, "nb_norm_samples"] | float | int | None = eqx.field(
|
|
144
|
+
norm_weights: Float[Array, " nb_norm_samples"] | float | int | None = eqx.field(
|
|
135
145
|
kw_only=True, default=None
|
|
136
146
|
)
|
|
137
|
-
obs_slice: slice | None = eqx.field(
|
|
147
|
+
obs_slice: EllipsisType | slice | None = eqx.field(
|
|
148
|
+
kw_only=True, default=None, static=True
|
|
149
|
+
)
|
|
138
150
|
|
|
139
|
-
params: InitVar[Params] = eqx.field(kw_only=True, default=None)
|
|
151
|
+
params: InitVar[Params[Array]] = eqx.field(kw_only=True, default=None)
|
|
140
152
|
|
|
141
|
-
def __post_init__(self, params=None):
|
|
153
|
+
def __post_init__(self, params: Params[Array] | None = None):
|
|
142
154
|
"""
|
|
143
155
|
Note that neither __init__ or __post_init__ are called when udating a
|
|
144
156
|
Module with eqx.tree_at
|
|
@@ -228,6 +240,11 @@ class _LossPDEAbstract(eqx.Module):
|
|
|
228
240
|
)
|
|
229
241
|
|
|
230
242
|
if isinstance(self.omega_boundary_fun, dict):
|
|
243
|
+
if not isinstance(self.omega_boundary_dim, dict):
|
|
244
|
+
raise ValueError(
|
|
245
|
+
"If omega_boundary_fun is a dict then"
|
|
246
|
+
" omega_boundary_dim should also be a dict"
|
|
247
|
+
)
|
|
231
248
|
if self.omega_boundary_dim is None:
|
|
232
249
|
self.omega_boundary_dim = {
|
|
233
250
|
k: jnp.s_[::] for k in self.omega_boundary_fun.keys()
|
|
@@ -262,27 +279,29 @@ class _LossPDEAbstract(eqx.Module):
|
|
|
262
279
|
raise ValueError(
|
|
263
280
|
"`norm_weights` must be provided when `norm_samples` is used!"
|
|
264
281
|
)
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
)
|
|
272
|
-
else:
|
|
282
|
+
if isinstance(self.norm_weights, (int, float)):
|
|
283
|
+
self.norm_weights = self.norm_weights * jnp.ones(
|
|
284
|
+
(self.norm_samples.shape[0],)
|
|
285
|
+
)
|
|
286
|
+
if isinstance(self.norm_weights, Array):
|
|
287
|
+
if not (self.norm_weights.shape[0] == self.norm_samples.shape[0]):
|
|
273
288
|
raise ValueError(
|
|
274
|
-
"
|
|
275
|
-
"
|
|
276
|
-
f" got shape {self.norm_weights.shape} and"
|
|
277
|
-
f" shape {self.norm_samples.shape}."
|
|
289
|
+
"self.norm_weights and "
|
|
290
|
+
"self.norm_samples must have the same leading dimension"
|
|
278
291
|
)
|
|
292
|
+
else:
|
|
293
|
+
raise ValueError("Wrong type for self.norm_weights")
|
|
294
|
+
|
|
295
|
+
@abc.abstractmethod
|
|
296
|
+
def __call__(self, *_, **__):
|
|
297
|
+
pass
|
|
279
298
|
|
|
280
299
|
@abc.abstractmethod
|
|
281
300
|
def evaluate(
|
|
282
301
|
self: eqx.Module,
|
|
283
|
-
params: Params,
|
|
302
|
+
params: Params[Array],
|
|
284
303
|
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
285
|
-
) -> tuple[Float,
|
|
304
|
+
) -> tuple[Float[Array, " "], LossDictPDEStatio | LossDictPDENonStatio]:
|
|
286
305
|
raise NotImplementedError
|
|
287
306
|
|
|
288
307
|
|
|
@@ -299,9 +318,9 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
299
318
|
|
|
300
319
|
Parameters
|
|
301
320
|
----------
|
|
302
|
-
u :
|
|
321
|
+
u : AbstractPINN
|
|
303
322
|
the PINN
|
|
304
|
-
dynamic_loss :
|
|
323
|
+
dynamic_loss : PDEStatio
|
|
305
324
|
the stationary PDE dynamic part of the loss, basically the differential
|
|
306
325
|
operator $\mathcal{N}[u](x)$. Should implement a method
|
|
307
326
|
`dynamic_loss.evaluate(x, u, params)`.
|
|
@@ -324,11 +343,11 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
324
343
|
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
325
344
|
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
326
345
|
is `"nn_params"` for each composant of the loss.
|
|
327
|
-
omega_boundary_fun :
|
|
346
|
+
omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
|
|
328
347
|
The function to be matched in the border condition (can be None) or a
|
|
329
348
|
dictionary of such functions as values and keys as described
|
|
330
349
|
in `omega_boundary_condition`.
|
|
331
|
-
omega_boundary_condition : str |
|
|
350
|
+
omega_boundary_condition : str | dict[str, str], default=None
|
|
332
351
|
Either None (no condition, by default), or a string defining
|
|
333
352
|
the boundary condition (Dirichlet or Von Neumann),
|
|
334
353
|
or a dictionary with such strings as values. In this case,
|
|
@@ -339,17 +358,17 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
339
358
|
a particular boundary condition on this facet.
|
|
340
359
|
The facet called “xmin”, resp. “xmax” etc., in 2D,
|
|
341
360
|
refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
|
|
342
|
-
omega_boundary_dim : slice |
|
|
361
|
+
omega_boundary_dim : slice | dict[str, slice], default=None
|
|
343
362
|
Either None, or a slice object or a dictionary of slice objects as
|
|
344
363
|
values and keys as described in `omega_boundary_condition`.
|
|
345
364
|
`omega_boundary_dim` indicates which dimension(s) of the PINN
|
|
346
365
|
will be forced to match the boundary condition.
|
|
347
366
|
Note that it must be a slice and not an integer
|
|
348
367
|
(but a preprocessing of the user provided argument takes care of it)
|
|
349
|
-
norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
|
|
368
|
+
norm_samples : Float[Array, " nb_norm_samples dimension"], default=None
|
|
350
369
|
Monte-Carlo sample points for computing the
|
|
351
370
|
normalization constant. Default is None.
|
|
352
|
-
norm_weights : Float[Array, "nb_norm_samples"] | float | int, default=None
|
|
371
|
+
norm_weights : Float[Array, " nb_norm_samples"] | float | int, default=None
|
|
353
372
|
The importance sampling weights for Monte-Carlo integration of the
|
|
354
373
|
normalization constant. Must be provided if `norm_samples` is provided.
|
|
355
374
|
`norm_weights` should have the same leading dimension as
|
|
@@ -360,7 +379,7 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
360
379
|
obs_slice : slice, default=None
|
|
361
380
|
slice object specifying the begininning/ending of the PINN output
|
|
362
381
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
363
|
-
params : InitVar[Params], default=None
|
|
382
|
+
params : InitVar[Params[Array]], default=None
|
|
364
383
|
The main Params object of the problem needed to instanciate the
|
|
365
384
|
DerivativeKeysODE if the latter is not specified.
|
|
366
385
|
|
|
@@ -375,13 +394,13 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
375
394
|
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
376
395
|
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
377
396
|
|
|
378
|
-
u:
|
|
379
|
-
dynamic_loss:
|
|
397
|
+
u: AbstractPINN
|
|
398
|
+
dynamic_loss: PDEStatio | None
|
|
380
399
|
key: Key | None = eqx.field(kw_only=True, default=None)
|
|
381
400
|
|
|
382
401
|
vmap_in_axes: tuple[Int] = eqx.field(init=False, static=True)
|
|
383
402
|
|
|
384
|
-
def __post_init__(self, params=None):
|
|
403
|
+
def __post_init__(self, params: Params[Array] | None = None):
|
|
385
404
|
"""
|
|
386
405
|
Note that neither __init__ or __post_init__ are called when udating a
|
|
387
406
|
Module with eqx.tree_at!
|
|
@@ -395,25 +414,27 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
395
414
|
|
|
396
415
|
def _get_dynamic_loss_batch(
|
|
397
416
|
self, batch: PDEStatioBatch
|
|
398
|
-
) -> Float[Array, "batch_size dimension"]:
|
|
417
|
+
) -> Float[Array, " batch_size dimension"]:
|
|
399
418
|
return batch.domain_batch
|
|
400
419
|
|
|
401
420
|
def _get_normalization_loss_batch(
|
|
402
421
|
self, _
|
|
403
|
-
) -> Float[Array, "nb_norm_samples dimension"]:
|
|
404
|
-
return (self.norm_samples,)
|
|
422
|
+
) -> tuple[Float[Array, " nb_norm_samples dimension"]]:
|
|
423
|
+
return (self.norm_samples,) # type: ignore -> cannot narrow a class attr
|
|
424
|
+
|
|
425
|
+
# we could have used typing.cast though
|
|
405
426
|
|
|
406
427
|
def _get_observations_loss_batch(
|
|
407
428
|
self, batch: PDEStatioBatch
|
|
408
|
-
) -> Float[Array, "batch_size obs_dim"]:
|
|
409
|
-
return
|
|
429
|
+
) -> Float[Array, " batch_size obs_dim"]:
|
|
430
|
+
return batch.obs_batch_dict["pinn_in"]
|
|
410
431
|
|
|
411
432
|
def __call__(self, *args, **kwargs):
|
|
412
433
|
return self.evaluate(*args, **kwargs)
|
|
413
434
|
|
|
414
435
|
def evaluate(
|
|
415
|
-
self, params: Params, batch: PDEStatioBatch
|
|
416
|
-
) -> tuple[Float[Array, "
|
|
436
|
+
self, params: Params[Array], batch: PDEStatioBatch
|
|
437
|
+
) -> tuple[Float[Array, " "], LossDictPDEStatio]:
|
|
417
438
|
"""
|
|
418
439
|
Evaluate the loss function at a batch of points for given parameters.
|
|
419
440
|
|
|
@@ -444,9 +465,9 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
444
465
|
self.dynamic_loss.evaluate,
|
|
445
466
|
self.u,
|
|
446
467
|
self._get_dynamic_loss_batch(batch),
|
|
447
|
-
_set_derivatives(params, self.derivative_keys.dyn_loss),
|
|
468
|
+
_set_derivatives(params, self.derivative_keys.dyn_loss), # type: ignore
|
|
448
469
|
self.vmap_in_axes + vmap_in_axes_params,
|
|
449
|
-
self.loss_weights.dyn_loss,
|
|
470
|
+
self.loss_weights.dyn_loss, # type: ignore
|
|
450
471
|
)
|
|
451
472
|
else:
|
|
452
473
|
mse_dyn_loss = jnp.array(0.0)
|
|
@@ -456,24 +477,28 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
456
477
|
mse_norm_loss = normalization_loss_apply(
|
|
457
478
|
self.u,
|
|
458
479
|
self._get_normalization_loss_batch(batch),
|
|
459
|
-
_set_derivatives(params, self.derivative_keys.norm_loss),
|
|
480
|
+
_set_derivatives(params, self.derivative_keys.norm_loss), # type: ignore
|
|
460
481
|
vmap_in_axes_params,
|
|
461
|
-
self.norm_weights,
|
|
462
|
-
self.loss_weights.norm_loss,
|
|
482
|
+
self.norm_weights, # type: ignore -> can't get the __post_init__ narrowing here
|
|
483
|
+
self.loss_weights.norm_loss, # type: ignore
|
|
463
484
|
)
|
|
464
485
|
else:
|
|
465
486
|
mse_norm_loss = jnp.array(0.0)
|
|
466
487
|
|
|
467
488
|
# boundary part
|
|
468
|
-
if
|
|
489
|
+
if (
|
|
490
|
+
self.omega_boundary_condition is not None
|
|
491
|
+
and self.omega_boundary_dim is not None
|
|
492
|
+
and self.omega_boundary_fun is not None
|
|
493
|
+
): # pyright cannot narrow down the three None otherwise as it is class attribute
|
|
469
494
|
mse_boundary_loss = boundary_condition_apply(
|
|
470
495
|
self.u,
|
|
471
496
|
batch,
|
|
472
|
-
_set_derivatives(params, self.derivative_keys.boundary_loss),
|
|
497
|
+
_set_derivatives(params, self.derivative_keys.boundary_loss), # type: ignore
|
|
473
498
|
self.omega_boundary_fun,
|
|
474
499
|
self.omega_boundary_condition,
|
|
475
500
|
self.omega_boundary_dim,
|
|
476
|
-
self.loss_weights.boundary_loss,
|
|
501
|
+
self.loss_weights.boundary_loss, # type: ignore
|
|
477
502
|
)
|
|
478
503
|
else:
|
|
479
504
|
mse_boundary_loss = jnp.array(0.0)
|
|
@@ -486,10 +511,10 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
486
511
|
mse_observation_loss = observations_loss_apply(
|
|
487
512
|
self.u,
|
|
488
513
|
self._get_observations_loss_batch(batch),
|
|
489
|
-
_set_derivatives(params, self.derivative_keys.observations),
|
|
514
|
+
_set_derivatives(params, self.derivative_keys.observations), # type: ignore
|
|
490
515
|
self.vmap_in_axes + vmap_in_axes_params,
|
|
491
516
|
batch.obs_batch_dict["val"],
|
|
492
|
-
self.loss_weights.observations,
|
|
517
|
+
self.loss_weights.observations, # type: ignore
|
|
493
518
|
self.obs_slice,
|
|
494
519
|
)
|
|
495
520
|
else:
|
|
@@ -505,8 +530,6 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
505
530
|
"norm_loss": mse_norm_loss,
|
|
506
531
|
"boundary_loss": mse_boundary_loss,
|
|
507
532
|
"observations": mse_observation_loss,
|
|
508
|
-
"initial_condition": jnp.array(0.0), # for compatibility in the
|
|
509
|
-
# tree_map of SystemLoss
|
|
510
533
|
}
|
|
511
534
|
)
|
|
512
535
|
|
|
@@ -527,9 +550,9 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
527
550
|
|
|
528
551
|
Parameters
|
|
529
552
|
----------
|
|
530
|
-
u :
|
|
553
|
+
u : AbstractPINN
|
|
531
554
|
the PINN
|
|
532
|
-
dynamic_loss :
|
|
555
|
+
dynamic_loss : PDENonStatio
|
|
533
556
|
the non stationary PDE dynamic part of the loss, basically the differential
|
|
534
557
|
operator $\mathcal{N}[u](t, x)$. Should implement a method
|
|
535
558
|
`dynamic_loss.evaluate(t, x, u, params)`.
|
|
@@ -553,11 +576,11 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
553
576
|
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
554
577
|
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
555
578
|
is `"nn_params"` for each composant of the loss.
|
|
556
|
-
omega_boundary_fun :
|
|
579
|
+
omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
|
|
557
580
|
The function to be matched in the border condition (can be None) or a
|
|
558
581
|
dictionary of such functions as values and keys as described
|
|
559
582
|
in `omega_boundary_condition`.
|
|
560
|
-
omega_boundary_condition : str |
|
|
583
|
+
omega_boundary_condition : str | dict[str, str], default=None
|
|
561
584
|
Either None (no condition, by default), or a string defining
|
|
562
585
|
the boundary condition (Dirichlet or Von Neumann),
|
|
563
586
|
or a dictionary with such strings as values. In this case,
|
|
@@ -568,17 +591,17 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
568
591
|
a particular boundary condition on this facet.
|
|
569
592
|
The facet called “xmin”, resp. “xmax” etc., in 2D,
|
|
570
593
|
refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
|
|
571
|
-
omega_boundary_dim : slice |
|
|
594
|
+
omega_boundary_dim : slice | dict[str, slice], default=None
|
|
572
595
|
Either None, or a slice object or a dictionary of slice objects as
|
|
573
596
|
values and keys as described in `omega_boundary_condition`.
|
|
574
597
|
`omega_boundary_dim` indicates which dimension(s) of the PINN
|
|
575
598
|
will be forced to match the boundary condition.
|
|
576
599
|
Note that it must be a slice and not an integer
|
|
577
600
|
(but a preprocessing of the user provided argument takes care of it)
|
|
578
|
-
norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
|
|
601
|
+
norm_samples : Float[Array, " nb_norm_samples dimension"], default=None
|
|
579
602
|
Monte-Carlo sample points for computing the
|
|
580
603
|
normalization constant. Default is None.
|
|
581
|
-
norm_weights : Float[Array, "nb_norm_samples"] | float | int, default=None
|
|
604
|
+
norm_weights : Float[Array, " nb_norm_samples"] | float | int, default=None
|
|
582
605
|
The importance sampling weights for Monte-Carlo integration of the
|
|
583
606
|
normalization constant. Must be provided if `norm_samples` is provided.
|
|
584
607
|
`norm_weights` should have the same leading dimension as
|
|
@@ -589,20 +612,25 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
589
612
|
obs_slice : slice, default=None
|
|
590
613
|
slice object specifying the begininning/ending of the PINN output
|
|
591
614
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
615
|
+
t0 : float | Float[Array, " 1"], default=None
|
|
616
|
+
The time at which to apply the initial condition. If None, the time
|
|
617
|
+
is set to `0` by default.
|
|
592
618
|
initial_condition_fun : Callable, default=None
|
|
593
|
-
A function representing the
|
|
594
|
-
(default) then no initial condition is applied
|
|
595
|
-
params : InitVar[Params], default=None
|
|
596
|
-
The main Params object of the problem needed to instanciate the
|
|
597
|
-
DerivativeKeysODE if the latter is not specified.
|
|
619
|
+
A function representing the initial condition at `t0`. If None
|
|
620
|
+
(default) then no initial condition is applied.
|
|
621
|
+
params : InitVar[Params[Array]], default=None
|
|
622
|
+
The main `Params` object of the problem needed to instanciate the
|
|
623
|
+
`DerivativeKeysODE` if the latter is not specified.
|
|
598
624
|
|
|
599
625
|
"""
|
|
600
626
|
|
|
627
|
+
dynamic_loss: PDENonStatio | None
|
|
601
628
|
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
602
629
|
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
603
630
|
initial_condition_fun: Callable | None = eqx.field(
|
|
604
631
|
kw_only=True, default=None, static=True
|
|
605
632
|
)
|
|
633
|
+
t0: float | Float[Array, " 1"] | None = eqx.field(kw_only=True, default=None)
|
|
606
634
|
|
|
607
635
|
_max_norm_samples_omega: Int = eqx.field(init=False, static=True)
|
|
608
636
|
_max_norm_time_slices: Int = eqx.field(init=False, static=True)
|
|
@@ -624,6 +652,21 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
624
652
|
"Initial condition wasn't provided. Be sure to cover for that"
|
|
625
653
|
"case (e.g by. hardcoding it into the PINN output)."
|
|
626
654
|
)
|
|
655
|
+
# some checks for t0
|
|
656
|
+
if isinstance(self.t0, Array):
|
|
657
|
+
if not self.t0.shape: # e.g. user input: jnp.array(0.)
|
|
658
|
+
self.t0 = jnp.array([self.t0])
|
|
659
|
+
elif self.t0.shape != (1,):
|
|
660
|
+
raise ValueError(
|
|
661
|
+
f"Wrong self.t0 input. It should be"
|
|
662
|
+
f"a float or an array of shape (1,). Got shape: {self.t0.shape}"
|
|
663
|
+
)
|
|
664
|
+
elif isinstance(self.t0, float): # e.g. user input: 0
|
|
665
|
+
self.t0 = jnp.array([self.t0])
|
|
666
|
+
elif self.t0 is None:
|
|
667
|
+
self.t0 = jnp.array([0])
|
|
668
|
+
else:
|
|
669
|
+
raise ValueError("Wrong value for t0")
|
|
627
670
|
|
|
628
671
|
# witht the variables below we avoid memory overflow since a cartesian
|
|
629
672
|
# product is taken
|
|
@@ -632,28 +675,30 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
632
675
|
|
|
633
676
|
def _get_dynamic_loss_batch(
|
|
634
677
|
self, batch: PDENonStatioBatch
|
|
635
|
-
) -> Float[Array, "batch_size 1+dimension"]:
|
|
678
|
+
) -> Float[Array, " batch_size 1+dimension"]:
|
|
636
679
|
return batch.domain_batch
|
|
637
680
|
|
|
638
681
|
def _get_normalization_loss_batch(
|
|
639
682
|
self, batch: PDENonStatioBatch
|
|
640
|
-
) ->
|
|
683
|
+
) -> tuple[
|
|
684
|
+
Float[Array, " nb_norm_time_slices 1"], Float[Array, " nb_norm_samples dim"]
|
|
685
|
+
]:
|
|
641
686
|
return (
|
|
642
687
|
batch.domain_batch[: self._max_norm_time_slices, 0:1],
|
|
643
|
-
self.norm_samples[: self._max_norm_samples_omega],
|
|
688
|
+
self.norm_samples[: self._max_norm_samples_omega], # type: ignore -> cannot narrow a class attr
|
|
644
689
|
)
|
|
645
690
|
|
|
646
691
|
def _get_observations_loss_batch(
|
|
647
692
|
self, batch: PDENonStatioBatch
|
|
648
|
-
) ->
|
|
649
|
-
return
|
|
693
|
+
) -> Float[Array, " batch_size 1+dim"]:
|
|
694
|
+
return batch.obs_batch_dict["pinn_in"]
|
|
650
695
|
|
|
651
696
|
def __call__(self, *args, **kwargs):
|
|
652
697
|
return self.evaluate(*args, **kwargs)
|
|
653
698
|
|
|
654
699
|
def evaluate(
|
|
655
|
-
self, params: Params, batch: PDENonStatioBatch
|
|
656
|
-
) -> tuple[Float[Array, "
|
|
700
|
+
self, params: Params[Array], batch: PDENonStatioBatch
|
|
701
|
+
) -> tuple[Float[Array, " "], LossDictPDENonStatio]:
|
|
657
702
|
"""
|
|
658
703
|
Evaluate the loss function at a batch of points for given parameters.
|
|
659
704
|
|
|
@@ -670,6 +715,7 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
670
715
|
inputs/outputs/parameters
|
|
671
716
|
"""
|
|
672
717
|
omega_batch = batch.initial_batch
|
|
718
|
+
assert omega_batch is not None
|
|
673
719
|
|
|
674
720
|
# Retrieve the optional eq_params_batch
|
|
675
721
|
# and update eq_params with the latter
|
|
@@ -682,17 +728,19 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
682
728
|
|
|
683
729
|
# For mse_dyn_loss, mse_norm_loss, mse_boundary_loss,
|
|
684
730
|
# mse_observation_loss we use the evaluate from parent class
|
|
685
|
-
partial_mse, partial_mse_terms = super().evaluate(params, batch)
|
|
731
|
+
partial_mse, partial_mse_terms = super().evaluate(params, batch) # type: ignore
|
|
732
|
+
# ignore because batch is not PDEStatioBatch. We could use typing.cast though
|
|
686
733
|
|
|
687
734
|
# initial condition
|
|
688
735
|
if self.initial_condition_fun is not None:
|
|
689
736
|
mse_initial_condition = initial_condition_apply(
|
|
690
737
|
self.u,
|
|
691
738
|
omega_batch,
|
|
692
|
-
_set_derivatives(params, self.derivative_keys.initial_condition),
|
|
739
|
+
_set_derivatives(params, self.derivative_keys.initial_condition), # type: ignore
|
|
693
740
|
(0,) + vmap_in_axes_params,
|
|
694
741
|
self.initial_condition_fun,
|
|
695
|
-
self.
|
|
742
|
+
self.t0, # type: ignore can't get the narrowing in __post_init__
|
|
743
|
+
self.loss_weights.initial_condition, # type: ignore
|
|
696
744
|
)
|
|
697
745
|
else:
|
|
698
746
|
mse_initial_condition = jnp.array(0.0)
|
|
@@ -704,425 +752,3 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
704
752
|
**partial_mse_terms,
|
|
705
753
|
"initial_condition": mse_initial_condition,
|
|
706
754
|
}
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
class SystemLossPDE(eqx.Module):
|
|
710
|
-
r"""
|
|
711
|
-
Class to implement a system of PDEs.
|
|
712
|
-
The goal is to give maximum freedom to the user. The class is created with
|
|
713
|
-
a dict of dynamic loss, and dictionaries of all the objects that are used
|
|
714
|
-
in LossPDENonStatio and LossPDEStatio. When then iterate
|
|
715
|
-
over the dynamic losses that compose the system. All the PINNs with all the
|
|
716
|
-
parameter dictionaries are passed as arguments to each dynamic loss
|
|
717
|
-
evaluate functions; it is inside the dynamic loss that specification are
|
|
718
|
-
performed.
|
|
719
|
-
|
|
720
|
-
**Note:** All the dictionaries (except `dynamic_loss_dict`) must have the same keys.
|
|
721
|
-
Indeed, these dictionaries (except `dynamic_loss_dict`) are tied to one
|
|
722
|
-
solution.
|
|
723
|
-
|
|
724
|
-
Parameters
|
|
725
|
-
----------
|
|
726
|
-
u_dict : Dict[str, eqx.Module]
|
|
727
|
-
dict of PINNs
|
|
728
|
-
loss_weights : LossWeightsPDEDict
|
|
729
|
-
A dictionary of LossWeightsODE
|
|
730
|
-
derivative_keys_dict : Dict[str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio], default=None
|
|
731
|
-
A dictionnary of DerivativeKeysPDEStatio or DerivativeKeysPDENonStatio
|
|
732
|
-
specifying what field of `params`
|
|
733
|
-
should be used during gradient computations for each of the terms of
|
|
734
|
-
the total loss, for each of the loss in the system. Default is
|
|
735
|
-
`"nn_params`" everywhere.
|
|
736
|
-
dynamic_loss_dict : Dict[str, PDEStatio | PDENonStatio]
|
|
737
|
-
A dict of dynamic part of the loss, basically the differential
|
|
738
|
-
operator $\mathcal{N}[u](t, x)$ or $\mathcal{N}[u](x)$.
|
|
739
|
-
key_dict : Dict[str, Key], default=None
|
|
740
|
-
A dictionary of JAX PRNG keys. The dictionary keys of key_dict must
|
|
741
|
-
match that of u_dict. See LossPDEStatio or LossPDENonStatio for
|
|
742
|
-
more details.
|
|
743
|
-
omega_boundary_fun_dict : Dict[str, Callable | Dict[str, Callable] | None], default=None
|
|
744
|
-
A dict of of function or of dict of functions or of None
|
|
745
|
-
(see doc for `omega_boundary_fun` in
|
|
746
|
-
LossPDEStatio or LossPDENonStatio). Default is None.
|
|
747
|
-
Must share the keys of `u_dict`.
|
|
748
|
-
omega_boundary_condition_dict : Dict[str, str | Dict[str, str] | None], default=None
|
|
749
|
-
A dict of strings or of dict of strings or of None
|
|
750
|
-
(see doc for `omega_boundary_condition_dict` in
|
|
751
|
-
LossPDEStatio or LossPDENonStatio). Default is None.
|
|
752
|
-
Must share the keys of `u_dict`
|
|
753
|
-
omega_boundary_dim_dict : Dict[str, slice | Dict[str, slice] | None], default=None
|
|
754
|
-
A dict of slices or of dict of slices or of None
|
|
755
|
-
(see doc for `omega_boundary_dim` in
|
|
756
|
-
LossPDEStatio or LossPDENonStatio). Default is None.
|
|
757
|
-
Must share the keys of `u_dict`
|
|
758
|
-
initial_condition_fun_dict : Dict[str, Callable | None], default=None
|
|
759
|
-
A dict of functions representing the temporal initial condition (None
|
|
760
|
-
value is possible). If None
|
|
761
|
-
(default) then no temporal boundary condition is applied
|
|
762
|
-
Must share the keys of `u_dict`
|
|
763
|
-
norm_samples_dict : Dict[str, Float[Array, "nb_norm_samples dimension"] | None, default=None
|
|
764
|
-
A dict of Monte-Carlo sample points for computing the
|
|
765
|
-
normalization constant. Default is None.
|
|
766
|
-
Must share the keys of `u_dict`
|
|
767
|
-
norm_weights_dict : Dict[str, Array[Float, "nb_norm_samples"] | float | int | None] | None, default=None
|
|
768
|
-
A dict of jnp.array with the same keys as `u_dict`. The importance
|
|
769
|
-
sampling weights for Monte-Carlo integration of the
|
|
770
|
-
normalization constant for each element of u_dict. Must be provided if
|
|
771
|
-
`norm_samples_dict` is provided.
|
|
772
|
-
`norm_weights_dict[key]` should have the same leading dimension as
|
|
773
|
-
`norm_samples_dict[key]` for each `key`.
|
|
774
|
-
Alternatively, the user can pass a float or an integer.
|
|
775
|
-
For each key, an array of similar shape to `norm_samples_dict[key]`
|
|
776
|
-
or shape `(1,)` is expected. These corresponds to the weights $w_k =
|
|
777
|
-
\frac{1}{q(x_k)}$ where $q(\cdot)$ is the proposal p.d.f. and $x_k$ are
|
|
778
|
-
the Monte-Carlo samples. Default is None
|
|
779
|
-
obs_slice_dict : Dict[str, slice | None] | None, default=None
|
|
780
|
-
dict of obs_slice, with keys from `u_dict` to designate the
|
|
781
|
-
output(s) channels that are forced to observed values, for each
|
|
782
|
-
PINNs. Default is None. But if a value is given, all the entries of
|
|
783
|
-
`u_dict` must be represented here with default value `jnp.s_[...]`
|
|
784
|
-
if no particular slice is to be given
|
|
785
|
-
params : InitVar[ParamsDict], default=None
|
|
786
|
-
The main Params object of the problem needed to instanciate the
|
|
787
|
-
DerivativeKeysODE if the latter is not specified.
|
|
788
|
-
|
|
789
|
-
"""
|
|
790
|
-
|
|
791
|
-
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
792
|
-
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
793
|
-
u_dict: Dict[str, eqx.Module]
|
|
794
|
-
dynamic_loss_dict: Dict[str, PDEStatio | PDENonStatio]
|
|
795
|
-
key_dict: Dict[str, Key] | None = eqx.field(kw_only=True, default=None)
|
|
796
|
-
derivative_keys_dict: Dict[
|
|
797
|
-
str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio | None
|
|
798
|
-
] = eqx.field(kw_only=True, default=None)
|
|
799
|
-
omega_boundary_fun_dict: Dict[str, Callable | Dict[str, Callable] | None] | None = (
|
|
800
|
-
eqx.field(kw_only=True, default=None, static=True)
|
|
801
|
-
)
|
|
802
|
-
omega_boundary_condition_dict: Dict[str, str | Dict[str, str] | None] | None = (
|
|
803
|
-
eqx.field(kw_only=True, default=None, static=True)
|
|
804
|
-
)
|
|
805
|
-
omega_boundary_dim_dict: Dict[str, slice | Dict[str, slice] | None] | None = (
|
|
806
|
-
eqx.field(kw_only=True, default=None, static=True)
|
|
807
|
-
)
|
|
808
|
-
initial_condition_fun_dict: Dict[str, Callable | None] | None = eqx.field(
|
|
809
|
-
kw_only=True, default=None, static=True
|
|
810
|
-
)
|
|
811
|
-
norm_samples_dict: Dict[str, Float[Array, "nb_norm_samples dimension"]] | None = (
|
|
812
|
-
eqx.field(kw_only=True, default=None)
|
|
813
|
-
)
|
|
814
|
-
norm_weights_dict: (
|
|
815
|
-
Dict[str, Float[Array, "nb_norm_samples dimension"] | float | int | None] | None
|
|
816
|
-
) = eqx.field(kw_only=True, default=None)
|
|
817
|
-
obs_slice_dict: Dict[str, slice | None] | None = eqx.field(
|
|
818
|
-
kw_only=True, default=None, static=True
|
|
819
|
-
)
|
|
820
|
-
|
|
821
|
-
# For the user loss_weights are passed as a LossWeightsPDEDict (with internal
|
|
822
|
-
# dictionary having keys in u_dict and / or dynamic_loss_dict)
|
|
823
|
-
loss_weights: InitVar[LossWeightsPDEDict | None] = eqx.field(
|
|
824
|
-
kw_only=True, default=None
|
|
825
|
-
)
|
|
826
|
-
params_dict: InitVar[ParamsDict] = eqx.field(kw_only=True, default=None)
|
|
827
|
-
|
|
828
|
-
# following have init=False and are set in the __post_init__
|
|
829
|
-
u_constraints_dict: Dict[str, LossPDEStatio | LossPDENonStatio] = eqx.field(
|
|
830
|
-
init=False
|
|
831
|
-
)
|
|
832
|
-
derivative_keys_dyn_loss: DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio = (
|
|
833
|
-
eqx.field(init=False)
|
|
834
|
-
)
|
|
835
|
-
u_dict_with_none: Dict[str, None] = eqx.field(init=False)
|
|
836
|
-
# internally the loss weights are handled with a dictionary
|
|
837
|
-
_loss_weights: Dict[str, dict] = eqx.field(init=False)
|
|
838
|
-
|
|
839
|
-
def __post_init__(self, loss_weights=None, params_dict=None):
|
|
840
|
-
# a dictionary that will be useful at different places
|
|
841
|
-
self.u_dict_with_none = {k: None for k in self.u_dict.keys()}
|
|
842
|
-
# First, for all the optional dict,
|
|
843
|
-
# if the user did not provide at all this optional argument,
|
|
844
|
-
# we make sure there is a null ponderating loss_weight and we
|
|
845
|
-
# create a dummy dict with the required keys and all the values to
|
|
846
|
-
# None
|
|
847
|
-
if self.key_dict is None:
|
|
848
|
-
self.key_dict = self.u_dict_with_none
|
|
849
|
-
if self.omega_boundary_fun_dict is None:
|
|
850
|
-
self.omega_boundary_fun_dict = self.u_dict_with_none
|
|
851
|
-
if self.omega_boundary_condition_dict is None:
|
|
852
|
-
self.omega_boundary_condition_dict = self.u_dict_with_none
|
|
853
|
-
if self.omega_boundary_dim_dict is None:
|
|
854
|
-
self.omega_boundary_dim_dict = self.u_dict_with_none
|
|
855
|
-
if self.initial_condition_fun_dict is None:
|
|
856
|
-
self.initial_condition_fun_dict = self.u_dict_with_none
|
|
857
|
-
if self.norm_samples_dict is None:
|
|
858
|
-
self.norm_samples_dict = self.u_dict_with_none
|
|
859
|
-
if self.norm_weights_dict is None:
|
|
860
|
-
self.norm_weights_dict = self.u_dict_with_none
|
|
861
|
-
if self.obs_slice_dict is None:
|
|
862
|
-
self.obs_slice_dict = {k: jnp.s_[...] for k in self.u_dict.keys()}
|
|
863
|
-
if self.u_dict.keys() != self.obs_slice_dict.keys():
|
|
864
|
-
raise ValueError("obs_slice_dict should have same keys as u_dict")
|
|
865
|
-
if self.derivative_keys_dict is None:
|
|
866
|
-
self.derivative_keys_dict = {
|
|
867
|
-
k: None
|
|
868
|
-
for k in set(
|
|
869
|
-
list(self.dynamic_loss_dict.keys()) + list(self.u_dict.keys())
|
|
870
|
-
)
|
|
871
|
-
}
|
|
872
|
-
# set() because we can have duplicate entries and in this case we
|
|
873
|
-
# say it corresponds to the same derivative_keys_dict entry
|
|
874
|
-
# we need both because the constraints (all but dyn_loss) will be
|
|
875
|
-
# done by iterating on u_dict while the dyn_loss will be by
|
|
876
|
-
# iterating on dynamic_loss_dict. So each time we will require dome
|
|
877
|
-
# derivative_keys_dict
|
|
878
|
-
|
|
879
|
-
# derivative keys for the u_constraints. Note that we create missing
|
|
880
|
-
# DerivativeKeysODE around a Params object and not ParamsDict
|
|
881
|
-
# this works because u_dict.keys == params_dict.nn_params.keys()
|
|
882
|
-
for k in self.u_dict.keys():
|
|
883
|
-
if self.derivative_keys_dict[k] is None:
|
|
884
|
-
if self.u_dict[k].eq_type == "statio_PDE":
|
|
885
|
-
self.derivative_keys_dict[k] = DerivativeKeysPDEStatio(
|
|
886
|
-
params=params_dict.extract_params(k)
|
|
887
|
-
)
|
|
888
|
-
else:
|
|
889
|
-
self.derivative_keys_dict[k] = DerivativeKeysPDENonStatio(
|
|
890
|
-
params=params_dict.extract_params(k)
|
|
891
|
-
)
|
|
892
|
-
|
|
893
|
-
# Second we make sure that all the dicts (except dynamic_loss_dict) have the same keys
|
|
894
|
-
if (
|
|
895
|
-
self.u_dict.keys() != self.key_dict.keys()
|
|
896
|
-
or self.u_dict.keys() != self.omega_boundary_fun_dict.keys()
|
|
897
|
-
or self.u_dict.keys() != self.omega_boundary_condition_dict.keys()
|
|
898
|
-
or self.u_dict.keys() != self.omega_boundary_dim_dict.keys()
|
|
899
|
-
or self.u_dict.keys() != self.initial_condition_fun_dict.keys()
|
|
900
|
-
or self.u_dict.keys() != self.norm_samples_dict.keys()
|
|
901
|
-
or self.u_dict.keys() != self.norm_weights_dict.keys()
|
|
902
|
-
):
|
|
903
|
-
raise ValueError("All the dicts concerning the PINNs should have same keys")
|
|
904
|
-
|
|
905
|
-
self._loss_weights = self.set_loss_weights(loss_weights)
|
|
906
|
-
|
|
907
|
-
# Third, in order not to benefit from LossPDEStatio and
|
|
908
|
-
# LossPDENonStatio and in order to factorize code, we create internally
|
|
909
|
-
# some losses object to implement the constraints on the solutions.
|
|
910
|
-
# We will not use the dynamic loss term
|
|
911
|
-
self.u_constraints_dict = {}
|
|
912
|
-
for i in self.u_dict.keys():
|
|
913
|
-
if self.u_dict[i].eq_type == "statio_PDE":
|
|
914
|
-
self.u_constraints_dict[i] = LossPDEStatio(
|
|
915
|
-
u=self.u_dict[i],
|
|
916
|
-
loss_weights=LossWeightsPDENonStatio(
|
|
917
|
-
dyn_loss=0.0,
|
|
918
|
-
norm_loss=1.0,
|
|
919
|
-
boundary_loss=1.0,
|
|
920
|
-
observations=1.0,
|
|
921
|
-
initial_condition=1.0,
|
|
922
|
-
),
|
|
923
|
-
dynamic_loss=None,
|
|
924
|
-
key=self.key_dict[i],
|
|
925
|
-
derivative_keys=self.derivative_keys_dict[i],
|
|
926
|
-
omega_boundary_fun=self.omega_boundary_fun_dict[i],
|
|
927
|
-
omega_boundary_condition=self.omega_boundary_condition_dict[i],
|
|
928
|
-
omega_boundary_dim=self.omega_boundary_dim_dict[i],
|
|
929
|
-
norm_samples=self.norm_samples_dict[i],
|
|
930
|
-
norm_weights=self.norm_weights_dict[i],
|
|
931
|
-
obs_slice=self.obs_slice_dict[i],
|
|
932
|
-
)
|
|
933
|
-
elif self.u_dict[i].eq_type == "nonstatio_PDE":
|
|
934
|
-
self.u_constraints_dict[i] = LossPDENonStatio(
|
|
935
|
-
u=self.u_dict[i],
|
|
936
|
-
loss_weights=LossWeightsPDENonStatio(
|
|
937
|
-
dyn_loss=0.0,
|
|
938
|
-
norm_loss=1.0,
|
|
939
|
-
boundary_loss=1.0,
|
|
940
|
-
observations=1.0,
|
|
941
|
-
initial_condition=1.0,
|
|
942
|
-
),
|
|
943
|
-
dynamic_loss=None,
|
|
944
|
-
key=self.key_dict[i],
|
|
945
|
-
derivative_keys=self.derivative_keys_dict[i],
|
|
946
|
-
omega_boundary_fun=self.omega_boundary_fun_dict[i],
|
|
947
|
-
omega_boundary_condition=self.omega_boundary_condition_dict[i],
|
|
948
|
-
omega_boundary_dim=self.omega_boundary_dim_dict[i],
|
|
949
|
-
initial_condition_fun=self.initial_condition_fun_dict[i],
|
|
950
|
-
norm_samples=self.norm_samples_dict[i],
|
|
951
|
-
norm_weights=self.norm_weights_dict[i],
|
|
952
|
-
obs_slice=self.obs_slice_dict[i],
|
|
953
|
-
)
|
|
954
|
-
else:
|
|
955
|
-
raise ValueError(
|
|
956
|
-
"Wrong value for self.u_dict[i].eq_type[i], "
|
|
957
|
-
f"got {self.u_dict[i].eq_type[i]}"
|
|
958
|
-
)
|
|
959
|
-
|
|
960
|
-
# derivative keys for the dynamic loss. Note that we create a
|
|
961
|
-
# DerivativeKeysODE around a ParamsDict object because a whole
|
|
962
|
-
# params_dict is feed to DynamicLoss.evaluate functions (extract_params
|
|
963
|
-
# happen inside it)
|
|
964
|
-
self.derivative_keys_dyn_loss = DerivativeKeysPDENonStatio(params=params_dict)
|
|
965
|
-
|
|
966
|
-
# also make sure we only have PINNs or SPINNs
|
|
967
|
-
if not (
|
|
968
|
-
all(isinstance(value, PINN) for value in self.u_dict.values())
|
|
969
|
-
or all(isinstance(value, SPINN) for value in self.u_dict.values())
|
|
970
|
-
):
|
|
971
|
-
raise ValueError(
|
|
972
|
-
"We only accept dictionary of PINNs or dictionary of SPINNs"
|
|
973
|
-
)
|
|
974
|
-
|
|
975
|
-
def set_loss_weights(
|
|
976
|
-
self, loss_weights_init: LossWeightsPDEDict
|
|
977
|
-
) -> dict[str, dict]:
|
|
978
|
-
"""
|
|
979
|
-
This rather complex function enables the user to specify a simple
|
|
980
|
-
loss_weights=LossWeightsPDEDict(dyn_loss=1., initial_condition=Tmax)
|
|
981
|
-
for ponderating values being applied to all the equations of the
|
|
982
|
-
system... So all the transformations are handled here
|
|
983
|
-
"""
|
|
984
|
-
_loss_weights = {}
|
|
985
|
-
for k in fields(loss_weights_init):
|
|
986
|
-
v = getattr(loss_weights_init, k.name)
|
|
987
|
-
if isinstance(v, dict):
|
|
988
|
-
for vv in v.keys():
|
|
989
|
-
if not isinstance(vv, (int, float)) and not (
|
|
990
|
-
isinstance(vv, Array)
|
|
991
|
-
and ((vv.shape == (1,) or len(vv.shape) == 0))
|
|
992
|
-
):
|
|
993
|
-
# TODO improve that
|
|
994
|
-
raise ValueError(
|
|
995
|
-
f"loss values cannot be vectorial here, got {vv}"
|
|
996
|
-
)
|
|
997
|
-
if k.name == "dyn_loss":
|
|
998
|
-
if v.keys() == self.dynamic_loss_dict.keys():
|
|
999
|
-
_loss_weights[k.name] = v
|
|
1000
|
-
else:
|
|
1001
|
-
raise ValueError(
|
|
1002
|
-
"Keys in nested dictionary of loss_weights"
|
|
1003
|
-
" do not match dynamic_loss_dict keys"
|
|
1004
|
-
)
|
|
1005
|
-
else:
|
|
1006
|
-
if v.keys() == self.u_dict.keys():
|
|
1007
|
-
_loss_weights[k.name] = v
|
|
1008
|
-
else:
|
|
1009
|
-
raise ValueError(
|
|
1010
|
-
"Keys in nested dictionary of loss_weights"
|
|
1011
|
-
" do not match u_dict keys"
|
|
1012
|
-
)
|
|
1013
|
-
if v is None:
|
|
1014
|
-
_loss_weights[k.name] = {kk: 0 for kk in self.u_dict.keys()}
|
|
1015
|
-
else:
|
|
1016
|
-
if not isinstance(v, (int, float)) and not (
|
|
1017
|
-
isinstance(v, Array) and ((v.shape == (1,) or len(v.shape) == 0))
|
|
1018
|
-
):
|
|
1019
|
-
# TODO improve that
|
|
1020
|
-
raise ValueError(f"loss values cannot be vectorial here, got {v}")
|
|
1021
|
-
if k.name == "dyn_loss":
|
|
1022
|
-
_loss_weights[k.name] = {
|
|
1023
|
-
kk: v for kk in self.dynamic_loss_dict.keys()
|
|
1024
|
-
}
|
|
1025
|
-
else:
|
|
1026
|
-
_loss_weights[k.name] = {kk: v for kk in self.u_dict.keys()}
|
|
1027
|
-
return _loss_weights
|
|
1028
|
-
|
|
1029
|
-
def __call__(self, *args, **kwargs):
|
|
1030
|
-
return self.evaluate(*args, **kwargs)
|
|
1031
|
-
|
|
1032
|
-
def evaluate(
|
|
1033
|
-
self,
|
|
1034
|
-
params_dict: ParamsDict,
|
|
1035
|
-
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
1036
|
-
) -> tuple[Float[Array, "1"], dict[str, float]]:
|
|
1037
|
-
"""
|
|
1038
|
-
Evaluate the loss function at a batch of points for given parameters.
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
Parameters
|
|
1042
|
-
---------
|
|
1043
|
-
params_dict
|
|
1044
|
-
Parameters at which the losses of the system are evaluated
|
|
1045
|
-
batch
|
|
1046
|
-
Such named tuples are composed of batch of points in the
|
|
1047
|
-
domain, a batch of points in the domain
|
|
1048
|
-
border, (a batch of time points a for PDENonStatioBatch) and an
|
|
1049
|
-
optional additional batch of parameters (eg. for metamodeling)
|
|
1050
|
-
and an optional additional batch of observed
|
|
1051
|
-
inputs/outputs/parameters
|
|
1052
|
-
"""
|
|
1053
|
-
if self.u_dict.keys() != params_dict.nn_params.keys():
|
|
1054
|
-
raise ValueError("u_dict and params_dict[nn_params] should have same keys ")
|
|
1055
|
-
|
|
1056
|
-
vmap_in_axes = (0,)
|
|
1057
|
-
|
|
1058
|
-
# Retrieve the optional eq_params_batch
|
|
1059
|
-
# and update eq_params with the latter
|
|
1060
|
-
# and update vmap_in_axes
|
|
1061
|
-
if batch.param_batch_dict is not None:
|
|
1062
|
-
eq_params_batch_dict = batch.param_batch_dict
|
|
1063
|
-
|
|
1064
|
-
# feed the eq_params with the batch
|
|
1065
|
-
for k in eq_params_batch_dict.keys():
|
|
1066
|
-
params_dict.eq_params[k] = eq_params_batch_dict[k]
|
|
1067
|
-
|
|
1068
|
-
vmap_in_axes_params = _get_vmap_in_axes_params(
|
|
1069
|
-
batch.param_batch_dict, params_dict
|
|
1070
|
-
)
|
|
1071
|
-
|
|
1072
|
-
def dyn_loss_for_one_key(dyn_loss, loss_weight):
|
|
1073
|
-
"""The function used in tree_map"""
|
|
1074
|
-
return dynamic_loss_apply(
|
|
1075
|
-
dyn_loss.evaluate,
|
|
1076
|
-
self.u_dict,
|
|
1077
|
-
(
|
|
1078
|
-
batch.domain_batch
|
|
1079
|
-
if isinstance(batch, PDEStatioBatch)
|
|
1080
|
-
else batch.domain_batch
|
|
1081
|
-
),
|
|
1082
|
-
_set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
|
|
1083
|
-
vmap_in_axes + vmap_in_axes_params,
|
|
1084
|
-
loss_weight,
|
|
1085
|
-
u_type=list(self.u_dict.values())[0].__class__.__base__,
|
|
1086
|
-
)
|
|
1087
|
-
|
|
1088
|
-
dyn_loss_mse_dict = jax.tree_util.tree_map(
|
|
1089
|
-
dyn_loss_for_one_key,
|
|
1090
|
-
self.dynamic_loss_dict,
|
|
1091
|
-
self._loss_weights["dyn_loss"],
|
|
1092
|
-
is_leaf=lambda x: isinstance(
|
|
1093
|
-
x, (PDEStatio, PDENonStatio)
|
|
1094
|
-
), # before when dynamic losses
|
|
1095
|
-
# where plain (unregister pytree) node classes, we could not traverse
|
|
1096
|
-
# this level. Now that dynamic losses are eqx.Module they can be
|
|
1097
|
-
# traversed by tree map recursion. Hence we need to specify to that
|
|
1098
|
-
# we want to stop at this level
|
|
1099
|
-
)
|
|
1100
|
-
mse_dyn_loss = jax.tree_util.tree_reduce(
|
|
1101
|
-
lambda x, y: x + y, jax.tree_util.tree_leaves(dyn_loss_mse_dict)
|
|
1102
|
-
)
|
|
1103
|
-
|
|
1104
|
-
# boundary conditions, normalization conditions, observation_loss,
|
|
1105
|
-
# initial condition... loss this is done via the internal
|
|
1106
|
-
# LossPDEStatio and NonStatio
|
|
1107
|
-
loss_weight_struct = {
|
|
1108
|
-
"dyn_loss": "*",
|
|
1109
|
-
"norm_loss": "*",
|
|
1110
|
-
"boundary_loss": "*",
|
|
1111
|
-
"observations": "*",
|
|
1112
|
-
"initial_condition": "*",
|
|
1113
|
-
}
|
|
1114
|
-
# we need to do the following for the tree_mapping to work
|
|
1115
|
-
if batch.obs_batch_dict is None:
|
|
1116
|
-
batch = append_obs_batch(batch, self.u_dict_with_none)
|
|
1117
|
-
total_loss, res_dict = constraints_system_loss_apply(
|
|
1118
|
-
self.u_constraints_dict,
|
|
1119
|
-
batch,
|
|
1120
|
-
params_dict,
|
|
1121
|
-
self._loss_weights,
|
|
1122
|
-
loss_weight_struct,
|
|
1123
|
-
)
|
|
1124
|
-
|
|
1125
|
-
# Add the mse_dyn_loss from the previous computations
|
|
1126
|
-
total_loss += mse_dyn_loss
|
|
1127
|
-
res_dict["dyn_loss"] += mse_dyn_loss
|
|
1128
|
-
return total_loss, res_dict
|