jinns 0.4.2__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_display.py +78 -21
- jinns/loss/_DynamicLoss.py +405 -907
- jinns/loss/_LossPDE.py +303 -154
- jinns/loss/__init__.py +0 -6
- jinns/loss/_boundary_conditions.py +231 -65
- jinns/loss/_operators.py +201 -45
- jinns/utils/__init__.py +2 -1
- jinns/utils/_pinn.py +308 -0
- jinns/utils/_spinn.py +237 -0
- jinns/utils/_utils.py +32 -306
- {jinns-0.4.2.dist-info → jinns-0.5.0.dist-info}/METADATA +15 -2
- jinns-0.5.0.dist-info/RECORD +24 -0
- jinns-0.4.2.dist-info/RECORD +0 -22
- {jinns-0.4.2.dist-info → jinns-0.5.0.dist-info}/LICENSE +0 -0
- {jinns-0.4.2.dist-info → jinns-0.5.0.dist-info}/WHEEL +0 -0
- {jinns-0.4.2.dist-info → jinns-0.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
import jax
|
|
2
2
|
import jax.numpy as jnp
|
|
3
3
|
from jax import vmap, grad
|
|
4
|
+
from jinns.utils._utils import _get_grid, _check_user_func_return
|
|
5
|
+
from jinns.utils._pinn import PINN
|
|
6
|
+
from jinns.utils._spinn import SPINN
|
|
4
7
|
|
|
5
8
|
|
|
6
9
|
def _compute_boundary_loss_statio(
|
|
@@ -127,18 +130,32 @@ def boundary_dirichlet_statio(f, border_batch, u, params):
|
|
|
127
130
|
dictionaries: `eq_params` and `nn_params``, respectively the
|
|
128
131
|
differential equation parameters and the neural network parameter
|
|
129
132
|
"""
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
dx
|
|
133
|
-
|
|
134
|
-
|
|
133
|
+
if isinstance(u, PINN):
|
|
134
|
+
v_u_boundary = vmap(
|
|
135
|
+
lambda dx: u(
|
|
136
|
+
dx,
|
|
137
|
+
u_params=params["nn_params"],
|
|
138
|
+
eq_params=jax.lax.stop_gradient(params["eq_params"]),
|
|
139
|
+
)
|
|
140
|
+
- f(dx),
|
|
141
|
+
(0),
|
|
142
|
+
0,
|
|
135
143
|
)
|
|
136
|
-
- f(dx),
|
|
137
|
-
(0),
|
|
138
|
-
0,
|
|
139
|
-
)
|
|
140
144
|
|
|
141
|
-
|
|
145
|
+
mse_u_boundary = jnp.sum((v_u_boundary(border_batch)) ** 2, axis=-1)
|
|
146
|
+
elif isinstance(u, SPINN):
|
|
147
|
+
values = u(
|
|
148
|
+
border_batch,
|
|
149
|
+
params["nn_params"],
|
|
150
|
+
jax.lax.stop_gradient(params["eq_params"]),
|
|
151
|
+
)
|
|
152
|
+
x_grid = _get_grid(border_batch)
|
|
153
|
+
boundaries = _check_user_func_return(f(x_grid), values.shape)
|
|
154
|
+
res = values - boundaries
|
|
155
|
+
mse_u_boundary = jnp.sum(
|
|
156
|
+
res**2,
|
|
157
|
+
axis=-1,
|
|
158
|
+
)
|
|
142
159
|
return mse_u_boundary
|
|
143
160
|
|
|
144
161
|
|
|
@@ -177,17 +194,69 @@ def boundary_neumann_statio(f, border_batch, u, params, facet):
|
|
|
177
194
|
# border_batch shape (batch_size, ndim, nfacets)
|
|
178
195
|
n = jnp.array([[-1, 1, 0, 0], [0, 0, -1, 1]])
|
|
179
196
|
|
|
180
|
-
|
|
181
|
-
lambda
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
197
|
+
if isinstance(u, PINN):
|
|
198
|
+
u_ = lambda x, nn, eq: u(t, x, nn, eq)[0]
|
|
199
|
+
v_neumann = vmap(
|
|
200
|
+
lambda dx: jnp.dot(
|
|
201
|
+
grad(u_, 0)(
|
|
202
|
+
dx, params["nn_params"], jax.lax.stop_gradient(params["eq_params"])
|
|
203
|
+
),
|
|
204
|
+
n[..., facet],
|
|
205
|
+
)
|
|
206
|
+
- f(dx),
|
|
207
|
+
0,
|
|
186
208
|
)
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
209
|
+
mse_u_boundary = jnp.sum((v_neumann(border_batch)) ** 2, axis=-1)
|
|
210
|
+
elif isinstance(u, SPINN):
|
|
211
|
+
# the gradient we see in the PINN case can get gradients wrt to x
|
|
212
|
+
# dimensions at once. But it would be very inefficient in SPINN because
|
|
213
|
+
# of the high dim output of u. So we do 2 explicit forward AD, handling all the
|
|
214
|
+
# high dim output at once
|
|
215
|
+
if border_batch.shape[0] == 1: # i.e. case 1D
|
|
216
|
+
_, du_dx = jax.jvp(
|
|
217
|
+
lambda x: u(
|
|
218
|
+
x,
|
|
219
|
+
params["nn_params"],
|
|
220
|
+
jax.lax.stop_gradient(params["eq_params"]),
|
|
221
|
+
),
|
|
222
|
+
(omega_border_batch,),
|
|
223
|
+
(jnp.ones_like(x),),
|
|
224
|
+
)
|
|
225
|
+
values = du_dx * n[facet]
|
|
226
|
+
elif omega_border_batch.shape[-1] == 2:
|
|
227
|
+
tangent_vec_0 = jnp.repeat(
|
|
228
|
+
jnp.array([1.0, 0.0])[None], omega_border_batch.shape[0], axis=0
|
|
229
|
+
)
|
|
230
|
+
tangent_vec_1 = jnp.repeat(
|
|
231
|
+
jnp.array([0.0, 1.0])[None], omega_border_batch.shape[0], axis=0
|
|
232
|
+
)
|
|
233
|
+
_, du_dx1 = jax.jvp(
|
|
234
|
+
lambda x: u(
|
|
235
|
+
x,
|
|
236
|
+
params["nn_params"],
|
|
237
|
+
jax.lax.stop_gradient(params["eq_params"]),
|
|
238
|
+
),
|
|
239
|
+
(omega_border_batch,),
|
|
240
|
+
(tangent_vec_0,),
|
|
241
|
+
)
|
|
242
|
+
_, du_dx2 = jax.jvp(
|
|
243
|
+
lambda x: u(
|
|
244
|
+
x,
|
|
245
|
+
params["nn_params"],
|
|
246
|
+
jax.lax.stop_gradient(params["eq_params"]),
|
|
247
|
+
),
|
|
248
|
+
(omega_border_batch,),
|
|
249
|
+
(tangent_vec_1,),
|
|
250
|
+
)
|
|
251
|
+
values = du_dx1 * n[0, facet] + du_dx2 * n[1, facet] # dot product
|
|
252
|
+
# explicitly written
|
|
253
|
+
else:
|
|
254
|
+
raise ValueError("Not implemented, we'll do that with a loop")
|
|
255
|
+
|
|
256
|
+
x_grid = _get_grid(border_batch)
|
|
257
|
+
boundaries = _check_user_func_return(f(x_grid), values.shape)
|
|
258
|
+
res = values - boundaries
|
|
259
|
+
mse_u_boundary = jnp.sum(res**2, axis=-1)
|
|
191
260
|
return mse_u_boundary
|
|
192
261
|
|
|
193
262
|
|
|
@@ -212,30 +281,56 @@ def boundary_dirichlet_nonstatio(f, times_batch, omega_border_batch, u, params):
|
|
|
212
281
|
dictionaries: `eq_params` and `nn_params``, respectively the
|
|
213
282
|
differential equation parameters and the neural network parameter
|
|
214
283
|
"""
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
t,
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
284
|
+
if isinstance(u, PINN):
|
|
285
|
+
tile_omega_border_batch = jnp.tile(
|
|
286
|
+
omega_border_batch, reps=(times_batch.shape[0], 1)
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
def rep_times(k):
|
|
290
|
+
return jnp.repeat(times_batch, k, axis=0)
|
|
291
|
+
|
|
292
|
+
v_u_boundary = vmap(
|
|
293
|
+
lambda t, dx: u(
|
|
294
|
+
t,
|
|
295
|
+
dx,
|
|
296
|
+
u_params=params["nn_params"],
|
|
297
|
+
eq_params=jax.lax.stop_gradient(params["eq_params"]),
|
|
298
|
+
)
|
|
299
|
+
- f(t, dx),
|
|
300
|
+
(0, 0),
|
|
301
|
+
0,
|
|
302
|
+
)
|
|
303
|
+
res = v_u_boundary(
|
|
304
|
+
rep_times(omega_border_batch.shape[0]), tile_omega_border_batch
|
|
305
|
+
) # TODO check if this cartesian product is always relevant
|
|
306
|
+
mse_u_boundary = jnp.sum(
|
|
307
|
+
res**2,
|
|
308
|
+
axis=-1,
|
|
309
|
+
)
|
|
310
|
+
elif isinstance(u, SPINN):
|
|
311
|
+
tile_omega_border_batch = jnp.tile(
|
|
312
|
+
omega_border_batch, reps=(times_batch.shape[0], 1)
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
if omega_border_batch.shape[0] == 1:
|
|
316
|
+
omega_border_batch = jnp.tile(
|
|
317
|
+
omega_border_batch, reps=(times_batch.shape[0], 1)
|
|
318
|
+
)
|
|
319
|
+
# otherwise we require batches to have same shape and we do not need
|
|
320
|
+
# this operation
|
|
321
|
+
|
|
322
|
+
values = u(
|
|
323
|
+
times_batch,
|
|
324
|
+
tile_omega_border_batch,
|
|
325
|
+
params["nn_params"],
|
|
326
|
+
jax.lax.stop_gradient(params["eq_params"]),
|
|
327
|
+
)
|
|
328
|
+
tx_grid = _get_grid(jnp.concatenate([times_batch, omega_border_batch], axis=-1))
|
|
329
|
+
boundaries = _check_user_func_return(
|
|
330
|
+
f(tx_grid[..., 0:1], tx_grid[..., 1:]), values.shape
|
|
228
331
|
)
|
|
229
|
-
-
|
|
230
|
-
(
|
|
231
|
-
0,
|
|
232
|
-
)
|
|
233
|
-
|
|
234
|
-
mse_u_boundary = jnp.mean(
|
|
235
|
-
(v_u_boundary(rep_times(omega_border_batch.shape[0]), tile_omega_border_batch))
|
|
236
|
-
** 2,
|
|
237
|
-
axis=0,
|
|
238
|
-
)
|
|
332
|
+
res = values - boundaries
|
|
333
|
+
mse_u_boundary = jnp.sum(res**2, axis=-1)
|
|
239
334
|
return mse_u_boundary
|
|
240
335
|
|
|
241
336
|
|
|
@@ -264,13 +359,6 @@ def boundary_neumann_nonstatio(f, times_batch, omega_border_batch, u, params, fa
|
|
|
264
359
|
An integer which represents the id of the facet which is currently
|
|
265
360
|
considered (in the order provided wy the DataGenerator which is fixed)
|
|
266
361
|
"""
|
|
267
|
-
tile_omega_border_batch = jnp.tile(
|
|
268
|
-
omega_border_batch, reps=(times_batch.shape[0], 1)
|
|
269
|
-
)
|
|
270
|
-
|
|
271
|
-
def rep_times(k):
|
|
272
|
-
return jnp.repeat(times_batch, k, axis=0)
|
|
273
|
-
|
|
274
362
|
# We resort to the shape of the border_batch to determine the dimension as
|
|
275
363
|
# described in the border_batch function
|
|
276
364
|
if jnp.squeeze(omega_border_batch).ndim == 0: # case 1D borders (just a scalar)
|
|
@@ -282,21 +370,99 @@ def boundary_neumann_nonstatio(f, times_batch, omega_border_batch, u, params, fa
|
|
|
282
370
|
# border_batch shape (batch_size, ndim, nfacets)
|
|
283
371
|
n = jnp.array([[-1, 1, 0, 0], [0, 0, -1, 1]])
|
|
284
372
|
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
t, dx, params["nn_params"], jax.lax.stop_gradient(params["eq_params"])
|
|
289
|
-
),
|
|
290
|
-
n[..., facet],
|
|
373
|
+
if isinstance(u, PINN):
|
|
374
|
+
tile_omega_border_batch = jnp.tile(
|
|
375
|
+
omega_border_batch, reps=(times_batch.shape[0], 1)
|
|
291
376
|
)
|
|
292
|
-
- f(t, dx),
|
|
293
|
-
0,
|
|
294
|
-
0,
|
|
295
|
-
)
|
|
296
|
-
mse_u_boundary = jnp.mean(
|
|
297
|
-
(v_neumann(rep_times(omega_border_batch.shape[0]), tile_omega_border_batch))
|
|
298
|
-
** 2,
|
|
299
|
-
axis=0,
|
|
300
|
-
)
|
|
301
377
|
|
|
378
|
+
def rep_times(k):
|
|
379
|
+
return jnp.repeat(times_batch, k, axis=0)
|
|
380
|
+
|
|
381
|
+
u_ = lambda t, x, nn, eq: u(t, x, nn, eq)[0]
|
|
382
|
+
v_neumann = vmap(
|
|
383
|
+
lambda t, dx: jnp.dot(
|
|
384
|
+
grad(u_, 1)(
|
|
385
|
+
t,
|
|
386
|
+
dx,
|
|
387
|
+
params["nn_params"],
|
|
388
|
+
jax.lax.stop_gradient(params["eq_params"]),
|
|
389
|
+
),
|
|
390
|
+
n[..., facet],
|
|
391
|
+
)
|
|
392
|
+
- f(t, dx),
|
|
393
|
+
0,
|
|
394
|
+
0,
|
|
395
|
+
)
|
|
396
|
+
mse_u_boundary = jnp.sum(
|
|
397
|
+
(v_neumann(rep_times(omega_border_batch.shape[0]), tile_omega_border_batch))
|
|
398
|
+
** 2,
|
|
399
|
+
axis=-1,
|
|
400
|
+
) # TODO check if this cartesian product is always relevant
|
|
401
|
+
|
|
402
|
+
elif isinstance(u, SPINN):
|
|
403
|
+
if omega_border_batch.shape[0] == 1:
|
|
404
|
+
omega_border_batch = jnp.tile(
|
|
405
|
+
omega_border_batch, reps=(times_batch.shape[0], 1)
|
|
406
|
+
)
|
|
407
|
+
# ie case 1D
|
|
408
|
+
# otherwise we require batches to have same shape and we do not need
|
|
409
|
+
# this operation
|
|
410
|
+
|
|
411
|
+
# the gradient we see in the PINN case can get gradients wrt to x
|
|
412
|
+
# dimensions at once. But it would be very inefficient in SPINN because
|
|
413
|
+
# of the high dim output of u. So we do 2 explicit forward AD, handling all the
|
|
414
|
+
# high dim output at once
|
|
415
|
+
if omega_border_batch.shape[0] == 1: # i.e. case 1D
|
|
416
|
+
_, du_dx = jax.jvp(
|
|
417
|
+
lambda x: u(
|
|
418
|
+
times_batch,
|
|
419
|
+
x,
|
|
420
|
+
params["nn_params"],
|
|
421
|
+
jax.lax.stop_gradient(params["eq_params"]),
|
|
422
|
+
),
|
|
423
|
+
(omega_border_batch,),
|
|
424
|
+
(jnp.ones_like(x),),
|
|
425
|
+
)
|
|
426
|
+
values = du_dx * n[facet]
|
|
427
|
+
elif omega_border_batch.shape[-1] == 2:
|
|
428
|
+
tangent_vec_0 = jnp.repeat(
|
|
429
|
+
jnp.array([1.0, 0.0])[None], omega_border_batch.shape[0], axis=0
|
|
430
|
+
)
|
|
431
|
+
tangent_vec_1 = jnp.repeat(
|
|
432
|
+
jnp.array([0.0, 1.0])[None], omega_border_batch.shape[0], axis=0
|
|
433
|
+
)
|
|
434
|
+
_, du_dx1 = jax.jvp(
|
|
435
|
+
lambda x: u(
|
|
436
|
+
times_batch,
|
|
437
|
+
x,
|
|
438
|
+
params["nn_params"],
|
|
439
|
+
jax.lax.stop_gradient(params["eq_params"]),
|
|
440
|
+
),
|
|
441
|
+
(omega_border_batch,),
|
|
442
|
+
(tangent_vec_0,),
|
|
443
|
+
)
|
|
444
|
+
_, du_dx2 = jax.jvp(
|
|
445
|
+
lambda x: u(
|
|
446
|
+
times_batch,
|
|
447
|
+
x,
|
|
448
|
+
params["nn_params"],
|
|
449
|
+
jax.lax.stop_gradient(params["eq_params"]),
|
|
450
|
+
),
|
|
451
|
+
(omega_border_batch,),
|
|
452
|
+
(tangent_vec_1,),
|
|
453
|
+
)
|
|
454
|
+
values = du_dx1 * n[0, facet] + du_dx2 * n[1, facet] # dot product
|
|
455
|
+
# explicitly written
|
|
456
|
+
else:
|
|
457
|
+
raise ValueError("Not implemented, we'll do that with a loop")
|
|
458
|
+
|
|
459
|
+
tx_grid = _get_grid(jnp.concatenate([times_batch, omega_border_batch], axis=-1))
|
|
460
|
+
boundaries = _check_user_func_return(
|
|
461
|
+
f(tx_grid[..., 0:1], tx_grid[..., 1:]), values.shape
|
|
462
|
+
)
|
|
463
|
+
res = values - boundaries
|
|
464
|
+
mse_u_boundary = jnp.sum(
|
|
465
|
+
res**2,
|
|
466
|
+
axis=-1,
|
|
467
|
+
)
|
|
302
468
|
return mse_u_boundary
|
jinns/loss/_operators.py
CHANGED
|
@@ -2,13 +2,16 @@ import jax
|
|
|
2
2
|
import jax.numpy as jnp
|
|
3
3
|
from jax import grad
|
|
4
4
|
from functools import partial
|
|
5
|
+
from jinns.utils._pinn import PINN
|
|
6
|
+
from jinns.utils._spinn import SPINN
|
|
5
7
|
|
|
6
8
|
|
|
7
|
-
def
|
|
9
|
+
def _div_rev(u, nn_params, eq_params, x, t=None):
|
|
8
10
|
r"""
|
|
9
|
-
Compute the divergence of a vector field :math:`\mathbf{u}
|
|
10
|
-
:math:`\nabla \cdot u(x)` with
|
|
11
|
-
to :math:`\mathbb{R}^
|
|
11
|
+
Compute the divergence of a vector field :math:`\mathbf{u}`, i.e.,
|
|
12
|
+
:math:`\nabla \cdot \mathbf{u}(\mathbf{x})` with :math:`\mathbf{u}` a vector
|
|
13
|
+
field from :math:`\mathbb{R}^d` to :math:`\mathbb{R}^d`.
|
|
14
|
+
The computation is done using backward AD
|
|
12
15
|
"""
|
|
13
16
|
|
|
14
17
|
def scan_fun(_, i):
|
|
@@ -26,42 +29,131 @@ def _div(u, nn_params, eq_params, x, t=None):
|
|
|
26
29
|
return jnp.sum(accu)
|
|
27
30
|
|
|
28
31
|
|
|
29
|
-
def
|
|
32
|
+
def _div_fwd(u, nn_params, eq_params, x, t=None):
|
|
30
33
|
r"""
|
|
31
|
-
Compute the
|
|
32
|
-
|
|
33
|
-
:math:`\
|
|
34
|
-
|
|
35
|
-
|
|
34
|
+
Compute the divergence of a **batched** vector field :math:`\mathbf{u}`, i.e.,
|
|
35
|
+
:math:`\nabla \cdot \mathbf{u}(\mathbf{x})` with :math:`\mathbf{u}` a vector
|
|
36
|
+
field from :math:`\mathbb{R}^{b \times d}` to :math:`\mathbb{R}^{b \times b
|
|
37
|
+
\times d}`. The result is then in :math:`\mathbb{R}^{b\times b}`.
|
|
38
|
+
Because of the embedding that happens in SPINNs the
|
|
39
|
+
computation is most efficient with forward AD. This is the idea behind Separable PINNs.
|
|
40
|
+
This function is to be used in the context of SPINNs only.
|
|
36
41
|
"""
|
|
37
42
|
|
|
38
43
|
def scan_fun(_, i):
|
|
44
|
+
tangent_vec = jnp.repeat(
|
|
45
|
+
jax.nn.one_hot(i, x.shape[-1])[None], x.shape[0], axis=0
|
|
46
|
+
)
|
|
39
47
|
if t is None:
|
|
40
|
-
|
|
41
|
-
lambda x
|
|
42
|
-
|
|
43
|
-
)(x, nn_params, eq_params)[i]
|
|
48
|
+
__, du_dxi = jax.jvp(
|
|
49
|
+
lambda x: u(x, nn_params, eq_params)[..., i], (x,), (tangent_vec,)
|
|
50
|
+
)
|
|
44
51
|
else:
|
|
45
|
-
|
|
46
|
-
lambda t, x, nn_params, eq_params
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
52
|
+
__, du_dxi = jax.jvp(
|
|
53
|
+
lambda x: u(t, x, nn_params, eq_params)[..., i], (x,), (tangent_vec,)
|
|
54
|
+
)
|
|
55
|
+
return _, du_dxi
|
|
56
|
+
|
|
57
|
+
_, accu = jax.lax.scan(scan_fun, {}, jnp.arange(x.shape[1]))
|
|
58
|
+
return jnp.sum(accu, axis=0)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _laplacian_rev(u, nn_params, eq_params, x, t=None):
|
|
62
|
+
r"""
|
|
63
|
+
Compute the Laplacian of a scalar field :math:`u` (from :math:`\mathbb{R}^d`
|
|
64
|
+
to :math:`\mathbb{R}`) for :math:`\mathbf{x}` of arbitrary dimension, i.e.,
|
|
65
|
+
:math:`\Delta u(\mathbf{x})=\nabla\cdot\nabla u(\mathbf{x})`.
|
|
66
|
+
The computation is done using backward AD.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
# Note that the last dim of u is nec. 1
|
|
70
|
+
if t is None:
|
|
71
|
+
u_ = lambda x: u(x, nn_params, eq_params)[0]
|
|
72
|
+
else:
|
|
73
|
+
u_ = lambda t, x: u(t, x, nn_params, eq_params)[0]
|
|
74
|
+
|
|
75
|
+
if t is None:
|
|
76
|
+
return jnp.trace(jax.hessian(u_)(x))
|
|
77
|
+
else:
|
|
78
|
+
return jnp.trace(jax.hessian(u_, argnums=1)(t, x))
|
|
79
|
+
|
|
80
|
+
# For a small d, we found out that trace of the Hessian is faster, but the
|
|
81
|
+
# trick below for taking directly the diagonal elements might prove useful
|
|
82
|
+
# in higher dimensions?
|
|
83
|
+
|
|
84
|
+
# def scan_fun(_, i):
|
|
85
|
+
# if t is None:
|
|
86
|
+
# d2u_dxi2 = grad(
|
|
87
|
+
# lambda x: grad(u_, 0)(x)[i],
|
|
88
|
+
# 0,
|
|
89
|
+
# )(
|
|
90
|
+
# x
|
|
91
|
+
# )[i]
|
|
92
|
+
# else:
|
|
93
|
+
# d2u_dxi2 = grad(
|
|
94
|
+
# lambda t, x: grad(u_, 1)(t, x)[i],
|
|
95
|
+
# 1,
|
|
96
|
+
# )(
|
|
97
|
+
# t, x
|
|
98
|
+
# )[i]
|
|
99
|
+
# return _, d2u_dxi2
|
|
100
|
+
|
|
101
|
+
# _, trace_hessian = jax.lax.scan(scan_fun, {}, jnp.arange(x.shape[0]))
|
|
102
|
+
# return jnp.sum(trace_hessian)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _laplacian_fwd(u, nn_params, eq_params, x, t=None):
|
|
106
|
+
r"""
|
|
107
|
+
Compute the Laplacian of a **batched** scalar field :math:`u`
|
|
108
|
+
(from :math:`\mathbb{R}^{b\times d}` to :math:`\mathbb{R}^{b\times b}`)
|
|
109
|
+
for :math:`\mathbf{x}` of arbitrary dimension :math:`d` with batch
|
|
110
|
+
dimension :math:`b`.
|
|
111
|
+
Because of the embedding that happens in SPINNs the
|
|
112
|
+
computation is most efficient with forward AD. This is the idea behind Separable PINNs.
|
|
113
|
+
This function is to be used in the context of SPINNs only.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
def scan_fun(_, i):
|
|
117
|
+
tangent_vec = jnp.repeat(
|
|
118
|
+
jax.nn.one_hot(i, x.shape[-1])[None], x.shape[0], axis=0
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
if t is None:
|
|
122
|
+
du_dxi_fun = lambda x: jax.jvp(
|
|
123
|
+
lambda x: u(x, nn_params, eq_params)[..., 0], (x,), (tangent_vec,)
|
|
124
|
+
)[
|
|
125
|
+
1
|
|
126
|
+
] # Note the indexing [..., 0]
|
|
127
|
+
__, d2u_dxi2 = jax.jvp(du_dxi_fun, (x,), (tangent_vec,))
|
|
128
|
+
else:
|
|
129
|
+
du_dxi_fun = lambda x: jax.jvp(
|
|
130
|
+
lambda x: u(t, x, nn_params, eq_params)[..., 0], (x,), (tangent_vec,)
|
|
131
|
+
)[
|
|
132
|
+
1
|
|
133
|
+
] # Note the indexing [..., 0]
|
|
134
|
+
__, d2u_dxi2 = jax.jvp(du_dxi_fun, (x,), (tangent_vec,))
|
|
51
135
|
return _, d2u_dxi2
|
|
52
136
|
|
|
53
|
-
_, trace_hessian = jax.lax.scan(scan_fun, {}, jnp.arange(x.shape[
|
|
54
|
-
return jnp.sum(trace_hessian)
|
|
137
|
+
_, trace_hessian = jax.lax.scan(scan_fun, {}, jnp.arange(x.shape[1]))
|
|
138
|
+
return jnp.sum(trace_hessian, axis=0)
|
|
55
139
|
|
|
56
140
|
|
|
57
141
|
def _vectorial_laplacian(u, nn_params, eq_params, x, t=None, u_vec_ndim=None):
|
|
58
142
|
r"""
|
|
59
|
-
Compute the vectorial Laplacian of a vector field
|
|
60
|
-
|
|
61
|
-
:math:`\
|
|
143
|
+
Compute the vectorial Laplacian of a vector field :math:`\mathbf{u}` (from
|
|
144
|
+
:math:`\mathbb{R}^d`
|
|
145
|
+
to :math:`\mathbb{R}^n`) for :math:`\mathbf{x}` of arbitrary dimension, i.e.,
|
|
146
|
+
:math:`\Delta \mathbf{u}(\mathbf{x})=\nabla\cdot\nabla
|
|
147
|
+
\mathbf{u}(\mathbf{x})`.
|
|
148
|
+
|
|
149
|
+
**Note:** We need to provide `u_vec_ndim` the dimension of the vector
|
|
150
|
+
:math:`\mathbf{u}(\mathbf{x})` if it is different than that of
|
|
151
|
+
:math:`\mathbf{x}`.
|
|
62
152
|
|
|
63
|
-
**Note:**
|
|
64
|
-
:math:`\
|
|
153
|
+
**Note:** `u` can be a SPINN, in this case, it corresponds to a vector
|
|
154
|
+
field from (from :math:`\mathbb{R}^{b\times d}` to
|
|
155
|
+
:math:`\mathbb{R}^{b\times b\times n}`) and forward mode AD is used.
|
|
156
|
+
Technically, the return is of dimension :math:`n\times b \times b`.
|
|
65
157
|
"""
|
|
66
158
|
if u_vec_ndim is None:
|
|
67
159
|
u_vec_ndim = x.shape[0]
|
|
@@ -69,25 +161,42 @@ def _vectorial_laplacian(u, nn_params, eq_params, x, t=None, u_vec_ndim=None):
|
|
|
69
161
|
def scan_fun(_, j):
|
|
70
162
|
# The loop over the components of u(x). We compute one Laplacian for
|
|
71
163
|
# each of these components
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
164
|
+
# Note the expand_dims
|
|
165
|
+
if isinstance(u, PINN):
|
|
166
|
+
if t is None:
|
|
167
|
+
uj = lambda x, nn_params, eq_params: jnp.expand_dims(
|
|
168
|
+
u(x, nn_params, eq_params)[j], axis=-1
|
|
169
|
+
)
|
|
170
|
+
else:
|
|
171
|
+
uj = lambda t, x, nn_params, eq_params: jnp.expand_dims(
|
|
172
|
+
u(t, x, nn_params, eq_params)[j], axis=-1
|
|
173
|
+
)
|
|
174
|
+
lap_on_j = _laplacian_rev(uj, nn_params, eq_params, x, t)
|
|
175
|
+
elif isinstance(u, SPINN):
|
|
176
|
+
if t is None:
|
|
177
|
+
uj = lambda x, nn_params, eq_params: jnp.expand_dims(
|
|
178
|
+
u(x, nn_params, eq_params)[..., j], axis=-1
|
|
179
|
+
)
|
|
180
|
+
else:
|
|
181
|
+
uj = lambda t, x, nn_params, eq_params: jnp.expand_dims(
|
|
182
|
+
u(t, x, nn_params, eq_params)[..., j], axis=-1
|
|
183
|
+
)
|
|
184
|
+
lap_on_j = _laplacian_fwd(uj, nn_params, eq_params, x, t)
|
|
185
|
+
|
|
77
186
|
return _, lap_on_j
|
|
78
187
|
|
|
79
188
|
_, vec_lap = jax.lax.scan(scan_fun, {}, jnp.arange(u_vec_ndim))
|
|
80
189
|
return vec_lap
|
|
81
190
|
|
|
82
191
|
|
|
83
|
-
def
|
|
192
|
+
def _u_dot_nabla_times_u_rev(u, nn_params, eq_params, x, t=None):
|
|
84
193
|
r"""
|
|
85
|
-
Implement :math:`((\mathbf{u}\cdot\nabla)\mathbf{u})(x)` for
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
Currently for `x.ndim=2`
|
|
89
|
-
|
|
90
|
-
|
|
194
|
+
Implement :math:`((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})` for
|
|
195
|
+
:math:`\mathbf{x}` of arbitrary
|
|
196
|
+
dimension. :math:`\mathbf{u}` is a vector field from :math:`\mathbb{R}^n`
|
|
197
|
+
to :math:`\mathbb{R}^n`. **Currently for** `x.ndim=2` **only**.
|
|
198
|
+
The computation is done using backward AD.
|
|
199
|
+
We do not use loops but code explicitly the expression to avoid
|
|
91
200
|
computing twice some terms
|
|
92
201
|
"""
|
|
93
202
|
if x.shape[0] == 2:
|
|
@@ -127,17 +236,64 @@ def _u_dot_nabla_times_u(u, nn_params, eq_params, x, t=None):
|
|
|
127
236
|
raise NotImplementedError("x.ndim must be 2")
|
|
128
237
|
|
|
129
238
|
|
|
239
|
+
def _u_dot_nabla_times_u_fwd(u, nn_params, eq_params, x, t=None):
|
|
240
|
+
r"""
|
|
241
|
+
Implement :math:`((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})` for
|
|
242
|
+
:math:`\mathbf{x}` of arbitrary dimension **with a batch dimension**.
|
|
243
|
+
I.e., :math:`\mathbf{u}` is a vector field from :math:`\mathbb{R}^{b\times
|
|
244
|
+
b}`
|
|
245
|
+
to :math:`\mathbb{R}^{b\times b \times d}`. **Currently for** :math:`d=2`
|
|
246
|
+
**only**.
|
|
247
|
+
We do not use loops but code explicitly the expression to avoid
|
|
248
|
+
computing twice some terms.
|
|
249
|
+
Because of the embedding that happens in SPINNs the
|
|
250
|
+
computation is most efficient with forward AD. This is the idea behind Separable PINNs.
|
|
251
|
+
This function is to be used in the context of SPINNs only.
|
|
252
|
+
"""
|
|
253
|
+
if x.shape[-1] == 2:
|
|
254
|
+
tangent_vec_0 = jnp.repeat(jnp.array([1.0, 0.0])[None], x.shape[0], axis=0)
|
|
255
|
+
tangent_vec_1 = jnp.repeat(jnp.array([0.0, 1.0])[None], x.shape[0], axis=0)
|
|
256
|
+
if t is None:
|
|
257
|
+
u_at_x, du_dx = jax.jvp(
|
|
258
|
+
lambda x: u(x, nn_params, eq_params), (x,), (tangent_vec_0,)
|
|
259
|
+
) # thanks to forward AD this gets dux_dx and duy_dx in a vector
|
|
260
|
+
# ie the derivatives of both components of u wrt x
|
|
261
|
+
# this also gets the vector of u evaluated at x
|
|
262
|
+
u_at_x, du_dy = jax.jvp(
|
|
263
|
+
lambda x: u(x, nn_params, eq_params), (x,), (tangent_vec_1,)
|
|
264
|
+
) # thanks to forward AD this gets dux_dy and duy_dy in a vector
|
|
265
|
+
# ie the derivatives of both components of u wrt y
|
|
266
|
+
|
|
267
|
+
else:
|
|
268
|
+
u_at_x, du_dx = jax.jvp(
|
|
269
|
+
lambda x: u(t, x, nn_params, eq_params), (x,), (tangent_vec_0,)
|
|
270
|
+
)
|
|
271
|
+
u_at_x, du_dy = jax.jvp(
|
|
272
|
+
lambda x: u(t, x, nn_params, eq_params), (x,), (tangent_vec_1,)
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
return jnp.stack(
|
|
276
|
+
[
|
|
277
|
+
u_at_x[..., 0] * du_dx[..., 0] + u_at_x[..., 1] * du_dy[..., 0],
|
|
278
|
+
u_at_x[..., 0] * du_dx[..., 1] + u_at_x[..., 1] * du_dy[..., 1],
|
|
279
|
+
],
|
|
280
|
+
axis=-1,
|
|
281
|
+
)
|
|
282
|
+
else:
|
|
283
|
+
raise NotImplementedError("x.ndim must be 2")
|
|
284
|
+
|
|
285
|
+
|
|
130
286
|
def _sobolev(u, m, statio=True):
|
|
131
287
|
r"""
|
|
132
|
-
Compute the Sobolev regularization of order m
|
|
133
|
-
of a scalar field u (from :math:`\mathbb{R}^
|
|
134
|
-
for x of arbitrary dimension i.e
|
|
288
|
+
Compute the Sobolev regularization of order :math:`m`
|
|
289
|
+
of a scalar field :math:`u` (from :math:`\mathbb{R}^{d}` to :math:`\mathbb{R}`)
|
|
290
|
+
for :math:`\mathbf{x}` of arbitrary dimension :math:`d`, i.e.,
|
|
135
291
|
:math:`\frac{1}{n_l}\sum_{l=1}^{n_l}\sum_{|\alpha|=1}^{m+1} ||\partial^{\alpha} u(x_l)||_2^2` where
|
|
136
|
-
:math:`m\geq\max(d_1 // 2, K)` with
|
|
292
|
+
:math:`m\geq\max(d_1 // 2, K)` with :math:`K` the order of the differential
|
|
137
293
|
operator.
|
|
138
294
|
|
|
139
|
-
This regularization is proposed in
|
|
140
|
-
|
|
295
|
+
This regularization is proposed in *Convergence and error analysis of
|
|
296
|
+
PINNs*, Doumeche et al., 2023, https://arxiv.org/pdf/2305.01240.pdf
|
|
141
297
|
"""
|
|
142
298
|
|
|
143
299
|
def jac_recursive(u, order, start):
|