jinns 0.8.10__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -0
- jinns/data/_Batchs.py +27 -0
- jinns/data/_DataGenerators.py +953 -1182
- jinns/data/__init__.py +4 -8
- jinns/experimental/__init__.py +0 -2
- jinns/experimental/_diffrax_solver.py +5 -5
- jinns/loss/_DynamicLoss.py +282 -305
- jinns/loss/_DynamicLossAbstract.py +321 -168
- jinns/loss/_LossODE.py +290 -307
- jinns/loss/_LossPDE.py +628 -1040
- jinns/loss/__init__.py +21 -5
- jinns/loss/_boundary_conditions.py +95 -96
- jinns/loss/{_Losses.py → _loss_utils.py} +104 -46
- jinns/loss/_loss_weights.py +59 -0
- jinns/loss/_operators.py +78 -72
- jinns/parameters/__init__.py +6 -0
- jinns/parameters/_derivative_keys.py +94 -0
- jinns/parameters/_params.py +115 -0
- jinns/plot/__init__.py +5 -0
- jinns/{data/_display.py → plot/_plot.py} +98 -75
- jinns/solver/_rar.py +193 -45
- jinns/solver/_solve.py +199 -144
- jinns/utils/__init__.py +3 -9
- jinns/utils/_containers.py +37 -43
- jinns/utils/_hyperpinn.py +226 -127
- jinns/utils/_pinn.py +183 -111
- jinns/utils/_save_load.py +121 -56
- jinns/utils/_spinn.py +117 -84
- jinns/utils/_types.py +64 -0
- jinns/utils/_utils.py +6 -160
- jinns/validation/_validation.py +52 -144
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/METADATA +5 -4
- jinns-1.0.0.dist-info/RECORD +38 -0
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/WHEEL +1 -1
- jinns/experimental/_sinuspinn.py +0 -135
- jinns/experimental/_spectralpinn.py +0 -87
- jinns/solver/_seq2seq.py +0 -157
- jinns/utils/_optim.py +0 -147
- jinns/utils/_utils_uspinn.py +0 -727
- jinns-0.8.10.dist-info/RECORD +0 -36
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/LICENSE +0 -0
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/top_level.txt +0 -0
jinns/solver/_rar.py
CHANGED
|
@@ -1,21 +1,32 @@
|
|
|
1
|
+
from __future__ import (
|
|
2
|
+
annotations,
|
|
3
|
+
) # https://docs.python.org/3/library/typing.html#constant
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Callable
|
|
6
|
+
from functools import partial
|
|
1
7
|
import jax
|
|
2
8
|
from jax import vmap
|
|
3
9
|
import jax.numpy as jnp
|
|
10
|
+
import equinox as eqx
|
|
11
|
+
from jaxtyping import Int, Bool
|
|
12
|
+
|
|
13
|
+
from jinns.data._Batchs import *
|
|
14
|
+
from jinns.loss._LossODE import LossODE, SystemLossODE
|
|
15
|
+
from jinns.loss._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
|
|
4
16
|
from jinns.data._DataGenerators import (
|
|
5
17
|
DataGeneratorODE,
|
|
6
18
|
CubicMeshPDEStatio,
|
|
7
19
|
CubicMeshPDENonStatio,
|
|
8
20
|
)
|
|
9
|
-
from jinns.loss._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
|
|
10
|
-
from jinns.loss._LossODE import LossODE, SystemLossODE
|
|
11
|
-
from jinns.loss._DynamicLossAbstract import PDEStatio
|
|
12
|
-
|
|
13
|
-
from functools import partial
|
|
14
21
|
from jinns.utils._hyperpinn import HYPERPINN
|
|
15
22
|
from jinns.utils._spinn import SPINN
|
|
16
23
|
|
|
17
24
|
|
|
18
|
-
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from jinns.utils._types import *
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _proceed_to_rar(data: AnyDataGenerator, i: Int) -> Bool:
|
|
19
30
|
"""Utilility function with various check to ensure we can proceed with the rar_step.
|
|
20
31
|
Return True if yes, and False otherwise"""
|
|
21
32
|
|
|
@@ -30,13 +41,13 @@ def _proceed_to_rar(data, i):
|
|
|
30
41
|
# Memory allocation checks (depends on the type of DataGenerator)
|
|
31
42
|
# check if we still have room to append new collocation points in the
|
|
32
43
|
# allocated jnp.array (can concern `data.p_times` or `p_omega`)
|
|
33
|
-
if isinstance(data, DataGeneratorODE
|
|
44
|
+
if isinstance(data, (DataGeneratorODE, CubicMeshPDENonStatio)):
|
|
34
45
|
check_list.append(
|
|
35
46
|
data.rar_parameters["selected_sample_size_times"]
|
|
36
47
|
<= jnp.count_nonzero(data.p_times == 0),
|
|
37
48
|
)
|
|
38
49
|
|
|
39
|
-
if isinstance(data, CubicMeshPDEStatio
|
|
50
|
+
if isinstance(data, (CubicMeshPDEStatio, CubicMeshPDENonStatio)):
|
|
40
51
|
# for now the above check are redundants but there may be a time when
|
|
41
52
|
# we drop inheritence
|
|
42
53
|
check_list.append(
|
|
@@ -49,7 +60,14 @@ def _proceed_to_rar(data, i):
|
|
|
49
60
|
|
|
50
61
|
|
|
51
62
|
@partial(jax.jit, static_argnames=["_rar_step_true", "_rar_step_false"])
|
|
52
|
-
def trigger_rar(
|
|
63
|
+
def trigger_rar(
|
|
64
|
+
i: Int,
|
|
65
|
+
loss: AnyLoss,
|
|
66
|
+
params: AnyParams,
|
|
67
|
+
data: AnyDataGenerator,
|
|
68
|
+
_rar_step_true: Callable[[rar_operands], AnyDataGenerator],
|
|
69
|
+
_rar_step_false: Callable[[rar_operands], AnyDataGenerator],
|
|
70
|
+
) -> tuple[AnyLoss, AnyParams, AnyDataGenerator]:
|
|
53
71
|
|
|
54
72
|
if data.rar_parameters is None:
|
|
55
73
|
# do nothing.
|
|
@@ -65,7 +83,13 @@ def trigger_rar(i, loss, params, data, _rar_step_true, _rar_step_false):
|
|
|
65
83
|
return loss, params, data
|
|
66
84
|
|
|
67
85
|
|
|
68
|
-
def init_rar(
|
|
86
|
+
def init_rar(
|
|
87
|
+
data: AnyDataGenerator,
|
|
88
|
+
) -> tuple[
|
|
89
|
+
AnyDataGenerator,
|
|
90
|
+
Callable[[rar_operands], AnyDataGenerator],
|
|
91
|
+
Callable[[rar_operands], AnyDataGenerator],
|
|
92
|
+
]:
|
|
69
93
|
"""
|
|
70
94
|
Separated from the main rar, because the initialization to get _true and
|
|
71
95
|
_false cannot be jit-ted.
|
|
@@ -100,13 +124,21 @@ def init_rar(data):
|
|
|
100
124
|
data.rar_parameters["sample_size_omega"],
|
|
101
125
|
data.rar_parameters["selected_sample_size_omega"],
|
|
102
126
|
)
|
|
127
|
+
else:
|
|
128
|
+
raise ValueError(f"Wrong type for data got {type(data)}")
|
|
103
129
|
|
|
104
|
-
data.
|
|
130
|
+
if isinstance(data, eqx.Module):
|
|
131
|
+
data = eqx.tree_at(lambda m: m.rar_iter_from_last_sampling, data, 0)
|
|
132
|
+
else:
|
|
133
|
+
data.rar_iter_from_last_sampling = 0
|
|
105
134
|
|
|
106
135
|
return data, _rar_step_true, _rar_step_false
|
|
107
136
|
|
|
108
137
|
|
|
109
|
-
def _rar_step_init(sample_size, selected_sample_size)
|
|
138
|
+
def _rar_step_init(sample_size: Int, selected_sample_size: Int) -> tuple[
|
|
139
|
+
Callable[[rar_operands], AnyDataGenerator],
|
|
140
|
+
Callable[[rar_operands], AnyDataGenerator],
|
|
141
|
+
]:
|
|
110
142
|
"""
|
|
111
143
|
This is a wrapper because the sampling size and
|
|
112
144
|
selected_sample_size, must be treated as static
|
|
@@ -116,11 +148,17 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
116
148
|
This is a kind of manual declaration of static argnums
|
|
117
149
|
"""
|
|
118
150
|
|
|
119
|
-
def rar_step_true(operands):
|
|
151
|
+
def rar_step_true(operands: rar_operands) -> AnyDataGenerator:
|
|
120
152
|
loss, params, data, i = operands
|
|
121
153
|
|
|
122
154
|
if isinstance(data, DataGeneratorODE):
|
|
123
|
-
|
|
155
|
+
|
|
156
|
+
if isinstance(data, eqx.Module):
|
|
157
|
+
new_key, subkey = jax.random.split(data.key)
|
|
158
|
+
new_omega_samples = data.sample_in_time_domain(subkey, sample_size)
|
|
159
|
+
data = eqx.tree_at(lambda m: m.key, data, new_key)
|
|
160
|
+
else:
|
|
161
|
+
new_omega_samples = data.sample_in_time_domain(sample_size)
|
|
124
162
|
|
|
125
163
|
# We can have different types of Loss
|
|
126
164
|
if isinstance(loss, LossODE):
|
|
@@ -162,17 +200,29 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
162
200
|
## add the new points in times
|
|
163
201
|
# start indices of update can be dynamic but the the shape (length)
|
|
164
202
|
# of the slice
|
|
165
|
-
|
|
203
|
+
new_times = jax.lax.dynamic_update_slice(
|
|
166
204
|
data.times,
|
|
167
205
|
higher_residual_points,
|
|
168
206
|
(data.nt_start + data.rar_iter_nb * selected_sample_size,),
|
|
169
207
|
)
|
|
170
208
|
|
|
209
|
+
if isinstance(data, eqx.Module):
|
|
210
|
+
data = eqx.tree_at(lambda m: m.times, data, new_times)
|
|
211
|
+
else:
|
|
212
|
+
data.times = new_times
|
|
171
213
|
## rearrange probabilities so that the probabilities of the new
|
|
172
214
|
## points are non-zero
|
|
173
215
|
new_proba = 1 / (data.nt_start + data.rar_iter_nb * selected_sample_size)
|
|
174
216
|
# the next work because nt_start is static
|
|
175
|
-
|
|
217
|
+
new_p_times = data.p_times.at[: data.nt_start].set(new_proba)
|
|
218
|
+
if isinstance(data, eqx.Module):
|
|
219
|
+
data = eqx.tree_at(
|
|
220
|
+
lambda m: m.p_times,
|
|
221
|
+
data,
|
|
222
|
+
new_p_times,
|
|
223
|
+
)
|
|
224
|
+
else:
|
|
225
|
+
data.p_times = new_p_times
|
|
176
226
|
|
|
177
227
|
# the next requires a fori_loop because the range is dynamic
|
|
178
228
|
def update_slices(i, p):
|
|
@@ -182,16 +232,29 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
182
232
|
((data.nt_start + i * selected_sample_size),),
|
|
183
233
|
)
|
|
184
234
|
|
|
185
|
-
data.rar_iter_nb
|
|
186
|
-
|
|
187
|
-
data.p_times = jax.lax.fori_loop(
|
|
235
|
+
new_rar_iter_nb = data.rar_iter_nb + 1
|
|
236
|
+
new_p_times = jax.lax.fori_loop(
|
|
188
237
|
0, data.rar_iter_nb, update_slices, data.p_times
|
|
189
238
|
)
|
|
239
|
+
if isinstance(data, eqx.Module):
|
|
240
|
+
data = eqx.tree_at(
|
|
241
|
+
lambda m: (m.rar_iter_nb, m.p_times),
|
|
242
|
+
data,
|
|
243
|
+
(new_rar_iter_nb, new_p_times),
|
|
244
|
+
)
|
|
245
|
+
else:
|
|
246
|
+
data.rar_iter_nb = new_rar_iter_nb
|
|
247
|
+
data.p_times = new_p_times
|
|
190
248
|
|
|
191
249
|
elif isinstance(data, CubicMeshPDEStatio) and not isinstance(
|
|
192
250
|
data, CubicMeshPDENonStatio
|
|
193
251
|
):
|
|
194
|
-
|
|
252
|
+
if isinstance(data, eqx.Module):
|
|
253
|
+
new_key, *subkeys = jax.random.split(data.key, data.dim + 1)
|
|
254
|
+
new_omega_samples = data.sample_in_omega_domain(subkeys, sample_size)
|
|
255
|
+
data = eqx.tree_at(lambda m: m.key, data, new_key)
|
|
256
|
+
else:
|
|
257
|
+
new_omega_samples = data.sample_in_omega_domain(sample_size)
|
|
195
258
|
|
|
196
259
|
# We can have different types of Loss
|
|
197
260
|
if isinstance(loss, LossPDEStatio):
|
|
@@ -209,7 +272,7 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
209
272
|
mse_on_s = (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
|
|
210
273
|
else:
|
|
211
274
|
mse_on_s = dyn_on_s**2
|
|
212
|
-
elif isinstance(loss,
|
|
275
|
+
elif isinstance(loss, SystemLossODE):
|
|
213
276
|
mse_on_s = 0
|
|
214
277
|
for i in loss.dynamic_loss_dict.keys():
|
|
215
278
|
# only the case LossPDEStatio here
|
|
@@ -237,17 +300,30 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
237
300
|
## add the new points in omega
|
|
238
301
|
# start indices of update can be dynamic but not the shape (length)
|
|
239
302
|
# of the slice
|
|
240
|
-
|
|
303
|
+
new_omega = jax.lax.dynamic_update_slice(
|
|
241
304
|
data.omega,
|
|
242
305
|
higher_residual_points,
|
|
243
306
|
(data.n_start + data.rar_iter_nb * selected_sample_size, data.dim),
|
|
244
307
|
)
|
|
245
308
|
|
|
309
|
+
if isinstance(data, eqx.Module):
|
|
310
|
+
data = eqx.tree_at(lambda m: m.omega, data, new_omega)
|
|
311
|
+
else:
|
|
312
|
+
data.omega = new_omega
|
|
313
|
+
|
|
246
314
|
## rearrange probabilities so that the probabilities of the new
|
|
247
315
|
## points are non-zero
|
|
248
316
|
new_proba = 1 / (data.n_start + data.rar_iter_nb * selected_sample_size)
|
|
249
317
|
# the next work because n_start is static
|
|
250
|
-
|
|
318
|
+
new_p_omega = data.p_omega.at[: data.n_start].set(new_proba)
|
|
319
|
+
if isinstance(data, eqx.Module):
|
|
320
|
+
data = eqx.tree_at(
|
|
321
|
+
lambda m: m.p_omega,
|
|
322
|
+
data,
|
|
323
|
+
new_p_omega,
|
|
324
|
+
)
|
|
325
|
+
else:
|
|
326
|
+
data.p_omega = new_p_omega
|
|
251
327
|
|
|
252
328
|
# the next requires a fori_loop because the range is dynamic
|
|
253
329
|
def update_slices(i, p):
|
|
@@ -257,13 +333,24 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
257
333
|
((data.n_start + i * selected_sample_size),),
|
|
258
334
|
)
|
|
259
335
|
|
|
260
|
-
data.rar_iter_nb
|
|
261
|
-
|
|
262
|
-
data.p_omega = jax.lax.fori_loop(
|
|
336
|
+
new_rar_iter_nb = data.rar_iter_nb + 1
|
|
337
|
+
new_p_omega = jax.lax.fori_loop(
|
|
263
338
|
0, data.rar_iter_nb, update_slices, data.p_omega
|
|
264
339
|
)
|
|
340
|
+
if isinstance(data, eqx.Module):
|
|
341
|
+
data = eqx.tree_at(
|
|
342
|
+
lambda m: (m.rar_iter_nb, m.p_omega),
|
|
343
|
+
data,
|
|
344
|
+
(new_rar_iter_nb, new_p_omega),
|
|
345
|
+
)
|
|
346
|
+
else:
|
|
347
|
+
data.rar_iter_nb = new_rar_iter_nb
|
|
348
|
+
data.p_omega = new_p_omega
|
|
265
349
|
|
|
266
350
|
elif isinstance(data, CubicMeshPDENonStatio):
|
|
351
|
+
if isinstance(loss.u, HYPERPINN) or isinstance(loss.u, SPINN):
|
|
352
|
+
raise NotImplementedError("RAR not implemented for hyperPINN and SPINN")
|
|
353
|
+
|
|
267
354
|
# NOTE in this case sample_size and selected_sample_size
|
|
268
355
|
# are tuples (times, omega) => we unpack them for clarity
|
|
269
356
|
selected_sample_size_times, selected_sample_size_omega = (
|
|
@@ -271,17 +358,29 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
271
358
|
)
|
|
272
359
|
sample_size_times, sample_size_omega = sample_size
|
|
273
360
|
|
|
274
|
-
|
|
275
|
-
|
|
361
|
+
if isinstance(data, eqx.Module):
|
|
362
|
+
new_key, subkey = jax.random.split(data.key)
|
|
363
|
+
new_times_samples = data.sample_in_time_domain(
|
|
364
|
+
subkey, sample_size_times
|
|
365
|
+
)
|
|
366
|
+
new_key, *subkeys = jax.random.split(new_key, data.dim + 1)
|
|
367
|
+
new_omega_samples = data.sample_in_omega_domain(
|
|
368
|
+
subkeys, sample_size_omega
|
|
369
|
+
)
|
|
370
|
+
data = eqx.tree_at(lambda m: m.key, data, new_key)
|
|
371
|
+
else:
|
|
372
|
+
new_times_samples = data.sample_in_time_domain(sample_size_times)
|
|
373
|
+
new_omega_samples = data.sample_in_omega_domain(sample_size_omega)
|
|
276
374
|
|
|
277
|
-
if
|
|
278
|
-
|
|
375
|
+
if not data.cartesian_product:
|
|
376
|
+
times = new_times_samples
|
|
377
|
+
omega = new_omega_samples
|
|
279
378
|
else:
|
|
280
379
|
# do cartesian product on new points
|
|
281
|
-
|
|
380
|
+
omega = jnp.tile(
|
|
282
381
|
new_omega_samples, reps=(sample_size_times, 1)
|
|
283
382
|
) # it is tiled
|
|
284
|
-
|
|
383
|
+
times = jnp.repeat(new_times_samples, sample_size_omega, axis=0)[
|
|
285
384
|
..., None
|
|
286
385
|
] # it is repeated + add an axis
|
|
287
386
|
|
|
@@ -291,7 +390,7 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
291
390
|
(0, 0),
|
|
292
391
|
0,
|
|
293
392
|
)
|
|
294
|
-
dyn_on_s = v_dyn_loss(
|
|
393
|
+
dyn_on_s = v_dyn_loss(times, omega).reshape(
|
|
295
394
|
(sample_size_times, sample_size_omega)
|
|
296
395
|
)
|
|
297
396
|
mse_on_s = dyn_on_s**2
|
|
@@ -305,7 +404,7 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
305
404
|
(0, 0),
|
|
306
405
|
0,
|
|
307
406
|
)
|
|
308
|
-
dyn_on_s += v_dyn_loss(
|
|
407
|
+
dyn_on_s += v_dyn_loss(times, omega).reshape(
|
|
309
408
|
(sample_size_times, sample_size_omega)
|
|
310
409
|
)
|
|
311
410
|
|
|
@@ -338,14 +437,23 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
338
437
|
## add the new points in times
|
|
339
438
|
# start indices of update can be dynamic but not the shape (length)
|
|
340
439
|
# of the slice
|
|
341
|
-
|
|
440
|
+
new_times = jax.lax.dynamic_update_slice(
|
|
342
441
|
data.times,
|
|
343
442
|
higher_residual_points_times,
|
|
344
|
-
(
|
|
443
|
+
(
|
|
444
|
+
data.n_start
|
|
445
|
+
+ data.rar_iter_nb # NOTE typo here nt_start ?
|
|
446
|
+
* selected_sample_size_times,
|
|
447
|
+
),
|
|
345
448
|
)
|
|
346
449
|
|
|
450
|
+
if isinstance(data, eqx.Module):
|
|
451
|
+
data = eqx.tree_at(lambda m: m.times, data, new_times)
|
|
452
|
+
else:
|
|
453
|
+
data.times = new_times
|
|
454
|
+
|
|
347
455
|
## add the new points in omega
|
|
348
|
-
|
|
456
|
+
new_omega = jax.lax.dynamic_update_slice(
|
|
349
457
|
data.omega,
|
|
350
458
|
higher_residual_points_omega,
|
|
351
459
|
(
|
|
@@ -354,19 +462,38 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
354
462
|
),
|
|
355
463
|
)
|
|
356
464
|
|
|
465
|
+
if isinstance(data, eqx.Module):
|
|
466
|
+
data = eqx.tree_at(lambda m: m.omega, data, new_omega)
|
|
467
|
+
else:
|
|
468
|
+
data.omega = new_omega
|
|
469
|
+
|
|
357
470
|
## rearrange probabilities so that the probabilities of the new
|
|
358
471
|
## points are non-zero
|
|
359
472
|
new_p_times = 1 / (
|
|
360
473
|
data.nt_start + data.rar_iter_nb * selected_sample_size_times
|
|
361
474
|
)
|
|
362
475
|
# the next work because nt_start is static
|
|
363
|
-
|
|
476
|
+
if isinstance(data, eqx.Module):
|
|
477
|
+
data = eqx.tree_at(
|
|
478
|
+
lambda m: m.p_times,
|
|
479
|
+
data,
|
|
480
|
+
data.p_times.at[: data.nt_start].set(new_p_times),
|
|
481
|
+
)
|
|
482
|
+
else:
|
|
483
|
+
data.p_times = data.p_times.at[: data.nt_start].set(new_p_times)
|
|
364
484
|
|
|
365
485
|
# same for p_omega (work because n_start is static)
|
|
366
486
|
new_p_omega = 1 / (
|
|
367
487
|
data.n_start + data.rar_iter_nb * selected_sample_size_omega
|
|
368
488
|
)
|
|
369
|
-
|
|
489
|
+
if isinstance(data, eqx.Module):
|
|
490
|
+
data = eqx.tree_at(
|
|
491
|
+
lambda m: m.p_omega,
|
|
492
|
+
data,
|
|
493
|
+
data.p_omega.at[: data.n_start].set(new_p_omega),
|
|
494
|
+
)
|
|
495
|
+
else:
|
|
496
|
+
data.p_omega = data.p_omega.at[: data.n_start].set(new_p_omega)
|
|
370
497
|
|
|
371
498
|
# the part of data.p_* after n_start requires a fori_loop because
|
|
372
499
|
# the range is dynamic
|
|
@@ -381,13 +508,13 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
381
508
|
|
|
382
509
|
return update_slices
|
|
383
510
|
|
|
384
|
-
data.rar_iter_nb
|
|
511
|
+
new_rar_iter_nb = data.rar_iter_nb + 1
|
|
385
512
|
|
|
386
513
|
## update rest of p_times
|
|
387
514
|
update_slices_times = create_update_slices(
|
|
388
515
|
new_p_times, selected_sample_size_times
|
|
389
516
|
)
|
|
390
|
-
|
|
517
|
+
new_p_times = jax.lax.fori_loop(
|
|
391
518
|
0,
|
|
392
519
|
data.rar_iter_nb,
|
|
393
520
|
update_slices_times,
|
|
@@ -397,21 +524,34 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
397
524
|
update_slices_omega = create_update_slices(
|
|
398
525
|
new_p_omega, selected_sample_size_omega
|
|
399
526
|
)
|
|
400
|
-
|
|
527
|
+
new_p_omega = jax.lax.fori_loop(
|
|
401
528
|
0,
|
|
402
529
|
data.rar_iter_nb,
|
|
403
530
|
update_slices_omega,
|
|
404
531
|
data.p_omega,
|
|
405
532
|
)
|
|
533
|
+
if isinstance(data, eqx.Module):
|
|
534
|
+
data = eqx.tree_at(
|
|
535
|
+
lambda m: (m.rar_iter_nb, m.p_omega, m.p_times),
|
|
536
|
+
data,
|
|
537
|
+
(new_rar_iter_nb, new_p_omega, new_p_times),
|
|
538
|
+
)
|
|
539
|
+
else:
|
|
540
|
+
data.rar_iter_nb = new_rar_iter_nb
|
|
541
|
+
data.p_times = new_p_times
|
|
542
|
+
data.p_omega = new_p_omega
|
|
406
543
|
|
|
407
544
|
# update RAR parameters for all cases
|
|
408
|
-
data.
|
|
545
|
+
if isinstance(data, eqx.Module):
|
|
546
|
+
data = eqx.tree_at(lambda m: m.rar_iter_from_last_sampling, data, 0)
|
|
547
|
+
else:
|
|
548
|
+
data.rar_iter_from_last_sampling = 0
|
|
409
549
|
|
|
410
550
|
# NOTE must return data to be correctly updated because we cannot
|
|
411
551
|
# have side effects in this function that will be jitted
|
|
412
552
|
return data
|
|
413
553
|
|
|
414
|
-
def rar_step_false(operands):
|
|
554
|
+
def rar_step_false(operands: rar_operands) -> AnyDataGenerator:
|
|
415
555
|
_, _, data, i = operands
|
|
416
556
|
|
|
417
557
|
# Add 1 only if we are after the burn in period
|
|
@@ -421,7 +561,15 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
421
561
|
lambda: 1,
|
|
422
562
|
)
|
|
423
563
|
|
|
424
|
-
data.rar_iter_from_last_sampling
|
|
564
|
+
new_rar_iter_from_last_sampling = data.rar_iter_from_last_sampling + increment
|
|
565
|
+
if isinstance(data, eqx.Module):
|
|
566
|
+
data = eqx.tree_at(
|
|
567
|
+
lambda m: m.rar_iter_from_last_sampling,
|
|
568
|
+
data,
|
|
569
|
+
new_rar_iter_from_last_sampling,
|
|
570
|
+
)
|
|
571
|
+
else:
|
|
572
|
+
data.rar_iter_from_last_sampling = new_rar_iter_from_last_sampling
|
|
425
573
|
return data
|
|
426
574
|
|
|
427
575
|
return rar_step_true, rar_step_false
|