jinns 0.9.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -0
- jinns/data/_Batchs.py +27 -0
- jinns/data/_DataGenerators.py +904 -1203
- jinns/data/__init__.py +4 -8
- jinns/experimental/__init__.py +0 -2
- jinns/experimental/_diffrax_solver.py +5 -5
- jinns/loss/_DynamicLoss.py +282 -305
- jinns/loss/_DynamicLossAbstract.py +321 -168
- jinns/loss/_LossODE.py +292 -309
- jinns/loss/_LossPDE.py +625 -1010
- jinns/loss/__init__.py +21 -5
- jinns/loss/_boundary_conditions.py +87 -41
- jinns/loss/{_Losses.py → _loss_utils.py} +95 -44
- jinns/loss/_loss_weights.py +59 -0
- jinns/loss/_operators.py +78 -72
- jinns/parameters/__init__.py +6 -0
- jinns/parameters/_derivative_keys.py +94 -0
- jinns/parameters/_params.py +115 -0
- jinns/plot/__init__.py +5 -0
- jinns/{data/_display.py → plot/_plot.py} +98 -75
- jinns/solver/_rar.py +183 -39
- jinns/solver/_solve.py +151 -124
- jinns/utils/__init__.py +3 -9
- jinns/utils/_containers.py +37 -44
- jinns/utils/_hyperpinn.py +224 -119
- jinns/utils/_pinn.py +183 -111
- jinns/utils/_save_load.py +121 -56
- jinns/utils/_spinn.py +113 -86
- jinns/utils/_types.py +64 -0
- jinns/utils/_utils.py +6 -160
- jinns/validation/_validation.py +48 -140
- {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/METADATA +4 -4
- jinns-1.0.0.dist-info/RECORD +38 -0
- {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/WHEEL +1 -1
- jinns/experimental/_sinuspinn.py +0 -135
- jinns/experimental/_spectralpinn.py +0 -87
- jinns/solver/_seq2seq.py +0 -157
- jinns/utils/_optim.py +0 -147
- jinns/utils/_utils_uspinn.py +0 -727
- jinns-0.9.0.dist-info/RECORD +0 -36
- {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/LICENSE +0 -0
- {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossODE.py
CHANGED
|
@@ -1,143 +1,169 @@
|
|
|
1
|
+
# pylint: disable=unsubscriptable-object, no-member
|
|
1
2
|
"""
|
|
2
3
|
Main module to implement a ODE loss in jinns
|
|
3
4
|
"""
|
|
5
|
+
from __future__ import (
|
|
6
|
+
annotations,
|
|
7
|
+
) # https://docs.python.org/3/library/typing.html#constant
|
|
4
8
|
|
|
9
|
+
from dataclasses import InitVar, fields
|
|
10
|
+
from typing import TYPE_CHECKING, Dict
|
|
11
|
+
import abc
|
|
5
12
|
import warnings
|
|
6
13
|
import jax
|
|
7
14
|
import jax.numpy as jnp
|
|
8
15
|
from jax import vmap
|
|
9
|
-
|
|
10
|
-
from
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
_update_eq_params_dict,
|
|
14
|
-
)
|
|
15
|
-
from jinns.loss._Losses import (
|
|
16
|
+
import equinox as eqx
|
|
17
|
+
from jaxtyping import Float, Array, Int
|
|
18
|
+
from jinns.data._DataGenerators import append_obs_batch
|
|
19
|
+
from jinns.loss._loss_utils import (
|
|
16
20
|
dynamic_loss_apply,
|
|
17
21
|
constraints_system_loss_apply,
|
|
18
22
|
observations_loss_apply,
|
|
19
23
|
)
|
|
24
|
+
from jinns.parameters._params import (
|
|
25
|
+
_get_vmap_in_axes_params,
|
|
26
|
+
_update_eq_params_dict,
|
|
27
|
+
)
|
|
28
|
+
from jinns.parameters._derivative_keys import _set_derivatives, DerivativeKeysODE
|
|
29
|
+
from jinns.loss._loss_weights import LossWeightsODE, LossWeightsODEDict
|
|
30
|
+
from jinns.loss._DynamicLossAbstract import ODE
|
|
20
31
|
from jinns.utils._pinn import PINN
|
|
21
32
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
@register_pytree_node_class
|
|
26
|
-
class LossODE:
|
|
27
|
-
r"""Loss object for an ordinary differential equation
|
|
33
|
+
if TYPE_CHECKING:
|
|
34
|
+
from jinns.utils._types import *
|
|
28
35
|
|
|
29
|
-
.. math::
|
|
30
|
-
\mathcal{N}[u](t) = 0, \forall t \in I
|
|
31
36
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
37
|
+
class _LossODEAbstract(eqx.Module):
|
|
38
|
+
"""
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
|
|
42
|
+
loss_weights : LossWeightsODE, default=None
|
|
43
|
+
The loss weights for the differents term : dynamic loss,
|
|
44
|
+
initial condition and eventually observations if any. All fields are
|
|
45
|
+
set to 1.0 by default.
|
|
46
|
+
derivative_keys : DerivativeKeysODE, default=None
|
|
47
|
+
Specify which field of `params` should be differentiated for each
|
|
48
|
+
composant of the total loss. Particularily useful for inverse problems.
|
|
49
|
+
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
50
|
+
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
51
|
+
is `"nn_params"` for each composant of the loss.
|
|
52
|
+
initial_condition : tuple, default=None
|
|
53
|
+
tuple of length 2 with initial condition $(t_0, u_0)$.
|
|
54
|
+
obs_slice : Slice, default=None
|
|
55
|
+
Slice object specifying the begininning/ending
|
|
56
|
+
slice of u output(s) that is observed. This is useful for
|
|
57
|
+
multidimensional PINN, with partially observed outputs.
|
|
58
|
+
Default is None (whole output is observed).
|
|
38
59
|
"""
|
|
39
60
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
obs_slice=None,
|
|
48
|
-
):
|
|
49
|
-
r"""
|
|
50
|
-
Parameters
|
|
51
|
-
----------
|
|
52
|
-
u :
|
|
53
|
-
the PINN
|
|
54
|
-
loss_weights :
|
|
55
|
-
a dictionary with values used to ponderate each term in the loss
|
|
56
|
-
function. Valid keys are `dyn_loss`, `initial_condition` and `observations`
|
|
57
|
-
Note that we can have jnp.arrays with the same dimension of
|
|
58
|
-
`u` which then ponderates each output of `u`
|
|
59
|
-
dynamic_loss :
|
|
60
|
-
the ODE dynamic part of the loss, basically the differential
|
|
61
|
-
operator :math:`\mathcal{N}[u](t)`. Should implement a method
|
|
62
|
-
`dynamic_loss.evaluate(t, u, params)`.
|
|
63
|
-
Can be None in order to
|
|
64
|
-
access only some part of the evaluate call results.
|
|
65
|
-
derivative_keys
|
|
66
|
-
A dict of lists of strings. In the dict, the key must correspond to
|
|
67
|
-
the loss term keywords. Then each of the values must correspond to keys in the parameter
|
|
68
|
-
dictionary (*at top level only of the parameter dictionary*).
|
|
69
|
-
It enables selecting the set of parameters
|
|
70
|
-
with respect to which the gradients of the dynamic
|
|
71
|
-
loss are computed. If nothing is provided, we set ["nn_params"] for all loss term
|
|
72
|
-
keywords, this is what is typically
|
|
73
|
-
done in solving forward problems, when we only estimate the
|
|
74
|
-
equation solution with a PINN. If some loss terms keywords are
|
|
75
|
-
missing we set their value to ["nn_params"] by default for the same
|
|
76
|
-
reason
|
|
77
|
-
initial_condition :
|
|
78
|
-
tuple of length 2 with initial condition :math:`(t0, u0)`.
|
|
79
|
-
Can be None in order to
|
|
80
|
-
access only some part of the evaluate call results.
|
|
81
|
-
obs_slice:
|
|
82
|
-
slice object specifying the begininning/ending
|
|
83
|
-
slice of u output(s) that is observed (this is then useful for
|
|
84
|
-
multidim PINN). Default is None.
|
|
85
|
-
|
|
86
|
-
Raises
|
|
87
|
-
------
|
|
88
|
-
ValueError
|
|
89
|
-
if initial condition is not a tuple.
|
|
90
|
-
"""
|
|
91
|
-
self.dynamic_loss = dynamic_loss
|
|
92
|
-
self.u = u
|
|
93
|
-
if derivative_keys is None:
|
|
94
|
-
# be default we only take gradient wrt nn_params
|
|
95
|
-
derivative_keys = {
|
|
96
|
-
k: ["nn_params"]
|
|
97
|
-
for k in [
|
|
98
|
-
"dyn_loss",
|
|
99
|
-
"initial_condition",
|
|
100
|
-
"observations",
|
|
101
|
-
]
|
|
102
|
-
}
|
|
103
|
-
if isinstance(derivative_keys, list):
|
|
104
|
-
# if the user only provided a list, this defines the gradient taken
|
|
105
|
-
# for all the loss entries
|
|
106
|
-
derivative_keys = {
|
|
107
|
-
k: derivative_keys
|
|
108
|
-
for k in [
|
|
109
|
-
"dyn_loss",
|
|
110
|
-
"initial_condition",
|
|
111
|
-
"observations",
|
|
112
|
-
]
|
|
113
|
-
}
|
|
61
|
+
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
62
|
+
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
63
|
+
# kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
|
|
64
|
+
derivative_keys: DerivativeKeysODE | None = eqx.field(kw_only=True, default=None)
|
|
65
|
+
loss_weights: LossWeightsODE | None = eqx.field(kw_only=True, default=None)
|
|
66
|
+
initial_condition: tuple | None = eqx.field(kw_only=True, default=None)
|
|
67
|
+
obs_slice: slice | None = eqx.field(kw_only=True, default=None, static=True)
|
|
114
68
|
|
|
115
|
-
|
|
69
|
+
def __post_init__(self):
|
|
70
|
+
if self.loss_weights is None:
|
|
71
|
+
self.loss_weights = LossWeightsODE()
|
|
116
72
|
|
|
117
|
-
if
|
|
73
|
+
if self.derivative_keys is None:
|
|
74
|
+
# be default we only take gradient wrt nn_params
|
|
75
|
+
self.derivative_keys = DerivativeKeysODE()
|
|
76
|
+
if self.initial_condition is None:
|
|
118
77
|
warnings.warn(
|
|
119
78
|
"Initial condition wasn't provided. Be sure to cover for that"
|
|
120
79
|
"case (e.g by. hardcoding it into the PINN output)."
|
|
121
80
|
)
|
|
122
81
|
else:
|
|
123
|
-
if
|
|
82
|
+
if (
|
|
83
|
+
not isinstance(self.initial_condition, tuple)
|
|
84
|
+
or len(self.initial_condition) != 2
|
|
85
|
+
):
|
|
124
86
|
raise ValueError(
|
|
125
|
-
|
|
87
|
+
"Initial condition should be a tuple of len 2 with (t0, u0), "
|
|
88
|
+
f"{self.initial_condition} was passed."
|
|
126
89
|
)
|
|
127
|
-
|
|
128
|
-
self.loss_weights = loss_weights
|
|
129
|
-
self.obs_slice = obs_slice
|
|
90
|
+
|
|
130
91
|
if self.obs_slice is None:
|
|
131
92
|
self.obs_slice = jnp.s_[...]
|
|
132
93
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
94
|
+
if self.loss_weights is None:
|
|
95
|
+
self.loss_weights = LossWeightsODE()
|
|
96
|
+
|
|
97
|
+
@abc.abstractmethod
|
|
98
|
+
def evaluate(
|
|
99
|
+
self: eqx.Module, params: Params, batch: ODEBatch
|
|
100
|
+
) -> tuple[Float, dict]:
|
|
101
|
+
raise NotImplementedError
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class LossODE(_LossODEAbstract):
|
|
105
|
+
r"""Loss object for an ordinary differential equation
|
|
106
|
+
|
|
107
|
+
$$
|
|
108
|
+
\mathcal{N}[u](t) = 0, \forall t \in I
|
|
109
|
+
$$
|
|
110
|
+
|
|
111
|
+
where $\mathcal{N}[\cdot]$ is a differential operator and the
|
|
112
|
+
initial condition is $u(t_0)=u_0$.
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
Parameters
|
|
116
|
+
----------
|
|
117
|
+
loss_weights : LossWeightsODE, default=None
|
|
118
|
+
The loss weights for the differents term : dynamic loss,
|
|
119
|
+
initial condition and eventually observations if any. All fields are
|
|
120
|
+
set to 1.0 by default.
|
|
121
|
+
derivative_keys : DerivativeKeysODE, default=None
|
|
122
|
+
Specify which field of `params` should be differentiated for each
|
|
123
|
+
composant of the total loss. Particularily useful for inverse problems.
|
|
124
|
+
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
125
|
+
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
126
|
+
is `"nn_params"` for each composant of the loss.
|
|
127
|
+
initial_condition : tuple, default=None
|
|
128
|
+
tuple of length 2 with initial condition $(t_0, u_0)$.
|
|
129
|
+
obs_slice Slice, default=None
|
|
130
|
+
Slice object specifying the begininning/ending
|
|
131
|
+
slice of u output(s) that is observed. This is useful for
|
|
132
|
+
multidimensional PINN, with partially observed outputs.
|
|
133
|
+
Default is None (whole output is observed).
|
|
134
|
+
u : eqx.Module
|
|
135
|
+
the PINN
|
|
136
|
+
dynamic_loss : DynamicLoss
|
|
137
|
+
the ODE dynamic part of the loss, basically the differential
|
|
138
|
+
operator $\mathcal{N}[u](t)$. Should implement a method
|
|
139
|
+
`dynamic_loss.evaluate(t, u, params)`.
|
|
140
|
+
Can be None in order to access only some part of the evaluate call.
|
|
141
|
+
|
|
142
|
+
Raises
|
|
143
|
+
------
|
|
144
|
+
ValueError
|
|
145
|
+
if initial condition is not a tuple.
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
149
|
+
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
150
|
+
u: eqx.Module
|
|
151
|
+
dynamic_loss: DynamicLoss | None
|
|
152
|
+
|
|
153
|
+
vmap_in_axes: tuple[Int] = eqx.field(init=False, static=True)
|
|
154
|
+
|
|
155
|
+
def __post_init__(self):
|
|
156
|
+
super().__post_init__() # because __init__ or __post_init__ of Base
|
|
157
|
+
# class is not automatically called
|
|
158
|
+
|
|
159
|
+
self.vmap_in_axes = (0,)
|
|
136
160
|
|
|
137
161
|
def __call__(self, *args, **kwargs):
|
|
138
162
|
return self.evaluate(*args, **kwargs)
|
|
139
163
|
|
|
140
|
-
def evaluate(
|
|
164
|
+
def evaluate(
|
|
165
|
+
self, params: Params, batch: ODEBatch
|
|
166
|
+
) -> tuple[Float[Array, "1"], dict[str, float]]:
|
|
141
167
|
"""
|
|
142
168
|
Evaluate the loss function at a batch of points for given parameters.
|
|
143
169
|
|
|
@@ -145,21 +171,14 @@ class LossODE:
|
|
|
145
171
|
Parameters
|
|
146
172
|
---------
|
|
147
173
|
params
|
|
148
|
-
|
|
149
|
-
Typically, it is a dictionary of
|
|
150
|
-
dictionaries: `eq_params` and `nn_params``, respectively the
|
|
151
|
-
differential equation parameters and the neural network parameter
|
|
174
|
+
Parameters at which the loss is evaluated
|
|
152
175
|
batch
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
metamodeling) and an optional additional batch of observed
|
|
157
|
-
inputs/outputs/parameters
|
|
176
|
+
Composed of a batch of time points
|
|
177
|
+
at which to evaluate the differential operator. An optional additional batch of parameters (eg. for metamodeling) and an optional additional batch of observed inputs/outputs/parameters can
|
|
178
|
+
be supplied.
|
|
158
179
|
"""
|
|
159
180
|
temporal_batch = batch.temporal_batch
|
|
160
181
|
|
|
161
|
-
vmap_in_axes_t = (0,)
|
|
162
|
-
|
|
163
182
|
# Retrieve the optional eq_params_batch
|
|
164
183
|
# and update eq_params with the latter
|
|
165
184
|
# and update vmap_in_axes
|
|
@@ -170,21 +189,19 @@ class LossODE:
|
|
|
170
189
|
vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
|
|
171
190
|
|
|
172
191
|
## dynamic part
|
|
173
|
-
params_ = _set_derivatives(params, "dyn_loss", self.derivative_keys)
|
|
174
192
|
if self.dynamic_loss is not None:
|
|
175
193
|
mse_dyn_loss = dynamic_loss_apply(
|
|
176
194
|
self.dynamic_loss.evaluate,
|
|
177
195
|
self.u,
|
|
178
196
|
(temporal_batch,),
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
self.loss_weights
|
|
197
|
+
_set_derivatives(params, self.derivative_keys.dyn_loss),
|
|
198
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
199
|
+
self.loss_weights.dyn_loss,
|
|
182
200
|
)
|
|
183
201
|
else:
|
|
184
202
|
mse_dyn_loss = jnp.array(0.0)
|
|
185
203
|
|
|
186
204
|
# initial condition
|
|
187
|
-
params_ = _set_derivatives(params, "initial_condition", self.derivative_keys)
|
|
188
205
|
if self.initial_condition is not None:
|
|
189
206
|
vmap_in_axes = (None,) + vmap_in_axes_params
|
|
190
207
|
if not jax.tree_util.tree_leaves(vmap_in_axes):
|
|
@@ -193,12 +210,24 @@ class LossODE:
|
|
|
193
210
|
v_u = self.u
|
|
194
211
|
else:
|
|
195
212
|
v_u = vmap(self.u, (None,) + vmap_in_axes_params)
|
|
196
|
-
t0, u0 = self.initial_condition
|
|
213
|
+
t0, u0 = self.initial_condition # pylint: disable=unpacking-non-sequence
|
|
197
214
|
t0 = jnp.array(t0)
|
|
198
215
|
u0 = jnp.array(u0)
|
|
199
216
|
mse_initial_condition = jnp.mean(
|
|
200
|
-
self.loss_weights
|
|
201
|
-
* jnp.sum(
|
|
217
|
+
self.loss_weights.initial_condition
|
|
218
|
+
* jnp.sum(
|
|
219
|
+
(
|
|
220
|
+
v_u(
|
|
221
|
+
t0,
|
|
222
|
+
_set_derivatives(
|
|
223
|
+
params, self.derivative_keys.initial_condition
|
|
224
|
+
),
|
|
225
|
+
)
|
|
226
|
+
- u0
|
|
227
|
+
)
|
|
228
|
+
** 2,
|
|
229
|
+
axis=-1,
|
|
230
|
+
)
|
|
202
231
|
)
|
|
203
232
|
else:
|
|
204
233
|
mse_initial_condition = jnp.array(0.0)
|
|
@@ -208,14 +237,13 @@ class LossODE:
|
|
|
208
237
|
params = _update_eq_params_dict(params, batch.obs_batch_dict["eq_params"])
|
|
209
238
|
|
|
210
239
|
# MSE loss wrt to an observed batch
|
|
211
|
-
params_ = _set_derivatives(params, "observations", self.derivative_keys)
|
|
212
240
|
mse_observation_loss = observations_loss_apply(
|
|
213
241
|
self.u,
|
|
214
242
|
(batch.obs_batch_dict["pinn_in"],),
|
|
215
|
-
|
|
216
|
-
|
|
243
|
+
_set_derivatives(params, self.derivative_keys.observations),
|
|
244
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
217
245
|
batch.obs_batch_dict["val"],
|
|
218
|
-
self.loss_weights
|
|
246
|
+
self.loss_weights.observations,
|
|
219
247
|
self.obs_slice,
|
|
220
248
|
)
|
|
221
249
|
else:
|
|
@@ -231,193 +259,174 @@ class LossODE:
|
|
|
231
259
|
}
|
|
232
260
|
)
|
|
233
261
|
|
|
234
|
-
def tree_flatten(self):
|
|
235
|
-
children = (self.initial_condition, self.loss_weights)
|
|
236
|
-
aux_data = {
|
|
237
|
-
"u": self.u,
|
|
238
|
-
"dynamic_loss": self.dynamic_loss,
|
|
239
|
-
"obs_slice": self.obs_slice,
|
|
240
|
-
"derivative_keys": self.derivative_keys,
|
|
241
|
-
}
|
|
242
|
-
return (children, aux_data)
|
|
243
|
-
|
|
244
|
-
@classmethod
|
|
245
|
-
def tree_unflatten(cls, aux_data, children):
|
|
246
|
-
(initial_condition, loss_weights) = children
|
|
247
|
-
loss_ode = cls(
|
|
248
|
-
loss_weights=loss_weights,
|
|
249
|
-
initial_condition=initial_condition,
|
|
250
|
-
**aux_data,
|
|
251
|
-
)
|
|
252
|
-
return loss_ode
|
|
253
|
-
|
|
254
262
|
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
"""
|
|
263
|
+
class SystemLossODE(eqx.Module):
|
|
264
|
+
r"""
|
|
258
265
|
Class to implement a system of ODEs.
|
|
259
266
|
The goal is to give maximum freedom to the user. The class is created with
|
|
260
|
-
a dict of dynamic loss and a dict of initial conditions.
|
|
261
|
-
over the dynamic losses that compose the system. All
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
267
|
+
a dict of dynamic loss and a dict of initial conditions. Then, it iterates
|
|
268
|
+
over the dynamic losses that compose the system. All PINNs are passed as
|
|
269
|
+
arguments to each dynamic loss evaluate functions, along with all the
|
|
270
|
+
parameter dictionaries. All specification is left to the responsability
|
|
271
|
+
of the user, inside the dynamic loss.
|
|
265
272
|
|
|
266
273
|
**Note:** All the dictionaries (except `dynamic_loss_dict`) must have the same keys.
|
|
267
274
|
Indeed, these dictionaries (except `dynamic_loss_dict`) are tied to one
|
|
268
275
|
solution.
|
|
269
276
|
|
|
270
|
-
|
|
271
|
-
|
|
277
|
+
Parameters
|
|
278
|
+
----------
|
|
279
|
+
u_dict : Dict[str, eqx.Module]
|
|
280
|
+
dict of PINNs
|
|
281
|
+
loss_weights : LossWeightsODEDict
|
|
282
|
+
A dictionary of LossWeightsODE
|
|
283
|
+
derivative_keys_dict : Dict[str, DerivativeKeysODE], default=None
|
|
284
|
+
A dictionnary of DerivativeKeysODE specifying what field of `params`
|
|
285
|
+
should be used during gradient computations for each of the terms of
|
|
286
|
+
the total loss, for each of the loss in the system. Default is
|
|
287
|
+
`"nn_params`" everywhere.
|
|
288
|
+
initial_condition_dict : Dict[str, tuple], default=None
|
|
289
|
+
dict of tuple of length 2 with initial condition $(t_0, u_0)$
|
|
290
|
+
Must share the keys of `u_dict`. Default is None. No initial
|
|
291
|
+
condition is permitted when the initial condition is hardcoded in
|
|
292
|
+
the PINN architecture for example
|
|
293
|
+
dynamic_loss_dict : Dict[str, ODE]
|
|
294
|
+
dict of dynamic part of the loss, basically the differential
|
|
295
|
+
operator $\mathcal{N}[u](t)$. Should implement a method
|
|
296
|
+
`dynamic_loss.evaluate(t, u, params)`
|
|
297
|
+
obs_slice_dict : Dict[str, Slice]
|
|
298
|
+
dict of obs_slice, with keys from `u_dict` to designate the
|
|
299
|
+
output(s) channels that are observed, for each
|
|
300
|
+
PINNs. Default is None. But if a value is given, all the entries of
|
|
301
|
+
`u_dict` must be represented here with default value `jnp.s_[...]`
|
|
302
|
+
if no particular slice is to be given.
|
|
303
|
+
|
|
304
|
+
Raises
|
|
305
|
+
------
|
|
306
|
+
ValueError
|
|
307
|
+
if initial condition is not a dict of tuple.
|
|
308
|
+
ValueError
|
|
309
|
+
if the dictionaries that should share the keys of u_dict do not.
|
|
272
310
|
"""
|
|
273
311
|
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
initial_condition_dict
|
|
306
|
-
dict of tuple of length 2 with initial condition :math:`(t_0, u_0)`
|
|
307
|
-
Must share the keys of `u_dict`. Default is None. No initial
|
|
308
|
-
condition is permitted when the initial condition is hardcoded in
|
|
309
|
-
the PINN architecture for example
|
|
310
|
-
dynamic_loss_dict
|
|
311
|
-
dict of dynamic part of the loss, basically the differential
|
|
312
|
-
operator :math:`\mathcal{N}[u](t)`. Should implement a method
|
|
313
|
-
`dynamic_loss.evaluate(t, u, params)`
|
|
314
|
-
obs_slice_dict
|
|
315
|
-
dict of obs_slice, with keys from `u_dict` to designate the
|
|
316
|
-
output(s) channels that are forced to observed values, for each
|
|
317
|
-
PINNs. Default is None. But if a value is given, all the entries of
|
|
318
|
-
`u_dict` must be represented here with default value `jnp.s_[...]`
|
|
319
|
-
if no particular slice is to be given
|
|
320
|
-
|
|
321
|
-
Raises
|
|
322
|
-
------
|
|
323
|
-
ValueError
|
|
324
|
-
if initial condition is not a dict of tuple
|
|
325
|
-
ValueError
|
|
326
|
-
if the dictionaries that should share the keys of u_dict do not
|
|
327
|
-
"""
|
|
328
|
-
|
|
312
|
+
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
313
|
+
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
314
|
+
u_dict: Dict[str, eqx.Module]
|
|
315
|
+
dynamic_loss_dict: Dict[str, ODE]
|
|
316
|
+
derivative_keys_dict: Dict[str, DerivativeKeysODE | None] | None = eqx.field(
|
|
317
|
+
kw_only=True, default=None
|
|
318
|
+
)
|
|
319
|
+
initial_condition_dict: Dict[str, tuple] | None = eqx.field(
|
|
320
|
+
kw_only=True, default=None
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
obs_slice_dict: Dict[str, slice | None] | None = eqx.field(
|
|
324
|
+
kw_only=True, default=None, static=True
|
|
325
|
+
) # We are at an "leaf" attribute here (slice, not valid JAX type). Since
|
|
326
|
+
# we do not expect it to change with put a static=True here. But note that
|
|
327
|
+
# this is the only static for all the SystemLossODE attribute, since all
|
|
328
|
+
# other are composed of more complex structures ("non-leaf")
|
|
329
|
+
|
|
330
|
+
# For the user loss_weights are passed as a LossWeightsODEDict (with internal
|
|
331
|
+
# dictionary having keys in u_dict and / or dynamic_loss_dict)
|
|
332
|
+
loss_weights: InitVar[LossWeightsODEDict | None] = eqx.field(
|
|
333
|
+
kw_only=True, default=None
|
|
334
|
+
)
|
|
335
|
+
u_constraints_dict: Dict[str, LossODE] = eqx.field(init=False)
|
|
336
|
+
derivative_keys_dyn_loss_dict: Dict[str, DerivativeKeysODE] = eqx.field(init=False)
|
|
337
|
+
|
|
338
|
+
u_dict_with_none: Dict[str, None] = eqx.field(init=False)
|
|
339
|
+
# internally the loss weights are handled with a dictionary
|
|
340
|
+
_loss_weights: Dict[str, dict] = eqx.field(init=False)
|
|
341
|
+
|
|
342
|
+
def __post_init__(self, loss_weights):
|
|
329
343
|
# a dictionary that will be useful at different places
|
|
330
|
-
self.u_dict_with_none = {k: None for k in u_dict.keys()}
|
|
331
|
-
if initial_condition_dict is None:
|
|
344
|
+
self.u_dict_with_none = {k: None for k in self.u_dict.keys()}
|
|
345
|
+
if self.initial_condition_dict is None:
|
|
332
346
|
self.initial_condition_dict = self.u_dict_with_none
|
|
333
347
|
else:
|
|
334
|
-
self.
|
|
335
|
-
if u_dict.keys() != initial_condition_dict.keys():
|
|
348
|
+
if self.u_dict.keys() != self.initial_condition_dict.keys():
|
|
336
349
|
raise ValueError(
|
|
337
350
|
"initial_condition_dict should have same keys as u_dict"
|
|
338
351
|
)
|
|
339
|
-
if obs_slice_dict is None:
|
|
340
|
-
self.obs_slice_dict = {k: jnp.s_[...] for k in u_dict.keys()}
|
|
352
|
+
if self.obs_slice_dict is None:
|
|
353
|
+
self.obs_slice_dict = {k: jnp.s_[...] for k in self.u_dict.keys()}
|
|
341
354
|
else:
|
|
342
|
-
self.
|
|
343
|
-
if u_dict.keys() != obs_slice_dict.keys():
|
|
355
|
+
if self.u_dict.keys() != self.obs_slice_dict.keys():
|
|
344
356
|
raise ValueError("obs_slice_dict should have same keys as u_dict")
|
|
345
357
|
|
|
346
|
-
if derivative_keys_dict is None:
|
|
358
|
+
if self.derivative_keys_dict is None:
|
|
347
359
|
self.derivative_keys_dict = {
|
|
348
360
|
k: None
|
|
349
|
-
for k in set(
|
|
361
|
+
for k in set(
|
|
362
|
+
list(self.dynamic_loss_dict.keys()) + list(self.u_dict.keys())
|
|
363
|
+
)
|
|
350
364
|
}
|
|
351
365
|
# set() because we can have duplicate entries and in this case we
|
|
352
366
|
# say it corresponds to the same derivative_keys_dict entry
|
|
353
|
-
|
|
354
|
-
|
|
367
|
+
# we need both because the constraints (all but dyn_loss) will be
|
|
368
|
+
# done by iterating on u_dict while the dyn_loss will be by
|
|
369
|
+
# iterating on dynamic_loss_dict. So each time we will require dome
|
|
370
|
+
# derivative_keys_dict
|
|
355
371
|
|
|
356
372
|
# but then if the user did not provide anything, we must at least have
|
|
357
373
|
# a default value for the dynamic_loss_dict keys entries in
|
|
358
374
|
# self.derivative_keys_dict since the computation of dynamic losses is
|
|
359
375
|
# made without create a lossODE object that would provide the
|
|
360
376
|
# default values
|
|
361
|
-
for k in dynamic_loss_dict.keys():
|
|
377
|
+
for k in self.dynamic_loss_dict.keys():
|
|
362
378
|
if self.derivative_keys_dict[k] is None:
|
|
363
|
-
self.derivative_keys_dict[k] =
|
|
379
|
+
self.derivative_keys_dict[k] = DerivativeKeysODE()
|
|
364
380
|
|
|
365
|
-
self.
|
|
366
|
-
self.u_dict = u_dict
|
|
367
|
-
|
|
368
|
-
self.loss_weights = loss_weights # We call the setter
|
|
369
|
-
# note that self.initial_condition_dict must be
|
|
370
|
-
# initialized beforehand
|
|
381
|
+
self._loss_weights = self.set_loss_weights(loss_weights)
|
|
371
382
|
|
|
372
383
|
# The constaints on the solutions will be implemented by reusing a
|
|
373
384
|
# LossODE class without dynamic loss term
|
|
374
385
|
self.u_constraints_dict = {}
|
|
375
386
|
for i in self.u_dict.keys():
|
|
376
387
|
self.u_constraints_dict[i] = LossODE(
|
|
377
|
-
u=u_dict[i],
|
|
378
|
-
loss_weights=
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
388
|
+
u=self.u_dict[i],
|
|
389
|
+
loss_weights=LossWeightsODE(
|
|
390
|
+
dyn_loss=0.0,
|
|
391
|
+
initial_condition=1.0,
|
|
392
|
+
observations=1.0,
|
|
393
|
+
),
|
|
383
394
|
dynamic_loss=None,
|
|
384
395
|
derivative_keys=self.derivative_keys_dict[i],
|
|
385
396
|
initial_condition=self.initial_condition_dict[i],
|
|
386
397
|
obs_slice=self.obs_slice_dict[i],
|
|
387
398
|
)
|
|
388
399
|
|
|
389
|
-
# for convenience in the tree_map of evaluate
|
|
390
|
-
# we separate the two derivative keys dict
|
|
400
|
+
# for convenience in the tree_map of evaluate
|
|
391
401
|
self.derivative_keys_dyn_loss_dict = {
|
|
392
402
|
k: self.derivative_keys_dict[k]
|
|
393
|
-
for k in self.dynamic_loss_dict.keys() & self.derivative_keys_dict.keys()
|
|
394
|
-
|
|
395
|
-
self.derivative_keys_u_dict = {
|
|
396
|
-
k: self.derivative_keys_dict[k]
|
|
397
|
-
for k in self.u_dict.keys() & self.derivative_keys_dict.keys()
|
|
403
|
+
for k in self.dynamic_loss_dict.keys() # & self.derivative_keys_dict.keys()
|
|
404
|
+
# comment because intersection is neceserily fulfilled right?
|
|
398
405
|
}
|
|
399
406
|
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
407
|
+
def set_loss_weights(self, loss_weights_init):
|
|
408
|
+
"""
|
|
409
|
+
This rather complex function enables the user to specify a simple
|
|
410
|
+
loss_weights=LossWeightsODEDict(dyn_loss=1., initial_condition=Tmax)
|
|
411
|
+
for ponderating values being applied to all the equations of the
|
|
412
|
+
system... So all the transformations are handled here
|
|
413
|
+
"""
|
|
414
|
+
_loss_weights = {}
|
|
415
|
+
for k in fields(loss_weights_init):
|
|
416
|
+
v = getattr(loss_weights_init, k.name)
|
|
408
417
|
if isinstance(v, dict):
|
|
409
|
-
for
|
|
418
|
+
for vv in v.values():
|
|
410
419
|
if not isinstance(vv, (int, float)) and not (
|
|
411
|
-
isinstance(vv,
|
|
420
|
+
isinstance(vv, Array)
|
|
412
421
|
and ((vv.shape == (1,) or len(vv.shape) == 0))
|
|
413
422
|
):
|
|
414
423
|
# TODO improve that
|
|
415
424
|
raise ValueError(
|
|
416
425
|
f"loss values cannot be vectorial here, got {vv}"
|
|
417
426
|
)
|
|
418
|
-
if k == "dyn_loss":
|
|
427
|
+
if k.name == "dyn_loss":
|
|
419
428
|
if v.keys() == self.dynamic_loss_dict.keys():
|
|
420
|
-
|
|
429
|
+
_loss_weights[k.name] = v
|
|
421
430
|
else:
|
|
422
431
|
raise ValueError(
|
|
423
432
|
"Keys in nested dictionary of loss_weights"
|
|
@@ -425,48 +434,41 @@ class SystemLossODE:
|
|
|
425
434
|
)
|
|
426
435
|
else:
|
|
427
436
|
if v.keys() == self.u_dict.keys():
|
|
428
|
-
|
|
437
|
+
_loss_weights[k.name] = v
|
|
429
438
|
else:
|
|
430
439
|
raise ValueError(
|
|
431
440
|
"Keys in nested dictionary of loss_weights"
|
|
432
441
|
" do not match u_dict keys"
|
|
433
442
|
)
|
|
443
|
+
elif v is None:
|
|
444
|
+
_loss_weights[k.name] = {kk: 0 for kk in self.u_dict.keys()}
|
|
434
445
|
else:
|
|
435
446
|
if not isinstance(v, (int, float)) and not (
|
|
436
|
-
isinstance(v,
|
|
437
|
-
and ((v.shape == (1,) or len(v.shape) == 0))
|
|
447
|
+
isinstance(v, Array) and ((v.shape == (1,) or len(v.shape) == 0))
|
|
438
448
|
):
|
|
439
449
|
# TODO improve that
|
|
440
450
|
raise ValueError(f"loss values cannot be vectorial here, got {v}")
|
|
441
|
-
if k == "dyn_loss":
|
|
442
|
-
|
|
451
|
+
if k.name == "dyn_loss":
|
|
452
|
+
_loss_weights[k.name] = {
|
|
443
453
|
kk: v for kk in self.dynamic_loss_dict.keys()
|
|
444
454
|
}
|
|
445
455
|
else:
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
if "observations" not in value.keys():
|
|
450
|
-
self._loss_weights["observations"] = {k: 0 for k in self.u_dict.keys()}
|
|
456
|
+
_loss_weights[k.name] = {kk: v for kk in self.u_dict.keys()}
|
|
457
|
+
|
|
458
|
+
return _loss_weights
|
|
451
459
|
|
|
452
460
|
def __call__(self, *args, **kwargs):
|
|
453
461
|
return self.evaluate(*args, **kwargs)
|
|
454
462
|
|
|
455
|
-
def evaluate(self, params_dict, batch):
|
|
463
|
+
def evaluate(self, params_dict: ParamsDict, batch: ODEBatch) -> Float[Array, "1"]:
|
|
456
464
|
"""
|
|
457
465
|
Evaluate the loss function at a batch of points for given parameters.
|
|
458
466
|
|
|
459
467
|
|
|
460
468
|
Parameters
|
|
461
469
|
---------
|
|
462
|
-
|
|
463
|
-
A
|
|
464
|
-
Typically, it is a dictionary of dictionaries of
|
|
465
|
-
dictionaries: `eq_params` and `nn_params``, respectively the
|
|
466
|
-
differential equation parameters and the neural network parameter.
|
|
467
|
-
Note that params_dict["nn_params"] need not be a dictionary anymore
|
|
468
|
-
but can directly be the parameters. It is useful when working with
|
|
469
|
-
neural networks sharing the same parameters
|
|
470
|
+
params
|
|
471
|
+
A ParamsDict object
|
|
470
472
|
batch
|
|
471
473
|
A ODEBatch object.
|
|
472
474
|
Such a named tuple is composed of a batch of time points
|
|
@@ -475,10 +477,10 @@ class SystemLossODE:
|
|
|
475
477
|
inputs/outputs/parameters
|
|
476
478
|
"""
|
|
477
479
|
if (
|
|
478
|
-
isinstance(params_dict
|
|
479
|
-
and self.u_dict.keys() != params_dict
|
|
480
|
+
isinstance(params_dict.nn_params, dict)
|
|
481
|
+
and self.u_dict.keys() != params_dict.nn_params.keys()
|
|
480
482
|
):
|
|
481
|
-
raise ValueError("u_dict and params_dict
|
|
483
|
+
raise ValueError("u_dict and params_dict.nn_params should have same keys ")
|
|
482
484
|
|
|
483
485
|
temporal_batch = batch.temporal_batch
|
|
484
486
|
|
|
@@ -489,7 +491,7 @@ class SystemLossODE:
|
|
|
489
491
|
# and update vmap_in_axes
|
|
490
492
|
if batch.param_batch_dict is not None:
|
|
491
493
|
# update params with the batches of generated params
|
|
492
|
-
|
|
494
|
+
params = _update_eq_params_dict(params, batch.param_batch_dict)
|
|
493
495
|
|
|
494
496
|
vmap_in_axes_params = _get_vmap_in_axes_params(
|
|
495
497
|
batch.param_batch_dict, params_dict
|
|
@@ -497,12 +499,11 @@ class SystemLossODE:
|
|
|
497
499
|
|
|
498
500
|
def dyn_loss_for_one_key(dyn_loss, derivative_key, loss_weight):
|
|
499
501
|
"""This function is used in tree_map"""
|
|
500
|
-
params_dict_ = _set_derivatives(params_dict, "dyn_loss", derivative_key)
|
|
501
502
|
return dynamic_loss_apply(
|
|
502
503
|
dyn_loss.evaluate,
|
|
503
504
|
self.u_dict,
|
|
504
505
|
(temporal_batch,),
|
|
505
|
-
|
|
506
|
+
_set_derivatives(params_dict, derivative_key.dyn_loss),
|
|
506
507
|
vmap_in_axes_t + vmap_in_axes_params,
|
|
507
508
|
loss_weight,
|
|
508
509
|
u_type=PINN,
|
|
@@ -513,6 +514,11 @@ class SystemLossODE:
|
|
|
513
514
|
self.dynamic_loss_dict,
|
|
514
515
|
self.derivative_keys_dyn_loss_dict,
|
|
515
516
|
self._loss_weights["dyn_loss"],
|
|
517
|
+
is_leaf=lambda x: isinstance(x, ODE), # before when dynamic losses
|
|
518
|
+
# where plain (unregister pytree) node classes, we could not traverse
|
|
519
|
+
# this level. Now that dynamic losses are eqx.Module they can be
|
|
520
|
+
# traversed by tree map recursion. Hence we need to specify to that
|
|
521
|
+
# we want to stop at this level
|
|
516
522
|
)
|
|
517
523
|
mse_dyn_loss = jax.tree_util.tree_reduce(
|
|
518
524
|
lambda x, y: x + y, jax.tree_util.tree_leaves(dyn_loss_mse_dict)
|
|
@@ -527,7 +533,8 @@ class SystemLossODE:
|
|
|
527
533
|
|
|
528
534
|
# we need to do the following for the tree_mapping to work
|
|
529
535
|
if batch.obs_batch_dict is None:
|
|
530
|
-
batch = batch
|
|
536
|
+
batch = append_obs_batch(batch, self.u_dict_with_none)
|
|
537
|
+
|
|
531
538
|
total_loss, res_dict = constraints_system_loss_apply(
|
|
532
539
|
self.u_constraints_dict,
|
|
533
540
|
batch,
|
|
@@ -540,27 +547,3 @@ class SystemLossODE:
|
|
|
540
547
|
total_loss += mse_dyn_loss
|
|
541
548
|
res_dict["dyn_loss"] += mse_dyn_loss
|
|
542
549
|
return total_loss, res_dict
|
|
543
|
-
|
|
544
|
-
def tree_flatten(self):
|
|
545
|
-
children = (
|
|
546
|
-
self.initial_condition_dict,
|
|
547
|
-
self._loss_weights,
|
|
548
|
-
)
|
|
549
|
-
aux_data = {
|
|
550
|
-
"u_dict": self.u_dict,
|
|
551
|
-
"dynamic_loss_dict": self.dynamic_loss_dict,
|
|
552
|
-
"derivative_keys_dict": self.derivative_keys_dict,
|
|
553
|
-
"obs_slice_dict": self.obs_slice_dict,
|
|
554
|
-
}
|
|
555
|
-
return (children, aux_data)
|
|
556
|
-
|
|
557
|
-
@classmethod
|
|
558
|
-
def tree_unflatten(cls, aux_data, children):
|
|
559
|
-
(initial_condition_dict, loss_weights) = children
|
|
560
|
-
loss_ode = cls(
|
|
561
|
-
loss_weights=loss_weights,
|
|
562
|
-
initial_condition_dict=initial_condition_dict,
|
|
563
|
-
**aux_data,
|
|
564
|
-
)
|
|
565
|
-
|
|
566
|
-
return loss_ode
|