jinns 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,11 +1,11 @@
1
1
  # pylint: disable=unsubscriptable-object
2
2
  """
3
- Define the DataGeneratorODE equinox module
3
+ Define the DataGenerators modules
4
4
  """
5
5
  from __future__ import (
6
6
  annotations,
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
-
8
+ import warnings
9
9
  from typing import TYPE_CHECKING, Dict
10
10
  from dataclasses import InitVar
11
11
  import equinox as eqx
@@ -20,8 +20,7 @@ if TYPE_CHECKING:
20
20
 
21
21
  def append_param_batch(batch: AnyBatch, param_batch_dict: dict) -> AnyBatch:
22
22
  """
23
- Utility function that fill the param_batch_dict of a batch object with a
24
- param_batch_dict
23
+ Utility function that fills the field `batch.param_batch_dict` of a batch object.
25
24
  """
26
25
  return eqx.tree_at(
27
26
  lambda m: m.param_batch_dict,
@@ -33,8 +32,7 @@ def append_param_batch(batch: AnyBatch, param_batch_dict: dict) -> AnyBatch:
33
32
 
34
33
  def append_obs_batch(batch: AnyBatch, obs_batch_dict: dict) -> AnyBatch:
35
34
  """
36
- Utility function that fill the obs_batch_dict of a batch object with a
37
- obs_batch_dict
35
+ Utility function that fills the field `batch.obs_batch_dict` of a batch object
38
36
  """
39
37
  return eqx.tree_at(
40
38
  lambda m: m.obs_batch_dict, batch, obs_batch_dict, is_leaf=lambda x: x is None
@@ -63,12 +61,17 @@ def _reset_batch_idx_and_permute(
63
61
  curr_idx = 0
64
62
  # reshuffling
65
63
  key, subkey = jax.random.split(key)
66
- # domain = random.permutation(subkey, domain, axis=0, independent=False)
67
- # we want that permutation = choice when p=None
68
- # otherwise p is used to avoid collocation points not in nt_start
69
- domain = jax.random.choice(
70
- subkey, domain, shape=(domain.shape[0],), replace=False, p=p
71
- )
64
+ if p is None:
65
+ domain = jax.random.permutation(subkey, domain, axis=0, independent=False)
66
+ else:
67
+ # otherwise p is used to avoid collocation points not in n_start
68
+ # NOTE that replace=True to avoid undefined behaviour but then, the
69
+ # domain.shape[0] does not really grow as in the original RAR. instead,
70
+ # it always comprises the same number of points, but the points are
71
+ # updated
72
+ domain = jax.random.choice(
73
+ subkey, domain, shape=(domain.shape[0],), replace=True, p=p
74
+ )
72
75
 
73
76
  # return updated
74
77
  return (key, domain, curr_idx)
@@ -121,13 +124,13 @@ def _check_and_set_rar_parameters(
121
124
  ) -> tuple[Int, Float[Array, "n"], Int, Int]:
122
125
  if rar_parameters is not None and n_start is None:
123
126
  raise ValueError(
124
- "nt_start must be provided in the context of RAR sampling scheme"
127
+ "n_start must be provided in the context of RAR sampling scheme"
125
128
  )
126
129
 
127
130
  if rar_parameters is not None:
128
131
  # Default p is None. However, in the RAR sampling scheme we use 0
129
132
  # probability to specify non-used collocation points (i.e. points
130
- # above nt_start). Thus, p is a vector of probability of shape (nt, 1).
133
+ # above n_start). Thus, p is a vector of probability of shape (nt, 1).
131
134
  p = jnp.zeros((n,))
132
135
  p = p.at[:n_start].set(1 / n_start)
133
136
  # set internal counter for the number of gradient steps since the
@@ -163,81 +166,83 @@ class DataGeneratorODE(eqx.Module):
163
166
  The minimum value of the time domain to consider
164
167
  tmax : float
165
168
  The maximum value of the time domain to consider
166
- temporal_batch_size : int
169
+ temporal_batch_size : int | None, default=None
167
170
  The size of the batch of randomly selected points among
168
- the `nt` points.
171
+ the `nt` points. If None, no minibatches are used.
169
172
  method : str, default="uniform"
170
173
  Either `grid` or `uniform`, default is `uniform`.
171
174
  The method that generates the `nt` time points. `grid` means
172
175
  regularly spaced points over the domain. `uniform` means uniformly
173
176
  sampled points over the domain
174
177
  rar_parameters : Dict[str, Int], default=None
175
- Default to None: do not use Residual Adaptative Resampling.
176
- Otherwise a dictionary with keys. `start_iter`: the iteration at
177
- which we start the RAR sampling scheme (we first have a burn in
178
- period). `update_rate`: the number of gradient steps taken between
179
- each appending of collocation points in the RAR algo.
180
- `sample_size`: the size of the sample from which we will select new
181
- collocation points. `selected_sample_size_times`: the number of selected
178
+ Defaults to None: do not use Residual Adaptative Resampling.
179
+ Otherwise a dictionary with keys
180
+
181
+ - `start_iter`: the iteration at which we start the RAR sampling scheme (we first have a "burn-in" period).
182
+ - `update_every`: the number of gradient steps taken between
183
+ each update of collocation points in the RAR algo.
184
+ - `sample_size`: the size of the sample from which we will select new
185
+ collocation points.
186
+ - `selected_sample_size`: the number of selected
182
187
  points from the sample to be added to the current collocation
183
- points
184
- nt_start : Int, default=None
188
+ points.
189
+ n_start : Int, default=None
185
190
  Defaults to None. The effective size of nt used at start time.
186
191
  This value must be
187
192
  provided when rar_parameters is not None. Otherwise we set internally
188
- nt_start = nt and this is hidden from the user.
189
- In RAR, nt_start
193
+ n_start = nt and this is hidden from the user.
194
+ In RAR, n_start
190
195
  then corresponds to the initial number of points we train the PINN.
191
196
  """
192
197
 
193
- key: Key
194
- nt: Int
195
- tmin: Float
196
- tmax: Float
197
- temporal_batch_size: Int = eqx.field(static=True) # static cause used as a
198
- # shape in jax.lax.dynamic_slice
199
- method: str = eqx.field(static=True, default_factory=lambda: "uniform")
200
- rar_parameters: Dict[str, Int] = None
201
- nt_start: Int = eqx.field(static=True, default=None)
198
+ key: Key = eqx.field(kw_only=True)
199
+ nt: Int = eqx.field(kw_only=True, static=True)
200
+ tmin: Float = eqx.field(kw_only=True)
201
+ tmax: Float = eqx.field(kw_only=True)
202
+ temporal_batch_size: Int | None = eqx.field(static=True, default=None, kw_only=True)
203
+ method: str = eqx.field(
204
+ static=True, kw_only=True, default_factory=lambda: "uniform"
205
+ )
206
+ rar_parameters: Dict[str, Int] = eqx.field(default=None, kw_only=True)
207
+ n_start: Int = eqx.field(static=True, default=None, kw_only=True)
202
208
 
203
- # all the init=False fields are set in __post_init__, even after a _replace
204
- # or eqx.tree_at __post_init__ is called
205
- p_times: Float[Array, "nt"] = eqx.field(init=False)
209
+ # all the init=False fields are set in __post_init__
210
+ p: Float[Array, "nt 1"] = eqx.field(init=False)
206
211
  rar_iter_from_last_sampling: Int = eqx.field(init=False)
207
212
  rar_iter_nb: Int = eqx.field(init=False)
208
213
  curr_time_idx: Int = eqx.field(init=False)
209
- times: Float[Array, "nt"] = eqx.field(init=False)
214
+ times: Float[Array, "nt 1"] = eqx.field(init=False)
210
215
 
211
216
  def __post_init__(self):
212
217
  (
213
- self.nt_start,
214
- self.p_times,
218
+ self.n_start,
219
+ self.p,
215
220
  self.rar_iter_from_last_sampling,
216
221
  self.rar_iter_nb,
217
- ) = _check_and_set_rar_parameters(self.rar_parameters, self.nt, self.nt_start)
218
-
219
- self.curr_time_idx = jnp.iinfo(jnp.int32).max - self.temporal_batch_size - 1
220
- # to be sure there is a
221
- # shuffling at first get_batch() we do not call
222
- # _reset_batch_idx_and_permute in __init__ or __post_init__ because it
223
- # would return a copy of self and we have not investigate what would
224
- # happen
225
- # NOTE the (- self.temporal_batch_size - 1) because otherwise when computing
226
- # `bend` we overflow the max int32 with unwanted behaviour
222
+ ) = _check_and_set_rar_parameters(self.rar_parameters, self.nt, self.n_start)
223
+
224
+ if self.temporal_batch_size is not None:
225
+ self.curr_time_idx = self.nt + self.temporal_batch_size
226
+ # to be sure there is a shuffling at first get_batch()
227
+ # NOTE in the extreme case we could do:
228
+ # self.curr_time_idx=jnp.iinfo(jnp.int32).max - self.temporal_batch_size - 1
229
+ # but we do not test for such extreme values. Where we subtract
230
+ # self.temporal_batch_size - 1 because otherwise when computing
231
+ # `bend` we do not want to overflow the max int32 with unwanted behaviour
232
+ else:
233
+ self.curr_time_idx = 0
227
234
 
228
235
  self.key, self.times = self.generate_time_data(self.key)
229
236
  # Note that, here, in __init__ (and __post_init__), this is the
230
237
  # only place where self assignment are authorized so we do the
231
- # above way for the key. Note that one of the motivation to return the
232
- # key from generate_*_data is to easily align key with legacy
233
- # DataGenerators to use same unit tests
238
+ # above way for the key.
234
239
 
235
240
  def sample_in_time_domain(
236
241
  self, key: Key, sample_size: Int = None
237
- ) -> Float[Array, "nt"]:
242
+ ) -> Float[Array, "nt 1"]:
238
243
  return jax.random.uniform(
239
244
  key,
240
- (self.nt if sample_size is None else sample_size,),
245
+ (self.nt if sample_size is None else sample_size, 1),
241
246
  minval=self.tmin,
242
247
  maxval=self.tmax,
243
248
  )
@@ -247,26 +252,26 @@ class DataGeneratorODE(eqx.Module):
247
252
  Construct a complete set of `self.nt` time points according to the
248
253
  specified `self.method`
249
254
 
250
- Note that self.times has always size self.nt and not self.nt_start, even
255
+ Note that self.times has always size self.nt and not self.n_start, even
251
256
  in RAR scheme, we must allocate all the collocation points
252
257
  """
253
258
  key, subkey = jax.random.split(self.key)
254
259
  if self.method == "grid":
255
260
  partial_times = (self.tmax - self.tmin) / self.nt
256
- return key, jnp.arange(self.tmin, self.tmax, partial_times)
261
+ return key, jnp.arange(self.tmin, self.tmax, partial_times)[:, None]
257
262
  if self.method == "uniform":
258
263
  return key, self.sample_in_time_domain(subkey)
259
264
  raise ValueError("Method " + self.method + " is not implemented.")
260
265
 
261
266
  def _get_time_operands(
262
267
  self,
263
- ) -> tuple[Key, Float[Array, "nt"], Int, Int, Float[Array, "nt"]]:
268
+ ) -> tuple[Key, Float[Array, "nt 1"], Int, Int, Float[Array, "nt 1"]]:
264
269
  return (
265
270
  self.key,
266
271
  self.times,
267
272
  self.curr_time_idx,
268
273
  self.temporal_batch_size,
269
- self.p_times,
274
+ self.p,
270
275
  )
271
276
 
272
277
  def temporal_batch(
@@ -276,14 +281,18 @@ class DataGeneratorODE(eqx.Module):
276
281
  Return a batch of time points. If all the batches have been seen, we
277
282
  reshuffle them, otherwise we just return the next unseen batch.
278
283
  """
284
+ if self.temporal_batch_size is None or self.temporal_batch_size == self.nt:
285
+ # Avoid unnecessary reshuffling
286
+ return self, self.times
287
+
279
288
  bstart = self.curr_time_idx
280
289
  bend = bstart + self.temporal_batch_size
281
290
 
282
291
  # Compute the effective number of used collocation points
283
292
  if self.rar_parameters is not None:
284
293
  nt_eff = (
285
- self.nt_start
286
- + self.rar_iter_nb * self.rar_parameters["selected_sample_size_times"]
294
+ self.n_start
295
+ + self.rar_iter_nb * self.rar_parameters["selected_sample_size"]
287
296
  )
288
297
  else:
289
298
  nt_eff = self.nt
@@ -295,11 +304,11 @@ class DataGeneratorODE(eqx.Module):
295
304
 
296
305
  # commands below are equivalent to
297
306
  # return self.times[i:(i+t_batch_size)]
298
- # start indices can be dynamic be the slice shape is fixed
307
+ # start indices can be dynamic but the slice shape is fixed
299
308
  return new, jax.lax.dynamic_slice(
300
309
  new.times,
301
- start_indices=(new.curr_time_idx,),
302
- slice_sizes=(new.temporal_batch_size,),
310
+ start_indices=(new.curr_time_idx, 0),
311
+ slice_sizes=(new.temporal_batch_size, 1),
303
312
  )
304
313
 
305
314
  def get_batch(self) -> tuple["DataGeneratorODE", ODEBatch]:
@@ -324,17 +333,17 @@ class CubicMeshPDEStatio(eqx.Module):
324
333
  batches. Batches are made so that each data point is seen only
325
334
  once during 1 epoch.
326
335
  nb : Int | None
327
- The total number of points in $\partial\Omega$.
328
- Can be `None` not to lose performance generating the border
329
- batch if they are not used
330
- omega_batch_size : Int
336
+ The total number of points in $\partial\Omega$. Can be None if no
337
+ boundary condition is specified.
338
+ omega_batch_size : Int | None, default=None
331
339
  The size of the batch of randomly selected points among
332
- the `n` points.
333
- omega_border_batch_size : Int | None
340
+ the `n` points. If None no minibatches are used.
341
+ omega_border_batch_size : Int | None, default=None
334
342
  The size of the batch of points randomly selected
335
- among the `nb` points.
336
- Can be `None` not to lose performance generating the border
337
- batch if they are not used
343
+ among the `nb` points. If None, `omega_border_batch_size`
344
+ no minibatches are used. In dimension 1,
345
+ minibatches are never used since the boundary is composed of two
346
+ singletons.
338
347
  dim : Int
339
348
  Dimension of $\Omega$ domain
340
349
  min_pts : tuple[tuple[Float, Float], ...]
@@ -351,34 +360,39 @@ class CubicMeshPDEStatio(eqx.Module):
351
360
  regularly spaced points over the domain. `uniform` means uniformly
352
361
  sampled points over the domain
353
362
  rar_parameters : Dict[str, Int], default=None
354
- Default to None: do not use Residual Adaptative Resampling.
355
- Otherwise a dictionary with keys. `start_iter`: the iteration at
356
- which we start the RAR sampling scheme (we first have a burn in
357
- period). `update_every`: the number of gradient steps taken between
358
- each appending of collocation points in the RAR algo.
359
- `sample_size_omega`: the size of the sample from which we will select new
360
- collocation points. `selected_sample_size_omega`: the number of selected
363
+ Defaults to None: do not use Residual Adaptative Resampling.
364
+ Otherwise a dictionary with keys
365
+
366
+ - `start_iter`: the iteration at which we start the RAR sampling scheme (we first have a "burn-in" period).
367
+ - `update_every`: the number of gradient steps taken between
368
+ each update of collocation points in the RAR algo.
369
+ - `sample_size`: the size of the sample from which we will select new
370
+ collocation points.
371
+ - `selected_sample_size`: the number of selected
361
372
  points from the sample to be added to the current collocation
362
- points
373
+ points.
363
374
  n_start : Int, default=None
364
375
  Defaults to None. The effective size of n used at start time.
365
376
  This value must be
366
377
  provided when rar_parameters is not None. Otherwise we set internally
367
378
  n_start = n and this is hidden from the user.
368
379
  In RAR, n_start
369
- then corresponds to the initial number of points we train the PINN.
380
+ then corresponds to the initial number of points we train the PINN on.
370
381
  """
371
382
 
372
383
  # kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
373
384
  key: Key = eqx.field(kw_only=True)
374
- n: Int = eqx.field(kw_only=True)
375
- nb: Int | None = eqx.field(kw_only=True)
376
- omega_batch_size: Int = eqx.field(
377
- kw_only=True, static=True
385
+ n: Int = eqx.field(kw_only=True, static=True)
386
+ nb: Int | None = eqx.field(kw_only=True, static=True, default=None)
387
+ omega_batch_size: Int | None = eqx.field(
388
+ kw_only=True,
389
+ static=True,
390
+ default=None, # can be None as
391
+ # CubicMeshPDENonStatio inherits but also if omega_batch_size=n
378
392
  ) # static cause used as a
379
393
  # shape in jax.lax.dynamic_slice
380
394
  omega_border_batch_size: Int | None = eqx.field(
381
- kw_only=True, static=True
395
+ kw_only=True, static=True, default=None
382
396
  ) # static cause used as a
383
397
  # shape in jax.lax.dynamic_slice
384
398
  dim: Int = eqx.field(kw_only=True, static=True) # static cause used as a
@@ -391,10 +405,8 @@ class CubicMeshPDEStatio(eqx.Module):
391
405
  rar_parameters: Dict[str, Int] = eqx.field(kw_only=True, default=None)
392
406
  n_start: Int = eqx.field(kw_only=True, default=None, static=True)
393
407
 
394
- # all the init=False fields are set in __post_init__, even after a _replace
395
- # or eqx.tree_at __post_init__ is called
396
- p_omega: Float[Array, "n"] = eqx.field(init=False)
397
- p_border: None = eqx.field(init=False)
408
+ # all the init=False fields are set in __post_init__
409
+ p: Float[Array, "n"] = eqx.field(init=False)
398
410
  rar_iter_from_last_sampling: Int = eqx.field(init=False)
399
411
  rar_iter_nb: Int = eqx.field(init=False)
400
412
  curr_omega_idx: Int = eqx.field(init=False)
@@ -410,51 +422,59 @@ class CubicMeshPDEStatio(eqx.Module):
410
422
 
411
423
  (
412
424
  self.n_start,
413
- self.p_omega,
425
+ self.p,
414
426
  self.rar_iter_from_last_sampling,
415
427
  self.rar_iter_nb,
416
428
  ) = _check_and_set_rar_parameters(self.rar_parameters, self.n, self.n_start)
417
429
 
418
- self.p_border = None # no RAR sampling for border for now
430
+ if self.method == "grid" and self.dim == 2:
431
+ perfect_sq = int(jnp.round(jnp.sqrt(self.n)) ** 2)
432
+ if self.n != perfect_sq:
433
+ warnings.warn(
434
+ "Grid sampling is requested in dimension 2 with a non"
435
+ f" perfect square dataset size (self.n = {self.n})."
436
+ f" Modifying self.n to self.n = {perfect_sq}."
437
+ )
438
+ self.n = perfect_sq
419
439
 
420
- # Special handling for the border batch
421
- if self.omega_border_batch_size is None:
422
- self.nb = None
423
- self.omega_border_batch_size = None
424
- elif self.dim == 1:
425
- # 1-D case : the arguments `nb` and `omega_border_batch_size` are
426
- # ignored but kept for backward stability. The attributes are
427
- # always set to 2.
428
- self.nb = 2
429
- self.omega_border_batch_size = 2
430
- # We are in 1-D case => omega_border_batch_size is
431
- # ignored since borders of Omega are singletons.
432
- # self.border_batch() will return [xmin, xmax]
440
+ if self.nb is not None:
441
+ if self.dim == 1:
442
+ self.omega_border_batch_size = None
443
+ # We are in 1-D case => omega_border_batch_size is
444
+ # ignored since borders of Omega are singletons.
445
+ # self.border_batch() will return [xmin, xmax]
446
+ else:
447
+ if self.nb % (2 * self.dim) != 0 or self.nb < 2 * self.dim:
448
+ raise ValueError(
449
+ f"number of border point must be"
450
+ f" a multiple of 2xd = {2*self.dim} (the # of faces of"
451
+ f" a d-dimensional cube). Got {self.nb=}."
452
+ )
453
+ if (
454
+ self.omega_border_batch_size is not None
455
+ and self.nb // (2 * self.dim) < self.omega_border_batch_size
456
+ ):
457
+ raise ValueError(
458
+ f"number of points per facets ({self.nb//(2*self.dim)})"
459
+ f" cannot be lower than border batch size "
460
+ f" ({self.omega_border_batch_size})."
461
+ )
462
+ self.nb = int((2 * self.dim) * (self.nb // (2 * self.dim)))
463
+
464
+ if self.omega_batch_size is None:
465
+ self.curr_omega_idx = 0
433
466
  else:
434
- if self.nb % (2 * self.dim) != 0 or self.nb < 2 * self.dim:
435
- raise ValueError(
436
- "number of border point must be"
437
- " a multiple of 2xd (the # of faces of a d-dimensional cube)"
438
- )
439
- if self.nb // (2 * self.dim) < self.omega_border_batch_size:
440
- raise ValueError(
441
- "number of points per facets (nb//2*self.dim)"
442
- " cannot be lower than border batch size"
443
- )
444
- self.nb = int((2 * self.dim) * (self.nb // (2 * self.dim)))
467
+ self.curr_omega_idx = self.n + self.omega_batch_size
468
+ # to be sure there is a shuffling at first get_batch()
445
469
 
446
- self.curr_omega_idx = jnp.iinfo(jnp.int32).max - self.omega_batch_size - 1
447
- # see explaination in DataGeneratorODE
448
470
  if self.omega_border_batch_size is None:
449
- self.curr_omega_border_idx = None
471
+ self.curr_omega_border_idx = 0
450
472
  else:
451
- self.curr_omega_border_idx = (
452
- jnp.iinfo(jnp.int32).max - self.omega_border_batch_size - 1
453
- )
454
- # key, subkey = jax.random.split(self.key)
455
- # self.key = key
456
- self.key, self.omega, self.omega_border = self.generate_data(self.key)
457
- # see explaination in DataGeneratorODE for the key
473
+ self.curr_omega_border_idx = self.nb + self.omega_border_batch_size
474
+ # to be sure there is a shuffling at first get_batch()
475
+
476
+ self.key, self.omega = self.generate_omega_data(self.key)
477
+ self.key, self.omega_border = self.generate_omega_border_data(self.key)
458
478
 
459
479
  def sample_in_omega_domain(
460
480
  self, keys: Key, sample_size: Int = None
@@ -480,9 +500,10 @@ class CubicMeshPDEStatio(eqx.Module):
480
500
  )
481
501
 
482
502
  def sample_in_omega_border_domain(
483
- self, keys: Key
503
+ self, keys: Key, sample_size: int = None
484
504
  ) -> Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None:
485
- if self.omega_border_batch_size is None:
505
+ sample_size = self.nb if sample_size is None else sample_size
506
+ if sample_size is None:
486
507
  return None
487
508
  if self.dim == 1:
488
509
  xmin = self.min_pts[0]
@@ -492,8 +513,7 @@ class CubicMeshPDEStatio(eqx.Module):
492
513
  # currently hard-coded the 4 edges for d==2
493
514
  # TODO : find a general & efficient way to sample from the border
494
515
  # (facets) of the hypercube in general dim.
495
-
496
- facet_n = self.nb // (2 * self.dim)
516
+ facet_n = sample_size // (2 * self.dim)
497
517
  xmin = jnp.hstack(
498
518
  [
499
519
  self.min_pts[0] * jnp.ones((facet_n, 1)),
@@ -544,54 +564,64 @@ class CubicMeshPDEStatio(eqx.Module):
544
564
  + f"implemented yet. You are asking for generation in dimension d={self.dim}."
545
565
  )
546
566
 
547
- def generate_data(self, key: Key) -> tuple[
567
+ def generate_omega_data(self, key: Key, data_size: int = None) -> tuple[
548
568
  Key,
549
569
  Float[Array, "n dim"],
550
- Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None,
551
570
  ]:
552
571
  r"""
553
572
  Construct a complete set of `self.n` $\Omega$ points according to the
554
- specified `self.method`. Also constructs a complete set of `self.nb`
555
- $\partial\Omega$ points if `self.omega_border_batch_size` is not
556
- `None`. If the latter is `None` we set `self.omega_border` to `None`.
573
+ specified `self.method`.
557
574
  """
575
+ data_size = self.n if data_size is None else data_size
558
576
  # Generate Omega
559
577
  if self.method == "grid":
560
578
  if self.dim == 1:
561
579
  xmin, xmax = self.min_pts[0], self.max_pts[0]
562
- partial = (xmax - xmin) / self.n
563
- # shape (n, 1)
564
- omega = jnp.arange(xmin, xmax, partial)[:, None]
580
+ ## shape (n, 1)
581
+ omega = jnp.linspace(xmin, xmax, data_size)[:, None]
565
582
  else:
566
- partials = [
567
- (self.max_pts[i] - self.min_pts[i]) / jnp.sqrt(self.n)
568
- for i in range(self.dim)
569
- ]
570
583
  xyz_ = jnp.meshgrid(
571
584
  *[
572
- jnp.arange(self.min_pts[i], self.max_pts[i], partials[i])
585
+ jnp.linspace(
586
+ self.min_pts[i],
587
+ self.max_pts[i],
588
+ int(jnp.round(jnp.sqrt(data_size))),
589
+ )
573
590
  for i in range(self.dim)
574
591
  ]
575
592
  )
576
- xyz_ = [a.reshape((self.n, 1)) for a in xyz_]
593
+ xyz_ = [a.reshape((data_size, 1)) for a in xyz_]
577
594
  omega = jnp.concatenate(xyz_, axis=-1)
578
595
  elif self.method == "uniform":
579
596
  if self.dim == 1:
580
597
  key, subkeys = jax.random.split(key, 2)
581
598
  else:
582
599
  key, *subkeys = jax.random.split(key, self.dim + 1)
583
- omega = self.sample_in_omega_domain(subkeys)
600
+ omega = self.sample_in_omega_domain(subkeys, sample_size=data_size)
584
601
  else:
585
602
  raise ValueError("Method " + self.method + " is not implemented.")
603
+ return key, omega
586
604
 
605
+ def generate_omega_border_data(self, key: Key, data_size: int = None) -> tuple[
606
+ Key,
607
+ Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None,
608
+ ]:
609
+ r"""
610
+ Also constructs a complete set of `self.nb`
611
+ $\partial\Omega$ points if `self.omega_border_batch_size` is not
612
+ `None`. If the latter is `None` we set `self.omega_border` to `None`.
613
+ """
587
614
  # Generate border of omega
588
- if self.dim == 2 and self.omega_border_batch_size is not None:
615
+ data_size = self.nb if data_size is None else data_size
616
+ if self.dim == 2:
589
617
  key, *subkeys = jax.random.split(key, 5)
590
618
  else:
591
619
  subkeys = None
592
- omega_border = self.sample_in_omega_border_domain(subkeys)
620
+ omega_border = self.sample_in_omega_border_domain(
621
+ subkeys, sample_size=data_size
622
+ )
593
623
 
594
- return key, omega, omega_border
624
+ return key, omega_border
595
625
 
596
626
  def _get_omega_operands(
597
627
  self,
@@ -601,7 +631,7 @@ class CubicMeshPDEStatio(eqx.Module):
601
631
  self.omega,
602
632
  self.curr_omega_idx,
603
633
  self.omega_batch_size,
604
- self.p_omega,
634
+ self.p,
605
635
  )
606
636
 
607
637
  def inside_batch(
@@ -612,11 +642,15 @@ class CubicMeshPDEStatio(eqx.Module):
612
642
  If all the batches have been seen, we reshuffle them,
613
643
  otherwise we just return the next unseen batch.
614
644
  """
645
+ if self.omega_batch_size is None or self.omega_batch_size == self.n:
646
+ # Avoid unnecessary reshuffling
647
+ return self, self.omega
648
+
615
649
  # Compute the effective number of used collocation points
616
650
  if self.rar_parameters is not None:
617
651
  n_eff = (
618
652
  self.n_start
619
- + self.rar_iter_nb * self.rar_parameters["selected_sample_size_omega"]
653
+ + self.rar_iter_nb * self.rar_parameters["selected_sample_size"]
620
654
  )
621
655
  else:
622
656
  n_eff = self.n
@@ -645,7 +679,7 @@ class CubicMeshPDEStatio(eqx.Module):
645
679
  self.omega_border,
646
680
  self.curr_omega_border_idx,
647
681
  self.omega_border_batch_size,
648
- self.p_border,
682
+ None,
649
683
  )
650
684
 
651
685
  def border_batch(
@@ -670,12 +704,23 @@ class CubicMeshPDEStatio(eqx.Module):
670
704
 
671
705
 
672
706
  """
673
- if self.omega_border_batch_size is None:
707
+ if self.nb is None:
708
+ # Avoid unnecessary reshuffling
674
709
  return self, None
710
+
675
711
  if self.dim == 1:
712
+ # Avoid unnecessary reshuffling
676
713
  # 1-D case, no randomness : we always return the whole omega border,
677
714
  # i.e. (1, 1, 2) shape jnp.array([[[xmin], [xmax]]]).
678
715
  return self, self.omega_border[None, None] # shape is (1, 1, 2)
716
+
717
+ if (
718
+ self.omega_border_batch_size is None
719
+ or self.omega_border_batch_size == self.nb // 2**self.dim
720
+ ):
721
+ # Avoid unnecessary reshuffling
722
+ return self, self.omega_border
723
+
679
724
  bstart = self.curr_omega_border_idx
680
725
  bend = bstart + self.omega_border_batch_size
681
726
 
@@ -701,7 +746,7 @@ class CubicMeshPDEStatio(eqx.Module):
701
746
  """
702
747
  new, inside_batch = self.inside_batch()
703
748
  new, border_batch = new.border_batch()
704
- return new, PDEStatioBatch(inside_batch=inside_batch, border_batch=border_batch)
749
+ return new, PDEStatioBatch(domain_batch=inside_batch, border_batch=border_batch)
705
750
 
706
751
 
707
752
  class CubicMeshPDENonStatio(CubicMeshPDEStatio):
@@ -715,30 +760,29 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
715
760
  key : Key
716
761
  Jax random key to sample new time points and to shuffle batches
717
762
  n : Int
718
- The number of total $\Omega$ points that will be divided in
763
+ The number of total $I\times \Omega$ points that will be divided in
719
764
  batches. Batches are made so that each data point is seen only
720
765
  once during 1 epoch.
721
766
  nb : Int | None
722
- The total number of points in $\partial\Omega$.
723
- Can be `None` not to lose performance generating the border
724
- batch if they are not used
725
- nt : Int
726
- The number of total time points that will be divided in
767
+ The total number of points in $\partial\Omega$. Can be None if no
768
+ boundary condition is specified.
769
+ ni : Int
770
+ The number of total $\Omega$ points at $t=0$ that will be divided in
727
771
  batches. Batches are made so that each data point is seen only
728
772
  once during 1 epoch.
729
- omega_batch_size : Int
773
+ domain_batch_size : Int | None, default=None
730
774
  The size of the batch of randomly selected points among
731
- the `n` points.
732
- omega_border_batch_size : Int | None
775
+ the `n` points. If None no mini-batches are used.
776
+ border_batch_size : Int | None, default=None
733
777
  The size of the batch of points randomly selected
734
- among the `nb` points.
735
- Can be `None` not to lose performance generating the border
736
- batch if they are not used
737
- temporal_batch_size : Int
778
+ among the `nb` points. If None, `domain_batch_size` no
779
+ mini-batches are used.
780
+ initial_batch_size : Int | None, default=None
738
781
  The size of the batch of randomly selected points among
739
- the `nt` points.
782
+ the `ni` points. If None no
783
+ mini-batches are used.
740
784
  dim : Int
741
- An integer. dimension of $\Omega$ domain
785
+ An integer. Dimension of $\Omega$ domain.
742
786
  min_pts : tuple[tuple[Float, Float], ...]
743
787
  A tuple of minimum values of the domain along each dimension. For a sampling
744
788
  in `n` dimension, this represents $(x_{1, min}, x_{2,min}, ...,
@@ -757,13 +801,15 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
757
801
  regularly spaced points over the domain. `uniform` means uniformly
758
802
  sampled points over the domain
759
803
  rar_parameters : Dict[str, Int], default=None
760
- Default to None: do not use Residual Adaptative Resampling.
761
- Otherwise a dictionary with keys. `start_iter`: the iteration at
762
- which we start the RAR sampling scheme (we first have a burn in
763
- period). `update_every`: the number of gradient steps taken between
764
- each appending of collocation points in the RAR algo.
765
- `sample_size_omega`: the size of the sample from which we will select new
766
- collocation points. `selected_sample_size_omega`: the number of selected
804
+ Defaults to None: do not use Residual Adaptative Resampling.
805
+ Otherwise a dictionary with keys
806
+
807
+ - `start_iter`: the iteration at which we start the RAR sampling scheme (we first have a "burn-in" period).
808
+ - `update_every`: the number of gradient steps taken between
809
+ each update of collocation points in the RAR algo.
810
+ - `sample_size`: the size of the sample from which we will select new
811
+ collocation points.
812
+ - `selected_sample_size`: the number of selected
767
813
  points from the sample to be added to the current collocation
768
814
  points.
769
815
  n_start : Int, default=None
@@ -773,27 +819,23 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
773
819
  n_start = n and this is hidden from the user.
774
820
  In RAR, n_start
775
821
  then corresponds to the initial number of omega points we train the PINN.
776
- nt_start : Int, default=None
777
- Defaults to None. A RAR hyper-parameter. Same as ``n_start`` but
778
- for times collocation point. See also ``DataGeneratorODE``
779
- documentation.
780
- cartesian_product : Bool, default=True
781
- Defaults to True. Whether we return the cartesian product of the
782
- temporal batch with the inside and border batches. If False we just
783
- return their concatenation.
784
822
  """
785
823
 
786
- temporal_batch_size: Int = eqx.field(kw_only=True)
787
824
  tmin: Float = eqx.field(kw_only=True)
788
825
  tmax: Float = eqx.field(kw_only=True)
789
- nt: Int = eqx.field(kw_only=True)
790
- temporal_batch_size: Int = eqx.field(kw_only=True, static=True)
791
- cartesian_product: Bool = eqx.field(kw_only=True, default=True, static=True)
792
- nt_start: int = eqx.field(kw_only=True, default=None, static=True)
793
-
794
- p_times: Array = eqx.field(init=False)
795
- curr_time_idx: Int = eqx.field(init=False)
796
- times: Array = eqx.field(init=False)
826
+ ni: Int = eqx.field(kw_only=True, static=True)
827
+ domain_batch_size: Int | None = eqx.field(kw_only=True, static=True, default=None)
828
+ initial_batch_size: Int | None = eqx.field(kw_only=True, static=True, default=None)
829
+ border_batch_size: Int | None = eqx.field(kw_only=True, static=True, default=None)
830
+
831
+ curr_domain_idx: Int = eqx.field(init=False)
832
+ curr_initial_idx: Int = eqx.field(init=False)
833
+ curr_border_idx: Int = eqx.field(init=False)
834
+ domain: Float[Array, "n 1+dim"] = eqx.field(init=False)
835
+ border: Float[Array, "(nb//2) 1+1 2"] | Float[Array, "(nb//4) 2+1 4"] | None = (
836
+ eqx.field(init=False)
837
+ )
838
+ initial: Float[Array, "ni dim"] = eqx.field(init=False)
797
839
 
798
840
  def __post_init__(self):
799
841
  """
@@ -803,162 +845,310 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
803
845
  super().__post_init__() # because __init__ or __post_init__ of Base
804
846
  # class is not automatically called
805
847
 
806
- if not self.cartesian_product:
807
- if self.temporal_batch_size != self.omega_batch_size:
848
+ if self.method == "grid":
849
+ # NOTE we must redo the sampling with the square root number of samples
850
+ # and then take the cartesian product
851
+ self.n = int(jnp.round(jnp.sqrt(self.n)) ** 2)
852
+ if self.dim == 2:
853
+ # in the case of grid sampling in 2D in dim 2 in non-statio,
854
+ # self.n needs to be a perfect ^4, because there is the
855
+ # cartesian product with time domain which is also present
856
+ perfect_4 = int(jnp.round(self.n**0.25) ** 4)
857
+ if self.n != perfect_4:
858
+ warnings.warn(
859
+ "Grid sampling is requested in dimension 2 in non"
860
+ " stationary setting with a non"
861
+ f" perfect square dataset size (self.n = {self.n})."
862
+ f" Modifying self.n to self.n = {perfect_4}."
863
+ )
864
+ self.n = perfect_4
865
+ self.key, half_domain_times = self.generate_time_data(
866
+ self.key, int(jnp.round(jnp.sqrt(self.n)))
867
+ )
868
+
869
+ self.key, half_domain_omega = self.generate_omega_data(
870
+ self.key, data_size=int(jnp.round(jnp.sqrt(self.n)))
871
+ )
872
+ self.domain = make_cartesian_product(half_domain_times, half_domain_omega)
873
+
874
+ # NOTE
875
+ (
876
+ self.n_start,
877
+ self.p,
878
+ self.rar_iter_from_last_sampling,
879
+ self.rar_iter_nb,
880
+ ) = _check_and_set_rar_parameters(self.rar_parameters, self.n, self.n_start)
881
+ elif self.method == "uniform":
882
+ self.key, domain_times = self.generate_time_data(self.key, self.n)
883
+ self.domain = jnp.concatenate([domain_times, self.omega], axis=1)
884
+ else:
885
+ raise ValueError(
886
+ f"Bad value for method. Got {self.method}, expected"
887
+ ' "grid" or "uniform"'
888
+ )
889
+
890
+ if self.domain_batch_size is None:
891
+ self.curr_domain_idx = 0
892
+ else:
893
+ self.curr_domain_idx = self.n + self.domain_batch_size
894
+ # to be sure there is a shuffling at first get_batch()
895
+ if self.nb is not None:
896
+ # the check below has already been done in super.__post_init__ if
897
+ # dim > 1. Here we retest it in whatever dim
898
+ if self.nb % (2 * self.dim) != 0 or self.nb < 2 * self.dim:
808
899
  raise ValueError(
809
- "If stacking is requested between the time and "
810
- "inside batches of collocation points, self.temporal_batch_size "
811
- "must then be equal to self.omega_batch_size"
900
+ "number of border point must be"
901
+ " a multiple of 2xd (the # of faces of a d-dimensional cube)"
812
902
  )
903
+ # the check below concern omega_border_batch_size for dim > 1 in
904
+ # super.__post_init__. Here it concerns all dim values since our
905
+ # border_batch is the concatenation or cartesian product with times
813
906
  if (
814
- self.dim > 1
815
- and self.omega_border_batch_size is not None
816
- and self.temporal_batch_size != self.omega_border_batch_size
907
+ self.border_batch_size is not None
908
+ and self.nb // (2 * self.dim) < self.border_batch_size
817
909
  ):
818
910
  raise ValueError(
819
- "If dim > 1 and stacking is requested between the time and "
820
- "inside batches of collocation points, self.temporal_batch_size "
821
- "must then be equal to self.omega_border_batch_size"
911
+ "number of points per facets (nb//2*self.dim)"
912
+ " cannot be lower than border batch size"
913
+ )
914
+ self.key, boundary_times = self.generate_time_data(
915
+ self.key, self.nb // (2 * self.dim)
916
+ )
917
+ boundary_times = boundary_times.reshape(-1, 1, 1)
918
+ boundary_times = jnp.repeat(
919
+ boundary_times, self.omega_border.shape[-1], axis=2
920
+ )
921
+ if self.dim == 1:
922
+ self.border = make_cartesian_product(
923
+ boundary_times, self.omega_border[None, None]
822
924
  )
823
- # Note if self.dim == 1:
824
- # print(
825
- # "Cartesian product is not requested but will be "
826
- # "executed anyway since dim=1"
827
- # )
925
+ else:
926
+ self.border = jnp.concatenate(
927
+ [boundary_times, self.omega_border], axis=1
928
+ )
929
+ if self.border_batch_size is None:
930
+ self.curr_border_idx = 0
931
+ else:
932
+ self.curr_border_idx = self.nb + self.border_batch_size
933
+ # to be sure there is a shuffling at first get_batch()
828
934
 
829
- # Set-up for timewise RAR (some quantity are already set-up by super())
830
- (
831
- self.nt_start,
832
- self.p_times,
833
- _,
834
- _,
835
- ) = _check_and_set_rar_parameters(self.rar_parameters, self.nt, self.nt_start)
935
+ else:
936
+ self.border = None
937
+ self.curr_border_idx = None
938
+ self.border_batch_size = None
939
+
940
+ if self.ni is not None:
941
+ perfect_sq = int(jnp.round(jnp.sqrt(self.ni)) ** 2)
942
+ if self.ni != perfect_sq:
943
+ warnings.warn(
944
+ "Grid sampling is requested in dimension 2 with a non"
945
+ f" perfect square dataset size (self.ni = {self.ni})."
946
+ f" Modifying self.ni to self.ni = {perfect_sq}."
947
+ )
948
+ self.ni = perfect_sq
949
+ self.key, self.initial = self.generate_omega_data(
950
+ self.key, data_size=self.ni
951
+ )
836
952
 
837
- self.curr_time_idx = jnp.iinfo(jnp.int32).max - self.temporal_batch_size - 1
838
- self.key, _ = jax.random.split(self.key, 2) # to make it equivalent to
839
- # the call to _reset_batch_idx_and_permute in legacy DG
840
- self.key, self.times = self.generate_time_data(self.key)
841
- # see explaination in DataGeneratorODE for the key
953
+ if self.initial_batch_size is None or self.initial_batch_size == self.ni:
954
+ self.curr_initial_idx = 0
955
+ else:
956
+ self.curr_initial_idx = self.ni + self.initial_batch_size
957
+ # to be sure there is a shuffling at first get_batch()
958
+ else:
959
+ self.initial = None
960
+ self.initial_batch_size = None
961
+ self.curr_initial_idx = None
842
962
 
843
- def sample_in_time_domain(
844
- self, key: Key, sample_size: Int = None
845
- ) -> Float[Array, "nt"]:
963
+ # the following attributes will not be used anymore
964
+ self.omega = None
965
+ self.omega_border = None
966
+
967
+ def generate_time_data(self, key: Key, nt: Int) -> tuple[Key, Float[Array, "nt 1"]]:
968
+ """
969
+ Construct a complete set of `nt` time points according to the
970
+ specified `self.method`
971
+ """
972
+ key, subkey = jax.random.split(key, 2)
973
+ if self.method == "grid":
974
+ partial_times = (self.tmax - self.tmin) / nt
975
+ return key, jnp.arange(self.tmin, self.tmax, partial_times)[:, None]
976
+ if self.method == "uniform":
977
+ return key, self.sample_in_time_domain(subkey, nt)
978
+ raise ValueError("Method " + self.method + " is not implemented.")
979
+
980
+ def sample_in_time_domain(self, key: Key, nt: Int) -> Float[Array, "nt 1"]:
846
981
  return jax.random.uniform(
847
982
  key,
848
- (self.nt if sample_size is None else sample_size,),
983
+ (nt, 1),
849
984
  minval=self.tmin,
850
985
  maxval=self.tmax,
851
986
  )
852
987
 
853
- def _get_time_operands(
988
+ def _get_domain_operands(
854
989
  self,
855
- ) -> tuple[Key, Float[Array, "nt"], Int, Int, Float[Array, "nt"]]:
990
+ ) -> tuple[Key, Float[Array, "n 1+dim"], Int, Int, None]:
856
991
  return (
857
992
  self.key,
858
- self.times,
859
- self.curr_time_idx,
860
- self.temporal_batch_size,
861
- self.p_times,
993
+ self.domain,
994
+ self.curr_domain_idx,
995
+ self.domain_batch_size,
996
+ self.p,
862
997
  )
863
998
 
864
- def generate_time_data(self, key: Key) -> tuple[Key, Float[Array, "nt"]]:
865
- """
866
- Construct a complete set of `self.nt` time points according to the
867
- specified `self.method`
999
+ def domain_batch(
1000
+ self,
1001
+ ) -> tuple["CubicMeshPDEStatio", Float[Array, "domain_batch_size 1+dim"]]:
868
1002
 
869
- Note that self.times has always size self.nt and not self.nt_start, even
870
- in RAR scheme, we must allocate all the collocation points
871
- """
872
- key, subkey = jax.random.split(key, 2)
873
- if self.method == "grid":
874
- partial_times = (self.tmax - self.tmin) / self.nt
875
- return key, jnp.arange(self.tmin, self.tmax, partial_times)
876
- if self.method == "uniform":
877
- return key, self.sample_in_time_domain(subkey)
878
- raise ValueError("Method " + self.method + " is not implemented.")
1003
+ if self.domain_batch_size is None or self.domain_batch_size == self.n:
1004
+ # Avoid unnecessary reshuffling
1005
+ return self, self.domain
879
1006
 
880
- def temporal_batch(
881
- self,
882
- ) -> tuple["CubicMeshPDENonStatio", Float[Array, "temporal_batch_size"]]:
883
- """
884
- Return a batch of time points. If all the batches have been seen, we
885
- reshuffle them, otherwise we just return the next unseen batch.
886
- """
887
- bstart = self.curr_time_idx
888
- bend = bstart + self.temporal_batch_size
1007
+ bstart = self.curr_domain_idx
1008
+ bend = bstart + self.domain_batch_size
889
1009
 
890
1010
  # Compute the effective number of used collocation points
891
1011
  if self.rar_parameters is not None:
892
- nt_eff = (
893
- self.nt_start
894
- + self.rar_iter_nb * self.rar_parameters["selected_sample_size_times"]
1012
+ n_eff = (
1013
+ self.n_start
1014
+ + self.rar_iter_nb * self.rar_parameters["selected_sample_size"]
895
1015
  )
896
1016
  else:
897
- nt_eff = self.nt
1017
+ n_eff = self.n
898
1018
 
899
- new_attributes = _reset_or_increment(bend, nt_eff, self._get_time_operands())
1019
+ new_attributes = _reset_or_increment(bend, n_eff, self._get_domain_operands())
900
1020
  new = eqx.tree_at(
901
- lambda m: (m.key, m.times, m.curr_time_idx), self, new_attributes
1021
+ lambda m: (m.key, m.domain, m.curr_domain_idx),
1022
+ self,
1023
+ new_attributes,
1024
+ )
1025
+ return new, jax.lax.dynamic_slice(
1026
+ new.domain,
1027
+ start_indices=(new.curr_domain_idx, 0),
1028
+ slice_sizes=(new.domain_batch_size, new.dim + 1),
1029
+ )
1030
+
1031
+ def _get_border_operands(
1032
+ self,
1033
+ ) -> tuple[
1034
+ Key, Float[Array, "nb 1+1 2"] | Float[Array, "(nb//4) 2+1 4"], Int, Int, None
1035
+ ]:
1036
+ return (
1037
+ self.key,
1038
+ self.border,
1039
+ self.curr_border_idx,
1040
+ self.border_batch_size,
1041
+ None,
1042
+ )
1043
+
1044
+ def border_batch(
1045
+ self,
1046
+ ) -> tuple[
1047
+ "CubicMeshPDENonStatio",
1048
+ Float[Array, "border_batch_size 1+1 2"]
1049
+ | Float[Array, "border_batch_size 2+1 4"]
1050
+ | None,
1051
+ ]:
1052
+ if self.nb is None:
1053
+ # Avoid unnecessary reshuffling
1054
+ return self, None
1055
+
1056
+ if (
1057
+ self.border_batch_size is None
1058
+ or self.border_batch_size == self.nb // 2**self.dim
1059
+ ):
1060
+ # Avoid unnecessary reshuffling
1061
+ return self, self.border
1062
+
1063
+ bstart = self.curr_border_idx
1064
+ bend = bstart + self.border_batch_size
1065
+
1066
+ n_eff = self.border.shape[0]
1067
+
1068
+ new_attributes = _reset_or_increment(bend, n_eff, self._get_border_operands())
1069
+ new = eqx.tree_at(
1070
+ lambda m: (m.key, m.border, m.curr_border_idx),
1071
+ self,
1072
+ new_attributes,
902
1073
  )
903
1074
 
904
1075
  return new, jax.lax.dynamic_slice(
905
- new.times,
906
- start_indices=(new.curr_time_idx,),
907
- slice_sizes=(new.temporal_batch_size,),
1076
+ new.border,
1077
+ start_indices=(new.curr_border_idx, 0, 0),
1078
+ slice_sizes=(
1079
+ new.border_batch_size,
1080
+ new.dim + 1,
1081
+ 2 * new.dim,
1082
+ ),
1083
+ )
1084
+
1085
+ def _get_initial_operands(
1086
+ self,
1087
+ ) -> tuple[Key, Float[Array, "ni dim"], Int, Int, None]:
1088
+ return (
1089
+ self.key,
1090
+ self.initial,
1091
+ self.curr_initial_idx,
1092
+ self.initial_batch_size,
1093
+ None,
1094
+ )
1095
+
1096
+ def initial_batch(
1097
+ self,
1098
+ ) -> tuple["CubicMeshPDEStatio", Float[Array, "initial_batch_size dim"]]:
1099
+ if self.initial_batch_size is None or self.initial_batch_size == self.ni:
1100
+ # Avoid unnecessary reshuffling
1101
+ return self, self.initial
1102
+
1103
+ bstart = self.curr_initial_idx
1104
+ bend = bstart + self.initial_batch_size
1105
+
1106
+ n_eff = self.ni
1107
+
1108
+ new_attributes = _reset_or_increment(bend, n_eff, self._get_initial_operands())
1109
+ new = eqx.tree_at(
1110
+ lambda m: (m.key, m.initial, m.curr_initial_idx),
1111
+ self,
1112
+ new_attributes,
1113
+ )
1114
+ return new, jax.lax.dynamic_slice(
1115
+ new.initial,
1116
+ start_indices=(new.curr_initial_idx, 0),
1117
+ slice_sizes=(new.initial_batch_size, new.dim),
908
1118
  )
909
1119
 
910
1120
  def get_batch(self) -> tuple["CubicMeshPDENonStatio", PDENonStatioBatch]:
911
1121
  """
912
- Generic method to return a batch. Here we call `self.inside_batch()`,
913
- `self.border_batch()` and `self.temporal_batch()`
1122
+ Generic method to return a batch. Here we call `self.domain_batch()`,
1123
+ `self.border_batch()` and `self.initial_batch()`
914
1124
  """
915
- new, x = self.inside_batch()
916
- new, dx = new.border_batch()
917
- new, t = new.temporal_batch()
918
- t = t.reshape(new.temporal_batch_size, 1)
919
-
920
- if new.cartesian_product:
921
- t_x = make_cartesian_product(t, x)
1125
+ new, domain = self.domain_batch()
1126
+ if self.border is not None:
1127
+ new, border = new.border_batch()
922
1128
  else:
923
- t_x = jnp.concatenate([t, x], axis=1)
924
-
925
- if dx is not None:
926
- t_ = t.reshape(new.temporal_batch_size, 1, 1)
927
- t_ = jnp.repeat(t_, dx.shape[-1], axis=2)
928
- if new.cartesian_product or new.dim == 1:
929
- t_dx = make_cartesian_product(t_, dx)
930
- else:
931
- t_dx = jnp.concatenate([t_, dx], axis=1)
1129
+ border = None
1130
+ if self.initial is not None:
1131
+ new, initial = new.initial_batch()
932
1132
  else:
933
- t_dx = None
1133
+ initial = None
934
1134
 
935
1135
  return new, PDENonStatioBatch(
936
- times_x_inside_batch=t_x, times_x_border_batch=t_dx
1136
+ domain_batch=domain, border_batch=border, initial_batch=initial
937
1137
  )
938
1138
 
939
1139
 
940
1140
  class DataGeneratorObservations(eqx.Module):
941
1141
  r"""
942
- Despite the class name, it is rather a dataloader from user provided
943
- observations that will be used for the observations loss
1142
+ Despite the class name, it is rather a dataloader for user-provided
1143
+ observations which will are used in the observations loss.
944
1144
 
945
1145
  Parameters
946
1146
  ----------
947
1147
  key : Key
948
1148
  Jax random key to shuffle batches
949
- obs_batch_size : Int
1149
+ obs_batch_size : Int | None
950
1150
  The size of the batch of randomly selected points among
951
- the `n` points. `obs_batch_size` will be the same for all
952
- elements of the return observation dict batch.
953
- NOTE: no check is done BUT users should be careful that
954
- `obs_batch_size` must be equal to `temporal_batch_size` or
955
- `omega_batch_size` or the product of both. In the first case, the
956
- present DataGeneratorObservations instance complements an ODEBatch,
957
- PDEStatioBatch or a PDENonStatioBatch (with self.cartesian_product
958
- = False). In the second case, `obs_batch_size` =
959
- `temporal_batch_size * omega_batch_size` if the present
960
- DataGeneratorParameter complements a PDENonStatioBatch
961
- with self.cartesian_product = True
1151
+ the `n` points. If None, no minibatch are used.
962
1152
  observed_pinn_in : Float[Array, "n_obs nb_pinn_in"]
963
1153
  Observed values corresponding to the input of the PINN
964
1154
  (eg. the time at which we recorded the observations). The first
@@ -984,11 +1174,11 @@ class DataGeneratorObservations(eqx.Module):
984
1174
  datasets of observations. Note that computations for **batches**
985
1175
  can still be performed on other devices (*e.g.* GPU, TPU or
986
1176
  any pre-defined Sharding) thanks to the `obs_batch_sharding`
987
- arguments of `jinns.solve()`. Read the docs for more info.
1177
+ arguments of `jinns.solve()`. Read `jinns.solve()` doc for more info.
988
1178
  """
989
1179
 
990
1180
  key: Key
991
- obs_batch_size: Int = eqx.field(static=True)
1181
+ obs_batch_size: Int | None = eqx.field(static=True)
992
1182
  observed_pinn_in: Float[Array, "n_obs nb_pinn_in"]
993
1183
  observed_values: Float[Array, "n_obs nb_pinn_out"]
994
1184
  observed_eq_params: Dict[str, Float[Array, "n_obs 1"]] = eqx.field(
@@ -996,7 +1186,7 @@ class DataGeneratorObservations(eqx.Module):
996
1186
  )
997
1187
  sharding_device: jax.sharding.Sharding = eqx.field(static=True, default=None)
998
1188
 
999
- n: Int = eqx.field(init=False)
1189
+ n: Int = eqx.field(init=False, static=True)
1000
1190
  curr_idx: Int = eqx.field(init=False)
1001
1191
  indices: Array = eqx.field(init=False)
1002
1192
 
@@ -1040,7 +1230,11 @@ class DataGeneratorObservations(eqx.Module):
1040
1230
  self.observed_eq_params, self.sharding_device
1041
1231
  )
1042
1232
 
1043
- self.curr_idx = jnp.iinfo(jnp.int32).max - self.obs_batch_size - 1
1233
+ if self.obs_batch_size is not None:
1234
+ self.curr_idx = self.n + self.obs_batch_size
1235
+ # to be sure there is a shuffling at first get_batch()
1236
+ else:
1237
+ self.curr_idx = 0
1044
1238
  # For speed and to avoid duplicating data what is really
1045
1239
  # shuffled is a vector of indices
1046
1240
  if self.sharding_device is not None:
@@ -1079,6 +1273,13 @@ class DataGeneratorObservations(eqx.Module):
1079
1273
  observed_pinn_in, observed_values, etc. are dictionaries with keys
1080
1274
  representing the PINNs.
1081
1275
  """
1276
+ if self.obs_batch_size is None or self.obs_batch_size == self.n:
1277
+ # Avoid unnecessary reshuffling
1278
+ return self, {
1279
+ "pinn_in": self.observed_pinn_in,
1280
+ "val": self.observed_values,
1281
+ "eq_params": self.observed_eq_params,
1282
+ }
1082
1283
 
1083
1284
  new_attributes = _reset_or_increment(
1084
1285
  self.curr_idx + self.obs_batch_size, self.n, self._get_operands()
@@ -1120,7 +1321,9 @@ class DataGeneratorObservations(eqx.Module):
1120
1321
 
1121
1322
  class DataGeneratorParameter(eqx.Module):
1122
1323
  r"""
1123
- A data generator for additional unidimensional parameter(s)
1324
+ A data generator for additional unidimensional equation parameter(s).
1325
+ Mostly useful for metamodeling where batch of `params.eq_params` are fed
1326
+ to the network.
1124
1327
 
1125
1328
  Parameters
1126
1329
  ----------
@@ -1131,19 +1334,11 @@ class DataGeneratorParameter(eqx.Module):
1131
1334
  The number of total points that will be divided in
1132
1335
  batches. Batches are made so that each data point is seen only
1133
1336
  once during 1 epoch.
1134
- param_batch_size : Int
1337
+ param_batch_size : Int | None, default=None
1135
1338
  The size of the batch of randomly selected points among
1136
- the `n` points. `param_batch_size` will be the same for all
1137
- additional batch of parameter.
1138
- NOTE: no check is done BUT users should be careful that
1139
- `param_batch_size` must be equal to `temporal_batch_size` or
1140
- `omega_batch_size` or the product of both. In the first case, the
1141
- present DataGeneratorParameter instance complements an ODEBatch, a
1142
- PDEStatioBatch or a PDENonStatioBatch (with self.cartesian_product
1143
- = False). In the second case, `param_batch_size` =
1144
- `temporal_batch_size * omega_batch_size` if the present
1145
- DataGeneratorParameter complements a PDENonStatioBatch
1146
- with self.cartesian_product = True
1339
+ the `n` points. **Important**: no check is performed but
1340
+ `param_batch_size` must be the same as other collocation points
1341
+ batch_size (time, space or timexspace depending on the context). This is because we vmap the network on all its axes at once to compute the MSE. Also, `param_batch_size` will be the same for all parameters. If None, no mini-batches are used.
1147
1342
  param_ranges : Dict[str, tuple[Float, Float] | None, default={}
1148
1343
  A dict. A dict of tuples (min, max), which
1149
1344
  reprensents the range of real numbers where to sample batches (of
@@ -1153,31 +1348,31 @@ class DataGeneratorParameter(eqx.Module):
1153
1348
  By providing several entries in this dictionary we can sample
1154
1349
  an arbitrary number of parameters.
1155
1350
  **Note** that we currently only support unidimensional parameters.
1156
- This argument can be done if we only use `user_data`.
1351
+ This argument can be None if we use `user_data`.
1157
1352
  method : str, default="uniform"
1158
1353
  Either `grid` or `uniform`, default is `uniform`. `grid` means
1159
1354
  regularly spaced points over the domain. `uniform` means uniformly
1160
1355
  sampled points over the domain
1161
- user_data : Dict[str, Float[Array, "n"]] | None, default={}
1356
+ user_data : Dict[str, Float[jnp.ndarray, "n"]] | None, default={}
1162
1357
  A dictionary containing user-provided data for parameters.
1163
- As for `param_ranges`, the key corresponds to the parameter name,
1164
- the keys must match the keys in `params["eq_params"]` and only
1165
- unidimensional arrays are supported. Therefore, the jnp arrays
1166
- found at `user_data[k]` must have shape `(n, 1)` or `(n,)`.
1358
+ The keys corresponds to the parameter name,
1359
+ and must match the keys in `params["eq_params"]`. Only
1360
+ unidimensional `jnp.array` are supported. Therefore, the array at
1361
+ `user_data[k]` must have shape `(n, 1)` or `(n,)`.
1167
1362
  Note that if the same key appears in `param_ranges` and `user_data`
1168
1363
  priority goes for the content in `user_data`.
1169
1364
  Defaults to None.
1170
1365
  """
1171
1366
 
1172
1367
  keys: Key | Dict[str, Key]
1173
- n: Int
1174
- param_batch_size: Int = eqx.field(static=True)
1368
+ n: Int = eqx.field(static=True)
1369
+ param_batch_size: Int | None = eqx.field(static=True, default=None)
1175
1370
  param_ranges: Dict[str, tuple[Float, Float]] = eqx.field(
1176
1371
  static=True, default_factory=lambda: {}
1177
1372
  )
1178
1373
  method: str = eqx.field(static=True, default="uniform")
1179
- user_data: Dict[str, Float[Array, "n"]] | None = eqx.field(
1180
- static=True, default_factory=lambda: {}
1374
+ user_data: Dict[str, Float[onp.Array, "n"]] | None = eqx.field(
1375
+ default_factory=lambda: {}
1181
1376
  )
1182
1377
 
1183
1378
  curr_param_idx: Dict[str, Int] = eqx.field(init=False)
@@ -1197,11 +1392,13 @@ class DataGeneratorParameter(eqx.Module):
1197
1392
  all_keys = set().union(self.param_ranges, self.user_data)
1198
1393
  self.keys = dict(zip(all_keys, jax.random.split(self.keys, len(all_keys))))
1199
1394
 
1200
- self.curr_param_idx = {}
1201
- for k in self.keys.keys():
1202
- self.curr_param_idx[k] = (
1203
- jnp.iinfo(jnp.int32).max - self.param_batch_size - 1
1204
- )
1395
+ if self.param_batch_size is None:
1396
+ self.curr_param_idx = None
1397
+ else:
1398
+ self.curr_param_idx = {}
1399
+ for k in self.keys.keys():
1400
+ self.curr_param_idx[k] = self.n + self.param_batch_size
1401
+ # to be sure there is a shuffling at first get_batch()
1205
1402
 
1206
1403
  # The call to self.generate_data() creates
1207
1404
  # the dict self.param_n_samples and then we will only use this one
@@ -1268,6 +1465,9 @@ class DataGeneratorParameter(eqx.Module):
1268
1465
  otherwise we just return the next unseen batch.
1269
1466
  """
1270
1467
 
1468
+ if self.param_batch_size is None or self.param_batch_size == self.n:
1469
+ return self, self.param_n_samples
1470
+
1271
1471
  def _reset_or_increment_wrapper(param_k, idx_k, key_k):
1272
1472
  return _reset_or_increment(
1273
1473
  idx_k + self.param_batch_size,
@@ -1319,7 +1519,7 @@ class DataGeneratorObservationsMultiPINNs(eqx.Module):
1319
1519
 
1320
1520
  Technically, the constraint on the observations in SystemLossXDE are
1321
1521
  applied in `constraints_system_loss_apply` and in this case the
1322
- batch.obs_batch_dict is a dict of obs_batch_dict over which the tree_map
1522
+ `batch.obs_batch_dict` is a dict of obs_batch_dict over which the tree_map
1323
1523
  applies (we select the obs_batch_dict corresponding to its `u_dict` entry)
1324
1524
 
1325
1525
  Parameters
@@ -1328,15 +1528,6 @@ class DataGeneratorObservationsMultiPINNs(eqx.Module):
1328
1528
  The size of the batch of randomly selected observations
1329
1529
  `obs_batch_size` will be the same for all the
1330
1530
  elements of the obs dict.
1331
- NOTE: no check is done BUT users should be careful that
1332
- `obs_batch_size` must be equal to `temporal_batch_size` or
1333
- `omega_batch_size` or the product of both. In the first case, the
1334
- present DataGeneratorObservations instance complements an ODEBatch,
1335
- PDEStatioBatch or a PDENonStatioBatch (with self.cartesian_product
1336
- = False). In the second case, `obs_batch_size` =
1337
- `temporal_batch_size * omega_batch_size` if the present
1338
- DataGeneratorParameter complements a PDENonStatioBatch
1339
- with self.cartesian_product = True
1340
1531
  observed_pinn_in_dict : Dict[str, Float[Array, "n_obs nb_pinn_in"] | None]
1341
1532
  A dict of observed_pinn_in as defined in DataGeneratorObservations.
1342
1533
  Keys must be that of `u_dict`.