jinns 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -0
- jinns/data/_Batchs.py +27 -0
- jinns/data/_DataGenerators.py +904 -1203
- jinns/data/__init__.py +4 -8
- jinns/experimental/__init__.py +0 -2
- jinns/experimental/_diffrax_solver.py +5 -5
- jinns/loss/_DynamicLoss.py +282 -305
- jinns/loss/_DynamicLossAbstract.py +322 -167
- jinns/loss/_LossODE.py +324 -322
- jinns/loss/_LossPDE.py +652 -1027
- jinns/loss/__init__.py +21 -5
- jinns/loss/_boundary_conditions.py +87 -41
- jinns/loss/{_Losses.py → _loss_utils.py} +101 -45
- jinns/loss/_loss_weights.py +59 -0
- jinns/loss/_operators.py +78 -72
- jinns/parameters/__init__.py +6 -0
- jinns/parameters/_derivative_keys.py +521 -0
- jinns/parameters/_params.py +115 -0
- jinns/plot/__init__.py +5 -0
- jinns/{data/_display.py → plot/_plot.py} +98 -75
- jinns/solver/_rar.py +183 -39
- jinns/solver/_solve.py +151 -124
- jinns/utils/__init__.py +3 -9
- jinns/utils/_containers.py +37 -44
- jinns/utils/_hyperpinn.py +224 -119
- jinns/utils/_pinn.py +183 -111
- jinns/utils/_save_load.py +121 -56
- jinns/utils/_spinn.py +113 -86
- jinns/utils/_types.py +64 -0
- jinns/utils/_utils.py +6 -160
- jinns/validation/_validation.py +48 -140
- jinns-1.1.0.dist-info/AUTHORS +2 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/METADATA +5 -4
- jinns-1.1.0.dist-info/RECORD +39 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/WHEEL +1 -1
- jinns/experimental/_sinuspinn.py +0 -135
- jinns/experimental/_spectralpinn.py +0 -87
- jinns/solver/_seq2seq.py +0 -157
- jinns/utils/_optim.py +0 -147
- jinns/utils/_utils_uspinn.py +0 -727
- jinns-0.9.0.dist-info/RECORD +0 -36
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/LICENSE +0 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Utility functions for plotting
|
|
2
|
+
Utility functions for plotting in 1D and 2D, with and without time.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from functools import partial
|
|
@@ -8,48 +8,55 @@ import matplotlib.pyplot as plt
|
|
|
8
8
|
import jax.numpy as jnp
|
|
9
9
|
from jax import vmap
|
|
10
10
|
from mpl_toolkits.axes_grid1 import ImageGrid
|
|
11
|
+
from typing import Callable, List
|
|
12
|
+
from jaxtyping import Array, Float, Bool
|
|
11
13
|
|
|
12
14
|
|
|
13
15
|
def plot2d(
|
|
14
|
-
fun,
|
|
15
|
-
xy_data,
|
|
16
|
-
times=None,
|
|
17
|
-
Tmax=1,
|
|
18
|
-
title="",
|
|
19
|
-
figsize=(7, 7),
|
|
20
|
-
cmap="inferno",
|
|
21
|
-
spinn=False,
|
|
22
|
-
vmin_vmax=None,
|
|
23
|
-
ax_for_plot=None,
|
|
16
|
+
fun: Callable,
|
|
17
|
+
xy_data: tuple[Float[Array, "nx"], Float[Array, "ny"]],
|
|
18
|
+
times: Float[Array, "nt"] | List[float] | None = None,
|
|
19
|
+
Tmax: float = 1,
|
|
20
|
+
title: str = "",
|
|
21
|
+
figsize: tuple = (7, 7),
|
|
22
|
+
cmap: str = "inferno",
|
|
23
|
+
spinn: bool = False,
|
|
24
|
+
vmin_vmax: tuple[float, float] | None = None,
|
|
25
|
+
ax_for_plot: plt.Axes | None = None,
|
|
24
26
|
):
|
|
25
27
|
r"""Generic function for plotting functions over rectangular 2-D domains
|
|
26
|
-
|
|
27
|
-
non-stationnary case :math:`u(t, x)`.
|
|
28
|
+
$\Omega$. It handles both the
|
|
28
29
|
|
|
29
|
-
|
|
30
|
-
|
|
30
|
+
1. the stationary case $u(x)$
|
|
31
|
+
2. the non-stationnary case $u(t, x)$
|
|
32
|
+
|
|
33
|
+
In the non-stationnary case, the `times` argument gives the time
|
|
34
|
+
slices $t_i$ at which to plot $u(t_i, x)$.
|
|
31
35
|
|
|
32
36
|
|
|
33
37
|
Parameters
|
|
34
38
|
----------
|
|
35
|
-
fun :
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
39
|
+
fun :
|
|
40
|
+
the function $u$ to plot on the meshgrid, and eventually the time
|
|
41
|
+
slices. It's suppose to have signature `u(x)` in the stationnary case,, and `u(t, x)` in the non-stationnary case. Use `partial` or `lambda to freeze / reorder any other arguments.
|
|
42
|
+
xy_data :
|
|
43
|
+
A list of 2 `jnp.Array` providing grid values for meshgrid creation
|
|
44
|
+
times :
|
|
45
|
+
list or Array of time slices where to plot the function. Use Tmax if
|
|
46
|
+
you trained with time-rescaling.
|
|
47
|
+
Tmax :
|
|
48
|
+
Useful if you used time rescaling in the differential equation for training, default to 1 (no rescaling).
|
|
49
|
+
title :
|
|
45
50
|
plot title, by default ""
|
|
46
|
-
figsize :
|
|
47
|
-
|
|
48
|
-
cmap :
|
|
49
|
-
|
|
50
|
-
vmin_vmax :
|
|
51
|
+
figsize :
|
|
52
|
+
By default (7, 7)
|
|
53
|
+
cmap :
|
|
54
|
+
the matplotlib color map used in the ImageGrid.
|
|
55
|
+
vmin_vmax :
|
|
51
56
|
The colorbar minimum and maximum value. Defaults None.
|
|
52
|
-
|
|
57
|
+
spinn :
|
|
58
|
+
True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
|
|
59
|
+
ax_for_plot :
|
|
53
60
|
If None, jinns triggers the plotting. Otherwise this argument
|
|
54
61
|
corresponds to the axis which will host the plot. Default is None.
|
|
55
62
|
NOTE: that this argument will have an effect only if times is None.
|
|
@@ -57,7 +64,7 @@ def plot2d(
|
|
|
57
64
|
Raises
|
|
58
65
|
------
|
|
59
66
|
ValueError
|
|
60
|
-
if xy_data is not a list of
|
|
67
|
+
if xy_data is not a list of length 2
|
|
61
68
|
"""
|
|
62
69
|
|
|
63
70
|
# if not isinstance(xy_data, jnp.ndarray) and not xy_data.shape[-1] == 2:
|
|
@@ -179,13 +186,13 @@ def plot2d(
|
|
|
179
186
|
|
|
180
187
|
def _plot_2D_statio(
|
|
181
188
|
v_fun,
|
|
182
|
-
mesh,
|
|
183
|
-
plot=True,
|
|
184
|
-
colorbar=True,
|
|
185
|
-
cmap="inferno",
|
|
186
|
-
figsize=(7, 7),
|
|
187
|
-
spinn=False,
|
|
188
|
-
vmin_vmax=None,
|
|
189
|
+
mesh: Float[Array, "nx*ny nx*ny"],
|
|
190
|
+
plot: Bool = True,
|
|
191
|
+
colorbar: Bool = True,
|
|
192
|
+
cmap: str = "inferno",
|
|
193
|
+
figsize: tuple[int, int] = (7, 7),
|
|
194
|
+
spinn: Bool = False,
|
|
195
|
+
vmin_vmax: tuple[float, float] = None,
|
|
189
196
|
):
|
|
190
197
|
"""Function that plot the function u(x) with 2-D input x using pcolormesh()
|
|
191
198
|
|
|
@@ -200,6 +207,12 @@ def _plot_2D_statio(
|
|
|
200
207
|
either show or return the plot, by default True
|
|
201
208
|
colorbar : bool, optional
|
|
202
209
|
add a colorbar, by default True
|
|
210
|
+
cmap :
|
|
211
|
+
the matplotlib color map used in the ImageGrid.
|
|
212
|
+
figsize :
|
|
213
|
+
By default (7, 7)
|
|
214
|
+
spinn :
|
|
215
|
+
True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
|
|
203
216
|
vmin_vmax: tuple, optional
|
|
204
217
|
The colorbar minimum and maximum value. Defaults None.
|
|
205
218
|
|
|
@@ -238,33 +251,37 @@ def _plot_2D_statio(
|
|
|
238
251
|
|
|
239
252
|
|
|
240
253
|
def plot1d_slice(
|
|
241
|
-
fun,
|
|
242
|
-
xdata,
|
|
243
|
-
time_slices
|
|
244
|
-
Tmax=1,
|
|
245
|
-
title="",
|
|
246
|
-
figsize=(10, 10),
|
|
247
|
-
spinn=False,
|
|
254
|
+
fun: Callable[[float, float], float],
|
|
255
|
+
xdata: Float[Array, "nx"],
|
|
256
|
+
time_slices: Float[Array, "nt"] | None = None,
|
|
257
|
+
Tmax: float = 1.0,
|
|
258
|
+
title: str = "",
|
|
259
|
+
figsize: tuple[int, int] = (10, 10),
|
|
260
|
+
spinn: Bool = False,
|
|
248
261
|
):
|
|
249
262
|
"""Function for plotting time slices of a function :math:`f(t_i, x)` where
|
|
250
|
-
`
|
|
263
|
+
`t_i` is time (1-D) and x is 1-D
|
|
251
264
|
|
|
252
265
|
Parameters
|
|
253
266
|
----------
|
|
254
|
-
fun
|
|
267
|
+
fun
|
|
255
268
|
f(t, x)
|
|
256
|
-
xdata
|
|
269
|
+
xdata
|
|
257
270
|
the discretization of space
|
|
258
|
-
time_slices
|
|
259
|
-
the time slices :math:`t_i` at which to plot
|
|
260
|
-
Tmax
|
|
271
|
+
time_slices
|
|
272
|
+
the time slices :math:`t_i` at which to plot.
|
|
273
|
+
Tmax
|
|
261
274
|
Useful if you used time re-scaling in the differential equation, by
|
|
262
275
|
default 1
|
|
263
|
-
title
|
|
276
|
+
title
|
|
264
277
|
title of the plot, by default ""
|
|
265
|
-
|
|
278
|
+
spinn :
|
|
279
|
+
True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
|
|
280
|
+
figsize
|
|
266
281
|
size of the figure, by default (10, 10)
|
|
267
282
|
"""
|
|
283
|
+
if time_slices is None:
|
|
284
|
+
time_slices = jnp.array([0])
|
|
268
285
|
plt.figure(figsize=figsize)
|
|
269
286
|
for t in time_slices:
|
|
270
287
|
if not spinn:
|
|
@@ -284,16 +301,16 @@ def plot1d_slice(
|
|
|
284
301
|
|
|
285
302
|
|
|
286
303
|
def plot1d_image(
|
|
287
|
-
fun,
|
|
288
|
-
xdata,
|
|
289
|
-
times,
|
|
290
|
-
Tmax=1,
|
|
291
|
-
title="",
|
|
292
|
-
figsize=(10, 10),
|
|
293
|
-
colorbar=True,
|
|
294
|
-
cmap="inferno",
|
|
295
|
-
spinn=False,
|
|
296
|
-
vmin_vmax=None,
|
|
304
|
+
fun: Callable[[float, float], float],
|
|
305
|
+
xdata: Float[Array, "nx"],
|
|
306
|
+
times: Float[Array, "nt"],
|
|
307
|
+
Tmax: float = 1.0,
|
|
308
|
+
title: str = "",
|
|
309
|
+
figsize: tuple[int, int] = (10, 10),
|
|
310
|
+
colorbar: Bool = True,
|
|
311
|
+
cmap: str = "inferno",
|
|
312
|
+
spinn: Bool = False,
|
|
313
|
+
vmin_vmax: tuple[float, float] = None,
|
|
297
314
|
):
|
|
298
315
|
"""Function for plotting the 2-D image of a function :math:`f(t, x)` where
|
|
299
316
|
`t` is time (1-D) and x is space (1-D).
|
|
@@ -302,19 +319,25 @@ def plot1d_image(
|
|
|
302
319
|
|
|
303
320
|
Parameters
|
|
304
321
|
----------
|
|
305
|
-
fun :
|
|
306
|
-
the function to plot
|
|
307
|
-
xdata :
|
|
322
|
+
fun :
|
|
323
|
+
callable with two arguments t and x the function to plot
|
|
324
|
+
xdata :
|
|
308
325
|
the discretization of space
|
|
309
|
-
times :
|
|
326
|
+
times :
|
|
310
327
|
the discretization of time
|
|
311
|
-
Tmax :
|
|
312
|
-
|
|
313
|
-
title :
|
|
314
|
-
|
|
315
|
-
figsize :
|
|
316
|
-
|
|
317
|
-
|
|
328
|
+
Tmax :
|
|
329
|
+
by default 1
|
|
330
|
+
title :
|
|
331
|
+
by default ""
|
|
332
|
+
figsize :
|
|
333
|
+
by default (10, 10)
|
|
334
|
+
colorbar :
|
|
335
|
+
Whether to add a colobar
|
|
336
|
+
cmap :
|
|
337
|
+
the matplotlib color map used in the ImageGrid.
|
|
338
|
+
spinn :
|
|
339
|
+
True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
|
|
340
|
+
vmin_vmax:
|
|
318
341
|
The colorbar minimum and maximum value. Defaults None.
|
|
319
342
|
"""
|
|
320
343
|
|
jinns/solver/_rar.py
CHANGED
|
@@ -1,21 +1,32 @@
|
|
|
1
|
+
from __future__ import (
|
|
2
|
+
annotations,
|
|
3
|
+
) # https://docs.python.org/3/library/typing.html#constant
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Callable
|
|
6
|
+
from functools import partial
|
|
1
7
|
import jax
|
|
2
8
|
from jax import vmap
|
|
3
9
|
import jax.numpy as jnp
|
|
10
|
+
import equinox as eqx
|
|
11
|
+
from jaxtyping import Int, Bool
|
|
12
|
+
|
|
13
|
+
from jinns.data._Batchs import *
|
|
14
|
+
from jinns.loss._LossODE import LossODE, SystemLossODE
|
|
15
|
+
from jinns.loss._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
|
|
4
16
|
from jinns.data._DataGenerators import (
|
|
5
17
|
DataGeneratorODE,
|
|
6
18
|
CubicMeshPDEStatio,
|
|
7
19
|
CubicMeshPDENonStatio,
|
|
8
20
|
)
|
|
9
|
-
from jinns.loss._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
|
|
10
|
-
from jinns.loss._LossODE import LossODE, SystemLossODE
|
|
11
|
-
from jinns.loss._DynamicLossAbstract import PDEStatio
|
|
12
|
-
|
|
13
|
-
from functools import partial
|
|
14
21
|
from jinns.utils._hyperpinn import HYPERPINN
|
|
15
22
|
from jinns.utils._spinn import SPINN
|
|
16
23
|
|
|
17
24
|
|
|
18
|
-
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from jinns.utils._types import *
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _proceed_to_rar(data: AnyDataGenerator, i: Int) -> Bool:
|
|
19
30
|
"""Utilility function with various check to ensure we can proceed with the rar_step.
|
|
20
31
|
Return True if yes, and False otherwise"""
|
|
21
32
|
|
|
@@ -30,13 +41,13 @@ def _proceed_to_rar(data, i):
|
|
|
30
41
|
# Memory allocation checks (depends on the type of DataGenerator)
|
|
31
42
|
# check if we still have room to append new collocation points in the
|
|
32
43
|
# allocated jnp.array (can concern `data.p_times` or `p_omega`)
|
|
33
|
-
if isinstance(data, DataGeneratorODE
|
|
44
|
+
if isinstance(data, (DataGeneratorODE, CubicMeshPDENonStatio)):
|
|
34
45
|
check_list.append(
|
|
35
46
|
data.rar_parameters["selected_sample_size_times"]
|
|
36
47
|
<= jnp.count_nonzero(data.p_times == 0),
|
|
37
48
|
)
|
|
38
49
|
|
|
39
|
-
if isinstance(data, CubicMeshPDEStatio
|
|
50
|
+
if isinstance(data, (CubicMeshPDEStatio, CubicMeshPDENonStatio)):
|
|
40
51
|
# for now the above check are redundants but there may be a time when
|
|
41
52
|
# we drop inheritence
|
|
42
53
|
check_list.append(
|
|
@@ -49,7 +60,14 @@ def _proceed_to_rar(data, i):
|
|
|
49
60
|
|
|
50
61
|
|
|
51
62
|
@partial(jax.jit, static_argnames=["_rar_step_true", "_rar_step_false"])
|
|
52
|
-
def trigger_rar(
|
|
63
|
+
def trigger_rar(
|
|
64
|
+
i: Int,
|
|
65
|
+
loss: AnyLoss,
|
|
66
|
+
params: AnyParams,
|
|
67
|
+
data: AnyDataGenerator,
|
|
68
|
+
_rar_step_true: Callable[[rar_operands], AnyDataGenerator],
|
|
69
|
+
_rar_step_false: Callable[[rar_operands], AnyDataGenerator],
|
|
70
|
+
) -> tuple[AnyLoss, AnyParams, AnyDataGenerator]:
|
|
53
71
|
|
|
54
72
|
if data.rar_parameters is None:
|
|
55
73
|
# do nothing.
|
|
@@ -65,7 +83,13 @@ def trigger_rar(i, loss, params, data, _rar_step_true, _rar_step_false):
|
|
|
65
83
|
return loss, params, data
|
|
66
84
|
|
|
67
85
|
|
|
68
|
-
def init_rar(
|
|
86
|
+
def init_rar(
|
|
87
|
+
data: AnyDataGenerator,
|
|
88
|
+
) -> tuple[
|
|
89
|
+
AnyDataGenerator,
|
|
90
|
+
Callable[[rar_operands], AnyDataGenerator],
|
|
91
|
+
Callable[[rar_operands], AnyDataGenerator],
|
|
92
|
+
]:
|
|
69
93
|
"""
|
|
70
94
|
Separated from the main rar, because the initialization to get _true and
|
|
71
95
|
_false cannot be jit-ted.
|
|
@@ -100,13 +124,21 @@ def init_rar(data):
|
|
|
100
124
|
data.rar_parameters["sample_size_omega"],
|
|
101
125
|
data.rar_parameters["selected_sample_size_omega"],
|
|
102
126
|
)
|
|
127
|
+
else:
|
|
128
|
+
raise ValueError(f"Wrong type for data got {type(data)}")
|
|
103
129
|
|
|
104
|
-
data.
|
|
130
|
+
if isinstance(data, eqx.Module):
|
|
131
|
+
data = eqx.tree_at(lambda m: m.rar_iter_from_last_sampling, data, 0)
|
|
132
|
+
else:
|
|
133
|
+
data.rar_iter_from_last_sampling = 0
|
|
105
134
|
|
|
106
135
|
return data, _rar_step_true, _rar_step_false
|
|
107
136
|
|
|
108
137
|
|
|
109
|
-
def _rar_step_init(sample_size, selected_sample_size)
|
|
138
|
+
def _rar_step_init(sample_size: Int, selected_sample_size: Int) -> tuple[
|
|
139
|
+
Callable[[rar_operands], AnyDataGenerator],
|
|
140
|
+
Callable[[rar_operands], AnyDataGenerator],
|
|
141
|
+
]:
|
|
110
142
|
"""
|
|
111
143
|
This is a wrapper because the sampling size and
|
|
112
144
|
selected_sample_size, must be treated as static
|
|
@@ -116,11 +148,17 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
116
148
|
This is a kind of manual declaration of static argnums
|
|
117
149
|
"""
|
|
118
150
|
|
|
119
|
-
def rar_step_true(operands):
|
|
151
|
+
def rar_step_true(operands: rar_operands) -> AnyDataGenerator:
|
|
120
152
|
loss, params, data, i = operands
|
|
121
153
|
|
|
122
154
|
if isinstance(data, DataGeneratorODE):
|
|
123
|
-
|
|
155
|
+
|
|
156
|
+
if isinstance(data, eqx.Module):
|
|
157
|
+
new_key, subkey = jax.random.split(data.key)
|
|
158
|
+
new_omega_samples = data.sample_in_time_domain(subkey, sample_size)
|
|
159
|
+
data = eqx.tree_at(lambda m: m.key, data, new_key)
|
|
160
|
+
else:
|
|
161
|
+
new_omega_samples = data.sample_in_time_domain(sample_size)
|
|
124
162
|
|
|
125
163
|
# We can have different types of Loss
|
|
126
164
|
if isinstance(loss, LossODE):
|
|
@@ -162,17 +200,29 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
162
200
|
## add the new points in times
|
|
163
201
|
# start indices of update can be dynamic but the the shape (length)
|
|
164
202
|
# of the slice
|
|
165
|
-
|
|
203
|
+
new_times = jax.lax.dynamic_update_slice(
|
|
166
204
|
data.times,
|
|
167
205
|
higher_residual_points,
|
|
168
206
|
(data.nt_start + data.rar_iter_nb * selected_sample_size,),
|
|
169
207
|
)
|
|
170
208
|
|
|
209
|
+
if isinstance(data, eqx.Module):
|
|
210
|
+
data = eqx.tree_at(lambda m: m.times, data, new_times)
|
|
211
|
+
else:
|
|
212
|
+
data.times = new_times
|
|
171
213
|
## rearrange probabilities so that the probabilities of the new
|
|
172
214
|
## points are non-zero
|
|
173
215
|
new_proba = 1 / (data.nt_start + data.rar_iter_nb * selected_sample_size)
|
|
174
216
|
# the next work because nt_start is static
|
|
175
|
-
|
|
217
|
+
new_p_times = data.p_times.at[: data.nt_start].set(new_proba)
|
|
218
|
+
if isinstance(data, eqx.Module):
|
|
219
|
+
data = eqx.tree_at(
|
|
220
|
+
lambda m: m.p_times,
|
|
221
|
+
data,
|
|
222
|
+
new_p_times,
|
|
223
|
+
)
|
|
224
|
+
else:
|
|
225
|
+
data.p_times = new_p_times
|
|
176
226
|
|
|
177
227
|
# the next requires a fori_loop because the range is dynamic
|
|
178
228
|
def update_slices(i, p):
|
|
@@ -182,16 +232,29 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
182
232
|
((data.nt_start + i * selected_sample_size),),
|
|
183
233
|
)
|
|
184
234
|
|
|
185
|
-
data.rar_iter_nb
|
|
186
|
-
|
|
187
|
-
data.p_times = jax.lax.fori_loop(
|
|
235
|
+
new_rar_iter_nb = data.rar_iter_nb + 1
|
|
236
|
+
new_p_times = jax.lax.fori_loop(
|
|
188
237
|
0, data.rar_iter_nb, update_slices, data.p_times
|
|
189
238
|
)
|
|
239
|
+
if isinstance(data, eqx.Module):
|
|
240
|
+
data = eqx.tree_at(
|
|
241
|
+
lambda m: (m.rar_iter_nb, m.p_times),
|
|
242
|
+
data,
|
|
243
|
+
(new_rar_iter_nb, new_p_times),
|
|
244
|
+
)
|
|
245
|
+
else:
|
|
246
|
+
data.rar_iter_nb = new_rar_iter_nb
|
|
247
|
+
data.p_times = new_p_times
|
|
190
248
|
|
|
191
249
|
elif isinstance(data, CubicMeshPDEStatio) and not isinstance(
|
|
192
250
|
data, CubicMeshPDENonStatio
|
|
193
251
|
):
|
|
194
|
-
|
|
252
|
+
if isinstance(data, eqx.Module):
|
|
253
|
+
new_key, *subkeys = jax.random.split(data.key, data.dim + 1)
|
|
254
|
+
new_omega_samples = data.sample_in_omega_domain(subkeys, sample_size)
|
|
255
|
+
data = eqx.tree_at(lambda m: m.key, data, new_key)
|
|
256
|
+
else:
|
|
257
|
+
new_omega_samples = data.sample_in_omega_domain(sample_size)
|
|
195
258
|
|
|
196
259
|
# We can have different types of Loss
|
|
197
260
|
if isinstance(loss, LossPDEStatio):
|
|
@@ -209,7 +272,7 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
209
272
|
mse_on_s = (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
|
|
210
273
|
else:
|
|
211
274
|
mse_on_s = dyn_on_s**2
|
|
212
|
-
elif isinstance(loss,
|
|
275
|
+
elif isinstance(loss, SystemLossODE):
|
|
213
276
|
mse_on_s = 0
|
|
214
277
|
for i in loss.dynamic_loss_dict.keys():
|
|
215
278
|
# only the case LossPDEStatio here
|
|
@@ -237,17 +300,30 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
237
300
|
## add the new points in omega
|
|
238
301
|
# start indices of update can be dynamic but not the shape (length)
|
|
239
302
|
# of the slice
|
|
240
|
-
|
|
303
|
+
new_omega = jax.lax.dynamic_update_slice(
|
|
241
304
|
data.omega,
|
|
242
305
|
higher_residual_points,
|
|
243
306
|
(data.n_start + data.rar_iter_nb * selected_sample_size, data.dim),
|
|
244
307
|
)
|
|
245
308
|
|
|
309
|
+
if isinstance(data, eqx.Module):
|
|
310
|
+
data = eqx.tree_at(lambda m: m.omega, data, new_omega)
|
|
311
|
+
else:
|
|
312
|
+
data.omega = new_omega
|
|
313
|
+
|
|
246
314
|
## rearrange probabilities so that the probabilities of the new
|
|
247
315
|
## points are non-zero
|
|
248
316
|
new_proba = 1 / (data.n_start + data.rar_iter_nb * selected_sample_size)
|
|
249
317
|
# the next work because n_start is static
|
|
250
|
-
|
|
318
|
+
new_p_omega = data.p_omega.at[: data.n_start].set(new_proba)
|
|
319
|
+
if isinstance(data, eqx.Module):
|
|
320
|
+
data = eqx.tree_at(
|
|
321
|
+
lambda m: m.p_omega,
|
|
322
|
+
data,
|
|
323
|
+
new_p_omega,
|
|
324
|
+
)
|
|
325
|
+
else:
|
|
326
|
+
data.p_omega = new_p_omega
|
|
251
327
|
|
|
252
328
|
# the next requires a fori_loop because the range is dynamic
|
|
253
329
|
def update_slices(i, p):
|
|
@@ -257,11 +333,19 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
257
333
|
((data.n_start + i * selected_sample_size),),
|
|
258
334
|
)
|
|
259
335
|
|
|
260
|
-
data.rar_iter_nb
|
|
261
|
-
|
|
262
|
-
data.p_omega = jax.lax.fori_loop(
|
|
336
|
+
new_rar_iter_nb = data.rar_iter_nb + 1
|
|
337
|
+
new_p_omega = jax.lax.fori_loop(
|
|
263
338
|
0, data.rar_iter_nb, update_slices, data.p_omega
|
|
264
339
|
)
|
|
340
|
+
if isinstance(data, eqx.Module):
|
|
341
|
+
data = eqx.tree_at(
|
|
342
|
+
lambda m: (m.rar_iter_nb, m.p_omega),
|
|
343
|
+
data,
|
|
344
|
+
(new_rar_iter_nb, new_p_omega),
|
|
345
|
+
)
|
|
346
|
+
else:
|
|
347
|
+
data.rar_iter_nb = new_rar_iter_nb
|
|
348
|
+
data.p_omega = new_p_omega
|
|
265
349
|
|
|
266
350
|
elif isinstance(data, CubicMeshPDENonStatio):
|
|
267
351
|
if isinstance(loss.u, HYPERPINN) or isinstance(loss.u, SPINN):
|
|
@@ -274,8 +358,19 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
274
358
|
)
|
|
275
359
|
sample_size_times, sample_size_omega = sample_size
|
|
276
360
|
|
|
277
|
-
|
|
278
|
-
|
|
361
|
+
if isinstance(data, eqx.Module):
|
|
362
|
+
new_key, subkey = jax.random.split(data.key)
|
|
363
|
+
new_times_samples = data.sample_in_time_domain(
|
|
364
|
+
subkey, sample_size_times
|
|
365
|
+
)
|
|
366
|
+
new_key, *subkeys = jax.random.split(new_key, data.dim + 1)
|
|
367
|
+
new_omega_samples = data.sample_in_omega_domain(
|
|
368
|
+
subkeys, sample_size_omega
|
|
369
|
+
)
|
|
370
|
+
data = eqx.tree_at(lambda m: m.key, data, new_key)
|
|
371
|
+
else:
|
|
372
|
+
new_times_samples = data.sample_in_time_domain(sample_size_times)
|
|
373
|
+
new_omega_samples = data.sample_in_omega_domain(sample_size_omega)
|
|
279
374
|
|
|
280
375
|
if not data.cartesian_product:
|
|
281
376
|
times = new_times_samples
|
|
@@ -342,14 +437,23 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
342
437
|
## add the new points in times
|
|
343
438
|
# start indices of update can be dynamic but not the shape (length)
|
|
344
439
|
# of the slice
|
|
345
|
-
|
|
440
|
+
new_times = jax.lax.dynamic_update_slice(
|
|
346
441
|
data.times,
|
|
347
442
|
higher_residual_points_times,
|
|
348
|
-
(
|
|
443
|
+
(
|
|
444
|
+
data.n_start
|
|
445
|
+
+ data.rar_iter_nb # NOTE typo here nt_start ?
|
|
446
|
+
* selected_sample_size_times,
|
|
447
|
+
),
|
|
349
448
|
)
|
|
350
449
|
|
|
450
|
+
if isinstance(data, eqx.Module):
|
|
451
|
+
data = eqx.tree_at(lambda m: m.times, data, new_times)
|
|
452
|
+
else:
|
|
453
|
+
data.times = new_times
|
|
454
|
+
|
|
351
455
|
## add the new points in omega
|
|
352
|
-
|
|
456
|
+
new_omega = jax.lax.dynamic_update_slice(
|
|
353
457
|
data.omega,
|
|
354
458
|
higher_residual_points_omega,
|
|
355
459
|
(
|
|
@@ -358,19 +462,38 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
358
462
|
),
|
|
359
463
|
)
|
|
360
464
|
|
|
465
|
+
if isinstance(data, eqx.Module):
|
|
466
|
+
data = eqx.tree_at(lambda m: m.omega, data, new_omega)
|
|
467
|
+
else:
|
|
468
|
+
data.omega = new_omega
|
|
469
|
+
|
|
361
470
|
## rearrange probabilities so that the probabilities of the new
|
|
362
471
|
## points are non-zero
|
|
363
472
|
new_p_times = 1 / (
|
|
364
473
|
data.nt_start + data.rar_iter_nb * selected_sample_size_times
|
|
365
474
|
)
|
|
366
475
|
# the next work because nt_start is static
|
|
367
|
-
|
|
476
|
+
if isinstance(data, eqx.Module):
|
|
477
|
+
data = eqx.tree_at(
|
|
478
|
+
lambda m: m.p_times,
|
|
479
|
+
data,
|
|
480
|
+
data.p_times.at[: data.nt_start].set(new_p_times),
|
|
481
|
+
)
|
|
482
|
+
else:
|
|
483
|
+
data.p_times = data.p_times.at[: data.nt_start].set(new_p_times)
|
|
368
484
|
|
|
369
485
|
# same for p_omega (work because n_start is static)
|
|
370
486
|
new_p_omega = 1 / (
|
|
371
487
|
data.n_start + data.rar_iter_nb * selected_sample_size_omega
|
|
372
488
|
)
|
|
373
|
-
|
|
489
|
+
if isinstance(data, eqx.Module):
|
|
490
|
+
data = eqx.tree_at(
|
|
491
|
+
lambda m: m.p_omega,
|
|
492
|
+
data,
|
|
493
|
+
data.p_omega.at[: data.n_start].set(new_p_omega),
|
|
494
|
+
)
|
|
495
|
+
else:
|
|
496
|
+
data.p_omega = data.p_omega.at[: data.n_start].set(new_p_omega)
|
|
374
497
|
|
|
375
498
|
# the part of data.p_* after n_start requires a fori_loop because
|
|
376
499
|
# the range is dynamic
|
|
@@ -385,13 +508,13 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
385
508
|
|
|
386
509
|
return update_slices
|
|
387
510
|
|
|
388
|
-
data.rar_iter_nb
|
|
511
|
+
new_rar_iter_nb = data.rar_iter_nb + 1
|
|
389
512
|
|
|
390
513
|
## update rest of p_times
|
|
391
514
|
update_slices_times = create_update_slices(
|
|
392
515
|
new_p_times, selected_sample_size_times
|
|
393
516
|
)
|
|
394
|
-
|
|
517
|
+
new_p_times = jax.lax.fori_loop(
|
|
395
518
|
0,
|
|
396
519
|
data.rar_iter_nb,
|
|
397
520
|
update_slices_times,
|
|
@@ -401,21 +524,34 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
401
524
|
update_slices_omega = create_update_slices(
|
|
402
525
|
new_p_omega, selected_sample_size_omega
|
|
403
526
|
)
|
|
404
|
-
|
|
527
|
+
new_p_omega = jax.lax.fori_loop(
|
|
405
528
|
0,
|
|
406
529
|
data.rar_iter_nb,
|
|
407
530
|
update_slices_omega,
|
|
408
531
|
data.p_omega,
|
|
409
532
|
)
|
|
533
|
+
if isinstance(data, eqx.Module):
|
|
534
|
+
data = eqx.tree_at(
|
|
535
|
+
lambda m: (m.rar_iter_nb, m.p_omega, m.p_times),
|
|
536
|
+
data,
|
|
537
|
+
(new_rar_iter_nb, new_p_omega, new_p_times),
|
|
538
|
+
)
|
|
539
|
+
else:
|
|
540
|
+
data.rar_iter_nb = new_rar_iter_nb
|
|
541
|
+
data.p_times = new_p_times
|
|
542
|
+
data.p_omega = new_p_omega
|
|
410
543
|
|
|
411
544
|
# update RAR parameters for all cases
|
|
412
|
-
data.
|
|
545
|
+
if isinstance(data, eqx.Module):
|
|
546
|
+
data = eqx.tree_at(lambda m: m.rar_iter_from_last_sampling, data, 0)
|
|
547
|
+
else:
|
|
548
|
+
data.rar_iter_from_last_sampling = 0
|
|
413
549
|
|
|
414
550
|
# NOTE must return data to be correctly updated because we cannot
|
|
415
551
|
# have side effects in this function that will be jitted
|
|
416
552
|
return data
|
|
417
553
|
|
|
418
|
-
def rar_step_false(operands):
|
|
554
|
+
def rar_step_false(operands: rar_operands) -> AnyDataGenerator:
|
|
419
555
|
_, _, data, i = operands
|
|
420
556
|
|
|
421
557
|
# Add 1 only if we are after the burn in period
|
|
@@ -425,7 +561,15 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
425
561
|
lambda: 1,
|
|
426
562
|
)
|
|
427
563
|
|
|
428
|
-
data.rar_iter_from_last_sampling
|
|
564
|
+
new_rar_iter_from_last_sampling = data.rar_iter_from_last_sampling + increment
|
|
565
|
+
if isinstance(data, eqx.Module):
|
|
566
|
+
data = eqx.tree_at(
|
|
567
|
+
lambda m: m.rar_iter_from_last_sampling,
|
|
568
|
+
data,
|
|
569
|
+
new_rar_iter_from_last_sampling,
|
|
570
|
+
)
|
|
571
|
+
else:
|
|
572
|
+
data.rar_iter_from_last_sampling = new_rar_iter_from_last_sampling
|
|
429
573
|
return data
|
|
430
574
|
|
|
431
575
|
return rar_step_true, rar_step_false
|