jinns 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -0
- jinns/data/_Batchs.py +27 -0
- jinns/data/_DataGenerators.py +904 -1203
- jinns/data/__init__.py +4 -8
- jinns/experimental/__init__.py +0 -2
- jinns/experimental/_diffrax_solver.py +5 -5
- jinns/loss/_DynamicLoss.py +282 -305
- jinns/loss/_DynamicLossAbstract.py +322 -167
- jinns/loss/_LossODE.py +324 -322
- jinns/loss/_LossPDE.py +652 -1027
- jinns/loss/__init__.py +21 -5
- jinns/loss/_boundary_conditions.py +87 -41
- jinns/loss/{_Losses.py → _loss_utils.py} +101 -45
- jinns/loss/_loss_weights.py +59 -0
- jinns/loss/_operators.py +78 -72
- jinns/parameters/__init__.py +6 -0
- jinns/parameters/_derivative_keys.py +521 -0
- jinns/parameters/_params.py +115 -0
- jinns/plot/__init__.py +5 -0
- jinns/{data/_display.py → plot/_plot.py} +98 -75
- jinns/solver/_rar.py +183 -39
- jinns/solver/_solve.py +151 -124
- jinns/utils/__init__.py +3 -9
- jinns/utils/_containers.py +37 -44
- jinns/utils/_hyperpinn.py +224 -119
- jinns/utils/_pinn.py +183 -111
- jinns/utils/_save_load.py +121 -56
- jinns/utils/_spinn.py +113 -86
- jinns/utils/_types.py +64 -0
- jinns/utils/_utils.py +6 -160
- jinns/validation/_validation.py +48 -140
- jinns-1.1.0.dist-info/AUTHORS +2 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/METADATA +5 -4
- jinns-1.1.0.dist-info/RECORD +39 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/WHEEL +1 -1
- jinns/experimental/_sinuspinn.py +0 -135
- jinns/experimental/_spectralpinn.py +0 -87
- jinns/solver/_seq2seq.py +0 -157
- jinns/utils/_optim.py +0 -147
- jinns/utils/_utils_uspinn.py +0 -727
- jinns-0.9.0.dist-info/RECORD +0 -36
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/LICENSE +0 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/top_level.txt +0 -0
jinns/utils/_spinn.py
CHANGED
|
@@ -3,54 +3,53 @@ Implements utility function to create Separable PINNs
|
|
|
3
3
|
https://arxiv.org/abs/2211.08761
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
+
from dataclasses import InitVar
|
|
7
|
+
from typing import Callable, Literal
|
|
6
8
|
import jax
|
|
7
9
|
import jax.numpy as jnp
|
|
8
10
|
import equinox as eqx
|
|
11
|
+
from jaxtyping import Key, Array, Float, PyTree
|
|
9
12
|
|
|
10
13
|
|
|
11
14
|
class _SPINN(eqx.Module):
|
|
12
15
|
"""
|
|
13
16
|
Construct a Separable PINN as proposed in
|
|
14
17
|
Cho et al., _Separable Physics-Informed Neural Networks_, NeurIPS, 2023
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
key : InitVar[Key]
|
|
22
|
+
A jax random key for the layer initializations.
|
|
23
|
+
d : int
|
|
24
|
+
The number of dimensions to treat separately.
|
|
25
|
+
eqx_list : InitVar[tuple[tuple[Callable, int, int] | Callable, ...]]
|
|
26
|
+
A tuple of tuples of successive equinox modules and activation functions to
|
|
27
|
+
describe the PINN architecture. The inner tuples must have the eqx module or
|
|
28
|
+
activation function as first item, other items represents arguments
|
|
29
|
+
that could be required (eg. the size of the layer).
|
|
30
|
+
The `key` argument need not be given.
|
|
31
|
+
Thus typical example is `eqx_list=
|
|
32
|
+
((eqx.nn.Linear, 2, 20),
|
|
33
|
+
jax.nn.tanh,
|
|
34
|
+
(eqx.nn.Linear, 20, 20),
|
|
35
|
+
jax.nn.tanh,
|
|
36
|
+
(eqx.nn.Linear, 20, 20),
|
|
37
|
+
jax.nn.tanh,
|
|
38
|
+
(eqx.nn.Linear, 20, 1)
|
|
39
|
+
)`.
|
|
15
40
|
"""
|
|
16
41
|
|
|
17
|
-
|
|
18
|
-
separated_mlp: list
|
|
19
|
-
d: int
|
|
20
|
-
r: int
|
|
21
|
-
m: int
|
|
42
|
+
d: int = eqx.field(static=True, kw_only=True)
|
|
22
43
|
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
An integer. The number of dimensions to treat separately
|
|
31
|
-
r
|
|
32
|
-
An integer. The dimension of the embedding
|
|
33
|
-
eqx_list
|
|
34
|
-
A list of list of successive equinox modules and activation functions to
|
|
35
|
-
describe *each separable PINN architecture*.
|
|
36
|
-
The inner lists have the eqx module or
|
|
37
|
-
axtivation function as first item, other items represents arguments
|
|
38
|
-
that could be required (eg. the size of the layer).
|
|
39
|
-
__Note:__ the `key` argument need not be given.
|
|
40
|
-
Thus typical example is `eqx_list=
|
|
41
|
-
[[eqx.nn.Linear, 1, 20],
|
|
42
|
-
[jax.nn.tanh],
|
|
43
|
-
[eqx.nn.Linear, 20, 20],
|
|
44
|
-
[jax.nn.tanh],
|
|
45
|
-
[eqx.nn.Linear, 20, 20],
|
|
46
|
-
[jax.nn.tanh],
|
|
47
|
-
[eqx.nn.Linear, 20, r]
|
|
48
|
-
]`
|
|
49
|
-
"""
|
|
50
|
-
self.d = d
|
|
51
|
-
self.r = r
|
|
52
|
-
self.m = m
|
|
44
|
+
key: InitVar[Key] = eqx.field(kw_only=True)
|
|
45
|
+
eqx_list: InitVar[tuple[tuple[Callable, int, int] | Callable, ...]] = eqx.field(
|
|
46
|
+
kw_only=True
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
layers: list = eqx.field(init=False)
|
|
50
|
+
separated_mlp: list = eqx.field(init=False)
|
|
53
51
|
|
|
52
|
+
def __post_init__(self, key, eqx_list):
|
|
54
53
|
self.separated_mlp = []
|
|
55
54
|
for _ in range(self.d):
|
|
56
55
|
self.layers = []
|
|
@@ -62,7 +61,9 @@ class _SPINN(eqx.Module):
|
|
|
62
61
|
self.layers.append(l[0](*l[1:], key=subkey))
|
|
63
62
|
self.separated_mlp.append(self.layers)
|
|
64
63
|
|
|
65
|
-
def __call__(
|
|
64
|
+
def __call__(
|
|
65
|
+
self, t: Float[Array, "1"], x: Float[Array, "omega_dim"]
|
|
66
|
+
) -> Float[Array, "d embed_dim*output_dim"]:
|
|
66
67
|
if t is not None:
|
|
67
68
|
dimensions = jnp.concatenate([t, x.flatten()], axis=0)
|
|
68
69
|
else:
|
|
@@ -78,53 +79,67 @@ class _SPINN(eqx.Module):
|
|
|
78
79
|
|
|
79
80
|
class SPINN(eqx.Module):
|
|
80
81
|
"""
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
The function create_SPINN has the role to population the `__call__` function
|
|
82
|
+
A SPINN object compatible with the rest of jinns.
|
|
83
|
+
This is typically created with `create_SPINN`.
|
|
84
84
|
|
|
85
85
|
**NOTE**: SPINNs with `t` and `x` as inputs are best used with a
|
|
86
86
|
DataGenerator with `self.cartesian_product=False` for memory consideration
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
d : int
|
|
91
|
+
The number of dimensions to treat separately.
|
|
92
|
+
|
|
87
93
|
"""
|
|
88
94
|
|
|
89
|
-
d: int
|
|
90
|
-
r: int
|
|
91
|
-
eq_type: str = eqx.field(static=True)
|
|
92
|
-
m: int
|
|
93
|
-
|
|
94
|
-
|
|
95
|
+
d: int = eqx.field(static=True, kw_only=True)
|
|
96
|
+
r: int = eqx.field(static=True, kw_only=True)
|
|
97
|
+
eq_type: str = eqx.field(static=True, kw_only=True)
|
|
98
|
+
m: int = eqx.field(static=True, kw_only=True)
|
|
99
|
+
|
|
100
|
+
spinn_mlp: InitVar[eqx.Module] = eqx.field(kw_only=True)
|
|
101
|
+
|
|
102
|
+
params: PyTree = eqx.field(init=False)
|
|
103
|
+
static: PyTree = eqx.field(init=False, static=True)
|
|
95
104
|
|
|
96
|
-
def
|
|
97
|
-
self.d, self.r, self.m = d, r, m
|
|
105
|
+
def __post_init__(self, spinn_mlp):
|
|
98
106
|
self.params, self.static = eqx.partition(spinn_mlp, eqx.is_inexact_array)
|
|
99
|
-
self.eq_type = eq_type
|
|
100
107
|
|
|
101
|
-
def init_params(self):
|
|
108
|
+
def init_params(self) -> PyTree:
|
|
109
|
+
"""
|
|
110
|
+
Returns an initial set of parameters
|
|
111
|
+
"""
|
|
102
112
|
return self.params
|
|
103
113
|
|
|
104
|
-
def __call__(self, *args):
|
|
114
|
+
def __call__(self, *args) -> Float[Array, "output_dim"]:
|
|
115
|
+
"""
|
|
116
|
+
Calls `eval_nn` with rearranged arguments
|
|
117
|
+
"""
|
|
105
118
|
if self.eq_type == "statio_PDE":
|
|
106
119
|
(x, params) = args
|
|
107
120
|
try:
|
|
108
|
-
spinn = eqx.combine(params
|
|
109
|
-
except (KeyError, TypeError) as e:
|
|
121
|
+
spinn = eqx.combine(params.nn_params, self.static)
|
|
122
|
+
except (KeyError, AttributeError, TypeError) as e:
|
|
110
123
|
spinn = eqx.combine(params, self.static)
|
|
111
124
|
v_model = jax.vmap(spinn, (0))
|
|
112
125
|
res = v_model(t=None, x=x)
|
|
113
|
-
return self.
|
|
126
|
+
return self.eval_nn(res)
|
|
114
127
|
if self.eq_type == "nonstatio_PDE":
|
|
115
128
|
(t, x, params) = args
|
|
116
129
|
try:
|
|
117
|
-
spinn = eqx.combine(params
|
|
118
|
-
except (KeyError, TypeError) as e:
|
|
130
|
+
spinn = eqx.combine(params.nn_params, self.static)
|
|
131
|
+
except (KeyError, AttributeError, TypeError) as e:
|
|
119
132
|
spinn = eqx.combine(params, self.static)
|
|
120
133
|
v_model = jax.vmap(spinn, ((0, 0)))
|
|
121
134
|
res = v_model(t, x)
|
|
122
|
-
return self.
|
|
135
|
+
return self.eval_nn(res)
|
|
123
136
|
raise RuntimeError("Wrong parameter value for eq_type")
|
|
124
137
|
|
|
125
|
-
def
|
|
138
|
+
def eval_nn(
|
|
139
|
+
self, res: Float[Array, "d embed_dim*output_dim"]
|
|
140
|
+
) -> Float[Array, "output_dim"]:
|
|
126
141
|
"""
|
|
127
|
-
|
|
142
|
+
Evaluate the SPINN on some inputs with some params.
|
|
128
143
|
"""
|
|
129
144
|
a = ", ".join([f"{chr(97 + d)}z" for d in range(res.shape[1])])
|
|
130
145
|
b = "".join([f"{chr(97 + d)}" for d in range(res.shape[1])])
|
|
@@ -148,59 +163,71 @@ class SPINN(eqx.Module):
|
|
|
148
163
|
return res
|
|
149
164
|
|
|
150
165
|
|
|
151
|
-
def create_SPINN(
|
|
166
|
+
def create_SPINN(
|
|
167
|
+
key: Key,
|
|
168
|
+
d: int,
|
|
169
|
+
r: int,
|
|
170
|
+
eqx_list: tuple[tuple[Callable, int, int] | Callable, ...],
|
|
171
|
+
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
|
|
172
|
+
m: int = 1,
|
|
173
|
+
) -> SPINN:
|
|
152
174
|
"""
|
|
153
175
|
Utility function to create a SPINN neural network with the equinox
|
|
154
176
|
library.
|
|
155
177
|
|
|
156
|
-
*Note* that a SPINN is not vmapped
|
|
157
|
-
same size for each input. It outputs
|
|
158
|
-
(batchsize, batchsize)
|
|
178
|
+
*Note* that a SPINN is not vmapped and expects the
|
|
179
|
+
same batch size for each of its input axis. It directly outputs a solution
|
|
180
|
+
of shape `(batchsize, batchsize)`. See the paper for more details.
|
|
159
181
|
|
|
160
182
|
Parameters
|
|
161
183
|
----------
|
|
162
184
|
key
|
|
163
|
-
A
|
|
185
|
+
A JAX random key that will be used to initialize the network parameters
|
|
164
186
|
d
|
|
165
|
-
|
|
187
|
+
The number of dimensions to treat separately.
|
|
166
188
|
r
|
|
167
|
-
An integer. The dimension of the embedding
|
|
189
|
+
An integer. The dimension of the embedding.
|
|
168
190
|
eqx_list
|
|
169
|
-
A
|
|
170
|
-
describe
|
|
171
|
-
|
|
172
|
-
axtivation function as first item, other items represents arguments
|
|
191
|
+
A tuple of tuples of successive equinox modules and activation functions to
|
|
192
|
+
describe the PINN architecture. The inner tuples must have the eqx module or
|
|
193
|
+
activation function as first item, other items represents arguments
|
|
173
194
|
that could be required (eg. the size of the layer).
|
|
174
|
-
|
|
175
|
-
Thus typical example is
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
eq_type
|
|
195
|
+
The `key` argument need not be given.
|
|
196
|
+
Thus typical example is
|
|
197
|
+
`eqx_list=((eqx.nn.Linear, 2, 20),
|
|
198
|
+
jax.nn.tanh,
|
|
199
|
+
(eqx.nn.Linear, 20, 20),
|
|
200
|
+
jax.nn.tanh,
|
|
201
|
+
(eqx.nn.Linear, 20, 20),
|
|
202
|
+
jax.nn.tanh,
|
|
203
|
+
(eqx.nn.Linear, 20, 1)
|
|
204
|
+
)`.
|
|
205
|
+
eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
|
|
185
206
|
A string with three possibilities.
|
|
186
207
|
"ODE": the PINN is called with one input `t`.
|
|
187
208
|
"statio_PDE": the PINN is called with one input `x`, `x`
|
|
188
209
|
can be high dimensional.
|
|
189
210
|
"nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
|
|
190
211
|
can be high dimensional.
|
|
212
|
+
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
213
|
+
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
214
|
+
after the `input_transform` function.
|
|
191
215
|
m
|
|
192
|
-
|
|
216
|
+
The output dimension of the neural network. According to
|
|
193
217
|
the SPINN article, a total embedding dimension of `r*m` is defined. We
|
|
194
218
|
then sum groups of `r` embedding dimensions to compute each output.
|
|
195
219
|
Default is 1.
|
|
196
220
|
|
|
197
|
-
|
|
198
|
-
|
|
221
|
+
!!! note
|
|
222
|
+
SPINNs with `t` and `x` as inputs are best used with a
|
|
223
|
+
DataGenerator with `self.cartesian_product=False` for memory
|
|
224
|
+
consideration
|
|
199
225
|
|
|
200
226
|
|
|
201
227
|
Returns
|
|
202
228
|
-------
|
|
203
|
-
|
|
229
|
+
spinn
|
|
230
|
+
An instanciated SPINN
|
|
204
231
|
|
|
205
232
|
Raises
|
|
206
233
|
------
|
|
@@ -235,7 +262,7 @@ def create_SPINN(key, d, r, eqx_list, eq_type, m=1):
|
|
|
235
262
|
"Too many dimensions, not enough letters available in jnp.einsum"
|
|
236
263
|
)
|
|
237
264
|
|
|
238
|
-
spinn_mlp = _SPINN(key, d,
|
|
239
|
-
spinn = SPINN(spinn_mlp, d, r, eq_type, m)
|
|
265
|
+
spinn_mlp = _SPINN(key=key, d=d, eqx_list=eqx_list)
|
|
266
|
+
spinn = SPINN(spinn_mlp=spinn_mlp, d=d, r=r, eq_type=eq_type, m=m)
|
|
240
267
|
|
|
241
268
|
return spinn
|
jinns/utils/_types.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from __future__ import (
|
|
2
|
+
annotations,
|
|
3
|
+
) # https://docs.python.org/3/library/typing.html#constant
|
|
4
|
+
|
|
5
|
+
from typing import TypeAlias, TYPE_CHECKING, NewType
|
|
6
|
+
from jaxtyping import Int
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from jinns.loss._LossPDE import (
|
|
10
|
+
LossPDEStatio,
|
|
11
|
+
LossPDENonStatio,
|
|
12
|
+
SystemLossPDE,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from jinns.loss._LossODE import LossODE, SystemLossODE
|
|
16
|
+
from jinns.parameters._params import Params, ParamsDict
|
|
17
|
+
from jinns.data._DataGenerators import (
|
|
18
|
+
DataGeneratorODE,
|
|
19
|
+
CubicMeshPDEStatio,
|
|
20
|
+
CubicMeshPDENonStatio,
|
|
21
|
+
DataGeneratorObservations,
|
|
22
|
+
DataGeneratorParameter,
|
|
23
|
+
DataGeneratorObservationsMultiPINNs,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
from jinns.loss import DynamicLoss
|
|
27
|
+
from jinns.data._Batchs import *
|
|
28
|
+
from jinns.utils._pinn import PINN
|
|
29
|
+
from jinns.utils._hyperpinn import HYPERPINN
|
|
30
|
+
from jinns.utils._spinn import SPINN
|
|
31
|
+
from jinns.utils._containers import *
|
|
32
|
+
from jinns.validation._validation import AbstractValidationModule
|
|
33
|
+
|
|
34
|
+
AnyLoss: TypeAlias = (
|
|
35
|
+
LossPDEStatio | LossPDENonStatio | SystemLossPDE | LossODE | SystemLossODE
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
AnyParams: TypeAlias = Params | ParamsDict
|
|
39
|
+
|
|
40
|
+
AnyDataGenerator: TypeAlias = (
|
|
41
|
+
DataGeneratorODE | CubicMeshPDEStatio | CubicMeshPDENonStatio
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
AnyPINN: TypeAlias = PINN | HYPERPINN | SPINN
|
|
45
|
+
|
|
46
|
+
AnyBatch: TypeAlias = ODEBatch | PDEStatioBatch | PDENonStatioBatch
|
|
47
|
+
rar_operands = NewType(
|
|
48
|
+
"rar_operands", tuple[AnyLoss, AnyParams, AnyDataGenerator, Int]
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
main_carry = NewType(
|
|
52
|
+
"main_carry",
|
|
53
|
+
tuple[
|
|
54
|
+
Int,
|
|
55
|
+
AnyLoss,
|
|
56
|
+
OptimizationContainer,
|
|
57
|
+
OptimizationExtraContainer,
|
|
58
|
+
DataGeneratorContainer,
|
|
59
|
+
AbstractValidationModule,
|
|
60
|
+
LossContainer,
|
|
61
|
+
StoredObjectContainer,
|
|
62
|
+
Float[Array, "n_iter"],
|
|
63
|
+
],
|
|
64
|
+
)
|
jinns/utils/_utils.py
CHANGED
|
@@ -7,9 +7,10 @@ from operator import getitem
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import jax
|
|
9
9
|
import jax.numpy as jnp
|
|
10
|
+
from jaxtyping import PyTree, Array
|
|
10
11
|
|
|
11
12
|
|
|
12
|
-
def _check_nan_in_pytree(pytree):
|
|
13
|
+
def _check_nan_in_pytree(pytree: PyTree) -> bool:
|
|
13
14
|
"""
|
|
14
15
|
Check if there is a NaN value anywhere is the pytree
|
|
15
16
|
|
|
@@ -25,40 +26,14 @@ def _check_nan_in_pytree(pytree):
|
|
|
25
26
|
"""
|
|
26
27
|
return jnp.any(
|
|
27
28
|
jnp.array(
|
|
28
|
-
|
|
29
|
-
jax.tree_util.
|
|
30
|
-
jax.tree_util.tree_map(lambda x: jnp.any(jnp.isnan(x)), pytree)
|
|
31
|
-
)
|
|
29
|
+
jax.tree_util.tree_leaves(
|
|
30
|
+
jax.tree_util.tree_map(lambda x: jnp.any(jnp.isnan(x)), pytree)
|
|
32
31
|
)
|
|
33
32
|
)
|
|
34
33
|
)
|
|
35
34
|
|
|
36
35
|
|
|
37
|
-
def
|
|
38
|
-
"""
|
|
39
|
-
Returns a pytree with the same structure as params with True is the
|
|
40
|
-
parameter is tracked False otherwise
|
|
41
|
-
"""
|
|
42
|
-
|
|
43
|
-
def set_nested_item(dataDict, mapList, val):
|
|
44
|
-
"""
|
|
45
|
-
Set item in nested dictionary
|
|
46
|
-
https://stackoverflow.com/questions/54137991/how-to-update-values-in-nested-dictionary-if-keys-are-in-a-list
|
|
47
|
-
"""
|
|
48
|
-
reduce(getitem, mapList[:-1], dataDict)[mapList[-1]] = val
|
|
49
|
-
return dataDict
|
|
50
|
-
|
|
51
|
-
tracked_params = jax.tree_util.tree_map(
|
|
52
|
-
lambda x: False, params
|
|
53
|
-
) # init with all False
|
|
54
|
-
|
|
55
|
-
for key_list in tracked_params_key_list:
|
|
56
|
-
tracked_params = set_nested_item(tracked_params, key_list, True)
|
|
57
|
-
|
|
58
|
-
return tracked_params
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
def _get_grid(in_array):
|
|
36
|
+
def _get_grid(in_array: Array) -> Array:
|
|
62
37
|
"""
|
|
63
38
|
From an array of shape (B, D), D > 1, get the grid array, i.e., an array of
|
|
64
39
|
shape (B, B, ...(D times)..., B, D): along the last axis we have the array
|
|
@@ -74,31 +49,7 @@ def _get_grid(in_array):
|
|
|
74
49
|
return in_array
|
|
75
50
|
|
|
76
51
|
|
|
77
|
-
def
|
|
78
|
-
"""
|
|
79
|
-
Return the input vmap axes when there is batch(es) of parameters to vmap
|
|
80
|
-
over. The latter are designated by keys in eq_params_batch_dict
|
|
81
|
-
If eq_params_batch_dict (ie no additional parameter batch), we return None
|
|
82
|
-
"""
|
|
83
|
-
if eq_params_batch_dict is None:
|
|
84
|
-
return (None,)
|
|
85
|
-
# We use pytree indexing of vmapped axes and vmap on axis
|
|
86
|
-
# 0 of the eq_parameters for which we have a batch
|
|
87
|
-
# this is for a fine-grained vmaping
|
|
88
|
-
# scheme over the params
|
|
89
|
-
vmap_in_axes_params = (
|
|
90
|
-
{
|
|
91
|
-
"nn_params": None,
|
|
92
|
-
"eq_params": {
|
|
93
|
-
k: (0 if k in eq_params_batch_dict.keys() else None)
|
|
94
|
-
for k in params["eq_params"].keys()
|
|
95
|
-
},
|
|
96
|
-
},
|
|
97
|
-
)
|
|
98
|
-
return vmap_in_axes_params
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
def _check_user_func_return(r, shape):
|
|
52
|
+
def _check_user_func_return(r: Array | int, shape: tuple) -> Array | int:
|
|
102
53
|
"""
|
|
103
54
|
Correctly handles the result from a user defined function (eg a boundary
|
|
104
55
|
condition) to get the correct broadcast
|
|
@@ -115,108 +66,3 @@ def _check_user_func_return(r, shape):
|
|
|
115
66
|
# the reshape below avoids a missing (1,) ending dimension
|
|
116
67
|
# depending on how the user has coded the inital function
|
|
117
68
|
return r.reshape(shape)
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
def _set_derivatives(params, loss_term, derivative_keys):
|
|
121
|
-
"""
|
|
122
|
-
Given derivative_keys, the parameters wrt which we want to compute
|
|
123
|
-
gradients in the loss, we set stop_gradient operators to not take the
|
|
124
|
-
derivatives with respect to the others. Note that we only operator at
|
|
125
|
-
top level
|
|
126
|
-
"""
|
|
127
|
-
try:
|
|
128
|
-
params = {
|
|
129
|
-
k: (
|
|
130
|
-
value
|
|
131
|
-
if k in derivative_keys[loss_term]
|
|
132
|
-
else jax.lax.stop_gradient(value)
|
|
133
|
-
)
|
|
134
|
-
for k, value in params.items()
|
|
135
|
-
}
|
|
136
|
-
except KeyError: # if the loss_term key has not been specified we
|
|
137
|
-
# only take gradients wrt "nn_params", all the other entries have
|
|
138
|
-
# stopped gradient
|
|
139
|
-
params = {
|
|
140
|
-
k: value if k in ["nn_params"] else jax.lax.stop_gradient(value)
|
|
141
|
-
for k, value in params.items()
|
|
142
|
-
}
|
|
143
|
-
|
|
144
|
-
return params
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
def _extract_nn_params(params_dict, nn_key):
|
|
148
|
-
"""
|
|
149
|
-
Given a params_dict for system loss (ie "nn_params" and "eq_params" as main
|
|
150
|
-
keys which contain dicts for each PINN (the nn_keys)) we extract the
|
|
151
|
-
corresponding "nn_params" for `nn_key` and reform a dict with "nn_params"
|
|
152
|
-
as main key as expected by the PINN/SPINN apply_fn
|
|
153
|
-
"""
|
|
154
|
-
try:
|
|
155
|
-
return {
|
|
156
|
-
"nn_params": params_dict["nn_params"][nn_key],
|
|
157
|
-
"eq_params": params_dict["eq_params"][nn_key],
|
|
158
|
-
}
|
|
159
|
-
except (KeyError, IndexError) as e:
|
|
160
|
-
return {
|
|
161
|
-
"nn_params": params_dict["nn_params"][nn_key],
|
|
162
|
-
"eq_params": params_dict["eq_params"],
|
|
163
|
-
}
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
def euler_maruyama_density(t, x, s, y, params, Tmax=1):
|
|
167
|
-
eps = 1e-6
|
|
168
|
-
delta = jnp.abs(t - s) * Tmax
|
|
169
|
-
mu = params["alpha_sde"] * (params["mu_sde"] - y) * delta
|
|
170
|
-
var = params["sigma_sde"] ** 2 * delta
|
|
171
|
-
return (
|
|
172
|
-
1 / jnp.sqrt(2 * jnp.pi * var) * jnp.exp(-0.5 * ((x - y) - mu) ** 2 / var) + eps
|
|
173
|
-
)
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
def log_euler_maruyama_density(t, x, s, y, params):
|
|
177
|
-
eps = 1e-6
|
|
178
|
-
delta = jnp.abs(t - s)
|
|
179
|
-
mu = params["alpha_sde"] * (params["mu_sde"] - y) * delta
|
|
180
|
-
logvar = params["logvar_sde"]
|
|
181
|
-
return (
|
|
182
|
-
-0.5
|
|
183
|
-
* (jnp.log(2 * jnp.pi * delta) + logvar + ((x - y) - mu) ** 2 / jnp.exp(logvar))
|
|
184
|
-
+ eps
|
|
185
|
-
)
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
def euler_maruyama(x0, alpha, mu, sigma, T, N):
|
|
189
|
-
"""
|
|
190
|
-
Simulate 1D diffusion process with simple parametrization using the Euler
|
|
191
|
-
Maruyama method in the interval [0, T]
|
|
192
|
-
"""
|
|
193
|
-
path = [np.array([x0])]
|
|
194
|
-
|
|
195
|
-
time_steps, step_size = np.linspace(0, T, N, retstep=True)
|
|
196
|
-
for _ in time_steps[1:]:
|
|
197
|
-
path.append(
|
|
198
|
-
path[-1]
|
|
199
|
-
+ step_size * (alpha * (mu - path[-1]))
|
|
200
|
-
+ sigma * np.random.normal(loc=0.0, scale=np.sqrt(step_size))
|
|
201
|
-
)
|
|
202
|
-
|
|
203
|
-
return time_steps, np.stack(path)
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
def _update_eq_params_dict(params, param_batch_dict):
|
|
207
|
-
# update params["eq_params"] with a batch of eq_params
|
|
208
|
-
# we avoid side_effect by recreating the dict `params`
|
|
209
|
-
# TODO transform `params` in a NamedTuple to be able to use _replace
|
|
210
|
-
# see Issue #1
|
|
211
|
-
param_batch_dict_ = param_batch_dict | {
|
|
212
|
-
k: None for k in set(params["eq_params"].keys()) - set(param_batch_dict.keys())
|
|
213
|
-
}
|
|
214
|
-
params = {"nn_params": params["nn_params"]} | {
|
|
215
|
-
"eq_params": jax.tree_util.tree_map(
|
|
216
|
-
lambda p, q: q if q is not None else p,
|
|
217
|
-
params["eq_params"],
|
|
218
|
-
param_batch_dict_,
|
|
219
|
-
)
|
|
220
|
-
}
|
|
221
|
-
|
|
222
|
-
return params
|