jinns 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +114 -187
- jinns/loss/_DynamicLossAbstract.py +74 -69
- jinns/loss/_LossODE.py +132 -348
- jinns/loss/_LossPDE.py +262 -549
- jinns/loss/__init__.py +32 -6
- jinns/loss/_abstract_loss.py +128 -0
- jinns/loss/_boundary_conditions.py +20 -19
- jinns/loss/_loss_components.py +43 -0
- jinns/loss/_loss_utils.py +85 -179
- jinns/loss/_loss_weight_updates.py +202 -0
- jinns/loss/_loss_weights.py +64 -40
- jinns/loss/_operators.py +84 -74
- jinns/nn/__init__.py +15 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +94 -57
- jinns/nn/_mlp.py +50 -25
- jinns/nn/_pinn.py +33 -19
- jinns/nn/_ppinn.py +70 -34
- jinns/nn/_save_load.py +21 -51
- jinns/nn/_spinn.py +33 -16
- jinns/nn/_spinn_mlp.py +28 -22
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +35 -34
- jinns/solver/_rar.py +80 -63
- jinns/solver/_solve.py +207 -92
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +16 -10
- jinns/utils/_types.py +20 -54
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/METADATA +8 -4
- jinns-1.5.0.dist-info/RECORD +55 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns-1.3.0.dist-info/RECORD +0 -44
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossODE.py
CHANGED
|
@@ -1,24 +1,23 @@
|
|
|
1
|
-
# pylint: disable=unsubscriptable-object, no-member
|
|
2
1
|
"""
|
|
3
2
|
Main module to implement a ODE loss in jinns
|
|
4
3
|
"""
|
|
4
|
+
|
|
5
5
|
from __future__ import (
|
|
6
6
|
annotations,
|
|
7
7
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
8
|
|
|
9
|
-
from dataclasses import InitVar
|
|
10
|
-
from typing import TYPE_CHECKING,
|
|
9
|
+
from dataclasses import InitVar
|
|
10
|
+
from typing import TYPE_CHECKING, TypedDict, Callable
|
|
11
|
+
from types import EllipsisType
|
|
11
12
|
import abc
|
|
12
13
|
import warnings
|
|
13
14
|
import jax
|
|
14
15
|
import jax.numpy as jnp
|
|
15
16
|
from jax import vmap
|
|
16
17
|
import equinox as eqx
|
|
17
|
-
from jaxtyping import Float, Array
|
|
18
|
-
from jinns.data._DataGenerators import append_obs_batch
|
|
18
|
+
from jaxtyping import Float, Array
|
|
19
19
|
from jinns.loss._loss_utils import (
|
|
20
20
|
dynamic_loss_apply,
|
|
21
|
-
constraints_system_loss_apply,
|
|
22
21
|
observations_loss_apply,
|
|
23
22
|
)
|
|
24
23
|
from jinns.parameters._params import (
|
|
@@ -26,37 +25,49 @@ from jinns.parameters._params import (
|
|
|
26
25
|
_update_eq_params_dict,
|
|
27
26
|
)
|
|
28
27
|
from jinns.parameters._derivative_keys import _set_derivatives, DerivativeKeysODE
|
|
29
|
-
from jinns.loss._loss_weights import LossWeightsODE
|
|
30
|
-
from jinns.loss.
|
|
31
|
-
from jinns.
|
|
28
|
+
from jinns.loss._loss_weights import LossWeightsODE
|
|
29
|
+
from jinns.loss._abstract_loss import AbstractLoss
|
|
30
|
+
from jinns.loss._loss_components import ODEComponents
|
|
31
|
+
from jinns.parameters._params import Params
|
|
32
32
|
|
|
33
33
|
if TYPE_CHECKING:
|
|
34
|
-
|
|
34
|
+
# imports only used in type hints
|
|
35
|
+
from jinns.data._Batchs import ODEBatch
|
|
36
|
+
from jinns.nn._abstract_pinn import AbstractPINN
|
|
37
|
+
from jinns.loss import ODE
|
|
38
|
+
|
|
39
|
+
class LossDictODE(TypedDict):
|
|
40
|
+
dyn_loss: Float[Array, " "]
|
|
41
|
+
initial_condition: Float[Array, " "]
|
|
42
|
+
observations: Float[Array, " "]
|
|
35
43
|
|
|
36
44
|
|
|
37
|
-
class _LossODEAbstract(
|
|
45
|
+
class _LossODEAbstract(AbstractLoss):
|
|
38
46
|
"""
|
|
39
47
|
Parameters
|
|
40
48
|
----------
|
|
41
49
|
|
|
42
50
|
loss_weights : LossWeightsODE, default=None
|
|
43
51
|
The loss weights for the differents term : dynamic loss,
|
|
44
|
-
initial condition and eventually observations if any.
|
|
45
|
-
|
|
52
|
+
initial condition and eventually observations if any.
|
|
53
|
+
Can be updated according to a specific algorithm. See
|
|
54
|
+
`update_weight_method`
|
|
55
|
+
update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
|
|
56
|
+
Default is None meaning no update for loss weights. Otherwise a string
|
|
46
57
|
derivative_keys : DerivativeKeysODE, default=None
|
|
47
58
|
Specify which field of `params` should be differentiated for each
|
|
48
59
|
composant of the total loss. Particularily useful for inverse problems.
|
|
49
60
|
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
50
61
|
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
51
62
|
is `"nn_params"` for each composant of the loss.
|
|
52
|
-
initial_condition : tuple, default=None
|
|
63
|
+
initial_condition : tuple[float | Float[Array, " 1"], Float[Array, " dim"]], default=None
|
|
53
64
|
tuple of length 2 with initial condition $(t_0, u_0)$.
|
|
54
|
-
obs_slice :
|
|
65
|
+
obs_slice : EllipsisType | slice | None, default=None
|
|
55
66
|
Slice object specifying the begininning/ending
|
|
56
67
|
slice of u output(s) that is observed. This is useful for
|
|
57
68
|
multidimensional PINN, with partially observed outputs.
|
|
58
69
|
Default is None (whole output is observed).
|
|
59
|
-
params : InitVar[Params], default=None
|
|
70
|
+
params : InitVar[Params[Array]], default=None
|
|
60
71
|
The main Params object of the problem needed to instanciate the
|
|
61
72
|
DerivativeKeysODE if the latter is not specified.
|
|
62
73
|
"""
|
|
@@ -66,24 +77,27 @@ class _LossODEAbstract(eqx.Module):
|
|
|
66
77
|
# kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
|
|
67
78
|
derivative_keys: DerivativeKeysODE | None = eqx.field(kw_only=True, default=None)
|
|
68
79
|
loss_weights: LossWeightsODE | None = eqx.field(kw_only=True, default=None)
|
|
69
|
-
initial_condition:
|
|
70
|
-
|
|
80
|
+
initial_condition: (
|
|
81
|
+
tuple[float | Float[Array, " 1"], Float[Array, " dim"]] | None
|
|
82
|
+
) = eqx.field(kw_only=True, default=None)
|
|
83
|
+
obs_slice: EllipsisType | slice | None = eqx.field(
|
|
84
|
+
kw_only=True, default=None, static=True
|
|
85
|
+
)
|
|
71
86
|
|
|
72
|
-
params: InitVar[Params] = eqx.field(default=None, kw_only=True)
|
|
87
|
+
params: InitVar[Params[Array]] = eqx.field(default=None, kw_only=True)
|
|
73
88
|
|
|
74
|
-
def __post_init__(self, params=None):
|
|
89
|
+
def __post_init__(self, params: Params[Array] | None = None):
|
|
75
90
|
if self.loss_weights is None:
|
|
76
91
|
self.loss_weights = LossWeightsODE()
|
|
77
92
|
|
|
78
93
|
if self.derivative_keys is None:
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
self.derivative_keys = DerivativeKeysODE(params=params)
|
|
82
|
-
except ValueError as exc:
|
|
94
|
+
# by default we only take gradient wrt nn_params
|
|
95
|
+
if params is None:
|
|
83
96
|
raise ValueError(
|
|
84
97
|
"Problem at self.derivative_keys initialization "
|
|
85
98
|
f"received {self.derivative_keys=} and {params=}"
|
|
86
|
-
)
|
|
99
|
+
)
|
|
100
|
+
self.derivative_keys = DerivativeKeysODE(params=params)
|
|
87
101
|
if self.initial_condition is None:
|
|
88
102
|
warnings.warn(
|
|
89
103
|
"Initial condition wasn't provided. Be sure to cover for that"
|
|
@@ -98,6 +112,21 @@ class _LossODEAbstract(eqx.Module):
|
|
|
98
112
|
"Initial condition should be a tuple of len 2 with (t0, u0), "
|
|
99
113
|
f"{self.initial_condition} was passed."
|
|
100
114
|
)
|
|
115
|
+
# some checks/reshaping for t0
|
|
116
|
+
t0, u0 = self.initial_condition
|
|
117
|
+
if isinstance(t0, Array):
|
|
118
|
+
if not t0.shape: # e.g. user input: jnp.array(0.)
|
|
119
|
+
t0 = jnp.array([t0])
|
|
120
|
+
elif t0.shape != (1,):
|
|
121
|
+
raise ValueError(
|
|
122
|
+
f"Wrong t0 input (self.initial_condition[0]) It should be"
|
|
123
|
+
f"a float or an array of shape (1,). Got shape: {t0.shape}"
|
|
124
|
+
)
|
|
125
|
+
if isinstance(t0, float): # e.g. user input: 0.
|
|
126
|
+
t0 = jnp.array([t0])
|
|
127
|
+
if isinstance(t0, int): # e.g. user input: 0
|
|
128
|
+
t0 = jnp.array([float(t0)])
|
|
129
|
+
self.initial_condition = (t0, u0)
|
|
101
130
|
|
|
102
131
|
if self.obs_slice is None:
|
|
103
132
|
self.obs_slice = jnp.s_[...]
|
|
@@ -105,10 +134,14 @@ class _LossODEAbstract(eqx.Module):
|
|
|
105
134
|
if self.loss_weights is None:
|
|
106
135
|
self.loss_weights = LossWeightsODE()
|
|
107
136
|
|
|
137
|
+
@abc.abstractmethod
|
|
138
|
+
def __call__(self, *_, **__):
|
|
139
|
+
pass
|
|
140
|
+
|
|
108
141
|
@abc.abstractmethod
|
|
109
142
|
def evaluate(
|
|
110
|
-
self: eqx.Module, params: Params, batch: ODEBatch
|
|
111
|
-
) -> tuple[Float,
|
|
143
|
+
self: eqx.Module, params: Params[Array], batch: ODEBatch
|
|
144
|
+
) -> tuple[Float[Array, " "], LossDictODE]:
|
|
112
145
|
raise NotImplementedError
|
|
113
146
|
|
|
114
147
|
|
|
@@ -127,27 +160,30 @@ class LossODE(_LossODEAbstract):
|
|
|
127
160
|
----------
|
|
128
161
|
loss_weights : LossWeightsODE, default=None
|
|
129
162
|
The loss weights for the differents term : dynamic loss,
|
|
130
|
-
initial condition and eventually observations if any.
|
|
131
|
-
|
|
163
|
+
initial condition and eventually observations if any.
|
|
164
|
+
Can be updated according to a specific algorithm. See
|
|
165
|
+
`update_weight_method`
|
|
166
|
+
update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
|
|
167
|
+
Default is None meaning no update for loss weights. Otherwise a string
|
|
132
168
|
derivative_keys : DerivativeKeysODE, default=None
|
|
133
169
|
Specify which field of `params` should be differentiated for each
|
|
134
170
|
composant of the total loss. Particularily useful for inverse problems.
|
|
135
171
|
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
136
172
|
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
137
173
|
is `"nn_params"` for each composant of the loss.
|
|
138
|
-
initial_condition : tuple, default=None
|
|
174
|
+
initial_condition : tuple[float | Float[Array, " 1"]], default=None
|
|
139
175
|
tuple of length 2 with initial condition $(t_0, u_0)$.
|
|
140
|
-
obs_slice
|
|
176
|
+
obs_slice : EllipsisType | slice | None, default=None
|
|
141
177
|
Slice object specifying the begininning/ending
|
|
142
178
|
slice of u output(s) that is observed. This is useful for
|
|
143
179
|
multidimensional PINN, with partially observed outputs.
|
|
144
180
|
Default is None (whole output is observed).
|
|
145
|
-
params : InitVar[Params], default=None
|
|
181
|
+
params : InitVar[Params[Array]], default=None
|
|
146
182
|
The main Params object of the problem needed to instanciate the
|
|
147
183
|
DerivativeKeysODE if the latter is not specified.
|
|
148
184
|
u : eqx.Module
|
|
149
185
|
the PINN
|
|
150
|
-
dynamic_loss :
|
|
186
|
+
dynamic_loss : ODE
|
|
151
187
|
the ODE dynamic part of the loss, basically the differential
|
|
152
188
|
operator $\mathcal{N}[u](t)$. Should implement a method
|
|
153
189
|
`dynamic_loss.evaluate(t, u, params)`.
|
|
@@ -161,12 +197,12 @@ class LossODE(_LossODEAbstract):
|
|
|
161
197
|
|
|
162
198
|
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
163
199
|
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
164
|
-
u:
|
|
165
|
-
dynamic_loss:
|
|
200
|
+
u: AbstractPINN
|
|
201
|
+
dynamic_loss: ODE | None
|
|
166
202
|
|
|
167
|
-
vmap_in_axes: tuple[
|
|
203
|
+
vmap_in_axes: tuple[int] = eqx.field(init=False, static=True)
|
|
168
204
|
|
|
169
|
-
def __post_init__(self, params=None):
|
|
205
|
+
def __post_init__(self, params: Params[Array] | None = None):
|
|
170
206
|
super().__post_init__(
|
|
171
207
|
params=params
|
|
172
208
|
) # because __init__ or __post_init__ of Base
|
|
@@ -177,21 +213,26 @@ class LossODE(_LossODEAbstract):
|
|
|
177
213
|
def __call__(self, *args, **kwargs):
|
|
178
214
|
return self.evaluate(*args, **kwargs)
|
|
179
215
|
|
|
180
|
-
def
|
|
181
|
-
self, params: Params, batch: ODEBatch
|
|
182
|
-
) -> tuple[
|
|
216
|
+
def evaluate_by_terms(
|
|
217
|
+
self, params: Params[Array], batch: ODEBatch
|
|
218
|
+
) -> tuple[
|
|
219
|
+
ODEComponents[Float[Array, " "] | None], ODEComponents[Float[Array, " "] | None]
|
|
220
|
+
]:
|
|
183
221
|
"""
|
|
184
222
|
Evaluate the loss function at a batch of points for given parameters.
|
|
185
223
|
|
|
224
|
+
We retrieve two PyTrees with loss values and gradients for each term
|
|
186
225
|
|
|
187
226
|
Parameters
|
|
188
227
|
---------
|
|
189
228
|
params
|
|
190
229
|
Parameters at which the loss is evaluated
|
|
191
230
|
batch
|
|
192
|
-
Composed of a batch of
|
|
193
|
-
|
|
194
|
-
|
|
231
|
+
Composed of a batch of points in the
|
|
232
|
+
domain, a batch of points in the domain
|
|
233
|
+
border and an optional additional batch of parameters (eg. for
|
|
234
|
+
metamodeling) and an optional additional batch of observed
|
|
235
|
+
inputs/outputs/parameters
|
|
195
236
|
"""
|
|
196
237
|
temporal_batch = batch.temporal_batch
|
|
197
238
|
|
|
@@ -206,16 +247,15 @@ class LossODE(_LossODEAbstract):
|
|
|
206
247
|
|
|
207
248
|
## dynamic part
|
|
208
249
|
if self.dynamic_loss is not None:
|
|
209
|
-
|
|
210
|
-
self.dynamic_loss.evaluate,
|
|
250
|
+
dyn_loss_fun = lambda p: dynamic_loss_apply(
|
|
251
|
+
self.dynamic_loss.evaluate, # type: ignore
|
|
211
252
|
self.u,
|
|
212
253
|
temporal_batch,
|
|
213
|
-
_set_derivatives(
|
|
254
|
+
_set_derivatives(p, self.derivative_keys.dyn_loss), # type: ignore
|
|
214
255
|
self.vmap_in_axes + vmap_in_axes_params,
|
|
215
|
-
self.loss_weights.dyn_loss,
|
|
216
256
|
)
|
|
217
257
|
else:
|
|
218
|
-
|
|
258
|
+
dyn_loss_fun = None
|
|
219
259
|
|
|
220
260
|
# initial condition
|
|
221
261
|
if self.initial_condition is not None:
|
|
@@ -226,18 +266,14 @@ class LossODE(_LossODEAbstract):
|
|
|
226
266
|
v_u = self.u
|
|
227
267
|
else:
|
|
228
268
|
v_u = vmap(self.u, (None,) + vmap_in_axes_params)
|
|
229
|
-
t0, u0 = self.initial_condition
|
|
230
|
-
t0 = jnp.array([t0])
|
|
269
|
+
t0, u0 = self.initial_condition
|
|
231
270
|
u0 = jnp.array(u0)
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
* jnp.sum(
|
|
271
|
+
initial_condition_fun = lambda p: jnp.mean(
|
|
272
|
+
jnp.sum(
|
|
235
273
|
(
|
|
236
274
|
v_u(
|
|
237
275
|
t0,
|
|
238
|
-
_set_derivatives(
|
|
239
|
-
params, self.derivative_keys.initial_condition
|
|
240
|
-
),
|
|
276
|
+
_set_derivatives(p, self.derivative_keys.initial_condition), # type: ignore
|
|
241
277
|
)
|
|
242
278
|
- u0
|
|
243
279
|
)
|
|
@@ -246,323 +282,71 @@ class LossODE(_LossODEAbstract):
|
|
|
246
282
|
)
|
|
247
283
|
)
|
|
248
284
|
else:
|
|
249
|
-
|
|
285
|
+
initial_condition_fun = None
|
|
250
286
|
|
|
251
287
|
if batch.obs_batch_dict is not None:
|
|
252
288
|
# update params with the batches of observed params
|
|
253
|
-
|
|
289
|
+
params_obs = _update_eq_params_dict(
|
|
290
|
+
params, batch.obs_batch_dict["eq_params"]
|
|
291
|
+
)
|
|
254
292
|
|
|
255
293
|
# MSE loss wrt to an observed batch
|
|
256
|
-
|
|
294
|
+
obs_loss_fun = lambda po: observations_loss_apply(
|
|
257
295
|
self.u,
|
|
258
|
-
|
|
259
|
-
_set_derivatives(
|
|
296
|
+
batch.obs_batch_dict["pinn_in"],
|
|
297
|
+
_set_derivatives(po, self.derivative_keys.observations), # type: ignore
|
|
260
298
|
self.vmap_in_axes + vmap_in_axes_params,
|
|
261
299
|
batch.obs_batch_dict["val"],
|
|
262
|
-
self.loss_weights.observations,
|
|
263
300
|
self.obs_slice,
|
|
264
301
|
)
|
|
265
302
|
else:
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
# total loss
|
|
269
|
-
total_loss = mse_dyn_loss + mse_initial_condition + mse_observation_loss
|
|
270
|
-
return total_loss, (
|
|
271
|
-
{
|
|
272
|
-
"dyn_loss": mse_dyn_loss,
|
|
273
|
-
"initial_condition": mse_initial_condition,
|
|
274
|
-
"observations": mse_observation_loss,
|
|
275
|
-
}
|
|
276
|
-
)
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
class SystemLossODE(eqx.Module):
|
|
280
|
-
r"""
|
|
281
|
-
Class to implement a system of ODEs.
|
|
282
|
-
The goal is to give maximum freedom to the user. The class is created with
|
|
283
|
-
a dict of dynamic loss and a dict of initial conditions. Then, it iterates
|
|
284
|
-
over the dynamic losses that compose the system. All PINNs are passed as
|
|
285
|
-
arguments to each dynamic loss evaluate functions, along with all the
|
|
286
|
-
parameter dictionaries. All specification is left to the responsability
|
|
287
|
-
of the user, inside the dynamic loss.
|
|
288
|
-
|
|
289
|
-
**Note:** All the dictionaries (except `dynamic_loss_dict`) must have the same keys.
|
|
290
|
-
Indeed, these dictionaries (except `dynamic_loss_dict`) are tied to one
|
|
291
|
-
solution.
|
|
303
|
+
params_obs = None
|
|
304
|
+
obs_loss_fun = None
|
|
292
305
|
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
Must share the keys of `u_dict`. Default is None. No initial
|
|
307
|
-
condition is permitted when the initial condition is hardcoded in
|
|
308
|
-
the PINN architecture for example
|
|
309
|
-
dynamic_loss_dict : Dict[str, ODE]
|
|
310
|
-
dict of dynamic part of the loss, basically the differential
|
|
311
|
-
operator $\mathcal{N}[u](t)$. Should implement a method
|
|
312
|
-
`dynamic_loss.evaluate(t, u, params)`
|
|
313
|
-
obs_slice_dict : Dict[str, Slice]
|
|
314
|
-
dict of obs_slice, with keys from `u_dict` to designate the
|
|
315
|
-
output(s) channels that are observed, for each
|
|
316
|
-
PINNs. Default is None. But if a value is given, all the entries of
|
|
317
|
-
`u_dict` must be represented here with default value `jnp.s_[...]`
|
|
318
|
-
if no particular slice is to be given.
|
|
319
|
-
params_dict : InitVar[ParamsDict], default=None
|
|
320
|
-
The main Params object of the problem needed to instanciate the
|
|
321
|
-
DerivativeKeysODE if the latter is not specified.
|
|
322
|
-
|
|
323
|
-
Raises
|
|
324
|
-
------
|
|
325
|
-
ValueError
|
|
326
|
-
if initial condition is not a dict of tuple.
|
|
327
|
-
ValueError
|
|
328
|
-
if the dictionaries that should share the keys of u_dict do not.
|
|
329
|
-
"""
|
|
330
|
-
|
|
331
|
-
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
332
|
-
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
333
|
-
u_dict: Dict[str, eqx.Module]
|
|
334
|
-
dynamic_loss_dict: Dict[str, ODE]
|
|
335
|
-
derivative_keys_dict: Dict[str, DerivativeKeysODE | None] | None = eqx.field(
|
|
336
|
-
kw_only=True, default=None
|
|
337
|
-
)
|
|
338
|
-
initial_condition_dict: Dict[str, tuple] | None = eqx.field(
|
|
339
|
-
kw_only=True, default=None
|
|
340
|
-
)
|
|
341
|
-
|
|
342
|
-
obs_slice_dict: Dict[str, slice | None] | None = eqx.field(
|
|
343
|
-
kw_only=True, default=None, static=True
|
|
344
|
-
) # We are at an "leaf" attribute here (slice, not valid JAX type). Since
|
|
345
|
-
# we do not expect it to change with put a static=True here. But note that
|
|
346
|
-
# this is the only static for all the SystemLossODE attribute, since all
|
|
347
|
-
# other are composed of more complex structures ("non-leaf")
|
|
348
|
-
|
|
349
|
-
# For the user loss_weights are passed as a LossWeightsODEDict (with internal
|
|
350
|
-
# dictionary having keys in u_dict and / or dynamic_loss_dict)
|
|
351
|
-
loss_weights: InitVar[LossWeightsODEDict | None] = eqx.field(
|
|
352
|
-
kw_only=True, default=None
|
|
353
|
-
)
|
|
354
|
-
params_dict: InitVar[ParamsDict] = eqx.field(kw_only=True, default=None)
|
|
355
|
-
|
|
356
|
-
u_constraints_dict: Dict[str, LossODE] = eqx.field(init=False)
|
|
357
|
-
derivative_keys_dyn_loss: DerivativeKeysODE = eqx.field(init=False)
|
|
358
|
-
|
|
359
|
-
u_dict_with_none: Dict[str, None] = eqx.field(init=False)
|
|
360
|
-
# internally the loss weights are handled with a dictionary
|
|
361
|
-
_loss_weights: Dict[str, dict] = eqx.field(init=False)
|
|
362
|
-
|
|
363
|
-
def __post_init__(self, loss_weights=None, params_dict=None):
|
|
364
|
-
# a dictionary that will be useful at different places
|
|
365
|
-
self.u_dict_with_none = {k: None for k in self.u_dict.keys()}
|
|
366
|
-
if self.initial_condition_dict is None:
|
|
367
|
-
self.initial_condition_dict = self.u_dict_with_none
|
|
368
|
-
else:
|
|
369
|
-
if self.u_dict.keys() != self.initial_condition_dict.keys():
|
|
370
|
-
raise ValueError(
|
|
371
|
-
"initial_condition_dict should have same keys as u_dict"
|
|
372
|
-
)
|
|
373
|
-
if self.obs_slice_dict is None:
|
|
374
|
-
self.obs_slice_dict = {k: jnp.s_[...] for k in self.u_dict.keys()}
|
|
375
|
-
else:
|
|
376
|
-
if self.u_dict.keys() != self.obs_slice_dict.keys():
|
|
377
|
-
raise ValueError("obs_slice_dict should have same keys as u_dict")
|
|
378
|
-
|
|
379
|
-
if self.derivative_keys_dict is None:
|
|
380
|
-
self.derivative_keys_dict = {
|
|
381
|
-
k: None
|
|
382
|
-
for k in set(
|
|
383
|
-
list(self.dynamic_loss_dict.keys()) + list(self.u_dict.keys())
|
|
384
|
-
)
|
|
385
|
-
}
|
|
386
|
-
# set() because we can have duplicate entries and in this case we
|
|
387
|
-
# say it corresponds to the same derivative_keys_dict entry
|
|
388
|
-
# we need both because the constraints (all but dyn_loss) will be
|
|
389
|
-
# done by iterating on u_dict while the dyn_loss will be by
|
|
390
|
-
# iterating on dynamic_loss_dict. So each time we will require dome
|
|
391
|
-
# derivative_keys_dict
|
|
392
|
-
|
|
393
|
-
# derivative keys for the u_constraints. Note that we create missing
|
|
394
|
-
# DerivativeKeysODE around a Params object and not ParamsDict
|
|
395
|
-
# this works because u_dict.keys == params_dict.nn_params.keys()
|
|
396
|
-
for k in self.u_dict.keys():
|
|
397
|
-
if self.derivative_keys_dict[k] is None:
|
|
398
|
-
self.derivative_keys_dict[k] = DerivativeKeysODE(
|
|
399
|
-
params=params_dict.extract_params(k)
|
|
400
|
-
)
|
|
401
|
-
|
|
402
|
-
self._loss_weights = self.set_loss_weights(loss_weights)
|
|
403
|
-
|
|
404
|
-
# The constaints on the solutions will be implemented by reusing a
|
|
405
|
-
# LossODE class without dynamic loss term
|
|
406
|
-
self.u_constraints_dict = {}
|
|
407
|
-
for i in self.u_dict.keys():
|
|
408
|
-
self.u_constraints_dict[i] = LossODE(
|
|
409
|
-
u=self.u_dict[i],
|
|
410
|
-
loss_weights=LossWeightsODE(
|
|
411
|
-
dyn_loss=0.0,
|
|
412
|
-
initial_condition=1.0,
|
|
413
|
-
observations=1.0,
|
|
414
|
-
),
|
|
415
|
-
dynamic_loss=None,
|
|
416
|
-
derivative_keys=self.derivative_keys_dict[i],
|
|
417
|
-
initial_condition=self.initial_condition_dict[i],
|
|
418
|
-
obs_slice=self.obs_slice_dict[i],
|
|
419
|
-
)
|
|
420
|
-
|
|
421
|
-
# derivative keys for the dynamic loss. Note that we create a
|
|
422
|
-
# DerivativeKeysODE around a ParamsDict object because a whole
|
|
423
|
-
# params_dict is feed to DynamicLoss.evaluate functions (extract_params
|
|
424
|
-
# happen inside it)
|
|
425
|
-
self.derivative_keys_dyn_loss = DerivativeKeysODE(params=params_dict)
|
|
306
|
+
# get the unweighted mses for each loss term as well as the gradients
|
|
307
|
+
all_funs: ODEComponents[Callable[[Params[Array]], Array] | None] = (
|
|
308
|
+
ODEComponents(dyn_loss_fun, initial_condition_fun, obs_loss_fun)
|
|
309
|
+
)
|
|
310
|
+
all_params: ODEComponents[Params[Array] | None] = ODEComponents(
|
|
311
|
+
params, params, params_obs
|
|
312
|
+
)
|
|
313
|
+
mses_grads = jax.tree.map(
|
|
314
|
+
lambda fun, params: self.get_gradients(fun, params),
|
|
315
|
+
all_funs,
|
|
316
|
+
all_params,
|
|
317
|
+
is_leaf=lambda x: x is None,
|
|
318
|
+
)
|
|
426
319
|
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
"""
|
|
434
|
-
_loss_weights = {}
|
|
435
|
-
for k in fields(loss_weights_init):
|
|
436
|
-
v = getattr(loss_weights_init, k.name)
|
|
437
|
-
if isinstance(v, dict):
|
|
438
|
-
for vv in v.values():
|
|
439
|
-
if not isinstance(vv, (int, float)) and not (
|
|
440
|
-
isinstance(vv, Array)
|
|
441
|
-
and ((vv.shape == (1,) or len(vv.shape) == 0))
|
|
442
|
-
):
|
|
443
|
-
# TODO improve that
|
|
444
|
-
raise ValueError(
|
|
445
|
-
f"loss values cannot be vectorial here, got {vv}"
|
|
446
|
-
)
|
|
447
|
-
if k.name == "dyn_loss":
|
|
448
|
-
if v.keys() == self.dynamic_loss_dict.keys():
|
|
449
|
-
_loss_weights[k.name] = v
|
|
450
|
-
else:
|
|
451
|
-
raise ValueError(
|
|
452
|
-
"Keys in nested dictionary of loss_weights"
|
|
453
|
-
" do not match dynamic_loss_dict keys"
|
|
454
|
-
)
|
|
455
|
-
else:
|
|
456
|
-
if v.keys() == self.u_dict.keys():
|
|
457
|
-
_loss_weights[k.name] = v
|
|
458
|
-
else:
|
|
459
|
-
raise ValueError(
|
|
460
|
-
"Keys in nested dictionary of loss_weights"
|
|
461
|
-
" do not match u_dict keys"
|
|
462
|
-
)
|
|
463
|
-
elif v is None:
|
|
464
|
-
_loss_weights[k.name] = {kk: 0 for kk in self.u_dict.keys()}
|
|
465
|
-
else:
|
|
466
|
-
if not isinstance(v, (int, float)) and not (
|
|
467
|
-
isinstance(v, Array) and ((v.shape == (1,) or len(v.shape) == 0))
|
|
468
|
-
):
|
|
469
|
-
# TODO improve that
|
|
470
|
-
raise ValueError(f"loss values cannot be vectorial here, got {v}")
|
|
471
|
-
if k.name == "dyn_loss":
|
|
472
|
-
_loss_weights[k.name] = {
|
|
473
|
-
kk: v for kk in self.dynamic_loss_dict.keys()
|
|
474
|
-
}
|
|
475
|
-
else:
|
|
476
|
-
_loss_weights[k.name] = {kk: v for kk in self.u_dict.keys()}
|
|
477
|
-
|
|
478
|
-
return _loss_weights
|
|
320
|
+
mses = jax.tree.map(
|
|
321
|
+
lambda leaf: leaf[0], mses_grads, is_leaf=lambda x: isinstance(x, tuple)
|
|
322
|
+
)
|
|
323
|
+
grads = jax.tree.map(
|
|
324
|
+
lambda leaf: leaf[1], mses_grads, is_leaf=lambda x: isinstance(x, tuple)
|
|
325
|
+
)
|
|
479
326
|
|
|
480
|
-
|
|
481
|
-
return self.evaluate(*args, **kwargs)
|
|
327
|
+
return mses, grads
|
|
482
328
|
|
|
483
|
-
def evaluate(
|
|
329
|
+
def evaluate(
|
|
330
|
+
self, params: Params[Array], batch: ODEBatch
|
|
331
|
+
) -> tuple[Float[Array, " "], ODEComponents[Float[Array, " "] | None]]:
|
|
484
332
|
"""
|
|
485
333
|
Evaluate the loss function at a batch of points for given parameters.
|
|
486
334
|
|
|
335
|
+
We retrieve the total value itself and a PyTree with loss values for each term
|
|
487
336
|
|
|
488
337
|
Parameters
|
|
489
338
|
---------
|
|
490
339
|
params
|
|
491
|
-
|
|
340
|
+
Parameters at which the loss is evaluated
|
|
492
341
|
batch
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
342
|
+
Composed of a batch of points in the
|
|
343
|
+
domain, a batch of points in the domain
|
|
344
|
+
border and an optional additional batch of parameters (eg. for
|
|
496
345
|
metamodeling) and an optional additional batch of observed
|
|
497
346
|
inputs/outputs/parameters
|
|
498
347
|
"""
|
|
499
|
-
|
|
500
|
-
isinstance(params_dict.nn_params, dict)
|
|
501
|
-
and self.u_dict.keys() != params_dict.nn_params.keys()
|
|
502
|
-
):
|
|
503
|
-
raise ValueError("u_dict and params_dict.nn_params should have same keys ")
|
|
504
|
-
|
|
505
|
-
temporal_batch = batch.temporal_batch
|
|
506
|
-
|
|
507
|
-
vmap_in_axes_t = (0,)
|
|
508
|
-
|
|
509
|
-
# Retrieve the optional eq_params_batch
|
|
510
|
-
# and update eq_params with the latter
|
|
511
|
-
# and update vmap_in_axes
|
|
512
|
-
if batch.param_batch_dict is not None:
|
|
513
|
-
# update params with the batches of generated params
|
|
514
|
-
params = _update_eq_params_dict(params, batch.param_batch_dict)
|
|
515
|
-
|
|
516
|
-
vmap_in_axes_params = _get_vmap_in_axes_params(
|
|
517
|
-
batch.param_batch_dict, params_dict
|
|
518
|
-
)
|
|
519
|
-
|
|
520
|
-
def dyn_loss_for_one_key(dyn_loss, loss_weight):
|
|
521
|
-
"""This function is used in tree_map"""
|
|
522
|
-
return dynamic_loss_apply(
|
|
523
|
-
dyn_loss.evaluate,
|
|
524
|
-
self.u_dict,
|
|
525
|
-
temporal_batch,
|
|
526
|
-
_set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
|
|
527
|
-
vmap_in_axes_t + vmap_in_axes_params,
|
|
528
|
-
loss_weight,
|
|
529
|
-
u_type=PINN,
|
|
530
|
-
)
|
|
348
|
+
loss_terms, _ = self.evaluate_by_terms(params, batch)
|
|
531
349
|
|
|
532
|
-
|
|
533
|
-
dyn_loss_for_one_key,
|
|
534
|
-
self.dynamic_loss_dict,
|
|
535
|
-
self._loss_weights["dyn_loss"],
|
|
536
|
-
is_leaf=lambda x: isinstance(x, ODE), # before when dynamic losses
|
|
537
|
-
# where plain (unregister pytree) node classes, we could not traverse
|
|
538
|
-
# this level. Now that dynamic losses are eqx.Module they can be
|
|
539
|
-
# traversed by tree map recursion. Hence we need to specify to that
|
|
540
|
-
# we want to stop at this level
|
|
541
|
-
)
|
|
542
|
-
mse_dyn_loss = jax.tree_util.tree_reduce(
|
|
543
|
-
lambda x, y: x + y, jax.tree_util.tree_leaves(dyn_loss_mse_dict)
|
|
544
|
-
)
|
|
545
|
-
|
|
546
|
-
# initial conditions and observation_loss via the internal LossODE
|
|
547
|
-
loss_weight_struct = {
|
|
548
|
-
"dyn_loss": "*",
|
|
549
|
-
"observations": "*",
|
|
550
|
-
"initial_condition": "*",
|
|
551
|
-
}
|
|
552
|
-
|
|
553
|
-
# we need to do the following for the tree_mapping to work
|
|
554
|
-
if batch.obs_batch_dict is None:
|
|
555
|
-
batch = append_obs_batch(batch, self.u_dict_with_none)
|
|
556
|
-
|
|
557
|
-
total_loss, res_dict = constraints_system_loss_apply(
|
|
558
|
-
self.u_constraints_dict,
|
|
559
|
-
batch,
|
|
560
|
-
params_dict,
|
|
561
|
-
self._loss_weights,
|
|
562
|
-
loss_weight_struct,
|
|
563
|
-
)
|
|
350
|
+
loss_val = self.ponderate_and_sum_loss(loss_terms)
|
|
564
351
|
|
|
565
|
-
|
|
566
|
-
total_loss += mse_dyn_loss
|
|
567
|
-
res_dict["dyn_loss"] += mse_dyn_loss
|
|
568
|
-
return total_loss, res_dict
|
|
352
|
+
return loss_val, loss_terms
|