jinns 1.5.0__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +7 -7
- jinns/data/_AbstractDataGenerator.py +1 -1
- jinns/data/_Batchs.py +47 -13
- jinns/data/_CubicMeshPDENonStatio.py +203 -54
- jinns/data/_CubicMeshPDEStatio.py +190 -54
- jinns/data/_DataGeneratorODE.py +48 -22
- jinns/data/_DataGeneratorObservations.py +75 -32
- jinns/data/_DataGeneratorParameter.py +152 -101
- jinns/data/__init__.py +2 -1
- jinns/data/_utils.py +22 -10
- jinns/loss/_DynamicLoss.py +21 -20
- jinns/loss/_DynamicLossAbstract.py +51 -36
- jinns/loss/_LossODE.py +210 -191
- jinns/loss/_LossPDE.py +441 -368
- jinns/loss/_abstract_loss.py +60 -25
- jinns/loss/_loss_components.py +4 -25
- jinns/loss/_loss_utils.py +23 -0
- jinns/loss/_loss_weight_updates.py +6 -7
- jinns/loss/_loss_weights.py +34 -35
- jinns/nn/_abstract_pinn.py +0 -2
- jinns/nn/_hyperpinn.py +34 -23
- jinns/nn/_mlp.py +5 -4
- jinns/nn/_pinn.py +1 -16
- jinns/nn/_ppinn.py +5 -16
- jinns/nn/_save_load.py +11 -4
- jinns/nn/_spinn.py +1 -16
- jinns/nn/_spinn_mlp.py +5 -5
- jinns/nn/_utils.py +33 -38
- jinns/parameters/__init__.py +3 -1
- jinns/parameters/_derivative_keys.py +99 -41
- jinns/parameters/_params.py +58 -25
- jinns/solver/_solve.py +14 -8
- jinns/utils/_DictToModuleMeta.py +66 -0
- jinns/utils/_ItemizableModule.py +19 -0
- jinns/utils/__init__.py +2 -1
- jinns/utils/_types.py +25 -15
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/METADATA +2 -2
- jinns-1.6.0.dist-info/RECORD +57 -0
- jinns-1.5.0.dist-info/RECORD +0 -55
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/WHEEL +0 -0
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossODE.py
CHANGED
|
@@ -7,9 +7,8 @@ from __future__ import (
|
|
|
7
7
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
8
|
|
|
9
9
|
from dataclasses import InitVar
|
|
10
|
-
from typing import TYPE_CHECKING,
|
|
10
|
+
from typing import TYPE_CHECKING, Callable, Any, cast
|
|
11
11
|
from types import EllipsisType
|
|
12
|
-
import abc
|
|
13
12
|
import warnings
|
|
14
13
|
import jax
|
|
15
14
|
import jax.numpy as jnp
|
|
@@ -19,133 +18,36 @@ from jaxtyping import Float, Array
|
|
|
19
18
|
from jinns.loss._loss_utils import (
|
|
20
19
|
dynamic_loss_apply,
|
|
21
20
|
observations_loss_apply,
|
|
21
|
+
initial_condition_check,
|
|
22
22
|
)
|
|
23
23
|
from jinns.parameters._params import (
|
|
24
24
|
_get_vmap_in_axes_params,
|
|
25
|
-
|
|
25
|
+
update_eq_params,
|
|
26
26
|
)
|
|
27
27
|
from jinns.parameters._derivative_keys import _set_derivatives, DerivativeKeysODE
|
|
28
28
|
from jinns.loss._loss_weights import LossWeightsODE
|
|
29
29
|
from jinns.loss._abstract_loss import AbstractLoss
|
|
30
30
|
from jinns.loss._loss_components import ODEComponents
|
|
31
31
|
from jinns.parameters._params import Params
|
|
32
|
+
from jinns.data._Batchs import ODEBatch
|
|
32
33
|
|
|
33
34
|
if TYPE_CHECKING:
|
|
34
35
|
# imports only used in type hints
|
|
35
|
-
from jinns.data._Batchs import ODEBatch
|
|
36
36
|
from jinns.nn._abstract_pinn import AbstractPINN
|
|
37
37
|
from jinns.loss import ODE
|
|
38
38
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
observations: Float[Array, " "]
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
class _LossODEAbstract(AbstractLoss):
|
|
46
|
-
"""
|
|
47
|
-
Parameters
|
|
48
|
-
----------
|
|
49
|
-
|
|
50
|
-
loss_weights : LossWeightsODE, default=None
|
|
51
|
-
The loss weights for the differents term : dynamic loss,
|
|
52
|
-
initial condition and eventually observations if any.
|
|
53
|
-
Can be updated according to a specific algorithm. See
|
|
54
|
-
`update_weight_method`
|
|
55
|
-
update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
|
|
56
|
-
Default is None meaning no update for loss weights. Otherwise a string
|
|
57
|
-
derivative_keys : DerivativeKeysODE, default=None
|
|
58
|
-
Specify which field of `params` should be differentiated for each
|
|
59
|
-
composant of the total loss. Particularily useful for inverse problems.
|
|
60
|
-
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
61
|
-
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
62
|
-
is `"nn_params"` for each composant of the loss.
|
|
63
|
-
initial_condition : tuple[float | Float[Array, " 1"], Float[Array, " dim"]], default=None
|
|
64
|
-
tuple of length 2 with initial condition $(t_0, u_0)$.
|
|
65
|
-
obs_slice : EllipsisType | slice | None, default=None
|
|
66
|
-
Slice object specifying the begininning/ending
|
|
67
|
-
slice of u output(s) that is observed. This is useful for
|
|
68
|
-
multidimensional PINN, with partially observed outputs.
|
|
69
|
-
Default is None (whole output is observed).
|
|
70
|
-
params : InitVar[Params[Array]], default=None
|
|
71
|
-
The main Params object of the problem needed to instanciate the
|
|
72
|
-
DerivativeKeysODE if the latter is not specified.
|
|
73
|
-
"""
|
|
74
|
-
|
|
75
|
-
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
76
|
-
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
77
|
-
# kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
|
|
78
|
-
derivative_keys: DerivativeKeysODE | None = eqx.field(kw_only=True, default=None)
|
|
79
|
-
loss_weights: LossWeightsODE | None = eqx.field(kw_only=True, default=None)
|
|
80
|
-
initial_condition: (
|
|
81
|
-
tuple[float | Float[Array, " 1"], Float[Array, " dim"]] | None
|
|
82
|
-
) = eqx.field(kw_only=True, default=None)
|
|
83
|
-
obs_slice: EllipsisType | slice | None = eqx.field(
|
|
84
|
-
kw_only=True, default=None, static=True
|
|
39
|
+
InitialConditionUser = (
|
|
40
|
+
tuple[Float[Array, " n_cond "], Float[Array, " n_cond dim"]]
|
|
41
|
+
| tuple[int | float | Float[Array, " "], int | float | Float[Array, " dim"]]
|
|
85
42
|
)
|
|
86
43
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
self.loss_weights = LossWeightsODE()
|
|
92
|
-
|
|
93
|
-
if self.derivative_keys is None:
|
|
94
|
-
# by default we only take gradient wrt nn_params
|
|
95
|
-
if params is None:
|
|
96
|
-
raise ValueError(
|
|
97
|
-
"Problem at self.derivative_keys initialization "
|
|
98
|
-
f"received {self.derivative_keys=} and {params=}"
|
|
99
|
-
)
|
|
100
|
-
self.derivative_keys = DerivativeKeysODE(params=params)
|
|
101
|
-
if self.initial_condition is None:
|
|
102
|
-
warnings.warn(
|
|
103
|
-
"Initial condition wasn't provided. Be sure to cover for that"
|
|
104
|
-
"case (e.g by. hardcoding it into the PINN output)."
|
|
105
|
-
)
|
|
106
|
-
else:
|
|
107
|
-
if (
|
|
108
|
-
not isinstance(self.initial_condition, tuple)
|
|
109
|
-
or len(self.initial_condition) != 2
|
|
110
|
-
):
|
|
111
|
-
raise ValueError(
|
|
112
|
-
"Initial condition should be a tuple of len 2 with (t0, u0), "
|
|
113
|
-
f"{self.initial_condition} was passed."
|
|
114
|
-
)
|
|
115
|
-
# some checks/reshaping for t0
|
|
116
|
-
t0, u0 = self.initial_condition
|
|
117
|
-
if isinstance(t0, Array):
|
|
118
|
-
if not t0.shape: # e.g. user input: jnp.array(0.)
|
|
119
|
-
t0 = jnp.array([t0])
|
|
120
|
-
elif t0.shape != (1,):
|
|
121
|
-
raise ValueError(
|
|
122
|
-
f"Wrong t0 input (self.initial_condition[0]) It should be"
|
|
123
|
-
f"a float or an array of shape (1,). Got shape: {t0.shape}"
|
|
124
|
-
)
|
|
125
|
-
if isinstance(t0, float): # e.g. user input: 0.
|
|
126
|
-
t0 = jnp.array([t0])
|
|
127
|
-
if isinstance(t0, int): # e.g. user input: 0
|
|
128
|
-
t0 = jnp.array([float(t0)])
|
|
129
|
-
self.initial_condition = (t0, u0)
|
|
130
|
-
|
|
131
|
-
if self.obs_slice is None:
|
|
132
|
-
self.obs_slice = jnp.s_[...]
|
|
133
|
-
|
|
134
|
-
if self.loss_weights is None:
|
|
135
|
-
self.loss_weights = LossWeightsODE()
|
|
136
|
-
|
|
137
|
-
@abc.abstractmethod
|
|
138
|
-
def __call__(self, *_, **__):
|
|
139
|
-
pass
|
|
140
|
-
|
|
141
|
-
@abc.abstractmethod
|
|
142
|
-
def evaluate(
|
|
143
|
-
self: eqx.Module, params: Params[Array], batch: ODEBatch
|
|
144
|
-
) -> tuple[Float[Array, " "], LossDictODE]:
|
|
145
|
-
raise NotImplementedError
|
|
44
|
+
InitialCondition = (
|
|
45
|
+
tuple[Float[Array, " n_cond "], Float[Array, " n_cond dim"]]
|
|
46
|
+
| tuple[Float[Array, " "], Float[Array, " dim"]]
|
|
47
|
+
)
|
|
146
48
|
|
|
147
49
|
|
|
148
|
-
class LossODE(
|
|
50
|
+
class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]]):
|
|
149
51
|
r"""Loss object for an ordinary differential equation
|
|
150
52
|
|
|
151
53
|
$$
|
|
@@ -158,6 +60,13 @@ class LossODE(_LossODEAbstract):
|
|
|
158
60
|
|
|
159
61
|
Parameters
|
|
160
62
|
----------
|
|
63
|
+
u : eqx.Module
|
|
64
|
+
the PINN
|
|
65
|
+
dynamic_loss : ODE
|
|
66
|
+
the ODE dynamic part of the loss, basically the differential
|
|
67
|
+
operator $\mathcal{N}[u](t)$. Should implement a method
|
|
68
|
+
`dynamic_loss.evaluate(t, u, params)`.
|
|
69
|
+
Can be None in order to access only some part of the evaluate call.
|
|
161
70
|
loss_weights : LossWeightsODE, default=None
|
|
162
71
|
The loss weights for the differents term : dynamic loss,
|
|
163
72
|
initial condition and eventually observations if any.
|
|
@@ -171,9 +80,16 @@ class LossODE(_LossODEAbstract):
|
|
|
171
80
|
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
172
81
|
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
173
82
|
is `"nn_params"` for each composant of the loss.
|
|
174
|
-
initial_condition : tuple[
|
|
175
|
-
|
|
176
|
-
|
|
83
|
+
initial_condition : tuple[
|
|
84
|
+
Float[Array, "n_cond "],
|
|
85
|
+
Float[Array, "n_cond dim"]
|
|
86
|
+
] |
|
|
87
|
+
tuple[int | float | Float[Array, " "],
|
|
88
|
+
int | float | Float[Array, " dim"]
|
|
89
|
+
], default=None
|
|
90
|
+
Most of the time, a tuple of length 2 with initial condition $(t_0, u_0)$.
|
|
91
|
+
From jinns v1.5.1 we accept tuples of jnp arrays with shape (n_cond, 1) for t0 and (n_cond, dim) for u0. This is useful to include observed conditions at different time points, such as *e.g* final conditions. It was designed to implement $\mathcal{L}^{aux}$ from _Systems biology informed deep learning for inferring parameters and hidden dynamics_, Alireza Yazdani et al., 2020
|
|
92
|
+
obs_slice : EllipsisType | slice, default=None
|
|
177
93
|
Slice object specifying the begininning/ending
|
|
178
94
|
slice of u output(s) that is observed. This is useful for
|
|
179
95
|
multidimensional PINN, with partially observed outputs.
|
|
@@ -181,14 +97,6 @@ class LossODE(_LossODEAbstract):
|
|
|
181
97
|
params : InitVar[Params[Array]], default=None
|
|
182
98
|
The main Params object of the problem needed to instanciate the
|
|
183
99
|
DerivativeKeysODE if the latter is not specified.
|
|
184
|
-
u : eqx.Module
|
|
185
|
-
the PINN
|
|
186
|
-
dynamic_loss : ODE
|
|
187
|
-
the ODE dynamic part of the loss, basically the differential
|
|
188
|
-
operator $\mathcal{N}[u](t)$. Should implement a method
|
|
189
|
-
`dynamic_loss.evaluate(t, u, params)`.
|
|
190
|
-
Can be None in order to access only some part of the evaluate call.
|
|
191
|
-
|
|
192
100
|
Raises
|
|
193
101
|
------
|
|
194
102
|
ValueError
|
|
@@ -199,19 +107,117 @@ class LossODE(_LossODEAbstract):
|
|
|
199
107
|
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
200
108
|
u: AbstractPINN
|
|
201
109
|
dynamic_loss: ODE | None
|
|
110
|
+
vmap_in_axes: tuple[int] = eqx.field(static=True)
|
|
111
|
+
derivative_keys: DerivativeKeysODE
|
|
112
|
+
loss_weights: LossWeightsODE
|
|
113
|
+
initial_condition: InitialCondition | None
|
|
114
|
+
obs_slice: EllipsisType | slice = eqx.field(static=True)
|
|
115
|
+
params: InitVar[Params[Array] | None]
|
|
116
|
+
|
|
117
|
+
def __init__(
|
|
118
|
+
self,
|
|
119
|
+
*,
|
|
120
|
+
u: AbstractPINN,
|
|
121
|
+
dynamic_loss: ODE | None,
|
|
122
|
+
loss_weights: LossWeightsODE | None = None,
|
|
123
|
+
derivative_keys: DerivativeKeysODE | None = None,
|
|
124
|
+
initial_condition: InitialConditionUser | None = None,
|
|
125
|
+
obs_slice: EllipsisType | slice | None = None,
|
|
126
|
+
params: Params[Array] | None = None,
|
|
127
|
+
**kwargs: Any, # this is for arguments for super()
|
|
128
|
+
):
|
|
129
|
+
if loss_weights is None:
|
|
130
|
+
self.loss_weights = LossWeightsODE()
|
|
131
|
+
else:
|
|
132
|
+
self.loss_weights = loss_weights
|
|
202
133
|
|
|
203
|
-
|
|
134
|
+
super().__init__(loss_weights=self.loss_weights, **kwargs)
|
|
135
|
+
self.u = u
|
|
136
|
+
self.dynamic_loss = dynamic_loss
|
|
137
|
+
self.vmap_in_axes = (0,)
|
|
138
|
+
if derivative_keys is None:
|
|
139
|
+
# by default we only take gradient wrt nn_params
|
|
140
|
+
if params is None:
|
|
141
|
+
raise ValueError(
|
|
142
|
+
"Problem at derivative_keys initialization "
|
|
143
|
+
f"received {derivative_keys=} and {params=}"
|
|
144
|
+
)
|
|
145
|
+
self.derivative_keys = DerivativeKeysODE(params=params)
|
|
146
|
+
else:
|
|
147
|
+
self.derivative_keys = derivative_keys
|
|
204
148
|
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
149
|
+
if initial_condition is None:
|
|
150
|
+
warnings.warn(
|
|
151
|
+
"Initial condition wasn't provided. Be sure to cover for that"
|
|
152
|
+
"case (e.g by. hardcoding it into the PINN output)."
|
|
153
|
+
)
|
|
154
|
+
self.initial_condition = initial_condition
|
|
155
|
+
else:
|
|
156
|
+
if len(initial_condition) != 2:
|
|
157
|
+
raise ValueError(
|
|
158
|
+
"Initial condition should be a tuple of len 2 with (t0, u0), "
|
|
159
|
+
f"{initial_condition} was passed."
|
|
160
|
+
)
|
|
161
|
+
# some checks/reshaping for t0 and u0
|
|
162
|
+
t0, u0 = initial_condition
|
|
163
|
+
if isinstance(t0, Array):
|
|
164
|
+
# at the end we want to end up with t0 of shape (:, 1) to account for
|
|
165
|
+
# possibly several data points
|
|
166
|
+
if t0.ndim <= 1:
|
|
167
|
+
# in this case we assume t0 belongs one (initial)
|
|
168
|
+
# condition
|
|
169
|
+
t0 = initial_condition_check(t0, dim_size=1)[
|
|
170
|
+
None, :
|
|
171
|
+
] # make a (1, 1) here
|
|
172
|
+
if t0.ndim > 2:
|
|
173
|
+
raise ValueError(
|
|
174
|
+
"It t0 is an Array, it represents n_cond"
|
|
175
|
+
" imposed conditions and must be of shape (n_cond, 1)"
|
|
176
|
+
)
|
|
177
|
+
else:
|
|
178
|
+
# in this case t0 clearly represents one (initial) condition
|
|
179
|
+
t0 = initial_condition_check(t0, dim_size=1)[
|
|
180
|
+
None, :
|
|
181
|
+
] # make a (1, 1) here
|
|
182
|
+
if isinstance(u0, Array):
|
|
183
|
+
# at the end we want to end up with u0 of shape (:, dim) to account for
|
|
184
|
+
# possibly several data points
|
|
185
|
+
if not u0.shape:
|
|
186
|
+
# in this case we assume u0 belongs to one (initial)
|
|
187
|
+
# condition
|
|
188
|
+
u0 = initial_condition_check(u0, dim_size=1)[
|
|
189
|
+
None, :
|
|
190
|
+
] # make a (1, 1) here
|
|
191
|
+
elif u0.ndim == 1:
|
|
192
|
+
# in this case we assume u0 belongs to one (initial)
|
|
193
|
+
# condition
|
|
194
|
+
u0 = initial_condition_check(u0, dim_size=u0.shape[0])[
|
|
195
|
+
None, :
|
|
196
|
+
] # make a (1, dim) here
|
|
197
|
+
if u0.ndim > 2:
|
|
198
|
+
raise ValueError(
|
|
199
|
+
"It u0 is an Array, it represents n_cond "
|
|
200
|
+
"imposed conditions and must be of shape (n_cond, dim)"
|
|
201
|
+
)
|
|
202
|
+
else:
|
|
203
|
+
# at the end we want to end up with u0 of shape (:, dim) to account for
|
|
204
|
+
# possibly several data points
|
|
205
|
+
u0 = initial_condition_check(u0, dim_size=None)[
|
|
206
|
+
None, :
|
|
207
|
+
] # make a (1, 1) here
|
|
210
208
|
|
|
211
|
-
|
|
209
|
+
if t0.shape[0] != u0.shape[0] or t0.ndim != u0.ndim:
|
|
210
|
+
raise ValueError(
|
|
211
|
+
"t0 and u0 must represent a same number of initial"
|
|
212
|
+
" conditial conditions"
|
|
213
|
+
)
|
|
212
214
|
|
|
213
|
-
|
|
214
|
-
|
|
215
|
+
self.initial_condition = (t0, u0)
|
|
216
|
+
|
|
217
|
+
if obs_slice is None:
|
|
218
|
+
self.obs_slice = jnp.s_[...]
|
|
219
|
+
else:
|
|
220
|
+
self.obs_slice = obs_slice
|
|
215
221
|
|
|
216
222
|
def evaluate_by_terms(
|
|
217
223
|
self, params: Params[Array], batch: ODEBatch
|
|
@@ -241,63 +247,92 @@ class LossODE(_LossODEAbstract):
|
|
|
241
247
|
# and update vmap_in_axes
|
|
242
248
|
if batch.param_batch_dict is not None:
|
|
243
249
|
# update params with the batches of generated params
|
|
244
|
-
params =
|
|
250
|
+
params = update_eq_params(params, batch.param_batch_dict)
|
|
245
251
|
|
|
246
|
-
vmap_in_axes_params = _get_vmap_in_axes_params(
|
|
252
|
+
vmap_in_axes_params = _get_vmap_in_axes_params(
|
|
253
|
+
cast(eqx.Module, batch.param_batch_dict), params
|
|
254
|
+
)
|
|
247
255
|
|
|
248
256
|
## dynamic part
|
|
249
257
|
if self.dynamic_loss is not None:
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
258
|
+
dyn_loss_eval = self.dynamic_loss.evaluate
|
|
259
|
+
dyn_loss_fun: Callable[[Params[Array]], Array] | None = (
|
|
260
|
+
lambda p: dynamic_loss_apply(
|
|
261
|
+
dyn_loss_eval,
|
|
262
|
+
self.u,
|
|
263
|
+
temporal_batch,
|
|
264
|
+
_set_derivatives(p, self.derivative_keys.dyn_loss),
|
|
265
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
266
|
+
)
|
|
256
267
|
)
|
|
257
268
|
else:
|
|
258
269
|
dyn_loss_fun = None
|
|
259
270
|
|
|
260
|
-
# initial condition
|
|
261
271
|
if self.initial_condition is not None:
|
|
262
|
-
|
|
263
|
-
if not jax.tree_util.tree_leaves(vmap_in_axes):
|
|
264
|
-
# test if only None in vmap_in_axes to avoid the value error:
|
|
265
|
-
# `vmap must have at least one non-None value in in_axes`
|
|
266
|
-
v_u = self.u
|
|
267
|
-
else:
|
|
268
|
-
v_u = vmap(self.u, (None,) + vmap_in_axes_params)
|
|
272
|
+
# initial condition
|
|
269
273
|
t0, u0 = self.initial_condition
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
274
|
+
|
|
275
|
+
# first construct the plain init loss no vmaping
|
|
276
|
+
initial_condition_fun__: Callable[[Array, Array, Params[Array]], Array] = (
|
|
277
|
+
lambda t, u, p: jnp.sum(
|
|
273
278
|
(
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
_set_derivatives(
|
|
279
|
+
self.u(
|
|
280
|
+
t,
|
|
281
|
+
_set_derivatives(
|
|
282
|
+
p,
|
|
283
|
+
self.derivative_keys.initial_condition,
|
|
284
|
+
),
|
|
277
285
|
)
|
|
278
|
-
-
|
|
286
|
+
- u
|
|
279
287
|
)
|
|
280
288
|
** 2,
|
|
281
|
-
axis
|
|
289
|
+
axis=0,
|
|
282
290
|
)
|
|
283
291
|
)
|
|
292
|
+
# now vmap over the number of conditions (first dim of t0 and u0)
|
|
293
|
+
# and take the mean
|
|
294
|
+
initial_condition_fun_: Callable[[Params[Array]], Array] = (
|
|
295
|
+
lambda p: jnp.mean(
|
|
296
|
+
vmap(initial_condition_fun__, (0, 0, None))(t0, u0, p)
|
|
297
|
+
)
|
|
298
|
+
)
|
|
299
|
+
# now vmap over the the possible batch of parameters and take the
|
|
300
|
+
# average. Note that we then finally have a cartesian product
|
|
301
|
+
# between the batch of parameters (if any) and the number of
|
|
302
|
+
# conditions (if any)
|
|
303
|
+
if not jax.tree_util.tree_leaves(vmap_in_axes_params):
|
|
304
|
+
# if there is no parameter batch to vmap over we cannot call
|
|
305
|
+
# vmap because calling vmap must be done with at least one non
|
|
306
|
+
# None in_axes or out_axes
|
|
307
|
+
initial_condition_fun = initial_condition_fun_
|
|
308
|
+
else:
|
|
309
|
+
initial_condition_fun: Callable[[Params[Array]], Array] | None = (
|
|
310
|
+
lambda p: jnp.mean(
|
|
311
|
+
vmap(initial_condition_fun_, vmap_in_axes_params)(p)
|
|
312
|
+
)
|
|
313
|
+
)
|
|
284
314
|
else:
|
|
285
315
|
initial_condition_fun = None
|
|
286
316
|
|
|
287
317
|
if batch.obs_batch_dict is not None:
|
|
288
318
|
# update params with the batches of observed params
|
|
289
|
-
params_obs =
|
|
290
|
-
params, batch.obs_batch_dict["eq_params"]
|
|
291
|
-
)
|
|
319
|
+
params_obs = update_eq_params(params, batch.obs_batch_dict["eq_params"])
|
|
292
320
|
|
|
293
|
-
|
|
294
|
-
obs_loss_fun = lambda po: observations_loss_apply(
|
|
295
|
-
self.u,
|
|
321
|
+
pinn_in, val = (
|
|
296
322
|
batch.obs_batch_dict["pinn_in"],
|
|
297
|
-
_set_derivatives(po, self.derivative_keys.observations), # type: ignore
|
|
298
|
-
self.vmap_in_axes + vmap_in_axes_params,
|
|
299
323
|
batch.obs_batch_dict["val"],
|
|
300
|
-
|
|
324
|
+
) # the reason for this intruction is https://github.com/microsoft/pyright/discussions/8340
|
|
325
|
+
|
|
326
|
+
# MSE loss wrt to an observed batch
|
|
327
|
+
obs_loss_fun: Callable[[Params[Array]], Array] | None = (
|
|
328
|
+
lambda po: observations_loss_apply(
|
|
329
|
+
self.u,
|
|
330
|
+
pinn_in,
|
|
331
|
+
_set_derivatives(po, self.derivative_keys.observations),
|
|
332
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
333
|
+
val,
|
|
334
|
+
self.obs_slice,
|
|
335
|
+
)
|
|
301
336
|
)
|
|
302
337
|
else:
|
|
303
338
|
params_obs = None
|
|
@@ -310,43 +345,27 @@ class LossODE(_LossODEAbstract):
|
|
|
310
345
|
all_params: ODEComponents[Params[Array] | None] = ODEComponents(
|
|
311
346
|
params, params, params_obs
|
|
312
347
|
)
|
|
348
|
+
|
|
349
|
+
# Note that the lambda functions below are with type: ignore just
|
|
350
|
+
# because the lambda are not type annotated, but there is no proper way
|
|
351
|
+
# to do this and we should assign the lambda to a type hinted variable
|
|
352
|
+
# before hand: this is not practical, let us not get mad at this
|
|
313
353
|
mses_grads = jax.tree.map(
|
|
314
|
-
|
|
354
|
+
self.get_gradients,
|
|
315
355
|
all_funs,
|
|
316
356
|
all_params,
|
|
317
357
|
is_leaf=lambda x: x is None,
|
|
318
358
|
)
|
|
319
359
|
|
|
320
360
|
mses = jax.tree.map(
|
|
321
|
-
lambda leaf: leaf[0],
|
|
361
|
+
lambda leaf: leaf[0], # type: ignore
|
|
362
|
+
mses_grads,
|
|
363
|
+
is_leaf=lambda x: isinstance(x, tuple),
|
|
322
364
|
)
|
|
323
365
|
grads = jax.tree.map(
|
|
324
|
-
lambda leaf: leaf[1],
|
|
366
|
+
lambda leaf: leaf[1], # type: ignore
|
|
367
|
+
mses_grads,
|
|
368
|
+
is_leaf=lambda x: isinstance(x, tuple),
|
|
325
369
|
)
|
|
326
370
|
|
|
327
371
|
return mses, grads
|
|
328
|
-
|
|
329
|
-
def evaluate(
|
|
330
|
-
self, params: Params[Array], batch: ODEBatch
|
|
331
|
-
) -> tuple[Float[Array, " "], ODEComponents[Float[Array, " "] | None]]:
|
|
332
|
-
"""
|
|
333
|
-
Evaluate the loss function at a batch of points for given parameters.
|
|
334
|
-
|
|
335
|
-
We retrieve the total value itself and a PyTree with loss values for each term
|
|
336
|
-
|
|
337
|
-
Parameters
|
|
338
|
-
---------
|
|
339
|
-
params
|
|
340
|
-
Parameters at which the loss is evaluated
|
|
341
|
-
batch
|
|
342
|
-
Composed of a batch of points in the
|
|
343
|
-
domain, a batch of points in the domain
|
|
344
|
-
border and an optional additional batch of parameters (eg. for
|
|
345
|
-
metamodeling) and an optional additional batch of observed
|
|
346
|
-
inputs/outputs/parameters
|
|
347
|
-
"""
|
|
348
|
-
loss_terms, _ = self.evaluate_by_terms(params, batch)
|
|
349
|
-
|
|
350
|
-
loss_val = self.ponderate_and_sum_loss(loss_terms)
|
|
351
|
-
|
|
352
|
-
return loss_val, loss_terms
|