jinns 0.8.10__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -0
- jinns/data/_Batchs.py +27 -0
- jinns/data/_DataGenerators.py +953 -1182
- jinns/data/__init__.py +4 -8
- jinns/experimental/__init__.py +0 -2
- jinns/experimental/_diffrax_solver.py +5 -5
- jinns/loss/_DynamicLoss.py +282 -305
- jinns/loss/_DynamicLossAbstract.py +321 -168
- jinns/loss/_LossODE.py +290 -307
- jinns/loss/_LossPDE.py +628 -1040
- jinns/loss/__init__.py +21 -5
- jinns/loss/_boundary_conditions.py +95 -96
- jinns/loss/{_Losses.py → _loss_utils.py} +104 -46
- jinns/loss/_loss_weights.py +59 -0
- jinns/loss/_operators.py +78 -72
- jinns/parameters/__init__.py +6 -0
- jinns/parameters/_derivative_keys.py +94 -0
- jinns/parameters/_params.py +115 -0
- jinns/plot/__init__.py +5 -0
- jinns/{data/_display.py → plot/_plot.py} +98 -75
- jinns/solver/_rar.py +193 -45
- jinns/solver/_solve.py +199 -144
- jinns/utils/__init__.py +3 -9
- jinns/utils/_containers.py +37 -43
- jinns/utils/_hyperpinn.py +226 -127
- jinns/utils/_pinn.py +183 -111
- jinns/utils/_save_load.py +121 -56
- jinns/utils/_spinn.py +117 -84
- jinns/utils/_types.py +64 -0
- jinns/utils/_utils.py +6 -160
- jinns/validation/_validation.py +52 -144
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/METADATA +5 -4
- jinns-1.0.0.dist-info/RECORD +38 -0
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/WHEEL +1 -1
- jinns/experimental/_sinuspinn.py +0 -135
- jinns/experimental/_spectralpinn.py +0 -87
- jinns/solver/_seq2seq.py +0 -157
- jinns/utils/_optim.py +0 -147
- jinns/utils/_utils_uspinn.py +0 -727
- jinns-0.8.10.dist-info/RECORD +0 -36
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/LICENSE +0 -0
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/top_level.txt +0 -0
jinns/utils/_pinn.py
CHANGED
|
@@ -2,44 +2,53 @@
|
|
|
2
2
|
Implements utility function to create PINNs
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from typing import Callable
|
|
5
|
+
from typing import Callable, Literal
|
|
6
|
+
from dataclasses import InitVar
|
|
6
7
|
import jax
|
|
7
8
|
import jax.numpy as jnp
|
|
8
|
-
from jax.typing import ArrayLike
|
|
9
9
|
import equinox as eqx
|
|
10
10
|
|
|
11
|
+
from jaxtyping import Array, Key, PyTree, Float
|
|
12
|
+
|
|
13
|
+
from jinns.parameters._params import Params
|
|
14
|
+
|
|
11
15
|
|
|
12
16
|
class _MLP(eqx.Module):
|
|
13
17
|
"""
|
|
14
18
|
Class to construct an equinox module from a key and a eqx_list. To be used
|
|
15
|
-
in pair with the function `create_PINN
|
|
19
|
+
in pair with the function `create_PINN`.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
key : InitVar[Key]
|
|
24
|
+
A jax random key for the layer initializations.
|
|
25
|
+
eqx_list : InitVar[tuple[tuple[Callable, int, int] | Callable, ...]]
|
|
26
|
+
A tuple of tuples of successive equinox modules and activation functions to
|
|
27
|
+
describe the PINN architecture. The inner tuples must have the eqx module or
|
|
28
|
+
activation function as first item, other items represents arguments
|
|
29
|
+
that could be required (eg. the size of the layer).
|
|
30
|
+
The `key` argument need not be given.
|
|
31
|
+
Thus typical example is `eqx_list=
|
|
32
|
+
((eqx.nn.Linear, 2, 20),
|
|
33
|
+
jax.nn.tanh,
|
|
34
|
+
(eqx.nn.Linear, 20, 20),
|
|
35
|
+
jax.nn.tanh,
|
|
36
|
+
(eqx.nn.Linear, 20, 20),
|
|
37
|
+
jax.nn.tanh,
|
|
38
|
+
(eqx.nn.Linear, 20, 1)
|
|
39
|
+
)`.
|
|
16
40
|
"""
|
|
17
41
|
|
|
18
|
-
|
|
42
|
+
key: InitVar[Key] = eqx.field(kw_only=True)
|
|
43
|
+
eqx_list: InitVar[tuple[tuple[Callable, int, int] | Callable, ...]] = eqx.field(
|
|
44
|
+
kw_only=True
|
|
45
|
+
)
|
|
19
46
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
----------
|
|
24
|
-
key
|
|
25
|
-
A jax random key
|
|
26
|
-
eqx_list
|
|
27
|
-
A list of list of successive equinox modules and activation functions to
|
|
28
|
-
describe the PINN architecture. The inner lists have the eqx module or
|
|
29
|
-
axtivation function as first item, other items represents arguments
|
|
30
|
-
that could be required (eg. the size of the layer).
|
|
31
|
-
__Note:__ the `key` argument need not be given.
|
|
32
|
-
Thus typical example is `eqx_list=
|
|
33
|
-
[[eqx.nn.Linear, 2, 20],
|
|
34
|
-
[jax.nn.tanh],
|
|
35
|
-
[eqx.nn.Linear, 20, 20],
|
|
36
|
-
[jax.nn.tanh],
|
|
37
|
-
[eqx.nn.Linear, 20, 20],
|
|
38
|
-
[jax.nn.tanh],
|
|
39
|
-
[eqx.nn.Linear, 20, 1]
|
|
40
|
-
]`
|
|
41
|
-
"""
|
|
47
|
+
# NOTE that the following should NOT be declared as static otherwise the
|
|
48
|
+
# eqx.partition that we use in the PINN module will misbehave
|
|
49
|
+
layers: list[eqx.Module] = eqx.field(init=False)
|
|
42
50
|
|
|
51
|
+
def __post_init__(self, key, eqx_list):
|
|
43
52
|
self.layers = []
|
|
44
53
|
for l in eqx_list:
|
|
45
54
|
if len(l) == 1:
|
|
@@ -48,75 +57,118 @@ class _MLP(eqx.Module):
|
|
|
48
57
|
key, subkey = jax.random.split(key, 2)
|
|
49
58
|
self.layers.append(l[0](*l[1:], key=subkey))
|
|
50
59
|
|
|
51
|
-
def __call__(self, t):
|
|
60
|
+
def __call__(self, t: Float[Array, "input_dim"]) -> Float[Array, "output_dim"]:
|
|
52
61
|
for layer in self.layers:
|
|
53
62
|
t = layer(t)
|
|
54
63
|
return t
|
|
55
64
|
|
|
56
65
|
|
|
57
66
|
class PINN(eqx.Module):
|
|
58
|
-
"""
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
67
|
+
r"""
|
|
68
|
+
A PINN object, i.e., a neural network compatible with the rest of jinns.
|
|
69
|
+
This is typically created with `create_PINN` which creates iternally a
|
|
70
|
+
`_MLP` object. However, a user could directly creates their PINN using this
|
|
71
|
+
class by passing a eqx.Module (for argument `mlp`)
|
|
72
|
+
that plays the role of the NN and that is
|
|
73
|
+
already instanciated.
|
|
63
74
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
75
|
+
Parameters
|
|
76
|
+
----------
|
|
77
|
+
slice_solution : slice
|
|
78
|
+
A jnp.s\_ object which indicates which axis of the PINN output is
|
|
79
|
+
dedicated to the actual equation solution. Default None
|
|
80
|
+
means that slice_solution = the whole PINN output. This argument is useful
|
|
81
|
+
when the PINN is also used to output equation parameters for example
|
|
82
|
+
Note that it must be a slice and not an integer (a preprocessing of the
|
|
83
|
+
user provided argument takes care of it).
|
|
84
|
+
eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
|
|
85
|
+
A string with three possibilities.
|
|
86
|
+
"ODE": the PINN is called with one input `t`.
|
|
87
|
+
"statio_PDE": the PINN is called with one input `x`, `x`
|
|
88
|
+
can be high dimensional.
|
|
89
|
+
"nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
|
|
90
|
+
can be high dimensional.
|
|
91
|
+
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
92
|
+
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
93
|
+
after the `input_transform` function.
|
|
94
|
+
input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
|
|
95
|
+
A function that will be called before entering the PINN. Its output(s)
|
|
96
|
+
must match the PINN inputs (except for the parameters).
|
|
97
|
+
Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
|
|
98
|
+
and the parameters. Default is no operation.
|
|
99
|
+
output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
|
|
100
|
+
A function with arguments begin the same input as the PINN, the PINN
|
|
101
|
+
output and the parameter. This function will be called after exiting the PINN.
|
|
102
|
+
Default is no operation.
|
|
103
|
+
output_slice : slice, default=None
|
|
104
|
+
A jnp.s\_[] to determine the different dimension for the PINN.
|
|
105
|
+
See `shared_pinn_outputs` argument of `create_PINN`.
|
|
106
|
+
mlp : eqx.Module
|
|
107
|
+
The actual neural network instanciated as an eqx.Module.
|
|
108
|
+
"""
|
|
71
109
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
110
|
+
slice_solution: slice = eqx.field(static=True, kw_only=True)
|
|
111
|
+
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"] = eqx.field(
|
|
112
|
+
static=True, kw_only=True
|
|
113
|
+
)
|
|
114
|
+
input_transform: Callable[
|
|
115
|
+
[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
|
|
116
|
+
] = eqx.field(static=True, kw_only=True)
|
|
117
|
+
output_transform: Callable[
|
|
118
|
+
[Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
|
|
119
|
+
Float[Array, "output_dim"],
|
|
120
|
+
] = eqx.field(static=True, kw_only=True)
|
|
121
|
+
output_slice: slice = eqx.field(static=True, kw_only=True, default=None)
|
|
122
|
+
|
|
123
|
+
mlp: InitVar[eqx.Module] = eqx.field(kw_only=True)
|
|
124
|
+
|
|
125
|
+
params: PyTree = eqx.field(init=False)
|
|
126
|
+
static: PyTree = eqx.field(init=False, static=True)
|
|
127
|
+
|
|
128
|
+
def __post_init__(self, mlp):
|
|
81
129
|
self.params, self.static = eqx.partition(mlp, eqx.is_inexact_array)
|
|
82
|
-
self.slice_solution = slice_solution
|
|
83
|
-
self.eq_type = eq_type
|
|
84
|
-
self.input_transform = input_transform
|
|
85
|
-
self.output_transform = output_transform
|
|
86
|
-
self.output_slice = output_slice
|
|
87
130
|
|
|
88
|
-
def init_params(self):
|
|
131
|
+
def init_params(self) -> PyTree:
|
|
132
|
+
"""
|
|
133
|
+
Returns an initial set of parameters
|
|
134
|
+
"""
|
|
89
135
|
return self.params
|
|
90
136
|
|
|
91
|
-
def __call__(self, *args):
|
|
137
|
+
def __call__(self, *args) -> Float[Array, "output_dim"]:
|
|
138
|
+
"""
|
|
139
|
+
Calls `eval_nn` with rearranged arguments
|
|
140
|
+
"""
|
|
92
141
|
if self.eq_type == "ODE":
|
|
93
142
|
(t, params) = args
|
|
94
143
|
if len(t.shape) == 0:
|
|
95
144
|
t = t[..., None] # Add mandatory dimension which can be lacking
|
|
96
145
|
# (eg. for the ODE batches) but this dimension can already
|
|
97
146
|
# exists (eg. for user provided observation times)
|
|
98
|
-
return self.
|
|
147
|
+
return self.eval_nn(t, params)
|
|
99
148
|
if self.eq_type == "statio_PDE":
|
|
100
149
|
(x, params) = args
|
|
101
|
-
return self.
|
|
150
|
+
return self.eval_nn(x, params)
|
|
102
151
|
if self.eq_type == "nonstatio_PDE":
|
|
103
152
|
(t, x, params) = args
|
|
104
153
|
t_x = jnp.concatenate([t, x], axis=-1)
|
|
105
|
-
return self.
|
|
106
|
-
t_x, params, self.input_transform, self.output_transform
|
|
107
|
-
)
|
|
154
|
+
return self.eval_nn(t_x, params)
|
|
108
155
|
raise ValueError("Wrong value for self.eq_type")
|
|
109
156
|
|
|
110
|
-
def
|
|
157
|
+
def eval_nn(
|
|
158
|
+
self,
|
|
159
|
+
inputs: Float[Array, "input_dim"],
|
|
160
|
+
params: Params | PyTree,
|
|
161
|
+
) -> Float[Array, "output_dim"]:
|
|
111
162
|
"""
|
|
112
|
-
|
|
113
|
-
call _eval_nn which always have the same content.
|
|
163
|
+
Evaluate the PINN on some inputs with some params.
|
|
114
164
|
"""
|
|
115
165
|
try:
|
|
116
|
-
model = eqx.combine(params
|
|
117
|
-
except (KeyError, TypeError) as e: # give more flexibility
|
|
166
|
+
model = eqx.combine(params.nn_params, self.static)
|
|
167
|
+
except (KeyError, AttributeError, TypeError) as e: # give more flexibility
|
|
118
168
|
model = eqx.combine(params, self.static)
|
|
119
|
-
res = output_transform(
|
|
169
|
+
res = self.output_transform(
|
|
170
|
+
inputs, model(self.input_transform(inputs, params)).squeeze(), params
|
|
171
|
+
)
|
|
120
172
|
|
|
121
173
|
if self.output_slice is not None:
|
|
122
174
|
res = res[self.output_slice]
|
|
@@ -128,15 +180,20 @@ class PINN(eqx.Module):
|
|
|
128
180
|
|
|
129
181
|
|
|
130
182
|
def create_PINN(
|
|
131
|
-
key,
|
|
132
|
-
eqx_list,
|
|
133
|
-
eq_type,
|
|
134
|
-
dim_x=0,
|
|
135
|
-
input_transform
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
183
|
+
key: Key,
|
|
184
|
+
eqx_list: tuple[tuple[Callable, int, int] | Callable, ...],
|
|
185
|
+
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
|
|
186
|
+
dim_x: int = 0,
|
|
187
|
+
input_transform: Callable[
|
|
188
|
+
[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
|
|
189
|
+
] = None,
|
|
190
|
+
output_transform: Callable[
|
|
191
|
+
[Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
|
|
192
|
+
Float[Array, "output_dim"],
|
|
193
|
+
] = None,
|
|
194
|
+
shared_pinn_outputs: tuple[slice] = None,
|
|
195
|
+
slice_solution: slice = None,
|
|
196
|
+
) -> PINN | list[PINN]:
|
|
140
197
|
r"""
|
|
141
198
|
Utility function to create a standard PINN neural network with the equinox
|
|
142
199
|
library.
|
|
@@ -144,22 +201,25 @@ def create_PINN(
|
|
|
144
201
|
Parameters
|
|
145
202
|
----------
|
|
146
203
|
key
|
|
147
|
-
A
|
|
204
|
+
A JAX random key that will be used to initialize the network
|
|
205
|
+
parameters.
|
|
148
206
|
eqx_list
|
|
149
|
-
A
|
|
150
|
-
describe the PINN architecture. The inner
|
|
151
|
-
|
|
152
|
-
that could be required (eg. the size of the layer).
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
207
|
+
A tuple of tuples of successive equinox modules and activation
|
|
208
|
+
functions to describe the PINN architecture. The inner tuples must have
|
|
209
|
+
the eqx module or activation function as first item, other items
|
|
210
|
+
represent arguments that could be required (eg. the size of the layer).
|
|
211
|
+
|
|
212
|
+
The `key` argument do not need to be given.
|
|
213
|
+
|
|
214
|
+
A typical example is `eqx_list = (
|
|
215
|
+
(eqx.nn.Linear, input_dim, 20),
|
|
216
|
+
(jax.nn.tanh,),
|
|
217
|
+
(eqx.nn.Linear, 20, 20),
|
|
218
|
+
(jax.nn.tanh,),
|
|
219
|
+
(eqx.nn.Linear, 20, 20),
|
|
220
|
+
(jax.nn.tanh,),
|
|
221
|
+
(eqx.nn.Linear, 20, output_dim)
|
|
222
|
+
)`.
|
|
163
223
|
eq_type
|
|
164
224
|
A string with three possibilities.
|
|
165
225
|
"ODE": the PINN is called with one input `t`.
|
|
@@ -169,17 +229,19 @@ def create_PINN(
|
|
|
169
229
|
can be high dimensional.
|
|
170
230
|
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
171
231
|
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
172
|
-
after the `input_transform` function
|
|
232
|
+
after the `input_transform` function.
|
|
173
233
|
dim_x
|
|
174
|
-
An integer. The dimension of `x`. Default `0
|
|
234
|
+
An integer. The dimension of `x`. Default `0`.
|
|
175
235
|
input_transform
|
|
176
236
|
A function that will be called before entering the PINN. Its output(s)
|
|
177
|
-
must match the PINN inputs
|
|
178
|
-
|
|
237
|
+
must match the PINN inputs (except for the parameters).
|
|
238
|
+
Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
|
|
239
|
+
and the parameters. Default is no operation.
|
|
179
240
|
output_transform
|
|
180
|
-
A function with arguments the same input
|
|
181
|
-
output
|
|
182
|
-
|
|
241
|
+
A function with arguments begin the same input as the PINN, the PINN
|
|
242
|
+
output and the parameter. This function will be called after exiting
|
|
243
|
+
the PINN.
|
|
244
|
+
Default is no operation.
|
|
183
245
|
shared_pinn_outputs
|
|
184
246
|
Default is None, for a stantard PINN.
|
|
185
247
|
A tuple of jnp.s\_[] (slices) to determine the different output for each
|
|
@@ -192,15 +254,18 @@ def create_PINN(
|
|
|
192
254
|
slice_solution
|
|
193
255
|
A jnp.s\_ object which indicates which axis of the PINN output is
|
|
194
256
|
dedicated to the actual equation solution. Default None
|
|
195
|
-
means that slice_solution = the whole PINN output. This argument is
|
|
196
|
-
when the PINN is also used to output equation parameters for
|
|
197
|
-
Note that it must be a slice and not an integer (a
|
|
198
|
-
user provided argument takes care of it)
|
|
257
|
+
means that slice_solution = the whole PINN output. This argument is
|
|
258
|
+
useful when the PINN is also used to output equation parameters for
|
|
259
|
+
example Note that it must be a slice and not an integer (a
|
|
260
|
+
preprocessing of the user provided argument takes care of it).
|
|
199
261
|
|
|
200
262
|
|
|
201
263
|
Returns
|
|
202
264
|
-------
|
|
203
|
-
|
|
265
|
+
pinn
|
|
266
|
+
A PINN instance or, when `shared_pinn_ouput` is not None,
|
|
267
|
+
a list of PINN instances with the same structure is returned,
|
|
268
|
+
only differing by there final slicing of the network output.
|
|
204
269
|
|
|
205
270
|
Raises
|
|
206
271
|
------
|
|
@@ -240,23 +305,30 @@ def create_PINN(
|
|
|
240
305
|
|
|
241
306
|
if output_transform is None:
|
|
242
307
|
|
|
243
|
-
def output_transform(_in_pinn, _out_pinn):
|
|
308
|
+
def output_transform(_in_pinn, _out_pinn, _params):
|
|
244
309
|
return _out_pinn
|
|
245
310
|
|
|
246
|
-
mlp = _MLP(key, eqx_list)
|
|
311
|
+
mlp = _MLP(key=key, eqx_list=eqx_list)
|
|
247
312
|
|
|
248
313
|
if shared_pinn_outputs is not None:
|
|
249
314
|
pinns = []
|
|
250
315
|
for output_slice in shared_pinn_outputs:
|
|
251
316
|
pinn = PINN(
|
|
252
|
-
mlp,
|
|
253
|
-
slice_solution,
|
|
254
|
-
eq_type,
|
|
255
|
-
input_transform,
|
|
256
|
-
output_transform,
|
|
257
|
-
output_slice,
|
|
317
|
+
mlp=mlp,
|
|
318
|
+
slice_solution=slice_solution,
|
|
319
|
+
eq_type=eq_type,
|
|
320
|
+
input_transform=input_transform,
|
|
321
|
+
output_transform=output_transform,
|
|
322
|
+
output_slice=output_slice,
|
|
258
323
|
)
|
|
259
324
|
pinns.append(pinn)
|
|
260
325
|
return pinns
|
|
261
|
-
pinn = PINN(
|
|
326
|
+
pinn = PINN(
|
|
327
|
+
mlp=mlp,
|
|
328
|
+
slice_solution=slice_solution,
|
|
329
|
+
eq_type=eq_type,
|
|
330
|
+
input_transform=input_transform,
|
|
331
|
+
output_transform=output_transform,
|
|
332
|
+
output_slice=None,
|
|
333
|
+
)
|
|
262
334
|
return pinn
|
jinns/utils/_save_load.py
CHANGED
|
@@ -2,58 +2,64 @@
|
|
|
2
2
|
Implements save and load functions
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
from typing import Callable, Literal
|
|
5
6
|
import pickle
|
|
6
7
|
import jax
|
|
7
8
|
import equinox as eqx
|
|
8
9
|
|
|
9
|
-
from jinns.utils._pinn import create_PINN
|
|
10
|
-
from jinns.utils._spinn import create_SPINN
|
|
11
|
-
from jinns.utils._hyperpinn import create_HYPERPINN
|
|
10
|
+
from jinns.utils._pinn import create_PINN, PINN
|
|
11
|
+
from jinns.utils._spinn import create_SPINN, SPINN
|
|
12
|
+
from jinns.utils._hyperpinn import create_HYPERPINN, HYPERPINN
|
|
13
|
+
from jinns.parameters._params import Params, ParamsDict
|
|
12
14
|
|
|
13
15
|
|
|
14
|
-
def function_to_string(
|
|
16
|
+
def function_to_string(
|
|
17
|
+
eqx_list: tuple[tuple[Callable, int, int] | Callable, ...]
|
|
18
|
+
) -> tuple[tuple[str, int, int] | str, ...]:
|
|
15
19
|
"""
|
|
16
20
|
We need this transformation for eqx_list to be pickled
|
|
17
21
|
|
|
18
|
-
From `
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
`
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
22
|
+
From `((eqx.nn.Linear, 2, 20),
|
|
23
|
+
(jax.nn.tanh),
|
|
24
|
+
(eqx.nn.Linear, 20, 20),
|
|
25
|
+
(jax.nn.tanh),
|
|
26
|
+
(eqx.nn.Linear, 20, 20),
|
|
27
|
+
(jax.nn.tanh),
|
|
28
|
+
(eqx.nn.Linear, 20, 1))` to
|
|
29
|
+
`(("Linear", 2, 20),
|
|
30
|
+
("tanh"),
|
|
31
|
+
("Linear", 20, 20),
|
|
32
|
+
("tanh"),
|
|
33
|
+
("Linear", 20, 20),
|
|
34
|
+
("tanh"),
|
|
35
|
+
("Linear", 20, 1))`
|
|
32
36
|
"""
|
|
33
37
|
return jax.tree_util.tree_map(
|
|
34
38
|
lambda x: x.__name__ if hasattr(x, "__call__") else x, eqx_list
|
|
35
39
|
)
|
|
36
40
|
|
|
37
41
|
|
|
38
|
-
def string_to_function(
|
|
42
|
+
def string_to_function(
|
|
43
|
+
eqx_list_with_string: tuple[tuple[str, int, int] | str, ...]
|
|
44
|
+
) -> tuple[tuple[Callable, int, int] | Callable, ...]:
|
|
39
45
|
"""
|
|
40
46
|
We need this transformation for eqx_list at the loading ("unpickling")
|
|
41
47
|
operation.
|
|
42
48
|
|
|
43
|
-
From `
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
to `
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
49
|
+
From `(("Linear", 2, 20),
|
|
50
|
+
("tanh"),
|
|
51
|
+
("Linear", 20, 20),
|
|
52
|
+
("tanh"),
|
|
53
|
+
("Linear", 20, 20),
|
|
54
|
+
("tanh"),
|
|
55
|
+
("Linear", 20, 1))`
|
|
56
|
+
to `((eqx.nn.Linear, 2, 20),
|
|
57
|
+
(jax.nn.tanh),
|
|
58
|
+
(eqx.nn.Linear, 20, 20),
|
|
59
|
+
(jax.nn.tanh),
|
|
60
|
+
(eqx.nn.Linear, 20, 20),
|
|
61
|
+
(jax.nn.tanh),
|
|
62
|
+
(eqx.nn.Linear, 20, 1))`
|
|
57
63
|
"""
|
|
58
64
|
|
|
59
65
|
def _str_to_fun(l):
|
|
@@ -76,16 +82,36 @@ def string_to_function(eqx_list_with_string):
|
|
|
76
82
|
)
|
|
77
83
|
|
|
78
84
|
|
|
79
|
-
def save_pinn(
|
|
85
|
+
def save_pinn(
|
|
86
|
+
filename: str,
|
|
87
|
+
u: PINN | HYPERPINN | SPINN,
|
|
88
|
+
params: Params | ParamsDict,
|
|
89
|
+
kwargs_creation,
|
|
90
|
+
):
|
|
80
91
|
"""
|
|
81
92
|
Save a PINN / HyperPINN / SPINN model
|
|
82
93
|
This function creates 3 files, beggining by `filename`
|
|
83
94
|
|
|
84
95
|
1. an eqx file to save the eqx.Module (the PINN, HyperPINN, ...)
|
|
85
|
-
2. a pickle file for the parameters
|
|
86
|
-
3. a pickle file for the arguments that have been used at PINN
|
|
87
|
-
|
|
88
|
-
|
|
96
|
+
2. a pickle file for the parameters of the equation
|
|
97
|
+
3. a pickle file for the arguments that have been used at PINN creation
|
|
98
|
+
and that we need to reconstruct the eqx.module later on.
|
|
99
|
+
|
|
100
|
+
Note that the equation parameters `Params.eq_params` go in the
|
|
101
|
+
pickle file while the neural network parameters `Params.nn_params` go in
|
|
102
|
+
the `"*-module.eqx"` file (normal behaviour with `eqx.
|
|
103
|
+
tree_serialise_leaves`).
|
|
104
|
+
|
|
105
|
+
Equation parameters are saved apart because the initial type of attribute
|
|
106
|
+
`params` in PINN / HYPERPINN / SPINN is not `Params` nor `ParamsDict`
|
|
107
|
+
but `PyTree` as inherited from `eqx.partition`.
|
|
108
|
+
Therefore, if we want to ensure a proper serialization/deserialization:
|
|
109
|
+
- we cannot save a `Params` object at this
|
|
110
|
+
attribute field ; the `Params` object must be split into `Params.nn_params`
|
|
111
|
+
(type `PyTree`) and `Params.eq_params` (type `dict`).
|
|
112
|
+
- in the case of a `ParamsDict` we cannot save `ParamsDict.nn_params` at
|
|
113
|
+
the attribute field `params` because it is not a `PyTree` (as expected in
|
|
114
|
+
the PINN / HYPERPINN / SPINN signature) but it is still a dictionary.
|
|
89
115
|
|
|
90
116
|
Parameters
|
|
91
117
|
----------
|
|
@@ -94,17 +120,32 @@ def save_pinn(filename, u, params, kwargs_creation):
|
|
|
94
120
|
u
|
|
95
121
|
The PINN
|
|
96
122
|
params
|
|
97
|
-
|
|
98
|
-
Typically, it is a dictionary of
|
|
99
|
-
dictionaries: `eq_params` and `nn_params`, respectively the
|
|
100
|
-
differential equation parameters and the neural network parameter
|
|
123
|
+
Params or ParamsDict to be save
|
|
101
124
|
kwargs_creation
|
|
102
125
|
The dictionary of arguments that were used to create the PINN, e.g.
|
|
103
126
|
the layers list, O/PDE type, etc.
|
|
104
127
|
"""
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
128
|
+
if isinstance(params, Params):
|
|
129
|
+
if isinstance(u, HYPERPINN):
|
|
130
|
+
u = eqx.tree_at(lambda m: m.params_hyper, u, params)
|
|
131
|
+
elif isinstance(u, (PINN, SPINN)):
|
|
132
|
+
u = eqx.tree_at(lambda m: m.params, u, params)
|
|
133
|
+
eqx.tree_serialise_leaves(filename + "-module.eqx", u)
|
|
134
|
+
|
|
135
|
+
elif isinstance(params, ParamsDict):
|
|
136
|
+
for key, params_ in params.nn_params.items():
|
|
137
|
+
if isinstance(u, HYPERPINN):
|
|
138
|
+
u = eqx.tree_at(lambda m: m.params_hyper, u, params_)
|
|
139
|
+
elif isinstance(u, (PINN, SPINN)):
|
|
140
|
+
u = eqx.tree_at(lambda m: m.params, u, params_)
|
|
141
|
+
eqx.tree_serialise_leaves(filename + f"-module_{key}.eqx", u)
|
|
142
|
+
|
|
143
|
+
else:
|
|
144
|
+
raise ValueError("The parameters to be saved must be a Params or a ParamsDict")
|
|
145
|
+
|
|
146
|
+
with open(filename + "-eq_params.pkl", "wb") as f:
|
|
147
|
+
pickle.dump(params.eq_params, f)
|
|
148
|
+
|
|
108
149
|
kwargs_creation = kwargs_creation.copy() # avoid side-effect that would be
|
|
109
150
|
# very probably harmless anyway
|
|
110
151
|
|
|
@@ -124,7 +165,11 @@ def save_pinn(filename, u, params, kwargs_creation):
|
|
|
124
165
|
pickle.dump(kwargs_creation, f)
|
|
125
166
|
|
|
126
167
|
|
|
127
|
-
def load_pinn(
|
|
168
|
+
def load_pinn(
|
|
169
|
+
filename: str,
|
|
170
|
+
type_: Literal["pinn", "hyperpinn", "spinn"],
|
|
171
|
+
key_list_for_paramsdict: list[str] = None,
|
|
172
|
+
) -> tuple[eqx.Module, Params | ParamsDict]:
|
|
128
173
|
"""
|
|
129
174
|
Load a PINN model. This function needs to access 3 files :
|
|
130
175
|
`{filename}-module.eqx`, `{filename}-parameters.pkl` and
|
|
@@ -132,27 +177,35 @@ def load_pinn(filename, type_):
|
|
|
132
177
|
|
|
133
178
|
These files are created by `jinns.utils.save_pinn`.
|
|
134
179
|
|
|
135
|
-
Note that this requires equinox
|
|
180
|
+
Note that this requires equinox>v0.11.3 for the
|
|
136
181
|
`eqx.filter_eval_shape` to work.
|
|
137
182
|
|
|
183
|
+
See note in `save_pinn` for more details about the saving process
|
|
184
|
+
|
|
138
185
|
Parameters
|
|
139
186
|
----------
|
|
140
187
|
filename
|
|
141
188
|
Filename (prefix) without extension.
|
|
142
189
|
type_
|
|
143
190
|
Type of model to load. Must be in ["pinn", "hyperpinn", "spinn"].
|
|
191
|
+
key_list_for_paramsdict
|
|
192
|
+
Pass the name of the keys of the dictionnary `ParamsDict.nn_params`. Default is None. In this case, we expect to retrieve a ParamsDict.
|
|
144
193
|
|
|
145
194
|
Returns
|
|
146
195
|
-------
|
|
147
196
|
u_reloaded
|
|
148
197
|
The reloaded PINN
|
|
149
|
-
|
|
198
|
+
params
|
|
150
199
|
The reloaded parameters
|
|
151
200
|
"""
|
|
152
201
|
with open(filename + "-arguments.pkl", "rb") as f:
|
|
153
202
|
kwargs_reloaded = pickle.load(f)
|
|
154
|
-
|
|
155
|
-
|
|
203
|
+
try:
|
|
204
|
+
with open(filename + "-eq_params.pkl", "rb") as f:
|
|
205
|
+
eq_params_reloaded = pickle.load(f)
|
|
206
|
+
except FileNotFoundError:
|
|
207
|
+
eq_params_reloaded = {}
|
|
208
|
+
print("No pickle file for equation parameters found!")
|
|
156
209
|
kwargs_reloaded["eqx_list"] = string_to_function(kwargs_reloaded["eqx_list"])
|
|
157
210
|
if type_ == "pinn":
|
|
158
211
|
# next line creates a shallow model, the jax arrays are just shapes and
|
|
@@ -167,9 +220,21 @@ def load_pinn(filename, type_):
|
|
|
167
220
|
u_reloaded_shallow = eqx.filter_eval_shape(create_HYPERPINN, **kwargs_reloaded)
|
|
168
221
|
else:
|
|
169
222
|
raise ValueError(f"{type_} is not valid")
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
223
|
+
if key_list_for_paramsdict is None:
|
|
224
|
+
# now the empty structure is populated with the actual saved array values
|
|
225
|
+
# stored in the eqx file
|
|
226
|
+
u_reloaded = eqx.tree_deserialise_leaves(
|
|
227
|
+
filename + "-module.eqx", u_reloaded_shallow
|
|
228
|
+
)
|
|
229
|
+
params = Params(
|
|
230
|
+
nn_params=u_reloaded.init_params(), eq_params=eq_params_reloaded
|
|
231
|
+
)
|
|
232
|
+
else:
|
|
233
|
+
nn_params_dict = {}
|
|
234
|
+
for key in key_list_for_paramsdict:
|
|
235
|
+
u_reloaded = eqx.tree_deserialise_leaves(
|
|
236
|
+
filename + f"-module_{key}.eqx", u_reloaded_shallow
|
|
237
|
+
)
|
|
238
|
+
nn_params_dict[key] = u_reloaded.init_params()
|
|
239
|
+
params = ParamsDict(nn_params=nn_params_dict, eq_params=eq_params_reloaded)
|
|
240
|
+
return u_reloaded, params
|