jinns 1.6.1__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,6 +16,7 @@ from jaxtyping import Float, Array, PyTree
16
16
  import jax
17
17
  import jax.numpy as jnp
18
18
  from jinns.parameters._params import EqParams
19
+ from jinns.nn import SPINN
19
20
 
20
21
 
21
22
  # See : https://docs.kidger.site/equinox/api/module/advanced_fields/#equinox.AbstractClassVar--known-issues
@@ -38,6 +39,7 @@ def _decorator_heteregeneous_params(evaluate):
38
39
  self._eval_heterogeneous_parameters(
39
40
  inputs, u, params, self.eq_params_heterogeneity
40
41
  ),
42
+ is_leaf=lambda x: x is None,
41
43
  )
42
44
  new_args = args[:-1] + (_params,)
43
45
  res = evaluate(*new_args)
@@ -152,7 +154,7 @@ class DynamicLoss(eqx.Module, Generic[InputDim]):
152
154
  "The output of dynamic loss must be vectorial, "
153
155
  "i.e. of shape (d,) with d >= 1"
154
156
  )
155
- if len(evaluation.shape) > 1:
157
+ if len(evaluation.shape) > 1 and not isinstance(u, SPINN):
156
158
  warnings.warn(
157
159
  "Return value from DynamicLoss' equation has more "
158
160
  "than one dimension. This is in general a mistake (probably from "
jinns/loss/_LossODE.py CHANGED
@@ -28,13 +28,13 @@ from jinns.parameters._derivative_keys import _set_derivatives, DerivativeKeysOD
28
28
  from jinns.loss._loss_weights import LossWeightsODE
29
29
  from jinns.loss._abstract_loss import AbstractLoss
30
30
  from jinns.loss._loss_components import ODEComponents
31
+ from jinns.loss import ODE
31
32
  from jinns.parameters._params import Params
32
33
  from jinns.data._Batchs import ODEBatch
33
34
 
34
35
  if TYPE_CHECKING:
35
36
  # imports only used in type hints
36
37
  from jinns.nn._abstract_pinn import AbstractPINN
37
- from jinns.loss import ODE
38
38
 
39
39
  InitialConditionUser = (
40
40
  tuple[Float[Array, " n_cond "], Float[Array, " n_cond dim"]]
@@ -47,7 +47,11 @@ if TYPE_CHECKING:
47
47
  )
48
48
 
49
49
 
50
- class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]]):
50
+ class LossODE(
51
+ AbstractLoss[
52
+ LossWeightsODE, ODEBatch, ODEComponents[Array | None], DerivativeKeysODE
53
+ ]
54
+ ):
51
55
  r"""Loss object for an ordinary differential equation
52
56
 
53
57
  $$
@@ -57,44 +61,46 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
57
61
  where $\mathcal{N}[\cdot]$ is a differential operator and the
58
62
  initial condition is $u(t_0)=u_0$.
59
63
 
60
-
61
64
  Parameters
62
65
  ----------
63
66
  u : eqx.Module
64
67
  the PINN
65
- dynamic_loss : ODE
68
+ dynamic_loss : tuple[ODE, ...] | ODE | None
66
69
  the ODE dynamic part of the loss, basically the differential
67
70
  operator $\mathcal{N}[u](t)$. Should implement a method
68
71
  `dynamic_loss.evaluate(t, u, params)`.
69
72
  Can be None in order to access only some part of the evaluate call.
70
- loss_weights : LossWeightsODE, default=None
73
+ loss_weights : LossWeightsODE | None, default=None
71
74
  The loss weights for the differents term : dynamic loss,
72
75
  initial condition and eventually observations if any.
73
76
  Can be updated according to a specific algorithm. See
74
77
  `update_weight_method`
75
- update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
78
+ update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'] | None, default=None
76
79
  Default is None meaning no update for loss weights. Otherwise a string
77
- derivative_keys : DerivativeKeysODE, default=None
80
+ keep_initial_loss_weight_scales : bool, default=True
81
+ Only used if an update weight method is specified. It decides whether
82
+ the updated loss weights are multiplied by the initial `loss_weights`
83
+ passed by the user at initialization. This is useful to force some
84
+ scale difference between the adaptative loss weights even after the
85
+ update method is applied.
86
+ derivative_keys : DerivativeKeysODE | None, default=None
78
87
  Specify which field of `params` should be differentiated for each
79
88
  composant of the total loss. Particularily useful for inverse problems.
80
89
  Fields can be "nn_params", "eq_params" or "both". Those that should not
81
90
  be updated will have a `jax.lax.stop_gradient` called on them. Default
82
91
  is `"nn_params"` for each composant of the loss.
83
- initial_condition : tuple[
84
- Float[Array, "n_cond "],
85
- Float[Array, "n_cond dim"]
86
- ] |
87
- tuple[int | float | Float[Array, " "],
88
- int | float | Float[Array, " dim"]
89
- ], default=None
92
+ initial_condition : InitialConditionUser | None, default=None
90
93
  Most of the time, a tuple of length 2 with initial condition $(t_0, u_0)$.
91
94
  From jinns v1.5.1 we accept tuples of jnp arrays with shape (n_cond, 1) for t0 and (n_cond, dim) for u0. This is useful to include observed conditions at different time points, such as *e.g* final conditions. It was designed to implement $\mathcal{L}^{aux}$ from _Systems biology informed deep learning for inferring parameters and hidden dynamics_, Alireza Yazdani et al., 2020
92
- obs_slice : EllipsisType | slice, default=None
95
+ obs_slice : tuple[EllipsisType | slice, ...] | EllipsisType | slice | None, default=None
93
96
  Slice object specifying the begininning/ending
94
97
  slice of u output(s) that is observed. This is useful for
95
98
  multidimensional PINN, with partially observed outputs.
96
99
  Default is None (whole output is observed).
97
- params : InitVar[Params[Array]], default=None
100
+ **Note**: If several observation datasets are passed this arguments need to be set as a
101
+ tuple of jnp.slice objects with the same length as the number of
102
+ observation datasets
103
+ params : InitVar[Params[Array]] | None, default=None
98
104
  The main Params object of the problem needed to instanciate the
99
105
  DerivativeKeysODE if the latter is not specified.
100
106
  Raises
@@ -106,23 +112,26 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
106
112
  # NOTE static=True only for leaf attributes that are not valid JAX types
107
113
  # (ie. jax.Array cannot be static) and that we do not expect to change
108
114
  u: AbstractPINN
109
- dynamic_loss: ODE | None
115
+ dynamic_loss: tuple[ODE | None, ...]
110
116
  vmap_in_axes: tuple[int] = eqx.field(static=True)
111
117
  derivative_keys: DerivativeKeysODE
112
118
  loss_weights: LossWeightsODE
113
119
  initial_condition: InitialCondition | None
114
- obs_slice: EllipsisType | slice = eqx.field(static=True)
120
+ obs_slice: tuple[EllipsisType | slice, ...] = eqx.field(static=True)
115
121
  params: InitVar[Params[Array] | None]
116
122
 
117
123
  def __init__(
118
124
  self,
119
125
  *,
120
126
  u: AbstractPINN,
121
- dynamic_loss: ODE | None,
127
+ dynamic_loss: tuple[ODE, ...] | ODE | None,
122
128
  loss_weights: LossWeightsODE | None = None,
123
129
  derivative_keys: DerivativeKeysODE | None = None,
124
130
  initial_condition: InitialConditionUser | None = None,
125
- obs_slice: EllipsisType | slice | None = None,
131
+ obs_slice: tuple[EllipsisType | slice, ...]
132
+ | EllipsisType
133
+ | slice
134
+ | None = None,
126
135
  params: Params[Array] | None = None,
127
136
  **kwargs: Any, # this is for arguments for super()
128
137
  ):
@@ -131,10 +140,6 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
131
140
  else:
132
141
  self.loss_weights = loss_weights
133
142
 
134
- super().__init__(loss_weights=self.loss_weights, **kwargs)
135
- self.u = u
136
- self.dynamic_loss = dynamic_loss
137
- self.vmap_in_axes = (0,)
138
143
  if derivative_keys is None:
139
144
  # by default we only take gradient wrt nn_params
140
145
  if params is None:
@@ -142,9 +147,28 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
142
147
  "Problem at derivative_keys initialization "
143
148
  f"received {derivative_keys=} and {params=}"
144
149
  )
145
- self.derivative_keys = DerivativeKeysODE(params=params)
150
+ derivative_keys = DerivativeKeysODE(params=params)
151
+
152
+ super().__init__(
153
+ loss_weights=self.loss_weights,
154
+ derivative_keys=derivative_keys,
155
+ vmap_in_axes=(0,),
156
+ **kwargs,
157
+ )
158
+ self.u = u
159
+ if not isinstance(dynamic_loss, tuple):
160
+ self.dynamic_loss = (dynamic_loss,)
146
161
  else:
147
- self.derivative_keys = derivative_keys
162
+ self.dynamic_loss = dynamic_loss
163
+ if self.update_weight_method is not None and jnp.any(
164
+ jnp.array(jax.tree.leaves(self.loss_weights)) == 0
165
+ ):
166
+ warnings.warn(
167
+ "self.update_weight_method is activated while some loss "
168
+ "weights are zero. The update weight method will likely "
169
+ "update the zero weight to some non-zero value. Check that "
170
+ "this is the desired behaviour."
171
+ )
148
172
 
149
173
  if initial_condition is None:
150
174
  warnings.warn(
@@ -215,12 +239,21 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
215
239
  self.initial_condition = (t0, u0)
216
240
 
217
241
  if obs_slice is None:
218
- self.obs_slice = jnp.s_[...]
242
+ self.obs_slice = (jnp.s_[...],)
243
+ elif not isinstance(obs_slice, tuple):
244
+ self.obs_slice = (obs_slice,)
219
245
  else:
220
246
  self.obs_slice = obs_slice
221
247
 
248
+ if self.loss_weights is None:
249
+ self.loss_weights = LossWeightsODE()
250
+
222
251
  def evaluate_by_terms(
223
- self, params: Params[Array], batch: ODEBatch
252
+ self,
253
+ opt_params: Params[Array],
254
+ batch: ODEBatch,
255
+ *,
256
+ non_opt_params: Params[Array] | None = None,
224
257
  ) -> tuple[
225
258
  ODEComponents[Float[Array, " "] | None], ODEComponents[Float[Array, " "] | None]
226
259
  ]:
@@ -231,15 +264,22 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
231
264
 
232
265
  Parameters
233
266
  ---------
234
- params
235
- Parameters at which the loss is evaluated
267
+ opt_params
268
+ Parameters, which are optimized, at which the loss is evaluated
236
269
  batch
237
270
  Composed of a batch of points in the
238
271
  domain, a batch of points in the domain
239
272
  border and an optional additional batch of parameters (eg. for
240
273
  metamodeling) and an optional additional batch of observed
241
274
  inputs/outputs/parameters
275
+ non_opt_params
276
+ Parameters, which are not optimized, at which the loss is evaluated
242
277
  """
278
+ if non_opt_params is not None:
279
+ params = eqx.combine(opt_params, non_opt_params)
280
+ else:
281
+ params = opt_params
282
+
243
283
  temporal_batch = batch.temporal_batch
244
284
 
245
285
  # Retrieve the optional eq_params_batch
@@ -253,23 +293,29 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
253
293
  cast(eqx.Module, batch.param_batch_dict), params
254
294
  )
255
295
 
256
- ## dynamic part
257
- if self.dynamic_loss is not None:
258
- dyn_loss_eval = self.dynamic_loss.evaluate
259
- dyn_loss_fun: Callable[[Params[Array]], Array] | None = (
260
- lambda p: dynamic_loss_apply(
261
- dyn_loss_eval,
262
- self.u,
263
- temporal_batch,
264
- _set_derivatives(p, self.derivative_keys.dyn_loss),
265
- self.vmap_in_axes + vmap_in_axes_params,
296
+ if self.dynamic_loss != (None,):
297
+ # Note, for the record, multiple dynamic losses
298
+ # have been introduced in MR 92
299
+ dyn_loss_fun: tuple[Callable[[Params[Array]], Array], ...] | None = (
300
+ jax.tree.map(
301
+ lambda d: lambda p: dynamic_loss_apply(
302
+ d.evaluate,
303
+ self.u,
304
+ temporal_batch,
305
+ _set_derivatives(p, self.derivative_keys.dyn_loss),
306
+ self.vmap_in_axes + vmap_in_axes_params,
307
+ ),
308
+ self.dynamic_loss,
309
+ is_leaf=lambda x: isinstance(x, ODE), # do not traverse
310
+ # further than first level
266
311
  )
267
312
  )
268
313
  else:
269
314
  dyn_loss_fun = None
270
315
 
271
316
  if self.initial_condition is not None:
272
- # initial condition
317
+ # Note, for the record, multiple initial conditions for LossODEs
318
+ # have been introduced in MR 77
273
319
  t0, u0 = self.initial_condition
274
320
 
275
321
  # first construct the plain init loss no vmaping
@@ -304,34 +350,50 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
304
350
  # if there is no parameter batch to vmap over we cannot call
305
351
  # vmap because calling vmap must be done with at least one non
306
352
  # None in_axes or out_axes
307
- initial_condition_fun = initial_condition_fun_
353
+ initial_condition_fun = (initial_condition_fun_,)
308
354
  else:
309
- initial_condition_fun: Callable[[Params[Array]], Array] | None = (
355
+ initial_condition_fun: (
356
+ tuple[Callable[[Params[Array]], Array], ...] | None
357
+ ) = (
310
358
  lambda p: jnp.mean(
311
359
  vmap(initial_condition_fun_, vmap_in_axes_params)(p)
312
- )
360
+ ),
313
361
  )
362
+ # Note that since MR 92
363
+ # initial_condition_fun is formed as a tuple for
364
+ # consistency with dynamic and observation
365
+ # losses and more modularity for later
314
366
  else:
315
367
  initial_condition_fun = None
316
368
 
317
369
  if batch.obs_batch_dict is not None:
318
- # update params with the batches of observed params
319
- params_obs = update_eq_params(params, batch.obs_batch_dict["eq_params"])
320
-
321
- pinn_in, val = (
322
- batch.obs_batch_dict["pinn_in"],
323
- batch.obs_batch_dict["val"],
324
- ) # the reason for this intruction is https://github.com/microsoft/pyright/discussions/8340
325
-
326
- # MSE loss wrt to an observed batch
327
- obs_loss_fun: Callable[[Params[Array]], Array] | None = (
328
- lambda po: observations_loss_apply(
329
- self.u,
330
- pinn_in,
331
- _set_derivatives(po, self.derivative_keys.observations),
332
- self.vmap_in_axes + vmap_in_axes_params,
333
- val,
370
+ # Note, for the record, multiple DGObs
371
+ # (leading to batch.obs_batch_dict being tuple | None)
372
+ # have been introduced in MR 92
373
+ if len(batch.obs_batch_dict) != len(self.obs_slice):
374
+ raise ValueError(
375
+ "There must be the same number of "
376
+ "observation datasets as the number of "
377
+ "obs_slice"
378
+ )
379
+ params_obs = jax.tree.map(
380
+ lambda d: update_eq_params(params, d["eq_params"]),
381
+ batch.obs_batch_dict,
382
+ is_leaf=lambda x: isinstance(x, dict),
383
+ )
384
+ obs_loss_fun: tuple[Callable[[Params[Array]], Array], ...] | None = (
385
+ jax.tree.map(
386
+ lambda d, slice_: lambda po: observations_loss_apply(
387
+ self.u,
388
+ d["pinn_in"],
389
+ _set_derivatives(po, self.derivative_keys.observations),
390
+ self.vmap_in_axes + vmap_in_axes_params,
391
+ d["val"],
392
+ slice_,
393
+ ),
394
+ batch.obs_batch_dict,
334
395
  self.obs_slice,
396
+ is_leaf=lambda x: isinstance(x, dict),
335
397
  )
336
398
  )
337
399
  else:
@@ -339,33 +401,36 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
339
401
  obs_loss_fun = None
340
402
 
341
403
  # get the unweighted mses for each loss term as well as the gradients
342
- all_funs: ODEComponents[Callable[[Params[Array]], Array] | None] = (
404
+ all_funs: ODEComponents[tuple[Callable[[Params[Array]], Array], ...] | None] = (
343
405
  ODEComponents(dyn_loss_fun, initial_condition_fun, obs_loss_fun)
344
406
  )
345
- all_params: ODEComponents[Params[Array] | None] = ODEComponents(
346
- params, params, params_obs
407
+ all_params: ODEComponents[tuple[Params[Array], ...] | None] = ODEComponents(
408
+ jax.tree.map(lambda l: params, dyn_loss_fun),
409
+ jax.tree.map(lambda l: params, initial_condition_fun),
410
+ params_obs,
347
411
  )
348
412
 
349
- # Note that the lambda functions below are with type: ignore just
350
- # because the lambda are not type annotated, but there is no proper way
351
- # to do this and we should assign the lambda to a type hinted variable
352
- # before hand: this is not practical, let us not get mad at this
353
413
  mses_grads = jax.tree.map(
354
414
  self.get_gradients,
355
415
  all_funs,
356
416
  all_params,
357
417
  is_leaf=lambda x: x is None,
358
418
  )
359
-
419
+ # NOTE the is_leaf below is more complex since it must pass possible the tuple
420
+ # of dyn_loss and then stops (but also account it should not stop when
421
+ # the tuple of dyn_loss is of length 2)
360
422
  mses = jax.tree.map(
361
- lambda leaf: leaf[0], # type: ignore
423
+ lambda leaf: leaf[0],
362
424
  mses_grads,
363
- is_leaf=lambda x: isinstance(x, tuple),
425
+ is_leaf=lambda x: isinstance(x, tuple)
426
+ and len(x) == 2
427
+ and isinstance(x[1], Params),
364
428
  )
365
429
  grads = jax.tree.map(
366
- lambda leaf: leaf[1], # type: ignore
430
+ lambda leaf: leaf[1],
367
431
  mses_grads,
368
- is_leaf=lambda x: isinstance(x, tuple),
432
+ is_leaf=lambda x: isinstance(x, tuple)
433
+ and len(x) == 2
434
+ and isinstance(x[1], Params),
369
435
  )
370
-
371
436
  return mses, grads