jinns 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_display.py +78 -21
- jinns/loss/_DynamicLoss.py +405 -907
- jinns/loss/_LossPDE.py +303 -154
- jinns/loss/__init__.py +0 -6
- jinns/loss/_boundary_conditions.py +231 -65
- jinns/loss/_operators.py +201 -45
- jinns/solver/_solve.py +2 -3
- jinns/utils/__init__.py +2 -1
- jinns/utils/_pinn.py +308 -0
- jinns/utils/_spinn.py +237 -0
- jinns/utils/_utils.py +32 -306
- {jinns-0.4.1.dist-info → jinns-0.5.0.dist-info}/METADATA +15 -2
- jinns-0.5.0.dist-info/RECORD +24 -0
- {jinns-0.4.1.dist-info → jinns-0.5.0.dist-info}/WHEEL +1 -1
- jinns-0.4.1.dist-info/RECORD +0 -22
- {jinns-0.4.1.dist-info → jinns-0.5.0.dist-info}/LICENSE +0 -0
- {jinns-0.4.1.dist-info → jinns-0.5.0.dist-info}/top_level.txt +0 -0
jinns/utils/_pinn.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class _MLP(eqx.Module):
|
|
7
|
+
"""
|
|
8
|
+
Class to construct an equinox module from a key and a eqx_list. To be used
|
|
9
|
+
in pair with the function `create_PINN`
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
layers: list
|
|
13
|
+
|
|
14
|
+
def __init__(self, key, eqx_list):
|
|
15
|
+
"""
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
key
|
|
19
|
+
A jax random key
|
|
20
|
+
eqx_list
|
|
21
|
+
A list of list of successive equinox modules and activation functions to
|
|
22
|
+
describe the PINN architecture. The inner lists have the eqx module or
|
|
23
|
+
axtivation function as first item, other items represents arguments
|
|
24
|
+
that could be required (eg. the size of the layer).
|
|
25
|
+
__Note:__ the `key` argument need not be given.
|
|
26
|
+
Thus typical example is `eqx_list=
|
|
27
|
+
[[eqx.nn.Linear, 2, 20],
|
|
28
|
+
[jax.nn.tanh],
|
|
29
|
+
[eqx.nn.Linear, 20, 20],
|
|
30
|
+
[jax.nn.tanh],
|
|
31
|
+
[eqx.nn.Linear, 20, 20],
|
|
32
|
+
[jax.nn.tanh],
|
|
33
|
+
[eqx.nn.Linear, 20, 1]
|
|
34
|
+
]`
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
self.layers = []
|
|
38
|
+
# TODO we are limited currently in the number of layer type we can
|
|
39
|
+
# parse and we lack some safety checks
|
|
40
|
+
for l in eqx_list:
|
|
41
|
+
if len(l) == 1:
|
|
42
|
+
self.layers.append(l[0])
|
|
43
|
+
else:
|
|
44
|
+
# By default we append a random key at the end of the
|
|
45
|
+
# arguments fed into a layer module call
|
|
46
|
+
key, subkey = jax.random.split(key, 2)
|
|
47
|
+
# the argument key is keyword only
|
|
48
|
+
self.layers.append(l[0](*l[1:], key=subkey))
|
|
49
|
+
|
|
50
|
+
def __call__(self, t):
|
|
51
|
+
for layer in self.layers:
|
|
52
|
+
t = layer(t)
|
|
53
|
+
return t
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class PINN:
|
|
57
|
+
"""
|
|
58
|
+
Basically a wrapper around the `__call__` function to be able to give a type to
|
|
59
|
+
our former `self.u`
|
|
60
|
+
The function create_PINN has the role to population the `__call__` function
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(self, key, eqx_list, output_slice=None):
|
|
64
|
+
_pinn = _MLP(key, eqx_list)
|
|
65
|
+
self.params, self.static = eqx.partition(_pinn, eqx.is_inexact_array)
|
|
66
|
+
self.output_slice = output_slice
|
|
67
|
+
|
|
68
|
+
def init_params(self):
|
|
69
|
+
return self.params
|
|
70
|
+
|
|
71
|
+
def __call__(self, *args, **kwargs):
|
|
72
|
+
return self.apply_fn(self, *args, **kwargs)
|
|
73
|
+
|
|
74
|
+
def _eval_nn(self, inputs, u_params, eq_params, input_transform, output_transform):
|
|
75
|
+
"""
|
|
76
|
+
inner function to factorize code. apply_fn (which takes varying forms)
|
|
77
|
+
call _eval_nn which always have the same content.
|
|
78
|
+
"""
|
|
79
|
+
model = eqx.combine(u_params, self.static)
|
|
80
|
+
res = output_transform(inputs, model(input_transform(inputs)).squeeze())
|
|
81
|
+
|
|
82
|
+
if self.output_slice is not None:
|
|
83
|
+
res = res[self.output_slice]
|
|
84
|
+
|
|
85
|
+
## force (1,) output for non vectorial solution (consistency)
|
|
86
|
+
if not res.shape:
|
|
87
|
+
return jnp.expand_dims(res, axis=-1)
|
|
88
|
+
else:
|
|
89
|
+
return res
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def create_PINN(
|
|
93
|
+
key,
|
|
94
|
+
eqx_list,
|
|
95
|
+
eq_type,
|
|
96
|
+
dim_x=0,
|
|
97
|
+
with_eq_params=None,
|
|
98
|
+
input_transform=None,
|
|
99
|
+
output_transform=None,
|
|
100
|
+
shared_pinn_outputs=None,
|
|
101
|
+
):
|
|
102
|
+
"""
|
|
103
|
+
Utility function to create a standard PINN neural network with the equinox
|
|
104
|
+
library.
|
|
105
|
+
|
|
106
|
+
Parameters
|
|
107
|
+
----------
|
|
108
|
+
key
|
|
109
|
+
A jax random key that will be used to initialize the network parameters
|
|
110
|
+
eqx_list
|
|
111
|
+
A list of list of successive equinox modules and activation functions to
|
|
112
|
+
describe the PINN architecture. The inner lists have the eqx module or
|
|
113
|
+
axtivation function as first item, other items represents arguments
|
|
114
|
+
that could be required (eg. the size of the layer).
|
|
115
|
+
__Note:__ the `key` argument need not be given.
|
|
116
|
+
Thus typical example is `eqx_list=
|
|
117
|
+
[[eqx.nn.Linear, 2, 20],
|
|
118
|
+
[jax.nn.tanh],
|
|
119
|
+
[eqx.nn.Linear, 20, 20],
|
|
120
|
+
[jax.nn.tanh],
|
|
121
|
+
[eqx.nn.Linear, 20, 20],
|
|
122
|
+
[jax.nn.tanh],
|
|
123
|
+
[eqx.nn.Linear, 20, 1]
|
|
124
|
+
]`
|
|
125
|
+
eq_type
|
|
126
|
+
A string with three possibilities.
|
|
127
|
+
"ODE": the PINN is called with one input `t`.
|
|
128
|
+
"statio_PDE": the PINN is called with one input `x`, `x`
|
|
129
|
+
can be high dimensional.
|
|
130
|
+
"nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
|
|
131
|
+
can be high dimensional.
|
|
132
|
+
**Note: the input dimension as given in eqx_list has to match the sum
|
|
133
|
+
of the dimension of `t` + the dimension of `x` + the number of
|
|
134
|
+
parameters in `eq_params` if with_eq_params is `True` (see below)**
|
|
135
|
+
dim_x
|
|
136
|
+
An integer. The dimension of `x`. Default `0`
|
|
137
|
+
with_eq_params
|
|
138
|
+
Default is None. Otherwise a list of keys from the dict `eq_params`
|
|
139
|
+
that the network also takes as inputs.
|
|
140
|
+
the equation parameters (`eq_params`).
|
|
141
|
+
**If some keys are provided, the input dimension
|
|
142
|
+
as given in eqx_list must take into account the number of such provided
|
|
143
|
+
keys (i.e., the input dimension is the addition of the dimension of ``t``
|
|
144
|
+
+ the dimension of ``x`` + the number of ``eq_params``)**
|
|
145
|
+
input_transform
|
|
146
|
+
A function that will be called before entering the PINN. Its output(s)
|
|
147
|
+
must match the PINN inputs. Default is the No operation
|
|
148
|
+
output_transform
|
|
149
|
+
A function with arguments the same input(s) as the PINN AND the PINN
|
|
150
|
+
output that will be called after exiting the PINN. Default is the No
|
|
151
|
+
operation
|
|
152
|
+
shared_pinn_outputs
|
|
153
|
+
A tuple of jnp.s_[] (slices) to determine the different output for each
|
|
154
|
+
network. In this case we return a list of PINNs, one for each output in
|
|
155
|
+
shared_pinn_outputs. This is useful to create PINNs that share the
|
|
156
|
+
same network and same parameters. Default is None, we only return one PINN.
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
Returns
|
|
160
|
+
-------
|
|
161
|
+
init_fn
|
|
162
|
+
A function which (re-)initializes the PINN parameters with the provided
|
|
163
|
+
jax random key
|
|
164
|
+
apply_fn
|
|
165
|
+
A function to apply the neural network on given inputs for given
|
|
166
|
+
parameters. A typical call will be of the form `u(t, nn_params)` for
|
|
167
|
+
ODE or `u(t, x, nn_params)` for nD PDEs (`x` being multidimensional)
|
|
168
|
+
or even `u(t, x, nn_params, eq_params)` if with_eq_params is `True`
|
|
169
|
+
|
|
170
|
+
Raises
|
|
171
|
+
------
|
|
172
|
+
RuntimeError
|
|
173
|
+
If the parameter value for eq_type is not in `["ODE", "statio_PDE",
|
|
174
|
+
"nonstatio_PDE"]`
|
|
175
|
+
RuntimeError
|
|
176
|
+
If we have a `dim_x > 0` and `eq_type == "ODE"`
|
|
177
|
+
or if we have a `dim_x = 0` and `eq_type != "ODE"`
|
|
178
|
+
"""
|
|
179
|
+
if eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
|
|
180
|
+
raise RuntimeError("Wrong parameter value for eq_type")
|
|
181
|
+
|
|
182
|
+
if eq_type == "ODE" and dim_x != 0:
|
|
183
|
+
raise RuntimeError("Wrong parameter combination eq_type and dim_x")
|
|
184
|
+
|
|
185
|
+
if eq_type != "ODE" and dim_x == 0:
|
|
186
|
+
raise RuntimeError("Wrong parameter combination eq_type and dim_x")
|
|
187
|
+
|
|
188
|
+
dim_t = 0 if eq_type == "statio_PDE" else 1
|
|
189
|
+
dim_in_params = len(with_eq_params) if with_eq_params is not None else 0
|
|
190
|
+
try:
|
|
191
|
+
nb_inputs_declared = eqx_list[0][1] # normally we look for 2nd ele of 1st layer
|
|
192
|
+
except IndexError:
|
|
193
|
+
nb_inputs_declared = eqx_list[1][1]
|
|
194
|
+
# but we can have, eg, a flatten first layer
|
|
195
|
+
|
|
196
|
+
try:
|
|
197
|
+
nb_outputs_declared = eqx_list[-1][2] # normally we look for 3rd ele of
|
|
198
|
+
# last layer
|
|
199
|
+
except IndexError:
|
|
200
|
+
nb_outputs_declared = eqx_list[-2][2]
|
|
201
|
+
# but we can have, eg, a `jnp.exp` last layer
|
|
202
|
+
|
|
203
|
+
# NOTE Currently the check below is disabled because we added
|
|
204
|
+
# input_transform
|
|
205
|
+
# if dim_t + dim_x + dim_in_params != nb_inputs_declared:
|
|
206
|
+
# raise RuntimeError("Error in the declarations of the number of parameters")
|
|
207
|
+
|
|
208
|
+
if input_transform is None:
|
|
209
|
+
|
|
210
|
+
def input_transform(_in):
|
|
211
|
+
return _in
|
|
212
|
+
|
|
213
|
+
if output_transform is None:
|
|
214
|
+
|
|
215
|
+
def output_transform(_in_pinn, _out_pinn):
|
|
216
|
+
return _out_pinn
|
|
217
|
+
|
|
218
|
+
if eq_type == "ODE":
|
|
219
|
+
if with_eq_params is None:
|
|
220
|
+
|
|
221
|
+
def apply_fn(self, t, u_params, eq_params=None):
|
|
222
|
+
t = t[
|
|
223
|
+
None
|
|
224
|
+
] # Note that we added a dimension to t which is lacking for the ODE batches
|
|
225
|
+
return self._eval_nn(
|
|
226
|
+
t, u_params, eq_params, input_transform, output_transform
|
|
227
|
+
).squeeze()
|
|
228
|
+
|
|
229
|
+
else:
|
|
230
|
+
|
|
231
|
+
def apply_fn(self, t, u_params, eq_params):
|
|
232
|
+
t = t[
|
|
233
|
+
None
|
|
234
|
+
] # We added a dimension to t which is lacking for the ODE batches
|
|
235
|
+
eq_params_flatten = jnp.concatenate(
|
|
236
|
+
[e.ravel() for k, e in eq_params.items() if k in with_eq_params]
|
|
237
|
+
)
|
|
238
|
+
t_eq_params = jnp.concatenate([t, eq_params_flatten], axis=-1)
|
|
239
|
+
return self._eval_nn(
|
|
240
|
+
t_eq_params, u_params, eq_params, input_transform, output_transform
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
elif eq_type == "statio_PDE":
|
|
244
|
+
# Here we add an argument `x` which can be high dimensional
|
|
245
|
+
if with_eq_params is None:
|
|
246
|
+
|
|
247
|
+
def apply_fn(self, x, u_params, eq_params=None):
|
|
248
|
+
return self._eval_nn(
|
|
249
|
+
x, u_params, eq_params, input_transform, output_transform
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
else:
|
|
253
|
+
|
|
254
|
+
def apply_fn(self, x, u_params, eq_params):
|
|
255
|
+
eq_params_flatten = jnp.concatenate(
|
|
256
|
+
[e.ravel() for k, e in eq_params.items() if k in with_eq_params]
|
|
257
|
+
)
|
|
258
|
+
x_eq_params = jnp.concatenate([x, eq_params_flatten], axis=-1)
|
|
259
|
+
return self._eval_nn(
|
|
260
|
+
x_eq_params, u_params, eq_params, input_transform, output_transform
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
elif eq_type == "nonstatio_PDE":
|
|
264
|
+
# Here we add an argument `x` which can be high dimensional
|
|
265
|
+
if with_eq_params is None:
|
|
266
|
+
|
|
267
|
+
def apply_fn(self, t, x, u_params, eq_params=None):
|
|
268
|
+
t_x = jnp.concatenate([t, x], axis=-1)
|
|
269
|
+
return self._eval_nn(
|
|
270
|
+
t_x, u_params, eq_params, input_transform, output_transform
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
else:
|
|
274
|
+
|
|
275
|
+
def apply_fn(self, t, x, u_params, eq_params):
|
|
276
|
+
t_x = jnp.concatenate([t, x], axis=-1)
|
|
277
|
+
eq_params_flatten = jnp.concatenate(
|
|
278
|
+
[e.ravel() for k, e in eq_params.items() if k in with_eq_params]
|
|
279
|
+
)
|
|
280
|
+
t_x_eq_params = jnp.concatenate([t_x, eq_params_flatten], axis=-1)
|
|
281
|
+
return self._eval_nn(
|
|
282
|
+
t_x_eq_params,
|
|
283
|
+
u_params,
|
|
284
|
+
eq_params,
|
|
285
|
+
input_transform,
|
|
286
|
+
output_transform,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
else:
|
|
290
|
+
raise RuntimeError("Wrong parameter value for eq_type")
|
|
291
|
+
|
|
292
|
+
if shared_pinn_outputs is not None:
|
|
293
|
+
pinns = []
|
|
294
|
+
static = None
|
|
295
|
+
for output_slice in shared_pinn_outputs:
|
|
296
|
+
pinn = PINN(key, eqx_list, output_slice)
|
|
297
|
+
pinn.apply_fn = apply_fn
|
|
298
|
+
# all the pinns are in fact the same so we share the same static
|
|
299
|
+
if static is None:
|
|
300
|
+
static = pinn.static
|
|
301
|
+
else:
|
|
302
|
+
pinn.static = static
|
|
303
|
+
pinns.append(pinn)
|
|
304
|
+
return pinns
|
|
305
|
+
else:
|
|
306
|
+
pinn = PINN(key, eqx_list)
|
|
307
|
+
pinn.apply_fn = apply_fn
|
|
308
|
+
return pinn
|
jinns/utils/_spinn.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class _SPINN(eqx.Module):
|
|
7
|
+
"""
|
|
8
|
+
Construct a Separable PINN as proposed in
|
|
9
|
+
Cho et al., _Separable Physics-Informed Neural Networks_, NeurIPS, 2023
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
layers: list
|
|
13
|
+
separated_mlp: list
|
|
14
|
+
d: int
|
|
15
|
+
r: int
|
|
16
|
+
m: int
|
|
17
|
+
|
|
18
|
+
def __init__(self, key, d, r, eqx_list, m=1):
|
|
19
|
+
"""
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
key
|
|
23
|
+
A jax random key
|
|
24
|
+
d
|
|
25
|
+
An integer. The number of dimensions to treat separately
|
|
26
|
+
r
|
|
27
|
+
An integer. The dimension of the embedding
|
|
28
|
+
eqx_list
|
|
29
|
+
A list of list of successive equinox modules and activation functions to
|
|
30
|
+
describe *each separable PINN architecture*.
|
|
31
|
+
The inner lists have the eqx module or
|
|
32
|
+
axtivation function as first item, other items represents arguments
|
|
33
|
+
that could be required (eg. the size of the layer).
|
|
34
|
+
__Note:__ the `key` argument need not be given.
|
|
35
|
+
Thus typical example is `eqx_list=
|
|
36
|
+
[[eqx.nn.Linear, d, 20],
|
|
37
|
+
[jax.nn.tanh],
|
|
38
|
+
[eqx.nn.Linear, 20, 20],
|
|
39
|
+
[jax.nn.tanh],
|
|
40
|
+
[eqx.nn.Linear, 20, 20],
|
|
41
|
+
[jax.nn.tanh],
|
|
42
|
+
[eqx.nn.Linear, 20, r]
|
|
43
|
+
]`
|
|
44
|
+
"""
|
|
45
|
+
keys = jax.random.split(key, 8)
|
|
46
|
+
|
|
47
|
+
self.d = d
|
|
48
|
+
self.r = r
|
|
49
|
+
self.m = m
|
|
50
|
+
|
|
51
|
+
self.separated_mlp = []
|
|
52
|
+
for d in range(self.d):
|
|
53
|
+
self.layers = []
|
|
54
|
+
for l in eqx_list:
|
|
55
|
+
if len(l) == 1:
|
|
56
|
+
self.layers.append(l[0])
|
|
57
|
+
else:
|
|
58
|
+
key, subkey = jax.random.split(key, 2)
|
|
59
|
+
self.layers.append(l[0](*l[1:], key=subkey))
|
|
60
|
+
self.separated_mlp.append(self.layers)
|
|
61
|
+
|
|
62
|
+
def __call__(self, t, x):
|
|
63
|
+
if t is not None:
|
|
64
|
+
dimensions = jnp.concatenate([t, x.flatten()], axis=0)
|
|
65
|
+
else:
|
|
66
|
+
dimensions = jnp.concatenate([x.flatten()], axis=0)
|
|
67
|
+
outputs = []
|
|
68
|
+
for d in range(self.d):
|
|
69
|
+
t_ = dimensions[d][None]
|
|
70
|
+
for layer in self.separated_mlp[d]:
|
|
71
|
+
t_ = layer(t_)
|
|
72
|
+
outputs += [t_]
|
|
73
|
+
return jnp.asarray(outputs)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class SPINN:
|
|
77
|
+
"""
|
|
78
|
+
Basically a wrapper around the `__call__` function to be able to give a type to
|
|
79
|
+
our former `self.u`
|
|
80
|
+
The function create_SPINN has the role to population the `__call__` function
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(self, key, d, r, eqx_list, m=1):
|
|
84
|
+
self.d, self.r, self.m = d, r, m
|
|
85
|
+
_spinn = _SPINN(key, d, r, eqx_list, m)
|
|
86
|
+
self.params, self.static = eqx.partition(_spinn, eqx.is_inexact_array)
|
|
87
|
+
|
|
88
|
+
def init_params(self):
|
|
89
|
+
return self.params
|
|
90
|
+
|
|
91
|
+
def __call__(self, *args, **kwargs):
|
|
92
|
+
return self.apply_fn(self, *args, **kwargs)
|
|
93
|
+
|
|
94
|
+
def _eval_nn(self, res):
|
|
95
|
+
"""
|
|
96
|
+
common content of apply_fn put here in order to factorize code
|
|
97
|
+
"""
|
|
98
|
+
a = ", ".join([f"{chr(97 + d)}z" for d in range(res.shape[1])])
|
|
99
|
+
b = "".join([f"{chr(97 + d)}" for d in range(res.shape[1])])
|
|
100
|
+
res = jnp.stack(
|
|
101
|
+
[
|
|
102
|
+
jnp.einsum(
|
|
103
|
+
f"{a} -> {b}",
|
|
104
|
+
*(
|
|
105
|
+
res[:, d, m * self.r : (m + 1) * self.r]
|
|
106
|
+
for d in range(res.shape[1])
|
|
107
|
+
),
|
|
108
|
+
)
|
|
109
|
+
for m in range(self.m)
|
|
110
|
+
],
|
|
111
|
+
axis=-1,
|
|
112
|
+
) # compute each output dimension
|
|
113
|
+
|
|
114
|
+
# force (1,) output for non vectorial solution (consistency)
|
|
115
|
+
if len(res.shape) == self.d:
|
|
116
|
+
return jnp.expand_dims(res, axis=-1)
|
|
117
|
+
else:
|
|
118
|
+
return res
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def create_SPINN(key, d, r, eqx_list, eq_type, m=1):
|
|
122
|
+
"""
|
|
123
|
+
Utility function to create a SPINN neural network with the equinox
|
|
124
|
+
library.
|
|
125
|
+
|
|
126
|
+
*Note* that a SPINN is not vmapped from the outside and expects batch of the
|
|
127
|
+
same size for each input. It outputs directly a solution of shape
|
|
128
|
+
(batchsize, batchsize). See the paper for more details.
|
|
129
|
+
|
|
130
|
+
Parameters
|
|
131
|
+
----------
|
|
132
|
+
key
|
|
133
|
+
A jax random key that will be used to initialize the network parameters
|
|
134
|
+
d
|
|
135
|
+
An integer. The number of dimensions to treat separately
|
|
136
|
+
r
|
|
137
|
+
An integer. The dimension of the embedding
|
|
138
|
+
eqx_list
|
|
139
|
+
A list of list of successive equinox modules and activation functions to
|
|
140
|
+
describe *each separable PINN architecture*.
|
|
141
|
+
The inner lists have the eqx module or
|
|
142
|
+
axtivation function as first item, other items represents arguments
|
|
143
|
+
that could be required (eg. the size of the layer).
|
|
144
|
+
__Note:__ the `key` argument need not be given.
|
|
145
|
+
Thus typical example is `eqx_list=
|
|
146
|
+
[[eqx.nn.Linear, d, 20],
|
|
147
|
+
[jax.nn.tanh],
|
|
148
|
+
[eqx.nn.Linear, 20, 20],
|
|
149
|
+
[jax.nn.tanh],
|
|
150
|
+
[eqx.nn.Linear, 20, 20],
|
|
151
|
+
[jax.nn.tanh],
|
|
152
|
+
[eqx.nn.Linear, 20, r]
|
|
153
|
+
]`
|
|
154
|
+
eq_type
|
|
155
|
+
A string with three possibilities.
|
|
156
|
+
"ODE": the PINN is called with one input `t`.
|
|
157
|
+
"statio_PDE": the PINN is called with one input `x`, `x`
|
|
158
|
+
can be high dimensional.
|
|
159
|
+
"nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
|
|
160
|
+
can be high dimensional.
|
|
161
|
+
**Note: the input dimension as given in eqx_list has to match the sum
|
|
162
|
+
of the dimension of `t` + the dimension of `x` + the number of
|
|
163
|
+
parameters in `eq_params` if with_eq_params is `True` (see below)**
|
|
164
|
+
m
|
|
165
|
+
An integer. The output dimension of the neural network. According to
|
|
166
|
+
the SPINN article, a total embedding dimension of `r*m` is defined. We
|
|
167
|
+
then sum groups of `r` embedding dimensions to compute each output.
|
|
168
|
+
Default is 1.
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
Returns
|
|
172
|
+
-------
|
|
173
|
+
init_fn
|
|
174
|
+
A function which (re-)initializes the SPINN parameters with the provided
|
|
175
|
+
jax random key
|
|
176
|
+
apply_fn
|
|
177
|
+
A function to apply the neural network on given inputs for given
|
|
178
|
+
parameters. A typical call will be of the form `u(t, nn_params)` for
|
|
179
|
+
ODE or `u(t, x, nn_params)` for nD PDEs (`x` being multidimensional)
|
|
180
|
+
or even `u(t, x, nn_params, eq_params)` if with_eq_params is `True`
|
|
181
|
+
|
|
182
|
+
Raises
|
|
183
|
+
------
|
|
184
|
+
RuntimeError
|
|
185
|
+
If the parameter value for eq_type is not in `["ODE", "statio_PDE",
|
|
186
|
+
"nonstatio_PDE"]` and for various failing checks
|
|
187
|
+
"""
|
|
188
|
+
|
|
189
|
+
if eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
|
|
190
|
+
raise RuntimeError("Wrong parameter value for eq_type")
|
|
191
|
+
|
|
192
|
+
try:
|
|
193
|
+
nb_inputs_declared = eqx_list[0][1] # normally we look for 2nd ele of 1st layer
|
|
194
|
+
except IndexError:
|
|
195
|
+
nb_inputs_declared = eqx_list[1][
|
|
196
|
+
1
|
|
197
|
+
] # but we can have, eg, a flatten first layer
|
|
198
|
+
if nb_inputs_declared != 1:
|
|
199
|
+
raise ValueError("Input dim must be set to 1 in SPINN!")
|
|
200
|
+
|
|
201
|
+
try:
|
|
202
|
+
nb_outputs_declared = eqx_list[-1][2] # normally we look for 3rd ele of
|
|
203
|
+
# last layer
|
|
204
|
+
except IndexError:
|
|
205
|
+
nb_outputs_declared = eqx_list[-2][2]
|
|
206
|
+
# but we can have, eg, a `jnp.exp` last layer
|
|
207
|
+
if nb_outputs_declared != r * m:
|
|
208
|
+
raise ValueError("Output dim must be set to r * m in SPINN!")
|
|
209
|
+
|
|
210
|
+
if d > 24:
|
|
211
|
+
raise ValueError(
|
|
212
|
+
"Too many dimensions, not enough letters" " available in jnp.einsum"
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
if eq_type == "statio_PDE":
|
|
216
|
+
|
|
217
|
+
def apply_fn(self, x, u_params, eq_params=None):
|
|
218
|
+
spinn = eqx.combine(u_params, self.static)
|
|
219
|
+
v_model = jax.vmap(spinn, (0))
|
|
220
|
+
res = v_model(t=None, x=x)
|
|
221
|
+
return self._eval_nn(res)
|
|
222
|
+
|
|
223
|
+
elif eq_type == "nonstatio_PDE":
|
|
224
|
+
|
|
225
|
+
def apply_fn(self, t, x, u_params, eq_params=None):
|
|
226
|
+
spinn = eqx.combine(u_params, self.static)
|
|
227
|
+
v_model = jax.vmap(spinn, ((0, 0)))
|
|
228
|
+
res = v_model(t, x)
|
|
229
|
+
return self._eval_nn(res)
|
|
230
|
+
|
|
231
|
+
else:
|
|
232
|
+
raise RuntimeError("Wrong parameter value for eq_type")
|
|
233
|
+
|
|
234
|
+
spinn = SPINN(key, d, r, eqx_list, m)
|
|
235
|
+
spinn.apply_fn = apply_fn
|
|
236
|
+
|
|
237
|
+
return spinn
|