jinns 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +114 -187
- jinns/loss/_DynamicLossAbstract.py +45 -68
- jinns/loss/_LossODE.py +71 -336
- jinns/loss/_LossPDE.py +146 -520
- jinns/loss/__init__.py +28 -6
- jinns/loss/_abstract_loss.py +15 -0
- jinns/loss/_boundary_conditions.py +20 -19
- jinns/loss/_loss_utils.py +78 -159
- jinns/loss/_loss_weights.py +12 -44
- jinns/loss/_operators.py +84 -74
- jinns/nn/__init__.py +15 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +94 -57
- jinns/nn/_mlp.py +50 -25
- jinns/nn/_pinn.py +33 -19
- jinns/nn/_ppinn.py +70 -34
- jinns/nn/_save_load.py +21 -51
- jinns/nn/_spinn.py +33 -16
- jinns/nn/_spinn_mlp.py +28 -22
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +35 -34
- jinns/solver/_rar.py +80 -63
- jinns/solver/_solve.py +89 -63
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +12 -9
- jinns/utils/_types.py +11 -57
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/METADATA +4 -3
- jinns-1.4.0.dist-info/RECORD +53 -0
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns-1.3.0.dist-info/RECORD +0 -44
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/nn/_spinn_mlp.py
CHANGED
|
@@ -4,13 +4,12 @@ https://arxiv.org/abs/2211.08761
|
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
6
|
from dataclasses import InitVar
|
|
7
|
-
from typing import Callable, Literal, Self, Union, Any
|
|
7
|
+
from typing import Callable, Literal, Self, Union, Any, TypeGuard
|
|
8
8
|
import jax
|
|
9
9
|
import jax.numpy as jnp
|
|
10
10
|
import equinox as eqx
|
|
11
11
|
from jaxtyping import Key, Array, Float, PyTree
|
|
12
12
|
|
|
13
|
-
from jinns.parameters._params import Params, ParamsDict
|
|
14
13
|
from jinns.nn._mlp import MLP
|
|
15
14
|
from jinns.nn._spinn import SPINN
|
|
16
15
|
|
|
@@ -26,7 +25,7 @@ class SMLP(eqx.Module):
|
|
|
26
25
|
d : int
|
|
27
26
|
The number of dimensions to treat separately, including time `t` if
|
|
28
27
|
used for non-stationnary equations.
|
|
29
|
-
eqx_list : InitVar[tuple[tuple[Callable, int, int] | Callable, ...]]
|
|
28
|
+
eqx_list : InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]]
|
|
30
29
|
A tuple of tuples of successive equinox modules and activation functions to
|
|
31
30
|
describe the PINN architecture. The inner tuples must have the eqx module or
|
|
32
31
|
activation function as first item, other items represents arguments
|
|
@@ -34,18 +33,18 @@ class SMLP(eqx.Module):
|
|
|
34
33
|
The `key` argument need not be given.
|
|
35
34
|
Thus typical example is `eqx_list=
|
|
36
35
|
((eqx.nn.Linear, 1, 20),
|
|
37
|
-
jax.nn.tanh,
|
|
36
|
+
(jax.nn.tanh,),
|
|
38
37
|
(eqx.nn.Linear, 20, 20),
|
|
39
|
-
jax.nn.tanh,
|
|
38
|
+
(jax.nn.tanh,),
|
|
40
39
|
(eqx.nn.Linear, 20, 20),
|
|
41
|
-
jax.nn.tanh,
|
|
40
|
+
(jax.nn.tanh,),
|
|
42
41
|
(eqx.nn.Linear, 20, r * m)
|
|
43
42
|
)`.
|
|
44
43
|
"""
|
|
45
44
|
|
|
46
45
|
key: InitVar[Key] = eqx.field(kw_only=True)
|
|
47
|
-
eqx_list: InitVar[tuple[tuple[Callable, int, int] | Callable, ...]] =
|
|
48
|
-
kw_only=True
|
|
46
|
+
eqx_list: InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]] = (
|
|
47
|
+
eqx.field(kw_only=True)
|
|
49
48
|
)
|
|
50
49
|
d: int = eqx.field(static=True, kw_only=True)
|
|
51
50
|
|
|
@@ -58,8 +57,8 @@ class SMLP(eqx.Module):
|
|
|
58
57
|
]
|
|
59
58
|
|
|
60
59
|
def __call__(
|
|
61
|
-
self, inputs: Float[Array, "dim"] | Float[Array, "dim+1"]
|
|
62
|
-
) -> Float[Array, "d embed_dim*output_dim"]:
|
|
60
|
+
self, inputs: Float[Array, " dim"] | Float[Array, " dim+1"]
|
|
61
|
+
) -> Float[Array, " d embed_dim*output_dim"]:
|
|
63
62
|
outputs = []
|
|
64
63
|
for d in range(self.d):
|
|
65
64
|
x_i = inputs[d : d + 1]
|
|
@@ -78,11 +77,11 @@ class SPINN_MLP(SPINN):
|
|
|
78
77
|
key: Key,
|
|
79
78
|
d: int,
|
|
80
79
|
r: int,
|
|
81
|
-
eqx_list: tuple[tuple[Callable, int, int] | Callable, ...],
|
|
80
|
+
eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...],
|
|
82
81
|
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
|
|
83
82
|
m: int = 1,
|
|
84
83
|
filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = None,
|
|
85
|
-
) -> tuple[Self,
|
|
84
|
+
) -> tuple[Self, SPINN]:
|
|
86
85
|
"""
|
|
87
86
|
Utility function to create a SPINN neural network with the equinox
|
|
88
87
|
library.
|
|
@@ -108,11 +107,11 @@ class SPINN_MLP(SPINN):
|
|
|
108
107
|
The `key` argument need not be given.
|
|
109
108
|
Thus typical example is
|
|
110
109
|
`eqx_list=((eqx.nn.Linear, 1, 20),
|
|
111
|
-
jax.nn.tanh,
|
|
110
|
+
(jax.nn.tanh,),
|
|
111
|
+
(eqx.nn.Linea)r, 20, 20),
|
|
112
|
+
(jax.nn.tanh,),
|
|
112
113
|
(eqx.nn.Linear, 20, 20),
|
|
113
|
-
jax.nn.tanh,
|
|
114
|
-
(eqx.nn.Linear, 20, 20),
|
|
115
|
-
jax.nn.tanh,
|
|
114
|
+
(jax.nn.tanh,),
|
|
116
115
|
(eqx.nn.Linear, 20, r * m)
|
|
117
116
|
)`.
|
|
118
117
|
eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
|
|
@@ -158,24 +157,31 @@ class SPINN_MLP(SPINN):
|
|
|
158
157
|
if eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
|
|
159
158
|
raise RuntimeError("Wrong parameter value for eq_type")
|
|
160
159
|
|
|
161
|
-
|
|
160
|
+
def element_is_layer(element: tuple) -> TypeGuard[tuple[Callable, int, int]]:
|
|
161
|
+
return len(element) > 1
|
|
162
|
+
|
|
163
|
+
if element_is_layer(eqx_list[0]):
|
|
162
164
|
nb_inputs_declared = eqx_list[0][
|
|
163
165
|
1
|
|
164
166
|
] # normally we look for 2nd ele of 1st layer
|
|
165
|
-
|
|
167
|
+
elif element_is_layer(eqx_list[1]):
|
|
166
168
|
nb_inputs_declared = eqx_list[1][
|
|
167
169
|
1
|
|
168
170
|
] # but we can have, eg, a flatten first layer
|
|
169
|
-
|
|
171
|
+
else:
|
|
172
|
+
nb_inputs_declared = None
|
|
173
|
+
if nb_inputs_declared is None or nb_inputs_declared != 1:
|
|
170
174
|
raise ValueError("Input dim must be set to 1 in SPINN!")
|
|
171
175
|
|
|
172
|
-
|
|
176
|
+
if element_is_layer(eqx_list[-1]):
|
|
173
177
|
nb_outputs_declared = eqx_list[-1][2] # normally we look for 3rd ele of
|
|
174
178
|
# last layer
|
|
175
|
-
|
|
179
|
+
elif element_is_layer(eqx_list[-2]):
|
|
176
180
|
nb_outputs_declared = eqx_list[-2][2]
|
|
177
181
|
# but we can have, eg, a `jnp.exp` last layer
|
|
178
|
-
|
|
182
|
+
else:
|
|
183
|
+
nb_outputs_declared = None
|
|
184
|
+
if nb_outputs_declared is None or nb_outputs_declared != r * m:
|
|
179
185
|
raise ValueError("Output dim must be set to r * m in SPINN!")
|
|
180
186
|
|
|
181
187
|
if d > 24:
|
jinns/nn/_utils.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from typing import Any, ParamSpec, Callable, Concatenate
|
|
2
|
+
from jaxtyping import PyTree, Array
|
|
3
|
+
from jinns.parameters._params import Params
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
P = ParamSpec("P")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _PyTree_to_Params(
|
|
10
|
+
call_fun: Callable[
|
|
11
|
+
Concatenate[Any, Any, PyTree | Params[Array], P],
|
|
12
|
+
Any,
|
|
13
|
+
],
|
|
14
|
+
) -> Callable[
|
|
15
|
+
Concatenate[Any, Any, PyTree | Params[Array], P],
|
|
16
|
+
Any,
|
|
17
|
+
]:
|
|
18
|
+
"""
|
|
19
|
+
Decorator to be used around __call__ functions of PINNs, SPINNs, etc. It
|
|
20
|
+
authorizes the __call__ with `params` being directly be the
|
|
21
|
+
PyTree (SPINN, PINN_MLP, ...) that we get out of `eqx.combine`
|
|
22
|
+
|
|
23
|
+
This generic approach enables to cleanly handle type hints, up to the small
|
|
24
|
+
effort required to understand type hints for decorators (ie ParamSpec).
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def wrapper(
|
|
28
|
+
self: Any,
|
|
29
|
+
inputs: Any,
|
|
30
|
+
params: PyTree | Params[Array],
|
|
31
|
+
*args: P.args,
|
|
32
|
+
**kwargs: P.kwargs,
|
|
33
|
+
):
|
|
34
|
+
if isinstance(params, PyTree) and not isinstance(params, Params):
|
|
35
|
+
params = Params(nn_params=params, eq_params={})
|
|
36
|
+
return call_fun(self, inputs, params, *args, **kwargs)
|
|
37
|
+
|
|
38
|
+
return wrapper
|
jinns/parameters/__init__.py
CHANGED
|
@@ -1,6 +1,13 @@
|
|
|
1
|
-
from ._params import Params
|
|
1
|
+
from ._params import Params
|
|
2
2
|
from ._derivative_keys import (
|
|
3
3
|
DerivativeKeysODE,
|
|
4
4
|
DerivativeKeysPDEStatio,
|
|
5
5
|
DerivativeKeysPDENonStatio,
|
|
6
6
|
)
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"Params",
|
|
10
|
+
"DerivativeKeysODE",
|
|
11
|
+
"DerivativeKeysPDEStatio",
|
|
12
|
+
"DerivativeKeysPDENonStatio",
|
|
13
|
+
]
|