jinns 1.5.1__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_AbstractDataGenerator.py +1 -1
- jinns/data/_Batchs.py +47 -13
- jinns/data/_CubicMeshPDENonStatio.py +55 -34
- jinns/data/_CubicMeshPDEStatio.py +63 -35
- jinns/data/_DataGeneratorODE.py +48 -22
- jinns/data/_DataGeneratorObservations.py +86 -32
- jinns/data/_DataGeneratorParameter.py +152 -101
- jinns/data/__init__.py +2 -1
- jinns/data/_utils.py +22 -10
- jinns/loss/_DynamicLoss.py +21 -20
- jinns/loss/_DynamicLossAbstract.py +51 -36
- jinns/loss/_LossODE.py +139 -184
- jinns/loss/_LossPDE.py +440 -358
- jinns/loss/_abstract_loss.py +60 -25
- jinns/loss/_loss_components.py +4 -25
- jinns/loss/_loss_weight_updates.py +6 -7
- jinns/loss/_loss_weights.py +34 -35
- jinns/nn/_abstract_pinn.py +0 -2
- jinns/nn/_hyperpinn.py +34 -23
- jinns/nn/_mlp.py +5 -4
- jinns/nn/_pinn.py +1 -16
- jinns/nn/_ppinn.py +5 -16
- jinns/nn/_save_load.py +11 -4
- jinns/nn/_spinn.py +1 -16
- jinns/nn/_spinn_mlp.py +5 -5
- jinns/nn/_utils.py +33 -38
- jinns/parameters/__init__.py +3 -1
- jinns/parameters/_derivative_keys.py +99 -41
- jinns/parameters/_params.py +50 -25
- jinns/solver/_solve.py +3 -3
- jinns/utils/_DictToModuleMeta.py +66 -0
- jinns/utils/_ItemizableModule.py +19 -0
- jinns/utils/__init__.py +2 -1
- jinns/utils/_types.py +25 -15
- {jinns-1.5.1.dist-info → jinns-1.6.1.dist-info}/METADATA +2 -2
- jinns-1.6.1.dist-info/RECORD +57 -0
- jinns-1.5.1.dist-info/RECORD +0 -55
- {jinns-1.5.1.dist-info → jinns-1.6.1.dist-info}/WHEEL +0 -0
- {jinns-1.5.1.dist-info → jinns-1.6.1.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.5.1.dist-info → jinns-1.6.1.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.5.1.dist-info → jinns-1.6.1.dist-info}/top_level.txt +0 -0
jinns/loss/_LossODE.py
CHANGED
|
@@ -7,9 +7,8 @@ from __future__ import (
|
|
|
7
7
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
8
|
|
|
9
9
|
from dataclasses import InitVar
|
|
10
|
-
from typing import TYPE_CHECKING,
|
|
10
|
+
from typing import TYPE_CHECKING, Callable, Any, cast
|
|
11
11
|
from types import EllipsisType
|
|
12
|
-
import abc
|
|
13
12
|
import warnings
|
|
14
13
|
import jax
|
|
15
14
|
import jax.numpy as jnp
|
|
@@ -23,31 +22,51 @@ from jinns.loss._loss_utils import (
|
|
|
23
22
|
)
|
|
24
23
|
from jinns.parameters._params import (
|
|
25
24
|
_get_vmap_in_axes_params,
|
|
26
|
-
|
|
25
|
+
update_eq_params,
|
|
27
26
|
)
|
|
28
27
|
from jinns.parameters._derivative_keys import _set_derivatives, DerivativeKeysODE
|
|
29
28
|
from jinns.loss._loss_weights import LossWeightsODE
|
|
30
29
|
from jinns.loss._abstract_loss import AbstractLoss
|
|
31
30
|
from jinns.loss._loss_components import ODEComponents
|
|
32
31
|
from jinns.parameters._params import Params
|
|
32
|
+
from jinns.data._Batchs import ODEBatch
|
|
33
33
|
|
|
34
34
|
if TYPE_CHECKING:
|
|
35
35
|
# imports only used in type hints
|
|
36
|
-
from jinns.data._Batchs import ODEBatch
|
|
37
36
|
from jinns.nn._abstract_pinn import AbstractPINN
|
|
38
37
|
from jinns.loss import ODE
|
|
39
38
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
39
|
+
InitialConditionUser = (
|
|
40
|
+
tuple[Float[Array, " n_cond "], Float[Array, " n_cond dim"]]
|
|
41
|
+
| tuple[int | float | Float[Array, " "], int | float | Float[Array, " dim"]]
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
InitialCondition = (
|
|
45
|
+
tuple[Float[Array, " n_cond "], Float[Array, " n_cond dim"]]
|
|
46
|
+
| tuple[Float[Array, " "], Float[Array, " dim"]]
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]]):
|
|
51
|
+
r"""Loss object for an ordinary differential equation
|
|
52
|
+
|
|
53
|
+
$$
|
|
54
|
+
\mathcal{N}[u](t) = 0, \forall t \in I
|
|
55
|
+
$$
|
|
56
|
+
|
|
57
|
+
where $\mathcal{N}[\cdot]$ is a differential operator and the
|
|
58
|
+
initial condition is $u(t_0)=u_0$.
|
|
44
59
|
|
|
45
60
|
|
|
46
|
-
class _LossODEAbstract(AbstractLoss):
|
|
47
|
-
r"""
|
|
48
61
|
Parameters
|
|
49
62
|
----------
|
|
50
|
-
|
|
63
|
+
u : eqx.Module
|
|
64
|
+
the PINN
|
|
65
|
+
dynamic_loss : ODE
|
|
66
|
+
the ODE dynamic part of the loss, basically the differential
|
|
67
|
+
operator $\mathcal{N}[u](t)$. Should implement a method
|
|
68
|
+
`dynamic_loss.evaluate(t, u, params)`.
|
|
69
|
+
Can be None in order to access only some part of the evaluate call.
|
|
51
70
|
loss_weights : LossWeightsODE, default=None
|
|
52
71
|
The loss weights for the differents term : dynamic loss,
|
|
53
72
|
initial condition and eventually observations if any.
|
|
@@ -67,10 +86,10 @@ class _LossODEAbstract(AbstractLoss):
|
|
|
67
86
|
] |
|
|
68
87
|
tuple[int | float | Float[Array, " "],
|
|
69
88
|
int | float | Float[Array, " dim"]
|
|
70
|
-
]
|
|
89
|
+
], default=None
|
|
71
90
|
Most of the time, a tuple of length 2 with initial condition $(t_0, u_0)$.
|
|
72
91
|
From jinns v1.5.1 we accept tuples of jnp arrays with shape (n_cond, 1) for t0 and (n_cond, dim) for u0. This is useful to include observed conditions at different time points, such as *e.g* final conditions. It was designed to implement $\mathcal{L}^{aux}$ from _Systems biology informed deep learning for inferring parameters and hidden dynamics_, Alireza Yazdani et al., 2020
|
|
73
|
-
obs_slice : EllipsisType | slice
|
|
92
|
+
obs_slice : EllipsisType | slice, default=None
|
|
74
93
|
Slice object specifying the begininning/ending
|
|
75
94
|
slice of u output(s) that is observed. This is useful for
|
|
76
95
|
multidimensional PINN, with partially observed outputs.
|
|
@@ -78,52 +97,69 @@ class _LossODEAbstract(AbstractLoss):
|
|
|
78
97
|
params : InitVar[Params[Array]], default=None
|
|
79
98
|
The main Params object of the problem needed to instanciate the
|
|
80
99
|
DerivativeKeysODE if the latter is not specified.
|
|
100
|
+
Raises
|
|
101
|
+
------
|
|
102
|
+
ValueError
|
|
103
|
+
if initial condition is not a tuple.
|
|
81
104
|
"""
|
|
82
105
|
|
|
83
106
|
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
84
107
|
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
108
|
+
u: AbstractPINN
|
|
109
|
+
dynamic_loss: ODE | None
|
|
110
|
+
vmap_in_axes: tuple[int] = eqx.field(static=True)
|
|
111
|
+
derivative_keys: DerivativeKeysODE
|
|
112
|
+
loss_weights: LossWeightsODE
|
|
113
|
+
initial_condition: InitialCondition | None
|
|
114
|
+
obs_slice: EllipsisType | slice = eqx.field(static=True)
|
|
115
|
+
params: InitVar[Params[Array] | None]
|
|
116
|
+
|
|
117
|
+
def __init__(
|
|
118
|
+
self,
|
|
119
|
+
*,
|
|
120
|
+
u: AbstractPINN,
|
|
121
|
+
dynamic_loss: ODE | None,
|
|
122
|
+
loss_weights: LossWeightsODE | None = None,
|
|
123
|
+
derivative_keys: DerivativeKeysODE | None = None,
|
|
124
|
+
initial_condition: InitialConditionUser | None = None,
|
|
125
|
+
obs_slice: EllipsisType | slice | None = None,
|
|
126
|
+
params: Params[Array] | None = None,
|
|
127
|
+
**kwargs: Any, # this is for arguments for super()
|
|
128
|
+
):
|
|
129
|
+
if loss_weights is None:
|
|
101
130
|
self.loss_weights = LossWeightsODE()
|
|
131
|
+
else:
|
|
132
|
+
self.loss_weights = loss_weights
|
|
102
133
|
|
|
103
|
-
|
|
134
|
+
super().__init__(loss_weights=self.loss_weights, **kwargs)
|
|
135
|
+
self.u = u
|
|
136
|
+
self.dynamic_loss = dynamic_loss
|
|
137
|
+
self.vmap_in_axes = (0,)
|
|
138
|
+
if derivative_keys is None:
|
|
104
139
|
# by default we only take gradient wrt nn_params
|
|
105
140
|
if params is None:
|
|
106
141
|
raise ValueError(
|
|
107
|
-
"Problem at
|
|
108
|
-
f"received {
|
|
142
|
+
"Problem at derivative_keys initialization "
|
|
143
|
+
f"received {derivative_keys=} and {params=}"
|
|
109
144
|
)
|
|
110
145
|
self.derivative_keys = DerivativeKeysODE(params=params)
|
|
111
|
-
|
|
146
|
+
else:
|
|
147
|
+
self.derivative_keys = derivative_keys
|
|
148
|
+
|
|
149
|
+
if initial_condition is None:
|
|
112
150
|
warnings.warn(
|
|
113
151
|
"Initial condition wasn't provided. Be sure to cover for that"
|
|
114
152
|
"case (e.g by. hardcoding it into the PINN output)."
|
|
115
153
|
)
|
|
154
|
+
self.initial_condition = initial_condition
|
|
116
155
|
else:
|
|
117
|
-
if (
|
|
118
|
-
not isinstance(self.initial_condition, tuple)
|
|
119
|
-
or len(self.initial_condition) != 2
|
|
120
|
-
):
|
|
156
|
+
if len(initial_condition) != 2:
|
|
121
157
|
raise ValueError(
|
|
122
158
|
"Initial condition should be a tuple of len 2 with (t0, u0), "
|
|
123
|
-
f"{
|
|
159
|
+
f"{initial_condition} was passed."
|
|
124
160
|
)
|
|
125
161
|
# some checks/reshaping for t0 and u0
|
|
126
|
-
t0, u0 =
|
|
162
|
+
t0, u0 = initial_condition
|
|
127
163
|
if isinstance(t0, Array):
|
|
128
164
|
# at the end we want to end up with t0 of shape (:, 1) to account for
|
|
129
165
|
# possibly several data points
|
|
@@ -178,90 +214,10 @@ class _LossODEAbstract(AbstractLoss):
|
|
|
178
214
|
|
|
179
215
|
self.initial_condition = (t0, u0)
|
|
180
216
|
|
|
181
|
-
if
|
|
217
|
+
if obs_slice is None:
|
|
182
218
|
self.obs_slice = jnp.s_[...]
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
self.loss_weights = LossWeightsODE()
|
|
186
|
-
|
|
187
|
-
@abc.abstractmethod
|
|
188
|
-
def __call__(self, *_, **__):
|
|
189
|
-
pass
|
|
190
|
-
|
|
191
|
-
@abc.abstractmethod
|
|
192
|
-
def evaluate(
|
|
193
|
-
self: eqx.Module, params: Params[Array], batch: ODEBatch
|
|
194
|
-
) -> tuple[Float[Array, " "], LossDictODE]:
|
|
195
|
-
raise NotImplementedError
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
class LossODE(_LossODEAbstract):
|
|
199
|
-
r"""Loss object for an ordinary differential equation
|
|
200
|
-
|
|
201
|
-
$$
|
|
202
|
-
\mathcal{N}[u](t) = 0, \forall t \in I
|
|
203
|
-
$$
|
|
204
|
-
|
|
205
|
-
where $\mathcal{N}[\cdot]$ is a differential operator and the
|
|
206
|
-
initial condition is $u(t_0)=u_0$.
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
Parameters
|
|
210
|
-
----------
|
|
211
|
-
loss_weights : LossWeightsODE, default=None
|
|
212
|
-
The loss weights for the differents term : dynamic loss,
|
|
213
|
-
initial condition and eventually observations if any.
|
|
214
|
-
Can be updated according to a specific algorithm. See
|
|
215
|
-
`update_weight_method`
|
|
216
|
-
update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
|
|
217
|
-
Default is None meaning no update for loss weights. Otherwise a string
|
|
218
|
-
derivative_keys : DerivativeKeysODE, default=None
|
|
219
|
-
Specify which field of `params` should be differentiated for each
|
|
220
|
-
composant of the total loss. Particularily useful for inverse problems.
|
|
221
|
-
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
222
|
-
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
223
|
-
is `"nn_params"` for each composant of the loss.
|
|
224
|
-
initial_condition : tuple[float | Float[Array, " 1"]], default=None
|
|
225
|
-
tuple of length 2 with initial condition $(t_0, u_0)$.
|
|
226
|
-
obs_slice : EllipsisType | slice | None, default=None
|
|
227
|
-
Slice object specifying the begininning/ending
|
|
228
|
-
slice of u output(s) that is observed. This is useful for
|
|
229
|
-
multidimensional PINN, with partially observed outputs.
|
|
230
|
-
Default is None (whole output is observed).
|
|
231
|
-
params : InitVar[Params[Array]], default=None
|
|
232
|
-
The main Params object of the problem needed to instanciate the
|
|
233
|
-
DerivativeKeysODE if the latter is not specified.
|
|
234
|
-
u : eqx.Module
|
|
235
|
-
the PINN
|
|
236
|
-
dynamic_loss : ODE
|
|
237
|
-
the ODE dynamic part of the loss, basically the differential
|
|
238
|
-
operator $\mathcal{N}[u](t)$. Should implement a method
|
|
239
|
-
`dynamic_loss.evaluate(t, u, params)`.
|
|
240
|
-
Can be None in order to access only some part of the evaluate call.
|
|
241
|
-
|
|
242
|
-
Raises
|
|
243
|
-
------
|
|
244
|
-
ValueError
|
|
245
|
-
if initial condition is not a tuple.
|
|
246
|
-
"""
|
|
247
|
-
|
|
248
|
-
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
249
|
-
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
250
|
-
u: AbstractPINN
|
|
251
|
-
dynamic_loss: ODE | None
|
|
252
|
-
|
|
253
|
-
vmap_in_axes: tuple[int] = eqx.field(init=False, static=True)
|
|
254
|
-
|
|
255
|
-
def __post_init__(self, params: Params[Array] | None = None):
|
|
256
|
-
super().__post_init__(
|
|
257
|
-
params=params
|
|
258
|
-
) # because __init__ or __post_init__ of Base
|
|
259
|
-
# class is not automatically called
|
|
260
|
-
|
|
261
|
-
self.vmap_in_axes = (0,)
|
|
262
|
-
|
|
263
|
-
def __call__(self, *args, **kwargs):
|
|
264
|
-
return self.evaluate(*args, **kwargs)
|
|
219
|
+
else:
|
|
220
|
+
self.obs_slice = obs_slice
|
|
265
221
|
|
|
266
222
|
def evaluate_by_terms(
|
|
267
223
|
self, params: Params[Array], batch: ODEBatch
|
|
@@ -291,46 +247,54 @@ class LossODE(_LossODEAbstract):
|
|
|
291
247
|
# and update vmap_in_axes
|
|
292
248
|
if batch.param_batch_dict is not None:
|
|
293
249
|
# update params with the batches of generated params
|
|
294
|
-
params =
|
|
250
|
+
params = update_eq_params(params, batch.param_batch_dict)
|
|
295
251
|
|
|
296
|
-
vmap_in_axes_params = _get_vmap_in_axes_params(
|
|
252
|
+
vmap_in_axes_params = _get_vmap_in_axes_params(
|
|
253
|
+
cast(eqx.Module, batch.param_batch_dict), params
|
|
254
|
+
)
|
|
297
255
|
|
|
298
256
|
## dynamic part
|
|
299
257
|
if self.dynamic_loss is not None:
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
258
|
+
dyn_loss_eval = self.dynamic_loss.evaluate
|
|
259
|
+
dyn_loss_fun: Callable[[Params[Array]], Array] | None = (
|
|
260
|
+
lambda p: dynamic_loss_apply(
|
|
261
|
+
dyn_loss_eval,
|
|
262
|
+
self.u,
|
|
263
|
+
temporal_batch,
|
|
264
|
+
_set_derivatives(p, self.derivative_keys.dyn_loss),
|
|
265
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
266
|
+
)
|
|
306
267
|
)
|
|
307
268
|
else:
|
|
308
269
|
dyn_loss_fun = None
|
|
309
270
|
|
|
310
|
-
# initial condition
|
|
311
271
|
if self.initial_condition is not None:
|
|
272
|
+
# initial condition
|
|
312
273
|
t0, u0 = self.initial_condition
|
|
313
|
-
u0 = jnp.array(u0)
|
|
314
274
|
|
|
315
275
|
# first construct the plain init loss no vmaping
|
|
316
|
-
initial_condition_fun__
|
|
317
|
-
(
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
276
|
+
initial_condition_fun__: Callable[[Array, Array, Params[Array]], Array] = (
|
|
277
|
+
lambda t, u, p: jnp.sum(
|
|
278
|
+
(
|
|
279
|
+
self.u(
|
|
280
|
+
t,
|
|
281
|
+
_set_derivatives(
|
|
282
|
+
p,
|
|
283
|
+
self.derivative_keys.initial_condition,
|
|
284
|
+
),
|
|
285
|
+
)
|
|
286
|
+
- u
|
|
324
287
|
)
|
|
325
|
-
|
|
288
|
+
** 2,
|
|
289
|
+
axis=0,
|
|
326
290
|
)
|
|
327
|
-
** 2,
|
|
328
|
-
axis=0,
|
|
329
291
|
)
|
|
330
292
|
# now vmap over the number of conditions (first dim of t0 and u0)
|
|
331
293
|
# and take the mean
|
|
332
|
-
initial_condition_fun_
|
|
333
|
-
|
|
294
|
+
initial_condition_fun_: Callable[[Params[Array]], Array] = (
|
|
295
|
+
lambda p: jnp.mean(
|
|
296
|
+
vmap(initial_condition_fun__, (0, 0, None))(t0, u0, p)
|
|
297
|
+
)
|
|
334
298
|
)
|
|
335
299
|
# now vmap over the the possible batch of parameters and take the
|
|
336
300
|
# average. Note that we then finally have a cartesian product
|
|
@@ -342,26 +306,33 @@ class LossODE(_LossODEAbstract):
|
|
|
342
306
|
# None in_axes or out_axes
|
|
343
307
|
initial_condition_fun = initial_condition_fun_
|
|
344
308
|
else:
|
|
345
|
-
initial_condition_fun
|
|
346
|
-
|
|
309
|
+
initial_condition_fun: Callable[[Params[Array]], Array] | None = (
|
|
310
|
+
lambda p: jnp.mean(
|
|
311
|
+
vmap(initial_condition_fun_, vmap_in_axes_params)(p)
|
|
312
|
+
)
|
|
347
313
|
)
|
|
348
314
|
else:
|
|
349
315
|
initial_condition_fun = None
|
|
350
316
|
|
|
351
317
|
if batch.obs_batch_dict is not None:
|
|
352
318
|
# update params with the batches of observed params
|
|
353
|
-
params_obs =
|
|
354
|
-
params, batch.obs_batch_dict["eq_params"]
|
|
355
|
-
)
|
|
319
|
+
params_obs = update_eq_params(params, batch.obs_batch_dict["eq_params"])
|
|
356
320
|
|
|
357
|
-
|
|
358
|
-
obs_loss_fun = lambda po: observations_loss_apply(
|
|
359
|
-
self.u,
|
|
321
|
+
pinn_in, val = (
|
|
360
322
|
batch.obs_batch_dict["pinn_in"],
|
|
361
|
-
_set_derivatives(po, self.derivative_keys.observations), # type: ignore
|
|
362
|
-
self.vmap_in_axes + vmap_in_axes_params,
|
|
363
323
|
batch.obs_batch_dict["val"],
|
|
364
|
-
|
|
324
|
+
) # the reason for this intruction is https://github.com/microsoft/pyright/discussions/8340
|
|
325
|
+
|
|
326
|
+
# MSE loss wrt to an observed batch
|
|
327
|
+
obs_loss_fun: Callable[[Params[Array]], Array] | None = (
|
|
328
|
+
lambda po: observations_loss_apply(
|
|
329
|
+
self.u,
|
|
330
|
+
pinn_in,
|
|
331
|
+
_set_derivatives(po, self.derivative_keys.observations),
|
|
332
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
333
|
+
val,
|
|
334
|
+
self.obs_slice,
|
|
335
|
+
)
|
|
365
336
|
)
|
|
366
337
|
else:
|
|
367
338
|
params_obs = None
|
|
@@ -374,43 +345,27 @@ class LossODE(_LossODEAbstract):
|
|
|
374
345
|
all_params: ODEComponents[Params[Array] | None] = ODEComponents(
|
|
375
346
|
params, params, params_obs
|
|
376
347
|
)
|
|
348
|
+
|
|
349
|
+
# Note that the lambda functions below are with type: ignore just
|
|
350
|
+
# because the lambda are not type annotated, but there is no proper way
|
|
351
|
+
# to do this and we should assign the lambda to a type hinted variable
|
|
352
|
+
# before hand: this is not practical, let us not get mad at this
|
|
377
353
|
mses_grads = jax.tree.map(
|
|
378
|
-
|
|
354
|
+
self.get_gradients,
|
|
379
355
|
all_funs,
|
|
380
356
|
all_params,
|
|
381
357
|
is_leaf=lambda x: x is None,
|
|
382
358
|
)
|
|
383
359
|
|
|
384
360
|
mses = jax.tree.map(
|
|
385
|
-
lambda leaf: leaf[0],
|
|
361
|
+
lambda leaf: leaf[0], # type: ignore
|
|
362
|
+
mses_grads,
|
|
363
|
+
is_leaf=lambda x: isinstance(x, tuple),
|
|
386
364
|
)
|
|
387
365
|
grads = jax.tree.map(
|
|
388
|
-
lambda leaf: leaf[1],
|
|
366
|
+
lambda leaf: leaf[1], # type: ignore
|
|
367
|
+
mses_grads,
|
|
368
|
+
is_leaf=lambda x: isinstance(x, tuple),
|
|
389
369
|
)
|
|
390
370
|
|
|
391
371
|
return mses, grads
|
|
392
|
-
|
|
393
|
-
def evaluate(
|
|
394
|
-
self, params: Params[Array], batch: ODEBatch
|
|
395
|
-
) -> tuple[Float[Array, " "], ODEComponents[Float[Array, " "] | None]]:
|
|
396
|
-
"""
|
|
397
|
-
Evaluate the loss function at a batch of points for given parameters.
|
|
398
|
-
|
|
399
|
-
We retrieve the total value itself and a PyTree with loss values for each term
|
|
400
|
-
|
|
401
|
-
Parameters
|
|
402
|
-
---------
|
|
403
|
-
params
|
|
404
|
-
Parameters at which the loss is evaluated
|
|
405
|
-
batch
|
|
406
|
-
Composed of a batch of points in the
|
|
407
|
-
domain, a batch of points in the domain
|
|
408
|
-
border and an optional additional batch of parameters (eg. for
|
|
409
|
-
metamodeling) and an optional additional batch of observed
|
|
410
|
-
inputs/outputs/parameters
|
|
411
|
-
"""
|
|
412
|
-
loss_terms, _ = self.evaluate_by_terms(params, batch)
|
|
413
|
-
|
|
414
|
-
loss_val = self.ponderate_and_sum_loss(loss_terms)
|
|
415
|
-
|
|
416
|
-
return loss_val, loss_terms
|