jinns 1.7.0__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/loss/_LossPDE.py CHANGED
@@ -122,7 +122,7 @@ class _LossPDEAbstract(
122
122
  made broadcastable to `norm_samples`.
123
123
  These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
124
124
  $q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
125
- obs_slice : EllipsisType | slice, default=None
125
+ obs_slice : tuple[EllipsisType | slice, ...] | EllipsisType | slice | None, default=None
126
126
  slice object specifying the begininning/ending of the PINN output
127
127
  that is observed (this is then useful for multidim PINN). Default is None.
128
128
  key : Key | None
@@ -136,7 +136,7 @@ class _LossPDEAbstract(
136
136
  # NOTE static=True only for leaf attributes that are not valid JAX types
137
137
  # (ie. jax.Array cannot be static) and that we do not expect to change
138
138
  u: eqx.AbstractVar[AbstractPINN]
139
- dynamic_loss: eqx.AbstractVar[Y]
139
+ dynamic_loss: tuple[eqx.AbstractVar[Y] | None, ...]
140
140
  omega_boundary_fun: (
141
141
  BoundaryConditionFun | dict[str, BoundaryConditionFun] | None
142
142
  ) = eqx.field(static=True)
@@ -146,7 +146,7 @@ class _LossPDEAbstract(
146
146
  omega_boundary_dim: slice | dict[str, slice] = eqx.field(static=True)
147
147
  norm_samples: Float[Array, " nb_norm_samples dimension"] | None
148
148
  norm_weights: Float[Array, " nb_norm_samples"] | None
149
- obs_slice: EllipsisType | slice = eqx.field(static=True)
149
+ obs_slice: tuple[EllipsisType | slice, ...] = eqx.field(static=True)
150
150
  key: PRNGKeyArray | None
151
151
 
152
152
  def __init__(
@@ -159,7 +159,10 @@ class _LossPDEAbstract(
159
159
  omega_boundary_dim: int | slice | dict[str, int | slice] | None = None,
160
160
  norm_samples: Float[Array, " nb_norm_samples dimension"] | None = None,
161
161
  norm_weights: Float[Array, " nb_norm_samples"] | float | int | None = None,
162
- obs_slice: EllipsisType | slice | None = None,
162
+ obs_slice: tuple[EllipsisType | slice, ...]
163
+ | EllipsisType
164
+ | slice
165
+ | None = None,
163
166
  key: PRNGKeyArray | None = None,
164
167
  derivative_keys: DKPDE,
165
168
  **kwargs: Any, # for arguments for super()
@@ -181,7 +184,9 @@ class _LossPDEAbstract(
181
184
  )
182
185
 
183
186
  if obs_slice is None:
184
- self.obs_slice = jnp.s_[...]
187
+ self.obs_slice = (jnp.s_[...],)
188
+ elif not isinstance(obs_slice, tuple):
189
+ self.obs_slice = (obs_slice,)
185
190
  else:
186
191
  self.obs_slice = obs_slice
187
192
 
@@ -321,16 +326,23 @@ class _LossPDEAbstract(
321
326
 
322
327
  def _get_dyn_loss_fun(
323
328
  self, batch: B, vmap_in_axes_params: tuple[Params[int | None] | None]
324
- ) -> Callable[[Params[Array]], Array] | None:
325
- if self.dynamic_loss is not None:
326
- dyn_loss_eval = self.dynamic_loss.evaluate
327
- dyn_loss_fun: Callable[[Params[Array]], Array] | None = (
328
- lambda p: dynamic_loss_apply(
329
- dyn_loss_eval,
330
- self.u,
331
- self._get_dynamic_loss_batch(batch),
332
- _set_derivatives(p, self.derivative_keys.dyn_loss),
333
- self.vmap_in_axes + vmap_in_axes_params,
329
+ ) -> tuple[Callable[[Params[Array]], Array], ...] | None:
330
+ # Note, for the record, multiple dynamic losses
331
+ # have been introduced in MR 92
332
+ if self.dynamic_loss != (None,):
333
+ dyn_loss_fun: tuple[Callable[[Params[Array]], Array], ...] | None = (
334
+ jax.tree.map(
335
+ lambda d: lambda p: dynamic_loss_apply(
336
+ d.evaluate,
337
+ self.u,
338
+ self._get_dynamic_loss_batch(batch),
339
+ _set_derivatives(p, self.derivative_keys.dyn_loss),
340
+ self.vmap_in_axes + vmap_in_axes_params,
341
+ ),
342
+ self.dynamic_loss,
343
+ is_leaf=lambda x: isinstance(
344
+ x, (PDEStatio, PDENonStatio)
345
+ ), # do not traverse further than first level
334
346
  )
335
347
  )
336
348
  else:
@@ -340,9 +352,13 @@ class _LossPDEAbstract(
340
352
 
341
353
  def _get_norm_loss_fun(
342
354
  self, batch: B, vmap_in_axes_params: tuple[Params[int | None] | None]
343
- ) -> Callable[[Params[Array]], Array] | None:
355
+ ) -> tuple[Callable[[Params[Array]], Array], ...] | None:
356
+ # Note that since MR 92
357
+ # norm_loss_fun is formed as a tuple for
358
+ # consistency with dynamic and observation
359
+ # losses and more modularity for later
344
360
  if self.norm_samples is not None:
345
- norm_loss_fun: Callable[[Params[Array]], Array] | None = (
361
+ norm_loss_fun: tuple[Callable[[Params[Array]], Array], ...] | None = (
346
362
  lambda p: normalization_loss_apply(
347
363
  self.u,
348
364
  cast(
@@ -351,7 +367,7 @@ class _LossPDEAbstract(
351
367
  _set_derivatives(p, self.derivative_keys.norm_loss),
352
368
  vmap_in_axes_params,
353
369
  self.norm_weights, # type: ignore -> can't get the __post_init__ narrowing here
354
- )
370
+ ),
355
371
  )
356
372
  else:
357
373
  norm_loss_fun = None
@@ -359,20 +375,24 @@ class _LossPDEAbstract(
359
375
 
360
376
  def _get_boundary_loss_fun(
361
377
  self, batch: B
362
- ) -> Callable[[Params[Array]], Array] | None:
378
+ ) -> tuple[Callable[[Params[Array]], Array], ...] | None:
379
+ # Note that since MR 92
380
+ # boundary_loss_fun is formed as a tuple for
381
+ # consistency with dynamic and observation
382
+ # losses and more modularity for later
363
383
  if (
364
384
  self.omega_boundary_condition is not None
365
385
  and self.omega_boundary_fun is not None
366
386
  ):
367
- boundary_loss_fun: Callable[[Params[Array]], Array] | None = (
387
+ boundary_loss_fun: tuple[Callable[[Params[Array]], Array], ...] | None = (
368
388
  lambda p: boundary_condition_apply(
369
389
  self.u,
370
390
  batch,
371
391
  _set_derivatives(p, self.derivative_keys.boundary_loss),
372
392
  self.omega_boundary_fun, # type: ignore (we are in lambda)
373
393
  self.omega_boundary_condition, # type: ignore
374
- self.omega_boundary_dim, # type: ignore
375
- )
394
+ self.omega_boundary_dim,
395
+ ),
376
396
  )
377
397
  else:
378
398
  boundary_loss_fun = None
@@ -384,24 +404,37 @@ class _LossPDEAbstract(
384
404
  batch: B,
385
405
  vmap_in_axes_params: tuple[Params[int | None] | None],
386
406
  params: Params[Array],
387
- ) -> tuple[Params[Array] | None, Callable[[Params[Array]], Array] | None]:
407
+ ) -> tuple[
408
+ Params[Array] | None, tuple[Callable[[Params[Array]], Array], ...] | None
409
+ ]:
410
+ # Note, for the record, multiple DGObs
411
+ # (leading to batch.obs_batch_dict being tuple | None)
412
+ # have been introduced in MR 92
388
413
  if batch.obs_batch_dict is not None:
389
- # update params with the batches of observed params
390
- params_obs = update_eq_params(params, batch.obs_batch_dict["eq_params"])
391
-
392
- pinn_in, val = (
393
- batch.obs_batch_dict["pinn_in"],
394
- batch.obs_batch_dict["val"],
414
+ if len(batch.obs_batch_dict) != len(self.obs_slice):
415
+ raise ValueError(
416
+ "There must be the same number of "
417
+ "observation datasets as the number of "
418
+ "obs_slice"
419
+ )
420
+ params_obs = jax.tree.map(
421
+ lambda d: update_eq_params(params, d["eq_params"]),
422
+ batch.obs_batch_dict,
423
+ is_leaf=lambda x: isinstance(x, dict),
395
424
  )
396
-
397
- obs_loss_fun: Callable[[Params[Array]], Array] | None = (
398
- lambda po: observations_loss_apply(
399
- self.u,
400
- pinn_in,
401
- _set_derivatives(po, self.derivative_keys.observations),
402
- self.vmap_in_axes + vmap_in_axes_params,
403
- val,
425
+ obs_loss_fun: tuple[Callable[[Params[Array]], Array], ...] | None = (
426
+ jax.tree.map(
427
+ lambda d, slice_: lambda po: observations_loss_apply(
428
+ self.u,
429
+ d["pinn_in"],
430
+ _set_derivatives(po, self.derivative_keys.observations),
431
+ self.vmap_in_axes + vmap_in_axes_params,
432
+ d["val"],
433
+ slice_,
434
+ ),
435
+ batch.obs_batch_dict,
404
436
  self.obs_slice,
437
+ is_leaf=lambda x: isinstance(x, dict),
405
438
  )
406
439
  )
407
440
  else:
@@ -434,7 +467,7 @@ class LossPDEStatio(
434
467
  ----------
435
468
  u : AbstractPINN
436
469
  the PINN
437
- dynamic_loss : PDEStatio | None
470
+ dynamic_loss : tuple[PDEStatio, ...] | PDEStatio | None
438
471
  the stationary PDE dynamic part of the loss, basically the differential
439
472
  operator $\mathcal{N}[u](x)$. Should implement a method
440
473
  `dynamic_loss.evaluate(x, u, params)`.
@@ -497,7 +530,7 @@ class LossPDEStatio(
497
530
  Alternatively, the user can pass a float or an integer.
498
531
  These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
499
532
  $q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
500
- obs_slice : slice, default=None
533
+ obs_slice : tuple[EllipsisType | slice, ...] | EllipsisType | slice | None, default=None
501
534
  slice object specifying the begininning/ending of the PINN output
502
535
  that is observed (this is then useful for multidim PINN). Default is None.
503
536
 
@@ -512,7 +545,7 @@ class LossPDEStatio(
512
545
  # (ie. jax.Array cannot be static) and that we do not expect to change
513
546
 
514
547
  u: AbstractPINN
515
- dynamic_loss: PDEStatio | None
548
+ dynamic_loss: tuple[PDEStatio | None, ...]
516
549
  loss_weights: LossWeightsPDEStatio
517
550
  derivative_keys: DerivativeKeysPDEStatio
518
551
 
@@ -522,7 +555,7 @@ class LossPDEStatio(
522
555
  self,
523
556
  *,
524
557
  u: AbstractPINN,
525
- dynamic_loss: PDEStatio | None,
558
+ dynamic_loss: tuple[PDEStatio, ...] | PDEStatio | None,
526
559
  loss_weights: LossWeightsPDEStatio | None = None,
527
560
  derivative_keys: DerivativeKeysPDEStatio | None = None,
528
561
  params: Params[Array] | None = None,
@@ -543,15 +576,16 @@ class LossPDEStatio(
543
576
  "Problem at derivative_keys initialization "
544
577
  f"received {derivative_keys=} and {params=}"
545
578
  ) from exc
546
- else:
547
- derivative_keys = derivative_keys
548
579
 
549
580
  super().__init__(
550
581
  derivative_keys=derivative_keys,
551
582
  vmap_in_axes=(0,),
552
583
  **kwargs,
553
584
  )
554
- self.dynamic_loss = dynamic_loss
585
+ if not isinstance(dynamic_loss, tuple):
586
+ self.dynamic_loss = (dynamic_loss,)
587
+ else:
588
+ self.dynamic_loss = dynamic_loss
555
589
 
556
590
  def _get_dynamic_loss_batch(
557
591
  self, batch: PDEStatioBatch
@@ -619,29 +653,44 @@ class LossPDEStatio(
619
653
  )
620
654
 
621
655
  # get the unweighted mses for each loss term as well as the gradients
622
- all_funs: PDEStatioComponents[Callable[[Params[Array]], Array] | None] = (
656
+ all_funs: PDEStatioComponents[
657
+ tuple[Callable[[Params[Array]], Array], ...] | None
658
+ ] = PDEStatioComponents(
659
+ dyn_loss_fun,
660
+ norm_loss_fun,
661
+ boundary_loss_fun,
662
+ obs_loss_fun,
663
+ )
664
+ all_params: PDEStatioComponents[tuple[Params[Array], ...] | None] = (
623
665
  PDEStatioComponents(
624
- dyn_loss_fun, norm_loss_fun, boundary_loss_fun, obs_loss_fun
666
+ jax.tree.map(lambda l: params, dyn_loss_fun),
667
+ jax.tree.map(lambda l: params, norm_loss_fun),
668
+ jax.tree.map(lambda l: params, boundary_loss_fun),
669
+ params_obs,
625
670
  )
626
671
  )
627
- all_params: PDEStatioComponents[Params[Array] | None] = PDEStatioComponents(
628
- params, params, params, params_obs
629
- )
630
672
  mses_grads = jax.tree.map(
631
673
  self.get_gradients,
632
674
  all_funs,
633
675
  all_params,
634
676
  is_leaf=lambda x: x is None,
635
677
  )
678
+ # NOTE the is_leaf below is more complex since it must pass possible the tuple
679
+ # of dyn_loss and then stops (but also account it should not stop when
680
+ # the tuple of dyn_loss is of length 2)
636
681
  mses = jax.tree.map(
637
- lambda leaf: leaf[0], # type: ignore
682
+ lambda leaf: leaf[0],
638
683
  mses_grads,
639
- is_leaf=lambda x: isinstance(x, tuple),
684
+ is_leaf=lambda x: isinstance(x, tuple)
685
+ and len(x) == 2
686
+ and isinstance(x[1], Params),
640
687
  )
641
688
  grads = jax.tree.map(
642
- lambda leaf: leaf[1], # type: ignore
689
+ lambda leaf: leaf[1],
643
690
  mses_grads,
644
- is_leaf=lambda x: isinstance(x, tuple),
691
+ is_leaf=lambda x: isinstance(x, tuple)
692
+ and len(x) == 2
693
+ and isinstance(x[1], Params),
645
694
  )
646
695
 
647
696
  return mses, grads
@@ -673,7 +722,7 @@ class LossPDENonStatio(
673
722
  ----------
674
723
  u : AbstractPINN
675
724
  the PINN
676
- dynamic_loss : PDENonStatio
725
+ dynamic_loss : tuple[PDENonStatio, ...] | PDENonStatio | None
677
726
  the non stationary PDE dynamic part of the loss, basically the differential
678
727
  operator $\mathcal{N}[u](t, x)$. Should implement a method
679
728
  `dynamic_loss.evaluate(t, x, u, params)`.
@@ -750,14 +799,13 @@ class LossPDENonStatio(
750
799
  Alternatively, the user can pass a float or an integer.
751
800
  These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
752
801
  $q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
753
- obs_slice : slice, default=None
802
+ obs_slice : tuple[EllipsisType | slice, ...] | EllipsisType | slice | None, default=None
754
803
  slice object specifying the begininning/ending of the PINN output
755
804
  that is observed (this is then useful for multidim PINN). Default is None.
756
-
757
805
  """
758
806
 
759
807
  u: AbstractPINN
760
- dynamic_loss: PDENonStatio | None
808
+ dynamic_loss: tuple[PDENonStatio | None, ...]
761
809
  loss_weights: LossWeightsPDENonStatio
762
810
  derivative_keys: DerivativeKeysPDENonStatio
763
811
  params: InitVar[Params[Array] | None]
@@ -774,7 +822,7 @@ class LossPDENonStatio(
774
822
  self,
775
823
  *,
776
824
  u: AbstractPINN,
777
- dynamic_loss: PDENonStatio | None,
825
+ dynamic_loss: tuple[PDENonStatio, ...] | PDENonStatio | None,
778
826
  loss_weights: LossWeightsPDENonStatio | None = None,
779
827
  derivative_keys: DerivativeKeysPDENonStatio | None = None,
780
828
  initial_condition_fun: Callable[[Float[Array, " dimension"]], Array]
@@ -800,8 +848,6 @@ class LossPDENonStatio(
800
848
  "Problem at derivative_keys initialization "
801
849
  f"received {derivative_keys=} and {params=}"
802
850
  ) from exc
803
- else:
804
- derivative_keys = derivative_keys
805
851
 
806
852
  super().__init__(
807
853
  derivative_keys=derivative_keys,
@@ -809,7 +855,10 @@ class LossPDENonStatio(
809
855
  **kwargs,
810
856
  )
811
857
 
812
- self.dynamic_loss = dynamic_loss
858
+ if not isinstance(dynamic_loss, tuple):
859
+ self.dynamic_loss = (dynamic_loss,)
860
+ else:
861
+ self.dynamic_loss = dynamic_loss
813
862
 
814
863
  if initial_condition_fun is None:
815
864
  warnings.warn(
@@ -904,7 +953,13 @@ class LossPDENonStatio(
904
953
 
905
954
  # initial condition
906
955
  if self.initial_condition_fun is not None:
907
- mse_initial_condition_fun: Callable[[Params[Array]], Array] | None = (
956
+ # Note that since MR 92
957
+ # initial_condition_fun is formed as a tuple for
958
+ # consistency with dynamic and observation
959
+ # losses and more modularity for later
960
+ initial_condition_fun: (
961
+ tuple[Callable[[Params[Array]], Array], ...] | None
962
+ ) = (
908
963
  lambda p: initial_condition_apply(
909
964
  self.u,
910
965
  omega_initial_batch,
@@ -912,39 +967,52 @@ class LossPDENonStatio(
912
967
  (0,) + vmap_in_axes_params,
913
968
  self.initial_condition_fun, # type: ignore
914
969
  self.t0,
915
- )
970
+ ),
916
971
  )
917
972
  else:
918
- mse_initial_condition_fun = None
973
+ initial_condition_fun = None
919
974
 
920
975
  # get the unweighted mses for each loss term as well as the gradients
921
- all_funs: PDENonStatioComponents[Callable[[Params[Array]], Array] | None] = (
976
+ all_funs: PDENonStatioComponents[
977
+ tuple[Callable[[Params[Array]], Array], ...] | None
978
+ ] = PDENonStatioComponents(
979
+ dyn_loss_fun,
980
+ norm_loss_fun,
981
+ boundary_loss_fun,
982
+ obs_loss_fun,
983
+ initial_condition_fun,
984
+ )
985
+ all_params: PDENonStatioComponents[tuple[Params[Array], ...] | None] = (
922
986
  PDENonStatioComponents(
923
- dyn_loss_fun,
924
- norm_loss_fun,
925
- boundary_loss_fun,
926
- obs_loss_fun,
927
- mse_initial_condition_fun,
987
+ jax.tree.map(lambda l: params, dyn_loss_fun),
988
+ jax.tree.map(lambda l: params, norm_loss_fun),
989
+ jax.tree.map(lambda l: params, boundary_loss_fun),
990
+ params_obs,
991
+ jax.tree.map(lambda l: params, initial_condition_fun),
928
992
  )
929
993
  )
930
- all_params: PDENonStatioComponents[Params[Array] | None] = (
931
- PDENonStatioComponents(params, params, params, params_obs, params)
932
- )
933
994
  mses_grads = jax.tree.map(
934
995
  self.get_gradients,
935
996
  all_funs,
936
997
  all_params,
937
998
  is_leaf=lambda x: x is None,
938
999
  )
1000
+ # NOTE the is_leaf below is more complex since it must pass possible the tuple
1001
+ # of dyn_loss and then stops (but also account it should not stop when
1002
+ # the tuple of dyn_loss is of length 2)
939
1003
  mses = jax.tree.map(
940
- lambda leaf: leaf[0], # type: ignore
1004
+ lambda leaf: leaf[0],
941
1005
  mses_grads,
942
- is_leaf=lambda x: isinstance(x, tuple),
1006
+ is_leaf=lambda x: isinstance(x, tuple)
1007
+ and len(x) == 2
1008
+ and isinstance(x[1], Params),
943
1009
  )
944
1010
  grads = jax.tree.map(
945
- lambda leaf: leaf[1], # type: ignore
1011
+ lambda leaf: leaf[1],
946
1012
  mses_grads,
947
- is_leaf=lambda x: isinstance(x, tuple),
1013
+ is_leaf=lambda x: isinstance(x, tuple)
1014
+ and len(x) == 2
1015
+ and isinstance(x[1], Params),
948
1016
  )
949
1017
 
950
1018
  return mses, grads
@@ -1,14 +1,21 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import abc
4
- from typing import Self, Literal, Callable, TypeVar, Generic, Any
5
- from jaxtyping import PRNGKeyArray, Array, PyTree, Float
4
+ import warnings
5
+ from typing import Self, Literal, Callable, TypeVar, Generic, Any, get_args
6
+ from dataclasses import InitVar
7
+ from jaxtyping import Array, PyTree, Float, PRNGKeyArray
6
8
  import equinox as eqx
7
9
  import jax
8
10
  import jax.numpy as jnp
9
11
  import optax
10
12
  from jinns.parameters._params import Params
11
- from jinns.loss._loss_weight_updates import soft_adapt, lr_annealing, ReLoBRaLo
13
+ from jinns.loss._loss_weight_updates import (
14
+ soft_adapt,
15
+ lr_annealing,
16
+ ReLoBRaLo,
17
+ prior_loss,
18
+ )
12
19
  from jinns.utils._types import (
13
20
  AnyLossComponents,
14
21
  AnyBatch,
@@ -38,6 +45,11 @@ DK = TypeVar("DK", bound=AnyDerivativeKeys)
38
45
  # the return types of evaluate_by_terms for example!
39
46
 
40
47
 
48
+ AvailableUpdateWeightMethods = Literal[
49
+ "softadapt", "soft_adapt", "prior_loss", "lr_annealing", "ReLoBRaLo"
50
+ ]
51
+
52
+
41
53
  class AbstractLoss(eqx.Module, Generic[L, B, C, DK]):
42
54
  """
43
55
  About the call:
@@ -46,10 +58,43 @@ class AbstractLoss(eqx.Module, Generic[L, B, C, DK]):
46
58
 
47
59
  derivative_keys: eqx.AbstractVar[DK]
48
60
  loss_weights: eqx.AbstractVar[L]
49
- update_weight_method: Literal["soft_adapt", "lr_annealing", "ReLoBRaLo"] | None = (
50
- eqx.field(kw_only=True, default=None, static=True)
61
+ loss_weight_scales: L = eqx.field(init=False)
62
+ update_weight_method: AvailableUpdateWeightMethods | None = eqx.field(
63
+ kw_only=True, default=None, static=True
51
64
  )
52
65
  vmap_in_axes: tuple[int] = eqx.field(static=True)
66
+ keep_initial_loss_weight_scales: InitVar[bool] = eqx.field(
67
+ default=True, kw_only=True
68
+ )
69
+
70
+ def __init__(
71
+ self,
72
+ *,
73
+ loss_weights,
74
+ derivative_keys,
75
+ vmap_in_axes,
76
+ update_weight_method=None,
77
+ keep_initial_loss_weight_scales: bool = False,
78
+ ):
79
+ if update_weight_method is not None and update_weight_method not in get_args(
80
+ AvailableUpdateWeightMethods
81
+ ):
82
+ raise ValueError(f"{update_weight_method=} is not a valid method")
83
+ self.update_weight_method = update_weight_method
84
+ self.loss_weights = loss_weights
85
+ self.derivative_keys = derivative_keys
86
+ self.vmap_in_axes = vmap_in_axes
87
+ if keep_initial_loss_weight_scales:
88
+ self.loss_weight_scales = self.loss_weights
89
+ if self.update_weight_method is not None:
90
+ warnings.warn(
91
+ "Loss weights out from update_weight_method will still be"
92
+ " multiplied by the initial input loss_weights"
93
+ )
94
+ else:
95
+ self.loss_weight_scales = optax.tree_utils.tree_ones_like(self.loss_weights)
96
+ # self.loss_weight_scales will contain None where self.loss_weights
97
+ # has None
53
98
 
54
99
  def __call__(self, *args: Any, **kwargs: Any) -> Any:
55
100
  return self.evaluate(*args, **kwargs)
@@ -127,7 +172,11 @@ class AbstractLoss(eqx.Module, Generic[L, B, C, DK]):
127
172
  raise ValueError(
128
173
  "The numbers of declared loss weights and "
129
174
  "declared loss terms do not concord "
130
- f" got {len(weights)} and {len(terms_list)}"
175
+ f" got {len(weights)} and {len(terms_list)}. "
176
+ "If you passed tuple of dyn_loss, make sure to pass "
177
+ "tuple of loss weights at LossWeights.dyn_loss."
178
+ "If you passed tuple of obs datasets, make sure to pass "
179
+ "tuple of loss weights at LossWeights.observations."
131
180
  )
132
181
 
133
182
  def ponderate_and_sum_gradient(self, terms: C) -> Params[Array | None]:
@@ -171,6 +220,8 @@ class AbstractLoss(eqx.Module, Generic[L, B, C, DK]):
171
220
  new_weights = soft_adapt(
172
221
  self.loss_weights, iteration_nb, loss_terms, stored_loss_terms
173
222
  )
223
+ elif self.update_weight_method == "prior_loss":
224
+ new_weights = prior_loss(self.loss_weights, iteration_nb, stored_loss_terms)
174
225
  elif self.update_weight_method == "lr_annealing":
175
226
  new_weights = lr_annealing(self.loss_weights, grad_terms)
176
227
  elif self.update_weight_method == "ReLoBRaLo":
@@ -183,6 +234,13 @@ class AbstractLoss(eqx.Module, Generic[L, B, C, DK]):
183
234
  # Below we update the non None entry in the PyTree self.loss_weights
184
235
  # we directly get the non None entries because None is not treated as a
185
236
  # leaf
237
+
238
+ new_weights = jax.lax.cond(
239
+ iteration_nb == 0,
240
+ lambda nw: nw,
241
+ lambda nw: jnp.array(jax.tree.leaves(self.loss_weight_scales)) * nw,
242
+ new_weights,
243
+ )
186
244
  return eqx.tree_at(
187
245
  lambda pt: jax.tree.leaves(pt.loss_weights), self, new_weights
188
246
  )
@@ -227,7 +227,7 @@ def boundary_neumann(
227
227
  if isinstance(u, PINN):
228
228
  u_ = lambda inputs, params: jnp.squeeze(u(inputs, params)[dim_to_apply])
229
229
 
230
- if u.eq_type == "statio_PDE":
230
+ if u.eq_type == "PDEStatio":
231
231
  v_neumann = vmap(
232
232
  lambda inputs, params: _subtract_with_check(
233
233
  f(inputs),
@@ -240,7 +240,7 @@ def boundary_neumann(
240
240
  vmap_in_axes,
241
241
  0,
242
242
  )
243
- elif u.eq_type == "nonstatio_PDE":
243
+ elif u.eq_type == "PDENonStatio":
244
244
  v_neumann = vmap(
245
245
  lambda inputs, params: _subtract_with_check(
246
246
  f(inputs),
@@ -274,14 +274,14 @@ def boundary_neumann(
274
274
  if (batch_array.shape[0] == 1 and isinstance(batch, PDEStatioBatch)) or (
275
275
  batch_array.shape[-1] == 2 and isinstance(batch, PDENonStatioBatch)
276
276
  ):
277
- if u.eq_type == "statio_PDE":
277
+ if u.eq_type == "PDEStatio":
278
278
  _, du_dx = jax.jvp(
279
279
  lambda inputs: u(inputs, params)[..., dim_to_apply],
280
280
  (batch_array,),
281
281
  (jnp.ones_like(batch_array),),
282
282
  )
283
283
  values = du_dx * n[facet]
284
- if u.eq_type == "nonstatio_PDE":
284
+ if u.eq_type == "PDENonStatio":
285
285
  _, du_dx = jax.jvp(
286
286
  lambda inputs: u(inputs, params)[..., dim_to_apply],
287
287
  (batch_array,),
@@ -291,7 +291,7 @@ def boundary_neumann(
291
291
  elif (batch_array.shape[-1] == 2 and isinstance(batch, PDEStatioBatch)) or (
292
292
  batch_array.shape[-1] == 3 and isinstance(batch, PDENonStatioBatch)
293
293
  ):
294
- if u.eq_type == "statio_PDE":
294
+ if u.eq_type == "PDEStatio":
295
295
  tangent_vec_0 = jnp.repeat(
296
296
  jnp.array([1.0, 0.0])[None], batch_array.shape[0], axis=0
297
297
  )
@@ -309,7 +309,7 @@ def boundary_neumann(
309
309
  (tangent_vec_1,),
310
310
  )
311
311
  values = du_dx1 * n[0, facet] + du_dx2 * n[1, facet] # dot product
312
- if u.eq_type == "nonstatio_PDE":
312
+ if u.eq_type == "PDENonStatio":
313
313
  tangent_vec_0 = jnp.repeat(
314
314
  jnp.array([0.0, 1.0, 0.0])[None], batch_array.shape[0], axis=0
315
315
  )
jinns/loss/_loss_utils.py CHANGED
@@ -26,12 +26,12 @@ from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
26
26
  from jinns.parameters._params import Params
27
27
 
28
28
  if TYPE_CHECKING:
29
- from jinns.utils._types import BoundaryConditionFun
29
+ from jinns.utils._types import BoundaryConditionFun, AnyBatch
30
30
  from jinns.nn._abstract_pinn import AbstractPINN
31
31
 
32
32
 
33
33
  def dynamic_loss_apply(
34
- dyn_loss: Callable,
34
+ dyn_loss: Callable[[AnyBatch, AbstractPINN, Params[Array]], Array],
35
35
  u: AbstractPINN,
36
36
  batch: (
37
37
  Float[Array, " batch_size 1"]
@@ -13,6 +13,36 @@ if TYPE_CHECKING:
13
13
  from jinns.utils._types import AnyLossComponents, AnyLossWeights
14
14
 
15
15
 
16
+ def prior_loss(
17
+ loss_weights: AnyLossWeights,
18
+ iteration_nb: int,
19
+ stored_loss_terms: AnyLossComponents,
20
+ ) -> Array:
21
+ """
22
+ Simple adaptative weights according to the prior loss idea:
23
+ the ponderation in front of a loss term is given by the inverse of the
24
+ value of that loss term at the previous iteration
25
+ """
26
+
27
+ def do_nothing(loss_weights, _):
28
+ return jnp.array(
29
+ jax.tree.leaves(loss_weights, is_leaf=eqx.is_inexact_array), dtype=float
30
+ )
31
+
32
+ def _prior_loss(_, stored_loss_terms):
33
+ new_weights = jax.tree.map(
34
+ lambda slt: 1 / (slt[iteration_nb - 1] + 1e-6), stored_loss_terms
35
+ )
36
+ return jnp.array(jax.tree.leaves(new_weights), dtype=float)
37
+
38
+ return jax.lax.cond(
39
+ iteration_nb == 0,
40
+ lambda op: do_nothing(*op),
41
+ lambda op: _prior_loss(*op),
42
+ (loss_weights, stored_loss_terms),
43
+ )
44
+
45
+
16
46
  def soft_adapt(
17
47
  loss_weights: AnyLossWeights,
18
48
  iteration_nb: int,
@@ -18,6 +18,10 @@ from jinns.loss._loss_components import (
18
18
  def lw_converter(x: Array | None) -> Array | None:
19
19
  if x is None:
20
20
  return x
21
+ elif isinstance(x, tuple):
22
+ # user might input tuple of scalar loss weights to account for cases
23
+ # when dyn loss is also a tuple of (possibly 1D) dyn_loss
24
+ return tuple(jnp.asarray(x_) for x_ in x)
21
25
  else:
22
26
  return jnp.asarray(x)
23
27