jinns 1.5.0__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. jinns/__init__.py +7 -7
  2. jinns/data/_AbstractDataGenerator.py +1 -1
  3. jinns/data/_Batchs.py +47 -13
  4. jinns/data/_CubicMeshPDENonStatio.py +203 -54
  5. jinns/data/_CubicMeshPDEStatio.py +190 -54
  6. jinns/data/_DataGeneratorODE.py +48 -22
  7. jinns/data/_DataGeneratorObservations.py +75 -32
  8. jinns/data/_DataGeneratorParameter.py +152 -101
  9. jinns/data/__init__.py +2 -1
  10. jinns/data/_utils.py +22 -10
  11. jinns/loss/_DynamicLoss.py +21 -20
  12. jinns/loss/_DynamicLossAbstract.py +51 -36
  13. jinns/loss/_LossODE.py +210 -191
  14. jinns/loss/_LossPDE.py +441 -368
  15. jinns/loss/_abstract_loss.py +60 -25
  16. jinns/loss/_loss_components.py +4 -25
  17. jinns/loss/_loss_utils.py +23 -0
  18. jinns/loss/_loss_weight_updates.py +6 -7
  19. jinns/loss/_loss_weights.py +34 -35
  20. jinns/nn/_abstract_pinn.py +0 -2
  21. jinns/nn/_hyperpinn.py +34 -23
  22. jinns/nn/_mlp.py +5 -4
  23. jinns/nn/_pinn.py +1 -16
  24. jinns/nn/_ppinn.py +5 -16
  25. jinns/nn/_save_load.py +11 -4
  26. jinns/nn/_spinn.py +1 -16
  27. jinns/nn/_spinn_mlp.py +5 -5
  28. jinns/nn/_utils.py +33 -38
  29. jinns/parameters/__init__.py +3 -1
  30. jinns/parameters/_derivative_keys.py +99 -41
  31. jinns/parameters/_params.py +58 -25
  32. jinns/solver/_solve.py +14 -8
  33. jinns/utils/_DictToModuleMeta.py +66 -0
  34. jinns/utils/_ItemizableModule.py +19 -0
  35. jinns/utils/__init__.py +2 -1
  36. jinns/utils/_types.py +25 -15
  37. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/METADATA +2 -2
  38. jinns-1.6.0.dist-info/RECORD +57 -0
  39. jinns-1.5.0.dist-info/RECORD +0 -55
  40. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/WHEEL +0 -0
  41. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/AUTHORS +0 -0
  42. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/LICENSE +0 -0
  43. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossODE.py CHANGED
@@ -7,9 +7,8 @@ from __future__ import (
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
8
 
9
9
  from dataclasses import InitVar
10
- from typing import TYPE_CHECKING, TypedDict, Callable
10
+ from typing import TYPE_CHECKING, Callable, Any, cast
11
11
  from types import EllipsisType
12
- import abc
13
12
  import warnings
14
13
  import jax
15
14
  import jax.numpy as jnp
@@ -19,133 +18,36 @@ from jaxtyping import Float, Array
19
18
  from jinns.loss._loss_utils import (
20
19
  dynamic_loss_apply,
21
20
  observations_loss_apply,
21
+ initial_condition_check,
22
22
  )
23
23
  from jinns.parameters._params import (
24
24
  _get_vmap_in_axes_params,
25
- _update_eq_params_dict,
25
+ update_eq_params,
26
26
  )
27
27
  from jinns.parameters._derivative_keys import _set_derivatives, DerivativeKeysODE
28
28
  from jinns.loss._loss_weights import LossWeightsODE
29
29
  from jinns.loss._abstract_loss import AbstractLoss
30
30
  from jinns.loss._loss_components import ODEComponents
31
31
  from jinns.parameters._params import Params
32
+ from jinns.data._Batchs import ODEBatch
32
33
 
33
34
  if TYPE_CHECKING:
34
35
  # imports only used in type hints
35
- from jinns.data._Batchs import ODEBatch
36
36
  from jinns.nn._abstract_pinn import AbstractPINN
37
37
  from jinns.loss import ODE
38
38
 
39
- class LossDictODE(TypedDict):
40
- dyn_loss: Float[Array, " "]
41
- initial_condition: Float[Array, " "]
42
- observations: Float[Array, " "]
43
-
44
-
45
- class _LossODEAbstract(AbstractLoss):
46
- """
47
- Parameters
48
- ----------
49
-
50
- loss_weights : LossWeightsODE, default=None
51
- The loss weights for the differents term : dynamic loss,
52
- initial condition and eventually observations if any.
53
- Can be updated according to a specific algorithm. See
54
- `update_weight_method`
55
- update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
56
- Default is None meaning no update for loss weights. Otherwise a string
57
- derivative_keys : DerivativeKeysODE, default=None
58
- Specify which field of `params` should be differentiated for each
59
- composant of the total loss. Particularily useful for inverse problems.
60
- Fields can be "nn_params", "eq_params" or "both". Those that should not
61
- be updated will have a `jax.lax.stop_gradient` called on them. Default
62
- is `"nn_params"` for each composant of the loss.
63
- initial_condition : tuple[float | Float[Array, " 1"], Float[Array, " dim"]], default=None
64
- tuple of length 2 with initial condition $(t_0, u_0)$.
65
- obs_slice : EllipsisType | slice | None, default=None
66
- Slice object specifying the begininning/ending
67
- slice of u output(s) that is observed. This is useful for
68
- multidimensional PINN, with partially observed outputs.
69
- Default is None (whole output is observed).
70
- params : InitVar[Params[Array]], default=None
71
- The main Params object of the problem needed to instanciate the
72
- DerivativeKeysODE if the latter is not specified.
73
- """
74
-
75
- # NOTE static=True only for leaf attributes that are not valid JAX types
76
- # (ie. jax.Array cannot be static) and that we do not expect to change
77
- # kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
78
- derivative_keys: DerivativeKeysODE | None = eqx.field(kw_only=True, default=None)
79
- loss_weights: LossWeightsODE | None = eqx.field(kw_only=True, default=None)
80
- initial_condition: (
81
- tuple[float | Float[Array, " 1"], Float[Array, " dim"]] | None
82
- ) = eqx.field(kw_only=True, default=None)
83
- obs_slice: EllipsisType | slice | None = eqx.field(
84
- kw_only=True, default=None, static=True
39
+ InitialConditionUser = (
40
+ tuple[Float[Array, " n_cond "], Float[Array, " n_cond dim"]]
41
+ | tuple[int | float | Float[Array, " "], int | float | Float[Array, " dim"]]
85
42
  )
86
43
 
87
- params: InitVar[Params[Array]] = eqx.field(default=None, kw_only=True)
88
-
89
- def __post_init__(self, params: Params[Array] | None = None):
90
- if self.loss_weights is None:
91
- self.loss_weights = LossWeightsODE()
92
-
93
- if self.derivative_keys is None:
94
- # by default we only take gradient wrt nn_params
95
- if params is None:
96
- raise ValueError(
97
- "Problem at self.derivative_keys initialization "
98
- f"received {self.derivative_keys=} and {params=}"
99
- )
100
- self.derivative_keys = DerivativeKeysODE(params=params)
101
- if self.initial_condition is None:
102
- warnings.warn(
103
- "Initial condition wasn't provided. Be sure to cover for that"
104
- "case (e.g by. hardcoding it into the PINN output)."
105
- )
106
- else:
107
- if (
108
- not isinstance(self.initial_condition, tuple)
109
- or len(self.initial_condition) != 2
110
- ):
111
- raise ValueError(
112
- "Initial condition should be a tuple of len 2 with (t0, u0), "
113
- f"{self.initial_condition} was passed."
114
- )
115
- # some checks/reshaping for t0
116
- t0, u0 = self.initial_condition
117
- if isinstance(t0, Array):
118
- if not t0.shape: # e.g. user input: jnp.array(0.)
119
- t0 = jnp.array([t0])
120
- elif t0.shape != (1,):
121
- raise ValueError(
122
- f"Wrong t0 input (self.initial_condition[0]) It should be"
123
- f"a float or an array of shape (1,). Got shape: {t0.shape}"
124
- )
125
- if isinstance(t0, float): # e.g. user input: 0.
126
- t0 = jnp.array([t0])
127
- if isinstance(t0, int): # e.g. user input: 0
128
- t0 = jnp.array([float(t0)])
129
- self.initial_condition = (t0, u0)
130
-
131
- if self.obs_slice is None:
132
- self.obs_slice = jnp.s_[...]
133
-
134
- if self.loss_weights is None:
135
- self.loss_weights = LossWeightsODE()
136
-
137
- @abc.abstractmethod
138
- def __call__(self, *_, **__):
139
- pass
140
-
141
- @abc.abstractmethod
142
- def evaluate(
143
- self: eqx.Module, params: Params[Array], batch: ODEBatch
144
- ) -> tuple[Float[Array, " "], LossDictODE]:
145
- raise NotImplementedError
44
+ InitialCondition = (
45
+ tuple[Float[Array, " n_cond "], Float[Array, " n_cond dim"]]
46
+ | tuple[Float[Array, " "], Float[Array, " dim"]]
47
+ )
146
48
 
147
49
 
148
- class LossODE(_LossODEAbstract):
50
+ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]]):
149
51
  r"""Loss object for an ordinary differential equation
150
52
 
151
53
  $$
@@ -158,6 +60,13 @@ class LossODE(_LossODEAbstract):
158
60
 
159
61
  Parameters
160
62
  ----------
63
+ u : eqx.Module
64
+ the PINN
65
+ dynamic_loss : ODE
66
+ the ODE dynamic part of the loss, basically the differential
67
+ operator $\mathcal{N}[u](t)$. Should implement a method
68
+ `dynamic_loss.evaluate(t, u, params)`.
69
+ Can be None in order to access only some part of the evaluate call.
161
70
  loss_weights : LossWeightsODE, default=None
162
71
  The loss weights for the differents term : dynamic loss,
163
72
  initial condition and eventually observations if any.
@@ -171,9 +80,16 @@ class LossODE(_LossODEAbstract):
171
80
  Fields can be "nn_params", "eq_params" or "both". Those that should not
172
81
  be updated will have a `jax.lax.stop_gradient` called on them. Default
173
82
  is `"nn_params"` for each composant of the loss.
174
- initial_condition : tuple[float | Float[Array, " 1"]], default=None
175
- tuple of length 2 with initial condition $(t_0, u_0)$.
176
- obs_slice : EllipsisType | slice | None, default=None
83
+ initial_condition : tuple[
84
+ Float[Array, "n_cond "],
85
+ Float[Array, "n_cond dim"]
86
+ ] |
87
+ tuple[int | float | Float[Array, " "],
88
+ int | float | Float[Array, " dim"]
89
+ ], default=None
90
+ Most of the time, a tuple of length 2 with initial condition $(t_0, u_0)$.
91
+ From jinns v1.5.1 we accept tuples of jnp arrays with shape (n_cond, 1) for t0 and (n_cond, dim) for u0. This is useful to include observed conditions at different time points, such as *e.g* final conditions. It was designed to implement $\mathcal{L}^{aux}$ from _Systems biology informed deep learning for inferring parameters and hidden dynamics_, Alireza Yazdani et al., 2020
92
+ obs_slice : EllipsisType | slice, default=None
177
93
  Slice object specifying the begininning/ending
178
94
  slice of u output(s) that is observed. This is useful for
179
95
  multidimensional PINN, with partially observed outputs.
@@ -181,14 +97,6 @@ class LossODE(_LossODEAbstract):
181
97
  params : InitVar[Params[Array]], default=None
182
98
  The main Params object of the problem needed to instanciate the
183
99
  DerivativeKeysODE if the latter is not specified.
184
- u : eqx.Module
185
- the PINN
186
- dynamic_loss : ODE
187
- the ODE dynamic part of the loss, basically the differential
188
- operator $\mathcal{N}[u](t)$. Should implement a method
189
- `dynamic_loss.evaluate(t, u, params)`.
190
- Can be None in order to access only some part of the evaluate call.
191
-
192
100
  Raises
193
101
  ------
194
102
  ValueError
@@ -199,19 +107,117 @@ class LossODE(_LossODEAbstract):
199
107
  # (ie. jax.Array cannot be static) and that we do not expect to change
200
108
  u: AbstractPINN
201
109
  dynamic_loss: ODE | None
110
+ vmap_in_axes: tuple[int] = eqx.field(static=True)
111
+ derivative_keys: DerivativeKeysODE
112
+ loss_weights: LossWeightsODE
113
+ initial_condition: InitialCondition | None
114
+ obs_slice: EllipsisType | slice = eqx.field(static=True)
115
+ params: InitVar[Params[Array] | None]
116
+
117
+ def __init__(
118
+ self,
119
+ *,
120
+ u: AbstractPINN,
121
+ dynamic_loss: ODE | None,
122
+ loss_weights: LossWeightsODE | None = None,
123
+ derivative_keys: DerivativeKeysODE | None = None,
124
+ initial_condition: InitialConditionUser | None = None,
125
+ obs_slice: EllipsisType | slice | None = None,
126
+ params: Params[Array] | None = None,
127
+ **kwargs: Any, # this is for arguments for super()
128
+ ):
129
+ if loss_weights is None:
130
+ self.loss_weights = LossWeightsODE()
131
+ else:
132
+ self.loss_weights = loss_weights
202
133
 
203
- vmap_in_axes: tuple[int] = eqx.field(init=False, static=True)
134
+ super().__init__(loss_weights=self.loss_weights, **kwargs)
135
+ self.u = u
136
+ self.dynamic_loss = dynamic_loss
137
+ self.vmap_in_axes = (0,)
138
+ if derivative_keys is None:
139
+ # by default we only take gradient wrt nn_params
140
+ if params is None:
141
+ raise ValueError(
142
+ "Problem at derivative_keys initialization "
143
+ f"received {derivative_keys=} and {params=}"
144
+ )
145
+ self.derivative_keys = DerivativeKeysODE(params=params)
146
+ else:
147
+ self.derivative_keys = derivative_keys
204
148
 
205
- def __post_init__(self, params: Params[Array] | None = None):
206
- super().__post_init__(
207
- params=params
208
- ) # because __init__ or __post_init__ of Base
209
- # class is not automatically called
149
+ if initial_condition is None:
150
+ warnings.warn(
151
+ "Initial condition wasn't provided. Be sure to cover for that"
152
+ "case (e.g by. hardcoding it into the PINN output)."
153
+ )
154
+ self.initial_condition = initial_condition
155
+ else:
156
+ if len(initial_condition) != 2:
157
+ raise ValueError(
158
+ "Initial condition should be a tuple of len 2 with (t0, u0), "
159
+ f"{initial_condition} was passed."
160
+ )
161
+ # some checks/reshaping for t0 and u0
162
+ t0, u0 = initial_condition
163
+ if isinstance(t0, Array):
164
+ # at the end we want to end up with t0 of shape (:, 1) to account for
165
+ # possibly several data points
166
+ if t0.ndim <= 1:
167
+ # in this case we assume t0 belongs one (initial)
168
+ # condition
169
+ t0 = initial_condition_check(t0, dim_size=1)[
170
+ None, :
171
+ ] # make a (1, 1) here
172
+ if t0.ndim > 2:
173
+ raise ValueError(
174
+ "It t0 is an Array, it represents n_cond"
175
+ " imposed conditions and must be of shape (n_cond, 1)"
176
+ )
177
+ else:
178
+ # in this case t0 clearly represents one (initial) condition
179
+ t0 = initial_condition_check(t0, dim_size=1)[
180
+ None, :
181
+ ] # make a (1, 1) here
182
+ if isinstance(u0, Array):
183
+ # at the end we want to end up with u0 of shape (:, dim) to account for
184
+ # possibly several data points
185
+ if not u0.shape:
186
+ # in this case we assume u0 belongs to one (initial)
187
+ # condition
188
+ u0 = initial_condition_check(u0, dim_size=1)[
189
+ None, :
190
+ ] # make a (1, 1) here
191
+ elif u0.ndim == 1:
192
+ # in this case we assume u0 belongs to one (initial)
193
+ # condition
194
+ u0 = initial_condition_check(u0, dim_size=u0.shape[0])[
195
+ None, :
196
+ ] # make a (1, dim) here
197
+ if u0.ndim > 2:
198
+ raise ValueError(
199
+ "It u0 is an Array, it represents n_cond "
200
+ "imposed conditions and must be of shape (n_cond, dim)"
201
+ )
202
+ else:
203
+ # at the end we want to end up with u0 of shape (:, dim) to account for
204
+ # possibly several data points
205
+ u0 = initial_condition_check(u0, dim_size=None)[
206
+ None, :
207
+ ] # make a (1, 1) here
210
208
 
211
- self.vmap_in_axes = (0,)
209
+ if t0.shape[0] != u0.shape[0] or t0.ndim != u0.ndim:
210
+ raise ValueError(
211
+ "t0 and u0 must represent a same number of initial"
212
+ " conditial conditions"
213
+ )
212
214
 
213
- def __call__(self, *args, **kwargs):
214
- return self.evaluate(*args, **kwargs)
215
+ self.initial_condition = (t0, u0)
216
+
217
+ if obs_slice is None:
218
+ self.obs_slice = jnp.s_[...]
219
+ else:
220
+ self.obs_slice = obs_slice
215
221
 
216
222
  def evaluate_by_terms(
217
223
  self, params: Params[Array], batch: ODEBatch
@@ -241,63 +247,92 @@ class LossODE(_LossODEAbstract):
241
247
  # and update vmap_in_axes
242
248
  if batch.param_batch_dict is not None:
243
249
  # update params with the batches of generated params
244
- params = _update_eq_params_dict(params, batch.param_batch_dict)
250
+ params = update_eq_params(params, batch.param_batch_dict)
245
251
 
246
- vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
252
+ vmap_in_axes_params = _get_vmap_in_axes_params(
253
+ cast(eqx.Module, batch.param_batch_dict), params
254
+ )
247
255
 
248
256
  ## dynamic part
249
257
  if self.dynamic_loss is not None:
250
- dyn_loss_fun = lambda p: dynamic_loss_apply(
251
- self.dynamic_loss.evaluate, # type: ignore
252
- self.u,
253
- temporal_batch,
254
- _set_derivatives(p, self.derivative_keys.dyn_loss), # type: ignore
255
- self.vmap_in_axes + vmap_in_axes_params,
258
+ dyn_loss_eval = self.dynamic_loss.evaluate
259
+ dyn_loss_fun: Callable[[Params[Array]], Array] | None = (
260
+ lambda p: dynamic_loss_apply(
261
+ dyn_loss_eval,
262
+ self.u,
263
+ temporal_batch,
264
+ _set_derivatives(p, self.derivative_keys.dyn_loss),
265
+ self.vmap_in_axes + vmap_in_axes_params,
266
+ )
256
267
  )
257
268
  else:
258
269
  dyn_loss_fun = None
259
270
 
260
- # initial condition
261
271
  if self.initial_condition is not None:
262
- vmap_in_axes = (None,) + vmap_in_axes_params
263
- if not jax.tree_util.tree_leaves(vmap_in_axes):
264
- # test if only None in vmap_in_axes to avoid the value error:
265
- # `vmap must have at least one non-None value in in_axes`
266
- v_u = self.u
267
- else:
268
- v_u = vmap(self.u, (None,) + vmap_in_axes_params)
272
+ # initial condition
269
273
  t0, u0 = self.initial_condition
270
- u0 = jnp.array(u0)
271
- initial_condition_fun = lambda p: jnp.mean(
272
- jnp.sum(
274
+
275
+ # first construct the plain init loss no vmaping
276
+ initial_condition_fun__: Callable[[Array, Array, Params[Array]], Array] = (
277
+ lambda t, u, p: jnp.sum(
273
278
  (
274
- v_u(
275
- t0,
276
- _set_derivatives(p, self.derivative_keys.initial_condition), # type: ignore
279
+ self.u(
280
+ t,
281
+ _set_derivatives(
282
+ p,
283
+ self.derivative_keys.initial_condition,
284
+ ),
277
285
  )
278
- - u0
286
+ - u
279
287
  )
280
288
  ** 2,
281
- axis=-1,
289
+ axis=0,
282
290
  )
283
291
  )
292
+ # now vmap over the number of conditions (first dim of t0 and u0)
293
+ # and take the mean
294
+ initial_condition_fun_: Callable[[Params[Array]], Array] = (
295
+ lambda p: jnp.mean(
296
+ vmap(initial_condition_fun__, (0, 0, None))(t0, u0, p)
297
+ )
298
+ )
299
+ # now vmap over the the possible batch of parameters and take the
300
+ # average. Note that we then finally have a cartesian product
301
+ # between the batch of parameters (if any) and the number of
302
+ # conditions (if any)
303
+ if not jax.tree_util.tree_leaves(vmap_in_axes_params):
304
+ # if there is no parameter batch to vmap over we cannot call
305
+ # vmap because calling vmap must be done with at least one non
306
+ # None in_axes or out_axes
307
+ initial_condition_fun = initial_condition_fun_
308
+ else:
309
+ initial_condition_fun: Callable[[Params[Array]], Array] | None = (
310
+ lambda p: jnp.mean(
311
+ vmap(initial_condition_fun_, vmap_in_axes_params)(p)
312
+ )
313
+ )
284
314
  else:
285
315
  initial_condition_fun = None
286
316
 
287
317
  if batch.obs_batch_dict is not None:
288
318
  # update params with the batches of observed params
289
- params_obs = _update_eq_params_dict(
290
- params, batch.obs_batch_dict["eq_params"]
291
- )
319
+ params_obs = update_eq_params(params, batch.obs_batch_dict["eq_params"])
292
320
 
293
- # MSE loss wrt to an observed batch
294
- obs_loss_fun = lambda po: observations_loss_apply(
295
- self.u,
321
+ pinn_in, val = (
296
322
  batch.obs_batch_dict["pinn_in"],
297
- _set_derivatives(po, self.derivative_keys.observations), # type: ignore
298
- self.vmap_in_axes + vmap_in_axes_params,
299
323
  batch.obs_batch_dict["val"],
300
- self.obs_slice,
324
+ ) # the reason for this intruction is https://github.com/microsoft/pyright/discussions/8340
325
+
326
+ # MSE loss wrt to an observed batch
327
+ obs_loss_fun: Callable[[Params[Array]], Array] | None = (
328
+ lambda po: observations_loss_apply(
329
+ self.u,
330
+ pinn_in,
331
+ _set_derivatives(po, self.derivative_keys.observations),
332
+ self.vmap_in_axes + vmap_in_axes_params,
333
+ val,
334
+ self.obs_slice,
335
+ )
301
336
  )
302
337
  else:
303
338
  params_obs = None
@@ -310,43 +345,27 @@ class LossODE(_LossODEAbstract):
310
345
  all_params: ODEComponents[Params[Array] | None] = ODEComponents(
311
346
  params, params, params_obs
312
347
  )
348
+
349
+ # Note that the lambda functions below are with type: ignore just
350
+ # because the lambda are not type annotated, but there is no proper way
351
+ # to do this and we should assign the lambda to a type hinted variable
352
+ # before hand: this is not practical, let us not get mad at this
313
353
  mses_grads = jax.tree.map(
314
- lambda fun, params: self.get_gradients(fun, params),
354
+ self.get_gradients,
315
355
  all_funs,
316
356
  all_params,
317
357
  is_leaf=lambda x: x is None,
318
358
  )
319
359
 
320
360
  mses = jax.tree.map(
321
- lambda leaf: leaf[0], mses_grads, is_leaf=lambda x: isinstance(x, tuple)
361
+ lambda leaf: leaf[0], # type: ignore
362
+ mses_grads,
363
+ is_leaf=lambda x: isinstance(x, tuple),
322
364
  )
323
365
  grads = jax.tree.map(
324
- lambda leaf: leaf[1], mses_grads, is_leaf=lambda x: isinstance(x, tuple)
366
+ lambda leaf: leaf[1], # type: ignore
367
+ mses_grads,
368
+ is_leaf=lambda x: isinstance(x, tuple),
325
369
  )
326
370
 
327
371
  return mses, grads
328
-
329
- def evaluate(
330
- self, params: Params[Array], batch: ODEBatch
331
- ) -> tuple[Float[Array, " "], ODEComponents[Float[Array, " "] | None]]:
332
- """
333
- Evaluate the loss function at a batch of points for given parameters.
334
-
335
- We retrieve the total value itself and a PyTree with loss values for each term
336
-
337
- Parameters
338
- ---------
339
- params
340
- Parameters at which the loss is evaluated
341
- batch
342
- Composed of a batch of points in the
343
- domain, a batch of points in the domain
344
- border and an optional additional batch of parameters (eg. for
345
- metamodeling) and an optional additional batch of observed
346
- inputs/outputs/parameters
347
- """
348
- loss_terms, _ = self.evaluate_by_terms(params, batch)
349
-
350
- loss_val = self.ponderate_and_sum_loss(loss_terms)
351
-
352
- return loss_val, loss_terms