jinns 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -0
- jinns/data/_Batchs.py +27 -0
- jinns/data/_DataGenerators.py +904 -1203
- jinns/data/__init__.py +4 -8
- jinns/experimental/__init__.py +0 -2
- jinns/experimental/_diffrax_solver.py +5 -5
- jinns/loss/_DynamicLoss.py +282 -305
- jinns/loss/_DynamicLossAbstract.py +322 -167
- jinns/loss/_LossODE.py +324 -322
- jinns/loss/_LossPDE.py +652 -1027
- jinns/loss/__init__.py +21 -5
- jinns/loss/_boundary_conditions.py +87 -41
- jinns/loss/{_Losses.py → _loss_utils.py} +101 -45
- jinns/loss/_loss_weights.py +59 -0
- jinns/loss/_operators.py +78 -72
- jinns/parameters/__init__.py +6 -0
- jinns/parameters/_derivative_keys.py +521 -0
- jinns/parameters/_params.py +115 -0
- jinns/plot/__init__.py +5 -0
- jinns/{data/_display.py → plot/_plot.py} +98 -75
- jinns/solver/_rar.py +183 -39
- jinns/solver/_solve.py +151 -124
- jinns/utils/__init__.py +3 -9
- jinns/utils/_containers.py +37 -44
- jinns/utils/_hyperpinn.py +224 -119
- jinns/utils/_pinn.py +183 -111
- jinns/utils/_save_load.py +121 -56
- jinns/utils/_spinn.py +113 -86
- jinns/utils/_types.py +64 -0
- jinns/utils/_utils.py +6 -160
- jinns/validation/_validation.py +48 -140
- jinns-1.1.0.dist-info/AUTHORS +2 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/METADATA +5 -4
- jinns-1.1.0.dist-info/RECORD +39 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/WHEEL +1 -1
- jinns/experimental/_sinuspinn.py +0 -135
- jinns/experimental/_spectralpinn.py +0 -87
- jinns/solver/_seq2seq.py +0 -157
- jinns/utils/_optim.py +0 -147
- jinns/utils/_utils_uspinn.py +0 -727
- jinns-0.9.0.dist-info/RECORD +0 -36
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/LICENSE +0 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossODE.py
CHANGED
|
@@ -1,143 +1,185 @@
|
|
|
1
|
+
# pylint: disable=unsubscriptable-object, no-member
|
|
1
2
|
"""
|
|
2
3
|
Main module to implement a ODE loss in jinns
|
|
3
4
|
"""
|
|
5
|
+
from __future__ import (
|
|
6
|
+
annotations,
|
|
7
|
+
) # https://docs.python.org/3/library/typing.html#constant
|
|
4
8
|
|
|
9
|
+
from dataclasses import InitVar, fields
|
|
10
|
+
from typing import TYPE_CHECKING, Dict
|
|
11
|
+
import abc
|
|
5
12
|
import warnings
|
|
6
13
|
import jax
|
|
7
14
|
import jax.numpy as jnp
|
|
8
15
|
from jax import vmap
|
|
9
|
-
|
|
10
|
-
from
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
_update_eq_params_dict,
|
|
14
|
-
)
|
|
15
|
-
from jinns.loss._Losses import (
|
|
16
|
+
import equinox as eqx
|
|
17
|
+
from jaxtyping import Float, Array, Int
|
|
18
|
+
from jinns.data._DataGenerators import append_obs_batch
|
|
19
|
+
from jinns.loss._loss_utils import (
|
|
16
20
|
dynamic_loss_apply,
|
|
17
21
|
constraints_system_loss_apply,
|
|
18
22
|
observations_loss_apply,
|
|
19
23
|
)
|
|
24
|
+
from jinns.parameters._params import (
|
|
25
|
+
_get_vmap_in_axes_params,
|
|
26
|
+
_update_eq_params_dict,
|
|
27
|
+
)
|
|
28
|
+
from jinns.parameters._derivative_keys import _set_derivatives, DerivativeKeysODE
|
|
29
|
+
from jinns.loss._loss_weights import LossWeightsODE, LossWeightsODEDict
|
|
30
|
+
from jinns.loss._DynamicLossAbstract import ODE
|
|
20
31
|
from jinns.utils._pinn import PINN
|
|
21
32
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
@register_pytree_node_class
|
|
26
|
-
class LossODE:
|
|
27
|
-
r"""Loss object for an ordinary differential equation
|
|
33
|
+
if TYPE_CHECKING:
|
|
34
|
+
from jinns.utils._types import *
|
|
28
35
|
|
|
29
|
-
.. math::
|
|
30
|
-
\mathcal{N}[u](t) = 0, \forall t \in I
|
|
31
36
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
37
|
+
class _LossODEAbstract(eqx.Module):
|
|
38
|
+
"""
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
|
|
42
|
+
loss_weights : LossWeightsODE, default=None
|
|
43
|
+
The loss weights for the differents term : dynamic loss,
|
|
44
|
+
initial condition and eventually observations if any. All fields are
|
|
45
|
+
set to 1.0 by default.
|
|
46
|
+
derivative_keys : DerivativeKeysODE, default=None
|
|
47
|
+
Specify which field of `params` should be differentiated for each
|
|
48
|
+
composant of the total loss. Particularily useful for inverse problems.
|
|
49
|
+
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
50
|
+
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
51
|
+
is `"nn_params"` for each composant of the loss.
|
|
52
|
+
initial_condition : tuple, default=None
|
|
53
|
+
tuple of length 2 with initial condition $(t_0, u_0)$.
|
|
54
|
+
obs_slice : Slice, default=None
|
|
55
|
+
Slice object specifying the begininning/ending
|
|
56
|
+
slice of u output(s) that is observed. This is useful for
|
|
57
|
+
multidimensional PINN, with partially observed outputs.
|
|
58
|
+
Default is None (whole output is observed).
|
|
59
|
+
params : InitVar[Params], default=None
|
|
60
|
+
The main Params object of the problem needed to instanciate the
|
|
61
|
+
DerivativeKeysODE if the latter is not specified.
|
|
38
62
|
"""
|
|
39
63
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
)
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
access only some part of the evaluate call results.
|
|
65
|
-
derivative_keys
|
|
66
|
-
A dict of lists of strings. In the dict, the key must correspond to
|
|
67
|
-
the loss term keywords. Then each of the values must correspond to keys in the parameter
|
|
68
|
-
dictionary (*at top level only of the parameter dictionary*).
|
|
69
|
-
It enables selecting the set of parameters
|
|
70
|
-
with respect to which the gradients of the dynamic
|
|
71
|
-
loss are computed. If nothing is provided, we set ["nn_params"] for all loss term
|
|
72
|
-
keywords, this is what is typically
|
|
73
|
-
done in solving forward problems, when we only estimate the
|
|
74
|
-
equation solution with a PINN. If some loss terms keywords are
|
|
75
|
-
missing we set their value to ["nn_params"] by default for the same
|
|
76
|
-
reason
|
|
77
|
-
initial_condition :
|
|
78
|
-
tuple of length 2 with initial condition :math:`(t0, u0)`.
|
|
79
|
-
Can be None in order to
|
|
80
|
-
access only some part of the evaluate call results.
|
|
81
|
-
obs_slice:
|
|
82
|
-
slice object specifying the begininning/ending
|
|
83
|
-
slice of u output(s) that is observed (this is then useful for
|
|
84
|
-
multidim PINN). Default is None.
|
|
85
|
-
|
|
86
|
-
Raises
|
|
87
|
-
------
|
|
88
|
-
ValueError
|
|
89
|
-
if initial condition is not a tuple.
|
|
90
|
-
"""
|
|
91
|
-
self.dynamic_loss = dynamic_loss
|
|
92
|
-
self.u = u
|
|
93
|
-
if derivative_keys is None:
|
|
94
|
-
# be default we only take gradient wrt nn_params
|
|
95
|
-
derivative_keys = {
|
|
96
|
-
k: ["nn_params"]
|
|
97
|
-
for k in [
|
|
98
|
-
"dyn_loss",
|
|
99
|
-
"initial_condition",
|
|
100
|
-
"observations",
|
|
101
|
-
]
|
|
102
|
-
}
|
|
103
|
-
if isinstance(derivative_keys, list):
|
|
104
|
-
# if the user only provided a list, this defines the gradient taken
|
|
105
|
-
# for all the loss entries
|
|
106
|
-
derivative_keys = {
|
|
107
|
-
k: derivative_keys
|
|
108
|
-
for k in [
|
|
109
|
-
"dyn_loss",
|
|
110
|
-
"initial_condition",
|
|
111
|
-
"observations",
|
|
112
|
-
]
|
|
113
|
-
}
|
|
114
|
-
|
|
115
|
-
self.derivative_keys = derivative_keys
|
|
116
|
-
|
|
117
|
-
if initial_condition is None:
|
|
64
|
+
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
65
|
+
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
66
|
+
# kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
|
|
67
|
+
derivative_keys: DerivativeKeysODE | None = eqx.field(kw_only=True, default=None)
|
|
68
|
+
loss_weights: LossWeightsODE | None = eqx.field(kw_only=True, default=None)
|
|
69
|
+
initial_condition: tuple | None = eqx.field(kw_only=True, default=None)
|
|
70
|
+
obs_slice: slice | None = eqx.field(kw_only=True, default=None, static=True)
|
|
71
|
+
|
|
72
|
+
params: InitVar[Params] = eqx.field(default=None, kw_only=True)
|
|
73
|
+
|
|
74
|
+
def __post_init__(self, params=None):
|
|
75
|
+
if self.loss_weights is None:
|
|
76
|
+
self.loss_weights = LossWeightsODE()
|
|
77
|
+
|
|
78
|
+
if self.derivative_keys is None:
|
|
79
|
+
try:
|
|
80
|
+
# be default we only take gradient wrt nn_params
|
|
81
|
+
self.derivative_keys = DerivativeKeysODE(params=params)
|
|
82
|
+
except ValueError as exc:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
"Problem at self.derivative_keys initialization "
|
|
85
|
+
f"received {self.derivative_keys=} and {params=}"
|
|
86
|
+
) from exc
|
|
87
|
+
if self.initial_condition is None:
|
|
118
88
|
warnings.warn(
|
|
119
89
|
"Initial condition wasn't provided. Be sure to cover for that"
|
|
120
90
|
"case (e.g by. hardcoding it into the PINN output)."
|
|
121
91
|
)
|
|
122
92
|
else:
|
|
123
|
-
if
|
|
93
|
+
if (
|
|
94
|
+
not isinstance(self.initial_condition, tuple)
|
|
95
|
+
or len(self.initial_condition) != 2
|
|
96
|
+
):
|
|
124
97
|
raise ValueError(
|
|
125
|
-
|
|
98
|
+
"Initial condition should be a tuple of len 2 with (t0, u0), "
|
|
99
|
+
f"{self.initial_condition} was passed."
|
|
126
100
|
)
|
|
127
|
-
|
|
128
|
-
self.loss_weights = loss_weights
|
|
129
|
-
self.obs_slice = obs_slice
|
|
101
|
+
|
|
130
102
|
if self.obs_slice is None:
|
|
131
103
|
self.obs_slice = jnp.s_[...]
|
|
132
104
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
105
|
+
if self.loss_weights is None:
|
|
106
|
+
self.loss_weights = LossWeightsODE()
|
|
107
|
+
|
|
108
|
+
@abc.abstractmethod
|
|
109
|
+
def evaluate(
|
|
110
|
+
self: eqx.Module, params: Params, batch: ODEBatch
|
|
111
|
+
) -> tuple[Float, dict]:
|
|
112
|
+
raise NotImplementedError
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class LossODE(_LossODEAbstract):
|
|
116
|
+
r"""Loss object for an ordinary differential equation
|
|
117
|
+
|
|
118
|
+
$$
|
|
119
|
+
\mathcal{N}[u](t) = 0, \forall t \in I
|
|
120
|
+
$$
|
|
121
|
+
|
|
122
|
+
where $\mathcal{N}[\cdot]$ is a differential operator and the
|
|
123
|
+
initial condition is $u(t_0)=u_0$.
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
Parameters
|
|
127
|
+
----------
|
|
128
|
+
loss_weights : LossWeightsODE, default=None
|
|
129
|
+
The loss weights for the differents term : dynamic loss,
|
|
130
|
+
initial condition and eventually observations if any. All fields are
|
|
131
|
+
set to 1.0 by default.
|
|
132
|
+
derivative_keys : DerivativeKeysODE, default=None
|
|
133
|
+
Specify which field of `params` should be differentiated for each
|
|
134
|
+
composant of the total loss. Particularily useful for inverse problems.
|
|
135
|
+
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
136
|
+
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
137
|
+
is `"nn_params"` for each composant of the loss.
|
|
138
|
+
initial_condition : tuple, default=None
|
|
139
|
+
tuple of length 2 with initial condition $(t_0, u_0)$.
|
|
140
|
+
obs_slice Slice, default=None
|
|
141
|
+
Slice object specifying the begininning/ending
|
|
142
|
+
slice of u output(s) that is observed. This is useful for
|
|
143
|
+
multidimensional PINN, with partially observed outputs.
|
|
144
|
+
Default is None (whole output is observed).
|
|
145
|
+
params : InitVar[Params], default=None
|
|
146
|
+
The main Params object of the problem needed to instanciate the
|
|
147
|
+
DerivativeKeysODE if the latter is not specified.
|
|
148
|
+
u : eqx.Module
|
|
149
|
+
the PINN
|
|
150
|
+
dynamic_loss : DynamicLoss
|
|
151
|
+
the ODE dynamic part of the loss, basically the differential
|
|
152
|
+
operator $\mathcal{N}[u](t)$. Should implement a method
|
|
153
|
+
`dynamic_loss.evaluate(t, u, params)`.
|
|
154
|
+
Can be None in order to access only some part of the evaluate call.
|
|
155
|
+
|
|
156
|
+
Raises
|
|
157
|
+
------
|
|
158
|
+
ValueError
|
|
159
|
+
if initial condition is not a tuple.
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
163
|
+
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
164
|
+
u: eqx.Module
|
|
165
|
+
dynamic_loss: DynamicLoss | None
|
|
166
|
+
|
|
167
|
+
vmap_in_axes: tuple[Int] = eqx.field(init=False, static=True)
|
|
168
|
+
|
|
169
|
+
def __post_init__(self, params=None):
|
|
170
|
+
super().__post_init__(
|
|
171
|
+
params=params
|
|
172
|
+
) # because __init__ or __post_init__ of Base
|
|
173
|
+
# class is not automatically called
|
|
174
|
+
|
|
175
|
+
self.vmap_in_axes = (0,)
|
|
136
176
|
|
|
137
177
|
def __call__(self, *args, **kwargs):
|
|
138
178
|
return self.evaluate(*args, **kwargs)
|
|
139
179
|
|
|
140
|
-
def evaluate(
|
|
180
|
+
def evaluate(
|
|
181
|
+
self, params: Params, batch: ODEBatch
|
|
182
|
+
) -> tuple[Float[Array, "1"], dict[str, float]]:
|
|
141
183
|
"""
|
|
142
184
|
Evaluate the loss function at a batch of points for given parameters.
|
|
143
185
|
|
|
@@ -145,21 +187,14 @@ class LossODE:
|
|
|
145
187
|
Parameters
|
|
146
188
|
---------
|
|
147
189
|
params
|
|
148
|
-
|
|
149
|
-
Typically, it is a dictionary of
|
|
150
|
-
dictionaries: `eq_params` and `nn_params``, respectively the
|
|
151
|
-
differential equation parameters and the neural network parameter
|
|
190
|
+
Parameters at which the loss is evaluated
|
|
152
191
|
batch
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
metamodeling) and an optional additional batch of observed
|
|
157
|
-
inputs/outputs/parameters
|
|
192
|
+
Composed of a batch of time points
|
|
193
|
+
at which to evaluate the differential operator. An optional additional batch of parameters (eg. for metamodeling) and an optional additional batch of observed inputs/outputs/parameters can
|
|
194
|
+
be supplied.
|
|
158
195
|
"""
|
|
159
196
|
temporal_batch = batch.temporal_batch
|
|
160
197
|
|
|
161
|
-
vmap_in_axes_t = (0,)
|
|
162
|
-
|
|
163
198
|
# Retrieve the optional eq_params_batch
|
|
164
199
|
# and update eq_params with the latter
|
|
165
200
|
# and update vmap_in_axes
|
|
@@ -170,21 +205,19 @@ class LossODE:
|
|
|
170
205
|
vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
|
|
171
206
|
|
|
172
207
|
## dynamic part
|
|
173
|
-
params_ = _set_derivatives(params, "dyn_loss", self.derivative_keys)
|
|
174
208
|
if self.dynamic_loss is not None:
|
|
175
209
|
mse_dyn_loss = dynamic_loss_apply(
|
|
176
210
|
self.dynamic_loss.evaluate,
|
|
177
211
|
self.u,
|
|
178
212
|
(temporal_batch,),
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
self.loss_weights
|
|
213
|
+
_set_derivatives(params, self.derivative_keys.dyn_loss),
|
|
214
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
215
|
+
self.loss_weights.dyn_loss,
|
|
182
216
|
)
|
|
183
217
|
else:
|
|
184
218
|
mse_dyn_loss = jnp.array(0.0)
|
|
185
219
|
|
|
186
220
|
# initial condition
|
|
187
|
-
params_ = _set_derivatives(params, "initial_condition", self.derivative_keys)
|
|
188
221
|
if self.initial_condition is not None:
|
|
189
222
|
vmap_in_axes = (None,) + vmap_in_axes_params
|
|
190
223
|
if not jax.tree_util.tree_leaves(vmap_in_axes):
|
|
@@ -193,12 +226,24 @@ class LossODE:
|
|
|
193
226
|
v_u = self.u
|
|
194
227
|
else:
|
|
195
228
|
v_u = vmap(self.u, (None,) + vmap_in_axes_params)
|
|
196
|
-
t0, u0 = self.initial_condition
|
|
229
|
+
t0, u0 = self.initial_condition # pylint: disable=unpacking-non-sequence
|
|
197
230
|
t0 = jnp.array(t0)
|
|
198
231
|
u0 = jnp.array(u0)
|
|
199
232
|
mse_initial_condition = jnp.mean(
|
|
200
|
-
self.loss_weights
|
|
201
|
-
* jnp.sum(
|
|
233
|
+
self.loss_weights.initial_condition
|
|
234
|
+
* jnp.sum(
|
|
235
|
+
(
|
|
236
|
+
v_u(
|
|
237
|
+
t0,
|
|
238
|
+
_set_derivatives(
|
|
239
|
+
params, self.derivative_keys.initial_condition
|
|
240
|
+
),
|
|
241
|
+
)
|
|
242
|
+
- u0
|
|
243
|
+
)
|
|
244
|
+
** 2,
|
|
245
|
+
axis=-1,
|
|
246
|
+
)
|
|
202
247
|
)
|
|
203
248
|
else:
|
|
204
249
|
mse_initial_condition = jnp.array(0.0)
|
|
@@ -208,14 +253,13 @@ class LossODE:
|
|
|
208
253
|
params = _update_eq_params_dict(params, batch.obs_batch_dict["eq_params"])
|
|
209
254
|
|
|
210
255
|
# MSE loss wrt to an observed batch
|
|
211
|
-
params_ = _set_derivatives(params, "observations", self.derivative_keys)
|
|
212
256
|
mse_observation_loss = observations_loss_apply(
|
|
213
257
|
self.u,
|
|
214
258
|
(batch.obs_batch_dict["pinn_in"],),
|
|
215
|
-
|
|
216
|
-
|
|
259
|
+
_set_derivatives(params, self.derivative_keys.observations),
|
|
260
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
217
261
|
batch.obs_batch_dict["val"],
|
|
218
|
-
self.loss_weights
|
|
262
|
+
self.loss_weights.observations,
|
|
219
263
|
self.obs_slice,
|
|
220
264
|
)
|
|
221
265
|
else:
|
|
@@ -231,193 +275,178 @@ class LossODE:
|
|
|
231
275
|
}
|
|
232
276
|
)
|
|
233
277
|
|
|
234
|
-
def tree_flatten(self):
|
|
235
|
-
children = (self.initial_condition, self.loss_weights)
|
|
236
|
-
aux_data = {
|
|
237
|
-
"u": self.u,
|
|
238
|
-
"dynamic_loss": self.dynamic_loss,
|
|
239
|
-
"obs_slice": self.obs_slice,
|
|
240
|
-
"derivative_keys": self.derivative_keys,
|
|
241
|
-
}
|
|
242
|
-
return (children, aux_data)
|
|
243
|
-
|
|
244
|
-
@classmethod
|
|
245
|
-
def tree_unflatten(cls, aux_data, children):
|
|
246
|
-
(initial_condition, loss_weights) = children
|
|
247
|
-
loss_ode = cls(
|
|
248
|
-
loss_weights=loss_weights,
|
|
249
|
-
initial_condition=initial_condition,
|
|
250
|
-
**aux_data,
|
|
251
|
-
)
|
|
252
|
-
return loss_ode
|
|
253
278
|
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
class SystemLossODE:
|
|
257
|
-
"""
|
|
279
|
+
class SystemLossODE(eqx.Module):
|
|
280
|
+
r"""
|
|
258
281
|
Class to implement a system of ODEs.
|
|
259
282
|
The goal is to give maximum freedom to the user. The class is created with
|
|
260
|
-
a dict of dynamic loss and a dict of initial conditions.
|
|
261
|
-
over the dynamic losses that compose the system. All
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
283
|
+
a dict of dynamic loss and a dict of initial conditions. Then, it iterates
|
|
284
|
+
over the dynamic losses that compose the system. All PINNs are passed as
|
|
285
|
+
arguments to each dynamic loss evaluate functions, along with all the
|
|
286
|
+
parameter dictionaries. All specification is left to the responsability
|
|
287
|
+
of the user, inside the dynamic loss.
|
|
265
288
|
|
|
266
289
|
**Note:** All the dictionaries (except `dynamic_loss_dict`) must have the same keys.
|
|
267
290
|
Indeed, these dictionaries (except `dynamic_loss_dict`) are tied to one
|
|
268
291
|
solution.
|
|
269
292
|
|
|
270
|
-
|
|
271
|
-
|
|
293
|
+
Parameters
|
|
294
|
+
----------
|
|
295
|
+
u_dict : Dict[str, eqx.Module]
|
|
296
|
+
dict of PINNs
|
|
297
|
+
loss_weights : LossWeightsODEDict
|
|
298
|
+
A dictionary of LossWeightsODE
|
|
299
|
+
derivative_keys_dict : Dict[str, DerivativeKeysODE], default=None
|
|
300
|
+
A dictionnary of DerivativeKeysODE specifying what field of `params`
|
|
301
|
+
should be used during gradient computations for each of the terms of
|
|
302
|
+
the total loss, for each of the loss in the system. Default is
|
|
303
|
+
`"nn_params`" everywhere.
|
|
304
|
+
initial_condition_dict : Dict[str, tuple], default=None
|
|
305
|
+
dict of tuple of length 2 with initial condition $(t_0, u_0)$
|
|
306
|
+
Must share the keys of `u_dict`. Default is None. No initial
|
|
307
|
+
condition is permitted when the initial condition is hardcoded in
|
|
308
|
+
the PINN architecture for example
|
|
309
|
+
dynamic_loss_dict : Dict[str, ODE]
|
|
310
|
+
dict of dynamic part of the loss, basically the differential
|
|
311
|
+
operator $\mathcal{N}[u](t)$. Should implement a method
|
|
312
|
+
`dynamic_loss.evaluate(t, u, params)`
|
|
313
|
+
obs_slice_dict : Dict[str, Slice]
|
|
314
|
+
dict of obs_slice, with keys from `u_dict` to designate the
|
|
315
|
+
output(s) channels that are observed, for each
|
|
316
|
+
PINNs. Default is None. But if a value is given, all the entries of
|
|
317
|
+
`u_dict` must be represented here with default value `jnp.s_[...]`
|
|
318
|
+
if no particular slice is to be given.
|
|
319
|
+
params_dict : InitVar[ParamsDict], default=None
|
|
320
|
+
The main Params object of the problem needed to instanciate the
|
|
321
|
+
DerivativeKeysODE if the latter is not specified.
|
|
322
|
+
|
|
323
|
+
Raises
|
|
324
|
+
------
|
|
325
|
+
ValueError
|
|
326
|
+
if initial condition is not a dict of tuple.
|
|
327
|
+
ValueError
|
|
328
|
+
if the dictionaries that should share the keys of u_dict do not.
|
|
272
329
|
"""
|
|
273
330
|
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
Must share the keys of `u_dict`. Default is None. No initial
|
|
308
|
-
condition is permitted when the initial condition is hardcoded in
|
|
309
|
-
the PINN architecture for example
|
|
310
|
-
dynamic_loss_dict
|
|
311
|
-
dict of dynamic part of the loss, basically the differential
|
|
312
|
-
operator :math:`\mathcal{N}[u](t)`. Should implement a method
|
|
313
|
-
`dynamic_loss.evaluate(t, u, params)`
|
|
314
|
-
obs_slice_dict
|
|
315
|
-
dict of obs_slice, with keys from `u_dict` to designate the
|
|
316
|
-
output(s) channels that are forced to observed values, for each
|
|
317
|
-
PINNs. Default is None. But if a value is given, all the entries of
|
|
318
|
-
`u_dict` must be represented here with default value `jnp.s_[...]`
|
|
319
|
-
if no particular slice is to be given
|
|
320
|
-
|
|
321
|
-
Raises
|
|
322
|
-
------
|
|
323
|
-
ValueError
|
|
324
|
-
if initial condition is not a dict of tuple
|
|
325
|
-
ValueError
|
|
326
|
-
if the dictionaries that should share the keys of u_dict do not
|
|
327
|
-
"""
|
|
328
|
-
|
|
331
|
+
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
332
|
+
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
333
|
+
u_dict: Dict[str, eqx.Module]
|
|
334
|
+
dynamic_loss_dict: Dict[str, ODE]
|
|
335
|
+
derivative_keys_dict: Dict[str, DerivativeKeysODE | None] | None = eqx.field(
|
|
336
|
+
kw_only=True, default=None
|
|
337
|
+
)
|
|
338
|
+
initial_condition_dict: Dict[str, tuple] | None = eqx.field(
|
|
339
|
+
kw_only=True, default=None
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
obs_slice_dict: Dict[str, slice | None] | None = eqx.field(
|
|
343
|
+
kw_only=True, default=None, static=True
|
|
344
|
+
) # We are at an "leaf" attribute here (slice, not valid JAX type). Since
|
|
345
|
+
# we do not expect it to change with put a static=True here. But note that
|
|
346
|
+
# this is the only static for all the SystemLossODE attribute, since all
|
|
347
|
+
# other are composed of more complex structures ("non-leaf")
|
|
348
|
+
|
|
349
|
+
# For the user loss_weights are passed as a LossWeightsODEDict (with internal
|
|
350
|
+
# dictionary having keys in u_dict and / or dynamic_loss_dict)
|
|
351
|
+
loss_weights: InitVar[LossWeightsODEDict | None] = eqx.field(
|
|
352
|
+
kw_only=True, default=None
|
|
353
|
+
)
|
|
354
|
+
params_dict: InitVar[ParamsDict] = eqx.field(kw_only=True, default=None)
|
|
355
|
+
|
|
356
|
+
u_constraints_dict: Dict[str, LossODE] = eqx.field(init=False)
|
|
357
|
+
derivative_keys_dyn_loss: DerivativeKeysODE = eqx.field(init=False)
|
|
358
|
+
|
|
359
|
+
u_dict_with_none: Dict[str, None] = eqx.field(init=False)
|
|
360
|
+
# internally the loss weights are handled with a dictionary
|
|
361
|
+
_loss_weights: Dict[str, dict] = eqx.field(init=False)
|
|
362
|
+
|
|
363
|
+
def __post_init__(self, loss_weights=None, params_dict=None):
|
|
329
364
|
# a dictionary that will be useful at different places
|
|
330
|
-
self.u_dict_with_none = {k: None for k in u_dict.keys()}
|
|
331
|
-
if initial_condition_dict is None:
|
|
365
|
+
self.u_dict_with_none = {k: None for k in self.u_dict.keys()}
|
|
366
|
+
if self.initial_condition_dict is None:
|
|
332
367
|
self.initial_condition_dict = self.u_dict_with_none
|
|
333
368
|
else:
|
|
334
|
-
self.
|
|
335
|
-
if u_dict.keys() != initial_condition_dict.keys():
|
|
369
|
+
if self.u_dict.keys() != self.initial_condition_dict.keys():
|
|
336
370
|
raise ValueError(
|
|
337
371
|
"initial_condition_dict should have same keys as u_dict"
|
|
338
372
|
)
|
|
339
|
-
if obs_slice_dict is None:
|
|
340
|
-
self.obs_slice_dict = {k: jnp.s_[...] for k in u_dict.keys()}
|
|
373
|
+
if self.obs_slice_dict is None:
|
|
374
|
+
self.obs_slice_dict = {k: jnp.s_[...] for k in self.u_dict.keys()}
|
|
341
375
|
else:
|
|
342
|
-
self.
|
|
343
|
-
if u_dict.keys() != obs_slice_dict.keys():
|
|
376
|
+
if self.u_dict.keys() != self.obs_slice_dict.keys():
|
|
344
377
|
raise ValueError("obs_slice_dict should have same keys as u_dict")
|
|
345
378
|
|
|
346
|
-
if derivative_keys_dict is None:
|
|
379
|
+
if self.derivative_keys_dict is None:
|
|
347
380
|
self.derivative_keys_dict = {
|
|
348
381
|
k: None
|
|
349
|
-
for k in set(
|
|
382
|
+
for k in set(
|
|
383
|
+
list(self.dynamic_loss_dict.keys()) + list(self.u_dict.keys())
|
|
384
|
+
)
|
|
350
385
|
}
|
|
351
386
|
# set() because we can have duplicate entries and in this case we
|
|
352
387
|
# say it corresponds to the same derivative_keys_dict entry
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
#
|
|
359
|
-
#
|
|
360
|
-
#
|
|
361
|
-
for k in
|
|
388
|
+
# we need both because the constraints (all but dyn_loss) will be
|
|
389
|
+
# done by iterating on u_dict while the dyn_loss will be by
|
|
390
|
+
# iterating on dynamic_loss_dict. So each time we will require dome
|
|
391
|
+
# derivative_keys_dict
|
|
392
|
+
|
|
393
|
+
# derivative keys for the u_constraints. Note that we create missing
|
|
394
|
+
# DerivativeKeysODE around a Params object and not ParamsDict
|
|
395
|
+
# this works because u_dict.keys == params_dict.nn_params.keys()
|
|
396
|
+
for k in self.u_dict.keys():
|
|
362
397
|
if self.derivative_keys_dict[k] is None:
|
|
363
|
-
self.derivative_keys_dict[k] =
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
self.u_dict = u_dict
|
|
398
|
+
self.derivative_keys_dict[k] = DerivativeKeysODE(
|
|
399
|
+
params=params_dict.extract_params(k)
|
|
400
|
+
)
|
|
367
401
|
|
|
368
|
-
self.
|
|
369
|
-
# note that self.initial_condition_dict must be
|
|
370
|
-
# initialized beforehand
|
|
402
|
+
self._loss_weights = self.set_loss_weights(loss_weights)
|
|
371
403
|
|
|
372
404
|
# The constaints on the solutions will be implemented by reusing a
|
|
373
405
|
# LossODE class without dynamic loss term
|
|
374
406
|
self.u_constraints_dict = {}
|
|
375
407
|
for i in self.u_dict.keys():
|
|
376
408
|
self.u_constraints_dict[i] = LossODE(
|
|
377
|
-
u=u_dict[i],
|
|
378
|
-
loss_weights=
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
409
|
+
u=self.u_dict[i],
|
|
410
|
+
loss_weights=LossWeightsODE(
|
|
411
|
+
dyn_loss=0.0,
|
|
412
|
+
initial_condition=1.0,
|
|
413
|
+
observations=1.0,
|
|
414
|
+
),
|
|
383
415
|
dynamic_loss=None,
|
|
384
416
|
derivative_keys=self.derivative_keys_dict[i],
|
|
385
417
|
initial_condition=self.initial_condition_dict[i],
|
|
386
418
|
obs_slice=self.obs_slice_dict[i],
|
|
387
419
|
)
|
|
388
420
|
|
|
389
|
-
#
|
|
390
|
-
#
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
}
|
|
395
|
-
self.derivative_keys_u_dict = {
|
|
396
|
-
k: self.derivative_keys_dict[k]
|
|
397
|
-
for k in self.u_dict.keys() & self.derivative_keys_dict.keys()
|
|
398
|
-
}
|
|
399
|
-
|
|
400
|
-
@property
|
|
401
|
-
def loss_weights(self):
|
|
402
|
-
return self._loss_weights
|
|
421
|
+
# derivative keys for the dynamic loss. Note that we create a
|
|
422
|
+
# DerivativeKeysODE around a ParamsDict object because a whole
|
|
423
|
+
# params_dict is feed to DynamicLoss.evaluate functions (extract_params
|
|
424
|
+
# happen inside it)
|
|
425
|
+
self.derivative_keys_dyn_loss = DerivativeKeysODE(params=params_dict)
|
|
403
426
|
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
427
|
+
def set_loss_weights(self, loss_weights_init):
|
|
428
|
+
"""
|
|
429
|
+
This rather complex function enables the user to specify a simple
|
|
430
|
+
loss_weights=LossWeightsODEDict(dyn_loss=1., initial_condition=Tmax)
|
|
431
|
+
for ponderating values being applied to all the equations of the
|
|
432
|
+
system... So all the transformations are handled here
|
|
433
|
+
"""
|
|
434
|
+
_loss_weights = {}
|
|
435
|
+
for k in fields(loss_weights_init):
|
|
436
|
+
v = getattr(loss_weights_init, k.name)
|
|
408
437
|
if isinstance(v, dict):
|
|
409
|
-
for
|
|
438
|
+
for vv in v.values():
|
|
410
439
|
if not isinstance(vv, (int, float)) and not (
|
|
411
|
-
isinstance(vv,
|
|
440
|
+
isinstance(vv, Array)
|
|
412
441
|
and ((vv.shape == (1,) or len(vv.shape) == 0))
|
|
413
442
|
):
|
|
414
443
|
# TODO improve that
|
|
415
444
|
raise ValueError(
|
|
416
445
|
f"loss values cannot be vectorial here, got {vv}"
|
|
417
446
|
)
|
|
418
|
-
if k == "dyn_loss":
|
|
447
|
+
if k.name == "dyn_loss":
|
|
419
448
|
if v.keys() == self.dynamic_loss_dict.keys():
|
|
420
|
-
|
|
449
|
+
_loss_weights[k.name] = v
|
|
421
450
|
else:
|
|
422
451
|
raise ValueError(
|
|
423
452
|
"Keys in nested dictionary of loss_weights"
|
|
@@ -425,48 +454,41 @@ class SystemLossODE:
|
|
|
425
454
|
)
|
|
426
455
|
else:
|
|
427
456
|
if v.keys() == self.u_dict.keys():
|
|
428
|
-
|
|
457
|
+
_loss_weights[k.name] = v
|
|
429
458
|
else:
|
|
430
459
|
raise ValueError(
|
|
431
460
|
"Keys in nested dictionary of loss_weights"
|
|
432
461
|
" do not match u_dict keys"
|
|
433
462
|
)
|
|
463
|
+
elif v is None:
|
|
464
|
+
_loss_weights[k.name] = {kk: 0 for kk in self.u_dict.keys()}
|
|
434
465
|
else:
|
|
435
466
|
if not isinstance(v, (int, float)) and not (
|
|
436
|
-
isinstance(v,
|
|
437
|
-
and ((v.shape == (1,) or len(v.shape) == 0))
|
|
467
|
+
isinstance(v, Array) and ((v.shape == (1,) or len(v.shape) == 0))
|
|
438
468
|
):
|
|
439
469
|
# TODO improve that
|
|
440
470
|
raise ValueError(f"loss values cannot be vectorial here, got {v}")
|
|
441
|
-
if k == "dyn_loss":
|
|
442
|
-
|
|
471
|
+
if k.name == "dyn_loss":
|
|
472
|
+
_loss_weights[k.name] = {
|
|
443
473
|
kk: v for kk in self.dynamic_loss_dict.keys()
|
|
444
474
|
}
|
|
445
475
|
else:
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
if "observations" not in value.keys():
|
|
450
|
-
self._loss_weights["observations"] = {k: 0 for k in self.u_dict.keys()}
|
|
476
|
+
_loss_weights[k.name] = {kk: v for kk in self.u_dict.keys()}
|
|
477
|
+
|
|
478
|
+
return _loss_weights
|
|
451
479
|
|
|
452
480
|
def __call__(self, *args, **kwargs):
|
|
453
481
|
return self.evaluate(*args, **kwargs)
|
|
454
482
|
|
|
455
|
-
def evaluate(self, params_dict, batch):
|
|
483
|
+
def evaluate(self, params_dict: ParamsDict, batch: ODEBatch) -> Float[Array, "1"]:
|
|
456
484
|
"""
|
|
457
485
|
Evaluate the loss function at a batch of points for given parameters.
|
|
458
486
|
|
|
459
487
|
|
|
460
488
|
Parameters
|
|
461
489
|
---------
|
|
462
|
-
|
|
463
|
-
A
|
|
464
|
-
Typically, it is a dictionary of dictionaries of
|
|
465
|
-
dictionaries: `eq_params` and `nn_params``, respectively the
|
|
466
|
-
differential equation parameters and the neural network parameter.
|
|
467
|
-
Note that params_dict["nn_params"] need not be a dictionary anymore
|
|
468
|
-
but can directly be the parameters. It is useful when working with
|
|
469
|
-
neural networks sharing the same parameters
|
|
490
|
+
params
|
|
491
|
+
A ParamsDict object
|
|
470
492
|
batch
|
|
471
493
|
A ODEBatch object.
|
|
472
494
|
Such a named tuple is composed of a batch of time points
|
|
@@ -475,10 +497,10 @@ class SystemLossODE:
|
|
|
475
497
|
inputs/outputs/parameters
|
|
476
498
|
"""
|
|
477
499
|
if (
|
|
478
|
-
isinstance(params_dict
|
|
479
|
-
and self.u_dict.keys() != params_dict
|
|
500
|
+
isinstance(params_dict.nn_params, dict)
|
|
501
|
+
and self.u_dict.keys() != params_dict.nn_params.keys()
|
|
480
502
|
):
|
|
481
|
-
raise ValueError("u_dict and params_dict
|
|
503
|
+
raise ValueError("u_dict and params_dict.nn_params should have same keys ")
|
|
482
504
|
|
|
483
505
|
temporal_batch = batch.temporal_batch
|
|
484
506
|
|
|
@@ -489,20 +511,19 @@ class SystemLossODE:
|
|
|
489
511
|
# and update vmap_in_axes
|
|
490
512
|
if batch.param_batch_dict is not None:
|
|
491
513
|
# update params with the batches of generated params
|
|
492
|
-
|
|
514
|
+
params = _update_eq_params_dict(params, batch.param_batch_dict)
|
|
493
515
|
|
|
494
516
|
vmap_in_axes_params = _get_vmap_in_axes_params(
|
|
495
517
|
batch.param_batch_dict, params_dict
|
|
496
518
|
)
|
|
497
519
|
|
|
498
|
-
def dyn_loss_for_one_key(dyn_loss,
|
|
520
|
+
def dyn_loss_for_one_key(dyn_loss, loss_weight):
|
|
499
521
|
"""This function is used in tree_map"""
|
|
500
|
-
params_dict_ = _set_derivatives(params_dict, "dyn_loss", derivative_key)
|
|
501
522
|
return dynamic_loss_apply(
|
|
502
523
|
dyn_loss.evaluate,
|
|
503
524
|
self.u_dict,
|
|
504
525
|
(temporal_batch,),
|
|
505
|
-
|
|
526
|
+
_set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
|
|
506
527
|
vmap_in_axes_t + vmap_in_axes_params,
|
|
507
528
|
loss_weight,
|
|
508
529
|
u_type=PINN,
|
|
@@ -511,8 +532,12 @@ class SystemLossODE:
|
|
|
511
532
|
dyn_loss_mse_dict = jax.tree_util.tree_map(
|
|
512
533
|
dyn_loss_for_one_key,
|
|
513
534
|
self.dynamic_loss_dict,
|
|
514
|
-
self.derivative_keys_dyn_loss_dict,
|
|
515
535
|
self._loss_weights["dyn_loss"],
|
|
536
|
+
is_leaf=lambda x: isinstance(x, ODE), # before when dynamic losses
|
|
537
|
+
# where plain (unregister pytree) node classes, we could not traverse
|
|
538
|
+
# this level. Now that dynamic losses are eqx.Module they can be
|
|
539
|
+
# traversed by tree map recursion. Hence we need to specify to that
|
|
540
|
+
# we want to stop at this level
|
|
516
541
|
)
|
|
517
542
|
mse_dyn_loss = jax.tree_util.tree_reduce(
|
|
518
543
|
lambda x, y: x + y, jax.tree_util.tree_leaves(dyn_loss_mse_dict)
|
|
@@ -527,7 +552,8 @@ class SystemLossODE:
|
|
|
527
552
|
|
|
528
553
|
# we need to do the following for the tree_mapping to work
|
|
529
554
|
if batch.obs_batch_dict is None:
|
|
530
|
-
batch = batch
|
|
555
|
+
batch = append_obs_batch(batch, self.u_dict_with_none)
|
|
556
|
+
|
|
531
557
|
total_loss, res_dict = constraints_system_loss_apply(
|
|
532
558
|
self.u_constraints_dict,
|
|
533
559
|
batch,
|
|
@@ -540,27 +566,3 @@ class SystemLossODE:
|
|
|
540
566
|
total_loss += mse_dyn_loss
|
|
541
567
|
res_dict["dyn_loss"] += mse_dyn_loss
|
|
542
568
|
return total_loss, res_dict
|
|
543
|
-
|
|
544
|
-
def tree_flatten(self):
|
|
545
|
-
children = (
|
|
546
|
-
self.initial_condition_dict,
|
|
547
|
-
self._loss_weights,
|
|
548
|
-
)
|
|
549
|
-
aux_data = {
|
|
550
|
-
"u_dict": self.u_dict,
|
|
551
|
-
"dynamic_loss_dict": self.dynamic_loss_dict,
|
|
552
|
-
"derivative_keys_dict": self.derivative_keys_dict,
|
|
553
|
-
"obs_slice_dict": self.obs_slice_dict,
|
|
554
|
-
}
|
|
555
|
-
return (children, aux_data)
|
|
556
|
-
|
|
557
|
-
@classmethod
|
|
558
|
-
def tree_unflatten(cls, aux_data, children):
|
|
559
|
-
(initial_condition_dict, loss_weights) = children
|
|
560
|
-
loss_ode = cls(
|
|
561
|
-
loss_weights=loss_weights,
|
|
562
|
-
initial_condition_dict=initial_condition_dict,
|
|
563
|
-
**aux_data,
|
|
564
|
-
)
|
|
565
|
-
|
|
566
|
-
return loss_ode
|