jinns 1.1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/solver/_rar.py CHANGED
@@ -13,6 +13,7 @@ from jaxtyping import Int, Bool
13
13
  from jinns.data._Batchs import *
14
14
  from jinns.loss._LossODE import LossODE, SystemLossODE
15
15
  from jinns.loss._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
16
+ from jinns.loss._loss_utils import dynamic_loss_apply
16
17
  from jinns.data._DataGenerators import (
17
18
  DataGeneratorODE,
18
19
  CubicMeshPDEStatio,
@@ -30,7 +31,7 @@ def _proceed_to_rar(data: AnyDataGenerator, i: Int) -> Bool:
30
31
  """Utilility function with various check to ensure we can proceed with the rar_step.
31
32
  Return True if yes, and False otherwise"""
32
33
 
33
- # Overall checks (universal for any data generator)
34
+ # Overall checks
34
35
  check_list = [
35
36
  # check if burn-in period has ended
36
37
  data.rar_parameters["start_iter"] <= i,
@@ -38,22 +39,12 @@ def _proceed_to_rar(data: AnyDataGenerator, i: Int) -> Bool:
38
39
  (data.rar_parameters["update_every"] - 1) == data.rar_iter_from_last_sampling,
39
40
  ]
40
41
 
41
- # Memory allocation checks (depends on the type of DataGenerator)
42
+ # Memory allocation checks
42
43
  # check if we still have room to append new collocation points in the
43
- # allocated jnp.array (can concern `data.p_times` or `p_omega`)
44
- if isinstance(data, (DataGeneratorODE, CubicMeshPDENonStatio)):
45
- check_list.append(
46
- data.rar_parameters["selected_sample_size_times"]
47
- <= jnp.count_nonzero(data.p_times == 0),
48
- )
49
-
50
- if isinstance(data, (CubicMeshPDEStatio, CubicMeshPDENonStatio)):
51
- # for now the above check are redundants but there may be a time when
52
- # we drop inheritence
53
- check_list.append(
54
- data.rar_parameters["selected_sample_size_omega"]
55
- <= jnp.count_nonzero(data.p_omega == 0),
56
- )
44
+ # allocated jnp.array
45
+ check_list.append(
46
+ data.rar_parameters["selected_sample_size"] <= jnp.count_nonzero(data.p == 0),
47
+ )
57
48
 
58
49
  proceed = jnp.all(jnp.array(check_list))
59
50
  return proceed
@@ -99,38 +90,12 @@ def init_rar(
99
90
  if data.rar_parameters is None:
100
91
  _rar_step_true, _rar_step_false = None, None
101
92
  else:
102
- if isinstance(data, DataGeneratorODE):
103
- # In this case we only need rar parameters related to `times`
104
- _rar_step_true, _rar_step_false = _rar_step_init(
105
- data.rar_parameters["sample_size_times"],
106
- data.rar_parameters["selected_sample_size_times"],
107
- )
108
- elif isinstance(data, CubicMeshPDENonStatio):
109
- # In this case we need rar parameters related to both `times`
110
- # and`omega`
111
- _rar_step_true, _rar_step_false = _rar_step_init(
112
- (
113
- data.rar_parameters["sample_size_times"],
114
- data.rar_parameters["sample_size_omega"],
115
- ),
116
- (
117
- data.rar_parameters["selected_sample_size_times"],
118
- data.rar_parameters["selected_sample_size_omega"],
119
- ),
120
- )
121
- elif isinstance(data, CubicMeshPDEStatio):
122
- # In this case we only need rar parameters related to `omega`
123
- _rar_step_true, _rar_step_false = _rar_step_init(
124
- data.rar_parameters["sample_size_omega"],
125
- data.rar_parameters["selected_sample_size_omega"],
126
- )
127
- else:
128
- raise ValueError(f"Wrong type for data got {type(data)}")
93
+ _rar_step_true, _rar_step_false = _rar_step_init(
94
+ data.rar_parameters["sample_size"],
95
+ data.rar_parameters["selected_sample_size"],
96
+ )
129
97
 
130
- if isinstance(data, eqx.Module):
131
- data = eqx.tree_at(lambda m: m.rar_iter_from_last_sampling, data, 0)
132
- else:
133
- data.rar_iter_from_last_sampling = 0
98
+ data = eqx.tree_at(lambda m: m.rar_iter_from_last_sampling, data, 0)
134
99
 
135
100
  return data, _rar_step_true, _rar_step_false
136
101
 
@@ -150,402 +115,132 @@ def _rar_step_init(sample_size: Int, selected_sample_size: Int) -> tuple[
150
115
 
151
116
  def rar_step_true(operands: rar_operands) -> AnyDataGenerator:
152
117
  loss, params, data, i = operands
118
+ if isinstance(loss.u, HYPERPINN) or isinstance(loss.u, SPINN):
119
+ raise NotImplementedError("RAR not implemented for hyperPINN and SPINN")
153
120
 
154
121
  if isinstance(data, DataGeneratorODE):
155
122
 
156
- if isinstance(data, eqx.Module):
157
- new_key, subkey = jax.random.split(data.key)
158
- new_omega_samples = data.sample_in_time_domain(subkey, sample_size)
159
- data = eqx.tree_at(lambda m: m.key, data, new_key)
160
- else:
161
- new_omega_samples = data.sample_in_time_domain(sample_size)
123
+ new_key, subkey = jax.random.split(data.key)
124
+ new_samples = data.sample_in_time_domain(subkey, sample_size)
125
+ data = eqx.tree_at(lambda m: m.key, data, new_key)
162
126
 
163
- # We can have different types of Loss
164
- if isinstance(loss, LossODE):
165
- v_dyn_loss = vmap(
166
- lambda t: loss.dynamic_loss.evaluate(t, loss.u, params),
167
- (0),
168
- 0,
169
- )
170
- dyn_on_s = v_dyn_loss(new_omega_samples)
171
- if dyn_on_s.ndim > 1:
172
- mse_on_s = (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
173
- else:
174
- mse_on_s = dyn_on_s**2
175
- elif isinstance(loss, SystemLossODE):
176
- mse_on_s = 0
177
-
178
- for i in loss.dynamic_loss_dict.keys():
179
- v_dyn_loss = vmap(
180
- lambda t: loss.dynamic_loss_dict[i].evaluate(
181
- t, loss.u_dict, params
182
- ),
183
- (0),
184
- 0,
185
- )
186
- dyn_on_s = v_dyn_loss(new_omega_samples)
187
- if dyn_on_s.ndim > 1:
188
- mse_on_s += (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
189
- else:
190
- mse_on_s += dyn_on_s**2
191
-
192
- ## Select the m points with higher dynamic loss
193
- higher_residual_idx = jax.lax.dynamic_slice(
194
- jnp.argsort(mse_on_s),
195
- (mse_on_s.shape[0] - selected_sample_size,),
196
- (selected_sample_size,),
197
- )
198
- higher_residual_points = new_omega_samples[higher_residual_idx]
127
+ elif isinstance(data, CubicMeshPDEStatio) and not isinstance(
128
+ data, CubicMeshPDENonStatio
129
+ ):
130
+ new_key, *subkeys = jax.random.split(data.key, data.dim + 1)
131
+ new_samples = data.sample_in_omega_domain(subkeys, sample_size)
132
+ data = eqx.tree_at(lambda m: m.key, data, new_key)
199
133
 
200
- ## add the new points in times
201
- # start indices of update can be dynamic but the the shape (length)
202
- # of the slice
203
- new_times = jax.lax.dynamic_update_slice(
204
- data.times,
205
- higher_residual_points,
206
- (data.nt_start + data.rar_iter_nb * selected_sample_size,),
134
+ elif isinstance(data, CubicMeshPDENonStatio):
135
+ new_key, subkey = jax.random.split(data.key)
136
+ new_samples_times = data.sample_in_time_domain(subkey, sample_size)
137
+ if data.dim == 1:
138
+ new_key, subkeys = jax.random.split(new_key, 2)
139
+ else:
140
+ new_key, *subkeys = jax.random.split(new_key, data.dim + 1)
141
+ new_samples_omega = data.sample_in_omega_domain(subkeys, sample_size)
142
+ new_samples = jnp.concatenate(
143
+ [new_samples_times, new_samples_omega], axis=1
207
144
  )
208
145
 
209
- if isinstance(data, eqx.Module):
210
- data = eqx.tree_at(lambda m: m.times, data, new_times)
211
- else:
212
- data.times = new_times
213
- ## rearrange probabilities so that the probabilities of the new
214
- ## points are non-zero
215
- new_proba = 1 / (data.nt_start + data.rar_iter_nb * selected_sample_size)
216
- # the next work because nt_start is static
217
- new_p_times = data.p_times.at[: data.nt_start].set(new_proba)
218
- if isinstance(data, eqx.Module):
219
- data = eqx.tree_at(
220
- lambda m: m.p_times,
221
- data,
222
- new_p_times,
223
- )
224
- else:
225
- data.p_times = new_p_times
226
-
227
- # the next requires a fori_loop because the range is dynamic
228
- def update_slices(i, p):
229
- return jax.lax.dynamic_update_slice(
230
- p,
231
- 1 / new_proba * jnp.ones((selected_sample_size,)),
232
- ((data.nt_start + i * selected_sample_size),),
233
- )
146
+ data = eqx.tree_at(lambda m: m.key, data, new_key)
234
147
 
235
- new_rar_iter_nb = data.rar_iter_nb + 1
236
- new_p_times = jax.lax.fori_loop(
237
- 0, data.rar_iter_nb, update_slices, data.p_times
148
+ # We can have different types of Loss
149
+ if isinstance(loss, (LossODE, LossPDEStatio, LossPDENonStatio)):
150
+ v_dyn_loss = vmap(
151
+ lambda inputs: loss.dynamic_loss.evaluate(inputs, loss.u, params),
238
152
  )
239
- if isinstance(data, eqx.Module):
240
- data = eqx.tree_at(
241
- lambda m: (m.rar_iter_nb, m.p_times),
242
- data,
243
- (new_rar_iter_nb, new_p_times),
244
- )
245
- else:
246
- data.rar_iter_nb = new_rar_iter_nb
247
- data.p_times = new_p_times
153
+ dyn_on_s = v_dyn_loss(new_samples)
248
154
 
249
- elif isinstance(data, CubicMeshPDEStatio) and not isinstance(
250
- data, CubicMeshPDENonStatio
251
- ):
252
- if isinstance(data, eqx.Module):
253
- new_key, *subkeys = jax.random.split(data.key, data.dim + 1)
254
- new_omega_samples = data.sample_in_omega_domain(subkeys, sample_size)
255
- data = eqx.tree_at(lambda m: m.key, data, new_key)
155
+ if dyn_on_s.ndim > 1:
156
+ mse_on_s = (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
256
157
  else:
257
- new_omega_samples = data.sample_in_omega_domain(sample_size)
158
+ mse_on_s = dyn_on_s**2
159
+ elif isinstance(loss, SystemLossODE, SystemLossPDE):
160
+ mse_on_s = 0
258
161
 
259
- # We can have different types of Loss
260
- if isinstance(loss, LossPDEStatio):
162
+ for i in loss.dynamic_loss_dict.keys():
261
163
  v_dyn_loss = vmap(
262
- lambda x: loss.dynamic_loss.evaluate(
263
- x,
264
- loss.u,
265
- params,
164
+ lambda inputs: loss.dynamic_loss_dict[i].evaluate(
165
+ inputs, loss.u_dict, params
266
166
  ),
267
167
  (0),
268
168
  0,
269
169
  )
270
- dyn_on_s = v_dyn_loss(new_omega_samples)
170
+ dyn_on_s = v_dyn_loss(new_samples)
271
171
  if dyn_on_s.ndim > 1:
272
- mse_on_s = (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
172
+ mse_on_s += (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
273
173
  else:
274
- mse_on_s = dyn_on_s**2
275
- elif isinstance(loss, SystemLossODE):
276
- mse_on_s = 0
277
- for i in loss.dynamic_loss_dict.keys():
278
- # only the case LossPDEStatio here
279
- v_dyn_loss = vmap(
280
- lambda x: loss.dynamic_loss_dict[i].evaluate(
281
- x, loss.u_dict, params
282
- ),
283
- 0,
284
- 0,
285
- )
286
- dyn_on_s = v_dyn_loss(new_omega_samples)
287
- if dyn_on_s.ndim > 1:
288
- mse_on_s += (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
289
- else:
290
- mse_on_s += dyn_on_s**2
291
-
292
- ## Select the m points with higher dynamic loss
293
- higher_residual_idx = jax.lax.dynamic_slice(
294
- jnp.argsort(mse_on_s),
295
- (mse_on_s.shape[0] - selected_sample_size,),
296
- (selected_sample_size,),
174
+ mse_on_s += dyn_on_s**2
175
+
176
+ ## Select the m points with higher dynamic loss
177
+ higher_residual_idx = jax.lax.dynamic_slice(
178
+ jnp.argsort(mse_on_s),
179
+ (mse_on_s.shape[0] - selected_sample_size,),
180
+ (selected_sample_size,),
181
+ )
182
+ higher_residual_points = new_samples[higher_residual_idx]
183
+
184
+ # add the new points
185
+ # start indices of update can be dynamic but the the shape (length)
186
+ # of the slice
187
+ if isinstance(data, DataGeneratorODE):
188
+ new_times = jax.lax.dynamic_update_slice(
189
+ data.times,
190
+ higher_residual_points,
191
+ (data.n_start + data.rar_iter_nb * selected_sample_size,),
297
192
  )
298
- higher_residual_points = new_omega_samples[higher_residual_idx]
299
193
 
300
- ## add the new points in omega
301
- # start indices of update can be dynamic but not the shape (length)
302
- # of the slice
194
+ data = eqx.tree_at(lambda m: m.times, data, new_times)
195
+ elif isinstance(data, CubicMeshPDEStatio) and not isinstance(
196
+ data, CubicMeshPDENonStatio
197
+ ):
303
198
  new_omega = jax.lax.dynamic_update_slice(
304
199
  data.omega,
305
200
  higher_residual_points,
306
201
  (data.n_start + data.rar_iter_nb * selected_sample_size, data.dim),
307
202
  )
308
203
 
309
- if isinstance(data, eqx.Module):
310
- data = eqx.tree_at(lambda m: m.omega, data, new_omega)
311
- else:
312
- data.omega = new_omega
313
-
314
- ## rearrange probabilities so that the probabilities of the new
315
- ## points are non-zero
316
- new_proba = 1 / (data.n_start + data.rar_iter_nb * selected_sample_size)
317
- # the next work because n_start is static
318
- new_p_omega = data.p_omega.at[: data.n_start].set(new_proba)
319
- if isinstance(data, eqx.Module):
320
- data = eqx.tree_at(
321
- lambda m: m.p_omega,
322
- data,
323
- new_p_omega,
324
- )
325
- else:
326
- data.p_omega = new_p_omega
327
-
328
- # the next requires a fori_loop because the range is dynamic
329
- def update_slices(i, p):
330
- return jax.lax.dynamic_update_slice(
331
- p,
332
- 1 / new_proba * jnp.ones((selected_sample_size,)),
333
- ((data.n_start + i * selected_sample_size),),
334
- )
335
-
336
- new_rar_iter_nb = data.rar_iter_nb + 1
337
- new_p_omega = jax.lax.fori_loop(
338
- 0, data.rar_iter_nb, update_slices, data.p_omega
339
- )
340
- if isinstance(data, eqx.Module):
341
- data = eqx.tree_at(
342
- lambda m: (m.rar_iter_nb, m.p_omega),
343
- data,
344
- (new_rar_iter_nb, new_p_omega),
345
- )
346
- else:
347
- data.rar_iter_nb = new_rar_iter_nb
348
- data.p_omega = new_p_omega
204
+ data = eqx.tree_at(lambda m: m.omega, data, new_omega)
349
205
 
350
206
  elif isinstance(data, CubicMeshPDENonStatio):
351
- if isinstance(loss.u, HYPERPINN) or isinstance(loss.u, SPINN):
352
- raise NotImplementedError("RAR not implemented for hyperPINN and SPINN")
353
-
354
- # NOTE in this case sample_size and selected_sample_size
355
- # are tuples (times, omega) => we unpack them for clarity
356
- selected_sample_size_times, selected_sample_size_omega = (
357
- selected_sample_size
358
- )
359
- sample_size_times, sample_size_omega = sample_size
360
-
361
- if isinstance(data, eqx.Module):
362
- new_key, subkey = jax.random.split(data.key)
363
- new_times_samples = data.sample_in_time_domain(
364
- subkey, sample_size_times
365
- )
366
- new_key, *subkeys = jax.random.split(new_key, data.dim + 1)
367
- new_omega_samples = data.sample_in_omega_domain(
368
- subkeys, sample_size_omega
369
- )
370
- data = eqx.tree_at(lambda m: m.key, data, new_key)
371
- else:
372
- new_times_samples = data.sample_in_time_domain(sample_size_times)
373
- new_omega_samples = data.sample_in_omega_domain(sample_size_omega)
374
-
375
- if not data.cartesian_product:
376
- times = new_times_samples
377
- omega = new_omega_samples
378
- else:
379
- # do cartesian product on new points
380
- omega = jnp.tile(
381
- new_omega_samples, reps=(sample_size_times, 1)
382
- ) # it is tiled
383
- times = jnp.repeat(new_times_samples, sample_size_omega, axis=0)[
384
- ..., None
385
- ] # it is repeated + add an axis
386
-
387
- if isinstance(loss, LossPDENonStatio):
388
- v_dyn_loss = vmap(
389
- lambda t, x: loss.dynamic_loss.evaluate(t, x, loss.u, params),
390
- (0, 0),
391
- 0,
392
- )
393
- dyn_on_s = v_dyn_loss(times, omega).reshape(
394
- (sample_size_times, sample_size_omega)
395
- )
396
- mse_on_s = dyn_on_s**2
397
- elif isinstance(loss, SystemLossPDE):
398
- dyn_on_s = jnp.zeros((sample_size_times, sample_size_omega))
399
- for i in loss.dynamic_loss_dict.keys():
400
- v_dyn_loss = vmap(
401
- lambda t, x: loss.dynamic_loss_dict[i].evaluate(
402
- t, x, loss.u_dict, params
403
- ),
404
- (0, 0),
405
- 0,
406
- )
407
- dyn_on_s += v_dyn_loss(times, omega).reshape(
408
- (sample_size_times, sample_size_omega)
409
- )
410
-
411
- mse_on_s = dyn_on_s**2
412
- # -- Select the m points with highest average residuals on time and
413
- # -- space (times in rows / omega in columns)
414
- # mean_times = mse_on_s.mean(axis=1)
415
- # mean_omega = mse_on_s.mean(axis=0)
416
- # times_idx = jax.lax.dynamic_slice(
417
- # jnp.argsort(mean_times),
418
- # (mse_on_s.shape[0] - selected_sample_size_times,),
419
- # (selected_sample_size_times,),
420
- # )
421
- # omega_idx = jax.lax.dynamic_slice(
422
- # jnp.argsort(mean_omega),
423
- # (mse_on_s.shape[1] - selected_sample_size_omega,),
424
- # (selected_sample_size_omega,),
425
- # )
426
-
427
- # -- Select the m worst points (t, x) with highest residuals
428
- n_select = max(selected_sample_size_times, selected_sample_size_omega)
429
- _, idx = jax.lax.top_k(mse_on_s.flatten(), k=n_select)
430
- arr_idx = jnp.unravel_index(idx, mse_on_s.shape)
431
- times_idx = arr_idx[0][:selected_sample_size_times]
432
- omega_idx = arr_idx[1][:selected_sample_size_omega]
433
-
434
- higher_residual_points_times = new_times_samples[times_idx]
435
- higher_residual_points_omega = new_omega_samples[omega_idx]
436
-
437
- ## add the new points in times
438
- # start indices of update can be dynamic but not the shape (length)
439
- # of the slice
440
- new_times = jax.lax.dynamic_update_slice(
441
- data.times,
442
- higher_residual_points_times,
443
- (
444
- data.n_start
445
- + data.rar_iter_nb # NOTE typo here nt_start ?
446
- * selected_sample_size_times,
447
- ),
448
- )
449
-
450
- if isinstance(data, eqx.Module):
451
- data = eqx.tree_at(lambda m: m.times, data, new_times)
452
- else:
453
- data.times = new_times
454
-
455
- ## add the new points in omega
456
- new_omega = jax.lax.dynamic_update_slice(
457
- data.omega,
458
- higher_residual_points_omega,
459
- (
460
- data.n_start + data.rar_iter_nb * selected_sample_size_omega,
461
- data.dim,
462
- ),
207
+ new_domain = jax.lax.dynamic_update_slice(
208
+ data.domain,
209
+ higher_residual_points,
210
+ (data.n_start + data.rar_iter_nb * selected_sample_size, 1 + data.dim),
463
211
  )
464
212
 
465
- if isinstance(data, eqx.Module):
466
- data = eqx.tree_at(lambda m: m.omega, data, new_omega)
467
- else:
468
- data.omega = new_omega
213
+ data = eqx.tree_at(lambda m: m.domain, data, new_domain)
214
+
215
+ ## rearrange probabilities so that the probabilities of the new
216
+ ## points are non-zero
217
+ new_proba = 1 / (data.n_start + data.rar_iter_nb * selected_sample_size)
218
+ # the next work because nt_start is static
219
+ new_p = data.p.at[: data.n_start].set(new_proba)
220
+ data = eqx.tree_at(
221
+ lambda m: m.p,
222
+ data,
223
+ new_p,
224
+ )
469
225
 
470
- ## rearrange probabilities so that the probabilities of the new
471
- ## points are non-zero
472
- new_p_times = 1 / (
473
- data.nt_start + data.rar_iter_nb * selected_sample_size_times
226
+ # the next requires a fori_loop because the range is dynamic
227
+ def update_slices(i, p):
228
+ return jax.lax.dynamic_update_slice(
229
+ p,
230
+ 1 / new_proba * jnp.ones((selected_sample_size,)),
231
+ ((data.n_start + i * selected_sample_size),),
474
232
  )
475
- # the next work because nt_start is static
476
- if isinstance(data, eqx.Module):
477
- data = eqx.tree_at(
478
- lambda m: m.p_times,
479
- data,
480
- data.p_times.at[: data.nt_start].set(new_p_times),
481
- )
482
- else:
483
- data.p_times = data.p_times.at[: data.nt_start].set(new_p_times)
484
233
 
485
- # same for p_omega (work because n_start is static)
486
- new_p_omega = 1 / (
487
- data.n_start + data.rar_iter_nb * selected_sample_size_omega
488
- )
489
- if isinstance(data, eqx.Module):
490
- data = eqx.tree_at(
491
- lambda m: m.p_omega,
492
- data,
493
- data.p_omega.at[: data.n_start].set(new_p_omega),
494
- )
495
- else:
496
- data.p_omega = data.p_omega.at[: data.n_start].set(new_p_omega)
497
-
498
- # the part of data.p_* after n_start requires a fori_loop because
499
- # the range is dynamic
500
- def create_update_slices(new_val, selected_sample_size):
501
- def update_slices(i, p):
502
- new_p = jax.lax.dynamic_update_slice(
503
- p,
504
- new_val * jnp.ones((selected_sample_size,)),
505
- ((data.n_start + i * selected_sample_size),),
506
- )
507
- return new_p
508
-
509
- return update_slices
510
-
511
- new_rar_iter_nb = data.rar_iter_nb + 1
512
-
513
- ## update rest of p_times
514
- update_slices_times = create_update_slices(
515
- new_p_times, selected_sample_size_times
516
- )
517
- new_p_times = jax.lax.fori_loop(
518
- 0,
519
- data.rar_iter_nb,
520
- update_slices_times,
521
- data.p_times,
522
- )
523
- ## update rest of p_omega
524
- update_slices_omega = create_update_slices(
525
- new_p_omega, selected_sample_size_omega
526
- )
527
- new_p_omega = jax.lax.fori_loop(
528
- 0,
529
- data.rar_iter_nb,
530
- update_slices_omega,
531
- data.p_omega,
532
- )
533
- if isinstance(data, eqx.Module):
534
- data = eqx.tree_at(
535
- lambda m: (m.rar_iter_nb, m.p_omega, m.p_times),
536
- data,
537
- (new_rar_iter_nb, new_p_omega, new_p_times),
538
- )
539
- else:
540
- data.rar_iter_nb = new_rar_iter_nb
541
- data.p_times = new_p_times
542
- data.p_omega = new_p_omega
234
+ new_rar_iter_nb = data.rar_iter_nb + 1
235
+ new_p = jax.lax.fori_loop(0, new_rar_iter_nb, update_slices, data.p)
236
+ data = eqx.tree_at(
237
+ lambda m: (m.rar_iter_nb, m.p),
238
+ data,
239
+ (new_rar_iter_nb, new_p),
240
+ )
543
241
 
544
242
  # update RAR parameters for all cases
545
- if isinstance(data, eqx.Module):
546
- data = eqx.tree_at(lambda m: m.rar_iter_from_last_sampling, data, 0)
547
- else:
548
- data.rar_iter_from_last_sampling = 0
243
+ data = eqx.tree_at(lambda m: m.rar_iter_from_last_sampling, data, 0)
549
244
 
550
245
  # NOTE must return data to be correctly updated because we cannot
551
246
  # have side effects in this function that will be jitted