jinns 1.4.0__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +7 -7
- jinns/data/_CubicMeshPDENonStatio.py +156 -28
- jinns/data/_CubicMeshPDEStatio.py +132 -24
- jinns/loss/_DynamicLossAbstract.py +30 -2
- jinns/loss/_LossODE.py +177 -64
- jinns/loss/_LossPDE.py +146 -68
- jinns/loss/__init__.py +4 -0
- jinns/loss/_abstract_loss.py +116 -3
- jinns/loss/_loss_components.py +43 -0
- jinns/loss/_loss_utils.py +34 -24
- jinns/loss/_loss_weight_updates.py +202 -0
- jinns/loss/_loss_weights.py +72 -16
- jinns/parameters/_params.py +8 -0
- jinns/solver/_solve.py +141 -46
- jinns/utils/_containers.py +5 -2
- jinns/utils/_types.py +12 -0
- {jinns-1.4.0.dist-info → jinns-1.5.1.dist-info}/METADATA +5 -2
- {jinns-1.4.0.dist-info → jinns-1.5.1.dist-info}/RECORD +22 -20
- {jinns-1.4.0.dist-info → jinns-1.5.1.dist-info}/WHEEL +1 -1
- {jinns-1.4.0.dist-info → jinns-1.5.1.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.4.0.dist-info → jinns-1.5.1.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.4.0.dist-info → jinns-1.5.1.dist-info}/top_level.txt +0 -0
jinns/loss/_LossPDE.py
CHANGED
|
@@ -11,6 +11,7 @@ from dataclasses import InitVar
|
|
|
11
11
|
from typing import TYPE_CHECKING, Callable, TypedDict
|
|
12
12
|
from types import EllipsisType
|
|
13
13
|
import warnings
|
|
14
|
+
import jax
|
|
14
15
|
import jax.numpy as jnp
|
|
15
16
|
import equinox as eqx
|
|
16
17
|
from jaxtyping import Float, Array, Key, Int
|
|
@@ -20,6 +21,7 @@ from jinns.loss._loss_utils import (
|
|
|
20
21
|
normalization_loss_apply,
|
|
21
22
|
observations_loss_apply,
|
|
22
23
|
initial_condition_apply,
|
|
24
|
+
initial_condition_check,
|
|
23
25
|
)
|
|
24
26
|
from jinns.parameters._params import (
|
|
25
27
|
_get_vmap_in_axes_params,
|
|
@@ -31,16 +33,17 @@ from jinns.parameters._derivative_keys import (
|
|
|
31
33
|
DerivativeKeysPDENonStatio,
|
|
32
34
|
)
|
|
33
35
|
from jinns.loss._abstract_loss import AbstractLoss
|
|
36
|
+
from jinns.loss._loss_components import PDEStatioComponents, PDENonStatioComponents
|
|
34
37
|
from jinns.loss._loss_weights import (
|
|
35
38
|
LossWeightsPDEStatio,
|
|
36
39
|
LossWeightsPDENonStatio,
|
|
37
40
|
)
|
|
38
41
|
from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
|
|
42
|
+
from jinns.parameters._params import Params
|
|
39
43
|
|
|
40
44
|
|
|
41
45
|
if TYPE_CHECKING:
|
|
42
46
|
# imports for type hints only
|
|
43
|
-
from jinns.parameters._params import Params
|
|
44
47
|
from jinns.nn._abstract_pinn import AbstractPINN
|
|
45
48
|
from jinns.loss import PDENonStatio, PDEStatio
|
|
46
49
|
from jinns.utils._types import BoundaryConditionFun
|
|
@@ -71,7 +74,10 @@ class _LossPDEAbstract(AbstractLoss):
|
|
|
71
74
|
The loss weights for the differents term : dynamic loss,
|
|
72
75
|
initial condition (if LossWeightsPDENonStatio), boundary conditions if
|
|
73
76
|
any, normalization loss if any and observations if any.
|
|
74
|
-
|
|
77
|
+
Can be updated according to a specific algorithm. See
|
|
78
|
+
`update_weight_method`
|
|
79
|
+
update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
|
|
80
|
+
Default is None meaning no update for loss weights. Otherwise a string
|
|
75
81
|
derivative_keys : DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio, default=None
|
|
76
82
|
Specify which field of `params` should be differentiated for each
|
|
77
83
|
composant of the total loss. Particularily useful for inverse problems.
|
|
@@ -336,7 +342,10 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
336
342
|
The loss weights for the differents term : dynamic loss,
|
|
337
343
|
boundary conditions if any, normalization loss if any and
|
|
338
344
|
observations if any.
|
|
339
|
-
|
|
345
|
+
Can be updated according to a specific algorithm. See
|
|
346
|
+
`update_weight_method`
|
|
347
|
+
update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
|
|
348
|
+
Default is None meaning no update for loss weights. Otherwise a string
|
|
340
349
|
derivative_keys : DerivativeKeysPDEStatio, default=None
|
|
341
350
|
Specify which field of `params` should be differentiated for each
|
|
342
351
|
composant of the total loss. Particularily useful for inverse problems.
|
|
@@ -432,12 +441,13 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
432
441
|
def __call__(self, *args, **kwargs):
|
|
433
442
|
return self.evaluate(*args, **kwargs)
|
|
434
443
|
|
|
435
|
-
def
|
|
444
|
+
def evaluate_by_terms(
|
|
436
445
|
self, params: Params[Array], batch: PDEStatioBatch
|
|
437
|
-
) -> tuple[
|
|
446
|
+
) -> tuple[PDEStatioComponents[Array | None], PDEStatioComponents[Array | None]]:
|
|
438
447
|
"""
|
|
439
448
|
Evaluate the loss function at a batch of points for given parameters.
|
|
440
449
|
|
|
450
|
+
We retrieve two PyTrees with loss values and gradients for each term
|
|
441
451
|
|
|
442
452
|
Parameters
|
|
443
453
|
---------
|
|
@@ -461,29 +471,27 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
461
471
|
|
|
462
472
|
# dynamic part
|
|
463
473
|
if self.dynamic_loss is not None:
|
|
464
|
-
|
|
465
|
-
self.dynamic_loss.evaluate,
|
|
474
|
+
dyn_loss_fun = lambda p: dynamic_loss_apply(
|
|
475
|
+
self.dynamic_loss.evaluate, # type: ignore
|
|
466
476
|
self.u,
|
|
467
477
|
self._get_dynamic_loss_batch(batch),
|
|
468
|
-
_set_derivatives(
|
|
478
|
+
_set_derivatives(p, self.derivative_keys.dyn_loss), # type: ignore
|
|
469
479
|
self.vmap_in_axes + vmap_in_axes_params,
|
|
470
|
-
self.loss_weights.dyn_loss, # type: ignore
|
|
471
480
|
)
|
|
472
481
|
else:
|
|
473
|
-
|
|
482
|
+
dyn_loss_fun = None
|
|
474
483
|
|
|
475
484
|
# normalization part
|
|
476
485
|
if self.norm_samples is not None:
|
|
477
|
-
|
|
486
|
+
norm_loss_fun = lambda p: normalization_loss_apply(
|
|
478
487
|
self.u,
|
|
479
488
|
self._get_normalization_loss_batch(batch),
|
|
480
|
-
_set_derivatives(
|
|
489
|
+
_set_derivatives(p, self.derivative_keys.norm_loss), # type: ignore
|
|
481
490
|
vmap_in_axes_params,
|
|
482
491
|
self.norm_weights, # type: ignore -> can't get the __post_init__ narrowing here
|
|
483
|
-
self.loss_weights.norm_loss, # type: ignore
|
|
484
492
|
)
|
|
485
493
|
else:
|
|
486
|
-
|
|
494
|
+
norm_loss_fun = None
|
|
487
495
|
|
|
488
496
|
# boundary part
|
|
489
497
|
if (
|
|
@@ -491,47 +499,84 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
491
499
|
and self.omega_boundary_dim is not None
|
|
492
500
|
and self.omega_boundary_fun is not None
|
|
493
501
|
): # pyright cannot narrow down the three None otherwise as it is class attribute
|
|
494
|
-
|
|
502
|
+
boundary_loss_fun = lambda p: boundary_condition_apply(
|
|
495
503
|
self.u,
|
|
496
504
|
batch,
|
|
497
|
-
_set_derivatives(
|
|
498
|
-
self.omega_boundary_fun,
|
|
499
|
-
self.omega_boundary_condition,
|
|
500
|
-
self.omega_boundary_dim,
|
|
501
|
-
self.loss_weights.boundary_loss, # type: ignore
|
|
505
|
+
_set_derivatives(p, self.derivative_keys.boundary_loss), # type: ignore
|
|
506
|
+
self.omega_boundary_fun, # type: ignore
|
|
507
|
+
self.omega_boundary_condition, # type: ignore
|
|
508
|
+
self.omega_boundary_dim, # type: ignore
|
|
502
509
|
)
|
|
503
510
|
else:
|
|
504
|
-
|
|
511
|
+
boundary_loss_fun = None
|
|
505
512
|
|
|
506
513
|
# Observation mse
|
|
507
514
|
if batch.obs_batch_dict is not None:
|
|
508
515
|
# update params with the batches of observed params
|
|
509
|
-
|
|
516
|
+
params_obs = _update_eq_params_dict(
|
|
517
|
+
params, batch.obs_batch_dict["eq_params"]
|
|
518
|
+
)
|
|
510
519
|
|
|
511
|
-
|
|
520
|
+
obs_loss_fun = lambda po: observations_loss_apply(
|
|
512
521
|
self.u,
|
|
513
522
|
self._get_observations_loss_batch(batch),
|
|
514
|
-
_set_derivatives(
|
|
523
|
+
_set_derivatives(po, self.derivative_keys.observations), # type: ignore
|
|
515
524
|
self.vmap_in_axes + vmap_in_axes_params,
|
|
516
525
|
batch.obs_batch_dict["val"],
|
|
517
|
-
self.loss_weights.observations, # type: ignore
|
|
518
526
|
self.obs_slice,
|
|
519
527
|
)
|
|
520
528
|
else:
|
|
521
|
-
|
|
529
|
+
params_obs = None
|
|
530
|
+
obs_loss_fun = None
|
|
522
531
|
|
|
523
|
-
#
|
|
524
|
-
|
|
525
|
-
|
|
532
|
+
# get the unweighted mses for each loss term as well as the gradients
|
|
533
|
+
all_funs: PDEStatioComponents[Callable[[Params[Array]], Array] | None] = (
|
|
534
|
+
PDEStatioComponents(
|
|
535
|
+
dyn_loss_fun, norm_loss_fun, boundary_loss_fun, obs_loss_fun
|
|
536
|
+
)
|
|
537
|
+
)
|
|
538
|
+
all_params: PDEStatioComponents[Params[Array] | None] = PDEStatioComponents(
|
|
539
|
+
params, params, params, params_obs
|
|
526
540
|
)
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
"observations": mse_observation_loss,
|
|
533
|
-
}
|
|
541
|
+
mses_grads = jax.tree.map(
|
|
542
|
+
lambda fun, params: self.get_gradients(fun, params),
|
|
543
|
+
all_funs,
|
|
544
|
+
all_params,
|
|
545
|
+
is_leaf=lambda x: x is None,
|
|
534
546
|
)
|
|
547
|
+
mses = jax.tree.map(
|
|
548
|
+
lambda leaf: leaf[0], mses_grads, is_leaf=lambda x: isinstance(x, tuple)
|
|
549
|
+
)
|
|
550
|
+
grads = jax.tree.map(
|
|
551
|
+
lambda leaf: leaf[1], mses_grads, is_leaf=lambda x: isinstance(x, tuple)
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
return mses, grads
|
|
555
|
+
|
|
556
|
+
def evaluate(
|
|
557
|
+
self, params: Params[Array], batch: PDEStatioBatch
|
|
558
|
+
) -> tuple[Float[Array, " "], PDEStatioComponents[Float[Array, " "] | None]]:
|
|
559
|
+
"""
|
|
560
|
+
Evaluate the loss function at a batch of points for given parameters.
|
|
561
|
+
|
|
562
|
+
We retrieve the total value itself and a PyTree with loss values for each term
|
|
563
|
+
|
|
564
|
+
Parameters
|
|
565
|
+
---------
|
|
566
|
+
params
|
|
567
|
+
Parameters at which the loss is evaluated
|
|
568
|
+
batch
|
|
569
|
+
Composed of a batch of points in the
|
|
570
|
+
domain, a batch of points in the domain
|
|
571
|
+
border and an optional additional batch of parameters (eg. for
|
|
572
|
+
metamodeling) and an optional additional batch of observed
|
|
573
|
+
inputs/outputs/parameters
|
|
574
|
+
"""
|
|
575
|
+
loss_terms, _ = self.evaluate_by_terms(params, batch)
|
|
576
|
+
|
|
577
|
+
loss_val = self.ponderate_and_sum_loss(loss_terms)
|
|
578
|
+
|
|
579
|
+
return loss_val, loss_terms
|
|
535
580
|
|
|
536
581
|
|
|
537
582
|
class LossPDENonStatio(LossPDEStatio):
|
|
@@ -569,7 +614,10 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
569
614
|
The loss weights for the differents term : dynamic loss,
|
|
570
615
|
boundary conditions if any, initial condition, normalization loss if any and
|
|
571
616
|
observations if any.
|
|
572
|
-
|
|
617
|
+
Can be updated according to a specific algorithm. See
|
|
618
|
+
`update_weight_method`
|
|
619
|
+
update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
|
|
620
|
+
Default is None meaning no update for loss weights. Otherwise a string
|
|
573
621
|
derivative_keys : DerivativeKeysPDENonStatio, default=None
|
|
574
622
|
Specify which field of `params` should be differentiated for each
|
|
575
623
|
composant of the total loss. Particularily useful for inverse problems.
|
|
@@ -653,20 +701,12 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
653
701
|
"case (e.g by. hardcoding it into the PINN output)."
|
|
654
702
|
)
|
|
655
703
|
# some checks for t0
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
elif self.t0.shape != (1,):
|
|
660
|
-
raise ValueError(
|
|
661
|
-
f"Wrong self.t0 input. It should be"
|
|
662
|
-
f"a float or an array of shape (1,). Got shape: {self.t0.shape}"
|
|
663
|
-
)
|
|
664
|
-
elif isinstance(self.t0, float): # e.g. user input: 0
|
|
665
|
-
self.t0 = jnp.array([self.t0])
|
|
666
|
-
elif self.t0 is None:
|
|
667
|
-
self.t0 = jnp.array([0])
|
|
704
|
+
t0 = self.t0
|
|
705
|
+
if t0 is None:
|
|
706
|
+
t0 = jnp.array([0])
|
|
668
707
|
else:
|
|
669
|
-
|
|
708
|
+
t0 = initial_condition_check(t0, dim_size=1)
|
|
709
|
+
self.t0 = t0
|
|
670
710
|
|
|
671
711
|
# witht the variables below we avoid memory overflow since a cartesian
|
|
672
712
|
# product is taken
|
|
@@ -696,22 +736,25 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
696
736
|
def __call__(self, *args, **kwargs):
|
|
697
737
|
return self.evaluate(*args, **kwargs)
|
|
698
738
|
|
|
699
|
-
def
|
|
739
|
+
def evaluate_by_terms(
|
|
700
740
|
self, params: Params[Array], batch: PDENonStatioBatch
|
|
701
|
-
) -> tuple[
|
|
741
|
+
) -> tuple[
|
|
742
|
+
PDENonStatioComponents[Array | None], PDENonStatioComponents[Array | None]
|
|
743
|
+
]:
|
|
702
744
|
"""
|
|
703
745
|
Evaluate the loss function at a batch of points for given parameters.
|
|
704
746
|
|
|
747
|
+
We retrieve two PyTrees with loss values and gradients for each term
|
|
705
748
|
|
|
706
749
|
Parameters
|
|
707
750
|
---------
|
|
708
751
|
params
|
|
709
752
|
Parameters at which the loss is evaluated
|
|
710
753
|
batch
|
|
711
|
-
Composed of a batch of points in
|
|
712
|
-
|
|
713
|
-
border
|
|
714
|
-
|
|
754
|
+
Composed of a batch of points in the
|
|
755
|
+
domain, a batch of points in the domain
|
|
756
|
+
border and an optional additional batch of parameters (eg. for
|
|
757
|
+
metamodeling) and an optional additional batch of observed
|
|
715
758
|
inputs/outputs/parameters
|
|
716
759
|
"""
|
|
717
760
|
omega_batch = batch.initial_batch
|
|
@@ -728,27 +771,62 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
728
771
|
|
|
729
772
|
# For mse_dyn_loss, mse_norm_loss, mse_boundary_loss,
|
|
730
773
|
# mse_observation_loss we use the evaluate from parent class
|
|
731
|
-
|
|
774
|
+
# As well as for their gradients
|
|
775
|
+
partial_mses, partial_grads = super().evaluate_by_terms(params, batch) # type: ignore
|
|
732
776
|
# ignore because batch is not PDEStatioBatch. We could use typing.cast though
|
|
733
777
|
|
|
734
778
|
# initial condition
|
|
735
779
|
if self.initial_condition_fun is not None:
|
|
736
|
-
|
|
780
|
+
mse_initial_condition_fun = lambda p: initial_condition_apply(
|
|
737
781
|
self.u,
|
|
738
782
|
omega_batch,
|
|
739
|
-
_set_derivatives(
|
|
783
|
+
_set_derivatives(p, self.derivative_keys.initial_condition), # type: ignore
|
|
740
784
|
(0,) + vmap_in_axes_params,
|
|
741
|
-
self.initial_condition_fun,
|
|
785
|
+
self.initial_condition_fun, # type: ignore
|
|
742
786
|
self.t0, # type: ignore can't get the narrowing in __post_init__
|
|
743
|
-
|
|
787
|
+
)
|
|
788
|
+
mse_initial_condition, grad_initial_condition = self.get_gradients(
|
|
789
|
+
mse_initial_condition_fun, params
|
|
744
790
|
)
|
|
745
791
|
else:
|
|
746
|
-
mse_initial_condition =
|
|
792
|
+
mse_initial_condition = None
|
|
793
|
+
grad_initial_condition = None
|
|
794
|
+
|
|
795
|
+
mses = PDENonStatioComponents(
|
|
796
|
+
partial_mses.dyn_loss,
|
|
797
|
+
partial_mses.norm_loss,
|
|
798
|
+
partial_mses.boundary_loss,
|
|
799
|
+
partial_mses.observations,
|
|
800
|
+
mse_initial_condition,
|
|
801
|
+
)
|
|
802
|
+
|
|
803
|
+
grads = PDENonStatioComponents(
|
|
804
|
+
partial_grads.dyn_loss,
|
|
805
|
+
partial_grads.norm_loss,
|
|
806
|
+
partial_grads.boundary_loss,
|
|
807
|
+
partial_grads.observations,
|
|
808
|
+
grad_initial_condition,
|
|
809
|
+
)
|
|
810
|
+
|
|
811
|
+
return mses, grads
|
|
812
|
+
|
|
813
|
+
def evaluate(
|
|
814
|
+
self, params: Params[Array], batch: PDENonStatioBatch
|
|
815
|
+
) -> tuple[Float[Array, " "], PDENonStatioComponents[Float[Array, " "] | None]]:
|
|
816
|
+
"""
|
|
817
|
+
Evaluate the loss function at a batch of points for given parameters.
|
|
818
|
+
We retrieve the total value itself and a PyTree with loss values for each term
|
|
747
819
|
|
|
748
|
-
# total loss
|
|
749
|
-
total_loss = partial_mse + mse_initial_condition
|
|
750
820
|
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
821
|
+
Parameters
|
|
822
|
+
---------
|
|
823
|
+
params
|
|
824
|
+
Parameters at which the loss is evaluated
|
|
825
|
+
batch
|
|
826
|
+
Composed of a batch of points in
|
|
827
|
+
the domain, a batch of points in the domain
|
|
828
|
+
border, a batch of time points and an optional additional batch
|
|
829
|
+
of parameters (eg. for metamodeling) and an optional additional batch of observed
|
|
830
|
+
inputs/outputs/parameters
|
|
831
|
+
"""
|
|
832
|
+
return super().evaluate(params, batch) # type: ignore
|
jinns/loss/__init__.py
CHANGED
|
@@ -14,6 +14,7 @@ from ._loss_weights import (
|
|
|
14
14
|
LossWeightsPDENonStatio,
|
|
15
15
|
LossWeightsPDEStatio,
|
|
16
16
|
)
|
|
17
|
+
from ._loss_weight_updates import soft_adapt, lr_annealing, ReLoBRaLo
|
|
17
18
|
|
|
18
19
|
from ._operators import (
|
|
19
20
|
divergence_fwd,
|
|
@@ -47,4 +48,7 @@ __all__ = [
|
|
|
47
48
|
"laplacian_rev",
|
|
48
49
|
"vectorial_laplacian_fwd",
|
|
49
50
|
"vectorial_laplacian_rev",
|
|
51
|
+
"soft_adapt",
|
|
52
|
+
"lr_annealing",
|
|
53
|
+
"ReLoBRaLo",
|
|
50
54
|
]
|
jinns/loss/_abstract_loss.py
CHANGED
|
@@ -1,15 +1,128 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import abc
|
|
2
|
-
from
|
|
4
|
+
from typing import TYPE_CHECKING, Self, Literal, Callable
|
|
5
|
+
from jaxtyping import Array, PyTree, Key
|
|
3
6
|
import equinox as eqx
|
|
7
|
+
import jax
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
import optax
|
|
10
|
+
from jinns.loss._loss_weights import AbstractLossWeights
|
|
11
|
+
from jinns.parameters._params import Params
|
|
12
|
+
from jinns.loss._loss_weight_updates import soft_adapt, lr_annealing, ReLoBRaLo
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from jinns.utils._types import AnyLossComponents, AnyBatch
|
|
4
16
|
|
|
5
17
|
|
|
6
18
|
class AbstractLoss(eqx.Module):
|
|
7
19
|
"""
|
|
8
|
-
|
|
9
|
-
The way to go for correct type hints apparently
|
|
20
|
+
About the call:
|
|
10
21
|
https://github.com/patrick-kidger/equinox/issues/1002 + https://docs.kidger.site/equinox/pattern/
|
|
11
22
|
"""
|
|
12
23
|
|
|
24
|
+
loss_weights: AbstractLossWeights
|
|
25
|
+
update_weight_method: Literal["soft_adapt", "lr_annealing", "ReLoBRaLo"] | None = (
|
|
26
|
+
eqx.field(kw_only=True, default=None, static=True)
|
|
27
|
+
)
|
|
28
|
+
|
|
13
29
|
@abc.abstractmethod
|
|
14
30
|
def __call__(self, *_, **__) -> Array:
|
|
15
31
|
pass
|
|
32
|
+
|
|
33
|
+
@abc.abstractmethod
|
|
34
|
+
def evaluate_by_terms(
|
|
35
|
+
self, params: Params[Array], batch: AnyBatch
|
|
36
|
+
) -> tuple[AnyLossComponents, AnyLossComponents]:
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
def get_gradients(
|
|
40
|
+
self, fun: Callable[[Params[Array]], Array], params: Params[Array]
|
|
41
|
+
) -> tuple[Array, Array]:
|
|
42
|
+
"""
|
|
43
|
+
params already filtered with derivative keys here
|
|
44
|
+
"""
|
|
45
|
+
if fun is None:
|
|
46
|
+
return None, None
|
|
47
|
+
value_grad_loss = jax.value_and_grad(fun)
|
|
48
|
+
loss_val, grads = value_grad_loss(params)
|
|
49
|
+
return loss_val, grads
|
|
50
|
+
|
|
51
|
+
def ponderate_and_sum_loss(self, terms):
|
|
52
|
+
"""
|
|
53
|
+
Get total loss from individual loss terms and weights
|
|
54
|
+
|
|
55
|
+
tree.leaves is needed to get rid of None from non used loss terms
|
|
56
|
+
"""
|
|
57
|
+
weights = jax.tree.leaves(
|
|
58
|
+
self.loss_weights,
|
|
59
|
+
is_leaf=lambda x: eqx.is_inexact_array(x) and x is not None,
|
|
60
|
+
)
|
|
61
|
+
terms = jax.tree.leaves(
|
|
62
|
+
terms, is_leaf=lambda x: eqx.is_inexact_array(x) and x is not None
|
|
63
|
+
)
|
|
64
|
+
if len(weights) == len(terms):
|
|
65
|
+
return jnp.sum(jnp.array(weights) * jnp.array(terms))
|
|
66
|
+
else:
|
|
67
|
+
raise ValueError(
|
|
68
|
+
"The numbers of declared loss weights and "
|
|
69
|
+
"declared loss terms do not concord "
|
|
70
|
+
f" got {len(weights)} and {len(terms)}"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def ponderate_and_sum_gradient(self, terms):
|
|
74
|
+
"""
|
|
75
|
+
Get total gradients from individual loss gradients and weights
|
|
76
|
+
for each parameter
|
|
77
|
+
|
|
78
|
+
tree.leaves is needed to get rid of None from non used loss terms
|
|
79
|
+
"""
|
|
80
|
+
weights = jax.tree.leaves(
|
|
81
|
+
self.loss_weights,
|
|
82
|
+
is_leaf=lambda x: eqx.is_inexact_array(x) and x is not None,
|
|
83
|
+
)
|
|
84
|
+
grads = jax.tree.leaves(terms, is_leaf=lambda x: isinstance(x, Params))
|
|
85
|
+
# gradient terms for each individual loss for each parameter (several
|
|
86
|
+
# Params structures)
|
|
87
|
+
weights_pytree = jax.tree.map(
|
|
88
|
+
lambda w: optax.tree_utils.tree_full_like(grads[0], w), weights
|
|
89
|
+
) # We need several Params structures full of the weight scalar
|
|
90
|
+
weighted_grads = jax.tree.map(
|
|
91
|
+
lambda w, p: w * p, weights_pytree, grads, is_leaf=eqx.is_inexact_array
|
|
92
|
+
) # Now we can multiply
|
|
93
|
+
return jax.tree.map(
|
|
94
|
+
lambda *grads: jnp.sum(jnp.array(grads), axis=0),
|
|
95
|
+
*weighted_grads,
|
|
96
|
+
is_leaf=eqx.is_inexact_array,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def update_weights(
|
|
100
|
+
self: Self,
|
|
101
|
+
iteration_nb: int,
|
|
102
|
+
loss_terms: PyTree,
|
|
103
|
+
stored_loss_terms: PyTree,
|
|
104
|
+
grad_terms: PyTree,
|
|
105
|
+
key: Key,
|
|
106
|
+
) -> Self:
|
|
107
|
+
"""
|
|
108
|
+
Update the loss weights according to a predefined scheme
|
|
109
|
+
"""
|
|
110
|
+
if self.update_weight_method == "soft_adapt":
|
|
111
|
+
new_weights = soft_adapt(
|
|
112
|
+
self.loss_weights, iteration_nb, loss_terms, stored_loss_terms
|
|
113
|
+
)
|
|
114
|
+
elif self.update_weight_method == "lr_annealing":
|
|
115
|
+
new_weights = lr_annealing(self.loss_weights, grad_terms)
|
|
116
|
+
elif self.update_weight_method == "ReLoBRaLo":
|
|
117
|
+
new_weights = ReLoBRaLo(
|
|
118
|
+
self.loss_weights, iteration_nb, loss_terms, stored_loss_terms, key
|
|
119
|
+
)
|
|
120
|
+
else:
|
|
121
|
+
raise ValueError("update_weight_method for loss weights not implemented")
|
|
122
|
+
|
|
123
|
+
# Below we update the non None entry in the PyTree self.loss_weights
|
|
124
|
+
# we directly get the non None entries because None is not treated as a
|
|
125
|
+
# leaf
|
|
126
|
+
return eqx.tree_at(
|
|
127
|
+
lambda pt: jax.tree.leaves(pt.loss_weights), self, new_weights
|
|
128
|
+
)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from typing import TypeVar, Generic
|
|
2
|
+
from dataclasses import fields
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
|
|
5
|
+
T = TypeVar("T")
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class XDEComponentsAbstract(eqx.Module, Generic[T]):
|
|
9
|
+
"""
|
|
10
|
+
Provides a template for ODE components with generic types.
|
|
11
|
+
One can inherit to specialize and add methods and attributes
|
|
12
|
+
We do not enforce keyword only to avoid being to verbose (this then can
|
|
13
|
+
work like a tuple)
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def items(self):
|
|
17
|
+
"""
|
|
18
|
+
For the dataclass to be iterated like a dictionary.
|
|
19
|
+
Practical and retrocompatible with old code when loss components were
|
|
20
|
+
dictionaries
|
|
21
|
+
"""
|
|
22
|
+
return {
|
|
23
|
+
field.name: getattr(self, field.name)
|
|
24
|
+
for field in fields(self)
|
|
25
|
+
if getattr(self, field.name) is not None
|
|
26
|
+
}.items()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ODEComponents(XDEComponentsAbstract[T]):
|
|
30
|
+
dyn_loss: T
|
|
31
|
+
initial_condition: T
|
|
32
|
+
observations: T
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class PDEStatioComponents(XDEComponentsAbstract[T]):
|
|
36
|
+
dyn_loss: T
|
|
37
|
+
norm_loss: T
|
|
38
|
+
boundary_loss: T
|
|
39
|
+
observations: T
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class PDENonStatioComponents(PDEStatioComponents[T]):
|
|
43
|
+
initial_condition: T
|