jinns 1.1.0__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_Batchs.py +4 -8
- jinns/data/_DataGenerators.py +534 -343
- jinns/loss/_DynamicLoss.py +152 -175
- jinns/loss/_DynamicLossAbstract.py +25 -73
- jinns/loss/_LossODE.py +4 -4
- jinns/loss/_LossPDE.py +102 -74
- jinns/loss/__init__.py +7 -6
- jinns/loss/_boundary_conditions.py +150 -281
- jinns/loss/_loss_utils.py +95 -67
- jinns/loss/_operators.py +441 -186
- jinns/nn/__init__.py +7 -0
- jinns/nn/_hyperpinn.py +397 -0
- jinns/nn/_mlp.py +192 -0
- jinns/nn/_pinn.py +190 -0
- jinns/nn/_ppinn.py +203 -0
- jinns/{utils → nn}/_save_load.py +47 -31
- jinns/nn/_spinn.py +106 -0
- jinns/nn/_spinn_mlp.py +196 -0
- jinns/plot/_plot.py +113 -100
- jinns/solver/_rar.py +104 -409
- jinns/solver/_solve.py +87 -38
- jinns/solver/_utils.py +122 -0
- jinns/utils/__init__.py +1 -4
- jinns/utils/_containers.py +3 -1
- jinns/utils/_types.py +5 -4
- jinns/utils/_utils.py +40 -12
- jinns-1.3.0.dist-info/METADATA +127 -0
- jinns-1.3.0.dist-info/RECORD +44 -0
- {jinns-1.1.0.dist-info → jinns-1.3.0.dist-info}/WHEEL +1 -1
- jinns/utils/_hyperpinn.py +0 -410
- jinns/utils/_pinn.py +0 -334
- jinns/utils/_spinn.py +0 -268
- jinns-1.1.0.dist-info/METADATA +0 -85
- jinns-1.1.0.dist-info/RECORD +0 -39
- {jinns-1.1.0.dist-info → jinns-1.3.0.dist-info}/AUTHORS +0 -0
- {jinns-1.1.0.dist-info → jinns-1.3.0.dist-info}/LICENSE +0 -0
- {jinns-1.1.0.dist-info → jinns-1.3.0.dist-info}/top_level.txt +0 -0
jinns/loss/_operators.py
CHANGED
|
@@ -2,45 +2,86 @@
|
|
|
2
2
|
Implements diverse operators for dynamic losses
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
5
7
|
import jax
|
|
6
8
|
import jax.numpy as jnp
|
|
7
9
|
from jax import grad
|
|
8
10
|
import equinox as eqx
|
|
9
11
|
from jaxtyping import Float, Array
|
|
10
|
-
from jinns.utils._pinn import PINN
|
|
11
|
-
from jinns.utils._spinn import SPINN
|
|
12
12
|
from jinns.parameters._params import Params
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
def
|
|
16
|
-
|
|
15
|
+
def divergence_rev(
|
|
16
|
+
inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
|
|
17
|
+
u: eqx.Module,
|
|
18
|
+
params: Params,
|
|
19
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
|
|
17
20
|
) -> float:
|
|
18
21
|
r"""
|
|
19
22
|
Compute the divergence of a vector field $\mathbf{u}$, i.e.,
|
|
20
|
-
$\
|
|
21
|
-
field from $\mathbb{R}^d$ to $\mathbb{R}^d
|
|
23
|
+
$\nabla_\mathbf{x} \cdot \mathbf{u}(\mathrm{inputs})$ with $\mathbf{u}$ a vector
|
|
24
|
+
field from $\mathbb{R}^d$ to $\mathbb{R}^d$ or $\mathbb{R}^{1+d}$
|
|
25
|
+
to $\mathbb{R}^{1+d}$. Thus, this
|
|
26
|
+
function can be used for stationary or non-stationary PINNs. In the first
|
|
27
|
+
case $\mathrm{inputs}=\mathbf{x}$, in the second case
|
|
28
|
+
case $\mathrm{inputs}=\mathbf{t,x}$.
|
|
22
29
|
The computation is done using backward AD
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
inputs
|
|
34
|
+
`x` or `t_x`
|
|
35
|
+
u
|
|
36
|
+
the PINN
|
|
37
|
+
params
|
|
38
|
+
the PINN parameters
|
|
39
|
+
eq_type
|
|
40
|
+
whether we consider a stationary or non stationary PINN. Most often we
|
|
41
|
+
can know that by inspecting the `u` argument (PINN object). But if `u` is
|
|
42
|
+
a function, we must set this attribute.
|
|
23
43
|
"""
|
|
24
44
|
|
|
45
|
+
try:
|
|
46
|
+
eq_type = u.eq_type
|
|
47
|
+
except AttributeError:
|
|
48
|
+
pass # use the value passed as argument
|
|
49
|
+
if eq_type is None:
|
|
50
|
+
raise ValueError("eq_type could not be set!")
|
|
51
|
+
|
|
25
52
|
def scan_fun(_, i):
|
|
26
|
-
if
|
|
27
|
-
du_dxi = grad(lambda
|
|
53
|
+
if eq_type == "nonstatio_PDE":
|
|
54
|
+
du_dxi = grad(lambda inputs, params: u(inputs, params)[1 + i])(
|
|
55
|
+
inputs, params
|
|
56
|
+
)[1 + i]
|
|
28
57
|
else:
|
|
29
|
-
du_dxi = grad(lambda
|
|
58
|
+
du_dxi = grad(lambda inputs, params: u(inputs, params)[i])(inputs, params)[
|
|
59
|
+
i
|
|
60
|
+
]
|
|
30
61
|
return _, du_dxi
|
|
31
62
|
|
|
32
|
-
|
|
63
|
+
if eq_type == "nonstatio_PDE":
|
|
64
|
+
_, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[0] - 1))
|
|
65
|
+
elif eq_type == "statio_PDE":
|
|
66
|
+
_, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[0]))
|
|
67
|
+
else:
|
|
68
|
+
raise ValueError("Unexpected u.eq_type!")
|
|
33
69
|
return jnp.sum(accu)
|
|
34
70
|
|
|
35
71
|
|
|
36
|
-
def
|
|
37
|
-
|
|
38
|
-
|
|
72
|
+
def divergence_fwd(
|
|
73
|
+
inputs: Float[Array, "batch_size dim"] | Float[Array, "batch_size 1+dim"],
|
|
74
|
+
u: eqx.Module,
|
|
75
|
+
params: Params,
|
|
76
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
|
|
77
|
+
) -> Float[Array, "batch_size * (1+dim) 1"] | Float[Array, "batch_size * (dim) 1"]:
|
|
39
78
|
r"""
|
|
40
79
|
Compute the divergence of a **batched** vector field $\mathbf{u}$, i.e.,
|
|
41
|
-
$\
|
|
80
|
+
$\nabla_\mathbf{x} \cdot \mathbf{u}(\mathbf{x})$ with $\mathbf{u}$ a vector
|
|
42
81
|
field from $\mathbb{R}^{b \times d}$ to $\mathbb{R}^{b \times b
|
|
43
|
-
\times d}
|
|
82
|
+
\times d}$ or from $\mathbb{R}^{b \times d+1}$ to $\mathbb{R}^{b \times b
|
|
83
|
+
\times d+1}$. Thus, this
|
|
84
|
+
function can be used for stationary or non-stationary PINNs.
|
|
44
85
|
Because of the embedding that happens in SPINNs the
|
|
45
86
|
computation is most efficient with forward AD. This is the idea behind
|
|
46
87
|
Separable PINNs.
|
|
@@ -48,79 +89,180 @@ def _div_fwd(
|
|
|
48
89
|
!!! warning "Warning"
|
|
49
90
|
|
|
50
91
|
This function is to be used in the context of SPINNs only.
|
|
92
|
+
|
|
93
|
+
Parameters
|
|
94
|
+
----------
|
|
95
|
+
inputs
|
|
96
|
+
`x` or `t_x`
|
|
97
|
+
u
|
|
98
|
+
the PINN
|
|
99
|
+
params
|
|
100
|
+
the PINN parameters
|
|
101
|
+
eq_type
|
|
102
|
+
whether we consider a stationary or non stationary PINN. Most often we
|
|
103
|
+
can know that by inspecting the `u` argument (PINN object). But if `u` is
|
|
104
|
+
a function, we must set this attribute.
|
|
51
105
|
"""
|
|
52
106
|
|
|
107
|
+
try:
|
|
108
|
+
eq_type = u.eq_type
|
|
109
|
+
except AttributeError:
|
|
110
|
+
pass # use the value passed as argument
|
|
111
|
+
if eq_type is None:
|
|
112
|
+
raise ValueError("eq_type could not be set!")
|
|
113
|
+
|
|
53
114
|
def scan_fun(_, i):
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
115
|
+
if eq_type == "nonstatio_PDE":
|
|
116
|
+
tangent_vec = jnp.repeat(
|
|
117
|
+
jax.nn.one_hot(i + 1, inputs.shape[-1])[None],
|
|
118
|
+
inputs.shape[0],
|
|
119
|
+
axis=0,
|
|
120
|
+
)
|
|
121
|
+
__, du_dxi = jax.jvp(
|
|
122
|
+
lambda inputs: u(inputs, params)[..., 1 + i], (inputs,), (tangent_vec,)
|
|
123
|
+
)
|
|
59
124
|
else:
|
|
125
|
+
tangent_vec = jnp.repeat(
|
|
126
|
+
jax.nn.one_hot(i, inputs.shape[-1])[None],
|
|
127
|
+
inputs.shape[0],
|
|
128
|
+
axis=0,
|
|
129
|
+
)
|
|
60
130
|
__, du_dxi = jax.jvp(
|
|
61
|
-
lambda
|
|
131
|
+
lambda inputs: u(inputs, params)[..., i], (inputs,), (tangent_vec,)
|
|
62
132
|
)
|
|
63
133
|
return _, du_dxi
|
|
64
134
|
|
|
65
|
-
|
|
135
|
+
if eq_type == "nonstatio_PDE":
|
|
136
|
+
_, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[1] - 1))
|
|
137
|
+
elif eq_type == "statio_PDE":
|
|
138
|
+
_, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[1]))
|
|
139
|
+
else:
|
|
140
|
+
raise ValueError("Unexpected u.eq_type!")
|
|
66
141
|
return jnp.sum(accu, axis=0)
|
|
67
142
|
|
|
68
143
|
|
|
69
|
-
def
|
|
70
|
-
|
|
144
|
+
def laplacian_rev(
|
|
145
|
+
inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
|
|
146
|
+
u: eqx.Module,
|
|
147
|
+
params: Params,
|
|
148
|
+
method: Literal["trace_hessian_x", "trace_hessian_t_x", "loop"] = "trace_hessian_x",
|
|
149
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
|
|
71
150
|
) -> float:
|
|
72
151
|
r"""
|
|
73
|
-
Compute the Laplacian of a scalar field $u$
|
|
74
|
-
to $\mathbb{R}$
|
|
75
|
-
|
|
152
|
+
Compute the Laplacian of a scalar field $u$ from $\mathbb{R}^d$
|
|
153
|
+
to $\mathbb{R}$ or from $\mathbb{R}^{1+d}$ to $\mathbb{R}$, i.e., this
|
|
154
|
+
function can be used for stationary or non-stationary PINNs. In the first
|
|
155
|
+
case $\mathrm{inputs}=\mathbf{x}$ is of arbitrary dimension, i.e.,
|
|
156
|
+
$\Delta_\mathbf{x} u(\mathbf{x})=\nabla_\mathbf{x}\cdot\nabla_\mathbf{x} u(\mathbf{x})$.
|
|
157
|
+
In the second case $inputs=\mathbf{t,x}$, but we still compute
|
|
158
|
+
$\Delta_\mathbf{x} u(\mathrm{inputs})$.
|
|
76
159
|
The computation is done using backward AD.
|
|
160
|
+
|
|
161
|
+
Parameters
|
|
162
|
+
----------
|
|
163
|
+
inputs
|
|
164
|
+
`x` or `t_x`
|
|
165
|
+
u
|
|
166
|
+
the PINN
|
|
167
|
+
params
|
|
168
|
+
the PINN parameters
|
|
169
|
+
method
|
|
170
|
+
how to compute the Laplacian. `"trace_hessian_x"` means that we take
|
|
171
|
+
the trace of the Hessian matrix computed with `x` only (`t` is excluded
|
|
172
|
+
from the beginning, we compute less derivatives at the price of a
|
|
173
|
+
concatenation). `"trace_hessian_t_x"` means that the computation
|
|
174
|
+
of the Hessian integrates `t` which is excluded at the end (we avoid a
|
|
175
|
+
concatenate but we compute more derivatives). `"loop"` means that we
|
|
176
|
+
directly compute the second order derivatives with a loop (we avoid
|
|
177
|
+
non-diagonal derivatives at the cost of a loop).
|
|
178
|
+
eq_type
|
|
179
|
+
whether we consider a stationary or non stationary PINN. Most often we
|
|
180
|
+
can know that by inspecting the `u` argument (PINN object). But if `u` is
|
|
181
|
+
a function, we must set this attribute.
|
|
77
182
|
"""
|
|
78
183
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
184
|
+
try:
|
|
185
|
+
eq_type = u.eq_type
|
|
186
|
+
except AttributeError:
|
|
187
|
+
pass # use the value passed as argument
|
|
188
|
+
if eq_type is None:
|
|
189
|
+
raise ValueError("eq_type could not be set!")
|
|
190
|
+
|
|
191
|
+
if method == "trace_hessian_x":
|
|
192
|
+
# NOTE we afford a concatenate here to avoid computing Hessian elements for
|
|
193
|
+
# nothing. In case of simple derivatives we prefer the vectorial
|
|
194
|
+
# computation and then discarding elements but for higher order derivatives
|
|
195
|
+
# it might not be worth it. See other options below for computating the
|
|
196
|
+
# Laplacian
|
|
197
|
+
if eq_type == "nonstatio_PDE":
|
|
198
|
+
u_ = lambda x: jnp.squeeze(
|
|
199
|
+
u(jnp.concatenate([inputs[:1], x], axis=0), params)
|
|
200
|
+
)
|
|
201
|
+
return jnp.sum(jnp.diag(jax.hessian(u_)(inputs[1:])))
|
|
202
|
+
if eq_type == "statio_PDE":
|
|
203
|
+
u_ = lambda inputs: jnp.squeeze(u(inputs, params))
|
|
204
|
+
return jnp.sum(jnp.diag(jax.hessian(u_)(inputs)))
|
|
205
|
+
raise ValueError("Unexpected eq_type!")
|
|
206
|
+
if method == "trace_hessian_t_x":
|
|
207
|
+
# NOTE that it is unclear whether it is better to vectorially compute the
|
|
208
|
+
# Hessian (despite a useless time dimension) as below
|
|
209
|
+
if eq_type == "nonstatio_PDE":
|
|
210
|
+
u_ = lambda inputs: jnp.squeeze(u(inputs, params))
|
|
211
|
+
return jnp.sum(jnp.diag(jax.hessian(u_)(inputs))[1:])
|
|
212
|
+
if eq_type == "statio_PDE":
|
|
213
|
+
u_ = lambda inputs: jnp.squeeze(u(inputs, params))
|
|
214
|
+
return jnp.sum(jnp.diag(jax.hessian(u_)(inputs)))
|
|
215
|
+
raise ValueError("Unexpected eq_type!")
|
|
216
|
+
|
|
217
|
+
if method == "loop":
|
|
218
|
+
# For a small d, we found out that trace of the Hessian is faster, see
|
|
219
|
+
# https://stackoverflow.com/questions/77517357/jax-grad-derivate-with-respect-an-specific-variable-in-a-matrix
|
|
220
|
+
# but could the trick below for taking directly the diagonal elements
|
|
221
|
+
# prove useful in higher dimensions?
|
|
222
|
+
|
|
223
|
+
u_ = lambda inputs: u(inputs, params).squeeze()
|
|
224
|
+
|
|
225
|
+
def scan_fun(_, i):
|
|
226
|
+
if eq_type == "nonstatio_PDE":
|
|
227
|
+
d2u_dxi2 = grad(
|
|
228
|
+
lambda inputs: grad(u_)(inputs)[1 + i],
|
|
229
|
+
)(
|
|
230
|
+
inputs
|
|
231
|
+
)[1 + i]
|
|
232
|
+
else:
|
|
233
|
+
d2u_dxi2 = grad(
|
|
234
|
+
lambda inputs: grad(u_, 0)(inputs)[i],
|
|
235
|
+
0,
|
|
236
|
+
)(
|
|
237
|
+
inputs
|
|
238
|
+
)[i]
|
|
239
|
+
return _, d2u_dxi2
|
|
240
|
+
|
|
241
|
+
if eq_type == "nonstatio_PDE":
|
|
242
|
+
_, trace_hessian = jax.lax.scan(
|
|
243
|
+
scan_fun, {}, jnp.arange(inputs.shape[0] - 1)
|
|
244
|
+
)
|
|
245
|
+
elif eq_type == "statio_PDE":
|
|
246
|
+
_, trace_hessian = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[0]))
|
|
247
|
+
else:
|
|
248
|
+
raise ValueError("Unexpected eq_type!")
|
|
249
|
+
return jnp.sum(trace_hessian)
|
|
250
|
+
raise ValueError("Unexpected method argument!")
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def laplacian_fwd(
|
|
254
|
+
inputs: Float[Array, "batch_size 1+dim"] | Float[Array, "batch_size dim"],
|
|
117
255
|
u: eqx.Module,
|
|
118
256
|
params: Params,
|
|
119
|
-
|
|
257
|
+
method: Literal["trace_hessian_t_x", "trace_hessian_x", "loop"] = "loop",
|
|
258
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
|
|
259
|
+
) -> Float[Array, "batch_size * (1+dim) 1"] | Float[Array, "batch_size * (dim) 1"]:
|
|
120
260
|
r"""
|
|
121
261
|
Compute the Laplacian of a **batched** scalar field $u$
|
|
122
|
-
|
|
123
|
-
|
|
262
|
+
from $\mathbb{R}^{b\times d}$ to $\mathbb{R}^{b\times b}$ or
|
|
263
|
+
from $\mathbb{R}^{b\times (1 + d)}$ to $\mathbb{R}^{b\times b}$ or, i.e., this
|
|
264
|
+
function can be used for stationary or non-stationary PINNs
|
|
265
|
+
for $\mathbf{x}$ of arbitrary dimension $d$ or $1+d$ with batch
|
|
124
266
|
dimension $b$.
|
|
125
267
|
Because of the embedding that happens in SPINNs the
|
|
126
268
|
computation is most efficient with forward AD. This is the idea behind
|
|
@@ -129,90 +271,228 @@ def _laplacian_fwd(
|
|
|
129
271
|
!!! warning "Warning"
|
|
130
272
|
|
|
131
273
|
This function is to be used in the context of SPINNs only.
|
|
274
|
+
|
|
275
|
+
!!! warning "Warning"
|
|
276
|
+
|
|
277
|
+
Because of the batch dimension, the current implementation of
|
|
278
|
+
`method="trace_hessian_t_x"` or `method="trace_hessian_x"`
|
|
279
|
+
should not be used except for debugging
|
|
280
|
+
purposes. Indeed, computing the Hessian is very costly.
|
|
281
|
+
|
|
282
|
+
Parameters
|
|
283
|
+
----------
|
|
284
|
+
inputs
|
|
285
|
+
`x` or `t_x`
|
|
286
|
+
u
|
|
287
|
+
the PINN
|
|
288
|
+
params
|
|
289
|
+
the PINN parameters
|
|
290
|
+
method
|
|
291
|
+
how to compute the Laplacian. `"trace_hessian_t_x"` means that the computation
|
|
292
|
+
of the Hessian integrates `t` which is excluded at the end (**see
|
|
293
|
+
Warning**). `"trace_hessian_x"` means an Hessian computation which
|
|
294
|
+
excludes `t` (**see Warning**). `"loop"` means that we
|
|
295
|
+
directly compute the second order derivatives with a loop (we avoid
|
|
296
|
+
non-diagonal derivatives at the cost of a loop).
|
|
297
|
+
eq_type
|
|
298
|
+
whether we consider a stationary or non stationary PINN. Most often we
|
|
299
|
+
can know that by inspecting the `u` argument (PINN object). But if `u` is
|
|
300
|
+
a function, we must set this attribute.
|
|
132
301
|
"""
|
|
133
302
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
303
|
+
try:
|
|
304
|
+
eq_type = u.eq_type
|
|
305
|
+
except AttributeError:
|
|
306
|
+
pass # use the value passed as argument
|
|
307
|
+
if eq_type is None:
|
|
308
|
+
raise ValueError("eq_type could not be set!")
|
|
309
|
+
|
|
310
|
+
if method == "loop":
|
|
311
|
+
|
|
312
|
+
def scan_fun(_, i):
|
|
313
|
+
if eq_type == "nonstatio_PDE":
|
|
314
|
+
tangent_vec = jnp.repeat(
|
|
315
|
+
jax.nn.one_hot(i + 1, inputs.shape[-1])[None],
|
|
316
|
+
inputs.shape[0],
|
|
317
|
+
axis=0,
|
|
318
|
+
)
|
|
319
|
+
else:
|
|
320
|
+
tangent_vec = jnp.repeat(
|
|
321
|
+
jax.nn.one_hot(i, inputs.shape[-1])[None],
|
|
322
|
+
inputs.shape[0],
|
|
323
|
+
axis=0,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
du_dxi_fun = lambda inputs: jax.jvp(
|
|
327
|
+
lambda inputs: u(inputs, params),
|
|
328
|
+
(inputs,),
|
|
329
|
+
(tangent_vec,),
|
|
330
|
+
)[1]
|
|
331
|
+
__, d2u_dxi2 = jax.jvp(du_dxi_fun, (inputs,), (tangent_vec,))
|
|
332
|
+
return _, d2u_dxi2
|
|
333
|
+
|
|
334
|
+
if eq_type == "nonstatio_PDE":
|
|
335
|
+
_, trace_hessian = jax.lax.scan(
|
|
336
|
+
scan_fun, {}, jnp.arange(inputs.shape[-1] - 1)
|
|
337
|
+
)
|
|
338
|
+
elif eq_type == "statio_PDE":
|
|
339
|
+
_, trace_hessian = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[-1]))
|
|
146
340
|
else:
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
341
|
+
raise ValueError("Unexpected eq_type!")
|
|
342
|
+
return jnp.sum(trace_hessian, axis=0)
|
|
343
|
+
if method == "trace_hessian_t_x":
|
|
344
|
+
if eq_type == "nonstatio_PDE":
|
|
345
|
+
# compute the Hessian including the batch dimension, get rid of the
|
|
346
|
+
# (..,1,..) axis that is here because of the scalar output
|
|
347
|
+
# if inputs.shape==(10,3) (1 for time, 2 for x_dim)
|
|
348
|
+
# then r.shape=(10,10,10,1,10,3,10,3)
|
|
349
|
+
# there are way too much derivatives!
|
|
350
|
+
r = jax.hessian(u)(inputs, params).squeeze()
|
|
351
|
+
# compute the traces by avoid the time derivatives
|
|
352
|
+
# after that r.shape=(10,10,10,10)
|
|
353
|
+
r = jnp.trace(r[..., :, 1:, :, 1:], axis1=-3, axis2=-1)
|
|
354
|
+
# but then we are in a cartesian product, for each coordinate on
|
|
355
|
+
# the first two dimensions we only want the trace at the same
|
|
356
|
+
# coordinate on the last two dimensions
|
|
357
|
+
# this is done easily with einsum but we need to automate the
|
|
358
|
+
# formula according to the input dim
|
|
359
|
+
res_dims = "".join([f"{chr(97 + d)}" for d in range(inputs.shape[-1])])
|
|
360
|
+
lap = jnp.einsum(res_dims + "ii->" + res_dims, r)
|
|
361
|
+
return lap[..., None]
|
|
362
|
+
if eq_type == "statio_PDE":
|
|
363
|
+
# compute the Hessian including the batch dimension, get rid of the
|
|
364
|
+
# (..,1,..) axis that is here because of the scalar output
|
|
365
|
+
# if inputs.shape==(10,2), r.shape=(10,10,1,10,2,10,2)
|
|
366
|
+
# there are way too much derivatives!
|
|
367
|
+
r = jax.hessian(u)(inputs, params).squeeze()
|
|
368
|
+
# compute the traces, after that r.shape=(10,10,10,10)
|
|
369
|
+
r = jnp.trace(r, axis1=-3, axis2=-1)
|
|
370
|
+
# but then we are in a cartesian product, for each coordinate on
|
|
371
|
+
# the first two dimensions we only want the trace at the same
|
|
372
|
+
# coordinate on the last two dimensions
|
|
373
|
+
# this is done easily with einsum but we need to automate the
|
|
374
|
+
# formula according to the input dim
|
|
375
|
+
res_dims = "".join([f"{chr(97 + d)}" for d in range(inputs.shape[-1])])
|
|
376
|
+
lap = jnp.einsum(res_dims + "ii->" + res_dims, r)
|
|
377
|
+
return lap[..., None]
|
|
378
|
+
raise ValueError("Unexpected eq_type!")
|
|
379
|
+
if method == "trace_hessian_x":
|
|
380
|
+
if eq_type == "statio_PDE":
|
|
381
|
+
# compute the Hessian including the batch dimension, get rid of the
|
|
382
|
+
# (..,1,..) axis that is here because of the scalar output
|
|
383
|
+
# if inputs.shape==(10,2), r.shape=(10,10,1,10,2,10,2)
|
|
384
|
+
# there are way too much derivatives!
|
|
385
|
+
r = jax.hessian(u)(inputs, params).squeeze()
|
|
386
|
+
# compute the traces, after that r.shape=(10,10,10,10)
|
|
387
|
+
r = jnp.trace(r, axis1=-3, axis2=-1)
|
|
388
|
+
# but then we are in a cartesian product, for each coordinate on
|
|
389
|
+
# the first two dimensions we only want the trace at the same
|
|
390
|
+
# coordinate on the last two dimensions
|
|
391
|
+
# this is done easily with einsum but we need to automate the
|
|
392
|
+
# formula according to the input dim
|
|
393
|
+
res_dims = "".join([f"{chr(97 + d)}" for d in range(inputs.shape[-1])])
|
|
394
|
+
lap = jnp.einsum(res_dims + "ii->" + res_dims, r)
|
|
395
|
+
return lap[..., None]
|
|
396
|
+
raise ValueError("Unexpected eq_type!")
|
|
397
|
+
raise ValueError("Unexpected method argument!")
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def vectorial_laplacian_rev(
|
|
401
|
+
inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
|
|
402
|
+
u: eqx.Module,
|
|
403
|
+
params: Params,
|
|
404
|
+
dim_out: int = None,
|
|
405
|
+
) -> Float[Array, "dim_out"]:
|
|
406
|
+
r"""
|
|
407
|
+
Compute the vectorial Laplacian of a vector field $\mathbf{u}$ from
|
|
408
|
+
$\mathbb{R}^d$ to $\mathbb{R}^n$ or from $\mathbb{R}^{1+d}$ to
|
|
409
|
+
$\mathbb{R}^n$, i.e., this
|
|
410
|
+
function can be used for stationary or non-stationary PINNs. In the first
|
|
411
|
+
case $\mathrm{inputs}=\mathbf{x}$ is of arbitrary dimension, i.e.,
|
|
412
|
+
$\Delta_\mathbf{x} \mathbf{u}(\mathbf{x})=\nabla\cdot\nabla
|
|
413
|
+
\mathbf{u}(\mathbf{x})$.
|
|
414
|
+
In the second case $inputs=\mathbf{t,x}$, and we perform
|
|
415
|
+
$\Delta_\mathbf{x} \mathbf{u}(\mathrm{inputs})=\nabla\cdot\nabla
|
|
416
|
+
\mathbf{u}(\mathrm{inputs})$.
|
|
417
|
+
|
|
418
|
+
Parameters
|
|
419
|
+
----------
|
|
420
|
+
inputs
|
|
421
|
+
`x` or `t_x`
|
|
422
|
+
u
|
|
423
|
+
the PINN
|
|
424
|
+
params
|
|
425
|
+
the PINN parameters
|
|
426
|
+
dim_out
|
|
427
|
+
Dimension of the vector $\mathbf{u}(\mathrm{inputs})$. This needs to be
|
|
428
|
+
provided if it is different than that of $\mathrm{inputs}$.
|
|
429
|
+
"""
|
|
430
|
+
if dim_out is None:
|
|
431
|
+
dim_out = inputs.shape[0]
|
|
432
|
+
|
|
433
|
+
def scan_fun(_, j):
|
|
434
|
+
# The loop over the components of u(x). We compute one Laplacian for
|
|
435
|
+
# each of these components
|
|
436
|
+
# Note the jnp.expand_dims call
|
|
437
|
+
uj = lambda inputs, params: jnp.expand_dims(u(inputs, params)[j], axis=-1)
|
|
438
|
+
lap_on_j = laplacian_rev(inputs, uj, params, eq_type=u.eq_type)
|
|
154
439
|
|
|
155
|
-
|
|
156
|
-
return jnp.sum(trace_hessian, axis=0)
|
|
440
|
+
return _, lap_on_j
|
|
157
441
|
|
|
442
|
+
_, vec_lap = jax.lax.scan(scan_fun, {}, jnp.arange(dim_out))
|
|
443
|
+
return vec_lap
|
|
158
444
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
445
|
+
|
|
446
|
+
def vectorial_laplacian_fwd(
|
|
447
|
+
inputs: Float[Array, "batch_size dim"] | Float[Array, "batch_size 1+dim"],
|
|
162
448
|
u: eqx.Module,
|
|
163
449
|
params: Params,
|
|
164
|
-
|
|
165
|
-
) -> (
|
|
166
|
-
Float[Array, "dimension_out"] | Float[Array, "batch_size batch_size dimension_out"]
|
|
167
|
-
):
|
|
450
|
+
dim_out: int = None,
|
|
451
|
+
) -> Float[Array, "batch_size * (1+dim) n"] | Float[Array, "batch_size * (dim) n"]:
|
|
168
452
|
r"""
|
|
169
|
-
Compute the vectorial Laplacian of a vector field $\mathbf{u}$
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
$\
|
|
173
|
-
|
|
453
|
+
Compute the vectorial Laplacian of a vector field $\mathbf{u}$ when
|
|
454
|
+
`u` is a SPINN, in this case, it corresponds to a vector
|
|
455
|
+
field from from $\mathbb{R}^{b\times d}$ to
|
|
456
|
+
$\mathbb{R}^{b\times b\times n}$ or from $\mathbb{R}^{b\times 1+d}$ to
|
|
457
|
+
$\mathbb{R}^{b\times b\times n}$, i.e., this
|
|
458
|
+
function can be used for stationary or non-stationary PINNs.
|
|
459
|
+
|
|
460
|
+
Forward mode AD is used.
|
|
461
|
+
|
|
462
|
+
!!! warning "Warning"
|
|
174
463
|
|
|
175
|
-
|
|
176
|
-
$\mathbf{u}(\mathbf{x})$ if it is different than that of
|
|
177
|
-
$\mathbf{x}$.
|
|
464
|
+
This function is to be used in the context of SPINNs only.
|
|
178
465
|
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
466
|
+
Parameters
|
|
467
|
+
----------
|
|
468
|
+
inputs
|
|
469
|
+
`x` or `t_x`
|
|
470
|
+
u
|
|
471
|
+
the PINN
|
|
472
|
+
params
|
|
473
|
+
the PINN parameters
|
|
474
|
+
dim_out
|
|
475
|
+
the value of the output dimension ($n$ in the formula above). Must be
|
|
476
|
+
set if different from $d$.
|
|
183
477
|
"""
|
|
184
|
-
if
|
|
185
|
-
|
|
478
|
+
if dim_out is None:
|
|
479
|
+
dim_out = inputs.shape[0]
|
|
186
480
|
|
|
187
481
|
def scan_fun(_, j):
|
|
188
482
|
# The loop over the components of u(x). We compute one Laplacian for
|
|
189
483
|
# each of these components
|
|
190
484
|
# Note the expand_dims
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
uj = lambda x, params: jnp.expand_dims(u(x, params)[j], axis=-1)
|
|
194
|
-
else:
|
|
195
|
-
uj = lambda t, x, params: jnp.expand_dims(u(t, x, params)[j], axis=-1)
|
|
196
|
-
lap_on_j = _laplacian_rev(t, x, uj, params)
|
|
197
|
-
elif isinstance(u, SPINN):
|
|
198
|
-
if t is None:
|
|
199
|
-
uj = lambda x, params: jnp.expand_dims(u(x, params)[..., j], axis=-1)
|
|
200
|
-
else:
|
|
201
|
-
uj = lambda t, x, params: jnp.expand_dims(
|
|
202
|
-
u(t, x, params)[..., j], axis=-1
|
|
203
|
-
)
|
|
204
|
-
lap_on_j = _laplacian_fwd(t, x, uj, params)
|
|
205
|
-
else:
|
|
206
|
-
raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
|
|
485
|
+
uj = lambda inputs, params: jnp.expand_dims(u(inputs, params)[..., j], axis=-1)
|
|
486
|
+
lap_on_j = laplacian_fwd(inputs, uj, params, eq_type=u.eq_type)
|
|
207
487
|
|
|
208
488
|
return _, lap_on_j
|
|
209
489
|
|
|
210
|
-
_, vec_lap = jax.lax.scan(scan_fun, {}, jnp.arange(
|
|
211
|
-
return vec_lap
|
|
490
|
+
_, vec_lap = jax.lax.scan(scan_fun, {}, jnp.arange(dim_out))
|
|
491
|
+
return jnp.moveaxis(vec_lap.squeeze(), 0, -1)
|
|
212
492
|
|
|
213
493
|
|
|
214
494
|
def _u_dot_nabla_times_u_rev(
|
|
215
|
-
|
|
495
|
+
x: Float[Array, "2"], u: eqx.Module, params: Params
|
|
216
496
|
) -> Float[Array, "2"]:
|
|
217
497
|
r"""
|
|
218
498
|
Implement $((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})$ for
|
|
@@ -223,43 +503,25 @@ def _u_dot_nabla_times_u_rev(
|
|
|
223
503
|
We do not use loops but code explicitly the expression to avoid
|
|
224
504
|
computing twice some terms
|
|
225
505
|
"""
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
uy = lambda x: u(x, params)[1]
|
|
230
|
-
|
|
231
|
-
dux_dx = lambda x: grad(ux, 0)(x)[0]
|
|
232
|
-
dux_dy = lambda x: grad(ux, 0)(x)[1]
|
|
233
|
-
|
|
234
|
-
duy_dx = lambda x: grad(uy, 0)(x)[0]
|
|
235
|
-
duy_dy = lambda x: grad(uy, 0)(x)[1]
|
|
236
|
-
|
|
237
|
-
return jnp.array(
|
|
238
|
-
[
|
|
239
|
-
ux(x) * dux_dx(x) + uy(x) * dux_dy(x),
|
|
240
|
-
ux(x) * duy_dx(x) + uy(x) * duy_dy(x),
|
|
241
|
-
]
|
|
242
|
-
)
|
|
243
|
-
ux = lambda t, x: u(t, x, params)[0]
|
|
244
|
-
uy = lambda t, x: u(t, x, params)[1]
|
|
506
|
+
assert x.shape[0] == 2
|
|
507
|
+
ux = lambda x: u(x, params)[0]
|
|
508
|
+
uy = lambda x: u(x, params)[1]
|
|
245
509
|
|
|
246
|
-
|
|
247
|
-
|
|
510
|
+
dux_dx = lambda x: grad(ux, 0)(x)[0]
|
|
511
|
+
dux_dy = lambda x: grad(ux, 0)(x)[1]
|
|
248
512
|
|
|
249
|
-
|
|
250
|
-
|
|
513
|
+
duy_dx = lambda x: grad(uy, 0)(x)[0]
|
|
514
|
+
duy_dy = lambda x: grad(uy, 0)(x)[1]
|
|
251
515
|
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
raise NotImplementedError("x.ndim must be 2")
|
|
516
|
+
return jnp.array(
|
|
517
|
+
[
|
|
518
|
+
ux(x) * dux_dx(x) + uy(x) * dux_dy(x),
|
|
519
|
+
ux(x) * duy_dx(x) + uy(x) * duy_dy(x),
|
|
520
|
+
]
|
|
521
|
+
)
|
|
259
522
|
|
|
260
523
|
|
|
261
524
|
def _u_dot_nabla_times_u_fwd(
|
|
262
|
-
t: Float[Array, "batch_size 1"],
|
|
263
525
|
x: Float[Array, "batch_size 2"],
|
|
264
526
|
u: eqx.Module,
|
|
265
527
|
params: Params,
|
|
@@ -277,29 +539,22 @@ def _u_dot_nabla_times_u_fwd(
|
|
|
277
539
|
computation is most efficient with forward AD. This is the idea behind Separable PINNs.
|
|
278
540
|
This function is to be used in the context of SPINNs only.
|
|
279
541
|
"""
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
u_at_x
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
[
|
|
300
|
-
u_at_x[..., 0] * du_dx[..., 0] + u_at_x[..., 1] * du_dy[..., 0],
|
|
301
|
-
u_at_x[..., 0] * du_dx[..., 1] + u_at_x[..., 1] * du_dy[..., 1],
|
|
302
|
-
],
|
|
303
|
-
axis=-1,
|
|
304
|
-
)
|
|
305
|
-
raise NotImplementedError("x.ndim must be 2")
|
|
542
|
+
assert x.shape[-1] == 2
|
|
543
|
+
tangent_vec_0 = jnp.repeat(jnp.array([1.0, 0.0])[None], x.shape[0], axis=0)
|
|
544
|
+
tangent_vec_1 = jnp.repeat(jnp.array([0.0, 1.0])[None], x.shape[0], axis=0)
|
|
545
|
+
u_at_x, du_dx = jax.jvp(
|
|
546
|
+
lambda x: u(x, params), (x,), (tangent_vec_0,)
|
|
547
|
+
) # thanks to forward AD this gets dux_dx and duy_dx in a vector
|
|
548
|
+
# ie the derivatives of both components of u wrt x
|
|
549
|
+
# this also gets the vector of u evaluated at x
|
|
550
|
+
u_at_x, du_dy = jax.jvp(
|
|
551
|
+
lambda x: u(x, params), (x,), (tangent_vec_1,)
|
|
552
|
+
) # thanks to forward AD this gets dux_dy and duy_dy in a vector
|
|
553
|
+
# ie the derivatives of both components of u wrt y
|
|
554
|
+
return jnp.stack(
|
|
555
|
+
[
|
|
556
|
+
u_at_x[..., 0] * du_dx[..., 0] + u_at_x[..., 1] * du_dy[..., 0],
|
|
557
|
+
u_at_x[..., 0] * du_dx[..., 1] + u_at_x[..., 1] * du_dy[..., 1],
|
|
558
|
+
],
|
|
559
|
+
axis=-1,
|
|
560
|
+
)
|