jinns 0.8.10__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +953 -1182
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +321 -168
  9. jinns/loss/_LossODE.py +290 -307
  10. jinns/loss/_LossPDE.py +628 -1040
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +95 -96
  13. jinns/loss/{_Losses.py → _loss_utils.py} +104 -46
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +94 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +193 -45
  22. jinns/solver/_solve.py +199 -144
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -43
  25. jinns/utils/_hyperpinn.py +226 -127
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +117 -84
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +52 -144
  32. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/METADATA +5 -4
  33. jinns-1.0.0.dist-info/RECORD +38 -0
  34. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/WHEEL +1 -1
  35. jinns/experimental/_sinuspinn.py +0 -135
  36. jinns/experimental/_spectralpinn.py +0 -87
  37. jinns/solver/_seq2seq.py +0 -157
  38. jinns/utils/_optim.py +0 -147
  39. jinns/utils/_utils_uspinn.py +0 -727
  40. jinns-0.8.10.dist-info/RECORD +0 -36
  41. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/LICENSE +0 -0
  42. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1,75 +1,93 @@
1
+ # pylint: disable=unsubscriptable-object
1
2
  """
2
- DataGenerators to generate batches of points in space, time and more
3
+ Define the DataGeneratorODE equinox module
3
4
  """
4
-
5
- from typing import NamedTuple
6
- from jax.typing import ArrayLike
5
+ from __future__ import (
6
+ annotations,
7
+ ) # https://docs.python.org/3/library/typing.html#constant
8
+
9
+ from typing import TYPE_CHECKING, Dict
10
+ from dataclasses import InitVar
11
+ import equinox as eqx
12
+ import jax
7
13
  import jax.numpy as jnp
8
- from jax import random
9
- from jax.tree_util import register_pytree_node_class
10
- import jax.lax
11
-
12
-
13
- class ODEBatch(NamedTuple):
14
- temporal_batch: ArrayLike
15
- param_batch_dict: dict = None
16
- obs_batch_dict: dict = None
17
-
14
+ from jaxtyping import Key, Int, PyTree, Array, Float, Bool
15
+ from jinns.data._Batchs import *
18
16
 
19
- class PDENonStatioBatch(NamedTuple):
20
- inside_batch: ArrayLike
21
- border_batch: ArrayLike
22
- temporal_batch: ArrayLike
23
- param_batch_dict: dict = None
24
- obs_batch_dict: dict = None
17
+ if TYPE_CHECKING:
18
+ from jinns.utils._types import *
25
19
 
26
20
 
27
- class PDEStatioBatch(NamedTuple):
28
- inside_batch: ArrayLike
29
- border_batch: ArrayLike
30
- param_batch_dict: dict = None
31
- obs_batch_dict: dict = None
32
-
33
-
34
- def append_param_batch(batch, param_batch_dict):
21
+ def append_param_batch(batch: AnyBatch, param_batch_dict: dict) -> AnyBatch:
35
22
  """
36
23
  Utility function that fill the param_batch_dict of a batch object with a
37
24
  param_batch_dict
38
25
  """
39
- return batch._replace(param_batch_dict=param_batch_dict)
26
+ return eqx.tree_at(
27
+ lambda m: m.param_batch_dict,
28
+ batch,
29
+ param_batch_dict,
30
+ is_leaf=lambda x: x is None,
31
+ )
40
32
 
41
33
 
42
- def append_obs_batch(batch, obs_batch_dict):
34
+ def append_obs_batch(batch: AnyBatch, obs_batch_dict: dict) -> AnyBatch:
43
35
  """
44
36
  Utility function that fill the obs_batch_dict of a batch object with a
45
37
  obs_batch_dict
46
38
  """
47
- return batch._replace(obs_batch_dict=obs_batch_dict)
39
+ return eqx.tree_at(
40
+ lambda m: m.obs_batch_dict, batch, obs_batch_dict, is_leaf=lambda x: x is None
41
+ )
42
+
43
+
44
+ def make_cartesian_product(
45
+ b1: Float[Array, "batch_size dim1"], b2: Float[Array, "batch_size dim2"]
46
+ ) -> Float[Array, "(batch_size*batch_size) (dim1+dim2)"]:
47
+ """
48
+ Create the cartesian product of a time and a border omega batches
49
+ by tiling and repeating
50
+ """
51
+ n1 = b1.shape[0]
52
+ n2 = b2.shape[0]
53
+ b1 = jnp.repeat(b1, n2, axis=0)
54
+ b2 = jnp.tile(b2, reps=(n1,) + tuple(1 for i in b2.shape[1:]))
55
+ return jnp.concatenate([b1, b2], axis=1)
48
56
 
49
57
 
50
- def _reset_batch_idx_and_permute(operands):
58
+ def _reset_batch_idx_and_permute(
59
+ operands: tuple[Key, Float[Array, "n dimension"], Int, None, Float[Array, "n"]]
60
+ ) -> tuple[Key, Float[Array, "n dimension"], Int]:
51
61
  key, domain, curr_idx, _, p = operands
52
62
  # resetting counter
53
63
  curr_idx = 0
54
64
  # reshuffling
55
- key, subkey = random.split(key)
65
+ key, subkey = jax.random.split(key)
56
66
  # domain = random.permutation(subkey, domain, axis=0, independent=False)
57
67
  # we want that permutation = choice when p=None
58
68
  # otherwise p is used to avoid collocation points not in nt_start
59
- domain = random.choice(subkey, domain, shape=(domain.shape[0],), replace=False, p=p)
69
+ domain = jax.random.choice(
70
+ subkey, domain, shape=(domain.shape[0],), replace=False, p=p
71
+ )
60
72
 
61
73
  # return updated
62
74
  return (key, domain, curr_idx)
63
75
 
64
76
 
65
- def _increment_batch_idx(operands):
77
+ def _increment_batch_idx(
78
+ operands: tuple[Key, Float[Array, "n dimension"], Int, None, Float[Array, "n"]]
79
+ ) -> tuple[Key, Float[Array, "n dimension"], Int]:
66
80
  key, domain, curr_idx, batch_size, _ = operands
67
81
  # simply increases counter and get the batch
68
82
  curr_idx += batch_size
69
83
  return (key, domain, curr_idx)
70
84
 
71
85
 
72
- def _reset_or_increment(bend, n_eff, operands):
86
+ def _reset_or_increment(
87
+ bend: Int,
88
+ n_eff: Int,
89
+ operands: tuple[Key, Float[Array, "n dimension"], Int, None, Float[Array, "n"]],
90
+ ) -> tuple[Key, Float[Array, "n dimension"], Int]:
73
91
  """
74
92
  Factorize the code of the jax.lax.cond which checks if we have seen all the
75
93
  batches in an epoch
@@ -98,15 +116,18 @@ def _reset_or_increment(bend, n_eff, operands):
98
116
  )
99
117
 
100
118
 
101
- def _check_and_set_rar_parameters(rar_parameters, n, n_start):
119
+ def _check_and_set_rar_parameters(
120
+ rar_parameters: dict, n: Int, n_start: Int
121
+ ) -> tuple[Int, Float[Array, "n"], Int, Int]:
102
122
  if rar_parameters is not None and n_start is None:
103
123
  raise ValueError(
104
- f"n_start or/and nt_start must be provided in the context of RAR sampling scheme, {n_start} was provided"
124
+ "nt_start must be provided in the context of RAR sampling scheme"
105
125
  )
126
+
106
127
  if rar_parameters is not None:
107
128
  # Default p is None. However, in the RAR sampling scheme we use 0
108
129
  # probability to specify non-used collocation points (i.e. points
109
- # above n_start). Thus, p is a vector of probability of shape (n, 1).
130
+ # above nt_start). Thus, p is a vector of probability of shape (nt, 1).
110
131
  p = jnp.zeros((n,))
111
132
  p = p.at[:n_start].set(1 / n_start)
112
133
  # set internal counter for the number of gradient steps since the
@@ -118,6 +139,7 @@ def _check_and_set_rar_parameters(rar_parameters, n, n_start):
118
139
  # have been added
119
140
  rar_iter_nb = 0
120
141
  else:
142
+ n_start = n
121
143
  p = None
122
144
  rar_iter_from_last_sampling = None
123
145
  rar_iter_nb = None
@@ -125,109 +147,102 @@ def _check_and_set_rar_parameters(rar_parameters, n, n_start):
125
147
  return n_start, p, rar_iter_from_last_sampling, rar_iter_nb
126
148
 
127
149
 
128
- #####################################################
129
- # DataGenerator for ODE : only returns time_batches
130
- #####################################################
131
-
132
-
133
- @register_pytree_node_class
134
- class DataGeneratorODE:
150
+ class DataGeneratorODE(eqx.Module):
135
151
  """
136
152
  A class implementing data generator object for ordinary differential equations.
137
153
 
138
-
139
- **Note:** DataGeneratorODE is jittable. Hence it implements the tree_flatten() and
140
- tree_unflatten methods.
154
+ Parameters
155
+ ----------
156
+ key : Key
157
+ Jax random key to sample new time points and to shuffle batches
158
+ nt : Int
159
+ The number of total time points that will be divided in
160
+ batches. Batches are made so that each data point is seen only
161
+ once during 1 epoch.
162
+ tmin : float
163
+ The minimum value of the time domain to consider
164
+ tmax : float
165
+ The maximum value of the time domain to consider
166
+ temporal_batch_size : int
167
+ The size of the batch of randomly selected points among
168
+ the `nt` points.
169
+ method : str, default="uniform"
170
+ Either `grid` or `uniform`, default is `uniform`.
171
+ The method that generates the `nt` time points. `grid` means
172
+ regularly spaced points over the domain. `uniform` means uniformly
173
+ sampled points over the domain
174
+ rar_parameters : Dict[str, Int], default=None
175
+ Default to None: do not use Residual Adaptative Resampling.
176
+ Otherwise a dictionary with keys. `start_iter`: the iteration at
177
+ which we start the RAR sampling scheme (we first have a burn in
178
+ period). `update_rate`: the number of gradient steps taken between
179
+ each appending of collocation points in the RAR algo.
180
+ `sample_size`: the size of the sample from which we will select new
181
+ collocation points. `selected_sample_size_times`: the number of selected
182
+ points from the sample to be added to the current collocation
183
+ points
184
+ nt_start : Int, default=None
185
+ Defaults to None. The effective size of nt used at start time.
186
+ This value must be
187
+ provided when rar_parameters is not None. Otherwise we set internally
188
+ nt_start = nt and this is hidden from the user.
189
+ In RAR, nt_start
190
+ then corresponds to the initial number of points we train the PINN.
141
191
  """
142
192
 
143
- def __init__(
144
- self,
145
- key,
146
- nt,
147
- tmin,
148
- tmax,
149
- temporal_batch_size,
150
- method="uniform",
151
- rar_parameters=None,
152
- nt_start=None,
153
- data_exists=False,
154
- ):
155
- """
156
- Parameters
157
- ----------
158
- key
159
- Jax random key to sample new time points and to shuffle batches
160
- nt
161
- An integer. The number of total time points that will be divided in
162
- batches. Batches are made so that each data point is seen only
163
- once during 1 epoch.
164
- tmin
165
- A float. The minimum value of the time domain to consider
166
- tmax
167
- A float. The maximum value of the time domain to consider
168
- temporal_batch_size
169
- An integer. The size of the batch of randomly selected points among
170
- the `nt` points.
171
- method
172
- Either `grid` or `uniform`, default is `uniform`.
173
- The method that generates the `nt` time points. `grid` means
174
- regularly spaced points over the domain. `uniform` means uniformly
175
- sampled points over the domain
176
- rar_parameters
177
- Default to None: do not use Residual Adaptative Resampling.
178
- Otherwise a dictionary with keys. `start_iter`: the iteration at
179
- which we start the RAR sampling scheme (we first have a burn in
180
- period). `update_every`: the number of gradient steps taken between
181
- each appending of collocation points in the RAR algo.
182
- `sample_size_times`: the size of the sample from which we will select new
183
- collocation points. `selected_sample_size_times`: the number of selected
184
- points from the sample to be added to the current collocation
185
- points
186
- "DeepXDE: A deep learning library for solving differential
187
- equations", L. Lu, SIAM Review, 2021
188
- nt_start
189
- Defaults to None. The effective size of nt used at start time.
190
- This value must be
191
- provided when rar_parameters is not None. Otherwise we set internally
192
- nt_start = nt and this is hidden from the user.
193
- In RAR, nt_start
194
- then corresponds to the initial number of points we train the PINN.
195
- data_exists
196
- Must be left to `False` when created by the user. Avoids the
197
- regeneration of the `nt` time points at each pytree flattening and
198
- unflattening.
199
- """
200
- self.data_exists = data_exists
201
- self._key = key
202
- self.nt = nt
203
- self.tmin = tmin
204
- self.tmax = tmax
205
- self.temporal_batch_size = temporal_batch_size
206
- self.method = method
207
- self.rar_parameters = rar_parameters
208
-
209
- # Set-up for RAR (if used)
193
+ key: Key
194
+ nt: Int
195
+ tmin: Float
196
+ tmax: Float
197
+ temporal_batch_size: Int = eqx.field(static=True) # static cause used as a
198
+ # shape in jax.lax.dynamic_slice
199
+ method: str = eqx.field(static=True, default_factory=lambda: "uniform")
200
+ rar_parameters: Dict[str, Int] = None
201
+ nt_start: Int = eqx.field(static=True, default=None)
202
+
203
+ # all the init=False fields are set in __post_init__, even after a _replace
204
+ # or eqx.tree_at __post_init__ is called
205
+ p_times: Float[Array, "nt"] = eqx.field(init=False)
206
+ rar_iter_from_last_sampling: Int = eqx.field(init=False)
207
+ rar_iter_nb: Int = eqx.field(init=False)
208
+ curr_time_idx: Int = eqx.field(init=False)
209
+ times: Float[Array, "nt"] = eqx.field(init=False)
210
+
211
+ def __post_init__(self):
210
212
  (
211
213
  self.nt_start,
212
214
  self.p_times,
213
215
  self.rar_iter_from_last_sampling,
214
216
  self.rar_iter_nb,
215
- ) = _check_and_set_rar_parameters(rar_parameters, n=nt, n_start=nt_start)
216
-
217
- if not self.data_exists:
218
- # Useful when using a lax.scan with pytree
219
- # Optionally can tell JAX not to re-generate data
220
- self.curr_time_idx = 0
221
- self.generate_time_data()
222
- self._key, self.times, _ = _reset_batch_idx_and_permute(
223
- self._get_time_operands()
224
- )
225
-
226
- def sample_in_time_domain(self, n_samples):
227
- self._key, subkey = random.split(self._key, 2)
228
- return random.uniform(subkey, (n_samples,), minval=self.tmin, maxval=self.tmax)
217
+ ) = _check_and_set_rar_parameters(self.rar_parameters, self.nt, self.nt_start)
218
+
219
+ self.curr_time_idx = jnp.iinfo(jnp.int32).max - self.temporal_batch_size - 1
220
+ # to be sure there is a
221
+ # shuffling at first get_batch() we do not call
222
+ # _reset_batch_idx_and_permute in __init__ or __post_init__ because it
223
+ # would return a copy of self and we have not investigate what would
224
+ # happen
225
+ # NOTE the (- self.temporal_batch_size - 1) because otherwise when computing
226
+ # `bend` we overflow the max int32 with unwanted behaviour
227
+
228
+ self.key, self.times = self.generate_time_data(self.key)
229
+ # Note that, here, in __init__ (and __post_init__), this is the
230
+ # only place where self assignment are authorized so we do the
231
+ # above way for the key. Note that one of the motivation to return the
232
+ # key from generate_*_data is to easily align key with legacy
233
+ # DataGenerators to use same unit tests
234
+
235
+ def sample_in_time_domain(
236
+ self, key: Key, sample_size: Int = None
237
+ ) -> Float[Array, "nt"]:
238
+ return jax.random.uniform(
239
+ key,
240
+ (self.nt if sample_size is None else sample_size,),
241
+ minval=self.tmin,
242
+ maxval=self.tmax,
243
+ )
229
244
 
230
- def generate_time_data(self):
245
+ def generate_time_data(self, key: Key) -> tuple[Key, Float[Array, "nt"]]:
231
246
  """
232
247
  Construct a complete set of `self.nt` time points according to the
233
248
  specified `self.method`
@@ -235,24 +250,28 @@ class DataGeneratorODE:
235
250
  Note that self.times has always size self.nt and not self.nt_start, even
236
251
  in RAR scheme, we must allocate all the collocation points
237
252
  """
253
+ key, subkey = jax.random.split(self.key)
238
254
  if self.method == "grid":
239
- self.partial_times = (self.tmax - self.tmin) / self.nt
240
- self.times = jnp.arange(self.tmin, self.tmax, self.partial_times)
241
- elif self.method == "uniform":
242
- self.times = self.sample_in_time_domain(self.nt)
243
- else:
244
- raise ValueError("Method " + self.method + " is not implemented.")
255
+ partial_times = (self.tmax - self.tmin) / self.nt
256
+ return key, jnp.arange(self.tmin, self.tmax, partial_times)
257
+ if self.method == "uniform":
258
+ return key, self.sample_in_time_domain(subkey)
259
+ raise ValueError("Method " + self.method + " is not implemented.")
245
260
 
246
- def _get_time_operands(self):
261
+ def _get_time_operands(
262
+ self,
263
+ ) -> tuple[Key, Float[Array, "nt"], Int, Int, Float[Array, "nt"]]:
247
264
  return (
248
- self._key,
265
+ self.key,
249
266
  self.times,
250
267
  self.curr_time_idx,
251
268
  self.temporal_batch_size,
252
269
  self.p_times,
253
270
  )
254
271
 
255
- def temporal_batch(self):
272
+ def temporal_batch(
273
+ self,
274
+ ) -> tuple["DataGeneratorODE", Float[Array, "temporal_batch_size"]]:
256
275
  """
257
276
  Return a batch of time points. If all the batches have been seen, we
258
277
  reshuffle them, otherwise we just return the next unseen batch.
@@ -264,210 +283,142 @@ class DataGeneratorODE:
264
283
  if self.rar_parameters is not None:
265
284
  nt_eff = (
266
285
  self.nt_start
267
- + self.rar_iter_nb * self.rar_parameters["selected_sample_size_omega"]
286
+ + self.rar_iter_nb * self.rar_parameters["selected_sample_size_times"]
268
287
  )
269
288
  else:
270
289
  nt_eff = self.nt
271
- (self._key, self.times, self.curr_time_idx) = _reset_or_increment(
272
- bend, nt_eff, self._get_time_operands()
290
+
291
+ new_attributes = _reset_or_increment(bend, nt_eff, self._get_time_operands())
292
+ new = eqx.tree_at(
293
+ lambda m: (m.key, m.times, m.curr_time_idx), self, new_attributes
273
294
  )
274
295
 
275
296
  # commands below are equivalent to
276
297
  # return self.times[i:(i+t_batch_size)]
277
298
  # start indices can be dynamic be the slice shape is fixed
278
- return jax.lax.dynamic_slice(
279
- self.times,
280
- start_indices=(self.curr_time_idx,),
281
- slice_sizes=(self.temporal_batch_size,),
299
+ return new, jax.lax.dynamic_slice(
300
+ new.times,
301
+ start_indices=(new.curr_time_idx,),
302
+ slice_sizes=(new.temporal_batch_size,),
282
303
  )
283
304
 
284
- def get_batch(self):
305
+ def get_batch(self) -> tuple["DataGeneratorODE", ODEBatch]:
285
306
  """
286
307
  Generic method to return a batch. Here we call `self.temporal_batch()`
287
308
  """
288
- return ODEBatch(temporal_batch=self.temporal_batch())
289
-
290
- def tree_flatten(self):
291
- children = (
292
- self._key,
293
- self.times,
294
- self.curr_time_idx,
295
- self.tmin,
296
- self.tmax,
297
- self.p_times,
298
- self.rar_iter_from_last_sampling,
299
- self.rar_iter_nb,
300
- ) # arrays / dynamic values
301
- aux_data = {
302
- k: vars(self)[k]
303
- for k in [
304
- "temporal_batch_size",
305
- "method",
306
- "nt",
307
- "rar_parameters",
308
- "nt_start",
309
- ]
310
- } # static values
311
- return (children, aux_data)
312
-
313
- @classmethod
314
- def tree_unflatten(cls, aux_data, children):
315
- """
316
- **Note:** When reconstructing the class, we force ``data_exists=True``
317
- in order not to re-generate the data at each flattening and
318
- unflattening that happens e.g. during the gradient descent in the
319
- optimization process
320
- """
321
- (
322
- key,
323
- times,
324
- curr_time_idx,
325
- tmin,
326
- tmax,
327
- p_times,
328
- rar_iter_from_last_sampling,
329
- rar_iter_nb,
330
- ) = children
331
- obj = cls(
332
- key=key,
333
- data_exists=True,
334
- tmin=tmin,
335
- tmax=tmax,
336
- **aux_data,
337
- )
338
- obj.times = times
339
- obj.curr_time_idx = curr_time_idx
340
- obj.p_times = p_times
341
- obj.rar_iter_from_last_sampling = rar_iter_from_last_sampling
342
- obj.rar_iter_nb = rar_iter_nb
343
- return obj
344
-
345
-
346
- ##########################################
347
- # Data Generator for PDE in stationnary
348
- # and non-stationnary cases
349
- ##########################################
309
+ new, temporal_batch = self.temporal_batch()
310
+ return new, ODEBatch(temporal_batch=temporal_batch)
350
311
 
351
312
 
352
- class DataGeneratorPDEAbstract:
353
- """generic skeleton class for a PDE data generator"""
354
-
355
- def __init__(self, data_exists=False) -> None:
356
- # /!\ WARNING /!\: an-end user should never create an object
357
- # with data_exists=True. Or else generate_data() won't be called.
358
- # Useful when using a lax.scan with a DataGenerator in the carry
359
- # It tells JAX not to re-generate data in the __init__()
360
- self.data_exists = data_exists
361
-
362
-
363
- @register_pytree_node_class
364
- class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
365
- """
313
+ class CubicMeshPDEStatio(eqx.Module):
314
+ r"""
366
315
  A class implementing data generator object for stationary partial
367
316
  differential equations.
368
317
 
369
-
370
- **Note:** CubicMeshPDEStatio is jittable. Hence it implements the tree_flatten() and
371
- tree_unflatten methods.
318
+ Parameters
319
+ ----------
320
+ key : Key
321
+ Jax random key to sample new time points and to shuffle batches
322
+ n : Int
323
+ The number of total $\Omega$ points that will be divided in
324
+ batches. Batches are made so that each data point is seen only
325
+ once during 1 epoch.
326
+ nb : Int | None
327
+ The total number of points in $\partial\Omega$.
328
+ Can be `None` not to lose performance generating the border
329
+ batch if they are not used
330
+ omega_batch_size : Int
331
+ The size of the batch of randomly selected points among
332
+ the `n` points.
333
+ omega_border_batch_size : Int | None
334
+ The size of the batch of points randomly selected
335
+ among the `nb` points.
336
+ Can be `None` not to lose performance generating the border
337
+ batch if they are not used
338
+ dim : Int
339
+ Dimension of $\Omega$ domain
340
+ min_pts : tuple[tuple[Float, Float], ...]
341
+ A tuple of minimum values of the domain along each dimension. For a sampling
342
+ in `n` dimension, this represents $(x_{1, min}, x_{2,min}, ...,
343
+ x_{n, min})$
344
+ max_pts : tuple[tuple[Float, Float], ...]
345
+ A tuple of maximum values of the domain along each dimension. For a sampling
346
+ in `n` dimension, this represents $(x_{1, max}, x_{2,max}, ...,
347
+ x_{n,max})$
348
+ method : str, default="uniform"
349
+ Either `grid` or `uniform`, default is `uniform`.
350
+ The method that generates the `nt` time points. `grid` means
351
+ regularly spaced points over the domain. `uniform` means uniformly
352
+ sampled points over the domain
353
+ rar_parameters : Dict[str, Int], default=None
354
+ Default to None: do not use Residual Adaptative Resampling.
355
+ Otherwise a dictionary with keys. `start_iter`: the iteration at
356
+ which we start the RAR sampling scheme (we first have a burn in
357
+ period). `update_every`: the number of gradient steps taken between
358
+ each appending of collocation points in the RAR algo.
359
+ `sample_size_omega`: the size of the sample from which we will select new
360
+ collocation points. `selected_sample_size_omega`: the number of selected
361
+ points from the sample to be added to the current collocation
362
+ points
363
+ n_start : Int, default=None
364
+ Defaults to None. The effective size of n used at start time.
365
+ This value must be
366
+ provided when rar_parameters is not None. Otherwise we set internally
367
+ n_start = n and this is hidden from the user.
368
+ In RAR, n_start
369
+ then corresponds to the initial number of points we train the PINN.
372
370
  """
373
371
 
374
- def __init__(
375
- self,
376
- key,
377
- n,
378
- nb,
379
- omega_batch_size,
380
- omega_border_batch_size,
381
- dim,
382
- min_pts,
383
- max_pts,
384
- method="grid",
385
- rar_parameters=None,
386
- n_start=None,
387
- data_exists=False,
388
- ):
389
- r"""
390
- Parameters
391
- ----------
392
- key
393
- Jax random key to sample new time points and to shuffle batches
394
- n
395
- An integer. The number of total :math:`\Omega` points that will be divided in
396
- batches. Batches are made so that each data point is seen only
397
- once during 1 epoch.
398
- nb
399
- An integer. The total number of points in :math:`\partial\Omega`.
400
- Can be `None` not to lose performance generating the border
401
- batch if they are not used
402
- omega_batch_size
403
- An integer. The size of the batch of randomly selected points among
404
- the `n` points.
405
- omega_border_batch_size
406
- An integer. The size of the batch of points randomly selected
407
- among the `nb` points.
408
- Can be `None` not to lose performance generating the border
409
- batch if they are not used
410
- dim
411
- An integer. dimension of :math:`\Omega` domain
412
- min_pts
413
- A tuple of minimum values of the domain along each dimension. For a sampling
414
- in `n` dimension, this represents :math:`(x_{1, min}, x_{2,min}, ...,
415
- x_{n, min})`
416
- max_pts
417
- A tuple of maximum values of the domain along each dimension. For a sampling
418
- in `n` dimension, this represents :math:`(x_{1, max}, x_{2,max}, ...,
419
- x_{n,max})`
420
- method
421
- Either `grid` or `uniform`, default is `grid`.
422
- The method that generates the `nt` time points. `grid` means
423
- regularly spaced points over the domain. `uniform` means uniformly
424
- sampled points over the domain
425
- rar_parameters
426
- Default to None: do not use Residual Adaptative Resampling.
427
- Otherwise a dictionary with keys. `start_iter`: the iteration at
428
- which we start the RAR sampling scheme (we first have a burn in
429
- period). `update_every`: the number of gradient steps taken between
430
- each appending of collocation points in the RAR algo.
431
- `sample_size_omega`: the size of the sample from which we will select new
432
- collocation points. `selected_sample_size_omega`: the number of selected
433
- points from the sample to be added to the current collocation
434
- points
435
- "DeepXDE: A deep learning library for solving differential
436
- equations", L. Lu, SIAM Review, 2021
437
- n_start
438
- Defaults to None. The effective size of n used at start time.
439
- This value must be
440
- provided when rar_parameters is not None. Otherwise we set internally
441
- n_start = n and this is hidden from the user.
442
- In RAR, n_start
443
- then corresponds to the initial number of points we train the PINN.
444
- data_exists
445
- Must be left to `False` when created by the user. Avoids the
446
- regeneration of :math:`\Omega`, :math:`\partial\Omega` and
447
- time points at each pytree flattening and unflattening.
448
- """
449
- super().__init__(data_exists=data_exists)
450
- self.method = method
451
- self._key = key
452
- self.dim = dim
453
- self.min_pts = min_pts
454
- self.max_pts = max_pts
455
- assert dim == len(min_pts) and isinstance(min_pts, tuple)
456
- assert dim == len(max_pts) and isinstance(max_pts, tuple)
457
- self.n = n
458
- self.rar_parameters = rar_parameters
372
+ # kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
373
+ key: Key = eqx.field(kw_only=True)
374
+ n: Int = eqx.field(kw_only=True)
375
+ nb: Int | None = eqx.field(kw_only=True)
376
+ omega_batch_size: Int = eqx.field(
377
+ kw_only=True, static=True
378
+ ) # static cause used as a
379
+ # shape in jax.lax.dynamic_slice
380
+ omega_border_batch_size: Int | None = eqx.field(
381
+ kw_only=True, static=True
382
+ ) # static cause used as a
383
+ # shape in jax.lax.dynamic_slice
384
+ dim: Int = eqx.field(kw_only=True, static=True) # static cause used as a
385
+ # shape in jax.lax.dynamic_slice
386
+ min_pts: tuple[tuple[Float, Float], ...] = eqx.field(kw_only=True)
387
+ max_pts: tuple[tuple[Float, Float], ...] = eqx.field(kw_only=True)
388
+ method: str = eqx.field(
389
+ kw_only=True, static=True, default_factory=lambda: "uniform"
390
+ )
391
+ rar_parameters: Dict[str, Int] = eqx.field(kw_only=True, default=None)
392
+ n_start: Int = eqx.field(kw_only=True, default=None, static=True)
393
+
394
+ # all the init=False fields are set in __post_init__, even after a _replace
395
+ # or eqx.tree_at __post_init__ is called
396
+ p_omega: Float[Array, "n"] = eqx.field(init=False)
397
+ p_border: None = eqx.field(init=False)
398
+ rar_iter_from_last_sampling: Int = eqx.field(init=False)
399
+ rar_iter_nb: Int = eqx.field(init=False)
400
+ curr_omega_idx: Int = eqx.field(init=False)
401
+ curr_omega_border_idx: Int = eqx.field(init=False)
402
+ omega: Float[Array, "n dim"] = eqx.field(init=False)
403
+ omega_border: Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None = eqx.field(
404
+ init=False
405
+ )
406
+
407
+ def __post_init__(self):
408
+ assert self.dim == len(self.min_pts) and isinstance(self.min_pts, tuple)
409
+ assert self.dim == len(self.max_pts) and isinstance(self.max_pts, tuple)
410
+
459
411
  (
460
412
  self.n_start,
461
413
  self.p_omega,
462
414
  self.rar_iter_from_last_sampling,
463
415
  self.rar_iter_nb,
464
- ) = _check_and_set_rar_parameters(rar_parameters, n=n, n_start=n_start)
416
+ ) = _check_and_set_rar_parameters(self.rar_parameters, self.n, self.n_start)
465
417
 
466
418
  self.p_border = None # no RAR sampling for border for now
467
419
 
468
- self.omega_batch_size = omega_batch_size
469
-
470
- if omega_border_batch_size is None:
420
+ # Special handling for the border batch
421
+ if self.omega_border_batch_size is None:
471
422
  self.nb = None
472
423
  self.omega_border_batch_size = None
473
424
  elif self.dim == 1:
@@ -476,53 +427,50 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
476
427
  # always set to 2.
477
428
  self.nb = 2
478
429
  self.omega_border_batch_size = 2
479
- # warnings.warn("We are in 1-D case => omega_border_batch_size is "
480
- # "ignored since borders of Omega are singletons."
481
- # " self.border_batch() will return [xmin, xmax]"
482
- # )
430
+ # We are in 1-D case => omega_border_batch_size is
431
+ # ignored since borders of Omega are singletons.
432
+ # self.border_batch() will return [xmin, xmax]
483
433
  else:
484
- if nb % (2 * self.dim) != 0 or nb < 2 * self.dim:
434
+ if self.nb % (2 * self.dim) != 0 or self.nb < 2 * self.dim:
485
435
  raise ValueError(
486
436
  "number of border point must be"
487
437
  " a multiple of 2xd (the # of faces of a d-dimensional cube)"
488
438
  )
489
- if nb // (2 * self.dim) < omega_border_batch_size:
439
+ if self.nb // (2 * self.dim) < self.omega_border_batch_size:
490
440
  raise ValueError(
491
441
  "number of points per facets (nb//2*self.dim)"
492
442
  " cannot be lower than border batch size"
493
443
  )
494
- self.nb = int((2 * self.dim) * (nb // (2 * self.dim)))
495
- self.omega_border_batch_size = omega_border_batch_size
496
-
497
- if not self.data_exists:
498
- # Useful when using a lax.scan with pytree
499
- # Optionally tells JAX not to re-generate data when re-building the
500
- # object
501
- self.curr_omega_idx = 0
502
- self.curr_omega_border_idx = 0
503
- self.generate_data()
504
- self._key, self.omega, _ = _reset_batch_idx_and_permute(
505
- self._get_omega_operands()
506
- )
507
- if self.omega_border is not None and self.dim > 1:
508
- self._key, self.omega_border, _ = _reset_batch_idx_and_permute(
509
- self._get_omega_border_operands()
510
- )
444
+ self.nb = int((2 * self.dim) * (self.nb // (2 * self.dim)))
511
445
 
512
- def sample_in_omega_domain(self, n_samples):
446
+ self.curr_omega_idx = jnp.iinfo(jnp.int32).max - self.omega_batch_size - 1
447
+ # see explaination in DataGeneratorODE
448
+ if self.omega_border_batch_size is None:
449
+ self.curr_omega_border_idx = None
450
+ else:
451
+ self.curr_omega_border_idx = (
452
+ jnp.iinfo(jnp.int32).max - self.omega_border_batch_size - 1
453
+ )
454
+ # key, subkey = jax.random.split(self.key)
455
+ # self.key = key
456
+ self.key, self.omega, self.omega_border = self.generate_data(self.key)
457
+ # see explaination in DataGeneratorODE for the key
458
+
459
+ def sample_in_omega_domain(
460
+ self, keys: Key, sample_size: Int = None
461
+ ) -> Float[Array, "n dim"]:
462
+ sample_size = self.n if sample_size is None else sample_size
513
463
  if self.dim == 1:
514
464
  xmin, xmax = self.min_pts[0], self.max_pts[0]
515
- self._key, subkey = random.split(self._key, 2)
516
- return random.uniform(
517
- subkey, shape=(n_samples, 1), minval=xmin, maxval=xmax
465
+ return jax.random.uniform(
466
+ keys, shape=(sample_size, 1), minval=xmin, maxval=xmax
518
467
  )
519
- keys = random.split(self._key, self.dim + 1)
520
- self._key = keys[0]
468
+ # keys = jax.random.split(key, self.dim)
521
469
  return jnp.concatenate(
522
470
  [
523
- random.uniform(
524
- keys[i + 1],
525
- (n_samples, 1),
471
+ jax.random.uniform(
472
+ keys[i],
473
+ (sample_size, 1),
526
474
  minval=self.min_pts[i],
527
475
  maxval=self.max_pts[i],
528
476
  )
@@ -531,7 +479,9 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
531
479
  axis=-1,
532
480
  )
533
481
 
534
- def sample_in_omega_border_domain(self, n_samples):
482
+ def sample_in_omega_border_domain(
483
+ self, keys: Key
484
+ ) -> Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None:
535
485
  if self.omega_border_batch_size is None:
536
486
  return None
537
487
  if self.dim == 1:
@@ -543,15 +493,12 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
543
493
  # TODO : find a general & efficient way to sample from the border
544
494
  # (facets) of the hypercube in general dim.
545
495
 
546
- facet_n = n_samples // (2 * self.dim)
547
- keys = random.split(self._key, 5)
548
- self._key = keys[0]
549
- subkeys = keys[1:]
496
+ facet_n = self.nb // (2 * self.dim)
550
497
  xmin = jnp.hstack(
551
498
  [
552
499
  self.min_pts[0] * jnp.ones((facet_n, 1)),
553
- random.uniform(
554
- subkeys[0],
500
+ jax.random.uniform(
501
+ keys[0],
555
502
  (facet_n, 1),
556
503
  minval=self.min_pts[1],
557
504
  maxval=self.max_pts[1],
@@ -561,8 +508,8 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
561
508
  xmax = jnp.hstack(
562
509
  [
563
510
  self.max_pts[0] * jnp.ones((facet_n, 1)),
564
- random.uniform(
565
- subkeys[1],
511
+ jax.random.uniform(
512
+ keys[1],
566
513
  (facet_n, 1),
567
514
  minval=self.min_pts[1],
568
515
  maxval=self.max_pts[1],
@@ -571,8 +518,8 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
571
518
  )
572
519
  ymin = jnp.hstack(
573
520
  [
574
- random.uniform(
575
- subkeys[2],
521
+ jax.random.uniform(
522
+ keys[2],
576
523
  (facet_n, 1),
577
524
  minval=self.min_pts[0],
578
525
  maxval=self.max_pts[0],
@@ -582,8 +529,8 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
582
529
  )
583
530
  ymax = jnp.hstack(
584
531
  [
585
- random.uniform(
586
- subkeys[3],
532
+ jax.random.uniform(
533
+ keys[3],
587
534
  (facet_n, 1),
588
535
  minval=self.min_pts[0],
589
536
  maxval=self.max_pts[0],
@@ -597,54 +544,71 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
597
544
  + f"implemented yet. You are asking for generation in dimension d={self.dim}."
598
545
  )
599
546
 
600
- def generate_data(self):
547
+ def generate_data(self, key: Key) -> tuple[
548
+ Key,
549
+ Float[Array, "n dim"],
550
+ Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None,
551
+ ]:
601
552
  r"""
602
- Construct a complete set of `self.n` :math:`\Omega` points according to the
553
+ Construct a complete set of `self.n` $\Omega$ points according to the
603
554
  specified `self.method`. Also constructs a complete set of `self.nb`
604
- :math:`\partial\Omega` points if `self.omega_border_batch_size` is not
555
+ $\partial\Omega$ points if `self.omega_border_batch_size` is not
605
556
  `None`. If the latter is `None` we set `self.omega_border` to `None`.
606
557
  """
607
-
608
558
  # Generate Omega
609
559
  if self.method == "grid":
610
560
  if self.dim == 1:
611
561
  xmin, xmax = self.min_pts[0], self.max_pts[0]
612
- self.partial = (xmax - xmin) / self.n
562
+ partial = (xmax - xmin) / self.n
613
563
  # shape (n, 1)
614
- self.omega = jnp.arange(xmin, xmax, self.partial)[:, None]
564
+ omega = jnp.arange(xmin, xmax, partial)[:, None]
615
565
  else:
616
- self.partials = [
566
+ partials = [
617
567
  (self.max_pts[i] - self.min_pts[i]) / jnp.sqrt(self.n)
618
568
  for i in range(self.dim)
619
569
  ]
620
570
  xyz_ = jnp.meshgrid(
621
571
  *[
622
- jnp.arange(self.min_pts[i], self.max_pts[i], self.partials[i])
572
+ jnp.arange(self.min_pts[i], self.max_pts[i], partials[i])
623
573
  for i in range(self.dim)
624
574
  ]
625
575
  )
626
576
  xyz_ = [a.reshape((self.n, 1)) for a in xyz_]
627
- self.omega = jnp.concatenate(xyz_, axis=-1)
577
+ omega = jnp.concatenate(xyz_, axis=-1)
628
578
  elif self.method == "uniform":
629
- self.omega = self.sample_in_omega_domain(self.n)
579
+ if self.dim == 1:
580
+ key, subkeys = jax.random.split(key, 2)
581
+ else:
582
+ key, *subkeys = jax.random.split(key, self.dim + 1)
583
+ omega = self.sample_in_omega_domain(subkeys)
630
584
  else:
631
585
  raise ValueError("Method " + self.method + " is not implemented.")
632
586
 
633
587
  # Generate border of omega
634
- self.omega_border = self.sample_in_omega_border_domain(self.nb)
588
+ if self.dim == 2 and self.omega_border_batch_size is not None:
589
+ key, *subkeys = jax.random.split(key, 5)
590
+ else:
591
+ subkeys = None
592
+ omega_border = self.sample_in_omega_border_domain(subkeys)
593
+
594
+ return key, omega, omega_border
635
595
 
636
- def _get_omega_operands(self):
596
+ def _get_omega_operands(
597
+ self,
598
+ ) -> tuple[Key, Float[Array, "n dim"], Int, Int, Float[Array, "n"]]:
637
599
  return (
638
- self._key,
600
+ self.key,
639
601
  self.omega,
640
602
  self.curr_omega_idx,
641
603
  self.omega_batch_size,
642
604
  self.p_omega,
643
605
  )
644
606
 
645
- def inside_batch(self):
607
+ def inside_batch(
608
+ self,
609
+ ) -> tuple["CubicMeshPDEStatio", Float[Array, "omega_batch_size dim"]]:
646
610
  r"""
647
- Return a batch of points in :math:`\Omega`.
611
+ Return a batch of points in $\Omega$.
648
612
  If all the batches have been seen, we reshuffle them,
649
613
  otherwise we just return the next unseen batch.
650
614
  """
@@ -660,38 +624,46 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
660
624
  bstart = self.curr_omega_idx
661
625
  bend = bstart + self.omega_batch_size
662
626
 
663
- (self._key, self.omega, self.curr_omega_idx) = _reset_or_increment(
664
- bend, n_eff, self._get_omega_operands()
627
+ new_attributes = _reset_or_increment(bend, n_eff, self._get_omega_operands())
628
+ new = eqx.tree_at(
629
+ lambda m: (m.key, m.omega, m.curr_omega_idx), self, new_attributes
665
630
  )
666
631
 
667
- # commands below are equivalent to
668
- # return self.omega[i:(i+batch_size), 0:dim]
669
- return jax.lax.dynamic_slice(
670
- self.omega,
671
- start_indices=(self.curr_omega_idx, 0),
672
- slice_sizes=(self.omega_batch_size, self.dim),
632
+ return new, jax.lax.dynamic_slice(
633
+ new.omega,
634
+ start_indices=(new.curr_omega_idx, 0),
635
+ slice_sizes=(new.omega_batch_size, new.dim),
673
636
  )
674
637
 
675
- def _get_omega_border_operands(self):
638
+ def _get_omega_border_operands(
639
+ self,
640
+ ) -> tuple[
641
+ Key, Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None, Int, Int, None
642
+ ]:
676
643
  return (
677
- self._key,
644
+ self.key,
678
645
  self.omega_border,
679
646
  self.curr_omega_border_idx,
680
647
  self.omega_border_batch_size,
681
648
  self.p_border,
682
649
  )
683
650
 
684
- def border_batch(self):
651
+ def border_batch(
652
+ self,
653
+ ) -> tuple[
654
+ "CubicMeshPDEStatio",
655
+ Float[Array, "1 1 2"] | Float[Array, "omega_border_batch_size 2 4"] | None,
656
+ ]:
685
657
  r"""
686
658
  Return
687
659
 
688
660
  - The value `None` if `self.omega_border_batch_size` is `None`.
689
661
 
690
- - a jnp array with two fixed values :math:`(x_{min}, x_{max})` if
662
+ - a jnp array with two fixed values $(x_{min}, x_{max})$ if
691
663
  `self.dim` = 1. There is no sampling here, we return the entire
692
- :math:`\partial\Omega`
664
+ $\partial\Omega$
693
665
 
694
- - a batch of points in :math:`\partial\Omega` otherwise, stacked by
666
+ - a batch of points in $\partial\Omega$ otherwise, stacked by
695
667
  facet on the last axis.
696
668
  If all the batches have been seen, we reshuffle them,
697
669
  otherwise we just return the next unseen batch.
@@ -699,229 +671,160 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
699
671
 
700
672
  """
701
673
  if self.omega_border_batch_size is None:
702
- return None
674
+ return self, None
703
675
  if self.dim == 1:
704
676
  # 1-D case, no randomness : we always return the whole omega border,
705
677
  # i.e. (1, 1, 2) shape jnp.array([[[xmin], [xmax]]]).
706
- return self.omega_border[None, None] # shape is (1, 1, 2)
678
+ return self, self.omega_border[None, None] # shape is (1, 1, 2)
707
679
  bstart = self.curr_omega_border_idx
708
680
  bend = bstart + self.omega_border_batch_size
709
681
 
710
- (
711
- self._key,
712
- self.omega_border,
713
- self.curr_omega_border_idx,
714
- ) = _reset_or_increment(bend, self.nb, self._get_omega_border_operands())
715
-
716
- # commands below are equivalent to
717
- # return self.omega[i:(i+batch_size), 0:dim, 0:nb_facets]
718
- # and nb_facets = 2 * dimension
719
- # but JAX prefer the latter
720
- return jax.lax.dynamic_slice(
721
- self.omega_border,
722
- start_indices=(self.curr_omega_border_idx, 0, 0),
723
- slice_sizes=(self.omega_border_batch_size, self.dim, 2 * self.dim),
682
+ new_attributes = _reset_or_increment(
683
+ bend, self.nb, self._get_omega_border_operands()
724
684
  )
725
-
726
- def get_batch(self):
727
- """
728
- Generic method to return a batch. Here we call `self.inside_batch()`
729
- and `self.border_batch()`
730
- """
731
- return PDEStatioBatch(
732
- inside_batch=self.inside_batch(), border_batch=self.border_batch()
685
+ new = eqx.tree_at(
686
+ lambda m: (m.key, m.omega_border, m.curr_omega_border_idx),
687
+ self,
688
+ new_attributes,
733
689
  )
734
690
 
735
- def tree_flatten(self):
736
- children = (
737
- self._key,
738
- self.omega,
739
- self.omega_border,
740
- self.curr_omega_idx,
741
- self.curr_omega_border_idx,
742
- self.min_pts,
743
- self.max_pts,
744
- self.p_omega,
745
- self.rar_iter_from_last_sampling,
746
- self.rar_iter_nb,
691
+ return new, jax.lax.dynamic_slice(
692
+ new.omega_border,
693
+ start_indices=(new.curr_omega_border_idx, 0, 0),
694
+ slice_sizes=(new.omega_border_batch_size, new.dim, 2 * new.dim),
747
695
  )
748
- aux_data = {
749
- k: vars(self)[k]
750
- for k in [
751
- "n",
752
- "nb",
753
- "omega_batch_size",
754
- "omega_border_batch_size",
755
- "method",
756
- "dim",
757
- "rar_parameters",
758
- "n_start",
759
- ]
760
- }
761
- return (children, aux_data)
762
696
 
763
- @classmethod
764
- def tree_unflatten(cls, aux_data, children):
697
+ def get_batch(self) -> tuple["CubicMeshPDEStatio", PDEStatioBatch]:
765
698
  """
766
- **Note:** When reconstructing the class, we force ``data_exists=True``
767
- in order not to re-generate the data at each flattening and
768
- unflattening that happens e.g. during the gradient descent in the
769
- optimization process
699
+ Generic method to return a batch. Here we call `self.inside_batch()`
700
+ and `self.border_batch()`
770
701
  """
771
- (
772
- key,
773
- omega,
774
- omega_border,
775
- curr_omega_idx,
776
- curr_omega_border_idx,
777
- min_pts,
778
- max_pts,
779
- p_omega,
780
- rar_iter_from_last_sampling,
781
- rar_iter_nb,
782
- ) = children
783
- # force data_exists=True here in order not to re-generate the data
784
- # at each iteration of lax.scan
785
- obj = cls(
786
- key=key,
787
- data_exists=True,
788
- min_pts=min_pts,
789
- max_pts=max_pts,
790
- **aux_data,
791
- )
792
- obj.omega = omega
793
- obj.omega_border = omega_border
794
- obj.curr_omega_idx = curr_omega_idx
795
- obj.curr_omega_border_idx = curr_omega_border_idx
796
- obj.p_omega = p_omega
797
- obj.rar_iter_from_last_sampling = rar_iter_from_last_sampling
798
- obj.rar_iter_nb = rar_iter_nb
799
- return obj
702
+ new, inside_batch = self.inside_batch()
703
+ new, border_batch = new.border_batch()
704
+ return new, PDEStatioBatch(inside_batch=inside_batch, border_batch=border_batch)
800
705
 
801
706
 
802
- @register_pytree_node_class
803
707
  class CubicMeshPDENonStatio(CubicMeshPDEStatio):
804
- """
708
+ r"""
805
709
  A class implementing data generator object for non stationary partial
806
710
  differential equations. Formally, it extends `CubicMeshPDEStatio`
807
711
  to include a temporal batch.
808
712
 
809
-
810
- **Note:** CubicMeshPDENonStatio is jittable. Hence it implements the tree_flatten() and
811
- tree_unflatten methods.
713
+ Parameters
714
+ ----------
715
+ key : Key
716
+ Jax random key to sample new time points and to shuffle batches
717
+ n : Int
718
+ The number of total $\Omega$ points that will be divided in
719
+ batches. Batches are made so that each data point is seen only
720
+ once during 1 epoch.
721
+ nb : Int | None
722
+ The total number of points in $\partial\Omega$.
723
+ Can be `None` not to lose performance generating the border
724
+ batch if they are not used
725
+ nt : Int
726
+ The number of total time points that will be divided in
727
+ batches. Batches are made so that each data point is seen only
728
+ once during 1 epoch.
729
+ omega_batch_size : Int
730
+ The size of the batch of randomly selected points among
731
+ the `n` points.
732
+ omega_border_batch_size : Int | None
733
+ The size of the batch of points randomly selected
734
+ among the `nb` points.
735
+ Can be `None` not to lose performance generating the border
736
+ batch if they are not used
737
+ temporal_batch_size : Int
738
+ The size of the batch of randomly selected points among
739
+ the `nt` points.
740
+ dim : Int
741
+ An integer. dimension of $\Omega$ domain
742
+ min_pts : tuple[tuple[Float, Float], ...]
743
+ A tuple of minimum values of the domain along each dimension. For a sampling
744
+ in `n` dimension, this represents $(x_{1, min}, x_{2,min}, ...,
745
+ x_{n, min})$
746
+ max_pts : tuple[tuple[Float, Float], ...]
747
+ A tuple of maximum values of the domain along each dimension. For a sampling
748
+ in `n` dimension, this represents $(x_{1, max}, x_{2,max}, ...,
749
+ x_{n,max})$
750
+ tmin : float
751
+ The minimum value of the time domain to consider
752
+ tmax : float
753
+ The maximum value of the time domain to consider
754
+ method : str, default="uniform"
755
+ Either `grid` or `uniform`, default is `uniform`.
756
+ The method that generates the `nt` time points. `grid` means
757
+ regularly spaced points over the domain. `uniform` means uniformly
758
+ sampled points over the domain
759
+ rar_parameters : Dict[str, Int], default=None
760
+ Default to None: do not use Residual Adaptative Resampling.
761
+ Otherwise a dictionary with keys. `start_iter`: the iteration at
762
+ which we start the RAR sampling scheme (we first have a burn in
763
+ period). `update_every`: the number of gradient steps taken between
764
+ each appending of collocation points in the RAR algo.
765
+ `sample_size_omega`: the size of the sample from which we will select new
766
+ collocation points. `selected_sample_size_omega`: the number of selected
767
+ points from the sample to be added to the current collocation
768
+ points.
769
+ n_start : Int, default=None
770
+ Defaults to None. The effective size of n used at start time.
771
+ This value must be
772
+ provided when rar_parameters is not None. Otherwise we set internally
773
+ n_start = n and this is hidden from the user.
774
+ In RAR, n_start
775
+ then corresponds to the initial number of omega points we train the PINN.
776
+ nt_start : Int, default=None
777
+ Defaults to None. A RAR hyper-parameter. Same as ``n_start`` but
778
+ for times collocation point. See also ``DataGeneratorODE``
779
+ documentation.
780
+ cartesian_product : Bool, default=True
781
+ Defaults to True. Whether we return the cartesian product of the
782
+ temporal batch with the inside and border batches. If False we just
783
+ return their concatenation.
812
784
  """
813
785
 
814
- def __init__(
815
- self,
816
- key,
817
- n,
818
- nb,
819
- nt,
820
- omega_batch_size,
821
- omega_border_batch_size,
822
- temporal_batch_size,
823
- dim,
824
- min_pts,
825
- max_pts,
826
- tmin,
827
- tmax,
828
- method="grid",
829
- rar_parameters=None,
830
- n_start=None,
831
- nt_start=None,
832
- data_exists=False,
833
- ):
834
- r"""
835
- Parameters
836
- ----------
837
- key
838
- Jax random key to sample new time points and to shuffle batches
839
- n
840
- An integer. The number of total :math:`\Omega` points that will be divided in
841
- batches. Batches are made so that each data point is seen only
842
- once during 1 epoch.
843
- nb
844
- An integer. The total number of points in :math:`\partial\Omega`.
845
- Can be `None` not to lose performance generating the border
846
- batch if they are not used
847
- nt
848
- An integer. The number of total time points that will be divided in
849
- batches. Batches are made so that each data point is seen only
850
- once during 1 epoch.
851
- omega_batch_size
852
- An integer. The size of the batch of randomly selected points among
853
- the `n` points.
854
- omega_border_batch_size
855
- An integer. The size of the batch of points randomly selected
856
- among the `nb` points.
857
- Can be `None` not to lose performance generating the border
858
- batch if they are not used
859
- temporal_batch_size
860
- An integer. The size of the batch of randomly selected points among
861
- the `nt` points.
862
- dim
863
- An integer. dimension of :math:`\Omega` domain
864
- min_pts
865
- A tuple of minimum values of the domain along each dimension. For a sampling
866
- in `n` dimension, this represents :math:`(x_{1, min}, x_{2,min}, ...,
867
- x_{n, min})`
868
- max_pts
869
- A tuple of maximum values of the domain along each dimension. For a sampling
870
- in `n` dimension, this represents :math:`(x_{1, max}, x_{2,max}, ...,
871
- x_{n,max})`
872
- tmin
873
- A float. The minimum value of the time domain to consider
874
- tmax
875
- A float. The maximum value of the time domain to consider
876
- method
877
- Either `grid` or `uniform`, default is `grid`.
878
- The method that generates the `nt` time points. `grid` means
879
- regularly spaced points over the domain. `uniform` means uniformly
880
- sampled points over the domain
881
- rar_parameters
882
- Default to None: do not use Residual Adaptative Resampling.
883
- Otherwise a dictionary with keys. `start_iter`: the iteration at
884
- which we start the RAR sampling scheme (we first have a burn in
885
- period). `update_every`: the number of gradient steps taken between
886
- each appending of collocation points in the RAR algo.
887
- `sample_size_omega`: the size of the sample from which we will select new
888
- collocation points. `selected_sample_size_omega`: the number of selected
889
- points from the sample to be added to the current collocation
890
- points.
891
- n_start
892
- Defaults to None. The effective size of n used at start time.
893
- This value must be
894
- provided when rar_parameters is not None. Otherwise we set internally
895
- n_start = n and this is hidden from the user.
896
- In RAR, n_start
897
- then corresponds to the initial number of omega points we train the PINN.
898
- nt_start
899
- Defaults to None. A RAR hyper-parameter. Same as ``n_start`` but
900
- for times collocation point. See also ``DataGeneratorODE``
901
- documentation.
902
- data_exists
903
- Must be left to `False` when created by the user. Avoids the
904
- regeneration of :math:`\Omega`, :math:`\partial\Omega` and
905
- time points at each pytree flattening and unflattening.
786
+ temporal_batch_size: Int = eqx.field(kw_only=True)
787
+ tmin: Float = eqx.field(kw_only=True)
788
+ tmax: Float = eqx.field(kw_only=True)
789
+ nt: Int = eqx.field(kw_only=True)
790
+ temporal_batch_size: Int = eqx.field(kw_only=True, static=True)
791
+ cartesian_product: Bool = eqx.field(kw_only=True, default=True, static=True)
792
+ nt_start: int = eqx.field(kw_only=True, default=None, static=True)
793
+
794
+ p_times: Array = eqx.field(init=False)
795
+ curr_time_idx: Int = eqx.field(init=False)
796
+ times: Array = eqx.field(init=False)
797
+
798
+ def __post_init__(self):
906
799
  """
907
- super().__init__(
908
- key,
909
- n,
910
- nb,
911
- omega_batch_size,
912
- omega_border_batch_size,
913
- dim,
914
- min_pts,
915
- max_pts,
916
- method,
917
- rar_parameters,
918
- n_start,
919
- data_exists,
920
- )
921
- self.temporal_batch_size = temporal_batch_size
922
- self.tmin = tmin
923
- self.tmax = tmax
924
- self.nt = nt
800
+ Note that neither __init__ or __post_init__ are called when udating a
801
+ Module with eqx.tree_at!
802
+ """
803
+ super().__post_init__() # because __init__ or __post_init__ of Base
804
+ # class is not automatically called
805
+
806
+ if not self.cartesian_product:
807
+ if self.temporal_batch_size != self.omega_batch_size:
808
+ raise ValueError(
809
+ "If stacking is requested between the time and "
810
+ "inside batches of collocation points, self.temporal_batch_size "
811
+ "must then be equal to self.omega_batch_size"
812
+ )
813
+ if (
814
+ self.dim > 1
815
+ and self.omega_border_batch_size is not None
816
+ and self.temporal_batch_size != self.omega_border_batch_size
817
+ ):
818
+ raise ValueError(
819
+ "If dim > 1 and stacking is requested between the time and "
820
+ "inside batches of collocation points, self.temporal_batch_size "
821
+ "must then be equal to self.omega_border_batch_size"
822
+ )
823
+ # Note if self.dim == 1:
824
+ # print(
825
+ # "Cartesian product is not requested but will be "
826
+ # "executed anyway since dim=1"
827
+ # )
925
828
 
926
829
  # Set-up for timewise RAR (some quantity are already set-up by super())
927
830
  (
@@ -929,46 +832,54 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
929
832
  self.p_times,
930
833
  _,
931
834
  _,
932
- ) = _check_and_set_rar_parameters(rar_parameters, n=nt, n_start=nt_start)
933
-
934
- if not self.data_exists:
935
- # Useful when using a lax.scan with pytree
936
- # Optionally can tell JAX not to re-generate data
937
- self.curr_time_idx = 0
938
- self.generate_data_nonstatio()
939
- self._key, self.times, _ = _reset_batch_idx_and_permute(
940
- self._get_time_operands()
941
- )
942
-
943
- def sample_in_time_domain(self, n_samples):
944
- self._key, subkey = random.split(self._key, 2)
945
- return random.uniform(subkey, (n_samples,), minval=self.tmin, maxval=self.tmax)
835
+ ) = _check_and_set_rar_parameters(self.rar_parameters, self.nt, self.nt_start)
836
+
837
+ self.curr_time_idx = jnp.iinfo(jnp.int32).max - self.temporal_batch_size - 1
838
+ self.key, _ = jax.random.split(self.key, 2) # to make it equivalent to
839
+ # the call to _reset_batch_idx_and_permute in legacy DG
840
+ self.key, self.times = self.generate_time_data(self.key)
841
+ # see explaination in DataGeneratorODE for the key
842
+
843
+ def sample_in_time_domain(
844
+ self, key: Key, sample_size: Int = None
845
+ ) -> Float[Array, "nt"]:
846
+ return jax.random.uniform(
847
+ key,
848
+ (self.nt if sample_size is None else sample_size,),
849
+ minval=self.tmin,
850
+ maxval=self.tmax,
851
+ )
946
852
 
947
- def _get_time_operands(self):
853
+ def _get_time_operands(
854
+ self,
855
+ ) -> tuple[Key, Float[Array, "nt"], Int, Int, Float[Array, "nt"]]:
948
856
  return (
949
- self._key,
857
+ self.key,
950
858
  self.times,
951
859
  self.curr_time_idx,
952
860
  self.temporal_batch_size,
953
861
  self.p_times,
954
862
  )
955
863
 
956
- def generate_data_nonstatio(self):
957
- r"""
864
+ def generate_time_data(self, key: Key) -> tuple[Key, Float[Array, "nt"]]:
865
+ """
958
866
  Construct a complete set of `self.nt` time points according to the
959
- specified `self.method`. This completes the `super` function
960
- `generate_data()` which generates :math:`\Omega` and
961
- :math:`\partial\Omega` points.
867
+ specified `self.method`
868
+
869
+ Note that self.times has always size self.nt and not self.nt_start, even
870
+ in RAR scheme, we must allocate all the collocation points
962
871
  """
872
+ key, subkey = jax.random.split(key, 2)
963
873
  if self.method == "grid":
964
- self.partial_times = (self.tmax - self.tmin) / self.nt
965
- self.times = jnp.arange(self.tmin, self.tmax, self.partial_times)
966
- elif self.method == "uniform":
967
- self.times = self.sample_in_time_domain(self.nt)
968
- else:
969
- raise ValueError("Method " + self.method + " is not implemented.")
874
+ partial_times = (self.tmax - self.tmin) / self.nt
875
+ return key, jnp.arange(self.tmin, self.tmax, partial_times)
876
+ if self.method == "uniform":
877
+ return key, self.sample_in_time_domain(subkey)
878
+ raise ValueError("Method " + self.method + " is not implemented.")
970
879
 
971
- def temporal_batch(self):
880
+ def temporal_batch(
881
+ self,
882
+ ) -> tuple["CubicMeshPDENonStatio", Float[Array, "temporal_batch_size"]]:
972
883
  """
973
884
  Return a batch of time points. If all the batches have been seen, we
974
885
  reshuffle them, otherwise we just return the next unseen batch.
@@ -979,233 +890,344 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
979
890
  # Compute the effective number of used collocation points
980
891
  if self.rar_parameters is not None:
981
892
  nt_eff = (
982
- self.n_start
893
+ self.nt_start
983
894
  + self.rar_iter_nb * self.rar_parameters["selected_sample_size_times"]
984
895
  )
985
896
  else:
986
897
  nt_eff = self.nt
987
898
 
988
- (self._key, self.times, self.curr_time_idx) = _reset_or_increment(
989
- bend, nt_eff, self._get_time_operands()
899
+ new_attributes = _reset_or_increment(bend, nt_eff, self._get_time_operands())
900
+ new = eqx.tree_at(
901
+ lambda m: (m.key, m.times, m.curr_time_idx), self, new_attributes
990
902
  )
991
903
 
992
- # commands below are equivalent to
993
- # return self.times[i:(i+t_batch_size)]
994
- # but JAX prefer the latter
995
- return jax.lax.dynamic_slice(
996
- self.times,
997
- start_indices=(self.curr_time_idx,),
998
- slice_sizes=(self.temporal_batch_size,),
904
+ return new, jax.lax.dynamic_slice(
905
+ new.times,
906
+ start_indices=(new.curr_time_idx,),
907
+ slice_sizes=(new.temporal_batch_size,),
999
908
  )
1000
909
 
1001
- def get_batch(self):
910
+ def get_batch(self) -> tuple["CubicMeshPDENonStatio", PDENonStatioBatch]:
1002
911
  """
1003
912
  Generic method to return a batch. Here we call `self.inside_batch()`,
1004
913
  `self.border_batch()` and `self.temporal_batch()`
1005
914
  """
1006
- return PDENonStatioBatch(
1007
- inside_batch=self.inside_batch(),
1008
- border_batch=self.border_batch(),
1009
- temporal_batch=self.temporal_batch(),
915
+ new, x = self.inside_batch()
916
+ new, dx = new.border_batch()
917
+ new, t = new.temporal_batch()
918
+ t = t.reshape(new.temporal_batch_size, 1)
919
+
920
+ if new.cartesian_product:
921
+ t_x = make_cartesian_product(t, x)
922
+ else:
923
+ t_x = jnp.concatenate([t, x], axis=1)
924
+
925
+ if dx is not None:
926
+ t_ = t.reshape(new.temporal_batch_size, 1, 1)
927
+ t_ = jnp.repeat(t_, dx.shape[-1], axis=2)
928
+ if new.cartesian_product or new.dim == 1:
929
+ t_dx = make_cartesian_product(t_, dx)
930
+ else:
931
+ t_dx = jnp.concatenate([t_, dx], axis=1)
932
+ else:
933
+ t_dx = None
934
+
935
+ return new, PDENonStatioBatch(
936
+ times_x_inside_batch=t_x, times_x_border_batch=t_dx
1010
937
  )
1011
938
 
1012
- def tree_flatten(self):
1013
- children = (
1014
- self._key,
1015
- self.omega,
1016
- self.omega_border,
1017
- self.times,
1018
- self.curr_omega_idx,
1019
- self.curr_omega_border_idx,
1020
- self.curr_time_idx,
1021
- self.min_pts,
1022
- self.max_pts,
1023
- self.tmin,
1024
- self.tmax,
1025
- self.p_times,
1026
- self.p_omega,
1027
- self.rar_iter_from_last_sampling,
1028
- self.rar_iter_nb,
939
+
940
+ class DataGeneratorObservations(eqx.Module):
941
+ r"""
942
+ Despite the class name, it is rather a dataloader from user provided
943
+ observations that will be used for the observations loss
944
+
945
+ Parameters
946
+ ----------
947
+ key : Key
948
+ Jax random key to shuffle batches
949
+ obs_batch_size : Int
950
+ The size of the batch of randomly selected points among
951
+ the `n` points. `obs_batch_size` will be the same for all
952
+ elements of the return observation dict batch.
953
+ NOTE: no check is done BUT users should be careful that
954
+ `obs_batch_size` must be equal to `temporal_batch_size` or
955
+ `omega_batch_size` or the product of both. In the first case, the
956
+ present DataGeneratorObservations instance complements an ODEBatch,
957
+ PDEStatioBatch or a PDENonStatioBatch (with self.cartesian_product
958
+ = False). In the second case, `obs_batch_size` =
959
+ `temporal_batch_size * omega_batch_size` if the present
960
+ DataGeneratorParameter complements a PDENonStatioBatch
961
+ with self.cartesian_product = True
962
+ observed_pinn_in : Float[Array, "n_obs nb_pinn_in"]
963
+ Observed values corresponding to the input of the PINN
964
+ (eg. the time at which we recorded the observations). The first
965
+ dimension must corresponds to the number of observed_values.
966
+ The second dimension depends on the input dimension of the PINN,
967
+ that is `1` for ODE, `n_dim_x` for stationnary PDE and `n_dim_x + 1`
968
+ for non-stationnary PDE.
969
+ observed_values : Float[Array, "n_obs, nb_pinn_out"]
970
+ Observed values that the PINN should learn to fit. The first
971
+ dimension must be aligned with observed_pinn_in.
972
+ observed_eq_params : Dict[str, Float[Array, "n_obs 1"]], default={}
973
+ A dict with keys corresponding to
974
+ the parameter name. The keys must match the keys in
975
+ `params["eq_params"]`. The values are jnp.array with 2 dimensions
976
+ with values corresponding to the parameter value for which we also
977
+ have observed_pinn_in and observed_values. Hence the first
978
+ dimension must be aligned with observed_pinn_in and observed_values.
979
+ Optional argument.
980
+ sharding_device : jax.sharding.Sharding, default=None
981
+ Default None. An optional sharding object to constraint the storage
982
+ of observed inputs, values and parameters. Typically, a
983
+ SingleDeviceSharding(cpu_device) to avoid loading on GPU huge
984
+ datasets of observations. Note that computations for **batches**
985
+ can still be performed on other devices (*e.g.* GPU, TPU or
986
+ any pre-defined Sharding) thanks to the `obs_batch_sharding`
987
+ arguments of `jinns.solve()`. Read the docs for more info.
988
+ """
989
+
990
+ key: Key
991
+ obs_batch_size: Int = eqx.field(static=True)
992
+ observed_pinn_in: Float[Array, "n_obs nb_pinn_in"]
993
+ observed_values: Float[Array, "n_obs nb_pinn_out"]
994
+ observed_eq_params: Dict[str, Float[Array, "n_obs 1"]] = eqx.field(
995
+ static=True, default_factory=lambda: {}
996
+ )
997
+ sharding_device: jax.sharding.Sharding = eqx.field(static=True, default=None)
998
+
999
+ n: Int = eqx.field(init=False)
1000
+ curr_idx: Int = eqx.field(init=False)
1001
+ indices: Array = eqx.field(init=False)
1002
+
1003
+ def __post_init__(self):
1004
+ if self.observed_pinn_in.shape[0] != self.observed_values.shape[0]:
1005
+ raise ValueError(
1006
+ "self.observed_pinn_in and self.observed_values must have same first axis"
1007
+ )
1008
+ for _, v in self.observed_eq_params.items():
1009
+ if v.shape[0] != self.observed_pinn_in.shape[0]:
1010
+ raise ValueError(
1011
+ "self.observed_pinn_in and the values of"
1012
+ " self.observed_eq_params must have the same first axis"
1013
+ )
1014
+ if len(self.observed_pinn_in.shape) == 1:
1015
+ self.observed_pinn_in = self.observed_pinn_in[:, None]
1016
+ if len(self.observed_pinn_in.shape) > 2:
1017
+ raise ValueError("self.observed_pinn_in must have 2 dimensions")
1018
+ if len(self.observed_values.shape) == 1:
1019
+ self.observed_values = self.observed_values[:, None]
1020
+ if len(self.observed_values.shape) > 2:
1021
+ raise ValueError("self.observed_values must have 2 dimensions")
1022
+ for k, v in self.observed_eq_params.items():
1023
+ if len(v.shape) == 1:
1024
+ self.observed_eq_params[k] = v[:, None]
1025
+ if len(v.shape) > 2:
1026
+ raise ValueError(
1027
+ "Each value of observed_eq_params must have 2 dimensions"
1028
+ )
1029
+
1030
+ self.n = self.observed_pinn_in.shape[0]
1031
+
1032
+ if self.sharding_device is not None:
1033
+ self.observed_pinn_in = jax.lax.with_sharding_constraint(
1034
+ self.observed_pinn_in, self.sharding_device
1035
+ )
1036
+ self.observed_values = jax.lax.with_sharding_constraint(
1037
+ self.observed_values, self.sharding_device
1038
+ )
1039
+ self.observed_eq_params = jax.lax.with_sharding_constraint(
1040
+ self.observed_eq_params, self.sharding_device
1041
+ )
1042
+
1043
+ self.curr_idx = jnp.iinfo(jnp.int32).max - self.obs_batch_size - 1
1044
+ # For speed and to avoid duplicating data what is really
1045
+ # shuffled is a vector of indices
1046
+ if self.sharding_device is not None:
1047
+ self.indices = jax.lax.with_sharding_constraint(
1048
+ jnp.arange(self.n), self.sharding_device
1049
+ )
1050
+ else:
1051
+ self.indices = jnp.arange(self.n)
1052
+
1053
+ # recall post_init is the only place with _init_ where we can set
1054
+ # self attribute in a in-place way
1055
+ self.key, _ = jax.random.split(self.key, 2) # to make it equivalent to
1056
+ # the call to _reset_batch_idx_and_permute in legacy DG
1057
+
1058
+ def _get_operands(self) -> tuple[Key, Int[Array, "n"], Int, Int, None]:
1059
+ return (
1060
+ self.key,
1061
+ self.indices,
1062
+ self.curr_idx,
1063
+ self.obs_batch_size,
1064
+ None,
1029
1065
  )
1030
- aux_data = {
1031
- k: vars(self)[k]
1032
- for k in [
1033
- "n",
1034
- "nb",
1035
- "nt",
1036
- "omega_batch_size",
1037
- "omega_border_batch_size",
1038
- "temporal_batch_size",
1039
- "method",
1040
- "dim",
1041
- "rar_parameters",
1042
- "n_start",
1043
- "nt_start",
1044
- ]
1045
- }
1046
- return (children, aux_data)
1047
1066
 
1048
- @classmethod
1049
- def tree_unflatten(cls, aux_data, children):
1067
+ def obs_batch(
1068
+ self,
1069
+ ) -> tuple[
1070
+ "DataGeneratorObservations", Dict[str, Float[Array, "obs_batch_size dim"]]
1071
+ ]:
1050
1072
  """
1051
- **Note:** When reconstructing the class, we force ``data_exists=True``
1052
- in order not to re-generate the data at each flattening and
1053
- unflattening that happens e.g. during the gradient descent in the
1054
- optimization process
1073
+ Return a dictionary with (keys, values): (pinn_in, a mini batch of pinn
1074
+ inputs), (obs, a mini batch of corresponding observations), (eq_params,
1075
+ a dictionary with entry names found in `params["eq_params"]` and values
1076
+ giving the correspond parameter value for the couple
1077
+ (input, observation) mentioned before).
1078
+ It can also be a dictionary of dictionaries as described above if
1079
+ observed_pinn_in, observed_values, etc. are dictionaries with keys
1080
+ representing the PINNs.
1055
1081
  """
1056
- (
1057
- key,
1058
- omega,
1059
- omega_border,
1060
- times,
1061
- curr_omega_idx,
1062
- curr_omega_border_idx,
1063
- curr_time_idx,
1064
- min_pts,
1065
- max_pts,
1066
- tmin,
1067
- tmax,
1068
- p_times,
1069
- p_omega,
1070
- rar_iter_from_last_sampling,
1071
- rar_iter_nb,
1072
- ) = children
1073
- obj = cls(
1074
- key=key,
1075
- data_exists=True,
1076
- min_pts=min_pts,
1077
- max_pts=max_pts,
1078
- tmin=tmin,
1079
- tmax=tmax,
1080
- **aux_data,
1082
+
1083
+ new_attributes = _reset_or_increment(
1084
+ self.curr_idx + self.obs_batch_size, self.n, self._get_operands()
1081
1085
  )
1082
- obj.omega = omega
1083
- obj.omega_border = omega_border
1084
- obj.times = times
1085
- obj.curr_omega_idx = curr_omega_idx
1086
- obj.curr_omega_border_idx = curr_omega_border_idx
1087
- obj.curr_time_idx = curr_time_idx
1088
- obj.p_times = p_times
1089
- obj.p_omega = p_omega
1090
- obj.rar_iter_from_last_sampling = rar_iter_from_last_sampling
1091
- obj.rar_iter_nb = rar_iter_nb
1092
-
1093
- return obj
1094
-
1095
-
1096
- @register_pytree_node_class
1097
- class DataGeneratorParameter:
1098
- """
1099
- A data generator for additional unidimensional parameter(s)
1100
- """
1086
+ new = eqx.tree_at(
1087
+ lambda m: (m.key, m.indices, m.curr_idx), self, new_attributes
1088
+ )
1089
+
1090
+ minib_indices = jax.lax.dynamic_slice(
1091
+ new.indices,
1092
+ start_indices=(new.curr_idx,),
1093
+ slice_sizes=(new.obs_batch_size,),
1094
+ )
1095
+
1096
+ obs_batch = {
1097
+ "pinn_in": jnp.take(
1098
+ new.observed_pinn_in, minib_indices, unique_indices=True, axis=0
1099
+ ),
1100
+ "val": jnp.take(
1101
+ new.observed_values, minib_indices, unique_indices=True, axis=0
1102
+ ),
1103
+ "eq_params": jax.tree_util.tree_map(
1104
+ lambda a: jnp.take(a, minib_indices, unique_indices=True, axis=0),
1105
+ new.observed_eq_params,
1106
+ ),
1107
+ }
1108
+ return new, obs_batch
1101
1109
 
1102
- def __init__(
1110
+ def get_batch(
1103
1111
  self,
1104
- key,
1105
- n,
1106
- param_batch_size,
1107
- param_ranges=None,
1108
- method="grid",
1109
- user_data=None,
1110
- data_exists=False,
1111
- ):
1112
- r"""
1113
- Parameters
1114
- ----------
1115
- key
1116
- Jax random key to sample new time points and to shuffle batches
1117
- or a dict of Jax random keys with key entries from param_ranges
1118
- n
1119
- An integer. The number of total points that will be divided in
1120
- batches. Batches are made so that each data point is seen only
1121
- once during 1 epoch.
1122
- param_batch_size
1123
- An integer. The size of the batch of randomly selected points among
1124
- the `n` points. `param_batch_size` will be the same for all the
1125
- additional batch(es) of parameter(s). `param_batch_size` must be
1126
- equal to `temporal_batch_size` or `omega_batch_size` or the product
1127
- of both whether the present DataGeneratorParameter instance
1128
- complements and ODEBatch, a PDEStatioBatch or a PDENonStatioBatch,
1129
- respectively.
1130
- param_ranges
1131
- A dict. A dict of tuples (min, max), which
1132
- reprensents the range of real numbers where to sample batches (of
1133
- length `param_batch_size` among `n` points).
1134
- The key corresponds to the parameter name. The keys must match the
1135
- keys in `params["eq_params"]`.
1136
- By providing several entries in this dictionary we can sample
1137
- an arbitrary number of parameters.
1138
- __Note__ that we currently only support unidimensional parameters
1139
- method
1140
- Either `grid` or `uniform`, default is `grid`. `grid` means
1141
- regularly spaced points over the domain. `uniform` means uniformly
1142
- sampled points over the domain
1143
- data_exists
1144
- Must be left to `False` when created by the user. Avoids the
1145
- regeneration of :math:`\Omega`, :math:`\partial\Omega` and
1146
- time points at each pytree flattening and unflattening.
1147
- user_data
1148
- A dictionary containing user-provided data for parameters.
1149
- As for `param_ranges`, the key corresponds to the parameter name,
1150
- the keys must match the keys in `params["eq_params"]` and only
1151
- unidimensional arrays are supported. Therefore, the jnp arrays
1152
- found at `user_data[k]` must have shape `(n, 1)` or `(n,)`.
1153
- Note that if the same key appears in `param_ranges` and `user_data`
1154
- priority goes for the content in `user_data`.
1155
- Defaults to None.
1112
+ ) -> tuple[
1113
+ "DataGeneratorObservations", Dict[str, Float[Array, "obs_batch_size dim"]]
1114
+ ]:
1115
+ """
1116
+ Generic method to return a batch
1156
1117
  """
1157
- self.data_exists = data_exists
1158
- self.method = method
1118
+ return self.obs_batch()
1119
+
1120
+
1121
+ class DataGeneratorParameter(eqx.Module):
1122
+ r"""
1123
+ A data generator for additional unidimensional parameter(s)
1159
1124
 
1160
- if n < param_batch_size:
1125
+ Parameters
1126
+ ----------
1127
+ keys : Key | Dict[str, Key]
1128
+ Jax random key to sample new time points and to shuffle batches
1129
+ or a dict of Jax random keys with key entries from param_ranges
1130
+ n : Int
1131
+ The number of total points that will be divided in
1132
+ batches. Batches are made so that each data point is seen only
1133
+ once during 1 epoch.
1134
+ param_batch_size : Int
1135
+ The size of the batch of randomly selected points among
1136
+ the `n` points. `param_batch_size` will be the same for all
1137
+ additional batch of parameter.
1138
+ NOTE: no check is done BUT users should be careful that
1139
+ `param_batch_size` must be equal to `temporal_batch_size` or
1140
+ `omega_batch_size` or the product of both. In the first case, the
1141
+ present DataGeneratorParameter instance complements an ODEBatch, a
1142
+ PDEStatioBatch or a PDENonStatioBatch (with self.cartesian_product
1143
+ = False). In the second case, `param_batch_size` =
1144
+ `temporal_batch_size * omega_batch_size` if the present
1145
+ DataGeneratorParameter complements a PDENonStatioBatch
1146
+ with self.cartesian_product = True
1147
+ param_ranges : Dict[str, tuple[Float, Float] | None, default={}
1148
+ A dict. A dict of tuples (min, max), which
1149
+ reprensents the range of real numbers where to sample batches (of
1150
+ length `param_batch_size` among `n` points).
1151
+ The key corresponds to the parameter name. The keys must match the
1152
+ keys in `params["eq_params"]`.
1153
+ By providing several entries in this dictionary we can sample
1154
+ an arbitrary number of parameters.
1155
+ **Note** that we currently only support unidimensional parameters.
1156
+ This argument can be done if we only use `user_data`.
1157
+ method : str, default="uniform"
1158
+ Either `grid` or `uniform`, default is `uniform`. `grid` means
1159
+ regularly spaced points over the domain. `uniform` means uniformly
1160
+ sampled points over the domain
1161
+ user_data : Dict[str, Float[Array, "n"]] | None, default={}
1162
+ A dictionary containing user-provided data for parameters.
1163
+ As for `param_ranges`, the key corresponds to the parameter name,
1164
+ the keys must match the keys in `params["eq_params"]` and only
1165
+ unidimensional arrays are supported. Therefore, the jnp arrays
1166
+ found at `user_data[k]` must have shape `(n, 1)` or `(n,)`.
1167
+ Note that if the same key appears in `param_ranges` and `user_data`
1168
+ priority goes for the content in `user_data`.
1169
+ Defaults to None.
1170
+ """
1171
+
1172
+ keys: Key | Dict[str, Key]
1173
+ n: Int
1174
+ param_batch_size: Int = eqx.field(static=True)
1175
+ param_ranges: Dict[str, tuple[Float, Float]] = eqx.field(
1176
+ static=True, default_factory=lambda: {}
1177
+ )
1178
+ method: str = eqx.field(static=True, default="uniform")
1179
+ user_data: Dict[str, Float[Array, "n"]] | None = eqx.field(
1180
+ static=True, default_factory=lambda: {}
1181
+ )
1182
+
1183
+ curr_param_idx: Dict[str, Int] = eqx.field(init=False)
1184
+ param_n_samples: Dict[str, Array] = eqx.field(init=False)
1185
+
1186
+ def __post_init__(self):
1187
+ if self.user_data is None:
1188
+ self.user_data = {}
1189
+ if self.param_ranges is None:
1190
+ self.param_ranges = {}
1191
+ if self.n < self.param_batch_size:
1161
1192
  raise ValueError(
1162
- f"Number of data points ({n}) is smaller than the"
1163
- f"number of batch points ({param_batch_size})."
1193
+ f"Number of data points ({self.n}) is smaller than the"
1194
+ f"number of batch points ({self.param_batch_size})."
1195
+ )
1196
+ if not isinstance(self.keys, dict):
1197
+ all_keys = set().union(self.param_ranges, self.user_data)
1198
+ self.keys = dict(zip(all_keys, jax.random.split(self.keys, len(all_keys))))
1199
+
1200
+ self.curr_param_idx = {}
1201
+ for k in self.keys.keys():
1202
+ self.curr_param_idx[k] = (
1203
+ jnp.iinfo(jnp.int32).max - self.param_batch_size - 1
1164
1204
  )
1165
1205
 
1166
- if user_data is None:
1167
- user_data = {}
1168
- if param_ranges is None:
1169
- param_ranges = {}
1170
- if not isinstance(key, dict):
1171
- all_keys = set().union(param_ranges, user_data)
1172
- self._keys = dict(zip(all_keys, jax.random.split(key, len(all_keys))))
1173
- else:
1174
- self._keys = key
1175
- self.n = n
1176
- self.param_batch_size = param_batch_size
1177
- self.param_ranges = param_ranges
1178
- self.user_data = user_data
1179
-
1180
- if not self.data_exists:
1181
- self.generate_data()
1182
- # The previous call to self.generate_data() has created
1183
- # the dict self.param_n_samples and then we will only use this one
1184
- # because it has merged the scattered data between `user_data` and
1185
- # `param_ranges`
1186
- self.curr_param_idx = {}
1187
- for k in self.param_n_samples.keys():
1188
- self.curr_param_idx[k] = 0
1189
- (
1190
- self._keys[k],
1191
- self.param_n_samples[k],
1192
- _,
1193
- ) = _reset_batch_idx_and_permute(self._get_param_operands(k))
1194
-
1195
- def generate_data(self):
1206
+ # The call to self.generate_data() creates
1207
+ # the dict self.param_n_samples and then we will only use this one
1208
+ # because it merges the scattered data between `user_data` and
1209
+ # `param_ranges`
1210
+ self.keys, self.param_n_samples = self.generate_data(self.keys)
1211
+
1212
+ def generate_data(
1213
+ self, keys: Dict[str, Key]
1214
+ ) -> tuple[Dict[str, Key], Dict[str, Float[Array, "n"]]]:
1196
1215
  """
1197
1216
  Generate parameter samples, either through generation
1198
1217
  or using user-provided data.
1199
1218
  """
1200
- self.param_n_samples = {}
1219
+ param_n_samples = {}
1201
1220
 
1202
1221
  all_keys = set().union(self.param_ranges, self.user_data)
1203
1222
  for k in all_keys:
1204
- if self.user_data and k in self.user_data:
1223
+ if (
1224
+ self.user_data
1225
+ and k in self.user_data.keys() # pylint: disable=no-member
1226
+ ):
1205
1227
  if self.user_data[k].shape == (self.n, 1):
1206
- self.param_n_samples[k] = self.user_data[k]
1228
+ param_n_samples[k] = self.user_data[k]
1207
1229
  if self.user_data[k].shape == (self.n,):
1208
- self.param_n_samples[k] = self.user_data[k][:, None]
1230
+ param_n_samples[k] = self.user_data[k][:, None]
1209
1231
  else:
1210
1232
  raise ValueError(
1211
1233
  "Wrong shape for user provided parameters"
@@ -1214,23 +1236,25 @@ class DataGeneratorParameter:
1214
1236
  else:
1215
1237
  if self.method == "grid":
1216
1238
  xmin, xmax = self.param_ranges[k][0], self.param_ranges[k][1]
1217
- self.partial = (xmax - xmin) / self.n
1239
+ partial = (xmax - xmin) / self.n
1218
1240
  # shape (n, 1)
1219
- self.param_n_samples[k] = jnp.arange(xmin, xmax, self.partial)[
1220
- :, None
1221
- ]
1241
+ param_n_samples[k] = jnp.arange(xmin, xmax, partial)[:, None]
1222
1242
  elif self.method == "uniform":
1223
1243
  xmin, xmax = self.param_ranges[k][0], self.param_ranges[k][1]
1224
- self._keys[k], subkey = random.split(self._keys[k], 2)
1225
- self.param_n_samples[k] = random.uniform(
1244
+ keys[k], subkey = jax.random.split(keys[k], 2)
1245
+ param_n_samples[k] = jax.random.uniform(
1226
1246
  subkey, shape=(self.n, 1), minval=xmin, maxval=xmax
1227
1247
  )
1228
1248
  else:
1229
1249
  raise ValueError("Method " + self.method + " is not implemented.")
1230
1250
 
1231
- def _get_param_operands(self, k):
1251
+ return keys, param_n_samples
1252
+
1253
+ def _get_param_operands(
1254
+ self, k: str
1255
+ ) -> tuple[Key, Float[Array, "n"], Int, Int, None]:
1232
1256
  return (
1233
- self._keys[k],
1257
+ self.keys[k],
1234
1258
  self.param_n_samples[k],
1235
1259
  self.curr_param_idx[k],
1236
1260
  self.param_batch_size,
@@ -1255,26 +1279,28 @@ class DataGeneratorParameter:
1255
1279
  _reset_or_increment_wrapper,
1256
1280
  self.param_n_samples,
1257
1281
  self.curr_param_idx,
1258
- self._keys,
1282
+ self.keys,
1259
1283
  )
1260
1284
  # we must transpose the pytrees because keys are merged in res
1261
1285
  # https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#transposing-trees
1262
- (
1263
- self._keys,
1264
- self.param_n_samples,
1265
- self.curr_param_idx,
1266
- ) = jax.tree_util.tree_transpose(
1267
- jax.tree_util.tree_structure(self._keys),
1286
+ new_attributes = jax.tree_util.tree_transpose(
1287
+ jax.tree_util.tree_structure(self.keys),
1268
1288
  jax.tree_util.tree_structure([0, 0, 0]),
1269
1289
  res,
1270
1290
  )
1271
1291
 
1272
- return jax.tree_util.tree_map(
1292
+ new = eqx.tree_at(
1293
+ lambda m: (m.keys, m.param_n_samples, m.curr_param_idx),
1294
+ self,
1295
+ new_attributes,
1296
+ )
1297
+
1298
+ return new, jax.tree_util.tree_map(
1273
1299
  lambda p, q: jax.lax.dynamic_slice(
1274
- p, start_indices=(q, 0), slice_sizes=(self.param_batch_size, 1)
1300
+ p, start_indices=(q, 0), slice_sizes=(new.param_batch_size, 1)
1275
1301
  ),
1276
- self.param_n_samples,
1277
- self.curr_param_idx,
1302
+ new.param_n_samples,
1303
+ new.curr_param_idx,
1278
1304
  )
1279
1305
 
1280
1306
  def get_batch(self):
@@ -1283,246 +1309,9 @@ class DataGeneratorParameter:
1283
1309
  """
1284
1310
  return self.param_batch()
1285
1311
 
1286
- def tree_flatten(self):
1287
- children = (
1288
- self._keys,
1289
- self.param_n_samples,
1290
- self.curr_param_idx,
1291
- )
1292
- aux_data = {
1293
- k: vars(self)[k]
1294
- for k in ["n", "param_batch_size", "method", "param_ranges", "user_data"]
1295
- }
1296
- return (children, aux_data)
1297
-
1298
- @classmethod
1299
- def tree_unflatten(cls, aux_data, children):
1300
- (
1301
- keys,
1302
- param_n_samples,
1303
- curr_param_idx,
1304
- ) = children
1305
- obj = cls(
1306
- key=keys,
1307
- data_exists=True,
1308
- **aux_data,
1309
- )
1310
- obj.param_n_samples = param_n_samples
1311
- obj.curr_param_idx = curr_param_idx
1312
- return obj
1313
-
1314
-
1315
- @register_pytree_node_class
1316
- class DataGeneratorObservations:
1317
- """
1318
- Despite the class name, it is rather a dataloader from user provided
1319
- observations that will be used for the observations loss
1320
- """
1321
-
1322
- def __init__(
1323
- self,
1324
- key,
1325
- obs_batch_size,
1326
- observed_pinn_in,
1327
- observed_values,
1328
- observed_eq_params=None,
1329
- data_exists=False,
1330
- sharding_device=None,
1331
- ):
1332
- r"""
1333
- Parameters
1334
- ----------
1335
- key
1336
- Jax random key to sample new time points and to shuffle batches
1337
- obs_batch_size
1338
- An integer. The size of the batch of randomly selected observations
1339
- `obs_batch_size` will be the same for all the
1340
- elements of the obs dict. `obs_batch_size` must be
1341
- equal to `temporal_batch_size` or `omega_batch_size` or the product
1342
- of both whether the present DataGeneratorParameter instance
1343
- complements and ODEBatch, a PDEStatioBatch or a PDENonStatioBatch,
1344
- respectively.
1345
- observed_pinn_in
1346
- A jnp.array with 2 dimensions.
1347
- Observed values corresponding to the input of the PINN
1348
- (eg. the time at which we recorded the observations). The first
1349
- dimension must corresponds to the number of observed_values and
1350
- observed_eq_params. The second dimension depends on the input dimension of the PINN, that is `1` for ODE, `n_dim_x` for stationnary PDE and `n_dim_x + 1` for non-stationnary PDE.
1351
- observed_values
1352
- A jnp.array with 2 dimensions.
1353
- Observed values that the PINN should learn to fit. The first dimension must be aligned with observed_pinn_in and the values of observed_eq_params.
1354
- observed_eq_params
1355
- Optional. Default is None. A dict with keys corresponding to the
1356
- parameter name. The keys must match the keys in
1357
- `params["eq_params"]`. The values are jnp.array with 2 dimensions
1358
- with values corresponding to the parameter value for which we also
1359
- have observed_pinn_in and observed_values. Hence the first
1360
- dimension must be aligned with observed_pinn_in and observed_values.
1361
- data_exists
1362
- Must be left to `False` when created by the user. Avoids the
1363
- resetting of curr_idx at each pytree flattening and unflattening.
1364
- sharding_device
1365
- Default None. An optional sharding object to constraint the storage
1366
- of observed inputs, values and parameters. Typically, a
1367
- SingleDeviceSharding(cpu_device) to avoid loading on GPU huge
1368
- datasets of observations. Note that computations for **batches**
1369
- can still be performed on other devices (*e.g.* GPU, TPU or
1370
- any pre-defined Sharding) thanks to the `obs_batch_sharding`
1371
- arguments of `jinns.solve()`. Read the docs for more info.
1372
-
1373
- """
1374
- if observed_eq_params is None:
1375
- observed_eq_params = {}
1376
-
1377
- if not data_exists:
1378
- self.observed_eq_params = observed_eq_params.copy()
1379
- else:
1380
- # avoid copying when in flatten/unflatten
1381
- self.observed_eq_params = observed_eq_params
1382
-
1383
- if observed_pinn_in.shape[0] != observed_values.shape[0]:
1384
- raise ValueError(
1385
- "observed_pinn_in and observed_values must have same first axis"
1386
- )
1387
- for _, v in self.observed_eq_params.items():
1388
- if v.shape[0] != observed_pinn_in.shape[0]:
1389
- raise ValueError(
1390
- "observed_pinn_in and the values of"
1391
- " observed_eq_params must have the same first axis"
1392
- )
1393
- if len(observed_pinn_in.shape) == 1:
1394
- observed_pinn_in = observed_pinn_in[:, None]
1395
- if len(observed_pinn_in.shape) > 2:
1396
- raise ValueError("observed_pinn_in must have 2 dimensions")
1397
- if len(observed_values.shape) == 1:
1398
- observed_values = observed_values[:, None]
1399
- if len(observed_values.shape) > 2:
1400
- raise ValueError("observed_values must have 2 dimensions")
1401
- for k, v in self.observed_eq_params.items():
1402
- if len(v.shape) == 1:
1403
- self.observed_eq_params[k] = v[:, None]
1404
- if len(v.shape) > 2:
1405
- raise ValueError(
1406
- "Each value of observed_eq_params must have 2 dimensions"
1407
- )
1408
-
1409
- self.n = observed_pinn_in.shape[0]
1410
- self._key = key
1411
- self.obs_batch_size = obs_batch_size
1412
-
1413
- self.data_exists = data_exists
1414
- if not self.data_exists and sharding_device is not None:
1415
- self.observed_pinn_in = jax.lax.with_sharding_constraint(
1416
- observed_pinn_in, sharding_device
1417
- )
1418
- self.observed_values = jax.lax.with_sharding_constraint(
1419
- observed_values, sharding_device
1420
- )
1421
- self.observed_eq_params = jax.lax.with_sharding_constraint(
1422
- self.observed_eq_params, sharding_device
1423
- )
1424
- else:
1425
- self.observed_pinn_in = observed_pinn_in
1426
- self.observed_values = observed_values
1427
-
1428
- if not self.data_exists:
1429
- self.curr_idx = 0
1430
- # NOTE for speed and to avoid duplicating data what is really
1431
- # shuffled is a vector of indices
1432
- indices = jnp.arange(self.n)
1433
- if sharding_device is not None:
1434
- self.indices = jax.lax.with_sharding_constraint(
1435
- indices, sharding_device
1436
- )
1437
- else:
1438
- self.indices = indices
1439
- self._key, self.indices, _ = _reset_batch_idx_and_permute(
1440
- self._get_operands()
1441
- )
1442
-
1443
- def _get_operands(self):
1444
- return (
1445
- self._key,
1446
- self.indices,
1447
- self.curr_idx,
1448
- self.obs_batch_size,
1449
- None,
1450
- )
1451
1312
 
1452
- def obs_batch(self):
1453
- """
1454
- Return a dictionary with (keys, values): (pinn_in, a mini batch of pinn
1455
- inputs), (obs, a mini batch of corresponding observations), (eq_params,
1456
- a dictionary with entry names found in `params["eq_params"]` and values
1457
- giving the correspond parameter value for the couple
1458
- (input, observation) mentioned before).
1459
- It can also be a dictionary of dictionaries as described above if
1460
- observed_pinn_in, observed_values, etc. are dictionaries with keys
1461
- representing the PINNs.
1462
- """
1463
-
1464
- (self._key, self.indices, self.curr_idx) = _reset_or_increment(
1465
- self.curr_idx + self.obs_batch_size, self.n, self._get_operands()
1466
- )
1467
-
1468
- minib_indices = jax.lax.dynamic_slice(
1469
- self.indices,
1470
- start_indices=(self.curr_idx,),
1471
- slice_sizes=(self.obs_batch_size,),
1472
- )
1473
-
1474
- obs_batch = {
1475
- "pinn_in": jnp.take(
1476
- self.observed_pinn_in, minib_indices, unique_indices=True, axis=0
1477
- ),
1478
- "val": jnp.take(
1479
- self.observed_values, minib_indices, unique_indices=True, axis=0
1480
- ),
1481
- "eq_params": jax.tree_util.tree_map(
1482
- lambda a: jnp.take(a, minib_indices, unique_indices=True, axis=0),
1483
- self.observed_eq_params,
1484
- ),
1485
- }
1486
- return obs_batch
1487
-
1488
- def get_batch(self):
1489
- """
1490
- Generic method to return a batch
1491
- """
1492
- return self.obs_batch()
1493
-
1494
- def tree_flatten(self):
1495
- children = (self._key, self.curr_idx, self.indices)
1496
- aux_data = {
1497
- k: vars(self)[k]
1498
- for k in [
1499
- "obs_batch_size",
1500
- "observed_pinn_in",
1501
- "observed_values",
1502
- "observed_eq_params",
1503
- ]
1504
- }
1505
- return (children, aux_data)
1506
-
1507
- @classmethod
1508
- def tree_unflatten(cls, aux_data, children):
1509
- (key, curr_idx, indices) = children
1510
- obj = cls(
1511
- key=key,
1512
- data_exists=True,
1513
- obs_batch_size=aux_data["obs_batch_size"],
1514
- observed_pinn_in=aux_data["observed_pinn_in"],
1515
- observed_values=aux_data["observed_values"],
1516
- observed_eq_params=aux_data["observed_eq_params"],
1517
- )
1518
- obj.curr_idx = curr_idx
1519
- obj.indices = indices
1520
- return obj
1521
-
1522
-
1523
- @register_pytree_node_class
1524
- class DataGeneratorObservationsMultiPINNs:
1525
- """
1313
+ class DataGeneratorObservationsMultiPINNs(eqx.Module):
1314
+ r"""
1526
1315
  Despite the class name, it is rather a dataloader from user provided
1527
1316
  observations that will be used for the observations loss.
1528
1317
  This is the DataGenerator to use when dealing with multiple PINNs
@@ -1532,141 +1321,123 @@ class DataGeneratorObservationsMultiPINNs:
1532
1321
  applied in `constraints_system_loss_apply` and in this case the
1533
1322
  batch.obs_batch_dict is a dict of obs_batch_dict over which the tree_map
1534
1323
  applies (we select the obs_batch_dict corresponding to its `u_dict` entry)
1324
+
1325
+ Parameters
1326
+ ----------
1327
+ obs_batch_size : Int
1328
+ The size of the batch of randomly selected observations
1329
+ `obs_batch_size` will be the same for all the
1330
+ elements of the obs dict.
1331
+ NOTE: no check is done BUT users should be careful that
1332
+ `obs_batch_size` must be equal to `temporal_batch_size` or
1333
+ `omega_batch_size` or the product of both. In the first case, the
1334
+ present DataGeneratorObservations instance complements an ODEBatch,
1335
+ PDEStatioBatch or a PDENonStatioBatch (with self.cartesian_product
1336
+ = False). In the second case, `obs_batch_size` =
1337
+ `temporal_batch_size * omega_batch_size` if the present
1338
+ DataGeneratorParameter complements a PDENonStatioBatch
1339
+ with self.cartesian_product = True
1340
+ observed_pinn_in_dict : Dict[str, Float[Array, "n_obs nb_pinn_in"] | None]
1341
+ A dict of observed_pinn_in as defined in DataGeneratorObservations.
1342
+ Keys must be that of `u_dict`.
1343
+ If no observation exists for a particular entry of `u_dict` the
1344
+ corresponding key must still exist in observed_pinn_in_dict with
1345
+ value None
1346
+ observed_values_dict : Dict[str, Float[Array, "n_obs, nb_pinn_out"] | None]
1347
+ A dict of observed_values as defined in DataGeneratorObservations.
1348
+ Keys must be that of `u_dict`.
1349
+ If no observation exists for a particular entry of `u_dict` the
1350
+ corresponding key must still exist in observed_values_dict with
1351
+ value None
1352
+ observed_eq_params_dict : Dict[str, Dict[str, Float[Array, "n_obs 1"]]]
1353
+ A dict of observed_eq_params as defined in DataGeneratorObservations.
1354
+ Keys must be that of `u_dict`.
1355
+ **Note**: if no observation exists for a particular entry of `u_dict` the
1356
+ corresponding key must still exist in observed_eq_params_dict with
1357
+ value `{}` (empty dictionnary).
1358
+ key
1359
+ Jax random key to shuffle batches.
1535
1360
  """
1536
1361
 
1537
- def __init__(
1538
- self,
1539
- obs_batch_size,
1540
- observed_pinn_in_dict,
1541
- observed_values_dict,
1542
- observed_eq_params_dict=None,
1543
- data_gen_obs_exists=False,
1544
- key=None,
1545
- ):
1546
- r"""
1547
- Parameters
1548
- ----------
1549
- obs_batch_size
1550
- An integer. The size of the batch of randomly selected observations
1551
- `obs_batch_size` will be the same for all the
1552
- elements of the obs dict. `obs_batch_size` must be
1553
- equal to `temporal_batch_size` or `omega_batch_size` or the product
1554
- of both whether the present DataGeneratorParameter instance
1555
- complements and ODEBatch, a PDEStatioBatch or a PDENonStatioBatch,
1556
- respectively.
1557
- observed_pinn_in_dict
1558
- A dict of observed_pinn_in as defined in DataGeneratorObservations.
1559
- Keys must be that of `u_dict`.
1560
- If no observation exists for a particular entry of `u_dict` the
1561
- corresponding key must still exist in observed_pinn_in_dict with
1562
- value None
1563
- observed_values_dict
1564
- A dict of observed_values as defined in DataGeneratorObservations.
1565
- Keys must be that of `u_dict`.
1566
- If no observation exists for a particular entry of `u_dict` the
1567
- corresponding key must still exist in observed_values_dict with
1568
- value None
1569
- observed_eq_params_dict
1570
- A dict of observed_eq_params as defined in DataGeneratorObservations.
1571
- Keys must be that of `u_dict`.
1572
- If no observation exists for a particular entry of `u_dict` the
1573
- corresponding key must still exist in observed_eq_params_dict with
1574
- value None
1575
- data_gen_obs_exists
1576
- Must be left to `False` when created by the user. Avoids the
1577
- regeneration the subclasses DataGeneratorObservations
1578
- at each pytree flattening and unflattening.
1579
- key
1580
- Jax random key to sample new time points and to shuffle batches.
1581
- Optional if data_gen_obs_exists is True
1582
- """
1583
- if (
1584
- observed_pinn_in_dict is None or observed_values_dict is None
1585
- ) and not data_gen_obs_exists:
1362
+ obs_batch_size: Int
1363
+ observed_pinn_in_dict: Dict[str, Float[Array, "n_obs nb_pinn_in"] | None]
1364
+ observed_values_dict: Dict[str, Float[Array, "n_obs nb_pinn_out"] | None]
1365
+ observed_eq_params_dict: Dict[str, Dict[str, Float[Array, "n_obs 1"]]] = eqx.field(
1366
+ default=None, kw_only=True
1367
+ )
1368
+ key: InitVar[Key]
1369
+
1370
+ data_gen_obs: Dict[str, "DataGeneratorObservations"] = eqx.field(init=False)
1371
+
1372
+ def __post_init__(self, key):
1373
+ if self.observed_pinn_in_dict is None or self.observed_values_dict is None:
1586
1374
  raise ValueError(
1587
- "observed_pinn_in_dict and observed_values_dict "
1588
- "must be provided with data_gen_obs_exists is False"
1375
+ "observed_pinn_in_dict and observed_values_dict " "must be provided"
1376
+ )
1377
+ if self.observed_pinn_in_dict.keys() != self.observed_values_dict.keys():
1378
+ raise ValueError(
1379
+ "Keys must be the same in observed_pinn_in_dict"
1380
+ " and observed_values_dict"
1589
1381
  )
1590
- self.obs_batch_size = obs_batch_size
1591
- self.data_gen_obs_exists = data_gen_obs_exists
1592
1382
 
1593
- if not self.data_gen_obs_exists:
1594
- if observed_pinn_in_dict.keys() != observed_values_dict.keys():
1595
- raise ValueError(
1596
- "Keys must be the same in observed_pinn_in_dict"
1597
- " and observed_values_dict"
1598
- )
1599
- if (
1600
- observed_eq_params_dict is not None
1601
- and observed_pinn_in_dict.keys() != observed_eq_params_dict.keys()
1602
- ):
1603
- raise ValueError(
1604
- "Keys must be the same in observed_eq_params_dict"
1605
- " and observed_pinn_in_dict and observed_values_dict"
1606
- )
1607
- if observed_eq_params_dict is None:
1608
- observed_eq_params_dict = {
1609
- k: None for k in observed_pinn_in_dict.keys()
1610
- }
1611
-
1612
- keys = dict(
1613
- zip(
1614
- observed_pinn_in_dict.keys(),
1615
- jax.random.split(key, len(observed_pinn_in_dict)),
1616
- )
1383
+ if self.observed_eq_params_dict is None:
1384
+ self.observed_eq_params_dict = {
1385
+ k: {} for k in self.observed_pinn_in_dict.keys()
1386
+ }
1387
+ elif self.observed_pinn_in_dict.keys() != self.observed_eq_params_dict.keys():
1388
+ raise ValueError(
1389
+ f"Keys must be the same in observed_eq_params_dict"
1390
+ f" and observed_pinn_in_dict and observed_values_dict"
1617
1391
  )
1618
- self.data_gen_obs = jax.tree_util.tree_map(
1619
- lambda k, pinn_in, val, eq_params: (
1620
- DataGeneratorObservations(
1621
- k, obs_batch_size, pinn_in, val, eq_params
1622
- )
1623
- if pinn_in is not None
1624
- else None
1625
- ),
1626
- keys,
1627
- observed_pinn_in_dict,
1628
- observed_values_dict,
1629
- observed_eq_params_dict,
1392
+
1393
+ keys = dict(
1394
+ zip(
1395
+ self.observed_pinn_in_dict.keys(),
1396
+ jax.random.split(key, len(self.observed_pinn_in_dict)),
1630
1397
  )
1398
+ )
1399
+ self.data_gen_obs = jax.tree_util.tree_map(
1400
+ lambda k, pinn_in, val, eq_params: (
1401
+ DataGeneratorObservations(
1402
+ k, self.obs_batch_size, pinn_in, val, eq_params
1403
+ )
1404
+ if pinn_in is not None
1405
+ else None
1406
+ ),
1407
+ keys,
1408
+ self.observed_pinn_in_dict,
1409
+ self.observed_values_dict,
1410
+ self.observed_eq_params_dict,
1411
+ )
1631
1412
 
1632
- def obs_batch(self):
1413
+ def obs_batch(self) -> tuple["DataGeneratorObservationsMultiPINNs", PyTree]:
1633
1414
  """
1634
1415
  Returns a dictionary of DataGeneratorObservations.obs_batch with keys
1635
1416
  from `u_dict`
1636
1417
  """
1637
- return jax.tree_util.tree_map(
1418
+ data_gen_and_batch_pytree = jax.tree_util.tree_map(
1638
1419
  lambda a: a.get_batch() if a is not None else {},
1639
1420
  self.data_gen_obs,
1640
1421
  is_leaf=lambda x: isinstance(x, DataGeneratorObservations),
1641
1422
  ) # note the is_leaf note to traverse the DataGeneratorObservations and
1642
1423
  # thus to be able to call the method on the element(s) of
1643
1424
  # self.data_gen_obs which are not None
1425
+ new_attribute = jax.tree_util.tree_map(
1426
+ lambda a: a[0],
1427
+ data_gen_and_batch_pytree,
1428
+ is_leaf=lambda x: isinstance(x, tuple),
1429
+ )
1430
+ new = eqx.tree_at(lambda m: m.data_gen_obs, self, new_attribute)
1431
+ batches = jax.tree_util.tree_map(
1432
+ lambda a: a[1],
1433
+ data_gen_and_batch_pytree,
1434
+ is_leaf=lambda x: isinstance(x, tuple),
1435
+ )
1644
1436
 
1645
- def get_batch(self):
1437
+ return new, batches
1438
+
1439
+ def get_batch(self) -> tuple["DataGeneratorObservationsMultiPINNs", PyTree]:
1646
1440
  """
1647
1441
  Generic method to return a batch
1648
1442
  """
1649
1443
  return self.obs_batch()
1650
-
1651
- def tree_flatten(self):
1652
- # because a dict with "str" keys cannot go in the children (jittable)
1653
- # attributes, we need to separate it in two and recreate the zip in the
1654
- # tree_unflatten
1655
- children = self.data_gen_obs.values()
1656
- aux_data = {
1657
- "obs_batch_size": self.obs_batch_size,
1658
- "data_gen_obs_keys": self.data_gen_obs.keys(),
1659
- }
1660
- return (children, aux_data)
1661
-
1662
- @classmethod
1663
- def tree_unflatten(cls, aux_data, children):
1664
- (data_gen_obs_values) = children
1665
- obj = cls(
1666
- observed_pinn_in_dict=None,
1667
- observed_values_dict=None,
1668
- data_gen_obs_exists=True,
1669
- obs_batch_size=aux_data["obs_batch_size"],
1670
- )
1671
- obj.data_gen_obs = dict(zip(aux_data["data_gen_obs_keys"], data_gen_obs_values))
1672
- return obj