jinns 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +904 -1203
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +322 -167
  9. jinns/loss/_LossODE.py +324 -322
  10. jinns/loss/_LossPDE.py +652 -1027
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +87 -41
  13. jinns/loss/{_Losses.py → _loss_utils.py} +101 -45
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +521 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +183 -39
  22. jinns/solver/_solve.py +151 -124
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -44
  25. jinns/utils/_hyperpinn.py +224 -119
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +113 -86
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +48 -140
  32. jinns-1.1.0.dist-info/AUTHORS +2 -0
  33. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/METADATA +5 -4
  34. jinns-1.1.0.dist-info/RECORD +39 -0
  35. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/WHEEL +1 -1
  36. jinns/experimental/_sinuspinn.py +0 -135
  37. jinns/experimental/_spectralpinn.py +0 -87
  38. jinns/solver/_seq2seq.py +0 -157
  39. jinns/utils/_optim.py +0 -147
  40. jinns/utils/_utils_uspinn.py +0 -727
  41. jinns-0.9.0.dist-info/RECORD +0 -36
  42. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/LICENSE +0 -0
  43. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,52 +1,49 @@
1
+ # pylint: disable=unsubscriptable-object
1
2
  """
2
- DataGenerators to generate batches of points in space, time and more
3
+ Define the DataGeneratorODE equinox module
3
4
  """
4
-
5
- from typing import NamedTuple
6
- from jax.typing import ArrayLike
5
+ from __future__ import (
6
+ annotations,
7
+ ) # https://docs.python.org/3/library/typing.html#constant
8
+
9
+ from typing import TYPE_CHECKING, Dict
10
+ from dataclasses import InitVar
11
+ import equinox as eqx
12
+ import jax
7
13
  import jax.numpy as jnp
8
- from jax import random
9
- from jax.tree_util import register_pytree_node_class
10
- import jax.lax
11
-
12
-
13
- class ODEBatch(NamedTuple):
14
- temporal_batch: ArrayLike
15
- param_batch_dict: dict = None
16
- obs_batch_dict: dict = None
17
-
18
-
19
- class PDENonStatioBatch(NamedTuple):
20
- times_x_inside_batch: ArrayLike
21
- times_x_border_batch: ArrayLike
22
- param_batch_dict: dict = None
23
- obs_batch_dict: dict = None
24
-
14
+ from jaxtyping import Key, Int, PyTree, Array, Float, Bool
15
+ from jinns.data._Batchs import *
25
16
 
26
- class PDEStatioBatch(NamedTuple):
27
- inside_batch: ArrayLike
28
- border_batch: ArrayLike
29
- param_batch_dict: dict = None
30
- obs_batch_dict: dict = None
17
+ if TYPE_CHECKING:
18
+ from jinns.utils._types import *
31
19
 
32
20
 
33
- def append_param_batch(batch, param_batch_dict):
21
+ def append_param_batch(batch: AnyBatch, param_batch_dict: dict) -> AnyBatch:
34
22
  """
35
23
  Utility function that fill the param_batch_dict of a batch object with a
36
24
  param_batch_dict
37
25
  """
38
- return batch._replace(param_batch_dict=param_batch_dict)
26
+ return eqx.tree_at(
27
+ lambda m: m.param_batch_dict,
28
+ batch,
29
+ param_batch_dict,
30
+ is_leaf=lambda x: x is None,
31
+ )
39
32
 
40
33
 
41
- def append_obs_batch(batch, obs_batch_dict):
34
+ def append_obs_batch(batch: AnyBatch, obs_batch_dict: dict) -> AnyBatch:
42
35
  """
43
36
  Utility function that fill the obs_batch_dict of a batch object with a
44
37
  obs_batch_dict
45
38
  """
46
- return batch._replace(obs_batch_dict=obs_batch_dict)
39
+ return eqx.tree_at(
40
+ lambda m: m.obs_batch_dict, batch, obs_batch_dict, is_leaf=lambda x: x is None
41
+ )
47
42
 
48
43
 
49
- def make_cartesian_product(b1, b2):
44
+ def make_cartesian_product(
45
+ b1: Float[Array, "batch_size dim1"], b2: Float[Array, "batch_size dim2"]
46
+ ) -> Float[Array, "(batch_size*batch_size) (dim1+dim2)"]:
50
47
  """
51
48
  Create the cartesian product of a time and a border omega batches
52
49
  by tiling and repeating
@@ -58,29 +55,39 @@ def make_cartesian_product(b1, b2):
58
55
  return jnp.concatenate([b1, b2], axis=1)
59
56
 
60
57
 
61
- def _reset_batch_idx_and_permute(operands):
58
+ def _reset_batch_idx_and_permute(
59
+ operands: tuple[Key, Float[Array, "n dimension"], Int, None, Float[Array, "n"]]
60
+ ) -> tuple[Key, Float[Array, "n dimension"], Int]:
62
61
  key, domain, curr_idx, _, p = operands
63
62
  # resetting counter
64
63
  curr_idx = 0
65
64
  # reshuffling
66
- key, subkey = random.split(key)
65
+ key, subkey = jax.random.split(key)
67
66
  # domain = random.permutation(subkey, domain, axis=0, independent=False)
68
67
  # we want that permutation = choice when p=None
69
68
  # otherwise p is used to avoid collocation points not in nt_start
70
- domain = random.choice(subkey, domain, shape=(domain.shape[0],), replace=False, p=p)
69
+ domain = jax.random.choice(
70
+ subkey, domain, shape=(domain.shape[0],), replace=False, p=p
71
+ )
71
72
 
72
73
  # return updated
73
74
  return (key, domain, curr_idx)
74
75
 
75
76
 
76
- def _increment_batch_idx(operands):
77
+ def _increment_batch_idx(
78
+ operands: tuple[Key, Float[Array, "n dimension"], Int, None, Float[Array, "n"]]
79
+ ) -> tuple[Key, Float[Array, "n dimension"], Int]:
77
80
  key, domain, curr_idx, batch_size, _ = operands
78
81
  # simply increases counter and get the batch
79
82
  curr_idx += batch_size
80
83
  return (key, domain, curr_idx)
81
84
 
82
85
 
83
- def _reset_or_increment(bend, n_eff, operands):
86
+ def _reset_or_increment(
87
+ bend: Int,
88
+ n_eff: Int,
89
+ operands: tuple[Key, Float[Array, "n dimension"], Int, None, Float[Array, "n"]],
90
+ ) -> tuple[Key, Float[Array, "n dimension"], Int]:
84
91
  """
85
92
  Factorize the code of the jax.lax.cond which checks if we have seen all the
86
93
  batches in an epoch
@@ -109,15 +116,18 @@ def _reset_or_increment(bend, n_eff, operands):
109
116
  )
110
117
 
111
118
 
112
- def _check_and_set_rar_parameters(rar_parameters, n, n_start):
119
+ def _check_and_set_rar_parameters(
120
+ rar_parameters: dict, n: Int, n_start: Int
121
+ ) -> tuple[Int, Float[Array, "n"], Int, Int]:
113
122
  if rar_parameters is not None and n_start is None:
114
123
  raise ValueError(
115
- f"n_start or/and nt_start must be provided in the context of RAR sampling scheme, {n_start} was provided"
124
+ "nt_start must be provided in the context of RAR sampling scheme"
116
125
  )
126
+
117
127
  if rar_parameters is not None:
118
128
  # Default p is None. However, in the RAR sampling scheme we use 0
119
129
  # probability to specify non-used collocation points (i.e. points
120
- # above n_start). Thus, p is a vector of probability of shape (n, 1).
130
+ # above nt_start). Thus, p is a vector of probability of shape (nt, 1).
121
131
  p = jnp.zeros((n,))
122
132
  p = p.at[:n_start].set(1 / n_start)
123
133
  # set internal counter for the number of gradient steps since the
@@ -129,6 +139,7 @@ def _check_and_set_rar_parameters(rar_parameters, n, n_start):
129
139
  # have been added
130
140
  rar_iter_nb = 0
131
141
  else:
142
+ n_start = n
132
143
  p = None
133
144
  rar_iter_from_last_sampling = None
134
145
  rar_iter_nb = None
@@ -136,109 +147,102 @@ def _check_and_set_rar_parameters(rar_parameters, n, n_start):
136
147
  return n_start, p, rar_iter_from_last_sampling, rar_iter_nb
137
148
 
138
149
 
139
- #####################################################
140
- # DataGenerator for ODE : only returns time_batches
141
- #####################################################
142
-
143
-
144
- @register_pytree_node_class
145
- class DataGeneratorODE:
150
+ class DataGeneratorODE(eqx.Module):
146
151
  """
147
152
  A class implementing data generator object for ordinary differential equations.
148
153
 
149
-
150
- **Note:** DataGeneratorODE is jittable. Hence it implements the tree_flatten() and
151
- tree_unflatten methods.
154
+ Parameters
155
+ ----------
156
+ key : Key
157
+ Jax random key to sample new time points and to shuffle batches
158
+ nt : Int
159
+ The number of total time points that will be divided in
160
+ batches. Batches are made so that each data point is seen only
161
+ once during 1 epoch.
162
+ tmin : float
163
+ The minimum value of the time domain to consider
164
+ tmax : float
165
+ The maximum value of the time domain to consider
166
+ temporal_batch_size : int
167
+ The size of the batch of randomly selected points among
168
+ the `nt` points.
169
+ method : str, default="uniform"
170
+ Either `grid` or `uniform`, default is `uniform`.
171
+ The method that generates the `nt` time points. `grid` means
172
+ regularly spaced points over the domain. `uniform` means uniformly
173
+ sampled points over the domain
174
+ rar_parameters : Dict[str, Int], default=None
175
+ Default to None: do not use Residual Adaptative Resampling.
176
+ Otherwise a dictionary with keys. `start_iter`: the iteration at
177
+ which we start the RAR sampling scheme (we first have a burn in
178
+ period). `update_rate`: the number of gradient steps taken between
179
+ each appending of collocation points in the RAR algo.
180
+ `sample_size`: the size of the sample from which we will select new
181
+ collocation points. `selected_sample_size_times`: the number of selected
182
+ points from the sample to be added to the current collocation
183
+ points
184
+ nt_start : Int, default=None
185
+ Defaults to None. The effective size of nt used at start time.
186
+ This value must be
187
+ provided when rar_parameters is not None. Otherwise we set internally
188
+ nt_start = nt and this is hidden from the user.
189
+ In RAR, nt_start
190
+ then corresponds to the initial number of points we train the PINN.
152
191
  """
153
192
 
154
- def __init__(
155
- self,
156
- key,
157
- nt,
158
- tmin,
159
- tmax,
160
- temporal_batch_size,
161
- method="uniform",
162
- rar_parameters=None,
163
- nt_start=None,
164
- data_exists=False,
165
- ):
166
- """
167
- Parameters
168
- ----------
169
- key
170
- Jax random key to sample new time points and to shuffle batches
171
- nt
172
- An integer. The number of total time points that will be divided in
173
- batches. Batches are made so that each data point is seen only
174
- once during 1 epoch.
175
- tmin
176
- A float. The minimum value of the time domain to consider
177
- tmax
178
- A float. The maximum value of the time domain to consider
179
- temporal_batch_size
180
- An integer. The size of the batch of randomly selected points among
181
- the `nt` points.
182
- method
183
- Either `grid` or `uniform`, default is `uniform`.
184
- The method that generates the `nt` time points. `grid` means
185
- regularly spaced points over the domain. `uniform` means uniformly
186
- sampled points over the domain
187
- rar_parameters
188
- Default to None: do not use Residual Adaptative Resampling.
189
- Otherwise a dictionary with keys. `start_iter`: the iteration at
190
- which we start the RAR sampling scheme (we first have a burn in
191
- period). `update_every`: the number of gradient steps taken between
192
- each appending of collocation points in the RAR algo.
193
- `sample_size_times`: the size of the sample from which we will select new
194
- collocation points. `selected_sample_size_times`: the number of selected
195
- points from the sample to be added to the current collocation
196
- points
197
- "DeepXDE: A deep learning library for solving differential
198
- equations", L. Lu, SIAM Review, 2021
199
- nt_start
200
- Defaults to None. The effective size of nt used at start time.
201
- This value must be
202
- provided when rar_parameters is not None. Otherwise we set internally
203
- nt_start = nt and this is hidden from the user.
204
- In RAR, nt_start
205
- then corresponds to the initial number of points we train the PINN.
206
- data_exists
207
- Must be left to `False` when created by the user. Avoids the
208
- regeneration of the `nt` time points at each pytree flattening and
209
- unflattening.
210
- """
211
- self.data_exists = data_exists
212
- self._key = key
213
- self.nt = nt
214
- self.tmin = tmin
215
- self.tmax = tmax
216
- self.temporal_batch_size = temporal_batch_size
217
- self.method = method
218
- self.rar_parameters = rar_parameters
219
-
220
- # Set-up for RAR (if used)
193
+ key: Key
194
+ nt: Int
195
+ tmin: Float
196
+ tmax: Float
197
+ temporal_batch_size: Int = eqx.field(static=True) # static cause used as a
198
+ # shape in jax.lax.dynamic_slice
199
+ method: str = eqx.field(static=True, default_factory=lambda: "uniform")
200
+ rar_parameters: Dict[str, Int] = None
201
+ nt_start: Int = eqx.field(static=True, default=None)
202
+
203
+ # all the init=False fields are set in __post_init__, even after a _replace
204
+ # or eqx.tree_at __post_init__ is called
205
+ p_times: Float[Array, "nt"] = eqx.field(init=False)
206
+ rar_iter_from_last_sampling: Int = eqx.field(init=False)
207
+ rar_iter_nb: Int = eqx.field(init=False)
208
+ curr_time_idx: Int = eqx.field(init=False)
209
+ times: Float[Array, "nt"] = eqx.field(init=False)
210
+
211
+ def __post_init__(self):
221
212
  (
222
213
  self.nt_start,
223
214
  self.p_times,
224
215
  self.rar_iter_from_last_sampling,
225
216
  self.rar_iter_nb,
226
- ) = _check_and_set_rar_parameters(rar_parameters, n=nt, n_start=nt_start)
227
-
228
- if not self.data_exists:
229
- # Useful when using a lax.scan with pytree
230
- # Optionally can tell JAX not to re-generate data
231
- self.curr_time_idx = 0
232
- self.generate_time_data()
233
- self._key, self.times, _ = _reset_batch_idx_and_permute(
234
- self._get_time_operands()
235
- )
236
-
237
- def sample_in_time_domain(self, n_samples):
238
- self._key, subkey = random.split(self._key, 2)
239
- return random.uniform(subkey, (n_samples,), minval=self.tmin, maxval=self.tmax)
217
+ ) = _check_and_set_rar_parameters(self.rar_parameters, self.nt, self.nt_start)
218
+
219
+ self.curr_time_idx = jnp.iinfo(jnp.int32).max - self.temporal_batch_size - 1
220
+ # to be sure there is a
221
+ # shuffling at first get_batch() we do not call
222
+ # _reset_batch_idx_and_permute in __init__ or __post_init__ because it
223
+ # would return a copy of self and we have not investigate what would
224
+ # happen
225
+ # NOTE the (- self.temporal_batch_size - 1) because otherwise when computing
226
+ # `bend` we overflow the max int32 with unwanted behaviour
227
+
228
+ self.key, self.times = self.generate_time_data(self.key)
229
+ # Note that, here, in __init__ (and __post_init__), this is the
230
+ # only place where self assignment are authorized so we do the
231
+ # above way for the key. Note that one of the motivation to return the
232
+ # key from generate_*_data is to easily align key with legacy
233
+ # DataGenerators to use same unit tests
234
+
235
+ def sample_in_time_domain(
236
+ self, key: Key, sample_size: Int = None
237
+ ) -> Float[Array, "nt"]:
238
+ return jax.random.uniform(
239
+ key,
240
+ (self.nt if sample_size is None else sample_size,),
241
+ minval=self.tmin,
242
+ maxval=self.tmax,
243
+ )
240
244
 
241
- def generate_time_data(self):
245
+ def generate_time_data(self, key: Key) -> tuple[Key, Float[Array, "nt"]]:
242
246
  """
243
247
  Construct a complete set of `self.nt` time points according to the
244
248
  specified `self.method`
@@ -246,24 +250,28 @@ class DataGeneratorODE:
246
250
  Note that self.times has always size self.nt and not self.nt_start, even
247
251
  in RAR scheme, we must allocate all the collocation points
248
252
  """
253
+ key, subkey = jax.random.split(self.key)
249
254
  if self.method == "grid":
250
- self.partial_times = (self.tmax - self.tmin) / self.nt
251
- self.times = jnp.arange(self.tmin, self.tmax, self.partial_times)
252
- elif self.method == "uniform":
253
- self.times = self.sample_in_time_domain(self.nt)
254
- else:
255
- raise ValueError("Method " + self.method + " is not implemented.")
255
+ partial_times = (self.tmax - self.tmin) / self.nt
256
+ return key, jnp.arange(self.tmin, self.tmax, partial_times)
257
+ if self.method == "uniform":
258
+ return key, self.sample_in_time_domain(subkey)
259
+ raise ValueError("Method " + self.method + " is not implemented.")
256
260
 
257
- def _get_time_operands(self):
261
+ def _get_time_operands(
262
+ self,
263
+ ) -> tuple[Key, Float[Array, "nt"], Int, Int, Float[Array, "nt"]]:
258
264
  return (
259
- self._key,
265
+ self.key,
260
266
  self.times,
261
267
  self.curr_time_idx,
262
268
  self.temporal_batch_size,
263
269
  self.p_times,
264
270
  )
265
271
 
266
- def temporal_batch(self):
272
+ def temporal_batch(
273
+ self,
274
+ ) -> tuple["DataGeneratorODE", Float[Array, "temporal_batch_size"]]:
267
275
  """
268
276
  Return a batch of time points. If all the batches have been seen, we
269
277
  reshuffle them, otherwise we just return the next unseen batch.
@@ -275,210 +283,142 @@ class DataGeneratorODE:
275
283
  if self.rar_parameters is not None:
276
284
  nt_eff = (
277
285
  self.nt_start
278
- + self.rar_iter_nb * self.rar_parameters["selected_sample_size_omega"]
286
+ + self.rar_iter_nb * self.rar_parameters["selected_sample_size_times"]
279
287
  )
280
288
  else:
281
289
  nt_eff = self.nt
282
- (self._key, self.times, self.curr_time_idx) = _reset_or_increment(
283
- bend, nt_eff, self._get_time_operands()
290
+
291
+ new_attributes = _reset_or_increment(bend, nt_eff, self._get_time_operands())
292
+ new = eqx.tree_at(
293
+ lambda m: (m.key, m.times, m.curr_time_idx), self, new_attributes
284
294
  )
285
295
 
286
296
  # commands below are equivalent to
287
297
  # return self.times[i:(i+t_batch_size)]
288
298
  # start indices can be dynamic be the slice shape is fixed
289
- return jax.lax.dynamic_slice(
290
- self.times,
291
- start_indices=(self.curr_time_idx,),
292
- slice_sizes=(self.temporal_batch_size,),
299
+ return new, jax.lax.dynamic_slice(
300
+ new.times,
301
+ start_indices=(new.curr_time_idx,),
302
+ slice_sizes=(new.temporal_batch_size,),
293
303
  )
294
304
 
295
- def get_batch(self):
305
+ def get_batch(self) -> tuple["DataGeneratorODE", ODEBatch]:
296
306
  """
297
307
  Generic method to return a batch. Here we call `self.temporal_batch()`
298
308
  """
299
- return ODEBatch(temporal_batch=self.temporal_batch())
300
-
301
- def tree_flatten(self):
302
- children = (
303
- self._key,
304
- self.times,
305
- self.curr_time_idx,
306
- self.tmin,
307
- self.tmax,
308
- self.p_times,
309
- self.rar_iter_from_last_sampling,
310
- self.rar_iter_nb,
311
- ) # arrays / dynamic values
312
- aux_data = {
313
- k: vars(self)[k]
314
- for k in [
315
- "temporal_batch_size",
316
- "method",
317
- "nt",
318
- "rar_parameters",
319
- "nt_start",
320
- ]
321
- } # static values
322
- return (children, aux_data)
323
-
324
- @classmethod
325
- def tree_unflatten(cls, aux_data, children):
326
- """
327
- **Note:** When reconstructing the class, we force ``data_exists=True``
328
- in order not to re-generate the data at each flattening and
329
- unflattening that happens e.g. during the gradient descent in the
330
- optimization process
331
- """
332
- (
333
- key,
334
- times,
335
- curr_time_idx,
336
- tmin,
337
- tmax,
338
- p_times,
339
- rar_iter_from_last_sampling,
340
- rar_iter_nb,
341
- ) = children
342
- obj = cls(
343
- key=key,
344
- data_exists=True,
345
- tmin=tmin,
346
- tmax=tmax,
347
- **aux_data,
348
- )
349
- obj.times = times
350
- obj.curr_time_idx = curr_time_idx
351
- obj.p_times = p_times
352
- obj.rar_iter_from_last_sampling = rar_iter_from_last_sampling
353
- obj.rar_iter_nb = rar_iter_nb
354
- return obj
355
-
356
-
357
- ##########################################
358
- # Data Generator for PDE in stationnary
359
- # and non-stationnary cases
360
- ##########################################
361
-
362
-
363
- class DataGeneratorPDEAbstract:
364
- """generic skeleton class for a PDE data generator"""
309
+ new, temporal_batch = self.temporal_batch()
310
+ return new, ODEBatch(temporal_batch=temporal_batch)
365
311
 
366
- def __init__(self, data_exists=False) -> None:
367
- # /!\ WARNING /!\: an-end user should never create an object
368
- # with data_exists=True. Or else generate_data() won't be called.
369
- # Useful when using a lax.scan with a DataGenerator in the carry
370
- # It tells JAX not to re-generate data in the __init__()
371
- self.data_exists = data_exists
372
312
 
373
-
374
- @register_pytree_node_class
375
- class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
376
- """
313
+ class CubicMeshPDEStatio(eqx.Module):
314
+ r"""
377
315
  A class implementing data generator object for stationary partial
378
316
  differential equations.
379
317
 
380
-
381
- **Note:** CubicMeshPDEStatio is jittable. Hence it implements the tree_flatten() and
382
- tree_unflatten methods.
318
+ Parameters
319
+ ----------
320
+ key : Key
321
+ Jax random key to sample new time points and to shuffle batches
322
+ n : Int
323
+ The number of total $\Omega$ points that will be divided in
324
+ batches. Batches are made so that each data point is seen only
325
+ once during 1 epoch.
326
+ nb : Int | None
327
+ The total number of points in $\partial\Omega$.
328
+ Can be `None` not to lose performance generating the border
329
+ batch if they are not used
330
+ omega_batch_size : Int
331
+ The size of the batch of randomly selected points among
332
+ the `n` points.
333
+ omega_border_batch_size : Int | None
334
+ The size of the batch of points randomly selected
335
+ among the `nb` points.
336
+ Can be `None` not to lose performance generating the border
337
+ batch if they are not used
338
+ dim : Int
339
+ Dimension of $\Omega$ domain
340
+ min_pts : tuple[tuple[Float, Float], ...]
341
+ A tuple of minimum values of the domain along each dimension. For a sampling
342
+ in `n` dimension, this represents $(x_{1, min}, x_{2,min}, ...,
343
+ x_{n, min})$
344
+ max_pts : tuple[tuple[Float, Float], ...]
345
+ A tuple of maximum values of the domain along each dimension. For a sampling
346
+ in `n` dimension, this represents $(x_{1, max}, x_{2,max}, ...,
347
+ x_{n,max})$
348
+ method : str, default="uniform"
349
+ Either `grid` or `uniform`, default is `uniform`.
350
+ The method that generates the `nt` time points. `grid` means
351
+ regularly spaced points over the domain. `uniform` means uniformly
352
+ sampled points over the domain
353
+ rar_parameters : Dict[str, Int], default=None
354
+ Default to None: do not use Residual Adaptative Resampling.
355
+ Otherwise a dictionary with keys. `start_iter`: the iteration at
356
+ which we start the RAR sampling scheme (we first have a burn in
357
+ period). `update_every`: the number of gradient steps taken between
358
+ each appending of collocation points in the RAR algo.
359
+ `sample_size_omega`: the size of the sample from which we will select new
360
+ collocation points. `selected_sample_size_omega`: the number of selected
361
+ points from the sample to be added to the current collocation
362
+ points
363
+ n_start : Int, default=None
364
+ Defaults to None. The effective size of n used at start time.
365
+ This value must be
366
+ provided when rar_parameters is not None. Otherwise we set internally
367
+ n_start = n and this is hidden from the user.
368
+ In RAR, n_start
369
+ then corresponds to the initial number of points we train the PINN.
383
370
  """
384
371
 
385
- def __init__(
386
- self,
387
- key,
388
- n,
389
- nb,
390
- omega_batch_size,
391
- omega_border_batch_size,
392
- dim,
393
- min_pts,
394
- max_pts,
395
- method="grid",
396
- rar_parameters=None,
397
- n_start=None,
398
- data_exists=False,
399
- ):
400
- r"""
401
- Parameters
402
- ----------
403
- key
404
- Jax random key to sample new time points and to shuffle batches
405
- n
406
- An integer. The number of total :math:`\Omega` points that will be divided in
407
- batches. Batches are made so that each data point is seen only
408
- once during 1 epoch.
409
- nb
410
- An integer. The total number of points in :math:`\partial\Omega`.
411
- Can be `None` not to lose performance generating the border
412
- batch if they are not used
413
- omega_batch_size
414
- An integer. The size of the batch of randomly selected points among
415
- the `n` points.
416
- omega_border_batch_size
417
- An integer. The size of the batch of points randomly selected
418
- among the `nb` points.
419
- Can be `None` not to lose performance generating the border
420
- batch if they are not used
421
- dim
422
- An integer. dimension of :math:`\Omega` domain
423
- min_pts
424
- A tuple of minimum values of the domain along each dimension. For a sampling
425
- in `n` dimension, this represents :math:`(x_{1, min}, x_{2,min}, ...,
426
- x_{n, min})`
427
- max_pts
428
- A tuple of maximum values of the domain along each dimension. For a sampling
429
- in `n` dimension, this represents :math:`(x_{1, max}, x_{2,max}, ...,
430
- x_{n,max})`
431
- method
432
- Either `grid` or `uniform`, default is `grid`.
433
- The method that generates the `nt` time points. `grid` means
434
- regularly spaced points over the domain. `uniform` means uniformly
435
- sampled points over the domain
436
- rar_parameters
437
- Default to None: do not use Residual Adaptative Resampling.
438
- Otherwise a dictionary with keys. `start_iter`: the iteration at
439
- which we start the RAR sampling scheme (we first have a burn in
440
- period). `update_every`: the number of gradient steps taken between
441
- each appending of collocation points in the RAR algo.
442
- `sample_size_omega`: the size of the sample from which we will select new
443
- collocation points. `selected_sample_size_omega`: the number of selected
444
- points from the sample to be added to the current collocation
445
- points
446
- "DeepXDE: A deep learning library for solving differential
447
- equations", L. Lu, SIAM Review, 2021
448
- n_start
449
- Defaults to None. The effective size of n used at start time.
450
- This value must be
451
- provided when rar_parameters is not None. Otherwise we set internally
452
- n_start = n and this is hidden from the user.
453
- In RAR, n_start
454
- then corresponds to the initial number of points we train the PINN.
455
- data_exists
456
- Must be left to `False` when created by the user. Avoids the
457
- regeneration of :math:`\Omega`, :math:`\partial\Omega` and
458
- time points at each pytree flattening and unflattening.
459
- """
460
- super().__init__(data_exists=data_exists)
461
- self.method = method
462
- self._key = key
463
- self.dim = dim
464
- self.min_pts = min_pts
465
- self.max_pts = max_pts
466
- assert dim == len(min_pts) and isinstance(min_pts, tuple)
467
- assert dim == len(max_pts) and isinstance(max_pts, tuple)
468
- self.n = n
469
- self.rar_parameters = rar_parameters
372
+ # kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
373
+ key: Key = eqx.field(kw_only=True)
374
+ n: Int = eqx.field(kw_only=True)
375
+ nb: Int | None = eqx.field(kw_only=True)
376
+ omega_batch_size: Int = eqx.field(
377
+ kw_only=True, static=True
378
+ ) # static cause used as a
379
+ # shape in jax.lax.dynamic_slice
380
+ omega_border_batch_size: Int | None = eqx.field(
381
+ kw_only=True, static=True
382
+ ) # static cause used as a
383
+ # shape in jax.lax.dynamic_slice
384
+ dim: Int = eqx.field(kw_only=True, static=True) # static cause used as a
385
+ # shape in jax.lax.dynamic_slice
386
+ min_pts: tuple[tuple[Float, Float], ...] = eqx.field(kw_only=True)
387
+ max_pts: tuple[tuple[Float, Float], ...] = eqx.field(kw_only=True)
388
+ method: str = eqx.field(
389
+ kw_only=True, static=True, default_factory=lambda: "uniform"
390
+ )
391
+ rar_parameters: Dict[str, Int] = eqx.field(kw_only=True, default=None)
392
+ n_start: Int = eqx.field(kw_only=True, default=None, static=True)
393
+
394
+ # all the init=False fields are set in __post_init__, even after a _replace
395
+ # or eqx.tree_at __post_init__ is called
396
+ p_omega: Float[Array, "n"] = eqx.field(init=False)
397
+ p_border: None = eqx.field(init=False)
398
+ rar_iter_from_last_sampling: Int = eqx.field(init=False)
399
+ rar_iter_nb: Int = eqx.field(init=False)
400
+ curr_omega_idx: Int = eqx.field(init=False)
401
+ curr_omega_border_idx: Int = eqx.field(init=False)
402
+ omega: Float[Array, "n dim"] = eqx.field(init=False)
403
+ omega_border: Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None = eqx.field(
404
+ init=False
405
+ )
406
+
407
+ def __post_init__(self):
408
+ assert self.dim == len(self.min_pts) and isinstance(self.min_pts, tuple)
409
+ assert self.dim == len(self.max_pts) and isinstance(self.max_pts, tuple)
410
+
470
411
  (
471
412
  self.n_start,
472
413
  self.p_omega,
473
414
  self.rar_iter_from_last_sampling,
474
415
  self.rar_iter_nb,
475
- ) = _check_and_set_rar_parameters(rar_parameters, n=n, n_start=n_start)
416
+ ) = _check_and_set_rar_parameters(self.rar_parameters, self.n, self.n_start)
476
417
 
477
418
  self.p_border = None # no RAR sampling for border for now
478
419
 
479
- self.omega_batch_size = omega_batch_size
480
-
481
- if omega_border_batch_size is None:
420
+ # Special handling for the border batch
421
+ if self.omega_border_batch_size is None:
482
422
  self.nb = None
483
423
  self.omega_border_batch_size = None
484
424
  elif self.dim == 1:
@@ -491,48 +431,46 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
491
431
  # ignored since borders of Omega are singletons.
492
432
  # self.border_batch() will return [xmin, xmax]
493
433
  else:
494
- if nb % (2 * self.dim) != 0 or nb < 2 * self.dim:
434
+ if self.nb % (2 * self.dim) != 0 or self.nb < 2 * self.dim:
495
435
  raise ValueError(
496
436
  "number of border point must be"
497
437
  " a multiple of 2xd (the # of faces of a d-dimensional cube)"
498
438
  )
499
- if nb // (2 * self.dim) < omega_border_batch_size:
439
+ if self.nb // (2 * self.dim) < self.omega_border_batch_size:
500
440
  raise ValueError(
501
441
  "number of points per facets (nb//2*self.dim)"
502
442
  " cannot be lower than border batch size"
503
443
  )
504
- self.nb = int((2 * self.dim) * (nb // (2 * self.dim)))
505
- self.omega_border_batch_size = omega_border_batch_size
506
-
507
- if not self.data_exists:
508
- # Useful when using a lax.scan with pytree
509
- # Optionally tells JAX not to re-generate data when re-building the
510
- # object
511
- self.curr_omega_idx = 0
512
- self.curr_omega_border_idx = 0
513
- self.generate_data()
514
- self._key, self.omega, _ = _reset_batch_idx_and_permute(
515
- self._get_omega_operands()
516
- )
517
- if self.omega_border is not None and self.dim > 1:
518
- self._key, self.omega_border, _ = _reset_batch_idx_and_permute(
519
- self._get_omega_border_operands()
520
- )
444
+ self.nb = int((2 * self.dim) * (self.nb // (2 * self.dim)))
521
445
 
522
- def sample_in_omega_domain(self, n_samples):
446
+ self.curr_omega_idx = jnp.iinfo(jnp.int32).max - self.omega_batch_size - 1
447
+ # see explaination in DataGeneratorODE
448
+ if self.omega_border_batch_size is None:
449
+ self.curr_omega_border_idx = None
450
+ else:
451
+ self.curr_omega_border_idx = (
452
+ jnp.iinfo(jnp.int32).max - self.omega_border_batch_size - 1
453
+ )
454
+ # key, subkey = jax.random.split(self.key)
455
+ # self.key = key
456
+ self.key, self.omega, self.omega_border = self.generate_data(self.key)
457
+ # see explaination in DataGeneratorODE for the key
458
+
459
+ def sample_in_omega_domain(
460
+ self, keys: Key, sample_size: Int = None
461
+ ) -> Float[Array, "n dim"]:
462
+ sample_size = self.n if sample_size is None else sample_size
523
463
  if self.dim == 1:
524
464
  xmin, xmax = self.min_pts[0], self.max_pts[0]
525
- self._key, subkey = random.split(self._key, 2)
526
- return random.uniform(
527
- subkey, shape=(n_samples, 1), minval=xmin, maxval=xmax
465
+ return jax.random.uniform(
466
+ keys, shape=(sample_size, 1), minval=xmin, maxval=xmax
528
467
  )
529
- keys = random.split(self._key, self.dim + 1)
530
- self._key = keys[0]
468
+ # keys = jax.random.split(key, self.dim)
531
469
  return jnp.concatenate(
532
470
  [
533
- random.uniform(
534
- keys[i + 1],
535
- (n_samples, 1),
471
+ jax.random.uniform(
472
+ keys[i],
473
+ (sample_size, 1),
536
474
  minval=self.min_pts[i],
537
475
  maxval=self.max_pts[i],
538
476
  )
@@ -541,7 +479,9 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
541
479
  axis=-1,
542
480
  )
543
481
 
544
- def sample_in_omega_border_domain(self, n_samples):
482
+ def sample_in_omega_border_domain(
483
+ self, keys: Key
484
+ ) -> Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None:
545
485
  if self.omega_border_batch_size is None:
546
486
  return None
547
487
  if self.dim == 1:
@@ -553,15 +493,12 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
553
493
  # TODO : find a general & efficient way to sample from the border
554
494
  # (facets) of the hypercube in general dim.
555
495
 
556
- facet_n = n_samples // (2 * self.dim)
557
- keys = random.split(self._key, 5)
558
- self._key = keys[0]
559
- subkeys = keys[1:]
496
+ facet_n = self.nb // (2 * self.dim)
560
497
  xmin = jnp.hstack(
561
498
  [
562
499
  self.min_pts[0] * jnp.ones((facet_n, 1)),
563
- random.uniform(
564
- subkeys[0],
500
+ jax.random.uniform(
501
+ keys[0],
565
502
  (facet_n, 1),
566
503
  minval=self.min_pts[1],
567
504
  maxval=self.max_pts[1],
@@ -571,8 +508,8 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
571
508
  xmax = jnp.hstack(
572
509
  [
573
510
  self.max_pts[0] * jnp.ones((facet_n, 1)),
574
- random.uniform(
575
- subkeys[1],
511
+ jax.random.uniform(
512
+ keys[1],
576
513
  (facet_n, 1),
577
514
  minval=self.min_pts[1],
578
515
  maxval=self.max_pts[1],
@@ -581,8 +518,8 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
581
518
  )
582
519
  ymin = jnp.hstack(
583
520
  [
584
- random.uniform(
585
- subkeys[2],
521
+ jax.random.uniform(
522
+ keys[2],
586
523
  (facet_n, 1),
587
524
  minval=self.min_pts[0],
588
525
  maxval=self.max_pts[0],
@@ -592,8 +529,8 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
592
529
  )
593
530
  ymax = jnp.hstack(
594
531
  [
595
- random.uniform(
596
- subkeys[3],
532
+ jax.random.uniform(
533
+ keys[3],
597
534
  (facet_n, 1),
598
535
  minval=self.min_pts[0],
599
536
  maxval=self.max_pts[0],
@@ -607,54 +544,71 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
607
544
  + f"implemented yet. You are asking for generation in dimension d={self.dim}."
608
545
  )
609
546
 
610
- def generate_data(self):
547
+ def generate_data(self, key: Key) -> tuple[
548
+ Key,
549
+ Float[Array, "n dim"],
550
+ Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None,
551
+ ]:
611
552
  r"""
612
- Construct a complete set of `self.n` :math:`\Omega` points according to the
553
+ Construct a complete set of `self.n` $\Omega$ points according to the
613
554
  specified `self.method`. Also constructs a complete set of `self.nb`
614
- :math:`\partial\Omega` points if `self.omega_border_batch_size` is not
555
+ $\partial\Omega$ points if `self.omega_border_batch_size` is not
615
556
  `None`. If the latter is `None` we set `self.omega_border` to `None`.
616
557
  """
617
-
618
558
  # Generate Omega
619
559
  if self.method == "grid":
620
560
  if self.dim == 1:
621
561
  xmin, xmax = self.min_pts[0], self.max_pts[0]
622
- self.partial = (xmax - xmin) / self.n
562
+ partial = (xmax - xmin) / self.n
623
563
  # shape (n, 1)
624
- self.omega = jnp.arange(xmin, xmax, self.partial)[:, None]
564
+ omega = jnp.arange(xmin, xmax, partial)[:, None]
625
565
  else:
626
- self.partials = [
566
+ partials = [
627
567
  (self.max_pts[i] - self.min_pts[i]) / jnp.sqrt(self.n)
628
568
  for i in range(self.dim)
629
569
  ]
630
570
  xyz_ = jnp.meshgrid(
631
571
  *[
632
- jnp.arange(self.min_pts[i], self.max_pts[i], self.partials[i])
572
+ jnp.arange(self.min_pts[i], self.max_pts[i], partials[i])
633
573
  for i in range(self.dim)
634
574
  ]
635
575
  )
636
576
  xyz_ = [a.reshape((self.n, 1)) for a in xyz_]
637
- self.omega = jnp.concatenate(xyz_, axis=-1)
577
+ omega = jnp.concatenate(xyz_, axis=-1)
638
578
  elif self.method == "uniform":
639
- self.omega = self.sample_in_omega_domain(self.n)
579
+ if self.dim == 1:
580
+ key, subkeys = jax.random.split(key, 2)
581
+ else:
582
+ key, *subkeys = jax.random.split(key, self.dim + 1)
583
+ omega = self.sample_in_omega_domain(subkeys)
640
584
  else:
641
585
  raise ValueError("Method " + self.method + " is not implemented.")
642
586
 
643
587
  # Generate border of omega
644
- self.omega_border = self.sample_in_omega_border_domain(self.nb)
588
+ if self.dim == 2 and self.omega_border_batch_size is not None:
589
+ key, *subkeys = jax.random.split(key, 5)
590
+ else:
591
+ subkeys = None
592
+ omega_border = self.sample_in_omega_border_domain(subkeys)
645
593
 
646
- def _get_omega_operands(self):
594
+ return key, omega, omega_border
595
+
596
+ def _get_omega_operands(
597
+ self,
598
+ ) -> tuple[Key, Float[Array, "n dim"], Int, Int, Float[Array, "n"]]:
647
599
  return (
648
- self._key,
600
+ self.key,
649
601
  self.omega,
650
602
  self.curr_omega_idx,
651
603
  self.omega_batch_size,
652
604
  self.p_omega,
653
605
  )
654
606
 
655
- def inside_batch(self):
607
+ def inside_batch(
608
+ self,
609
+ ) -> tuple["CubicMeshPDEStatio", Float[Array, "omega_batch_size dim"]]:
656
610
  r"""
657
- Return a batch of points in :math:`\Omega`.
611
+ Return a batch of points in $\Omega$.
658
612
  If all the batches have been seen, we reshuffle them,
659
613
  otherwise we just return the next unseen batch.
660
614
  """
@@ -670,38 +624,46 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
670
624
  bstart = self.curr_omega_idx
671
625
  bend = bstart + self.omega_batch_size
672
626
 
673
- (self._key, self.omega, self.curr_omega_idx) = _reset_or_increment(
674
- bend, n_eff, self._get_omega_operands()
627
+ new_attributes = _reset_or_increment(bend, n_eff, self._get_omega_operands())
628
+ new = eqx.tree_at(
629
+ lambda m: (m.key, m.omega, m.curr_omega_idx), self, new_attributes
675
630
  )
676
631
 
677
- # commands below are equivalent to
678
- # return self.omega[i:(i+batch_size), 0:dim]
679
- return jax.lax.dynamic_slice(
680
- self.omega,
681
- start_indices=(self.curr_omega_idx, 0),
682
- slice_sizes=(self.omega_batch_size, self.dim),
632
+ return new, jax.lax.dynamic_slice(
633
+ new.omega,
634
+ start_indices=(new.curr_omega_idx, 0),
635
+ slice_sizes=(new.omega_batch_size, new.dim),
683
636
  )
684
637
 
685
- def _get_omega_border_operands(self):
638
+ def _get_omega_border_operands(
639
+ self,
640
+ ) -> tuple[
641
+ Key, Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None, Int, Int, None
642
+ ]:
686
643
  return (
687
- self._key,
644
+ self.key,
688
645
  self.omega_border,
689
646
  self.curr_omega_border_idx,
690
647
  self.omega_border_batch_size,
691
648
  self.p_border,
692
649
  )
693
650
 
694
- def border_batch(self):
651
+ def border_batch(
652
+ self,
653
+ ) -> tuple[
654
+ "CubicMeshPDEStatio",
655
+ Float[Array, "1 1 2"] | Float[Array, "omega_border_batch_size 2 4"] | None,
656
+ ]:
695
657
  r"""
696
658
  Return
697
659
 
698
660
  - The value `None` if `self.omega_border_batch_size` is `None`.
699
661
 
700
- - a jnp array with two fixed values :math:`(x_{min}, x_{max})` if
662
+ - a jnp array with two fixed values $(x_{min}, x_{max})$ if
701
663
  `self.dim` = 1. There is no sampling here, we return the entire
702
- :math:`\partial\Omega`
664
+ $\partial\Omega$
703
665
 
704
- - a batch of points in :math:`\partial\Omega` otherwise, stacked by
666
+ - a batch of points in $\partial\Omega$ otherwise, stacked by
705
667
  facet on the last axis.
706
668
  If all the batches have been seen, we reshuffle them,
707
669
  otherwise we just return the next unseen batch.
@@ -709,236 +671,138 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
709
671
 
710
672
  """
711
673
  if self.omega_border_batch_size is None:
712
- return None
674
+ return self, None
713
675
  if self.dim == 1:
714
676
  # 1-D case, no randomness : we always return the whole omega border,
715
677
  # i.e. (1, 1, 2) shape jnp.array([[[xmin], [xmax]]]).
716
- return self.omega_border[None, None] # shape is (1, 1, 2)
678
+ return self, self.omega_border[None, None] # shape is (1, 1, 2)
717
679
  bstart = self.curr_omega_border_idx
718
680
  bend = bstart + self.omega_border_batch_size
719
681
 
720
- (
721
- self._key,
722
- self.omega_border,
723
- self.curr_omega_border_idx,
724
- ) = _reset_or_increment(bend, self.nb, self._get_omega_border_operands())
725
-
726
- # commands below are equivalent to
727
- # return self.omega[i:(i+batch_size), 0:dim, 0:nb_facets]
728
- # and nb_facets = 2 * dimension
729
- # but JAX prefer the latter
730
- return jax.lax.dynamic_slice(
731
- self.omega_border,
732
- start_indices=(self.curr_omega_border_idx, 0, 0),
733
- slice_sizes=(self.omega_border_batch_size, self.dim, 2 * self.dim),
682
+ new_attributes = _reset_or_increment(
683
+ bend, self.nb, self._get_omega_border_operands()
734
684
  )
735
-
736
- def get_batch(self):
737
- """
738
- Generic method to return a batch. Here we call `self.inside_batch()`
739
- and `self.border_batch()`
740
- """
741
- return PDEStatioBatch(
742
- inside_batch=self.inside_batch(), border_batch=self.border_batch()
685
+ new = eqx.tree_at(
686
+ lambda m: (m.key, m.omega_border, m.curr_omega_border_idx),
687
+ self,
688
+ new_attributes,
743
689
  )
744
690
 
745
- def tree_flatten(self):
746
- children = (
747
- self._key,
748
- self.omega,
749
- self.omega_border,
750
- self.curr_omega_idx,
751
- self.curr_omega_border_idx,
752
- self.min_pts,
753
- self.max_pts,
754
- self.p_omega,
755
- self.rar_iter_from_last_sampling,
756
- self.rar_iter_nb,
691
+ return new, jax.lax.dynamic_slice(
692
+ new.omega_border,
693
+ start_indices=(new.curr_omega_border_idx, 0, 0),
694
+ slice_sizes=(new.omega_border_batch_size, new.dim, 2 * new.dim),
757
695
  )
758
- aux_data = {
759
- k: vars(self)[k]
760
- for k in [
761
- "n",
762
- "nb",
763
- "omega_batch_size",
764
- "omega_border_batch_size",
765
- "method",
766
- "dim",
767
- "rar_parameters",
768
- "n_start",
769
- ]
770
- }
771
- return (children, aux_data)
772
696
 
773
- @classmethod
774
- def tree_unflatten(cls, aux_data, children):
697
+ def get_batch(self) -> tuple["CubicMeshPDEStatio", PDEStatioBatch]:
775
698
  """
776
- **Note:** When reconstructing the class, we force ``data_exists=True``
777
- in order not to re-generate the data at each flattening and
778
- unflattening that happens e.g. during the gradient descent in the
779
- optimization process
699
+ Generic method to return a batch. Here we call `self.inside_batch()`
700
+ and `self.border_batch()`
780
701
  """
781
- (
782
- key,
783
- omega,
784
- omega_border,
785
- curr_omega_idx,
786
- curr_omega_border_idx,
787
- min_pts,
788
- max_pts,
789
- p_omega,
790
- rar_iter_from_last_sampling,
791
- rar_iter_nb,
792
- ) = children
793
- # force data_exists=True here in order not to re-generate the data
794
- # at each iteration of lax.scan
795
- obj = cls(
796
- key=key,
797
- data_exists=True,
798
- min_pts=min_pts,
799
- max_pts=max_pts,
800
- **aux_data,
801
- )
802
- obj.omega = omega
803
- obj.omega_border = omega_border
804
- obj.curr_omega_idx = curr_omega_idx
805
- obj.curr_omega_border_idx = curr_omega_border_idx
806
- obj.p_omega = p_omega
807
- obj.rar_iter_from_last_sampling = rar_iter_from_last_sampling
808
- obj.rar_iter_nb = rar_iter_nb
809
- return obj
702
+ new, inside_batch = self.inside_batch()
703
+ new, border_batch = new.border_batch()
704
+ return new, PDEStatioBatch(inside_batch=inside_batch, border_batch=border_batch)
810
705
 
811
706
 
812
- @register_pytree_node_class
813
707
  class CubicMeshPDENonStatio(CubicMeshPDEStatio):
814
- """
708
+ r"""
815
709
  A class implementing data generator object for non stationary partial
816
710
  differential equations. Formally, it extends `CubicMeshPDEStatio`
817
711
  to include a temporal batch.
818
712
 
819
-
820
- **Note:** CubicMeshPDENonStatio is jittable. Hence it implements the tree_flatten() and
821
- tree_unflatten methods.
713
+ Parameters
714
+ ----------
715
+ key : Key
716
+ Jax random key to sample new time points and to shuffle batches
717
+ n : Int
718
+ The number of total $\Omega$ points that will be divided in
719
+ batches. Batches are made so that each data point is seen only
720
+ once during 1 epoch.
721
+ nb : Int | None
722
+ The total number of points in $\partial\Omega$.
723
+ Can be `None` not to lose performance generating the border
724
+ batch if they are not used
725
+ nt : Int
726
+ The number of total time points that will be divided in
727
+ batches. Batches are made so that each data point is seen only
728
+ once during 1 epoch.
729
+ omega_batch_size : Int
730
+ The size of the batch of randomly selected points among
731
+ the `n` points.
732
+ omega_border_batch_size : Int | None
733
+ The size of the batch of points randomly selected
734
+ among the `nb` points.
735
+ Can be `None` not to lose performance generating the border
736
+ batch if they are not used
737
+ temporal_batch_size : Int
738
+ The size of the batch of randomly selected points among
739
+ the `nt` points.
740
+ dim : Int
741
+ An integer. dimension of $\Omega$ domain
742
+ min_pts : tuple[tuple[Float, Float], ...]
743
+ A tuple of minimum values of the domain along each dimension. For a sampling
744
+ in `n` dimension, this represents $(x_{1, min}, x_{2,min}, ...,
745
+ x_{n, min})$
746
+ max_pts : tuple[tuple[Float, Float], ...]
747
+ A tuple of maximum values of the domain along each dimension. For a sampling
748
+ in `n` dimension, this represents $(x_{1, max}, x_{2,max}, ...,
749
+ x_{n,max})$
750
+ tmin : float
751
+ The minimum value of the time domain to consider
752
+ tmax : float
753
+ The maximum value of the time domain to consider
754
+ method : str, default="uniform"
755
+ Either `grid` or `uniform`, default is `uniform`.
756
+ The method that generates the `nt` time points. `grid` means
757
+ regularly spaced points over the domain. `uniform` means uniformly
758
+ sampled points over the domain
759
+ rar_parameters : Dict[str, Int], default=None
760
+ Default to None: do not use Residual Adaptative Resampling.
761
+ Otherwise a dictionary with keys. `start_iter`: the iteration at
762
+ which we start the RAR sampling scheme (we first have a burn in
763
+ period). `update_every`: the number of gradient steps taken between
764
+ each appending of collocation points in the RAR algo.
765
+ `sample_size_omega`: the size of the sample from which we will select new
766
+ collocation points. `selected_sample_size_omega`: the number of selected
767
+ points from the sample to be added to the current collocation
768
+ points.
769
+ n_start : Int, default=None
770
+ Defaults to None. The effective size of n used at start time.
771
+ This value must be
772
+ provided when rar_parameters is not None. Otherwise we set internally
773
+ n_start = n and this is hidden from the user.
774
+ In RAR, n_start
775
+ then corresponds to the initial number of omega points we train the PINN.
776
+ nt_start : Int, default=None
777
+ Defaults to None. A RAR hyper-parameter. Same as ``n_start`` but
778
+ for times collocation point. See also ``DataGeneratorODE``
779
+ documentation.
780
+ cartesian_product : Bool, default=True
781
+ Defaults to True. Whether we return the cartesian product of the
782
+ temporal batch with the inside and border batches. If False we just
783
+ return their concatenation.
822
784
  """
823
785
 
824
- def __init__(
825
- self,
826
- key,
827
- n,
828
- nb,
829
- nt,
830
- omega_batch_size,
831
- omega_border_batch_size,
832
- temporal_batch_size,
833
- dim,
834
- min_pts,
835
- max_pts,
836
- tmin,
837
- tmax,
838
- method="grid",
839
- rar_parameters=None,
840
- n_start=None,
841
- nt_start=None,
842
- cartesian_product=True,
843
- data_exists=False,
844
- ):
845
- r"""
846
- Parameters
847
- ----------
848
- key
849
- Jax random key to sample new time points and to shuffle batches
850
- n
851
- An integer. The number of total :math:`\Omega` points that will be divided in
852
- batches. Batches are made so that each data point is seen only
853
- once during 1 epoch.
854
- nb
855
- An integer. The total number of points in :math:`\partial\Omega`.
856
- Can be `None` not to lose performance generating the border
857
- batch if they are not used
858
- nt
859
- An integer. The number of total time points that will be divided in
860
- batches. Batches are made so that each data point is seen only
861
- once during 1 epoch.
862
- omega_batch_size
863
- An integer. The size of the batch of randomly selected points among
864
- the `n` points.
865
- omega_border_batch_size
866
- An integer. The size of the batch of points randomly selected
867
- among the `nb` points.
868
- Can be `None` not to lose performance generating the border
869
- batch if they are not used
870
- temporal_batch_size
871
- An integer. The size of the batch of randomly selected points among
872
- the `nt` points.
873
- dim
874
- An integer. dimension of :math:`\Omega` domain
875
- min_pts
876
- A tuple of minimum values of the domain along each dimension. For a sampling
877
- in `n` dimension, this represents :math:`(x_{1, min}, x_{2,min}, ...,
878
- x_{n, min})`
879
- max_pts
880
- A tuple of maximum values of the domain along each dimension. For a sampling
881
- in `n` dimension, this represents :math:`(x_{1, max}, x_{2,max}, ...,
882
- x_{n,max})`
883
- tmin
884
- A float. The minimum value of the time domain to consider
885
- tmax
886
- A float. The maximum value of the time domain to consider
887
- method
888
- Either `grid` or `uniform`, default is `grid`.
889
- The method that generates the `nt` time points. `grid` means
890
- regularly spaced points over the domain. `uniform` means uniformly
891
- sampled points over the domain
892
- rar_parameters
893
- Default to None: do not use Residual Adaptative Resampling.
894
- Otherwise a dictionary with keys. `start_iter`: the iteration at
895
- which we start the RAR sampling scheme (we first have a burn in
896
- period). `update_every`: the number of gradient steps taken between
897
- each appending of collocation points in the RAR algo.
898
- `sample_size_omega`: the size of the sample from which we will select new
899
- collocation points. `selected_sample_size_omega`: the number of selected
900
- points from the sample to be added to the current collocation
901
- points.
902
- n_start
903
- Defaults to None. The effective size of n used at start time.
904
- This value must be
905
- provided when rar_parameters is not None. Otherwise we set internally
906
- n_start = n and this is hidden from the user.
907
- In RAR, n_start
908
- then corresponds to the initial number of omega points we train the PINN.
909
- nt_start
910
- Defaults to None. A RAR hyper-parameter. Same as ``n_start`` but
911
- for times collocation point. See also ``DataGeneratorODE``
912
- documentation.
913
- cartesian_product
914
- Defaults to True. Whether we return the cartesian product of the
915
- temporal batch with the inside and border batches. If False we just
916
- return their concatenation.
917
- data_exists
918
- Must be left to `False` when created by the user. Avoids the
919
- regeneration of :math:`\Omega`, :math:`\partial\Omega` and
920
- time points at each pytree flattening and unflattening.
786
+ temporal_batch_size: Int = eqx.field(kw_only=True)
787
+ tmin: Float = eqx.field(kw_only=True)
788
+ tmax: Float = eqx.field(kw_only=True)
789
+ nt: Int = eqx.field(kw_only=True)
790
+ temporal_batch_size: Int = eqx.field(kw_only=True, static=True)
791
+ cartesian_product: Bool = eqx.field(kw_only=True, default=True, static=True)
792
+ nt_start: int = eqx.field(kw_only=True, default=None, static=True)
793
+
794
+ p_times: Array = eqx.field(init=False)
795
+ curr_time_idx: Int = eqx.field(init=False)
796
+ times: Array = eqx.field(init=False)
797
+
798
+ def __post_init__(self):
921
799
  """
922
- super().__init__(
923
- key,
924
- n,
925
- nb,
926
- omega_batch_size,
927
- omega_border_batch_size,
928
- dim,
929
- min_pts,
930
- max_pts,
931
- method,
932
- rar_parameters,
933
- n_start,
934
- data_exists,
935
- )
936
- self.temporal_batch_size = temporal_batch_size
937
- self.tmin = tmin
938
- self.tmax = tmax
939
- self.nt = nt
800
+ Note that neither __init__ or __post_init__ are called when udating a
801
+ Module with eqx.tree_at!
802
+ """
803
+ super().__post_init__() # because __init__ or __post_init__ of Base
804
+ # class is not automatically called
940
805
 
941
- self.cartesian_product = cartesian_product
942
806
  if not self.cartesian_product:
943
807
  if self.temporal_batch_size != self.omega_batch_size:
944
808
  raise ValueError(
@@ -968,46 +832,54 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
968
832
  self.p_times,
969
833
  _,
970
834
  _,
971
- ) = _check_and_set_rar_parameters(rar_parameters, n=nt, n_start=nt_start)
972
-
973
- if not self.data_exists:
974
- # Useful when using a lax.scan with pytree
975
- # Optionally can tell JAX not to re-generate data
976
- self.curr_time_idx = 0
977
- self.generate_data_nonstatio()
978
- self._key, self.times, _ = _reset_batch_idx_and_permute(
979
- self._get_time_operands()
980
- )
981
-
982
- def sample_in_time_domain(self, n_samples):
983
- self._key, subkey = random.split(self._key, 2)
984
- return random.uniform(subkey, (n_samples,), minval=self.tmin, maxval=self.tmax)
835
+ ) = _check_and_set_rar_parameters(self.rar_parameters, self.nt, self.nt_start)
836
+
837
+ self.curr_time_idx = jnp.iinfo(jnp.int32).max - self.temporal_batch_size - 1
838
+ self.key, _ = jax.random.split(self.key, 2) # to make it equivalent to
839
+ # the call to _reset_batch_idx_and_permute in legacy DG
840
+ self.key, self.times = self.generate_time_data(self.key)
841
+ # see explaination in DataGeneratorODE for the key
842
+
843
+ def sample_in_time_domain(
844
+ self, key: Key, sample_size: Int = None
845
+ ) -> Float[Array, "nt"]:
846
+ return jax.random.uniform(
847
+ key,
848
+ (self.nt if sample_size is None else sample_size,),
849
+ minval=self.tmin,
850
+ maxval=self.tmax,
851
+ )
985
852
 
986
- def _get_time_operands(self):
853
+ def _get_time_operands(
854
+ self,
855
+ ) -> tuple[Key, Float[Array, "nt"], Int, Int, Float[Array, "nt"]]:
987
856
  return (
988
- self._key,
857
+ self.key,
989
858
  self.times,
990
859
  self.curr_time_idx,
991
860
  self.temporal_batch_size,
992
861
  self.p_times,
993
862
  )
994
863
 
995
- def generate_data_nonstatio(self):
996
- r"""
864
+ def generate_time_data(self, key: Key) -> tuple[Key, Float[Array, "nt"]]:
865
+ """
997
866
  Construct a complete set of `self.nt` time points according to the
998
- specified `self.method`. This completes the `super` function
999
- `generate_data()` which generates :math:`\Omega` and
1000
- :math:`\partial\Omega` points.
867
+ specified `self.method`
868
+
869
+ Note that self.times has always size self.nt and not self.nt_start, even
870
+ in RAR scheme, we must allocate all the collocation points
1001
871
  """
872
+ key, subkey = jax.random.split(key, 2)
1002
873
  if self.method == "grid":
1003
- self.partial_times = (self.tmax - self.tmin) / self.nt
1004
- self.times = jnp.arange(self.tmin, self.tmax, self.partial_times)
1005
- elif self.method == "uniform":
1006
- self.times = self.sample_in_time_domain(self.nt)
1007
- else:
1008
- raise ValueError("Method " + self.method + " is not implemented.")
874
+ partial_times = (self.tmax - self.tmin) / self.nt
875
+ return key, jnp.arange(self.tmin, self.tmax, partial_times)
876
+ if self.method == "uniform":
877
+ return key, self.sample_in_time_domain(subkey)
878
+ raise ValueError("Method " + self.method + " is not implemented.")
1009
879
 
1010
- def temporal_batch(self):
880
+ def temporal_batch(
881
+ self,
882
+ ) -> tuple["CubicMeshPDENonStatio", Float[Array, "temporal_batch_size"]]:
1011
883
  """
1012
884
  Return a batch of time points. If all the batches have been seen, we
1013
885
  reshuffle them, otherwise we just return the next unseen batch.
@@ -1018,254 +890,344 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
1018
890
  # Compute the effective number of used collocation points
1019
891
  if self.rar_parameters is not None:
1020
892
  nt_eff = (
1021
- self.n_start
893
+ self.nt_start
1022
894
  + self.rar_iter_nb * self.rar_parameters["selected_sample_size_times"]
1023
895
  )
1024
896
  else:
1025
897
  nt_eff = self.nt
1026
898
 
1027
- (self._key, self.times, self.curr_time_idx) = _reset_or_increment(
1028
- bend, nt_eff, self._get_time_operands()
899
+ new_attributes = _reset_or_increment(bend, nt_eff, self._get_time_operands())
900
+ new = eqx.tree_at(
901
+ lambda m: (m.key, m.times, m.curr_time_idx), self, new_attributes
1029
902
  )
1030
903
 
1031
- # commands below are equivalent to
1032
- # return self.times[i:(i+t_batch_size)]
1033
- # but JAX prefer the latter
1034
- return jax.lax.dynamic_slice(
1035
- self.times,
1036
- start_indices=(self.curr_time_idx,),
1037
- slice_sizes=(self.temporal_batch_size,),
904
+ return new, jax.lax.dynamic_slice(
905
+ new.times,
906
+ start_indices=(new.curr_time_idx,),
907
+ slice_sizes=(new.temporal_batch_size,),
1038
908
  )
1039
909
 
1040
- def get_batch(self):
910
+ def get_batch(self) -> tuple["CubicMeshPDENonStatio", PDENonStatioBatch]:
1041
911
  """
1042
912
  Generic method to return a batch. Here we call `self.inside_batch()`,
1043
913
  `self.border_batch()` and `self.temporal_batch()`
1044
914
  """
1045
- x = self.inside_batch()
1046
- dx = self.border_batch()
1047
- t = self.temporal_batch().reshape(self.temporal_batch_size, 1)
915
+ new, x = self.inside_batch()
916
+ new, dx = new.border_batch()
917
+ new, t = new.temporal_batch()
918
+ t = t.reshape(new.temporal_batch_size, 1)
1048
919
 
1049
- if self.cartesian_product:
920
+ if new.cartesian_product:
1050
921
  t_x = make_cartesian_product(t, x)
1051
922
  else:
1052
923
  t_x = jnp.concatenate([t, x], axis=1)
1053
924
 
1054
925
  if dx is not None:
1055
- t_ = t.reshape(self.temporal_batch_size, 1, 1)
926
+ t_ = t.reshape(new.temporal_batch_size, 1, 1)
1056
927
  t_ = jnp.repeat(t_, dx.shape[-1], axis=2)
1057
- if self.cartesian_product or self.dim == 1:
928
+ if new.cartesian_product or new.dim == 1:
1058
929
  t_dx = make_cartesian_product(t_, dx)
1059
930
  else:
1060
931
  t_dx = jnp.concatenate([t_, dx], axis=1)
1061
932
  else:
1062
933
  t_dx = None
1063
934
 
1064
- return PDENonStatioBatch(times_x_inside_batch=t_x, times_x_border_batch=t_dx)
935
+ return new, PDENonStatioBatch(
936
+ times_x_inside_batch=t_x, times_x_border_batch=t_dx
937
+ )
1065
938
 
1066
- def tree_flatten(self):
1067
- children = (
1068
- self._key,
1069
- self.omega,
1070
- self.omega_border,
1071
- self.times,
1072
- self.curr_omega_idx,
1073
- self.curr_omega_border_idx,
1074
- self.curr_time_idx,
1075
- self.min_pts,
1076
- self.max_pts,
1077
- self.tmin,
1078
- self.tmax,
1079
- self.p_times,
1080
- self.p_omega,
1081
- self.rar_iter_from_last_sampling,
1082
- self.rar_iter_nb,
939
+
940
+ class DataGeneratorObservations(eqx.Module):
941
+ r"""
942
+ Despite the class name, it is rather a dataloader from user provided
943
+ observations that will be used for the observations loss
944
+
945
+ Parameters
946
+ ----------
947
+ key : Key
948
+ Jax random key to shuffle batches
949
+ obs_batch_size : Int
950
+ The size of the batch of randomly selected points among
951
+ the `n` points. `obs_batch_size` will be the same for all
952
+ elements of the return observation dict batch.
953
+ NOTE: no check is done BUT users should be careful that
954
+ `obs_batch_size` must be equal to `temporal_batch_size` or
955
+ `omega_batch_size` or the product of both. In the first case, the
956
+ present DataGeneratorObservations instance complements an ODEBatch,
957
+ PDEStatioBatch or a PDENonStatioBatch (with self.cartesian_product
958
+ = False). In the second case, `obs_batch_size` =
959
+ `temporal_batch_size * omega_batch_size` if the present
960
+ DataGeneratorParameter complements a PDENonStatioBatch
961
+ with self.cartesian_product = True
962
+ observed_pinn_in : Float[Array, "n_obs nb_pinn_in"]
963
+ Observed values corresponding to the input of the PINN
964
+ (eg. the time at which we recorded the observations). The first
965
+ dimension must corresponds to the number of observed_values.
966
+ The second dimension depends on the input dimension of the PINN,
967
+ that is `1` for ODE, `n_dim_x` for stationnary PDE and `n_dim_x + 1`
968
+ for non-stationnary PDE.
969
+ observed_values : Float[Array, "n_obs, nb_pinn_out"]
970
+ Observed values that the PINN should learn to fit. The first
971
+ dimension must be aligned with observed_pinn_in.
972
+ observed_eq_params : Dict[str, Float[Array, "n_obs 1"]], default={}
973
+ A dict with keys corresponding to
974
+ the parameter name. The keys must match the keys in
975
+ `params["eq_params"]`. The values are jnp.array with 2 dimensions
976
+ with values corresponding to the parameter value for which we also
977
+ have observed_pinn_in and observed_values. Hence the first
978
+ dimension must be aligned with observed_pinn_in and observed_values.
979
+ Optional argument.
980
+ sharding_device : jax.sharding.Sharding, default=None
981
+ Default None. An optional sharding object to constraint the storage
982
+ of observed inputs, values and parameters. Typically, a
983
+ SingleDeviceSharding(cpu_device) to avoid loading on GPU huge
984
+ datasets of observations. Note that computations for **batches**
985
+ can still be performed on other devices (*e.g.* GPU, TPU or
986
+ any pre-defined Sharding) thanks to the `obs_batch_sharding`
987
+ arguments of `jinns.solve()`. Read the docs for more info.
988
+ """
989
+
990
+ key: Key
991
+ obs_batch_size: Int = eqx.field(static=True)
992
+ observed_pinn_in: Float[Array, "n_obs nb_pinn_in"]
993
+ observed_values: Float[Array, "n_obs nb_pinn_out"]
994
+ observed_eq_params: Dict[str, Float[Array, "n_obs 1"]] = eqx.field(
995
+ static=True, default_factory=lambda: {}
996
+ )
997
+ sharding_device: jax.sharding.Sharding = eqx.field(static=True, default=None)
998
+
999
+ n: Int = eqx.field(init=False)
1000
+ curr_idx: Int = eqx.field(init=False)
1001
+ indices: Array = eqx.field(init=False)
1002
+
1003
+ def __post_init__(self):
1004
+ if self.observed_pinn_in.shape[0] != self.observed_values.shape[0]:
1005
+ raise ValueError(
1006
+ "self.observed_pinn_in and self.observed_values must have same first axis"
1007
+ )
1008
+ for _, v in self.observed_eq_params.items():
1009
+ if v.shape[0] != self.observed_pinn_in.shape[0]:
1010
+ raise ValueError(
1011
+ "self.observed_pinn_in and the values of"
1012
+ " self.observed_eq_params must have the same first axis"
1013
+ )
1014
+ if len(self.observed_pinn_in.shape) == 1:
1015
+ self.observed_pinn_in = self.observed_pinn_in[:, None]
1016
+ if len(self.observed_pinn_in.shape) > 2:
1017
+ raise ValueError("self.observed_pinn_in must have 2 dimensions")
1018
+ if len(self.observed_values.shape) == 1:
1019
+ self.observed_values = self.observed_values[:, None]
1020
+ if len(self.observed_values.shape) > 2:
1021
+ raise ValueError("self.observed_values must have 2 dimensions")
1022
+ for k, v in self.observed_eq_params.items():
1023
+ if len(v.shape) == 1:
1024
+ self.observed_eq_params[k] = v[:, None]
1025
+ if len(v.shape) > 2:
1026
+ raise ValueError(
1027
+ "Each value of observed_eq_params must have 2 dimensions"
1028
+ )
1029
+
1030
+ self.n = self.observed_pinn_in.shape[0]
1031
+
1032
+ if self.sharding_device is not None:
1033
+ self.observed_pinn_in = jax.lax.with_sharding_constraint(
1034
+ self.observed_pinn_in, self.sharding_device
1035
+ )
1036
+ self.observed_values = jax.lax.with_sharding_constraint(
1037
+ self.observed_values, self.sharding_device
1038
+ )
1039
+ self.observed_eq_params = jax.lax.with_sharding_constraint(
1040
+ self.observed_eq_params, self.sharding_device
1041
+ )
1042
+
1043
+ self.curr_idx = jnp.iinfo(jnp.int32).max - self.obs_batch_size - 1
1044
+ # For speed and to avoid duplicating data what is really
1045
+ # shuffled is a vector of indices
1046
+ if self.sharding_device is not None:
1047
+ self.indices = jax.lax.with_sharding_constraint(
1048
+ jnp.arange(self.n), self.sharding_device
1049
+ )
1050
+ else:
1051
+ self.indices = jnp.arange(self.n)
1052
+
1053
+ # recall post_init is the only place with _init_ where we can set
1054
+ # self attribute in a in-place way
1055
+ self.key, _ = jax.random.split(self.key, 2) # to make it equivalent to
1056
+ # the call to _reset_batch_idx_and_permute in legacy DG
1057
+
1058
+ def _get_operands(self) -> tuple[Key, Int[Array, "n"], Int, Int, None]:
1059
+ return (
1060
+ self.key,
1061
+ self.indices,
1062
+ self.curr_idx,
1063
+ self.obs_batch_size,
1064
+ None,
1083
1065
  )
1084
- aux_data = {
1085
- k: vars(self)[k]
1086
- for k in [
1087
- "n",
1088
- "nb",
1089
- "nt",
1090
- "omega_batch_size",
1091
- "omega_border_batch_size",
1092
- "temporal_batch_size",
1093
- "method",
1094
- "dim",
1095
- "rar_parameters",
1096
- "n_start",
1097
- "nt_start",
1098
- "cartesian_product",
1099
- ]
1100
- }
1101
- return (children, aux_data)
1102
1066
 
1103
- @classmethod
1104
- def tree_unflatten(cls, aux_data, children):
1067
+ def obs_batch(
1068
+ self,
1069
+ ) -> tuple[
1070
+ "DataGeneratorObservations", Dict[str, Float[Array, "obs_batch_size dim"]]
1071
+ ]:
1105
1072
  """
1106
- **Note:** When reconstructing the class, we force ``data_exists=True``
1107
- in order not to re-generate the data at each flattening and
1108
- unflattening that happens e.g. during the gradient descent in the
1109
- optimization process
1073
+ Return a dictionary with (keys, values): (pinn_in, a mini batch of pinn
1074
+ inputs), (obs, a mini batch of corresponding observations), (eq_params,
1075
+ a dictionary with entry names found in `params["eq_params"]` and values
1076
+ giving the correspond parameter value for the couple
1077
+ (input, observation) mentioned before).
1078
+ It can also be a dictionary of dictionaries as described above if
1079
+ observed_pinn_in, observed_values, etc. are dictionaries with keys
1080
+ representing the PINNs.
1110
1081
  """
1111
- (
1112
- key,
1113
- omega,
1114
- omega_border,
1115
- times,
1116
- curr_omega_idx,
1117
- curr_omega_border_idx,
1118
- curr_time_idx,
1119
- min_pts,
1120
- max_pts,
1121
- tmin,
1122
- tmax,
1123
- p_times,
1124
- p_omega,
1125
- rar_iter_from_last_sampling,
1126
- rar_iter_nb,
1127
- ) = children
1128
- obj = cls(
1129
- key=key,
1130
- data_exists=True,
1131
- min_pts=min_pts,
1132
- max_pts=max_pts,
1133
- tmin=tmin,
1134
- tmax=tmax,
1135
- **aux_data,
1082
+
1083
+ new_attributes = _reset_or_increment(
1084
+ self.curr_idx + self.obs_batch_size, self.n, self._get_operands()
1085
+ )
1086
+ new = eqx.tree_at(
1087
+ lambda m: (m.key, m.indices, m.curr_idx), self, new_attributes
1088
+ )
1089
+
1090
+ minib_indices = jax.lax.dynamic_slice(
1091
+ new.indices,
1092
+ start_indices=(new.curr_idx,),
1093
+ slice_sizes=(new.obs_batch_size,),
1136
1094
  )
1137
- obj.omega = omega
1138
- obj.omega_border = omega_border
1139
- obj.times = times
1140
- obj.curr_omega_idx = curr_omega_idx
1141
- obj.curr_omega_border_idx = curr_omega_border_idx
1142
- obj.curr_time_idx = curr_time_idx
1143
- obj.p_times = p_times
1144
- obj.p_omega = p_omega
1145
- obj.rar_iter_from_last_sampling = rar_iter_from_last_sampling
1146
- obj.rar_iter_nb = rar_iter_nb
1147
-
1148
- return obj
1149
-
1150
-
1151
- @register_pytree_node_class
1152
- class DataGeneratorParameter:
1153
- """
1154
- A data generator for additional unidimensional parameter(s)
1155
- """
1156
1095
 
1157
- def __init__(
1096
+ obs_batch = {
1097
+ "pinn_in": jnp.take(
1098
+ new.observed_pinn_in, minib_indices, unique_indices=True, axis=0
1099
+ ),
1100
+ "val": jnp.take(
1101
+ new.observed_values, minib_indices, unique_indices=True, axis=0
1102
+ ),
1103
+ "eq_params": jax.tree_util.tree_map(
1104
+ lambda a: jnp.take(a, minib_indices, unique_indices=True, axis=0),
1105
+ new.observed_eq_params,
1106
+ ),
1107
+ }
1108
+ return new, obs_batch
1109
+
1110
+ def get_batch(
1158
1111
  self,
1159
- key,
1160
- n,
1161
- param_batch_size,
1162
- param_ranges=None,
1163
- method="grid",
1164
- user_data=None,
1165
- data_exists=False,
1166
- ):
1167
- r"""
1168
- Parameters
1169
- ----------
1170
- key
1171
- Jax random key to sample new time points and to shuffle batches
1172
- or a dict of Jax random keys with key entries from param_ranges
1173
- n
1174
- An integer. The number of total points that will be divided in
1175
- batches. Batches are made so that each data point is seen only
1176
- once during 1 epoch.
1177
- param_batch_size
1178
- An integer. The size of the batch of randomly selected points among
1179
- the `n` points. `param_batch_size` will be the same for all
1180
- additional batch of parameter.
1181
- NOTE: no check is done BUT users should be careful that
1182
- `param_batch_size` must be equal to `temporal_batch_size` or
1183
- `omega_batch_size` or the product of both. In the first case, the
1184
- present DataGeneratorParameter instance complements an ODEBatch, a
1185
- PDEStatioBatch or a PDENonStatioBatch (with self.cartesian_product
1186
- = False). In the second case, `param_batch_size` =
1187
- `temporal_batch_size * omega_batch_size` if the present
1188
- DataGeneratorParameter complements a PDENonStatioBatch
1189
- with self.cartesian_product = True
1190
- param_ranges
1191
- A dict. A dict of tuples (min, max), which
1192
- reprensents the range of real numbers where to sample batches (of
1193
- length `param_batch_size` among `n` points).
1194
- The key corresponds to the parameter name. The keys must match the
1195
- keys in `params["eq_params"]`.
1196
- By providing several entries in this dictionary we can sample
1197
- an arbitrary number of parameters.
1198
- __Note__ that we currently only support unidimensional parameters
1199
- method
1200
- Either `grid` or `uniform`, default is `grid`. `grid` means
1201
- regularly spaced points over the domain. `uniform` means uniformly
1202
- sampled points over the domain
1203
- data_exists
1204
- Must be left to `False` when created by the user. Avoids the
1205
- regeneration of :math:`\Omega`, :math:`\partial\Omega` and
1206
- time points at each pytree flattening and unflattening.
1207
- user_data
1208
- A dictionary containing user-provided data for parameters.
1209
- As for `param_ranges`, the key corresponds to the parameter name,
1210
- the keys must match the keys in `params["eq_params"]` and only
1211
- unidimensional arrays are supported. Therefore, the jnp arrays
1212
- found at `user_data[k]` must have shape `(n, 1)` or `(n,)`.
1213
- Note that if the same key appears in `param_ranges` and `user_data`
1214
- priority goes for the content in `user_data`.
1215
- Defaults to None.
1112
+ ) -> tuple[
1113
+ "DataGeneratorObservations", Dict[str, Float[Array, "obs_batch_size dim"]]
1114
+ ]:
1216
1115
  """
1217
- self.data_exists = data_exists
1218
- self.method = method
1116
+ Generic method to return a batch
1117
+ """
1118
+ return self.obs_batch()
1119
+
1120
+
1121
+ class DataGeneratorParameter(eqx.Module):
1122
+ r"""
1123
+ A data generator for additional unidimensional parameter(s)
1124
+
1125
+ Parameters
1126
+ ----------
1127
+ keys : Key | Dict[str, Key]
1128
+ Jax random key to sample new time points and to shuffle batches
1129
+ or a dict of Jax random keys with key entries from param_ranges
1130
+ n : Int
1131
+ The number of total points that will be divided in
1132
+ batches. Batches are made so that each data point is seen only
1133
+ once during 1 epoch.
1134
+ param_batch_size : Int
1135
+ The size of the batch of randomly selected points among
1136
+ the `n` points. `param_batch_size` will be the same for all
1137
+ additional batch of parameter.
1138
+ NOTE: no check is done BUT users should be careful that
1139
+ `param_batch_size` must be equal to `temporal_batch_size` or
1140
+ `omega_batch_size` or the product of both. In the first case, the
1141
+ present DataGeneratorParameter instance complements an ODEBatch, a
1142
+ PDEStatioBatch or a PDENonStatioBatch (with self.cartesian_product
1143
+ = False). In the second case, `param_batch_size` =
1144
+ `temporal_batch_size * omega_batch_size` if the present
1145
+ DataGeneratorParameter complements a PDENonStatioBatch
1146
+ with self.cartesian_product = True
1147
+ param_ranges : Dict[str, tuple[Float, Float] | None, default={}
1148
+ A dict. A dict of tuples (min, max), which
1149
+ reprensents the range of real numbers where to sample batches (of
1150
+ length `param_batch_size` among `n` points).
1151
+ The key corresponds to the parameter name. The keys must match the
1152
+ keys in `params["eq_params"]`.
1153
+ By providing several entries in this dictionary we can sample
1154
+ an arbitrary number of parameters.
1155
+ **Note** that we currently only support unidimensional parameters.
1156
+ This argument can be done if we only use `user_data`.
1157
+ method : str, default="uniform"
1158
+ Either `grid` or `uniform`, default is `uniform`. `grid` means
1159
+ regularly spaced points over the domain. `uniform` means uniformly
1160
+ sampled points over the domain
1161
+ user_data : Dict[str, Float[Array, "n"]] | None, default={}
1162
+ A dictionary containing user-provided data for parameters.
1163
+ As for `param_ranges`, the key corresponds to the parameter name,
1164
+ the keys must match the keys in `params["eq_params"]` and only
1165
+ unidimensional arrays are supported. Therefore, the jnp arrays
1166
+ found at `user_data[k]` must have shape `(n, 1)` or `(n,)`.
1167
+ Note that if the same key appears in `param_ranges` and `user_data`
1168
+ priority goes for the content in `user_data`.
1169
+ Defaults to None.
1170
+ """
1171
+
1172
+ keys: Key | Dict[str, Key]
1173
+ n: Int
1174
+ param_batch_size: Int = eqx.field(static=True)
1175
+ param_ranges: Dict[str, tuple[Float, Float]] = eqx.field(
1176
+ static=True, default_factory=lambda: {}
1177
+ )
1178
+ method: str = eqx.field(static=True, default="uniform")
1179
+ user_data: Dict[str, Float[Array, "n"]] | None = eqx.field(
1180
+ static=True, default_factory=lambda: {}
1181
+ )
1182
+
1183
+ curr_param_idx: Dict[str, Int] = eqx.field(init=False)
1184
+ param_n_samples: Dict[str, Array] = eqx.field(init=False)
1219
1185
 
1220
- if n < param_batch_size:
1186
+ def __post_init__(self):
1187
+ if self.user_data is None:
1188
+ self.user_data = {}
1189
+ if self.param_ranges is None:
1190
+ self.param_ranges = {}
1191
+ if self.n < self.param_batch_size:
1221
1192
  raise ValueError(
1222
- f"Number of data points ({n}) is smaller than the"
1223
- f"number of batch points ({param_batch_size})."
1193
+ f"Number of data points ({self.n}) is smaller than the"
1194
+ f"number of batch points ({self.param_batch_size})."
1195
+ )
1196
+ if not isinstance(self.keys, dict):
1197
+ all_keys = set().union(self.param_ranges, self.user_data)
1198
+ self.keys = dict(zip(all_keys, jax.random.split(self.keys, len(all_keys))))
1199
+
1200
+ self.curr_param_idx = {}
1201
+ for k in self.keys.keys():
1202
+ self.curr_param_idx[k] = (
1203
+ jnp.iinfo(jnp.int32).max - self.param_batch_size - 1
1224
1204
  )
1225
1205
 
1226
- if user_data is None:
1227
- user_data = {}
1228
- if param_ranges is None:
1229
- param_ranges = {}
1230
- if not isinstance(key, dict):
1231
- all_keys = set().union(param_ranges, user_data)
1232
- self._keys = dict(zip(all_keys, jax.random.split(key, len(all_keys))))
1233
- else:
1234
- self._keys = key
1235
- self.n = n
1236
- self.param_batch_size = param_batch_size
1237
- self.param_ranges = param_ranges
1238
- self.user_data = user_data
1239
-
1240
- if not self.data_exists:
1241
- self.generate_data()
1242
- # The previous call to self.generate_data() has created
1243
- # the dict self.param_n_samples and then we will only use this one
1244
- # because it has merged the scattered data between `user_data` and
1245
- # `param_ranges`
1246
- self.curr_param_idx = {}
1247
- for k in self.param_n_samples.keys():
1248
- self.curr_param_idx[k] = 0
1249
- (
1250
- self._keys[k],
1251
- self.param_n_samples[k],
1252
- _,
1253
- ) = _reset_batch_idx_and_permute(self._get_param_operands(k))
1254
-
1255
- def generate_data(self):
1206
+ # The call to self.generate_data() creates
1207
+ # the dict self.param_n_samples and then we will only use this one
1208
+ # because it merges the scattered data between `user_data` and
1209
+ # `param_ranges`
1210
+ self.keys, self.param_n_samples = self.generate_data(self.keys)
1211
+
1212
+ def generate_data(
1213
+ self, keys: Dict[str, Key]
1214
+ ) -> tuple[Dict[str, Key], Dict[str, Float[Array, "n"]]]:
1256
1215
  """
1257
1216
  Generate parameter samples, either through generation
1258
1217
  or using user-provided data.
1259
1218
  """
1260
- self.param_n_samples = {}
1219
+ param_n_samples = {}
1261
1220
 
1262
1221
  all_keys = set().union(self.param_ranges, self.user_data)
1263
1222
  for k in all_keys:
1264
- if self.user_data and k in self.user_data:
1223
+ if (
1224
+ self.user_data
1225
+ and k in self.user_data.keys() # pylint: disable=no-member
1226
+ ):
1265
1227
  if self.user_data[k].shape == (self.n, 1):
1266
- self.param_n_samples[k] = self.user_data[k]
1228
+ param_n_samples[k] = self.user_data[k]
1267
1229
  if self.user_data[k].shape == (self.n,):
1268
- self.param_n_samples[k] = self.user_data[k][:, None]
1230
+ param_n_samples[k] = self.user_data[k][:, None]
1269
1231
  else:
1270
1232
  raise ValueError(
1271
1233
  "Wrong shape for user provided parameters"
@@ -1274,23 +1236,25 @@ class DataGeneratorParameter:
1274
1236
  else:
1275
1237
  if self.method == "grid":
1276
1238
  xmin, xmax = self.param_ranges[k][0], self.param_ranges[k][1]
1277
- self.partial = (xmax - xmin) / self.n
1239
+ partial = (xmax - xmin) / self.n
1278
1240
  # shape (n, 1)
1279
- self.param_n_samples[k] = jnp.arange(xmin, xmax, self.partial)[
1280
- :, None
1281
- ]
1241
+ param_n_samples[k] = jnp.arange(xmin, xmax, partial)[:, None]
1282
1242
  elif self.method == "uniform":
1283
1243
  xmin, xmax = self.param_ranges[k][0], self.param_ranges[k][1]
1284
- self._keys[k], subkey = random.split(self._keys[k], 2)
1285
- self.param_n_samples[k] = random.uniform(
1244
+ keys[k], subkey = jax.random.split(keys[k], 2)
1245
+ param_n_samples[k] = jax.random.uniform(
1286
1246
  subkey, shape=(self.n, 1), minval=xmin, maxval=xmax
1287
1247
  )
1288
1248
  else:
1289
1249
  raise ValueError("Method " + self.method + " is not implemented.")
1290
1250
 
1291
- def _get_param_operands(self, k):
1251
+ return keys, param_n_samples
1252
+
1253
+ def _get_param_operands(
1254
+ self, k: str
1255
+ ) -> tuple[Key, Float[Array, "n"], Int, Int, None]:
1292
1256
  return (
1293
- self._keys[k],
1257
+ self.keys[k],
1294
1258
  self.param_n_samples[k],
1295
1259
  self.curr_param_idx[k],
1296
1260
  self.param_batch_size,
@@ -1315,26 +1279,28 @@ class DataGeneratorParameter:
1315
1279
  _reset_or_increment_wrapper,
1316
1280
  self.param_n_samples,
1317
1281
  self.curr_param_idx,
1318
- self._keys,
1282
+ self.keys,
1319
1283
  )
1320
1284
  # we must transpose the pytrees because keys are merged in res
1321
1285
  # https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#transposing-trees
1322
- (
1323
- self._keys,
1324
- self.param_n_samples,
1325
- self.curr_param_idx,
1326
- ) = jax.tree_util.tree_transpose(
1327
- jax.tree_util.tree_structure(self._keys),
1286
+ new_attributes = jax.tree_util.tree_transpose(
1287
+ jax.tree_util.tree_structure(self.keys),
1328
1288
  jax.tree_util.tree_structure([0, 0, 0]),
1329
1289
  res,
1330
1290
  )
1331
1291
 
1332
- return jax.tree_util.tree_map(
1292
+ new = eqx.tree_at(
1293
+ lambda m: (m.keys, m.param_n_samples, m.curr_param_idx),
1294
+ self,
1295
+ new_attributes,
1296
+ )
1297
+
1298
+ return new, jax.tree_util.tree_map(
1333
1299
  lambda p, q: jax.lax.dynamic_slice(
1334
- p, start_indices=(q, 0), slice_sizes=(self.param_batch_size, 1)
1300
+ p, start_indices=(q, 0), slice_sizes=(new.param_batch_size, 1)
1335
1301
  ),
1336
- self.param_n_samples,
1337
- self.curr_param_idx,
1302
+ new.param_n_samples,
1303
+ new.curr_param_idx,
1338
1304
  )
1339
1305
 
1340
1306
  def get_batch(self):
@@ -1343,251 +1309,9 @@ class DataGeneratorParameter:
1343
1309
  """
1344
1310
  return self.param_batch()
1345
1311
 
1346
- def tree_flatten(self):
1347
- children = (
1348
- self._keys,
1349
- self.param_n_samples,
1350
- self.curr_param_idx,
1351
- )
1352
- aux_data = {
1353
- k: vars(self)[k]
1354
- for k in ["n", "param_batch_size", "method", "param_ranges", "user_data"]
1355
- }
1356
- return (children, aux_data)
1357
-
1358
- @classmethod
1359
- def tree_unflatten(cls, aux_data, children):
1360
- (
1361
- keys,
1362
- param_n_samples,
1363
- curr_param_idx,
1364
- ) = children
1365
- obj = cls(
1366
- key=keys,
1367
- data_exists=True,
1368
- **aux_data,
1369
- )
1370
- obj.param_n_samples = param_n_samples
1371
- obj.curr_param_idx = curr_param_idx
1372
- return obj
1373
-
1374
-
1375
- @register_pytree_node_class
1376
- class DataGeneratorObservations:
1377
- """
1378
- Despite the class name, it is rather a dataloader from user provided
1379
- observations that will be used for the observations loss
1380
- """
1381
-
1382
- def __init__(
1383
- self,
1384
- key,
1385
- obs_batch_size,
1386
- observed_pinn_in,
1387
- observed_values,
1388
- observed_eq_params=None,
1389
- data_exists=False,
1390
- sharding_device=None,
1391
- ):
1392
- r"""
1393
- Parameters
1394
- ----------
1395
- key
1396
- Jax random key to sample new time points and to shuffle batches
1397
- obs_batch_size
1398
- An integer. The size of the batch of randomly selected points among
1399
- the `n` points. `obs_batch_size` will be the same for all
1400
- elements of the obs dict.
1401
- NOTE: no check is done BUT users should be careful that
1402
- `obs_batch_size` must be equal to `temporal_batch_size` or
1403
- `omega_batch_size` or the product of both. In the first case, the
1404
- present DataGeneratorObservations instance complements an ODEBatch,
1405
- PDEStatioBatch or a PDENonStatioBatch (with self.cartesian_product
1406
- = False). In the second case, `obs_batch_size` =
1407
- `temporal_batch_size * omega_batch_size` if the present
1408
- DataGeneratorParameter complements a PDENonStatioBatch
1409
- with self.cartesian_product = True
1410
- observed_pinn_in
1411
- A jnp.array with 2 dimensions.
1412
- Observed values corresponding to the input of the PINN
1413
- (eg. the time at which we recorded the observations). The first
1414
- dimension must corresponds to the number of observed_values and
1415
- observed_eq_params. The second dimension depends on the input dimension of the PINN, that is `1` for ODE, `n_dim_x` for stationnary PDE and `n_dim_x + 1` for non-stationnary PDE.
1416
- observed_values
1417
- A jnp.array with 2 dimensions.
1418
- Observed values that the PINN should learn to fit. The first dimension must be aligned with observed_pinn_in and the values of observed_eq_params.
1419
- observed_eq_params
1420
- Optional. Default is None. A dict with keys corresponding to the
1421
- parameter name. The keys must match the keys in
1422
- `params["eq_params"]`. The values are jnp.array with 2 dimensions
1423
- with values corresponding to the parameter value for which we also
1424
- have observed_pinn_in and observed_values. Hence the first
1425
- dimension must be aligned with observed_pinn_in and observed_values.
1426
- data_exists
1427
- Must be left to `False` when created by the user. Avoids the
1428
- resetting of curr_idx at each pytree flattening and unflattening.
1429
- sharding_device
1430
- Default None. An optional sharding object to constraint the storage
1431
- of observed inputs, values and parameters. Typically, a
1432
- SingleDeviceSharding(cpu_device) to avoid loading on GPU huge
1433
- datasets of observations. Note that computations for **batches**
1434
- can still be performed on other devices (*e.g.* GPU, TPU or
1435
- any pre-defined Sharding) thanks to the `obs_batch_sharding`
1436
- arguments of `jinns.solve()`. Read the docs for more info.
1437
-
1438
- """
1439
- if observed_eq_params is None:
1440
- observed_eq_params = {}
1441
-
1442
- if not data_exists:
1443
- self.observed_eq_params = observed_eq_params.copy()
1444
- else:
1445
- # avoid copying when in flatten/unflatten
1446
- self.observed_eq_params = observed_eq_params
1447
-
1448
- if observed_pinn_in.shape[0] != observed_values.shape[0]:
1449
- raise ValueError(
1450
- "observed_pinn_in and observed_values must have same first axis"
1451
- )
1452
- for _, v in self.observed_eq_params.items():
1453
- if v.shape[0] != observed_pinn_in.shape[0]:
1454
- raise ValueError(
1455
- "observed_pinn_in and the values of"
1456
- " observed_eq_params must have the same first axis"
1457
- )
1458
- if len(observed_pinn_in.shape) == 1:
1459
- observed_pinn_in = observed_pinn_in[:, None]
1460
- if len(observed_pinn_in.shape) > 2:
1461
- raise ValueError("observed_pinn_in must have 2 dimensions")
1462
- if len(observed_values.shape) == 1:
1463
- observed_values = observed_values[:, None]
1464
- if len(observed_values.shape) > 2:
1465
- raise ValueError("observed_values must have 2 dimensions")
1466
- for k, v in self.observed_eq_params.items():
1467
- if len(v.shape) == 1:
1468
- self.observed_eq_params[k] = v[:, None]
1469
- if len(v.shape) > 2:
1470
- raise ValueError(
1471
- "Each value of observed_eq_params must have 2 dimensions"
1472
- )
1473
-
1474
- self.n = observed_pinn_in.shape[0]
1475
- self._key = key
1476
- self.obs_batch_size = obs_batch_size
1477
-
1478
- self.data_exists = data_exists
1479
- if not self.data_exists and sharding_device is not None:
1480
- self.observed_pinn_in = jax.lax.with_sharding_constraint(
1481
- observed_pinn_in, sharding_device
1482
- )
1483
- self.observed_values = jax.lax.with_sharding_constraint(
1484
- observed_values, sharding_device
1485
- )
1486
- self.observed_eq_params = jax.lax.with_sharding_constraint(
1487
- self.observed_eq_params, sharding_device
1488
- )
1489
- else:
1490
- self.observed_pinn_in = observed_pinn_in
1491
- self.observed_values = observed_values
1492
-
1493
- if not self.data_exists:
1494
- self.curr_idx = 0
1495
- # NOTE for speed and to avoid duplicating data what is really
1496
- # shuffled is a vector of indices
1497
- indices = jnp.arange(self.n)
1498
- if sharding_device is not None:
1499
- self.indices = jax.lax.with_sharding_constraint(
1500
- indices, sharding_device
1501
- )
1502
- else:
1503
- self.indices = indices
1504
- self._key, self.indices, _ = _reset_batch_idx_and_permute(
1505
- self._get_operands()
1506
- )
1507
-
1508
- def _get_operands(self):
1509
- return (
1510
- self._key,
1511
- self.indices,
1512
- self.curr_idx,
1513
- self.obs_batch_size,
1514
- None,
1515
- )
1516
1312
 
1517
- def obs_batch(self):
1518
- """
1519
- Return a dictionary with (keys, values): (pinn_in, a mini batch of pinn
1520
- inputs), (obs, a mini batch of corresponding observations), (eq_params,
1521
- a dictionary with entry names found in `params["eq_params"]` and values
1522
- giving the correspond parameter value for the couple
1523
- (input, observation) mentioned before).
1524
- It can also be a dictionary of dictionaries as described above if
1525
- observed_pinn_in, observed_values, etc. are dictionaries with keys
1526
- representing the PINNs.
1527
- """
1528
-
1529
- (self._key, self.indices, self.curr_idx) = _reset_or_increment(
1530
- self.curr_idx + self.obs_batch_size, self.n, self._get_operands()
1531
- )
1532
-
1533
- minib_indices = jax.lax.dynamic_slice(
1534
- self.indices,
1535
- start_indices=(self.curr_idx,),
1536
- slice_sizes=(self.obs_batch_size,),
1537
- )
1538
-
1539
- obs_batch = {
1540
- "pinn_in": jnp.take(
1541
- self.observed_pinn_in, minib_indices, unique_indices=True, axis=0
1542
- ),
1543
- "val": jnp.take(
1544
- self.observed_values, minib_indices, unique_indices=True, axis=0
1545
- ),
1546
- "eq_params": jax.tree_util.tree_map(
1547
- lambda a: jnp.take(a, minib_indices, unique_indices=True, axis=0),
1548
- self.observed_eq_params,
1549
- ),
1550
- }
1551
- return obs_batch
1552
-
1553
- def get_batch(self):
1554
- """
1555
- Generic method to return a batch
1556
- """
1557
- return self.obs_batch()
1558
-
1559
- def tree_flatten(self):
1560
- children = (self._key, self.curr_idx, self.indices)
1561
- aux_data = {
1562
- k: vars(self)[k]
1563
- for k in [
1564
- "obs_batch_size",
1565
- "observed_pinn_in",
1566
- "observed_values",
1567
- "observed_eq_params",
1568
- ]
1569
- }
1570
- return (children, aux_data)
1571
-
1572
- @classmethod
1573
- def tree_unflatten(cls, aux_data, children):
1574
- (key, curr_idx, indices) = children
1575
- obj = cls(
1576
- key=key,
1577
- data_exists=True,
1578
- obs_batch_size=aux_data["obs_batch_size"],
1579
- observed_pinn_in=aux_data["observed_pinn_in"],
1580
- observed_values=aux_data["observed_values"],
1581
- observed_eq_params=aux_data["observed_eq_params"],
1582
- )
1583
- obj.curr_idx = curr_idx
1584
- obj.indices = indices
1585
- return obj
1586
-
1587
-
1588
- @register_pytree_node_class
1589
- class DataGeneratorObservationsMultiPINNs:
1590
- """
1313
+ class DataGeneratorObservationsMultiPINNs(eqx.Module):
1314
+ r"""
1591
1315
  Despite the class name, it is rather a dataloader from user provided
1592
1316
  observations that will be used for the observations loss.
1593
1317
  This is the DataGenerator to use when dealing with multiple PINNs
@@ -1597,146 +1321,123 @@ class DataGeneratorObservationsMultiPINNs:
1597
1321
  applied in `constraints_system_loss_apply` and in this case the
1598
1322
  batch.obs_batch_dict is a dict of obs_batch_dict over which the tree_map
1599
1323
  applies (we select the obs_batch_dict corresponding to its `u_dict` entry)
1324
+
1325
+ Parameters
1326
+ ----------
1327
+ obs_batch_size : Int
1328
+ The size of the batch of randomly selected observations
1329
+ `obs_batch_size` will be the same for all the
1330
+ elements of the obs dict.
1331
+ NOTE: no check is done BUT users should be careful that
1332
+ `obs_batch_size` must be equal to `temporal_batch_size` or
1333
+ `omega_batch_size` or the product of both. In the first case, the
1334
+ present DataGeneratorObservations instance complements an ODEBatch,
1335
+ PDEStatioBatch or a PDENonStatioBatch (with self.cartesian_product
1336
+ = False). In the second case, `obs_batch_size` =
1337
+ `temporal_batch_size * omega_batch_size` if the present
1338
+ DataGeneratorParameter complements a PDENonStatioBatch
1339
+ with self.cartesian_product = True
1340
+ observed_pinn_in_dict : Dict[str, Float[Array, "n_obs nb_pinn_in"] | None]
1341
+ A dict of observed_pinn_in as defined in DataGeneratorObservations.
1342
+ Keys must be that of `u_dict`.
1343
+ If no observation exists for a particular entry of `u_dict` the
1344
+ corresponding key must still exist in observed_pinn_in_dict with
1345
+ value None
1346
+ observed_values_dict : Dict[str, Float[Array, "n_obs, nb_pinn_out"] | None]
1347
+ A dict of observed_values as defined in DataGeneratorObservations.
1348
+ Keys must be that of `u_dict`.
1349
+ If no observation exists for a particular entry of `u_dict` the
1350
+ corresponding key must still exist in observed_values_dict with
1351
+ value None
1352
+ observed_eq_params_dict : Dict[str, Dict[str, Float[Array, "n_obs 1"]]]
1353
+ A dict of observed_eq_params as defined in DataGeneratorObservations.
1354
+ Keys must be that of `u_dict`.
1355
+ **Note**: if no observation exists for a particular entry of `u_dict` the
1356
+ corresponding key must still exist in observed_eq_params_dict with
1357
+ value `{}` (empty dictionnary).
1358
+ key
1359
+ Jax random key to shuffle batches.
1600
1360
  """
1601
1361
 
1602
- def __init__(
1603
- self,
1604
- obs_batch_size,
1605
- observed_pinn_in_dict,
1606
- observed_values_dict,
1607
- observed_eq_params_dict=None,
1608
- data_gen_obs_exists=False,
1609
- key=None,
1610
- ):
1611
- r"""
1612
- Parameters
1613
- ----------
1614
- obs_batch_size
1615
- An integer. The size of the batch of randomly selected observations
1616
- `obs_batch_size` will be the same for all the
1617
- elements of the obs dict.
1618
- NOTE: no check is done BUT users should be careful that
1619
- `obs_batch_size` must be equal to `temporal_batch_size` or
1620
- `omega_batch_size` or the product of both. In the first case, the
1621
- present DataGeneratorObservations instance complements an ODEBatch,
1622
- PDEStatioBatch or a PDENonStatioBatch (with self.cartesian_product
1623
- = False). In the second case, `obs_batch_size` =
1624
- `temporal_batch_size * omega_batch_size` if the present
1625
- DataGeneratorParameter complements a PDENonStatioBatch
1626
- with self.cartesian_product = True
1627
- observed_pinn_in_dict
1628
- A dict of observed_pinn_in as defined in DataGeneratorObservations.
1629
- Keys must be that of `u_dict`.
1630
- If no observation exists for a particular entry of `u_dict` the
1631
- corresponding key must still exist in observed_pinn_in_dict with
1632
- value None
1633
- observed_values_dict
1634
- A dict of observed_values as defined in DataGeneratorObservations.
1635
- Keys must be that of `u_dict`.
1636
- If no observation exists for a particular entry of `u_dict` the
1637
- corresponding key must still exist in observed_values_dict with
1638
- value None
1639
- observed_eq_params_dict
1640
- A dict of observed_eq_params as defined in DataGeneratorObservations.
1641
- Keys must be that of `u_dict`.
1642
- If no observation exists for a particular entry of `u_dict` the
1643
- corresponding key must still exist in observed_eq_params_dict with
1644
- value None
1645
- data_gen_obs_exists
1646
- Must be left to `False` when created by the user. Avoids the
1647
- regeneration the subclasses DataGeneratorObservations
1648
- at each pytree flattening and unflattening.
1649
- key
1650
- Jax random key to sample new time points and to shuffle batches.
1651
- Optional if data_gen_obs_exists is True
1652
- """
1653
- if (
1654
- observed_pinn_in_dict is None or observed_values_dict is None
1655
- ) and not data_gen_obs_exists:
1362
+ obs_batch_size: Int
1363
+ observed_pinn_in_dict: Dict[str, Float[Array, "n_obs nb_pinn_in"] | None]
1364
+ observed_values_dict: Dict[str, Float[Array, "n_obs nb_pinn_out"] | None]
1365
+ observed_eq_params_dict: Dict[str, Dict[str, Float[Array, "n_obs 1"]]] = eqx.field(
1366
+ default=None, kw_only=True
1367
+ )
1368
+ key: InitVar[Key]
1369
+
1370
+ data_gen_obs: Dict[str, "DataGeneratorObservations"] = eqx.field(init=False)
1371
+
1372
+ def __post_init__(self, key):
1373
+ if self.observed_pinn_in_dict is None or self.observed_values_dict is None:
1374
+ raise ValueError(
1375
+ "observed_pinn_in_dict and observed_values_dict " "must be provided"
1376
+ )
1377
+ if self.observed_pinn_in_dict.keys() != self.observed_values_dict.keys():
1656
1378
  raise ValueError(
1657
- "observed_pinn_in_dict and observed_values_dict "
1658
- "must be provided with data_gen_obs_exists is False"
1379
+ "Keys must be the same in observed_pinn_in_dict"
1380
+ " and observed_values_dict"
1659
1381
  )
1660
- self.obs_batch_size = obs_batch_size
1661
- self.data_gen_obs_exists = data_gen_obs_exists
1662
1382
 
1663
- if not self.data_gen_obs_exists:
1664
- if observed_pinn_in_dict.keys() != observed_values_dict.keys():
1665
- raise ValueError(
1666
- "Keys must be the same in observed_pinn_in_dict"
1667
- " and observed_values_dict"
1668
- )
1669
- if (
1670
- observed_eq_params_dict is not None
1671
- and observed_pinn_in_dict.keys() != observed_eq_params_dict.keys()
1672
- ):
1673
- raise ValueError(
1674
- "Keys must be the same in observed_eq_params_dict"
1675
- " and observed_pinn_in_dict and observed_values_dict"
1676
- )
1677
- if observed_eq_params_dict is None:
1678
- observed_eq_params_dict = {
1679
- k: None for k in observed_pinn_in_dict.keys()
1680
- }
1681
-
1682
- keys = dict(
1683
- zip(
1684
- observed_pinn_in_dict.keys(),
1685
- jax.random.split(key, len(observed_pinn_in_dict)),
1686
- )
1383
+ if self.observed_eq_params_dict is None:
1384
+ self.observed_eq_params_dict = {
1385
+ k: {} for k in self.observed_pinn_in_dict.keys()
1386
+ }
1387
+ elif self.observed_pinn_in_dict.keys() != self.observed_eq_params_dict.keys():
1388
+ raise ValueError(
1389
+ f"Keys must be the same in observed_eq_params_dict"
1390
+ f" and observed_pinn_in_dict and observed_values_dict"
1687
1391
  )
1688
- self.data_gen_obs = jax.tree_util.tree_map(
1689
- lambda k, pinn_in, val, eq_params: (
1690
- DataGeneratorObservations(
1691
- k, obs_batch_size, pinn_in, val, eq_params
1692
- )
1693
- if pinn_in is not None
1694
- else None
1695
- ),
1696
- keys,
1697
- observed_pinn_in_dict,
1698
- observed_values_dict,
1699
- observed_eq_params_dict,
1392
+
1393
+ keys = dict(
1394
+ zip(
1395
+ self.observed_pinn_in_dict.keys(),
1396
+ jax.random.split(key, len(self.observed_pinn_in_dict)),
1700
1397
  )
1398
+ )
1399
+ self.data_gen_obs = jax.tree_util.tree_map(
1400
+ lambda k, pinn_in, val, eq_params: (
1401
+ DataGeneratorObservations(
1402
+ k, self.obs_batch_size, pinn_in, val, eq_params
1403
+ )
1404
+ if pinn_in is not None
1405
+ else None
1406
+ ),
1407
+ keys,
1408
+ self.observed_pinn_in_dict,
1409
+ self.observed_values_dict,
1410
+ self.observed_eq_params_dict,
1411
+ )
1701
1412
 
1702
- def obs_batch(self):
1413
+ def obs_batch(self) -> tuple["DataGeneratorObservationsMultiPINNs", PyTree]:
1703
1414
  """
1704
1415
  Returns a dictionary of DataGeneratorObservations.obs_batch with keys
1705
1416
  from `u_dict`
1706
1417
  """
1707
- return jax.tree_util.tree_map(
1418
+ data_gen_and_batch_pytree = jax.tree_util.tree_map(
1708
1419
  lambda a: a.get_batch() if a is not None else {},
1709
1420
  self.data_gen_obs,
1710
1421
  is_leaf=lambda x: isinstance(x, DataGeneratorObservations),
1711
1422
  ) # note the is_leaf note to traverse the DataGeneratorObservations and
1712
1423
  # thus to be able to call the method on the element(s) of
1713
1424
  # self.data_gen_obs which are not None
1425
+ new_attribute = jax.tree_util.tree_map(
1426
+ lambda a: a[0],
1427
+ data_gen_and_batch_pytree,
1428
+ is_leaf=lambda x: isinstance(x, tuple),
1429
+ )
1430
+ new = eqx.tree_at(lambda m: m.data_gen_obs, self, new_attribute)
1431
+ batches = jax.tree_util.tree_map(
1432
+ lambda a: a[1],
1433
+ data_gen_and_batch_pytree,
1434
+ is_leaf=lambda x: isinstance(x, tuple),
1435
+ )
1714
1436
 
1715
- def get_batch(self):
1437
+ return new, batches
1438
+
1439
+ def get_batch(self) -> tuple["DataGeneratorObservationsMultiPINNs", PyTree]:
1716
1440
  """
1717
1441
  Generic method to return a batch
1718
1442
  """
1719
1443
  return self.obs_batch()
1720
-
1721
- def tree_flatten(self):
1722
- # because a dict with "str" keys cannot go in the children (jittable)
1723
- # attributes, we need to separate it in two and recreate the zip in the
1724
- # tree_unflatten
1725
- children = self.data_gen_obs.values()
1726
- aux_data = {
1727
- "obs_batch_size": self.obs_batch_size,
1728
- "data_gen_obs_keys": self.data_gen_obs.keys(),
1729
- }
1730
- return (children, aux_data)
1731
-
1732
- @classmethod
1733
- def tree_unflatten(cls, aux_data, children):
1734
- (data_gen_obs_values) = children
1735
- obj = cls(
1736
- observed_pinn_in_dict=None,
1737
- observed_values_dict=None,
1738
- data_gen_obs_exists=True,
1739
- obs_batch_size=aux_data["obs_batch_size"],
1740
- )
1741
- obj.data_gen_obs = dict(zip(aux_data["data_gen_obs_keys"], data_gen_obs_values))
1742
- return obj