jinns 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +116 -189
- jinns/loss/_DynamicLossAbstract.py +45 -68
- jinns/loss/_LossODE.py +71 -336
- jinns/loss/_LossPDE.py +176 -513
- jinns/loss/__init__.py +28 -6
- jinns/loss/_abstract_loss.py +15 -0
- jinns/loss/_boundary_conditions.py +22 -21
- jinns/loss/_loss_utils.py +98 -173
- jinns/loss/_loss_weights.py +12 -44
- jinns/loss/_operators.py +84 -76
- jinns/nn/__init__.py +22 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +434 -0
- jinns/nn/_mlp.py +217 -0
- jinns/nn/_pinn.py +204 -0
- jinns/nn/_ppinn.py +239 -0
- jinns/{utils → nn}/_save_load.py +39 -53
- jinns/nn/_spinn.py +123 -0
- jinns/nn/_spinn_mlp.py +202 -0
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +38 -37
- jinns/solver/_rar.py +82 -65
- jinns/solver/_solve.py +111 -71
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -5
- jinns/utils/_containers.py +12 -9
- jinns/utils/_types.py +11 -57
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/METADATA +11 -10
- jinns-1.4.0.dist-info/RECORD +53 -0
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns/utils/_hyperpinn.py +0 -420
- jinns/utils/_pinn.py +0 -324
- jinns/utils/_ppinn.py +0 -227
- jinns/utils/_spinn.py +0 -249
- jinns-1.2.0.dist-info/RECORD +0 -41
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossODE.py
CHANGED
|
@@ -1,24 +1,23 @@
|
|
|
1
|
-
# pylint: disable=unsubscriptable-object, no-member
|
|
2
1
|
"""
|
|
3
2
|
Main module to implement a ODE loss in jinns
|
|
4
3
|
"""
|
|
4
|
+
|
|
5
5
|
from __future__ import (
|
|
6
6
|
annotations,
|
|
7
7
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
8
|
|
|
9
|
-
from dataclasses import InitVar
|
|
10
|
-
from typing import TYPE_CHECKING,
|
|
9
|
+
from dataclasses import InitVar
|
|
10
|
+
from typing import TYPE_CHECKING, TypedDict
|
|
11
|
+
from types import EllipsisType
|
|
11
12
|
import abc
|
|
12
13
|
import warnings
|
|
13
14
|
import jax
|
|
14
15
|
import jax.numpy as jnp
|
|
15
16
|
from jax import vmap
|
|
16
17
|
import equinox as eqx
|
|
17
|
-
from jaxtyping import Float, Array
|
|
18
|
-
from jinns.data._DataGenerators import append_obs_batch
|
|
18
|
+
from jaxtyping import Float, Array
|
|
19
19
|
from jinns.loss._loss_utils import (
|
|
20
20
|
dynamic_loss_apply,
|
|
21
|
-
constraints_system_loss_apply,
|
|
22
21
|
observations_loss_apply,
|
|
23
22
|
)
|
|
24
23
|
from jinns.parameters._params import (
|
|
@@ -26,15 +25,23 @@ from jinns.parameters._params import (
|
|
|
26
25
|
_update_eq_params_dict,
|
|
27
26
|
)
|
|
28
27
|
from jinns.parameters._derivative_keys import _set_derivatives, DerivativeKeysODE
|
|
29
|
-
from jinns.loss._loss_weights import LossWeightsODE
|
|
30
|
-
from jinns.loss.
|
|
31
|
-
from jinns.utils._pinn import PINN
|
|
28
|
+
from jinns.loss._loss_weights import LossWeightsODE
|
|
29
|
+
from jinns.loss._abstract_loss import AbstractLoss
|
|
32
30
|
|
|
33
31
|
if TYPE_CHECKING:
|
|
34
|
-
|
|
32
|
+
# imports only used in type hints
|
|
33
|
+
from jinns.parameters._params import Params
|
|
34
|
+
from jinns.data._Batchs import ODEBatch
|
|
35
|
+
from jinns.nn._abstract_pinn import AbstractPINN
|
|
36
|
+
from jinns.loss import ODE
|
|
37
|
+
|
|
38
|
+
class LossDictODE(TypedDict):
|
|
39
|
+
dyn_loss: Float[Array, " "]
|
|
40
|
+
initial_condition: Float[Array, " "]
|
|
41
|
+
observations: Float[Array, " "]
|
|
35
42
|
|
|
36
43
|
|
|
37
|
-
class _LossODEAbstract(
|
|
44
|
+
class _LossODEAbstract(AbstractLoss):
|
|
38
45
|
"""
|
|
39
46
|
Parameters
|
|
40
47
|
----------
|
|
@@ -49,14 +56,14 @@ class _LossODEAbstract(eqx.Module):
|
|
|
49
56
|
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
50
57
|
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
51
58
|
is `"nn_params"` for each composant of the loss.
|
|
52
|
-
initial_condition : tuple, default=None
|
|
59
|
+
initial_condition : tuple[float | Float[Array, " 1"], Float[Array, " dim"]], default=None
|
|
53
60
|
tuple of length 2 with initial condition $(t_0, u_0)$.
|
|
54
|
-
obs_slice :
|
|
61
|
+
obs_slice : EllipsisType | slice | None, default=None
|
|
55
62
|
Slice object specifying the begininning/ending
|
|
56
63
|
slice of u output(s) that is observed. This is useful for
|
|
57
64
|
multidimensional PINN, with partially observed outputs.
|
|
58
65
|
Default is None (whole output is observed).
|
|
59
|
-
params : InitVar[Params], default=None
|
|
66
|
+
params : InitVar[Params[Array]], default=None
|
|
60
67
|
The main Params object of the problem needed to instanciate the
|
|
61
68
|
DerivativeKeysODE if the latter is not specified.
|
|
62
69
|
"""
|
|
@@ -66,24 +73,27 @@ class _LossODEAbstract(eqx.Module):
|
|
|
66
73
|
# kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
|
|
67
74
|
derivative_keys: DerivativeKeysODE | None = eqx.field(kw_only=True, default=None)
|
|
68
75
|
loss_weights: LossWeightsODE | None = eqx.field(kw_only=True, default=None)
|
|
69
|
-
initial_condition:
|
|
70
|
-
|
|
76
|
+
initial_condition: (
|
|
77
|
+
tuple[float | Float[Array, " 1"], Float[Array, " dim"]] | None
|
|
78
|
+
) = eqx.field(kw_only=True, default=None)
|
|
79
|
+
obs_slice: EllipsisType | slice | None = eqx.field(
|
|
80
|
+
kw_only=True, default=None, static=True
|
|
81
|
+
)
|
|
71
82
|
|
|
72
|
-
params: InitVar[Params] = eqx.field(default=None, kw_only=True)
|
|
83
|
+
params: InitVar[Params[Array]] = eqx.field(default=None, kw_only=True)
|
|
73
84
|
|
|
74
|
-
def __post_init__(self, params=None):
|
|
85
|
+
def __post_init__(self, params: Params[Array] | None = None):
|
|
75
86
|
if self.loss_weights is None:
|
|
76
87
|
self.loss_weights = LossWeightsODE()
|
|
77
88
|
|
|
78
89
|
if self.derivative_keys is None:
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
self.derivative_keys = DerivativeKeysODE(params=params)
|
|
82
|
-
except ValueError as exc:
|
|
90
|
+
# by default we only take gradient wrt nn_params
|
|
91
|
+
if params is None:
|
|
83
92
|
raise ValueError(
|
|
84
93
|
"Problem at self.derivative_keys initialization "
|
|
85
94
|
f"received {self.derivative_keys=} and {params=}"
|
|
86
|
-
)
|
|
95
|
+
)
|
|
96
|
+
self.derivative_keys = DerivativeKeysODE(params=params)
|
|
87
97
|
if self.initial_condition is None:
|
|
88
98
|
warnings.warn(
|
|
89
99
|
"Initial condition wasn't provided. Be sure to cover for that"
|
|
@@ -98,6 +108,19 @@ class _LossODEAbstract(eqx.Module):
|
|
|
98
108
|
"Initial condition should be a tuple of len 2 with (t0, u0), "
|
|
99
109
|
f"{self.initial_condition} was passed."
|
|
100
110
|
)
|
|
111
|
+
# some checks/reshaping for t0
|
|
112
|
+
t0, u0 = self.initial_condition
|
|
113
|
+
if isinstance(t0, Array):
|
|
114
|
+
if not t0.shape: # e.g. user input: jnp.array(0.)
|
|
115
|
+
t0 = jnp.array([t0])
|
|
116
|
+
elif t0.shape != (1,):
|
|
117
|
+
raise ValueError(
|
|
118
|
+
f"Wrong t0 input (self.initial_condition[0]) It should be"
|
|
119
|
+
f"a float or an array of shape (1,). Got shape: {t0.shape}"
|
|
120
|
+
)
|
|
121
|
+
if isinstance(t0, float): # e.g. user input: 0
|
|
122
|
+
t0 = jnp.array([t0])
|
|
123
|
+
self.initial_condition = (t0, u0)
|
|
101
124
|
|
|
102
125
|
if self.obs_slice is None:
|
|
103
126
|
self.obs_slice = jnp.s_[...]
|
|
@@ -105,10 +128,14 @@ class _LossODEAbstract(eqx.Module):
|
|
|
105
128
|
if self.loss_weights is None:
|
|
106
129
|
self.loss_weights = LossWeightsODE()
|
|
107
130
|
|
|
131
|
+
@abc.abstractmethod
|
|
132
|
+
def __call__(self, *_, **__):
|
|
133
|
+
pass
|
|
134
|
+
|
|
108
135
|
@abc.abstractmethod
|
|
109
136
|
def evaluate(
|
|
110
|
-
self: eqx.Module, params: Params, batch: ODEBatch
|
|
111
|
-
) -> tuple[Float,
|
|
137
|
+
self: eqx.Module, params: Params[Array], batch: ODEBatch
|
|
138
|
+
) -> tuple[Float[Array, " "], LossDictODE]:
|
|
112
139
|
raise NotImplementedError
|
|
113
140
|
|
|
114
141
|
|
|
@@ -135,19 +162,19 @@ class LossODE(_LossODEAbstract):
|
|
|
135
162
|
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
136
163
|
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
137
164
|
is `"nn_params"` for each composant of the loss.
|
|
138
|
-
initial_condition : tuple, default=None
|
|
165
|
+
initial_condition : tuple[float | Float[Array, " 1"]], default=None
|
|
139
166
|
tuple of length 2 with initial condition $(t_0, u_0)$.
|
|
140
|
-
obs_slice
|
|
167
|
+
obs_slice : EllipsisType | slice | None, default=None
|
|
141
168
|
Slice object specifying the begininning/ending
|
|
142
169
|
slice of u output(s) that is observed. This is useful for
|
|
143
170
|
multidimensional PINN, with partially observed outputs.
|
|
144
171
|
Default is None (whole output is observed).
|
|
145
|
-
params : InitVar[Params], default=None
|
|
172
|
+
params : InitVar[Params[Array]], default=None
|
|
146
173
|
The main Params object of the problem needed to instanciate the
|
|
147
174
|
DerivativeKeysODE if the latter is not specified.
|
|
148
175
|
u : eqx.Module
|
|
149
176
|
the PINN
|
|
150
|
-
dynamic_loss :
|
|
177
|
+
dynamic_loss : ODE
|
|
151
178
|
the ODE dynamic part of the loss, basically the differential
|
|
152
179
|
operator $\mathcal{N}[u](t)$. Should implement a method
|
|
153
180
|
`dynamic_loss.evaluate(t, u, params)`.
|
|
@@ -161,12 +188,12 @@ class LossODE(_LossODEAbstract):
|
|
|
161
188
|
|
|
162
189
|
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
163
190
|
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
164
|
-
u:
|
|
165
|
-
dynamic_loss:
|
|
191
|
+
u: AbstractPINN
|
|
192
|
+
dynamic_loss: ODE | None
|
|
166
193
|
|
|
167
|
-
vmap_in_axes: tuple[
|
|
194
|
+
vmap_in_axes: tuple[int] = eqx.field(init=False, static=True)
|
|
168
195
|
|
|
169
|
-
def __post_init__(self, params=None):
|
|
196
|
+
def __post_init__(self, params: Params[Array] | None = None):
|
|
170
197
|
super().__post_init__(
|
|
171
198
|
params=params
|
|
172
199
|
) # because __init__ or __post_init__ of Base
|
|
@@ -178,8 +205,8 @@ class LossODE(_LossODEAbstract):
|
|
|
178
205
|
return self.evaluate(*args, **kwargs)
|
|
179
206
|
|
|
180
207
|
def evaluate(
|
|
181
|
-
self, params: Params, batch: ODEBatch
|
|
182
|
-
) -> tuple[Float[Array, "
|
|
208
|
+
self, params: Params[Array], batch: ODEBatch
|
|
209
|
+
) -> tuple[Float[Array, " "], LossDictODE]:
|
|
183
210
|
"""
|
|
184
211
|
Evaluate the loss function at a batch of points for given parameters.
|
|
185
212
|
|
|
@@ -210,9 +237,9 @@ class LossODE(_LossODEAbstract):
|
|
|
210
237
|
self.dynamic_loss.evaluate,
|
|
211
238
|
self.u,
|
|
212
239
|
temporal_batch,
|
|
213
|
-
_set_derivatives(params, self.derivative_keys.dyn_loss),
|
|
240
|
+
_set_derivatives(params, self.derivative_keys.dyn_loss), # type: ignore
|
|
214
241
|
self.vmap_in_axes + vmap_in_axes_params,
|
|
215
|
-
self.loss_weights.dyn_loss,
|
|
242
|
+
self.loss_weights.dyn_loss, # type: ignore
|
|
216
243
|
)
|
|
217
244
|
else:
|
|
218
245
|
mse_dyn_loss = jnp.array(0.0)
|
|
@@ -226,17 +253,17 @@ class LossODE(_LossODEAbstract):
|
|
|
226
253
|
v_u = self.u
|
|
227
254
|
else:
|
|
228
255
|
v_u = vmap(self.u, (None,) + vmap_in_axes_params)
|
|
229
|
-
t0, u0 = self.initial_condition
|
|
230
|
-
t0 = jnp.array([t0])
|
|
256
|
+
t0, u0 = self.initial_condition
|
|
231
257
|
u0 = jnp.array(u0)
|
|
232
258
|
mse_initial_condition = jnp.mean(
|
|
233
|
-
self.loss_weights.initial_condition
|
|
259
|
+
self.loss_weights.initial_condition # type: ignore
|
|
234
260
|
* jnp.sum(
|
|
235
261
|
(
|
|
236
262
|
v_u(
|
|
237
263
|
t0,
|
|
238
264
|
_set_derivatives(
|
|
239
|
-
params,
|
|
265
|
+
params,
|
|
266
|
+
self.derivative_keys.initial_condition, # type: ignore
|
|
240
267
|
),
|
|
241
268
|
)
|
|
242
269
|
- u0
|
|
@@ -255,11 +282,11 @@ class LossODE(_LossODEAbstract):
|
|
|
255
282
|
# MSE loss wrt to an observed batch
|
|
256
283
|
mse_observation_loss = observations_loss_apply(
|
|
257
284
|
self.u,
|
|
258
|
-
|
|
259
|
-
_set_derivatives(params, self.derivative_keys.observations),
|
|
285
|
+
batch.obs_batch_dict["pinn_in"],
|
|
286
|
+
_set_derivatives(params, self.derivative_keys.observations), # type: ignore
|
|
260
287
|
self.vmap_in_axes + vmap_in_axes_params,
|
|
261
288
|
batch.obs_batch_dict["val"],
|
|
262
|
-
self.loss_weights.observations,
|
|
289
|
+
self.loss_weights.observations, # type: ignore
|
|
263
290
|
self.obs_slice,
|
|
264
291
|
)
|
|
265
292
|
else:
|
|
@@ -274,295 +301,3 @@ class LossODE(_LossODEAbstract):
|
|
|
274
301
|
"observations": mse_observation_loss,
|
|
275
302
|
}
|
|
276
303
|
)
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
class SystemLossODE(eqx.Module):
|
|
280
|
-
r"""
|
|
281
|
-
Class to implement a system of ODEs.
|
|
282
|
-
The goal is to give maximum freedom to the user. The class is created with
|
|
283
|
-
a dict of dynamic loss and a dict of initial conditions. Then, it iterates
|
|
284
|
-
over the dynamic losses that compose the system. All PINNs are passed as
|
|
285
|
-
arguments to each dynamic loss evaluate functions, along with all the
|
|
286
|
-
parameter dictionaries. All specification is left to the responsability
|
|
287
|
-
of the user, inside the dynamic loss.
|
|
288
|
-
|
|
289
|
-
**Note:** All the dictionaries (except `dynamic_loss_dict`) must have the same keys.
|
|
290
|
-
Indeed, these dictionaries (except `dynamic_loss_dict`) are tied to one
|
|
291
|
-
solution.
|
|
292
|
-
|
|
293
|
-
Parameters
|
|
294
|
-
----------
|
|
295
|
-
u_dict : Dict[str, eqx.Module]
|
|
296
|
-
dict of PINNs
|
|
297
|
-
loss_weights : LossWeightsODEDict
|
|
298
|
-
A dictionary of LossWeightsODE
|
|
299
|
-
derivative_keys_dict : Dict[str, DerivativeKeysODE], default=None
|
|
300
|
-
A dictionnary of DerivativeKeysODE specifying what field of `params`
|
|
301
|
-
should be used during gradient computations for each of the terms of
|
|
302
|
-
the total loss, for each of the loss in the system. Default is
|
|
303
|
-
`"nn_params`" everywhere.
|
|
304
|
-
initial_condition_dict : Dict[str, tuple], default=None
|
|
305
|
-
dict of tuple of length 2 with initial condition $(t_0, u_0)$
|
|
306
|
-
Must share the keys of `u_dict`. Default is None. No initial
|
|
307
|
-
condition is permitted when the initial condition is hardcoded in
|
|
308
|
-
the PINN architecture for example
|
|
309
|
-
dynamic_loss_dict : Dict[str, ODE]
|
|
310
|
-
dict of dynamic part of the loss, basically the differential
|
|
311
|
-
operator $\mathcal{N}[u](t)$. Should implement a method
|
|
312
|
-
`dynamic_loss.evaluate(t, u, params)`
|
|
313
|
-
obs_slice_dict : Dict[str, Slice]
|
|
314
|
-
dict of obs_slice, with keys from `u_dict` to designate the
|
|
315
|
-
output(s) channels that are observed, for each
|
|
316
|
-
PINNs. Default is None. But if a value is given, all the entries of
|
|
317
|
-
`u_dict` must be represented here with default value `jnp.s_[...]`
|
|
318
|
-
if no particular slice is to be given.
|
|
319
|
-
params_dict : InitVar[ParamsDict], default=None
|
|
320
|
-
The main Params object of the problem needed to instanciate the
|
|
321
|
-
DerivativeKeysODE if the latter is not specified.
|
|
322
|
-
|
|
323
|
-
Raises
|
|
324
|
-
------
|
|
325
|
-
ValueError
|
|
326
|
-
if initial condition is not a dict of tuple.
|
|
327
|
-
ValueError
|
|
328
|
-
if the dictionaries that should share the keys of u_dict do not.
|
|
329
|
-
"""
|
|
330
|
-
|
|
331
|
-
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
332
|
-
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
333
|
-
u_dict: Dict[str, eqx.Module]
|
|
334
|
-
dynamic_loss_dict: Dict[str, ODE]
|
|
335
|
-
derivative_keys_dict: Dict[str, DerivativeKeysODE | None] | None = eqx.field(
|
|
336
|
-
kw_only=True, default=None
|
|
337
|
-
)
|
|
338
|
-
initial_condition_dict: Dict[str, tuple] | None = eqx.field(
|
|
339
|
-
kw_only=True, default=None
|
|
340
|
-
)
|
|
341
|
-
|
|
342
|
-
obs_slice_dict: Dict[str, slice | None] | None = eqx.field(
|
|
343
|
-
kw_only=True, default=None, static=True
|
|
344
|
-
) # We are at an "leaf" attribute here (slice, not valid JAX type). Since
|
|
345
|
-
# we do not expect it to change with put a static=True here. But note that
|
|
346
|
-
# this is the only static for all the SystemLossODE attribute, since all
|
|
347
|
-
# other are composed of more complex structures ("non-leaf")
|
|
348
|
-
|
|
349
|
-
# For the user loss_weights are passed as a LossWeightsODEDict (with internal
|
|
350
|
-
# dictionary having keys in u_dict and / or dynamic_loss_dict)
|
|
351
|
-
loss_weights: InitVar[LossWeightsODEDict | None] = eqx.field(
|
|
352
|
-
kw_only=True, default=None
|
|
353
|
-
)
|
|
354
|
-
params_dict: InitVar[ParamsDict] = eqx.field(kw_only=True, default=None)
|
|
355
|
-
|
|
356
|
-
u_constraints_dict: Dict[str, LossODE] = eqx.field(init=False)
|
|
357
|
-
derivative_keys_dyn_loss: DerivativeKeysODE = eqx.field(init=False)
|
|
358
|
-
|
|
359
|
-
u_dict_with_none: Dict[str, None] = eqx.field(init=False)
|
|
360
|
-
# internally the loss weights are handled with a dictionary
|
|
361
|
-
_loss_weights: Dict[str, dict] = eqx.field(init=False)
|
|
362
|
-
|
|
363
|
-
def __post_init__(self, loss_weights=None, params_dict=None):
|
|
364
|
-
# a dictionary that will be useful at different places
|
|
365
|
-
self.u_dict_with_none = {k: None for k in self.u_dict.keys()}
|
|
366
|
-
if self.initial_condition_dict is None:
|
|
367
|
-
self.initial_condition_dict = self.u_dict_with_none
|
|
368
|
-
else:
|
|
369
|
-
if self.u_dict.keys() != self.initial_condition_dict.keys():
|
|
370
|
-
raise ValueError(
|
|
371
|
-
"initial_condition_dict should have same keys as u_dict"
|
|
372
|
-
)
|
|
373
|
-
if self.obs_slice_dict is None:
|
|
374
|
-
self.obs_slice_dict = {k: jnp.s_[...] for k in self.u_dict.keys()}
|
|
375
|
-
else:
|
|
376
|
-
if self.u_dict.keys() != self.obs_slice_dict.keys():
|
|
377
|
-
raise ValueError("obs_slice_dict should have same keys as u_dict")
|
|
378
|
-
|
|
379
|
-
if self.derivative_keys_dict is None:
|
|
380
|
-
self.derivative_keys_dict = {
|
|
381
|
-
k: None
|
|
382
|
-
for k in set(
|
|
383
|
-
list(self.dynamic_loss_dict.keys()) + list(self.u_dict.keys())
|
|
384
|
-
)
|
|
385
|
-
}
|
|
386
|
-
# set() because we can have duplicate entries and in this case we
|
|
387
|
-
# say it corresponds to the same derivative_keys_dict entry
|
|
388
|
-
# we need both because the constraints (all but dyn_loss) will be
|
|
389
|
-
# done by iterating on u_dict while the dyn_loss will be by
|
|
390
|
-
# iterating on dynamic_loss_dict. So each time we will require dome
|
|
391
|
-
# derivative_keys_dict
|
|
392
|
-
|
|
393
|
-
# derivative keys for the u_constraints. Note that we create missing
|
|
394
|
-
# DerivativeKeysODE around a Params object and not ParamsDict
|
|
395
|
-
# this works because u_dict.keys == params_dict.nn_params.keys()
|
|
396
|
-
for k in self.u_dict.keys():
|
|
397
|
-
if self.derivative_keys_dict[k] is None:
|
|
398
|
-
self.derivative_keys_dict[k] = DerivativeKeysODE(
|
|
399
|
-
params=params_dict.extract_params(k)
|
|
400
|
-
)
|
|
401
|
-
|
|
402
|
-
self._loss_weights = self.set_loss_weights(loss_weights)
|
|
403
|
-
|
|
404
|
-
# The constaints on the solutions will be implemented by reusing a
|
|
405
|
-
# LossODE class without dynamic loss term
|
|
406
|
-
self.u_constraints_dict = {}
|
|
407
|
-
for i in self.u_dict.keys():
|
|
408
|
-
self.u_constraints_dict[i] = LossODE(
|
|
409
|
-
u=self.u_dict[i],
|
|
410
|
-
loss_weights=LossWeightsODE(
|
|
411
|
-
dyn_loss=0.0,
|
|
412
|
-
initial_condition=1.0,
|
|
413
|
-
observations=1.0,
|
|
414
|
-
),
|
|
415
|
-
dynamic_loss=None,
|
|
416
|
-
derivative_keys=self.derivative_keys_dict[i],
|
|
417
|
-
initial_condition=self.initial_condition_dict[i],
|
|
418
|
-
obs_slice=self.obs_slice_dict[i],
|
|
419
|
-
)
|
|
420
|
-
|
|
421
|
-
# derivative keys for the dynamic loss. Note that we create a
|
|
422
|
-
# DerivativeKeysODE around a ParamsDict object because a whole
|
|
423
|
-
# params_dict is feed to DynamicLoss.evaluate functions (extract_params
|
|
424
|
-
# happen inside it)
|
|
425
|
-
self.derivative_keys_dyn_loss = DerivativeKeysODE(params=params_dict)
|
|
426
|
-
|
|
427
|
-
def set_loss_weights(self, loss_weights_init):
|
|
428
|
-
"""
|
|
429
|
-
This rather complex function enables the user to specify a simple
|
|
430
|
-
loss_weights=LossWeightsODEDict(dyn_loss=1., initial_condition=Tmax)
|
|
431
|
-
for ponderating values being applied to all the equations of the
|
|
432
|
-
system... So all the transformations are handled here
|
|
433
|
-
"""
|
|
434
|
-
_loss_weights = {}
|
|
435
|
-
for k in fields(loss_weights_init):
|
|
436
|
-
v = getattr(loss_weights_init, k.name)
|
|
437
|
-
if isinstance(v, dict):
|
|
438
|
-
for vv in v.values():
|
|
439
|
-
if not isinstance(vv, (int, float)) and not (
|
|
440
|
-
isinstance(vv, Array)
|
|
441
|
-
and ((vv.shape == (1,) or len(vv.shape) == 0))
|
|
442
|
-
):
|
|
443
|
-
# TODO improve that
|
|
444
|
-
raise ValueError(
|
|
445
|
-
f"loss values cannot be vectorial here, got {vv}"
|
|
446
|
-
)
|
|
447
|
-
if k.name == "dyn_loss":
|
|
448
|
-
if v.keys() == self.dynamic_loss_dict.keys():
|
|
449
|
-
_loss_weights[k.name] = v
|
|
450
|
-
else:
|
|
451
|
-
raise ValueError(
|
|
452
|
-
"Keys in nested dictionary of loss_weights"
|
|
453
|
-
" do not match dynamic_loss_dict keys"
|
|
454
|
-
)
|
|
455
|
-
else:
|
|
456
|
-
if v.keys() == self.u_dict.keys():
|
|
457
|
-
_loss_weights[k.name] = v
|
|
458
|
-
else:
|
|
459
|
-
raise ValueError(
|
|
460
|
-
"Keys in nested dictionary of loss_weights"
|
|
461
|
-
" do not match u_dict keys"
|
|
462
|
-
)
|
|
463
|
-
elif v is None:
|
|
464
|
-
_loss_weights[k.name] = {kk: 0 for kk in self.u_dict.keys()}
|
|
465
|
-
else:
|
|
466
|
-
if not isinstance(v, (int, float)) and not (
|
|
467
|
-
isinstance(v, Array) and ((v.shape == (1,) or len(v.shape) == 0))
|
|
468
|
-
):
|
|
469
|
-
# TODO improve that
|
|
470
|
-
raise ValueError(f"loss values cannot be vectorial here, got {v}")
|
|
471
|
-
if k.name == "dyn_loss":
|
|
472
|
-
_loss_weights[k.name] = {
|
|
473
|
-
kk: v for kk in self.dynamic_loss_dict.keys()
|
|
474
|
-
}
|
|
475
|
-
else:
|
|
476
|
-
_loss_weights[k.name] = {kk: v for kk in self.u_dict.keys()}
|
|
477
|
-
|
|
478
|
-
return _loss_weights
|
|
479
|
-
|
|
480
|
-
def __call__(self, *args, **kwargs):
|
|
481
|
-
return self.evaluate(*args, **kwargs)
|
|
482
|
-
|
|
483
|
-
def evaluate(self, params_dict: ParamsDict, batch: ODEBatch) -> Float[Array, "1"]:
|
|
484
|
-
"""
|
|
485
|
-
Evaluate the loss function at a batch of points for given parameters.
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
Parameters
|
|
489
|
-
---------
|
|
490
|
-
params
|
|
491
|
-
A ParamsDict object
|
|
492
|
-
batch
|
|
493
|
-
A ODEBatch object.
|
|
494
|
-
Such a named tuple is composed of a batch of time points
|
|
495
|
-
at which to evaluate an optional additional batch of parameters (eg. for
|
|
496
|
-
metamodeling) and an optional additional batch of observed
|
|
497
|
-
inputs/outputs/parameters
|
|
498
|
-
"""
|
|
499
|
-
if (
|
|
500
|
-
isinstance(params_dict.nn_params, dict)
|
|
501
|
-
and self.u_dict.keys() != params_dict.nn_params.keys()
|
|
502
|
-
):
|
|
503
|
-
raise ValueError("u_dict and params_dict.nn_params should have same keys ")
|
|
504
|
-
|
|
505
|
-
temporal_batch = batch.temporal_batch
|
|
506
|
-
|
|
507
|
-
vmap_in_axes_t = (0,)
|
|
508
|
-
|
|
509
|
-
# Retrieve the optional eq_params_batch
|
|
510
|
-
# and update eq_params with the latter
|
|
511
|
-
# and update vmap_in_axes
|
|
512
|
-
if batch.param_batch_dict is not None:
|
|
513
|
-
# update params with the batches of generated params
|
|
514
|
-
params = _update_eq_params_dict(params, batch.param_batch_dict)
|
|
515
|
-
|
|
516
|
-
vmap_in_axes_params = _get_vmap_in_axes_params(
|
|
517
|
-
batch.param_batch_dict, params_dict
|
|
518
|
-
)
|
|
519
|
-
|
|
520
|
-
def dyn_loss_for_one_key(dyn_loss, loss_weight):
|
|
521
|
-
"""This function is used in tree_map"""
|
|
522
|
-
return dynamic_loss_apply(
|
|
523
|
-
dyn_loss.evaluate,
|
|
524
|
-
self.u_dict,
|
|
525
|
-
temporal_batch,
|
|
526
|
-
_set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
|
|
527
|
-
vmap_in_axes_t + vmap_in_axes_params,
|
|
528
|
-
loss_weight,
|
|
529
|
-
u_type=PINN,
|
|
530
|
-
)
|
|
531
|
-
|
|
532
|
-
dyn_loss_mse_dict = jax.tree_util.tree_map(
|
|
533
|
-
dyn_loss_for_one_key,
|
|
534
|
-
self.dynamic_loss_dict,
|
|
535
|
-
self._loss_weights["dyn_loss"],
|
|
536
|
-
is_leaf=lambda x: isinstance(x, ODE), # before when dynamic losses
|
|
537
|
-
# where plain (unregister pytree) node classes, we could not traverse
|
|
538
|
-
# this level. Now that dynamic losses are eqx.Module they can be
|
|
539
|
-
# traversed by tree map recursion. Hence we need to specify to that
|
|
540
|
-
# we want to stop at this level
|
|
541
|
-
)
|
|
542
|
-
mse_dyn_loss = jax.tree_util.tree_reduce(
|
|
543
|
-
lambda x, y: x + y, jax.tree_util.tree_leaves(dyn_loss_mse_dict)
|
|
544
|
-
)
|
|
545
|
-
|
|
546
|
-
# initial conditions and observation_loss via the internal LossODE
|
|
547
|
-
loss_weight_struct = {
|
|
548
|
-
"dyn_loss": "*",
|
|
549
|
-
"observations": "*",
|
|
550
|
-
"initial_condition": "*",
|
|
551
|
-
}
|
|
552
|
-
|
|
553
|
-
# we need to do the following for the tree_mapping to work
|
|
554
|
-
if batch.obs_batch_dict is None:
|
|
555
|
-
batch = append_obs_batch(batch, self.u_dict_with_none)
|
|
556
|
-
|
|
557
|
-
total_loss, res_dict = constraints_system_loss_apply(
|
|
558
|
-
self.u_constraints_dict,
|
|
559
|
-
batch,
|
|
560
|
-
params_dict,
|
|
561
|
-
self._loss_weights,
|
|
562
|
-
loss_weight_struct,
|
|
563
|
-
)
|
|
564
|
-
|
|
565
|
-
# Add the mse_dyn_loss from the previous computations
|
|
566
|
-
total_loss += mse_dyn_loss
|
|
567
|
-
res_dict["dyn_loss"] += mse_dyn_loss
|
|
568
|
-
return total_loss, res_dict
|