jinns 1.7.0__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,6 +16,7 @@ from jaxtyping import Float, Array, PyTree
16
16
  import jax
17
17
  import jax.numpy as jnp
18
18
  from jinns.parameters._params import EqParams
19
+ from jinns.nn import SPINN
19
20
 
20
21
 
21
22
  # See : https://docs.kidger.site/equinox/api/module/advanced_fields/#equinox.AbstractClassVar--known-issues
@@ -38,6 +39,7 @@ def _decorator_heteregeneous_params(evaluate):
38
39
  self._eval_heterogeneous_parameters(
39
40
  inputs, u, params, self.eq_params_heterogeneity
40
41
  ),
42
+ is_leaf=lambda x: x is None,
41
43
  )
42
44
  new_args = args[:-1] + (_params,)
43
45
  res = evaluate(*new_args)
@@ -152,7 +154,7 @@ class DynamicLoss(eqx.Module, Generic[InputDim]):
152
154
  "The output of dynamic loss must be vectorial, "
153
155
  "i.e. of shape (d,) with d >= 1"
154
156
  )
155
- if len(evaluation.shape) > 1:
157
+ if len(evaluation.shape) > 1 and not isinstance(u, SPINN):
156
158
  warnings.warn(
157
159
  "Return value from DynamicLoss' equation has more "
158
160
  "than one dimension. This is in general a mistake (probably from "
jinns/loss/_LossODE.py CHANGED
@@ -28,13 +28,13 @@ from jinns.parameters._derivative_keys import _set_derivatives, DerivativeKeysOD
28
28
  from jinns.loss._loss_weights import LossWeightsODE
29
29
  from jinns.loss._abstract_loss import AbstractLoss
30
30
  from jinns.loss._loss_components import ODEComponents
31
+ from jinns.loss import ODE
31
32
  from jinns.parameters._params import Params
32
33
  from jinns.data._Batchs import ODEBatch
33
34
 
34
35
  if TYPE_CHECKING:
35
36
  # imports only used in type hints
36
37
  from jinns.nn._abstract_pinn import AbstractPINN
37
- from jinns.loss import ODE
38
38
 
39
39
  InitialConditionUser = (
40
40
  tuple[Float[Array, " n_cond "], Float[Array, " n_cond dim"]]
@@ -65,39 +65,42 @@ class LossODE(
65
65
  ----------
66
66
  u : eqx.Module
67
67
  the PINN
68
- dynamic_loss : ODE
68
+ dynamic_loss : tuple[ODE, ...] | ODE | None
69
69
  the ODE dynamic part of the loss, basically the differential
70
70
  operator $\mathcal{N}[u](t)$. Should implement a method
71
71
  `dynamic_loss.evaluate(t, u, params)`.
72
72
  Can be None in order to access only some part of the evaluate call.
73
- loss_weights : LossWeightsODE, default=None
73
+ loss_weights : LossWeightsODE | None, default=None
74
74
  The loss weights for the differents term : dynamic loss,
75
75
  initial condition and eventually observations if any.
76
76
  Can be updated according to a specific algorithm. See
77
77
  `update_weight_method`
78
- update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
78
+ update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'] | None, default=None
79
79
  Default is None meaning no update for loss weights. Otherwise a string
80
- derivative_keys : DerivativeKeysODE, default=None
80
+ keep_initial_loss_weight_scales : bool, default=True
81
+ Only used if an update weight method is specified. It decides whether
82
+ the updated loss weights are multiplied by the initial `loss_weights`
83
+ passed by the user at initialization. This is useful to force some
84
+ scale difference between the adaptative loss weights even after the
85
+ update method is applied.
86
+ derivative_keys : DerivativeKeysODE | None, default=None
81
87
  Specify which field of `params` should be differentiated for each
82
88
  composant of the total loss. Particularily useful for inverse problems.
83
89
  Fields can be "nn_params", "eq_params" or "both". Those that should not
84
90
  be updated will have a `jax.lax.stop_gradient` called on them. Default
85
91
  is `"nn_params"` for each composant of the loss.
86
- initial_condition : tuple[
87
- Float[Array, "n_cond "],
88
- Float[Array, "n_cond dim"]
89
- ] |
90
- tuple[int | float | Float[Array, " "],
91
- int | float | Float[Array, " dim"]
92
- ], default=None
92
+ initial_condition : InitialConditionUser | None, default=None
93
93
  Most of the time, a tuple of length 2 with initial condition $(t_0, u_0)$.
94
94
  From jinns v1.5.1 we accept tuples of jnp arrays with shape (n_cond, 1) for t0 and (n_cond, dim) for u0. This is useful to include observed conditions at different time points, such as *e.g* final conditions. It was designed to implement $\mathcal{L}^{aux}$ from _Systems biology informed deep learning for inferring parameters and hidden dynamics_, Alireza Yazdani et al., 2020
95
- obs_slice : EllipsisType | slice, default=None
95
+ obs_slice : tuple[EllipsisType | slice, ...] | EllipsisType | slice | None, default=None
96
96
  Slice object specifying the begininning/ending
97
97
  slice of u output(s) that is observed. This is useful for
98
98
  multidimensional PINN, with partially observed outputs.
99
99
  Default is None (whole output is observed).
100
- params : InitVar[Params[Array]], default=None
100
+ **Note**: If several observation datasets are passed this arguments need to be set as a
101
+ tuple of jnp.slice objects with the same length as the number of
102
+ observation datasets
103
+ params : InitVar[Params[Array]] | None, default=None
101
104
  The main Params object of the problem needed to instanciate the
102
105
  DerivativeKeysODE if the latter is not specified.
103
106
  Raises
@@ -109,22 +112,26 @@ class LossODE(
109
112
  # NOTE static=True only for leaf attributes that are not valid JAX types
110
113
  # (ie. jax.Array cannot be static) and that we do not expect to change
111
114
  u: AbstractPINN
112
- dynamic_loss: ODE | None
115
+ dynamic_loss: tuple[ODE | None, ...]
116
+ vmap_in_axes: tuple[int] = eqx.field(static=True)
113
117
  derivative_keys: DerivativeKeysODE
114
118
  loss_weights: LossWeightsODE
115
119
  initial_condition: InitialCondition | None
116
- obs_slice: EllipsisType | slice = eqx.field(static=True)
120
+ obs_slice: tuple[EllipsisType | slice, ...] = eqx.field(static=True)
117
121
  params: InitVar[Params[Array] | None]
118
122
 
119
123
  def __init__(
120
124
  self,
121
125
  *,
122
126
  u: AbstractPINN,
123
- dynamic_loss: ODE | None,
127
+ dynamic_loss: tuple[ODE, ...] | ODE | None,
124
128
  loss_weights: LossWeightsODE | None = None,
125
129
  derivative_keys: DerivativeKeysODE | None = None,
126
130
  initial_condition: InitialConditionUser | None = None,
127
- obs_slice: EllipsisType | slice | None = None,
131
+ obs_slice: tuple[EllipsisType | slice, ...]
132
+ | EllipsisType
133
+ | slice
134
+ | None = None,
128
135
  params: Params[Array] | None = None,
129
136
  **kwargs: Any, # this is for arguments for super()
130
137
  ):
@@ -141,8 +148,6 @@ class LossODE(
141
148
  f"received {derivative_keys=} and {params=}"
142
149
  )
143
150
  derivative_keys = DerivativeKeysODE(params=params)
144
- else:
145
- derivative_keys = derivative_keys
146
151
 
147
152
  super().__init__(
148
153
  loss_weights=self.loss_weights,
@@ -151,7 +156,10 @@ class LossODE(
151
156
  **kwargs,
152
157
  )
153
158
  self.u = u
154
- self.dynamic_loss = dynamic_loss
159
+ if not isinstance(dynamic_loss, tuple):
160
+ self.dynamic_loss = (dynamic_loss,)
161
+ else:
162
+ self.dynamic_loss = dynamic_loss
155
163
  if self.update_weight_method is not None and jnp.any(
156
164
  jnp.array(jax.tree.leaves(self.loss_weights)) == 0
157
165
  ):
@@ -231,10 +239,15 @@ class LossODE(
231
239
  self.initial_condition = (t0, u0)
232
240
 
233
241
  if obs_slice is None:
234
- self.obs_slice = jnp.s_[...]
242
+ self.obs_slice = (jnp.s_[...],)
243
+ elif not isinstance(obs_slice, tuple):
244
+ self.obs_slice = (obs_slice,)
235
245
  else:
236
246
  self.obs_slice = obs_slice
237
247
 
248
+ if self.loss_weights is None:
249
+ self.loss_weights = LossWeightsODE()
250
+
238
251
  def evaluate_by_terms(
239
252
  self,
240
253
  opt_params: Params[Array],
@@ -280,23 +293,29 @@ class LossODE(
280
293
  cast(eqx.Module, batch.param_batch_dict), params
281
294
  )
282
295
 
283
- ## dynamic part
284
- if self.dynamic_loss is not None:
285
- dyn_loss_eval = self.dynamic_loss.evaluate
286
- dyn_loss_fun: Callable[[Params[Array]], Array] | None = (
287
- lambda p: dynamic_loss_apply(
288
- dyn_loss_eval,
289
- self.u,
290
- temporal_batch,
291
- _set_derivatives(p, self.derivative_keys.dyn_loss),
292
- self.vmap_in_axes + vmap_in_axes_params,
296
+ if self.dynamic_loss != (None,):
297
+ # Note, for the record, multiple dynamic losses
298
+ # have been introduced in MR 92
299
+ dyn_loss_fun: tuple[Callable[[Params[Array]], Array], ...] | None = (
300
+ jax.tree.map(
301
+ lambda d: lambda p: dynamic_loss_apply(
302
+ d.evaluate,
303
+ self.u,
304
+ temporal_batch,
305
+ _set_derivatives(p, self.derivative_keys.dyn_loss),
306
+ self.vmap_in_axes + vmap_in_axes_params,
307
+ ),
308
+ self.dynamic_loss,
309
+ is_leaf=lambda x: isinstance(x, ODE), # do not traverse
310
+ # further than first level
293
311
  )
294
312
  )
295
313
  else:
296
314
  dyn_loss_fun = None
297
315
 
298
316
  if self.initial_condition is not None:
299
- # initial condition
317
+ # Note, for the record, multiple initial conditions for LossODEs
318
+ # have been introduced in MR 77
300
319
  t0, u0 = self.initial_condition
301
320
 
302
321
  # first construct the plain init loss no vmaping
@@ -331,34 +350,50 @@ class LossODE(
331
350
  # if there is no parameter batch to vmap over we cannot call
332
351
  # vmap because calling vmap must be done with at least one non
333
352
  # None in_axes or out_axes
334
- initial_condition_fun = initial_condition_fun_
353
+ initial_condition_fun = (initial_condition_fun_,)
335
354
  else:
336
- initial_condition_fun: Callable[[Params[Array]], Array] | None = (
355
+ initial_condition_fun: (
356
+ tuple[Callable[[Params[Array]], Array], ...] | None
357
+ ) = (
337
358
  lambda p: jnp.mean(
338
359
  vmap(initial_condition_fun_, vmap_in_axes_params)(p)
339
- )
360
+ ),
340
361
  )
362
+ # Note that since MR 92
363
+ # initial_condition_fun is formed as a tuple for
364
+ # consistency with dynamic and observation
365
+ # losses and more modularity for later
341
366
  else:
342
367
  initial_condition_fun = None
343
368
 
344
369
  if batch.obs_batch_dict is not None:
345
- # update params with the batches of observed params
346
- params_obs = update_eq_params(params, batch.obs_batch_dict["eq_params"])
347
-
348
- pinn_in, val = (
349
- batch.obs_batch_dict["pinn_in"],
350
- batch.obs_batch_dict["val"],
351
- ) # the reason for this intruction is https://github.com/microsoft/pyright/discussions/8340
352
-
353
- # MSE loss wrt to an observed batch
354
- obs_loss_fun: Callable[[Params[Array]], Array] | None = (
355
- lambda po: observations_loss_apply(
356
- self.u,
357
- pinn_in,
358
- _set_derivatives(po, self.derivative_keys.observations),
359
- self.vmap_in_axes + vmap_in_axes_params,
360
- val,
370
+ # Note, for the record, multiple DGObs
371
+ # (leading to batch.obs_batch_dict being tuple | None)
372
+ # have been introduced in MR 92
373
+ if len(batch.obs_batch_dict) != len(self.obs_slice):
374
+ raise ValueError(
375
+ "There must be the same number of "
376
+ "observation datasets as the number of "
377
+ "obs_slice"
378
+ )
379
+ params_obs = jax.tree.map(
380
+ lambda d: update_eq_params(params, d["eq_params"]),
381
+ batch.obs_batch_dict,
382
+ is_leaf=lambda x: isinstance(x, dict),
383
+ )
384
+ obs_loss_fun: tuple[Callable[[Params[Array]], Array], ...] | None = (
385
+ jax.tree.map(
386
+ lambda d, slice_: lambda po: observations_loss_apply(
387
+ self.u,
388
+ d["pinn_in"],
389
+ _set_derivatives(po, self.derivative_keys.observations),
390
+ self.vmap_in_axes + vmap_in_axes_params,
391
+ d["val"],
392
+ slice_,
393
+ ),
394
+ batch.obs_batch_dict,
361
395
  self.obs_slice,
396
+ is_leaf=lambda x: isinstance(x, dict),
362
397
  )
363
398
  )
364
399
  else:
@@ -366,33 +401,36 @@ class LossODE(
366
401
  obs_loss_fun = None
367
402
 
368
403
  # get the unweighted mses for each loss term as well as the gradients
369
- all_funs: ODEComponents[Callable[[Params[Array]], Array] | None] = (
404
+ all_funs: ODEComponents[tuple[Callable[[Params[Array]], Array], ...] | None] = (
370
405
  ODEComponents(dyn_loss_fun, initial_condition_fun, obs_loss_fun)
371
406
  )
372
- all_params: ODEComponents[Params[Array] | None] = ODEComponents(
373
- params, params, params_obs
407
+ all_params: ODEComponents[tuple[Params[Array], ...] | None] = ODEComponents(
408
+ jax.tree.map(lambda l: params, dyn_loss_fun),
409
+ jax.tree.map(lambda l: params, initial_condition_fun),
410
+ params_obs,
374
411
  )
375
412
 
376
- # Note that the lambda functions below are with type: ignore just
377
- # because the lambda are not type annotated, but there is no proper way
378
- # to do this and we should assign the lambda to a type hinted variable
379
- # before hand: this is not practical, let us not get mad at this
380
413
  mses_grads = jax.tree.map(
381
414
  self.get_gradients,
382
415
  all_funs,
383
416
  all_params,
384
417
  is_leaf=lambda x: x is None,
385
418
  )
386
-
419
+ # NOTE the is_leaf below is more complex since it must pass possible the tuple
420
+ # of dyn_loss and then stops (but also account it should not stop when
421
+ # the tuple of dyn_loss is of length 2)
387
422
  mses = jax.tree.map(
388
- lambda leaf: leaf[0], # type: ignore
423
+ lambda leaf: leaf[0],
389
424
  mses_grads,
390
- is_leaf=lambda x: isinstance(x, tuple),
425
+ is_leaf=lambda x: isinstance(x, tuple)
426
+ and len(x) == 2
427
+ and isinstance(x[1], Params),
391
428
  )
392
429
  grads = jax.tree.map(
393
- lambda leaf: leaf[1], # type: ignore
430
+ lambda leaf: leaf[1],
394
431
  mses_grads,
395
- is_leaf=lambda x: isinstance(x, tuple),
432
+ is_leaf=lambda x: isinstance(x, tuple)
433
+ and len(x) == 2
434
+ and isinstance(x[1], Params),
396
435
  )
397
-
398
436
  return mses, grads