jinns 1.1.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/loss/_operators.py CHANGED
@@ -2,45 +2,86 @@
2
2
  Implements diverse operators for dynamic losses
3
3
  """
4
4
 
5
+ from typing import Literal
6
+
5
7
  import jax
6
8
  import jax.numpy as jnp
7
9
  from jax import grad
8
10
  import equinox as eqx
9
11
  from jaxtyping import Float, Array
10
- from jinns.utils._pinn import PINN
11
- from jinns.utils._spinn import SPINN
12
12
  from jinns.parameters._params import Params
13
13
 
14
14
 
15
- def _div_rev(
16
- t: Float[Array, "1"], x: Float[Array, "dimension"], u: eqx.Module, params: Params
15
+ def divergence_rev(
16
+ inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
17
+ u: eqx.Module,
18
+ params: Params,
19
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
17
20
  ) -> float:
18
21
  r"""
19
22
  Compute the divergence of a vector field $\mathbf{u}$, i.e.,
20
- $\nabla \cdot \mathbf{u}(\mathbf{x})$ with $\mathbf{u}$ a vector
21
- field from $\mathbb{R}^d$ to $\mathbb{R}^d$.
23
+ $\nabla_\mathbf{x} \cdot \mathbf{u}(\mathrm{inputs})$ with $\mathbf{u}$ a vector
24
+ field from $\mathbb{R}^d$ to $\mathbb{R}^d$ or $\mathbb{R}^{1+d}$
25
+ to $\mathbb{R}^{1+d}$. Thus, this
26
+ function can be used for stationary or non-stationary PINNs. In the first
27
+ case $\mathrm{inputs}=\mathbf{x}$, in the second case
28
+ case $\mathrm{inputs}=\mathbf{t,x}$.
22
29
  The computation is done using backward AD
30
+
31
+ Parameters
32
+ ----------
33
+ inputs
34
+ `x` or `t_x`
35
+ u
36
+ the PINN
37
+ params
38
+ the PINN parameters
39
+ eq_type
40
+ whether we consider a stationary or non stationary PINN. Most often we
41
+ can know that by inspecting the `u` argument (PINN object). But if `u` is
42
+ a function, we must set this attribute.
23
43
  """
24
44
 
45
+ try:
46
+ eq_type = u.eq_type
47
+ except AttributeError:
48
+ pass # use the value passed as argument
49
+ if eq_type is None:
50
+ raise ValueError("eq_type could not be set!")
51
+
25
52
  def scan_fun(_, i):
26
- if t is None:
27
- du_dxi = grad(lambda x, params: u(x, params)[i], 0)(x, params)[i]
53
+ if eq_type == "nonstatio_PDE":
54
+ du_dxi = grad(lambda inputs, params: u(inputs, params)[1 + i])(
55
+ inputs, params
56
+ )[1 + i]
28
57
  else:
29
- du_dxi = grad(lambda t, x, params: u(t, x, params)[i], 1)(t, x, params)[i]
58
+ du_dxi = grad(lambda inputs, params: u(inputs, params)[i])(inputs, params)[
59
+ i
60
+ ]
30
61
  return _, du_dxi
31
62
 
32
- _, accu = jax.lax.scan(scan_fun, {}, jnp.arange(x.shape[0]))
63
+ if eq_type == "nonstatio_PDE":
64
+ _, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[0] - 1))
65
+ elif eq_type == "statio_PDE":
66
+ _, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[0]))
67
+ else:
68
+ raise ValueError("Unexpected u.eq_type!")
33
69
  return jnp.sum(accu)
34
70
 
35
71
 
36
- def _div_fwd(
37
- t: Float[Array, "1"], x: Float[Array, "dimension"], u: eqx.Module, params: Params
38
- ) -> float:
72
+ def divergence_fwd(
73
+ inputs: Float[Array, "batch_size dim"] | Float[Array, "batch_size 1+dim"],
74
+ u: eqx.Module,
75
+ params: Params,
76
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
77
+ ) -> Float[Array, "batch_size * (1+dim) 1"] | Float[Array, "batch_size * (dim) 1"]:
39
78
  r"""
40
79
  Compute the divergence of a **batched** vector field $\mathbf{u}$, i.e.,
41
- $\nabla \cdot \mathbf{u}(\mathbf{x})$ with $\mathbf{u}$ a vector
80
+ $\nabla_\mathbf{x} \cdot \mathbf{u}(\mathbf{x})$ with $\mathbf{u}$ a vector
42
81
  field from $\mathbb{R}^{b \times d}$ to $\mathbb{R}^{b \times b
43
- \times d}$. The result is then in $\mathbb{R}^{b\times b}$.
82
+ \times d}$ or from $\mathbb{R}^{b \times d+1}$ to $\mathbb{R}^{b \times b
83
+ \times d+1}$. Thus, this
84
+ function can be used for stationary or non-stationary PINNs.
44
85
  Because of the embedding that happens in SPINNs the
45
86
  computation is most efficient with forward AD. This is the idea behind
46
87
  Separable PINNs.
@@ -48,79 +89,180 @@ def _div_fwd(
48
89
  !!! warning "Warning"
49
90
 
50
91
  This function is to be used in the context of SPINNs only.
92
+
93
+ Parameters
94
+ ----------
95
+ inputs
96
+ `x` or `t_x`
97
+ u
98
+ the PINN
99
+ params
100
+ the PINN parameters
101
+ eq_type
102
+ whether we consider a stationary or non stationary PINN. Most often we
103
+ can know that by inspecting the `u` argument (PINN object). But if `u` is
104
+ a function, we must set this attribute.
51
105
  """
52
106
 
107
+ try:
108
+ eq_type = u.eq_type
109
+ except AttributeError:
110
+ pass # use the value passed as argument
111
+ if eq_type is None:
112
+ raise ValueError("eq_type could not be set!")
113
+
53
114
  def scan_fun(_, i):
54
- tangent_vec = jnp.repeat(
55
- jax.nn.one_hot(i, x.shape[-1])[None], x.shape[0], axis=0
56
- )
57
- if t is None:
58
- __, du_dxi = jax.jvp(lambda x: u(x, params)[..., i], (x,), (tangent_vec,))
115
+ if eq_type == "nonstatio_PDE":
116
+ tangent_vec = jnp.repeat(
117
+ jax.nn.one_hot(i + 1, inputs.shape[-1])[None],
118
+ inputs.shape[0],
119
+ axis=0,
120
+ )
121
+ __, du_dxi = jax.jvp(
122
+ lambda inputs: u(inputs, params)[..., 1 + i], (inputs,), (tangent_vec,)
123
+ )
59
124
  else:
125
+ tangent_vec = jnp.repeat(
126
+ jax.nn.one_hot(i, inputs.shape[-1])[None],
127
+ inputs.shape[0],
128
+ axis=0,
129
+ )
60
130
  __, du_dxi = jax.jvp(
61
- lambda x: u(t, x, params)[..., i], (x,), (tangent_vec,)
131
+ lambda inputs: u(inputs, params)[..., i], (inputs,), (tangent_vec,)
62
132
  )
63
133
  return _, du_dxi
64
134
 
65
- _, accu = jax.lax.scan(scan_fun, {}, jnp.arange(x.shape[1]))
135
+ if eq_type == "nonstatio_PDE":
136
+ _, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[1] - 1))
137
+ elif eq_type == "statio_PDE":
138
+ _, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[1]))
139
+ else:
140
+ raise ValueError("Unexpected u.eq_type!")
66
141
  return jnp.sum(accu, axis=0)
67
142
 
68
143
 
69
- def _laplacian_rev(
70
- t: Float[Array, "1"], x: Float[Array, "dimension"], u: eqx.Module, params: Params
144
+ def laplacian_rev(
145
+ inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
146
+ u: eqx.Module,
147
+ params: Params,
148
+ method: Literal["trace_hessian_x", "trace_hessian_t_x", "loop"] = "trace_hessian_x",
149
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
71
150
  ) -> float:
72
151
  r"""
73
- Compute the Laplacian of a scalar field $u$ (from $\mathbb{R}^d$
74
- to $\mathbb{R}$) for $\mathbf{x}$ of arbitrary dimension, i.e.,
75
- $\Delta u(\mathbf{x})=\nabla\cdot\nabla u(\mathbf{x})$.
152
+ Compute the Laplacian of a scalar field $u$ from $\mathbb{R}^d$
153
+ to $\mathbb{R}$ or from $\mathbb{R}^{1+d}$ to $\mathbb{R}$, i.e., this
154
+ function can be used for stationary or non-stationary PINNs. In the first
155
+ case $\mathrm{inputs}=\mathbf{x}$ is of arbitrary dimension, i.e.,
156
+ $\Delta_\mathbf{x} u(\mathbf{x})=\nabla_\mathbf{x}\cdot\nabla_\mathbf{x} u(\mathbf{x})$.
157
+ In the second case $inputs=\mathbf{t,x}$, but we still compute
158
+ $\Delta_\mathbf{x} u(\mathrm{inputs})$.
76
159
  The computation is done using backward AD.
160
+
161
+ Parameters
162
+ ----------
163
+ inputs
164
+ `x` or `t_x`
165
+ u
166
+ the PINN
167
+ params
168
+ the PINN parameters
169
+ method
170
+ how to compute the Laplacian. `"trace_hessian_x"` means that we take
171
+ the trace of the Hessian matrix computed with `x` only (`t` is excluded
172
+ from the beginning, we compute less derivatives at the price of a
173
+ concatenation). `"trace_hessian_t_x"` means that the computation
174
+ of the Hessian integrates `t` which is excluded at the end (we avoid a
175
+ concatenate but we compute more derivatives). `"loop"` means that we
176
+ directly compute the second order derivatives with a loop (we avoid
177
+ non-diagonal derivatives at the cost of a loop).
178
+ eq_type
179
+ whether we consider a stationary or non stationary PINN. Most often we
180
+ can know that by inspecting the `u` argument (PINN object). But if `u` is
181
+ a function, we must set this attribute.
77
182
  """
78
183
 
79
- # Note that the last dim of u is nec. 1
80
- if t is None:
81
- u_ = lambda x: u(x, params)[0]
82
- else:
83
- u_ = lambda t, x: u(t, x, params)[0]
84
-
85
- if t is None:
86
- return jnp.trace(jax.hessian(u_)(x))
87
- return jnp.trace(jax.hessian(u_, argnums=1)(t, x))
88
-
89
- # For a small d, we found out that trace of the Hessian is faster, but the
90
- # trick below for taking directly the diagonal elements might prove useful
91
- # in higher dimensions?
92
-
93
- # def scan_fun(_, i):
94
- # if t is None:
95
- # d2u_dxi2 = grad(
96
- # lambda x: grad(u_, 0)(x)[i],
97
- # 0,
98
- # )(
99
- # x
100
- # )[i]
101
- # else:
102
- # d2u_dxi2 = grad(
103
- # lambda t, x: grad(u_, 1)(t, x)[i],
104
- # 1,
105
- # )(
106
- # t, x
107
- # )[i]
108
- # return _, d2u_dxi2
109
-
110
- # _, trace_hessian = jax.lax.scan(scan_fun, {}, jnp.arange(x.shape[0]))
111
- # return jnp.sum(trace_hessian)
112
-
113
-
114
- def _laplacian_fwd(
115
- t: Float[Array, "batch_size 1"],
116
- x: Float[Array, "batch_size dimension"],
184
+ try:
185
+ eq_type = u.eq_type
186
+ except AttributeError:
187
+ pass # use the value passed as argument
188
+ if eq_type is None:
189
+ raise ValueError("eq_type could not be set!")
190
+
191
+ if method == "trace_hessian_x":
192
+ # NOTE we afford a concatenate here to avoid computing Hessian elements for
193
+ # nothing. In case of simple derivatives we prefer the vectorial
194
+ # computation and then discarding elements but for higher order derivatives
195
+ # it might not be worth it. See other options below for computating the
196
+ # Laplacian
197
+ if eq_type == "nonstatio_PDE":
198
+ u_ = lambda x: jnp.squeeze(
199
+ u(jnp.concatenate([inputs[:1], x], axis=0), params)
200
+ )
201
+ return jnp.sum(jnp.diag(jax.hessian(u_)(inputs[1:])))
202
+ if eq_type == "statio_PDE":
203
+ u_ = lambda inputs: jnp.squeeze(u(inputs, params))
204
+ return jnp.sum(jnp.diag(jax.hessian(u_)(inputs)))
205
+ raise ValueError("Unexpected eq_type!")
206
+ if method == "trace_hessian_t_x":
207
+ # NOTE that it is unclear whether it is better to vectorially compute the
208
+ # Hessian (despite a useless time dimension) as below
209
+ if eq_type == "nonstatio_PDE":
210
+ u_ = lambda inputs: jnp.squeeze(u(inputs, params))
211
+ return jnp.sum(jnp.diag(jax.hessian(u_)(inputs))[1:])
212
+ if eq_type == "statio_PDE":
213
+ u_ = lambda inputs: jnp.squeeze(u(inputs, params))
214
+ return jnp.sum(jnp.diag(jax.hessian(u_)(inputs)))
215
+ raise ValueError("Unexpected eq_type!")
216
+
217
+ if method == "loop":
218
+ # For a small d, we found out that trace of the Hessian is faster, see
219
+ # https://stackoverflow.com/questions/77517357/jax-grad-derivate-with-respect-an-specific-variable-in-a-matrix
220
+ # but could the trick below for taking directly the diagonal elements
221
+ # prove useful in higher dimensions?
222
+
223
+ u_ = lambda inputs: u(inputs, params).squeeze()
224
+
225
+ def scan_fun(_, i):
226
+ if eq_type == "nonstatio_PDE":
227
+ d2u_dxi2 = grad(
228
+ lambda inputs: grad(u_)(inputs)[1 + i],
229
+ )(
230
+ inputs
231
+ )[1 + i]
232
+ else:
233
+ d2u_dxi2 = grad(
234
+ lambda inputs: grad(u_, 0)(inputs)[i],
235
+ 0,
236
+ )(
237
+ inputs
238
+ )[i]
239
+ return _, d2u_dxi2
240
+
241
+ if eq_type == "nonstatio_PDE":
242
+ _, trace_hessian = jax.lax.scan(
243
+ scan_fun, {}, jnp.arange(inputs.shape[0] - 1)
244
+ )
245
+ elif eq_type == "statio_PDE":
246
+ _, trace_hessian = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[0]))
247
+ else:
248
+ raise ValueError("Unexpected eq_type!")
249
+ return jnp.sum(trace_hessian)
250
+ raise ValueError("Unexpected method argument!")
251
+
252
+
253
+ def laplacian_fwd(
254
+ inputs: Float[Array, "batch_size 1+dim"] | Float[Array, "batch_size dim"],
117
255
  u: eqx.Module,
118
256
  params: Params,
119
- ) -> Float[Array, "batch_size batch_size"]:
257
+ method: Literal["trace_hessian_t_x", "trace_hessian_x", "loop"] = "loop",
258
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
259
+ ) -> Float[Array, "batch_size * (1+dim) 1"] | Float[Array, "batch_size * (dim) 1"]:
120
260
  r"""
121
261
  Compute the Laplacian of a **batched** scalar field $u$
122
- (from $\mathbb{R}^{b\times d}$ to $\mathbb{R}^{b\times b}$)
123
- for $\mathbf{x}$ of arbitrary dimension $d$ with batch
262
+ from $\mathbb{R}^{b\times d}$ to $\mathbb{R}^{b\times b}$ or
263
+ from $\mathbb{R}^{b\times (1 + d)}$ to $\mathbb{R}^{b\times b}$ or, i.e., this
264
+ function can be used for stationary or non-stationary PINNs
265
+ for $\mathbf{x}$ of arbitrary dimension $d$ or $1+d$ with batch
124
266
  dimension $b$.
125
267
  Because of the embedding that happens in SPINNs the
126
268
  computation is most efficient with forward AD. This is the idea behind
@@ -129,90 +271,228 @@ def _laplacian_fwd(
129
271
  !!! warning "Warning"
130
272
 
131
273
  This function is to be used in the context of SPINNs only.
274
+
275
+ !!! warning "Warning"
276
+
277
+ Because of the batch dimension, the current implementation of
278
+ `method="trace_hessian_t_x"` or `method="trace_hessian_x"`
279
+ should not be used except for debugging
280
+ purposes. Indeed, computing the Hessian is very costly.
281
+
282
+ Parameters
283
+ ----------
284
+ inputs
285
+ `x` or `t_x`
286
+ u
287
+ the PINN
288
+ params
289
+ the PINN parameters
290
+ method
291
+ how to compute the Laplacian. `"trace_hessian_t_x"` means that the computation
292
+ of the Hessian integrates `t` which is excluded at the end (**see
293
+ Warning**). `"trace_hessian_x"` means an Hessian computation which
294
+ excludes `t` (**see Warning**). `"loop"` means that we
295
+ directly compute the second order derivatives with a loop (we avoid
296
+ non-diagonal derivatives at the cost of a loop).
297
+ eq_type
298
+ whether we consider a stationary or non stationary PINN. Most often we
299
+ can know that by inspecting the `u` argument (PINN object). But if `u` is
300
+ a function, we must set this attribute.
132
301
  """
133
302
 
134
- def scan_fun(_, i):
135
- tangent_vec = jnp.repeat(
136
- jax.nn.one_hot(i, x.shape[-1])[None], x.shape[0], axis=0
137
- )
138
-
139
- if t is None:
140
- du_dxi_fun = lambda x: jax.jvp(
141
- lambda x: u(x, params)[..., 0], (x,), (tangent_vec,)
142
- )[
143
- 1
144
- ] # Note the indexing [..., 0]
145
- __, d2u_dxi2 = jax.jvp(du_dxi_fun, (x,), (tangent_vec,))
303
+ try:
304
+ eq_type = u.eq_type
305
+ except AttributeError:
306
+ pass # use the value passed as argument
307
+ if eq_type is None:
308
+ raise ValueError("eq_type could not be set!")
309
+
310
+ if method == "loop":
311
+
312
+ def scan_fun(_, i):
313
+ if eq_type == "nonstatio_PDE":
314
+ tangent_vec = jnp.repeat(
315
+ jax.nn.one_hot(i + 1, inputs.shape[-1])[None],
316
+ inputs.shape[0],
317
+ axis=0,
318
+ )
319
+ else:
320
+ tangent_vec = jnp.repeat(
321
+ jax.nn.one_hot(i, inputs.shape[-1])[None],
322
+ inputs.shape[0],
323
+ axis=0,
324
+ )
325
+
326
+ du_dxi_fun = lambda inputs: jax.jvp(
327
+ lambda inputs: u(inputs, params),
328
+ (inputs,),
329
+ (tangent_vec,),
330
+ )[1]
331
+ __, d2u_dxi2 = jax.jvp(du_dxi_fun, (inputs,), (tangent_vec,))
332
+ return _, d2u_dxi2
333
+
334
+ if eq_type == "nonstatio_PDE":
335
+ _, trace_hessian = jax.lax.scan(
336
+ scan_fun, {}, jnp.arange(inputs.shape[-1] - 1)
337
+ )
338
+ elif eq_type == "statio_PDE":
339
+ _, trace_hessian = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[-1]))
146
340
  else:
147
- du_dxi_fun = lambda x: jax.jvp(
148
- lambda x: u(t, x, params)[..., 0], (x,), (tangent_vec,)
149
- )[
150
- 1
151
- ] # Note the indexing [..., 0]
152
- __, d2u_dxi2 = jax.jvp(du_dxi_fun, (x,), (tangent_vec,))
153
- return _, d2u_dxi2
341
+ raise ValueError("Unexpected eq_type!")
342
+ return jnp.sum(trace_hessian, axis=0)
343
+ if method == "trace_hessian_t_x":
344
+ if eq_type == "nonstatio_PDE":
345
+ # compute the Hessian including the batch dimension, get rid of the
346
+ # (..,1,..) axis that is here because of the scalar output
347
+ # if inputs.shape==(10,3) (1 for time, 2 for x_dim)
348
+ # then r.shape=(10,10,10,1,10,3,10,3)
349
+ # there are way too much derivatives!
350
+ r = jax.hessian(u)(inputs, params).squeeze()
351
+ # compute the traces by avoid the time derivatives
352
+ # after that r.shape=(10,10,10,10)
353
+ r = jnp.trace(r[..., :, 1:, :, 1:], axis1=-3, axis2=-1)
354
+ # but then we are in a cartesian product, for each coordinate on
355
+ # the first two dimensions we only want the trace at the same
356
+ # coordinate on the last two dimensions
357
+ # this is done easily with einsum but we need to automate the
358
+ # formula according to the input dim
359
+ res_dims = "".join([f"{chr(97 + d)}" for d in range(inputs.shape[-1])])
360
+ lap = jnp.einsum(res_dims + "ii->" + res_dims, r)
361
+ return lap[..., None]
362
+ if eq_type == "statio_PDE":
363
+ # compute the Hessian including the batch dimension, get rid of the
364
+ # (..,1,..) axis that is here because of the scalar output
365
+ # if inputs.shape==(10,2), r.shape=(10,10,1,10,2,10,2)
366
+ # there are way too much derivatives!
367
+ r = jax.hessian(u)(inputs, params).squeeze()
368
+ # compute the traces, after that r.shape=(10,10,10,10)
369
+ r = jnp.trace(r, axis1=-3, axis2=-1)
370
+ # but then we are in a cartesian product, for each coordinate on
371
+ # the first two dimensions we only want the trace at the same
372
+ # coordinate on the last two dimensions
373
+ # this is done easily with einsum but we need to automate the
374
+ # formula according to the input dim
375
+ res_dims = "".join([f"{chr(97 + d)}" for d in range(inputs.shape[-1])])
376
+ lap = jnp.einsum(res_dims + "ii->" + res_dims, r)
377
+ return lap[..., None]
378
+ raise ValueError("Unexpected eq_type!")
379
+ if method == "trace_hessian_x":
380
+ if eq_type == "statio_PDE":
381
+ # compute the Hessian including the batch dimension, get rid of the
382
+ # (..,1,..) axis that is here because of the scalar output
383
+ # if inputs.shape==(10,2), r.shape=(10,10,1,10,2,10,2)
384
+ # there are way too much derivatives!
385
+ r = jax.hessian(u)(inputs, params).squeeze()
386
+ # compute the traces, after that r.shape=(10,10,10,10)
387
+ r = jnp.trace(r, axis1=-3, axis2=-1)
388
+ # but then we are in a cartesian product, for each coordinate on
389
+ # the first two dimensions we only want the trace at the same
390
+ # coordinate on the last two dimensions
391
+ # this is done easily with einsum but we need to automate the
392
+ # formula according to the input dim
393
+ res_dims = "".join([f"{chr(97 + d)}" for d in range(inputs.shape[-1])])
394
+ lap = jnp.einsum(res_dims + "ii->" + res_dims, r)
395
+ return lap[..., None]
396
+ raise ValueError("Unexpected eq_type!")
397
+ raise ValueError("Unexpected method argument!")
398
+
399
+
400
+ def vectorial_laplacian_rev(
401
+ inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
402
+ u: eqx.Module,
403
+ params: Params,
404
+ dim_out: int = None,
405
+ ) -> Float[Array, "dim_out"]:
406
+ r"""
407
+ Compute the vectorial Laplacian of a vector field $\mathbf{u}$ from
408
+ $\mathbb{R}^d$ to $\mathbb{R}^n$ or from $\mathbb{R}^{1+d}$ to
409
+ $\mathbb{R}^n$, i.e., this
410
+ function can be used for stationary or non-stationary PINNs. In the first
411
+ case $\mathrm{inputs}=\mathbf{x}$ is of arbitrary dimension, i.e.,
412
+ $\Delta_\mathbf{x} \mathbf{u}(\mathbf{x})=\nabla\cdot\nabla
413
+ \mathbf{u}(\mathbf{x})$.
414
+ In the second case $inputs=\mathbf{t,x}$, and we perform
415
+ $\Delta_\mathbf{x} \mathbf{u}(\mathrm{inputs})=\nabla\cdot\nabla
416
+ \mathbf{u}(\mathrm{inputs})$.
417
+
418
+ Parameters
419
+ ----------
420
+ inputs
421
+ `x` or `t_x`
422
+ u
423
+ the PINN
424
+ params
425
+ the PINN parameters
426
+ dim_out
427
+ Dimension of the vector $\mathbf{u}(\mathrm{inputs})$. This needs to be
428
+ provided if it is different than that of $\mathrm{inputs}$.
429
+ """
430
+ if dim_out is None:
431
+ dim_out = inputs.shape[0]
432
+
433
+ def scan_fun(_, j):
434
+ # The loop over the components of u(x). We compute one Laplacian for
435
+ # each of these components
436
+ # Note the jnp.expand_dims call
437
+ uj = lambda inputs, params: jnp.expand_dims(u(inputs, params)[j], axis=-1)
438
+ lap_on_j = laplacian_rev(inputs, uj, params, eq_type=u.eq_type)
154
439
 
155
- _, trace_hessian = jax.lax.scan(scan_fun, {}, jnp.arange(x.shape[1]))
156
- return jnp.sum(trace_hessian, axis=0)
440
+ return _, lap_on_j
157
441
 
442
+ _, vec_lap = jax.lax.scan(scan_fun, {}, jnp.arange(dim_out))
443
+ return vec_lap
158
444
 
159
- def _vectorial_laplacian(
160
- t: Float[Array, "1"] | Float[Array, "batch_size 1"],
161
- x: Float[Array, "dimension_in"] | Float[Array, "batch_size dimension"],
445
+
446
+ def vectorial_laplacian_fwd(
447
+ inputs: Float[Array, "batch_size dim"] | Float[Array, "batch_size 1+dim"],
162
448
  u: eqx.Module,
163
449
  params: Params,
164
- u_vec_ndim: int = None,
165
- ) -> (
166
- Float[Array, "dimension_out"] | Float[Array, "batch_size batch_size dimension_out"]
167
- ):
450
+ dim_out: int = None,
451
+ ) -> Float[Array, "batch_size * (1+dim) n"] | Float[Array, "batch_size * (dim) n"]:
168
452
  r"""
169
- Compute the vectorial Laplacian of a vector field $\mathbf{u}$ (from
170
- $\mathbb{R}^d$
171
- to $\mathbb{R}^n$) for $\mathbf{x}$ of arbitrary dimension, i.e.,
172
- $\Delta \mathbf{u}(\mathbf{x})=\nabla\cdot\nabla
173
- \mathbf{u}(\mathbf{x})$.
453
+ Compute the vectorial Laplacian of a vector field $\mathbf{u}$ when
454
+ `u` is a SPINN, in this case, it corresponds to a vector
455
+ field from from $\mathbb{R}^{b\times d}$ to
456
+ $\mathbb{R}^{b\times b\times n}$ or from $\mathbb{R}^{b\times 1+d}$ to
457
+ $\mathbb{R}^{b\times b\times n}$, i.e., this
458
+ function can be used for stationary or non-stationary PINNs.
459
+
460
+ Forward mode AD is used.
461
+
462
+ !!! warning "Warning"
174
463
 
175
- **Note:** We need to provide `u_vec_ndim` the dimension of the vector
176
- $\mathbf{u}(\mathbf{x})$ if it is different than that of
177
- $\mathbf{x}$.
464
+ This function is to be used in the context of SPINNs only.
178
465
 
179
- **Note:** `u` can be a SPINN, in this case, it corresponds to a vector
180
- field from (from $\mathbb{R}^{b\times d}$ to
181
- $\mathbb{R}^{b\times b\times n}$) and forward mode AD is used.
182
- Technically, the return is of dimension $n\times b \times b$.
466
+ Parameters
467
+ ----------
468
+ inputs
469
+ `x` or `t_x`
470
+ u
471
+ the PINN
472
+ params
473
+ the PINN parameters
474
+ dim_out
475
+ the value of the output dimension ($n$ in the formula above). Must be
476
+ set if different from $d$.
183
477
  """
184
- if u_vec_ndim is None:
185
- u_vec_ndim = x.shape[0]
478
+ if dim_out is None:
479
+ dim_out = inputs.shape[0]
186
480
 
187
481
  def scan_fun(_, j):
188
482
  # The loop over the components of u(x). We compute one Laplacian for
189
483
  # each of these components
190
484
  # Note the expand_dims
191
- if isinstance(u, PINN):
192
- if t is None:
193
- uj = lambda x, params: jnp.expand_dims(u(x, params)[j], axis=-1)
194
- else:
195
- uj = lambda t, x, params: jnp.expand_dims(u(t, x, params)[j], axis=-1)
196
- lap_on_j = _laplacian_rev(t, x, uj, params)
197
- elif isinstance(u, SPINN):
198
- if t is None:
199
- uj = lambda x, params: jnp.expand_dims(u(x, params)[..., j], axis=-1)
200
- else:
201
- uj = lambda t, x, params: jnp.expand_dims(
202
- u(t, x, params)[..., j], axis=-1
203
- )
204
- lap_on_j = _laplacian_fwd(t, x, uj, params)
205
- else:
206
- raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
485
+ uj = lambda inputs, params: jnp.expand_dims(u(inputs, params)[..., j], axis=-1)
486
+ lap_on_j = laplacian_fwd(inputs, uj, params, eq_type=u.eq_type)
207
487
 
208
488
  return _, lap_on_j
209
489
 
210
- _, vec_lap = jax.lax.scan(scan_fun, {}, jnp.arange(u_vec_ndim))
211
- return vec_lap
490
+ _, vec_lap = jax.lax.scan(scan_fun, {}, jnp.arange(dim_out))
491
+ return jnp.moveaxis(vec_lap.squeeze(), 0, -1)
212
492
 
213
493
 
214
494
  def _u_dot_nabla_times_u_rev(
215
- t: Float[Array, "1"], x: Float[Array, "2"], u: eqx.Module, params: Params
495
+ x: Float[Array, "2"], u: eqx.Module, params: Params
216
496
  ) -> Float[Array, "2"]:
217
497
  r"""
218
498
  Implement $((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})$ for
@@ -223,43 +503,25 @@ def _u_dot_nabla_times_u_rev(
223
503
  We do not use loops but code explicitly the expression to avoid
224
504
  computing twice some terms
225
505
  """
226
- if x.shape[0] == 2:
227
- if t is None:
228
- ux = lambda x: u(x, params)[0]
229
- uy = lambda x: u(x, params)[1]
230
-
231
- dux_dx = lambda x: grad(ux, 0)(x)[0]
232
- dux_dy = lambda x: grad(ux, 0)(x)[1]
233
-
234
- duy_dx = lambda x: grad(uy, 0)(x)[0]
235
- duy_dy = lambda x: grad(uy, 0)(x)[1]
236
-
237
- return jnp.array(
238
- [
239
- ux(x) * dux_dx(x) + uy(x) * dux_dy(x),
240
- ux(x) * duy_dx(x) + uy(x) * duy_dy(x),
241
- ]
242
- )
243
- ux = lambda t, x: u(t, x, params)[0]
244
- uy = lambda t, x: u(t, x, params)[1]
506
+ assert x.shape[0] == 2
507
+ ux = lambda x: u(x, params)[0]
508
+ uy = lambda x: u(x, params)[1]
245
509
 
246
- dux_dx = lambda t, x: grad(ux, 1)(t, x)[0]
247
- dux_dy = lambda t, x: grad(ux, 1)(t, x)[1]
510
+ dux_dx = lambda x: grad(ux, 0)(x)[0]
511
+ dux_dy = lambda x: grad(ux, 0)(x)[1]
248
512
 
249
- duy_dx = lambda t, x: grad(uy, 1)(t, x)[0]
250
- duy_dy = lambda t, x: grad(uy, 1)(t, x)[1]
513
+ duy_dx = lambda x: grad(uy, 0)(x)[0]
514
+ duy_dy = lambda x: grad(uy, 0)(x)[1]
251
515
 
252
- return jnp.array(
253
- [
254
- ux(t, x) * dux_dx(t, x) + uy(t, x) * dux_dy(t, x),
255
- ux(t, x) * duy_dx(t, x) + uy(t, x) * duy_dy(t, x),
256
- ]
257
- )
258
- raise NotImplementedError("x.ndim must be 2")
516
+ return jnp.array(
517
+ [
518
+ ux(x) * dux_dx(x) + uy(x) * dux_dy(x),
519
+ ux(x) * duy_dx(x) + uy(x) * duy_dy(x),
520
+ ]
521
+ )
259
522
 
260
523
 
261
524
  def _u_dot_nabla_times_u_fwd(
262
- t: Float[Array, "batch_size 1"],
263
525
  x: Float[Array, "batch_size 2"],
264
526
  u: eqx.Module,
265
527
  params: Params,
@@ -277,29 +539,22 @@ def _u_dot_nabla_times_u_fwd(
277
539
  computation is most efficient with forward AD. This is the idea behind Separable PINNs.
278
540
  This function is to be used in the context of SPINNs only.
279
541
  """
280
- if x.shape[-1] == 2:
281
- tangent_vec_0 = jnp.repeat(jnp.array([1.0, 0.0])[None], x.shape[0], axis=0)
282
- tangent_vec_1 = jnp.repeat(jnp.array([0.0, 1.0])[None], x.shape[0], axis=0)
283
- if t is None:
284
- u_at_x, du_dx = jax.jvp(
285
- lambda x: u(x, params), (x,), (tangent_vec_0,)
286
- ) # thanks to forward AD this gets dux_dx and duy_dx in a vector
287
- # ie the derivatives of both components of u wrt x
288
- # this also gets the vector of u evaluated at x
289
- u_at_x, du_dy = jax.jvp(
290
- lambda x: u(x, params), (x,), (tangent_vec_1,)
291
- ) # thanks to forward AD this gets dux_dy and duy_dy in a vector
292
- # ie the derivatives of both components of u wrt y
293
-
294
- else:
295
- u_at_x, du_dx = jax.jvp(lambda x: u(t, x, params), (x,), (tangent_vec_0,))
296
- u_at_x, du_dy = jax.jvp(lambda x: u(t, x, params), (x,), (tangent_vec_1,))
297
-
298
- return jnp.stack(
299
- [
300
- u_at_x[..., 0] * du_dx[..., 0] + u_at_x[..., 1] * du_dy[..., 0],
301
- u_at_x[..., 0] * du_dx[..., 1] + u_at_x[..., 1] * du_dy[..., 1],
302
- ],
303
- axis=-1,
304
- )
305
- raise NotImplementedError("x.ndim must be 2")
542
+ assert x.shape[-1] == 2
543
+ tangent_vec_0 = jnp.repeat(jnp.array([1.0, 0.0])[None], x.shape[0], axis=0)
544
+ tangent_vec_1 = jnp.repeat(jnp.array([0.0, 1.0])[None], x.shape[0], axis=0)
545
+ u_at_x, du_dx = jax.jvp(
546
+ lambda x: u(x, params), (x,), (tangent_vec_0,)
547
+ ) # thanks to forward AD this gets dux_dx and duy_dx in a vector
548
+ # ie the derivatives of both components of u wrt x
549
+ # this also gets the vector of u evaluated at x
550
+ u_at_x, du_dy = jax.jvp(
551
+ lambda x: u(x, params), (x,), (tangent_vec_1,)
552
+ ) # thanks to forward AD this gets dux_dy and duy_dy in a vector
553
+ # ie the derivatives of both components of u wrt y
554
+ return jnp.stack(
555
+ [
556
+ u_at_x[..., 0] * du_dx[..., 0] + u_at_x[..., 1] * du_dy[..., 0],
557
+ u_at_x[..., 0] * du_dx[..., 1] + u_at_x[..., 1] * du_dy[..., 1],
558
+ ],
559
+ axis=-1,
560
+ )