jinns 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +114 -187
- jinns/loss/_DynamicLossAbstract.py +74 -69
- jinns/loss/_LossODE.py +132 -348
- jinns/loss/_LossPDE.py +262 -549
- jinns/loss/__init__.py +32 -6
- jinns/loss/_abstract_loss.py +128 -0
- jinns/loss/_boundary_conditions.py +20 -19
- jinns/loss/_loss_components.py +43 -0
- jinns/loss/_loss_utils.py +85 -179
- jinns/loss/_loss_weight_updates.py +202 -0
- jinns/loss/_loss_weights.py +64 -40
- jinns/loss/_operators.py +84 -74
- jinns/nn/__init__.py +15 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +94 -57
- jinns/nn/_mlp.py +50 -25
- jinns/nn/_pinn.py +33 -19
- jinns/nn/_ppinn.py +70 -34
- jinns/nn/_save_load.py +21 -51
- jinns/nn/_spinn.py +33 -16
- jinns/nn/_spinn_mlp.py +28 -22
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +35 -34
- jinns/solver/_rar.py +80 -63
- jinns/solver/_solve.py +207 -92
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +16 -10
- jinns/utils/_types.py +20 -54
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/METADATA +8 -4
- jinns-1.5.0.dist-info/RECORD +55 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns-1.3.0.dist-info/RECORD +0 -44
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/top_level.txt +0 -0
jinns/plot/_plot.py
CHANGED
|
@@ -2,26 +2,24 @@
|
|
|
2
2
|
Utility functions for plotting in 1D and 2D, with and without time.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from
|
|
6
|
-
import warnings
|
|
7
|
-
import matplotlib.pyplot as plt
|
|
5
|
+
from typing import Callable
|
|
8
6
|
import jax.numpy as jnp
|
|
9
7
|
from jax import vmap
|
|
8
|
+
import matplotlib.pyplot as plt
|
|
10
9
|
from mpl_toolkits.axes_grid1 import ImageGrid
|
|
11
|
-
from
|
|
12
|
-
from jaxtyping import Array, Float, Bool
|
|
10
|
+
from jaxtyping import Array, Float
|
|
13
11
|
|
|
14
12
|
|
|
15
13
|
def plot2d(
|
|
16
14
|
fun: Callable,
|
|
17
|
-
xy_data: tuple[Float[Array, "nx"], Float[Array, "ny"]],
|
|
18
|
-
times: Float[Array, "nt"] |
|
|
15
|
+
xy_data: tuple[Float[Array, " nx"], Float[Array, " ny"]],
|
|
16
|
+
times: Float[Array, " nt"] | list[float] | None = None,
|
|
19
17
|
Tmax: float = 1,
|
|
20
18
|
title: str = "",
|
|
21
19
|
figsize: tuple = (7, 7),
|
|
22
20
|
cmap: str = "inferno",
|
|
23
21
|
spinn: bool = False,
|
|
24
|
-
vmin_vmax: tuple[float, float]
|
|
22
|
+
vmin_vmax: tuple[float | None, float | None] | None = None,
|
|
25
23
|
):
|
|
26
24
|
r"""Generic function for plotting functions over rectangular 2-D domains
|
|
27
25
|
$\Omega$. It handles both the
|
|
@@ -61,6 +59,8 @@ def plot2d(
|
|
|
61
59
|
ValueError
|
|
62
60
|
if xy_data is not a list of length 2
|
|
63
61
|
"""
|
|
62
|
+
if vmin_vmax is None:
|
|
63
|
+
vmin_vmax = (None, None)
|
|
64
64
|
|
|
65
65
|
# if not isinstance(xy_data, jnp.ndarray) and not xy_data.shape[-1] == 2:
|
|
66
66
|
if not isinstance(xy_data, list) and not len(xy_data) == 2:
|
|
@@ -100,7 +100,7 @@ def plot2d(
|
|
|
100
100
|
im = ax.pcolormesh(
|
|
101
101
|
mesh[0],
|
|
102
102
|
mesh[1],
|
|
103
|
-
ret
|
|
103
|
+
ret,
|
|
104
104
|
cmap=cmap,
|
|
105
105
|
vmin=vmin_vmax[0],
|
|
106
106
|
vmax=vmin_vmax[1],
|
|
@@ -113,7 +113,7 @@ def plot2d(
|
|
|
113
113
|
if not isinstance(times, list):
|
|
114
114
|
try:
|
|
115
115
|
times = times.tolist()
|
|
116
|
-
except:
|
|
116
|
+
except AttributeError:
|
|
117
117
|
raise ValueError("times must be a list or an array")
|
|
118
118
|
|
|
119
119
|
fig = plt.figure(figsize=figsize)
|
|
@@ -129,7 +129,7 @@ def plot2d(
|
|
|
129
129
|
cbar_pad=0.4,
|
|
130
130
|
)
|
|
131
131
|
|
|
132
|
-
for idx, (t, ax) in enumerate(zip(times, grid)):
|
|
132
|
+
for idx, (t, ax) in enumerate(zip(times, iter(grid))):
|
|
133
133
|
if not spinn:
|
|
134
134
|
x_grid, y_grid = mesh
|
|
135
135
|
v_fun_at_t = vmap(fun)(
|
|
@@ -142,12 +142,11 @@ def plot2d(
|
|
|
142
142
|
axis=-1,
|
|
143
143
|
)
|
|
144
144
|
)
|
|
145
|
-
t_slice
|
|
145
|
+
t_slice = _plot_2D_statio(
|
|
146
146
|
v_fun_at_t,
|
|
147
147
|
mesh,
|
|
148
148
|
plot=False, # only use to compute t_slice
|
|
149
149
|
colorbar=False,
|
|
150
|
-
cmap=None,
|
|
151
150
|
vmin_vmax=vmin_vmax,
|
|
152
151
|
)
|
|
153
152
|
elif spinn:
|
|
@@ -161,7 +160,7 @@ def plot2d(
|
|
|
161
160
|
axis=-1,
|
|
162
161
|
)
|
|
163
162
|
values_grid = jnp.squeeze(fun(t_x)[0]).T
|
|
164
|
-
t_slice
|
|
163
|
+
t_slice = _plot_2D_statio(
|
|
165
164
|
values_grid,
|
|
166
165
|
mesh,
|
|
167
166
|
plot=False, # only use to compute t_slice
|
|
@@ -182,14 +181,14 @@ def plot2d(
|
|
|
182
181
|
|
|
183
182
|
|
|
184
183
|
def _plot_2D_statio(
|
|
185
|
-
v_fun: Callable | Float[Array, "(nx*ny)^2 1"],
|
|
186
|
-
mesh: Float[Array, "nx*ny nx*ny"],
|
|
187
|
-
plot:
|
|
188
|
-
colorbar:
|
|
184
|
+
v_fun: Callable | Float[Array, " (nx*ny)^2 1"],
|
|
185
|
+
mesh: list[Float[Array, " nx*ny nx*ny"]],
|
|
186
|
+
plot: bool = True,
|
|
187
|
+
colorbar: bool = True,
|
|
189
188
|
cmap: str = "inferno",
|
|
190
189
|
figsize: tuple[int, int] = (7, 7),
|
|
191
|
-
vmin_vmax: tuple[float, float]
|
|
192
|
-
):
|
|
190
|
+
vmin_vmax: tuple[float | None, float | None] | None = None,
|
|
191
|
+
) -> Array | None:
|
|
193
192
|
"""Function that plot the function u(x) with 2-D input x using pcolormesh()
|
|
194
193
|
|
|
195
194
|
|
|
@@ -217,6 +216,8 @@ def _plot_2D_statio(
|
|
|
217
216
|
Either None or the values of u() over the meshgrid and the current plt axis
|
|
218
217
|
|
|
219
218
|
"""
|
|
219
|
+
if vmin_vmax is None:
|
|
220
|
+
vmin_vmax = (None, None)
|
|
220
221
|
|
|
221
222
|
x_grid, y_grid = mesh
|
|
222
223
|
if callable(v_fun):
|
|
@@ -238,19 +239,18 @@ def _plot_2D_statio(
|
|
|
238
239
|
|
|
239
240
|
if colorbar:
|
|
240
241
|
fig.colorbar(im, format="%0.2f")
|
|
241
|
-
# don't plt.show() because it is done in plot2d()
|
|
242
242
|
else:
|
|
243
|
-
return values_grid
|
|
243
|
+
return values_grid
|
|
244
244
|
|
|
245
245
|
|
|
246
246
|
def plot1d_slice(
|
|
247
|
-
fun: Callable[[
|
|
248
|
-
xdata: Float[Array, "nx"],
|
|
249
|
-
time_slices: Float[Array, "nt"] | None = None,
|
|
247
|
+
fun: Callable[[Float[Array, " "]], Float[Array, " "]],
|
|
248
|
+
xdata: Float[Array, " nx"],
|
|
249
|
+
time_slices: Float[Array, " nt"] | None = None,
|
|
250
250
|
Tmax: float = 1.0,
|
|
251
251
|
title: str = "",
|
|
252
252
|
figsize: tuple[int, int] = (10, 10),
|
|
253
|
-
spinn:
|
|
253
|
+
spinn: bool = False,
|
|
254
254
|
ax=None,
|
|
255
255
|
):
|
|
256
256
|
"""Function for plotting time slices of a function :math:`f(t_i, x)` where
|
|
@@ -284,7 +284,7 @@ def plot1d_slice(
|
|
|
284
284
|
if time_slices is None:
|
|
285
285
|
time_slices = jnp.array([0])
|
|
286
286
|
if ax is None:
|
|
287
|
-
|
|
287
|
+
_, ax = plt.subplots(figsize=figsize)
|
|
288
288
|
|
|
289
289
|
for t in time_slices:
|
|
290
290
|
t_xdata = jnp.concatenate(
|
|
@@ -306,16 +306,16 @@ def plot1d_slice(
|
|
|
306
306
|
|
|
307
307
|
|
|
308
308
|
def plot1d_image(
|
|
309
|
-
fun: Callable[[
|
|
310
|
-
xdata: Float[Array, "nx"],
|
|
311
|
-
times: Float[Array, "nt"],
|
|
309
|
+
fun: Callable[[Float[Array, " "]], Float[Array, " "]],
|
|
310
|
+
xdata: Float[Array, " nx"],
|
|
311
|
+
times: Float[Array, " nt"],
|
|
312
312
|
Tmax: float = 1.0,
|
|
313
313
|
title: str = "",
|
|
314
314
|
figsize: tuple[int, int] = (10, 10),
|
|
315
|
-
colorbar:
|
|
315
|
+
colorbar: bool = True,
|
|
316
316
|
cmap: str = "inferno",
|
|
317
|
-
spinn:
|
|
318
|
-
vmin_vmax: tuple[float, float]
|
|
317
|
+
spinn: bool = False,
|
|
318
|
+
vmin_vmax: tuple[float | None, float | None] | None = None,
|
|
319
319
|
):
|
|
320
320
|
"""Function for plotting the 2-D image of a function :math:`f(t, x)` where
|
|
321
321
|
`t` is time (1-D) and x is space (1-D).
|
|
@@ -350,7 +350,8 @@ def plot1d_image(
|
|
|
350
350
|
fig, ax
|
|
351
351
|
A `matplotlib` `Figure` and `Axes` objects with the figure.
|
|
352
352
|
"""
|
|
353
|
-
|
|
353
|
+
if vmin_vmax is None:
|
|
354
|
+
vmin_vmax = (None, None)
|
|
354
355
|
mesh = jnp.meshgrid(times, xdata) # cartesian product
|
|
355
356
|
if not spinn:
|
|
356
357
|
# the trick is to use _plot2Dstatio
|
jinns/solver/_rar.py
CHANGED
|
@@ -2,41 +2,73 @@ from __future__ import (
|
|
|
2
2
|
annotations,
|
|
3
3
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
4
4
|
|
|
5
|
-
from typing import TYPE_CHECKING, Callable
|
|
5
|
+
from typing import TYPE_CHECKING, Callable, TypeAlias, Any, TypedDict
|
|
6
6
|
from functools import partial
|
|
7
|
+
from jaxtyping import Float, Array, Bool
|
|
7
8
|
import jax
|
|
8
9
|
from jax import vmap
|
|
9
10
|
import jax.numpy as jnp
|
|
10
11
|
import equinox as eqx
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
from jinns.data.
|
|
14
|
-
from jinns.
|
|
15
|
-
from jinns.loss._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
|
|
16
|
-
from jinns.loss._loss_utils import dynamic_loss_apply
|
|
17
|
-
from jinns.data._DataGenerators import (
|
|
18
|
-
DataGeneratorODE,
|
|
19
|
-
CubicMeshPDEStatio,
|
|
20
|
-
CubicMeshPDENonStatio,
|
|
21
|
-
)
|
|
12
|
+
|
|
13
|
+
from jinns.data._DataGeneratorODE import DataGeneratorODE
|
|
14
|
+
from jinns.data._CubicMeshPDEStatio import CubicMeshPDEStatio
|
|
15
|
+
from jinns.data._CubicMeshPDENonStatio import CubicMeshPDENonStatio
|
|
22
16
|
from jinns.nn._hyperpinn import HyperPINN
|
|
23
17
|
from jinns.nn._spinn import SPINN
|
|
24
18
|
|
|
25
19
|
|
|
26
20
|
if TYPE_CHECKING:
|
|
27
|
-
from jinns.
|
|
21
|
+
from jinns.data._AbstractDataGenerator import AbstractDataGenerator
|
|
22
|
+
from jinns.utils._types import AnyLoss
|
|
23
|
+
from jinns.parameters._params import Params
|
|
24
|
+
|
|
25
|
+
class DataGeneratorWithRAR(AbstractDataGenerator):
|
|
26
|
+
"""
|
|
27
|
+
Add the required RAR operands for type checks
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
rar_parameters: RarParameterDict
|
|
31
|
+
n_start: int
|
|
32
|
+
rar_iter_from_last_sampling: int
|
|
33
|
+
rar_iter_nb: int
|
|
34
|
+
p: Float[Array, " n 1"]
|
|
35
|
+
|
|
36
|
+
rar_operands: TypeAlias = tuple[Any, Params, DataGeneratorWithRAR, int]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class RarParameterDict(TypedDict):
|
|
40
|
+
"""
|
|
41
|
+
TypedDict to specify the Residual Adaptative Resampling procedure
|
|
42
|
+
Otherwise a dictionary with keys
|
|
43
|
+
- `start_iter`: the iteration at which we start the RAR sampling scheme (we first have a "burn-in" period).
|
|
44
|
+
- `update_every`: the number of gradient steps taken between
|
|
45
|
+
each update of collocation points in the RAR algo.
|
|
46
|
+
- `sample_size`: the size of the sample from which we will select new
|
|
47
|
+
collocation points.
|
|
48
|
+
- `selected_sample_size`: the number of selected
|
|
49
|
+
points from the sample to be added to the current collocation
|
|
50
|
+
points.
|
|
51
|
+
"""
|
|
28
52
|
|
|
53
|
+
start_iter: int
|
|
54
|
+
update_every: int
|
|
55
|
+
sample_size: int
|
|
56
|
+
selected_sample_size: int
|
|
29
57
|
|
|
30
|
-
|
|
58
|
+
|
|
59
|
+
def _proceed_to_rar(data: DataGeneratorWithRAR, i: int) -> Bool[Array, " "]:
|
|
31
60
|
"""Utilility function with various check to ensure we can proceed with the rar_step.
|
|
32
61
|
Return True if yes, and False otherwise"""
|
|
33
62
|
|
|
34
63
|
# Overall checks
|
|
35
64
|
check_list = [
|
|
36
65
|
# check if burn-in period has ended
|
|
37
|
-
data.rar_parameters["start_iter"] <= i,
|
|
66
|
+
jnp.asarray(data.rar_parameters["start_iter"] <= i),
|
|
38
67
|
# check if enough iterations since last points added
|
|
39
|
-
(
|
|
68
|
+
jnp.asarray(
|
|
69
|
+
(data.rar_parameters["update_every"] - 1)
|
|
70
|
+
== data.rar_iter_from_last_sampling
|
|
71
|
+
),
|
|
40
72
|
]
|
|
41
73
|
|
|
42
74
|
# Memory allocation checks
|
|
@@ -52,14 +84,13 @@ def _proceed_to_rar(data: AnyDataGenerator, i: Int) -> Bool:
|
|
|
52
84
|
|
|
53
85
|
@partial(jax.jit, static_argnames=["_rar_step_true", "_rar_step_false"])
|
|
54
86
|
def trigger_rar(
|
|
55
|
-
i:
|
|
87
|
+
i: int,
|
|
56
88
|
loss: AnyLoss,
|
|
57
|
-
params:
|
|
58
|
-
data:
|
|
59
|
-
_rar_step_true: Callable[[rar_operands],
|
|
60
|
-
_rar_step_false: Callable[[rar_operands],
|
|
61
|
-
) -> tuple[AnyLoss,
|
|
62
|
-
|
|
89
|
+
params: Params,
|
|
90
|
+
data: DataGeneratorWithRAR,
|
|
91
|
+
_rar_step_true: Callable[[rar_operands], DataGeneratorWithRAR],
|
|
92
|
+
_rar_step_false: Callable[[rar_operands], DataGeneratorWithRAR],
|
|
93
|
+
) -> tuple[AnyLoss, Params, DataGeneratorWithRAR]:
|
|
63
94
|
if data.rar_parameters is None:
|
|
64
95
|
# do nothing.
|
|
65
96
|
return loss, params, data
|
|
@@ -75,11 +106,11 @@ def trigger_rar(
|
|
|
75
106
|
|
|
76
107
|
|
|
77
108
|
def init_rar(
|
|
78
|
-
data:
|
|
109
|
+
data: DataGeneratorWithRAR,
|
|
79
110
|
) -> tuple[
|
|
80
|
-
|
|
81
|
-
Callable[[rar_operands],
|
|
82
|
-
Callable[[rar_operands],
|
|
111
|
+
DataGeneratorWithRAR,
|
|
112
|
+
Callable[[rar_operands], DataGeneratorWithRAR] | None,
|
|
113
|
+
Callable[[rar_operands], DataGeneratorWithRAR] | None,
|
|
83
114
|
]:
|
|
84
115
|
"""
|
|
85
116
|
Separated from the main rar, because the initialization to get _true and
|
|
@@ -100,9 +131,11 @@ def init_rar(
|
|
|
100
131
|
return data, _rar_step_true, _rar_step_false
|
|
101
132
|
|
|
102
133
|
|
|
103
|
-
def _rar_step_init(
|
|
104
|
-
|
|
105
|
-
|
|
134
|
+
def _rar_step_init(
|
|
135
|
+
sample_size: int, selected_sample_size: int
|
|
136
|
+
) -> tuple[
|
|
137
|
+
Callable[[rar_operands], DataGeneratorWithRAR],
|
|
138
|
+
Callable[[rar_operands], DataGeneratorWithRAR],
|
|
106
139
|
]:
|
|
107
140
|
"""
|
|
108
141
|
This is a wrapper because the sampling size and
|
|
@@ -113,13 +146,12 @@ def _rar_step_init(sample_size: Int, selected_sample_size: Int) -> tuple[
|
|
|
113
146
|
This is a kind of manual declaration of static argnums
|
|
114
147
|
"""
|
|
115
148
|
|
|
116
|
-
def rar_step_true(operands: rar_operands) ->
|
|
117
|
-
loss, params, data,
|
|
149
|
+
def rar_step_true(operands: rar_operands) -> DataGeneratorWithRAR:
|
|
150
|
+
loss, params, data, _ = operands
|
|
118
151
|
if isinstance(loss.u, HyperPINN) or isinstance(loss.u, SPINN):
|
|
119
152
|
raise NotImplementedError("RAR not implemented for hyperPINN and SPINN")
|
|
120
153
|
|
|
121
154
|
if isinstance(data, DataGeneratorODE):
|
|
122
|
-
|
|
123
155
|
new_key, subkey = jax.random.split(data.key)
|
|
124
156
|
new_samples = data.sample_in_time_domain(subkey, sample_size)
|
|
125
157
|
data = eqx.tree_at(lambda m: m.key, data, new_key)
|
|
@@ -145,33 +177,15 @@ def _rar_step_init(sample_size: Int, selected_sample_size: Int) -> tuple[
|
|
|
145
177
|
|
|
146
178
|
data = eqx.tree_at(lambda m: m.key, data, new_key)
|
|
147
179
|
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
)
|
|
153
|
-
dyn_on_s = v_dyn_loss(new_samples)
|
|
180
|
+
v_dyn_loss = vmap(
|
|
181
|
+
lambda inputs: loss.dynamic_loss.evaluate(inputs, loss.u, params),
|
|
182
|
+
)
|
|
183
|
+
dyn_on_s = v_dyn_loss(new_samples)
|
|
154
184
|
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
elif isinstance(loss, SystemLossODE, SystemLossPDE):
|
|
160
|
-
mse_on_s = 0
|
|
161
|
-
|
|
162
|
-
for i in loss.dynamic_loss_dict.keys():
|
|
163
|
-
v_dyn_loss = vmap(
|
|
164
|
-
lambda inputs: loss.dynamic_loss_dict[i].evaluate(
|
|
165
|
-
inputs, loss.u_dict, params
|
|
166
|
-
),
|
|
167
|
-
(0),
|
|
168
|
-
0,
|
|
169
|
-
)
|
|
170
|
-
dyn_on_s = v_dyn_loss(new_samples)
|
|
171
|
-
if dyn_on_s.ndim > 1:
|
|
172
|
-
mse_on_s += (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
|
|
173
|
-
else:
|
|
174
|
-
mse_on_s += dyn_on_s**2
|
|
185
|
+
if dyn_on_s.ndim > 1:
|
|
186
|
+
mse_on_s = (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
|
|
187
|
+
else:
|
|
188
|
+
mse_on_s = dyn_on_s**2
|
|
175
189
|
|
|
176
190
|
## Select the m points with higher dynamic loss
|
|
177
191
|
higher_residual_idx = jax.lax.dynamic_slice(
|
|
@@ -188,7 +202,7 @@ def _rar_step_init(sample_size: Int, selected_sample_size: Int) -> tuple[
|
|
|
188
202
|
new_times = jax.lax.dynamic_update_slice(
|
|
189
203
|
data.times,
|
|
190
204
|
higher_residual_points,
|
|
191
|
-
(data.n_start + data.rar_iter_nb * selected_sample_size,),
|
|
205
|
+
(data.n_start + data.rar_iter_nb * selected_sample_size,), # type: ignore
|
|
192
206
|
)
|
|
193
207
|
|
|
194
208
|
data = eqx.tree_at(lambda m: m.times, data, new_times)
|
|
@@ -198,7 +212,7 @@ def _rar_step_init(sample_size: Int, selected_sample_size: Int) -> tuple[
|
|
|
198
212
|
new_omega = jax.lax.dynamic_update_slice(
|
|
199
213
|
data.omega,
|
|
200
214
|
higher_residual_points,
|
|
201
|
-
(data.n_start + data.rar_iter_nb * selected_sample_size, data.dim),
|
|
215
|
+
(data.n_start + data.rar_iter_nb * selected_sample_size, data.dim), # type: ignore
|
|
202
216
|
)
|
|
203
217
|
|
|
204
218
|
data = eqx.tree_at(lambda m: m.omega, data, new_omega)
|
|
@@ -207,7 +221,10 @@ def _rar_step_init(sample_size: Int, selected_sample_size: Int) -> tuple[
|
|
|
207
221
|
new_domain = jax.lax.dynamic_update_slice(
|
|
208
222
|
data.domain,
|
|
209
223
|
higher_residual_points,
|
|
210
|
-
(
|
|
224
|
+
(
|
|
225
|
+
data.n_start + data.rar_iter_nb * selected_sample_size, # type: ignore
|
|
226
|
+
1 + data.dim,
|
|
227
|
+
),
|
|
211
228
|
)
|
|
212
229
|
|
|
213
230
|
data = eqx.tree_at(lambda m: m.domain, data, new_domain)
|
|
@@ -246,7 +263,7 @@ def _rar_step_init(sample_size: Int, selected_sample_size: Int) -> tuple[
|
|
|
246
263
|
# have side effects in this function that will be jitted
|
|
247
264
|
return data
|
|
248
265
|
|
|
249
|
-
def rar_step_false(operands: rar_operands) ->
|
|
266
|
+
def rar_step_false(operands: rar_operands) -> DataGeneratorWithRAR:
|
|
250
267
|
_, _, data, i = operands
|
|
251
268
|
|
|
252
269
|
# Add 1 only if we are after the burn in period
|