jinns 1.5.0__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. jinns/__init__.py +7 -7
  2. jinns/data/_AbstractDataGenerator.py +1 -1
  3. jinns/data/_Batchs.py +47 -13
  4. jinns/data/_CubicMeshPDENonStatio.py +203 -54
  5. jinns/data/_CubicMeshPDEStatio.py +190 -54
  6. jinns/data/_DataGeneratorODE.py +48 -22
  7. jinns/data/_DataGeneratorObservations.py +75 -32
  8. jinns/data/_DataGeneratorParameter.py +152 -101
  9. jinns/data/__init__.py +2 -1
  10. jinns/data/_utils.py +22 -10
  11. jinns/loss/_DynamicLoss.py +21 -20
  12. jinns/loss/_DynamicLossAbstract.py +51 -36
  13. jinns/loss/_LossODE.py +210 -191
  14. jinns/loss/_LossPDE.py +441 -368
  15. jinns/loss/_abstract_loss.py +60 -25
  16. jinns/loss/_loss_components.py +4 -25
  17. jinns/loss/_loss_utils.py +23 -0
  18. jinns/loss/_loss_weight_updates.py +6 -7
  19. jinns/loss/_loss_weights.py +34 -35
  20. jinns/nn/_abstract_pinn.py +0 -2
  21. jinns/nn/_hyperpinn.py +34 -23
  22. jinns/nn/_mlp.py +5 -4
  23. jinns/nn/_pinn.py +1 -16
  24. jinns/nn/_ppinn.py +5 -16
  25. jinns/nn/_save_load.py +11 -4
  26. jinns/nn/_spinn.py +1 -16
  27. jinns/nn/_spinn_mlp.py +5 -5
  28. jinns/nn/_utils.py +33 -38
  29. jinns/parameters/__init__.py +3 -1
  30. jinns/parameters/_derivative_keys.py +99 -41
  31. jinns/parameters/_params.py +58 -25
  32. jinns/solver/_solve.py +14 -8
  33. jinns/utils/_DictToModuleMeta.py +66 -0
  34. jinns/utils/_ItemizableModule.py +19 -0
  35. jinns/utils/__init__.py +2 -1
  36. jinns/utils/_types.py +25 -15
  37. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/METADATA +2 -2
  38. jinns-1.6.0.dist-info/RECORD +57 -0
  39. jinns-1.5.0.dist-info/RECORD +0 -55
  40. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/WHEEL +0 -0
  41. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/AUTHORS +0 -0
  42. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/LICENSE +0 -0
  43. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/top_level.txt +0 -0
@@ -8,8 +8,11 @@ from __future__ import (
8
8
  import warnings
9
9
  import equinox as eqx
10
10
  import jax
11
+ import numpy as np
11
12
  import jax.numpy as jnp
12
- from jaxtyping import Key, Array, Float
13
+ from scipy.stats import qmc
14
+ from jaxtyping import PRNGKeyArray, Array, Float
15
+ from typing import Literal
13
16
  from jinns.data._Batchs import PDEStatioBatch
14
17
  from jinns.data._utils import _check_and_set_rar_parameters, _reset_or_increment
15
18
  from jinns.data._AbstractDataGenerator import AbstractDataGenerator
@@ -22,7 +25,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
22
25
 
23
26
  Parameters
24
27
  ----------
25
- key : Key
28
+ key : PRNGKeyArray
26
29
  Jax random key to sample new time points and to shuffle batches
27
30
  n : int
28
31
  The number of total $\Omega$ points that will be divided in
@@ -50,11 +53,13 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
50
53
  A tuple of maximum values of the domain along each dimension. For a sampling
51
54
  in `n` dimension, this represents $(x_{1, max}, x_{2,max}, ...,
52
55
  x_{n,max})$
53
- method : str, default="uniform"
54
- Either `grid` or `uniform`, default is `uniform`.
56
+ method : Literal["grid", "uniform", "sobol", "halton"], default="uniform"
57
+ Either "grid", "uniform", "sobol" or "halton", default is `uniform`.
55
58
  The method that generates the `nt` time points. `grid` means
56
59
  regularly spaced points over the domain. `uniform` means uniformly
57
- sampled points over the domain
60
+ sampled points over the domain.
61
+ **Note** that Sobol and Halton approaches use scipy modules and will not
62
+ be JIT compatible.
58
63
  rar_parameters : dict[str, int], default=None
59
64
  Defaults to None: do not use Residual Adaptative Resampling.
60
65
  Otherwise a dictionary with keys
@@ -75,32 +80,28 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
75
80
  then corresponds to the initial number of points we train the PINN on.
76
81
  """
77
82
 
78
- # kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
79
- key: Key = eqx.field(kw_only=True)
80
- n: int = eqx.field(kw_only=True, static=True)
83
+ key: PRNGKeyArray
84
+ n: int = eqx.field(static=True)
81
85
  nb: int | None = eqx.field(kw_only=True, static=True, default=None)
82
86
  omega_batch_size: int | None = eqx.field(
83
- kw_only=True,
84
87
  static=True,
85
- default=None, # can be None as
88
+ # can be None as
86
89
  # CubicMeshPDENonStatio inherits but also if omega_batch_size=n
87
90
  ) # static cause used as a
88
91
  # shape in jax.lax.dynamic_slice
89
92
  omega_border_batch_size: int | None = eqx.field(
90
- kw_only=True, static=True, default=None
93
+ static=True,
91
94
  ) # static cause used as a
92
95
  # shape in jax.lax.dynamic_slice
93
- dim: int = eqx.field(kw_only=True, static=True) # static cause used as a
96
+ dim: int = eqx.field(static=True) # static cause used as a
94
97
  # shape in jax.lax.dynamic_slice
95
- min_pts: tuple[float, ...] = eqx.field(kw_only=True)
96
- max_pts: tuple[float, ...] = eqx.field(kw_only=True)
97
- method: str = eqx.field(
98
- kw_only=True, static=True, default_factory=lambda: "uniform"
99
- )
100
- rar_parameters: dict[str, int] = eqx.field(kw_only=True, default=None)
101
- n_start: int = eqx.field(kw_only=True, default=None, static=True)
98
+ min_pts: tuple[float, ...]
99
+ max_pts: tuple[float, ...]
100
+ method: Literal["grid", "uniform", "sobol", "halton"] = eqx.field(static=True)
101
+ rar_parameters: None | dict[str, int]
102
+ n_start: int = eqx.field(static=True)
102
103
 
103
- # all the init=False fields are set in __post_init__
104
+ # --- Below fields are not passed as arguments to __init__
104
105
  p: Float[Array, " n"] | None = eqx.field(init=False)
105
106
  rar_iter_from_last_sampling: int | None = eqx.field(init=False)
106
107
  rar_iter_nb: int | None = eqx.field(init=False)
@@ -111,7 +112,32 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
111
112
  eqx.field(init=False)
112
113
  )
113
114
 
114
- def __post_init__(self):
115
+ def __init__(
116
+ self,
117
+ *,
118
+ key: PRNGKeyArray,
119
+ n: int,
120
+ nb: int | None = None,
121
+ omega_batch_size: int | None = None,
122
+ omega_border_batch_size: int | None = None,
123
+ dim: int,
124
+ min_pts: tuple[float, ...],
125
+ max_pts: tuple[float, ...],
126
+ method: Literal["grid", "uniform", "sobol", "halton"] = "uniform",
127
+ rar_parameters: dict[str, int] | None = None,
128
+ n_start: int | None = None,
129
+ ):
130
+ self.key = key
131
+ self.n = n
132
+ self.nb = nb
133
+ self.omega_batch_size = omega_batch_size
134
+ self.omega_border_batch_size = omega_border_batch_size
135
+ self.dim = dim
136
+ self.min_pts = min_pts
137
+ self.max_pts = max_pts
138
+ self.method = method
139
+ self.rar_parameters = rar_parameters
140
+
115
141
  assert self.dim == len(self.min_pts) and isinstance(self.min_pts, tuple)
116
142
  assert self.dim == len(self.max_pts) and isinstance(self.max_pts, tuple)
117
143
 
@@ -120,7 +146,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
120
146
  self.p,
121
147
  self.rar_iter_from_last_sampling,
122
148
  self.rar_iter_nb,
123
- ) = _check_and_set_rar_parameters(self.rar_parameters, self.n, self.n_start)
149
+ ) = _check_and_set_rar_parameters(self.rar_parameters, self.n, n_start)
124
150
 
125
151
  if self.method == "grid" and self.dim == 2:
126
152
  perfect_sq = int(jnp.round(jnp.sqrt(self.n)) ** 2)
@@ -132,6 +158,22 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
132
158
  )
133
159
  self.n = perfect_sq
134
160
 
161
+ if self.method in ["sobol", "halton"]:
162
+ log2_n = jnp.log2(self.n)
163
+ lower_pow = 2 ** jnp.floor(log2_n)
164
+ higher_pow = 2 ** jnp.ceil(log2_n)
165
+ closest_two_power = (
166
+ lower_pow
167
+ if (self.n - lower_pow) < (higher_pow - self.n)
168
+ else higher_pow
169
+ )
170
+ if self.n != closest_two_power:
171
+ warnings.warn(
172
+ f"QuasiMonteCarlo sampling with {self.method} requires sample size to be a power fo 2."
173
+ f"Modfiying self.n from {self.n} to {closest_two_power}.",
174
+ )
175
+ self.n = int(closest_two_power)
176
+
135
177
  if self.omega_batch_size is None:
136
178
  self.curr_omega_idx = 0
137
179
  else:
@@ -174,29 +216,53 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
174
216
  self.key, self.omega_border = self.generate_omega_border_data(self.key)
175
217
 
176
218
  def sample_in_omega_domain(
177
- self, keys: Key, sample_size: int
219
+ self, keys: list[PRNGKeyArray], sample_size: int
178
220
  ) -> Float[Array, " n dim"]:
221
+ if self.method == "uniform":
222
+ if self.dim == 1:
223
+ xmin, xmax = self.min_pts[0], self.max_pts[0]
224
+ return jax.random.uniform(
225
+ *keys, shape=(sample_size, 1), minval=xmin, maxval=xmax
226
+ )
227
+
228
+ return jnp.concatenate(
229
+ [
230
+ jax.random.uniform(
231
+ keys[i],
232
+ (sample_size, 1),
233
+ minval=self.min_pts[i],
234
+ maxval=self.max_pts[i],
235
+ )
236
+ for i in range(self.dim)
237
+ ],
238
+ axis=-1,
239
+ )
240
+ else:
241
+ return self._qmc_in_omega_domain(keys[0], sample_size)
242
+
243
+ def _qmc_in_omega_domain(
244
+ self, subkey: PRNGKeyArray, sample_size: int
245
+ ) -> Float[Array, "n dim"]:
246
+ qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
179
247
  if self.dim == 1:
180
- xmin, xmax = self.min_pts[0], self.max_pts[0]
181
- return jax.random.uniform(
182
- keys, shape=(sample_size, 1), minval=xmin, maxval=xmax
248
+ qmc_seq = qmc_generator(
249
+ d=self.dim,
250
+ scramble=True,
251
+ rng=np.random.default_rng(np.uint32(subkey)),
183
252
  )
184
- # keys = jax.random.split(key, self.dim)
185
- return jnp.concatenate(
186
- [
187
- jax.random.uniform(
188
- keys[i],
189
- (sample_size, 1),
190
- minval=self.min_pts[i],
191
- maxval=self.max_pts[i],
192
- )
193
- for i in range(self.dim)
194
- ],
195
- axis=-1,
253
+ u = qmc_seq.random(n=sample_size)
254
+ return jnp.array(
255
+ qmc.scale(u, l_bounds=self.min_pts[0], u_bounds=self.max_pts[0])
256
+ )
257
+ sampler = qmc.Sobol(
258
+ d=self.dim, scramble=True, rng=np.random.default_rng(np.uint32(subkey))
196
259
  )
260
+ samples = sampler.random(n=sample_size)
261
+ samples = qmc.scale(samples, l_bounds=self.min_pts, u_bounds=self.max_pts)
262
+ return jnp.array(samples)
197
263
 
198
264
  def sample_in_omega_border_domain(
199
- self, keys: Key, sample_size: int | None = None
265
+ self, keys: list[PRNGKeyArray] | None, sample_size: int | None = None
200
266
  ) -> Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None:
201
267
  sample_size = self.nb if sample_size is None else sample_size
202
268
  if sample_size is None:
@@ -206,6 +272,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
206
272
  xmax = self.max_pts[0]
207
273
  return jnp.array([xmin, xmax]).astype(float)
208
274
  if self.dim == 2:
275
+ assert keys is not None
209
276
  # currently hard-coded the 4 edges for d==2
210
277
  # TODO : find a general & efficient way to sample from the border
211
278
  # (facets) of the hypercube in general dim.
@@ -260,10 +327,67 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
260
327
  + f"implemented yet. You are asking for generation in dimension d={self.dim}."
261
328
  )
262
329
 
330
+ def qmc_in_omega_border_domain(
331
+ self, keys: list[PRNGKeyArray] | None, sample_size: int | None = None
332
+ ) -> Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None:
333
+ qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
334
+ sample_size = self.nb if sample_size is None else sample_size
335
+ if sample_size is None:
336
+ return None
337
+ if self.dim == 1:
338
+ xmin = self.min_pts[0]
339
+ xmax = self.max_pts[0]
340
+ return jnp.array([xmin, xmax]).astype(float)
341
+ if self.dim == 2:
342
+ assert keys is not None
343
+ # currently hard-coded the 4 edges for d==2
344
+ # TODO : find a general & efficient way to sample from the border
345
+ # (facets) of the hypercube in general dim.
346
+ facet_n = sample_size // (2 * self.dim)
347
+
348
+ def generate_qmc_sample(key, min_val, max_val):
349
+ qmc_seq = qmc_generator(
350
+ d=1,
351
+ scramble=True,
352
+ rng=np.random.default_rng(np.uint32(key)),
353
+ )
354
+ u = qmc_seq.random(n=facet_n)
355
+ return jnp.array(qmc.scale(u, l_bounds=min_val, u_bounds=max_val))
356
+
357
+ xmin = jnp.hstack(
358
+ [
359
+ self.min_pts[0] * jnp.ones((facet_n, 1)),
360
+ generate_qmc_sample(keys[0], self.min_pts[1], self.max_pts[1]),
361
+ ]
362
+ )
363
+ xmax = jnp.hstack(
364
+ [
365
+ self.max_pts[0] * jnp.ones((facet_n, 1)),
366
+ generate_qmc_sample(keys[1], self.min_pts[1], self.max_pts[1]),
367
+ ]
368
+ )
369
+ ymin = jnp.hstack(
370
+ [
371
+ generate_qmc_sample(keys[2], self.min_pts[0], self.max_pts[0]),
372
+ self.min_pts[1] * jnp.ones((facet_n, 1)),
373
+ ]
374
+ )
375
+ ymax = jnp.hstack(
376
+ [
377
+ generate_qmc_sample(keys[3], self.min_pts[0], self.max_pts[0]),
378
+ self.max_pts[1] * jnp.ones((facet_n, 1)),
379
+ ]
380
+ )
381
+ return jnp.stack([xmin, xmax, ymin, ymax], axis=-1)
382
+ raise NotImplementedError(
383
+ "Generation of the border of a cube in dimension > 2 is not "
384
+ + f"implemented yet. You are asking for generation in dimension d={self.dim}."
385
+ )
386
+
263
387
  def generate_omega_data(
264
- self, key: Key, data_size: int | None = None
388
+ self, key: PRNGKeyArray, data_size: int | None = None
265
389
  ) -> tuple[
266
- Key,
390
+ PRNGKeyArray,
267
391
  Float[Array, " n dim"],
268
392
  ]:
269
393
  r"""
@@ -290,20 +414,21 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
290
414
  )
291
415
  xyz_ = [a.reshape((data_size, 1)) for a in xyz_]
292
416
  omega = jnp.concatenate(xyz_, axis=-1)
293
- elif self.method == "uniform":
294
- if self.dim == 1:
295
- key, subkeys = jax.random.split(key, 2)
417
+ elif self.method in ["uniform", "sobol", "halton"]:
418
+ if self.dim == 1 or self.method in ["sobol", "halton"]:
419
+ key, subkey = jax.random.split(key, 2)
420
+ omega = self.sample_in_omega_domain([subkey], sample_size=data_size)
296
421
  else:
297
422
  key, *subkeys = jax.random.split(key, self.dim + 1)
298
- omega = self.sample_in_omega_domain(subkeys, sample_size=data_size)
423
+ omega = self.sample_in_omega_domain(subkeys, sample_size=data_size)
299
424
  else:
300
425
  raise ValueError("Method " + self.method + " is not implemented.")
301
426
  return key, omega
302
427
 
303
428
  def generate_omega_border_data(
304
- self, key: Key, data_size: int | None = None
429
+ self, key: PRNGKeyArray, data_size: int | None = None
305
430
  ) -> tuple[
306
- Key,
431
+ PRNGKeyArray,
307
432
  Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None,
308
433
  ]:
309
434
  r"""
@@ -317,15 +442,24 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
317
442
  key, *subkeys = jax.random.split(key, 5)
318
443
  else:
319
444
  subkeys = None
320
- omega_border = self.sample_in_omega_border_domain(
321
- subkeys, sample_size=data_size
322
- )
323
445
 
446
+ if self.method in ["grid", "uniform"]:
447
+ omega_border = self.sample_in_omega_border_domain(
448
+ subkeys, sample_size=data_size
449
+ )
450
+ elif self.method in ["sobol", "halton"]:
451
+ omega_border = self.qmc_in_omega_border_domain(
452
+ subkeys, sample_size=data_size
453
+ )
454
+ else:
455
+ raise ValueError("Method " + self.method + " is not implemented.")
324
456
  return key, omega_border
325
457
 
326
458
  def _get_omega_operands(
327
459
  self,
328
- ) -> tuple[Key, Float[Array, " n dim"], int, int | None, Float[Array, " n"] | None]:
460
+ ) -> tuple[
461
+ PRNGKeyArray, Float[Array, " n dim"], int, int | None, Float[Array, " n"] | None
462
+ ]:
329
463
  return (
330
464
  self.key,
331
465
  self.omega,
@@ -367,7 +501,9 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
367
501
  # handled above
368
502
  )
369
503
  new = eqx.tree_at(
370
- lambda m: (m.key, m.omega, m.curr_omega_idx), self, new_attributes
504
+ lambda m: (m.key, m.omega, m.curr_omega_idx), # type: ignore
505
+ self,
506
+ new_attributes,
371
507
  )
372
508
 
373
509
  return new, jax.lax.dynamic_slice(
@@ -379,7 +515,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
379
515
  def _get_omega_border_operands(
380
516
  self,
381
517
  ) -> tuple[
382
- Key,
518
+ PRNGKeyArray,
383
519
  Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None,
384
520
  int,
385
521
  int | None,
@@ -443,7 +579,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
443
579
  # handled above
444
580
  )
445
581
  new = eqx.tree_at(
446
- lambda m: (m.key, m.omega_border, m.curr_omega_border_idx),
582
+ lambda m: (m.key, m.omega_border, m.curr_omega_border_idx), # type: ignore
447
583
  self,
448
584
  new_attributes,
449
585
  )
@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING
9
9
  import equinox as eqx
10
10
  import jax
11
11
  import jax.numpy as jnp
12
- from jaxtyping import Key, Array, Float
12
+ from jaxtyping import PRNGKeyArray, Array, Float
13
13
  from jinns.data._Batchs import ODEBatch
14
14
  from jinns.data._utils import _check_and_set_rar_parameters, _reset_or_increment
15
15
  from jinns.data._AbstractDataGenerator import AbstractDataGenerator
@@ -24,7 +24,7 @@ class DataGeneratorODE(AbstractDataGenerator):
24
24
 
25
25
  Parameters
26
26
  ----------
27
- key : Key
27
+ key : PRNGKeyArray
28
28
  Jax random key to sample new time points and to shuffle batches
29
29
  nt : int
30
30
  The number of total time points that will be divided in
@@ -42,10 +42,10 @@ class DataGeneratorODE(AbstractDataGenerator):
42
42
  The method that generates the `nt` time points. `grid` means
43
43
  regularly spaced points over the domain. `uniform` means uniformly
44
44
  sampled points over the domain
45
- rar_parameters : RarParameterDict, default=None
45
+ rar_parameters : None | RarParameterDict, default=None
46
46
  A TypedDict to specify the Residual Adaptative Resampling procedure. See
47
47
  the docstring from RarParameterDict
48
- n_start : int, default=None
48
+ n_start : None | int, default=None
49
49
  Defaults to None. The effective size of nt used at start time.
50
50
  This value must be
51
51
  provided when rar_parameters is not None. Otherwise we set internally
@@ -54,25 +54,43 @@ class DataGeneratorODE(AbstractDataGenerator):
54
54
  then corresponds to the initial number of points we train the PINN.
55
55
  """
56
56
 
57
- key: Key = eqx.field(kw_only=True)
58
- nt: int = eqx.field(kw_only=True, static=True)
59
- tmin: Float = eqx.field(kw_only=True)
60
- tmax: Float = eqx.field(kw_only=True)
61
- temporal_batch_size: int | None = eqx.field(static=True, default=None, kw_only=True)
62
- method: str = eqx.field(
63
- static=True, kw_only=True, default_factory=lambda: "uniform"
64
- )
65
- rar_parameters: dict[str, int] = eqx.field(default=None, kw_only=True)
66
- n_start: int = eqx.field(static=True, default=None, kw_only=True)
67
-
68
- # all the init=False fields are set in __post_init__
57
+ key: PRNGKeyArray
58
+ nt: int = eqx.field(static=True)
59
+ tmin: float
60
+ tmax: float
61
+ temporal_batch_size: int | None = eqx.field(static=True)
62
+ method: str = eqx.field(static=True)
63
+ rar_parameters: None | dict[str, int]
64
+ n_start: None | int
65
+
66
+ # --- Below fields are not passed as arguments to __init__
69
67
  p: Float[Array, " nt 1"] | None = eqx.field(init=False)
70
68
  rar_iter_from_last_sampling: int | None = eqx.field(init=False)
71
69
  rar_iter_nb: int | None = eqx.field(init=False)
72
70
  curr_time_idx: int = eqx.field(init=False)
73
71
  times: Float[Array, " nt 1"] = eqx.field(init=False)
74
72
 
75
- def __post_init__(self):
73
+ def __init__(
74
+ self,
75
+ *,
76
+ key: PRNGKeyArray,
77
+ nt: int,
78
+ tmin: float,
79
+ tmax: float,
80
+ temporal_batch_size: int | None,
81
+ method: str = "uniform",
82
+ rar_parameters: None | dict[str, int] = None,
83
+ n_start: None | int = None,
84
+ ):
85
+ self.key = key
86
+ self.nt = nt
87
+ self.tmin = tmin
88
+ self.tmax = tmax
89
+ self.temporal_batch_size = temporal_batch_size
90
+ self.method = method
91
+ self.n_start = n_start
92
+ self.rar_parameters = rar_parameters
93
+
76
94
  (
77
95
  self.n_start,
78
96
  self.p,
@@ -97,7 +115,7 @@ class DataGeneratorODE(AbstractDataGenerator):
97
115
  # above way for the key.
98
116
 
99
117
  def sample_in_time_domain(
100
- self, key: Key, sample_size: int | None = None
118
+ self, key: PRNGKeyArray, sample_size: int | None = None
101
119
  ) -> Float[Array, " nt 1"]:
102
120
  return jax.random.uniform(
103
121
  key,
@@ -106,7 +124,9 @@ class DataGeneratorODE(AbstractDataGenerator):
106
124
  maxval=self.tmax,
107
125
  )
108
126
 
109
- def generate_time_data(self, key: Key) -> tuple[Key, Float[Array, " nt"]]:
127
+ def generate_time_data(
128
+ self, key: PRNGKeyArray
129
+ ) -> tuple[PRNGKeyArray, Float[Array, " nt"]]:
110
130
  """
111
131
  Construct a complete set of `self.nt` time points according to the
112
132
  specified `self.method`
@@ -125,7 +145,11 @@ class DataGeneratorODE(AbstractDataGenerator):
125
145
  def _get_time_operands(
126
146
  self,
127
147
  ) -> tuple[
128
- Key, Float[Array, " nt 1"], int, int | None, Float[Array, " nt 1"] | None
148
+ PRNGKeyArray,
149
+ Float[Array, " nt 1"],
150
+ int,
151
+ int | None,
152
+ Float[Array, " nt 1"] | None,
129
153
  ]:
130
154
  return (
131
155
  self.key,
@@ -150,7 +174,7 @@ class DataGeneratorODE(AbstractDataGenerator):
150
174
  bend = bstart + self.temporal_batch_size
151
175
 
152
176
  # Compute the effective number of used collocation points
153
- if self.rar_parameters is not None:
177
+ if self.rar_parameters is not None and self.n_start is not None:
154
178
  nt_eff = (
155
179
  self.n_start
156
180
  + self.rar_iter_nb # type: ignore
@@ -167,7 +191,9 @@ class DataGeneratorODE(AbstractDataGenerator):
167
191
  # handled above
168
192
  )
169
193
  new = eqx.tree_at(
170
- lambda m: (m.key, m.times, m.curr_time_idx), self, new_attributes
194
+ lambda m: (m.key, m.times, m.curr_time_idx), # type: ignore
195
+ self,
196
+ new_attributes,
171
197
  )
172
198
 
173
199
  # commands below are equivalent to