jinns 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_Batchs.py +4 -8
- jinns/data/_DataGenerators.py +532 -341
- jinns/loss/_DynamicLoss.py +150 -173
- jinns/loss/_DynamicLossAbstract.py +27 -73
- jinns/loss/_LossODE.py +45 -26
- jinns/loss/_LossPDE.py +85 -84
- jinns/loss/__init__.py +7 -6
- jinns/loss/_boundary_conditions.py +148 -279
- jinns/loss/_loss_utils.py +85 -58
- jinns/loss/_operators.py +441 -184
- jinns/parameters/_derivative_keys.py +487 -60
- jinns/plot/_plot.py +111 -98
- jinns/solver/_rar.py +102 -407
- jinns/solver/_solve.py +73 -38
- jinns/solver/_utils.py +122 -0
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +3 -1
- jinns/utils/_hyperpinn.py +17 -7
- jinns/utils/_pinn.py +17 -27
- jinns/utils/_ppinn.py +227 -0
- jinns/utils/_save_load.py +13 -13
- jinns/utils/_spinn.py +24 -43
- jinns/utils/_types.py +1 -0
- jinns/utils/_utils.py +40 -12
- jinns-1.2.0.dist-info/AUTHORS +2 -0
- jinns-1.2.0.dist-info/METADATA +127 -0
- jinns-1.2.0.dist-info/RECORD +41 -0
- {jinns-1.0.0.dist-info → jinns-1.2.0.dist-info}/WHEEL +1 -1
- jinns-1.0.0.dist-info/METADATA +0 -84
- jinns-1.0.0.dist-info/RECORD +0 -38
- {jinns-1.0.0.dist-info → jinns-1.2.0.dist-info}/LICENSE +0 -0
- {jinns-1.0.0.dist-info → jinns-1.2.0.dist-info}/top_level.txt +0 -0
jinns/loss/_operators.py
CHANGED
|
@@ -2,6 +2,8 @@
|
|
|
2
2
|
Implements diverse operators for dynamic losses
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
5
7
|
import jax
|
|
6
8
|
import jax.numpy as jnp
|
|
7
9
|
from jax import grad
|
|
@@ -12,35 +14,76 @@ from jinns.utils._spinn import SPINN
|
|
|
12
14
|
from jinns.parameters._params import Params
|
|
13
15
|
|
|
14
16
|
|
|
15
|
-
def
|
|
16
|
-
|
|
17
|
+
def divergence_rev(
|
|
18
|
+
inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
|
|
19
|
+
u: eqx.Module,
|
|
20
|
+
params: Params,
|
|
21
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
|
|
17
22
|
) -> float:
|
|
18
23
|
r"""
|
|
19
24
|
Compute the divergence of a vector field $\mathbf{u}$, i.e.,
|
|
20
|
-
$\
|
|
21
|
-
field from $\mathbb{R}^d$ to $\mathbb{R}^d
|
|
25
|
+
$\nabla_\mathbf{x} \cdot \mathbf{u}(\mathrm{inputs})$ with $\mathbf{u}$ a vector
|
|
26
|
+
field from $\mathbb{R}^d$ to $\mathbb{R}^d$ or $\mathbb{R}^{1+d}$
|
|
27
|
+
to $\mathbb{R}^{1+d}$. Thus, this
|
|
28
|
+
function can be used for stationary or non-stationary PINNs. In the first
|
|
29
|
+
case $\mathrm{inputs}=\mathbf{x}$, in the second case
|
|
30
|
+
case $\mathrm{inputs}=\mathbf{t,x}$.
|
|
22
31
|
The computation is done using backward AD
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
inputs
|
|
36
|
+
`x` or `t_x`
|
|
37
|
+
u
|
|
38
|
+
the PINN
|
|
39
|
+
params
|
|
40
|
+
the PINN parameters
|
|
41
|
+
eq_type
|
|
42
|
+
whether we consider a stationary or non stationary PINN. Most often we
|
|
43
|
+
can know that by inspecting the `u` argument (PINN object). But if `u` is
|
|
44
|
+
a function, we must set this attribute.
|
|
23
45
|
"""
|
|
24
46
|
|
|
47
|
+
try:
|
|
48
|
+
eq_type = u.eq_type
|
|
49
|
+
except AttributeError:
|
|
50
|
+
pass # use the value passed as argument
|
|
51
|
+
if eq_type is None:
|
|
52
|
+
raise ValueError("eq_type could not be set!")
|
|
53
|
+
|
|
25
54
|
def scan_fun(_, i):
|
|
26
|
-
if
|
|
27
|
-
du_dxi = grad(lambda
|
|
55
|
+
if eq_type == "nonstatio_PDE":
|
|
56
|
+
du_dxi = grad(lambda inputs, params: u(inputs, params)[1 + i])(
|
|
57
|
+
inputs, params
|
|
58
|
+
)[1 + i]
|
|
28
59
|
else:
|
|
29
|
-
du_dxi = grad(lambda
|
|
60
|
+
du_dxi = grad(lambda inputs, params: u(inputs, params)[i])(inputs, params)[
|
|
61
|
+
i
|
|
62
|
+
]
|
|
30
63
|
return _, du_dxi
|
|
31
64
|
|
|
32
|
-
|
|
65
|
+
if eq_type == "nonstatio_PDE":
|
|
66
|
+
_, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[0] - 1))
|
|
67
|
+
elif eq_type == "statio_PDE":
|
|
68
|
+
_, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[0]))
|
|
69
|
+
else:
|
|
70
|
+
raise ValueError("Unexpected u.eq_type!")
|
|
33
71
|
return jnp.sum(accu)
|
|
34
72
|
|
|
35
73
|
|
|
36
|
-
def
|
|
37
|
-
|
|
38
|
-
|
|
74
|
+
def divergence_fwd(
|
|
75
|
+
inputs: Float[Array, "batch_size dim"] | Float[Array, "batch_size 1+dim"],
|
|
76
|
+
u: eqx.Module,
|
|
77
|
+
params: Params,
|
|
78
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
|
|
79
|
+
) -> Float[Array, "batch_size * (1+dim) 1"] | Float[Array, "batch_size * (dim) 1"]:
|
|
39
80
|
r"""
|
|
40
81
|
Compute the divergence of a **batched** vector field $\mathbf{u}$, i.e.,
|
|
41
|
-
$\
|
|
82
|
+
$\nabla_\mathbf{x} \cdot \mathbf{u}(\mathbf{x})$ with $\mathbf{u}$ a vector
|
|
42
83
|
field from $\mathbb{R}^{b \times d}$ to $\mathbb{R}^{b \times b
|
|
43
|
-
\times d}
|
|
84
|
+
\times d}$ or from $\mathbb{R}^{b \times d+1}$ to $\mathbb{R}^{b \times b
|
|
85
|
+
\times d+1}$. Thus, this
|
|
86
|
+
function can be used for stationary or non-stationary PINNs.
|
|
44
87
|
Because of the embedding that happens in SPINNs the
|
|
45
88
|
computation is most efficient with forward AD. This is the idea behind
|
|
46
89
|
Separable PINNs.
|
|
@@ -48,79 +91,180 @@ def _div_fwd(
|
|
|
48
91
|
!!! warning "Warning"
|
|
49
92
|
|
|
50
93
|
This function is to be used in the context of SPINNs only.
|
|
94
|
+
|
|
95
|
+
Parameters
|
|
96
|
+
----------
|
|
97
|
+
inputs
|
|
98
|
+
`x` or `t_x`
|
|
99
|
+
u
|
|
100
|
+
the PINN
|
|
101
|
+
params
|
|
102
|
+
the PINN parameters
|
|
103
|
+
eq_type
|
|
104
|
+
whether we consider a stationary or non stationary PINN. Most often we
|
|
105
|
+
can know that by inspecting the `u` argument (PINN object). But if `u` is
|
|
106
|
+
a function, we must set this attribute.
|
|
51
107
|
"""
|
|
52
108
|
|
|
109
|
+
try:
|
|
110
|
+
eq_type = u.eq_type
|
|
111
|
+
except AttributeError:
|
|
112
|
+
pass # use the value passed as argument
|
|
113
|
+
if eq_type is None:
|
|
114
|
+
raise ValueError("eq_type could not be set!")
|
|
115
|
+
|
|
53
116
|
def scan_fun(_, i):
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
117
|
+
if eq_type == "nonstatio_PDE":
|
|
118
|
+
tangent_vec = jnp.repeat(
|
|
119
|
+
jax.nn.one_hot(i + 1, inputs.shape[-1])[None],
|
|
120
|
+
inputs.shape[0],
|
|
121
|
+
axis=0,
|
|
122
|
+
)
|
|
123
|
+
__, du_dxi = jax.jvp(
|
|
124
|
+
lambda inputs: u(inputs, params)[..., 1 + i], (inputs,), (tangent_vec,)
|
|
125
|
+
)
|
|
59
126
|
else:
|
|
127
|
+
tangent_vec = jnp.repeat(
|
|
128
|
+
jax.nn.one_hot(i, inputs.shape[-1])[None],
|
|
129
|
+
inputs.shape[0],
|
|
130
|
+
axis=0,
|
|
131
|
+
)
|
|
60
132
|
__, du_dxi = jax.jvp(
|
|
61
|
-
lambda
|
|
133
|
+
lambda inputs: u(inputs, params)[..., i], (inputs,), (tangent_vec,)
|
|
62
134
|
)
|
|
63
135
|
return _, du_dxi
|
|
64
136
|
|
|
65
|
-
|
|
137
|
+
if eq_type == "nonstatio_PDE":
|
|
138
|
+
_, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[1] - 1))
|
|
139
|
+
elif eq_type == "statio_PDE":
|
|
140
|
+
_, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[1]))
|
|
141
|
+
else:
|
|
142
|
+
raise ValueError("Unexpected u.eq_type!")
|
|
66
143
|
return jnp.sum(accu, axis=0)
|
|
67
144
|
|
|
68
145
|
|
|
69
|
-
def
|
|
70
|
-
|
|
146
|
+
def laplacian_rev(
|
|
147
|
+
inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
|
|
148
|
+
u: eqx.Module,
|
|
149
|
+
params: Params,
|
|
150
|
+
method: Literal["trace_hessian_x", "trace_hessian_t_x", "loop"] = "trace_hessian_x",
|
|
151
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
|
|
71
152
|
) -> float:
|
|
72
153
|
r"""
|
|
73
|
-
Compute the Laplacian of a scalar field $u$
|
|
74
|
-
to $\mathbb{R}$
|
|
75
|
-
|
|
154
|
+
Compute the Laplacian of a scalar field $u$ from $\mathbb{R}^d$
|
|
155
|
+
to $\mathbb{R}$ or from $\mathbb{R}^{1+d}$ to $\mathbb{R}$, i.e., this
|
|
156
|
+
function can be used for stationary or non-stationary PINNs. In the first
|
|
157
|
+
case $\mathrm{inputs}=\mathbf{x}$ is of arbitrary dimension, i.e.,
|
|
158
|
+
$\Delta_\mathbf{x} u(\mathbf{x})=\nabla_\mathbf{x}\cdot\nabla_\mathbf{x} u(\mathbf{x})$.
|
|
159
|
+
In the second case $inputs=\mathbf{t,x}$, but we still compute
|
|
160
|
+
$\Delta_\mathbf{x} u(\mathrm{inputs})$.
|
|
76
161
|
The computation is done using backward AD.
|
|
162
|
+
|
|
163
|
+
Parameters
|
|
164
|
+
----------
|
|
165
|
+
inputs
|
|
166
|
+
`x` or `t_x`
|
|
167
|
+
u
|
|
168
|
+
the PINN
|
|
169
|
+
params
|
|
170
|
+
the PINN parameters
|
|
171
|
+
method
|
|
172
|
+
how to compute the Laplacian. `"trace_hessian_x"` means that we take
|
|
173
|
+
the trace of the Hessian matrix computed with `x` only (`t` is excluded
|
|
174
|
+
from the beginning, we compute less derivatives at the price of a
|
|
175
|
+
concatenation). `"trace_hessian_t_x"` means that the computation
|
|
176
|
+
of the Hessian integrates `t` which is excluded at the end (we avoid a
|
|
177
|
+
concatenate but we compute more derivatives). `"loop"` means that we
|
|
178
|
+
directly compute the second order derivatives with a loop (we avoid
|
|
179
|
+
non-diagonal derivatives at the cost of a loop).
|
|
180
|
+
eq_type
|
|
181
|
+
whether we consider a stationary or non stationary PINN. Most often we
|
|
182
|
+
can know that by inspecting the `u` argument (PINN object). But if `u` is
|
|
183
|
+
a function, we must set this attribute.
|
|
77
184
|
"""
|
|
78
185
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
186
|
+
try:
|
|
187
|
+
eq_type = u.eq_type
|
|
188
|
+
except AttributeError:
|
|
189
|
+
pass # use the value passed as argument
|
|
190
|
+
if eq_type is None:
|
|
191
|
+
raise ValueError("eq_type could not be set!")
|
|
192
|
+
|
|
193
|
+
if method == "trace_hessian_x":
|
|
194
|
+
# NOTE we afford a concatenate here to avoid computing Hessian elements for
|
|
195
|
+
# nothing. In case of simple derivatives we prefer the vectorial
|
|
196
|
+
# computation and then discarding elements but for higher order derivatives
|
|
197
|
+
# it might not be worth it. See other options below for computating the
|
|
198
|
+
# Laplacian
|
|
199
|
+
if eq_type == "nonstatio_PDE":
|
|
200
|
+
u_ = lambda x: jnp.squeeze(
|
|
201
|
+
u(jnp.concatenate([inputs[:1], x], axis=0), params)
|
|
202
|
+
)
|
|
203
|
+
return jnp.sum(jnp.diag(jax.hessian(u_)(inputs[1:])))
|
|
204
|
+
if eq_type == "statio_PDE":
|
|
205
|
+
u_ = lambda inputs: jnp.squeeze(u(inputs, params))
|
|
206
|
+
return jnp.sum(jnp.diag(jax.hessian(u_)(inputs)))
|
|
207
|
+
raise ValueError("Unexpected eq_type!")
|
|
208
|
+
if method == "trace_hessian_t_x":
|
|
209
|
+
# NOTE that it is unclear whether it is better to vectorially compute the
|
|
210
|
+
# Hessian (despite a useless time dimension) as below
|
|
211
|
+
if eq_type == "nonstatio_PDE":
|
|
212
|
+
u_ = lambda inputs: jnp.squeeze(u(inputs, params))
|
|
213
|
+
return jnp.sum(jnp.diag(jax.hessian(u_)(inputs))[1:])
|
|
214
|
+
if eq_type == "statio_PDE":
|
|
215
|
+
u_ = lambda inputs: jnp.squeeze(u(inputs, params))
|
|
216
|
+
return jnp.sum(jnp.diag(jax.hessian(u_)(inputs)))
|
|
217
|
+
raise ValueError("Unexpected eq_type!")
|
|
218
|
+
|
|
219
|
+
if method == "loop":
|
|
220
|
+
# For a small d, we found out that trace of the Hessian is faster, see
|
|
221
|
+
# https://stackoverflow.com/questions/77517357/jax-grad-derivate-with-respect-an-specific-variable-in-a-matrix
|
|
222
|
+
# but could the trick below for taking directly the diagonal elements
|
|
223
|
+
# prove useful in higher dimensions?
|
|
224
|
+
|
|
225
|
+
u_ = lambda inputs: u(inputs, params).squeeze()
|
|
226
|
+
|
|
227
|
+
def scan_fun(_, i):
|
|
228
|
+
if eq_type == "nonstatio_PDE":
|
|
229
|
+
d2u_dxi2 = grad(
|
|
230
|
+
lambda inputs: grad(u_)(inputs)[1 + i],
|
|
231
|
+
)(
|
|
232
|
+
inputs
|
|
233
|
+
)[1 + i]
|
|
234
|
+
else:
|
|
235
|
+
d2u_dxi2 = grad(
|
|
236
|
+
lambda inputs: grad(u_, 0)(inputs)[i],
|
|
237
|
+
0,
|
|
238
|
+
)(
|
|
239
|
+
inputs
|
|
240
|
+
)[i]
|
|
241
|
+
return _, d2u_dxi2
|
|
242
|
+
|
|
243
|
+
if eq_type == "nonstatio_PDE":
|
|
244
|
+
_, trace_hessian = jax.lax.scan(
|
|
245
|
+
scan_fun, {}, jnp.arange(inputs.shape[0] - 1)
|
|
246
|
+
)
|
|
247
|
+
elif eq_type == "statio_PDE":
|
|
248
|
+
_, trace_hessian = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[0]))
|
|
249
|
+
else:
|
|
250
|
+
raise ValueError("Unexpected eq_type!")
|
|
251
|
+
return jnp.sum(trace_hessian)
|
|
252
|
+
raise ValueError("Unexpected method argument!")
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def laplacian_fwd(
|
|
256
|
+
inputs: Float[Array, "batch_size 1+dim"] | Float[Array, "batch_size dim"],
|
|
117
257
|
u: eqx.Module,
|
|
118
258
|
params: Params,
|
|
119
|
-
|
|
259
|
+
method: Literal["trace_hessian_t_x", "trace_hessian_x", "loop"] = "loop",
|
|
260
|
+
eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
|
|
261
|
+
) -> Float[Array, "batch_size * (1+dim) 1"] | Float[Array, "batch_size * (dim) 1"]:
|
|
120
262
|
r"""
|
|
121
263
|
Compute the Laplacian of a **batched** scalar field $u$
|
|
122
|
-
|
|
123
|
-
|
|
264
|
+
from $\mathbb{R}^{b\times d}$ to $\mathbb{R}^{b\times b}$ or
|
|
265
|
+
from $\mathbb{R}^{b\times (1 + d)}$ to $\mathbb{R}^{b\times b}$ or, i.e., this
|
|
266
|
+
function can be used for stationary or non-stationary PINNs
|
|
267
|
+
for $\mathbf{x}$ of arbitrary dimension $d$ or $1+d$ with batch
|
|
124
268
|
dimension $b$.
|
|
125
269
|
Because of the embedding that happens in SPINNs the
|
|
126
270
|
computation is most efficient with forward AD. This is the idea behind
|
|
@@ -129,90 +273,228 @@ def _laplacian_fwd(
|
|
|
129
273
|
!!! warning "Warning"
|
|
130
274
|
|
|
131
275
|
This function is to be used in the context of SPINNs only.
|
|
276
|
+
|
|
277
|
+
!!! warning "Warning"
|
|
278
|
+
|
|
279
|
+
Because of the batch dimension, the current implementation of
|
|
280
|
+
`method="trace_hessian_t_x"` or `method="trace_hessian_x"`
|
|
281
|
+
should not be used except for debugging
|
|
282
|
+
purposes. Indeed, computing the Hessian is very costly.
|
|
283
|
+
|
|
284
|
+
Parameters
|
|
285
|
+
----------
|
|
286
|
+
inputs
|
|
287
|
+
`x` or `t_x`
|
|
288
|
+
u
|
|
289
|
+
the PINN
|
|
290
|
+
params
|
|
291
|
+
the PINN parameters
|
|
292
|
+
method
|
|
293
|
+
how to compute the Laplacian. `"trace_hessian_t_x"` means that the computation
|
|
294
|
+
of the Hessian integrates `t` which is excluded at the end (**see
|
|
295
|
+
Warning**). `"trace_hessian_x"` means an Hessian computation which
|
|
296
|
+
excludes `t` (**see Warning**). `"loop"` means that we
|
|
297
|
+
directly compute the second order derivatives with a loop (we avoid
|
|
298
|
+
non-diagonal derivatives at the cost of a loop).
|
|
299
|
+
eq_type
|
|
300
|
+
whether we consider a stationary or non stationary PINN. Most often we
|
|
301
|
+
can know that by inspecting the `u` argument (PINN object). But if `u` is
|
|
302
|
+
a function, we must set this attribute.
|
|
132
303
|
"""
|
|
133
304
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
305
|
+
try:
|
|
306
|
+
eq_type = u.eq_type
|
|
307
|
+
except AttributeError:
|
|
308
|
+
pass # use the value passed as argument
|
|
309
|
+
if eq_type is None:
|
|
310
|
+
raise ValueError("eq_type could not be set!")
|
|
311
|
+
|
|
312
|
+
if method == "loop":
|
|
313
|
+
|
|
314
|
+
def scan_fun(_, i):
|
|
315
|
+
if eq_type == "nonstatio_PDE":
|
|
316
|
+
tangent_vec = jnp.repeat(
|
|
317
|
+
jax.nn.one_hot(i + 1, inputs.shape[-1])[None],
|
|
318
|
+
inputs.shape[0],
|
|
319
|
+
axis=0,
|
|
320
|
+
)
|
|
321
|
+
else:
|
|
322
|
+
tangent_vec = jnp.repeat(
|
|
323
|
+
jax.nn.one_hot(i, inputs.shape[-1])[None],
|
|
324
|
+
inputs.shape[0],
|
|
325
|
+
axis=0,
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
du_dxi_fun = lambda inputs: jax.jvp(
|
|
329
|
+
lambda inputs: u(inputs, params),
|
|
330
|
+
(inputs,),
|
|
331
|
+
(tangent_vec,),
|
|
332
|
+
)[1]
|
|
333
|
+
__, d2u_dxi2 = jax.jvp(du_dxi_fun, (inputs,), (tangent_vec,))
|
|
334
|
+
return _, d2u_dxi2
|
|
335
|
+
|
|
336
|
+
if eq_type == "nonstatio_PDE":
|
|
337
|
+
_, trace_hessian = jax.lax.scan(
|
|
338
|
+
scan_fun, {}, jnp.arange(inputs.shape[-1] - 1)
|
|
339
|
+
)
|
|
340
|
+
elif eq_type == "statio_PDE":
|
|
341
|
+
_, trace_hessian = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[-1]))
|
|
146
342
|
else:
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
343
|
+
raise ValueError("Unexpected eq_type!")
|
|
344
|
+
return jnp.sum(trace_hessian, axis=0)
|
|
345
|
+
if method == "trace_hessian_t_x":
|
|
346
|
+
if eq_type == "nonstatio_PDE":
|
|
347
|
+
# compute the Hessian including the batch dimension, get rid of the
|
|
348
|
+
# (..,1,..) axis that is here because of the scalar output
|
|
349
|
+
# if inputs.shape==(10,3) (1 for time, 2 for x_dim)
|
|
350
|
+
# then r.shape=(10,10,10,1,10,3,10,3)
|
|
351
|
+
# there are way too much derivatives!
|
|
352
|
+
r = jax.hessian(u)(inputs, params).squeeze()
|
|
353
|
+
# compute the traces by avoid the time derivatives
|
|
354
|
+
# after that r.shape=(10,10,10,10)
|
|
355
|
+
r = jnp.trace(r[..., :, 1:, :, 1:], axis1=-3, axis2=-1)
|
|
356
|
+
# but then we are in a cartesian product, for each coordinate on
|
|
357
|
+
# the first two dimensions we only want the trace at the same
|
|
358
|
+
# coordinate on the last two dimensions
|
|
359
|
+
# this is done easily with einsum but we need to automate the
|
|
360
|
+
# formula according to the input dim
|
|
361
|
+
res_dims = "".join([f"{chr(97 + d)}" for d in range(inputs.shape[-1])])
|
|
362
|
+
lap = jnp.einsum(res_dims + "ii->" + res_dims, r)
|
|
363
|
+
return lap[..., None]
|
|
364
|
+
if eq_type == "statio_PDE":
|
|
365
|
+
# compute the Hessian including the batch dimension, get rid of the
|
|
366
|
+
# (..,1,..) axis that is here because of the scalar output
|
|
367
|
+
# if inputs.shape==(10,2), r.shape=(10,10,1,10,2,10,2)
|
|
368
|
+
# there are way too much derivatives!
|
|
369
|
+
r = jax.hessian(u)(inputs, params).squeeze()
|
|
370
|
+
# compute the traces, after that r.shape=(10,10,10,10)
|
|
371
|
+
r = jnp.trace(r, axis1=-3, axis2=-1)
|
|
372
|
+
# but then we are in a cartesian product, for each coordinate on
|
|
373
|
+
# the first two dimensions we only want the trace at the same
|
|
374
|
+
# coordinate on the last two dimensions
|
|
375
|
+
# this is done easily with einsum but we need to automate the
|
|
376
|
+
# formula according to the input dim
|
|
377
|
+
res_dims = "".join([f"{chr(97 + d)}" for d in range(inputs.shape[-1])])
|
|
378
|
+
lap = jnp.einsum(res_dims + "ii->" + res_dims, r)
|
|
379
|
+
return lap[..., None]
|
|
380
|
+
raise ValueError("Unexpected eq_type!")
|
|
381
|
+
if method == "trace_hessian_x":
|
|
382
|
+
if eq_type == "statio_PDE":
|
|
383
|
+
# compute the Hessian including the batch dimension, get rid of the
|
|
384
|
+
# (..,1,..) axis that is here because of the scalar output
|
|
385
|
+
# if inputs.shape==(10,2), r.shape=(10,10,1,10,2,10,2)
|
|
386
|
+
# there are way too much derivatives!
|
|
387
|
+
r = jax.hessian(u)(inputs, params).squeeze()
|
|
388
|
+
# compute the traces, after that r.shape=(10,10,10,10)
|
|
389
|
+
r = jnp.trace(r, axis1=-3, axis2=-1)
|
|
390
|
+
# but then we are in a cartesian product, for each coordinate on
|
|
391
|
+
# the first two dimensions we only want the trace at the same
|
|
392
|
+
# coordinate on the last two dimensions
|
|
393
|
+
# this is done easily with einsum but we need to automate the
|
|
394
|
+
# formula according to the input dim
|
|
395
|
+
res_dims = "".join([f"{chr(97 + d)}" for d in range(inputs.shape[-1])])
|
|
396
|
+
lap = jnp.einsum(res_dims + "ii->" + res_dims, r)
|
|
397
|
+
return lap[..., None]
|
|
398
|
+
raise ValueError("Unexpected eq_type!")
|
|
399
|
+
raise ValueError("Unexpected method argument!")
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def vectorial_laplacian_rev(
|
|
403
|
+
inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
|
|
404
|
+
u: eqx.Module,
|
|
405
|
+
params: Params,
|
|
406
|
+
dim_out: int = None,
|
|
407
|
+
) -> Float[Array, "dim_out"]:
|
|
408
|
+
r"""
|
|
409
|
+
Compute the vectorial Laplacian of a vector field $\mathbf{u}$ from
|
|
410
|
+
$\mathbb{R}^d$ to $\mathbb{R}^n$ or from $\mathbb{R}^{1+d}$ to
|
|
411
|
+
$\mathbb{R}^n$, i.e., this
|
|
412
|
+
function can be used for stationary or non-stationary PINNs. In the first
|
|
413
|
+
case $\mathrm{inputs}=\mathbf{x}$ is of arbitrary dimension, i.e.,
|
|
414
|
+
$\Delta_\mathbf{x} \mathbf{u}(\mathbf{x})=\nabla\cdot\nabla
|
|
415
|
+
\mathbf{u}(\mathbf{x})$.
|
|
416
|
+
In the second case $inputs=\mathbf{t,x}$, and we perform
|
|
417
|
+
$\Delta_\mathbf{x} \mathbf{u}(\mathrm{inputs})=\nabla\cdot\nabla
|
|
418
|
+
\mathbf{u}(\mathrm{inputs})$.
|
|
419
|
+
|
|
420
|
+
Parameters
|
|
421
|
+
----------
|
|
422
|
+
inputs
|
|
423
|
+
`x` or `t_x`
|
|
424
|
+
u
|
|
425
|
+
the PINN
|
|
426
|
+
params
|
|
427
|
+
the PINN parameters
|
|
428
|
+
dim_out
|
|
429
|
+
Dimension of the vector $\mathbf{u}(\mathrm{inputs})$. This needs to be
|
|
430
|
+
provided if it is different than that of $\mathrm{inputs}$.
|
|
431
|
+
"""
|
|
432
|
+
if dim_out is None:
|
|
433
|
+
dim_out = inputs.shape[0]
|
|
434
|
+
|
|
435
|
+
def scan_fun(_, j):
|
|
436
|
+
# The loop over the components of u(x). We compute one Laplacian for
|
|
437
|
+
# each of these components
|
|
438
|
+
# Note the jnp.expand_dims call
|
|
439
|
+
uj = lambda inputs, params: jnp.expand_dims(u(inputs, params)[j], axis=-1)
|
|
440
|
+
lap_on_j = laplacian_rev(inputs, uj, params, eq_type=u.eq_type)
|
|
154
441
|
|
|
155
|
-
|
|
156
|
-
return jnp.sum(trace_hessian, axis=0)
|
|
442
|
+
return _, lap_on_j
|
|
157
443
|
|
|
444
|
+
_, vec_lap = jax.lax.scan(scan_fun, {}, jnp.arange(dim_out))
|
|
445
|
+
return vec_lap
|
|
158
446
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
447
|
+
|
|
448
|
+
def vectorial_laplacian_fwd(
|
|
449
|
+
inputs: Float[Array, "batch_size dim"] | Float[Array, "batch_size 1+dim"],
|
|
162
450
|
u: eqx.Module,
|
|
163
451
|
params: Params,
|
|
164
|
-
|
|
165
|
-
) -> (
|
|
166
|
-
Float[Array, "dimension_out"] | Float[Array, "batch_size batch_size dimension_out"]
|
|
167
|
-
):
|
|
452
|
+
dim_out: int = None,
|
|
453
|
+
) -> Float[Array, "batch_size * (1+dim) n"] | Float[Array, "batch_size * (dim) n"]:
|
|
168
454
|
r"""
|
|
169
|
-
Compute the vectorial Laplacian of a vector field $\mathbf{u}$
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
$\
|
|
173
|
-
|
|
455
|
+
Compute the vectorial Laplacian of a vector field $\mathbf{u}$ when
|
|
456
|
+
`u` is a SPINN, in this case, it corresponds to a vector
|
|
457
|
+
field from from $\mathbb{R}^{b\times d}$ to
|
|
458
|
+
$\mathbb{R}^{b\times b\times n}$ or from $\mathbb{R}^{b\times 1+d}$ to
|
|
459
|
+
$\mathbb{R}^{b\times b\times n}$, i.e., this
|
|
460
|
+
function can be used for stationary or non-stationary PINNs.
|
|
461
|
+
|
|
462
|
+
Forward mode AD is used.
|
|
463
|
+
|
|
464
|
+
!!! warning "Warning"
|
|
174
465
|
|
|
175
|
-
|
|
176
|
-
$\mathbf{u}(\mathbf{x})$ if it is different than that of
|
|
177
|
-
$\mathbf{x}$.
|
|
466
|
+
This function is to be used in the context of SPINNs only.
|
|
178
467
|
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
468
|
+
Parameters
|
|
469
|
+
----------
|
|
470
|
+
inputs
|
|
471
|
+
`x` or `t_x`
|
|
472
|
+
u
|
|
473
|
+
the PINN
|
|
474
|
+
params
|
|
475
|
+
the PINN parameters
|
|
476
|
+
dim_out
|
|
477
|
+
the value of the output dimension ($n$ in the formula above). Must be
|
|
478
|
+
set if different from $d$.
|
|
183
479
|
"""
|
|
184
|
-
if
|
|
185
|
-
|
|
480
|
+
if dim_out is None:
|
|
481
|
+
dim_out = inputs.shape[0]
|
|
186
482
|
|
|
187
483
|
def scan_fun(_, j):
|
|
188
484
|
# The loop over the components of u(x). We compute one Laplacian for
|
|
189
485
|
# each of these components
|
|
190
486
|
# Note the expand_dims
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
uj = lambda x, params: jnp.expand_dims(u(x, params)[j], axis=-1)
|
|
194
|
-
else:
|
|
195
|
-
uj = lambda t, x, params: jnp.expand_dims(u(t, x, params)[j], axis=-1)
|
|
196
|
-
lap_on_j = _laplacian_rev(t, x, uj, params)
|
|
197
|
-
elif isinstance(u, SPINN):
|
|
198
|
-
if t is None:
|
|
199
|
-
uj = lambda x, params: jnp.expand_dims(u(x, params)[..., j], axis=-1)
|
|
200
|
-
else:
|
|
201
|
-
uj = lambda t, x, params: jnp.expand_dims(
|
|
202
|
-
u(t, x, params)[..., j], axis=-1
|
|
203
|
-
)
|
|
204
|
-
lap_on_j = _laplacian_fwd(t, x, uj, params)
|
|
205
|
-
else:
|
|
206
|
-
raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
|
|
487
|
+
uj = lambda inputs, params: jnp.expand_dims(u(inputs, params)[..., j], axis=-1)
|
|
488
|
+
lap_on_j = laplacian_fwd(inputs, uj, params, eq_type=u.eq_type)
|
|
207
489
|
|
|
208
490
|
return _, lap_on_j
|
|
209
491
|
|
|
210
|
-
_, vec_lap = jax.lax.scan(scan_fun, {}, jnp.arange(
|
|
211
|
-
return vec_lap
|
|
492
|
+
_, vec_lap = jax.lax.scan(scan_fun, {}, jnp.arange(dim_out))
|
|
493
|
+
return jnp.moveaxis(vec_lap.squeeze(), 0, -1)
|
|
212
494
|
|
|
213
495
|
|
|
214
496
|
def _u_dot_nabla_times_u_rev(
|
|
215
|
-
|
|
497
|
+
x: Float[Array, "2"], u: eqx.Module, params: Params
|
|
216
498
|
) -> Float[Array, "2"]:
|
|
217
499
|
r"""
|
|
218
500
|
Implement $((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})$ for
|
|
@@ -223,43 +505,25 @@ def _u_dot_nabla_times_u_rev(
|
|
|
223
505
|
We do not use loops but code explicitly the expression to avoid
|
|
224
506
|
computing twice some terms
|
|
225
507
|
"""
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
uy = lambda x: u(x, params)[1]
|
|
230
|
-
|
|
231
|
-
dux_dx = lambda x: grad(ux, 0)(x)[0]
|
|
232
|
-
dux_dy = lambda x: grad(ux, 0)(x)[1]
|
|
233
|
-
|
|
234
|
-
duy_dx = lambda x: grad(uy, 0)(x)[0]
|
|
235
|
-
duy_dy = lambda x: grad(uy, 0)(x)[1]
|
|
236
|
-
|
|
237
|
-
return jnp.array(
|
|
238
|
-
[
|
|
239
|
-
ux(x) * dux_dx(x) + uy(x) * dux_dy(x),
|
|
240
|
-
ux(x) * duy_dx(x) + uy(x) * duy_dy(x),
|
|
241
|
-
]
|
|
242
|
-
)
|
|
243
|
-
ux = lambda t, x: u(t, x, params)[0]
|
|
244
|
-
uy = lambda t, x: u(t, x, params)[1]
|
|
508
|
+
assert x.shape[0] == 2
|
|
509
|
+
ux = lambda x: u(x, params)[0]
|
|
510
|
+
uy = lambda x: u(x, params)[1]
|
|
245
511
|
|
|
246
|
-
|
|
247
|
-
|
|
512
|
+
dux_dx = lambda x: grad(ux, 0)(x)[0]
|
|
513
|
+
dux_dy = lambda x: grad(ux, 0)(x)[1]
|
|
248
514
|
|
|
249
|
-
|
|
250
|
-
|
|
515
|
+
duy_dx = lambda x: grad(uy, 0)(x)[0]
|
|
516
|
+
duy_dy = lambda x: grad(uy, 0)(x)[1]
|
|
251
517
|
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
raise NotImplementedError("x.ndim must be 2")
|
|
518
|
+
return jnp.array(
|
|
519
|
+
[
|
|
520
|
+
ux(x) * dux_dx(x) + uy(x) * dux_dy(x),
|
|
521
|
+
ux(x) * duy_dx(x) + uy(x) * duy_dy(x),
|
|
522
|
+
]
|
|
523
|
+
)
|
|
259
524
|
|
|
260
525
|
|
|
261
526
|
def _u_dot_nabla_times_u_fwd(
|
|
262
|
-
t: Float[Array, "batch_size 1"],
|
|
263
527
|
x: Float[Array, "batch_size 2"],
|
|
264
528
|
u: eqx.Module,
|
|
265
529
|
params: Params,
|
|
@@ -277,29 +541,22 @@ def _u_dot_nabla_times_u_fwd(
|
|
|
277
541
|
computation is most efficient with forward AD. This is the idea behind Separable PINNs.
|
|
278
542
|
This function is to be used in the context of SPINNs only.
|
|
279
543
|
"""
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
u_at_x
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
[
|
|
300
|
-
u_at_x[..., 0] * du_dx[..., 0] + u_at_x[..., 1] * du_dy[..., 0],
|
|
301
|
-
u_at_x[..., 0] * du_dx[..., 1] + u_at_x[..., 1] * du_dy[..., 1],
|
|
302
|
-
],
|
|
303
|
-
axis=-1,
|
|
304
|
-
)
|
|
305
|
-
raise NotImplementedError("x.ndim must be 2")
|
|
544
|
+
assert x.shape[-1] == 2
|
|
545
|
+
tangent_vec_0 = jnp.repeat(jnp.array([1.0, 0.0])[None], x.shape[0], axis=0)
|
|
546
|
+
tangent_vec_1 = jnp.repeat(jnp.array([0.0, 1.0])[None], x.shape[0], axis=0)
|
|
547
|
+
u_at_x, du_dx = jax.jvp(
|
|
548
|
+
lambda x: u(x, params), (x,), (tangent_vec_0,)
|
|
549
|
+
) # thanks to forward AD this gets dux_dx and duy_dx in a vector
|
|
550
|
+
# ie the derivatives of both components of u wrt x
|
|
551
|
+
# this also gets the vector of u evaluated at x
|
|
552
|
+
u_at_x, du_dy = jax.jvp(
|
|
553
|
+
lambda x: u(x, params), (x,), (tangent_vec_1,)
|
|
554
|
+
) # thanks to forward AD this gets dux_dy and duy_dy in a vector
|
|
555
|
+
# ie the derivatives of both components of u wrt y
|
|
556
|
+
return jnp.stack(
|
|
557
|
+
[
|
|
558
|
+
u_at_x[..., 0] * du_dx[..., 0] + u_at_x[..., 1] * du_dy[..., 0],
|
|
559
|
+
u_at_x[..., 0] * du_dx[..., 1] + u_at_x[..., 1] * du_dy[..., 1],
|
|
560
|
+
],
|
|
561
|
+
axis=-1,
|
|
562
|
+
)
|