jinns 0.4.2__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,9 @@
1
1
  import jax
2
2
  import jax.numpy as jnp
3
3
  from jax import vmap, grad
4
+ from jinns.utils._utils import _get_grid, _check_user_func_return
5
+ from jinns.utils._pinn import PINN
6
+ from jinns.utils._spinn import SPINN
4
7
 
5
8
 
6
9
  def _compute_boundary_loss_statio(
@@ -127,18 +130,32 @@ def boundary_dirichlet_statio(f, border_batch, u, params):
127
130
  dictionaries: `eq_params` and `nn_params``, respectively the
128
131
  differential equation parameters and the neural network parameter
129
132
  """
130
- v_u_boundary = vmap(
131
- lambda dx: u(
132
- dx,
133
- u_params=params["nn_params"],
134
- eq_params=jax.lax.stop_gradient(params["eq_params"]),
133
+ if isinstance(u, PINN):
134
+ v_u_boundary = vmap(
135
+ lambda dx: u(
136
+ dx,
137
+ u_params=params["nn_params"],
138
+ eq_params=jax.lax.stop_gradient(params["eq_params"]),
139
+ )
140
+ - f(dx),
141
+ (0),
142
+ 0,
135
143
  )
136
- - f(dx),
137
- (0),
138
- 0,
139
- )
140
144
 
141
- mse_u_boundary = jnp.mean((v_u_boundary(border_batch)) ** 2, axis=0)
145
+ mse_u_boundary = jnp.sum((v_u_boundary(border_batch)) ** 2, axis=-1)
146
+ elif isinstance(u, SPINN):
147
+ values = u(
148
+ border_batch,
149
+ params["nn_params"],
150
+ jax.lax.stop_gradient(params["eq_params"]),
151
+ )
152
+ x_grid = _get_grid(border_batch)
153
+ boundaries = _check_user_func_return(f(x_grid), values.shape)
154
+ res = values - boundaries
155
+ mse_u_boundary = jnp.sum(
156
+ res**2,
157
+ axis=-1,
158
+ )
142
159
  return mse_u_boundary
143
160
 
144
161
 
@@ -177,17 +194,69 @@ def boundary_neumann_statio(f, border_batch, u, params, facet):
177
194
  # border_batch shape (batch_size, ndim, nfacets)
178
195
  n = jnp.array([[-1, 1, 0, 0], [0, 0, -1, 1]])
179
196
 
180
- v_neumann = vmap(
181
- lambda dx: jnp.dot(
182
- grad(u, 0)(
183
- dx, params["nn_params"], jax.lax.stop_gradient(params["eq_params"])
184
- ),
185
- n[..., facet],
197
+ if isinstance(u, PINN):
198
+ u_ = lambda x, nn, eq: u(t, x, nn, eq)[0]
199
+ v_neumann = vmap(
200
+ lambda dx: jnp.dot(
201
+ grad(u_, 0)(
202
+ dx, params["nn_params"], jax.lax.stop_gradient(params["eq_params"])
203
+ ),
204
+ n[..., facet],
205
+ )
206
+ - f(dx),
207
+ 0,
186
208
  )
187
- - f(dx),
188
- 0,
189
- )
190
- mse_u_boundary = jnp.mean((v_neumann(border_batch)) ** 2, axis=0)
209
+ mse_u_boundary = jnp.sum((v_neumann(border_batch)) ** 2, axis=-1)
210
+ elif isinstance(u, SPINN):
211
+ # the gradient we see in the PINN case can get gradients wrt to x
212
+ # dimensions at once. But it would be very inefficient in SPINN because
213
+ # of the high dim output of u. So we do 2 explicit forward AD, handling all the
214
+ # high dim output at once
215
+ if border_batch.shape[0] == 1: # i.e. case 1D
216
+ _, du_dx = jax.jvp(
217
+ lambda x: u(
218
+ x,
219
+ params["nn_params"],
220
+ jax.lax.stop_gradient(params["eq_params"]),
221
+ ),
222
+ (omega_border_batch,),
223
+ (jnp.ones_like(x),),
224
+ )
225
+ values = du_dx * n[facet]
226
+ elif omega_border_batch.shape[-1] == 2:
227
+ tangent_vec_0 = jnp.repeat(
228
+ jnp.array([1.0, 0.0])[None], omega_border_batch.shape[0], axis=0
229
+ )
230
+ tangent_vec_1 = jnp.repeat(
231
+ jnp.array([0.0, 1.0])[None], omega_border_batch.shape[0], axis=0
232
+ )
233
+ _, du_dx1 = jax.jvp(
234
+ lambda x: u(
235
+ x,
236
+ params["nn_params"],
237
+ jax.lax.stop_gradient(params["eq_params"]),
238
+ ),
239
+ (omega_border_batch,),
240
+ (tangent_vec_0,),
241
+ )
242
+ _, du_dx2 = jax.jvp(
243
+ lambda x: u(
244
+ x,
245
+ params["nn_params"],
246
+ jax.lax.stop_gradient(params["eq_params"]),
247
+ ),
248
+ (omega_border_batch,),
249
+ (tangent_vec_1,),
250
+ )
251
+ values = du_dx1 * n[0, facet] + du_dx2 * n[1, facet] # dot product
252
+ # explicitly written
253
+ else:
254
+ raise ValueError("Not implemented, we'll do that with a loop")
255
+
256
+ x_grid = _get_grid(border_batch)
257
+ boundaries = _check_user_func_return(f(x_grid), values.shape)
258
+ res = values - boundaries
259
+ mse_u_boundary = jnp.sum(res**2, axis=-1)
191
260
  return mse_u_boundary
192
261
 
193
262
 
@@ -212,30 +281,56 @@ def boundary_dirichlet_nonstatio(f, times_batch, omega_border_batch, u, params):
212
281
  dictionaries: `eq_params` and `nn_params``, respectively the
213
282
  differential equation parameters and the neural network parameter
214
283
  """
215
- tile_omega_border_batch = jnp.tile(
216
- omega_border_batch, reps=(times_batch.shape[0], 1)
217
- )
218
-
219
- def rep_times(k):
220
- return jnp.repeat(times_batch, k, axis=0)
221
-
222
- v_u_boundary = vmap(
223
- lambda t, dx: u(
224
- t,
225
- dx,
226
- u_params=params["nn_params"],
227
- eq_params=jax.lax.stop_gradient(params["eq_params"]),
284
+ if isinstance(u, PINN):
285
+ tile_omega_border_batch = jnp.tile(
286
+ omega_border_batch, reps=(times_batch.shape[0], 1)
287
+ )
288
+
289
+ def rep_times(k):
290
+ return jnp.repeat(times_batch, k, axis=0)
291
+
292
+ v_u_boundary = vmap(
293
+ lambda t, dx: u(
294
+ t,
295
+ dx,
296
+ u_params=params["nn_params"],
297
+ eq_params=jax.lax.stop_gradient(params["eq_params"]),
298
+ )
299
+ - f(t, dx),
300
+ (0, 0),
301
+ 0,
302
+ )
303
+ res = v_u_boundary(
304
+ rep_times(omega_border_batch.shape[0]), tile_omega_border_batch
305
+ ) # TODO check if this cartesian product is always relevant
306
+ mse_u_boundary = jnp.sum(
307
+ res**2,
308
+ axis=-1,
309
+ )
310
+ elif isinstance(u, SPINN):
311
+ tile_omega_border_batch = jnp.tile(
312
+ omega_border_batch, reps=(times_batch.shape[0], 1)
313
+ )
314
+
315
+ if omega_border_batch.shape[0] == 1:
316
+ omega_border_batch = jnp.tile(
317
+ omega_border_batch, reps=(times_batch.shape[0], 1)
318
+ )
319
+ # otherwise we require batches to have same shape and we do not need
320
+ # this operation
321
+
322
+ values = u(
323
+ times_batch,
324
+ tile_omega_border_batch,
325
+ params["nn_params"],
326
+ jax.lax.stop_gradient(params["eq_params"]),
327
+ )
328
+ tx_grid = _get_grid(jnp.concatenate([times_batch, omega_border_batch], axis=-1))
329
+ boundaries = _check_user_func_return(
330
+ f(tx_grid[..., 0:1], tx_grid[..., 1:]), values.shape
228
331
  )
229
- - f(t, dx),
230
- (0, 0),
231
- 0,
232
- )
233
-
234
- mse_u_boundary = jnp.mean(
235
- (v_u_boundary(rep_times(omega_border_batch.shape[0]), tile_omega_border_batch))
236
- ** 2,
237
- axis=0,
238
- )
332
+ res = values - boundaries
333
+ mse_u_boundary = jnp.sum(res**2, axis=-1)
239
334
  return mse_u_boundary
240
335
 
241
336
 
@@ -264,13 +359,6 @@ def boundary_neumann_nonstatio(f, times_batch, omega_border_batch, u, params, fa
264
359
  An integer which represents the id of the facet which is currently
265
360
  considered (in the order provided wy the DataGenerator which is fixed)
266
361
  """
267
- tile_omega_border_batch = jnp.tile(
268
- omega_border_batch, reps=(times_batch.shape[0], 1)
269
- )
270
-
271
- def rep_times(k):
272
- return jnp.repeat(times_batch, k, axis=0)
273
-
274
362
  # We resort to the shape of the border_batch to determine the dimension as
275
363
  # described in the border_batch function
276
364
  if jnp.squeeze(omega_border_batch).ndim == 0: # case 1D borders (just a scalar)
@@ -282,21 +370,99 @@ def boundary_neumann_nonstatio(f, times_batch, omega_border_batch, u, params, fa
282
370
  # border_batch shape (batch_size, ndim, nfacets)
283
371
  n = jnp.array([[-1, 1, 0, 0], [0, 0, -1, 1]])
284
372
 
285
- v_neumann = vmap(
286
- lambda t, dx: jnp.dot(
287
- grad(u, 1)(
288
- t, dx, params["nn_params"], jax.lax.stop_gradient(params["eq_params"])
289
- ),
290
- n[..., facet],
373
+ if isinstance(u, PINN):
374
+ tile_omega_border_batch = jnp.tile(
375
+ omega_border_batch, reps=(times_batch.shape[0], 1)
291
376
  )
292
- - f(t, dx),
293
- 0,
294
- 0,
295
- )
296
- mse_u_boundary = jnp.mean(
297
- (v_neumann(rep_times(omega_border_batch.shape[0]), tile_omega_border_batch))
298
- ** 2,
299
- axis=0,
300
- )
301
377
 
378
+ def rep_times(k):
379
+ return jnp.repeat(times_batch, k, axis=0)
380
+
381
+ u_ = lambda t, x, nn, eq: u(t, x, nn, eq)[0]
382
+ v_neumann = vmap(
383
+ lambda t, dx: jnp.dot(
384
+ grad(u_, 1)(
385
+ t,
386
+ dx,
387
+ params["nn_params"],
388
+ jax.lax.stop_gradient(params["eq_params"]),
389
+ ),
390
+ n[..., facet],
391
+ )
392
+ - f(t, dx),
393
+ 0,
394
+ 0,
395
+ )
396
+ mse_u_boundary = jnp.sum(
397
+ (v_neumann(rep_times(omega_border_batch.shape[0]), tile_omega_border_batch))
398
+ ** 2,
399
+ axis=-1,
400
+ ) # TODO check if this cartesian product is always relevant
401
+
402
+ elif isinstance(u, SPINN):
403
+ if omega_border_batch.shape[0] == 1:
404
+ omega_border_batch = jnp.tile(
405
+ omega_border_batch, reps=(times_batch.shape[0], 1)
406
+ )
407
+ # ie case 1D
408
+ # otherwise we require batches to have same shape and we do not need
409
+ # this operation
410
+
411
+ # the gradient we see in the PINN case can get gradients wrt to x
412
+ # dimensions at once. But it would be very inefficient in SPINN because
413
+ # of the high dim output of u. So we do 2 explicit forward AD, handling all the
414
+ # high dim output at once
415
+ if omega_border_batch.shape[0] == 1: # i.e. case 1D
416
+ _, du_dx = jax.jvp(
417
+ lambda x: u(
418
+ times_batch,
419
+ x,
420
+ params["nn_params"],
421
+ jax.lax.stop_gradient(params["eq_params"]),
422
+ ),
423
+ (omega_border_batch,),
424
+ (jnp.ones_like(x),),
425
+ )
426
+ values = du_dx * n[facet]
427
+ elif omega_border_batch.shape[-1] == 2:
428
+ tangent_vec_0 = jnp.repeat(
429
+ jnp.array([1.0, 0.0])[None], omega_border_batch.shape[0], axis=0
430
+ )
431
+ tangent_vec_1 = jnp.repeat(
432
+ jnp.array([0.0, 1.0])[None], omega_border_batch.shape[0], axis=0
433
+ )
434
+ _, du_dx1 = jax.jvp(
435
+ lambda x: u(
436
+ times_batch,
437
+ x,
438
+ params["nn_params"],
439
+ jax.lax.stop_gradient(params["eq_params"]),
440
+ ),
441
+ (omega_border_batch,),
442
+ (tangent_vec_0,),
443
+ )
444
+ _, du_dx2 = jax.jvp(
445
+ lambda x: u(
446
+ times_batch,
447
+ x,
448
+ params["nn_params"],
449
+ jax.lax.stop_gradient(params["eq_params"]),
450
+ ),
451
+ (omega_border_batch,),
452
+ (tangent_vec_1,),
453
+ )
454
+ values = du_dx1 * n[0, facet] + du_dx2 * n[1, facet] # dot product
455
+ # explicitly written
456
+ else:
457
+ raise ValueError("Not implemented, we'll do that with a loop")
458
+
459
+ tx_grid = _get_grid(jnp.concatenate([times_batch, omega_border_batch], axis=-1))
460
+ boundaries = _check_user_func_return(
461
+ f(tx_grid[..., 0:1], tx_grid[..., 1:]), values.shape
462
+ )
463
+ res = values - boundaries
464
+ mse_u_boundary = jnp.sum(
465
+ res**2,
466
+ axis=-1,
467
+ )
302
468
  return mse_u_boundary
jinns/loss/_operators.py CHANGED
@@ -2,13 +2,16 @@ import jax
2
2
  import jax.numpy as jnp
3
3
  from jax import grad
4
4
  from functools import partial
5
+ from jinns.utils._pinn import PINN
6
+ from jinns.utils._spinn import SPINN
5
7
 
6
8
 
7
- def _div(u, nn_params, eq_params, x, t=None):
9
+ def _div_rev(u, nn_params, eq_params, x, t=None):
8
10
  r"""
9
- Compute the divergence of a vector field :math:`\mathbf{u}` ie
10
- :math:`\nabla \cdot u(x)` with u a vector field from :math:`\mathbb{R}^n`
11
- to :math:`\mathbb{R}^n`
11
+ Compute the divergence of a vector field :math:`\mathbf{u}`, i.e.,
12
+ :math:`\nabla \cdot \mathbf{u}(\mathbf{x})` with :math:`\mathbf{u}` a vector
13
+ field from :math:`\mathbb{R}^d` to :math:`\mathbb{R}^d`.
14
+ The computation is done using backward AD
12
15
  """
13
16
 
14
17
  def scan_fun(_, i):
@@ -26,42 +29,131 @@ def _div(u, nn_params, eq_params, x, t=None):
26
29
  return jnp.sum(accu)
27
30
 
28
31
 
29
- def _laplacian(u, nn_params, eq_params, x, t=None):
32
+ def _div_fwd(u, nn_params, eq_params, x, t=None):
30
33
  r"""
31
- Compute the Laplacian of a scalar field u (from :math:`\mathbb{R}^n`
32
- to :math:`\mathbb{R}`) for x of arbitrary dimension ie
33
- :math:`\Delta u(x)=\nabla\cdot\nabla u(x)`
34
- For computational reason we do not compute the trace of the Hessian but
35
- we explicitly call the gradient twice
34
+ Compute the divergence of a **batched** vector field :math:`\mathbf{u}`, i.e.,
35
+ :math:`\nabla \cdot \mathbf{u}(\mathbf{x})` with :math:`\mathbf{u}` a vector
36
+ field from :math:`\mathbb{R}^{b \times d}` to :math:`\mathbb{R}^{b \times b
37
+ \times d}`. The result is then in :math:`\mathbb{R}^{b\times b}`.
38
+ Because of the embedding that happens in SPINNs the
39
+ computation is most efficient with forward AD. This is the idea behind Separable PINNs.
40
+ This function is to be used in the context of SPINNs only.
36
41
  """
37
42
 
38
43
  def scan_fun(_, i):
44
+ tangent_vec = jnp.repeat(
45
+ jax.nn.one_hot(i, x.shape[-1])[None], x.shape[0], axis=0
46
+ )
39
47
  if t is None:
40
- d2u_dxi2 = grad(
41
- lambda x, nn_params, eq_params: grad(u, 0)(x, nn_params, eq_params)[i],
42
- 0,
43
- )(x, nn_params, eq_params)[i]
48
+ __, du_dxi = jax.jvp(
49
+ lambda x: u(x, nn_params, eq_params)[..., i], (x,), (tangent_vec,)
50
+ )
44
51
  else:
45
- d2u_dxi2 = grad(
46
- lambda t, x, nn_params, eq_params: grad(u, 1)(
47
- t, x, nn_params, eq_params
48
- )[i],
49
- 1,
50
- )(t, x, nn_params, eq_params)[i]
52
+ __, du_dxi = jax.jvp(
53
+ lambda x: u(t, x, nn_params, eq_params)[..., i], (x,), (tangent_vec,)
54
+ )
55
+ return _, du_dxi
56
+
57
+ _, accu = jax.lax.scan(scan_fun, {}, jnp.arange(x.shape[1]))
58
+ return jnp.sum(accu, axis=0)
59
+
60
+
61
+ def _laplacian_rev(u, nn_params, eq_params, x, t=None):
62
+ r"""
63
+ Compute the Laplacian of a scalar field :math:`u` (from :math:`\mathbb{R}^d`
64
+ to :math:`\mathbb{R}`) for :math:`\mathbf{x}` of arbitrary dimension, i.e.,
65
+ :math:`\Delta u(\mathbf{x})=\nabla\cdot\nabla u(\mathbf{x})`.
66
+ The computation is done using backward AD.
67
+ """
68
+
69
+ # Note that the last dim of u is nec. 1
70
+ if t is None:
71
+ u_ = lambda x: u(x, nn_params, eq_params)[0]
72
+ else:
73
+ u_ = lambda t, x: u(t, x, nn_params, eq_params)[0]
74
+
75
+ if t is None:
76
+ return jnp.trace(jax.hessian(u_)(x))
77
+ else:
78
+ return jnp.trace(jax.hessian(u_, argnums=1)(t, x))
79
+
80
+ # For a small d, we found out that trace of the Hessian is faster, but the
81
+ # trick below for taking directly the diagonal elements might prove useful
82
+ # in higher dimensions?
83
+
84
+ # def scan_fun(_, i):
85
+ # if t is None:
86
+ # d2u_dxi2 = grad(
87
+ # lambda x: grad(u_, 0)(x)[i],
88
+ # 0,
89
+ # )(
90
+ # x
91
+ # )[i]
92
+ # else:
93
+ # d2u_dxi2 = grad(
94
+ # lambda t, x: grad(u_, 1)(t, x)[i],
95
+ # 1,
96
+ # )(
97
+ # t, x
98
+ # )[i]
99
+ # return _, d2u_dxi2
100
+
101
+ # _, trace_hessian = jax.lax.scan(scan_fun, {}, jnp.arange(x.shape[0]))
102
+ # return jnp.sum(trace_hessian)
103
+
104
+
105
+ def _laplacian_fwd(u, nn_params, eq_params, x, t=None):
106
+ r"""
107
+ Compute the Laplacian of a **batched** scalar field :math:`u`
108
+ (from :math:`\mathbb{R}^{b\times d}` to :math:`\mathbb{R}^{b\times b}`)
109
+ for :math:`\mathbf{x}` of arbitrary dimension :math:`d` with batch
110
+ dimension :math:`b`.
111
+ Because of the embedding that happens in SPINNs the
112
+ computation is most efficient with forward AD. This is the idea behind Separable PINNs.
113
+ This function is to be used in the context of SPINNs only.
114
+ """
115
+
116
+ def scan_fun(_, i):
117
+ tangent_vec = jnp.repeat(
118
+ jax.nn.one_hot(i, x.shape[-1])[None], x.shape[0], axis=0
119
+ )
120
+
121
+ if t is None:
122
+ du_dxi_fun = lambda x: jax.jvp(
123
+ lambda x: u(x, nn_params, eq_params)[..., 0], (x,), (tangent_vec,)
124
+ )[
125
+ 1
126
+ ] # Note the indexing [..., 0]
127
+ __, d2u_dxi2 = jax.jvp(du_dxi_fun, (x,), (tangent_vec,))
128
+ else:
129
+ du_dxi_fun = lambda x: jax.jvp(
130
+ lambda x: u(t, x, nn_params, eq_params)[..., 0], (x,), (tangent_vec,)
131
+ )[
132
+ 1
133
+ ] # Note the indexing [..., 0]
134
+ __, d2u_dxi2 = jax.jvp(du_dxi_fun, (x,), (tangent_vec,))
51
135
  return _, d2u_dxi2
52
136
 
53
- _, trace_hessian = jax.lax.scan(scan_fun, {}, jnp.arange(x.shape[0]))
54
- return jnp.sum(trace_hessian)
137
+ _, trace_hessian = jax.lax.scan(scan_fun, {}, jnp.arange(x.shape[1]))
138
+ return jnp.sum(trace_hessian, axis=0)
55
139
 
56
140
 
57
141
  def _vectorial_laplacian(u, nn_params, eq_params, x, t=None, u_vec_ndim=None):
58
142
  r"""
59
- Compute the vectorial Laplacian of a vector field u (from :math:`\mathbb{R}^m`
60
- to :math:`\mathbb{R}^n`) for x of arbitrary dimension ie
61
- :math:`\Delta \mathbf{u}(x)=\nabla\cdot\nabla \mathbf{u}(x)`
143
+ Compute the vectorial Laplacian of a vector field :math:`\mathbf{u}` (from
144
+ :math:`\mathbb{R}^d`
145
+ to :math:`\mathbb{R}^n`) for :math:`\mathbf{x}` of arbitrary dimension, i.e.,
146
+ :math:`\Delta \mathbf{u}(\mathbf{x})=\nabla\cdot\nabla
147
+ \mathbf{u}(\mathbf{x})`.
148
+
149
+ **Note:** We need to provide `u_vec_ndim` the dimension of the vector
150
+ :math:`\mathbf{u}(\mathbf{x})` if it is different than that of
151
+ :math:`\mathbf{x}`.
62
152
 
63
- **Note:** We need to provide in u_vec_ndim the dimension of the vector
64
- :math:`\mathbf{u}(x)` if it is different than that of x
153
+ **Note:** `u` can be a SPINN, in this case, it corresponds to a vector
154
+ field from (from :math:`\mathbb{R}^{b\times d}` to
155
+ :math:`\mathbb{R}^{b\times b\times n}`) and forward mode AD is used.
156
+ Technically, the return is of dimension :math:`n\times b \times b`.
65
157
  """
66
158
  if u_vec_ndim is None:
67
159
  u_vec_ndim = x.shape[0]
@@ -69,25 +161,42 @@ def _vectorial_laplacian(u, nn_params, eq_params, x, t=None, u_vec_ndim=None):
69
161
  def scan_fun(_, j):
70
162
  # The loop over the components of u(x). We compute one Laplacian for
71
163
  # each of these components
72
- if t is None:
73
- uj = lambda x, nn_params, eq_params: u(x, nn_params, eq_params)[j]
74
- else:
75
- uj = lambda t, x, nn_params, eq_params: u(t, x, nn_params, eq_params)[j]
76
- lap_on_j = _laplacian(uj, nn_params, eq_params, x, t)
164
+ # Note the expand_dims
165
+ if isinstance(u, PINN):
166
+ if t is None:
167
+ uj = lambda x, nn_params, eq_params: jnp.expand_dims(
168
+ u(x, nn_params, eq_params)[j], axis=-1
169
+ )
170
+ else:
171
+ uj = lambda t, x, nn_params, eq_params: jnp.expand_dims(
172
+ u(t, x, nn_params, eq_params)[j], axis=-1
173
+ )
174
+ lap_on_j = _laplacian_rev(uj, nn_params, eq_params, x, t)
175
+ elif isinstance(u, SPINN):
176
+ if t is None:
177
+ uj = lambda x, nn_params, eq_params: jnp.expand_dims(
178
+ u(x, nn_params, eq_params)[..., j], axis=-1
179
+ )
180
+ else:
181
+ uj = lambda t, x, nn_params, eq_params: jnp.expand_dims(
182
+ u(t, x, nn_params, eq_params)[..., j], axis=-1
183
+ )
184
+ lap_on_j = _laplacian_fwd(uj, nn_params, eq_params, x, t)
185
+
77
186
  return _, lap_on_j
78
187
 
79
188
  _, vec_lap = jax.lax.scan(scan_fun, {}, jnp.arange(u_vec_ndim))
80
189
  return vec_lap
81
190
 
82
191
 
83
- def _u_dot_nabla_times_u(u, nn_params, eq_params, x, t=None):
192
+ def _u_dot_nabla_times_u_rev(u, nn_params, eq_params, x, t=None):
84
193
  r"""
85
- Implement :math:`((\mathbf{u}\cdot\nabla)\mathbf{u})(x)` for x of arbitrary
86
- dimension. Note that :math:`\mathbf{u}` is a vector field from :math:`\mathbb{R}^n`
87
- to :math:`\mathbb{R}^n`
88
- Currently for `x.ndim=2`
89
-
90
- **Note:** We do not use loops but code explicitly the expression to avoid
194
+ Implement :math:`((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})` for
195
+ :math:`\mathbf{x}` of arbitrary
196
+ dimension. :math:`\mathbf{u}` is a vector field from :math:`\mathbb{R}^n`
197
+ to :math:`\mathbb{R}^n`. **Currently for** `x.ndim=2` **only**.
198
+ The computation is done using backward AD.
199
+ We do not use loops but code explicitly the expression to avoid
91
200
  computing twice some terms
92
201
  """
93
202
  if x.shape[0] == 2:
@@ -127,17 +236,64 @@ def _u_dot_nabla_times_u(u, nn_params, eq_params, x, t=None):
127
236
  raise NotImplementedError("x.ndim must be 2")
128
237
 
129
238
 
239
+ def _u_dot_nabla_times_u_fwd(u, nn_params, eq_params, x, t=None):
240
+ r"""
241
+ Implement :math:`((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})` for
242
+ :math:`\mathbf{x}` of arbitrary dimension **with a batch dimension**.
243
+ I.e., :math:`\mathbf{u}` is a vector field from :math:`\mathbb{R}^{b\times
244
+ b}`
245
+ to :math:`\mathbb{R}^{b\times b \times d}`. **Currently for** :math:`d=2`
246
+ **only**.
247
+ We do not use loops but code explicitly the expression to avoid
248
+ computing twice some terms.
249
+ Because of the embedding that happens in SPINNs the
250
+ computation is most efficient with forward AD. This is the idea behind Separable PINNs.
251
+ This function is to be used in the context of SPINNs only.
252
+ """
253
+ if x.shape[-1] == 2:
254
+ tangent_vec_0 = jnp.repeat(jnp.array([1.0, 0.0])[None], x.shape[0], axis=0)
255
+ tangent_vec_1 = jnp.repeat(jnp.array([0.0, 1.0])[None], x.shape[0], axis=0)
256
+ if t is None:
257
+ u_at_x, du_dx = jax.jvp(
258
+ lambda x: u(x, nn_params, eq_params), (x,), (tangent_vec_0,)
259
+ ) # thanks to forward AD this gets dux_dx and duy_dx in a vector
260
+ # ie the derivatives of both components of u wrt x
261
+ # this also gets the vector of u evaluated at x
262
+ u_at_x, du_dy = jax.jvp(
263
+ lambda x: u(x, nn_params, eq_params), (x,), (tangent_vec_1,)
264
+ ) # thanks to forward AD this gets dux_dy and duy_dy in a vector
265
+ # ie the derivatives of both components of u wrt y
266
+
267
+ else:
268
+ u_at_x, du_dx = jax.jvp(
269
+ lambda x: u(t, x, nn_params, eq_params), (x,), (tangent_vec_0,)
270
+ )
271
+ u_at_x, du_dy = jax.jvp(
272
+ lambda x: u(t, x, nn_params, eq_params), (x,), (tangent_vec_1,)
273
+ )
274
+
275
+ return jnp.stack(
276
+ [
277
+ u_at_x[..., 0] * du_dx[..., 0] + u_at_x[..., 1] * du_dy[..., 0],
278
+ u_at_x[..., 0] * du_dx[..., 1] + u_at_x[..., 1] * du_dy[..., 1],
279
+ ],
280
+ axis=-1,
281
+ )
282
+ else:
283
+ raise NotImplementedError("x.ndim must be 2")
284
+
285
+
130
286
  def _sobolev(u, m, statio=True):
131
287
  r"""
132
- Compute the Sobolev regularization of order m
133
- of a scalar field u (from :math:`\mathbb{R}^d1` to :math:`\mathbb{R}`)
134
- for x of arbitrary dimension i.e.
288
+ Compute the Sobolev regularization of order :math:`m`
289
+ of a scalar field :math:`u` (from :math:`\mathbb{R}^{d}` to :math:`\mathbb{R}`)
290
+ for :math:`\mathbf{x}` of arbitrary dimension :math:`d`, i.e.,
135
291
  :math:`\frac{1}{n_l}\sum_{l=1}^{n_l}\sum_{|\alpha|=1}^{m+1} ||\partial^{\alpha} u(x_l)||_2^2` where
136
- :math:`m\geq\max(d_1 // 2, K)` with `K` the order of the differential
292
+ :math:`m\geq\max(d_1 // 2, K)` with :math:`K` the order of the differential
137
293
  operator.
138
294
 
139
- This regularization is proposed in _Convergence and error analysis of
140
- PINNs_, Doumeche et al., 2023, https://arxiv.org/pdf/2305.01240.pdf
295
+ This regularization is proposed in *Convergence and error analysis of
296
+ PINNs*, Doumeche et al., 2023, https://arxiv.org/pdf/2305.01240.pdf
141
297
  """
142
298
 
143
299
  def jac_recursive(u, order, start):
jinns/utils/__init__.py CHANGED
@@ -3,5 +3,6 @@ from ._utils import (
3
3
  euler_maruyama_density,
4
4
  log_euler_maruyama_density,
5
5
  alternate_optax_solver,
6
- create_PINN,
7
6
  )
7
+ from ._pinn import create_PINN
8
+ from ._spinn import create_SPINN