jinns 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +116 -189
  13. jinns/loss/_DynamicLossAbstract.py +45 -68
  14. jinns/loss/_LossODE.py +71 -336
  15. jinns/loss/_LossPDE.py +176 -513
  16. jinns/loss/__init__.py +28 -6
  17. jinns/loss/_abstract_loss.py +15 -0
  18. jinns/loss/_boundary_conditions.py +22 -21
  19. jinns/loss/_loss_utils.py +98 -173
  20. jinns/loss/_loss_weights.py +12 -44
  21. jinns/loss/_operators.py +84 -76
  22. jinns/nn/__init__.py +22 -0
  23. jinns/nn/_abstract_pinn.py +22 -0
  24. jinns/nn/_hyperpinn.py +434 -0
  25. jinns/nn/_mlp.py +217 -0
  26. jinns/nn/_pinn.py +204 -0
  27. jinns/nn/_ppinn.py +239 -0
  28. jinns/{utils → nn}/_save_load.py +39 -53
  29. jinns/nn/_spinn.py +123 -0
  30. jinns/nn/_spinn_mlp.py +202 -0
  31. jinns/nn/_utils.py +38 -0
  32. jinns/parameters/__init__.py +8 -1
  33. jinns/parameters/_derivative_keys.py +116 -177
  34. jinns/parameters/_params.py +18 -46
  35. jinns/plot/__init__.py +2 -0
  36. jinns/plot/_plot.py +38 -37
  37. jinns/solver/_rar.py +82 -65
  38. jinns/solver/_solve.py +111 -71
  39. jinns/solver/_utils.py +4 -6
  40. jinns/utils/__init__.py +2 -5
  41. jinns/utils/_containers.py +12 -9
  42. jinns/utils/_types.py +11 -57
  43. jinns/utils/_utils.py +4 -11
  44. jinns/validation/__init__.py +2 -0
  45. jinns/validation/_validation.py +20 -19
  46. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/METADATA +11 -10
  47. jinns-1.4.0.dist-info/RECORD +53 -0
  48. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
  49. jinns/data/_DataGenerators.py +0 -1634
  50. jinns/utils/_hyperpinn.py +0 -420
  51. jinns/utils/_pinn.py +0 -324
  52. jinns/utils/_ppinn.py +0 -227
  53. jinns/utils/_spinn.py +0 -249
  54. jinns-1.2.0.dist-info/RECORD +0 -41
  55. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
  56. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
  57. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,464 @@
1
+ """
2
+ Define the DataGenerators modules
3
+ """
4
+
5
+ from __future__ import (
6
+ annotations,
7
+ ) # https://docs.python.org/3/library/typing.html#constant
8
+ import warnings
9
+ import equinox as eqx
10
+ import jax
11
+ import jax.numpy as jnp
12
+ from jaxtyping import Key, Array, Float
13
+ from jinns.data._Batchs import PDEStatioBatch
14
+ from jinns.data._utils import _check_and_set_rar_parameters, _reset_or_increment
15
+ from jinns.data._AbstractDataGenerator import AbstractDataGenerator
16
+
17
+
18
+ class CubicMeshPDEStatio(AbstractDataGenerator):
19
+ r"""
20
+ A class implementing data generator object for stationary partial
21
+ differential equations.
22
+
23
+ Parameters
24
+ ----------
25
+ key : Key
26
+ Jax random key to sample new time points and to shuffle batches
27
+ n : int
28
+ The number of total $\Omega$ points that will be divided in
29
+ batches. Batches are made so that each data point is seen only
30
+ once during 1 epoch.
31
+ nb : int | None
32
+ The total number of points in $\partial\Omega$. Can be None if no
33
+ boundary condition is specified.
34
+ omega_batch_size : int | None, default=None
35
+ The size of the batch of randomly selected points among
36
+ the `n` points. If None no minibatches are used.
37
+ omega_border_batch_size : int | None, default=None
38
+ The size of the batch of points randomly selected
39
+ among the `nb` points. If None, `omega_border_batch_size`
40
+ no minibatches are used. In dimension 1,
41
+ minibatches are never used since the boundary is composed of two
42
+ singletons.
43
+ dim : int
44
+ Dimension of $\Omega$ domain
45
+ min_pts : tuple[tuple[Float, Float], ...]
46
+ A tuple of minimum values of the domain along each dimension. For a sampling
47
+ in `n` dimension, this represents $(x_{1, min}, x_{2,min}, ...,
48
+ x_{n, min})$
49
+ max_pts : tuple[tuple[Float, Float], ...]
50
+ A tuple of maximum values of the domain along each dimension. For a sampling
51
+ in `n` dimension, this represents $(x_{1, max}, x_{2,max}, ...,
52
+ x_{n,max})$
53
+ method : str, default="uniform"
54
+ Either `grid` or `uniform`, default is `uniform`.
55
+ The method that generates the `nt` time points. `grid` means
56
+ regularly spaced points over the domain. `uniform` means uniformly
57
+ sampled points over the domain
58
+ rar_parameters : dict[str, int], default=None
59
+ Defaults to None: do not use Residual Adaptative Resampling.
60
+ Otherwise a dictionary with keys
61
+ - `start_iter`: the iteration at which we start the RAR sampling scheme (we first have a "burn-in" period).
62
+ - `update_every`: the number of gradient steps taken between
63
+ each update of collocation points in the RAR algo.
64
+ - `sample_size`: the size of the sample from which we will select new
65
+ collocation points.
66
+ - `selected_sample_size`: the number of selected
67
+ points from the sample to be added to the current collocation
68
+ points.
69
+ n_start : int, default=None
70
+ Defaults to None. The effective size of n used at start time.
71
+ This value must be
72
+ provided when rar_parameters is not None. Otherwise we set internally
73
+ n_start = n and this is hidden from the user.
74
+ In RAR, n_start
75
+ then corresponds to the initial number of points we train the PINN on.
76
+ """
77
+
78
+ # kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
79
+ key: Key = eqx.field(kw_only=True)
80
+ n: int = eqx.field(kw_only=True, static=True)
81
+ nb: int | None = eqx.field(kw_only=True, static=True, default=None)
82
+ omega_batch_size: int | None = eqx.field(
83
+ kw_only=True,
84
+ static=True,
85
+ default=None, # can be None as
86
+ # CubicMeshPDENonStatio inherits but also if omega_batch_size=n
87
+ ) # static cause used as a
88
+ # shape in jax.lax.dynamic_slice
89
+ omega_border_batch_size: int | None = eqx.field(
90
+ kw_only=True, static=True, default=None
91
+ ) # static cause used as a
92
+ # shape in jax.lax.dynamic_slice
93
+ dim: int = eqx.field(kw_only=True, static=True) # static cause used as a
94
+ # shape in jax.lax.dynamic_slice
95
+ min_pts: tuple[float, ...] = eqx.field(kw_only=True)
96
+ max_pts: tuple[float, ...] = eqx.field(kw_only=True)
97
+ method: str = eqx.field(
98
+ kw_only=True, static=True, default_factory=lambda: "uniform"
99
+ )
100
+ rar_parameters: dict[str, int] = eqx.field(kw_only=True, default=None)
101
+ n_start: int = eqx.field(kw_only=True, default=None, static=True)
102
+
103
+ # all the init=False fields are set in __post_init__
104
+ p: Float[Array, " n"] | None = eqx.field(init=False)
105
+ rar_iter_from_last_sampling: int | None = eqx.field(init=False)
106
+ rar_iter_nb: int | None = eqx.field(init=False)
107
+ curr_omega_idx: int = eqx.field(init=False)
108
+ curr_omega_border_idx: int = eqx.field(init=False)
109
+ omega: Float[Array, " n dim"] = eqx.field(init=False)
110
+ omega_border: Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None = (
111
+ eqx.field(init=False)
112
+ )
113
+
114
+ def __post_init__(self):
115
+ assert self.dim == len(self.min_pts) and isinstance(self.min_pts, tuple)
116
+ assert self.dim == len(self.max_pts) and isinstance(self.max_pts, tuple)
117
+
118
+ (
119
+ self.n_start,
120
+ self.p,
121
+ self.rar_iter_from_last_sampling,
122
+ self.rar_iter_nb,
123
+ ) = _check_and_set_rar_parameters(self.rar_parameters, self.n, self.n_start)
124
+
125
+ if self.method == "grid" and self.dim == 2:
126
+ perfect_sq = int(jnp.round(jnp.sqrt(self.n)) ** 2)
127
+ if self.n != perfect_sq:
128
+ warnings.warn(
129
+ "Grid sampling is requested in dimension 2 with a non"
130
+ f" perfect square dataset size (self.n = {self.n})."
131
+ f" Modifying self.n to self.n = {perfect_sq}."
132
+ )
133
+ self.n = perfect_sq
134
+
135
+ if self.omega_batch_size is None:
136
+ self.curr_omega_idx = 0
137
+ else:
138
+ self.curr_omega_idx = self.n + self.omega_batch_size
139
+ # to be sure there is a shuffling at first get_batch()
140
+
141
+ if self.nb is not None:
142
+ if self.dim == 1:
143
+ self.omega_border_batch_size = None
144
+ # We are in 1-D case => omega_border_batch_size is
145
+ # ignored since borders of Omega are singletons.
146
+ # self.border_batch() will return [xmin, xmax]
147
+ else:
148
+ if self.nb % (2 * self.dim) != 0 or self.nb < 2 * self.dim:
149
+ raise ValueError(
150
+ f"number of border point must be"
151
+ f" a multiple of 2xd = {2 * self.dim} (the # of faces of"
152
+ f" a d-dimensional cube). Got {self.nb=}."
153
+ )
154
+ if (
155
+ self.omega_border_batch_size is not None
156
+ and self.nb // (2 * self.dim) < self.omega_border_batch_size
157
+ ):
158
+ raise ValueError(
159
+ f"number of points per facets ({self.nb // (2 * self.dim)})"
160
+ f" cannot be lower than border batch size "
161
+ f" ({self.omega_border_batch_size})."
162
+ )
163
+ self.nb = int((2 * self.dim) * (self.nb // (2 * self.dim)))
164
+
165
+ if self.omega_border_batch_size is None:
166
+ self.curr_omega_border_idx = 0
167
+ else:
168
+ self.curr_omega_border_idx = self.nb + self.omega_border_batch_size
169
+ # to be sure there is a shuffling at first get_batch()
170
+ else: # self.nb is None
171
+ self.curr_omega_border_idx = 0
172
+
173
+ self.key, self.omega = self.generate_omega_data(self.key)
174
+ self.key, self.omega_border = self.generate_omega_border_data(self.key)
175
+
176
+ def sample_in_omega_domain(
177
+ self, keys: Key, sample_size: int
178
+ ) -> Float[Array, " n dim"]:
179
+ if self.dim == 1:
180
+ xmin, xmax = self.min_pts[0], self.max_pts[0]
181
+ return jax.random.uniform(
182
+ keys, shape=(sample_size, 1), minval=xmin, maxval=xmax
183
+ )
184
+ # keys = jax.random.split(key, self.dim)
185
+ return jnp.concatenate(
186
+ [
187
+ jax.random.uniform(
188
+ keys[i],
189
+ (sample_size, 1),
190
+ minval=self.min_pts[i],
191
+ maxval=self.max_pts[i],
192
+ )
193
+ for i in range(self.dim)
194
+ ],
195
+ axis=-1,
196
+ )
197
+
198
+ def sample_in_omega_border_domain(
199
+ self, keys: Key, sample_size: int | None = None
200
+ ) -> Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None:
201
+ sample_size = self.nb if sample_size is None else sample_size
202
+ if sample_size is None:
203
+ return None
204
+ if self.dim == 1:
205
+ xmin = self.min_pts[0]
206
+ xmax = self.max_pts[0]
207
+ return jnp.array([xmin, xmax]).astype(float)
208
+ if self.dim == 2:
209
+ # currently hard-coded the 4 edges for d==2
210
+ # TODO : find a general & efficient way to sample from the border
211
+ # (facets) of the hypercube in general dim.
212
+ facet_n = sample_size // (2 * self.dim)
213
+ xmin = jnp.hstack(
214
+ [
215
+ self.min_pts[0] * jnp.ones((facet_n, 1)),
216
+ jax.random.uniform(
217
+ keys[0],
218
+ (facet_n, 1),
219
+ minval=self.min_pts[1],
220
+ maxval=self.max_pts[1],
221
+ ),
222
+ ]
223
+ )
224
+ xmax = jnp.hstack(
225
+ [
226
+ self.max_pts[0] * jnp.ones((facet_n, 1)),
227
+ jax.random.uniform(
228
+ keys[1],
229
+ (facet_n, 1),
230
+ minval=self.min_pts[1],
231
+ maxval=self.max_pts[1],
232
+ ),
233
+ ]
234
+ )
235
+ ymin = jnp.hstack(
236
+ [
237
+ jax.random.uniform(
238
+ keys[2],
239
+ (facet_n, 1),
240
+ minval=self.min_pts[0],
241
+ maxval=self.max_pts[0],
242
+ ),
243
+ self.min_pts[1] * jnp.ones((facet_n, 1)),
244
+ ]
245
+ )
246
+ ymax = jnp.hstack(
247
+ [
248
+ jax.random.uniform(
249
+ keys[3],
250
+ (facet_n, 1),
251
+ minval=self.min_pts[0],
252
+ maxval=self.max_pts[0],
253
+ ),
254
+ self.max_pts[1] * jnp.ones((facet_n, 1)),
255
+ ]
256
+ )
257
+ return jnp.stack([xmin, xmax, ymin, ymax], axis=-1)
258
+ raise NotImplementedError(
259
+ "Generation of the border of a cube in dimension > 2 is not "
260
+ + f"implemented yet. You are asking for generation in dimension d={self.dim}."
261
+ )
262
+
263
+ def generate_omega_data(
264
+ self, key: Key, data_size: int | None = None
265
+ ) -> tuple[
266
+ Key,
267
+ Float[Array, " n dim"],
268
+ ]:
269
+ r"""
270
+ Construct a complete set of `self.n` $\Omega$ points according to the
271
+ specified `self.method`.
272
+ """
273
+ data_size = self.n if data_size is None else data_size
274
+ # Generate Omega
275
+ if self.method == "grid":
276
+ if self.dim == 1:
277
+ xmin, xmax = self.min_pts[0], self.max_pts[0]
278
+ ## shape (n, 1)
279
+ omega = jnp.linspace(xmin, xmax, data_size)[:, None]
280
+ else:
281
+ xyz_ = jnp.meshgrid(
282
+ *[
283
+ jnp.linspace(
284
+ self.min_pts[i],
285
+ self.max_pts[i],
286
+ int(jnp.round(jnp.sqrt(data_size))),
287
+ )
288
+ for i in range(self.dim)
289
+ ]
290
+ )
291
+ xyz_ = [a.reshape((data_size, 1)) for a in xyz_]
292
+ omega = jnp.concatenate(xyz_, axis=-1)
293
+ elif self.method == "uniform":
294
+ if self.dim == 1:
295
+ key, subkeys = jax.random.split(key, 2)
296
+ else:
297
+ key, *subkeys = jax.random.split(key, self.dim + 1)
298
+ omega = self.sample_in_omega_domain(subkeys, sample_size=data_size)
299
+ else:
300
+ raise ValueError("Method " + self.method + " is not implemented.")
301
+ return key, omega
302
+
303
+ def generate_omega_border_data(
304
+ self, key: Key, data_size: int | None = None
305
+ ) -> tuple[
306
+ Key,
307
+ Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None,
308
+ ]:
309
+ r"""
310
+ Also constructs a complete set of `self.nb`
311
+ $\partial\Omega$ points if `self.omega_border_batch_size` is not
312
+ `None`. If the latter is `None` we set `self.omega_border` to `None`.
313
+ """
314
+ # Generate border of omega
315
+ data_size = self.nb if data_size is None else data_size
316
+ if self.dim == 2:
317
+ key, *subkeys = jax.random.split(key, 5)
318
+ else:
319
+ subkeys = None
320
+ omega_border = self.sample_in_omega_border_domain(
321
+ subkeys, sample_size=data_size
322
+ )
323
+
324
+ return key, omega_border
325
+
326
+ def _get_omega_operands(
327
+ self,
328
+ ) -> tuple[Key, Float[Array, " n dim"], int, int | None, Float[Array, " n"] | None]:
329
+ return (
330
+ self.key,
331
+ self.omega,
332
+ self.curr_omega_idx,
333
+ self.omega_batch_size,
334
+ self.p,
335
+ )
336
+
337
+ def inside_batch(
338
+ self,
339
+ ) -> tuple[CubicMeshPDEStatio, Float[Array, " omega_batch_size dim"]]:
340
+ r"""
341
+ Return a batch of points in $\Omega$.
342
+ If all the batches have been seen, we reshuffle them,
343
+ otherwise we just return the next unseen batch.
344
+ """
345
+ if self.omega_batch_size is None or self.omega_batch_size == self.n:
346
+ # Avoid unnecessary reshuffling
347
+ return self, self.omega
348
+
349
+ # Compute the effective number of used collocation points
350
+ if self.rar_parameters is not None:
351
+ n_eff = (
352
+ self.n_start
353
+ + self.rar_iter_nb # type: ignore
354
+ * self.rar_parameters["selected_sample_size"]
355
+ )
356
+ else:
357
+ n_eff = self.n
358
+
359
+ bstart = self.curr_omega_idx
360
+ bend = bstart + self.omega_batch_size
361
+
362
+ new_attributes = _reset_or_increment(
363
+ bend,
364
+ n_eff,
365
+ self._get_omega_operands(), # type: ignore
366
+ # ignore since the case self.omega_batch_size is None has been
367
+ # handled above
368
+ )
369
+ new = eqx.tree_at(
370
+ lambda m: (m.key, m.omega, m.curr_omega_idx), self, new_attributes
371
+ )
372
+
373
+ return new, jax.lax.dynamic_slice(
374
+ new.omega,
375
+ start_indices=(new.curr_omega_idx, 0),
376
+ slice_sizes=(new.omega_batch_size, new.dim),
377
+ )
378
+
379
+ def _get_omega_border_operands(
380
+ self,
381
+ ) -> tuple[
382
+ Key,
383
+ Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None,
384
+ int,
385
+ int | None,
386
+ None,
387
+ ]:
388
+ return (
389
+ self.key,
390
+ self.omega_border,
391
+ self.curr_omega_border_idx,
392
+ self.omega_border_batch_size,
393
+ None,
394
+ )
395
+
396
+ def border_batch(
397
+ self,
398
+ ) -> tuple[
399
+ CubicMeshPDEStatio,
400
+ Float[Array, " 1 1 2"] | Float[Array, " omega_border_batch_size 2 4"] | None,
401
+ ]:
402
+ r"""
403
+ Return
404
+
405
+ - The value `None` if `self.omega_border_batch_size` is `None`.
406
+
407
+ - a jnp array with two fixed values $(x_{min}, x_{max})$ if
408
+ `self.dim` = 1. There is no sampling here, we return the entire
409
+ $\partial\Omega$
410
+
411
+ - a batch of points in $\partial\Omega$ otherwise, stacked by
412
+ facet on the last axis.
413
+ If all the batches have been seen, we reshuffle them,
414
+ otherwise we just return the next unseen batch.
415
+
416
+
417
+ """
418
+ if self.nb is None or self.omega_border is None:
419
+ # Avoid unnecessary reshuffling
420
+ return self, None
421
+
422
+ if self.dim == 1:
423
+ # Avoid unnecessary reshuffling
424
+ # 1-D case, no randomness : we always return the whole omega border,
425
+ # i.e. (1, 1, 2) shape jnp.array([[[xmin], [xmax]]]).
426
+ return self, self.omega_border[None, None] # shape is (1, 1, 2)
427
+
428
+ if (
429
+ self.omega_border_batch_size is None
430
+ or self.omega_border_batch_size == self.nb // 2**self.dim
431
+ ):
432
+ # Avoid unnecessary reshuffling
433
+ return self, self.omega_border
434
+
435
+ bstart = self.curr_omega_border_idx
436
+ bend = bstart + self.omega_border_batch_size
437
+
438
+ new_attributes = _reset_or_increment(
439
+ bend,
440
+ self.nb,
441
+ self._get_omega_border_operands(), # type: ignore
442
+ # ignore since the case self.omega_border_batch_size is None has been
443
+ # handled above
444
+ )
445
+ new = eqx.tree_at(
446
+ lambda m: (m.key, m.omega_border, m.curr_omega_border_idx),
447
+ self,
448
+ new_attributes,
449
+ )
450
+
451
+ return new, jax.lax.dynamic_slice(
452
+ new.omega_border,
453
+ start_indices=(new.curr_omega_border_idx, 0, 0),
454
+ slice_sizes=(new.omega_border_batch_size, new.dim, 2 * new.dim),
455
+ )
456
+
457
+ def get_batch(self) -> tuple[CubicMeshPDEStatio, PDEStatioBatch]:
458
+ """
459
+ Generic method to return a batch. Here we call `self.inside_batch()`
460
+ and `self.border_batch()`
461
+ """
462
+ new, inside_batch = self.inside_batch()
463
+ new, border_batch = new.border_batch()
464
+ return new, PDEStatioBatch(domain_batch=inside_batch, border_batch=border_batch)
@@ -0,0 +1,187 @@
1
+ """
2
+ Define the DataGenerators modules
3
+ """
4
+
5
+ from __future__ import (
6
+ annotations,
7
+ ) # https://docs.python.org/3/library/typing.html#constant
8
+ from typing import TYPE_CHECKING
9
+ import equinox as eqx
10
+ import jax
11
+ import jax.numpy as jnp
12
+ from jaxtyping import Key, Array, Float
13
+ from jinns.data._Batchs import ODEBatch
14
+ from jinns.data._utils import _check_and_set_rar_parameters, _reset_or_increment
15
+ from jinns.data._AbstractDataGenerator import AbstractDataGenerator
16
+
17
+ if TYPE_CHECKING:
18
+ pass
19
+
20
+
21
+ class DataGeneratorODE(AbstractDataGenerator):
22
+ """
23
+ A class implementing data generator object for ordinary differential equations.
24
+
25
+ Parameters
26
+ ----------
27
+ key : Key
28
+ Jax random key to sample new time points and to shuffle batches
29
+ nt : int
30
+ The number of total time points that will be divided in
31
+ batches. Batches are made so that each data point is seen only
32
+ once during 1 epoch.
33
+ tmin : float
34
+ The minimum value of the time domain to consider
35
+ tmax : float
36
+ The maximum value of the time domain to consider
37
+ temporal_batch_size : int | None, default=None
38
+ The size of the batch of randomly selected points among
39
+ the `nt` points. If None, no minibatches are used.
40
+ method : str, default="uniform"
41
+ Either `grid` or `uniform`, default is `uniform`.
42
+ The method that generates the `nt` time points. `grid` means
43
+ regularly spaced points over the domain. `uniform` means uniformly
44
+ sampled points over the domain
45
+ rar_parameters : RarParameterDict, default=None
46
+ A TypedDict to specify the Residual Adaptative Resampling procedure. See
47
+ the docstring from RarParameterDict
48
+ n_start : int, default=None
49
+ Defaults to None. The effective size of nt used at start time.
50
+ This value must be
51
+ provided when rar_parameters is not None. Otherwise we set internally
52
+ n_start = nt and this is hidden from the user.
53
+ In RAR, n_start
54
+ then corresponds to the initial number of points we train the PINN.
55
+ """
56
+
57
+ key: Key = eqx.field(kw_only=True)
58
+ nt: int = eqx.field(kw_only=True, static=True)
59
+ tmin: Float = eqx.field(kw_only=True)
60
+ tmax: Float = eqx.field(kw_only=True)
61
+ temporal_batch_size: int | None = eqx.field(static=True, default=None, kw_only=True)
62
+ method: str = eqx.field(
63
+ static=True, kw_only=True, default_factory=lambda: "uniform"
64
+ )
65
+ rar_parameters: dict[str, int] = eqx.field(default=None, kw_only=True)
66
+ n_start: int = eqx.field(static=True, default=None, kw_only=True)
67
+
68
+ # all the init=False fields are set in __post_init__
69
+ p: Float[Array, " nt 1"] | None = eqx.field(init=False)
70
+ rar_iter_from_last_sampling: int | None = eqx.field(init=False)
71
+ rar_iter_nb: int | None = eqx.field(init=False)
72
+ curr_time_idx: int = eqx.field(init=False)
73
+ times: Float[Array, " nt 1"] = eqx.field(init=False)
74
+
75
+ def __post_init__(self):
76
+ (
77
+ self.n_start,
78
+ self.p,
79
+ self.rar_iter_from_last_sampling,
80
+ self.rar_iter_nb,
81
+ ) = _check_and_set_rar_parameters(self.rar_parameters, self.nt, self.n_start)
82
+
83
+ if self.temporal_batch_size is not None:
84
+ self.curr_time_idx = self.nt + self.temporal_batch_size
85
+ # to be sure there is a shuffling at first get_batch()
86
+ # NOTE in the extreme case we could do:
87
+ # self.curr_time_idx=jnp.iinfo(jnp.int32).max - self.temporal_batch_size - 1
88
+ # but we do not test for such extreme values. Where we subtract
89
+ # self.temporal_batch_size - 1 because otherwise when computing
90
+ # `bend` we do not want to overflow the max int32 with unwanted behaviour
91
+ else:
92
+ self.curr_time_idx = 0
93
+
94
+ self.key, self.times = self.generate_time_data(self.key)
95
+ # Note that, here, in __init__ (and __post_init__), this is the
96
+ # only place where self assignment are authorized so we do the
97
+ # above way for the key.
98
+
99
+ def sample_in_time_domain(
100
+ self, key: Key, sample_size: int | None = None
101
+ ) -> Float[Array, " nt 1"]:
102
+ return jax.random.uniform(
103
+ key,
104
+ (self.nt if sample_size is None else sample_size, 1),
105
+ minval=self.tmin,
106
+ maxval=self.tmax,
107
+ )
108
+
109
+ def generate_time_data(self, key: Key) -> tuple[Key, Float[Array, " nt"]]:
110
+ """
111
+ Construct a complete set of `self.nt` time points according to the
112
+ specified `self.method`
113
+
114
+ Note that self.times has always size self.nt and not self.n_start, even
115
+ in RAR scheme, we must allocate all the collocation points
116
+ """
117
+ key, subkey = jax.random.split(self.key)
118
+ if self.method == "grid":
119
+ partial_times = (self.tmax - self.tmin) / self.nt
120
+ return key, jnp.arange(self.tmin, self.tmax, partial_times)[:, None]
121
+ if self.method == "uniform":
122
+ return key, self.sample_in_time_domain(subkey)
123
+ raise ValueError("Method " + self.method + " is not implemented.")
124
+
125
+ def _get_time_operands(
126
+ self,
127
+ ) -> tuple[
128
+ Key, Float[Array, " nt 1"], int, int | None, Float[Array, " nt 1"] | None
129
+ ]:
130
+ return (
131
+ self.key,
132
+ self.times,
133
+ self.curr_time_idx,
134
+ self.temporal_batch_size,
135
+ self.p,
136
+ )
137
+
138
+ def temporal_batch(
139
+ self,
140
+ ) -> tuple[DataGeneratorODE, Float[Array, " temporal_batch_size"]]:
141
+ """
142
+ Return a batch of time points. If all the batches have been seen, we
143
+ reshuffle them, otherwise we just return the next unseen batch.
144
+ """
145
+ if self.temporal_batch_size is None or self.temporal_batch_size == self.nt:
146
+ # Avoid unnecessary reshuffling
147
+ return self, self.times
148
+
149
+ bstart = self.curr_time_idx
150
+ bend = bstart + self.temporal_batch_size
151
+
152
+ # Compute the effective number of used collocation points
153
+ if self.rar_parameters is not None:
154
+ nt_eff = (
155
+ self.n_start
156
+ + self.rar_iter_nb # type: ignore
157
+ * self.rar_parameters["selected_sample_size"]
158
+ )
159
+ else:
160
+ nt_eff = self.nt
161
+
162
+ new_attributes = _reset_or_increment(
163
+ bend,
164
+ nt_eff,
165
+ self._get_time_operands(), # type: ignore
166
+ # ignore since the case self.temporal_batch_size is None has been
167
+ # handled above
168
+ )
169
+ new = eqx.tree_at(
170
+ lambda m: (m.key, m.times, m.curr_time_idx), self, new_attributes
171
+ )
172
+
173
+ # commands below are equivalent to
174
+ # return self.times[i:(i+t_batch_size)]
175
+ # start indices can be dynamic but the slice shape is fixed
176
+ return new, jax.lax.dynamic_slice(
177
+ new.times,
178
+ start_indices=(new.curr_time_idx, 0),
179
+ slice_sizes=(new.temporal_batch_size, 1),
180
+ )
181
+
182
+ def get_batch(self) -> tuple[DataGeneratorODE, ODEBatch]:
183
+ """
184
+ Generic method to return a batch. Here we call `self.temporal_batch()`
185
+ """
186
+ new, temporal_batch = self.temporal_batch()
187
+ return new, ODEBatch(temporal_batch=temporal_batch)