jinns 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +114 -187
  13. jinns/loss/_DynamicLossAbstract.py +74 -69
  14. jinns/loss/_LossODE.py +132 -348
  15. jinns/loss/_LossPDE.py +262 -549
  16. jinns/loss/__init__.py +32 -6
  17. jinns/loss/_abstract_loss.py +128 -0
  18. jinns/loss/_boundary_conditions.py +20 -19
  19. jinns/loss/_loss_components.py +43 -0
  20. jinns/loss/_loss_utils.py +85 -179
  21. jinns/loss/_loss_weight_updates.py +202 -0
  22. jinns/loss/_loss_weights.py +64 -40
  23. jinns/loss/_operators.py +84 -74
  24. jinns/nn/__init__.py +15 -0
  25. jinns/nn/_abstract_pinn.py +22 -0
  26. jinns/nn/_hyperpinn.py +94 -57
  27. jinns/nn/_mlp.py +50 -25
  28. jinns/nn/_pinn.py +33 -19
  29. jinns/nn/_ppinn.py +70 -34
  30. jinns/nn/_save_load.py +21 -51
  31. jinns/nn/_spinn.py +33 -16
  32. jinns/nn/_spinn_mlp.py +28 -22
  33. jinns/nn/_utils.py +38 -0
  34. jinns/parameters/__init__.py +8 -1
  35. jinns/parameters/_derivative_keys.py +116 -177
  36. jinns/parameters/_params.py +18 -46
  37. jinns/plot/__init__.py +2 -0
  38. jinns/plot/_plot.py +35 -34
  39. jinns/solver/_rar.py +80 -63
  40. jinns/solver/_solve.py +207 -92
  41. jinns/solver/_utils.py +4 -6
  42. jinns/utils/__init__.py +2 -0
  43. jinns/utils/_containers.py +16 -10
  44. jinns/utils/_types.py +20 -54
  45. jinns/utils/_utils.py +4 -11
  46. jinns/validation/__init__.py +2 -0
  47. jinns/validation/_validation.py +20 -19
  48. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/METADATA +8 -4
  49. jinns-1.5.0.dist-info/RECORD +55 -0
  50. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/WHEEL +1 -1
  51. jinns/data/_DataGenerators.py +0 -1634
  52. jinns-1.3.0.dist-info/RECORD +0 -44
  53. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/AUTHORS +0 -0
  54. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/LICENSE +0 -0
  55. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/top_level.txt +0 -0
jinns/loss/__init__.py CHANGED
@@ -1,22 +1,20 @@
1
1
  from ._DynamicLossAbstract import DynamicLoss, ODE, PDEStatio, PDENonStatio
2
- from ._LossODE import LossODE, SystemLossODE
3
- from ._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
2
+ from ._LossODE import LossODE
3
+ from ._LossPDE import LossPDEStatio, LossPDENonStatio
4
4
  from ._DynamicLoss import (
5
5
  GeneralizedLotkaVolterra,
6
6
  BurgersEquation,
7
7
  FPENonStatioLoss2D,
8
8
  OU_FPENonStatioLoss2D,
9
9
  FisherKPP,
10
- MassConservation2DStatio,
11
- NavierStokes2DStatio,
10
+ NavierStokesMassConservation2DStatio,
12
11
  )
13
12
  from ._loss_weights import (
14
13
  LossWeightsODE,
15
- LossWeightsODEDict,
16
14
  LossWeightsPDENonStatio,
17
15
  LossWeightsPDEStatio,
18
- LossWeightsPDEDict,
19
16
  )
17
+ from ._loss_weight_updates import soft_adapt, lr_annealing, ReLoBRaLo
20
18
 
21
19
  from ._operators import (
22
20
  divergence_fwd,
@@ -26,3 +24,31 @@ from ._operators import (
26
24
  vectorial_laplacian_fwd,
27
25
  vectorial_laplacian_rev,
28
26
  )
27
+
28
+ __all__ = [
29
+ "DynamicLoss",
30
+ "ODE",
31
+ "PDEStatio",
32
+ "PDENonStatio",
33
+ "LossODE",
34
+ "LossPDEStatio",
35
+ "LossPDENonStatio",
36
+ "GeneralizedLotkaVolterra",
37
+ "BurgersEquation",
38
+ "FPENonStatioLoss2D",
39
+ "OU_FPENonStatioLoss2D",
40
+ "FisherKPP",
41
+ "NavierStokesMassConservation2DStatio",
42
+ "LossWeightsODE",
43
+ "LossWeightsPDEStatio",
44
+ "LossWeightsPDENonStatio",
45
+ "divergence_fwd",
46
+ "divergence_rev",
47
+ "laplacian_fwd",
48
+ "laplacian_rev",
49
+ "vectorial_laplacian_fwd",
50
+ "vectorial_laplacian_rev",
51
+ "soft_adapt",
52
+ "lr_annealing",
53
+ "ReLoBRaLo",
54
+ ]
@@ -0,0 +1,128 @@
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ from typing import TYPE_CHECKING, Self, Literal, Callable
5
+ from jaxtyping import Array, PyTree, Key
6
+ import equinox as eqx
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import optax
10
+ from jinns.loss._loss_weights import AbstractLossWeights
11
+ from jinns.parameters._params import Params
12
+ from jinns.loss._loss_weight_updates import soft_adapt, lr_annealing, ReLoBRaLo
13
+
14
+ if TYPE_CHECKING:
15
+ from jinns.utils._types import AnyLossComponents, AnyBatch
16
+
17
+
18
+ class AbstractLoss(eqx.Module):
19
+ """
20
+ About the call:
21
+ https://github.com/patrick-kidger/equinox/issues/1002 + https://docs.kidger.site/equinox/pattern/
22
+ """
23
+
24
+ loss_weights: AbstractLossWeights
25
+ update_weight_method: Literal["soft_adapt", "lr_annealing", "ReLoBRaLo"] | None = (
26
+ eqx.field(kw_only=True, default=None, static=True)
27
+ )
28
+
29
+ @abc.abstractmethod
30
+ def __call__(self, *_, **__) -> Array:
31
+ pass
32
+
33
+ @abc.abstractmethod
34
+ def evaluate_by_terms(
35
+ self, params: Params[Array], batch: AnyBatch
36
+ ) -> tuple[AnyLossComponents, AnyLossComponents]:
37
+ pass
38
+
39
+ def get_gradients(
40
+ self, fun: Callable[[Params[Array]], Array], params: Params[Array]
41
+ ) -> tuple[Array, Array]:
42
+ """
43
+ params already filtered with derivative keys here
44
+ """
45
+ if fun is None:
46
+ return None, None
47
+ value_grad_loss = jax.value_and_grad(fun)
48
+ loss_val, grads = value_grad_loss(params)
49
+ return loss_val, grads
50
+
51
+ def ponderate_and_sum_loss(self, terms):
52
+ """
53
+ Get total loss from individual loss terms and weights
54
+
55
+ tree.leaves is needed to get rid of None from non used loss terms
56
+ """
57
+ weights = jax.tree.leaves(
58
+ self.loss_weights,
59
+ is_leaf=lambda x: eqx.is_inexact_array(x) and x is not None,
60
+ )
61
+ terms = jax.tree.leaves(
62
+ terms, is_leaf=lambda x: eqx.is_inexact_array(x) and x is not None
63
+ )
64
+ if len(weights) == len(terms):
65
+ return jnp.sum(jnp.array(weights) * jnp.array(terms))
66
+ else:
67
+ raise ValueError(
68
+ "The numbers of declared loss weights and "
69
+ "declared loss terms do not concord "
70
+ f" got {len(weights)} and {len(terms)}"
71
+ )
72
+
73
+ def ponderate_and_sum_gradient(self, terms):
74
+ """
75
+ Get total gradients from individual loss gradients and weights
76
+ for each parameter
77
+
78
+ tree.leaves is needed to get rid of None from non used loss terms
79
+ """
80
+ weights = jax.tree.leaves(
81
+ self.loss_weights,
82
+ is_leaf=lambda x: eqx.is_inexact_array(x) and x is not None,
83
+ )
84
+ grads = jax.tree.leaves(terms, is_leaf=lambda x: isinstance(x, Params))
85
+ # gradient terms for each individual loss for each parameter (several
86
+ # Params structures)
87
+ weights_pytree = jax.tree.map(
88
+ lambda w: optax.tree_utils.tree_full_like(grads[0], w), weights
89
+ ) # We need several Params structures full of the weight scalar
90
+ weighted_grads = jax.tree.map(
91
+ lambda w, p: w * p, weights_pytree, grads, is_leaf=eqx.is_inexact_array
92
+ ) # Now we can multiply
93
+ return jax.tree.map(
94
+ lambda *grads: jnp.sum(jnp.array(grads), axis=0),
95
+ *weighted_grads,
96
+ is_leaf=eqx.is_inexact_array,
97
+ )
98
+
99
+ def update_weights(
100
+ self: Self,
101
+ iteration_nb: int,
102
+ loss_terms: PyTree,
103
+ stored_loss_terms: PyTree,
104
+ grad_terms: PyTree,
105
+ key: Key,
106
+ ) -> Self:
107
+ """
108
+ Update the loss weights according to a predefined scheme
109
+ """
110
+ if self.update_weight_method == "soft_adapt":
111
+ new_weights = soft_adapt(
112
+ self.loss_weights, iteration_nb, loss_terms, stored_loss_terms
113
+ )
114
+ elif self.update_weight_method == "lr_annealing":
115
+ new_weights = lr_annealing(self.loss_weights, grad_terms)
116
+ elif self.update_weight_method == "ReLoBRaLo":
117
+ new_weights = ReLoBRaLo(
118
+ self.loss_weights, iteration_nb, loss_terms, stored_loss_terms, key
119
+ )
120
+ else:
121
+ raise ValueError("update_weight_method for loss weights not implemented")
122
+
123
+ # Below we update the non None entry in the PyTree self.loss_weights
124
+ # we directly get the non None entries because None is not treated as a
125
+ # leaf
126
+ return eqx.tree_at(
127
+ lambda pt: jax.tree.leaves(pt.loss_weights), self, new_weights
128
+ )
@@ -7,31 +7,31 @@ from __future__ import (
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
8
 
9
9
  from typing import TYPE_CHECKING, Callable
10
+ from jaxtyping import Array, Float
10
11
  import jax
11
12
  import jax.numpy as jnp
12
13
  from jax import vmap, grad
13
- import equinox as eqx
14
14
  from jinns.utils._utils import get_grid, _subtract_with_check
15
- from jinns.data._Batchs import *
15
+ from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
16
16
  from jinns.nn._pinn import PINN
17
17
  from jinns.nn._spinn import SPINN
18
18
 
19
19
  if TYPE_CHECKING:
20
- from jinns.utils._types import *
20
+ from jinns.parameters._params import Params
21
+ from jinns.utils._types import BoundaryConditionFun
22
+ from jinns.nn._abstract_pinn import AbstractPINN
21
23
 
22
24
 
23
25
  def _compute_boundary_loss(
24
26
  boundary_condition_type: str,
25
- f: Callable[
26
- [Float[Array, "dim"] | Float[Array, "dim + 1"]], Float[Array, "dim_solution"]
27
- ],
27
+ f: BoundaryConditionFun,
28
28
  batch: PDEStatioBatch | PDENonStatioBatch,
29
- u: eqx.Module,
30
- params: AnyParams,
29
+ u: AbstractPINN,
30
+ params: Params[Array],
31
31
  facet: int,
32
32
  dim_to_apply: slice,
33
33
  vmap_in_axes: tuple,
34
- ) -> float:
34
+ ) -> Float[Array, " "]:
35
35
  r"""A generic function that will compute the mini-batch MSE of a
36
36
  boundary condition in the stationary case, resp. non-stationary, given by:
37
37
 
@@ -67,7 +67,7 @@ def _compute_boundary_loss(
67
67
  u
68
68
  a PINN
69
69
  params
70
- Params or ParamsDict
70
+ Params
71
71
  facet
72
72
  An integer which represents the id of the facet which is currently
73
73
  considered (in the order provided by the DataGenerator which is fixed)
@@ -96,15 +96,15 @@ def _compute_boundary_loss(
96
96
 
97
97
  def boundary_dirichlet(
98
98
  f: Callable[
99
- [Float[Array, "dim"] | Float[Array, "dim + 1"]], Float[Array, "dim_solution"]
99
+ [Float[Array, " dim"] | Float[Array, " dim + 1"]], Float[Array, " dim_solution"]
100
100
  ],
101
101
  batch: PDEStatioBatch | PDENonStatioBatch,
102
- u: eqx.Module,
103
- params: Params | ParamsDict,
102
+ u: AbstractPINN,
103
+ params: Params[Array],
104
104
  facet: int,
105
105
  dim_to_apply: slice,
106
106
  vmap_in_axes: tuple,
107
- ) -> float:
107
+ ) -> Float[Array, " "]:
108
108
  r"""
109
109
  This omega boundary condition enforces a solution that is equal to `f`
110
110
  at `times_batch` x `omega_border` (non stationary case) or at `omega_border`
@@ -135,6 +135,7 @@ def boundary_dirichlet(
135
135
  vmap_in_axes
136
136
  A tuple object which specifies the in_axes of the vmapping
137
137
  """
138
+ assert batch.border_batch is not None
138
139
  batch_array = batch.border_batch
139
140
  batch_array = batch_array[..., facet]
140
141
 
@@ -168,15 +169,15 @@ def boundary_dirichlet(
168
169
 
169
170
  def boundary_neumann(
170
171
  f: Callable[
171
- [Float[Array, "dim"] | Float[Array, "dim + 1"]], Float[Array, "dim_solution"]
172
+ [Float[Array, " dim"] | Float[Array, " dim + 1"]], Float[Array, " dim_solution"]
172
173
  ],
173
174
  batch: PDEStatioBatch | PDENonStatioBatch,
174
- u: eqx.Module,
175
- params: Params | ParamsDict,
175
+ u: AbstractPINN,
176
+ params: Params[Array],
176
177
  facet: int,
177
178
  dim_to_apply: slice,
178
179
  vmap_in_axes: tuple,
179
- ) -> float:
180
+ ) -> Float[Array, " "]:
180
181
  r"""
181
182
  This omega boundary condition enforces a solution where $\nabla u\cdot
182
183
  n$ is equal to `f` at the cartesian product of `time_batch` x `omega
@@ -208,6 +209,7 @@ def boundary_neumann(
208
209
  vmap_in_axes
209
210
  A tuple object which specifies the in_axes of the vmapping
210
211
  """
212
+ assert batch.border_batch is not None
211
213
  batch_array = batch.border_batch
212
214
  batch_array = batch_array[..., facet]
213
215
 
@@ -223,7 +225,6 @@ def boundary_neumann(
223
225
  n = jnp.array([[-1, 1, 0, 0], [0, 0, -1, 1]])
224
226
 
225
227
  if isinstance(u, PINN):
226
-
227
228
  u_ = lambda inputs, params: jnp.squeeze(u(inputs, params)[dim_to_apply])
228
229
 
229
230
  if u.eq_type == "statio_PDE":
@@ -0,0 +1,43 @@
1
+ from typing import TypeVar, Generic
2
+ from dataclasses import fields
3
+ import equinox as eqx
4
+
5
+ T = TypeVar("T")
6
+
7
+
8
+ class XDEComponentsAbstract(eqx.Module, Generic[T]):
9
+ """
10
+ Provides a template for ODE components with generic types.
11
+ One can inherit to specialize and add methods and attributes
12
+ We do not enforce keyword only to avoid being to verbose (this then can
13
+ work like a tuple)
14
+ """
15
+
16
+ def items(self):
17
+ """
18
+ For the dataclass to be iterated like a dictionary.
19
+ Practical and retrocompatible with old code when loss components were
20
+ dictionaries
21
+ """
22
+ return {
23
+ field.name: getattr(self, field.name)
24
+ for field in fields(self)
25
+ if getattr(self, field.name) is not None
26
+ }.items()
27
+
28
+
29
+ class ODEComponents(XDEComponentsAbstract[T]):
30
+ dyn_loss: T
31
+ initial_condition: T
32
+ observations: T
33
+
34
+
35
+ class PDEStatioComponents(XDEComponentsAbstract[T]):
36
+ dyn_loss: T
37
+ norm_loss: T
38
+ boundary_loss: T
39
+ observations: T
40
+
41
+
42
+ class PDENonStatioComponents(PDEStatioComponents[T]):
43
+ initial_condition: T