jinns 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +116 -189
- jinns/loss/_DynamicLossAbstract.py +45 -68
- jinns/loss/_LossODE.py +71 -336
- jinns/loss/_LossPDE.py +176 -513
- jinns/loss/__init__.py +28 -6
- jinns/loss/_abstract_loss.py +15 -0
- jinns/loss/_boundary_conditions.py +22 -21
- jinns/loss/_loss_utils.py +98 -173
- jinns/loss/_loss_weights.py +12 -44
- jinns/loss/_operators.py +84 -76
- jinns/nn/__init__.py +22 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +434 -0
- jinns/nn/_mlp.py +217 -0
- jinns/nn/_pinn.py +204 -0
- jinns/nn/_ppinn.py +239 -0
- jinns/{utils → nn}/_save_load.py +39 -53
- jinns/nn/_spinn.py +123 -0
- jinns/nn/_spinn_mlp.py +202 -0
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +38 -37
- jinns/solver/_rar.py +82 -65
- jinns/solver/_solve.py +111 -71
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -5
- jinns/utils/_containers.py +12 -9
- jinns/utils/_types.py +11 -57
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/METADATA +11 -10
- jinns-1.4.0.dist-info/RECORD +53 -0
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns/utils/_hyperpinn.py +0 -420
- jinns/utils/_pinn.py +0 -324
- jinns/utils/_ppinn.py +0 -227
- jinns/utils/_spinn.py +0 -249
- jinns-1.2.0.dist-info/RECORD +0 -41
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossPDE.py
CHANGED
|
@@ -1,16 +1,16 @@
|
|
|
1
|
-
# pylint: disable=unsubscriptable-object, no-member
|
|
2
1
|
"""
|
|
3
2
|
Main module to implement a PDE loss in jinns
|
|
4
3
|
"""
|
|
4
|
+
|
|
5
5
|
from __future__ import (
|
|
6
6
|
annotations,
|
|
7
7
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
8
|
|
|
9
9
|
import abc
|
|
10
|
-
from dataclasses import InitVar
|
|
11
|
-
from typing import TYPE_CHECKING,
|
|
10
|
+
from dataclasses import InitVar
|
|
11
|
+
from typing import TYPE_CHECKING, Callable, TypedDict
|
|
12
|
+
from types import EllipsisType
|
|
12
13
|
import warnings
|
|
13
|
-
import jax
|
|
14
14
|
import jax.numpy as jnp
|
|
15
15
|
import equinox as eqx
|
|
16
16
|
from jaxtyping import Float, Array, Key, Int
|
|
@@ -20,9 +20,7 @@ from jinns.loss._loss_utils import (
|
|
|
20
20
|
normalization_loss_apply,
|
|
21
21
|
observations_loss_apply,
|
|
22
22
|
initial_condition_apply,
|
|
23
|
-
constraints_system_loss_apply,
|
|
24
23
|
)
|
|
25
|
-
from jinns.data._DataGenerators import append_obs_batch
|
|
26
24
|
from jinns.parameters._params import (
|
|
27
25
|
_get_vmap_in_axes_params,
|
|
28
26
|
_update_eq_params_dict,
|
|
@@ -32,19 +30,30 @@ from jinns.parameters._derivative_keys import (
|
|
|
32
30
|
DerivativeKeysPDEStatio,
|
|
33
31
|
DerivativeKeysPDENonStatio,
|
|
34
32
|
)
|
|
33
|
+
from jinns.loss._abstract_loss import AbstractLoss
|
|
35
34
|
from jinns.loss._loss_weights import (
|
|
36
35
|
LossWeightsPDEStatio,
|
|
37
36
|
LossWeightsPDENonStatio,
|
|
38
|
-
LossWeightsPDEDict,
|
|
39
37
|
)
|
|
40
|
-
from jinns.loss._DynamicLossAbstract import PDEStatio, PDENonStatio
|
|
41
|
-
from jinns.utils._pinn import PINN
|
|
42
|
-
from jinns.utils._spinn import SPINN
|
|
43
38
|
from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
|
|
44
39
|
|
|
45
40
|
|
|
46
41
|
if TYPE_CHECKING:
|
|
47
|
-
|
|
42
|
+
# imports for type hints only
|
|
43
|
+
from jinns.parameters._params import Params
|
|
44
|
+
from jinns.nn._abstract_pinn import AbstractPINN
|
|
45
|
+
from jinns.loss import PDENonStatio, PDEStatio
|
|
46
|
+
from jinns.utils._types import BoundaryConditionFun
|
|
47
|
+
|
|
48
|
+
class LossDictPDEStatio(TypedDict):
|
|
49
|
+
dyn_loss: Float[Array, " "]
|
|
50
|
+
norm_loss: Float[Array, " "]
|
|
51
|
+
boundary_loss: Float[Array, " "]
|
|
52
|
+
observations: Float[Array, " "]
|
|
53
|
+
|
|
54
|
+
class LossDictPDENonStatio(LossDictPDEStatio):
|
|
55
|
+
initial_condition: Float[Array, " "]
|
|
56
|
+
|
|
48
57
|
|
|
49
58
|
_IMPLEMENTED_BOUNDARY_CONDITIONS = [
|
|
50
59
|
"dirichlet",
|
|
@@ -53,8 +62,8 @@ _IMPLEMENTED_BOUNDARY_CONDITIONS = [
|
|
|
53
62
|
]
|
|
54
63
|
|
|
55
64
|
|
|
56
|
-
class _LossPDEAbstract(
|
|
57
|
-
"""
|
|
65
|
+
class _LossPDEAbstract(AbstractLoss):
|
|
66
|
+
r"""
|
|
58
67
|
Parameters
|
|
59
68
|
----------
|
|
60
69
|
|
|
@@ -69,11 +78,11 @@ class _LossPDEAbstract(eqx.Module):
|
|
|
69
78
|
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
70
79
|
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
71
80
|
is `"nn_params"` for each composant of the loss.
|
|
72
|
-
omega_boundary_fun :
|
|
81
|
+
omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
|
|
73
82
|
The function to be matched in the border condition (can be None) or a
|
|
74
83
|
dictionary of such functions as values and keys as described
|
|
75
84
|
in `omega_boundary_condition`.
|
|
76
|
-
omega_boundary_condition : str |
|
|
85
|
+
omega_boundary_condition : str | dict[str, str], default=None
|
|
77
86
|
Either None (no condition, by default), or a string defining
|
|
78
87
|
the boundary condition (Dirichlet or Von Neumann),
|
|
79
88
|
or a dictionary with such strings as values. In this case,
|
|
@@ -84,24 +93,29 @@ class _LossPDEAbstract(eqx.Module):
|
|
|
84
93
|
a particular boundary condition on this facet.
|
|
85
94
|
The facet called “xmin”, resp. “xmax” etc., in 2D,
|
|
86
95
|
refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
|
|
87
|
-
omega_boundary_dim : slice |
|
|
96
|
+
omega_boundary_dim : slice | dict[str, slice], default=None
|
|
88
97
|
Either None, or a slice object or a dictionary of slice objects as
|
|
89
98
|
values and keys as described in `omega_boundary_condition`.
|
|
90
99
|
`omega_boundary_dim` indicates which dimension(s) of the PINN
|
|
91
100
|
will be forced to match the boundary condition.
|
|
92
101
|
Note that it must be a slice and not an integer
|
|
93
102
|
(but a preprocessing of the user provided argument takes care of it)
|
|
94
|
-
norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
|
|
95
|
-
|
|
103
|
+
norm_samples : Float[Array, " nb_norm_samples dimension"], default=None
|
|
104
|
+
Monte-Carlo sample points for computing the
|
|
96
105
|
normalization constant. Default is None.
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
106
|
+
norm_weights : Float[Array, " nb_norm_samples"] | float | int, default=None
|
|
107
|
+
The importance sampling weights for Monte-Carlo integration of the
|
|
108
|
+
normalization constant. Must be provided if `norm_samples` is provided.
|
|
109
|
+
`norm_weights` should be broadcastble to
|
|
110
|
+
`norm_samples`.
|
|
111
|
+
Alternatively, the user can pass a float or an integer that will be
|
|
112
|
+
made broadcastable to `norm_samples`.
|
|
113
|
+
These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
|
|
114
|
+
$q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
|
|
115
|
+
obs_slice : EllipsisType | slice, default=None
|
|
102
116
|
slice object specifying the begininning/ending of the PINN output
|
|
103
117
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
104
|
-
params : InitVar[Params], default=None
|
|
118
|
+
params : InitVar[Params[Array]], default=None
|
|
105
119
|
The main Params object of the problem needed to instanciate the
|
|
106
120
|
DerivativeKeysODE if the latter is not specified.
|
|
107
121
|
"""
|
|
@@ -115,24 +129,28 @@ class _LossPDEAbstract(eqx.Module):
|
|
|
115
129
|
loss_weights: LossWeightsPDEStatio | LossWeightsPDENonStatio | None = eqx.field(
|
|
116
130
|
kw_only=True, default=None
|
|
117
131
|
)
|
|
118
|
-
omega_boundary_fun:
|
|
132
|
+
omega_boundary_fun: (
|
|
133
|
+
BoundaryConditionFun | dict[str, BoundaryConditionFun] | None
|
|
134
|
+
) = eqx.field(kw_only=True, default=None, static=True)
|
|
135
|
+
omega_boundary_condition: str | dict[str, str] | None = eqx.field(
|
|
119
136
|
kw_only=True, default=None, static=True
|
|
120
137
|
)
|
|
121
|
-
|
|
138
|
+
omega_boundary_dim: slice | dict[str, slice] | None = eqx.field(
|
|
122
139
|
kw_only=True, default=None, static=True
|
|
123
140
|
)
|
|
124
|
-
|
|
125
|
-
kw_only=True, default=None
|
|
141
|
+
norm_samples: Float[Array, " nb_norm_samples dimension"] | None = eqx.field(
|
|
142
|
+
kw_only=True, default=None
|
|
126
143
|
)
|
|
127
|
-
|
|
144
|
+
norm_weights: Float[Array, " nb_norm_samples"] | float | int | None = eqx.field(
|
|
128
145
|
kw_only=True, default=None
|
|
129
146
|
)
|
|
130
|
-
|
|
131
|
-
|
|
147
|
+
obs_slice: EllipsisType | slice | None = eqx.field(
|
|
148
|
+
kw_only=True, default=None, static=True
|
|
149
|
+
)
|
|
132
150
|
|
|
133
|
-
params: InitVar[Params] = eqx.field(kw_only=True, default=None)
|
|
151
|
+
params: InitVar[Params[Array]] = eqx.field(kw_only=True, default=None)
|
|
134
152
|
|
|
135
|
-
def __post_init__(self, params=None):
|
|
153
|
+
def __post_init__(self, params: Params[Array] | None = None):
|
|
136
154
|
"""
|
|
137
155
|
Note that neither __init__ or __post_init__ are called when udating a
|
|
138
156
|
Module with eqx.tree_at
|
|
@@ -222,6 +240,11 @@ class _LossPDEAbstract(eqx.Module):
|
|
|
222
240
|
)
|
|
223
241
|
|
|
224
242
|
if isinstance(self.omega_boundary_fun, dict):
|
|
243
|
+
if not isinstance(self.omega_boundary_dim, dict):
|
|
244
|
+
raise ValueError(
|
|
245
|
+
"If omega_boundary_fun is a dict then"
|
|
246
|
+
" omega_boundary_dim should also be a dict"
|
|
247
|
+
)
|
|
225
248
|
if self.omega_boundary_dim is None:
|
|
226
249
|
self.omega_boundary_dim = {
|
|
227
250
|
k: jnp.s_[::] for k in self.omega_boundary_fun.keys()
|
|
@@ -251,15 +274,34 @@ class _LossPDEAbstract(eqx.Module):
|
|
|
251
274
|
if not isinstance(self.omega_boundary_dim, slice):
|
|
252
275
|
raise ValueError("self.omega_boundary_dim must be a jnp.s_ object")
|
|
253
276
|
|
|
254
|
-
if self.norm_samples is not None
|
|
255
|
-
|
|
277
|
+
if self.norm_samples is not None:
|
|
278
|
+
if self.norm_weights is None:
|
|
279
|
+
raise ValueError(
|
|
280
|
+
"`norm_weights` must be provided when `norm_samples` is used!"
|
|
281
|
+
)
|
|
282
|
+
if isinstance(self.norm_weights, (int, float)):
|
|
283
|
+
self.norm_weights = self.norm_weights * jnp.ones(
|
|
284
|
+
(self.norm_samples.shape[0],)
|
|
285
|
+
)
|
|
286
|
+
if isinstance(self.norm_weights, Array):
|
|
287
|
+
if not (self.norm_weights.shape[0] == self.norm_samples.shape[0]):
|
|
288
|
+
raise ValueError(
|
|
289
|
+
"self.norm_weights and "
|
|
290
|
+
"self.norm_samples must have the same leading dimension"
|
|
291
|
+
)
|
|
292
|
+
else:
|
|
293
|
+
raise ValueError("Wrong type for self.norm_weights")
|
|
294
|
+
|
|
295
|
+
@abc.abstractmethod
|
|
296
|
+
def __call__(self, *_, **__):
|
|
297
|
+
pass
|
|
256
298
|
|
|
257
299
|
@abc.abstractmethod
|
|
258
300
|
def evaluate(
|
|
259
301
|
self: eqx.Module,
|
|
260
|
-
params: Params,
|
|
302
|
+
params: Params[Array],
|
|
261
303
|
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
262
|
-
) -> tuple[Float,
|
|
304
|
+
) -> tuple[Float[Array, " "], LossDictPDEStatio | LossDictPDENonStatio]:
|
|
263
305
|
raise NotImplementedError
|
|
264
306
|
|
|
265
307
|
|
|
@@ -276,9 +318,9 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
276
318
|
|
|
277
319
|
Parameters
|
|
278
320
|
----------
|
|
279
|
-
u :
|
|
321
|
+
u : AbstractPINN
|
|
280
322
|
the PINN
|
|
281
|
-
dynamic_loss :
|
|
323
|
+
dynamic_loss : PDEStatio
|
|
282
324
|
the stationary PDE dynamic part of the loss, basically the differential
|
|
283
325
|
operator $\mathcal{N}[u](x)$. Should implement a method
|
|
284
326
|
`dynamic_loss.evaluate(x, u, params)`.
|
|
@@ -301,11 +343,11 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
301
343
|
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
302
344
|
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
303
345
|
is `"nn_params"` for each composant of the loss.
|
|
304
|
-
omega_boundary_fun :
|
|
346
|
+
omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
|
|
305
347
|
The function to be matched in the border condition (can be None) or a
|
|
306
348
|
dictionary of such functions as values and keys as described
|
|
307
349
|
in `omega_boundary_condition`.
|
|
308
|
-
omega_boundary_condition : str |
|
|
350
|
+
omega_boundary_condition : str | dict[str, str], default=None
|
|
309
351
|
Either None (no condition, by default), or a string defining
|
|
310
352
|
the boundary condition (Dirichlet or Von Neumann),
|
|
311
353
|
or a dictionary with such strings as values. In this case,
|
|
@@ -316,24 +358,28 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
316
358
|
a particular boundary condition on this facet.
|
|
317
359
|
The facet called “xmin”, resp. “xmax” etc., in 2D,
|
|
318
360
|
refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
|
|
319
|
-
omega_boundary_dim : slice |
|
|
361
|
+
omega_boundary_dim : slice | dict[str, slice], default=None
|
|
320
362
|
Either None, or a slice object or a dictionary of slice objects as
|
|
321
363
|
values and keys as described in `omega_boundary_condition`.
|
|
322
364
|
`omega_boundary_dim` indicates which dimension(s) of the PINN
|
|
323
365
|
will be forced to match the boundary condition.
|
|
324
366
|
Note that it must be a slice and not an integer
|
|
325
367
|
(but a preprocessing of the user provided argument takes care of it)
|
|
326
|
-
norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
|
|
327
|
-
|
|
368
|
+
norm_samples : Float[Array, " nb_norm_samples dimension"], default=None
|
|
369
|
+
Monte-Carlo sample points for computing the
|
|
328
370
|
normalization constant. Default is None.
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
371
|
+
norm_weights : Float[Array, " nb_norm_samples"] | float | int, default=None
|
|
372
|
+
The importance sampling weights for Monte-Carlo integration of the
|
|
373
|
+
normalization constant. Must be provided if `norm_samples` is provided.
|
|
374
|
+
`norm_weights` should have the same leading dimension as
|
|
375
|
+
`norm_samples`.
|
|
376
|
+
Alternatively, the user can pass a float or an integer.
|
|
377
|
+
These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
|
|
378
|
+
$q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
|
|
333
379
|
obs_slice : slice, default=None
|
|
334
380
|
slice object specifying the begininning/ending of the PINN output
|
|
335
381
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
336
|
-
params : InitVar[Params], default=None
|
|
382
|
+
params : InitVar[Params[Array]], default=None
|
|
337
383
|
The main Params object of the problem needed to instanciate the
|
|
338
384
|
DerivativeKeysODE if the latter is not specified.
|
|
339
385
|
|
|
@@ -348,13 +394,13 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
348
394
|
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
349
395
|
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
350
396
|
|
|
351
|
-
u:
|
|
352
|
-
dynamic_loss:
|
|
397
|
+
u: AbstractPINN
|
|
398
|
+
dynamic_loss: PDEStatio | None
|
|
353
399
|
key: Key | None = eqx.field(kw_only=True, default=None)
|
|
354
400
|
|
|
355
401
|
vmap_in_axes: tuple[Int] = eqx.field(init=False, static=True)
|
|
356
402
|
|
|
357
|
-
def __post_init__(self, params=None):
|
|
403
|
+
def __post_init__(self, params: Params[Array] | None = None):
|
|
358
404
|
"""
|
|
359
405
|
Note that neither __init__ or __post_init__ are called when udating a
|
|
360
406
|
Module with eqx.tree_at!
|
|
@@ -368,25 +414,27 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
368
414
|
|
|
369
415
|
def _get_dynamic_loss_batch(
|
|
370
416
|
self, batch: PDEStatioBatch
|
|
371
|
-
) -> Float[Array, "batch_size dimension"]:
|
|
417
|
+
) -> Float[Array, " batch_size dimension"]:
|
|
372
418
|
return batch.domain_batch
|
|
373
419
|
|
|
374
420
|
def _get_normalization_loss_batch(
|
|
375
421
|
self, _
|
|
376
|
-
) -> Float[Array, "nb_norm_samples dimension"]:
|
|
377
|
-
return (self.norm_samples,)
|
|
422
|
+
) -> tuple[Float[Array, " nb_norm_samples dimension"]]:
|
|
423
|
+
return (self.norm_samples,) # type: ignore -> cannot narrow a class attr
|
|
424
|
+
|
|
425
|
+
# we could have used typing.cast though
|
|
378
426
|
|
|
379
427
|
def _get_observations_loss_batch(
|
|
380
428
|
self, batch: PDEStatioBatch
|
|
381
|
-
) -> Float[Array, "batch_size obs_dim"]:
|
|
382
|
-
return
|
|
429
|
+
) -> Float[Array, " batch_size obs_dim"]:
|
|
430
|
+
return batch.obs_batch_dict["pinn_in"]
|
|
383
431
|
|
|
384
432
|
def __call__(self, *args, **kwargs):
|
|
385
433
|
return self.evaluate(*args, **kwargs)
|
|
386
434
|
|
|
387
435
|
def evaluate(
|
|
388
|
-
self, params: Params, batch: PDEStatioBatch
|
|
389
|
-
) -> tuple[Float[Array, "
|
|
436
|
+
self, params: Params[Array], batch: PDEStatioBatch
|
|
437
|
+
) -> tuple[Float[Array, " "], LossDictPDEStatio]:
|
|
390
438
|
"""
|
|
391
439
|
Evaluate the loss function at a batch of points for given parameters.
|
|
392
440
|
|
|
@@ -417,9 +465,9 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
417
465
|
self.dynamic_loss.evaluate,
|
|
418
466
|
self.u,
|
|
419
467
|
self._get_dynamic_loss_batch(batch),
|
|
420
|
-
_set_derivatives(params, self.derivative_keys.dyn_loss),
|
|
468
|
+
_set_derivatives(params, self.derivative_keys.dyn_loss), # type: ignore
|
|
421
469
|
self.vmap_in_axes + vmap_in_axes_params,
|
|
422
|
-
self.loss_weights.dyn_loss,
|
|
470
|
+
self.loss_weights.dyn_loss, # type: ignore
|
|
423
471
|
)
|
|
424
472
|
else:
|
|
425
473
|
mse_dyn_loss = jnp.array(0.0)
|
|
@@ -429,24 +477,28 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
429
477
|
mse_norm_loss = normalization_loss_apply(
|
|
430
478
|
self.u,
|
|
431
479
|
self._get_normalization_loss_batch(batch),
|
|
432
|
-
_set_derivatives(params, self.derivative_keys.norm_loss),
|
|
480
|
+
_set_derivatives(params, self.derivative_keys.norm_loss), # type: ignore
|
|
433
481
|
vmap_in_axes_params,
|
|
434
|
-
self.
|
|
435
|
-
self.loss_weights.norm_loss,
|
|
482
|
+
self.norm_weights, # type: ignore -> can't get the __post_init__ narrowing here
|
|
483
|
+
self.loss_weights.norm_loss, # type: ignore
|
|
436
484
|
)
|
|
437
485
|
else:
|
|
438
486
|
mse_norm_loss = jnp.array(0.0)
|
|
439
487
|
|
|
440
488
|
# boundary part
|
|
441
|
-
if
|
|
489
|
+
if (
|
|
490
|
+
self.omega_boundary_condition is not None
|
|
491
|
+
and self.omega_boundary_dim is not None
|
|
492
|
+
and self.omega_boundary_fun is not None
|
|
493
|
+
): # pyright cannot narrow down the three None otherwise as it is class attribute
|
|
442
494
|
mse_boundary_loss = boundary_condition_apply(
|
|
443
495
|
self.u,
|
|
444
496
|
batch,
|
|
445
|
-
_set_derivatives(params, self.derivative_keys.boundary_loss),
|
|
497
|
+
_set_derivatives(params, self.derivative_keys.boundary_loss), # type: ignore
|
|
446
498
|
self.omega_boundary_fun,
|
|
447
499
|
self.omega_boundary_condition,
|
|
448
500
|
self.omega_boundary_dim,
|
|
449
|
-
self.loss_weights.boundary_loss,
|
|
501
|
+
self.loss_weights.boundary_loss, # type: ignore
|
|
450
502
|
)
|
|
451
503
|
else:
|
|
452
504
|
mse_boundary_loss = jnp.array(0.0)
|
|
@@ -459,10 +511,10 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
459
511
|
mse_observation_loss = observations_loss_apply(
|
|
460
512
|
self.u,
|
|
461
513
|
self._get_observations_loss_batch(batch),
|
|
462
|
-
_set_derivatives(params, self.derivative_keys.observations),
|
|
514
|
+
_set_derivatives(params, self.derivative_keys.observations), # type: ignore
|
|
463
515
|
self.vmap_in_axes + vmap_in_axes_params,
|
|
464
516
|
batch.obs_batch_dict["val"],
|
|
465
|
-
self.loss_weights.observations,
|
|
517
|
+
self.loss_weights.observations, # type: ignore
|
|
466
518
|
self.obs_slice,
|
|
467
519
|
)
|
|
468
520
|
else:
|
|
@@ -478,8 +530,6 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
478
530
|
"norm_loss": mse_norm_loss,
|
|
479
531
|
"boundary_loss": mse_boundary_loss,
|
|
480
532
|
"observations": mse_observation_loss,
|
|
481
|
-
"initial_condition": jnp.array(0.0), # for compatibility in the
|
|
482
|
-
# tree_map of SystemLoss
|
|
483
533
|
}
|
|
484
534
|
)
|
|
485
535
|
|
|
@@ -500,9 +550,9 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
500
550
|
|
|
501
551
|
Parameters
|
|
502
552
|
----------
|
|
503
|
-
u :
|
|
553
|
+
u : AbstractPINN
|
|
504
554
|
the PINN
|
|
505
|
-
dynamic_loss :
|
|
555
|
+
dynamic_loss : PDENonStatio
|
|
506
556
|
the non stationary PDE dynamic part of the loss, basically the differential
|
|
507
557
|
operator $\mathcal{N}[u](t, x)$. Should implement a method
|
|
508
558
|
`dynamic_loss.evaluate(t, x, u, params)`.
|
|
@@ -526,11 +576,11 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
526
576
|
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
527
577
|
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
528
578
|
is `"nn_params"` for each composant of the loss.
|
|
529
|
-
omega_boundary_fun :
|
|
579
|
+
omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
|
|
530
580
|
The function to be matched in the border condition (can be None) or a
|
|
531
581
|
dictionary of such functions as values and keys as described
|
|
532
582
|
in `omega_boundary_condition`.
|
|
533
|
-
omega_boundary_condition : str |
|
|
583
|
+
omega_boundary_condition : str | dict[str, str], default=None
|
|
534
584
|
Either None (no condition, by default), or a string defining
|
|
535
585
|
the boundary condition (Dirichlet or Von Neumann),
|
|
536
586
|
or a dictionary with such strings as values. In this case,
|
|
@@ -541,37 +591,46 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
541
591
|
a particular boundary condition on this facet.
|
|
542
592
|
The facet called “xmin”, resp. “xmax” etc., in 2D,
|
|
543
593
|
refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
|
|
544
|
-
omega_boundary_dim : slice |
|
|
594
|
+
omega_boundary_dim : slice | dict[str, slice], default=None
|
|
545
595
|
Either None, or a slice object or a dictionary of slice objects as
|
|
546
596
|
values and keys as described in `omega_boundary_condition`.
|
|
547
597
|
`omega_boundary_dim` indicates which dimension(s) of the PINN
|
|
548
598
|
will be forced to match the boundary condition.
|
|
549
599
|
Note that it must be a slice and not an integer
|
|
550
600
|
(but a preprocessing of the user provided argument takes care of it)
|
|
551
|
-
norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
|
|
552
|
-
|
|
601
|
+
norm_samples : Float[Array, " nb_norm_samples dimension"], default=None
|
|
602
|
+
Monte-Carlo sample points for computing the
|
|
553
603
|
normalization constant. Default is None.
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
604
|
+
norm_weights : Float[Array, " nb_norm_samples"] | float | int, default=None
|
|
605
|
+
The importance sampling weights for Monte-Carlo integration of the
|
|
606
|
+
normalization constant. Must be provided if `norm_samples` is provided.
|
|
607
|
+
`norm_weights` should have the same leading dimension as
|
|
608
|
+
`norm_samples`.
|
|
609
|
+
Alternatively, the user can pass a float or an integer.
|
|
610
|
+
These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
|
|
611
|
+
$q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
|
|
558
612
|
obs_slice : slice, default=None
|
|
559
613
|
slice object specifying the begininning/ending of the PINN output
|
|
560
614
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
615
|
+
t0 : float | Float[Array, " 1"], default=None
|
|
616
|
+
The time at which to apply the initial condition. If None, the time
|
|
617
|
+
is set to `0` by default.
|
|
561
618
|
initial_condition_fun : Callable, default=None
|
|
562
|
-
A function representing the
|
|
563
|
-
(default) then no initial condition is applied
|
|
564
|
-
params : InitVar[Params], default=None
|
|
565
|
-
The main Params object of the problem needed to instanciate the
|
|
566
|
-
DerivativeKeysODE if the latter is not specified.
|
|
619
|
+
A function representing the initial condition at `t0`. If None
|
|
620
|
+
(default) then no initial condition is applied.
|
|
621
|
+
params : InitVar[Params[Array]], default=None
|
|
622
|
+
The main `Params` object of the problem needed to instanciate the
|
|
623
|
+
`DerivativeKeysODE` if the latter is not specified.
|
|
567
624
|
|
|
568
625
|
"""
|
|
569
626
|
|
|
627
|
+
dynamic_loss: PDENonStatio | None
|
|
570
628
|
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
571
629
|
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
572
630
|
initial_condition_fun: Callable | None = eqx.field(
|
|
573
631
|
kw_only=True, default=None, static=True
|
|
574
632
|
)
|
|
633
|
+
t0: float | Float[Array, " 1"] | None = eqx.field(kw_only=True, default=None)
|
|
575
634
|
|
|
576
635
|
_max_norm_samples_omega: Int = eqx.field(init=False, static=True)
|
|
577
636
|
_max_norm_time_slices: Int = eqx.field(init=False, static=True)
|
|
@@ -593,6 +652,21 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
593
652
|
"Initial condition wasn't provided. Be sure to cover for that"
|
|
594
653
|
"case (e.g by. hardcoding it into the PINN output)."
|
|
595
654
|
)
|
|
655
|
+
# some checks for t0
|
|
656
|
+
if isinstance(self.t0, Array):
|
|
657
|
+
if not self.t0.shape: # e.g. user input: jnp.array(0.)
|
|
658
|
+
self.t0 = jnp.array([self.t0])
|
|
659
|
+
elif self.t0.shape != (1,):
|
|
660
|
+
raise ValueError(
|
|
661
|
+
f"Wrong self.t0 input. It should be"
|
|
662
|
+
f"a float or an array of shape (1,). Got shape: {self.t0.shape}"
|
|
663
|
+
)
|
|
664
|
+
elif isinstance(self.t0, float): # e.g. user input: 0
|
|
665
|
+
self.t0 = jnp.array([self.t0])
|
|
666
|
+
elif self.t0 is None:
|
|
667
|
+
self.t0 = jnp.array([0])
|
|
668
|
+
else:
|
|
669
|
+
raise ValueError("Wrong value for t0")
|
|
596
670
|
|
|
597
671
|
# witht the variables below we avoid memory overflow since a cartesian
|
|
598
672
|
# product is taken
|
|
@@ -601,28 +675,30 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
601
675
|
|
|
602
676
|
def _get_dynamic_loss_batch(
|
|
603
677
|
self, batch: PDENonStatioBatch
|
|
604
|
-
) -> Float[Array, "batch_size 1+dimension"]:
|
|
678
|
+
) -> Float[Array, " batch_size 1+dimension"]:
|
|
605
679
|
return batch.domain_batch
|
|
606
680
|
|
|
607
681
|
def _get_normalization_loss_batch(
|
|
608
682
|
self, batch: PDENonStatioBatch
|
|
609
|
-
) ->
|
|
683
|
+
) -> tuple[
|
|
684
|
+
Float[Array, " nb_norm_time_slices 1"], Float[Array, " nb_norm_samples dim"]
|
|
685
|
+
]:
|
|
610
686
|
return (
|
|
611
687
|
batch.domain_batch[: self._max_norm_time_slices, 0:1],
|
|
612
|
-
self.norm_samples[: self._max_norm_samples_omega],
|
|
688
|
+
self.norm_samples[: self._max_norm_samples_omega], # type: ignore -> cannot narrow a class attr
|
|
613
689
|
)
|
|
614
690
|
|
|
615
691
|
def _get_observations_loss_batch(
|
|
616
692
|
self, batch: PDENonStatioBatch
|
|
617
|
-
) ->
|
|
618
|
-
return
|
|
693
|
+
) -> Float[Array, " batch_size 1+dim"]:
|
|
694
|
+
return batch.obs_batch_dict["pinn_in"]
|
|
619
695
|
|
|
620
696
|
def __call__(self, *args, **kwargs):
|
|
621
697
|
return self.evaluate(*args, **kwargs)
|
|
622
698
|
|
|
623
699
|
def evaluate(
|
|
624
|
-
self, params: Params, batch: PDENonStatioBatch
|
|
625
|
-
) -> tuple[Float[Array, "
|
|
700
|
+
self, params: Params[Array], batch: PDENonStatioBatch
|
|
701
|
+
) -> tuple[Float[Array, " "], LossDictPDENonStatio]:
|
|
626
702
|
"""
|
|
627
703
|
Evaluate the loss function at a batch of points for given parameters.
|
|
628
704
|
|
|
@@ -639,6 +715,7 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
639
715
|
inputs/outputs/parameters
|
|
640
716
|
"""
|
|
641
717
|
omega_batch = batch.initial_batch
|
|
718
|
+
assert omega_batch is not None
|
|
642
719
|
|
|
643
720
|
# Retrieve the optional eq_params_batch
|
|
644
721
|
# and update eq_params with the latter
|
|
@@ -651,17 +728,19 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
651
728
|
|
|
652
729
|
# For mse_dyn_loss, mse_norm_loss, mse_boundary_loss,
|
|
653
730
|
# mse_observation_loss we use the evaluate from parent class
|
|
654
|
-
partial_mse, partial_mse_terms = super().evaluate(params, batch)
|
|
731
|
+
partial_mse, partial_mse_terms = super().evaluate(params, batch) # type: ignore
|
|
732
|
+
# ignore because batch is not PDEStatioBatch. We could use typing.cast though
|
|
655
733
|
|
|
656
734
|
# initial condition
|
|
657
735
|
if self.initial_condition_fun is not None:
|
|
658
736
|
mse_initial_condition = initial_condition_apply(
|
|
659
737
|
self.u,
|
|
660
738
|
omega_batch,
|
|
661
|
-
_set_derivatives(params, self.derivative_keys.initial_condition),
|
|
739
|
+
_set_derivatives(params, self.derivative_keys.initial_condition), # type: ignore
|
|
662
740
|
(0,) + vmap_in_axes_params,
|
|
663
741
|
self.initial_condition_fun,
|
|
664
|
-
self.
|
|
742
|
+
self.t0, # type: ignore can't get the narrowing in __post_init__
|
|
743
|
+
self.loss_weights.initial_condition, # type: ignore
|
|
665
744
|
)
|
|
666
745
|
else:
|
|
667
746
|
mse_initial_condition = jnp.array(0.0)
|
|
@@ -673,419 +752,3 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
673
752
|
**partial_mse_terms,
|
|
674
753
|
"initial_condition": mse_initial_condition,
|
|
675
754
|
}
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
class SystemLossPDE(eqx.Module):
|
|
679
|
-
r"""
|
|
680
|
-
Class to implement a system of PDEs.
|
|
681
|
-
The goal is to give maximum freedom to the user. The class is created with
|
|
682
|
-
a dict of dynamic loss, and dictionaries of all the objects that are used
|
|
683
|
-
in LossPDENonStatio and LossPDEStatio. When then iterate
|
|
684
|
-
over the dynamic losses that compose the system. All the PINNs with all the
|
|
685
|
-
parameter dictionaries are passed as arguments to each dynamic loss
|
|
686
|
-
evaluate functions; it is inside the dynamic loss that specification are
|
|
687
|
-
performed.
|
|
688
|
-
|
|
689
|
-
**Note:** All the dictionaries (except `dynamic_loss_dict`) must have the same keys.
|
|
690
|
-
Indeed, these dictionaries (except `dynamic_loss_dict`) are tied to one
|
|
691
|
-
solution.
|
|
692
|
-
|
|
693
|
-
Parameters
|
|
694
|
-
----------
|
|
695
|
-
u_dict : Dict[str, eqx.Module]
|
|
696
|
-
dict of PINNs
|
|
697
|
-
loss_weights : LossWeightsPDEDict
|
|
698
|
-
A dictionary of LossWeightsODE
|
|
699
|
-
derivative_keys_dict : Dict[str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio], default=None
|
|
700
|
-
A dictionnary of DerivativeKeysPDEStatio or DerivativeKeysPDENonStatio
|
|
701
|
-
specifying what field of `params`
|
|
702
|
-
should be used during gradient computations for each of the terms of
|
|
703
|
-
the total loss, for each of the loss in the system. Default is
|
|
704
|
-
`"nn_params`" everywhere.
|
|
705
|
-
dynamic_loss_dict : Dict[str, PDEStatio | PDENonStatio]
|
|
706
|
-
A dict of dynamic part of the loss, basically the differential
|
|
707
|
-
operator $\mathcal{N}[u](t, x)$ or $\mathcal{N}[u](x)$.
|
|
708
|
-
key_dict : Dict[str, Key], default=None
|
|
709
|
-
A dictionary of JAX PRNG keys. The dictionary keys of key_dict must
|
|
710
|
-
match that of u_dict. See LossPDEStatio or LossPDENonStatio for
|
|
711
|
-
more details.
|
|
712
|
-
omega_boundary_fun_dict : Dict[str, Callable | Dict[str, Callable] | None], default=None
|
|
713
|
-
A dict of of function or of dict of functions or of None
|
|
714
|
-
(see doc for `omega_boundary_fun` in
|
|
715
|
-
LossPDEStatio or LossPDENonStatio). Default is None.
|
|
716
|
-
Must share the keys of `u_dict`.
|
|
717
|
-
omega_boundary_condition_dict : Dict[str, str | Dict[str, str] | None], default=None
|
|
718
|
-
A dict of strings or of dict of strings or of None
|
|
719
|
-
(see doc for `omega_boundary_condition_dict` in
|
|
720
|
-
LossPDEStatio or LossPDENonStatio). Default is None.
|
|
721
|
-
Must share the keys of `u_dict`
|
|
722
|
-
omega_boundary_dim_dict : Dict[str, slice | Dict[str, slice] | None], default=None
|
|
723
|
-
A dict of slices or of dict of slices or of None
|
|
724
|
-
(see doc for `omega_boundary_dim` in
|
|
725
|
-
LossPDEStatio or LossPDENonStatio). Default is None.
|
|
726
|
-
Must share the keys of `u_dict`
|
|
727
|
-
initial_condition_fun_dict : Dict[str, Callable | None], default=None
|
|
728
|
-
A dict of functions representing the temporal initial condition (None
|
|
729
|
-
value is possible). If None
|
|
730
|
-
(default) then no temporal boundary condition is applied
|
|
731
|
-
Must share the keys of `u_dict`
|
|
732
|
-
norm_samples_dict : Dict[str, Float[Array, "nb_norm_samples dimension"] | None, default=None
|
|
733
|
-
A dict of fixed sample point in the space over which to compute the
|
|
734
|
-
normalization constant. Default is None
|
|
735
|
-
Must share the keys of `u_dict`
|
|
736
|
-
norm_int_length_dict : Dict[str, float | None] | None, default=None
|
|
737
|
-
A dict of Float. The domain area
|
|
738
|
-
(or interval length in 1D) upon which we perform the numerical
|
|
739
|
-
integration for each element of u_dict.
|
|
740
|
-
Default is None
|
|
741
|
-
Must share the keys of `u_dict`
|
|
742
|
-
obs_slice_dict : Dict[str, slice | None] | None, default=None
|
|
743
|
-
dict of obs_slice, with keys from `u_dict` to designate the
|
|
744
|
-
output(s) channels that are forced to observed values, for each
|
|
745
|
-
PINNs. Default is None. But if a value is given, all the entries of
|
|
746
|
-
`u_dict` must be represented here with default value `jnp.s_[...]`
|
|
747
|
-
if no particular slice is to be given
|
|
748
|
-
params : InitVar[ParamsDict], default=None
|
|
749
|
-
The main Params object of the problem needed to instanciate the
|
|
750
|
-
DerivativeKeysODE if the latter is not specified.
|
|
751
|
-
|
|
752
|
-
"""
|
|
753
|
-
|
|
754
|
-
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
755
|
-
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
756
|
-
u_dict: Dict[str, eqx.Module]
|
|
757
|
-
dynamic_loss_dict: Dict[str, PDEStatio | PDENonStatio]
|
|
758
|
-
key_dict: Dict[str, Key] | None = eqx.field(kw_only=True, default=None)
|
|
759
|
-
derivative_keys_dict: Dict[
|
|
760
|
-
str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio | None
|
|
761
|
-
] = eqx.field(kw_only=True, default=None)
|
|
762
|
-
omega_boundary_fun_dict: Dict[str, Callable | Dict[str, Callable] | None] | None = (
|
|
763
|
-
eqx.field(kw_only=True, default=None, static=True)
|
|
764
|
-
)
|
|
765
|
-
omega_boundary_condition_dict: Dict[str, str | Dict[str, str] | None] | None = (
|
|
766
|
-
eqx.field(kw_only=True, default=None, static=True)
|
|
767
|
-
)
|
|
768
|
-
omega_boundary_dim_dict: Dict[str, slice | Dict[str, slice] | None] | None = (
|
|
769
|
-
eqx.field(kw_only=True, default=None, static=True)
|
|
770
|
-
)
|
|
771
|
-
initial_condition_fun_dict: Dict[str, Callable | None] | None = eqx.field(
|
|
772
|
-
kw_only=True, default=None, static=True
|
|
773
|
-
)
|
|
774
|
-
norm_samples_dict: Dict[str, Float[Array, "nb_norm_samples dimension"]] | None = (
|
|
775
|
-
eqx.field(kw_only=True, default=None)
|
|
776
|
-
)
|
|
777
|
-
norm_int_length_dict: Dict[str, float | None] | None = eqx.field(
|
|
778
|
-
kw_only=True, default=None
|
|
779
|
-
)
|
|
780
|
-
obs_slice_dict: Dict[str, slice | None] | None = eqx.field(
|
|
781
|
-
kw_only=True, default=None, static=True
|
|
782
|
-
)
|
|
783
|
-
|
|
784
|
-
# For the user loss_weights are passed as a LossWeightsPDEDict (with internal
|
|
785
|
-
# dictionary having keys in u_dict and / or dynamic_loss_dict)
|
|
786
|
-
loss_weights: InitVar[LossWeightsPDEDict | None] = eqx.field(
|
|
787
|
-
kw_only=True, default=None
|
|
788
|
-
)
|
|
789
|
-
params_dict: InitVar[ParamsDict] = eqx.field(kw_only=True, default=None)
|
|
790
|
-
|
|
791
|
-
# following have init=False and are set in the __post_init__
|
|
792
|
-
u_constraints_dict: Dict[str, LossPDEStatio | LossPDENonStatio] = eqx.field(
|
|
793
|
-
init=False
|
|
794
|
-
)
|
|
795
|
-
derivative_keys_dyn_loss: DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio = (
|
|
796
|
-
eqx.field(init=False)
|
|
797
|
-
)
|
|
798
|
-
u_dict_with_none: Dict[str, None] = eqx.field(init=False)
|
|
799
|
-
# internally the loss weights are handled with a dictionary
|
|
800
|
-
_loss_weights: Dict[str, dict] = eqx.field(init=False)
|
|
801
|
-
|
|
802
|
-
def __post_init__(self, loss_weights=None, params_dict=None):
|
|
803
|
-
# a dictionary that will be useful at different places
|
|
804
|
-
self.u_dict_with_none = {k: None for k in self.u_dict.keys()}
|
|
805
|
-
# First, for all the optional dict,
|
|
806
|
-
# if the user did not provide at all this optional argument,
|
|
807
|
-
# we make sure there is a null ponderating loss_weight and we
|
|
808
|
-
# create a dummy dict with the required keys and all the values to
|
|
809
|
-
# None
|
|
810
|
-
if self.key_dict is None:
|
|
811
|
-
self.key_dict = self.u_dict_with_none
|
|
812
|
-
if self.omega_boundary_fun_dict is None:
|
|
813
|
-
self.omega_boundary_fun_dict = self.u_dict_with_none
|
|
814
|
-
if self.omega_boundary_condition_dict is None:
|
|
815
|
-
self.omega_boundary_condition_dict = self.u_dict_with_none
|
|
816
|
-
if self.omega_boundary_dim_dict is None:
|
|
817
|
-
self.omega_boundary_dim_dict = self.u_dict_with_none
|
|
818
|
-
if self.initial_condition_fun_dict is None:
|
|
819
|
-
self.initial_condition_fun_dict = self.u_dict_with_none
|
|
820
|
-
if self.norm_samples_dict is None:
|
|
821
|
-
self.norm_samples_dict = self.u_dict_with_none
|
|
822
|
-
if self.norm_int_length_dict is None:
|
|
823
|
-
self.norm_int_length_dict = self.u_dict_with_none
|
|
824
|
-
if self.obs_slice_dict is None:
|
|
825
|
-
self.obs_slice_dict = {k: jnp.s_[...] for k in self.u_dict.keys()}
|
|
826
|
-
if self.u_dict.keys() != self.obs_slice_dict.keys():
|
|
827
|
-
raise ValueError("obs_slice_dict should have same keys as u_dict")
|
|
828
|
-
if self.derivative_keys_dict is None:
|
|
829
|
-
self.derivative_keys_dict = {
|
|
830
|
-
k: None
|
|
831
|
-
for k in set(
|
|
832
|
-
list(self.dynamic_loss_dict.keys()) + list(self.u_dict.keys())
|
|
833
|
-
)
|
|
834
|
-
}
|
|
835
|
-
# set() because we can have duplicate entries and in this case we
|
|
836
|
-
# say it corresponds to the same derivative_keys_dict entry
|
|
837
|
-
# we need both because the constraints (all but dyn_loss) will be
|
|
838
|
-
# done by iterating on u_dict while the dyn_loss will be by
|
|
839
|
-
# iterating on dynamic_loss_dict. So each time we will require dome
|
|
840
|
-
# derivative_keys_dict
|
|
841
|
-
|
|
842
|
-
# derivative keys for the u_constraints. Note that we create missing
|
|
843
|
-
# DerivativeKeysODE around a Params object and not ParamsDict
|
|
844
|
-
# this works because u_dict.keys == params_dict.nn_params.keys()
|
|
845
|
-
for k in self.u_dict.keys():
|
|
846
|
-
if self.derivative_keys_dict[k] is None:
|
|
847
|
-
if self.u_dict[k].eq_type == "statio_PDE":
|
|
848
|
-
self.derivative_keys_dict[k] = DerivativeKeysPDEStatio(
|
|
849
|
-
params=params_dict.extract_params(k)
|
|
850
|
-
)
|
|
851
|
-
else:
|
|
852
|
-
self.derivative_keys_dict[k] = DerivativeKeysPDENonStatio(
|
|
853
|
-
params=params_dict.extract_params(k)
|
|
854
|
-
)
|
|
855
|
-
|
|
856
|
-
# Second we make sure that all the dicts (except dynamic_loss_dict) have the same keys
|
|
857
|
-
if (
|
|
858
|
-
self.u_dict.keys() != self.key_dict.keys()
|
|
859
|
-
or self.u_dict.keys() != self.omega_boundary_fun_dict.keys()
|
|
860
|
-
or self.u_dict.keys() != self.omega_boundary_condition_dict.keys()
|
|
861
|
-
or self.u_dict.keys() != self.omega_boundary_dim_dict.keys()
|
|
862
|
-
or self.u_dict.keys() != self.initial_condition_fun_dict.keys()
|
|
863
|
-
or self.u_dict.keys() != self.norm_samples_dict.keys()
|
|
864
|
-
or self.u_dict.keys() != self.norm_int_length_dict.keys()
|
|
865
|
-
):
|
|
866
|
-
raise ValueError("All the dicts concerning the PINNs should have same keys")
|
|
867
|
-
|
|
868
|
-
self._loss_weights = self.set_loss_weights(loss_weights)
|
|
869
|
-
|
|
870
|
-
# Third, in order not to benefit from LossPDEStatio and
|
|
871
|
-
# LossPDENonStatio and in order to factorize code, we create internally
|
|
872
|
-
# some losses object to implement the constraints on the solutions.
|
|
873
|
-
# We will not use the dynamic loss term
|
|
874
|
-
self.u_constraints_dict = {}
|
|
875
|
-
for i in self.u_dict.keys():
|
|
876
|
-
if self.u_dict[i].eq_type == "statio_PDE":
|
|
877
|
-
self.u_constraints_dict[i] = LossPDEStatio(
|
|
878
|
-
u=self.u_dict[i],
|
|
879
|
-
loss_weights=LossWeightsPDENonStatio(
|
|
880
|
-
dyn_loss=0.0,
|
|
881
|
-
norm_loss=1.0,
|
|
882
|
-
boundary_loss=1.0,
|
|
883
|
-
observations=1.0,
|
|
884
|
-
initial_condition=1.0,
|
|
885
|
-
),
|
|
886
|
-
dynamic_loss=None,
|
|
887
|
-
key=self.key_dict[i],
|
|
888
|
-
derivative_keys=self.derivative_keys_dict[i],
|
|
889
|
-
omega_boundary_fun=self.omega_boundary_fun_dict[i],
|
|
890
|
-
omega_boundary_condition=self.omega_boundary_condition_dict[i],
|
|
891
|
-
omega_boundary_dim=self.omega_boundary_dim_dict[i],
|
|
892
|
-
norm_samples=self.norm_samples_dict[i],
|
|
893
|
-
norm_int_length=self.norm_int_length_dict[i],
|
|
894
|
-
obs_slice=self.obs_slice_dict[i],
|
|
895
|
-
)
|
|
896
|
-
elif self.u_dict[i].eq_type == "nonstatio_PDE":
|
|
897
|
-
self.u_constraints_dict[i] = LossPDENonStatio(
|
|
898
|
-
u=self.u_dict[i],
|
|
899
|
-
loss_weights=LossWeightsPDENonStatio(
|
|
900
|
-
dyn_loss=0.0,
|
|
901
|
-
norm_loss=1.0,
|
|
902
|
-
boundary_loss=1.0,
|
|
903
|
-
observations=1.0,
|
|
904
|
-
initial_condition=1.0,
|
|
905
|
-
),
|
|
906
|
-
dynamic_loss=None,
|
|
907
|
-
key=self.key_dict[i],
|
|
908
|
-
derivative_keys=self.derivative_keys_dict[i],
|
|
909
|
-
omega_boundary_fun=self.omega_boundary_fun_dict[i],
|
|
910
|
-
omega_boundary_condition=self.omega_boundary_condition_dict[i],
|
|
911
|
-
omega_boundary_dim=self.omega_boundary_dim_dict[i],
|
|
912
|
-
initial_condition_fun=self.initial_condition_fun_dict[i],
|
|
913
|
-
norm_samples=self.norm_samples_dict[i],
|
|
914
|
-
norm_int_length=self.norm_int_length_dict[i],
|
|
915
|
-
obs_slice=self.obs_slice_dict[i],
|
|
916
|
-
)
|
|
917
|
-
else:
|
|
918
|
-
raise ValueError(
|
|
919
|
-
"Wrong value for self.u_dict[i].eq_type[i], "
|
|
920
|
-
f"got {self.u_dict[i].eq_type[i]}"
|
|
921
|
-
)
|
|
922
|
-
|
|
923
|
-
# derivative keys for the dynamic loss. Note that we create a
|
|
924
|
-
# DerivativeKeysODE around a ParamsDict object because a whole
|
|
925
|
-
# params_dict is feed to DynamicLoss.evaluate functions (extract_params
|
|
926
|
-
# happen inside it)
|
|
927
|
-
self.derivative_keys_dyn_loss = DerivativeKeysPDENonStatio(params=params_dict)
|
|
928
|
-
|
|
929
|
-
# also make sure we only have PINNs or SPINNs
|
|
930
|
-
if not (
|
|
931
|
-
all(isinstance(value, PINN) for value in self.u_dict.values())
|
|
932
|
-
or all(isinstance(value, SPINN) for value in self.u_dict.values())
|
|
933
|
-
):
|
|
934
|
-
raise ValueError(
|
|
935
|
-
"We only accept dictionary of PINNs or dictionary of SPINNs"
|
|
936
|
-
)
|
|
937
|
-
|
|
938
|
-
def set_loss_weights(
|
|
939
|
-
self, loss_weights_init: LossWeightsPDEDict
|
|
940
|
-
) -> dict[str, dict]:
|
|
941
|
-
"""
|
|
942
|
-
This rather complex function enables the user to specify a simple
|
|
943
|
-
loss_weights=LossWeightsPDEDict(dyn_loss=1., initial_condition=Tmax)
|
|
944
|
-
for ponderating values being applied to all the equations of the
|
|
945
|
-
system... So all the transformations are handled here
|
|
946
|
-
"""
|
|
947
|
-
_loss_weights = {}
|
|
948
|
-
for k in fields(loss_weights_init):
|
|
949
|
-
v = getattr(loss_weights_init, k.name)
|
|
950
|
-
if isinstance(v, dict):
|
|
951
|
-
for vv in v.keys():
|
|
952
|
-
if not isinstance(vv, (int, float)) and not (
|
|
953
|
-
isinstance(vv, Array)
|
|
954
|
-
and ((vv.shape == (1,) or len(vv.shape) == 0))
|
|
955
|
-
):
|
|
956
|
-
# TODO improve that
|
|
957
|
-
raise ValueError(
|
|
958
|
-
f"loss values cannot be vectorial here, got {vv}"
|
|
959
|
-
)
|
|
960
|
-
if k.name == "dyn_loss":
|
|
961
|
-
if v.keys() == self.dynamic_loss_dict.keys():
|
|
962
|
-
_loss_weights[k.name] = v
|
|
963
|
-
else:
|
|
964
|
-
raise ValueError(
|
|
965
|
-
"Keys in nested dictionary of loss_weights"
|
|
966
|
-
" do not match dynamic_loss_dict keys"
|
|
967
|
-
)
|
|
968
|
-
else:
|
|
969
|
-
if v.keys() == self.u_dict.keys():
|
|
970
|
-
_loss_weights[k.name] = v
|
|
971
|
-
else:
|
|
972
|
-
raise ValueError(
|
|
973
|
-
"Keys in nested dictionary of loss_weights"
|
|
974
|
-
" do not match u_dict keys"
|
|
975
|
-
)
|
|
976
|
-
if v is None:
|
|
977
|
-
_loss_weights[k.name] = {kk: 0 for kk in self.u_dict.keys()}
|
|
978
|
-
else:
|
|
979
|
-
if not isinstance(v, (int, float)) and not (
|
|
980
|
-
isinstance(v, Array) and ((v.shape == (1,) or len(v.shape) == 0))
|
|
981
|
-
):
|
|
982
|
-
# TODO improve that
|
|
983
|
-
raise ValueError(f"loss values cannot be vectorial here, got {v}")
|
|
984
|
-
if k.name == "dyn_loss":
|
|
985
|
-
_loss_weights[k.name] = {
|
|
986
|
-
kk: v for kk in self.dynamic_loss_dict.keys()
|
|
987
|
-
}
|
|
988
|
-
else:
|
|
989
|
-
_loss_weights[k.name] = {kk: v for kk in self.u_dict.keys()}
|
|
990
|
-
return _loss_weights
|
|
991
|
-
|
|
992
|
-
def __call__(self, *args, **kwargs):
|
|
993
|
-
return self.evaluate(*args, **kwargs)
|
|
994
|
-
|
|
995
|
-
def evaluate(
|
|
996
|
-
self,
|
|
997
|
-
params_dict: ParamsDict,
|
|
998
|
-
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
999
|
-
) -> tuple[Float[Array, "1"], dict[str, float]]:
|
|
1000
|
-
"""
|
|
1001
|
-
Evaluate the loss function at a batch of points for given parameters.
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
Parameters
|
|
1005
|
-
---------
|
|
1006
|
-
params_dict
|
|
1007
|
-
Parameters at which the losses of the system are evaluated
|
|
1008
|
-
batch
|
|
1009
|
-
Such named tuples are composed of batch of points in the
|
|
1010
|
-
domain, a batch of points in the domain
|
|
1011
|
-
border, (a batch of time points a for PDENonStatioBatch) and an
|
|
1012
|
-
optional additional batch of parameters (eg. for metamodeling)
|
|
1013
|
-
and an optional additional batch of observed
|
|
1014
|
-
inputs/outputs/parameters
|
|
1015
|
-
"""
|
|
1016
|
-
if self.u_dict.keys() != params_dict.nn_params.keys():
|
|
1017
|
-
raise ValueError("u_dict and params_dict[nn_params] should have same keys ")
|
|
1018
|
-
|
|
1019
|
-
vmap_in_axes = (0,)
|
|
1020
|
-
|
|
1021
|
-
# Retrieve the optional eq_params_batch
|
|
1022
|
-
# and update eq_params with the latter
|
|
1023
|
-
# and update vmap_in_axes
|
|
1024
|
-
if batch.param_batch_dict is not None:
|
|
1025
|
-
eq_params_batch_dict = batch.param_batch_dict
|
|
1026
|
-
|
|
1027
|
-
# feed the eq_params with the batch
|
|
1028
|
-
for k in eq_params_batch_dict.keys():
|
|
1029
|
-
params_dict.eq_params[k] = eq_params_batch_dict[k]
|
|
1030
|
-
|
|
1031
|
-
vmap_in_axes_params = _get_vmap_in_axes_params(
|
|
1032
|
-
batch.param_batch_dict, params_dict
|
|
1033
|
-
)
|
|
1034
|
-
|
|
1035
|
-
def dyn_loss_for_one_key(dyn_loss, loss_weight):
|
|
1036
|
-
"""The function used in tree_map"""
|
|
1037
|
-
return dynamic_loss_apply(
|
|
1038
|
-
dyn_loss.evaluate,
|
|
1039
|
-
self.u_dict,
|
|
1040
|
-
(
|
|
1041
|
-
batch.domain_batch
|
|
1042
|
-
if isinstance(batch, PDEStatioBatch)
|
|
1043
|
-
else batch.domain_batch
|
|
1044
|
-
),
|
|
1045
|
-
_set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
|
|
1046
|
-
vmap_in_axes + vmap_in_axes_params,
|
|
1047
|
-
loss_weight,
|
|
1048
|
-
u_type=type(list(self.u_dict.values())[0]),
|
|
1049
|
-
)
|
|
1050
|
-
|
|
1051
|
-
dyn_loss_mse_dict = jax.tree_util.tree_map(
|
|
1052
|
-
dyn_loss_for_one_key,
|
|
1053
|
-
self.dynamic_loss_dict,
|
|
1054
|
-
self._loss_weights["dyn_loss"],
|
|
1055
|
-
is_leaf=lambda x: isinstance(
|
|
1056
|
-
x, (PDEStatio, PDENonStatio)
|
|
1057
|
-
), # before when dynamic losses
|
|
1058
|
-
# where plain (unregister pytree) node classes, we could not traverse
|
|
1059
|
-
# this level. Now that dynamic losses are eqx.Module they can be
|
|
1060
|
-
# traversed by tree map recursion. Hence we need to specify to that
|
|
1061
|
-
# we want to stop at this level
|
|
1062
|
-
)
|
|
1063
|
-
mse_dyn_loss = jax.tree_util.tree_reduce(
|
|
1064
|
-
lambda x, y: x + y, jax.tree_util.tree_leaves(dyn_loss_mse_dict)
|
|
1065
|
-
)
|
|
1066
|
-
|
|
1067
|
-
# boundary conditions, normalization conditions, observation_loss,
|
|
1068
|
-
# initial condition... loss this is done via the internal
|
|
1069
|
-
# LossPDEStatio and NonStatio
|
|
1070
|
-
loss_weight_struct = {
|
|
1071
|
-
"dyn_loss": "*",
|
|
1072
|
-
"norm_loss": "*",
|
|
1073
|
-
"boundary_loss": "*",
|
|
1074
|
-
"observations": "*",
|
|
1075
|
-
"initial_condition": "*",
|
|
1076
|
-
}
|
|
1077
|
-
# we need to do the following for the tree_mapping to work
|
|
1078
|
-
if batch.obs_batch_dict is None:
|
|
1079
|
-
batch = append_obs_batch(batch, self.u_dict_with_none)
|
|
1080
|
-
total_loss, res_dict = constraints_system_loss_apply(
|
|
1081
|
-
self.u_constraints_dict,
|
|
1082
|
-
batch,
|
|
1083
|
-
params_dict,
|
|
1084
|
-
self._loss_weights,
|
|
1085
|
-
loss_weight_struct,
|
|
1086
|
-
)
|
|
1087
|
-
|
|
1088
|
-
# Add the mse_dyn_loss from the previous computations
|
|
1089
|
-
total_loss += mse_dyn_loss
|
|
1090
|
-
res_dict["dyn_loss"] += mse_dyn_loss
|
|
1091
|
-
return total_loss, res_dict
|