jinns 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +116 -189
- jinns/loss/_DynamicLossAbstract.py +45 -68
- jinns/loss/_LossODE.py +71 -336
- jinns/loss/_LossPDE.py +176 -513
- jinns/loss/__init__.py +28 -6
- jinns/loss/_abstract_loss.py +15 -0
- jinns/loss/_boundary_conditions.py +22 -21
- jinns/loss/_loss_utils.py +98 -173
- jinns/loss/_loss_weights.py +12 -44
- jinns/loss/_operators.py +84 -76
- jinns/nn/__init__.py +22 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +434 -0
- jinns/nn/_mlp.py +217 -0
- jinns/nn/_pinn.py +204 -0
- jinns/nn/_ppinn.py +239 -0
- jinns/{utils → nn}/_save_load.py +39 -53
- jinns/nn/_spinn.py +123 -0
- jinns/nn/_spinn_mlp.py +202 -0
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +38 -37
- jinns/solver/_rar.py +82 -65
- jinns/solver/_solve.py +111 -71
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -5
- jinns/utils/_containers.py +12 -9
- jinns/utils/_types.py +11 -57
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/METADATA +11 -10
- jinns-1.4.0.dist-info/RECORD +53 -0
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns/utils/_hyperpinn.py +0 -420
- jinns/utils/_pinn.py +0 -324
- jinns/utils/_ppinn.py +0 -227
- jinns/utils/_spinn.py +0 -249
- jinns-1.2.0.dist-info/RECORD +0 -41
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/nn/_pinn.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Implement abstract class for PINN architectures
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from typing import Callable, Union, Any, Literal, overload
|
|
8
|
+
from dataclasses import InitVar
|
|
9
|
+
import equinox as eqx
|
|
10
|
+
from jaxtyping import Float, Array, PyTree
|
|
11
|
+
import jax.numpy as jnp
|
|
12
|
+
from jinns.parameters._params import Params
|
|
13
|
+
from jinns.nn._abstract_pinn import AbstractPINN
|
|
14
|
+
from jinns.nn._utils import _PyTree_to_Params
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class PINN(AbstractPINN):
|
|
18
|
+
r"""
|
|
19
|
+
Base class for PINN objects. It can be seen as a wrapper on
|
|
20
|
+
an `eqx.Module` which actually implement the NN architectures, with extra
|
|
21
|
+
arguments handling the "physics-informed" aspect.
|
|
22
|
+
|
|
23
|
+
!!! Note
|
|
24
|
+
We use the `eqx.partition` and `eqx.combine` strategy of Equinox: a
|
|
25
|
+
`filter_spec` is applied on the PyTree and splits it into two PyTree with
|
|
26
|
+
the same structure: a static one (invisible to JAX transform such as JIT,
|
|
27
|
+
grad, etc.) and dynamic one. By convention, anything not static is
|
|
28
|
+
considered a parameter in Jinns.
|
|
29
|
+
|
|
30
|
+
For compatibility with jinns, we require that a `PINN` architecture:
|
|
31
|
+
|
|
32
|
+
1) has an eqx.Module (`eqx_network`) InitVar passed to __post_init__
|
|
33
|
+
representing the network architecture.
|
|
34
|
+
2) calls `eqx.partition` in __post_init__ in order to store the
|
|
35
|
+
static part of the model and the initial parameters.
|
|
36
|
+
3) has a `eq_type` argument, used for handling internal operations in
|
|
37
|
+
jinns.
|
|
38
|
+
4) has a `slice_solution` argument. It is a `jnp.s\_` object which
|
|
39
|
+
indicates which axis of the PINN output is dedicated to the actual equation
|
|
40
|
+
solution. Default None means that slice_solution = the whole PINN output.
|
|
41
|
+
For example, this argument is useful when the PINN is also used to output
|
|
42
|
+
equation parameters. Note that it must be a slice and not an integer (a
|
|
43
|
+
preprocessing of the user provided argument takes care of it).
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
slice_solution : slice
|
|
48
|
+
Default is jnp.s\_[...]. A jnp.s\_ object which indicates which axis of the PINN output is
|
|
49
|
+
dedicated to the actual equation solution. Default None
|
|
50
|
+
means that slice_solution = the whole PINN output. This argument is useful
|
|
51
|
+
when the PINN is also used to output equation parameters for example
|
|
52
|
+
Note that it must be a slice and not an integer (a preprocessing of the
|
|
53
|
+
user provided argument takes care of it).
|
|
54
|
+
eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
|
|
55
|
+
A string with three possibilities.
|
|
56
|
+
"ODE": the PINN is called with one input `t`.
|
|
57
|
+
"statio_PDE": the PINN is called with one input `x`, `x`
|
|
58
|
+
can be high dimensional.
|
|
59
|
+
"nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
|
|
60
|
+
can be high dimensional.
|
|
61
|
+
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
62
|
+
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
63
|
+
after the `input_transform` function.
|
|
64
|
+
input_transform : Callable[[Float[Array, " input_dim"], Params[Array]], Float[Array, " output_dim"]]
|
|
65
|
+
A function that will be called before entering the PINN. Its output(s)
|
|
66
|
+
must match the PINN inputs (except for the parameters).
|
|
67
|
+
Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
|
|
68
|
+
and the parameters. Default is no operation.
|
|
69
|
+
output_transform : Callable[[Float[Array, " input_dim"], Float[Array, " output_dim"], Params[Array]], Float[Array, " output_dim"]]
|
|
70
|
+
A function with arguments begin the same input as the PINN, the PINN
|
|
71
|
+
output and the parameter. This function will be called after exiting the PINN.
|
|
72
|
+
Default is no operation.
|
|
73
|
+
eqx_network : eqx.Module
|
|
74
|
+
The actual neural network instanciated as an eqx.Module.
|
|
75
|
+
filter_spec : PyTree[Union[bool, Callable[[Any], bool]]]
|
|
76
|
+
Default is `eqx.is_inexact_array`. This tells Jinns what to consider as
|
|
77
|
+
a trainable parameter. Quoting from equinox documentation:
|
|
78
|
+
a PyTree whose structure should be a prefix of the structure of pytree.
|
|
79
|
+
Each of its leaves should either be 1) True, in which case the leaf or
|
|
80
|
+
subtree is kept; 2) False, in which case the leaf or subtree is
|
|
81
|
+
replaced with replace; 3) a callable Leaf -> bool, in which case this is evaluated on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
Raises
|
|
85
|
+
------
|
|
86
|
+
RuntimeError
|
|
87
|
+
If the parameter value for eq_type is not in `["ODE", "statio_PDE",
|
|
88
|
+
"nonstatio_PDE"]`
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"] = eqx.field(
|
|
92
|
+
static=True, kw_only=True
|
|
93
|
+
)
|
|
94
|
+
slice_solution: slice = eqx.field(static=True, kw_only=True, default=None)
|
|
95
|
+
input_transform: Callable[
|
|
96
|
+
[Float[Array, " input_dim"], Params[Array]], Float[Array, " output_dim"]
|
|
97
|
+
] = eqx.field(static=True, kw_only=True, default=None)
|
|
98
|
+
output_transform: Callable[
|
|
99
|
+
[Float[Array, " input_dim"], Float[Array, " output_dim"], Params[Array]],
|
|
100
|
+
Float[Array, " output_dim"],
|
|
101
|
+
] = eqx.field(static=True, kw_only=True, default=None)
|
|
102
|
+
|
|
103
|
+
eqx_network: InitVar[eqx.Module] = eqx.field(kw_only=True)
|
|
104
|
+
filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = eqx.field(
|
|
105
|
+
static=True, kw_only=True, default=eqx.is_inexact_array
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
init_params: PINN = eqx.field(init=False)
|
|
109
|
+
static: PINN = eqx.field(init=False, static=True)
|
|
110
|
+
|
|
111
|
+
def __post_init__(self, eqx_network):
|
|
112
|
+
if self.eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
|
|
113
|
+
raise RuntimeError("Wrong parameter value for eq_type")
|
|
114
|
+
# saving the static part of the model and initial parameters
|
|
115
|
+
|
|
116
|
+
if self.filter_spec is None:
|
|
117
|
+
self.filter_spec = eqx.is_inexact_array
|
|
118
|
+
|
|
119
|
+
self.init_params, self.static = eqx.partition(eqx_network, self.filter_spec)
|
|
120
|
+
|
|
121
|
+
if self.input_transform is None:
|
|
122
|
+
self.input_transform = lambda _in, _params: _in
|
|
123
|
+
|
|
124
|
+
if self.output_transform is None:
|
|
125
|
+
self.output_transform = lambda _in_pinn, _out_pinn, _params: _out_pinn
|
|
126
|
+
|
|
127
|
+
if self.slice_solution is None:
|
|
128
|
+
self.slice_solution = jnp.s_[:]
|
|
129
|
+
|
|
130
|
+
if isinstance(self.slice_solution, int):
|
|
131
|
+
# rewrite it as a slice to ensure that axis does not disappear when
|
|
132
|
+
# indexing
|
|
133
|
+
self.slice_solution = jnp.s_[self.slice_solution : self.slice_solution + 1]
|
|
134
|
+
|
|
135
|
+
def eval(self, network, inputs, *args, **kwargs):
|
|
136
|
+
"""How to call your Equinox module `network`. The purpose of this method
|
|
137
|
+
is to give more flexibility : user should re-implement `eval`
|
|
138
|
+
when inheriting from `PINN` if they desire more flexibility on how to
|
|
139
|
+
evaluate the network.
|
|
140
|
+
|
|
141
|
+
Defaults to using `network.__call__(inputs)` but it could be more refined *e.g.* `network.anymethod(inputs)`.
|
|
142
|
+
|
|
143
|
+
Parameters
|
|
144
|
+
----------
|
|
145
|
+
network : eqx.Module
|
|
146
|
+
Your neural network with the parameters set, usually returned by
|
|
147
|
+
`eqx.combine(self.static, current_params)`.
|
|
148
|
+
inputs : Array
|
|
149
|
+
The inputs, evetually transformed by `self.input_transformed` if
|
|
150
|
+
specified by the user.
|
|
151
|
+
|
|
152
|
+
Returns
|
|
153
|
+
-------
|
|
154
|
+
Array
|
|
155
|
+
The output
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
return network(inputs)
|
|
159
|
+
|
|
160
|
+
@overload
|
|
161
|
+
@_PyTree_to_Params
|
|
162
|
+
def __call__(
|
|
163
|
+
self,
|
|
164
|
+
inputs: Float[Array, " input_dim"],
|
|
165
|
+
params: PyTree,
|
|
166
|
+
*args,
|
|
167
|
+
**kwargs,
|
|
168
|
+
) -> Float[Array, " output_dim"]: ...
|
|
169
|
+
|
|
170
|
+
@_PyTree_to_Params
|
|
171
|
+
def __call__(
|
|
172
|
+
self,
|
|
173
|
+
inputs: Float[Array, " input_dim"],
|
|
174
|
+
params: Params[Array],
|
|
175
|
+
*args,
|
|
176
|
+
**kwargs,
|
|
177
|
+
) -> Float[Array, " output_dim"]:
|
|
178
|
+
"""
|
|
179
|
+
A proper __call__ implementation performs an eqx.combine here with
|
|
180
|
+
`params` and `self.static` to recreate the callable eqx.Module
|
|
181
|
+
architecture. The rest of the content of this function is dependent on
|
|
182
|
+
the network.
|
|
183
|
+
|
|
184
|
+
Note that that thanks to the decorator, params can also directly be the
|
|
185
|
+
PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
if len(inputs.shape) == 0:
|
|
189
|
+
# This can happen often when the user directly provides some
|
|
190
|
+
# collocation points (eg for plotting, whithout using
|
|
191
|
+
# DataGenerators)
|
|
192
|
+
inputs = inputs[None]
|
|
193
|
+
|
|
194
|
+
model = eqx.combine(params.nn_params, self.static)
|
|
195
|
+
|
|
196
|
+
# evaluate the model
|
|
197
|
+
res = self.eval(model, self.input_transform(inputs, params), *args, **kwargs)
|
|
198
|
+
|
|
199
|
+
res = self.output_transform(inputs, res.squeeze(), params)
|
|
200
|
+
|
|
201
|
+
# force (1,) output for non vectorial solution (consistency)
|
|
202
|
+
if not res.shape:
|
|
203
|
+
return jnp.expand_dims(res, axis=-1)
|
|
204
|
+
return res
|
jinns/nn/_ppinn.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Implements utility function to create PINNs
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from typing import Callable, Literal, Self, cast, overload
|
|
8
|
+
from dataclasses import InitVar
|
|
9
|
+
import jax
|
|
10
|
+
import jax.numpy as jnp
|
|
11
|
+
import equinox as eqx
|
|
12
|
+
|
|
13
|
+
from jaxtyping import Array, Key, Float, PyTree
|
|
14
|
+
|
|
15
|
+
from jinns.parameters._params import Params
|
|
16
|
+
from jinns.nn._pinn import PINN
|
|
17
|
+
from jinns.nn._mlp import MLP
|
|
18
|
+
from jinns.nn._utils import _PyTree_to_Params
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class PPINN_MLP(PINN):
|
|
22
|
+
r"""
|
|
23
|
+
A PPINN MLP (Parallel PINN with MLPs) object which mimicks the PFNN architecture from
|
|
24
|
+
DeepXDE. This is in fact a PINN MLP that encompasses several PINN MLPs internally.
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
slice_solution : slice
|
|
29
|
+
A jnp.s\_ object which indicates which axis of the PPINN output is
|
|
30
|
+
dedicated to the actual equation solution. Default None
|
|
31
|
+
means that slice_solution = the whole PPINN output. This argument is useful
|
|
32
|
+
when the PINN is also used to output equation parameters for example
|
|
33
|
+
Note that it must be a slice and not an integer (a preprocessing of the
|
|
34
|
+
user provided argument takes care of it).
|
|
35
|
+
eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
|
|
36
|
+
A string with three possibilities.
|
|
37
|
+
"ODE": the PPINN is called with one input `t`.
|
|
38
|
+
"statio_PDE": the PPINN is called with one input `x`, `x`
|
|
39
|
+
can be high dimensional.
|
|
40
|
+
"nonstatio_PDE": the PPINN is called with two inputs `t` and `x`, `x`
|
|
41
|
+
can be high dimensional.
|
|
42
|
+
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
43
|
+
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
44
|
+
after the `input_transform` function.
|
|
45
|
+
input_transform : Callable[[Float[Array, " input_dim"], Params[Array]], Float[Array, " output_dim"]]
|
|
46
|
+
A function that will be called before entering the PPINN. Its output(s)
|
|
47
|
+
must match the PPINN inputs (except for the parameters).
|
|
48
|
+
Its inputs are the PPINN inputs (`t` and/or `x` concatenated together)
|
|
49
|
+
and the parameters. Default is no operation.
|
|
50
|
+
output_transform : Callable[[Float[Array, " input_dim"], Float[Array, " output_dim"], Params[Array]], Float[Array, " output_dim"]]
|
|
51
|
+
A function with arguments begin the same input as the PPINN, the PPINN
|
|
52
|
+
output and the parameter. This function will be called after exiting
|
|
53
|
+
the PPINN.
|
|
54
|
+
Default is no operation.
|
|
55
|
+
filter_spec : PyTree[Union[bool, Callable[[Any], bool]]]
|
|
56
|
+
Default is `eqx.is_inexact_array`. This tells Jinns what to consider as
|
|
57
|
+
a trainable parameter. Quoting from equinox documentation:
|
|
58
|
+
a PyTree whose structure should be a prefix of the structure of pytree.
|
|
59
|
+
Each of its leaves should either be 1) True, in which case the leaf or
|
|
60
|
+
subtree is kept; 2) False, in which case the leaf or subtree is
|
|
61
|
+
replaced with replace; 3) a callable Leaf -> bool, in which case this is evaluated on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
|
|
62
|
+
eqx_network_list
|
|
63
|
+
A list of eqx.nn.MLP objects with same input
|
|
64
|
+
dimensions. They represent the parallel subnetworks of the PPIN MLP.
|
|
65
|
+
Their respective outputs are concatenated.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
eqx_network_list: InitVar[list[eqx.Module]] = eqx.field(kw_only=True)
|
|
69
|
+
init_params: tuple[PINN, ...] = eqx.field(
|
|
70
|
+
init=False
|
|
71
|
+
) # overriding parent attribute type
|
|
72
|
+
static: tuple[PINN, ...] = eqx.field(
|
|
73
|
+
init=False, static=True
|
|
74
|
+
) # overriding parent attribute type
|
|
75
|
+
|
|
76
|
+
def __post_init__(self, eqx_network, eqx_network_list):
|
|
77
|
+
super().__post_init__(
|
|
78
|
+
eqx_network=eqx_network_list[0], # this is not used since it is
|
|
79
|
+
# overwritten just below
|
|
80
|
+
)
|
|
81
|
+
params, static = eqx.partition(eqx_network_list[0], self.filter_spec)
|
|
82
|
+
self.init_params, self.static = (params,), (static,)
|
|
83
|
+
for eqx_network_ in eqx_network_list[1:]:
|
|
84
|
+
params, static = eqx.partition(eqx_network_, self.filter_spec)
|
|
85
|
+
self.init_params = self.init_params + (params,)
|
|
86
|
+
self.static = self.static + (static,)
|
|
87
|
+
|
|
88
|
+
@overload
|
|
89
|
+
@_PyTree_to_Params
|
|
90
|
+
def __call__(
|
|
91
|
+
self,
|
|
92
|
+
inputs: Float[Array, " input_dim"],
|
|
93
|
+
params: PyTree,
|
|
94
|
+
*args,
|
|
95
|
+
**kwargs,
|
|
96
|
+
) -> Float[Array, " output_dim"]: ...
|
|
97
|
+
|
|
98
|
+
@_PyTree_to_Params
|
|
99
|
+
def __call__(
|
|
100
|
+
self,
|
|
101
|
+
inputs: Float[Array, " 1"] | Float[Array, " dim"] | Float[Array, " 1+dim"],
|
|
102
|
+
params: Params[Array],
|
|
103
|
+
) -> Float[Array, " output_dim"]:
|
|
104
|
+
"""
|
|
105
|
+
Evaluate the PPINN on some inputs with some params.
|
|
106
|
+
|
|
107
|
+
Note that that thanks to the decorator, params can also directly be the
|
|
108
|
+
PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
|
|
109
|
+
"""
|
|
110
|
+
if len(inputs.shape) == 0:
|
|
111
|
+
# This can happen often when the user directly provides some
|
|
112
|
+
# collocation points (eg for plotting, whithout using
|
|
113
|
+
# DataGenerators)
|
|
114
|
+
inputs = inputs[None]
|
|
115
|
+
transformed_inputs = self.input_transform(inputs, params)
|
|
116
|
+
|
|
117
|
+
outs = []
|
|
118
|
+
|
|
119
|
+
# try:
|
|
120
|
+
for params_, static in zip(params.nn_params, self.static):
|
|
121
|
+
model = eqx.combine(params_, static)
|
|
122
|
+
outs += [model(transformed_inputs)] # type: ignore
|
|
123
|
+
# except (KeyError, AttributeError, TypeError) as e:
|
|
124
|
+
# for params_, static in zip(params, self.static):
|
|
125
|
+
# model = eqx.combine(params_, static)
|
|
126
|
+
# outs += [model(transformed_inputs)]
|
|
127
|
+
# Note that below is then a global output transform
|
|
128
|
+
res = self.output_transform(inputs, jnp.concatenate(outs, axis=0), params)
|
|
129
|
+
|
|
130
|
+
## force (1,) output for non vectorial solution (consistency)
|
|
131
|
+
if not res.shape:
|
|
132
|
+
return jnp.expand_dims(res, axis=-1)
|
|
133
|
+
return res
|
|
134
|
+
|
|
135
|
+
@classmethod
|
|
136
|
+
def create(
|
|
137
|
+
cls,
|
|
138
|
+
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
|
|
139
|
+
eqx_network_list: list[eqx.nn.MLP | MLP] | None = None,
|
|
140
|
+
key: Key = None,
|
|
141
|
+
eqx_list_list: (
|
|
142
|
+
list[tuple[tuple[Callable, int, int] | tuple[Callable], ...]] | None
|
|
143
|
+
) = None,
|
|
144
|
+
input_transform: (
|
|
145
|
+
Callable[
|
|
146
|
+
[Float[Array, " input_dim"], Params[Array]],
|
|
147
|
+
Float[Array, " output_dim"],
|
|
148
|
+
]
|
|
149
|
+
| None
|
|
150
|
+
) = None,
|
|
151
|
+
output_transform: (
|
|
152
|
+
Callable[
|
|
153
|
+
[
|
|
154
|
+
Float[Array, " input_dim"],
|
|
155
|
+
Float[Array, " output_dim"],
|
|
156
|
+
Params[Array],
|
|
157
|
+
],
|
|
158
|
+
Float[Array, " output_dim"],
|
|
159
|
+
]
|
|
160
|
+
| None
|
|
161
|
+
) = None,
|
|
162
|
+
slice_solution: slice | None = None,
|
|
163
|
+
) -> tuple[Self, tuple[PINN, ...]]:
|
|
164
|
+
r"""
|
|
165
|
+
Utility function to create a Parrallel PINN neural network for Jinns.
|
|
166
|
+
|
|
167
|
+
Parameters
|
|
168
|
+
----------
|
|
169
|
+
eq_type
|
|
170
|
+
A string with three possibilities.
|
|
171
|
+
"ODE": the PPINN MLP is called with one input `t`.
|
|
172
|
+
"statio_PDE": the PPINN MLP is called with one input `x`, `x`
|
|
173
|
+
can be high dimensional.
|
|
174
|
+
"nonstatio_PDE": the PPINN MLP is called with two inputs `t` and `x`, `x`
|
|
175
|
+
can be high dimensional.
|
|
176
|
+
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
177
|
+
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
178
|
+
after the `input_transform` function.
|
|
179
|
+
eqx_network_list
|
|
180
|
+
Default is None. A list of eqx.nn.MLP objects with same input
|
|
181
|
+
dimensions. They represent the parallel subnetworks of the PPIN MLP.
|
|
182
|
+
Their respective outputs are concatenated.
|
|
183
|
+
key
|
|
184
|
+
Default is None. Must be provided with `eqx_list_list` if
|
|
185
|
+
`eqx_network_list` is not provided. A JAX random key that will be used
|
|
186
|
+
to initialize the networks parameters.
|
|
187
|
+
eqx_list_list
|
|
188
|
+
Default is None. Must be provided if `eqx_network_list` is not
|
|
189
|
+
provided. A list of `eqx_list` (see `PINN_MLP.create()`). The input dimension must be the
|
|
190
|
+
same for each sub-`eqx_list`. Then the parallel subnetworks can be
|
|
191
|
+
different. Their respective outputs are concatenated.
|
|
192
|
+
input_transform
|
|
193
|
+
A function that will be called before entering the PPINN MLP. Its output(s)
|
|
194
|
+
must match the PPINN MLP inputs (except for the parameters).
|
|
195
|
+
Its inputs are the PPINN MLP inputs (`t` and/or `x` concatenated together)
|
|
196
|
+
and the parameters. Default is no operation.
|
|
197
|
+
output_transform
|
|
198
|
+
This function will be called after exiting
|
|
199
|
+
the PPINN MLP, i.e., on the concatenated outputs of all parallel networks
|
|
200
|
+
Default is no operation.
|
|
201
|
+
slice_solution
|
|
202
|
+
A jnp.s\_ object which indicates which axis of the PPINN MLP output is
|
|
203
|
+
dedicated to the actual equation solution. Default None
|
|
204
|
+
means that slice_solution = the whole PPINN MLP output. This argument is
|
|
205
|
+
useful when the PPINN MLP is also used to output equation parameters for
|
|
206
|
+
example Note that it must be a slice and not an integer (a
|
|
207
|
+
preprocessing of the user provided argument takes care of it).
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
Returns
|
|
211
|
+
-------
|
|
212
|
+
ppinn
|
|
213
|
+
A PPINN MLP instance
|
|
214
|
+
ppinn.init_params
|
|
215
|
+
An initial set of parameters for the PPINN MLP
|
|
216
|
+
|
|
217
|
+
"""
|
|
218
|
+
|
|
219
|
+
if eqx_network_list is None:
|
|
220
|
+
if eqx_list_list is None or key is None:
|
|
221
|
+
raise ValueError(
|
|
222
|
+
"If eqx_network_list is None, then key and eqx_list_list"
|
|
223
|
+
" must be provided"
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
eqx_network_list = []
|
|
227
|
+
for eqx_list in eqx_list_list:
|
|
228
|
+
key, subkey = jax.random.split(key, 2)
|
|
229
|
+
eqx_network_list.append(MLP(key=subkey, eqx_list=eqx_list))
|
|
230
|
+
|
|
231
|
+
ppinn = cls(
|
|
232
|
+
eqx_network=None, # type: ignore
|
|
233
|
+
eqx_network_list=cast(list[eqx.Module], eqx_network_list),
|
|
234
|
+
slice_solution=slice_solution, # type: ignore
|
|
235
|
+
eq_type=eq_type,
|
|
236
|
+
input_transform=input_transform, # type: ignore
|
|
237
|
+
output_transform=output_transform, # type: ignore
|
|
238
|
+
)
|
|
239
|
+
return ppinn, ppinn.init_params
|
jinns/{utils → nn}/_save_load.py
RENAMED
|
@@ -7,14 +7,16 @@ import pickle
|
|
|
7
7
|
import jax
|
|
8
8
|
import equinox as eqx
|
|
9
9
|
|
|
10
|
-
from jinns.
|
|
11
|
-
from jinns.
|
|
12
|
-
from jinns.
|
|
13
|
-
from jinns.
|
|
10
|
+
from jinns.nn._pinn import PINN
|
|
11
|
+
from jinns.nn._spinn import SPINN
|
|
12
|
+
from jinns.nn._mlp import PINN_MLP
|
|
13
|
+
from jinns.nn._spinn_mlp import SPINN_MLP
|
|
14
|
+
from jinns.nn._hyperpinn import HyperPINN
|
|
15
|
+
from jinns.parameters._params import Params
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
def function_to_string(
|
|
17
|
-
eqx_list: tuple[tuple[Callable, int, int] | Callable, ...]
|
|
19
|
+
eqx_list: tuple[tuple[Callable, int, int] | Callable, ...],
|
|
18
20
|
) -> tuple[tuple[str, int, int] | str, ...]:
|
|
19
21
|
"""
|
|
20
22
|
We need this transformation for eqx_list to be pickled
|
|
@@ -40,7 +42,7 @@ def function_to_string(
|
|
|
40
42
|
|
|
41
43
|
|
|
42
44
|
def string_to_function(
|
|
43
|
-
eqx_list_with_string: tuple[tuple[str, int, int] | str, ...]
|
|
45
|
+
eqx_list_with_string: tuple[tuple[str, int, int] | str, ...],
|
|
44
46
|
) -> tuple[tuple[Callable, int, int] | Callable, ...]:
|
|
45
47
|
"""
|
|
46
48
|
We need this transformation for eqx_list at the loading ("unpickling")
|
|
@@ -84,8 +86,8 @@ def string_to_function(
|
|
|
84
86
|
|
|
85
87
|
def save_pinn(
|
|
86
88
|
filename: str,
|
|
87
|
-
u: PINN |
|
|
88
|
-
params: Params
|
|
89
|
+
u: PINN | HyperPINN | SPINN,
|
|
90
|
+
params: Params,
|
|
89
91
|
kwargs_creation,
|
|
90
92
|
):
|
|
91
93
|
"""
|
|
@@ -103,15 +105,12 @@ def save_pinn(
|
|
|
103
105
|
tree_serialise_leaves`).
|
|
104
106
|
|
|
105
107
|
Equation parameters are saved apart because the initial type of attribute
|
|
106
|
-
`params` in PINN /
|
|
108
|
+
`params` in PINN / HyperPINN / SPINN is not `Params`
|
|
107
109
|
but `PyTree` as inherited from `eqx.partition`.
|
|
108
110
|
Therefore, if we want to ensure a proper serialization/deserialization:
|
|
109
111
|
- we cannot save a `Params` object at this
|
|
110
112
|
attribute field ; the `Params` object must be split into `Params.nn_params`
|
|
111
113
|
(type `PyTree`) and `Params.eq_params` (type `dict`).
|
|
112
|
-
- in the case of a `ParamsDict` we cannot save `ParamsDict.nn_params` at
|
|
113
|
-
the attribute field `params` because it is not a `PyTree` (as expected in
|
|
114
|
-
the PINN / HYPERPINN / SPINN signature) but it is still a dictionary.
|
|
115
114
|
|
|
116
115
|
Parameters
|
|
117
116
|
----------
|
|
@@ -120,28 +119,16 @@ def save_pinn(
|
|
|
120
119
|
u
|
|
121
120
|
The PINN
|
|
122
121
|
params
|
|
123
|
-
Params
|
|
122
|
+
Params to be saved
|
|
124
123
|
kwargs_creation
|
|
125
124
|
The dictionary of arguments that were used to create the PINN, e.g.
|
|
126
125
|
the layers list, O/PDE type, etc.
|
|
127
126
|
"""
|
|
128
|
-
if isinstance(
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
eqx.tree_serialise_leaves(filename + "-module.eqx", u)
|
|
134
|
-
|
|
135
|
-
elif isinstance(params, ParamsDict):
|
|
136
|
-
for key, params_ in params.nn_params.items():
|
|
137
|
-
if isinstance(u, HYPERPINN):
|
|
138
|
-
u = eqx.tree_at(lambda m: m.params_hyper, u, params_)
|
|
139
|
-
elif isinstance(u, (PINN, SPINN)):
|
|
140
|
-
u = eqx.tree_at(lambda m: m.params, u, params_)
|
|
141
|
-
eqx.tree_serialise_leaves(filename + f"-module_{key}.eqx", u)
|
|
142
|
-
|
|
143
|
-
else:
|
|
144
|
-
raise ValueError("The parameters to be saved must be a Params or a ParamsDict")
|
|
127
|
+
if isinstance(u, HyperPINN):
|
|
128
|
+
u = eqx.tree_at(lambda m: m.init_params_hyper, u, params)
|
|
129
|
+
elif isinstance(u, (PINN, SPINN)):
|
|
130
|
+
u = eqx.tree_at(lambda m: m.init_params, u, params)
|
|
131
|
+
eqx.tree_serialise_leaves(filename + "-module.eqx", u)
|
|
145
132
|
|
|
146
133
|
with open(filename + "-eq_params.pkl", "wb") as f:
|
|
147
134
|
pickle.dump(params.eq_params, f)
|
|
@@ -167,9 +154,8 @@ def save_pinn(
|
|
|
167
154
|
|
|
168
155
|
def load_pinn(
|
|
169
156
|
filename: str,
|
|
170
|
-
type_: Literal["
|
|
171
|
-
|
|
172
|
-
) -> tuple[eqx.Module, Params | ParamsDict]:
|
|
157
|
+
type_: Literal["pinn_mlp", "hyperpinn", "spinn_mlp"],
|
|
158
|
+
) -> tuple[eqx.Module, Params]:
|
|
173
159
|
"""
|
|
174
160
|
Load a PINN model. This function needs to access 3 files :
|
|
175
161
|
`{filename}-module.eqx`, `{filename}-parameters.pkl` and
|
|
@@ -187,9 +173,7 @@ def load_pinn(
|
|
|
187
173
|
filename
|
|
188
174
|
Filename (prefix) without extension.
|
|
189
175
|
type_
|
|
190
|
-
Type of model to load. Must be in ["
|
|
191
|
-
key_list_for_paramsdict
|
|
192
|
-
Pass the name of the keys of the dictionnary `ParamsDict.nn_params`. Default is None. In this case, we expect to retrieve a ParamsDict.
|
|
176
|
+
Type of model to load. Must be in ["pinn_mlp", "hyperpinn", "spinn"].
|
|
193
177
|
|
|
194
178
|
Returns
|
|
195
179
|
-------
|
|
@@ -207,34 +191,36 @@ def load_pinn(
|
|
|
207
191
|
eq_params_reloaded = {}
|
|
208
192
|
print("No pickle file for equation parameters found!")
|
|
209
193
|
kwargs_reloaded["eqx_list"] = string_to_function(kwargs_reloaded["eqx_list"])
|
|
210
|
-
if type_ == "
|
|
194
|
+
if type_ == "pinn_mlp":
|
|
211
195
|
# next line creates a shallow model, the jax arrays are just shapes and
|
|
212
196
|
# not populated, this just recreates the correct pytree structure
|
|
213
|
-
u_reloaded_shallow, _ = eqx.filter_eval_shape(
|
|
214
|
-
|
|
215
|
-
|
|
197
|
+
u_reloaded_shallow, _ = eqx.filter_eval_shape(
|
|
198
|
+
PINN_MLP.create, **kwargs_reloaded
|
|
199
|
+
)
|
|
200
|
+
elif type_ == "spinn_mlp":
|
|
201
|
+
u_reloaded_shallow, _ = eqx.filter_eval_shape(
|
|
202
|
+
SPINN_MLP.create, **kwargs_reloaded
|
|
203
|
+
)
|
|
216
204
|
elif type_ == "hyperpinn":
|
|
217
205
|
kwargs_reloaded["eqx_list_hyper"] = string_to_function(
|
|
218
206
|
kwargs_reloaded["eqx_list_hyper"]
|
|
219
207
|
)
|
|
220
208
|
u_reloaded_shallow, _ = eqx.filter_eval_shape(
|
|
221
|
-
|
|
209
|
+
HyperPINN.create, **kwargs_reloaded
|
|
222
210
|
)
|
|
223
211
|
else:
|
|
224
212
|
raise ValueError(f"{type_} is not valid")
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
213
|
+
# now the empty structure is populated with the actual saved array values
|
|
214
|
+
# stored in the eqx file
|
|
215
|
+
u_reloaded = eqx.tree_deserialise_leaves(
|
|
216
|
+
filename + "-module.eqx", u_reloaded_shallow
|
|
217
|
+
)
|
|
218
|
+
if isinstance(u_reloaded, HyperPINN):
|
|
219
|
+
params = Params(
|
|
220
|
+
nn_params=u_reloaded.init_params_hyper, eq_params=eq_params_reloaded
|
|
230
221
|
)
|
|
222
|
+
elif isinstance(u_reloaded, (PINN, SPINN)):
|
|
231
223
|
params = Params(nn_params=u_reloaded.init_params, eq_params=eq_params_reloaded)
|
|
232
224
|
else:
|
|
233
|
-
|
|
234
|
-
for key in key_list_for_paramsdict:
|
|
235
|
-
u_reloaded = eqx.tree_deserialise_leaves(
|
|
236
|
-
filename + f"-module_{key}.eqx", u_reloaded_shallow
|
|
237
|
-
)
|
|
238
|
-
nn_params_dict[key] = u_reloaded.init_params
|
|
239
|
-
params = ParamsDict(nn_params=nn_params_dict, eq_params=eq_params_reloaded)
|
|
225
|
+
raise ValueError("Wrong type for u_reloaded")
|
|
240
226
|
return u_reloaded, params
|