jinns 1.4.0__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/loss/_LossODE.py CHANGED
@@ -7,7 +7,7 @@ from __future__ import (
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
8
 
9
9
  from dataclasses import InitVar
10
- from typing import TYPE_CHECKING, TypedDict
10
+ from typing import TYPE_CHECKING, TypedDict, Callable
11
11
  from types import EllipsisType
12
12
  import abc
13
13
  import warnings
@@ -19,6 +19,7 @@ from jaxtyping import Float, Array
19
19
  from jinns.loss._loss_utils import (
20
20
  dynamic_loss_apply,
21
21
  observations_loss_apply,
22
+ initial_condition_check,
22
23
  )
23
24
  from jinns.parameters._params import (
24
25
  _get_vmap_in_axes_params,
@@ -27,10 +28,11 @@ from jinns.parameters._params import (
27
28
  from jinns.parameters._derivative_keys import _set_derivatives, DerivativeKeysODE
28
29
  from jinns.loss._loss_weights import LossWeightsODE
29
30
  from jinns.loss._abstract_loss import AbstractLoss
31
+ from jinns.loss._loss_components import ODEComponents
32
+ from jinns.parameters._params import Params
30
33
 
31
34
  if TYPE_CHECKING:
32
35
  # imports only used in type hints
33
- from jinns.parameters._params import Params
34
36
  from jinns.data._Batchs import ODEBatch
35
37
  from jinns.nn._abstract_pinn import AbstractPINN
36
38
  from jinns.loss import ODE
@@ -42,22 +44,32 @@ if TYPE_CHECKING:
42
44
 
43
45
 
44
46
  class _LossODEAbstract(AbstractLoss):
45
- """
47
+ r"""
46
48
  Parameters
47
49
  ----------
48
50
 
49
51
  loss_weights : LossWeightsODE, default=None
50
52
  The loss weights for the differents term : dynamic loss,
51
- initial condition and eventually observations if any. All fields are
52
- set to 1.0 by default.
53
+ initial condition and eventually observations if any.
54
+ Can be updated according to a specific algorithm. See
55
+ `update_weight_method`
56
+ update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
57
+ Default is None meaning no update for loss weights. Otherwise a string
53
58
  derivative_keys : DerivativeKeysODE, default=None
54
59
  Specify which field of `params` should be differentiated for each
55
60
  composant of the total loss. Particularily useful for inverse problems.
56
61
  Fields can be "nn_params", "eq_params" or "both". Those that should not
57
62
  be updated will have a `jax.lax.stop_gradient` called on them. Default
58
63
  is `"nn_params"` for each composant of the loss.
59
- initial_condition : tuple[float | Float[Array, " 1"], Float[Array, " dim"]], default=None
60
- tuple of length 2 with initial condition $(t_0, u_0)$.
64
+ initial_condition : tuple[
65
+ Float[Array, "n_cond "],
66
+ Float[Array, "n_cond dim"]
67
+ ] |
68
+ tuple[int | float | Float[Array, " "],
69
+ int | float | Float[Array, " dim"]
70
+ ] | None, default=None
71
+ Most of the time, a tuple of length 2 with initial condition $(t_0, u_0)$.
72
+ From jinns v1.5.1 we accept tuples of jnp arrays with shape (n_cond, 1) for t0 and (n_cond, dim) for u0. This is useful to include observed conditions at different time points, such as *e.g* final conditions. It was designed to implement $\mathcal{L}^{aux}$ from _Systems biology informed deep learning for inferring parameters and hidden dynamics_, Alireza Yazdani et al., 2020
61
73
  obs_slice : EllipsisType | slice | None, default=None
62
74
  Slice object specifying the begininning/ending
63
75
  slice of u output(s) that is observed. This is useful for
@@ -74,7 +86,9 @@ class _LossODEAbstract(AbstractLoss):
74
86
  derivative_keys: DerivativeKeysODE | None = eqx.field(kw_only=True, default=None)
75
87
  loss_weights: LossWeightsODE | None = eqx.field(kw_only=True, default=None)
76
88
  initial_condition: (
77
- tuple[float | Float[Array, " 1"], Float[Array, " dim"]] | None
89
+ tuple[Float[Array, " n_cond 1"], Float[Array, " n_cond dim"]]
90
+ | tuple[int | float | Float[Array, " "], int | float | Float[Array, " dim"]]
91
+ | None
78
92
  ) = eqx.field(kw_only=True, default=None)
79
93
  obs_slice: EllipsisType | slice | None = eqx.field(
80
94
  kw_only=True, default=None, static=True
@@ -108,18 +122,60 @@ class _LossODEAbstract(AbstractLoss):
108
122
  "Initial condition should be a tuple of len 2 with (t0, u0), "
109
123
  f"{self.initial_condition} was passed."
110
124
  )
111
- # some checks/reshaping for t0
125
+ # some checks/reshaping for t0 and u0
112
126
  t0, u0 = self.initial_condition
113
127
  if isinstance(t0, Array):
114
- if not t0.shape: # e.g. user input: jnp.array(0.)
115
- t0 = jnp.array([t0])
116
- elif t0.shape != (1,):
128
+ # at the end we want to end up with t0 of shape (:, 1) to account for
129
+ # possibly several data points
130
+ if t0.ndim <= 1:
131
+ # in this case we assume t0 belongs one (initial)
132
+ # condition
133
+ t0 = initial_condition_check(t0, dim_size=1)[
134
+ None, :
135
+ ] # make a (1, 1) here
136
+ if t0.ndim > 2:
137
+ raise ValueError(
138
+ "It t0 is an Array, it represents n_cond"
139
+ " imposed conditions and must be of shape (n_cond, 1)"
140
+ )
141
+ else:
142
+ # in this case t0 clearly represents one (initial) condition
143
+ t0 = initial_condition_check(t0, dim_size=1)[
144
+ None, :
145
+ ] # make a (1, 1) here
146
+ if isinstance(u0, Array):
147
+ # at the end we want to end up with u0 of shape (:, dim) to account for
148
+ # possibly several data points
149
+ if not u0.shape:
150
+ # in this case we assume u0 belongs to one (initial)
151
+ # condition
152
+ u0 = initial_condition_check(u0, dim_size=1)[
153
+ None, :
154
+ ] # make a (1, 1) here
155
+ elif u0.ndim == 1:
156
+ # in this case we assume u0 belongs to one (initial)
157
+ # condition
158
+ u0 = initial_condition_check(u0, dim_size=u0.shape[0])[
159
+ None, :
160
+ ] # make a (1, dim) here
161
+ if u0.ndim > 2:
117
162
  raise ValueError(
118
- f"Wrong t0 input (self.initial_condition[0]) It should be"
119
- f"a float or an array of shape (1,). Got shape: {t0.shape}"
163
+ "It u0 is an Array, it represents n_cond "
164
+ "imposed conditions and must be of shape (n_cond, dim)"
120
165
  )
121
- if isinstance(t0, float): # e.g. user input: 0
122
- t0 = jnp.array([t0])
166
+ else:
167
+ # at the end we want to end up with u0 of shape (:, dim) to account for
168
+ # possibly several data points
169
+ u0 = initial_condition_check(u0, dim_size=None)[
170
+ None, :
171
+ ] # make a (1, 1) here
172
+
173
+ if t0.shape[0] != u0.shape[0] or t0.ndim != u0.ndim:
174
+ raise ValueError(
175
+ "t0 and u0 must represent a same number of initial"
176
+ " conditial conditions"
177
+ )
178
+
123
179
  self.initial_condition = (t0, u0)
124
180
 
125
181
  if self.obs_slice is None:
@@ -154,8 +210,11 @@ class LossODE(_LossODEAbstract):
154
210
  ----------
155
211
  loss_weights : LossWeightsODE, default=None
156
212
  The loss weights for the differents term : dynamic loss,
157
- initial condition and eventually observations if any. All fields are
158
- set to 1.0 by default.
213
+ initial condition and eventually observations if any.
214
+ Can be updated according to a specific algorithm. See
215
+ `update_weight_method`
216
+ update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
217
+ Default is None meaning no update for loss weights. Otherwise a string
159
218
  derivative_keys : DerivativeKeysODE, default=None
160
219
  Specify which field of `params` should be differentiated for each
161
220
  composant of the total loss. Particularily useful for inverse problems.
@@ -204,21 +263,26 @@ class LossODE(_LossODEAbstract):
204
263
  def __call__(self, *args, **kwargs):
205
264
  return self.evaluate(*args, **kwargs)
206
265
 
207
- def evaluate(
266
+ def evaluate_by_terms(
208
267
  self, params: Params[Array], batch: ODEBatch
209
- ) -> tuple[Float[Array, " "], LossDictODE]:
268
+ ) -> tuple[
269
+ ODEComponents[Float[Array, " "] | None], ODEComponents[Float[Array, " "] | None]
270
+ ]:
210
271
  """
211
272
  Evaluate the loss function at a batch of points for given parameters.
212
273
 
274
+ We retrieve two PyTrees with loss values and gradients for each term
213
275
 
214
276
  Parameters
215
277
  ---------
216
278
  params
217
279
  Parameters at which the loss is evaluated
218
280
  batch
219
- Composed of a batch of time points
220
- at which to evaluate the differential operator. An optional additional batch of parameters (eg. for metamodeling) and an optional additional batch of observed inputs/outputs/parameters can
221
- be supplied.
281
+ Composed of a batch of points in the
282
+ domain, a batch of points in the domain
283
+ border and an optional additional batch of parameters (eg. for
284
+ metamodeling) and an optional additional batch of observed
285
+ inputs/outputs/parameters
222
286
  """
223
287
  temporal_batch = batch.temporal_batch
224
288
 
@@ -233,71 +297,120 @@ class LossODE(_LossODEAbstract):
233
297
 
234
298
  ## dynamic part
235
299
  if self.dynamic_loss is not None:
236
- mse_dyn_loss = dynamic_loss_apply(
237
- self.dynamic_loss.evaluate,
300
+ dyn_loss_fun = lambda p: dynamic_loss_apply(
301
+ self.dynamic_loss.evaluate, # type: ignore
238
302
  self.u,
239
303
  temporal_batch,
240
- _set_derivatives(params, self.derivative_keys.dyn_loss), # type: ignore
304
+ _set_derivatives(p, self.derivative_keys.dyn_loss), # type: ignore
241
305
  self.vmap_in_axes + vmap_in_axes_params,
242
- self.loss_weights.dyn_loss, # type: ignore
243
306
  )
244
307
  else:
245
- mse_dyn_loss = jnp.array(0.0)
308
+ dyn_loss_fun = None
246
309
 
247
310
  # initial condition
248
311
  if self.initial_condition is not None:
249
- vmap_in_axes = (None,) + vmap_in_axes_params
250
- if not jax.tree_util.tree_leaves(vmap_in_axes):
251
- # test if only None in vmap_in_axes to avoid the value error:
252
- # `vmap must have at least one non-None value in in_axes`
253
- v_u = self.u
254
- else:
255
- v_u = vmap(self.u, (None,) + vmap_in_axes_params)
256
312
  t0, u0 = self.initial_condition
257
313
  u0 = jnp.array(u0)
258
- mse_initial_condition = jnp.mean(
259
- self.loss_weights.initial_condition # type: ignore
260
- * jnp.sum(
261
- (
262
- v_u(
263
- t0,
264
- _set_derivatives(
265
- params,
266
- self.derivative_keys.initial_condition, # type: ignore
267
- ),
268
- )
269
- - u0
314
+
315
+ # first construct the plain init loss no vmaping
316
+ initial_condition_fun__ = lambda t, u, p: jnp.sum(
317
+ (
318
+ self.u(
319
+ t,
320
+ _set_derivatives(
321
+ p,
322
+ self.derivative_keys.initial_condition, # type: ignore
323
+ ),
270
324
  )
271
- ** 2,
272
- axis=-1,
325
+ - u
273
326
  )
327
+ ** 2,
328
+ axis=0,
329
+ )
330
+ # now vmap over the number of conditions (first dim of t0 and u0)
331
+ # and take the mean
332
+ initial_condition_fun_ = lambda p: jnp.mean(
333
+ vmap(initial_condition_fun__, (0, 0, None))(t0, u0, p)
274
334
  )
335
+ # now vmap over the the possible batch of parameters and take the
336
+ # average. Note that we then finally have a cartesian product
337
+ # between the batch of parameters (if any) and the number of
338
+ # conditions (if any)
339
+ if not jax.tree_util.tree_leaves(vmap_in_axes_params):
340
+ # if there is no parameter batch to vmap over we cannot call
341
+ # vmap because calling vmap must be done with at least one non
342
+ # None in_axes or out_axes
343
+ initial_condition_fun = initial_condition_fun_
344
+ else:
345
+ initial_condition_fun = lambda p: jnp.mean(
346
+ vmap(initial_condition_fun_, vmap_in_axes_params)(p)
347
+ )
275
348
  else:
276
- mse_initial_condition = jnp.array(0.0)
349
+ initial_condition_fun = None
277
350
 
278
351
  if batch.obs_batch_dict is not None:
279
352
  # update params with the batches of observed params
280
- params = _update_eq_params_dict(params, batch.obs_batch_dict["eq_params"])
353
+ params_obs = _update_eq_params_dict(
354
+ params, batch.obs_batch_dict["eq_params"]
355
+ )
281
356
 
282
357
  # MSE loss wrt to an observed batch
283
- mse_observation_loss = observations_loss_apply(
358
+ obs_loss_fun = lambda po: observations_loss_apply(
284
359
  self.u,
285
360
  batch.obs_batch_dict["pinn_in"],
286
- _set_derivatives(params, self.derivative_keys.observations), # type: ignore
361
+ _set_derivatives(po, self.derivative_keys.observations), # type: ignore
287
362
  self.vmap_in_axes + vmap_in_axes_params,
288
363
  batch.obs_batch_dict["val"],
289
- self.loss_weights.observations, # type: ignore
290
364
  self.obs_slice,
291
365
  )
292
366
  else:
293
- mse_observation_loss = jnp.array(0.0)
294
-
295
- # total loss
296
- total_loss = mse_dyn_loss + mse_initial_condition + mse_observation_loss
297
- return total_loss, (
298
- {
299
- "dyn_loss": mse_dyn_loss,
300
- "initial_condition": mse_initial_condition,
301
- "observations": mse_observation_loss,
302
- }
367
+ params_obs = None
368
+ obs_loss_fun = None
369
+
370
+ # get the unweighted mses for each loss term as well as the gradients
371
+ all_funs: ODEComponents[Callable[[Params[Array]], Array] | None] = (
372
+ ODEComponents(dyn_loss_fun, initial_condition_fun, obs_loss_fun)
373
+ )
374
+ all_params: ODEComponents[Params[Array] | None] = ODEComponents(
375
+ params, params, params_obs
303
376
  )
377
+ mses_grads = jax.tree.map(
378
+ lambda fun, params: self.get_gradients(fun, params),
379
+ all_funs,
380
+ all_params,
381
+ is_leaf=lambda x: x is None,
382
+ )
383
+
384
+ mses = jax.tree.map(
385
+ lambda leaf: leaf[0], mses_grads, is_leaf=lambda x: isinstance(x, tuple)
386
+ )
387
+ grads = jax.tree.map(
388
+ lambda leaf: leaf[1], mses_grads, is_leaf=lambda x: isinstance(x, tuple)
389
+ )
390
+
391
+ return mses, grads
392
+
393
+ def evaluate(
394
+ self, params: Params[Array], batch: ODEBatch
395
+ ) -> tuple[Float[Array, " "], ODEComponents[Float[Array, " "] | None]]:
396
+ """
397
+ Evaluate the loss function at a batch of points for given parameters.
398
+
399
+ We retrieve the total value itself and a PyTree with loss values for each term
400
+
401
+ Parameters
402
+ ---------
403
+ params
404
+ Parameters at which the loss is evaluated
405
+ batch
406
+ Composed of a batch of points in the
407
+ domain, a batch of points in the domain
408
+ border and an optional additional batch of parameters (eg. for
409
+ metamodeling) and an optional additional batch of observed
410
+ inputs/outputs/parameters
411
+ """
412
+ loss_terms, _ = self.evaluate_by_terms(params, batch)
413
+
414
+ loss_val = self.ponderate_and_sum_loss(loss_terms)
415
+
416
+ return loss_val, loss_terms