jinns 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +114 -187
  13. jinns/loss/_DynamicLossAbstract.py +45 -68
  14. jinns/loss/_LossODE.py +71 -336
  15. jinns/loss/_LossPDE.py +146 -520
  16. jinns/loss/__init__.py +28 -6
  17. jinns/loss/_abstract_loss.py +15 -0
  18. jinns/loss/_boundary_conditions.py +20 -19
  19. jinns/loss/_loss_utils.py +78 -159
  20. jinns/loss/_loss_weights.py +12 -44
  21. jinns/loss/_operators.py +84 -74
  22. jinns/nn/__init__.py +15 -0
  23. jinns/nn/_abstract_pinn.py +22 -0
  24. jinns/nn/_hyperpinn.py +94 -57
  25. jinns/nn/_mlp.py +50 -25
  26. jinns/nn/_pinn.py +33 -19
  27. jinns/nn/_ppinn.py +70 -34
  28. jinns/nn/_save_load.py +21 -51
  29. jinns/nn/_spinn.py +33 -16
  30. jinns/nn/_spinn_mlp.py +28 -22
  31. jinns/nn/_utils.py +38 -0
  32. jinns/parameters/__init__.py +8 -1
  33. jinns/parameters/_derivative_keys.py +116 -177
  34. jinns/parameters/_params.py +18 -46
  35. jinns/plot/__init__.py +2 -0
  36. jinns/plot/_plot.py +35 -34
  37. jinns/solver/_rar.py +80 -63
  38. jinns/solver/_solve.py +89 -63
  39. jinns/solver/_utils.py +4 -6
  40. jinns/utils/__init__.py +2 -0
  41. jinns/utils/_containers.py +12 -9
  42. jinns/utils/_types.py +11 -57
  43. jinns/utils/_utils.py +4 -11
  44. jinns/validation/__init__.py +2 -0
  45. jinns/validation/_validation.py +20 -19
  46. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/METADATA +4 -3
  47. jinns-1.4.0.dist-info/RECORD +53 -0
  48. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
  49. jinns/data/_DataGenerators.py +0 -1634
  50. jinns-1.3.0.dist-info/RECORD +0 -44
  51. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
  52. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
  53. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
@@ -1,1634 +0,0 @@
1
- # pylint: disable=unsubscriptable-object
2
- """
3
- Define the DataGenerators modules
4
- """
5
- from __future__ import (
6
- annotations,
7
- ) # https://docs.python.org/3/library/typing.html#constant
8
- import warnings
9
- from typing import TYPE_CHECKING, Dict
10
- from dataclasses import InitVar
11
- import equinox as eqx
12
- import jax
13
- import jax.numpy as jnp
14
- from jaxtyping import Key, Int, PyTree, Array, Float, Bool
15
- from jinns.data._Batchs import *
16
-
17
- if TYPE_CHECKING:
18
- from jinns.utils._types import *
19
-
20
-
21
- def append_param_batch(batch: AnyBatch, param_batch_dict: dict) -> AnyBatch:
22
- """
23
- Utility function that fills the field `batch.param_batch_dict` of a batch object.
24
- """
25
- return eqx.tree_at(
26
- lambda m: m.param_batch_dict,
27
- batch,
28
- param_batch_dict,
29
- is_leaf=lambda x: x is None,
30
- )
31
-
32
-
33
- def append_obs_batch(batch: AnyBatch, obs_batch_dict: dict) -> AnyBatch:
34
- """
35
- Utility function that fills the field `batch.obs_batch_dict` of a batch object
36
- """
37
- return eqx.tree_at(
38
- lambda m: m.obs_batch_dict, batch, obs_batch_dict, is_leaf=lambda x: x is None
39
- )
40
-
41
-
42
- def make_cartesian_product(
43
- b1: Float[Array, "batch_size dim1"], b2: Float[Array, "batch_size dim2"]
44
- ) -> Float[Array, "(batch_size*batch_size) (dim1+dim2)"]:
45
- """
46
- Create the cartesian product of a time and a border omega batches
47
- by tiling and repeating
48
- """
49
- n1 = b1.shape[0]
50
- n2 = b2.shape[0]
51
- b1 = jnp.repeat(b1, n2, axis=0)
52
- b2 = jnp.tile(b2, reps=(n1,) + tuple(1 for i in b2.shape[1:]))
53
- return jnp.concatenate([b1, b2], axis=1)
54
-
55
-
56
- def _reset_batch_idx_and_permute(
57
- operands: tuple[Key, Float[Array, "n dimension"], Int, None, Float[Array, "n"]],
58
- ) -> tuple[Key, Float[Array, "n dimension"], Int]:
59
- key, domain, curr_idx, _, p = operands
60
- # resetting counter
61
- curr_idx = 0
62
- # reshuffling
63
- key, subkey = jax.random.split(key)
64
- if p is None:
65
- domain = jax.random.permutation(subkey, domain, axis=0, independent=False)
66
- else:
67
- # otherwise p is used to avoid collocation points not in n_start
68
- # NOTE that replace=True to avoid undefined behaviour but then, the
69
- # domain.shape[0] does not really grow as in the original RAR. instead,
70
- # it always comprises the same number of points, but the points are
71
- # updated
72
- domain = jax.random.choice(
73
- subkey, domain, shape=(domain.shape[0],), replace=True, p=p
74
- )
75
-
76
- # return updated
77
- return (key, domain, curr_idx)
78
-
79
-
80
- def _increment_batch_idx(
81
- operands: tuple[Key, Float[Array, "n dimension"], Int, None, Float[Array, "n"]],
82
- ) -> tuple[Key, Float[Array, "n dimension"], Int]:
83
- key, domain, curr_idx, batch_size, _ = operands
84
- # simply increases counter and get the batch
85
- curr_idx += batch_size
86
- return (key, domain, curr_idx)
87
-
88
-
89
- def _reset_or_increment(
90
- bend: Int,
91
- n_eff: Int,
92
- operands: tuple[Key, Float[Array, "n dimension"], Int, None, Float[Array, "n"]],
93
- ) -> tuple[Key, Float[Array, "n dimension"], Int]:
94
- """
95
- Factorize the code of the jax.lax.cond which checks if we have seen all the
96
- batches in an epoch
97
- If bend > n_eff (ie n when no RAR sampling) we reshuffle and start from 0
98
- again. Otherwise, if bend < n_eff, this means there are still *_batch_size
99
- samples at least that have not been seen and we can take a new batch
100
-
101
- Parameters
102
- ----------
103
- bend
104
- An integer. The new hypothetical index for the starting of the batch
105
- n_eff
106
- An integer. The number of points to see to complete an epoch
107
- operands
108
- A tuple. As passed to _reset_batch_idx_and_permute and
109
- _increment_batch_idx
110
-
111
- Returns
112
- -------
113
- res
114
- A tuple as returned by _reset_batch_idx_and_permute or
115
- _increment_batch_idx
116
- """
117
- return jax.lax.cond(
118
- bend > n_eff, _reset_batch_idx_and_permute, _increment_batch_idx, operands
119
- )
120
-
121
-
122
- def _check_and_set_rar_parameters(
123
- rar_parameters: dict, n: Int, n_start: Int
124
- ) -> tuple[Int, Float[Array, "n"], Int, Int]:
125
- if rar_parameters is not None and n_start is None:
126
- raise ValueError(
127
- "n_start must be provided in the context of RAR sampling scheme"
128
- )
129
-
130
- if rar_parameters is not None:
131
- # Default p is None. However, in the RAR sampling scheme we use 0
132
- # probability to specify non-used collocation points (i.e. points
133
- # above n_start). Thus, p is a vector of probability of shape (nt, 1).
134
- p = jnp.zeros((n,))
135
- p = p.at[:n_start].set(1 / n_start)
136
- # set internal counter for the number of gradient steps since the
137
- # last new collocation points have been added
138
- # It is not 0 to ensure the first iteration of RAR happens just
139
- # after start_iter. See the _proceed_to_rar() function in _rar.py
140
- rar_iter_from_last_sampling = rar_parameters["update_every"] - 1
141
- # set iternal counter for the number of times collocation points
142
- # have been added
143
- rar_iter_nb = 0
144
- else:
145
- n_start = n
146
- p = None
147
- rar_iter_from_last_sampling = None
148
- rar_iter_nb = None
149
-
150
- return n_start, p, rar_iter_from_last_sampling, rar_iter_nb
151
-
152
-
153
- class DataGeneratorODE(eqx.Module):
154
- """
155
- A class implementing data generator object for ordinary differential equations.
156
-
157
- Parameters
158
- ----------
159
- key : Key
160
- Jax random key to sample new time points and to shuffle batches
161
- nt : Int
162
- The number of total time points that will be divided in
163
- batches. Batches are made so that each data point is seen only
164
- once during 1 epoch.
165
- tmin : float
166
- The minimum value of the time domain to consider
167
- tmax : float
168
- The maximum value of the time domain to consider
169
- temporal_batch_size : int | None, default=None
170
- The size of the batch of randomly selected points among
171
- the `nt` points. If None, no minibatches are used.
172
- method : str, default="uniform"
173
- Either `grid` or `uniform`, default is `uniform`.
174
- The method that generates the `nt` time points. `grid` means
175
- regularly spaced points over the domain. `uniform` means uniformly
176
- sampled points over the domain
177
- rar_parameters : Dict[str, Int], default=None
178
- Defaults to None: do not use Residual Adaptative Resampling.
179
- Otherwise a dictionary with keys
180
-
181
- - `start_iter`: the iteration at which we start the RAR sampling scheme (we first have a "burn-in" period).
182
- - `update_every`: the number of gradient steps taken between
183
- each update of collocation points in the RAR algo.
184
- - `sample_size`: the size of the sample from which we will select new
185
- collocation points.
186
- - `selected_sample_size`: the number of selected
187
- points from the sample to be added to the current collocation
188
- points.
189
- n_start : Int, default=None
190
- Defaults to None. The effective size of nt used at start time.
191
- This value must be
192
- provided when rar_parameters is not None. Otherwise we set internally
193
- n_start = nt and this is hidden from the user.
194
- In RAR, n_start
195
- then corresponds to the initial number of points we train the PINN.
196
- """
197
-
198
- key: Key = eqx.field(kw_only=True)
199
- nt: Int = eqx.field(kw_only=True, static=True)
200
- tmin: Float = eqx.field(kw_only=True)
201
- tmax: Float = eqx.field(kw_only=True)
202
- temporal_batch_size: Int | None = eqx.field(static=True, default=None, kw_only=True)
203
- method: str = eqx.field(
204
- static=True, kw_only=True, default_factory=lambda: "uniform"
205
- )
206
- rar_parameters: Dict[str, Int] = eqx.field(default=None, kw_only=True)
207
- n_start: Int = eqx.field(static=True, default=None, kw_only=True)
208
-
209
- # all the init=False fields are set in __post_init__
210
- p: Float[Array, "nt 1"] = eqx.field(init=False)
211
- rar_iter_from_last_sampling: Int = eqx.field(init=False)
212
- rar_iter_nb: Int = eqx.field(init=False)
213
- curr_time_idx: Int = eqx.field(init=False)
214
- times: Float[Array, "nt 1"] = eqx.field(init=False)
215
-
216
- def __post_init__(self):
217
- (
218
- self.n_start,
219
- self.p,
220
- self.rar_iter_from_last_sampling,
221
- self.rar_iter_nb,
222
- ) = _check_and_set_rar_parameters(self.rar_parameters, self.nt, self.n_start)
223
-
224
- if self.temporal_batch_size is not None:
225
- self.curr_time_idx = self.nt + self.temporal_batch_size
226
- # to be sure there is a shuffling at first get_batch()
227
- # NOTE in the extreme case we could do:
228
- # self.curr_time_idx=jnp.iinfo(jnp.int32).max - self.temporal_batch_size - 1
229
- # but we do not test for such extreme values. Where we subtract
230
- # self.temporal_batch_size - 1 because otherwise when computing
231
- # `bend` we do not want to overflow the max int32 with unwanted behaviour
232
- else:
233
- self.curr_time_idx = 0
234
-
235
- self.key, self.times = self.generate_time_data(self.key)
236
- # Note that, here, in __init__ (and __post_init__), this is the
237
- # only place where self assignment are authorized so we do the
238
- # above way for the key.
239
-
240
- def sample_in_time_domain(
241
- self, key: Key, sample_size: Int = None
242
- ) -> Float[Array, "nt 1"]:
243
- return jax.random.uniform(
244
- key,
245
- (self.nt if sample_size is None else sample_size, 1),
246
- minval=self.tmin,
247
- maxval=self.tmax,
248
- )
249
-
250
- def generate_time_data(self, key: Key) -> tuple[Key, Float[Array, "nt"]]:
251
- """
252
- Construct a complete set of `self.nt` time points according to the
253
- specified `self.method`
254
-
255
- Note that self.times has always size self.nt and not self.n_start, even
256
- in RAR scheme, we must allocate all the collocation points
257
- """
258
- key, subkey = jax.random.split(self.key)
259
- if self.method == "grid":
260
- partial_times = (self.tmax - self.tmin) / self.nt
261
- return key, jnp.arange(self.tmin, self.tmax, partial_times)[:, None]
262
- if self.method == "uniform":
263
- return key, self.sample_in_time_domain(subkey)
264
- raise ValueError("Method " + self.method + " is not implemented.")
265
-
266
- def _get_time_operands(
267
- self,
268
- ) -> tuple[Key, Float[Array, "nt 1"], Int, Int, Float[Array, "nt 1"]]:
269
- return (
270
- self.key,
271
- self.times,
272
- self.curr_time_idx,
273
- self.temporal_batch_size,
274
- self.p,
275
- )
276
-
277
- def temporal_batch(
278
- self,
279
- ) -> tuple["DataGeneratorODE", Float[Array, "temporal_batch_size"]]:
280
- """
281
- Return a batch of time points. If all the batches have been seen, we
282
- reshuffle them, otherwise we just return the next unseen batch.
283
- """
284
- if self.temporal_batch_size is None or self.temporal_batch_size == self.nt:
285
- # Avoid unnecessary reshuffling
286
- return self, self.times
287
-
288
- bstart = self.curr_time_idx
289
- bend = bstart + self.temporal_batch_size
290
-
291
- # Compute the effective number of used collocation points
292
- if self.rar_parameters is not None:
293
- nt_eff = (
294
- self.n_start
295
- + self.rar_iter_nb * self.rar_parameters["selected_sample_size"]
296
- )
297
- else:
298
- nt_eff = self.nt
299
-
300
- new_attributes = _reset_or_increment(bend, nt_eff, self._get_time_operands())
301
- new = eqx.tree_at(
302
- lambda m: (m.key, m.times, m.curr_time_idx), self, new_attributes
303
- )
304
-
305
- # commands below are equivalent to
306
- # return self.times[i:(i+t_batch_size)]
307
- # start indices can be dynamic but the slice shape is fixed
308
- return new, jax.lax.dynamic_slice(
309
- new.times,
310
- start_indices=(new.curr_time_idx, 0),
311
- slice_sizes=(new.temporal_batch_size, 1),
312
- )
313
-
314
- def get_batch(self) -> tuple["DataGeneratorODE", ODEBatch]:
315
- """
316
- Generic method to return a batch. Here we call `self.temporal_batch()`
317
- """
318
- new, temporal_batch = self.temporal_batch()
319
- return new, ODEBatch(temporal_batch=temporal_batch)
320
-
321
-
322
- class CubicMeshPDEStatio(eqx.Module):
323
- r"""
324
- A class implementing data generator object for stationary partial
325
- differential equations.
326
-
327
- Parameters
328
- ----------
329
- key : Key
330
- Jax random key to sample new time points and to shuffle batches
331
- n : Int
332
- The number of total $\Omega$ points that will be divided in
333
- batches. Batches are made so that each data point is seen only
334
- once during 1 epoch.
335
- nb : Int | None
336
- The total number of points in $\partial\Omega$. Can be None if no
337
- boundary condition is specified.
338
- omega_batch_size : Int | None, default=None
339
- The size of the batch of randomly selected points among
340
- the `n` points. If None no minibatches are used.
341
- omega_border_batch_size : Int | None, default=None
342
- The size of the batch of points randomly selected
343
- among the `nb` points. If None, `omega_border_batch_size`
344
- no minibatches are used. In dimension 1,
345
- minibatches are never used since the boundary is composed of two
346
- singletons.
347
- dim : Int
348
- Dimension of $\Omega$ domain
349
- min_pts : tuple[tuple[Float, Float], ...]
350
- A tuple of minimum values of the domain along each dimension. For a sampling
351
- in `n` dimension, this represents $(x_{1, min}, x_{2,min}, ...,
352
- x_{n, min})$
353
- max_pts : tuple[tuple[Float, Float], ...]
354
- A tuple of maximum values of the domain along each dimension. For a sampling
355
- in `n` dimension, this represents $(x_{1, max}, x_{2,max}, ...,
356
- x_{n,max})$
357
- method : str, default="uniform"
358
- Either `grid` or `uniform`, default is `uniform`.
359
- The method that generates the `nt` time points. `grid` means
360
- regularly spaced points over the domain. `uniform` means uniformly
361
- sampled points over the domain
362
- rar_parameters : Dict[str, Int], default=None
363
- Defaults to None: do not use Residual Adaptative Resampling.
364
- Otherwise a dictionary with keys
365
-
366
- - `start_iter`: the iteration at which we start the RAR sampling scheme (we first have a "burn-in" period).
367
- - `update_every`: the number of gradient steps taken between
368
- each update of collocation points in the RAR algo.
369
- - `sample_size`: the size of the sample from which we will select new
370
- collocation points.
371
- - `selected_sample_size`: the number of selected
372
- points from the sample to be added to the current collocation
373
- points.
374
- n_start : Int, default=None
375
- Defaults to None. The effective size of n used at start time.
376
- This value must be
377
- provided when rar_parameters is not None. Otherwise we set internally
378
- n_start = n and this is hidden from the user.
379
- In RAR, n_start
380
- then corresponds to the initial number of points we train the PINN on.
381
- """
382
-
383
- # kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
384
- key: Key = eqx.field(kw_only=True)
385
- n: Int = eqx.field(kw_only=True, static=True)
386
- nb: Int | None = eqx.field(kw_only=True, static=True, default=None)
387
- omega_batch_size: Int | None = eqx.field(
388
- kw_only=True,
389
- static=True,
390
- default=None, # can be None as
391
- # CubicMeshPDENonStatio inherits but also if omega_batch_size=n
392
- ) # static cause used as a
393
- # shape in jax.lax.dynamic_slice
394
- omega_border_batch_size: Int | None = eqx.field(
395
- kw_only=True, static=True, default=None
396
- ) # static cause used as a
397
- # shape in jax.lax.dynamic_slice
398
- dim: Int = eqx.field(kw_only=True, static=True) # static cause used as a
399
- # shape in jax.lax.dynamic_slice
400
- min_pts: tuple[tuple[Float, Float], ...] = eqx.field(kw_only=True)
401
- max_pts: tuple[tuple[Float, Float], ...] = eqx.field(kw_only=True)
402
- method: str = eqx.field(
403
- kw_only=True, static=True, default_factory=lambda: "uniform"
404
- )
405
- rar_parameters: Dict[str, Int] = eqx.field(kw_only=True, default=None)
406
- n_start: Int = eqx.field(kw_only=True, default=None, static=True)
407
-
408
- # all the init=False fields are set in __post_init__
409
- p: Float[Array, "n"] = eqx.field(init=False)
410
- rar_iter_from_last_sampling: Int = eqx.field(init=False)
411
- rar_iter_nb: Int = eqx.field(init=False)
412
- curr_omega_idx: Int = eqx.field(init=False)
413
- curr_omega_border_idx: Int = eqx.field(init=False)
414
- omega: Float[Array, "n dim"] = eqx.field(init=False)
415
- omega_border: Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None = eqx.field(
416
- init=False
417
- )
418
-
419
- def __post_init__(self):
420
- assert self.dim == len(self.min_pts) and isinstance(self.min_pts, tuple)
421
- assert self.dim == len(self.max_pts) and isinstance(self.max_pts, tuple)
422
-
423
- (
424
- self.n_start,
425
- self.p,
426
- self.rar_iter_from_last_sampling,
427
- self.rar_iter_nb,
428
- ) = _check_and_set_rar_parameters(self.rar_parameters, self.n, self.n_start)
429
-
430
- if self.method == "grid" and self.dim == 2:
431
- perfect_sq = int(jnp.round(jnp.sqrt(self.n)) ** 2)
432
- if self.n != perfect_sq:
433
- warnings.warn(
434
- "Grid sampling is requested in dimension 2 with a non"
435
- f" perfect square dataset size (self.n = {self.n})."
436
- f" Modifying self.n to self.n = {perfect_sq}."
437
- )
438
- self.n = perfect_sq
439
-
440
- if self.nb is not None:
441
- if self.dim == 1:
442
- self.omega_border_batch_size = None
443
- # We are in 1-D case => omega_border_batch_size is
444
- # ignored since borders of Omega are singletons.
445
- # self.border_batch() will return [xmin, xmax]
446
- else:
447
- if self.nb % (2 * self.dim) != 0 or self.nb < 2 * self.dim:
448
- raise ValueError(
449
- f"number of border point must be"
450
- f" a multiple of 2xd = {2*self.dim} (the # of faces of"
451
- f" a d-dimensional cube). Got {self.nb=}."
452
- )
453
- if (
454
- self.omega_border_batch_size is not None
455
- and self.nb // (2 * self.dim) < self.omega_border_batch_size
456
- ):
457
- raise ValueError(
458
- f"number of points per facets ({self.nb//(2*self.dim)})"
459
- f" cannot be lower than border batch size "
460
- f" ({self.omega_border_batch_size})."
461
- )
462
- self.nb = int((2 * self.dim) * (self.nb // (2 * self.dim)))
463
-
464
- if self.omega_batch_size is None:
465
- self.curr_omega_idx = 0
466
- else:
467
- self.curr_omega_idx = self.n + self.omega_batch_size
468
- # to be sure there is a shuffling at first get_batch()
469
-
470
- if self.omega_border_batch_size is None:
471
- self.curr_omega_border_idx = 0
472
- else:
473
- self.curr_omega_border_idx = self.nb + self.omega_border_batch_size
474
- # to be sure there is a shuffling at first get_batch()
475
-
476
- self.key, self.omega = self.generate_omega_data(self.key)
477
- self.key, self.omega_border = self.generate_omega_border_data(self.key)
478
-
479
- def sample_in_omega_domain(
480
- self, keys: Key, sample_size: Int = None
481
- ) -> Float[Array, "n dim"]:
482
- sample_size = self.n if sample_size is None else sample_size
483
- if self.dim == 1:
484
- xmin, xmax = self.min_pts[0], self.max_pts[0]
485
- return jax.random.uniform(
486
- keys, shape=(sample_size, 1), minval=xmin, maxval=xmax
487
- )
488
- # keys = jax.random.split(key, self.dim)
489
- return jnp.concatenate(
490
- [
491
- jax.random.uniform(
492
- keys[i],
493
- (sample_size, 1),
494
- minval=self.min_pts[i],
495
- maxval=self.max_pts[i],
496
- )
497
- for i in range(self.dim)
498
- ],
499
- axis=-1,
500
- )
501
-
502
- def sample_in_omega_border_domain(
503
- self, keys: Key, sample_size: int = None
504
- ) -> Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None:
505
- sample_size = self.nb if sample_size is None else sample_size
506
- if sample_size is None:
507
- return None
508
- if self.dim == 1:
509
- xmin = self.min_pts[0]
510
- xmax = self.max_pts[0]
511
- return jnp.array([xmin, xmax]).astype(float)
512
- if self.dim == 2:
513
- # currently hard-coded the 4 edges for d==2
514
- # TODO : find a general & efficient way to sample from the border
515
- # (facets) of the hypercube in general dim.
516
- facet_n = sample_size // (2 * self.dim)
517
- xmin = jnp.hstack(
518
- [
519
- self.min_pts[0] * jnp.ones((facet_n, 1)),
520
- jax.random.uniform(
521
- keys[0],
522
- (facet_n, 1),
523
- minval=self.min_pts[1],
524
- maxval=self.max_pts[1],
525
- ),
526
- ]
527
- )
528
- xmax = jnp.hstack(
529
- [
530
- self.max_pts[0] * jnp.ones((facet_n, 1)),
531
- jax.random.uniform(
532
- keys[1],
533
- (facet_n, 1),
534
- minval=self.min_pts[1],
535
- maxval=self.max_pts[1],
536
- ),
537
- ]
538
- )
539
- ymin = jnp.hstack(
540
- [
541
- jax.random.uniform(
542
- keys[2],
543
- (facet_n, 1),
544
- minval=self.min_pts[0],
545
- maxval=self.max_pts[0],
546
- ),
547
- self.min_pts[1] * jnp.ones((facet_n, 1)),
548
- ]
549
- )
550
- ymax = jnp.hstack(
551
- [
552
- jax.random.uniform(
553
- keys[3],
554
- (facet_n, 1),
555
- minval=self.min_pts[0],
556
- maxval=self.max_pts[0],
557
- ),
558
- self.max_pts[1] * jnp.ones((facet_n, 1)),
559
- ]
560
- )
561
- return jnp.stack([xmin, xmax, ymin, ymax], axis=-1)
562
- raise NotImplementedError(
563
- "Generation of the border of a cube in dimension > 2 is not "
564
- + f"implemented yet. You are asking for generation in dimension d={self.dim}."
565
- )
566
-
567
- def generate_omega_data(self, key: Key, data_size: int = None) -> tuple[
568
- Key,
569
- Float[Array, "n dim"],
570
- ]:
571
- r"""
572
- Construct a complete set of `self.n` $\Omega$ points according to the
573
- specified `self.method`.
574
- """
575
- data_size = self.n if data_size is None else data_size
576
- # Generate Omega
577
- if self.method == "grid":
578
- if self.dim == 1:
579
- xmin, xmax = self.min_pts[0], self.max_pts[0]
580
- ## shape (n, 1)
581
- omega = jnp.linspace(xmin, xmax, data_size)[:, None]
582
- else:
583
- xyz_ = jnp.meshgrid(
584
- *[
585
- jnp.linspace(
586
- self.min_pts[i],
587
- self.max_pts[i],
588
- int(jnp.round(jnp.sqrt(data_size))),
589
- )
590
- for i in range(self.dim)
591
- ]
592
- )
593
- xyz_ = [a.reshape((data_size, 1)) for a in xyz_]
594
- omega = jnp.concatenate(xyz_, axis=-1)
595
- elif self.method == "uniform":
596
- if self.dim == 1:
597
- key, subkeys = jax.random.split(key, 2)
598
- else:
599
- key, *subkeys = jax.random.split(key, self.dim + 1)
600
- omega = self.sample_in_omega_domain(subkeys, sample_size=data_size)
601
- else:
602
- raise ValueError("Method " + self.method + " is not implemented.")
603
- return key, omega
604
-
605
- def generate_omega_border_data(self, key: Key, data_size: int = None) -> tuple[
606
- Key,
607
- Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None,
608
- ]:
609
- r"""
610
- Also constructs a complete set of `self.nb`
611
- $\partial\Omega$ points if `self.omega_border_batch_size` is not
612
- `None`. If the latter is `None` we set `self.omega_border` to `None`.
613
- """
614
- # Generate border of omega
615
- data_size = self.nb if data_size is None else data_size
616
- if self.dim == 2:
617
- key, *subkeys = jax.random.split(key, 5)
618
- else:
619
- subkeys = None
620
- omega_border = self.sample_in_omega_border_domain(
621
- subkeys, sample_size=data_size
622
- )
623
-
624
- return key, omega_border
625
-
626
- def _get_omega_operands(
627
- self,
628
- ) -> tuple[Key, Float[Array, "n dim"], Int, Int, Float[Array, "n"]]:
629
- return (
630
- self.key,
631
- self.omega,
632
- self.curr_omega_idx,
633
- self.omega_batch_size,
634
- self.p,
635
- )
636
-
637
- def inside_batch(
638
- self,
639
- ) -> tuple["CubicMeshPDEStatio", Float[Array, "omega_batch_size dim"]]:
640
- r"""
641
- Return a batch of points in $\Omega$.
642
- If all the batches have been seen, we reshuffle them,
643
- otherwise we just return the next unseen batch.
644
- """
645
- if self.omega_batch_size is None or self.omega_batch_size == self.n:
646
- # Avoid unnecessary reshuffling
647
- return self, self.omega
648
-
649
- # Compute the effective number of used collocation points
650
- if self.rar_parameters is not None:
651
- n_eff = (
652
- self.n_start
653
- + self.rar_iter_nb * self.rar_parameters["selected_sample_size"]
654
- )
655
- else:
656
- n_eff = self.n
657
-
658
- bstart = self.curr_omega_idx
659
- bend = bstart + self.omega_batch_size
660
-
661
- new_attributes = _reset_or_increment(bend, n_eff, self._get_omega_operands())
662
- new = eqx.tree_at(
663
- lambda m: (m.key, m.omega, m.curr_omega_idx), self, new_attributes
664
- )
665
-
666
- return new, jax.lax.dynamic_slice(
667
- new.omega,
668
- start_indices=(new.curr_omega_idx, 0),
669
- slice_sizes=(new.omega_batch_size, new.dim),
670
- )
671
-
672
- def _get_omega_border_operands(
673
- self,
674
- ) -> tuple[
675
- Key, Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None, Int, Int, None
676
- ]:
677
- return (
678
- self.key,
679
- self.omega_border,
680
- self.curr_omega_border_idx,
681
- self.omega_border_batch_size,
682
- None,
683
- )
684
-
685
- def border_batch(
686
- self,
687
- ) -> tuple[
688
- "CubicMeshPDEStatio",
689
- Float[Array, "1 1 2"] | Float[Array, "omega_border_batch_size 2 4"] | None,
690
- ]:
691
- r"""
692
- Return
693
-
694
- - The value `None` if `self.omega_border_batch_size` is `None`.
695
-
696
- - a jnp array with two fixed values $(x_{min}, x_{max})$ if
697
- `self.dim` = 1. There is no sampling here, we return the entire
698
- $\partial\Omega$
699
-
700
- - a batch of points in $\partial\Omega$ otherwise, stacked by
701
- facet on the last axis.
702
- If all the batches have been seen, we reshuffle them,
703
- otherwise we just return the next unseen batch.
704
-
705
-
706
- """
707
- if self.nb is None:
708
- # Avoid unnecessary reshuffling
709
- return self, None
710
-
711
- if self.dim == 1:
712
- # Avoid unnecessary reshuffling
713
- # 1-D case, no randomness : we always return the whole omega border,
714
- # i.e. (1, 1, 2) shape jnp.array([[[xmin], [xmax]]]).
715
- return self, self.omega_border[None, None] # shape is (1, 1, 2)
716
-
717
- if (
718
- self.omega_border_batch_size is None
719
- or self.omega_border_batch_size == self.nb // 2**self.dim
720
- ):
721
- # Avoid unnecessary reshuffling
722
- return self, self.omega_border
723
-
724
- bstart = self.curr_omega_border_idx
725
- bend = bstart + self.omega_border_batch_size
726
-
727
- new_attributes = _reset_or_increment(
728
- bend, self.nb, self._get_omega_border_operands()
729
- )
730
- new = eqx.tree_at(
731
- lambda m: (m.key, m.omega_border, m.curr_omega_border_idx),
732
- self,
733
- new_attributes,
734
- )
735
-
736
- return new, jax.lax.dynamic_slice(
737
- new.omega_border,
738
- start_indices=(new.curr_omega_border_idx, 0, 0),
739
- slice_sizes=(new.omega_border_batch_size, new.dim, 2 * new.dim),
740
- )
741
-
742
- def get_batch(self) -> tuple["CubicMeshPDEStatio", PDEStatioBatch]:
743
- """
744
- Generic method to return a batch. Here we call `self.inside_batch()`
745
- and `self.border_batch()`
746
- """
747
- new, inside_batch = self.inside_batch()
748
- new, border_batch = new.border_batch()
749
- return new, PDEStatioBatch(domain_batch=inside_batch, border_batch=border_batch)
750
-
751
-
752
- class CubicMeshPDENonStatio(CubicMeshPDEStatio):
753
- r"""
754
- A class implementing data generator object for non stationary partial
755
- differential equations. Formally, it extends `CubicMeshPDEStatio`
756
- to include a temporal batch.
757
-
758
- Parameters
759
- ----------
760
- key : Key
761
- Jax random key to sample new time points and to shuffle batches
762
- n : Int
763
- The number of total $I\times \Omega$ points that will be divided in
764
- batches. Batches are made so that each data point is seen only
765
- once during 1 epoch.
766
- nb : Int | None
767
- The total number of points in $\partial\Omega$. Can be None if no
768
- boundary condition is specified.
769
- ni : Int
770
- The number of total $\Omega$ points at $t=0$ that will be divided in
771
- batches. Batches are made so that each data point is seen only
772
- once during 1 epoch.
773
- domain_batch_size : Int | None, default=None
774
- The size of the batch of randomly selected points among
775
- the `n` points. If None no mini-batches are used.
776
- border_batch_size : Int | None, default=None
777
- The size of the batch of points randomly selected
778
- among the `nb` points. If None, `domain_batch_size` no
779
- mini-batches are used.
780
- initial_batch_size : Int | None, default=None
781
- The size of the batch of randomly selected points among
782
- the `ni` points. If None no
783
- mini-batches are used.
784
- dim : Int
785
- An integer. Dimension of $\Omega$ domain.
786
- min_pts : tuple[tuple[Float, Float], ...]
787
- A tuple of minimum values of the domain along each dimension. For a sampling
788
- in `n` dimension, this represents $(x_{1, min}, x_{2,min}, ...,
789
- x_{n, min})$
790
- max_pts : tuple[tuple[Float, Float], ...]
791
- A tuple of maximum values of the domain along each dimension. For a sampling
792
- in `n` dimension, this represents $(x_{1, max}, x_{2,max}, ...,
793
- x_{n,max})$
794
- tmin : float
795
- The minimum value of the time domain to consider
796
- tmax : float
797
- The maximum value of the time domain to consider
798
- method : str, default="uniform"
799
- Either `grid` or `uniform`, default is `uniform`.
800
- The method that generates the `nt` time points. `grid` means
801
- regularly spaced points over the domain. `uniform` means uniformly
802
- sampled points over the domain
803
- rar_parameters : Dict[str, Int], default=None
804
- Defaults to None: do not use Residual Adaptative Resampling.
805
- Otherwise a dictionary with keys
806
-
807
- - `start_iter`: the iteration at which we start the RAR sampling scheme (we first have a "burn-in" period).
808
- - `update_every`: the number of gradient steps taken between
809
- each update of collocation points in the RAR algo.
810
- - `sample_size`: the size of the sample from which we will select new
811
- collocation points.
812
- - `selected_sample_size`: the number of selected
813
- points from the sample to be added to the current collocation
814
- points.
815
- n_start : Int, default=None
816
- Defaults to None. The effective size of n used at start time.
817
- This value must be
818
- provided when rar_parameters is not None. Otherwise we set internally
819
- n_start = n and this is hidden from the user.
820
- In RAR, n_start
821
- then corresponds to the initial number of omega points we train the PINN.
822
- """
823
-
824
- tmin: Float = eqx.field(kw_only=True)
825
- tmax: Float = eqx.field(kw_only=True)
826
- ni: Int = eqx.field(kw_only=True, static=True)
827
- domain_batch_size: Int | None = eqx.field(kw_only=True, static=True, default=None)
828
- initial_batch_size: Int | None = eqx.field(kw_only=True, static=True, default=None)
829
- border_batch_size: Int | None = eqx.field(kw_only=True, static=True, default=None)
830
-
831
- curr_domain_idx: Int = eqx.field(init=False)
832
- curr_initial_idx: Int = eqx.field(init=False)
833
- curr_border_idx: Int = eqx.field(init=False)
834
- domain: Float[Array, "n 1+dim"] = eqx.field(init=False)
835
- border: Float[Array, "(nb//2) 1+1 2"] | Float[Array, "(nb//4) 2+1 4"] | None = (
836
- eqx.field(init=False)
837
- )
838
- initial: Float[Array, "ni dim"] = eqx.field(init=False)
839
-
840
- def __post_init__(self):
841
- """
842
- Note that neither __init__ or __post_init__ are called when udating a
843
- Module with eqx.tree_at!
844
- """
845
- super().__post_init__() # because __init__ or __post_init__ of Base
846
- # class is not automatically called
847
-
848
- if self.method == "grid":
849
- # NOTE we must redo the sampling with the square root number of samples
850
- # and then take the cartesian product
851
- self.n = int(jnp.round(jnp.sqrt(self.n)) ** 2)
852
- if self.dim == 2:
853
- # in the case of grid sampling in 2D in dim 2 in non-statio,
854
- # self.n needs to be a perfect ^4, because there is the
855
- # cartesian product with time domain which is also present
856
- perfect_4 = int(jnp.round(self.n**0.25) ** 4)
857
- if self.n != perfect_4:
858
- warnings.warn(
859
- "Grid sampling is requested in dimension 2 in non"
860
- " stationary setting with a non"
861
- f" perfect square dataset size (self.n = {self.n})."
862
- f" Modifying self.n to self.n = {perfect_4}."
863
- )
864
- self.n = perfect_4
865
- self.key, half_domain_times = self.generate_time_data(
866
- self.key, int(jnp.round(jnp.sqrt(self.n)))
867
- )
868
-
869
- self.key, half_domain_omega = self.generate_omega_data(
870
- self.key, data_size=int(jnp.round(jnp.sqrt(self.n)))
871
- )
872
- self.domain = make_cartesian_product(half_domain_times, half_domain_omega)
873
-
874
- # NOTE
875
- (
876
- self.n_start,
877
- self.p,
878
- self.rar_iter_from_last_sampling,
879
- self.rar_iter_nb,
880
- ) = _check_and_set_rar_parameters(self.rar_parameters, self.n, self.n_start)
881
- elif self.method == "uniform":
882
- self.key, domain_times = self.generate_time_data(self.key, self.n)
883
- self.domain = jnp.concatenate([domain_times, self.omega], axis=1)
884
- else:
885
- raise ValueError(
886
- f"Bad value for method. Got {self.method}, expected"
887
- ' "grid" or "uniform"'
888
- )
889
-
890
- if self.domain_batch_size is None:
891
- self.curr_domain_idx = 0
892
- else:
893
- self.curr_domain_idx = self.n + self.domain_batch_size
894
- # to be sure there is a shuffling at first get_batch()
895
- if self.nb is not None:
896
- # the check below has already been done in super.__post_init__ if
897
- # dim > 1. Here we retest it in whatever dim
898
- if self.nb % (2 * self.dim) != 0 or self.nb < 2 * self.dim:
899
- raise ValueError(
900
- "number of border point must be"
901
- " a multiple of 2xd (the # of faces of a d-dimensional cube)"
902
- )
903
- # the check below concern omega_border_batch_size for dim > 1 in
904
- # super.__post_init__. Here it concerns all dim values since our
905
- # border_batch is the concatenation or cartesian product with times
906
- if (
907
- self.border_batch_size is not None
908
- and self.nb // (2 * self.dim) < self.border_batch_size
909
- ):
910
- raise ValueError(
911
- "number of points per facets (nb//2*self.dim)"
912
- " cannot be lower than border batch size"
913
- )
914
- self.key, boundary_times = self.generate_time_data(
915
- self.key, self.nb // (2 * self.dim)
916
- )
917
- boundary_times = boundary_times.reshape(-1, 1, 1)
918
- boundary_times = jnp.repeat(
919
- boundary_times, self.omega_border.shape[-1], axis=2
920
- )
921
- if self.dim == 1:
922
- self.border = make_cartesian_product(
923
- boundary_times, self.omega_border[None, None]
924
- )
925
- else:
926
- self.border = jnp.concatenate(
927
- [boundary_times, self.omega_border], axis=1
928
- )
929
- if self.border_batch_size is None:
930
- self.curr_border_idx = 0
931
- else:
932
- self.curr_border_idx = self.nb + self.border_batch_size
933
- # to be sure there is a shuffling at first get_batch()
934
-
935
- else:
936
- self.border = None
937
- self.curr_border_idx = None
938
- self.border_batch_size = None
939
-
940
- if self.ni is not None:
941
- perfect_sq = int(jnp.round(jnp.sqrt(self.ni)) ** 2)
942
- if self.ni != perfect_sq:
943
- warnings.warn(
944
- "Grid sampling is requested in dimension 2 with a non"
945
- f" perfect square dataset size (self.ni = {self.ni})."
946
- f" Modifying self.ni to self.ni = {perfect_sq}."
947
- )
948
- self.ni = perfect_sq
949
- self.key, self.initial = self.generate_omega_data(
950
- self.key, data_size=self.ni
951
- )
952
-
953
- if self.initial_batch_size is None or self.initial_batch_size == self.ni:
954
- self.curr_initial_idx = 0
955
- else:
956
- self.curr_initial_idx = self.ni + self.initial_batch_size
957
- # to be sure there is a shuffling at first get_batch()
958
- else:
959
- self.initial = None
960
- self.initial_batch_size = None
961
- self.curr_initial_idx = None
962
-
963
- # the following attributes will not be used anymore
964
- self.omega = None
965
- self.omega_border = None
966
-
967
- def generate_time_data(self, key: Key, nt: Int) -> tuple[Key, Float[Array, "nt 1"]]:
968
- """
969
- Construct a complete set of `nt` time points according to the
970
- specified `self.method`
971
- """
972
- key, subkey = jax.random.split(key, 2)
973
- if self.method == "grid":
974
- partial_times = (self.tmax - self.tmin) / nt
975
- return key, jnp.arange(self.tmin, self.tmax, partial_times)[:, None]
976
- if self.method == "uniform":
977
- return key, self.sample_in_time_domain(subkey, nt)
978
- raise ValueError("Method " + self.method + " is not implemented.")
979
-
980
- def sample_in_time_domain(self, key: Key, nt: Int) -> Float[Array, "nt 1"]:
981
- return jax.random.uniform(
982
- key,
983
- (nt, 1),
984
- minval=self.tmin,
985
- maxval=self.tmax,
986
- )
987
-
988
- def _get_domain_operands(
989
- self,
990
- ) -> tuple[Key, Float[Array, "n 1+dim"], Int, Int, None]:
991
- return (
992
- self.key,
993
- self.domain,
994
- self.curr_domain_idx,
995
- self.domain_batch_size,
996
- self.p,
997
- )
998
-
999
- def domain_batch(
1000
- self,
1001
- ) -> tuple["CubicMeshPDEStatio", Float[Array, "domain_batch_size 1+dim"]]:
1002
-
1003
- if self.domain_batch_size is None or self.domain_batch_size == self.n:
1004
- # Avoid unnecessary reshuffling
1005
- return self, self.domain
1006
-
1007
- bstart = self.curr_domain_idx
1008
- bend = bstart + self.domain_batch_size
1009
-
1010
- # Compute the effective number of used collocation points
1011
- if self.rar_parameters is not None:
1012
- n_eff = (
1013
- self.n_start
1014
- + self.rar_iter_nb * self.rar_parameters["selected_sample_size"]
1015
- )
1016
- else:
1017
- n_eff = self.n
1018
-
1019
- new_attributes = _reset_or_increment(bend, n_eff, self._get_domain_operands())
1020
- new = eqx.tree_at(
1021
- lambda m: (m.key, m.domain, m.curr_domain_idx),
1022
- self,
1023
- new_attributes,
1024
- )
1025
- return new, jax.lax.dynamic_slice(
1026
- new.domain,
1027
- start_indices=(new.curr_domain_idx, 0),
1028
- slice_sizes=(new.domain_batch_size, new.dim + 1),
1029
- )
1030
-
1031
- def _get_border_operands(
1032
- self,
1033
- ) -> tuple[
1034
- Key, Float[Array, "nb 1+1 2"] | Float[Array, "(nb//4) 2+1 4"], Int, Int, None
1035
- ]:
1036
- return (
1037
- self.key,
1038
- self.border,
1039
- self.curr_border_idx,
1040
- self.border_batch_size,
1041
- None,
1042
- )
1043
-
1044
- def border_batch(
1045
- self,
1046
- ) -> tuple[
1047
- "CubicMeshPDENonStatio",
1048
- Float[Array, "border_batch_size 1+1 2"]
1049
- | Float[Array, "border_batch_size 2+1 4"]
1050
- | None,
1051
- ]:
1052
- if self.nb is None:
1053
- # Avoid unnecessary reshuffling
1054
- return self, None
1055
-
1056
- if (
1057
- self.border_batch_size is None
1058
- or self.border_batch_size == self.nb // 2**self.dim
1059
- ):
1060
- # Avoid unnecessary reshuffling
1061
- return self, self.border
1062
-
1063
- bstart = self.curr_border_idx
1064
- bend = bstart + self.border_batch_size
1065
-
1066
- n_eff = self.border.shape[0]
1067
-
1068
- new_attributes = _reset_or_increment(bend, n_eff, self._get_border_operands())
1069
- new = eqx.tree_at(
1070
- lambda m: (m.key, m.border, m.curr_border_idx),
1071
- self,
1072
- new_attributes,
1073
- )
1074
-
1075
- return new, jax.lax.dynamic_slice(
1076
- new.border,
1077
- start_indices=(new.curr_border_idx, 0, 0),
1078
- slice_sizes=(
1079
- new.border_batch_size,
1080
- new.dim + 1,
1081
- 2 * new.dim,
1082
- ),
1083
- )
1084
-
1085
- def _get_initial_operands(
1086
- self,
1087
- ) -> tuple[Key, Float[Array, "ni dim"], Int, Int, None]:
1088
- return (
1089
- self.key,
1090
- self.initial,
1091
- self.curr_initial_idx,
1092
- self.initial_batch_size,
1093
- None,
1094
- )
1095
-
1096
- def initial_batch(
1097
- self,
1098
- ) -> tuple["CubicMeshPDEStatio", Float[Array, "initial_batch_size dim"]]:
1099
- if self.initial_batch_size is None or self.initial_batch_size == self.ni:
1100
- # Avoid unnecessary reshuffling
1101
- return self, self.initial
1102
-
1103
- bstart = self.curr_initial_idx
1104
- bend = bstart + self.initial_batch_size
1105
-
1106
- n_eff = self.ni
1107
-
1108
- new_attributes = _reset_or_increment(bend, n_eff, self._get_initial_operands())
1109
- new = eqx.tree_at(
1110
- lambda m: (m.key, m.initial, m.curr_initial_idx),
1111
- self,
1112
- new_attributes,
1113
- )
1114
- return new, jax.lax.dynamic_slice(
1115
- new.initial,
1116
- start_indices=(new.curr_initial_idx, 0),
1117
- slice_sizes=(new.initial_batch_size, new.dim),
1118
- )
1119
-
1120
- def get_batch(self) -> tuple["CubicMeshPDENonStatio", PDENonStatioBatch]:
1121
- """
1122
- Generic method to return a batch. Here we call `self.domain_batch()`,
1123
- `self.border_batch()` and `self.initial_batch()`
1124
- """
1125
- new, domain = self.domain_batch()
1126
- if self.border is not None:
1127
- new, border = new.border_batch()
1128
- else:
1129
- border = None
1130
- if self.initial is not None:
1131
- new, initial = new.initial_batch()
1132
- else:
1133
- initial = None
1134
-
1135
- return new, PDENonStatioBatch(
1136
- domain_batch=domain, border_batch=border, initial_batch=initial
1137
- )
1138
-
1139
-
1140
- class DataGeneratorObservations(eqx.Module):
1141
- r"""
1142
- Despite the class name, it is rather a dataloader for user-provided
1143
- observations which will are used in the observations loss.
1144
-
1145
- Parameters
1146
- ----------
1147
- key : Key
1148
- Jax random key to shuffle batches
1149
- obs_batch_size : Int | None
1150
- The size of the batch of randomly selected points among
1151
- the `n` points. If None, no minibatch are used.
1152
- observed_pinn_in : Float[Array, "n_obs nb_pinn_in"]
1153
- Observed values corresponding to the input of the PINN
1154
- (eg. the time at which we recorded the observations). The first
1155
- dimension must corresponds to the number of observed_values.
1156
- The second dimension depends on the input dimension of the PINN,
1157
- that is `1` for ODE, `n_dim_x` for stationnary PDE and `n_dim_x + 1`
1158
- for non-stationnary PDE.
1159
- observed_values : Float[Array, "n_obs, nb_pinn_out"]
1160
- Observed values that the PINN should learn to fit. The first
1161
- dimension must be aligned with observed_pinn_in.
1162
- observed_eq_params : Dict[str, Float[Array, "n_obs 1"]], default={}
1163
- A dict with keys corresponding to
1164
- the parameter name. The keys must match the keys in
1165
- `params["eq_params"]`. The values are jnp.array with 2 dimensions
1166
- with values corresponding to the parameter value for which we also
1167
- have observed_pinn_in and observed_values. Hence the first
1168
- dimension must be aligned with observed_pinn_in and observed_values.
1169
- Optional argument.
1170
- sharding_device : jax.sharding.Sharding, default=None
1171
- Default None. An optional sharding object to constraint the storage
1172
- of observed inputs, values and parameters. Typically, a
1173
- SingleDeviceSharding(cpu_device) to avoid loading on GPU huge
1174
- datasets of observations. Note that computations for **batches**
1175
- can still be performed on other devices (*e.g.* GPU, TPU or
1176
- any pre-defined Sharding) thanks to the `obs_batch_sharding`
1177
- arguments of `jinns.solve()`. Read `jinns.solve()` doc for more info.
1178
- """
1179
-
1180
- key: Key
1181
- obs_batch_size: Int | None = eqx.field(static=True)
1182
- observed_pinn_in: Float[Array, "n_obs nb_pinn_in"]
1183
- observed_values: Float[Array, "n_obs nb_pinn_out"]
1184
- observed_eq_params: Dict[str, Float[Array, "n_obs 1"]] = eqx.field(
1185
- static=True, default_factory=lambda: {}
1186
- )
1187
- sharding_device: jax.sharding.Sharding = eqx.field(static=True, default=None)
1188
-
1189
- n: Int = eqx.field(init=False, static=True)
1190
- curr_idx: Int = eqx.field(init=False)
1191
- indices: Array = eqx.field(init=False)
1192
-
1193
- def __post_init__(self):
1194
- if self.observed_pinn_in.shape[0] != self.observed_values.shape[0]:
1195
- raise ValueError(
1196
- "self.observed_pinn_in and self.observed_values must have same first axis"
1197
- )
1198
- for _, v in self.observed_eq_params.items():
1199
- if v.shape[0] != self.observed_pinn_in.shape[0]:
1200
- raise ValueError(
1201
- "self.observed_pinn_in and the values of"
1202
- " self.observed_eq_params must have the same first axis"
1203
- )
1204
- if len(self.observed_pinn_in.shape) == 1:
1205
- self.observed_pinn_in = self.observed_pinn_in[:, None]
1206
- if len(self.observed_pinn_in.shape) > 2:
1207
- raise ValueError("self.observed_pinn_in must have 2 dimensions")
1208
- if len(self.observed_values.shape) == 1:
1209
- self.observed_values = self.observed_values[:, None]
1210
- if len(self.observed_values.shape) > 2:
1211
- raise ValueError("self.observed_values must have 2 dimensions")
1212
- for k, v in self.observed_eq_params.items():
1213
- if len(v.shape) == 1:
1214
- self.observed_eq_params[k] = v[:, None]
1215
- if len(v.shape) > 2:
1216
- raise ValueError(
1217
- "Each value of observed_eq_params must have 2 dimensions"
1218
- )
1219
-
1220
- self.n = self.observed_pinn_in.shape[0]
1221
-
1222
- if self.sharding_device is not None:
1223
- self.observed_pinn_in = jax.lax.with_sharding_constraint(
1224
- self.observed_pinn_in, self.sharding_device
1225
- )
1226
- self.observed_values = jax.lax.with_sharding_constraint(
1227
- self.observed_values, self.sharding_device
1228
- )
1229
- self.observed_eq_params = jax.lax.with_sharding_constraint(
1230
- self.observed_eq_params, self.sharding_device
1231
- )
1232
-
1233
- if self.obs_batch_size is not None:
1234
- self.curr_idx = self.n + self.obs_batch_size
1235
- # to be sure there is a shuffling at first get_batch()
1236
- else:
1237
- self.curr_idx = 0
1238
- # For speed and to avoid duplicating data what is really
1239
- # shuffled is a vector of indices
1240
- if self.sharding_device is not None:
1241
- self.indices = jax.lax.with_sharding_constraint(
1242
- jnp.arange(self.n), self.sharding_device
1243
- )
1244
- else:
1245
- self.indices = jnp.arange(self.n)
1246
-
1247
- # recall post_init is the only place with _init_ where we can set
1248
- # self attribute in a in-place way
1249
- self.key, _ = jax.random.split(self.key, 2) # to make it equivalent to
1250
- # the call to _reset_batch_idx_and_permute in legacy DG
1251
-
1252
- def _get_operands(self) -> tuple[Key, Int[Array, "n"], Int, Int, None]:
1253
- return (
1254
- self.key,
1255
- self.indices,
1256
- self.curr_idx,
1257
- self.obs_batch_size,
1258
- None,
1259
- )
1260
-
1261
- def obs_batch(
1262
- self,
1263
- ) -> tuple[
1264
- "DataGeneratorObservations", Dict[str, Float[Array, "obs_batch_size dim"]]
1265
- ]:
1266
- """
1267
- Return a dictionary with (keys, values): (pinn_in, a mini batch of pinn
1268
- inputs), (obs, a mini batch of corresponding observations), (eq_params,
1269
- a dictionary with entry names found in `params["eq_params"]` and values
1270
- giving the correspond parameter value for the couple
1271
- (input, observation) mentioned before).
1272
- It can also be a dictionary of dictionaries as described above if
1273
- observed_pinn_in, observed_values, etc. are dictionaries with keys
1274
- representing the PINNs.
1275
- """
1276
- if self.obs_batch_size is None or self.obs_batch_size == self.n:
1277
- # Avoid unnecessary reshuffling
1278
- return self, {
1279
- "pinn_in": self.observed_pinn_in,
1280
- "val": self.observed_values,
1281
- "eq_params": self.observed_eq_params,
1282
- }
1283
-
1284
- new_attributes = _reset_or_increment(
1285
- self.curr_idx + self.obs_batch_size, self.n, self._get_operands()
1286
- )
1287
- new = eqx.tree_at(
1288
- lambda m: (m.key, m.indices, m.curr_idx), self, new_attributes
1289
- )
1290
-
1291
- minib_indices = jax.lax.dynamic_slice(
1292
- new.indices,
1293
- start_indices=(new.curr_idx,),
1294
- slice_sizes=(new.obs_batch_size,),
1295
- )
1296
-
1297
- obs_batch = {
1298
- "pinn_in": jnp.take(
1299
- new.observed_pinn_in, minib_indices, unique_indices=True, axis=0
1300
- ),
1301
- "val": jnp.take(
1302
- new.observed_values, minib_indices, unique_indices=True, axis=0
1303
- ),
1304
- "eq_params": jax.tree_util.tree_map(
1305
- lambda a: jnp.take(a, minib_indices, unique_indices=True, axis=0),
1306
- new.observed_eq_params,
1307
- ),
1308
- }
1309
- return new, obs_batch
1310
-
1311
- def get_batch(
1312
- self,
1313
- ) -> tuple[
1314
- "DataGeneratorObservations", Dict[str, Float[Array, "obs_batch_size dim"]]
1315
- ]:
1316
- """
1317
- Generic method to return a batch
1318
- """
1319
- return self.obs_batch()
1320
-
1321
-
1322
- class DataGeneratorParameter(eqx.Module):
1323
- r"""
1324
- A data generator for additional unidimensional equation parameter(s).
1325
- Mostly useful for metamodeling where batch of `params.eq_params` are fed
1326
- to the network.
1327
-
1328
- Parameters
1329
- ----------
1330
- keys : Key | Dict[str, Key]
1331
- Jax random key to sample new time points and to shuffle batches
1332
- or a dict of Jax random keys with key entries from param_ranges
1333
- n : Int
1334
- The number of total points that will be divided in
1335
- batches. Batches are made so that each data point is seen only
1336
- once during 1 epoch.
1337
- param_batch_size : Int | None, default=None
1338
- The size of the batch of randomly selected points among
1339
- the `n` points. **Important**: no check is performed but
1340
- `param_batch_size` must be the same as other collocation points
1341
- batch_size (time, space or timexspace depending on the context). This is because we vmap the network on all its axes at once to compute the MSE. Also, `param_batch_size` will be the same for all parameters. If None, no mini-batches are used.
1342
- param_ranges : Dict[str, tuple[Float, Float] | None, default={}
1343
- A dict. A dict of tuples (min, max), which
1344
- reprensents the range of real numbers where to sample batches (of
1345
- length `param_batch_size` among `n` points).
1346
- The key corresponds to the parameter name. The keys must match the
1347
- keys in `params["eq_params"]`.
1348
- By providing several entries in this dictionary we can sample
1349
- an arbitrary number of parameters.
1350
- **Note** that we currently only support unidimensional parameters.
1351
- This argument can be None if we use `user_data`.
1352
- method : str, default="uniform"
1353
- Either `grid` or `uniform`, default is `uniform`. `grid` means
1354
- regularly spaced points over the domain. `uniform` means uniformly
1355
- sampled points over the domain
1356
- user_data : Dict[str, Float[jnp.ndarray, "n"]] | None, default={}
1357
- A dictionary containing user-provided data for parameters.
1358
- The keys corresponds to the parameter name,
1359
- and must match the keys in `params["eq_params"]`. Only
1360
- unidimensional `jnp.array` are supported. Therefore, the array at
1361
- `user_data[k]` must have shape `(n, 1)` or `(n,)`.
1362
- Note that if the same key appears in `param_ranges` and `user_data`
1363
- priority goes for the content in `user_data`.
1364
- Defaults to None.
1365
- """
1366
-
1367
- keys: Key | Dict[str, Key]
1368
- n: Int = eqx.field(static=True)
1369
- param_batch_size: Int | None = eqx.field(static=True, default=None)
1370
- param_ranges: Dict[str, tuple[Float, Float]] = eqx.field(
1371
- static=True, default_factory=lambda: {}
1372
- )
1373
- method: str = eqx.field(static=True, default="uniform")
1374
- user_data: Dict[str, Float[onp.Array, "n"]] | None = eqx.field(
1375
- default_factory=lambda: {}
1376
- )
1377
-
1378
- curr_param_idx: Dict[str, Int] = eqx.field(init=False)
1379
- param_n_samples: Dict[str, Array] = eqx.field(init=False)
1380
-
1381
- def __post_init__(self):
1382
- if self.user_data is None:
1383
- self.user_data = {}
1384
- if self.param_ranges is None:
1385
- self.param_ranges = {}
1386
- if self.n < self.param_batch_size:
1387
- raise ValueError(
1388
- f"Number of data points ({self.n}) is smaller than the"
1389
- f"number of batch points ({self.param_batch_size})."
1390
- )
1391
- if not isinstance(self.keys, dict):
1392
- all_keys = set().union(self.param_ranges, self.user_data)
1393
- self.keys = dict(zip(all_keys, jax.random.split(self.keys, len(all_keys))))
1394
-
1395
- if self.param_batch_size is None:
1396
- self.curr_param_idx = None
1397
- else:
1398
- self.curr_param_idx = {}
1399
- for k in self.keys.keys():
1400
- self.curr_param_idx[k] = self.n + self.param_batch_size
1401
- # to be sure there is a shuffling at first get_batch()
1402
-
1403
- # The call to self.generate_data() creates
1404
- # the dict self.param_n_samples and then we will only use this one
1405
- # because it merges the scattered data between `user_data` and
1406
- # `param_ranges`
1407
- self.keys, self.param_n_samples = self.generate_data(self.keys)
1408
-
1409
- def generate_data(
1410
- self, keys: Dict[str, Key]
1411
- ) -> tuple[Dict[str, Key], Dict[str, Float[Array, "n"]]]:
1412
- """
1413
- Generate parameter samples, either through generation
1414
- or using user-provided data.
1415
- """
1416
- param_n_samples = {}
1417
-
1418
- all_keys = set().union(self.param_ranges, self.user_data)
1419
- for k in all_keys:
1420
- if (
1421
- self.user_data
1422
- and k in self.user_data.keys() # pylint: disable=no-member
1423
- ):
1424
- if self.user_data[k].shape == (self.n, 1):
1425
- param_n_samples[k] = self.user_data[k]
1426
- if self.user_data[k].shape == (self.n,):
1427
- param_n_samples[k] = self.user_data[k][:, None]
1428
- else:
1429
- raise ValueError(
1430
- "Wrong shape for user provided parameters"
1431
- f" in user_data dictionary at key='{k}'"
1432
- )
1433
- else:
1434
- if self.method == "grid":
1435
- xmin, xmax = self.param_ranges[k][0], self.param_ranges[k][1]
1436
- partial = (xmax - xmin) / self.n
1437
- # shape (n, 1)
1438
- param_n_samples[k] = jnp.arange(xmin, xmax, partial)[:, None]
1439
- elif self.method == "uniform":
1440
- xmin, xmax = self.param_ranges[k][0], self.param_ranges[k][1]
1441
- keys[k], subkey = jax.random.split(keys[k], 2)
1442
- param_n_samples[k] = jax.random.uniform(
1443
- subkey, shape=(self.n, 1), minval=xmin, maxval=xmax
1444
- )
1445
- else:
1446
- raise ValueError("Method " + self.method + " is not implemented.")
1447
-
1448
- return keys, param_n_samples
1449
-
1450
- def _get_param_operands(
1451
- self, k: str
1452
- ) -> tuple[Key, Float[Array, "n"], Int, Int, None]:
1453
- return (
1454
- self.keys[k],
1455
- self.param_n_samples[k],
1456
- self.curr_param_idx[k],
1457
- self.param_batch_size,
1458
- None,
1459
- )
1460
-
1461
- def param_batch(self):
1462
- """
1463
- Return a dictionary with batches of parameters
1464
- If all the batches have been seen, we reshuffle them,
1465
- otherwise we just return the next unseen batch.
1466
- """
1467
-
1468
- if self.param_batch_size is None or self.param_batch_size == self.n:
1469
- return self, self.param_n_samples
1470
-
1471
- def _reset_or_increment_wrapper(param_k, idx_k, key_k):
1472
- return _reset_or_increment(
1473
- idx_k + self.param_batch_size,
1474
- self.n,
1475
- (key_k, param_k, idx_k, self.param_batch_size, None),
1476
- )
1477
-
1478
- res = jax.tree_util.tree_map(
1479
- _reset_or_increment_wrapper,
1480
- self.param_n_samples,
1481
- self.curr_param_idx,
1482
- self.keys,
1483
- )
1484
- # we must transpose the pytrees because keys are merged in res
1485
- # https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#transposing-trees
1486
- new_attributes = jax.tree_util.tree_transpose(
1487
- jax.tree_util.tree_structure(self.keys),
1488
- jax.tree_util.tree_structure([0, 0, 0]),
1489
- res,
1490
- )
1491
-
1492
- new = eqx.tree_at(
1493
- lambda m: (m.keys, m.param_n_samples, m.curr_param_idx),
1494
- self,
1495
- new_attributes,
1496
- )
1497
-
1498
- return new, jax.tree_util.tree_map(
1499
- lambda p, q: jax.lax.dynamic_slice(
1500
- p, start_indices=(q, 0), slice_sizes=(new.param_batch_size, 1)
1501
- ),
1502
- new.param_n_samples,
1503
- new.curr_param_idx,
1504
- )
1505
-
1506
- def get_batch(self):
1507
- """
1508
- Generic method to return a batch
1509
- """
1510
- return self.param_batch()
1511
-
1512
-
1513
- class DataGeneratorObservationsMultiPINNs(eqx.Module):
1514
- r"""
1515
- Despite the class name, it is rather a dataloader from user provided
1516
- observations that will be used for the observations loss.
1517
- This is the DataGenerator to use when dealing with multiple PINNs
1518
- (`u_dict`) in SystemLossODE/SystemLossPDE
1519
-
1520
- Technically, the constraint on the observations in SystemLossXDE are
1521
- applied in `constraints_system_loss_apply` and in this case the
1522
- `batch.obs_batch_dict` is a dict of obs_batch_dict over which the tree_map
1523
- applies (we select the obs_batch_dict corresponding to its `u_dict` entry)
1524
-
1525
- Parameters
1526
- ----------
1527
- obs_batch_size : Int
1528
- The size of the batch of randomly selected observations
1529
- `obs_batch_size` will be the same for all the
1530
- elements of the obs dict.
1531
- observed_pinn_in_dict : Dict[str, Float[Array, "n_obs nb_pinn_in"] | None]
1532
- A dict of observed_pinn_in as defined in DataGeneratorObservations.
1533
- Keys must be that of `u_dict`.
1534
- If no observation exists for a particular entry of `u_dict` the
1535
- corresponding key must still exist in observed_pinn_in_dict with
1536
- value None
1537
- observed_values_dict : Dict[str, Float[Array, "n_obs, nb_pinn_out"] | None]
1538
- A dict of observed_values as defined in DataGeneratorObservations.
1539
- Keys must be that of `u_dict`.
1540
- If no observation exists for a particular entry of `u_dict` the
1541
- corresponding key must still exist in observed_values_dict with
1542
- value None
1543
- observed_eq_params_dict : Dict[str, Dict[str, Float[Array, "n_obs 1"]]]
1544
- A dict of observed_eq_params as defined in DataGeneratorObservations.
1545
- Keys must be that of `u_dict`.
1546
- **Note**: if no observation exists for a particular entry of `u_dict` the
1547
- corresponding key must still exist in observed_eq_params_dict with
1548
- value `{}` (empty dictionnary).
1549
- key
1550
- Jax random key to shuffle batches.
1551
- """
1552
-
1553
- obs_batch_size: Int
1554
- observed_pinn_in_dict: Dict[str, Float[Array, "n_obs nb_pinn_in"] | None]
1555
- observed_values_dict: Dict[str, Float[Array, "n_obs nb_pinn_out"] | None]
1556
- observed_eq_params_dict: Dict[str, Dict[str, Float[Array, "n_obs 1"]]] = eqx.field(
1557
- default=None, kw_only=True
1558
- )
1559
- key: InitVar[Key]
1560
-
1561
- data_gen_obs: Dict[str, "DataGeneratorObservations"] = eqx.field(init=False)
1562
-
1563
- def __post_init__(self, key):
1564
- if self.observed_pinn_in_dict is None or self.observed_values_dict is None:
1565
- raise ValueError(
1566
- "observed_pinn_in_dict and observed_values_dict " "must be provided"
1567
- )
1568
- if self.observed_pinn_in_dict.keys() != self.observed_values_dict.keys():
1569
- raise ValueError(
1570
- "Keys must be the same in observed_pinn_in_dict"
1571
- " and observed_values_dict"
1572
- )
1573
-
1574
- if self.observed_eq_params_dict is None:
1575
- self.observed_eq_params_dict = {
1576
- k: {} for k in self.observed_pinn_in_dict.keys()
1577
- }
1578
- elif self.observed_pinn_in_dict.keys() != self.observed_eq_params_dict.keys():
1579
- raise ValueError(
1580
- f"Keys must be the same in observed_eq_params_dict"
1581
- f" and observed_pinn_in_dict and observed_values_dict"
1582
- )
1583
-
1584
- keys = dict(
1585
- zip(
1586
- self.observed_pinn_in_dict.keys(),
1587
- jax.random.split(key, len(self.observed_pinn_in_dict)),
1588
- )
1589
- )
1590
- self.data_gen_obs = jax.tree_util.tree_map(
1591
- lambda k, pinn_in, val, eq_params: (
1592
- DataGeneratorObservations(
1593
- k, self.obs_batch_size, pinn_in, val, eq_params
1594
- )
1595
- if pinn_in is not None
1596
- else None
1597
- ),
1598
- keys,
1599
- self.observed_pinn_in_dict,
1600
- self.observed_values_dict,
1601
- self.observed_eq_params_dict,
1602
- )
1603
-
1604
- def obs_batch(self) -> tuple["DataGeneratorObservationsMultiPINNs", PyTree]:
1605
- """
1606
- Returns a dictionary of DataGeneratorObservations.obs_batch with keys
1607
- from `u_dict`
1608
- """
1609
- data_gen_and_batch_pytree = jax.tree_util.tree_map(
1610
- lambda a: a.get_batch() if a is not None else {},
1611
- self.data_gen_obs,
1612
- is_leaf=lambda x: isinstance(x, DataGeneratorObservations),
1613
- ) # note the is_leaf note to traverse the DataGeneratorObservations and
1614
- # thus to be able to call the method on the element(s) of
1615
- # self.data_gen_obs which are not None
1616
- new_attribute = jax.tree_util.tree_map(
1617
- lambda a: a[0],
1618
- data_gen_and_batch_pytree,
1619
- is_leaf=lambda x: isinstance(x, tuple),
1620
- )
1621
- new = eqx.tree_at(lambda m: m.data_gen_obs, self, new_attribute)
1622
- batches = jax.tree_util.tree_map(
1623
- lambda a: a[1],
1624
- data_gen_and_batch_pytree,
1625
- is_leaf=lambda x: isinstance(x, tuple),
1626
- )
1627
-
1628
- return new, batches
1629
-
1630
- def get_batch(self) -> tuple["DataGeneratorObservationsMultiPINNs", PyTree]:
1631
- """
1632
- Generic method to return a batch
1633
- """
1634
- return self.obs_batch()