jinns 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +116 -189
  13. jinns/loss/_DynamicLossAbstract.py +45 -68
  14. jinns/loss/_LossODE.py +71 -336
  15. jinns/loss/_LossPDE.py +176 -513
  16. jinns/loss/__init__.py +28 -6
  17. jinns/loss/_abstract_loss.py +15 -0
  18. jinns/loss/_boundary_conditions.py +22 -21
  19. jinns/loss/_loss_utils.py +98 -173
  20. jinns/loss/_loss_weights.py +12 -44
  21. jinns/loss/_operators.py +84 -76
  22. jinns/nn/__init__.py +22 -0
  23. jinns/nn/_abstract_pinn.py +22 -0
  24. jinns/nn/_hyperpinn.py +434 -0
  25. jinns/nn/_mlp.py +217 -0
  26. jinns/nn/_pinn.py +204 -0
  27. jinns/nn/_ppinn.py +239 -0
  28. jinns/{utils → nn}/_save_load.py +39 -53
  29. jinns/nn/_spinn.py +123 -0
  30. jinns/nn/_spinn_mlp.py +202 -0
  31. jinns/nn/_utils.py +38 -0
  32. jinns/parameters/__init__.py +8 -1
  33. jinns/parameters/_derivative_keys.py +116 -177
  34. jinns/parameters/_params.py +18 -46
  35. jinns/plot/__init__.py +2 -0
  36. jinns/plot/_plot.py +38 -37
  37. jinns/solver/_rar.py +82 -65
  38. jinns/solver/_solve.py +111 -71
  39. jinns/solver/_utils.py +4 -6
  40. jinns/utils/__init__.py +2 -5
  41. jinns/utils/_containers.py +12 -9
  42. jinns/utils/_types.py +11 -57
  43. jinns/utils/_utils.py +4 -11
  44. jinns/validation/__init__.py +2 -0
  45. jinns/validation/_validation.py +20 -19
  46. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/METADATA +11 -10
  47. jinns-1.4.0.dist-info/RECORD +53 -0
  48. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
  49. jinns/data/_DataGenerators.py +0 -1634
  50. jinns/utils/_hyperpinn.py +0 -420
  51. jinns/utils/_pinn.py +0 -324
  52. jinns/utils/_ppinn.py +0 -227
  53. jinns/utils/_spinn.py +0 -249
  54. jinns-1.2.0.dist-info/RECORD +0 -41
  55. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
  56. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
  57. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/nn/_spinn.py ADDED
@@ -0,0 +1,123 @@
1
+ from __future__ import annotations
2
+ from typing import Union, Callable, Any, Literal, overload
3
+ from dataclasses import InitVar
4
+ from jaxtyping import PyTree, Float, Array
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import equinox as eqx
8
+
9
+ from jinns.parameters._params import Params
10
+ from jinns.nn._abstract_pinn import AbstractPINN
11
+ from jinns.nn._utils import _PyTree_to_Params
12
+
13
+
14
+ class SPINN(AbstractPINN):
15
+ """
16
+ A Separable PINN object compatible with the rest of jinns.
17
+
18
+ Parameters
19
+ ----------
20
+ d : int
21
+ The number of dimensions to treat separately, including time `t` if
22
+ used for non-stationnary equations.
23
+ r : int
24
+ An integer. The dimension of the embedding.
25
+ eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
26
+ A string with three possibilities.
27
+ "ODE": the PINN is called with one input `t`.
28
+ "statio_PDE": the PINN is called with one input `x`, `x`
29
+ can be high dimensional.
30
+ "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
31
+ can be high dimensional.
32
+ **Note**: the input dimension as given in eqx_list has to match the sum
33
+ of the dimension of `t` + the dimension of `x`.
34
+ m : int
35
+ The output dimension of the neural network. According to
36
+ the SPINN article, a total embedding dimension of `r*m` is defined. We
37
+ then sum groups of `r` embedding dimensions to compute each output.
38
+ Default is 1.
39
+ filter_spec : PyTree[Union[bool, Callable[[Any], bool]]]
40
+ Default is `eqx.is_inexact_array`. This tells Jinns what to consider as
41
+ a trainable parameter. Quoting from equinox documentation:
42
+ a PyTree whose structure should be a prefix of the structure of pytree.
43
+ Each of its leaves should either be 1) True, in which case the leaf or
44
+ subtree is kept; 2) False, in which case the leaf or subtree is
45
+ replaced with replace; 3) a callable Leaf -> bool, in which case this is evaluated on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
46
+ eqx_spinn_network : eqx.Module
47
+ The actual neural network instanciated as an eqx.Module. It should be
48
+ an architecture taking `d` inputs and returning `d` times an embedding
49
+ of dimension `r`*`m`. See the Separable PINN paper for more details.
50
+
51
+ """
52
+
53
+ eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"] = eqx.field(
54
+ static=True, kw_only=True
55
+ )
56
+ d: int = eqx.field(static=True, kw_only=True)
57
+ r: int = eqx.field(static=True, kw_only=True)
58
+ m: int = eqx.field(static=True, kw_only=True, default=1)
59
+ filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = eqx.field(
60
+ static=True, kw_only=True, default=None
61
+ )
62
+ eqx_spinn_network: InitVar[eqx.Module] = eqx.field(kw_only=True)
63
+
64
+ init_params: SPINN = eqx.field(init=False)
65
+ static: SPINN = eqx.field(init=False, static=True)
66
+
67
+ def __post_init__(self, eqx_spinn_network):
68
+ if self.filter_spec is None:
69
+ self.filter_spec = eqx.is_inexact_array
70
+
71
+ self.init_params, self.static = eqx.partition(
72
+ eqx_spinn_network, self.filter_spec
73
+ )
74
+
75
+ @overload
76
+ @_PyTree_to_Params
77
+ def __call__(
78
+ self,
79
+ inputs: Float[Array, " input_dim"],
80
+ params: PyTree,
81
+ *args,
82
+ **kwargs,
83
+ ) -> Float[Array, " output_dim"]: ...
84
+
85
+ @_PyTree_to_Params
86
+ def __call__(
87
+ self,
88
+ t_x: Float[Array, " batch_size 1+dim"],
89
+ params: Params[Array],
90
+ ) -> Float[Array, " output_dim"]:
91
+ """
92
+ Evaluate the SPINN on some inputs with some params.
93
+
94
+ Note that that thanks to the decorator, params can also directly be the
95
+ PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
96
+ """
97
+ # try:
98
+ spinn = eqx.combine(params.nn_params, self.static)
99
+ # except (KeyError, AttributeError, TypeError) as e:
100
+ # spinn = eqx.combine(params, self.static)
101
+ v_model = jax.vmap(spinn)
102
+ res = v_model(t_x) # type: ignore
103
+
104
+ a = ", ".join([f"{chr(97 + d)}z" for d in range(res.shape[1])])
105
+ b = "".join([f"{chr(97 + d)}" for d in range(res.shape[1])])
106
+ res = jnp.stack(
107
+ [
108
+ jnp.einsum(
109
+ f"{a} -> {b}",
110
+ *(
111
+ res[:, d, m * self.r : (m + 1) * self.r]
112
+ for d in range(res.shape[1])
113
+ ),
114
+ )
115
+ for m in range(self.m)
116
+ ],
117
+ axis=-1,
118
+ ) # compute each output dimension
119
+
120
+ # force (1,) output for non vectorial solution (consistency)
121
+ if len(res.shape) == self.d:
122
+ return jnp.expand_dims(res, axis=-1)
123
+ return res
jinns/nn/_spinn_mlp.py ADDED
@@ -0,0 +1,202 @@
1
+ """
2
+ Implements utility function to create Separable PINNs
3
+ https://arxiv.org/abs/2211.08761
4
+ """
5
+
6
+ from dataclasses import InitVar
7
+ from typing import Callable, Literal, Self, Union, Any, TypeGuard
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import equinox as eqx
11
+ from jaxtyping import Key, Array, Float, PyTree
12
+
13
+ from jinns.nn._mlp import MLP
14
+ from jinns.nn._spinn import SPINN
15
+
16
+
17
+ class SMLP(eqx.Module):
18
+ """
19
+ Construct a Separable MLP
20
+
21
+ Parameters
22
+ ----------
23
+ key : InitVar[Key]
24
+ A jax random key for the layer initializations.
25
+ d : int
26
+ The number of dimensions to treat separately, including time `t` if
27
+ used for non-stationnary equations.
28
+ eqx_list : InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]]
29
+ A tuple of tuples of successive equinox modules and activation functions to
30
+ describe the PINN architecture. The inner tuples must have the eqx module or
31
+ activation function as first item, other items represents arguments
32
+ that could be required (eg. the size of the layer).
33
+ The `key` argument need not be given.
34
+ Thus typical example is `eqx_list=
35
+ ((eqx.nn.Linear, 1, 20),
36
+ (jax.nn.tanh,),
37
+ (eqx.nn.Linear, 20, 20),
38
+ (jax.nn.tanh,),
39
+ (eqx.nn.Linear, 20, 20),
40
+ (jax.nn.tanh,),
41
+ (eqx.nn.Linear, 20, r * m)
42
+ )`.
43
+ """
44
+
45
+ key: InitVar[Key] = eqx.field(kw_only=True)
46
+ eqx_list: InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]] = (
47
+ eqx.field(kw_only=True)
48
+ )
49
+ d: int = eqx.field(static=True, kw_only=True)
50
+
51
+ separated_mlp: list[MLP] = eqx.field(init=False)
52
+
53
+ def __post_init__(self, key, eqx_list):
54
+ keys = jax.random.split(key, self.d)
55
+ self.separated_mlp = [
56
+ MLP(key=keys[d_], eqx_list=eqx_list) for d_ in range(self.d)
57
+ ]
58
+
59
+ def __call__(
60
+ self, inputs: Float[Array, " dim"] | Float[Array, " dim+1"]
61
+ ) -> Float[Array, " d embed_dim*output_dim"]:
62
+ outputs = []
63
+ for d in range(self.d):
64
+ x_i = inputs[d : d + 1]
65
+ outputs += [self.separated_mlp[d](x_i)]
66
+ return jnp.asarray(outputs)
67
+
68
+
69
+ class SPINN_MLP(SPINN):
70
+ """
71
+ An implementable SPINN based on a MLP architecture
72
+ """
73
+
74
+ @classmethod
75
+ def create(
76
+ cls,
77
+ key: Key,
78
+ d: int,
79
+ r: int,
80
+ eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...],
81
+ eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
82
+ m: int = 1,
83
+ filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = None,
84
+ ) -> tuple[Self, SPINN]:
85
+ """
86
+ Utility function to create a SPINN neural network with the equinox
87
+ library.
88
+
89
+ *Note* that a SPINN is not vmapped and expects the
90
+ same batch size for each of its input axis. It directly outputs a
91
+ solution of shape `(batchsize,) * d`. See the paper for more
92
+ details.
93
+
94
+ Parameters
95
+ ----------
96
+ key : Key
97
+ A JAX random key that will be used to initialize the network parameters
98
+ d : int
99
+ The number of dimensions to treat separately.
100
+ r : int
101
+ An integer. The dimension of the embedding.
102
+ eqx_list : tuple[tuple[Callable, int, int] | Callable, ...],
103
+ A tuple of tuples of successive equinox modules and activation functions to
104
+ describe the PINN architecture. The inner tuples must have the eqx module or
105
+ activation function as first item, other items represents arguments
106
+ that could be required (eg. the size of the layer).
107
+ The `key` argument need not be given.
108
+ Thus typical example is
109
+ `eqx_list=((eqx.nn.Linear, 1, 20),
110
+ (jax.nn.tanh,),
111
+ (eqx.nn.Linea)r, 20, 20),
112
+ (jax.nn.tanh,),
113
+ (eqx.nn.Linear, 20, 20),
114
+ (jax.nn.tanh,),
115
+ (eqx.nn.Linear, 20, r * m)
116
+ )`.
117
+ eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
118
+ A string with three possibilities.
119
+ "ODE": the PINN is called with one input `t`.
120
+ "statio_PDE": the PINN is called with one input `x`, `x`
121
+ can be high dimensional.
122
+ "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
123
+ can be high dimensional.
124
+ **Note**: the input dimension as given in eqx_list has to match the sum
125
+ of the dimension of `t` + the dimension of `x`.
126
+ m : int
127
+ The output dimension of the neural network. According to
128
+ the SPINN article, a total embedding dimension of `r*m` is defined. We
129
+ then sum groups of `r` embedding dimensions to compute each output.
130
+ Default is 1.
131
+ filter_spec : PyTree[Union[bool, Callable[[Any], bool]]]
132
+ Default is None which leads to `eqx.is_inexact_array` in the class
133
+ instanciation. This tells Jinns what to consider as
134
+ a trainable parameter. Quoting from equinox documentation:
135
+ a PyTree whose structure should be a prefix of the structure of pytree.
136
+ Each of its leaves should either be 1) True, in which case the leaf or
137
+ subtree is kept; 2) False, in which case the leaf or subtree is
138
+ replaced with replace; 3) a callable Leaf -> bool, in which case this is evaluated on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
139
+
140
+
141
+
142
+
143
+ Returns
144
+ -------
145
+ spinn
146
+ An instanciated SPINN
147
+ spinn.init_params
148
+ The initial set of parameters of the model
149
+
150
+ Raises
151
+ ------
152
+ RuntimeError
153
+ If the parameter value for eq_type is not in `["ODE", "statio_PDE",
154
+ "nonstatio_PDE"]` and for various failing checks
155
+ """
156
+
157
+ if eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
158
+ raise RuntimeError("Wrong parameter value for eq_type")
159
+
160
+ def element_is_layer(element: tuple) -> TypeGuard[tuple[Callable, int, int]]:
161
+ return len(element) > 1
162
+
163
+ if element_is_layer(eqx_list[0]):
164
+ nb_inputs_declared = eqx_list[0][
165
+ 1
166
+ ] # normally we look for 2nd ele of 1st layer
167
+ elif element_is_layer(eqx_list[1]):
168
+ nb_inputs_declared = eqx_list[1][
169
+ 1
170
+ ] # but we can have, eg, a flatten first layer
171
+ else:
172
+ nb_inputs_declared = None
173
+ if nb_inputs_declared is None or nb_inputs_declared != 1:
174
+ raise ValueError("Input dim must be set to 1 in SPINN!")
175
+
176
+ if element_is_layer(eqx_list[-1]):
177
+ nb_outputs_declared = eqx_list[-1][2] # normally we look for 3rd ele of
178
+ # last layer
179
+ elif element_is_layer(eqx_list[-2]):
180
+ nb_outputs_declared = eqx_list[-2][2]
181
+ # but we can have, eg, a `jnp.exp` last layer
182
+ else:
183
+ nb_outputs_declared = None
184
+ if nb_outputs_declared is None or nb_outputs_declared != r * m:
185
+ raise ValueError("Output dim must be set to r * m in SPINN!")
186
+
187
+ if d > 24:
188
+ raise ValueError(
189
+ "Too many dimensions, not enough letters available in jnp.einsum"
190
+ )
191
+
192
+ smlp = SMLP(key=key, d=d, eqx_list=eqx_list)
193
+ spinn = cls(
194
+ eqx_spinn_network=smlp,
195
+ d=d,
196
+ r=r,
197
+ eq_type=eq_type,
198
+ m=m,
199
+ filter_spec=filter_spec,
200
+ )
201
+
202
+ return spinn, spinn.init_params
jinns/nn/_utils.py ADDED
@@ -0,0 +1,38 @@
1
+ from typing import Any, ParamSpec, Callable, Concatenate
2
+ from jaxtyping import PyTree, Array
3
+ from jinns.parameters._params import Params
4
+
5
+
6
+ P = ParamSpec("P")
7
+
8
+
9
+ def _PyTree_to_Params(
10
+ call_fun: Callable[
11
+ Concatenate[Any, Any, PyTree | Params[Array], P],
12
+ Any,
13
+ ],
14
+ ) -> Callable[
15
+ Concatenate[Any, Any, PyTree | Params[Array], P],
16
+ Any,
17
+ ]:
18
+ """
19
+ Decorator to be used around __call__ functions of PINNs, SPINNs, etc. It
20
+ authorizes the __call__ with `params` being directly be the
21
+ PyTree (SPINN, PINN_MLP, ...) that we get out of `eqx.combine`
22
+
23
+ This generic approach enables to cleanly handle type hints, up to the small
24
+ effort required to understand type hints for decorators (ie ParamSpec).
25
+ """
26
+
27
+ def wrapper(
28
+ self: Any,
29
+ inputs: Any,
30
+ params: PyTree | Params[Array],
31
+ *args: P.args,
32
+ **kwargs: P.kwargs,
33
+ ):
34
+ if isinstance(params, PyTree) and not isinstance(params, Params):
35
+ params = Params(nn_params=params, eq_params={})
36
+ return call_fun(self, inputs, params, *args, **kwargs)
37
+
38
+ return wrapper
@@ -1,6 +1,13 @@
1
- from ._params import Params, ParamsDict
1
+ from ._params import Params
2
2
  from ._derivative_keys import (
3
3
  DerivativeKeysODE,
4
4
  DerivativeKeysPDEStatio,
5
5
  DerivativeKeysPDENonStatio,
6
6
  )
7
+
8
+ __all__ = [
9
+ "Params",
10
+ "DerivativeKeysODE",
11
+ "DerivativeKeysPDEStatio",
12
+ "DerivativeKeysPDENonStatio",
13
+ ]