jinns 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +114 -187
  13. jinns/loss/_DynamicLossAbstract.py +45 -68
  14. jinns/loss/_LossODE.py +71 -336
  15. jinns/loss/_LossPDE.py +146 -520
  16. jinns/loss/__init__.py +28 -6
  17. jinns/loss/_abstract_loss.py +15 -0
  18. jinns/loss/_boundary_conditions.py +20 -19
  19. jinns/loss/_loss_utils.py +78 -159
  20. jinns/loss/_loss_weights.py +12 -44
  21. jinns/loss/_operators.py +84 -74
  22. jinns/nn/__init__.py +15 -0
  23. jinns/nn/_abstract_pinn.py +22 -0
  24. jinns/nn/_hyperpinn.py +94 -57
  25. jinns/nn/_mlp.py +50 -25
  26. jinns/nn/_pinn.py +33 -19
  27. jinns/nn/_ppinn.py +70 -34
  28. jinns/nn/_save_load.py +21 -51
  29. jinns/nn/_spinn.py +33 -16
  30. jinns/nn/_spinn_mlp.py +28 -22
  31. jinns/nn/_utils.py +38 -0
  32. jinns/parameters/__init__.py +8 -1
  33. jinns/parameters/_derivative_keys.py +116 -177
  34. jinns/parameters/_params.py +18 -46
  35. jinns/plot/__init__.py +2 -0
  36. jinns/plot/_plot.py +35 -34
  37. jinns/solver/_rar.py +80 -63
  38. jinns/solver/_solve.py +89 -63
  39. jinns/solver/_utils.py +4 -6
  40. jinns/utils/__init__.py +2 -0
  41. jinns/utils/_containers.py +12 -9
  42. jinns/utils/_types.py +11 -57
  43. jinns/utils/_utils.py +4 -11
  44. jinns/validation/__init__.py +2 -0
  45. jinns/validation/_validation.py +20 -19
  46. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/METADATA +4 -3
  47. jinns-1.4.0.dist-info/RECORD +53 -0
  48. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
  49. jinns/data/_DataGenerators.py +0 -1634
  50. jinns-1.3.0.dist-info/RECORD +0 -44
  51. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
  52. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
  53. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/nn/_spinn_mlp.py CHANGED
@@ -4,13 +4,12 @@ https://arxiv.org/abs/2211.08761
4
4
  """
5
5
 
6
6
  from dataclasses import InitVar
7
- from typing import Callable, Literal, Self, Union, Any
7
+ from typing import Callable, Literal, Self, Union, Any, TypeGuard
8
8
  import jax
9
9
  import jax.numpy as jnp
10
10
  import equinox as eqx
11
11
  from jaxtyping import Key, Array, Float, PyTree
12
12
 
13
- from jinns.parameters._params import Params, ParamsDict
14
13
  from jinns.nn._mlp import MLP
15
14
  from jinns.nn._spinn import SPINN
16
15
 
@@ -26,7 +25,7 @@ class SMLP(eqx.Module):
26
25
  d : int
27
26
  The number of dimensions to treat separately, including time `t` if
28
27
  used for non-stationnary equations.
29
- eqx_list : InitVar[tuple[tuple[Callable, int, int] | Callable, ...]]
28
+ eqx_list : InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]]
30
29
  A tuple of tuples of successive equinox modules and activation functions to
31
30
  describe the PINN architecture. The inner tuples must have the eqx module or
32
31
  activation function as first item, other items represents arguments
@@ -34,18 +33,18 @@ class SMLP(eqx.Module):
34
33
  The `key` argument need not be given.
35
34
  Thus typical example is `eqx_list=
36
35
  ((eqx.nn.Linear, 1, 20),
37
- jax.nn.tanh,
36
+ (jax.nn.tanh,),
38
37
  (eqx.nn.Linear, 20, 20),
39
- jax.nn.tanh,
38
+ (jax.nn.tanh,),
40
39
  (eqx.nn.Linear, 20, 20),
41
- jax.nn.tanh,
40
+ (jax.nn.tanh,),
42
41
  (eqx.nn.Linear, 20, r * m)
43
42
  )`.
44
43
  """
45
44
 
46
45
  key: InitVar[Key] = eqx.field(kw_only=True)
47
- eqx_list: InitVar[tuple[tuple[Callable, int, int] | Callable, ...]] = eqx.field(
48
- kw_only=True
46
+ eqx_list: InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]] = (
47
+ eqx.field(kw_only=True)
49
48
  )
50
49
  d: int = eqx.field(static=True, kw_only=True)
51
50
 
@@ -58,8 +57,8 @@ class SMLP(eqx.Module):
58
57
  ]
59
58
 
60
59
  def __call__(
61
- self, inputs: Float[Array, "dim"] | Float[Array, "dim+1"]
62
- ) -> Float[Array, "d embed_dim*output_dim"]:
60
+ self, inputs: Float[Array, " dim"] | Float[Array, " dim+1"]
61
+ ) -> Float[Array, " d embed_dim*output_dim"]:
63
62
  outputs = []
64
63
  for d in range(self.d):
65
64
  x_i = inputs[d : d + 1]
@@ -78,11 +77,11 @@ class SPINN_MLP(SPINN):
78
77
  key: Key,
79
78
  d: int,
80
79
  r: int,
81
- eqx_list: tuple[tuple[Callable, int, int] | Callable, ...],
80
+ eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...],
82
81
  eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
83
82
  m: int = 1,
84
83
  filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = None,
85
- ) -> tuple[Self, PyTree]:
84
+ ) -> tuple[Self, SPINN]:
86
85
  """
87
86
  Utility function to create a SPINN neural network with the equinox
88
87
  library.
@@ -108,11 +107,11 @@ class SPINN_MLP(SPINN):
108
107
  The `key` argument need not be given.
109
108
  Thus typical example is
110
109
  `eqx_list=((eqx.nn.Linear, 1, 20),
111
- jax.nn.tanh,
110
+ (jax.nn.tanh,),
111
+ (eqx.nn.Linea)r, 20, 20),
112
+ (jax.nn.tanh,),
112
113
  (eqx.nn.Linear, 20, 20),
113
- jax.nn.tanh,
114
- (eqx.nn.Linear, 20, 20),
115
- jax.nn.tanh,
114
+ (jax.nn.tanh,),
116
115
  (eqx.nn.Linear, 20, r * m)
117
116
  )`.
118
117
  eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
@@ -158,24 +157,31 @@ class SPINN_MLP(SPINN):
158
157
  if eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
159
158
  raise RuntimeError("Wrong parameter value for eq_type")
160
159
 
161
- try:
160
+ def element_is_layer(element: tuple) -> TypeGuard[tuple[Callable, int, int]]:
161
+ return len(element) > 1
162
+
163
+ if element_is_layer(eqx_list[0]):
162
164
  nb_inputs_declared = eqx_list[0][
163
165
  1
164
166
  ] # normally we look for 2nd ele of 1st layer
165
- except IndexError:
167
+ elif element_is_layer(eqx_list[1]):
166
168
  nb_inputs_declared = eqx_list[1][
167
169
  1
168
170
  ] # but we can have, eg, a flatten first layer
169
- if nb_inputs_declared != 1:
171
+ else:
172
+ nb_inputs_declared = None
173
+ if nb_inputs_declared is None or nb_inputs_declared != 1:
170
174
  raise ValueError("Input dim must be set to 1 in SPINN!")
171
175
 
172
- try:
176
+ if element_is_layer(eqx_list[-1]):
173
177
  nb_outputs_declared = eqx_list[-1][2] # normally we look for 3rd ele of
174
178
  # last layer
175
- except IndexError:
179
+ elif element_is_layer(eqx_list[-2]):
176
180
  nb_outputs_declared = eqx_list[-2][2]
177
181
  # but we can have, eg, a `jnp.exp` last layer
178
- if nb_outputs_declared != r * m:
182
+ else:
183
+ nb_outputs_declared = None
184
+ if nb_outputs_declared is None or nb_outputs_declared != r * m:
179
185
  raise ValueError("Output dim must be set to r * m in SPINN!")
180
186
 
181
187
  if d > 24:
jinns/nn/_utils.py ADDED
@@ -0,0 +1,38 @@
1
+ from typing import Any, ParamSpec, Callable, Concatenate
2
+ from jaxtyping import PyTree, Array
3
+ from jinns.parameters._params import Params
4
+
5
+
6
+ P = ParamSpec("P")
7
+
8
+
9
+ def _PyTree_to_Params(
10
+ call_fun: Callable[
11
+ Concatenate[Any, Any, PyTree | Params[Array], P],
12
+ Any,
13
+ ],
14
+ ) -> Callable[
15
+ Concatenate[Any, Any, PyTree | Params[Array], P],
16
+ Any,
17
+ ]:
18
+ """
19
+ Decorator to be used around __call__ functions of PINNs, SPINNs, etc. It
20
+ authorizes the __call__ with `params` being directly be the
21
+ PyTree (SPINN, PINN_MLP, ...) that we get out of `eqx.combine`
22
+
23
+ This generic approach enables to cleanly handle type hints, up to the small
24
+ effort required to understand type hints for decorators (ie ParamSpec).
25
+ """
26
+
27
+ def wrapper(
28
+ self: Any,
29
+ inputs: Any,
30
+ params: PyTree | Params[Array],
31
+ *args: P.args,
32
+ **kwargs: P.kwargs,
33
+ ):
34
+ if isinstance(params, PyTree) and not isinstance(params, Params):
35
+ params = Params(nn_params=params, eq_params={})
36
+ return call_fun(self, inputs, params, *args, **kwargs)
37
+
38
+ return wrapper
@@ -1,6 +1,13 @@
1
- from ._params import Params, ParamsDict
1
+ from ._params import Params
2
2
  from ._derivative_keys import (
3
3
  DerivativeKeysODE,
4
4
  DerivativeKeysPDEStatio,
5
5
  DerivativeKeysPDENonStatio,
6
6
  )
7
+
8
+ __all__ = [
9
+ "Params",
10
+ "DerivativeKeysODE",
11
+ "DerivativeKeysPDEStatio",
12
+ "DerivativeKeysPDENonStatio",
13
+ ]