jinns 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +114 -187
- jinns/loss/_DynamicLossAbstract.py +45 -68
- jinns/loss/_LossODE.py +71 -336
- jinns/loss/_LossPDE.py +146 -520
- jinns/loss/__init__.py +28 -6
- jinns/loss/_abstract_loss.py +15 -0
- jinns/loss/_boundary_conditions.py +20 -19
- jinns/loss/_loss_utils.py +78 -159
- jinns/loss/_loss_weights.py +12 -44
- jinns/loss/_operators.py +84 -74
- jinns/nn/__init__.py +15 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +94 -57
- jinns/nn/_mlp.py +50 -25
- jinns/nn/_pinn.py +33 -19
- jinns/nn/_ppinn.py +70 -34
- jinns/nn/_save_load.py +21 -51
- jinns/nn/_spinn.py +33 -16
- jinns/nn/_spinn_mlp.py +28 -22
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +35 -34
- jinns/solver/_rar.py +80 -63
- jinns/solver/_solve.py +89 -63
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +12 -9
- jinns/utils/_types.py +11 -57
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/METADATA +4 -3
- jinns-1.4.0.dist-info/RECORD +53 -0
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns-1.3.0.dist-info/RECORD +0 -44
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/nn/_mlp.py
CHANGED
|
@@ -2,16 +2,30 @@
|
|
|
2
2
|
Implements utility function to create PINNs
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from typing import Callable, Literal, Self, Union, Any, TYPE_CHECKING, cast
|
|
6
8
|
from dataclasses import InitVar
|
|
7
9
|
import jax
|
|
8
10
|
import equinox as eqx
|
|
9
|
-
|
|
11
|
+
from typing import Protocol
|
|
10
12
|
from jaxtyping import Array, Key, PyTree, Float
|
|
11
13
|
|
|
12
14
|
from jinns.parameters._params import Params
|
|
13
15
|
from jinns.nn._pinn import PINN
|
|
14
16
|
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
|
|
19
|
+
class CallableMLPModule(Protocol):
|
|
20
|
+
"""
|
|
21
|
+
Basically just a way to add a __call__ to an eqx.Module.
|
|
22
|
+
https://github.com/patrick-kidger/equinox/issues/1002
|
|
23
|
+
We chose the strutural subtyping of protocols instead of subclassing an
|
|
24
|
+
eqx.Module just to add a __call__ here
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __call__(self, *_, **__) -> Array: ...
|
|
28
|
+
|
|
15
29
|
|
|
16
30
|
class MLP(eqx.Module):
|
|
17
31
|
"""
|
|
@@ -21,7 +35,7 @@ class MLP(eqx.Module):
|
|
|
21
35
|
----------
|
|
22
36
|
key : InitVar[Key]
|
|
23
37
|
A jax random key for the layer initializations.
|
|
24
|
-
eqx_list : InitVar[tuple[tuple[Callable, int, int] | Callable, ...]]
|
|
38
|
+
eqx_list : InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]]
|
|
25
39
|
A tuple of tuples of successive equinox modules and activation functions to
|
|
26
40
|
describe the PINN architecture. The inner tuples must have the eqx module or
|
|
27
41
|
activation function as first item, other items represents arguments
|
|
@@ -29,23 +43,23 @@ class MLP(eqx.Module):
|
|
|
29
43
|
The `key` argument need not be given.
|
|
30
44
|
Thus typical example is `eqx_list=
|
|
31
45
|
((eqx.nn.Linear, 2, 20),
|
|
32
|
-
jax.nn.tanh,
|
|
46
|
+
(jax.nn.tanh,),
|
|
33
47
|
(eqx.nn.Linear, 20, 20),
|
|
34
|
-
jax.nn.tanh,
|
|
48
|
+
(jax.nn.tanh,),
|
|
35
49
|
(eqx.nn.Linear, 20, 20),
|
|
36
|
-
jax.nn.tanh,
|
|
50
|
+
(jax.nn.tanh,),
|
|
37
51
|
(eqx.nn.Linear, 20, 1)
|
|
38
52
|
)`.
|
|
39
53
|
"""
|
|
40
54
|
|
|
41
55
|
key: InitVar[Key] = eqx.field(kw_only=True)
|
|
42
|
-
eqx_list: InitVar[tuple[tuple[Callable, int, int] | Callable, ...]] =
|
|
43
|
-
kw_only=True
|
|
56
|
+
eqx_list: InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]] = (
|
|
57
|
+
eqx.field(kw_only=True)
|
|
44
58
|
)
|
|
45
59
|
|
|
46
60
|
# NOTE that the following should NOT be declared as static otherwise the
|
|
47
61
|
# eqx.partition that we use in the PINN module will misbehave
|
|
48
|
-
layers: list[
|
|
62
|
+
layers: list[CallableMLPModule | Callable[[Array], Array]] = eqx.field(init=False)
|
|
49
63
|
|
|
50
64
|
def __post_init__(self, key, eqx_list):
|
|
51
65
|
self.layers = []
|
|
@@ -63,7 +77,7 @@ class MLP(eqx.Module):
|
|
|
63
77
|
self.layers.append(l[0](*l[1:], key=subkey))
|
|
64
78
|
k += 1
|
|
65
79
|
|
|
66
|
-
def __call__(self, t: Float[Array, "input_dim"]) -> Float[Array, "output_dim"]:
|
|
80
|
+
def __call__(self, t: Float[Array, " input_dim"]) -> Float[Array, " output_dim"]:
|
|
67
81
|
for layer in self.layers:
|
|
68
82
|
t = layer(t)
|
|
69
83
|
return t
|
|
@@ -81,19 +95,30 @@ class PINN_MLP(PINN):
|
|
|
81
95
|
def create(
|
|
82
96
|
cls,
|
|
83
97
|
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
|
|
84
|
-
eqx_network: eqx.nn.MLP = None,
|
|
98
|
+
eqx_network: eqx.nn.MLP | MLP | None = None,
|
|
85
99
|
key: Key = None,
|
|
86
|
-
eqx_list: tuple[tuple[Callable, int, int] | Callable, ...] = None,
|
|
87
|
-
input_transform:
|
|
88
|
-
[
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
100
|
+
eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...] | None = None,
|
|
101
|
+
input_transform: (
|
|
102
|
+
Callable[
|
|
103
|
+
[Float[Array, " input_dim"], Params[Array]],
|
|
104
|
+
Float[Array, " output_dim"],
|
|
105
|
+
]
|
|
106
|
+
| None
|
|
107
|
+
) = None,
|
|
108
|
+
output_transform: (
|
|
109
|
+
Callable[
|
|
110
|
+
[
|
|
111
|
+
Float[Array, " input_dim"],
|
|
112
|
+
Float[Array, " output_dim"],
|
|
113
|
+
Params[Array],
|
|
114
|
+
],
|
|
115
|
+
Float[Array, " output_dim"],
|
|
116
|
+
]
|
|
117
|
+
| None
|
|
118
|
+
) = None,
|
|
119
|
+
slice_solution: slice | None = None,
|
|
95
120
|
filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = None,
|
|
96
|
-
) -> tuple[Self,
|
|
121
|
+
) -> tuple[Self, PINN]:
|
|
97
122
|
r"""
|
|
98
123
|
Instanciate standard PINN MLP object. The actual NN is either passed as
|
|
99
124
|
a eqx.nn.MLP (`eqx_network` argument) or constructed as a custom
|
|
@@ -179,14 +204,14 @@ class PINN_MLP(PINN):
|
|
|
179
204
|
raise ValueError(
|
|
180
205
|
"If eqx_network is None, then key and eqx_list must be provided"
|
|
181
206
|
)
|
|
182
|
-
eqx_network = MLP(key=key, eqx_list=eqx_list)
|
|
207
|
+
eqx_network = cast(MLP, MLP(key=key, eqx_list=eqx_list))
|
|
183
208
|
|
|
184
209
|
mlp = cls(
|
|
185
210
|
eqx_network=eqx_network,
|
|
186
|
-
slice_solution=slice_solution,
|
|
211
|
+
slice_solution=slice_solution, # type: ignore
|
|
187
212
|
eq_type=eq_type,
|
|
188
|
-
input_transform=input_transform,
|
|
189
|
-
output_transform=output_transform,
|
|
213
|
+
input_transform=input_transform, # type: ignore
|
|
214
|
+
output_transform=output_transform, # type: ignore
|
|
190
215
|
filter_spec=filter_spec,
|
|
191
216
|
)
|
|
192
217
|
return mlp, mlp.init_params
|
jinns/nn/_pinn.py
CHANGED
|
@@ -2,15 +2,19 @@
|
|
|
2
2
|
Implement abstract class for PINN architectures
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from typing import Callable, Union, Any, Literal, overload
|
|
6
8
|
from dataclasses import InitVar
|
|
7
9
|
import equinox as eqx
|
|
8
10
|
from jaxtyping import Float, Array, PyTree
|
|
9
11
|
import jax.numpy as jnp
|
|
10
|
-
from jinns.parameters._params import Params
|
|
12
|
+
from jinns.parameters._params import Params
|
|
13
|
+
from jinns.nn._abstract_pinn import AbstractPINN
|
|
14
|
+
from jinns.nn._utils import _PyTree_to_Params
|
|
11
15
|
|
|
12
16
|
|
|
13
|
-
class PINN(
|
|
17
|
+
class PINN(AbstractPINN):
|
|
14
18
|
r"""
|
|
15
19
|
Base class for PINN objects. It can be seen as a wrapper on
|
|
16
20
|
an `eqx.Module` which actually implement the NN architectures, with extra
|
|
@@ -57,12 +61,12 @@ class PINN(eqx.Module):
|
|
|
57
61
|
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
58
62
|
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
59
63
|
after the `input_transform` function.
|
|
60
|
-
input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
|
|
64
|
+
input_transform : Callable[[Float[Array, " input_dim"], Params[Array]], Float[Array, " output_dim"]]
|
|
61
65
|
A function that will be called before entering the PINN. Its output(s)
|
|
62
66
|
must match the PINN inputs (except for the parameters).
|
|
63
67
|
Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
|
|
64
68
|
and the parameters. Default is no operation.
|
|
65
|
-
output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
|
|
69
|
+
output_transform : Callable[[Float[Array, " input_dim"], Float[Array, " output_dim"], Params[Array]], Float[Array, " output_dim"]]
|
|
66
70
|
A function with arguments begin the same input as the PINN, the PINN
|
|
67
71
|
output and the parameter. This function will be called after exiting the PINN.
|
|
68
72
|
Default is no operation.
|
|
@@ -84,16 +88,16 @@ class PINN(eqx.Module):
|
|
|
84
88
|
"nonstatio_PDE"]`
|
|
85
89
|
"""
|
|
86
90
|
|
|
87
|
-
slice_solution: slice = eqx.field(static=True, kw_only=True, default=None)
|
|
88
91
|
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"] = eqx.field(
|
|
89
92
|
static=True, kw_only=True
|
|
90
93
|
)
|
|
94
|
+
slice_solution: slice = eqx.field(static=True, kw_only=True, default=None)
|
|
91
95
|
input_transform: Callable[
|
|
92
|
-
[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
|
|
96
|
+
[Float[Array, " input_dim"], Params[Array]], Float[Array, " output_dim"]
|
|
93
97
|
] = eqx.field(static=True, kw_only=True, default=None)
|
|
94
98
|
output_transform: Callable[
|
|
95
|
-
[Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
|
|
96
|
-
Float[Array, "output_dim"],
|
|
99
|
+
[Float[Array, " input_dim"], Float[Array, " output_dim"], Params[Array]],
|
|
100
|
+
Float[Array, " output_dim"],
|
|
97
101
|
] = eqx.field(static=True, kw_only=True, default=None)
|
|
98
102
|
|
|
99
103
|
eqx_network: InitVar[eqx.Module] = eqx.field(kw_only=True)
|
|
@@ -101,11 +105,10 @@ class PINN(eqx.Module):
|
|
|
101
105
|
static=True, kw_only=True, default=eqx.is_inexact_array
|
|
102
106
|
)
|
|
103
107
|
|
|
104
|
-
init_params:
|
|
105
|
-
static:
|
|
108
|
+
init_params: PINN = eqx.field(init=False)
|
|
109
|
+
static: PINN = eqx.field(init=False, static=True)
|
|
106
110
|
|
|
107
111
|
def __post_init__(self, eqx_network):
|
|
108
|
-
|
|
109
112
|
if self.eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
|
|
110
113
|
raise RuntimeError("Wrong parameter value for eq_type")
|
|
111
114
|
# saving the static part of the model and initial parameters
|
|
@@ -154,18 +157,32 @@ class PINN(eqx.Module):
|
|
|
154
157
|
|
|
155
158
|
return network(inputs)
|
|
156
159
|
|
|
160
|
+
@overload
|
|
161
|
+
@_PyTree_to_Params
|
|
162
|
+
def __call__(
|
|
163
|
+
self,
|
|
164
|
+
inputs: Float[Array, " input_dim"],
|
|
165
|
+
params: PyTree,
|
|
166
|
+
*args,
|
|
167
|
+
**kwargs,
|
|
168
|
+
) -> Float[Array, " output_dim"]: ...
|
|
169
|
+
|
|
170
|
+
@_PyTree_to_Params
|
|
157
171
|
def __call__(
|
|
158
172
|
self,
|
|
159
|
-
inputs: Float[Array, "input_dim"],
|
|
160
|
-
params: Params
|
|
173
|
+
inputs: Float[Array, " input_dim"],
|
|
174
|
+
params: Params[Array],
|
|
161
175
|
*args,
|
|
162
176
|
**kwargs,
|
|
163
|
-
) -> Float[Array, "output_dim"]:
|
|
177
|
+
) -> Float[Array, " output_dim"]:
|
|
164
178
|
"""
|
|
165
179
|
A proper __call__ implementation performs an eqx.combine here with
|
|
166
180
|
`params` and `self.static` to recreate the callable eqx.Module
|
|
167
181
|
architecture. The rest of the content of this function is dependent on
|
|
168
182
|
the network.
|
|
183
|
+
|
|
184
|
+
Note that that thanks to the decorator, params can also directly be the
|
|
185
|
+
PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
|
|
169
186
|
"""
|
|
170
187
|
|
|
171
188
|
if len(inputs.shape) == 0:
|
|
@@ -174,10 +191,7 @@ class PINN(eqx.Module):
|
|
|
174
191
|
# DataGenerators)
|
|
175
192
|
inputs = inputs[None]
|
|
176
193
|
|
|
177
|
-
|
|
178
|
-
model = eqx.combine(params.nn_params, self.static)
|
|
179
|
-
except (KeyError, AttributeError, TypeError) as e: # give more flexibility
|
|
180
|
-
model = eqx.combine(params, self.static)
|
|
194
|
+
model = eqx.combine(params.nn_params, self.static)
|
|
181
195
|
|
|
182
196
|
# evaluate the model
|
|
183
197
|
res = self.eval(model, self.input_transform(inputs, params), *args, **kwargs)
|
jinns/nn/_ppinn.py
CHANGED
|
@@ -2,17 +2,20 @@
|
|
|
2
2
|
Implements utility function to create PINNs
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from typing import Callable, Literal, Self, cast, overload
|
|
6
8
|
from dataclasses import InitVar
|
|
7
9
|
import jax
|
|
8
10
|
import jax.numpy as jnp
|
|
9
11
|
import equinox as eqx
|
|
10
12
|
|
|
11
|
-
from jaxtyping import Array, Key,
|
|
13
|
+
from jaxtyping import Array, Key, Float, PyTree
|
|
12
14
|
|
|
13
|
-
from jinns.parameters._params import Params
|
|
15
|
+
from jinns.parameters._params import Params
|
|
14
16
|
from jinns.nn._pinn import PINN
|
|
15
17
|
from jinns.nn._mlp import MLP
|
|
18
|
+
from jinns.nn._utils import _PyTree_to_Params
|
|
16
19
|
|
|
17
20
|
|
|
18
21
|
class PPINN_MLP(PINN):
|
|
@@ -39,12 +42,12 @@ class PPINN_MLP(PINN):
|
|
|
39
42
|
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
40
43
|
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
41
44
|
after the `input_transform` function.
|
|
42
|
-
input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
|
|
45
|
+
input_transform : Callable[[Float[Array, " input_dim"], Params[Array]], Float[Array, " output_dim"]]
|
|
43
46
|
A function that will be called before entering the PPINN. Its output(s)
|
|
44
47
|
must match the PPINN inputs (except for the parameters).
|
|
45
48
|
Its inputs are the PPINN inputs (`t` and/or `x` concatenated together)
|
|
46
49
|
and the parameters. Default is no operation.
|
|
47
|
-
output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
|
|
50
|
+
output_transform : Callable[[Float[Array, " input_dim"], Float[Array, " output_dim"], Params[Array]], Float[Array, " output_dim"]]
|
|
48
51
|
A function with arguments begin the same input as the PPINN, the PPINN
|
|
49
52
|
output and the parameter. This function will be called after exiting
|
|
50
53
|
the PPINN.
|
|
@@ -63,25 +66,46 @@ class PPINN_MLP(PINN):
|
|
|
63
66
|
"""
|
|
64
67
|
|
|
65
68
|
eqx_network_list: InitVar[list[eqx.Module]] = eqx.field(kw_only=True)
|
|
69
|
+
init_params: tuple[PINN, ...] = eqx.field(
|
|
70
|
+
init=False
|
|
71
|
+
) # overriding parent attribute type
|
|
72
|
+
static: tuple[PINN, ...] = eqx.field(
|
|
73
|
+
init=False, static=True
|
|
74
|
+
) # overriding parent attribute type
|
|
66
75
|
|
|
67
76
|
def __post_init__(self, eqx_network, eqx_network_list):
|
|
68
77
|
super().__post_init__(
|
|
69
78
|
eqx_network=eqx_network_list[0], # this is not used since it is
|
|
70
79
|
# overwritten just below
|
|
71
80
|
)
|
|
72
|
-
|
|
73
|
-
|
|
81
|
+
params, static = eqx.partition(eqx_network_list[0], self.filter_spec)
|
|
82
|
+
self.init_params, self.static = (params,), (static,)
|
|
83
|
+
for eqx_network_ in eqx_network_list[1:]:
|
|
74
84
|
params, static = eqx.partition(eqx_network_, self.filter_spec)
|
|
75
85
|
self.init_params = self.init_params + (params,)
|
|
76
86
|
self.static = self.static + (static,)
|
|
77
87
|
|
|
88
|
+
@overload
|
|
89
|
+
@_PyTree_to_Params
|
|
78
90
|
def __call__(
|
|
79
91
|
self,
|
|
80
|
-
inputs: Float[Array, "
|
|
92
|
+
inputs: Float[Array, " input_dim"],
|
|
81
93
|
params: PyTree,
|
|
82
|
-
|
|
94
|
+
*args,
|
|
95
|
+
**kwargs,
|
|
96
|
+
) -> Float[Array, " output_dim"]: ...
|
|
97
|
+
|
|
98
|
+
@_PyTree_to_Params
|
|
99
|
+
def __call__(
|
|
100
|
+
self,
|
|
101
|
+
inputs: Float[Array, " 1"] | Float[Array, " dim"] | Float[Array, " 1+dim"],
|
|
102
|
+
params: Params[Array],
|
|
103
|
+
) -> Float[Array, " output_dim"]:
|
|
83
104
|
"""
|
|
84
105
|
Evaluate the PPINN on some inputs with some params.
|
|
106
|
+
|
|
107
|
+
Note that that thanks to the decorator, params can also directly be the
|
|
108
|
+
PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
|
|
85
109
|
"""
|
|
86
110
|
if len(inputs.shape) == 0:
|
|
87
111
|
# This can happen often when the user directly provides some
|
|
@@ -92,14 +116,14 @@ class PPINN_MLP(PINN):
|
|
|
92
116
|
|
|
93
117
|
outs = []
|
|
94
118
|
|
|
95
|
-
try:
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
except (KeyError, AttributeError, TypeError) as e:
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
119
|
+
# try:
|
|
120
|
+
for params_, static in zip(params.nn_params, self.static):
|
|
121
|
+
model = eqx.combine(params_, static)
|
|
122
|
+
outs += [model(transformed_inputs)] # type: ignore
|
|
123
|
+
# except (KeyError, AttributeError, TypeError) as e:
|
|
124
|
+
# for params_, static in zip(params, self.static):
|
|
125
|
+
# model = eqx.combine(params_, static)
|
|
126
|
+
# outs += [model(transformed_inputs)]
|
|
103
127
|
# Note that below is then a global output transform
|
|
104
128
|
res = self.output_transform(inputs, jnp.concatenate(outs, axis=0), params)
|
|
105
129
|
|
|
@@ -112,18 +136,31 @@ class PPINN_MLP(PINN):
|
|
|
112
136
|
def create(
|
|
113
137
|
cls,
|
|
114
138
|
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
|
|
115
|
-
eqx_network_list: list[eqx.nn.MLP] = None,
|
|
139
|
+
eqx_network_list: list[eqx.nn.MLP | MLP] | None = None,
|
|
116
140
|
key: Key = None,
|
|
117
|
-
eqx_list_list:
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
141
|
+
eqx_list_list: (
|
|
142
|
+
list[tuple[tuple[Callable, int, int] | tuple[Callable], ...]] | None
|
|
143
|
+
) = None,
|
|
144
|
+
input_transform: (
|
|
145
|
+
Callable[
|
|
146
|
+
[Float[Array, " input_dim"], Params[Array]],
|
|
147
|
+
Float[Array, " output_dim"],
|
|
148
|
+
]
|
|
149
|
+
| None
|
|
150
|
+
) = None,
|
|
151
|
+
output_transform: (
|
|
152
|
+
Callable[
|
|
153
|
+
[
|
|
154
|
+
Float[Array, " input_dim"],
|
|
155
|
+
Float[Array, " output_dim"],
|
|
156
|
+
Params[Array],
|
|
157
|
+
],
|
|
158
|
+
Float[Array, " output_dim"],
|
|
159
|
+
]
|
|
160
|
+
| None
|
|
161
|
+
) = None,
|
|
162
|
+
slice_solution: slice | None = None,
|
|
163
|
+
) -> tuple[Self, tuple[PINN, ...]]:
|
|
127
164
|
r"""
|
|
128
165
|
Utility function to create a Parrallel PINN neural network for Jinns.
|
|
129
166
|
|
|
@@ -189,15 +226,14 @@ class PPINN_MLP(PINN):
|
|
|
189
226
|
eqx_network_list = []
|
|
190
227
|
for eqx_list in eqx_list_list:
|
|
191
228
|
key, subkey = jax.random.split(key, 2)
|
|
192
|
-
print(subkey)
|
|
193
229
|
eqx_network_list.append(MLP(key=subkey, eqx_list=eqx_list))
|
|
194
230
|
|
|
195
231
|
ppinn = cls(
|
|
196
|
-
eqx_network=None,
|
|
197
|
-
eqx_network_list=eqx_network_list,
|
|
198
|
-
slice_solution=slice_solution,
|
|
232
|
+
eqx_network=None, # type: ignore
|
|
233
|
+
eqx_network_list=cast(list[eqx.Module], eqx_network_list),
|
|
234
|
+
slice_solution=slice_solution, # type: ignore
|
|
199
235
|
eq_type=eq_type,
|
|
200
|
-
input_transform=input_transform,
|
|
201
|
-
output_transform=output_transform,
|
|
236
|
+
input_transform=input_transform, # type: ignore
|
|
237
|
+
output_transform=output_transform, # type: ignore
|
|
202
238
|
)
|
|
203
239
|
return ppinn, ppinn.init_params
|
jinns/nn/_save_load.py
CHANGED
|
@@ -12,7 +12,7 @@ from jinns.nn._spinn import SPINN
|
|
|
12
12
|
from jinns.nn._mlp import PINN_MLP
|
|
13
13
|
from jinns.nn._spinn_mlp import SPINN_MLP
|
|
14
14
|
from jinns.nn._hyperpinn import HyperPINN
|
|
15
|
-
from jinns.parameters._params import Params
|
|
15
|
+
from jinns.parameters._params import Params
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
def function_to_string(
|
|
@@ -87,7 +87,7 @@ def string_to_function(
|
|
|
87
87
|
def save_pinn(
|
|
88
88
|
filename: str,
|
|
89
89
|
u: PINN | HyperPINN | SPINN,
|
|
90
|
-
params: Params
|
|
90
|
+
params: Params,
|
|
91
91
|
kwargs_creation,
|
|
92
92
|
):
|
|
93
93
|
"""
|
|
@@ -105,15 +105,12 @@ def save_pinn(
|
|
|
105
105
|
tree_serialise_leaves`).
|
|
106
106
|
|
|
107
107
|
Equation parameters are saved apart because the initial type of attribute
|
|
108
|
-
`params` in PINN / HyperPINN / SPINN is not `Params`
|
|
108
|
+
`params` in PINN / HyperPINN / SPINN is not `Params`
|
|
109
109
|
but `PyTree` as inherited from `eqx.partition`.
|
|
110
110
|
Therefore, if we want to ensure a proper serialization/deserialization:
|
|
111
111
|
- we cannot save a `Params` object at this
|
|
112
112
|
attribute field ; the `Params` object must be split into `Params.nn_params`
|
|
113
113
|
(type `PyTree`) and `Params.eq_params` (type `dict`).
|
|
114
|
-
- in the case of a `ParamsDict` we cannot save `ParamsDict.nn_params` at
|
|
115
|
-
the attribute field `params` because it is not a `PyTree` (as expected in
|
|
116
|
-
the PINN / HyperPINN / SPINN signature) but it is still a dictionary.
|
|
117
114
|
|
|
118
115
|
Parameters
|
|
119
116
|
----------
|
|
@@ -122,28 +119,16 @@ def save_pinn(
|
|
|
122
119
|
u
|
|
123
120
|
The PINN
|
|
124
121
|
params
|
|
125
|
-
Params
|
|
122
|
+
Params to be saved
|
|
126
123
|
kwargs_creation
|
|
127
124
|
The dictionary of arguments that were used to create the PINN, e.g.
|
|
128
125
|
the layers list, O/PDE type, etc.
|
|
129
126
|
"""
|
|
130
|
-
if isinstance(
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
eqx.tree_serialise_leaves(filename + "-module.eqx", u)
|
|
136
|
-
|
|
137
|
-
elif isinstance(params, ParamsDict):
|
|
138
|
-
for key, params_ in params.nn_params.items():
|
|
139
|
-
if isinstance(u, HyperPINN):
|
|
140
|
-
u = eqx.tree_at(lambda m: m.init_params_hyper, u, params_)
|
|
141
|
-
elif isinstance(u, (PINN, SPINN)):
|
|
142
|
-
u = eqx.tree_at(lambda m: m.init_params, u, params_)
|
|
143
|
-
eqx.tree_serialise_leaves(filename + f"-module_{key}.eqx", u)
|
|
144
|
-
|
|
145
|
-
else:
|
|
146
|
-
raise ValueError("The parameters to be saved must be a Params or a ParamsDict")
|
|
127
|
+
if isinstance(u, HyperPINN):
|
|
128
|
+
u = eqx.tree_at(lambda m: m.init_params_hyper, u, params)
|
|
129
|
+
elif isinstance(u, (PINN, SPINN)):
|
|
130
|
+
u = eqx.tree_at(lambda m: m.init_params, u, params)
|
|
131
|
+
eqx.tree_serialise_leaves(filename + "-module.eqx", u)
|
|
147
132
|
|
|
148
133
|
with open(filename + "-eq_params.pkl", "wb") as f:
|
|
149
134
|
pickle.dump(params.eq_params, f)
|
|
@@ -170,8 +155,7 @@ def save_pinn(
|
|
|
170
155
|
def load_pinn(
|
|
171
156
|
filename: str,
|
|
172
157
|
type_: Literal["pinn_mlp", "hyperpinn", "spinn_mlp"],
|
|
173
|
-
|
|
174
|
-
) -> tuple[eqx.Module, Params | ParamsDict]:
|
|
158
|
+
) -> tuple[eqx.Module, Params]:
|
|
175
159
|
"""
|
|
176
160
|
Load a PINN model. This function needs to access 3 files :
|
|
177
161
|
`{filename}-module.eqx`, `{filename}-parameters.pkl` and
|
|
@@ -190,8 +174,6 @@ def load_pinn(
|
|
|
190
174
|
Filename (prefix) without extension.
|
|
191
175
|
type_
|
|
192
176
|
Type of model to load. Must be in ["pinn_mlp", "hyperpinn", "spinn"].
|
|
193
|
-
key_list_for_paramsdict
|
|
194
|
-
Pass the name of the keys of the dictionnary `ParamsDict.nn_params`. Default is None. In this case, we expect to retrieve a ParamsDict.
|
|
195
177
|
|
|
196
178
|
Returns
|
|
197
179
|
-------
|
|
@@ -228,29 +210,17 @@ def load_pinn(
|
|
|
228
210
|
)
|
|
229
211
|
else:
|
|
230
212
|
raise ValueError(f"{type_} is not valid")
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
213
|
+
# now the empty structure is populated with the actual saved array values
|
|
214
|
+
# stored in the eqx file
|
|
215
|
+
u_reloaded = eqx.tree_deserialise_leaves(
|
|
216
|
+
filename + "-module.eqx", u_reloaded_shallow
|
|
217
|
+
)
|
|
218
|
+
if isinstance(u_reloaded, HyperPINN):
|
|
219
|
+
params = Params(
|
|
220
|
+
nn_params=u_reloaded.init_params_hyper, eq_params=eq_params_reloaded
|
|
236
221
|
)
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
nn_params=u_reloaded.init_params_hyper, eq_params=eq_params_reloaded
|
|
240
|
-
)
|
|
241
|
-
elif isinstance(u_reloaded, (PINN, SPINN)):
|
|
242
|
-
params = Params(
|
|
243
|
-
nn_params=u_reloaded.init_params, eq_params=eq_params_reloaded
|
|
244
|
-
)
|
|
222
|
+
elif isinstance(u_reloaded, (PINN, SPINN)):
|
|
223
|
+
params = Params(nn_params=u_reloaded.init_params, eq_params=eq_params_reloaded)
|
|
245
224
|
else:
|
|
246
|
-
|
|
247
|
-
for key in key_list_for_paramsdict:
|
|
248
|
-
u_reloaded = eqx.tree_deserialise_leaves(
|
|
249
|
-
filename + f"-module_{key}.eqx", u_reloaded_shallow
|
|
250
|
-
)
|
|
251
|
-
if isinstance(u_reloaded, HyperPINN):
|
|
252
|
-
nn_params_dict[key] = u_reloaded.init_params_hyper
|
|
253
|
-
elif isinstance(u_reloaded, (PINN, SPINN)):
|
|
254
|
-
nn_params_dict[key] = u_reloaded.init_params
|
|
255
|
-
params = ParamsDict(nn_params=nn_params_dict, eq_params=eq_params_reloaded)
|
|
225
|
+
raise ValueError("Wrong type for u_reloaded")
|
|
256
226
|
return u_reloaded, params
|
jinns/nn/_spinn.py
CHANGED
|
@@ -1,14 +1,17 @@
|
|
|
1
|
-
from
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Union, Callable, Any, Literal, overload
|
|
2
3
|
from dataclasses import InitVar
|
|
3
4
|
from jaxtyping import PyTree, Float, Array
|
|
4
5
|
import jax
|
|
5
6
|
import jax.numpy as jnp
|
|
6
7
|
import equinox as eqx
|
|
7
8
|
|
|
8
|
-
from jinns.parameters._params import Params
|
|
9
|
+
from jinns.parameters._params import Params
|
|
10
|
+
from jinns.nn._abstract_pinn import AbstractPINN
|
|
11
|
+
from jinns.nn._utils import _PyTree_to_Params
|
|
9
12
|
|
|
10
13
|
|
|
11
|
-
class SPINN(
|
|
14
|
+
class SPINN(AbstractPINN):
|
|
12
15
|
"""
|
|
13
16
|
A Separable PINN object compatible with the rest of jinns.
|
|
14
17
|
|
|
@@ -47,21 +50,21 @@ class SPINN(eqx.Module):
|
|
|
47
50
|
|
|
48
51
|
"""
|
|
49
52
|
|
|
53
|
+
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"] = eqx.field(
|
|
54
|
+
static=True, kw_only=True
|
|
55
|
+
)
|
|
50
56
|
d: int = eqx.field(static=True, kw_only=True)
|
|
51
57
|
r: int = eqx.field(static=True, kw_only=True)
|
|
52
|
-
eq_type: str = eqx.field(static=True, kw_only=True)
|
|
53
58
|
m: int = eqx.field(static=True, kw_only=True, default=1)
|
|
54
|
-
|
|
55
59
|
filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = eqx.field(
|
|
56
60
|
static=True, kw_only=True, default=None
|
|
57
61
|
)
|
|
58
62
|
eqx_spinn_network: InitVar[eqx.Module] = eqx.field(kw_only=True)
|
|
59
63
|
|
|
60
|
-
init_params:
|
|
61
|
-
static:
|
|
64
|
+
init_params: SPINN = eqx.field(init=False)
|
|
65
|
+
static: SPINN = eqx.field(init=False, static=True)
|
|
62
66
|
|
|
63
67
|
def __post_init__(self, eqx_spinn_network):
|
|
64
|
-
|
|
65
68
|
if self.filter_spec is None:
|
|
66
69
|
self.filter_spec = eqx.is_inexact_array
|
|
67
70
|
|
|
@@ -69,20 +72,34 @@ class SPINN(eqx.Module):
|
|
|
69
72
|
eqx_spinn_network, self.filter_spec
|
|
70
73
|
)
|
|
71
74
|
|
|
75
|
+
@overload
|
|
76
|
+
@_PyTree_to_Params
|
|
72
77
|
def __call__(
|
|
73
78
|
self,
|
|
74
|
-
|
|
75
|
-
params:
|
|
76
|
-
|
|
79
|
+
inputs: Float[Array, " input_dim"],
|
|
80
|
+
params: PyTree,
|
|
81
|
+
*args,
|
|
82
|
+
**kwargs,
|
|
83
|
+
) -> Float[Array, " output_dim"]: ...
|
|
84
|
+
|
|
85
|
+
@_PyTree_to_Params
|
|
86
|
+
def __call__(
|
|
87
|
+
self,
|
|
88
|
+
t_x: Float[Array, " batch_size 1+dim"],
|
|
89
|
+
params: Params[Array],
|
|
90
|
+
) -> Float[Array, " output_dim"]:
|
|
77
91
|
"""
|
|
78
92
|
Evaluate the SPINN on some inputs with some params.
|
|
93
|
+
|
|
94
|
+
Note that that thanks to the decorator, params can also directly be the
|
|
95
|
+
PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
|
|
79
96
|
"""
|
|
80
|
-
try:
|
|
81
|
-
|
|
82
|
-
except (KeyError, AttributeError, TypeError) as e:
|
|
83
|
-
|
|
97
|
+
# try:
|
|
98
|
+
spinn = eqx.combine(params.nn_params, self.static)
|
|
99
|
+
# except (KeyError, AttributeError, TypeError) as e:
|
|
100
|
+
# spinn = eqx.combine(params, self.static)
|
|
84
101
|
v_model = jax.vmap(spinn)
|
|
85
|
-
res = v_model(t_x)
|
|
102
|
+
res = v_model(t_x) # type: ignore
|
|
86
103
|
|
|
87
104
|
a = ", ".join([f"{chr(97 + d)}z" for d in range(res.shape[1])])
|
|
88
105
|
b = "".join([f"{chr(97 + d)}" for d in range(res.shape[1])])
|