jinns 1.4.0__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/loss/_LossPDE.py CHANGED
@@ -11,6 +11,7 @@ from dataclasses import InitVar
11
11
  from typing import TYPE_CHECKING, Callable, TypedDict
12
12
  from types import EllipsisType
13
13
  import warnings
14
+ import jax
14
15
  import jax.numpy as jnp
15
16
  import equinox as eqx
16
17
  from jaxtyping import Float, Array, Key, Int
@@ -20,6 +21,7 @@ from jinns.loss._loss_utils import (
20
21
  normalization_loss_apply,
21
22
  observations_loss_apply,
22
23
  initial_condition_apply,
24
+ initial_condition_check,
23
25
  )
24
26
  from jinns.parameters._params import (
25
27
  _get_vmap_in_axes_params,
@@ -31,16 +33,17 @@ from jinns.parameters._derivative_keys import (
31
33
  DerivativeKeysPDENonStatio,
32
34
  )
33
35
  from jinns.loss._abstract_loss import AbstractLoss
36
+ from jinns.loss._loss_components import PDEStatioComponents, PDENonStatioComponents
34
37
  from jinns.loss._loss_weights import (
35
38
  LossWeightsPDEStatio,
36
39
  LossWeightsPDENonStatio,
37
40
  )
38
41
  from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
42
+ from jinns.parameters._params import Params
39
43
 
40
44
 
41
45
  if TYPE_CHECKING:
42
46
  # imports for type hints only
43
- from jinns.parameters._params import Params
44
47
  from jinns.nn._abstract_pinn import AbstractPINN
45
48
  from jinns.loss import PDENonStatio, PDEStatio
46
49
  from jinns.utils._types import BoundaryConditionFun
@@ -71,7 +74,10 @@ class _LossPDEAbstract(AbstractLoss):
71
74
  The loss weights for the differents term : dynamic loss,
72
75
  initial condition (if LossWeightsPDENonStatio), boundary conditions if
73
76
  any, normalization loss if any and observations if any.
74
- All fields are set to 1.0 by default.
77
+ Can be updated according to a specific algorithm. See
78
+ `update_weight_method`
79
+ update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
80
+ Default is None meaning no update for loss weights. Otherwise a string
75
81
  derivative_keys : DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio, default=None
76
82
  Specify which field of `params` should be differentiated for each
77
83
  composant of the total loss. Particularily useful for inverse problems.
@@ -336,7 +342,10 @@ class LossPDEStatio(_LossPDEAbstract):
336
342
  The loss weights for the differents term : dynamic loss,
337
343
  boundary conditions if any, normalization loss if any and
338
344
  observations if any.
339
- All fields are set to 1.0 by default.
345
+ Can be updated according to a specific algorithm. See
346
+ `update_weight_method`
347
+ update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
348
+ Default is None meaning no update for loss weights. Otherwise a string
340
349
  derivative_keys : DerivativeKeysPDEStatio, default=None
341
350
  Specify which field of `params` should be differentiated for each
342
351
  composant of the total loss. Particularily useful for inverse problems.
@@ -432,12 +441,13 @@ class LossPDEStatio(_LossPDEAbstract):
432
441
  def __call__(self, *args, **kwargs):
433
442
  return self.evaluate(*args, **kwargs)
434
443
 
435
- def evaluate(
444
+ def evaluate_by_terms(
436
445
  self, params: Params[Array], batch: PDEStatioBatch
437
- ) -> tuple[Float[Array, " "], LossDictPDEStatio]:
446
+ ) -> tuple[PDEStatioComponents[Array | None], PDEStatioComponents[Array | None]]:
438
447
  """
439
448
  Evaluate the loss function at a batch of points for given parameters.
440
449
 
450
+ We retrieve two PyTrees with loss values and gradients for each term
441
451
 
442
452
  Parameters
443
453
  ---------
@@ -461,29 +471,27 @@ class LossPDEStatio(_LossPDEAbstract):
461
471
 
462
472
  # dynamic part
463
473
  if self.dynamic_loss is not None:
464
- mse_dyn_loss = dynamic_loss_apply(
465
- self.dynamic_loss.evaluate,
474
+ dyn_loss_fun = lambda p: dynamic_loss_apply(
475
+ self.dynamic_loss.evaluate, # type: ignore
466
476
  self.u,
467
477
  self._get_dynamic_loss_batch(batch),
468
- _set_derivatives(params, self.derivative_keys.dyn_loss), # type: ignore
478
+ _set_derivatives(p, self.derivative_keys.dyn_loss), # type: ignore
469
479
  self.vmap_in_axes + vmap_in_axes_params,
470
- self.loss_weights.dyn_loss, # type: ignore
471
480
  )
472
481
  else:
473
- mse_dyn_loss = jnp.array(0.0)
482
+ dyn_loss_fun = None
474
483
 
475
484
  # normalization part
476
485
  if self.norm_samples is not None:
477
- mse_norm_loss = normalization_loss_apply(
486
+ norm_loss_fun = lambda p: normalization_loss_apply(
478
487
  self.u,
479
488
  self._get_normalization_loss_batch(batch),
480
- _set_derivatives(params, self.derivative_keys.norm_loss), # type: ignore
489
+ _set_derivatives(p, self.derivative_keys.norm_loss), # type: ignore
481
490
  vmap_in_axes_params,
482
491
  self.norm_weights, # type: ignore -> can't get the __post_init__ narrowing here
483
- self.loss_weights.norm_loss, # type: ignore
484
492
  )
485
493
  else:
486
- mse_norm_loss = jnp.array(0.0)
494
+ norm_loss_fun = None
487
495
 
488
496
  # boundary part
489
497
  if (
@@ -491,47 +499,84 @@ class LossPDEStatio(_LossPDEAbstract):
491
499
  and self.omega_boundary_dim is not None
492
500
  and self.omega_boundary_fun is not None
493
501
  ): # pyright cannot narrow down the three None otherwise as it is class attribute
494
- mse_boundary_loss = boundary_condition_apply(
502
+ boundary_loss_fun = lambda p: boundary_condition_apply(
495
503
  self.u,
496
504
  batch,
497
- _set_derivatives(params, self.derivative_keys.boundary_loss), # type: ignore
498
- self.omega_boundary_fun,
499
- self.omega_boundary_condition,
500
- self.omega_boundary_dim,
501
- self.loss_weights.boundary_loss, # type: ignore
505
+ _set_derivatives(p, self.derivative_keys.boundary_loss), # type: ignore
506
+ self.omega_boundary_fun, # type: ignore
507
+ self.omega_boundary_condition, # type: ignore
508
+ self.omega_boundary_dim, # type: ignore
502
509
  )
503
510
  else:
504
- mse_boundary_loss = jnp.array(0.0)
511
+ boundary_loss_fun = None
505
512
 
506
513
  # Observation mse
507
514
  if batch.obs_batch_dict is not None:
508
515
  # update params with the batches of observed params
509
- params = _update_eq_params_dict(params, batch.obs_batch_dict["eq_params"])
516
+ params_obs = _update_eq_params_dict(
517
+ params, batch.obs_batch_dict["eq_params"]
518
+ )
510
519
 
511
- mse_observation_loss = observations_loss_apply(
520
+ obs_loss_fun = lambda po: observations_loss_apply(
512
521
  self.u,
513
522
  self._get_observations_loss_batch(batch),
514
- _set_derivatives(params, self.derivative_keys.observations), # type: ignore
523
+ _set_derivatives(po, self.derivative_keys.observations), # type: ignore
515
524
  self.vmap_in_axes + vmap_in_axes_params,
516
525
  batch.obs_batch_dict["val"],
517
- self.loss_weights.observations, # type: ignore
518
526
  self.obs_slice,
519
527
  )
520
528
  else:
521
- mse_observation_loss = jnp.array(0.0)
529
+ params_obs = None
530
+ obs_loss_fun = None
522
531
 
523
- # total loss
524
- total_loss = (
525
- mse_dyn_loss + mse_norm_loss + mse_boundary_loss + mse_observation_loss
532
+ # get the unweighted mses for each loss term as well as the gradients
533
+ all_funs: PDEStatioComponents[Callable[[Params[Array]], Array] | None] = (
534
+ PDEStatioComponents(
535
+ dyn_loss_fun, norm_loss_fun, boundary_loss_fun, obs_loss_fun
536
+ )
537
+ )
538
+ all_params: PDEStatioComponents[Params[Array] | None] = PDEStatioComponents(
539
+ params, params, params, params_obs
526
540
  )
527
- return total_loss, (
528
- {
529
- "dyn_loss": mse_dyn_loss,
530
- "norm_loss": mse_norm_loss,
531
- "boundary_loss": mse_boundary_loss,
532
- "observations": mse_observation_loss,
533
- }
541
+ mses_grads = jax.tree.map(
542
+ lambda fun, params: self.get_gradients(fun, params),
543
+ all_funs,
544
+ all_params,
545
+ is_leaf=lambda x: x is None,
534
546
  )
547
+ mses = jax.tree.map(
548
+ lambda leaf: leaf[0], mses_grads, is_leaf=lambda x: isinstance(x, tuple)
549
+ )
550
+ grads = jax.tree.map(
551
+ lambda leaf: leaf[1], mses_grads, is_leaf=lambda x: isinstance(x, tuple)
552
+ )
553
+
554
+ return mses, grads
555
+
556
+ def evaluate(
557
+ self, params: Params[Array], batch: PDEStatioBatch
558
+ ) -> tuple[Float[Array, " "], PDEStatioComponents[Float[Array, " "] | None]]:
559
+ """
560
+ Evaluate the loss function at a batch of points for given parameters.
561
+
562
+ We retrieve the total value itself and a PyTree with loss values for each term
563
+
564
+ Parameters
565
+ ---------
566
+ params
567
+ Parameters at which the loss is evaluated
568
+ batch
569
+ Composed of a batch of points in the
570
+ domain, a batch of points in the domain
571
+ border and an optional additional batch of parameters (eg. for
572
+ metamodeling) and an optional additional batch of observed
573
+ inputs/outputs/parameters
574
+ """
575
+ loss_terms, _ = self.evaluate_by_terms(params, batch)
576
+
577
+ loss_val = self.ponderate_and_sum_loss(loss_terms)
578
+
579
+ return loss_val, loss_terms
535
580
 
536
581
 
537
582
  class LossPDENonStatio(LossPDEStatio):
@@ -569,7 +614,10 @@ class LossPDENonStatio(LossPDEStatio):
569
614
  The loss weights for the differents term : dynamic loss,
570
615
  boundary conditions if any, initial condition, normalization loss if any and
571
616
  observations if any.
572
- All fields are set to 1.0 by default.
617
+ Can be updated according to a specific algorithm. See
618
+ `update_weight_method`
619
+ update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
620
+ Default is None meaning no update for loss weights. Otherwise a string
573
621
  derivative_keys : DerivativeKeysPDENonStatio, default=None
574
622
  Specify which field of `params` should be differentiated for each
575
623
  composant of the total loss. Particularily useful for inverse problems.
@@ -653,20 +701,12 @@ class LossPDENonStatio(LossPDEStatio):
653
701
  "case (e.g by. hardcoding it into the PINN output)."
654
702
  )
655
703
  # some checks for t0
656
- if isinstance(self.t0, Array):
657
- if not self.t0.shape: # e.g. user input: jnp.array(0.)
658
- self.t0 = jnp.array([self.t0])
659
- elif self.t0.shape != (1,):
660
- raise ValueError(
661
- f"Wrong self.t0 input. It should be"
662
- f"a float or an array of shape (1,). Got shape: {self.t0.shape}"
663
- )
664
- elif isinstance(self.t0, float): # e.g. user input: 0
665
- self.t0 = jnp.array([self.t0])
666
- elif self.t0 is None:
667
- self.t0 = jnp.array([0])
704
+ t0 = self.t0
705
+ if t0 is None:
706
+ t0 = jnp.array([0])
668
707
  else:
669
- raise ValueError("Wrong value for t0")
708
+ t0 = initial_condition_check(t0, dim_size=1)
709
+ self.t0 = t0
670
710
 
671
711
  # witht the variables below we avoid memory overflow since a cartesian
672
712
  # product is taken
@@ -696,22 +736,25 @@ class LossPDENonStatio(LossPDEStatio):
696
736
  def __call__(self, *args, **kwargs):
697
737
  return self.evaluate(*args, **kwargs)
698
738
 
699
- def evaluate(
739
+ def evaluate_by_terms(
700
740
  self, params: Params[Array], batch: PDENonStatioBatch
701
- ) -> tuple[Float[Array, " "], LossDictPDENonStatio]:
741
+ ) -> tuple[
742
+ PDENonStatioComponents[Array | None], PDENonStatioComponents[Array | None]
743
+ ]:
702
744
  """
703
745
  Evaluate the loss function at a batch of points for given parameters.
704
746
 
747
+ We retrieve two PyTrees with loss values and gradients for each term
705
748
 
706
749
  Parameters
707
750
  ---------
708
751
  params
709
752
  Parameters at which the loss is evaluated
710
753
  batch
711
- Composed of a batch of points in
712
- the domain, a batch of points in the domain
713
- border, a batch of time points and an optional additional batch
714
- of parameters (eg. for metamodeling) and an optional additional batch of observed
754
+ Composed of a batch of points in the
755
+ domain, a batch of points in the domain
756
+ border and an optional additional batch of parameters (eg. for
757
+ metamodeling) and an optional additional batch of observed
715
758
  inputs/outputs/parameters
716
759
  """
717
760
  omega_batch = batch.initial_batch
@@ -728,27 +771,62 @@ class LossPDENonStatio(LossPDEStatio):
728
771
 
729
772
  # For mse_dyn_loss, mse_norm_loss, mse_boundary_loss,
730
773
  # mse_observation_loss we use the evaluate from parent class
731
- partial_mse, partial_mse_terms = super().evaluate(params, batch) # type: ignore
774
+ # As well as for their gradients
775
+ partial_mses, partial_grads = super().evaluate_by_terms(params, batch) # type: ignore
732
776
  # ignore because batch is not PDEStatioBatch. We could use typing.cast though
733
777
 
734
778
  # initial condition
735
779
  if self.initial_condition_fun is not None:
736
- mse_initial_condition = initial_condition_apply(
780
+ mse_initial_condition_fun = lambda p: initial_condition_apply(
737
781
  self.u,
738
782
  omega_batch,
739
- _set_derivatives(params, self.derivative_keys.initial_condition), # type: ignore
783
+ _set_derivatives(p, self.derivative_keys.initial_condition), # type: ignore
740
784
  (0,) + vmap_in_axes_params,
741
- self.initial_condition_fun,
785
+ self.initial_condition_fun, # type: ignore
742
786
  self.t0, # type: ignore can't get the narrowing in __post_init__
743
- self.loss_weights.initial_condition, # type: ignore
787
+ )
788
+ mse_initial_condition, grad_initial_condition = self.get_gradients(
789
+ mse_initial_condition_fun, params
744
790
  )
745
791
  else:
746
- mse_initial_condition = jnp.array(0.0)
792
+ mse_initial_condition = None
793
+ grad_initial_condition = None
794
+
795
+ mses = PDENonStatioComponents(
796
+ partial_mses.dyn_loss,
797
+ partial_mses.norm_loss,
798
+ partial_mses.boundary_loss,
799
+ partial_mses.observations,
800
+ mse_initial_condition,
801
+ )
802
+
803
+ grads = PDENonStatioComponents(
804
+ partial_grads.dyn_loss,
805
+ partial_grads.norm_loss,
806
+ partial_grads.boundary_loss,
807
+ partial_grads.observations,
808
+ grad_initial_condition,
809
+ )
810
+
811
+ return mses, grads
812
+
813
+ def evaluate(
814
+ self, params: Params[Array], batch: PDENonStatioBatch
815
+ ) -> tuple[Float[Array, " "], PDENonStatioComponents[Float[Array, " "] | None]]:
816
+ """
817
+ Evaluate the loss function at a batch of points for given parameters.
818
+ We retrieve the total value itself and a PyTree with loss values for each term
747
819
 
748
- # total loss
749
- total_loss = partial_mse + mse_initial_condition
750
820
 
751
- return total_loss, {
752
- **partial_mse_terms,
753
- "initial_condition": mse_initial_condition,
754
- }
821
+ Parameters
822
+ ---------
823
+ params
824
+ Parameters at which the loss is evaluated
825
+ batch
826
+ Composed of a batch of points in
827
+ the domain, a batch of points in the domain
828
+ border, a batch of time points and an optional additional batch
829
+ of parameters (eg. for metamodeling) and an optional additional batch of observed
830
+ inputs/outputs/parameters
831
+ """
832
+ return super().evaluate(params, batch) # type: ignore
jinns/loss/__init__.py CHANGED
@@ -14,6 +14,7 @@ from ._loss_weights import (
14
14
  LossWeightsPDENonStatio,
15
15
  LossWeightsPDEStatio,
16
16
  )
17
+ from ._loss_weight_updates import soft_adapt, lr_annealing, ReLoBRaLo
17
18
 
18
19
  from ._operators import (
19
20
  divergence_fwd,
@@ -47,4 +48,7 @@ __all__ = [
47
48
  "laplacian_rev",
48
49
  "vectorial_laplacian_fwd",
49
50
  "vectorial_laplacian_rev",
51
+ "soft_adapt",
52
+ "lr_annealing",
53
+ "ReLoBRaLo",
50
54
  ]
@@ -1,15 +1,128 @@
1
+ from __future__ import annotations
2
+
1
3
  import abc
2
- from jaxtyping import Array
4
+ from typing import TYPE_CHECKING, Self, Literal, Callable
5
+ from jaxtyping import Array, PyTree, Key
3
6
  import equinox as eqx
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import optax
10
+ from jinns.loss._loss_weights import AbstractLossWeights
11
+ from jinns.parameters._params import Params
12
+ from jinns.loss._loss_weight_updates import soft_adapt, lr_annealing, ReLoBRaLo
13
+
14
+ if TYPE_CHECKING:
15
+ from jinns.utils._types import AnyLossComponents, AnyBatch
4
16
 
5
17
 
6
18
  class AbstractLoss(eqx.Module):
7
19
  """
8
- Basically just a way to add a __call__ to an eqx.Module.
9
- The way to go for correct type hints apparently
20
+ About the call:
10
21
  https://github.com/patrick-kidger/equinox/issues/1002 + https://docs.kidger.site/equinox/pattern/
11
22
  """
12
23
 
24
+ loss_weights: AbstractLossWeights
25
+ update_weight_method: Literal["soft_adapt", "lr_annealing", "ReLoBRaLo"] | None = (
26
+ eqx.field(kw_only=True, default=None, static=True)
27
+ )
28
+
13
29
  @abc.abstractmethod
14
30
  def __call__(self, *_, **__) -> Array:
15
31
  pass
32
+
33
+ @abc.abstractmethod
34
+ def evaluate_by_terms(
35
+ self, params: Params[Array], batch: AnyBatch
36
+ ) -> tuple[AnyLossComponents, AnyLossComponents]:
37
+ pass
38
+
39
+ def get_gradients(
40
+ self, fun: Callable[[Params[Array]], Array], params: Params[Array]
41
+ ) -> tuple[Array, Array]:
42
+ """
43
+ params already filtered with derivative keys here
44
+ """
45
+ if fun is None:
46
+ return None, None
47
+ value_grad_loss = jax.value_and_grad(fun)
48
+ loss_val, grads = value_grad_loss(params)
49
+ return loss_val, grads
50
+
51
+ def ponderate_and_sum_loss(self, terms):
52
+ """
53
+ Get total loss from individual loss terms and weights
54
+
55
+ tree.leaves is needed to get rid of None from non used loss terms
56
+ """
57
+ weights = jax.tree.leaves(
58
+ self.loss_weights,
59
+ is_leaf=lambda x: eqx.is_inexact_array(x) and x is not None,
60
+ )
61
+ terms = jax.tree.leaves(
62
+ terms, is_leaf=lambda x: eqx.is_inexact_array(x) and x is not None
63
+ )
64
+ if len(weights) == len(terms):
65
+ return jnp.sum(jnp.array(weights) * jnp.array(terms))
66
+ else:
67
+ raise ValueError(
68
+ "The numbers of declared loss weights and "
69
+ "declared loss terms do not concord "
70
+ f" got {len(weights)} and {len(terms)}"
71
+ )
72
+
73
+ def ponderate_and_sum_gradient(self, terms):
74
+ """
75
+ Get total gradients from individual loss gradients and weights
76
+ for each parameter
77
+
78
+ tree.leaves is needed to get rid of None from non used loss terms
79
+ """
80
+ weights = jax.tree.leaves(
81
+ self.loss_weights,
82
+ is_leaf=lambda x: eqx.is_inexact_array(x) and x is not None,
83
+ )
84
+ grads = jax.tree.leaves(terms, is_leaf=lambda x: isinstance(x, Params))
85
+ # gradient terms for each individual loss for each parameter (several
86
+ # Params structures)
87
+ weights_pytree = jax.tree.map(
88
+ lambda w: optax.tree_utils.tree_full_like(grads[0], w), weights
89
+ ) # We need several Params structures full of the weight scalar
90
+ weighted_grads = jax.tree.map(
91
+ lambda w, p: w * p, weights_pytree, grads, is_leaf=eqx.is_inexact_array
92
+ ) # Now we can multiply
93
+ return jax.tree.map(
94
+ lambda *grads: jnp.sum(jnp.array(grads), axis=0),
95
+ *weighted_grads,
96
+ is_leaf=eqx.is_inexact_array,
97
+ )
98
+
99
+ def update_weights(
100
+ self: Self,
101
+ iteration_nb: int,
102
+ loss_terms: PyTree,
103
+ stored_loss_terms: PyTree,
104
+ grad_terms: PyTree,
105
+ key: Key,
106
+ ) -> Self:
107
+ """
108
+ Update the loss weights according to a predefined scheme
109
+ """
110
+ if self.update_weight_method == "soft_adapt":
111
+ new_weights = soft_adapt(
112
+ self.loss_weights, iteration_nb, loss_terms, stored_loss_terms
113
+ )
114
+ elif self.update_weight_method == "lr_annealing":
115
+ new_weights = lr_annealing(self.loss_weights, grad_terms)
116
+ elif self.update_weight_method == "ReLoBRaLo":
117
+ new_weights = ReLoBRaLo(
118
+ self.loss_weights, iteration_nb, loss_terms, stored_loss_terms, key
119
+ )
120
+ else:
121
+ raise ValueError("update_weight_method for loss weights not implemented")
122
+
123
+ # Below we update the non None entry in the PyTree self.loss_weights
124
+ # we directly get the non None entries because None is not treated as a
125
+ # leaf
126
+ return eqx.tree_at(
127
+ lambda pt: jax.tree.leaves(pt.loss_weights), self, new_weights
128
+ )
@@ -0,0 +1,43 @@
1
+ from typing import TypeVar, Generic
2
+ from dataclasses import fields
3
+ import equinox as eqx
4
+
5
+ T = TypeVar("T")
6
+
7
+
8
+ class XDEComponentsAbstract(eqx.Module, Generic[T]):
9
+ """
10
+ Provides a template for ODE components with generic types.
11
+ One can inherit to specialize and add methods and attributes
12
+ We do not enforce keyword only to avoid being to verbose (this then can
13
+ work like a tuple)
14
+ """
15
+
16
+ def items(self):
17
+ """
18
+ For the dataclass to be iterated like a dictionary.
19
+ Practical and retrocompatible with old code when loss components were
20
+ dictionaries
21
+ """
22
+ return {
23
+ field.name: getattr(self, field.name)
24
+ for field in fields(self)
25
+ if getattr(self, field.name) is not None
26
+ }.items()
27
+
28
+
29
+ class ODEComponents(XDEComponentsAbstract[T]):
30
+ dyn_loss: T
31
+ initial_condition: T
32
+ observations: T
33
+
34
+
35
+ class PDEStatioComponents(XDEComponentsAbstract[T]):
36
+ dyn_loss: T
37
+ norm_loss: T
38
+ boundary_loss: T
39
+ observations: T
40
+
41
+
42
+ class PDENonStatioComponents(PDEStatioComponents[T]):
43
+ initial_condition: T