jinns 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/loss/_operators.py CHANGED
@@ -2,6 +2,8 @@
2
2
  Implements diverse operators for dynamic losses
3
3
  """
4
4
 
5
+ from typing import Literal
6
+
5
7
  import jax
6
8
  import jax.numpy as jnp
7
9
  from jax import grad
@@ -12,35 +14,76 @@ from jinns.utils._spinn import SPINN
12
14
  from jinns.parameters._params import Params
13
15
 
14
16
 
15
- def _div_rev(
16
- t: Float[Array, "1"], x: Float[Array, "dimension"], u: eqx.Module, params: Params
17
+ def divergence_rev(
18
+ inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
19
+ u: eqx.Module,
20
+ params: Params,
21
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
17
22
  ) -> float:
18
23
  r"""
19
24
  Compute the divergence of a vector field $\mathbf{u}$, i.e.,
20
- $\nabla \cdot \mathbf{u}(\mathbf{x})$ with $\mathbf{u}$ a vector
21
- field from $\mathbb{R}^d$ to $\mathbb{R}^d$.
25
+ $\nabla_\mathbf{x} \cdot \mathbf{u}(\mathrm{inputs})$ with $\mathbf{u}$ a vector
26
+ field from $\mathbb{R}^d$ to $\mathbb{R}^d$ or $\mathbb{R}^{1+d}$
27
+ to $\mathbb{R}^{1+d}$. Thus, this
28
+ function can be used for stationary or non-stationary PINNs. In the first
29
+ case $\mathrm{inputs}=\mathbf{x}$, in the second case
30
+ case $\mathrm{inputs}=\mathbf{t,x}$.
22
31
  The computation is done using backward AD
32
+
33
+ Parameters
34
+ ----------
35
+ inputs
36
+ `x` or `t_x`
37
+ u
38
+ the PINN
39
+ params
40
+ the PINN parameters
41
+ eq_type
42
+ whether we consider a stationary or non stationary PINN. Most often we
43
+ can know that by inspecting the `u` argument (PINN object). But if `u` is
44
+ a function, we must set this attribute.
23
45
  """
24
46
 
47
+ try:
48
+ eq_type = u.eq_type
49
+ except AttributeError:
50
+ pass # use the value passed as argument
51
+ if eq_type is None:
52
+ raise ValueError("eq_type could not be set!")
53
+
25
54
  def scan_fun(_, i):
26
- if t is None:
27
- du_dxi = grad(lambda x, params: u(x, params)[i], 0)(x, params)[i]
55
+ if eq_type == "nonstatio_PDE":
56
+ du_dxi = grad(lambda inputs, params: u(inputs, params)[1 + i])(
57
+ inputs, params
58
+ )[1 + i]
28
59
  else:
29
- du_dxi = grad(lambda t, x, params: u(t, x, params)[i], 1)(t, x, params)[i]
60
+ du_dxi = grad(lambda inputs, params: u(inputs, params)[i])(inputs, params)[
61
+ i
62
+ ]
30
63
  return _, du_dxi
31
64
 
32
- _, accu = jax.lax.scan(scan_fun, {}, jnp.arange(x.shape[0]))
65
+ if eq_type == "nonstatio_PDE":
66
+ _, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[0] - 1))
67
+ elif eq_type == "statio_PDE":
68
+ _, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[0]))
69
+ else:
70
+ raise ValueError("Unexpected u.eq_type!")
33
71
  return jnp.sum(accu)
34
72
 
35
73
 
36
- def _div_fwd(
37
- t: Float[Array, "1"], x: Float[Array, "dimension"], u: eqx.Module, params: Params
38
- ) -> float:
74
+ def divergence_fwd(
75
+ inputs: Float[Array, "batch_size dim"] | Float[Array, "batch_size 1+dim"],
76
+ u: eqx.Module,
77
+ params: Params,
78
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
79
+ ) -> Float[Array, "batch_size * (1+dim) 1"] | Float[Array, "batch_size * (dim) 1"]:
39
80
  r"""
40
81
  Compute the divergence of a **batched** vector field $\mathbf{u}$, i.e.,
41
- $\nabla \cdot \mathbf{u}(\mathbf{x})$ with $\mathbf{u}$ a vector
82
+ $\nabla_\mathbf{x} \cdot \mathbf{u}(\mathbf{x})$ with $\mathbf{u}$ a vector
42
83
  field from $\mathbb{R}^{b \times d}$ to $\mathbb{R}^{b \times b
43
- \times d}$. The result is then in $\mathbb{R}^{b\times b}$.
84
+ \times d}$ or from $\mathbb{R}^{b \times d+1}$ to $\mathbb{R}^{b \times b
85
+ \times d+1}$. Thus, this
86
+ function can be used for stationary or non-stationary PINNs.
44
87
  Because of the embedding that happens in SPINNs the
45
88
  computation is most efficient with forward AD. This is the idea behind
46
89
  Separable PINNs.
@@ -48,79 +91,180 @@ def _div_fwd(
48
91
  !!! warning "Warning"
49
92
 
50
93
  This function is to be used in the context of SPINNs only.
94
+
95
+ Parameters
96
+ ----------
97
+ inputs
98
+ `x` or `t_x`
99
+ u
100
+ the PINN
101
+ params
102
+ the PINN parameters
103
+ eq_type
104
+ whether we consider a stationary or non stationary PINN. Most often we
105
+ can know that by inspecting the `u` argument (PINN object). But if `u` is
106
+ a function, we must set this attribute.
51
107
  """
52
108
 
109
+ try:
110
+ eq_type = u.eq_type
111
+ except AttributeError:
112
+ pass # use the value passed as argument
113
+ if eq_type is None:
114
+ raise ValueError("eq_type could not be set!")
115
+
53
116
  def scan_fun(_, i):
54
- tangent_vec = jnp.repeat(
55
- jax.nn.one_hot(i, x.shape[-1])[None], x.shape[0], axis=0
56
- )
57
- if t is None:
58
- __, du_dxi = jax.jvp(lambda x: u(x, params)[..., i], (x,), (tangent_vec,))
117
+ if eq_type == "nonstatio_PDE":
118
+ tangent_vec = jnp.repeat(
119
+ jax.nn.one_hot(i + 1, inputs.shape[-1])[None],
120
+ inputs.shape[0],
121
+ axis=0,
122
+ )
123
+ __, du_dxi = jax.jvp(
124
+ lambda inputs: u(inputs, params)[..., 1 + i], (inputs,), (tangent_vec,)
125
+ )
59
126
  else:
127
+ tangent_vec = jnp.repeat(
128
+ jax.nn.one_hot(i, inputs.shape[-1])[None],
129
+ inputs.shape[0],
130
+ axis=0,
131
+ )
60
132
  __, du_dxi = jax.jvp(
61
- lambda x: u(t, x, params)[..., i], (x,), (tangent_vec,)
133
+ lambda inputs: u(inputs, params)[..., i], (inputs,), (tangent_vec,)
62
134
  )
63
135
  return _, du_dxi
64
136
 
65
- _, accu = jax.lax.scan(scan_fun, {}, jnp.arange(x.shape[1]))
137
+ if eq_type == "nonstatio_PDE":
138
+ _, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[1] - 1))
139
+ elif eq_type == "statio_PDE":
140
+ _, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[1]))
141
+ else:
142
+ raise ValueError("Unexpected u.eq_type!")
66
143
  return jnp.sum(accu, axis=0)
67
144
 
68
145
 
69
- def _laplacian_rev(
70
- t: Float[Array, "1"], x: Float[Array, "dimension"], u: eqx.Module, params: Params
146
+ def laplacian_rev(
147
+ inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
148
+ u: eqx.Module,
149
+ params: Params,
150
+ method: Literal["trace_hessian_x", "trace_hessian_t_x", "loop"] = "trace_hessian_x",
151
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
71
152
  ) -> float:
72
153
  r"""
73
- Compute the Laplacian of a scalar field $u$ (from $\mathbb{R}^d$
74
- to $\mathbb{R}$) for $\mathbf{x}$ of arbitrary dimension, i.e.,
75
- $\Delta u(\mathbf{x})=\nabla\cdot\nabla u(\mathbf{x})$.
154
+ Compute the Laplacian of a scalar field $u$ from $\mathbb{R}^d$
155
+ to $\mathbb{R}$ or from $\mathbb{R}^{1+d}$ to $\mathbb{R}$, i.e., this
156
+ function can be used for stationary or non-stationary PINNs. In the first
157
+ case $\mathrm{inputs}=\mathbf{x}$ is of arbitrary dimension, i.e.,
158
+ $\Delta_\mathbf{x} u(\mathbf{x})=\nabla_\mathbf{x}\cdot\nabla_\mathbf{x} u(\mathbf{x})$.
159
+ In the second case $inputs=\mathbf{t,x}$, but we still compute
160
+ $\Delta_\mathbf{x} u(\mathrm{inputs})$.
76
161
  The computation is done using backward AD.
162
+
163
+ Parameters
164
+ ----------
165
+ inputs
166
+ `x` or `t_x`
167
+ u
168
+ the PINN
169
+ params
170
+ the PINN parameters
171
+ method
172
+ how to compute the Laplacian. `"trace_hessian_x"` means that we take
173
+ the trace of the Hessian matrix computed with `x` only (`t` is excluded
174
+ from the beginning, we compute less derivatives at the price of a
175
+ concatenation). `"trace_hessian_t_x"` means that the computation
176
+ of the Hessian integrates `t` which is excluded at the end (we avoid a
177
+ concatenate but we compute more derivatives). `"loop"` means that we
178
+ directly compute the second order derivatives with a loop (we avoid
179
+ non-diagonal derivatives at the cost of a loop).
180
+ eq_type
181
+ whether we consider a stationary or non stationary PINN. Most often we
182
+ can know that by inspecting the `u` argument (PINN object). But if `u` is
183
+ a function, we must set this attribute.
77
184
  """
78
185
 
79
- # Note that the last dim of u is nec. 1
80
- if t is None:
81
- u_ = lambda x: u(x, params)[0]
82
- else:
83
- u_ = lambda t, x: u(t, x, params)[0]
84
-
85
- if t is None:
86
- return jnp.trace(jax.hessian(u_)(x))
87
- return jnp.trace(jax.hessian(u_, argnums=1)(t, x))
88
-
89
- # For a small d, we found out that trace of the Hessian is faster, but the
90
- # trick below for taking directly the diagonal elements might prove useful
91
- # in higher dimensions?
92
-
93
- # def scan_fun(_, i):
94
- # if t is None:
95
- # d2u_dxi2 = grad(
96
- # lambda x: grad(u_, 0)(x)[i],
97
- # 0,
98
- # )(
99
- # x
100
- # )[i]
101
- # else:
102
- # d2u_dxi2 = grad(
103
- # lambda t, x: grad(u_, 1)(t, x)[i],
104
- # 1,
105
- # )(
106
- # t, x
107
- # )[i]
108
- # return _, d2u_dxi2
109
-
110
- # _, trace_hessian = jax.lax.scan(scan_fun, {}, jnp.arange(x.shape[0]))
111
- # return jnp.sum(trace_hessian)
112
-
113
-
114
- def _laplacian_fwd(
115
- t: Float[Array, "batch_size 1"],
116
- x: Float[Array, "batch_size dimension"],
186
+ try:
187
+ eq_type = u.eq_type
188
+ except AttributeError:
189
+ pass # use the value passed as argument
190
+ if eq_type is None:
191
+ raise ValueError("eq_type could not be set!")
192
+
193
+ if method == "trace_hessian_x":
194
+ # NOTE we afford a concatenate here to avoid computing Hessian elements for
195
+ # nothing. In case of simple derivatives we prefer the vectorial
196
+ # computation and then discarding elements but for higher order derivatives
197
+ # it might not be worth it. See other options below for computating the
198
+ # Laplacian
199
+ if eq_type == "nonstatio_PDE":
200
+ u_ = lambda x: jnp.squeeze(
201
+ u(jnp.concatenate([inputs[:1], x], axis=0), params)
202
+ )
203
+ return jnp.sum(jnp.diag(jax.hessian(u_)(inputs[1:])))
204
+ if eq_type == "statio_PDE":
205
+ u_ = lambda inputs: jnp.squeeze(u(inputs, params))
206
+ return jnp.sum(jnp.diag(jax.hessian(u_)(inputs)))
207
+ raise ValueError("Unexpected eq_type!")
208
+ if method == "trace_hessian_t_x":
209
+ # NOTE that it is unclear whether it is better to vectorially compute the
210
+ # Hessian (despite a useless time dimension) as below
211
+ if eq_type == "nonstatio_PDE":
212
+ u_ = lambda inputs: jnp.squeeze(u(inputs, params))
213
+ return jnp.sum(jnp.diag(jax.hessian(u_)(inputs))[1:])
214
+ if eq_type == "statio_PDE":
215
+ u_ = lambda inputs: jnp.squeeze(u(inputs, params))
216
+ return jnp.sum(jnp.diag(jax.hessian(u_)(inputs)))
217
+ raise ValueError("Unexpected eq_type!")
218
+
219
+ if method == "loop":
220
+ # For a small d, we found out that trace of the Hessian is faster, see
221
+ # https://stackoverflow.com/questions/77517357/jax-grad-derivate-with-respect-an-specific-variable-in-a-matrix
222
+ # but could the trick below for taking directly the diagonal elements
223
+ # prove useful in higher dimensions?
224
+
225
+ u_ = lambda inputs: u(inputs, params).squeeze()
226
+
227
+ def scan_fun(_, i):
228
+ if eq_type == "nonstatio_PDE":
229
+ d2u_dxi2 = grad(
230
+ lambda inputs: grad(u_)(inputs)[1 + i],
231
+ )(
232
+ inputs
233
+ )[1 + i]
234
+ else:
235
+ d2u_dxi2 = grad(
236
+ lambda inputs: grad(u_, 0)(inputs)[i],
237
+ 0,
238
+ )(
239
+ inputs
240
+ )[i]
241
+ return _, d2u_dxi2
242
+
243
+ if eq_type == "nonstatio_PDE":
244
+ _, trace_hessian = jax.lax.scan(
245
+ scan_fun, {}, jnp.arange(inputs.shape[0] - 1)
246
+ )
247
+ elif eq_type == "statio_PDE":
248
+ _, trace_hessian = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[0]))
249
+ else:
250
+ raise ValueError("Unexpected eq_type!")
251
+ return jnp.sum(trace_hessian)
252
+ raise ValueError("Unexpected method argument!")
253
+
254
+
255
+ def laplacian_fwd(
256
+ inputs: Float[Array, "batch_size 1+dim"] | Float[Array, "batch_size dim"],
117
257
  u: eqx.Module,
118
258
  params: Params,
119
- ) -> Float[Array, "batch_size batch_size"]:
259
+ method: Literal["trace_hessian_t_x", "trace_hessian_x", "loop"] = "loop",
260
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
261
+ ) -> Float[Array, "batch_size * (1+dim) 1"] | Float[Array, "batch_size * (dim) 1"]:
120
262
  r"""
121
263
  Compute the Laplacian of a **batched** scalar field $u$
122
- (from $\mathbb{R}^{b\times d}$ to $\mathbb{R}^{b\times b}$)
123
- for $\mathbf{x}$ of arbitrary dimension $d$ with batch
264
+ from $\mathbb{R}^{b\times d}$ to $\mathbb{R}^{b\times b}$ or
265
+ from $\mathbb{R}^{b\times (1 + d)}$ to $\mathbb{R}^{b\times b}$ or, i.e., this
266
+ function can be used for stationary or non-stationary PINNs
267
+ for $\mathbf{x}$ of arbitrary dimension $d$ or $1+d$ with batch
124
268
  dimension $b$.
125
269
  Because of the embedding that happens in SPINNs the
126
270
  computation is most efficient with forward AD. This is the idea behind
@@ -129,90 +273,228 @@ def _laplacian_fwd(
129
273
  !!! warning "Warning"
130
274
 
131
275
  This function is to be used in the context of SPINNs only.
276
+
277
+ !!! warning "Warning"
278
+
279
+ Because of the batch dimension, the current implementation of
280
+ `method="trace_hessian_t_x"` or `method="trace_hessian_x"`
281
+ should not be used except for debugging
282
+ purposes. Indeed, computing the Hessian is very costly.
283
+
284
+ Parameters
285
+ ----------
286
+ inputs
287
+ `x` or `t_x`
288
+ u
289
+ the PINN
290
+ params
291
+ the PINN parameters
292
+ method
293
+ how to compute the Laplacian. `"trace_hessian_t_x"` means that the computation
294
+ of the Hessian integrates `t` which is excluded at the end (**see
295
+ Warning**). `"trace_hessian_x"` means an Hessian computation which
296
+ excludes `t` (**see Warning**). `"loop"` means that we
297
+ directly compute the second order derivatives with a loop (we avoid
298
+ non-diagonal derivatives at the cost of a loop).
299
+ eq_type
300
+ whether we consider a stationary or non stationary PINN. Most often we
301
+ can know that by inspecting the `u` argument (PINN object). But if `u` is
302
+ a function, we must set this attribute.
132
303
  """
133
304
 
134
- def scan_fun(_, i):
135
- tangent_vec = jnp.repeat(
136
- jax.nn.one_hot(i, x.shape[-1])[None], x.shape[0], axis=0
137
- )
138
-
139
- if t is None:
140
- du_dxi_fun = lambda x: jax.jvp(
141
- lambda x: u(x, params)[..., 0], (x,), (tangent_vec,)
142
- )[
143
- 1
144
- ] # Note the indexing [..., 0]
145
- __, d2u_dxi2 = jax.jvp(du_dxi_fun, (x,), (tangent_vec,))
305
+ try:
306
+ eq_type = u.eq_type
307
+ except AttributeError:
308
+ pass # use the value passed as argument
309
+ if eq_type is None:
310
+ raise ValueError("eq_type could not be set!")
311
+
312
+ if method == "loop":
313
+
314
+ def scan_fun(_, i):
315
+ if eq_type == "nonstatio_PDE":
316
+ tangent_vec = jnp.repeat(
317
+ jax.nn.one_hot(i + 1, inputs.shape[-1])[None],
318
+ inputs.shape[0],
319
+ axis=0,
320
+ )
321
+ else:
322
+ tangent_vec = jnp.repeat(
323
+ jax.nn.one_hot(i, inputs.shape[-1])[None],
324
+ inputs.shape[0],
325
+ axis=0,
326
+ )
327
+
328
+ du_dxi_fun = lambda inputs: jax.jvp(
329
+ lambda inputs: u(inputs, params),
330
+ (inputs,),
331
+ (tangent_vec,),
332
+ )[1]
333
+ __, d2u_dxi2 = jax.jvp(du_dxi_fun, (inputs,), (tangent_vec,))
334
+ return _, d2u_dxi2
335
+
336
+ if eq_type == "nonstatio_PDE":
337
+ _, trace_hessian = jax.lax.scan(
338
+ scan_fun, {}, jnp.arange(inputs.shape[-1] - 1)
339
+ )
340
+ elif eq_type == "statio_PDE":
341
+ _, trace_hessian = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[-1]))
146
342
  else:
147
- du_dxi_fun = lambda x: jax.jvp(
148
- lambda x: u(t, x, params)[..., 0], (x,), (tangent_vec,)
149
- )[
150
- 1
151
- ] # Note the indexing [..., 0]
152
- __, d2u_dxi2 = jax.jvp(du_dxi_fun, (x,), (tangent_vec,))
153
- return _, d2u_dxi2
343
+ raise ValueError("Unexpected eq_type!")
344
+ return jnp.sum(trace_hessian, axis=0)
345
+ if method == "trace_hessian_t_x":
346
+ if eq_type == "nonstatio_PDE":
347
+ # compute the Hessian including the batch dimension, get rid of the
348
+ # (..,1,..) axis that is here because of the scalar output
349
+ # if inputs.shape==(10,3) (1 for time, 2 for x_dim)
350
+ # then r.shape=(10,10,10,1,10,3,10,3)
351
+ # there are way too much derivatives!
352
+ r = jax.hessian(u)(inputs, params).squeeze()
353
+ # compute the traces by avoid the time derivatives
354
+ # after that r.shape=(10,10,10,10)
355
+ r = jnp.trace(r[..., :, 1:, :, 1:], axis1=-3, axis2=-1)
356
+ # but then we are in a cartesian product, for each coordinate on
357
+ # the first two dimensions we only want the trace at the same
358
+ # coordinate on the last two dimensions
359
+ # this is done easily with einsum but we need to automate the
360
+ # formula according to the input dim
361
+ res_dims = "".join([f"{chr(97 + d)}" for d in range(inputs.shape[-1])])
362
+ lap = jnp.einsum(res_dims + "ii->" + res_dims, r)
363
+ return lap[..., None]
364
+ if eq_type == "statio_PDE":
365
+ # compute the Hessian including the batch dimension, get rid of the
366
+ # (..,1,..) axis that is here because of the scalar output
367
+ # if inputs.shape==(10,2), r.shape=(10,10,1,10,2,10,2)
368
+ # there are way too much derivatives!
369
+ r = jax.hessian(u)(inputs, params).squeeze()
370
+ # compute the traces, after that r.shape=(10,10,10,10)
371
+ r = jnp.trace(r, axis1=-3, axis2=-1)
372
+ # but then we are in a cartesian product, for each coordinate on
373
+ # the first two dimensions we only want the trace at the same
374
+ # coordinate on the last two dimensions
375
+ # this is done easily with einsum but we need to automate the
376
+ # formula according to the input dim
377
+ res_dims = "".join([f"{chr(97 + d)}" for d in range(inputs.shape[-1])])
378
+ lap = jnp.einsum(res_dims + "ii->" + res_dims, r)
379
+ return lap[..., None]
380
+ raise ValueError("Unexpected eq_type!")
381
+ if method == "trace_hessian_x":
382
+ if eq_type == "statio_PDE":
383
+ # compute the Hessian including the batch dimension, get rid of the
384
+ # (..,1,..) axis that is here because of the scalar output
385
+ # if inputs.shape==(10,2), r.shape=(10,10,1,10,2,10,2)
386
+ # there are way too much derivatives!
387
+ r = jax.hessian(u)(inputs, params).squeeze()
388
+ # compute the traces, after that r.shape=(10,10,10,10)
389
+ r = jnp.trace(r, axis1=-3, axis2=-1)
390
+ # but then we are in a cartesian product, for each coordinate on
391
+ # the first two dimensions we only want the trace at the same
392
+ # coordinate on the last two dimensions
393
+ # this is done easily with einsum but we need to automate the
394
+ # formula according to the input dim
395
+ res_dims = "".join([f"{chr(97 + d)}" for d in range(inputs.shape[-1])])
396
+ lap = jnp.einsum(res_dims + "ii->" + res_dims, r)
397
+ return lap[..., None]
398
+ raise ValueError("Unexpected eq_type!")
399
+ raise ValueError("Unexpected method argument!")
400
+
401
+
402
+ def vectorial_laplacian_rev(
403
+ inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
404
+ u: eqx.Module,
405
+ params: Params,
406
+ dim_out: int = None,
407
+ ) -> Float[Array, "dim_out"]:
408
+ r"""
409
+ Compute the vectorial Laplacian of a vector field $\mathbf{u}$ from
410
+ $\mathbb{R}^d$ to $\mathbb{R}^n$ or from $\mathbb{R}^{1+d}$ to
411
+ $\mathbb{R}^n$, i.e., this
412
+ function can be used for stationary or non-stationary PINNs. In the first
413
+ case $\mathrm{inputs}=\mathbf{x}$ is of arbitrary dimension, i.e.,
414
+ $\Delta_\mathbf{x} \mathbf{u}(\mathbf{x})=\nabla\cdot\nabla
415
+ \mathbf{u}(\mathbf{x})$.
416
+ In the second case $inputs=\mathbf{t,x}$, and we perform
417
+ $\Delta_\mathbf{x} \mathbf{u}(\mathrm{inputs})=\nabla\cdot\nabla
418
+ \mathbf{u}(\mathrm{inputs})$.
419
+
420
+ Parameters
421
+ ----------
422
+ inputs
423
+ `x` or `t_x`
424
+ u
425
+ the PINN
426
+ params
427
+ the PINN parameters
428
+ dim_out
429
+ Dimension of the vector $\mathbf{u}(\mathrm{inputs})$. This needs to be
430
+ provided if it is different than that of $\mathrm{inputs}$.
431
+ """
432
+ if dim_out is None:
433
+ dim_out = inputs.shape[0]
434
+
435
+ def scan_fun(_, j):
436
+ # The loop over the components of u(x). We compute one Laplacian for
437
+ # each of these components
438
+ # Note the jnp.expand_dims call
439
+ uj = lambda inputs, params: jnp.expand_dims(u(inputs, params)[j], axis=-1)
440
+ lap_on_j = laplacian_rev(inputs, uj, params, eq_type=u.eq_type)
154
441
 
155
- _, trace_hessian = jax.lax.scan(scan_fun, {}, jnp.arange(x.shape[1]))
156
- return jnp.sum(trace_hessian, axis=0)
442
+ return _, lap_on_j
157
443
 
444
+ _, vec_lap = jax.lax.scan(scan_fun, {}, jnp.arange(dim_out))
445
+ return vec_lap
158
446
 
159
- def _vectorial_laplacian(
160
- t: Float[Array, "1"] | Float[Array, "batch_size 1"],
161
- x: Float[Array, "dimension_in"] | Float[Array, "batch_size dimension"],
447
+
448
+ def vectorial_laplacian_fwd(
449
+ inputs: Float[Array, "batch_size dim"] | Float[Array, "batch_size 1+dim"],
162
450
  u: eqx.Module,
163
451
  params: Params,
164
- u_vec_ndim: int = None,
165
- ) -> (
166
- Float[Array, "dimension_out"] | Float[Array, "batch_size batch_size dimension_out"]
167
- ):
452
+ dim_out: int = None,
453
+ ) -> Float[Array, "batch_size * (1+dim) n"] | Float[Array, "batch_size * (dim) n"]:
168
454
  r"""
169
- Compute the vectorial Laplacian of a vector field $\mathbf{u}$ (from
170
- $\mathbb{R}^d$
171
- to $\mathbb{R}^n$) for $\mathbf{x}$ of arbitrary dimension, i.e.,
172
- $\Delta \mathbf{u}(\mathbf{x})=\nabla\cdot\nabla
173
- \mathbf{u}(\mathbf{x})$.
455
+ Compute the vectorial Laplacian of a vector field $\mathbf{u}$ when
456
+ `u` is a SPINN, in this case, it corresponds to a vector
457
+ field from from $\mathbb{R}^{b\times d}$ to
458
+ $\mathbb{R}^{b\times b\times n}$ or from $\mathbb{R}^{b\times 1+d}$ to
459
+ $\mathbb{R}^{b\times b\times n}$, i.e., this
460
+ function can be used for stationary or non-stationary PINNs.
461
+
462
+ Forward mode AD is used.
463
+
464
+ !!! warning "Warning"
174
465
 
175
- **Note:** We need to provide `u_vec_ndim` the dimension of the vector
176
- $\mathbf{u}(\mathbf{x})$ if it is different than that of
177
- $\mathbf{x}$.
466
+ This function is to be used in the context of SPINNs only.
178
467
 
179
- **Note:** `u` can be a SPINN, in this case, it corresponds to a vector
180
- field from (from $\mathbb{R}^{b\times d}$ to
181
- $\mathbb{R}^{b\times b\times n}$) and forward mode AD is used.
182
- Technically, the return is of dimension $n\times b \times b$.
468
+ Parameters
469
+ ----------
470
+ inputs
471
+ `x` or `t_x`
472
+ u
473
+ the PINN
474
+ params
475
+ the PINN parameters
476
+ dim_out
477
+ the value of the output dimension ($n$ in the formula above). Must be
478
+ set if different from $d$.
183
479
  """
184
- if u_vec_ndim is None:
185
- u_vec_ndim = x.shape[0]
480
+ if dim_out is None:
481
+ dim_out = inputs.shape[0]
186
482
 
187
483
  def scan_fun(_, j):
188
484
  # The loop over the components of u(x). We compute one Laplacian for
189
485
  # each of these components
190
486
  # Note the expand_dims
191
- if isinstance(u, PINN):
192
- if t is None:
193
- uj = lambda x, params: jnp.expand_dims(u(x, params)[j], axis=-1)
194
- else:
195
- uj = lambda t, x, params: jnp.expand_dims(u(t, x, params)[j], axis=-1)
196
- lap_on_j = _laplacian_rev(t, x, uj, params)
197
- elif isinstance(u, SPINN):
198
- if t is None:
199
- uj = lambda x, params: jnp.expand_dims(u(x, params)[..., j], axis=-1)
200
- else:
201
- uj = lambda t, x, params: jnp.expand_dims(
202
- u(t, x, params)[..., j], axis=-1
203
- )
204
- lap_on_j = _laplacian_fwd(t, x, uj, params)
205
- else:
206
- raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
487
+ uj = lambda inputs, params: jnp.expand_dims(u(inputs, params)[..., j], axis=-1)
488
+ lap_on_j = laplacian_fwd(inputs, uj, params, eq_type=u.eq_type)
207
489
 
208
490
  return _, lap_on_j
209
491
 
210
- _, vec_lap = jax.lax.scan(scan_fun, {}, jnp.arange(u_vec_ndim))
211
- return vec_lap
492
+ _, vec_lap = jax.lax.scan(scan_fun, {}, jnp.arange(dim_out))
493
+ return jnp.moveaxis(vec_lap.squeeze(), 0, -1)
212
494
 
213
495
 
214
496
  def _u_dot_nabla_times_u_rev(
215
- t: Float[Array, "1"], x: Float[Array, "2"], u: eqx.Module, params: Params
497
+ x: Float[Array, "2"], u: eqx.Module, params: Params
216
498
  ) -> Float[Array, "2"]:
217
499
  r"""
218
500
  Implement $((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})$ for
@@ -223,43 +505,25 @@ def _u_dot_nabla_times_u_rev(
223
505
  We do not use loops but code explicitly the expression to avoid
224
506
  computing twice some terms
225
507
  """
226
- if x.shape[0] == 2:
227
- if t is None:
228
- ux = lambda x: u(x, params)[0]
229
- uy = lambda x: u(x, params)[1]
230
-
231
- dux_dx = lambda x: grad(ux, 0)(x)[0]
232
- dux_dy = lambda x: grad(ux, 0)(x)[1]
233
-
234
- duy_dx = lambda x: grad(uy, 0)(x)[0]
235
- duy_dy = lambda x: grad(uy, 0)(x)[1]
236
-
237
- return jnp.array(
238
- [
239
- ux(x) * dux_dx(x) + uy(x) * dux_dy(x),
240
- ux(x) * duy_dx(x) + uy(x) * duy_dy(x),
241
- ]
242
- )
243
- ux = lambda t, x: u(t, x, params)[0]
244
- uy = lambda t, x: u(t, x, params)[1]
508
+ assert x.shape[0] == 2
509
+ ux = lambda x: u(x, params)[0]
510
+ uy = lambda x: u(x, params)[1]
245
511
 
246
- dux_dx = lambda t, x: grad(ux, 1)(t, x)[0]
247
- dux_dy = lambda t, x: grad(ux, 1)(t, x)[1]
512
+ dux_dx = lambda x: grad(ux, 0)(x)[0]
513
+ dux_dy = lambda x: grad(ux, 0)(x)[1]
248
514
 
249
- duy_dx = lambda t, x: grad(uy, 1)(t, x)[0]
250
- duy_dy = lambda t, x: grad(uy, 1)(t, x)[1]
515
+ duy_dx = lambda x: grad(uy, 0)(x)[0]
516
+ duy_dy = lambda x: grad(uy, 0)(x)[1]
251
517
 
252
- return jnp.array(
253
- [
254
- ux(t, x) * dux_dx(t, x) + uy(t, x) * dux_dy(t, x),
255
- ux(t, x) * duy_dx(t, x) + uy(t, x) * duy_dy(t, x),
256
- ]
257
- )
258
- raise NotImplementedError("x.ndim must be 2")
518
+ return jnp.array(
519
+ [
520
+ ux(x) * dux_dx(x) + uy(x) * dux_dy(x),
521
+ ux(x) * duy_dx(x) + uy(x) * duy_dy(x),
522
+ ]
523
+ )
259
524
 
260
525
 
261
526
  def _u_dot_nabla_times_u_fwd(
262
- t: Float[Array, "batch_size 1"],
263
527
  x: Float[Array, "batch_size 2"],
264
528
  u: eqx.Module,
265
529
  params: Params,
@@ -277,29 +541,22 @@ def _u_dot_nabla_times_u_fwd(
277
541
  computation is most efficient with forward AD. This is the idea behind Separable PINNs.
278
542
  This function is to be used in the context of SPINNs only.
279
543
  """
280
- if x.shape[-1] == 2:
281
- tangent_vec_0 = jnp.repeat(jnp.array([1.0, 0.0])[None], x.shape[0], axis=0)
282
- tangent_vec_1 = jnp.repeat(jnp.array([0.0, 1.0])[None], x.shape[0], axis=0)
283
- if t is None:
284
- u_at_x, du_dx = jax.jvp(
285
- lambda x: u(x, params), (x,), (tangent_vec_0,)
286
- ) # thanks to forward AD this gets dux_dx and duy_dx in a vector
287
- # ie the derivatives of both components of u wrt x
288
- # this also gets the vector of u evaluated at x
289
- u_at_x, du_dy = jax.jvp(
290
- lambda x: u(x, params), (x,), (tangent_vec_1,)
291
- ) # thanks to forward AD this gets dux_dy and duy_dy in a vector
292
- # ie the derivatives of both components of u wrt y
293
-
294
- else:
295
- u_at_x, du_dx = jax.jvp(lambda x: u(t, x, params), (x,), (tangent_vec_0,))
296
- u_at_x, du_dy = jax.jvp(lambda x: u(t, x, params), (x,), (tangent_vec_1,))
297
-
298
- return jnp.stack(
299
- [
300
- u_at_x[..., 0] * du_dx[..., 0] + u_at_x[..., 1] * du_dy[..., 0],
301
- u_at_x[..., 0] * du_dx[..., 1] + u_at_x[..., 1] * du_dy[..., 1],
302
- ],
303
- axis=-1,
304
- )
305
- raise NotImplementedError("x.ndim must be 2")
544
+ assert x.shape[-1] == 2
545
+ tangent_vec_0 = jnp.repeat(jnp.array([1.0, 0.0])[None], x.shape[0], axis=0)
546
+ tangent_vec_1 = jnp.repeat(jnp.array([0.0, 1.0])[None], x.shape[0], axis=0)
547
+ u_at_x, du_dx = jax.jvp(
548
+ lambda x: u(x, params), (x,), (tangent_vec_0,)
549
+ ) # thanks to forward AD this gets dux_dx and duy_dx in a vector
550
+ # ie the derivatives of both components of u wrt x
551
+ # this also gets the vector of u evaluated at x
552
+ u_at_x, du_dy = jax.jvp(
553
+ lambda x: u(x, params), (x,), (tangent_vec_1,)
554
+ ) # thanks to forward AD this gets dux_dy and duy_dy in a vector
555
+ # ie the derivatives of both components of u wrt y
556
+ return jnp.stack(
557
+ [
558
+ u_at_x[..., 0] * du_dx[..., 0] + u_at_x[..., 1] * du_dy[..., 0],
559
+ u_at_x[..., 0] * du_dx[..., 1] + u_at_x[..., 1] * du_dy[..., 1],
560
+ ],
561
+ axis=-1,
562
+ )