jinns 1.1.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,17 +13,18 @@ from jax import grad
13
13
  import jax.numpy as jnp
14
14
  import equinox as eqx
15
15
 
16
- from jinns.utils._pinn import PINN
17
- from jinns.utils._spinn import SPINN
16
+ from jinns.nn._pinn import PINN
17
+ from jinns.nn._spinn_mlp import SPINN
18
18
 
19
- from jinns.utils._utils import _get_grid
19
+ from jinns.utils._utils import get_grid
20
20
  from jinns.loss._DynamicLossAbstract import ODE, PDEStatio, PDENonStatio
21
21
  from jinns.loss._operators import (
22
- _laplacian_rev,
23
- _laplacian_fwd,
24
- _div_rev,
25
- _div_fwd,
26
- _vectorial_laplacian,
22
+ laplacian_rev,
23
+ laplacian_fwd,
24
+ divergence_rev,
25
+ divergence_fwd,
26
+ vectorial_laplacian_rev,
27
+ vectorial_laplacian_fwd,
27
28
  _u_dot_nabla_times_u_rev,
28
29
  _u_dot_nabla_times_u_fwd,
29
30
  )
@@ -42,24 +43,29 @@ class FisherKPP(PDENonStatio):
42
43
  $$
43
44
  \frac{\partial}{\partial t} u(t,x)=D\Delta u(t,x) + u(t,x)(r(x) - \gamma(x)u(t,x))
44
45
  $$
46
+
47
+ Parameters
48
+ ----------
49
+ dim_x : int, default=1
50
+ The dimension of x, the space domain. Default is 1.
45
51
  """
46
52
 
53
+ dim_x: int = eqx.field(default=1, static=True)
54
+
47
55
  def equation(
48
56
  self,
49
- t: Float[Array, "1"],
50
- x: Float[Array, "dim"],
57
+ t_x: Float[Array, "1+dim"],
51
58
  u: eqx.Module,
52
59
  params: Params,
53
60
  ) -> Float[Array, "1"]:
54
61
  r"""
55
- Evaluate the dynamic loss at $(t,x)$.
62
+ Evaluate the dynamic loss at $(t, x)$.
56
63
 
57
64
  Parameters
58
65
  ---------
59
- t
60
- A time point.
61
- x
62
- A point in $\Omega$.
66
+ t_x
67
+ A jnp array containing the concatenation of a time point
68
+ and a point in $\Omega$
63
69
  u
64
70
  The PINN
65
71
  params
@@ -70,28 +76,31 @@ class FisherKPP(PDENonStatio):
70
76
  """
71
77
  if isinstance(u, PINN):
72
78
  # Note that the last dim of u is nec. 1
73
- u_ = lambda t, x: u(t, x, params)[0]
79
+ u_ = lambda t_x: u(t_x, params)[0]
74
80
 
75
- du_dt = grad(u_, 0)(t, x)
81
+ du_dt = grad(u_)(t_x)[0]
76
82
 
77
- lap = _laplacian_rev(t, x, u, params)[..., None]
83
+ lap = laplacian_rev(t_x, u, params, eq_type=u.eq_type)[..., None]
78
84
 
79
85
  return du_dt + self.Tmax * (
80
86
  -params.eq_params["D"] * lap
81
- - u(t, x, params)
82
- * (params.eq_params["r"] - params.eq_params["g"] * u(t, x, params))
87
+ - u(t_x, params)
88
+ * (params.eq_params["r"] - params.eq_params["g"] * u(t_x, params))
83
89
  )
84
90
  if isinstance(u, SPINN):
91
+ s = jnp.zeros((1, self.dim_x + 1))
92
+ s = s.at[0].set(1.0)
93
+ v0 = jnp.repeat(s, t_x.shape[0], axis=0)
85
94
  u_tx, du_dt = jax.jvp(
86
- lambda t: u(t, x, params),
87
- (t,),
88
- (jnp.ones_like(t),),
95
+ lambda t_x: u(t_x, params),
96
+ (t_x,),
97
+ (v0,),
89
98
  )
90
- lap = _laplacian_fwd(t, x, u, params)[..., None]
99
+ lap = laplacian_fwd(t_x, u, params, eq_type=u.eq_type)
100
+
91
101
  return du_dt + self.Tmax * (
92
102
  -params.eq_params["D"] * lap
93
- - u_tx
94
- * (params.eq_params["r"][..., None] - params.eq_params["g"] * u_tx)
103
+ - u_tx * (params.eq_params["r"] - params.eq_params["g"] * u_tx)
95
104
  )
96
105
  raise ValueError("u is not among the recognized types (PINN or SPINN)")
97
106
 
@@ -181,9 +190,9 @@ class GeneralizedLotkaVolterra(ODE):
181
190
  )
182
191
 
183
192
 
184
- class BurgerEquation(PDENonStatio):
193
+ class BurgersEquation(PDENonStatio):
185
194
  r"""
186
- Return the Burger dynamic loss term (in 1 space dimension):
195
+ Return the Burgers dynamic loss term (in 1 space dimension):
187
196
 
188
197
  $$
189
198
  \frac{\partial}{\partial t} u(t,x) + u(t,x)\frac{\partial}{\partial x}
@@ -207,8 +216,7 @@ class BurgerEquation(PDENonStatio):
207
216
 
208
217
  def equation(
209
218
  self,
210
- t: Float[Array, "1"],
211
- x: Float[Array, "dim"],
219
+ t_x: Float[Array, "1+dim"],
212
220
  u: eqx.Module,
213
221
  params: Params,
214
222
  ) -> Float[Array, "1"]:
@@ -216,44 +224,51 @@ class BurgerEquation(PDENonStatio):
216
224
  Evaluate the dynamic loss at :math:`(t,x)`.
217
225
 
218
226
  Parameters
219
- ---------
220
- t
221
- A time point
222
- x
223
- A point in $\Omega$
227
+ ----------
228
+ t_x
229
+ A jnp array containing the concatenation of a time point
230
+ and a point in $\Omega$
224
231
  u
225
232
  The PINN
226
233
  params
227
234
  The dictionary of parameters of the model.
228
235
  """
229
236
  if isinstance(u, PINN):
230
- # Note that the last dim of u is nec. 1
231
- u_ = lambda t, x: jnp.squeeze(u(t, x, params)[u.slice_solution])
232
- du_dt = grad(u_, 0)
233
- du_dx = grad(u_, 1)
234
- d2u_dx2 = grad(
235
- lambda t, x: du_dx(t, x)[0],
236
- 1,
237
- )
238
-
239
- return du_dt(t, x) + self.Tmax * (
240
- u(t, x, params) * du_dx(t, x) - params.eq_params["nu"] * d2u_dx2(t, x)
237
+ u_ = lambda t_x: jnp.squeeze(u(t_x, params)[u.slice_solution])
238
+ du_dtx = grad(u_)
239
+ d2u_dx_dtx = grad(lambda t_x: du_dtx(t_x)[1])
240
+ du_dtx_values = du_dtx(t_x)
241
+
242
+ return du_dtx_values[0:1] + self.Tmax * (
243
+ u_(t_x) * du_dtx_values[1:2]
244
+ - params.eq_params["nu"] * d2u_dx_dtx(t_x)[1:2]
241
245
  )
242
246
 
243
247
  if isinstance(u, SPINN):
244
248
  # d=2 JVP calls are expected since we have time and x
245
249
  # then with a batch of size B, we then have Bd JVP calls
250
+ v0 = jnp.repeat(jnp.array([[1.0, 0.0]]), t_x.shape[0], axis=0)
251
+ v1 = jnp.repeat(jnp.array([[0.0, 1.0]]), t_x.shape[0], axis=0)
246
252
  u_tx, du_dt = jax.jvp(
247
- lambda t: u(t, x, params),
248
- (t,),
249
- (jnp.ones_like(t),),
253
+ lambda t_x: u(t_x, params),
254
+ (t_x,),
255
+ (v0,),
256
+ )
257
+ _, du_dx = jax.jvp(
258
+ lambda t_x: u(t_x, params),
259
+ (t_x,),
260
+ (v1,),
250
261
  )
251
- du_dx_fun = lambda x: jax.jvp(
252
- lambda x: u(t, x, params),
253
- (x,),
254
- (jnp.ones_like(x),),
262
+ # both calls above could be condensed into the one jacfwd below
263
+ # u_ = lambda t_x: u(t_x, params)
264
+ # J = jax.jacfwd(u_)(t_x)
265
+
266
+ du_dx_fun = lambda t_x: jax.jvp(
267
+ lambda t_x: u(t_x, params),
268
+ (t_x,),
269
+ (v1,),
255
270
  )[1]
256
- du_dx, d2u_dx2 = jax.jvp(du_dx_fun, (x,), (jnp.ones_like(x),))
271
+ _, d2u_dx2 = jax.jvp(du_dx_fun, (t_x,), (v1,))
257
272
  # Note that ones_like(x) works because x is Bx1 !
258
273
  return du_dt + self.Tmax * (u_tx * du_dx - params.eq_params["nu"] * d2u_dx2)
259
274
  raise ValueError("u is not among the recognized types (PINN or SPINN)")
@@ -297,8 +312,7 @@ class FPENonStatioLoss2D(PDENonStatio):
297
312
 
298
313
  def equation(
299
314
  self,
300
- t: Float[Array, "1"],
301
- x: Float[Array, "dim"],
315
+ t_x: Float[Array, "1+dim"],
302
316
  u: eqx.Module,
303
317
  params: Params,
304
318
  ) -> Float[Array, "1"]:
@@ -307,10 +321,8 @@ class FPENonStatioLoss2D(PDENonStatio):
307
321
 
308
322
  Parameters
309
323
  ---------
310
- t
311
- A time point
312
- x
313
- A point in $\Omega$
324
+ t_x
325
+ A collocation point in $I\times\Omega$
314
326
  u
315
327
  The PINN
316
328
  params
@@ -321,114 +333,87 @@ class FPENonStatioLoss2D(PDENonStatio):
321
333
  """
322
334
  if isinstance(u, PINN):
323
335
  # Note that the last dim of u is nec. 1
324
- u_ = lambda t, x: u(t, x, params)[0]
325
-
326
- order_1 = (
327
- grad(
328
- lambda t, x: self.drift(t, x, params.eq_params)[0] * u_(t, x),
329
- 1,
330
- )(
331
- t, x
332
- )[0:1]
333
- + grad(
334
- lambda t, x: self.drift(t, x, params.eq_params)[1] * u_(t, x),
335
- 1,
336
- )(t, x)[1:2]
337
- )
336
+ u_ = lambda t_x: u(t_x, params)[0]
337
+
338
+ order_1_fun = lambda t_x: self.drift(t_x[1:], params.eq_params) * u_(t_x)
339
+ grad_order_1 = jnp.trace(jax.jacrev(order_1_fun)(t_x)[..., 1:])[None]
338
340
 
339
- order_2 = (
340
- grad(
341
- lambda t, x: grad(
342
- lambda t, x: u_(t, x)
343
- * self.diffusion(t, x, params.eq_params)[0, 0],
344
- 1,
345
- )(t, x)[0],
346
- 1,
347
- )(t, x)[0:1]
348
- + grad(
349
- lambda t, x: grad(
350
- lambda t, x: u_(t, x)
351
- * self.diffusion(t, x, params.eq_params)[1, 0],
352
- 1,
353
- )(t, x)[1],
354
- 1,
355
- )(t, x)[0:1]
356
- + grad(
357
- lambda t, x: grad(
358
- lambda t, x: u_(t, x)
359
- * self.diffusion(t, x, params.eq_params)[0, 1],
360
- 1,
361
- )(t, x)[0],
362
- 1,
363
- )(t, x)[1:2]
364
- + grad(
365
- lambda t, x: grad(
366
- lambda t, x: u_(t, x)
367
- * self.diffusion(t, x, params.eq_params)[1, 1],
368
- 1,
369
- )(t, x)[1],
370
- 1,
371
- )(t, x)[1:2]
341
+ order_2_fun = lambda t_x: self.diffusion(t_x[1:], params.eq_params) * u_(
342
+ t_x
343
+ )
344
+ grad_order_2_fun = lambda t_x: jax.jacrev(order_2_fun)(t_x)[..., 1:]
345
+ grad_grad_order_2 = (
346
+ jnp.trace(
347
+ jax.jacrev(lambda t_x: grad_order_2_fun(t_x)[0, :, 0])(t_x)[..., 1:]
348
+ )[None]
349
+ + jnp.trace(
350
+ jax.jacrev(lambda t_x: grad_order_2_fun(t_x)[1, :, 1])(t_x)[..., 1:]
351
+ )[None]
372
352
  )
353
+ # This is be a condensed form of the explicit which is less efficient
354
+ # since 4 jacrev are called (as compared to 2)
355
+ # grad_order_2_fun = lambda t_x, i, j: jax.jacrev(order_2_fun)(t_x)[i, j, 1:]
356
+ # grad_grad_order_2 = (
357
+ # jax.jacrev(lambda t_x: grad_order_2_fun(t_x, 0, 0))(t_x)[0, 1] +
358
+ # jax.jacrev(lambda t_x: grad_order_2_fun(t_x, 1, 0))(t_x)[1, 1] +
359
+ # jax.jacrev(lambda t_x: grad_order_2_fun(t_x, 0, 1))(t_x)[0, 2] +
360
+ # jax.jacrev(lambda t_x: grad_order_2_fun(t_x, 1, 1))(t_x)[1, 2]
361
+ # )[None]
373
362
 
374
- du_dt = grad(u_, 0)(t, x)
363
+ du_dt = grad(u_)(t_x)[0:1]
375
364
 
376
- return -du_dt + self.Tmax * (-order_1 + order_2)
365
+ return -du_dt + self.Tmax * (-grad_order_1 + grad_grad_order_2)
377
366
 
378
367
  if isinstance(u, SPINN):
379
- x_grid = _get_grid(x)
368
+ v0 = jnp.repeat(jnp.array([[1.0, 0.0, 0.0]]), t_x.shape[0], axis=0)
380
369
  _, du_dt = jax.jvp(
381
- lambda t: u(t, x, params),
382
- (t,),
383
- (jnp.ones_like(t),),
370
+ lambda t_x: u(t_x, params),
371
+ (t_x,),
372
+ (v0,),
384
373
  )
385
374
 
386
375
  # in forward AD we do not have the results for all the input
387
376
  # dimension at once (as it is the case with grad), we then write
388
377
  # two jvp calls
389
- tangent_vec_0 = jnp.repeat(jnp.array([1.0, 0.0])[None], x.shape[0], axis=0)
390
- tangent_vec_1 = jnp.repeat(jnp.array([0.0, 1.0])[None], x.shape[0], axis=0)
378
+ v1 = jnp.repeat(jnp.array([[0.0, 1.0, 0.0]]), t_x.shape[0], axis=0)
379
+ v2 = jnp.repeat(jnp.array([[0.0, 0.0, 1.0]]), t_x.shape[0], axis=0)
391
380
  _, dau_dx1 = jax.jvp(
392
- lambda x: self.drift(t, _get_grid(x), params.eq_params)[None, ..., 0:1]
393
- * u(t, x, params)[..., 0:1],
394
- (x,),
395
- (tangent_vec_0,),
381
+ lambda t_x: self.drift(get_grid(t_x[:, 1:]), params.eq_params)[
382
+ None, ..., 0:1
383
+ ]
384
+ * u(t_x, params),
385
+ (t_x,),
386
+ (v1,),
396
387
  )
397
388
  _, dau_dx2 = jax.jvp(
398
- lambda x: self.drift(t, _get_grid(x), params.eq_params)[None, ..., 1:2]
399
- * u(t, x, params)[..., 0:1],
400
- (x,),
401
- (tangent_vec_1,),
389
+ lambda t_x: self.drift(get_grid(t_x[:, 1:]), params.eq_params)[
390
+ None, ..., 1:2
391
+ ]
392
+ * u(t_x, params),
393
+ (t_x,),
394
+ (v2,),
402
395
  )
403
396
 
404
- dsu_dx1_fun = lambda x, i, j: jax.jvp(
405
- lambda x: self.diffusion(t, _get_grid(x), params.eq_params, i, j)[
406
- None, None, None, None
407
- ]
408
- * u(t, x, params)[..., 0:1],
409
- (x,),
410
- (tangent_vec_0,),
397
+ dsu_dx1_fun = lambda t_x, i, j: jax.jvp(
398
+ lambda t_x: self.diffusion(
399
+ get_grid(t_x[:, 1:]), params.eq_params, i, j
400
+ )[None, None, None, None]
401
+ * u(t_x, params),
402
+ (t_x,),
403
+ (v1,),
411
404
  )[1]
412
- dsu_dx2_fun = lambda x, i, j: jax.jvp(
413
- lambda x: self.diffusion(t, _get_grid(x), params.eq_params, i, j)[
414
- None, None, None, None
415
- ]
416
- * u(t, x, params)[..., 0:1],
417
- (x,),
418
- (tangent_vec_1,),
405
+ dsu_dx2_fun = lambda t_x, i, j: jax.jvp(
406
+ lambda t_x: self.diffusion(
407
+ get_grid(t_x[:, 1:]), params.eq_params, i, j
408
+ )[None, None, None, None]
409
+ * u(t_x, params),
410
+ (t_x,),
411
+ (v2,),
419
412
  )[1]
420
- _, d2su_dx12 = jax.jvp(
421
- lambda x: dsu_dx1_fun(x, 0, 0), (x,), (tangent_vec_0,)
422
- )
423
- _, d2su_dx1dx2 = jax.jvp(
424
- lambda x: dsu_dx1_fun(x, 0, 1), (x,), (tangent_vec_1,)
425
- )
426
- _, d2su_dx22 = jax.jvp(
427
- lambda x: dsu_dx2_fun(x, 1, 1), (x,), (tangent_vec_1,)
428
- )
429
- _, d2su_dx2dx1 = jax.jvp(
430
- lambda x: dsu_dx2_fun(x, 1, 0), (x,), (tangent_vec_0,)
431
- )
413
+ _, d2su_dx12 = jax.jvp(lambda t_x: dsu_dx1_fun(t_x, 0, 0), (t_x,), (v1,))
414
+ _, d2su_dx1dx2 = jax.jvp(lambda t_x: dsu_dx1_fun(t_x, 0, 1), (t_x,), (v2,))
415
+ _, d2su_dx22 = jax.jvp(lambda t_x: dsu_dx2_fun(t_x, 1, 1), (t_x,), (v2,))
416
+ _, d2su_dx2dx1 = jax.jvp(lambda t_x: dsu_dx2_fun(t_x, 1, 0), (t_x,), (v1,))
432
417
 
433
418
  return -du_dt + self.Tmax * (
434
419
  -(dau_dx1 + dau_dx2)
@@ -474,14 +459,12 @@ class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
474
459
  heterogeneity for no parameters.
475
460
  """
476
461
 
477
- def drift(self, t, x, eq_params):
462
+ def drift(self, x, eq_params):
478
463
  r"""
479
464
  Return the drift term
480
465
 
481
466
  Parameters
482
467
  ----------
483
- t
484
- A time point
485
468
  x
486
469
  A point in $\Omega$
487
470
  eq_params
@@ -489,15 +472,13 @@ class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
489
472
  """
490
473
  return eq_params["alpha"] * (eq_params["mu"] - x)
491
474
 
492
- def sigma_mat(self, t, x, eq_params):
475
+ def sigma_mat(self, x, eq_params):
493
476
  r"""
494
477
  Return the square root of the diffusion tensor in the sense of the outer
495
478
  product used to create the diffusion tensor
496
479
 
497
480
  Parameters
498
481
  ----------
499
- t
500
- A time point
501
482
  x
502
483
  A point in $\Omega$
503
484
  eq_params
@@ -506,15 +487,13 @@ class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
506
487
 
507
488
  return jnp.diag(eq_params["sigma"])
508
489
 
509
- def diffusion(self, t, x, eq_params, i=None, j=None):
490
+ def diffusion(self, x, eq_params, i=None, j=None):
510
491
  r"""
511
492
  Return the computation of the diffusion tensor term in 2D (or
512
493
  higher)
513
494
 
514
495
  Parameters
515
496
  ----------
516
- t
517
- A time point
518
497
  x
519
498
  A point in $\Omega$
520
499
  eq_params
@@ -523,14 +502,14 @@ class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
523
502
  if i is None or j is None:
524
503
  return 0.5 * (
525
504
  jnp.matmul(
526
- self.sigma_mat(t, x, eq_params),
527
- jnp.transpose(self.sigma_mat(t, x, eq_params)),
505
+ self.sigma_mat(x, eq_params),
506
+ jnp.transpose(self.sigma_mat(x, eq_params)),
528
507
  )
529
508
  )
530
509
  return 0.5 * (
531
510
  jnp.matmul(
532
- self.sigma_mat(t, x, eq_params),
533
- jnp.transpose(self.sigma_mat(t, x, eq_params)),
511
+ self.sigma_mat(x, eq_params),
512
+ jnp.transpose(self.sigma_mat(x, eq_params)),
534
513
  )[i, j]
535
514
  )
536
515
 
@@ -591,12 +570,12 @@ class MassConservation2DStatio(PDEStatio):
591
570
  if isinstance(u_dict[self.nn_key], PINN):
592
571
  u = u_dict[self.nn_key]
593
572
 
594
- return _div_rev(None, x, u, params)[..., None]
573
+ return divergence_rev(x, u, params)[..., None]
595
574
 
596
575
  if isinstance(u_dict[self.nn_key], SPINN):
597
576
  u = u_dict[self.nn_key]
598
577
 
599
- return _div_fwd(None, x, u, params)[..., None]
578
+ return divergence_fwd(x, u, params)[..., None]
600
579
  raise ValueError("u is not among the recognized types (PINN or SPINN)")
601
580
 
602
581
 
@@ -614,12 +593,14 @@ class NavierStokes2DStatio(PDEStatio):
614
593
 
615
594
 
616
595
  $$
617
- \begin{pmatrix}u_x\frac{\partial}{\partial x} u_x + u_y\frac{\partial}{\partial y} u_x \\
596
+ \begin{pmatrix}u_x\frac{\partial}{\partial x} u_x +
597
+ u_y\frac{\partial}{\partial y} u_x, \\
618
598
  u_x\frac{\partial}{\partial x} u_y + u_y\frac{\partial}{\partial y} u_y \end{pmatrix} +
619
- \frac{1}{\rho} \begin{pmatrix} \frac{\partial}{\partial x} p \\ \frac{\partial}{\partial y} p \end{pmatrix}
599
+ \frac{1}{\rho} \begin{pmatrix} \frac{\partial}{\partial x} p, \\ \frac{\partial}{\partial y} p \end{pmatrix}
620
600
  - \theta
621
601
  \begin{pmatrix}
622
- \frac{\partial^2}{\partial x^2} u_x + \frac{\partial^2}{\partial y^2} u_x \\
602
+ \frac{\partial^2}{\partial x^2} u_x + \frac{\partial^2}{\partial y^2}
603
+ u_x, \\
623
604
  \frac{\partial^2}{\partial x^2} u_y + \frac{\partial^2}{\partial y^2} u_y
624
605
  \end{pmatrix} = 0,
625
606
  $$
@@ -680,12 +661,12 @@ class NavierStokes2DStatio(PDEStatio):
680
661
  if isinstance(u_dict[self.u_key], PINN):
681
662
  u = u_dict[self.u_key]
682
663
 
683
- u_dot_nabla_x_u = _u_dot_nabla_times_u_rev(None, x, u, u_params)
664
+ u_dot_nabla_x_u = _u_dot_nabla_times_u_rev(x, u, u_params)
684
665
 
685
666
  p = lambda x: u_dict[self.p_key](x, p_params)
686
667
  jac_p = jax.jacrev(p, 0)(x) # compute the gradient
687
668
 
688
- vec_laplacian_u = _vectorial_laplacian(None, x, u, u_params, u_vec_ndim=2)
669
+ vec_laplacian_u = vectorial_laplacian_rev(x, u, u_params, dim_out=2)
689
670
 
690
671
  # dynamic loss on x axis
691
672
  result_x = (
@@ -707,7 +688,7 @@ class NavierStokes2DStatio(PDEStatio):
707
688
  if isinstance(u_dict[self.u_key], SPINN):
708
689
  u = u_dict[self.u_key]
709
690
 
710
- u_dot_nabla_x_u = _u_dot_nabla_times_u_fwd(None, x, u, u_params)
691
+ u_dot_nabla_x_u = _u_dot_nabla_times_u_fwd(x, u, u_params)
711
692
 
712
693
  p = lambda x: u_dict[self.p_key](x, p_params)
713
694
 
@@ -716,11 +697,7 @@ class NavierStokes2DStatio(PDEStatio):
716
697
  tangent_vec_1 = jnp.repeat(jnp.array([0.0, 1.0])[None], x.shape[0], axis=0)
717
698
  _, dp_dy = jax.jvp(p, (x,), (tangent_vec_1,))
718
699
 
719
- vec_laplacian_u = jnp.moveaxis(
720
- _vectorial_laplacian(None, x, u, u_params, u_vec_ndim=2),
721
- source=0,
722
- destination=-1,
723
- )
700
+ vec_laplacian_u = vectorial_laplacian_fwd(x, u, u_params, dim_out=2)
724
701
 
725
702
  # dynamic loss on x axis
726
703
  result_x = (