jinns 0.9.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +904 -1203
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +321 -168
  9. jinns/loss/_LossODE.py +292 -309
  10. jinns/loss/_LossPDE.py +625 -1010
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +87 -41
  13. jinns/loss/{_Losses.py → _loss_utils.py} +95 -44
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +94 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +183 -39
  22. jinns/solver/_solve.py +151 -124
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -44
  25. jinns/utils/_hyperpinn.py +224 -119
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +113 -86
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +48 -140
  32. {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/METADATA +4 -4
  33. jinns-1.0.0.dist-info/RECORD +38 -0
  34. {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/WHEEL +1 -1
  35. jinns/experimental/_sinuspinn.py +0 -135
  36. jinns/experimental/_spectralpinn.py +0 -87
  37. jinns/solver/_seq2seq.py +0 -157
  38. jinns/utils/_optim.py +0 -147
  39. jinns/utils/_utils_uspinn.py +0 -727
  40. jinns-0.9.0.dist-info/RECORD +0 -36
  41. {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/LICENSE +0 -0
  42. {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  """
2
- Utility functions for plotting
2
+ Utility functions for plotting in 1D and 2D, with and without time.
3
3
  """
4
4
 
5
5
  from functools import partial
@@ -8,48 +8,55 @@ import matplotlib.pyplot as plt
8
8
  import jax.numpy as jnp
9
9
  from jax import vmap
10
10
  from mpl_toolkits.axes_grid1 import ImageGrid
11
+ from typing import Callable, List
12
+ from jaxtyping import Array, Float, Bool
11
13
 
12
14
 
13
15
  def plot2d(
14
- fun,
15
- xy_data,
16
- times=None,
17
- Tmax=1,
18
- title="",
19
- figsize=(7, 7),
20
- cmap="inferno",
21
- spinn=False,
22
- vmin_vmax=None,
23
- ax_for_plot=None,
16
+ fun: Callable,
17
+ xy_data: tuple[Float[Array, "nx"], Float[Array, "ny"]],
18
+ times: Float[Array, "nt"] | List[float] | None = None,
19
+ Tmax: float = 1,
20
+ title: str = "",
21
+ figsize: tuple = (7, 7),
22
+ cmap: str = "inferno",
23
+ spinn: bool = False,
24
+ vmin_vmax: tuple[float, float] | None = None,
25
+ ax_for_plot: plt.Axes | None = None,
24
26
  ):
25
27
  r"""Generic function for plotting functions over rectangular 2-D domains
26
- :math:`\Omega`. It treats both the stationary case :math:`u(x)` or the
27
- non-stationnary case :math:`u(t, x)`.
28
+ $\Omega$. It handles both the
28
29
 
29
- When in the non-stationnary case, the `times` argument gives the time
30
- slices :math:`t_i` at which to plot :math:`u(t_i, x)`.
30
+ 1. the stationary case $u(x)$
31
+ 2. the non-stationnary case $u(t, x)$
32
+
33
+ In the non-stationnary case, the `times` argument gives the time
34
+ slices $t_i$ at which to plot $u(t_i, x)$.
31
35
 
32
36
 
33
37
  Parameters
34
38
  ----------
35
- fun : _type_
36
- _description_
37
- xy_data : _type_
38
- _description_
39
- times : _type_, optional
40
- _description_, by default None
41
- Tmax : float, only in non-stationary cases
42
- Useful if you used time re-scaling in the differential equation, by
43
- default 1
44
- title : str, optional
39
+ fun :
40
+ the function $u$ to plot on the meshgrid, and eventually the time
41
+ slices. It's suppose to have signature `u(x)` in the stationnary case,, and `u(t, x)` in the non-stationnary case. Use `partial` or `lambda to freeze / reorder any other arguments.
42
+ xy_data :
43
+ A list of 2 `jnp.Array` providing grid values for meshgrid creation
44
+ times :
45
+ list or Array of time slices where to plot the function. Use Tmax if
46
+ you trained with time-rescaling.
47
+ Tmax :
48
+ Useful if you used time rescaling in the differential equation for training, default to 1 (no rescaling).
49
+ title :
45
50
  plot title, by default ""
46
- figsize : tuple, optional
47
- _description_, by default (7, 7)
48
- cmap : str, optional
49
- _description_, by default "inferno"
50
- vmin_vmax : tuple, optional
51
+ figsize :
52
+ By default (7, 7)
53
+ cmap :
54
+ the matplotlib color map used in the ImageGrid.
55
+ vmin_vmax :
51
56
  The colorbar minimum and maximum value. Defaults None.
52
- ax_for_plot : Matplotlib axis, optional
57
+ spinn :
58
+ True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
59
+ ax_for_plot :
53
60
  If None, jinns triggers the plotting. Otherwise this argument
54
61
  corresponds to the axis which will host the plot. Default is None.
55
62
  NOTE: that this argument will have an effect only if times is None.
@@ -57,7 +64,7 @@ def plot2d(
57
64
  Raises
58
65
  ------
59
66
  ValueError
60
- if xy_data is not a list of type 2
67
+ if xy_data is not a list of length 2
61
68
  """
62
69
 
63
70
  # if not isinstance(xy_data, jnp.ndarray) and not xy_data.shape[-1] == 2:
@@ -179,13 +186,13 @@ def plot2d(
179
186
 
180
187
  def _plot_2D_statio(
181
188
  v_fun,
182
- mesh,
183
- plot=True,
184
- colorbar=True,
185
- cmap="inferno",
186
- figsize=(7, 7),
187
- spinn=False,
188
- vmin_vmax=None,
189
+ mesh: Float[Array, "nx*ny nx*ny"],
190
+ plot: Bool = True,
191
+ colorbar: Bool = True,
192
+ cmap: str = "inferno",
193
+ figsize: tuple[int, int] = (7, 7),
194
+ spinn: Bool = False,
195
+ vmin_vmax: tuple[float, float] = None,
189
196
  ):
190
197
  """Function that plot the function u(x) with 2-D input x using pcolormesh()
191
198
 
@@ -200,6 +207,12 @@ def _plot_2D_statio(
200
207
  either show or return the plot, by default True
201
208
  colorbar : bool, optional
202
209
  add a colorbar, by default True
210
+ cmap :
211
+ the matplotlib color map used in the ImageGrid.
212
+ figsize :
213
+ By default (7, 7)
214
+ spinn :
215
+ True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
203
216
  vmin_vmax: tuple, optional
204
217
  The colorbar minimum and maximum value. Defaults None.
205
218
 
@@ -238,33 +251,37 @@ def _plot_2D_statio(
238
251
 
239
252
 
240
253
  def plot1d_slice(
241
- fun,
242
- xdata,
243
- time_slices=jnp.array([0]),
244
- Tmax=1,
245
- title="",
246
- figsize=(10, 10),
247
- spinn=False,
254
+ fun: Callable[[float, float], float],
255
+ xdata: Float[Array, "nx"],
256
+ time_slices: Float[Array, "nt"] | None = None,
257
+ Tmax: float = 1.0,
258
+ title: str = "",
259
+ figsize: tuple[int, int] = (10, 10),
260
+ spinn: Bool = False,
248
261
  ):
249
262
  """Function for plotting time slices of a function :math:`f(t_i, x)` where
250
- `t` is time (1-D) and x is 1-D
263
+ `t_i` is time (1-D) and x is 1-D
251
264
 
252
265
  Parameters
253
266
  ----------
254
- fun : callable with two arguments `t` and `x`
267
+ fun
255
268
  f(t, x)
256
- xdata : jnp.array
269
+ xdata
257
270
  the discretization of space
258
- time_slices : list, optional
259
- the time slices :math:`t_i` at which to plot, by default [0]
260
- Tmax : int, optional
271
+ time_slices
272
+ the time slices :math:`t_i` at which to plot.
273
+ Tmax
261
274
  Useful if you used time re-scaling in the differential equation, by
262
275
  default 1
263
- title : str, optional
276
+ title
264
277
  title of the plot, by default ""
265
- figsize : tuple, optional
278
+ spinn :
279
+ True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
280
+ figsize
266
281
  size of the figure, by default (10, 10)
267
282
  """
283
+ if time_slices is None:
284
+ time_slices = jnp.array([0])
268
285
  plt.figure(figsize=figsize)
269
286
  for t in time_slices:
270
287
  if not spinn:
@@ -284,16 +301,16 @@ def plot1d_slice(
284
301
 
285
302
 
286
303
  def plot1d_image(
287
- fun,
288
- xdata,
289
- times,
290
- Tmax=1,
291
- title="",
292
- figsize=(10, 10),
293
- colorbar=True,
294
- cmap="inferno",
295
- spinn=False,
296
- vmin_vmax=None,
304
+ fun: Callable[[float, float], float],
305
+ xdata: Float[Array, "nx"],
306
+ times: Float[Array, "nt"],
307
+ Tmax: float = 1.0,
308
+ title: str = "",
309
+ figsize: tuple[int, int] = (10, 10),
310
+ colorbar: Bool = True,
311
+ cmap: str = "inferno",
312
+ spinn: Bool = False,
313
+ vmin_vmax: tuple[float, float] = None,
297
314
  ):
298
315
  """Function for plotting the 2-D image of a function :math:`f(t, x)` where
299
316
  `t` is time (1-D) and x is space (1-D).
@@ -302,19 +319,25 @@ def plot1d_image(
302
319
 
303
320
  Parameters
304
321
  ----------
305
- fun : callable with two arguments t and x
306
- the function to plot
307
- xdata : jnp.array
322
+ fun :
323
+ callable with two arguments t and x the function to plot
324
+ xdata :
308
325
  the discretization of space
309
- times : jnp.array
326
+ times :
310
327
  the discretization of time
311
- Tmax : int, optional
312
- _description_, by default 1
313
- title : str, optional
314
- , by default ""
315
- figsize : tuple, optional
316
- , by default (10, 10)
317
- vmin_vmax: tuple
328
+ Tmax :
329
+ by default 1
330
+ title :
331
+ by default ""
332
+ figsize :
333
+ by default (10, 10)
334
+ colorbar :
335
+ Whether to add a colobar
336
+ cmap :
337
+ the matplotlib color map used in the ImageGrid.
338
+ spinn :
339
+ True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
340
+ vmin_vmax:
318
341
  The colorbar minimum and maximum value. Defaults None.
319
342
  """
320
343
 
jinns/solver/_rar.py CHANGED
@@ -1,21 +1,32 @@
1
+ from __future__ import (
2
+ annotations,
3
+ ) # https://docs.python.org/3/library/typing.html#constant
4
+
5
+ from typing import TYPE_CHECKING, Callable
6
+ from functools import partial
1
7
  import jax
2
8
  from jax import vmap
3
9
  import jax.numpy as jnp
10
+ import equinox as eqx
11
+ from jaxtyping import Int, Bool
12
+
13
+ from jinns.data._Batchs import *
14
+ from jinns.loss._LossODE import LossODE, SystemLossODE
15
+ from jinns.loss._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
4
16
  from jinns.data._DataGenerators import (
5
17
  DataGeneratorODE,
6
18
  CubicMeshPDEStatio,
7
19
  CubicMeshPDENonStatio,
8
20
  )
9
- from jinns.loss._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
10
- from jinns.loss._LossODE import LossODE, SystemLossODE
11
- from jinns.loss._DynamicLossAbstract import PDEStatio
12
-
13
- from functools import partial
14
21
  from jinns.utils._hyperpinn import HYPERPINN
15
22
  from jinns.utils._spinn import SPINN
16
23
 
17
24
 
18
- def _proceed_to_rar(data, i):
25
+ if TYPE_CHECKING:
26
+ from jinns.utils._types import *
27
+
28
+
29
+ def _proceed_to_rar(data: AnyDataGenerator, i: Int) -> Bool:
19
30
  """Utilility function with various check to ensure we can proceed with the rar_step.
20
31
  Return True if yes, and False otherwise"""
21
32
 
@@ -30,13 +41,13 @@ def _proceed_to_rar(data, i):
30
41
  # Memory allocation checks (depends on the type of DataGenerator)
31
42
  # check if we still have room to append new collocation points in the
32
43
  # allocated jnp.array (can concern `data.p_times` or `p_omega`)
33
- if isinstance(data, DataGeneratorODE) or isinstance(data, CubicMeshPDENonStatio):
44
+ if isinstance(data, (DataGeneratorODE, CubicMeshPDENonStatio)):
34
45
  check_list.append(
35
46
  data.rar_parameters["selected_sample_size_times"]
36
47
  <= jnp.count_nonzero(data.p_times == 0),
37
48
  )
38
49
 
39
- if isinstance(data, CubicMeshPDEStatio) or isinstance(data, CubicMeshPDENonStatio):
50
+ if isinstance(data, (CubicMeshPDEStatio, CubicMeshPDENonStatio)):
40
51
  # for now the above check are redundants but there may be a time when
41
52
  # we drop inheritence
42
53
  check_list.append(
@@ -49,7 +60,14 @@ def _proceed_to_rar(data, i):
49
60
 
50
61
 
51
62
  @partial(jax.jit, static_argnames=["_rar_step_true", "_rar_step_false"])
52
- def trigger_rar(i, loss, params, data, _rar_step_true, _rar_step_false):
63
+ def trigger_rar(
64
+ i: Int,
65
+ loss: AnyLoss,
66
+ params: AnyParams,
67
+ data: AnyDataGenerator,
68
+ _rar_step_true: Callable[[rar_operands], AnyDataGenerator],
69
+ _rar_step_false: Callable[[rar_operands], AnyDataGenerator],
70
+ ) -> tuple[AnyLoss, AnyParams, AnyDataGenerator]:
53
71
 
54
72
  if data.rar_parameters is None:
55
73
  # do nothing.
@@ -65,7 +83,13 @@ def trigger_rar(i, loss, params, data, _rar_step_true, _rar_step_false):
65
83
  return loss, params, data
66
84
 
67
85
 
68
- def init_rar(data):
86
+ def init_rar(
87
+ data: AnyDataGenerator,
88
+ ) -> tuple[
89
+ AnyDataGenerator,
90
+ Callable[[rar_operands], AnyDataGenerator],
91
+ Callable[[rar_operands], AnyDataGenerator],
92
+ ]:
69
93
  """
70
94
  Separated from the main rar, because the initialization to get _true and
71
95
  _false cannot be jit-ted.
@@ -100,13 +124,21 @@ def init_rar(data):
100
124
  data.rar_parameters["sample_size_omega"],
101
125
  data.rar_parameters["selected_sample_size_omega"],
102
126
  )
127
+ else:
128
+ raise ValueError(f"Wrong type for data got {type(data)}")
103
129
 
104
- data.rar_parameters["iter_from_last_sampling"] = 0
130
+ if isinstance(data, eqx.Module):
131
+ data = eqx.tree_at(lambda m: m.rar_iter_from_last_sampling, data, 0)
132
+ else:
133
+ data.rar_iter_from_last_sampling = 0
105
134
 
106
135
  return data, _rar_step_true, _rar_step_false
107
136
 
108
137
 
109
- def _rar_step_init(sample_size, selected_sample_size):
138
+ def _rar_step_init(sample_size: Int, selected_sample_size: Int) -> tuple[
139
+ Callable[[rar_operands], AnyDataGenerator],
140
+ Callable[[rar_operands], AnyDataGenerator],
141
+ ]:
110
142
  """
111
143
  This is a wrapper because the sampling size and
112
144
  selected_sample_size, must be treated as static
@@ -116,11 +148,17 @@ def _rar_step_init(sample_size, selected_sample_size):
116
148
  This is a kind of manual declaration of static argnums
117
149
  """
118
150
 
119
- def rar_step_true(operands):
151
+ def rar_step_true(operands: rar_operands) -> AnyDataGenerator:
120
152
  loss, params, data, i = operands
121
153
 
122
154
  if isinstance(data, DataGeneratorODE):
123
- new_omega_samples = data.sample_in_time_domain(sample_size)
155
+
156
+ if isinstance(data, eqx.Module):
157
+ new_key, subkey = jax.random.split(data.key)
158
+ new_omega_samples = data.sample_in_time_domain(subkey, sample_size)
159
+ data = eqx.tree_at(lambda m: m.key, data, new_key)
160
+ else:
161
+ new_omega_samples = data.sample_in_time_domain(sample_size)
124
162
 
125
163
  # We can have different types of Loss
126
164
  if isinstance(loss, LossODE):
@@ -162,17 +200,29 @@ def _rar_step_init(sample_size, selected_sample_size):
162
200
  ## add the new points in times
163
201
  # start indices of update can be dynamic but the the shape (length)
164
202
  # of the slice
165
- data.times = jax.lax.dynamic_update_slice(
203
+ new_times = jax.lax.dynamic_update_slice(
166
204
  data.times,
167
205
  higher_residual_points,
168
206
  (data.nt_start + data.rar_iter_nb * selected_sample_size,),
169
207
  )
170
208
 
209
+ if isinstance(data, eqx.Module):
210
+ data = eqx.tree_at(lambda m: m.times, data, new_times)
211
+ else:
212
+ data.times = new_times
171
213
  ## rearrange probabilities so that the probabilities of the new
172
214
  ## points are non-zero
173
215
  new_proba = 1 / (data.nt_start + data.rar_iter_nb * selected_sample_size)
174
216
  # the next work because nt_start is static
175
- data.p_times = data.p_times.at[: data.nt_start].set(new_proba)
217
+ new_p_times = data.p_times.at[: data.nt_start].set(new_proba)
218
+ if isinstance(data, eqx.Module):
219
+ data = eqx.tree_at(
220
+ lambda m: m.p_times,
221
+ data,
222
+ new_p_times,
223
+ )
224
+ else:
225
+ data.p_times = new_p_times
176
226
 
177
227
  # the next requires a fori_loop because the range is dynamic
178
228
  def update_slices(i, p):
@@ -182,16 +232,29 @@ def _rar_step_init(sample_size, selected_sample_size):
182
232
  ((data.nt_start + i * selected_sample_size),),
183
233
  )
184
234
 
185
- data.rar_iter_nb += 1
186
-
187
- data.p_times = jax.lax.fori_loop(
235
+ new_rar_iter_nb = data.rar_iter_nb + 1
236
+ new_p_times = jax.lax.fori_loop(
188
237
  0, data.rar_iter_nb, update_slices, data.p_times
189
238
  )
239
+ if isinstance(data, eqx.Module):
240
+ data = eqx.tree_at(
241
+ lambda m: (m.rar_iter_nb, m.p_times),
242
+ data,
243
+ (new_rar_iter_nb, new_p_times),
244
+ )
245
+ else:
246
+ data.rar_iter_nb = new_rar_iter_nb
247
+ data.p_times = new_p_times
190
248
 
191
249
  elif isinstance(data, CubicMeshPDEStatio) and not isinstance(
192
250
  data, CubicMeshPDENonStatio
193
251
  ):
194
- new_omega_samples = data.sample_in_omega_domain(sample_size)
252
+ if isinstance(data, eqx.Module):
253
+ new_key, *subkeys = jax.random.split(data.key, data.dim + 1)
254
+ new_omega_samples = data.sample_in_omega_domain(subkeys, sample_size)
255
+ data = eqx.tree_at(lambda m: m.key, data, new_key)
256
+ else:
257
+ new_omega_samples = data.sample_in_omega_domain(sample_size)
195
258
 
196
259
  # We can have different types of Loss
197
260
  if isinstance(loss, LossPDEStatio):
@@ -209,7 +272,7 @@ def _rar_step_init(sample_size, selected_sample_size):
209
272
  mse_on_s = (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
210
273
  else:
211
274
  mse_on_s = dyn_on_s**2
212
- elif isinstance(loss, SystemLossPDE):
275
+ elif isinstance(loss, SystemLossODE):
213
276
  mse_on_s = 0
214
277
  for i in loss.dynamic_loss_dict.keys():
215
278
  # only the case LossPDEStatio here
@@ -237,17 +300,30 @@ def _rar_step_init(sample_size, selected_sample_size):
237
300
  ## add the new points in omega
238
301
  # start indices of update can be dynamic but not the shape (length)
239
302
  # of the slice
240
- data.omega = jax.lax.dynamic_update_slice(
303
+ new_omega = jax.lax.dynamic_update_slice(
241
304
  data.omega,
242
305
  higher_residual_points,
243
306
  (data.n_start + data.rar_iter_nb * selected_sample_size, data.dim),
244
307
  )
245
308
 
309
+ if isinstance(data, eqx.Module):
310
+ data = eqx.tree_at(lambda m: m.omega, data, new_omega)
311
+ else:
312
+ data.omega = new_omega
313
+
246
314
  ## rearrange probabilities so that the probabilities of the new
247
315
  ## points are non-zero
248
316
  new_proba = 1 / (data.n_start + data.rar_iter_nb * selected_sample_size)
249
317
  # the next work because n_start is static
250
- data.p_omega = data.p_omega.at[: data.n_start].set(new_proba)
318
+ new_p_omega = data.p_omega.at[: data.n_start].set(new_proba)
319
+ if isinstance(data, eqx.Module):
320
+ data = eqx.tree_at(
321
+ lambda m: m.p_omega,
322
+ data,
323
+ new_p_omega,
324
+ )
325
+ else:
326
+ data.p_omega = new_p_omega
251
327
 
252
328
  # the next requires a fori_loop because the range is dynamic
253
329
  def update_slices(i, p):
@@ -257,11 +333,19 @@ def _rar_step_init(sample_size, selected_sample_size):
257
333
  ((data.n_start + i * selected_sample_size),),
258
334
  )
259
335
 
260
- data.rar_iter_nb += 1
261
-
262
- data.p_omega = jax.lax.fori_loop(
336
+ new_rar_iter_nb = data.rar_iter_nb + 1
337
+ new_p_omega = jax.lax.fori_loop(
263
338
  0, data.rar_iter_nb, update_slices, data.p_omega
264
339
  )
340
+ if isinstance(data, eqx.Module):
341
+ data = eqx.tree_at(
342
+ lambda m: (m.rar_iter_nb, m.p_omega),
343
+ data,
344
+ (new_rar_iter_nb, new_p_omega),
345
+ )
346
+ else:
347
+ data.rar_iter_nb = new_rar_iter_nb
348
+ data.p_omega = new_p_omega
265
349
 
266
350
  elif isinstance(data, CubicMeshPDENonStatio):
267
351
  if isinstance(loss.u, HYPERPINN) or isinstance(loss.u, SPINN):
@@ -274,8 +358,19 @@ def _rar_step_init(sample_size, selected_sample_size):
274
358
  )
275
359
  sample_size_times, sample_size_omega = sample_size
276
360
 
277
- new_times_samples = data.sample_in_time_domain(sample_size_times)
278
- new_omega_samples = data.sample_in_omega_domain(sample_size_omega)
361
+ if isinstance(data, eqx.Module):
362
+ new_key, subkey = jax.random.split(data.key)
363
+ new_times_samples = data.sample_in_time_domain(
364
+ subkey, sample_size_times
365
+ )
366
+ new_key, *subkeys = jax.random.split(new_key, data.dim + 1)
367
+ new_omega_samples = data.sample_in_omega_domain(
368
+ subkeys, sample_size_omega
369
+ )
370
+ data = eqx.tree_at(lambda m: m.key, data, new_key)
371
+ else:
372
+ new_times_samples = data.sample_in_time_domain(sample_size_times)
373
+ new_omega_samples = data.sample_in_omega_domain(sample_size_omega)
279
374
 
280
375
  if not data.cartesian_product:
281
376
  times = new_times_samples
@@ -342,14 +437,23 @@ def _rar_step_init(sample_size, selected_sample_size):
342
437
  ## add the new points in times
343
438
  # start indices of update can be dynamic but not the shape (length)
344
439
  # of the slice
345
- data.times = jax.lax.dynamic_update_slice(
440
+ new_times = jax.lax.dynamic_update_slice(
346
441
  data.times,
347
442
  higher_residual_points_times,
348
- (data.n_start + data.rar_iter_nb * selected_sample_size_times,),
443
+ (
444
+ data.n_start
445
+ + data.rar_iter_nb # NOTE typo here nt_start ?
446
+ * selected_sample_size_times,
447
+ ),
349
448
  )
350
449
 
450
+ if isinstance(data, eqx.Module):
451
+ data = eqx.tree_at(lambda m: m.times, data, new_times)
452
+ else:
453
+ data.times = new_times
454
+
351
455
  ## add the new points in omega
352
- data.omega = jax.lax.dynamic_update_slice(
456
+ new_omega = jax.lax.dynamic_update_slice(
353
457
  data.omega,
354
458
  higher_residual_points_omega,
355
459
  (
@@ -358,19 +462,38 @@ def _rar_step_init(sample_size, selected_sample_size):
358
462
  ),
359
463
  )
360
464
 
465
+ if isinstance(data, eqx.Module):
466
+ data = eqx.tree_at(lambda m: m.omega, data, new_omega)
467
+ else:
468
+ data.omega = new_omega
469
+
361
470
  ## rearrange probabilities so that the probabilities of the new
362
471
  ## points are non-zero
363
472
  new_p_times = 1 / (
364
473
  data.nt_start + data.rar_iter_nb * selected_sample_size_times
365
474
  )
366
475
  # the next work because nt_start is static
367
- data.p_times = data.p_times.at[: data.nt_start].set(new_p_times)
476
+ if isinstance(data, eqx.Module):
477
+ data = eqx.tree_at(
478
+ lambda m: m.p_times,
479
+ data,
480
+ data.p_times.at[: data.nt_start].set(new_p_times),
481
+ )
482
+ else:
483
+ data.p_times = data.p_times.at[: data.nt_start].set(new_p_times)
368
484
 
369
485
  # same for p_omega (work because n_start is static)
370
486
  new_p_omega = 1 / (
371
487
  data.n_start + data.rar_iter_nb * selected_sample_size_omega
372
488
  )
373
- data.p_omega = data.p_omega.at[: data.n_start].set(new_p_omega)
489
+ if isinstance(data, eqx.Module):
490
+ data = eqx.tree_at(
491
+ lambda m: m.p_omega,
492
+ data,
493
+ data.p_omega.at[: data.n_start].set(new_p_omega),
494
+ )
495
+ else:
496
+ data.p_omega = data.p_omega.at[: data.n_start].set(new_p_omega)
374
497
 
375
498
  # the part of data.p_* after n_start requires a fori_loop because
376
499
  # the range is dynamic
@@ -385,13 +508,13 @@ def _rar_step_init(sample_size, selected_sample_size):
385
508
 
386
509
  return update_slices
387
510
 
388
- data.rar_iter_nb += 1
511
+ new_rar_iter_nb = data.rar_iter_nb + 1
389
512
 
390
513
  ## update rest of p_times
391
514
  update_slices_times = create_update_slices(
392
515
  new_p_times, selected_sample_size_times
393
516
  )
394
- data.p_times = jax.lax.fori_loop(
517
+ new_p_times = jax.lax.fori_loop(
395
518
  0,
396
519
  data.rar_iter_nb,
397
520
  update_slices_times,
@@ -401,21 +524,34 @@ def _rar_step_init(sample_size, selected_sample_size):
401
524
  update_slices_omega = create_update_slices(
402
525
  new_p_omega, selected_sample_size_omega
403
526
  )
404
- data.p_omega = jax.lax.fori_loop(
527
+ new_p_omega = jax.lax.fori_loop(
405
528
  0,
406
529
  data.rar_iter_nb,
407
530
  update_slices_omega,
408
531
  data.p_omega,
409
532
  )
533
+ if isinstance(data, eqx.Module):
534
+ data = eqx.tree_at(
535
+ lambda m: (m.rar_iter_nb, m.p_omega, m.p_times),
536
+ data,
537
+ (new_rar_iter_nb, new_p_omega, new_p_times),
538
+ )
539
+ else:
540
+ data.rar_iter_nb = new_rar_iter_nb
541
+ data.p_times = new_p_times
542
+ data.p_omega = new_p_omega
410
543
 
411
544
  # update RAR parameters for all cases
412
- data.rar_iter_from_last_sampling = 0
545
+ if isinstance(data, eqx.Module):
546
+ data = eqx.tree_at(lambda m: m.rar_iter_from_last_sampling, data, 0)
547
+ else:
548
+ data.rar_iter_from_last_sampling = 0
413
549
 
414
550
  # NOTE must return data to be correctly updated because we cannot
415
551
  # have side effects in this function that will be jitted
416
552
  return data
417
553
 
418
- def rar_step_false(operands):
554
+ def rar_step_false(operands: rar_operands) -> AnyDataGenerator:
419
555
  _, _, data, i = operands
420
556
 
421
557
  # Add 1 only if we are after the burn in period
@@ -425,7 +561,15 @@ def _rar_step_init(sample_size, selected_sample_size):
425
561
  lambda: 1,
426
562
  )
427
563
 
428
- data.rar_iter_from_last_sampling += increment
564
+ new_rar_iter_from_last_sampling = data.rar_iter_from_last_sampling + increment
565
+ if isinstance(data, eqx.Module):
566
+ data = eqx.tree_at(
567
+ lambda m: m.rar_iter_from_last_sampling,
568
+ data,
569
+ new_rar_iter_from_last_sampling,
570
+ )
571
+ else:
572
+ data.rar_iter_from_last_sampling = new_rar_iter_from_last_sampling
429
573
  return data
430
574
 
431
575
  return rar_step_true, rar_step_false