jinns 0.8.10__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -0
- jinns/data/_Batchs.py +27 -0
- jinns/data/_DataGenerators.py +953 -1182
- jinns/data/__init__.py +4 -8
- jinns/experimental/__init__.py +0 -2
- jinns/experimental/_diffrax_solver.py +5 -5
- jinns/loss/_DynamicLoss.py +282 -305
- jinns/loss/_DynamicLossAbstract.py +321 -168
- jinns/loss/_LossODE.py +290 -307
- jinns/loss/_LossPDE.py +628 -1040
- jinns/loss/__init__.py +21 -5
- jinns/loss/_boundary_conditions.py +95 -96
- jinns/loss/{_Losses.py → _loss_utils.py} +104 -46
- jinns/loss/_loss_weights.py +59 -0
- jinns/loss/_operators.py +78 -72
- jinns/parameters/__init__.py +6 -0
- jinns/parameters/_derivative_keys.py +94 -0
- jinns/parameters/_params.py +115 -0
- jinns/plot/__init__.py +5 -0
- jinns/{data/_display.py → plot/_plot.py} +98 -75
- jinns/solver/_rar.py +193 -45
- jinns/solver/_solve.py +199 -144
- jinns/utils/__init__.py +3 -9
- jinns/utils/_containers.py +37 -43
- jinns/utils/_hyperpinn.py +226 -127
- jinns/utils/_pinn.py +183 -111
- jinns/utils/_save_load.py +121 -56
- jinns/utils/_spinn.py +117 -84
- jinns/utils/_types.py +64 -0
- jinns/utils/_utils.py +6 -160
- jinns/validation/_validation.py +52 -144
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/METADATA +5 -4
- jinns-1.0.0.dist-info/RECORD +38 -0
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/WHEEL +1 -1
- jinns/experimental/_sinuspinn.py +0 -135
- jinns/experimental/_spectralpinn.py +0 -87
- jinns/solver/_seq2seq.py +0 -157
- jinns/utils/_optim.py +0 -147
- jinns/utils/_utils_uspinn.py +0 -727
- jinns-0.8.10.dist-info/RECORD +0 -36
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/LICENSE +0 -0
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/top_level.txt +0 -0
jinns/utils/_utils_uspinn.py
DELETED
|
@@ -1,727 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import jax
|
|
3
|
-
import jax.numpy as jnp
|
|
4
|
-
import optax
|
|
5
|
-
import equinox as eqx
|
|
6
|
-
from functools import reduce
|
|
7
|
-
from operator import getitem
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
def _check_nan_in_pytree(pytree):
|
|
11
|
-
"""
|
|
12
|
-
Check if there is a NaN value anywhere is the pytree
|
|
13
|
-
|
|
14
|
-
Parameters
|
|
15
|
-
----------
|
|
16
|
-
pytree
|
|
17
|
-
A pytree
|
|
18
|
-
|
|
19
|
-
Returns
|
|
20
|
-
-------
|
|
21
|
-
res
|
|
22
|
-
A boolean. True if any of the pytree content is NaN
|
|
23
|
-
"""
|
|
24
|
-
return jnp.any(
|
|
25
|
-
jnp.array(
|
|
26
|
-
[
|
|
27
|
-
value
|
|
28
|
-
for value in jax.tree_util.tree_leaves(
|
|
29
|
-
jax.tree_util.tree_map(lambda x: jnp.any(jnp.isnan(x)), pytree)
|
|
30
|
-
)
|
|
31
|
-
]
|
|
32
|
-
)
|
|
33
|
-
)
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def _tracked_parameters(params, tracked_params_key_list):
|
|
37
|
-
"""
|
|
38
|
-
Returns a pytree with the same structure as params with True is the
|
|
39
|
-
parameter is tracked False otherwise
|
|
40
|
-
"""
|
|
41
|
-
|
|
42
|
-
def set_nested_item(dataDict, mapList, val):
|
|
43
|
-
"""
|
|
44
|
-
Set item in nested dictionary
|
|
45
|
-
https://stackoverflow.com/questions/54137991/how-to-update-values-in-nested-dictionary-if-keys-are-in-a-list
|
|
46
|
-
"""
|
|
47
|
-
reduce(getitem, mapList[:-1], dataDict)[mapList[-1]] = val
|
|
48
|
-
return dataDict
|
|
49
|
-
|
|
50
|
-
tracked_params = jax.tree_util.tree_map(
|
|
51
|
-
lambda x: False, params
|
|
52
|
-
) # init with all False
|
|
53
|
-
|
|
54
|
-
for key_list in tracked_params_key_list:
|
|
55
|
-
tracked_params = set_nested_item(tracked_params, key_list, True)
|
|
56
|
-
|
|
57
|
-
return tracked_params
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
class _MLP(eqx.Module):
|
|
61
|
-
"""
|
|
62
|
-
Class to construct an equinox module from a key and a eqx_list. To be used
|
|
63
|
-
in pair with the function `create_PINN`
|
|
64
|
-
"""
|
|
65
|
-
|
|
66
|
-
layers: list
|
|
67
|
-
|
|
68
|
-
def __init__(self, key, eqx_list):
|
|
69
|
-
"""
|
|
70
|
-
Parameters
|
|
71
|
-
----------
|
|
72
|
-
key
|
|
73
|
-
A jax random key
|
|
74
|
-
eqx_list
|
|
75
|
-
A list of list of successive equinox modules and activation functions to
|
|
76
|
-
describe the PINN architecture. The inner lists have the eqx module or
|
|
77
|
-
axtivation function as first item, other items represents arguments
|
|
78
|
-
that could be required (eg. the size of the layer).
|
|
79
|
-
__Note:__ the `key` argument need not be given.
|
|
80
|
-
Thus typical example is `eqx_list=
|
|
81
|
-
[[eqx.nn.Linear, 2, 20],
|
|
82
|
-
[jax.nn.tanh],
|
|
83
|
-
[eqx.nn.Linear, 20, 20],
|
|
84
|
-
[jax.nn.tanh],
|
|
85
|
-
[eqx.nn.Linear, 20, 20],
|
|
86
|
-
[jax.nn.tanh],
|
|
87
|
-
[eqx.nn.Linear, 20, 1]
|
|
88
|
-
]`
|
|
89
|
-
"""
|
|
90
|
-
|
|
91
|
-
self.layers = []
|
|
92
|
-
# TODO we are limited currently in the number of layer type we can
|
|
93
|
-
# parse and we lack some safety checks
|
|
94
|
-
for l in eqx_list:
|
|
95
|
-
if len(l) == 1:
|
|
96
|
-
self.layers.append(l[0])
|
|
97
|
-
else:
|
|
98
|
-
# By default we append a random key at the end of the
|
|
99
|
-
# arguments fed into a layer module call
|
|
100
|
-
key, subkey = jax.random.split(key, 2)
|
|
101
|
-
# the argument key is keyword only
|
|
102
|
-
self.layers.append(l[0](*l[1:], key=subkey))
|
|
103
|
-
|
|
104
|
-
def __call__(self, t):
|
|
105
|
-
for layer in self.layers:
|
|
106
|
-
t = layer(t)
|
|
107
|
-
return t
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
class PINN:
|
|
111
|
-
"""
|
|
112
|
-
Basically a wrapper around the `__call__` function to be able to give a type to
|
|
113
|
-
our former `self.u`
|
|
114
|
-
The function create_PINN has the role to population the `__call__` function
|
|
115
|
-
"""
|
|
116
|
-
|
|
117
|
-
def __init__(self, key, eqx_list, output_slice=None):
|
|
118
|
-
_pinn = _MLP(key, eqx_list)
|
|
119
|
-
self.params, self.static = eqx.partition(_pinn, eqx.is_inexact_array)
|
|
120
|
-
self.output_slice = output_slice
|
|
121
|
-
|
|
122
|
-
def init_params(self):
|
|
123
|
-
return self.params
|
|
124
|
-
|
|
125
|
-
def __call__(self, *args, **kwargs):
|
|
126
|
-
return self.apply_fn(self, *args, **kwargs)
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
def create_PINN(
|
|
130
|
-
key,
|
|
131
|
-
eqx_list,
|
|
132
|
-
eq_type,
|
|
133
|
-
dim_x=0,
|
|
134
|
-
with_eq_params=None,
|
|
135
|
-
input_transform=None,
|
|
136
|
-
output_transform=None,
|
|
137
|
-
shared_pinn_outputs=None,
|
|
138
|
-
):
|
|
139
|
-
"""
|
|
140
|
-
Utility function to create a standard PINN neural network with the equinox
|
|
141
|
-
library.
|
|
142
|
-
|
|
143
|
-
Parameters
|
|
144
|
-
----------
|
|
145
|
-
key
|
|
146
|
-
A jax random key that will be used to initialize the network parameters
|
|
147
|
-
eqx_list
|
|
148
|
-
A list of list of successive equinox modules and activation functions to
|
|
149
|
-
describe the PINN architecture. The inner lists have the eqx module or
|
|
150
|
-
axtivation function as first item, other items represents arguments
|
|
151
|
-
that could be required (eg. the size of the layer).
|
|
152
|
-
__Note:__ the `key` argument need not be given.
|
|
153
|
-
Thus typical example is `eqx_list=
|
|
154
|
-
[[eqx.nn.Linear, 2, 20],
|
|
155
|
-
[jax.nn.tanh],
|
|
156
|
-
[eqx.nn.Linear, 20, 20],
|
|
157
|
-
[jax.nn.tanh],
|
|
158
|
-
[eqx.nn.Linear, 20, 20],
|
|
159
|
-
[jax.nn.tanh],
|
|
160
|
-
[eqx.nn.Linear, 20, 1]
|
|
161
|
-
]`
|
|
162
|
-
eq_type
|
|
163
|
-
A string with three possibilities.
|
|
164
|
-
"ODE": the PINN is called with one input `t`.
|
|
165
|
-
"statio_PDE": the PINN is called with one input `x`, `x`
|
|
166
|
-
can be high dimensional.
|
|
167
|
-
"nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
|
|
168
|
-
can be high dimensional.
|
|
169
|
-
**Note: the input dimension as given in eqx_list has to match the sum
|
|
170
|
-
of the dimension of `t` + the dimension of `x` + the number of
|
|
171
|
-
parameters in `eq_params` if with_eq_params is `True` (see below)**
|
|
172
|
-
dim_x
|
|
173
|
-
An integer. The dimension of `x`. Default `0`
|
|
174
|
-
with_eq_params
|
|
175
|
-
Default is None. Otherwise a list of keys from the dict `eq_params`
|
|
176
|
-
that the network also takes as inputs.
|
|
177
|
-
the equation parameters (`eq_params`).
|
|
178
|
-
**If some keys are provided, the input dimension
|
|
179
|
-
as given in eqx_list must take into account the number of such provided
|
|
180
|
-
keys (i.e., the input dimension is the addition of the dimension of ``t``
|
|
181
|
-
+ the dimension of ``x`` + the number of ``eq_params``)**
|
|
182
|
-
input_transform
|
|
183
|
-
A function that will be called before entering the PINN. Its output(s)
|
|
184
|
-
must mathc the PINN inputs.
|
|
185
|
-
output_transform
|
|
186
|
-
A function with arguments the same input(s) as the PINN AND the PINN
|
|
187
|
-
output that will be called after exiting the PINN
|
|
188
|
-
shared_pinn_outputs
|
|
189
|
-
A tuple of jnp.s_[] (slices) to determine the different output for each
|
|
190
|
-
network. In this case we return a list of PINNs, one for each output in
|
|
191
|
-
shared_pinn_outputs. This is useful to create PINNs that share the
|
|
192
|
-
same network and same parameters. Default is None, we only return one PINN.
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
Returns
|
|
196
|
-
-------
|
|
197
|
-
init_fn
|
|
198
|
-
A function which (re-)initializes the PINN parameters with the provided
|
|
199
|
-
jax random key
|
|
200
|
-
apply_fn
|
|
201
|
-
A function to apply the neural network on given inputs for given
|
|
202
|
-
parameters. A typical call will be of the form `u(t, nn_params)` for
|
|
203
|
-
ODE or `u(t, x, nn_params)` for nD PDEs (`x` being multidimensional)
|
|
204
|
-
or even `u(t, x, nn_params, eq_params)` if with_eq_params is `True`
|
|
205
|
-
|
|
206
|
-
Raises
|
|
207
|
-
------
|
|
208
|
-
RuntimeError
|
|
209
|
-
If the parameter value for eq_type is not in `["ODE", "statio_PDE",
|
|
210
|
-
"nonstatio_PDE"]`
|
|
211
|
-
RuntimeError
|
|
212
|
-
If we have a `dim_x > 0` and `eq_type == "ODE"`
|
|
213
|
-
or if we have a `dim_x = 0` and `eq_type != "ODE"`
|
|
214
|
-
"""
|
|
215
|
-
if eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
|
|
216
|
-
raise RuntimeError("Wrong parameter value for eq_type")
|
|
217
|
-
|
|
218
|
-
if eq_type == "ODE" and dim_x != 0:
|
|
219
|
-
raise RuntimeError("Wrong parameter combination eq_type and dim_x")
|
|
220
|
-
|
|
221
|
-
if eq_type != "ODE" and dim_x == 0:
|
|
222
|
-
raise RuntimeError("Wrong parameter combination eq_type and dim_x")
|
|
223
|
-
|
|
224
|
-
dim_t = 0 if eq_type == "statio_PDE" else 1
|
|
225
|
-
dim_in_params = len(with_eq_params) if with_eq_params is not None else 0
|
|
226
|
-
try:
|
|
227
|
-
nb_inputs_declared = eqx_list[0][1] # normally we look for 2nd ele of 1st layer
|
|
228
|
-
except IndexError:
|
|
229
|
-
nb_inputs_declared = eqx_list[1][1]
|
|
230
|
-
# but we can have, eg, a flatten first layer
|
|
231
|
-
|
|
232
|
-
try:
|
|
233
|
-
nb_outputs_declared = eqx_list[-1][2] # normally we look for 3rd ele of
|
|
234
|
-
# last layer
|
|
235
|
-
except IndexError:
|
|
236
|
-
nb_outputs_declared = eqx_list[-2][2]
|
|
237
|
-
# but we can have, eg, a `jnp.exp` last layer
|
|
238
|
-
|
|
239
|
-
# NOTE Currently the check below is disabled because we added
|
|
240
|
-
# input_transform
|
|
241
|
-
# if dim_t + dim_x + dim_in_params != nb_inputs_declared:
|
|
242
|
-
# raise RuntimeError("Error in the declarations of the number of parameters")
|
|
243
|
-
|
|
244
|
-
if eq_type == "ODE":
|
|
245
|
-
if with_eq_params is None:
|
|
246
|
-
|
|
247
|
-
def apply_fn(self, t, u_params, eq_params=None):
|
|
248
|
-
model = eqx.combine(u_params, self.static)
|
|
249
|
-
t = t[
|
|
250
|
-
None
|
|
251
|
-
] # Note that we added a dimension to t which is lacking for the ODE batches
|
|
252
|
-
if output_transform is None:
|
|
253
|
-
if input_transform is not None:
|
|
254
|
-
res = model(input_transform(t)).squeeze()
|
|
255
|
-
else:
|
|
256
|
-
res = model(t).squeeze()
|
|
257
|
-
else:
|
|
258
|
-
if input_transform is not None:
|
|
259
|
-
res = output_transform(t, model(input_transform(t)).squeeze())
|
|
260
|
-
else:
|
|
261
|
-
res = output_transform(t, model(t).squeeze())
|
|
262
|
-
if self.output_slice is not None:
|
|
263
|
-
return res[self.output_slice]
|
|
264
|
-
else:
|
|
265
|
-
return res
|
|
266
|
-
|
|
267
|
-
else:
|
|
268
|
-
|
|
269
|
-
def apply_fn(self, t, u_params, eq_params):
|
|
270
|
-
model = eqx.combine(u_params, self.static)
|
|
271
|
-
t = t[
|
|
272
|
-
None
|
|
273
|
-
] # We added a dimension to t which is lacking for the ODE batches
|
|
274
|
-
eq_params_flatten = jnp.concatenate(
|
|
275
|
-
[e.ravel() for k, e in eq_params.items() if k in with_eq_params]
|
|
276
|
-
)
|
|
277
|
-
t_eq_params = jnp.concatenate([t, eq_params_flatten], axis=-1)
|
|
278
|
-
|
|
279
|
-
if output_transform is None:
|
|
280
|
-
if input_transform is not None:
|
|
281
|
-
res = model(input_transform(t_eq_params)).squeeze()
|
|
282
|
-
else:
|
|
283
|
-
res = model(t_eq_params).squeeze()
|
|
284
|
-
else:
|
|
285
|
-
if input_transform is not None:
|
|
286
|
-
res = output_transform(
|
|
287
|
-
t_eq_params,
|
|
288
|
-
model(input_transform(t_eq_params)).squeeze(),
|
|
289
|
-
)
|
|
290
|
-
else:
|
|
291
|
-
res = output_transform(
|
|
292
|
-
t_eq_params, model(t_eq_params).squeeze()
|
|
293
|
-
)
|
|
294
|
-
|
|
295
|
-
if self.output_slice is not None:
|
|
296
|
-
return res[self.output_slice]
|
|
297
|
-
else:
|
|
298
|
-
return res
|
|
299
|
-
|
|
300
|
-
elif eq_type == "statio_PDE":
|
|
301
|
-
# Here we add an argument `x` which can be high dimensional
|
|
302
|
-
if with_eq_params is None:
|
|
303
|
-
|
|
304
|
-
def apply_fn(self, x, u_params, eq_params=None):
|
|
305
|
-
model = eqx.combine(u_params, self.static)
|
|
306
|
-
|
|
307
|
-
if output_transform is None:
|
|
308
|
-
if input_transform is not None:
|
|
309
|
-
res = model(input_transform(x)).squeeze()
|
|
310
|
-
else:
|
|
311
|
-
res = model(x).squeeze()
|
|
312
|
-
else:
|
|
313
|
-
if input_transform is not None:
|
|
314
|
-
res = output_transform(x, model(input_transform(x)).squeeze())
|
|
315
|
-
else:
|
|
316
|
-
res = output_transform(x, model(x).squeeze()).squeeze()
|
|
317
|
-
|
|
318
|
-
if self.output_slice is not None:
|
|
319
|
-
res = res[self.output_slice]
|
|
320
|
-
|
|
321
|
-
# force (1,) output for non vectorial solution (consistency)
|
|
322
|
-
if not res.shape:
|
|
323
|
-
return jnp.expand_dims(res, axis=-1)
|
|
324
|
-
else:
|
|
325
|
-
return res
|
|
326
|
-
|
|
327
|
-
else:
|
|
328
|
-
|
|
329
|
-
def apply_fn(self, x, u_params, eq_params):
|
|
330
|
-
model = eqx.combine(u_params, self.static)
|
|
331
|
-
eq_params_flatten = jnp.concatenate(
|
|
332
|
-
[e.ravel() for k, e in eq_params.items() if k in with_eq_params]
|
|
333
|
-
)
|
|
334
|
-
x_eq_params = jnp.concatenate([x, eq_params_flatten], axis=-1)
|
|
335
|
-
|
|
336
|
-
if output_transform is None:
|
|
337
|
-
if input_transform is not None:
|
|
338
|
-
res = model(input_transform(x_eq_params)).squeeze()
|
|
339
|
-
else:
|
|
340
|
-
res = model(x_eq_params).squeeze()
|
|
341
|
-
else:
|
|
342
|
-
if input_transform is not None:
|
|
343
|
-
res = output_transform(
|
|
344
|
-
x_eq_params,
|
|
345
|
-
model(input_transform(x_eq_params)).squeeze(),
|
|
346
|
-
)
|
|
347
|
-
else:
|
|
348
|
-
res = output_transform(
|
|
349
|
-
x_eq_params, model(x_eq_params).squeeze()
|
|
350
|
-
)
|
|
351
|
-
|
|
352
|
-
if self.output_slice is not None:
|
|
353
|
-
res = res[self.output_slice]
|
|
354
|
-
|
|
355
|
-
# force (1,) output for non vectorial solution (consistency)
|
|
356
|
-
if not res.shape:
|
|
357
|
-
return jnp.expand_dims(res, axis=-1)
|
|
358
|
-
else:
|
|
359
|
-
return res
|
|
360
|
-
|
|
361
|
-
elif eq_type == "nonstatio_PDE":
|
|
362
|
-
# Here we add an argument `x` which can be high dimensional
|
|
363
|
-
if with_eq_params is None:
|
|
364
|
-
|
|
365
|
-
def apply_fn(self, t, x, u_params, eq_params=None):
|
|
366
|
-
model = eqx.combine(u_params, self.static)
|
|
367
|
-
t_x = jnp.concatenate([t, x], axis=-1)
|
|
368
|
-
|
|
369
|
-
if output_transform is None:
|
|
370
|
-
if input_transform is not None:
|
|
371
|
-
res = model(input_transform(t_x)).squeeze()
|
|
372
|
-
else:
|
|
373
|
-
res = model(t_x).squeeze()
|
|
374
|
-
else:
|
|
375
|
-
if input_transform is not None:
|
|
376
|
-
res = output_transform(
|
|
377
|
-
t_x, model(input_transform(t_x)).squeeze()
|
|
378
|
-
)
|
|
379
|
-
else:
|
|
380
|
-
res = output_transform(t_x, model(t_x).squeeze())
|
|
381
|
-
|
|
382
|
-
if self.output_slice is not None:
|
|
383
|
-
res = res[self.output_slice]
|
|
384
|
-
|
|
385
|
-
## force (1,) output for non vectorial solution (consistency)
|
|
386
|
-
if not res.shape:
|
|
387
|
-
return jnp.expand_dims(res, axis=-1)
|
|
388
|
-
else:
|
|
389
|
-
return res
|
|
390
|
-
|
|
391
|
-
else:
|
|
392
|
-
|
|
393
|
-
def apply_fn(self, t, x, u_params, eq_params):
|
|
394
|
-
model = eqx.combine(u_params, self.static)
|
|
395
|
-
t_x = jnp.concatenate([t, x], axis=-1)
|
|
396
|
-
eq_params_flatten = jnp.concatenate(
|
|
397
|
-
[e.ravel() for k, e in eq_params.items() if k in with_eq_params]
|
|
398
|
-
)
|
|
399
|
-
t_x_eq_params = jnp.concatenate([t_x, eq_params_flatten], axis=-1)
|
|
400
|
-
|
|
401
|
-
if output_transform is None:
|
|
402
|
-
if input_transform is not None:
|
|
403
|
-
res = model(input_transform(t_x_eq_params)).squeeze()
|
|
404
|
-
else:
|
|
405
|
-
res = model(t_x_eq_params).squeeze()
|
|
406
|
-
else:
|
|
407
|
-
if input_transform is not None:
|
|
408
|
-
res = output_transform(
|
|
409
|
-
t_x_eq_params,
|
|
410
|
-
model(input_transform(t_x_eq_params)).squeeze(),
|
|
411
|
-
)
|
|
412
|
-
else:
|
|
413
|
-
res = output_transform(
|
|
414
|
-
t_x_eq_params,
|
|
415
|
-
model(input_transform(t_x_eq_params)).squeeze(),
|
|
416
|
-
)
|
|
417
|
-
|
|
418
|
-
if self.output_slice is not None:
|
|
419
|
-
res = res[self.output_slice]
|
|
420
|
-
|
|
421
|
-
# force (1,) output for non vectorial solution (consistency)
|
|
422
|
-
if not res.shape:
|
|
423
|
-
return jnp.expand_dims(res, axis=-1)
|
|
424
|
-
else:
|
|
425
|
-
return res
|
|
426
|
-
|
|
427
|
-
else:
|
|
428
|
-
raise RuntimeError("Wrong parameter value for eq_type")
|
|
429
|
-
|
|
430
|
-
if shared_pinn_outputs is not None:
|
|
431
|
-
pinns = []
|
|
432
|
-
static = None
|
|
433
|
-
for output_slice in shared_pinn_outputs:
|
|
434
|
-
pinn = PINN(key, eqx_list, output_slice)
|
|
435
|
-
pinn.apply_fn = apply_fn
|
|
436
|
-
# all the pinns are in fact the same so we share the same static
|
|
437
|
-
if static is None:
|
|
438
|
-
static = pinn.static
|
|
439
|
-
else:
|
|
440
|
-
pinn.static = static
|
|
441
|
-
pinns.append(pinn)
|
|
442
|
-
return pinns
|
|
443
|
-
else:
|
|
444
|
-
pinn = PINN(key, eqx_list)
|
|
445
|
-
pinn.apply_fn = apply_fn
|
|
446
|
-
return pinn
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
class _SPINN(eqx.Module):
|
|
450
|
-
"""
|
|
451
|
-
Construct a Separable PINN as proposed in
|
|
452
|
-
Cho et al., _Separable Physics-Informed Neural Networks_, NeurIPS, 2023
|
|
453
|
-
"""
|
|
454
|
-
|
|
455
|
-
layers: list
|
|
456
|
-
separated_mlp: list
|
|
457
|
-
d: int
|
|
458
|
-
r: int
|
|
459
|
-
m: int
|
|
460
|
-
|
|
461
|
-
def __init__(self, key, d, r, eqx_list, m=1):
|
|
462
|
-
"""
|
|
463
|
-
Parameters
|
|
464
|
-
----------
|
|
465
|
-
key
|
|
466
|
-
A jax random key
|
|
467
|
-
d
|
|
468
|
-
An integer. The number of dimensions to treat separately
|
|
469
|
-
r
|
|
470
|
-
An integer. The dimension of the embedding
|
|
471
|
-
eqx_list
|
|
472
|
-
A list of list of successive equinox modules and activation functions to
|
|
473
|
-
describe *each separable PINN architecture*.
|
|
474
|
-
The inner lists have the eqx module or
|
|
475
|
-
axtivation function as first item, other items represents arguments
|
|
476
|
-
that could be required (eg. the size of the layer).
|
|
477
|
-
__Note:__ the `key` argument need not be given.
|
|
478
|
-
Thus typical example is `eqx_list=
|
|
479
|
-
[[eqx.nn.Linear, d, 20],
|
|
480
|
-
[jax.nn.tanh],
|
|
481
|
-
[eqx.nn.Linear, 20, 20],
|
|
482
|
-
[jax.nn.tanh],
|
|
483
|
-
[eqx.nn.Linear, 20, 20],
|
|
484
|
-
[jax.nn.tanh],
|
|
485
|
-
[eqx.nn.Linear, 20, r]
|
|
486
|
-
]`
|
|
487
|
-
"""
|
|
488
|
-
keys = jax.random.split(key, 8)
|
|
489
|
-
|
|
490
|
-
self.d = d
|
|
491
|
-
self.r = r
|
|
492
|
-
self.m = m
|
|
493
|
-
|
|
494
|
-
self.separated_mlp = []
|
|
495
|
-
for d in range(self.d):
|
|
496
|
-
self.layers = []
|
|
497
|
-
for l in eqx_list:
|
|
498
|
-
if len(l) == 1:
|
|
499
|
-
self.layers.append(l[0])
|
|
500
|
-
else:
|
|
501
|
-
key, subkey = jax.random.split(key, 2)
|
|
502
|
-
self.layers.append(l[0](*l[1:], key=subkey))
|
|
503
|
-
self.separated_mlp.append(self.layers)
|
|
504
|
-
|
|
505
|
-
def __call__(self, t, x):
|
|
506
|
-
if t is not None:
|
|
507
|
-
dimensions = jnp.concatenate([t, x.flatten()], axis=0)
|
|
508
|
-
else:
|
|
509
|
-
dimensions = jnp.concatenate([x.flatten()], axis=0)
|
|
510
|
-
outputs = []
|
|
511
|
-
for d in range(self.d):
|
|
512
|
-
t_ = dimensions[d][None]
|
|
513
|
-
for layer in self.separated_mlp[d]:
|
|
514
|
-
t_ = layer(t_)
|
|
515
|
-
outputs += [t_]
|
|
516
|
-
return jnp.asarray(outputs)
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
def _get_grid(in_array):
|
|
520
|
-
"""
|
|
521
|
-
From an array of shape (B, D), D > 1, get the grid array, i.e., an array of
|
|
522
|
-
shape (B, B, ...(D times)..., B, D): along the last axis we have the array
|
|
523
|
-
of values
|
|
524
|
-
"""
|
|
525
|
-
if in_array.shape[-1] > 1 or in_array.ndim > 1:
|
|
526
|
-
return jnp.stack(
|
|
527
|
-
jnp.meshgrid(
|
|
528
|
-
*(in_array[..., d] for d in range(in_array.shape[-1])), indexing="ij"
|
|
529
|
-
),
|
|
530
|
-
axis=-1,
|
|
531
|
-
)
|
|
532
|
-
else:
|
|
533
|
-
return in_array
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
def _get_vmap_in_axes_params(eq_params_batch_dict, params):
|
|
537
|
-
"""
|
|
538
|
-
Return the input vmap axes when there is batch(es) of parameters to vmap
|
|
539
|
-
over. The latter are designated by keys in eq_params_batch_dict
|
|
540
|
-
If eq_params_batch_dict (ie no additional parameter batch), we return None
|
|
541
|
-
"""
|
|
542
|
-
if eq_params_batch_dict is None:
|
|
543
|
-
return (None,)
|
|
544
|
-
else:
|
|
545
|
-
# We use pytree indexing of vmapped axes and vmap on axis
|
|
546
|
-
# 0 of the eq_parameters for which we have a batch
|
|
547
|
-
# this is for a fine-grained vmaping
|
|
548
|
-
# scheme over the params
|
|
549
|
-
vmap_in_axes_params = (
|
|
550
|
-
{
|
|
551
|
-
"eq_params": {
|
|
552
|
-
k: (0 if k in eq_params_batch_dict.keys() else None)
|
|
553
|
-
for k in params["eq_params"].keys()
|
|
554
|
-
},
|
|
555
|
-
"nn_params": None,
|
|
556
|
-
},
|
|
557
|
-
)
|
|
558
|
-
return vmap_in_axes_params
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
def _check_user_func_return(r, shape):
|
|
562
|
-
"""
|
|
563
|
-
Correctly handles the result from a user defined function (eg a boundary
|
|
564
|
-
condition) to get the correct broadcast
|
|
565
|
-
"""
|
|
566
|
-
if isinstance(r, int) or isinstance(r, float):
|
|
567
|
-
# if we have a scalar cast it to float
|
|
568
|
-
return float(r)
|
|
569
|
-
if r.shape == () or len(r.shape) == 1:
|
|
570
|
-
# if we have a scalar (or a vector, but no batch dim) inside an array
|
|
571
|
-
return r.astype(float)
|
|
572
|
-
else:
|
|
573
|
-
# if we have an array of the shape of the batch dimension(s) check that
|
|
574
|
-
# we have the correct broadcast
|
|
575
|
-
# the reshape below avoids a missing (1,) ending dimension
|
|
576
|
-
# depending on how the user has coded the inital function
|
|
577
|
-
return r.reshape(shape)
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
def alternate_optax_solver(
|
|
581
|
-
steps, parameters_set1, parameters_set2, lr_set1, lr_set2, label_fn=None
|
|
582
|
-
):
|
|
583
|
-
"""
|
|
584
|
-
This function creates an optax optimizer that alternates the optimization
|
|
585
|
-
between two set of parameters (ie. when some parameters are update to a
|
|
586
|
-
given learning rates, others are not updated (learning rate = 0)
|
|
587
|
-
The optimizers are scaled by adam parameters.
|
|
588
|
-
|
|
589
|
-
__Note:__ The alternating pattern relies on
|
|
590
|
-
`optax.piecewise_constant_schedule` which __multiplies__ learning rates of
|
|
591
|
-
previous steps (current included) to set the new learning rate. Hence, our
|
|
592
|
-
strategy used here is to relying on potentially cancelling power of tens to
|
|
593
|
-
create the alternating scheme.
|
|
594
|
-
|
|
595
|
-
Parameters
|
|
596
|
-
----------
|
|
597
|
-
steps
|
|
598
|
-
An array which describes the epochis number at which we alternate the
|
|
599
|
-
optimization: the parameter_set that is being updated now stops
|
|
600
|
-
updating, the other parameter_set starts updating.
|
|
601
|
-
__Note:__ The step 0 should not be included
|
|
602
|
-
parameters_set1
|
|
603
|
-
A list of leaf level keys which must be found in the general `params` dict. The
|
|
604
|
-
parameters in this `set1` will be the parameters which are updated
|
|
605
|
-
first in the alternating scheme.
|
|
606
|
-
parameters_set2
|
|
607
|
-
A list of leaf level keys which must be found in the general `params` dict. The
|
|
608
|
-
parameters in this `set2` will be the parameters which are not updated
|
|
609
|
-
first in the alternating scheme.
|
|
610
|
-
lr_set1
|
|
611
|
-
A float. The learning rate of updates for set1.
|
|
612
|
-
lr_set2
|
|
613
|
-
A float. The learning rate of updates for set2.
|
|
614
|
-
label_fn
|
|
615
|
-
The same function as the label_fn function passed in an optax
|
|
616
|
-
`multi_transform`
|
|
617
|
-
[https://optax.readthedocs.io/en/latest/api.html#optax.multi_transform](see
|
|
618
|
-
here)
|
|
619
|
-
Default None, ie, we already internally provide the default one (as
|
|
620
|
-
proposed in the optax documentation) which may suit many use cases
|
|
621
|
-
|
|
622
|
-
Returns
|
|
623
|
-
-------
|
|
624
|
-
tx
|
|
625
|
-
The optax optimizer object
|
|
626
|
-
"""
|
|
627
|
-
|
|
628
|
-
def map_nested_fn(fn):
|
|
629
|
-
"""
|
|
630
|
-
Recursively apply `fn` to the key-value pairs of a nested dict
|
|
631
|
-
We follow the example from
|
|
632
|
-
https://optax.readthedocs.io/en/latest/api.html#optax.multi_transform
|
|
633
|
-
for different learning rates
|
|
634
|
-
"""
|
|
635
|
-
|
|
636
|
-
def map_fn(nested_dict):
|
|
637
|
-
return {
|
|
638
|
-
k: (map_fn(v) if isinstance(v, dict) else fn(k, v))
|
|
639
|
-
for k, v in nested_dict.items()
|
|
640
|
-
}
|
|
641
|
-
|
|
642
|
-
return map_fn
|
|
643
|
-
|
|
644
|
-
label_fn = map_nested_fn(lambda k, _: k)
|
|
645
|
-
|
|
646
|
-
power_to_0 = 1e-25 # power of ten used to force a learning rate to 0
|
|
647
|
-
power_to_lr = 1 / power_to_0 # power of ten used to force a learning rate to lr
|
|
648
|
-
nn_params_scheduler = optax.piecewise_constant_schedule(
|
|
649
|
-
init_value=lr_set1,
|
|
650
|
-
boundaries_and_scales={
|
|
651
|
-
k: (
|
|
652
|
-
power_to_0
|
|
653
|
-
if even_odd % 2 == 0 # set lr to 0 eg if even_odd is even ie at
|
|
654
|
-
# first step
|
|
655
|
-
else power_to_lr
|
|
656
|
-
)
|
|
657
|
-
for even_odd, k in enumerate(steps)
|
|
658
|
-
},
|
|
659
|
-
)
|
|
660
|
-
eq_params_scheduler = optax.piecewise_constant_schedule(
|
|
661
|
-
init_value=power_to_0 * lr_set2, # so normal learning rate is 1e-3
|
|
662
|
-
boundaries_and_scales={
|
|
663
|
-
k: (power_to_lr if even_odd % 2 == 0 else power_to_0)
|
|
664
|
-
for even_odd, k in enumerate(steps)
|
|
665
|
-
},
|
|
666
|
-
)
|
|
667
|
-
|
|
668
|
-
# the scheduler for set1 is called nn_chain because we usually start by
|
|
669
|
-
# updating the NN parameters
|
|
670
|
-
nn_chain = optax.chain(
|
|
671
|
-
optax.scale_by_adam(),
|
|
672
|
-
optax.scale_by_schedule(nn_params_scheduler),
|
|
673
|
-
optax.scale(-1.0),
|
|
674
|
-
)
|
|
675
|
-
eq_chain = optax.chain(
|
|
676
|
-
optax.scale_by_adam(),
|
|
677
|
-
optax.scale_by_schedule(eq_params_scheduler),
|
|
678
|
-
optax.scale(-1.0),
|
|
679
|
-
)
|
|
680
|
-
dict_params_set1 = {p: nn_chain for p in parameters_set1}
|
|
681
|
-
dict_params_set2 = {p: eq_chain for p in parameters_set2}
|
|
682
|
-
tx = optax.multi_transform(
|
|
683
|
-
{**dict_params_set1, **dict_params_set2},
|
|
684
|
-
label_fn,
|
|
685
|
-
)
|
|
686
|
-
|
|
687
|
-
return tx
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
def euler_maruyama_density(t, x, s, y, params, Tmax=1):
|
|
691
|
-
eps = 1e-6
|
|
692
|
-
delta = jnp.abs(t - s) * Tmax
|
|
693
|
-
mu = params["alpha_sde"] * (params["mu_sde"] - y) * delta
|
|
694
|
-
var = params["sigma_sde"] ** 2 * delta
|
|
695
|
-
return (
|
|
696
|
-
1 / jnp.sqrt(2 * jnp.pi * var) * jnp.exp(-0.5 * ((x - y) - mu) ** 2 / var) + eps
|
|
697
|
-
)
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
def log_euler_maruyama_density(t, x, s, y, params):
|
|
701
|
-
eps = 1e-6
|
|
702
|
-
delta = jnp.abs(t - s)
|
|
703
|
-
mu = params["alpha_sde"] * (params["mu_sde"] - y) * delta
|
|
704
|
-
logvar = params["logvar_sde"]
|
|
705
|
-
return (
|
|
706
|
-
-0.5
|
|
707
|
-
* (jnp.log(2 * jnp.pi * delta) + logvar + ((x - y) - mu) ** 2 / jnp.exp(logvar))
|
|
708
|
-
+ eps
|
|
709
|
-
)
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
def euler_maruyama(x0, alpha, mu, sigma, T, N):
|
|
713
|
-
"""
|
|
714
|
-
Simulate 1D diffusion process with simple parametrization using the Euler
|
|
715
|
-
Maruyama method in the interval [0, T]
|
|
716
|
-
"""
|
|
717
|
-
path = [np.array([x0])]
|
|
718
|
-
|
|
719
|
-
time_steps, step_size = np.linspace(0, T, N, retstep=True)
|
|
720
|
-
for i in time_steps[1:]:
|
|
721
|
-
path.append(
|
|
722
|
-
path[-1]
|
|
723
|
-
+ step_size * (alpha * (mu - path[-1]))
|
|
724
|
-
+ sigma * np.random.normal(loc=0.0, scale=np.sqrt(step_size))
|
|
725
|
-
)
|
|
726
|
-
|
|
727
|
-
return time_steps, np.stack(path)
|