jinns 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +114 -187
- jinns/loss/_DynamicLossAbstract.py +74 -69
- jinns/loss/_LossODE.py +132 -348
- jinns/loss/_LossPDE.py +262 -549
- jinns/loss/__init__.py +32 -6
- jinns/loss/_abstract_loss.py +128 -0
- jinns/loss/_boundary_conditions.py +20 -19
- jinns/loss/_loss_components.py +43 -0
- jinns/loss/_loss_utils.py +85 -179
- jinns/loss/_loss_weight_updates.py +202 -0
- jinns/loss/_loss_weights.py +64 -40
- jinns/loss/_operators.py +84 -74
- jinns/nn/__init__.py +15 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +94 -57
- jinns/nn/_mlp.py +50 -25
- jinns/nn/_pinn.py +33 -19
- jinns/nn/_ppinn.py +70 -34
- jinns/nn/_save_load.py +21 -51
- jinns/nn/_spinn.py +33 -16
- jinns/nn/_spinn_mlp.py +28 -22
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +35 -34
- jinns/solver/_rar.py +80 -63
- jinns/solver/_solve.py +207 -92
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +16 -10
- jinns/utils/_types.py +20 -54
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/METADATA +8 -4
- jinns-1.5.0.dist-info/RECORD +55 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns-1.3.0.dist-info/RECORD +0 -44
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossPDE.py
CHANGED
|
@@ -1,14 +1,15 @@
|
|
|
1
|
-
# pylint: disable=unsubscriptable-object, no-member
|
|
2
1
|
"""
|
|
3
2
|
Main module to implement a PDE loss in jinns
|
|
4
3
|
"""
|
|
4
|
+
|
|
5
5
|
from __future__ import (
|
|
6
6
|
annotations,
|
|
7
7
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
8
|
|
|
9
9
|
import abc
|
|
10
|
-
from dataclasses import InitVar
|
|
11
|
-
from typing import TYPE_CHECKING,
|
|
10
|
+
from dataclasses import InitVar
|
|
11
|
+
from typing import TYPE_CHECKING, Callable, TypedDict
|
|
12
|
+
from types import EllipsisType
|
|
12
13
|
import warnings
|
|
13
14
|
import jax
|
|
14
15
|
import jax.numpy as jnp
|
|
@@ -20,9 +21,7 @@ from jinns.loss._loss_utils import (
|
|
|
20
21
|
normalization_loss_apply,
|
|
21
22
|
observations_loss_apply,
|
|
22
23
|
initial_condition_apply,
|
|
23
|
-
constraints_system_loss_apply,
|
|
24
24
|
)
|
|
25
|
-
from jinns.data._DataGenerators import append_obs_batch
|
|
26
25
|
from jinns.parameters._params import (
|
|
27
26
|
_get_vmap_in_axes_params,
|
|
28
27
|
_update_eq_params_dict,
|
|
@@ -32,19 +31,31 @@ from jinns.parameters._derivative_keys import (
|
|
|
32
31
|
DerivativeKeysPDEStatio,
|
|
33
32
|
DerivativeKeysPDENonStatio,
|
|
34
33
|
)
|
|
34
|
+
from jinns.loss._abstract_loss import AbstractLoss
|
|
35
|
+
from jinns.loss._loss_components import PDEStatioComponents, PDENonStatioComponents
|
|
35
36
|
from jinns.loss._loss_weights import (
|
|
36
37
|
LossWeightsPDEStatio,
|
|
37
38
|
LossWeightsPDENonStatio,
|
|
38
|
-
LossWeightsPDEDict,
|
|
39
39
|
)
|
|
40
|
-
from jinns.loss._DynamicLossAbstract import PDEStatio, PDENonStatio
|
|
41
|
-
from jinns.nn._pinn import PINN
|
|
42
|
-
from jinns.nn._spinn import SPINN
|
|
43
40
|
from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
|
|
41
|
+
from jinns.parameters._params import Params
|
|
44
42
|
|
|
45
43
|
|
|
46
44
|
if TYPE_CHECKING:
|
|
47
|
-
|
|
45
|
+
# imports for type hints only
|
|
46
|
+
from jinns.nn._abstract_pinn import AbstractPINN
|
|
47
|
+
from jinns.loss import PDENonStatio, PDEStatio
|
|
48
|
+
from jinns.utils._types import BoundaryConditionFun
|
|
49
|
+
|
|
50
|
+
class LossDictPDEStatio(TypedDict):
|
|
51
|
+
dyn_loss: Float[Array, " "]
|
|
52
|
+
norm_loss: Float[Array, " "]
|
|
53
|
+
boundary_loss: Float[Array, " "]
|
|
54
|
+
observations: Float[Array, " "]
|
|
55
|
+
|
|
56
|
+
class LossDictPDENonStatio(LossDictPDEStatio):
|
|
57
|
+
initial_condition: Float[Array, " "]
|
|
58
|
+
|
|
48
59
|
|
|
49
60
|
_IMPLEMENTED_BOUNDARY_CONDITIONS = [
|
|
50
61
|
"dirichlet",
|
|
@@ -53,8 +64,8 @@ _IMPLEMENTED_BOUNDARY_CONDITIONS = [
|
|
|
53
64
|
]
|
|
54
65
|
|
|
55
66
|
|
|
56
|
-
class _LossPDEAbstract(
|
|
57
|
-
"""
|
|
67
|
+
class _LossPDEAbstract(AbstractLoss):
|
|
68
|
+
r"""
|
|
58
69
|
Parameters
|
|
59
70
|
----------
|
|
60
71
|
|
|
@@ -62,18 +73,21 @@ class _LossPDEAbstract(eqx.Module):
|
|
|
62
73
|
The loss weights for the differents term : dynamic loss,
|
|
63
74
|
initial condition (if LossWeightsPDENonStatio), boundary conditions if
|
|
64
75
|
any, normalization loss if any and observations if any.
|
|
65
|
-
|
|
76
|
+
Can be updated according to a specific algorithm. See
|
|
77
|
+
`update_weight_method`
|
|
78
|
+
update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
|
|
79
|
+
Default is None meaning no update for loss weights. Otherwise a string
|
|
66
80
|
derivative_keys : DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio, default=None
|
|
67
81
|
Specify which field of `params` should be differentiated for each
|
|
68
82
|
composant of the total loss. Particularily useful for inverse problems.
|
|
69
83
|
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
70
84
|
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
71
85
|
is `"nn_params"` for each composant of the loss.
|
|
72
|
-
omega_boundary_fun :
|
|
86
|
+
omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
|
|
73
87
|
The function to be matched in the border condition (can be None) or a
|
|
74
88
|
dictionary of such functions as values and keys as described
|
|
75
89
|
in `omega_boundary_condition`.
|
|
76
|
-
omega_boundary_condition : str |
|
|
90
|
+
omega_boundary_condition : str | dict[str, str], default=None
|
|
77
91
|
Either None (no condition, by default), or a string defining
|
|
78
92
|
the boundary condition (Dirichlet or Von Neumann),
|
|
79
93
|
or a dictionary with such strings as values. In this case,
|
|
@@ -84,28 +98,29 @@ class _LossPDEAbstract(eqx.Module):
|
|
|
84
98
|
a particular boundary condition on this facet.
|
|
85
99
|
The facet called “xmin”, resp. “xmax” etc., in 2D,
|
|
86
100
|
refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
|
|
87
|
-
omega_boundary_dim : slice |
|
|
101
|
+
omega_boundary_dim : slice | dict[str, slice], default=None
|
|
88
102
|
Either None, or a slice object or a dictionary of slice objects as
|
|
89
103
|
values and keys as described in `omega_boundary_condition`.
|
|
90
104
|
`omega_boundary_dim` indicates which dimension(s) of the PINN
|
|
91
105
|
will be forced to match the boundary condition.
|
|
92
106
|
Note that it must be a slice and not an integer
|
|
93
107
|
(but a preprocessing of the user provided argument takes care of it)
|
|
94
|
-
norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
|
|
108
|
+
norm_samples : Float[Array, " nb_norm_samples dimension"], default=None
|
|
95
109
|
Monte-Carlo sample points for computing the
|
|
96
110
|
normalization constant. Default is None.
|
|
97
|
-
norm_weights : Float[Array, "nb_norm_samples"] | float | int, default=None
|
|
111
|
+
norm_weights : Float[Array, " nb_norm_samples"] | float | int, default=None
|
|
98
112
|
The importance sampling weights for Monte-Carlo integration of the
|
|
99
113
|
normalization constant. Must be provided if `norm_samples` is provided.
|
|
100
|
-
`norm_weights` should
|
|
114
|
+
`norm_weights` should be broadcastble to
|
|
101
115
|
`norm_samples`.
|
|
102
|
-
Alternatively, the user can pass a float or an integer
|
|
116
|
+
Alternatively, the user can pass a float or an integer that will be
|
|
117
|
+
made broadcastable to `norm_samples`.
|
|
103
118
|
These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
|
|
104
119
|
$q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
|
|
105
|
-
obs_slice : slice, default=None
|
|
120
|
+
obs_slice : EllipsisType | slice, default=None
|
|
106
121
|
slice object specifying the begininning/ending of the PINN output
|
|
107
122
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
108
|
-
params : InitVar[Params], default=None
|
|
123
|
+
params : InitVar[Params[Array]], default=None
|
|
109
124
|
The main Params object of the problem needed to instanciate the
|
|
110
125
|
DerivativeKeysODE if the latter is not specified.
|
|
111
126
|
"""
|
|
@@ -119,26 +134,28 @@ class _LossPDEAbstract(eqx.Module):
|
|
|
119
134
|
loss_weights: LossWeightsPDEStatio | LossWeightsPDENonStatio | None = eqx.field(
|
|
120
135
|
kw_only=True, default=None
|
|
121
136
|
)
|
|
122
|
-
omega_boundary_fun:
|
|
123
|
-
|
|
124
|
-
)
|
|
125
|
-
omega_boundary_condition: str |
|
|
137
|
+
omega_boundary_fun: (
|
|
138
|
+
BoundaryConditionFun | dict[str, BoundaryConditionFun] | None
|
|
139
|
+
) = eqx.field(kw_only=True, default=None, static=True)
|
|
140
|
+
omega_boundary_condition: str | dict[str, str] | None = eqx.field(
|
|
126
141
|
kw_only=True, default=None, static=True
|
|
127
142
|
)
|
|
128
|
-
omega_boundary_dim: slice |
|
|
143
|
+
omega_boundary_dim: slice | dict[str, slice] | None = eqx.field(
|
|
129
144
|
kw_only=True, default=None, static=True
|
|
130
145
|
)
|
|
131
|
-
norm_samples: Float[Array, "nb_norm_samples dimension"] | None = eqx.field(
|
|
146
|
+
norm_samples: Float[Array, " nb_norm_samples dimension"] | None = eqx.field(
|
|
132
147
|
kw_only=True, default=None
|
|
133
148
|
)
|
|
134
|
-
norm_weights: Float[Array, "nb_norm_samples"] | float | int | None = eqx.field(
|
|
149
|
+
norm_weights: Float[Array, " nb_norm_samples"] | float | int | None = eqx.field(
|
|
135
150
|
kw_only=True, default=None
|
|
136
151
|
)
|
|
137
|
-
obs_slice: slice | None = eqx.field(
|
|
152
|
+
obs_slice: EllipsisType | slice | None = eqx.field(
|
|
153
|
+
kw_only=True, default=None, static=True
|
|
154
|
+
)
|
|
138
155
|
|
|
139
|
-
params: InitVar[Params] = eqx.field(kw_only=True, default=None)
|
|
156
|
+
params: InitVar[Params[Array]] = eqx.field(kw_only=True, default=None)
|
|
140
157
|
|
|
141
|
-
def __post_init__(self, params=None):
|
|
158
|
+
def __post_init__(self, params: Params[Array] | None = None):
|
|
142
159
|
"""
|
|
143
160
|
Note that neither __init__ or __post_init__ are called when udating a
|
|
144
161
|
Module with eqx.tree_at
|
|
@@ -228,6 +245,11 @@ class _LossPDEAbstract(eqx.Module):
|
|
|
228
245
|
)
|
|
229
246
|
|
|
230
247
|
if isinstance(self.omega_boundary_fun, dict):
|
|
248
|
+
if not isinstance(self.omega_boundary_dim, dict):
|
|
249
|
+
raise ValueError(
|
|
250
|
+
"If omega_boundary_fun is a dict then"
|
|
251
|
+
" omega_boundary_dim should also be a dict"
|
|
252
|
+
)
|
|
231
253
|
if self.omega_boundary_dim is None:
|
|
232
254
|
self.omega_boundary_dim = {
|
|
233
255
|
k: jnp.s_[::] for k in self.omega_boundary_fun.keys()
|
|
@@ -262,27 +284,29 @@ class _LossPDEAbstract(eqx.Module):
|
|
|
262
284
|
raise ValueError(
|
|
263
285
|
"`norm_weights` must be provided when `norm_samples` is used!"
|
|
264
286
|
)
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
)
|
|
272
|
-
else:
|
|
287
|
+
if isinstance(self.norm_weights, (int, float)):
|
|
288
|
+
self.norm_weights = self.norm_weights * jnp.ones(
|
|
289
|
+
(self.norm_samples.shape[0],)
|
|
290
|
+
)
|
|
291
|
+
if isinstance(self.norm_weights, Array):
|
|
292
|
+
if not (self.norm_weights.shape[0] == self.norm_samples.shape[0]):
|
|
273
293
|
raise ValueError(
|
|
274
|
-
"
|
|
275
|
-
"
|
|
276
|
-
f" got shape {self.norm_weights.shape} and"
|
|
277
|
-
f" shape {self.norm_samples.shape}."
|
|
294
|
+
"self.norm_weights and "
|
|
295
|
+
"self.norm_samples must have the same leading dimension"
|
|
278
296
|
)
|
|
297
|
+
else:
|
|
298
|
+
raise ValueError("Wrong type for self.norm_weights")
|
|
299
|
+
|
|
300
|
+
@abc.abstractmethod
|
|
301
|
+
def __call__(self, *_, **__):
|
|
302
|
+
pass
|
|
279
303
|
|
|
280
304
|
@abc.abstractmethod
|
|
281
305
|
def evaluate(
|
|
282
306
|
self: eqx.Module,
|
|
283
|
-
params: Params,
|
|
307
|
+
params: Params[Array],
|
|
284
308
|
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
285
|
-
) -> tuple[Float,
|
|
309
|
+
) -> tuple[Float[Array, " "], LossDictPDEStatio | LossDictPDENonStatio]:
|
|
286
310
|
raise NotImplementedError
|
|
287
311
|
|
|
288
312
|
|
|
@@ -299,9 +323,9 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
299
323
|
|
|
300
324
|
Parameters
|
|
301
325
|
----------
|
|
302
|
-
u :
|
|
326
|
+
u : AbstractPINN
|
|
303
327
|
the PINN
|
|
304
|
-
dynamic_loss :
|
|
328
|
+
dynamic_loss : PDEStatio
|
|
305
329
|
the stationary PDE dynamic part of the loss, basically the differential
|
|
306
330
|
operator $\mathcal{N}[u](x)$. Should implement a method
|
|
307
331
|
`dynamic_loss.evaluate(x, u, params)`.
|
|
@@ -317,18 +341,21 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
317
341
|
The loss weights for the differents term : dynamic loss,
|
|
318
342
|
boundary conditions if any, normalization loss if any and
|
|
319
343
|
observations if any.
|
|
320
|
-
|
|
344
|
+
Can be updated according to a specific algorithm. See
|
|
345
|
+
`update_weight_method`
|
|
346
|
+
update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
|
|
347
|
+
Default is None meaning no update for loss weights. Otherwise a string
|
|
321
348
|
derivative_keys : DerivativeKeysPDEStatio, default=None
|
|
322
349
|
Specify which field of `params` should be differentiated for each
|
|
323
350
|
composant of the total loss. Particularily useful for inverse problems.
|
|
324
351
|
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
325
352
|
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
326
353
|
is `"nn_params"` for each composant of the loss.
|
|
327
|
-
omega_boundary_fun :
|
|
354
|
+
omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
|
|
328
355
|
The function to be matched in the border condition (can be None) or a
|
|
329
356
|
dictionary of such functions as values and keys as described
|
|
330
357
|
in `omega_boundary_condition`.
|
|
331
|
-
omega_boundary_condition : str |
|
|
358
|
+
omega_boundary_condition : str | dict[str, str], default=None
|
|
332
359
|
Either None (no condition, by default), or a string defining
|
|
333
360
|
the boundary condition (Dirichlet or Von Neumann),
|
|
334
361
|
or a dictionary with such strings as values. In this case,
|
|
@@ -339,17 +366,17 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
339
366
|
a particular boundary condition on this facet.
|
|
340
367
|
The facet called “xmin”, resp. “xmax” etc., in 2D,
|
|
341
368
|
refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
|
|
342
|
-
omega_boundary_dim : slice |
|
|
369
|
+
omega_boundary_dim : slice | dict[str, slice], default=None
|
|
343
370
|
Either None, or a slice object or a dictionary of slice objects as
|
|
344
371
|
values and keys as described in `omega_boundary_condition`.
|
|
345
372
|
`omega_boundary_dim` indicates which dimension(s) of the PINN
|
|
346
373
|
will be forced to match the boundary condition.
|
|
347
374
|
Note that it must be a slice and not an integer
|
|
348
375
|
(but a preprocessing of the user provided argument takes care of it)
|
|
349
|
-
norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
|
|
376
|
+
norm_samples : Float[Array, " nb_norm_samples dimension"], default=None
|
|
350
377
|
Monte-Carlo sample points for computing the
|
|
351
378
|
normalization constant. Default is None.
|
|
352
|
-
norm_weights : Float[Array, "nb_norm_samples"] | float | int, default=None
|
|
379
|
+
norm_weights : Float[Array, " nb_norm_samples"] | float | int, default=None
|
|
353
380
|
The importance sampling weights for Monte-Carlo integration of the
|
|
354
381
|
normalization constant. Must be provided if `norm_samples` is provided.
|
|
355
382
|
`norm_weights` should have the same leading dimension as
|
|
@@ -360,7 +387,7 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
360
387
|
obs_slice : slice, default=None
|
|
361
388
|
slice object specifying the begininning/ending of the PINN output
|
|
362
389
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
363
|
-
params : InitVar[Params], default=None
|
|
390
|
+
params : InitVar[Params[Array]], default=None
|
|
364
391
|
The main Params object of the problem needed to instanciate the
|
|
365
392
|
DerivativeKeysODE if the latter is not specified.
|
|
366
393
|
|
|
@@ -375,13 +402,13 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
375
402
|
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
376
403
|
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
377
404
|
|
|
378
|
-
u:
|
|
379
|
-
dynamic_loss:
|
|
405
|
+
u: AbstractPINN
|
|
406
|
+
dynamic_loss: PDEStatio | None
|
|
380
407
|
key: Key | None = eqx.field(kw_only=True, default=None)
|
|
381
408
|
|
|
382
409
|
vmap_in_axes: tuple[Int] = eqx.field(init=False, static=True)
|
|
383
410
|
|
|
384
|
-
def __post_init__(self, params=None):
|
|
411
|
+
def __post_init__(self, params: Params[Array] | None = None):
|
|
385
412
|
"""
|
|
386
413
|
Note that neither __init__ or __post_init__ are called when udating a
|
|
387
414
|
Module with eqx.tree_at!
|
|
@@ -395,28 +422,31 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
395
422
|
|
|
396
423
|
def _get_dynamic_loss_batch(
|
|
397
424
|
self, batch: PDEStatioBatch
|
|
398
|
-
) -> Float[Array, "batch_size dimension"]:
|
|
425
|
+
) -> Float[Array, " batch_size dimension"]:
|
|
399
426
|
return batch.domain_batch
|
|
400
427
|
|
|
401
428
|
def _get_normalization_loss_batch(
|
|
402
429
|
self, _
|
|
403
|
-
) -> Float[Array, "nb_norm_samples dimension"]:
|
|
404
|
-
return (self.norm_samples,)
|
|
430
|
+
) -> tuple[Float[Array, " nb_norm_samples dimension"]]:
|
|
431
|
+
return (self.norm_samples,) # type: ignore -> cannot narrow a class attr
|
|
432
|
+
|
|
433
|
+
# we could have used typing.cast though
|
|
405
434
|
|
|
406
435
|
def _get_observations_loss_batch(
|
|
407
436
|
self, batch: PDEStatioBatch
|
|
408
|
-
) -> Float[Array, "batch_size obs_dim"]:
|
|
409
|
-
return
|
|
437
|
+
) -> Float[Array, " batch_size obs_dim"]:
|
|
438
|
+
return batch.obs_batch_dict["pinn_in"]
|
|
410
439
|
|
|
411
440
|
def __call__(self, *args, **kwargs):
|
|
412
441
|
return self.evaluate(*args, **kwargs)
|
|
413
442
|
|
|
414
|
-
def
|
|
415
|
-
self, params: Params, batch: PDEStatioBatch
|
|
416
|
-
) -> tuple[
|
|
443
|
+
def evaluate_by_terms(
|
|
444
|
+
self, params: Params[Array], batch: PDEStatioBatch
|
|
445
|
+
) -> tuple[PDEStatioComponents[Array | None], PDEStatioComponents[Array | None]]:
|
|
417
446
|
"""
|
|
418
447
|
Evaluate the loss function at a batch of points for given parameters.
|
|
419
448
|
|
|
449
|
+
We retrieve two PyTrees with loss values and gradients for each term
|
|
420
450
|
|
|
421
451
|
Parameters
|
|
422
452
|
---------
|
|
@@ -440,75 +470,112 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
440
470
|
|
|
441
471
|
# dynamic part
|
|
442
472
|
if self.dynamic_loss is not None:
|
|
443
|
-
|
|
444
|
-
self.dynamic_loss.evaluate,
|
|
473
|
+
dyn_loss_fun = lambda p: dynamic_loss_apply(
|
|
474
|
+
self.dynamic_loss.evaluate, # type: ignore
|
|
445
475
|
self.u,
|
|
446
476
|
self._get_dynamic_loss_batch(batch),
|
|
447
|
-
_set_derivatives(
|
|
477
|
+
_set_derivatives(p, self.derivative_keys.dyn_loss), # type: ignore
|
|
448
478
|
self.vmap_in_axes + vmap_in_axes_params,
|
|
449
|
-
self.loss_weights.dyn_loss,
|
|
450
479
|
)
|
|
451
480
|
else:
|
|
452
|
-
|
|
481
|
+
dyn_loss_fun = None
|
|
453
482
|
|
|
454
483
|
# normalization part
|
|
455
484
|
if self.norm_samples is not None:
|
|
456
|
-
|
|
485
|
+
norm_loss_fun = lambda p: normalization_loss_apply(
|
|
457
486
|
self.u,
|
|
458
487
|
self._get_normalization_loss_batch(batch),
|
|
459
|
-
_set_derivatives(
|
|
488
|
+
_set_derivatives(p, self.derivative_keys.norm_loss), # type: ignore
|
|
460
489
|
vmap_in_axes_params,
|
|
461
|
-
self.norm_weights,
|
|
462
|
-
self.loss_weights.norm_loss,
|
|
490
|
+
self.norm_weights, # type: ignore -> can't get the __post_init__ narrowing here
|
|
463
491
|
)
|
|
464
492
|
else:
|
|
465
|
-
|
|
493
|
+
norm_loss_fun = None
|
|
466
494
|
|
|
467
495
|
# boundary part
|
|
468
|
-
if
|
|
469
|
-
|
|
496
|
+
if (
|
|
497
|
+
self.omega_boundary_condition is not None
|
|
498
|
+
and self.omega_boundary_dim is not None
|
|
499
|
+
and self.omega_boundary_fun is not None
|
|
500
|
+
): # pyright cannot narrow down the three None otherwise as it is class attribute
|
|
501
|
+
boundary_loss_fun = lambda p: boundary_condition_apply(
|
|
470
502
|
self.u,
|
|
471
503
|
batch,
|
|
472
|
-
_set_derivatives(
|
|
473
|
-
self.omega_boundary_fun,
|
|
474
|
-
self.omega_boundary_condition,
|
|
475
|
-
self.omega_boundary_dim,
|
|
476
|
-
self.loss_weights.boundary_loss,
|
|
504
|
+
_set_derivatives(p, self.derivative_keys.boundary_loss), # type: ignore
|
|
505
|
+
self.omega_boundary_fun, # type: ignore
|
|
506
|
+
self.omega_boundary_condition, # type: ignore
|
|
507
|
+
self.omega_boundary_dim, # type: ignore
|
|
477
508
|
)
|
|
478
509
|
else:
|
|
479
|
-
|
|
510
|
+
boundary_loss_fun = None
|
|
480
511
|
|
|
481
512
|
# Observation mse
|
|
482
513
|
if batch.obs_batch_dict is not None:
|
|
483
514
|
# update params with the batches of observed params
|
|
484
|
-
|
|
515
|
+
params_obs = _update_eq_params_dict(
|
|
516
|
+
params, batch.obs_batch_dict["eq_params"]
|
|
517
|
+
)
|
|
485
518
|
|
|
486
|
-
|
|
519
|
+
obs_loss_fun = lambda po: observations_loss_apply(
|
|
487
520
|
self.u,
|
|
488
521
|
self._get_observations_loss_batch(batch),
|
|
489
|
-
_set_derivatives(
|
|
522
|
+
_set_derivatives(po, self.derivative_keys.observations), # type: ignore
|
|
490
523
|
self.vmap_in_axes + vmap_in_axes_params,
|
|
491
524
|
batch.obs_batch_dict["val"],
|
|
492
|
-
self.loss_weights.observations,
|
|
493
525
|
self.obs_slice,
|
|
494
526
|
)
|
|
495
527
|
else:
|
|
496
|
-
|
|
528
|
+
params_obs = None
|
|
529
|
+
obs_loss_fun = None
|
|
497
530
|
|
|
498
|
-
#
|
|
499
|
-
|
|
500
|
-
|
|
531
|
+
# get the unweighted mses for each loss term as well as the gradients
|
|
532
|
+
all_funs: PDEStatioComponents[Callable[[Params[Array]], Array] | None] = (
|
|
533
|
+
PDEStatioComponents(
|
|
534
|
+
dyn_loss_fun, norm_loss_fun, boundary_loss_fun, obs_loss_fun
|
|
535
|
+
)
|
|
536
|
+
)
|
|
537
|
+
all_params: PDEStatioComponents[Params[Array] | None] = PDEStatioComponents(
|
|
538
|
+
params, params, params, params_obs
|
|
539
|
+
)
|
|
540
|
+
mses_grads = jax.tree.map(
|
|
541
|
+
lambda fun, params: self.get_gradients(fun, params),
|
|
542
|
+
all_funs,
|
|
543
|
+
all_params,
|
|
544
|
+
is_leaf=lambda x: x is None,
|
|
501
545
|
)
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
"dyn_loss": mse_dyn_loss,
|
|
505
|
-
"norm_loss": mse_norm_loss,
|
|
506
|
-
"boundary_loss": mse_boundary_loss,
|
|
507
|
-
"observations": mse_observation_loss,
|
|
508
|
-
"initial_condition": jnp.array(0.0), # for compatibility in the
|
|
509
|
-
# tree_map of SystemLoss
|
|
510
|
-
}
|
|
546
|
+
mses = jax.tree.map(
|
|
547
|
+
lambda leaf: leaf[0], mses_grads, is_leaf=lambda x: isinstance(x, tuple)
|
|
511
548
|
)
|
|
549
|
+
grads = jax.tree.map(
|
|
550
|
+
lambda leaf: leaf[1], mses_grads, is_leaf=lambda x: isinstance(x, tuple)
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
return mses, grads
|
|
554
|
+
|
|
555
|
+
def evaluate(
|
|
556
|
+
self, params: Params[Array], batch: PDEStatioBatch
|
|
557
|
+
) -> tuple[Float[Array, " "], PDEStatioComponents[Float[Array, " "] | None]]:
|
|
558
|
+
"""
|
|
559
|
+
Evaluate the loss function at a batch of points for given parameters.
|
|
560
|
+
|
|
561
|
+
We retrieve the total value itself and a PyTree with loss values for each term
|
|
562
|
+
|
|
563
|
+
Parameters
|
|
564
|
+
---------
|
|
565
|
+
params
|
|
566
|
+
Parameters at which the loss is evaluated
|
|
567
|
+
batch
|
|
568
|
+
Composed of a batch of points in the
|
|
569
|
+
domain, a batch of points in the domain
|
|
570
|
+
border and an optional additional batch of parameters (eg. for
|
|
571
|
+
metamodeling) and an optional additional batch of observed
|
|
572
|
+
inputs/outputs/parameters
|
|
573
|
+
"""
|
|
574
|
+
loss_terms, _ = self.evaluate_by_terms(params, batch)
|
|
575
|
+
|
|
576
|
+
loss_val = self.ponderate_and_sum_loss(loss_terms)
|
|
577
|
+
|
|
578
|
+
return loss_val, loss_terms
|
|
512
579
|
|
|
513
580
|
|
|
514
581
|
class LossPDENonStatio(LossPDEStatio):
|
|
@@ -527,9 +594,9 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
527
594
|
|
|
528
595
|
Parameters
|
|
529
596
|
----------
|
|
530
|
-
u :
|
|
597
|
+
u : AbstractPINN
|
|
531
598
|
the PINN
|
|
532
|
-
dynamic_loss :
|
|
599
|
+
dynamic_loss : PDENonStatio
|
|
533
600
|
the non stationary PDE dynamic part of the loss, basically the differential
|
|
534
601
|
operator $\mathcal{N}[u](t, x)$. Should implement a method
|
|
535
602
|
`dynamic_loss.evaluate(t, x, u, params)`.
|
|
@@ -546,18 +613,21 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
546
613
|
The loss weights for the differents term : dynamic loss,
|
|
547
614
|
boundary conditions if any, initial condition, normalization loss if any and
|
|
548
615
|
observations if any.
|
|
549
|
-
|
|
616
|
+
Can be updated according to a specific algorithm. See
|
|
617
|
+
`update_weight_method`
|
|
618
|
+
update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
|
|
619
|
+
Default is None meaning no update for loss weights. Otherwise a string
|
|
550
620
|
derivative_keys : DerivativeKeysPDENonStatio, default=None
|
|
551
621
|
Specify which field of `params` should be differentiated for each
|
|
552
622
|
composant of the total loss. Particularily useful for inverse problems.
|
|
553
623
|
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
554
624
|
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
555
625
|
is `"nn_params"` for each composant of the loss.
|
|
556
|
-
omega_boundary_fun :
|
|
626
|
+
omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
|
|
557
627
|
The function to be matched in the border condition (can be None) or a
|
|
558
628
|
dictionary of such functions as values and keys as described
|
|
559
629
|
in `omega_boundary_condition`.
|
|
560
|
-
omega_boundary_condition : str |
|
|
630
|
+
omega_boundary_condition : str | dict[str, str], default=None
|
|
561
631
|
Either None (no condition, by default), or a string defining
|
|
562
632
|
the boundary condition (Dirichlet or Von Neumann),
|
|
563
633
|
or a dictionary with such strings as values. In this case,
|
|
@@ -568,17 +638,17 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
568
638
|
a particular boundary condition on this facet.
|
|
569
639
|
The facet called “xmin”, resp. “xmax” etc., in 2D,
|
|
570
640
|
refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
|
|
571
|
-
omega_boundary_dim : slice |
|
|
641
|
+
omega_boundary_dim : slice | dict[str, slice], default=None
|
|
572
642
|
Either None, or a slice object or a dictionary of slice objects as
|
|
573
643
|
values and keys as described in `omega_boundary_condition`.
|
|
574
644
|
`omega_boundary_dim` indicates which dimension(s) of the PINN
|
|
575
645
|
will be forced to match the boundary condition.
|
|
576
646
|
Note that it must be a slice and not an integer
|
|
577
647
|
(but a preprocessing of the user provided argument takes care of it)
|
|
578
|
-
norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
|
|
648
|
+
norm_samples : Float[Array, " nb_norm_samples dimension"], default=None
|
|
579
649
|
Monte-Carlo sample points for computing the
|
|
580
650
|
normalization constant. Default is None.
|
|
581
|
-
norm_weights : Float[Array, "nb_norm_samples"] | float | int, default=None
|
|
651
|
+
norm_weights : Float[Array, " nb_norm_samples"] | float | int, default=None
|
|
582
652
|
The importance sampling weights for Monte-Carlo integration of the
|
|
583
653
|
normalization constant. Must be provided if `norm_samples` is provided.
|
|
584
654
|
`norm_weights` should have the same leading dimension as
|
|
@@ -589,20 +659,25 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
589
659
|
obs_slice : slice, default=None
|
|
590
660
|
slice object specifying the begininning/ending of the PINN output
|
|
591
661
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
662
|
+
t0 : float | Float[Array, " 1"], default=None
|
|
663
|
+
The time at which to apply the initial condition. If None, the time
|
|
664
|
+
is set to `0` by default.
|
|
592
665
|
initial_condition_fun : Callable, default=None
|
|
593
|
-
A function representing the
|
|
594
|
-
(default) then no initial condition is applied
|
|
595
|
-
params : InitVar[Params], default=None
|
|
596
|
-
The main Params object of the problem needed to instanciate the
|
|
597
|
-
DerivativeKeysODE if the latter is not specified.
|
|
666
|
+
A function representing the initial condition at `t0`. If None
|
|
667
|
+
(default) then no initial condition is applied.
|
|
668
|
+
params : InitVar[Params[Array]], default=None
|
|
669
|
+
The main `Params` object of the problem needed to instanciate the
|
|
670
|
+
`DerivativeKeysODE` if the latter is not specified.
|
|
598
671
|
|
|
599
672
|
"""
|
|
600
673
|
|
|
674
|
+
dynamic_loss: PDENonStatio | None
|
|
601
675
|
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
602
676
|
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
603
677
|
initial_condition_fun: Callable | None = eqx.field(
|
|
604
678
|
kw_only=True, default=None, static=True
|
|
605
679
|
)
|
|
680
|
+
t0: float | Float[Array, " 1"] | None = eqx.field(kw_only=True, default=None)
|
|
606
681
|
|
|
607
682
|
_max_norm_samples_omega: Int = eqx.field(init=False, static=True)
|
|
608
683
|
_max_norm_time_slices: Int = eqx.field(init=False, static=True)
|
|
@@ -624,6 +699,23 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
624
699
|
"Initial condition wasn't provided. Be sure to cover for that"
|
|
625
700
|
"case (e.g by. hardcoding it into the PINN output)."
|
|
626
701
|
)
|
|
702
|
+
# some checks for t0
|
|
703
|
+
if isinstance(self.t0, Array):
|
|
704
|
+
if not self.t0.shape: # e.g. user input: jnp.array(0.)
|
|
705
|
+
self.t0 = jnp.array([self.t0])
|
|
706
|
+
elif self.t0.shape != (1,):
|
|
707
|
+
raise ValueError(
|
|
708
|
+
f"Wrong self.t0 input. It should be"
|
|
709
|
+
f"a float or an array of shape (1,). Got shape: {self.t0.shape}"
|
|
710
|
+
)
|
|
711
|
+
elif isinstance(self.t0, float): # e.g. user input: 0.
|
|
712
|
+
self.t0 = jnp.array([self.t0])
|
|
713
|
+
elif isinstance(self.t0, int): # e.g. user input: 0
|
|
714
|
+
self.t0 = jnp.array([float(self.t0)])
|
|
715
|
+
elif self.t0 is None:
|
|
716
|
+
self.t0 = jnp.array([0])
|
|
717
|
+
else:
|
|
718
|
+
raise ValueError("Wrong value for t0")
|
|
627
719
|
|
|
628
720
|
# witht the variables below we avoid memory overflow since a cartesian
|
|
629
721
|
# product is taken
|
|
@@ -632,44 +724,50 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
632
724
|
|
|
633
725
|
def _get_dynamic_loss_batch(
|
|
634
726
|
self, batch: PDENonStatioBatch
|
|
635
|
-
) -> Float[Array, "batch_size 1+dimension"]:
|
|
727
|
+
) -> Float[Array, " batch_size 1+dimension"]:
|
|
636
728
|
return batch.domain_batch
|
|
637
729
|
|
|
638
730
|
def _get_normalization_loss_batch(
|
|
639
731
|
self, batch: PDENonStatioBatch
|
|
640
|
-
) ->
|
|
732
|
+
) -> tuple[
|
|
733
|
+
Float[Array, " nb_norm_time_slices 1"], Float[Array, " nb_norm_samples dim"]
|
|
734
|
+
]:
|
|
641
735
|
return (
|
|
642
736
|
batch.domain_batch[: self._max_norm_time_slices, 0:1],
|
|
643
|
-
self.norm_samples[: self._max_norm_samples_omega],
|
|
737
|
+
self.norm_samples[: self._max_norm_samples_omega], # type: ignore -> cannot narrow a class attr
|
|
644
738
|
)
|
|
645
739
|
|
|
646
740
|
def _get_observations_loss_batch(
|
|
647
741
|
self, batch: PDENonStatioBatch
|
|
648
|
-
) ->
|
|
649
|
-
return
|
|
742
|
+
) -> Float[Array, " batch_size 1+dim"]:
|
|
743
|
+
return batch.obs_batch_dict["pinn_in"]
|
|
650
744
|
|
|
651
745
|
def __call__(self, *args, **kwargs):
|
|
652
746
|
return self.evaluate(*args, **kwargs)
|
|
653
747
|
|
|
654
|
-
def
|
|
655
|
-
self, params: Params, batch: PDENonStatioBatch
|
|
656
|
-
) -> tuple[
|
|
748
|
+
def evaluate_by_terms(
|
|
749
|
+
self, params: Params[Array], batch: PDENonStatioBatch
|
|
750
|
+
) -> tuple[
|
|
751
|
+
PDENonStatioComponents[Array | None], PDENonStatioComponents[Array | None]
|
|
752
|
+
]:
|
|
657
753
|
"""
|
|
658
754
|
Evaluate the loss function at a batch of points for given parameters.
|
|
659
755
|
|
|
756
|
+
We retrieve two PyTrees with loss values and gradients for each term
|
|
660
757
|
|
|
661
758
|
Parameters
|
|
662
759
|
---------
|
|
663
760
|
params
|
|
664
761
|
Parameters at which the loss is evaluated
|
|
665
762
|
batch
|
|
666
|
-
Composed of a batch of points in
|
|
667
|
-
|
|
668
|
-
border
|
|
669
|
-
|
|
763
|
+
Composed of a batch of points in the
|
|
764
|
+
domain, a batch of points in the domain
|
|
765
|
+
border and an optional additional batch of parameters (eg. for
|
|
766
|
+
metamodeling) and an optional additional batch of observed
|
|
670
767
|
inputs/outputs/parameters
|
|
671
768
|
"""
|
|
672
769
|
omega_batch = batch.initial_batch
|
|
770
|
+
assert omega_batch is not None
|
|
673
771
|
|
|
674
772
|
# Retrieve the optional eq_params_batch
|
|
675
773
|
# and update eq_params with the latter
|
|
@@ -682,447 +780,62 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
682
780
|
|
|
683
781
|
# For mse_dyn_loss, mse_norm_loss, mse_boundary_loss,
|
|
684
782
|
# mse_observation_loss we use the evaluate from parent class
|
|
685
|
-
|
|
783
|
+
# As well as for their gradients
|
|
784
|
+
partial_mses, partial_grads = super().evaluate_by_terms(params, batch) # type: ignore
|
|
785
|
+
# ignore because batch is not PDEStatioBatch. We could use typing.cast though
|
|
686
786
|
|
|
687
787
|
# initial condition
|
|
688
788
|
if self.initial_condition_fun is not None:
|
|
689
|
-
|
|
789
|
+
mse_initial_condition_fun = lambda p: initial_condition_apply(
|
|
690
790
|
self.u,
|
|
691
791
|
omega_batch,
|
|
692
|
-
_set_derivatives(
|
|
792
|
+
_set_derivatives(p, self.derivative_keys.initial_condition), # type: ignore
|
|
693
793
|
(0,) + vmap_in_axes_params,
|
|
694
|
-
self.initial_condition_fun,
|
|
695
|
-
self.
|
|
794
|
+
self.initial_condition_fun, # type: ignore
|
|
795
|
+
self.t0, # type: ignore can't get the narrowing in __post_init__
|
|
696
796
|
)
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
# total loss
|
|
701
|
-
total_loss = partial_mse + mse_initial_condition
|
|
702
|
-
|
|
703
|
-
return total_loss, {
|
|
704
|
-
**partial_mse_terms,
|
|
705
|
-
"initial_condition": mse_initial_condition,
|
|
706
|
-
}
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
class SystemLossPDE(eqx.Module):
|
|
710
|
-
r"""
|
|
711
|
-
Class to implement a system of PDEs.
|
|
712
|
-
The goal is to give maximum freedom to the user. The class is created with
|
|
713
|
-
a dict of dynamic loss, and dictionaries of all the objects that are used
|
|
714
|
-
in LossPDENonStatio and LossPDEStatio. When then iterate
|
|
715
|
-
over the dynamic losses that compose the system. All the PINNs with all the
|
|
716
|
-
parameter dictionaries are passed as arguments to each dynamic loss
|
|
717
|
-
evaluate functions; it is inside the dynamic loss that specification are
|
|
718
|
-
performed.
|
|
719
|
-
|
|
720
|
-
**Note:** All the dictionaries (except `dynamic_loss_dict`) must have the same keys.
|
|
721
|
-
Indeed, these dictionaries (except `dynamic_loss_dict`) are tied to one
|
|
722
|
-
solution.
|
|
723
|
-
|
|
724
|
-
Parameters
|
|
725
|
-
----------
|
|
726
|
-
u_dict : Dict[str, eqx.Module]
|
|
727
|
-
dict of PINNs
|
|
728
|
-
loss_weights : LossWeightsPDEDict
|
|
729
|
-
A dictionary of LossWeightsODE
|
|
730
|
-
derivative_keys_dict : Dict[str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio], default=None
|
|
731
|
-
A dictionnary of DerivativeKeysPDEStatio or DerivativeKeysPDENonStatio
|
|
732
|
-
specifying what field of `params`
|
|
733
|
-
should be used during gradient computations for each of the terms of
|
|
734
|
-
the total loss, for each of the loss in the system. Default is
|
|
735
|
-
`"nn_params`" everywhere.
|
|
736
|
-
dynamic_loss_dict : Dict[str, PDEStatio | PDENonStatio]
|
|
737
|
-
A dict of dynamic part of the loss, basically the differential
|
|
738
|
-
operator $\mathcal{N}[u](t, x)$ or $\mathcal{N}[u](x)$.
|
|
739
|
-
key_dict : Dict[str, Key], default=None
|
|
740
|
-
A dictionary of JAX PRNG keys. The dictionary keys of key_dict must
|
|
741
|
-
match that of u_dict. See LossPDEStatio or LossPDENonStatio for
|
|
742
|
-
more details.
|
|
743
|
-
omega_boundary_fun_dict : Dict[str, Callable | Dict[str, Callable] | None], default=None
|
|
744
|
-
A dict of of function or of dict of functions or of None
|
|
745
|
-
(see doc for `omega_boundary_fun` in
|
|
746
|
-
LossPDEStatio or LossPDENonStatio). Default is None.
|
|
747
|
-
Must share the keys of `u_dict`.
|
|
748
|
-
omega_boundary_condition_dict : Dict[str, str | Dict[str, str] | None], default=None
|
|
749
|
-
A dict of strings or of dict of strings or of None
|
|
750
|
-
(see doc for `omega_boundary_condition_dict` in
|
|
751
|
-
LossPDEStatio or LossPDENonStatio). Default is None.
|
|
752
|
-
Must share the keys of `u_dict`
|
|
753
|
-
omega_boundary_dim_dict : Dict[str, slice | Dict[str, slice] | None], default=None
|
|
754
|
-
A dict of slices or of dict of slices or of None
|
|
755
|
-
(see doc for `omega_boundary_dim` in
|
|
756
|
-
LossPDEStatio or LossPDENonStatio). Default is None.
|
|
757
|
-
Must share the keys of `u_dict`
|
|
758
|
-
initial_condition_fun_dict : Dict[str, Callable | None], default=None
|
|
759
|
-
A dict of functions representing the temporal initial condition (None
|
|
760
|
-
value is possible). If None
|
|
761
|
-
(default) then no temporal boundary condition is applied
|
|
762
|
-
Must share the keys of `u_dict`
|
|
763
|
-
norm_samples_dict : Dict[str, Float[Array, "nb_norm_samples dimension"] | None, default=None
|
|
764
|
-
A dict of Monte-Carlo sample points for computing the
|
|
765
|
-
normalization constant. Default is None.
|
|
766
|
-
Must share the keys of `u_dict`
|
|
767
|
-
norm_weights_dict : Dict[str, Array[Float, "nb_norm_samples"] | float | int | None] | None, default=None
|
|
768
|
-
A dict of jnp.array with the same keys as `u_dict`. The importance
|
|
769
|
-
sampling weights for Monte-Carlo integration of the
|
|
770
|
-
normalization constant for each element of u_dict. Must be provided if
|
|
771
|
-
`norm_samples_dict` is provided.
|
|
772
|
-
`norm_weights_dict[key]` should have the same leading dimension as
|
|
773
|
-
`norm_samples_dict[key]` for each `key`.
|
|
774
|
-
Alternatively, the user can pass a float or an integer.
|
|
775
|
-
For each key, an array of similar shape to `norm_samples_dict[key]`
|
|
776
|
-
or shape `(1,)` is expected. These corresponds to the weights $w_k =
|
|
777
|
-
\frac{1}{q(x_k)}$ where $q(\cdot)$ is the proposal p.d.f. and $x_k$ are
|
|
778
|
-
the Monte-Carlo samples. Default is None
|
|
779
|
-
obs_slice_dict : Dict[str, slice | None] | None, default=None
|
|
780
|
-
dict of obs_slice, with keys from `u_dict` to designate the
|
|
781
|
-
output(s) channels that are forced to observed values, for each
|
|
782
|
-
PINNs. Default is None. But if a value is given, all the entries of
|
|
783
|
-
`u_dict` must be represented here with default value `jnp.s_[...]`
|
|
784
|
-
if no particular slice is to be given
|
|
785
|
-
params : InitVar[ParamsDict], default=None
|
|
786
|
-
The main Params object of the problem needed to instanciate the
|
|
787
|
-
DerivativeKeysODE if the latter is not specified.
|
|
788
|
-
|
|
789
|
-
"""
|
|
790
|
-
|
|
791
|
-
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
792
|
-
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
793
|
-
u_dict: Dict[str, eqx.Module]
|
|
794
|
-
dynamic_loss_dict: Dict[str, PDEStatio | PDENonStatio]
|
|
795
|
-
key_dict: Dict[str, Key] | None = eqx.field(kw_only=True, default=None)
|
|
796
|
-
derivative_keys_dict: Dict[
|
|
797
|
-
str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio | None
|
|
798
|
-
] = eqx.field(kw_only=True, default=None)
|
|
799
|
-
omega_boundary_fun_dict: Dict[str, Callable | Dict[str, Callable] | None] | None = (
|
|
800
|
-
eqx.field(kw_only=True, default=None, static=True)
|
|
801
|
-
)
|
|
802
|
-
omega_boundary_condition_dict: Dict[str, str | Dict[str, str] | None] | None = (
|
|
803
|
-
eqx.field(kw_only=True, default=None, static=True)
|
|
804
|
-
)
|
|
805
|
-
omega_boundary_dim_dict: Dict[str, slice | Dict[str, slice] | None] | None = (
|
|
806
|
-
eqx.field(kw_only=True, default=None, static=True)
|
|
807
|
-
)
|
|
808
|
-
initial_condition_fun_dict: Dict[str, Callable | None] | None = eqx.field(
|
|
809
|
-
kw_only=True, default=None, static=True
|
|
810
|
-
)
|
|
811
|
-
norm_samples_dict: Dict[str, Float[Array, "nb_norm_samples dimension"]] | None = (
|
|
812
|
-
eqx.field(kw_only=True, default=None)
|
|
813
|
-
)
|
|
814
|
-
norm_weights_dict: (
|
|
815
|
-
Dict[str, Float[Array, "nb_norm_samples dimension"] | float | int | None] | None
|
|
816
|
-
) = eqx.field(kw_only=True, default=None)
|
|
817
|
-
obs_slice_dict: Dict[str, slice | None] | None = eqx.field(
|
|
818
|
-
kw_only=True, default=None, static=True
|
|
819
|
-
)
|
|
820
|
-
|
|
821
|
-
# For the user loss_weights are passed as a LossWeightsPDEDict (with internal
|
|
822
|
-
# dictionary having keys in u_dict and / or dynamic_loss_dict)
|
|
823
|
-
loss_weights: InitVar[LossWeightsPDEDict | None] = eqx.field(
|
|
824
|
-
kw_only=True, default=None
|
|
825
|
-
)
|
|
826
|
-
params_dict: InitVar[ParamsDict] = eqx.field(kw_only=True, default=None)
|
|
827
|
-
|
|
828
|
-
# following have init=False and are set in the __post_init__
|
|
829
|
-
u_constraints_dict: Dict[str, LossPDEStatio | LossPDENonStatio] = eqx.field(
|
|
830
|
-
init=False
|
|
831
|
-
)
|
|
832
|
-
derivative_keys_dyn_loss: DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio = (
|
|
833
|
-
eqx.field(init=False)
|
|
834
|
-
)
|
|
835
|
-
u_dict_with_none: Dict[str, None] = eqx.field(init=False)
|
|
836
|
-
# internally the loss weights are handled with a dictionary
|
|
837
|
-
_loss_weights: Dict[str, dict] = eqx.field(init=False)
|
|
838
|
-
|
|
839
|
-
def __post_init__(self, loss_weights=None, params_dict=None):
|
|
840
|
-
# a dictionary that will be useful at different places
|
|
841
|
-
self.u_dict_with_none = {k: None for k in self.u_dict.keys()}
|
|
842
|
-
# First, for all the optional dict,
|
|
843
|
-
# if the user did not provide at all this optional argument,
|
|
844
|
-
# we make sure there is a null ponderating loss_weight and we
|
|
845
|
-
# create a dummy dict with the required keys and all the values to
|
|
846
|
-
# None
|
|
847
|
-
if self.key_dict is None:
|
|
848
|
-
self.key_dict = self.u_dict_with_none
|
|
849
|
-
if self.omega_boundary_fun_dict is None:
|
|
850
|
-
self.omega_boundary_fun_dict = self.u_dict_with_none
|
|
851
|
-
if self.omega_boundary_condition_dict is None:
|
|
852
|
-
self.omega_boundary_condition_dict = self.u_dict_with_none
|
|
853
|
-
if self.omega_boundary_dim_dict is None:
|
|
854
|
-
self.omega_boundary_dim_dict = self.u_dict_with_none
|
|
855
|
-
if self.initial_condition_fun_dict is None:
|
|
856
|
-
self.initial_condition_fun_dict = self.u_dict_with_none
|
|
857
|
-
if self.norm_samples_dict is None:
|
|
858
|
-
self.norm_samples_dict = self.u_dict_with_none
|
|
859
|
-
if self.norm_weights_dict is None:
|
|
860
|
-
self.norm_weights_dict = self.u_dict_with_none
|
|
861
|
-
if self.obs_slice_dict is None:
|
|
862
|
-
self.obs_slice_dict = {k: jnp.s_[...] for k in self.u_dict.keys()}
|
|
863
|
-
if self.u_dict.keys() != self.obs_slice_dict.keys():
|
|
864
|
-
raise ValueError("obs_slice_dict should have same keys as u_dict")
|
|
865
|
-
if self.derivative_keys_dict is None:
|
|
866
|
-
self.derivative_keys_dict = {
|
|
867
|
-
k: None
|
|
868
|
-
for k in set(
|
|
869
|
-
list(self.dynamic_loss_dict.keys()) + list(self.u_dict.keys())
|
|
870
|
-
)
|
|
871
|
-
}
|
|
872
|
-
# set() because we can have duplicate entries and in this case we
|
|
873
|
-
# say it corresponds to the same derivative_keys_dict entry
|
|
874
|
-
# we need both because the constraints (all but dyn_loss) will be
|
|
875
|
-
# done by iterating on u_dict while the dyn_loss will be by
|
|
876
|
-
# iterating on dynamic_loss_dict. So each time we will require dome
|
|
877
|
-
# derivative_keys_dict
|
|
878
|
-
|
|
879
|
-
# derivative keys for the u_constraints. Note that we create missing
|
|
880
|
-
# DerivativeKeysODE around a Params object and not ParamsDict
|
|
881
|
-
# this works because u_dict.keys == params_dict.nn_params.keys()
|
|
882
|
-
for k in self.u_dict.keys():
|
|
883
|
-
if self.derivative_keys_dict[k] is None:
|
|
884
|
-
if self.u_dict[k].eq_type == "statio_PDE":
|
|
885
|
-
self.derivative_keys_dict[k] = DerivativeKeysPDEStatio(
|
|
886
|
-
params=params_dict.extract_params(k)
|
|
887
|
-
)
|
|
888
|
-
else:
|
|
889
|
-
self.derivative_keys_dict[k] = DerivativeKeysPDENonStatio(
|
|
890
|
-
params=params_dict.extract_params(k)
|
|
891
|
-
)
|
|
892
|
-
|
|
893
|
-
# Second we make sure that all the dicts (except dynamic_loss_dict) have the same keys
|
|
894
|
-
if (
|
|
895
|
-
self.u_dict.keys() != self.key_dict.keys()
|
|
896
|
-
or self.u_dict.keys() != self.omega_boundary_fun_dict.keys()
|
|
897
|
-
or self.u_dict.keys() != self.omega_boundary_condition_dict.keys()
|
|
898
|
-
or self.u_dict.keys() != self.omega_boundary_dim_dict.keys()
|
|
899
|
-
or self.u_dict.keys() != self.initial_condition_fun_dict.keys()
|
|
900
|
-
or self.u_dict.keys() != self.norm_samples_dict.keys()
|
|
901
|
-
or self.u_dict.keys() != self.norm_weights_dict.keys()
|
|
902
|
-
):
|
|
903
|
-
raise ValueError("All the dicts concerning the PINNs should have same keys")
|
|
904
|
-
|
|
905
|
-
self._loss_weights = self.set_loss_weights(loss_weights)
|
|
906
|
-
|
|
907
|
-
# Third, in order not to benefit from LossPDEStatio and
|
|
908
|
-
# LossPDENonStatio and in order to factorize code, we create internally
|
|
909
|
-
# some losses object to implement the constraints on the solutions.
|
|
910
|
-
# We will not use the dynamic loss term
|
|
911
|
-
self.u_constraints_dict = {}
|
|
912
|
-
for i in self.u_dict.keys():
|
|
913
|
-
if self.u_dict[i].eq_type == "statio_PDE":
|
|
914
|
-
self.u_constraints_dict[i] = LossPDEStatio(
|
|
915
|
-
u=self.u_dict[i],
|
|
916
|
-
loss_weights=LossWeightsPDENonStatio(
|
|
917
|
-
dyn_loss=0.0,
|
|
918
|
-
norm_loss=1.0,
|
|
919
|
-
boundary_loss=1.0,
|
|
920
|
-
observations=1.0,
|
|
921
|
-
initial_condition=1.0,
|
|
922
|
-
),
|
|
923
|
-
dynamic_loss=None,
|
|
924
|
-
key=self.key_dict[i],
|
|
925
|
-
derivative_keys=self.derivative_keys_dict[i],
|
|
926
|
-
omega_boundary_fun=self.omega_boundary_fun_dict[i],
|
|
927
|
-
omega_boundary_condition=self.omega_boundary_condition_dict[i],
|
|
928
|
-
omega_boundary_dim=self.omega_boundary_dim_dict[i],
|
|
929
|
-
norm_samples=self.norm_samples_dict[i],
|
|
930
|
-
norm_weights=self.norm_weights_dict[i],
|
|
931
|
-
obs_slice=self.obs_slice_dict[i],
|
|
932
|
-
)
|
|
933
|
-
elif self.u_dict[i].eq_type == "nonstatio_PDE":
|
|
934
|
-
self.u_constraints_dict[i] = LossPDENonStatio(
|
|
935
|
-
u=self.u_dict[i],
|
|
936
|
-
loss_weights=LossWeightsPDENonStatio(
|
|
937
|
-
dyn_loss=0.0,
|
|
938
|
-
norm_loss=1.0,
|
|
939
|
-
boundary_loss=1.0,
|
|
940
|
-
observations=1.0,
|
|
941
|
-
initial_condition=1.0,
|
|
942
|
-
),
|
|
943
|
-
dynamic_loss=None,
|
|
944
|
-
key=self.key_dict[i],
|
|
945
|
-
derivative_keys=self.derivative_keys_dict[i],
|
|
946
|
-
omega_boundary_fun=self.omega_boundary_fun_dict[i],
|
|
947
|
-
omega_boundary_condition=self.omega_boundary_condition_dict[i],
|
|
948
|
-
omega_boundary_dim=self.omega_boundary_dim_dict[i],
|
|
949
|
-
initial_condition_fun=self.initial_condition_fun_dict[i],
|
|
950
|
-
norm_samples=self.norm_samples_dict[i],
|
|
951
|
-
norm_weights=self.norm_weights_dict[i],
|
|
952
|
-
obs_slice=self.obs_slice_dict[i],
|
|
953
|
-
)
|
|
954
|
-
else:
|
|
955
|
-
raise ValueError(
|
|
956
|
-
"Wrong value for self.u_dict[i].eq_type[i], "
|
|
957
|
-
f"got {self.u_dict[i].eq_type[i]}"
|
|
958
|
-
)
|
|
959
|
-
|
|
960
|
-
# derivative keys for the dynamic loss. Note that we create a
|
|
961
|
-
# DerivativeKeysODE around a ParamsDict object because a whole
|
|
962
|
-
# params_dict is feed to DynamicLoss.evaluate functions (extract_params
|
|
963
|
-
# happen inside it)
|
|
964
|
-
self.derivative_keys_dyn_loss = DerivativeKeysPDENonStatio(params=params_dict)
|
|
965
|
-
|
|
966
|
-
# also make sure we only have PINNs or SPINNs
|
|
967
|
-
if not (
|
|
968
|
-
all(isinstance(value, PINN) for value in self.u_dict.values())
|
|
969
|
-
or all(isinstance(value, SPINN) for value in self.u_dict.values())
|
|
970
|
-
):
|
|
971
|
-
raise ValueError(
|
|
972
|
-
"We only accept dictionary of PINNs or dictionary of SPINNs"
|
|
797
|
+
mse_initial_condition, grad_initial_condition = self.get_gradients(
|
|
798
|
+
mse_initial_condition_fun, params
|
|
973
799
|
)
|
|
800
|
+
else:
|
|
801
|
+
mse_initial_condition = None
|
|
802
|
+
grad_initial_condition = None
|
|
803
|
+
|
|
804
|
+
mses = PDENonStatioComponents(
|
|
805
|
+
partial_mses.dyn_loss,
|
|
806
|
+
partial_mses.norm_loss,
|
|
807
|
+
partial_mses.boundary_loss,
|
|
808
|
+
partial_mses.observations,
|
|
809
|
+
mse_initial_condition,
|
|
810
|
+
)
|
|
974
811
|
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
system... So all the transformations are handled here
|
|
983
|
-
"""
|
|
984
|
-
_loss_weights = {}
|
|
985
|
-
for k in fields(loss_weights_init):
|
|
986
|
-
v = getattr(loss_weights_init, k.name)
|
|
987
|
-
if isinstance(v, dict):
|
|
988
|
-
for vv in v.keys():
|
|
989
|
-
if not isinstance(vv, (int, float)) and not (
|
|
990
|
-
isinstance(vv, Array)
|
|
991
|
-
and ((vv.shape == (1,) or len(vv.shape) == 0))
|
|
992
|
-
):
|
|
993
|
-
# TODO improve that
|
|
994
|
-
raise ValueError(
|
|
995
|
-
f"loss values cannot be vectorial here, got {vv}"
|
|
996
|
-
)
|
|
997
|
-
if k.name == "dyn_loss":
|
|
998
|
-
if v.keys() == self.dynamic_loss_dict.keys():
|
|
999
|
-
_loss_weights[k.name] = v
|
|
1000
|
-
else:
|
|
1001
|
-
raise ValueError(
|
|
1002
|
-
"Keys in nested dictionary of loss_weights"
|
|
1003
|
-
" do not match dynamic_loss_dict keys"
|
|
1004
|
-
)
|
|
1005
|
-
else:
|
|
1006
|
-
if v.keys() == self.u_dict.keys():
|
|
1007
|
-
_loss_weights[k.name] = v
|
|
1008
|
-
else:
|
|
1009
|
-
raise ValueError(
|
|
1010
|
-
"Keys in nested dictionary of loss_weights"
|
|
1011
|
-
" do not match u_dict keys"
|
|
1012
|
-
)
|
|
1013
|
-
if v is None:
|
|
1014
|
-
_loss_weights[k.name] = {kk: 0 for kk in self.u_dict.keys()}
|
|
1015
|
-
else:
|
|
1016
|
-
if not isinstance(v, (int, float)) and not (
|
|
1017
|
-
isinstance(v, Array) and ((v.shape == (1,) or len(v.shape) == 0))
|
|
1018
|
-
):
|
|
1019
|
-
# TODO improve that
|
|
1020
|
-
raise ValueError(f"loss values cannot be vectorial here, got {v}")
|
|
1021
|
-
if k.name == "dyn_loss":
|
|
1022
|
-
_loss_weights[k.name] = {
|
|
1023
|
-
kk: v for kk in self.dynamic_loss_dict.keys()
|
|
1024
|
-
}
|
|
1025
|
-
else:
|
|
1026
|
-
_loss_weights[k.name] = {kk: v for kk in self.u_dict.keys()}
|
|
1027
|
-
return _loss_weights
|
|
812
|
+
grads = PDENonStatioComponents(
|
|
813
|
+
partial_grads.dyn_loss,
|
|
814
|
+
partial_grads.norm_loss,
|
|
815
|
+
partial_grads.boundary_loss,
|
|
816
|
+
partial_grads.observations,
|
|
817
|
+
grad_initial_condition,
|
|
818
|
+
)
|
|
1028
819
|
|
|
1029
|
-
|
|
1030
|
-
return self.evaluate(*args, **kwargs)
|
|
820
|
+
return mses, grads
|
|
1031
821
|
|
|
1032
822
|
def evaluate(
|
|
1033
|
-
self,
|
|
1034
|
-
|
|
1035
|
-
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
1036
|
-
) -> tuple[Float[Array, "1"], dict[str, float]]:
|
|
823
|
+
self, params: Params[Array], batch: PDENonStatioBatch
|
|
824
|
+
) -> tuple[Float[Array, " "], PDENonStatioComponents[Float[Array, " "] | None]]:
|
|
1037
825
|
"""
|
|
1038
826
|
Evaluate the loss function at a batch of points for given parameters.
|
|
827
|
+
We retrieve the total value itself and a PyTree with loss values for each term
|
|
1039
828
|
|
|
1040
829
|
|
|
1041
830
|
Parameters
|
|
1042
831
|
---------
|
|
1043
|
-
|
|
1044
|
-
Parameters at which the
|
|
832
|
+
params
|
|
833
|
+
Parameters at which the loss is evaluated
|
|
1045
834
|
batch
|
|
1046
|
-
|
|
1047
|
-
domain, a batch of points in the domain
|
|
1048
|
-
border,
|
|
1049
|
-
|
|
1050
|
-
and an optional additional batch of observed
|
|
835
|
+
Composed of a batch of points in
|
|
836
|
+
the domain, a batch of points in the domain
|
|
837
|
+
border, a batch of time points and an optional additional batch
|
|
838
|
+
of parameters (eg. for metamodeling) and an optional additional batch of observed
|
|
1051
839
|
inputs/outputs/parameters
|
|
1052
840
|
"""
|
|
1053
|
-
|
|
1054
|
-
raise ValueError("u_dict and params_dict[nn_params] should have same keys ")
|
|
1055
|
-
|
|
1056
|
-
vmap_in_axes = (0,)
|
|
1057
|
-
|
|
1058
|
-
# Retrieve the optional eq_params_batch
|
|
1059
|
-
# and update eq_params with the latter
|
|
1060
|
-
# and update vmap_in_axes
|
|
1061
|
-
if batch.param_batch_dict is not None:
|
|
1062
|
-
eq_params_batch_dict = batch.param_batch_dict
|
|
1063
|
-
|
|
1064
|
-
# feed the eq_params with the batch
|
|
1065
|
-
for k in eq_params_batch_dict.keys():
|
|
1066
|
-
params_dict.eq_params[k] = eq_params_batch_dict[k]
|
|
1067
|
-
|
|
1068
|
-
vmap_in_axes_params = _get_vmap_in_axes_params(
|
|
1069
|
-
batch.param_batch_dict, params_dict
|
|
1070
|
-
)
|
|
1071
|
-
|
|
1072
|
-
def dyn_loss_for_one_key(dyn_loss, loss_weight):
|
|
1073
|
-
"""The function used in tree_map"""
|
|
1074
|
-
return dynamic_loss_apply(
|
|
1075
|
-
dyn_loss.evaluate,
|
|
1076
|
-
self.u_dict,
|
|
1077
|
-
(
|
|
1078
|
-
batch.domain_batch
|
|
1079
|
-
if isinstance(batch, PDEStatioBatch)
|
|
1080
|
-
else batch.domain_batch
|
|
1081
|
-
),
|
|
1082
|
-
_set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
|
|
1083
|
-
vmap_in_axes + vmap_in_axes_params,
|
|
1084
|
-
loss_weight,
|
|
1085
|
-
u_type=list(self.u_dict.values())[0].__class__.__base__,
|
|
1086
|
-
)
|
|
1087
|
-
|
|
1088
|
-
dyn_loss_mse_dict = jax.tree_util.tree_map(
|
|
1089
|
-
dyn_loss_for_one_key,
|
|
1090
|
-
self.dynamic_loss_dict,
|
|
1091
|
-
self._loss_weights["dyn_loss"],
|
|
1092
|
-
is_leaf=lambda x: isinstance(
|
|
1093
|
-
x, (PDEStatio, PDENonStatio)
|
|
1094
|
-
), # before when dynamic losses
|
|
1095
|
-
# where plain (unregister pytree) node classes, we could not traverse
|
|
1096
|
-
# this level. Now that dynamic losses are eqx.Module they can be
|
|
1097
|
-
# traversed by tree map recursion. Hence we need to specify to that
|
|
1098
|
-
# we want to stop at this level
|
|
1099
|
-
)
|
|
1100
|
-
mse_dyn_loss = jax.tree_util.tree_reduce(
|
|
1101
|
-
lambda x, y: x + y, jax.tree_util.tree_leaves(dyn_loss_mse_dict)
|
|
1102
|
-
)
|
|
1103
|
-
|
|
1104
|
-
# boundary conditions, normalization conditions, observation_loss,
|
|
1105
|
-
# initial condition... loss this is done via the internal
|
|
1106
|
-
# LossPDEStatio and NonStatio
|
|
1107
|
-
loss_weight_struct = {
|
|
1108
|
-
"dyn_loss": "*",
|
|
1109
|
-
"norm_loss": "*",
|
|
1110
|
-
"boundary_loss": "*",
|
|
1111
|
-
"observations": "*",
|
|
1112
|
-
"initial_condition": "*",
|
|
1113
|
-
}
|
|
1114
|
-
# we need to do the following for the tree_mapping to work
|
|
1115
|
-
if batch.obs_batch_dict is None:
|
|
1116
|
-
batch = append_obs_batch(batch, self.u_dict_with_none)
|
|
1117
|
-
total_loss, res_dict = constraints_system_loss_apply(
|
|
1118
|
-
self.u_constraints_dict,
|
|
1119
|
-
batch,
|
|
1120
|
-
params_dict,
|
|
1121
|
-
self._loss_weights,
|
|
1122
|
-
loss_weight_struct,
|
|
1123
|
-
)
|
|
1124
|
-
|
|
1125
|
-
# Add the mse_dyn_loss from the previous computations
|
|
1126
|
-
total_loss += mse_dyn_loss
|
|
1127
|
-
res_dict["dyn_loss"] += mse_dyn_loss
|
|
1128
|
-
return total_loss, res_dict
|
|
841
|
+
return super().evaluate(params, batch) # type: ignore
|