jinns 0.8.10__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -0
- jinns/data/_Batchs.py +27 -0
- jinns/data/_DataGenerators.py +953 -1182
- jinns/data/__init__.py +4 -8
- jinns/experimental/__init__.py +0 -2
- jinns/experimental/_diffrax_solver.py +5 -5
- jinns/loss/_DynamicLoss.py +282 -305
- jinns/loss/_DynamicLossAbstract.py +321 -168
- jinns/loss/_LossODE.py +290 -307
- jinns/loss/_LossPDE.py +628 -1040
- jinns/loss/__init__.py +21 -5
- jinns/loss/_boundary_conditions.py +95 -96
- jinns/loss/{_Losses.py → _loss_utils.py} +104 -46
- jinns/loss/_loss_weights.py +59 -0
- jinns/loss/_operators.py +78 -72
- jinns/parameters/__init__.py +6 -0
- jinns/parameters/_derivative_keys.py +94 -0
- jinns/parameters/_params.py +115 -0
- jinns/plot/__init__.py +5 -0
- jinns/{data/_display.py → plot/_plot.py} +98 -75
- jinns/solver/_rar.py +193 -45
- jinns/solver/_solve.py +199 -144
- jinns/utils/__init__.py +3 -9
- jinns/utils/_containers.py +37 -43
- jinns/utils/_hyperpinn.py +226 -127
- jinns/utils/_pinn.py +183 -111
- jinns/utils/_save_load.py +121 -56
- jinns/utils/_spinn.py +117 -84
- jinns/utils/_types.py +64 -0
- jinns/utils/_utils.py +6 -160
- jinns/validation/_validation.py +52 -144
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/METADATA +5 -4
- jinns-1.0.0.dist-info/RECORD +38 -0
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/WHEEL +1 -1
- jinns/experimental/_sinuspinn.py +0 -135
- jinns/experimental/_spectralpinn.py +0 -87
- jinns/solver/_seq2seq.py +0 -157
- jinns/utils/_optim.py +0 -147
- jinns/utils/_utils_uspinn.py +0 -727
- jinns-0.8.10.dist-info/RECORD +0 -36
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/LICENSE +0 -0
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Formalize the loss weights data structure
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Dict
|
|
6
|
+
from jaxtyping import Array, Float
|
|
7
|
+
import equinox as eqx
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class LossWeightsODE(eqx.Module):
|
|
11
|
+
|
|
12
|
+
dyn_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
13
|
+
initial_condition: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
14
|
+
observations: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class LossWeightsODEDict(eqx.Module):
|
|
18
|
+
|
|
19
|
+
dyn_loss: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=None)
|
|
20
|
+
initial_condition: Dict[str, Array | Float | None] = eqx.field(
|
|
21
|
+
kw_only=True, default=None
|
|
22
|
+
)
|
|
23
|
+
observations: Dict[str, Array | Float | None] = eqx.field(
|
|
24
|
+
kw_only=True, default=None
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class LossWeightsPDEStatio(eqx.Module):
|
|
29
|
+
|
|
30
|
+
dyn_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
31
|
+
norm_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
32
|
+
boundary_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
33
|
+
observations: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class LossWeightsPDENonStatio(eqx.Module):
|
|
37
|
+
|
|
38
|
+
dyn_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
39
|
+
norm_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
40
|
+
boundary_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
41
|
+
observations: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
42
|
+
initial_condition: Array | Float | None = eqx.field(kw_only=True, default=1.0)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class LossWeightsPDEDict(eqx.Module):
|
|
46
|
+
"""
|
|
47
|
+
Only one type of LossWeights data structure for the SystemLossPDE:
|
|
48
|
+
Include the initial condition always for the code to be more generic
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
dyn_loss: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=1.0)
|
|
52
|
+
norm_loss: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=1.0)
|
|
53
|
+
boundary_loss: Dict[str, Array | Float | None] = eqx.field(
|
|
54
|
+
kw_only=True, default=1.0
|
|
55
|
+
)
|
|
56
|
+
observations: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=1.0)
|
|
57
|
+
initial_condition: Dict[str, Array | Float | None] = eqx.field(
|
|
58
|
+
kw_only=True, default=1.0
|
|
59
|
+
)
|
jinns/loss/_operators.py
CHANGED
|
@@ -5,15 +5,20 @@ Implements diverse operators for dynamic losses
|
|
|
5
5
|
import jax
|
|
6
6
|
import jax.numpy as jnp
|
|
7
7
|
from jax import grad
|
|
8
|
+
import equinox as eqx
|
|
9
|
+
from jaxtyping import Float, Array
|
|
8
10
|
from jinns.utils._pinn import PINN
|
|
9
11
|
from jinns.utils._spinn import SPINN
|
|
12
|
+
from jinns.parameters._params import Params
|
|
10
13
|
|
|
11
14
|
|
|
12
|
-
def _div_rev(
|
|
15
|
+
def _div_rev(
|
|
16
|
+
t: Float[Array, "1"], x: Float[Array, "dimension"], u: eqx.Module, params: Params
|
|
17
|
+
) -> float:
|
|
13
18
|
r"""
|
|
14
|
-
Compute the divergence of a vector field
|
|
15
|
-
|
|
16
|
-
field from
|
|
19
|
+
Compute the divergence of a vector field $\mathbf{u}$, i.e.,
|
|
20
|
+
$\nabla \cdot \mathbf{u}(\mathbf{x})$ with $\mathbf{u}$ a vector
|
|
21
|
+
field from $\mathbb{R}^d$ to $\mathbb{R}^d$.
|
|
17
22
|
The computation is done using backward AD
|
|
18
23
|
"""
|
|
19
24
|
|
|
@@ -28,15 +33,21 @@ def _div_rev(t, x, u, params):
|
|
|
28
33
|
return jnp.sum(accu)
|
|
29
34
|
|
|
30
35
|
|
|
31
|
-
def _div_fwd(
|
|
36
|
+
def _div_fwd(
|
|
37
|
+
t: Float[Array, "1"], x: Float[Array, "dimension"], u: eqx.Module, params: Params
|
|
38
|
+
) -> float:
|
|
32
39
|
r"""
|
|
33
|
-
Compute the divergence of a **batched** vector field
|
|
34
|
-
|
|
35
|
-
field from
|
|
36
|
-
\times d}
|
|
40
|
+
Compute the divergence of a **batched** vector field $\mathbf{u}$, i.e.,
|
|
41
|
+
$\nabla \cdot \mathbf{u}(\mathbf{x})$ with $\mathbf{u}$ a vector
|
|
42
|
+
field from $\mathbb{R}^{b \times d}$ to $\mathbb{R}^{b \times b
|
|
43
|
+
\times d}$. The result is then in $\mathbb{R}^{b\times b}$.
|
|
37
44
|
Because of the embedding that happens in SPINNs the
|
|
38
|
-
computation is most efficient with forward AD. This is the idea behind
|
|
39
|
-
|
|
45
|
+
computation is most efficient with forward AD. This is the idea behind
|
|
46
|
+
Separable PINNs.
|
|
47
|
+
|
|
48
|
+
!!! warning "Warning"
|
|
49
|
+
|
|
50
|
+
This function is to be used in the context of SPINNs only.
|
|
40
51
|
"""
|
|
41
52
|
|
|
42
53
|
def scan_fun(_, i):
|
|
@@ -55,11 +66,13 @@ def _div_fwd(t, x, u, params):
|
|
|
55
66
|
return jnp.sum(accu, axis=0)
|
|
56
67
|
|
|
57
68
|
|
|
58
|
-
def _laplacian_rev(
|
|
69
|
+
def _laplacian_rev(
|
|
70
|
+
t: Float[Array, "1"], x: Float[Array, "dimension"], u: eqx.Module, params: Params
|
|
71
|
+
) -> float:
|
|
59
72
|
r"""
|
|
60
|
-
Compute the Laplacian of a scalar field
|
|
61
|
-
to
|
|
62
|
-
|
|
73
|
+
Compute the Laplacian of a scalar field $u$ (from $\mathbb{R}^d$
|
|
74
|
+
to $\mathbb{R}$) for $\mathbf{x}$ of arbitrary dimension, i.e.,
|
|
75
|
+
$\Delta u(\mathbf{x})=\nabla\cdot\nabla u(\mathbf{x})$.
|
|
63
76
|
The computation is done using backward AD.
|
|
64
77
|
"""
|
|
65
78
|
|
|
@@ -98,15 +111,24 @@ def _laplacian_rev(t, x, u, params):
|
|
|
98
111
|
# return jnp.sum(trace_hessian)
|
|
99
112
|
|
|
100
113
|
|
|
101
|
-
def _laplacian_fwd(
|
|
114
|
+
def _laplacian_fwd(
|
|
115
|
+
t: Float[Array, "batch_size 1"],
|
|
116
|
+
x: Float[Array, "batch_size dimension"],
|
|
117
|
+
u: eqx.Module,
|
|
118
|
+
params: Params,
|
|
119
|
+
) -> Float[Array, "batch_size batch_size"]:
|
|
102
120
|
r"""
|
|
103
|
-
Compute the Laplacian of a **batched** scalar field
|
|
104
|
-
(from
|
|
105
|
-
for
|
|
106
|
-
dimension
|
|
121
|
+
Compute the Laplacian of a **batched** scalar field $u$
|
|
122
|
+
(from $\mathbb{R}^{b\times d}$ to $\mathbb{R}^{b\times b}$)
|
|
123
|
+
for $\mathbf{x}$ of arbitrary dimension $d$ with batch
|
|
124
|
+
dimension $b$.
|
|
107
125
|
Because of the embedding that happens in SPINNs the
|
|
108
|
-
computation is most efficient with forward AD. This is the idea behind
|
|
109
|
-
|
|
126
|
+
computation is most efficient with forward AD. This is the idea behind
|
|
127
|
+
Separable PINNs.
|
|
128
|
+
|
|
129
|
+
!!! warning "Warning"
|
|
130
|
+
|
|
131
|
+
This function is to be used in the context of SPINNs only.
|
|
110
132
|
"""
|
|
111
133
|
|
|
112
134
|
def scan_fun(_, i):
|
|
@@ -134,22 +156,30 @@ def _laplacian_fwd(t, x, u, params):
|
|
|
134
156
|
return jnp.sum(trace_hessian, axis=0)
|
|
135
157
|
|
|
136
158
|
|
|
137
|
-
def _vectorial_laplacian(
|
|
159
|
+
def _vectorial_laplacian(
|
|
160
|
+
t: Float[Array, "1"] | Float[Array, "batch_size 1"],
|
|
161
|
+
x: Float[Array, "dimension_in"] | Float[Array, "batch_size dimension"],
|
|
162
|
+
u: eqx.Module,
|
|
163
|
+
params: Params,
|
|
164
|
+
u_vec_ndim: int = None,
|
|
165
|
+
) -> (
|
|
166
|
+
Float[Array, "dimension_out"] | Float[Array, "batch_size batch_size dimension_out"]
|
|
167
|
+
):
|
|
138
168
|
r"""
|
|
139
|
-
Compute the vectorial Laplacian of a vector field
|
|
140
|
-
|
|
141
|
-
to
|
|
142
|
-
|
|
143
|
-
\mathbf{u}(\mathbf{x})
|
|
169
|
+
Compute the vectorial Laplacian of a vector field $\mathbf{u}$ (from
|
|
170
|
+
$\mathbb{R}^d$
|
|
171
|
+
to $\mathbb{R}^n$) for $\mathbf{x}$ of arbitrary dimension, i.e.,
|
|
172
|
+
$\Delta \mathbf{u}(\mathbf{x})=\nabla\cdot\nabla
|
|
173
|
+
\mathbf{u}(\mathbf{x})$.
|
|
144
174
|
|
|
145
175
|
**Note:** We need to provide `u_vec_ndim` the dimension of the vector
|
|
146
|
-
|
|
147
|
-
|
|
176
|
+
$\mathbf{u}(\mathbf{x})$ if it is different than that of
|
|
177
|
+
$\mathbf{x}$.
|
|
148
178
|
|
|
149
179
|
**Note:** `u` can be a SPINN, in this case, it corresponds to a vector
|
|
150
|
-
field from (from
|
|
151
|
-
|
|
152
|
-
Technically, the return is of dimension
|
|
180
|
+
field from (from $\mathbb{R}^{b\times d}$ to
|
|
181
|
+
$\mathbb{R}^{b\times b\times n}$) and forward mode AD is used.
|
|
182
|
+
Technically, the return is of dimension $n\times b \times b$.
|
|
153
183
|
"""
|
|
154
184
|
if u_vec_ndim is None:
|
|
155
185
|
u_vec_ndim = x.shape[0]
|
|
@@ -172,6 +202,8 @@ def _vectorial_laplacian(t, x, u, params, u_vec_ndim=None):
|
|
|
172
202
|
u(t, x, params)[..., j], axis=-1
|
|
173
203
|
)
|
|
174
204
|
lap_on_j = _laplacian_fwd(t, x, uj, params)
|
|
205
|
+
else:
|
|
206
|
+
raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
|
|
175
207
|
|
|
176
208
|
return _, lap_on_j
|
|
177
209
|
|
|
@@ -179,12 +211,14 @@ def _vectorial_laplacian(t, x, u, params, u_vec_ndim=None):
|
|
|
179
211
|
return vec_lap
|
|
180
212
|
|
|
181
213
|
|
|
182
|
-
def _u_dot_nabla_times_u_rev(
|
|
214
|
+
def _u_dot_nabla_times_u_rev(
|
|
215
|
+
t: Float[Array, "1"], x: Float[Array, "2"], u: eqx.Module, params: Params
|
|
216
|
+
) -> Float[Array, "2"]:
|
|
183
217
|
r"""
|
|
184
|
-
Implement
|
|
185
|
-
|
|
186
|
-
dimension.
|
|
187
|
-
to
|
|
218
|
+
Implement $((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})$ for
|
|
219
|
+
$\mathbf{x}$ of arbitrary
|
|
220
|
+
dimension. $\mathbf{u}$ is a vector field from $\mathbb{R}^n$
|
|
221
|
+
to $\mathbb{R}^n$. **Currently for** `x.ndim=2` **only**.
|
|
188
222
|
The computation is done using backward AD.
|
|
189
223
|
We do not use loops but code explicitly the expression to avoid
|
|
190
224
|
computing twice some terms
|
|
@@ -224,7 +258,12 @@ def _u_dot_nabla_times_u_rev(t, x, u, params):
|
|
|
224
258
|
raise NotImplementedError("x.ndim must be 2")
|
|
225
259
|
|
|
226
260
|
|
|
227
|
-
def _u_dot_nabla_times_u_fwd(
|
|
261
|
+
def _u_dot_nabla_times_u_fwd(
|
|
262
|
+
t: Float[Array, "batch_size 1"],
|
|
263
|
+
x: Float[Array, "batch_size 2"],
|
|
264
|
+
u: eqx.Module,
|
|
265
|
+
params: Params,
|
|
266
|
+
) -> Float[Array, "batch_size batch_size 2"]:
|
|
228
267
|
r"""
|
|
229
268
|
Implement :math:`((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})` for
|
|
230
269
|
:math:`\mathbf{x}` of arbitrary dimension **with a batch dimension**.
|
|
@@ -264,36 +303,3 @@ def _u_dot_nabla_times_u_fwd(t, x, u, params):
|
|
|
264
303
|
axis=-1,
|
|
265
304
|
)
|
|
266
305
|
raise NotImplementedError("x.ndim must be 2")
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
def _sobolev(u, m, statio=True):
|
|
270
|
-
r"""
|
|
271
|
-
Compute the Sobolev regularization of order :math:`m`
|
|
272
|
-
of a scalar field :math:`u` (from :math:`\mathbb{R}^{d}` to :math:`\mathbb{R}`)
|
|
273
|
-
for :math:`\mathbf{x}` of arbitrary dimension :math:`d`, i.e.,
|
|
274
|
-
:math:`\frac{1}{n_l}\sum_{l=1}^{n_l}\sum_{|\alpha|=1}^{m+1} ||\partial^{\alpha} u(x_l)||_2^2` where
|
|
275
|
-
:math:`m\geq\max(d_1 // 2, K)` with :math:`K` the order of the differential
|
|
276
|
-
operator.
|
|
277
|
-
|
|
278
|
-
This regularization is proposed in *Convergence and error analysis of
|
|
279
|
-
PINNs*, Doumeche et al., 2023, https://arxiv.org/pdf/2305.01240.pdf
|
|
280
|
-
"""
|
|
281
|
-
|
|
282
|
-
def jac_recursive(u, order, start):
|
|
283
|
-
# Compute the derivative of order `start`
|
|
284
|
-
if order == 0:
|
|
285
|
-
return u
|
|
286
|
-
if start == 0:
|
|
287
|
-
return jac_recursive(jax.jacrev(u), order - 1, start + 1)
|
|
288
|
-
return jac_recursive(jax.jacfwd(u), order - 1, start + 1)
|
|
289
|
-
|
|
290
|
-
if statio:
|
|
291
|
-
return lambda x, params: jnp.sum(
|
|
292
|
-
jac_recursive(lambda x: u(x, params), m + 1, 0)(x) ** 2
|
|
293
|
-
)
|
|
294
|
-
return lambda t, x, params: jnp.sum(
|
|
295
|
-
jac_recursive(lambda tx: u(tx[0:1], tx[1:], params), m + 1, 0)(
|
|
296
|
-
jnp.concatenate([t, x], axis=0)
|
|
297
|
-
)
|
|
298
|
-
** 2
|
|
299
|
-
)
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Formalize the data structure for the derivative keys
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from dataclasses import fields
|
|
6
|
+
from typing import Literal
|
|
7
|
+
import jax
|
|
8
|
+
import equinox as eqx
|
|
9
|
+
|
|
10
|
+
from jinns.parameters._params import Params
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DerivativeKeysODE(eqx.Module):
|
|
14
|
+
# we use static = True because all fields are string, hence should be
|
|
15
|
+
# invisible by JAX transforms (JIT, etc.)
|
|
16
|
+
dyn_loss: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
|
|
17
|
+
kw_only=True, default="nn_params", static=True
|
|
18
|
+
)
|
|
19
|
+
observations: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
|
|
20
|
+
kw_only=True, default="nn_params", static=True
|
|
21
|
+
)
|
|
22
|
+
initial_condition: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
|
|
23
|
+
kw_only=True, default="nn_params", static=True
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class DerivativeKeysPDEStatio(eqx.Module):
|
|
28
|
+
|
|
29
|
+
dyn_loss: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
|
|
30
|
+
kw_only=True, default="nn_params", static=True
|
|
31
|
+
)
|
|
32
|
+
observations: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
|
|
33
|
+
kw_only=True, default="nn_params", static=True
|
|
34
|
+
)
|
|
35
|
+
boundary_loss: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
|
|
36
|
+
kw_only=True, default="nn_params", static=True
|
|
37
|
+
)
|
|
38
|
+
norm_loss: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
|
|
39
|
+
kw_only=True, default="nn_params", static=True
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class DerivativeKeysPDENonStatio(DerivativeKeysPDEStatio):
|
|
44
|
+
|
|
45
|
+
initial_condition: Literal["nn_params", "eq_params", "both"] = eqx.field(
|
|
46
|
+
kw_only=True, default="nn_params", static=True
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _set_derivatives(params, derivative_keys):
|
|
51
|
+
"""
|
|
52
|
+
We construct an eqx.Module with the fields of derivative_keys, each field
|
|
53
|
+
has a copy of the params with appropriate derivatives set
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def _set_derivatives_(loss_term_derivative):
|
|
57
|
+
if loss_term_derivative == "both":
|
|
58
|
+
return params
|
|
59
|
+
# the next line put a stop_gradient around the fields that do not
|
|
60
|
+
# appear in loss_term_derivative. Currently there are only two possible
|
|
61
|
+
# values nn_params and eq_params but there might be more in the future
|
|
62
|
+
return eqx.tree_at(
|
|
63
|
+
lambda p: tuple(
|
|
64
|
+
getattr(p, f.name)
|
|
65
|
+
for f in fields(Params)
|
|
66
|
+
if f.name != loss_term_derivative
|
|
67
|
+
),
|
|
68
|
+
params,
|
|
69
|
+
replace_fn=jax.lax.stop_gradient,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
def _set_derivatives_dict(loss_term_derivative):
|
|
73
|
+
if loss_term_derivative == "both":
|
|
74
|
+
return params
|
|
75
|
+
# the next line put a stop_gradient around the fields that do not
|
|
76
|
+
# appear in loss_term_derivative. Currently there are only two possible
|
|
77
|
+
# values nn_params and eq_params but there might be more in the future
|
|
78
|
+
return {
|
|
79
|
+
k: eqx.tree_at(
|
|
80
|
+
lambda p: tuple(
|
|
81
|
+
getattr(p, f.name)
|
|
82
|
+
for f in fields(Params)
|
|
83
|
+
if f.name != loss_term_derivative
|
|
84
|
+
),
|
|
85
|
+
params_,
|
|
86
|
+
replace_fn=jax.lax.stop_gradient,
|
|
87
|
+
)
|
|
88
|
+
for k, params_ in params
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
if not isinstance(params, dict):
|
|
92
|
+
return _set_derivatives_(derivative_keys)
|
|
93
|
+
else:
|
|
94
|
+
return _set_derivatives_dict(derivative_keys)
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Formalize the data structure for the parameters
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
import equinox as eqx
|
|
7
|
+
from typing import Dict
|
|
8
|
+
from jaxtyping import Array, PyTree
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Params(eqx.Module):
|
|
12
|
+
"""
|
|
13
|
+
The equinox module for the parameters
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
nn_params : Pytree
|
|
18
|
+
A PyTree of the non-static part of the PINN eqx.Module, i.e., the
|
|
19
|
+
parameters of the PINN
|
|
20
|
+
eq_params : Dict[str, Array]
|
|
21
|
+
A dictionary of the equation parameters. Keys are the parameter name,
|
|
22
|
+
values are their corresponding value
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
nn_params: PyTree = eqx.field(kw_only=True, default=None)
|
|
26
|
+
eq_params: Dict[str, Array] = eqx.field(kw_only=True, default=None)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ParamsDict(eqx.Module):
|
|
30
|
+
"""
|
|
31
|
+
The equinox module for a dictionnary of parameters with different keys
|
|
32
|
+
corresponding to different equations.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
nn_params : Dict[str, PyTree]
|
|
37
|
+
The neural network's parameters. Most of the time, it will be the
|
|
38
|
+
Array part of an `eqx.Module` obtained by
|
|
39
|
+
`eqx.partition(module, eqx.is_inexact_array)`.
|
|
40
|
+
eq_params : Dict[str, Array]
|
|
41
|
+
A dictionary of the equation parameters. Dict keys are the parameter name as defined your custom loss.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
nn_params: Dict[str, PyTree] = eqx.field(kw_only=True, default=None)
|
|
45
|
+
eq_params: Dict[str, Array] = eqx.field(kw_only=True, default=None)
|
|
46
|
+
|
|
47
|
+
def extract_params(self, nn_key: str) -> Params:
|
|
48
|
+
"""
|
|
49
|
+
Extract the corresponding `nn_params` and `eq_params` for `nn_key` and
|
|
50
|
+
return them in the form of a `Params` object.
|
|
51
|
+
"""
|
|
52
|
+
try:
|
|
53
|
+
return Params(
|
|
54
|
+
nn_params=self.nn_params[nn_key],
|
|
55
|
+
eq_params=self.eq_params[nn_key],
|
|
56
|
+
)
|
|
57
|
+
except (KeyError, IndexError) as e:
|
|
58
|
+
return Params(
|
|
59
|
+
nn_params=self.nn_params[nn_key],
|
|
60
|
+
eq_params=self.eq_params,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _update_eq_params_dict(
|
|
65
|
+
params: Params, param_batch_dict: Dict[str, Array]
|
|
66
|
+
) -> Params:
|
|
67
|
+
"""
|
|
68
|
+
Update params.eq_params with a batch of eq_params for given key(s)
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
# artificially "complete" `param_batch_dict` with None to match `params`
|
|
72
|
+
# PyTree structure
|
|
73
|
+
param_batch_dict_ = param_batch_dict | {
|
|
74
|
+
k: None for k in set(params.eq_params.keys()) - set(param_batch_dict.keys())
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
# Replace at non None leafs
|
|
78
|
+
params = eqx.tree_at(
|
|
79
|
+
lambda p: p.eq_params,
|
|
80
|
+
params,
|
|
81
|
+
jax.tree_util.tree_map(
|
|
82
|
+
lambda p, q: q if q is not None else p,
|
|
83
|
+
params.eq_params,
|
|
84
|
+
param_batch_dict_,
|
|
85
|
+
),
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return params
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _get_vmap_in_axes_params(
|
|
92
|
+
eq_params_batch_dict: Dict[str, Array], params: Params | ParamsDict
|
|
93
|
+
) -> tuple[Params]:
|
|
94
|
+
"""
|
|
95
|
+
Return the input vmap axes when there is batch(es) of parameters to vmap
|
|
96
|
+
over. The latter are designated by keys in eq_params_batch_dict.
|
|
97
|
+
If eq_params_batch_dict is None (i.e. no additional parameter batch), we
|
|
98
|
+
return (None,).
|
|
99
|
+
"""
|
|
100
|
+
if eq_params_batch_dict is None:
|
|
101
|
+
return (None,)
|
|
102
|
+
# We use pytree indexing of vmapped axes and vmap on axis
|
|
103
|
+
# 0 of the eq_parameters for which we have a batch
|
|
104
|
+
# this is for a fine-grained vmaping
|
|
105
|
+
# scheme over the params
|
|
106
|
+
vmap_in_axes_params = (
|
|
107
|
+
type(params)(
|
|
108
|
+
nn_params=None,
|
|
109
|
+
eq_params={
|
|
110
|
+
k: (0 if k in eq_params_batch_dict.keys() else None)
|
|
111
|
+
for k in params.eq_params.keys()
|
|
112
|
+
},
|
|
113
|
+
),
|
|
114
|
+
)
|
|
115
|
+
return vmap_in_axes_params
|