jinns 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_DataGenerators.py +2 -2
- jinns/loss/_DynamicLoss.py +2 -2
- jinns/loss/_LossODE.py +1 -1
- jinns/loss/_LossPDE.py +75 -38
- jinns/loss/_boundary_conditions.py +2 -2
- jinns/loss/_loss_utils.py +21 -15
- jinns/loss/_operators.py +0 -2
- jinns/nn/__init__.py +7 -0
- jinns/nn/_hyperpinn.py +397 -0
- jinns/nn/_mlp.py +192 -0
- jinns/nn/_pinn.py +190 -0
- jinns/nn/_ppinn.py +203 -0
- jinns/{utils → nn}/_save_load.py +39 -23
- jinns/nn/_spinn.py +106 -0
- jinns/nn/_spinn_mlp.py +196 -0
- jinns/plot/_plot.py +3 -3
- jinns/solver/_rar.py +3 -3
- jinns/solver/_solve.py +23 -9
- jinns/utils/__init__.py +0 -5
- jinns/utils/_types.py +4 -4
- {jinns-1.2.0.dist-info → jinns-1.3.0.dist-info}/METADATA +9 -9
- jinns-1.3.0.dist-info/RECORD +44 -0
- {jinns-1.2.0.dist-info → jinns-1.3.0.dist-info}/WHEEL +1 -1
- jinns/utils/_hyperpinn.py +0 -420
- jinns/utils/_pinn.py +0 -324
- jinns/utils/_ppinn.py +0 -227
- jinns/utils/_spinn.py +0 -249
- jinns-1.2.0.dist-info/RECORD +0 -41
- {jinns-1.2.0.dist-info → jinns-1.3.0.dist-info}/AUTHORS +0 -0
- {jinns-1.2.0.dist-info → jinns-1.3.0.dist-info}/LICENSE +0 -0
- {jinns-1.2.0.dist-info → jinns-1.3.0.dist-info}/top_level.txt +0 -0
jinns/nn/_pinn.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Implement abstract class for PINN architectures
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Literal, Callable, Union, Any
|
|
6
|
+
from dataclasses import InitVar
|
|
7
|
+
import equinox as eqx
|
|
8
|
+
from jaxtyping import Float, Array, PyTree
|
|
9
|
+
import jax.numpy as jnp
|
|
10
|
+
from jinns.parameters._params import Params, ParamsDict
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class PINN(eqx.Module):
|
|
14
|
+
r"""
|
|
15
|
+
Base class for PINN objects. It can be seen as a wrapper on
|
|
16
|
+
an `eqx.Module` which actually implement the NN architectures, with extra
|
|
17
|
+
arguments handling the "physics-informed" aspect.
|
|
18
|
+
|
|
19
|
+
!!! Note
|
|
20
|
+
We use the `eqx.partition` and `eqx.combine` strategy of Equinox: a
|
|
21
|
+
`filter_spec` is applied on the PyTree and splits it into two PyTree with
|
|
22
|
+
the same structure: a static one (invisible to JAX transform such as JIT,
|
|
23
|
+
grad, etc.) and dynamic one. By convention, anything not static is
|
|
24
|
+
considered a parameter in Jinns.
|
|
25
|
+
|
|
26
|
+
For compatibility with jinns, we require that a `PINN` architecture:
|
|
27
|
+
|
|
28
|
+
1) has an eqx.Module (`eqx_network`) InitVar passed to __post_init__
|
|
29
|
+
representing the network architecture.
|
|
30
|
+
2) calls `eqx.partition` in __post_init__ in order to store the
|
|
31
|
+
static part of the model and the initial parameters.
|
|
32
|
+
3) has a `eq_type` argument, used for handling internal operations in
|
|
33
|
+
jinns.
|
|
34
|
+
4) has a `slice_solution` argument. It is a `jnp.s\_` object which
|
|
35
|
+
indicates which axis of the PINN output is dedicated to the actual equation
|
|
36
|
+
solution. Default None means that slice_solution = the whole PINN output.
|
|
37
|
+
For example, this argument is useful when the PINN is also used to output
|
|
38
|
+
equation parameters. Note that it must be a slice and not an integer (a
|
|
39
|
+
preprocessing of the user provided argument takes care of it).
|
|
40
|
+
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
slice_solution : slice
|
|
44
|
+
Default is jnp.s\_[...]. A jnp.s\_ object which indicates which axis of the PINN output is
|
|
45
|
+
dedicated to the actual equation solution. Default None
|
|
46
|
+
means that slice_solution = the whole PINN output. This argument is useful
|
|
47
|
+
when the PINN is also used to output equation parameters for example
|
|
48
|
+
Note that it must be a slice and not an integer (a preprocessing of the
|
|
49
|
+
user provided argument takes care of it).
|
|
50
|
+
eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
|
|
51
|
+
A string with three possibilities.
|
|
52
|
+
"ODE": the PINN is called with one input `t`.
|
|
53
|
+
"statio_PDE": the PINN is called with one input `x`, `x`
|
|
54
|
+
can be high dimensional.
|
|
55
|
+
"nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
|
|
56
|
+
can be high dimensional.
|
|
57
|
+
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
58
|
+
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
59
|
+
after the `input_transform` function.
|
|
60
|
+
input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
|
|
61
|
+
A function that will be called before entering the PINN. Its output(s)
|
|
62
|
+
must match the PINN inputs (except for the parameters).
|
|
63
|
+
Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
|
|
64
|
+
and the parameters. Default is no operation.
|
|
65
|
+
output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
|
|
66
|
+
A function with arguments begin the same input as the PINN, the PINN
|
|
67
|
+
output and the parameter. This function will be called after exiting the PINN.
|
|
68
|
+
Default is no operation.
|
|
69
|
+
eqx_network : eqx.Module
|
|
70
|
+
The actual neural network instanciated as an eqx.Module.
|
|
71
|
+
filter_spec : PyTree[Union[bool, Callable[[Any], bool]]]
|
|
72
|
+
Default is `eqx.is_inexact_array`. This tells Jinns what to consider as
|
|
73
|
+
a trainable parameter. Quoting from equinox documentation:
|
|
74
|
+
a PyTree whose structure should be a prefix of the structure of pytree.
|
|
75
|
+
Each of its leaves should either be 1) True, in which case the leaf or
|
|
76
|
+
subtree is kept; 2) False, in which case the leaf or subtree is
|
|
77
|
+
replaced with replace; 3) a callable Leaf -> bool, in which case this is evaluated on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
Raises
|
|
81
|
+
------
|
|
82
|
+
RuntimeError
|
|
83
|
+
If the parameter value for eq_type is not in `["ODE", "statio_PDE",
|
|
84
|
+
"nonstatio_PDE"]`
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
slice_solution: slice = eqx.field(static=True, kw_only=True, default=None)
|
|
88
|
+
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"] = eqx.field(
|
|
89
|
+
static=True, kw_only=True
|
|
90
|
+
)
|
|
91
|
+
input_transform: Callable[
|
|
92
|
+
[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
|
|
93
|
+
] = eqx.field(static=True, kw_only=True, default=None)
|
|
94
|
+
output_transform: Callable[
|
|
95
|
+
[Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
|
|
96
|
+
Float[Array, "output_dim"],
|
|
97
|
+
] = eqx.field(static=True, kw_only=True, default=None)
|
|
98
|
+
|
|
99
|
+
eqx_network: InitVar[eqx.Module] = eqx.field(kw_only=True)
|
|
100
|
+
filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = eqx.field(
|
|
101
|
+
static=True, kw_only=True, default=eqx.is_inexact_array
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
init_params: PyTree = eqx.field(init=False)
|
|
105
|
+
static: PyTree = eqx.field(init=False, static=True)
|
|
106
|
+
|
|
107
|
+
def __post_init__(self, eqx_network):
|
|
108
|
+
|
|
109
|
+
if self.eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
|
|
110
|
+
raise RuntimeError("Wrong parameter value for eq_type")
|
|
111
|
+
# saving the static part of the model and initial parameters
|
|
112
|
+
|
|
113
|
+
if self.filter_spec is None:
|
|
114
|
+
self.filter_spec = eqx.is_inexact_array
|
|
115
|
+
|
|
116
|
+
self.init_params, self.static = eqx.partition(eqx_network, self.filter_spec)
|
|
117
|
+
|
|
118
|
+
if self.input_transform is None:
|
|
119
|
+
self.input_transform = lambda _in, _params: _in
|
|
120
|
+
|
|
121
|
+
if self.output_transform is None:
|
|
122
|
+
self.output_transform = lambda _in_pinn, _out_pinn, _params: _out_pinn
|
|
123
|
+
|
|
124
|
+
if self.slice_solution is None:
|
|
125
|
+
self.slice_solution = jnp.s_[:]
|
|
126
|
+
|
|
127
|
+
if isinstance(self.slice_solution, int):
|
|
128
|
+
# rewrite it as a slice to ensure that axis does not disappear when
|
|
129
|
+
# indexing
|
|
130
|
+
self.slice_solution = jnp.s_[self.slice_solution : self.slice_solution + 1]
|
|
131
|
+
|
|
132
|
+
def eval(self, network, inputs, *args, **kwargs):
|
|
133
|
+
"""How to call your Equinox module `network`. The purpose of this method
|
|
134
|
+
is to give more flexibility : user should re-implement `eval`
|
|
135
|
+
when inheriting from `PINN` if they desire more flexibility on how to
|
|
136
|
+
evaluate the network.
|
|
137
|
+
|
|
138
|
+
Defaults to using `network.__call__(inputs)` but it could be more refined *e.g.* `network.anymethod(inputs)`.
|
|
139
|
+
|
|
140
|
+
Parameters
|
|
141
|
+
----------
|
|
142
|
+
network : eqx.Module
|
|
143
|
+
Your neural network with the parameters set, usually returned by
|
|
144
|
+
`eqx.combine(self.static, current_params)`.
|
|
145
|
+
inputs : Array
|
|
146
|
+
The inputs, evetually transformed by `self.input_transformed` if
|
|
147
|
+
specified by the user.
|
|
148
|
+
|
|
149
|
+
Returns
|
|
150
|
+
-------
|
|
151
|
+
Array
|
|
152
|
+
The output
|
|
153
|
+
"""
|
|
154
|
+
|
|
155
|
+
return network(inputs)
|
|
156
|
+
|
|
157
|
+
def __call__(
|
|
158
|
+
self,
|
|
159
|
+
inputs: Float[Array, "input_dim"],
|
|
160
|
+
params: Params | ParamsDict | PyTree,
|
|
161
|
+
*args,
|
|
162
|
+
**kwargs,
|
|
163
|
+
) -> Float[Array, "output_dim"]:
|
|
164
|
+
"""
|
|
165
|
+
A proper __call__ implementation performs an eqx.combine here with
|
|
166
|
+
`params` and `self.static` to recreate the callable eqx.Module
|
|
167
|
+
architecture. The rest of the content of this function is dependent on
|
|
168
|
+
the network.
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
if len(inputs.shape) == 0:
|
|
172
|
+
# This can happen often when the user directly provides some
|
|
173
|
+
# collocation points (eg for plotting, whithout using
|
|
174
|
+
# DataGenerators)
|
|
175
|
+
inputs = inputs[None]
|
|
176
|
+
|
|
177
|
+
try:
|
|
178
|
+
model = eqx.combine(params.nn_params, self.static)
|
|
179
|
+
except (KeyError, AttributeError, TypeError) as e: # give more flexibility
|
|
180
|
+
model = eqx.combine(params, self.static)
|
|
181
|
+
|
|
182
|
+
# evaluate the model
|
|
183
|
+
res = self.eval(model, self.input_transform(inputs, params), *args, **kwargs)
|
|
184
|
+
|
|
185
|
+
res = self.output_transform(inputs, res.squeeze(), params)
|
|
186
|
+
|
|
187
|
+
# force (1,) output for non vectorial solution (consistency)
|
|
188
|
+
if not res.shape:
|
|
189
|
+
return jnp.expand_dims(res, axis=-1)
|
|
190
|
+
return res
|
jinns/nn/_ppinn.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Implements utility function to create PINNs
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Callable, Literal, Self
|
|
6
|
+
from dataclasses import InitVar
|
|
7
|
+
import jax
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
import equinox as eqx
|
|
10
|
+
|
|
11
|
+
from jaxtyping import Array, Key, PyTree, Float
|
|
12
|
+
|
|
13
|
+
from jinns.parameters._params import Params, ParamsDict
|
|
14
|
+
from jinns.nn._pinn import PINN
|
|
15
|
+
from jinns.nn._mlp import MLP
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class PPINN_MLP(PINN):
|
|
19
|
+
r"""
|
|
20
|
+
A PPINN MLP (Parallel PINN with MLPs) object which mimicks the PFNN architecture from
|
|
21
|
+
DeepXDE. This is in fact a PINN MLP that encompasses several PINN MLPs internally.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
slice_solution : slice
|
|
26
|
+
A jnp.s\_ object which indicates which axis of the PPINN output is
|
|
27
|
+
dedicated to the actual equation solution. Default None
|
|
28
|
+
means that slice_solution = the whole PPINN output. This argument is useful
|
|
29
|
+
when the PINN is also used to output equation parameters for example
|
|
30
|
+
Note that it must be a slice and not an integer (a preprocessing of the
|
|
31
|
+
user provided argument takes care of it).
|
|
32
|
+
eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
|
|
33
|
+
A string with three possibilities.
|
|
34
|
+
"ODE": the PPINN is called with one input `t`.
|
|
35
|
+
"statio_PDE": the PPINN is called with one input `x`, `x`
|
|
36
|
+
can be high dimensional.
|
|
37
|
+
"nonstatio_PDE": the PPINN is called with two inputs `t` and `x`, `x`
|
|
38
|
+
can be high dimensional.
|
|
39
|
+
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
40
|
+
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
41
|
+
after the `input_transform` function.
|
|
42
|
+
input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
|
|
43
|
+
A function that will be called before entering the PPINN. Its output(s)
|
|
44
|
+
must match the PPINN inputs (except for the parameters).
|
|
45
|
+
Its inputs are the PPINN inputs (`t` and/or `x` concatenated together)
|
|
46
|
+
and the parameters. Default is no operation.
|
|
47
|
+
output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
|
|
48
|
+
A function with arguments begin the same input as the PPINN, the PPINN
|
|
49
|
+
output and the parameter. This function will be called after exiting
|
|
50
|
+
the PPINN.
|
|
51
|
+
Default is no operation.
|
|
52
|
+
filter_spec : PyTree[Union[bool, Callable[[Any], bool]]]
|
|
53
|
+
Default is `eqx.is_inexact_array`. This tells Jinns what to consider as
|
|
54
|
+
a trainable parameter. Quoting from equinox documentation:
|
|
55
|
+
a PyTree whose structure should be a prefix of the structure of pytree.
|
|
56
|
+
Each of its leaves should either be 1) True, in which case the leaf or
|
|
57
|
+
subtree is kept; 2) False, in which case the leaf or subtree is
|
|
58
|
+
replaced with replace; 3) a callable Leaf -> bool, in which case this is evaluated on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
|
|
59
|
+
eqx_network_list
|
|
60
|
+
A list of eqx.nn.MLP objects with same input
|
|
61
|
+
dimensions. They represent the parallel subnetworks of the PPIN MLP.
|
|
62
|
+
Their respective outputs are concatenated.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
eqx_network_list: InitVar[list[eqx.Module]] = eqx.field(kw_only=True)
|
|
66
|
+
|
|
67
|
+
def __post_init__(self, eqx_network, eqx_network_list):
|
|
68
|
+
super().__post_init__(
|
|
69
|
+
eqx_network=eqx_network_list[0], # this is not used since it is
|
|
70
|
+
# overwritten just below
|
|
71
|
+
)
|
|
72
|
+
self.init_params, self.static = (), ()
|
|
73
|
+
for eqx_network_ in eqx_network_list:
|
|
74
|
+
params, static = eqx.partition(eqx_network_, self.filter_spec)
|
|
75
|
+
self.init_params = self.init_params + (params,)
|
|
76
|
+
self.static = self.static + (static,)
|
|
77
|
+
|
|
78
|
+
def __call__(
|
|
79
|
+
self,
|
|
80
|
+
inputs: Float[Array, "1"] | Float[Array, "dim"] | Float[Array, "1+dim"],
|
|
81
|
+
params: PyTree,
|
|
82
|
+
) -> Float[Array, "output_dim"]:
|
|
83
|
+
"""
|
|
84
|
+
Evaluate the PPINN on some inputs with some params.
|
|
85
|
+
"""
|
|
86
|
+
if len(inputs.shape) == 0:
|
|
87
|
+
# This can happen often when the user directly provides some
|
|
88
|
+
# collocation points (eg for plotting, whithout using
|
|
89
|
+
# DataGenerators)
|
|
90
|
+
inputs = inputs[None]
|
|
91
|
+
transformed_inputs = self.input_transform(inputs, params)
|
|
92
|
+
|
|
93
|
+
outs = []
|
|
94
|
+
|
|
95
|
+
try:
|
|
96
|
+
for params_, static in zip(params.nn_params, self.static):
|
|
97
|
+
model = eqx.combine(params_, static)
|
|
98
|
+
outs += [model(transformed_inputs)]
|
|
99
|
+
except (KeyError, AttributeError, TypeError) as e:
|
|
100
|
+
for params_, static in zip(params, self.static):
|
|
101
|
+
model = eqx.combine(params_, static)
|
|
102
|
+
outs += [model(transformed_inputs)]
|
|
103
|
+
# Note that below is then a global output transform
|
|
104
|
+
res = self.output_transform(inputs, jnp.concatenate(outs, axis=0), params)
|
|
105
|
+
|
|
106
|
+
## force (1,) output for non vectorial solution (consistency)
|
|
107
|
+
if not res.shape:
|
|
108
|
+
return jnp.expand_dims(res, axis=-1)
|
|
109
|
+
return res
|
|
110
|
+
|
|
111
|
+
@classmethod
|
|
112
|
+
def create(
|
|
113
|
+
cls,
|
|
114
|
+
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
|
|
115
|
+
eqx_network_list: list[eqx.nn.MLP] = None,
|
|
116
|
+
key: Key = None,
|
|
117
|
+
eqx_list_list: list[tuple[tuple[Callable, int, int] | Callable, ...]] = None,
|
|
118
|
+
input_transform: Callable[
|
|
119
|
+
[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
|
|
120
|
+
] = None,
|
|
121
|
+
output_transform: Callable[
|
|
122
|
+
[Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
|
|
123
|
+
Float[Array, "output_dim"],
|
|
124
|
+
] = None,
|
|
125
|
+
slice_solution: slice = None,
|
|
126
|
+
) -> tuple[Self, PyTree]:
|
|
127
|
+
r"""
|
|
128
|
+
Utility function to create a Parrallel PINN neural network for Jinns.
|
|
129
|
+
|
|
130
|
+
Parameters
|
|
131
|
+
----------
|
|
132
|
+
eq_type
|
|
133
|
+
A string with three possibilities.
|
|
134
|
+
"ODE": the PPINN MLP is called with one input `t`.
|
|
135
|
+
"statio_PDE": the PPINN MLP is called with one input `x`, `x`
|
|
136
|
+
can be high dimensional.
|
|
137
|
+
"nonstatio_PDE": the PPINN MLP is called with two inputs `t` and `x`, `x`
|
|
138
|
+
can be high dimensional.
|
|
139
|
+
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
140
|
+
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
141
|
+
after the `input_transform` function.
|
|
142
|
+
eqx_network_list
|
|
143
|
+
Default is None. A list of eqx.nn.MLP objects with same input
|
|
144
|
+
dimensions. They represent the parallel subnetworks of the PPIN MLP.
|
|
145
|
+
Their respective outputs are concatenated.
|
|
146
|
+
key
|
|
147
|
+
Default is None. Must be provided with `eqx_list_list` if
|
|
148
|
+
`eqx_network_list` is not provided. A JAX random key that will be used
|
|
149
|
+
to initialize the networks parameters.
|
|
150
|
+
eqx_list_list
|
|
151
|
+
Default is None. Must be provided if `eqx_network_list` is not
|
|
152
|
+
provided. A list of `eqx_list` (see `PINN_MLP.create()`). The input dimension must be the
|
|
153
|
+
same for each sub-`eqx_list`. Then the parallel subnetworks can be
|
|
154
|
+
different. Their respective outputs are concatenated.
|
|
155
|
+
input_transform
|
|
156
|
+
A function that will be called before entering the PPINN MLP. Its output(s)
|
|
157
|
+
must match the PPINN MLP inputs (except for the parameters).
|
|
158
|
+
Its inputs are the PPINN MLP inputs (`t` and/or `x` concatenated together)
|
|
159
|
+
and the parameters. Default is no operation.
|
|
160
|
+
output_transform
|
|
161
|
+
This function will be called after exiting
|
|
162
|
+
the PPINN MLP, i.e., on the concatenated outputs of all parallel networks
|
|
163
|
+
Default is no operation.
|
|
164
|
+
slice_solution
|
|
165
|
+
A jnp.s\_ object which indicates which axis of the PPINN MLP output is
|
|
166
|
+
dedicated to the actual equation solution. Default None
|
|
167
|
+
means that slice_solution = the whole PPINN MLP output. This argument is
|
|
168
|
+
useful when the PPINN MLP is also used to output equation parameters for
|
|
169
|
+
example Note that it must be a slice and not an integer (a
|
|
170
|
+
preprocessing of the user provided argument takes care of it).
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
Returns
|
|
174
|
+
-------
|
|
175
|
+
ppinn
|
|
176
|
+
A PPINN MLP instance
|
|
177
|
+
ppinn.init_params
|
|
178
|
+
An initial set of parameters for the PPINN MLP
|
|
179
|
+
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
if eqx_network_list is None:
|
|
183
|
+
if eqx_list_list is None or key is None:
|
|
184
|
+
raise ValueError(
|
|
185
|
+
"If eqx_network_list is None, then key and eqx_list_list"
|
|
186
|
+
" must be provided"
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
eqx_network_list = []
|
|
190
|
+
for eqx_list in eqx_list_list:
|
|
191
|
+
key, subkey = jax.random.split(key, 2)
|
|
192
|
+
print(subkey)
|
|
193
|
+
eqx_network_list.append(MLP(key=subkey, eqx_list=eqx_list))
|
|
194
|
+
|
|
195
|
+
ppinn = cls(
|
|
196
|
+
eqx_network=None,
|
|
197
|
+
eqx_network_list=eqx_network_list,
|
|
198
|
+
slice_solution=slice_solution,
|
|
199
|
+
eq_type=eq_type,
|
|
200
|
+
input_transform=input_transform,
|
|
201
|
+
output_transform=output_transform,
|
|
202
|
+
)
|
|
203
|
+
return ppinn, ppinn.init_params
|
jinns/{utils → nn}/_save_load.py
RENAMED
|
@@ -7,14 +7,16 @@ import pickle
|
|
|
7
7
|
import jax
|
|
8
8
|
import equinox as eqx
|
|
9
9
|
|
|
10
|
-
from jinns.
|
|
11
|
-
from jinns.
|
|
12
|
-
from jinns.
|
|
10
|
+
from jinns.nn._pinn import PINN
|
|
11
|
+
from jinns.nn._spinn import SPINN
|
|
12
|
+
from jinns.nn._mlp import PINN_MLP
|
|
13
|
+
from jinns.nn._spinn_mlp import SPINN_MLP
|
|
14
|
+
from jinns.nn._hyperpinn import HyperPINN
|
|
13
15
|
from jinns.parameters._params import Params, ParamsDict
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
def function_to_string(
|
|
17
|
-
eqx_list: tuple[tuple[Callable, int, int] | Callable, ...]
|
|
19
|
+
eqx_list: tuple[tuple[Callable, int, int] | Callable, ...],
|
|
18
20
|
) -> tuple[tuple[str, int, int] | str, ...]:
|
|
19
21
|
"""
|
|
20
22
|
We need this transformation for eqx_list to be pickled
|
|
@@ -40,7 +42,7 @@ def function_to_string(
|
|
|
40
42
|
|
|
41
43
|
|
|
42
44
|
def string_to_function(
|
|
43
|
-
eqx_list_with_string: tuple[tuple[str, int, int] | str, ...]
|
|
45
|
+
eqx_list_with_string: tuple[tuple[str, int, int] | str, ...],
|
|
44
46
|
) -> tuple[tuple[Callable, int, int] | Callable, ...]:
|
|
45
47
|
"""
|
|
46
48
|
We need this transformation for eqx_list at the loading ("unpickling")
|
|
@@ -84,7 +86,7 @@ def string_to_function(
|
|
|
84
86
|
|
|
85
87
|
def save_pinn(
|
|
86
88
|
filename: str,
|
|
87
|
-
u: PINN |
|
|
89
|
+
u: PINN | HyperPINN | SPINN,
|
|
88
90
|
params: Params | ParamsDict,
|
|
89
91
|
kwargs_creation,
|
|
90
92
|
):
|
|
@@ -103,7 +105,7 @@ def save_pinn(
|
|
|
103
105
|
tree_serialise_leaves`).
|
|
104
106
|
|
|
105
107
|
Equation parameters are saved apart because the initial type of attribute
|
|
106
|
-
`params` in PINN /
|
|
108
|
+
`params` in PINN / HyperPINN / SPINN is not `Params` nor `ParamsDict`
|
|
107
109
|
but `PyTree` as inherited from `eqx.partition`.
|
|
108
110
|
Therefore, if we want to ensure a proper serialization/deserialization:
|
|
109
111
|
- we cannot save a `Params` object at this
|
|
@@ -111,7 +113,7 @@ def save_pinn(
|
|
|
111
113
|
(type `PyTree`) and `Params.eq_params` (type `dict`).
|
|
112
114
|
- in the case of a `ParamsDict` we cannot save `ParamsDict.nn_params` at
|
|
113
115
|
the attribute field `params` because it is not a `PyTree` (as expected in
|
|
114
|
-
the PINN /
|
|
116
|
+
the PINN / HyperPINN / SPINN signature) but it is still a dictionary.
|
|
115
117
|
|
|
116
118
|
Parameters
|
|
117
119
|
----------
|
|
@@ -126,18 +128,18 @@ def save_pinn(
|
|
|
126
128
|
the layers list, O/PDE type, etc.
|
|
127
129
|
"""
|
|
128
130
|
if isinstance(params, Params):
|
|
129
|
-
if isinstance(u,
|
|
130
|
-
u = eqx.tree_at(lambda m: m.
|
|
131
|
+
if isinstance(u, HyperPINN):
|
|
132
|
+
u = eqx.tree_at(lambda m: m.init_params_hyper, u, params)
|
|
131
133
|
elif isinstance(u, (PINN, SPINN)):
|
|
132
|
-
u = eqx.tree_at(lambda m: m.
|
|
134
|
+
u = eqx.tree_at(lambda m: m.init_params, u, params)
|
|
133
135
|
eqx.tree_serialise_leaves(filename + "-module.eqx", u)
|
|
134
136
|
|
|
135
137
|
elif isinstance(params, ParamsDict):
|
|
136
138
|
for key, params_ in params.nn_params.items():
|
|
137
|
-
if isinstance(u,
|
|
138
|
-
u = eqx.tree_at(lambda m: m.
|
|
139
|
+
if isinstance(u, HyperPINN):
|
|
140
|
+
u = eqx.tree_at(lambda m: m.init_params_hyper, u, params_)
|
|
139
141
|
elif isinstance(u, (PINN, SPINN)):
|
|
140
|
-
u = eqx.tree_at(lambda m: m.
|
|
142
|
+
u = eqx.tree_at(lambda m: m.init_params, u, params_)
|
|
141
143
|
eqx.tree_serialise_leaves(filename + f"-module_{key}.eqx", u)
|
|
142
144
|
|
|
143
145
|
else:
|
|
@@ -167,7 +169,7 @@ def save_pinn(
|
|
|
167
169
|
|
|
168
170
|
def load_pinn(
|
|
169
171
|
filename: str,
|
|
170
|
-
type_: Literal["
|
|
172
|
+
type_: Literal["pinn_mlp", "hyperpinn", "spinn_mlp"],
|
|
171
173
|
key_list_for_paramsdict: list[str] = None,
|
|
172
174
|
) -> tuple[eqx.Module, Params | ParamsDict]:
|
|
173
175
|
"""
|
|
@@ -187,7 +189,7 @@ def load_pinn(
|
|
|
187
189
|
filename
|
|
188
190
|
Filename (prefix) without extension.
|
|
189
191
|
type_
|
|
190
|
-
Type of model to load. Must be in ["
|
|
192
|
+
Type of model to load. Must be in ["pinn_mlp", "hyperpinn", "spinn"].
|
|
191
193
|
key_list_for_paramsdict
|
|
192
194
|
Pass the name of the keys of the dictionnary `ParamsDict.nn_params`. Default is None. In this case, we expect to retrieve a ParamsDict.
|
|
193
195
|
|
|
@@ -207,18 +209,22 @@ def load_pinn(
|
|
|
207
209
|
eq_params_reloaded = {}
|
|
208
210
|
print("No pickle file for equation parameters found!")
|
|
209
211
|
kwargs_reloaded["eqx_list"] = string_to_function(kwargs_reloaded["eqx_list"])
|
|
210
|
-
if type_ == "
|
|
212
|
+
if type_ == "pinn_mlp":
|
|
211
213
|
# next line creates a shallow model, the jax arrays are just shapes and
|
|
212
214
|
# not populated, this just recreates the correct pytree structure
|
|
213
|
-
u_reloaded_shallow, _ = eqx.filter_eval_shape(
|
|
214
|
-
|
|
215
|
-
|
|
215
|
+
u_reloaded_shallow, _ = eqx.filter_eval_shape(
|
|
216
|
+
PINN_MLP.create, **kwargs_reloaded
|
|
217
|
+
)
|
|
218
|
+
elif type_ == "spinn_mlp":
|
|
219
|
+
u_reloaded_shallow, _ = eqx.filter_eval_shape(
|
|
220
|
+
SPINN_MLP.create, **kwargs_reloaded
|
|
221
|
+
)
|
|
216
222
|
elif type_ == "hyperpinn":
|
|
217
223
|
kwargs_reloaded["eqx_list_hyper"] = string_to_function(
|
|
218
224
|
kwargs_reloaded["eqx_list_hyper"]
|
|
219
225
|
)
|
|
220
226
|
u_reloaded_shallow, _ = eqx.filter_eval_shape(
|
|
221
|
-
|
|
227
|
+
HyperPINN.create, **kwargs_reloaded
|
|
222
228
|
)
|
|
223
229
|
else:
|
|
224
230
|
raise ValueError(f"{type_} is not valid")
|
|
@@ -228,13 +234,23 @@ def load_pinn(
|
|
|
228
234
|
u_reloaded = eqx.tree_deserialise_leaves(
|
|
229
235
|
filename + "-module.eqx", u_reloaded_shallow
|
|
230
236
|
)
|
|
231
|
-
|
|
237
|
+
if isinstance(u_reloaded, HyperPINN):
|
|
238
|
+
params = Params(
|
|
239
|
+
nn_params=u_reloaded.init_params_hyper, eq_params=eq_params_reloaded
|
|
240
|
+
)
|
|
241
|
+
elif isinstance(u_reloaded, (PINN, SPINN)):
|
|
242
|
+
params = Params(
|
|
243
|
+
nn_params=u_reloaded.init_params, eq_params=eq_params_reloaded
|
|
244
|
+
)
|
|
232
245
|
else:
|
|
233
246
|
nn_params_dict = {}
|
|
234
247
|
for key in key_list_for_paramsdict:
|
|
235
248
|
u_reloaded = eqx.tree_deserialise_leaves(
|
|
236
249
|
filename + f"-module_{key}.eqx", u_reloaded_shallow
|
|
237
250
|
)
|
|
238
|
-
|
|
251
|
+
if isinstance(u_reloaded, HyperPINN):
|
|
252
|
+
nn_params_dict[key] = u_reloaded.init_params_hyper
|
|
253
|
+
elif isinstance(u_reloaded, (PINN, SPINN)):
|
|
254
|
+
nn_params_dict[key] = u_reloaded.init_params
|
|
239
255
|
params = ParamsDict(nn_params=nn_params_dict, eq_params=eq_params_reloaded)
|
|
240
256
|
return u_reloaded, params
|
jinns/nn/_spinn.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
from typing import Union, Callable, Any
|
|
2
|
+
from dataclasses import InitVar
|
|
3
|
+
from jaxtyping import PyTree, Float, Array
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import equinox as eqx
|
|
7
|
+
|
|
8
|
+
from jinns.parameters._params import Params, ParamsDict
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SPINN(eqx.Module):
|
|
12
|
+
"""
|
|
13
|
+
A Separable PINN object compatible with the rest of jinns.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
d : int
|
|
18
|
+
The number of dimensions to treat separately, including time `t` if
|
|
19
|
+
used for non-stationnary equations.
|
|
20
|
+
r : int
|
|
21
|
+
An integer. The dimension of the embedding.
|
|
22
|
+
eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
|
|
23
|
+
A string with three possibilities.
|
|
24
|
+
"ODE": the PINN is called with one input `t`.
|
|
25
|
+
"statio_PDE": the PINN is called with one input `x`, `x`
|
|
26
|
+
can be high dimensional.
|
|
27
|
+
"nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
|
|
28
|
+
can be high dimensional.
|
|
29
|
+
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
30
|
+
of the dimension of `t` + the dimension of `x`.
|
|
31
|
+
m : int
|
|
32
|
+
The output dimension of the neural network. According to
|
|
33
|
+
the SPINN article, a total embedding dimension of `r*m` is defined. We
|
|
34
|
+
then sum groups of `r` embedding dimensions to compute each output.
|
|
35
|
+
Default is 1.
|
|
36
|
+
filter_spec : PyTree[Union[bool, Callable[[Any], bool]]]
|
|
37
|
+
Default is `eqx.is_inexact_array`. This tells Jinns what to consider as
|
|
38
|
+
a trainable parameter. Quoting from equinox documentation:
|
|
39
|
+
a PyTree whose structure should be a prefix of the structure of pytree.
|
|
40
|
+
Each of its leaves should either be 1) True, in which case the leaf or
|
|
41
|
+
subtree is kept; 2) False, in which case the leaf or subtree is
|
|
42
|
+
replaced with replace; 3) a callable Leaf -> bool, in which case this is evaluated on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
|
|
43
|
+
eqx_spinn_network : eqx.Module
|
|
44
|
+
The actual neural network instanciated as an eqx.Module. It should be
|
|
45
|
+
an architecture taking `d` inputs and returning `d` times an embedding
|
|
46
|
+
of dimension `r`*`m`. See the Separable PINN paper for more details.
|
|
47
|
+
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
d: int = eqx.field(static=True, kw_only=True)
|
|
51
|
+
r: int = eqx.field(static=True, kw_only=True)
|
|
52
|
+
eq_type: str = eqx.field(static=True, kw_only=True)
|
|
53
|
+
m: int = eqx.field(static=True, kw_only=True, default=1)
|
|
54
|
+
|
|
55
|
+
filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = eqx.field(
|
|
56
|
+
static=True, kw_only=True, default=None
|
|
57
|
+
)
|
|
58
|
+
eqx_spinn_network: InitVar[eqx.Module] = eqx.field(kw_only=True)
|
|
59
|
+
|
|
60
|
+
init_params: PyTree = eqx.field(init=False)
|
|
61
|
+
static: PyTree = eqx.field(init=False, static=True)
|
|
62
|
+
|
|
63
|
+
def __post_init__(self, eqx_spinn_network):
|
|
64
|
+
|
|
65
|
+
if self.filter_spec is None:
|
|
66
|
+
self.filter_spec = eqx.is_inexact_array
|
|
67
|
+
|
|
68
|
+
self.init_params, self.static = eqx.partition(
|
|
69
|
+
eqx_spinn_network, self.filter_spec
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
def __call__(
|
|
73
|
+
self,
|
|
74
|
+
t_x: Float[Array, "batch_size 1+dim"],
|
|
75
|
+
params: Params | ParamsDict | PyTree,
|
|
76
|
+
) -> Float[Array, "output_dim"]:
|
|
77
|
+
"""
|
|
78
|
+
Evaluate the SPINN on some inputs with some params.
|
|
79
|
+
"""
|
|
80
|
+
try:
|
|
81
|
+
spinn = eqx.combine(params.nn_params, self.static)
|
|
82
|
+
except (KeyError, AttributeError, TypeError) as e:
|
|
83
|
+
spinn = eqx.combine(params, self.static)
|
|
84
|
+
v_model = jax.vmap(spinn)
|
|
85
|
+
res = v_model(t_x)
|
|
86
|
+
|
|
87
|
+
a = ", ".join([f"{chr(97 + d)}z" for d in range(res.shape[1])])
|
|
88
|
+
b = "".join([f"{chr(97 + d)}" for d in range(res.shape[1])])
|
|
89
|
+
res = jnp.stack(
|
|
90
|
+
[
|
|
91
|
+
jnp.einsum(
|
|
92
|
+
f"{a} -> {b}",
|
|
93
|
+
*(
|
|
94
|
+
res[:, d, m * self.r : (m + 1) * self.r]
|
|
95
|
+
for d in range(res.shape[1])
|
|
96
|
+
),
|
|
97
|
+
)
|
|
98
|
+
for m in range(self.m)
|
|
99
|
+
],
|
|
100
|
+
axis=-1,
|
|
101
|
+
) # compute each output dimension
|
|
102
|
+
|
|
103
|
+
# force (1,) output for non vectorial solution (consistency)
|
|
104
|
+
if len(res.shape) == self.d:
|
|
105
|
+
return jnp.expand_dims(res, axis=-1)
|
|
106
|
+
return res
|